pax_global_header00006660000000000000000000000064145557721640014532gustar00rootroot0000000000000052 comment=281d4fcf60e2f037fa09d62be2543bb4d7511f2b pgvector-0.6.0/000077500000000000000000000000001455577216400133665ustar00rootroot00000000000000pgvector-0.6.0/.dockerignore000066400000000000000000000001211455577216400160340ustar00rootroot00000000000000/.git/ /dist/ /results/ /tmp_check/ /sql/vector--?.?.?.sql regression.* *.o *.so pgvector-0.6.0/.editorconfig000066400000000000000000000001241455577216400160400ustar00rootroot00000000000000root = true [*.{c,h,pl,pm,sql}] indent_style = tab indent_size = tab tab_width = 4 pgvector-0.6.0/.github/000077500000000000000000000000001455577216400147265ustar00rootroot00000000000000pgvector-0.6.0/.github/workflows/000077500000000000000000000000001455577216400167635ustar00rootroot00000000000000pgvector-0.6.0/.github/workflows/build.yml000066400000000000000000000067211455577216400206130ustar00rootroot00000000000000name: build on: [push, pull_request] jobs: ubuntu: runs-on: ${{ matrix.os }} if: ${{ !startsWith(github.ref_name, 'mac') && !startsWith(github.ref_name, 'windows') }} strategy: fail-fast: false matrix: include: - postgres: 17 os: ubuntu-22.04 - postgres: 16 os: ubuntu-22.04 - postgres: 15 os: ubuntu-22.04 - postgres: 14 os: ubuntu-22.04 - postgres: 13 os: ubuntu-20.04 - postgres: 12 os: ubuntu-20.04 steps: - uses: actions/checkout@v4 - uses: ankane/setup-postgres@v1 with: postgres-version: ${{ matrix.postgres }} dev-files: true - run: make env: PG_CFLAGS: -Wall -Wextra -Werror -Wno-unused-parameter -Wno-sign-compare - run: | export PG_CONFIG=`which pg_config` sudo --preserve-env=PG_CONFIG make install - run: make installcheck - if: ${{ failure() }} run: cat regression.diffs - run: | sudo apt-get update sudo apt-get install libipc-run-perl - run: make prove_installcheck mac: runs-on: macos-latest if: ${{ !startsWith(github.ref_name, 'windows') }} steps: - uses: actions/checkout@v4 - uses: ankane/setup-postgres@v1 with: postgres-version: 14 - run: make env: PG_CFLAGS: -Wall -Wextra -Werror -Wno-unused-parameter - run: make install - run: make installcheck - if: ${{ failure() }} run: cat regression.diffs - run: | brew install cpanm cpanm --notest IPC::Run wget -q https://github.com/postgres/postgres/archive/refs/tags/REL_14_10.tar.gz tar xf REL_14_10.tar.gz - run: make prove_installcheck PROVE_FLAGS="-I ./postgres-REL_14_10/src/test/perl" PERL5LIB="/Users/runner/perl5/lib/perl5" - run: make clean && /usr/local/opt/llvm@15/bin/scan-build --status-bugs make PG_CFLAGS="-DUSE_ASSERT_CHECKING" windows: runs-on: windows-latest if: ${{ !startsWith(github.ref_name, 'mac') }} steps: - uses: actions/checkout@v4 - uses: ankane/setup-postgres@v1 with: postgres-version: 14 - run: | call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" && ^ nmake /NOLOGO /F Makefile.win && ^ nmake /NOLOGO /F Makefile.win install && ^ nmake /NOLOGO /F Makefile.win installcheck && ^ nmake /NOLOGO /F Makefile.win clean && ^ nmake /NOLOGO /F Makefile.win uninstall shell: cmd i386: if: ${{ !startsWith(github.ref_name, 'mac') && !startsWith(github.ref_name, 'windows') }} runs-on: ubuntu-latest container: image: debian:12 options: --platform linux/386 steps: - run: apt-get update && apt-get install -y build-essential git libipc-run-perl postgresql-15 postgresql-server-dev-15 sudo - run: service postgresql start - run: | git clone https://github.com/${{ github.repository }}.git pgvector cd pgvector git fetch origin ${{ github.ref }} git reset --hard FETCH_HEAD make make install chown -R postgres . sudo -u postgres make installcheck sudo -u postgres make prove_installcheck env: PG_CFLAGS: -Wall -Wextra -Werror -Wno-unused-parameter -Wno-sign-compare pgvector-0.6.0/.gitignore000066400000000000000000000001651455577216400153600ustar00rootroot00000000000000/dist/ /log/ /results/ /tmp_check/ /sql/vector--?.?.?.sql regression.* *.o *.so *.bc *.dll *.dylib *.obj *.lib *.exp pgvector-0.6.0/CHANGELOG.md000066400000000000000000000104631455577216400152030ustar00rootroot00000000000000## 0.6.0 (2024-01-29) If upgrading with Postgres 12 or Docker, see [these notes](https://github.com/pgvector/pgvector#060). - Changed storage for vector from `extended` to `external` - Added support for parallel index builds for HNSW - Added validation for GUC parameters - Improved performance of HNSW - Reduced memory usage for HNSW index builds - Reduced WAL generation for HNSW index builds - Fixed error with logical replication - Fixed `invalid memory alloc request size` error with HNSW index builds - Moved Docker image to `pgvector` org - Added Docker tags for each supported version of Postgres - Dropped support for Postgres 11 ## 0.5.1 (2023-10-10) - Improved performance of HNSW index builds - Added check for MVCC-compliant snapshot for index scans ## 0.5.0 (2023-08-28) - Added HNSW index type - Added support for parallel index builds for IVFFlat - Added `l1_distance` function - Added element-wise multiplication for vectors - Added `sum` aggregate - Improved performance of distance functions - Fixed out of range results for cosine distance - Fixed results for NULL and NaN distances for IVFFlat ## 0.4.4 (2023-06-12) - Improved error message for malformed vector literal - Fixed segmentation fault with text input - Fixed consecutive delimiters with text input ## 0.4.3 (2023-06-10) - Improved cost estimation - Improved support for spaces with text input - Fixed infinite and NaN values with binary input - Fixed infinite values with vector addition and subtraction - Fixed infinite values with list centers - Fixed compilation error when `float8` is pass by reference - Fixed compilation error on PowerPC - Fixed segmentation fault with index creation on i386 ## 0.4.2 (2023-05-13) - Added notice when index created with little data - Fixed dimensions check for some direct function calls - Fixed installation error with Postgres 12.0-12.2 ## 0.4.1 (2023-03-21) - Improved performance of cosine distance - Fixed index scan count ## 0.4.0 (2023-01-11) If upgrading with Postgres < 13, see [this note](https://github.com/pgvector/pgvector/blob/v0.4.0/README.md#040). - Changed text representation for vector elements to match `real` - Changed storage for vector from `plain` to `extended` - Increased max dimensions for vector from 1024 to 16000 - Increased max dimensions for index from 1024 to 2000 - Improved accuracy of text parsing for certain inputs - Added `avg` aggregate for vector - Added experimental support for Windows - Dropped support for Postgres 10 ## 0.3.2 (2022-11-22) - Fixed `invalid memory alloc request size` error ## 0.3.1 (2022-11-02) If upgrading from 0.2.7 or 0.3.0, [recreate](https://github.com/pgvector/pgvector/blob/v0.3.1/README.md#031) all `ivfflat` indexes after upgrading to ensure all data is indexed. - Fixed issue with inserts silently corrupting `ivfflat` indexes (introduced in 0.2.7) - Fixed segmentation fault with index creation when lists > 6500 ## 0.3.0 (2022-10-15) - Added support for Postgres 15 - Dropped support for Postgres 9.6 ## 0.2.7 (2022-07-31) - Fixed `unexpected data beyond EOF` error ## 0.2.6 (2022-05-22) - Improved performance of index creation for Postgres < 12 ## 0.2.5 (2022-02-11) - Reduced memory usage during index creation - Fixed index creation exceeding `maintenance_work_mem` - Fixed error with index creation when lists > 1600 ## 0.2.4 (2022-02-06) - Added support for parallel vacuum - Fixed issue with index not reusing space ## 0.2.3 (2022-01-30) - Added indexing progress for Postgres 12+ - Improved interrupt handling during index creation ## 0.2.2 (2022-01-15) - Fixed compilation error on Mac ARM ## 0.2.1 (2022-01-02) - Fixed `operator is not unique` error ## 0.2.0 (2021-10-03) - Added support for Postgres 14 ## 0.1.8 (2021-09-07) - Added cast for `vector` to `real[]` ## 0.1.7 (2021-06-13) - Added cast for `numeric[]` to `vector` ## 0.1.6 (2021-06-09) - Fixed segmentation fault with `COUNT` ## 0.1.5 (2021-05-25) - Reduced memory usage during index creation ## 0.1.4 (2021-05-09) - Fixed kmeans for inner product - Fixed multiple definition error with GCC 10 ## 0.1.3 (2021-05-06) - Added Dockerfile - Fixed version ## 0.1.2 (2021-04-26) - Vectorized distance calculations - Improved cost estimation ## 0.1.1 (2021-04-25) - Added binary representation for `COPY` - Marked functions as `PARALLEL SAFE` ## 0.1.0 (2021-04-20) - First release pgvector-0.6.0/Dockerfile000066400000000000000000000011201455577216400153520ustar00rootroot00000000000000ARG PG_MAJOR=16 FROM postgres:$PG_MAJOR ARG PG_MAJOR COPY . /tmp/pgvector RUN apt-get update && \ apt-mark hold locales && \ apt-get install -y --no-install-recommends build-essential postgresql-server-dev-$PG_MAJOR && \ cd /tmp/pgvector && \ make clean && \ make OPTFLAGS="" && \ make install && \ mkdir /usr/share/doc/pgvector && \ cp LICENSE README.md /usr/share/doc/pgvector && \ rm -r /tmp/pgvector && \ apt-get remove -y build-essential postgresql-server-dev-$PG_MAJOR && \ apt-get autoremove -y && \ apt-mark unhold locales && \ rm -rf /var/lib/apt/lists/* pgvector-0.6.0/LICENSE000066400000000000000000000021201455577216400143660ustar00rootroot00000000000000Portions Copyright (c) 1996-2023, PostgreSQL Global Development Group Portions Copyright (c) 1994, The Regents of the University of California Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. pgvector-0.6.0/META.json000066400000000000000000000021211455577216400150030ustar00rootroot00000000000000{ "name": "vector", "abstract": "Open-source vector similarity search for Postgres", "description": "Supports L2 distance, inner product, and cosine distance", "version": "0.6.0", "maintainer": [ "Andrew Kane " ], "license": { "PostgreSQL": "http://www.postgresql.org/about/licence" }, "prereqs": { "runtime": { "requires": { "PostgreSQL": "12.0.0" } } }, "provides": { "vector": { "file": "sql/vector.sql", "docfile": "README.md", "version": "0.6.0", "abstract": "Open-source vector similarity search for Postgres" } }, "resources": { "homepage": "https://github.com/pgvector/pgvector", "bugtracker": { "web": "https://github.com/pgvector/pgvector/issues" }, "repository": { "url": "https://github.com/pgvector/pgvector.git", "web": "https://github.com/pgvector/pgvector", "type": "git" } }, "generated_by": "Andrew Kane", "meta-spec": { "version": "1.0.0", "url": "http://pgxn.org/meta/spec.txt" }, "tags": [ "vectors", "datatype", "nearest neighbor search", "approximate nearest neighbors" ] } pgvector-0.6.0/Makefile000066400000000000000000000050571455577216400150350ustar00rootroot00000000000000EXTENSION = vector EXTVERSION = 0.6.0 MODULE_big = vector DATA = $(wildcard sql/*--*.sql) OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o HEADERS = src/vector.h TESTS = $(wildcard test/sql/*.sql) REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS)) REGRESS_OPTS = --inputdir=test --load-extension=$(EXTENSION) OPTFLAGS = -march=native # Mac ARM doesn't support -march=native ifeq ($(shell uname -s), Darwin) ifeq ($(shell uname -p), arm) # no difference with -march=armv8.5-a OPTFLAGS = endif endif # PowerPC doesn't support -march=native ifneq ($(filter ppc64%, $(shell uname -m)), ) OPTFLAGS = endif # For auto-vectorization: # - GCC (needs -ftree-vectorize OR -O3) - https://gcc.gnu.org/projects/tree-ssa/vectorization.html # - Clang (could use pragma instead) - https://llvm.org/docs/Vectorizers.html PG_CFLAGS += $(OPTFLAGS) -ftree-vectorize -fassociative-math -fno-signed-zeros -fno-trapping-math # Debug GCC auto-vectorization # PG_CFLAGS += -fopt-info-vec # Debug Clang auto-vectorization # PG_CFLAGS += -Rpass=loop-vectorize -Rpass-analysis=loop-vectorize all: sql/$(EXTENSION)--$(EXTVERSION).sql sql/$(EXTENSION)--$(EXTVERSION).sql: sql/$(EXTENSION).sql cp $< $@ EXTRA_CLEAN = sql/$(EXTENSION)--$(EXTVERSION).sql PG_CONFIG ?= pg_config PGXS := $(shell $(PG_CONFIG) --pgxs) include $(PGXS) # for Mac ifeq ($(PROVE),) PROVE = prove endif # for Postgres 15 PROVE_FLAGS += -I ./test/perl prove_installcheck: rm -rf $(CURDIR)/tmp_check cd $(srcdir) && TESTDIR='$(CURDIR)' PATH="$(bindir):$$PATH" PGPORT='6$(DEF_PGPORT)' PG_REGRESS='$(top_builddir)/src/test/regress/pg_regress' $(PROVE) $(PG_PROVE_FLAGS) $(PROVE_FLAGS) $(if $(PROVE_TESTS),$(PROVE_TESTS),test/t/*.pl) .PHONY: dist dist: mkdir -p dist git archive --format zip --prefix=$(EXTENSION)-$(EXTVERSION)/ --output dist/$(EXTENSION)-$(EXTVERSION).zip master # for Docker PG_MAJOR ?= 16 .PHONY: docker docker: docker build --pull --no-cache --build-arg PG_MAJOR=$(PG_MAJOR) -t pgvector/pgvector:pg$(PG_MAJOR) . docker build --build-arg PG_MAJOR=$(PG_MAJOR) -t pgvector/pgvector:$(EXTVERSION)-pg$(PG_MAJOR) . .PHONY: docker-release docker-release: docker buildx build --push --pull --no-cache --platform linux/amd64,linux/arm64 --build-arg PG_MAJOR=$(PG_MAJOR) -t pgvector/pgvector:pg$(PG_MAJOR) . docker buildx build --push --platform linux/amd64,linux/arm64 --build-arg PG_MAJOR=$(PG_MAJOR) -t pgvector/pgvector:$(EXTVERSION)-pg$(PG_MAJOR) . pgvector-0.6.0/Makefile.win000066400000000000000000000050071455577216400156240ustar00rootroot00000000000000EXTENSION = vector EXTVERSION = 0.6.0 OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj HEADERS = src\vector.h REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged REGRESS_OPTS = --inputdir=test --load-extension=$(EXTENSION) # For /arch flags # https://learn.microsoft.com/en-us/cpp/build/reference/arch-minimum-cpu-architecture OPTFLAGS = # For auto-vectorization: # - MSVC (needs /O2 /fp:fast) - https://learn.microsoft.com/en-us/cpp/parallel/auto-parallelization-and-auto-vectorization?#auto-vectorizer PG_CFLAGS = $(PG_CFLAGS) $(OPTFLAGS) /O2 /fp:fast # Debug MSVC auto-vectorization # https://learn.microsoft.com/en-us/cpp/error-messages/tool-errors/vectorizer-and-parallelizer-messages # PG_CFLAGS = $(PG_CFLAGS) /Qvec-report:2 all: sql\$(EXTENSION)--$(EXTVERSION).sql sql\$(EXTENSION)--$(EXTVERSION).sql: sql\$(EXTENSION).sql copy sql\$(EXTENSION).sql $@ # TODO use pg_config !ifndef PGROOT !error PGROOT is not set !endif BINDIR = $(PGROOT)\bin INCLUDEDIR = $(PGROOT)\include INCLUDEDIR_SERVER = $(PGROOT)\include\server LIBDIR = $(PGROOT)\lib PKGLIBDIR = $(PGROOT)\lib SHAREDIR = $(PGROOT)\share CFLAGS = /nologo /I"$(INCLUDEDIR_SERVER)\port\win32_msvc" /I"$(INCLUDEDIR_SERVER)\port\win32" /I"$(INCLUDEDIR_SERVER)" /I"$(INCLUDEDIR)" CFLAGS = $(CFLAGS) $(PG_CFLAGS) SHLIB = $(EXTENSION).dll LIBS = "$(LIBDIR)\postgres.lib" .c.obj: $(CC) $(CFLAGS) /c $< /Fo$@ $(SHLIB): $(OBJS) $(CC) $(CFLAGS) $(OBJS) $(LIBS) /link /DLL /OUT:$(SHLIB) all: $(SHLIB) install: copy $(SHLIB) "$(PKGLIBDIR)" copy $(EXTENSION).control "$(SHAREDIR)\extension" copy sql\$(EXTENSION)--*.sql "$(SHAREDIR)\extension" mkdir "$(INCLUDEDIR_SERVER)\extension\$(EXTENSION)" for %f in ($(HEADERS)) do copy %f "$(INCLUDEDIR_SERVER)\extension\$(EXTENSION)" installcheck: "$(BINDIR)\pg_regress" --bindir="$(BINDIR)" $(REGRESS_OPTS) $(REGRESS) uninstall: del /f "$(PKGLIBDIR)\$(SHLIB)" del /f "$(SHAREDIR)\extension\$(EXTENSION).control" del /f "$(SHAREDIR)\extension\$(EXTENSION)--*.sql" del /f "$(INCLUDEDIR_SERVER)\extension\$(EXTENSION)\*.h" rmdir "$(INCLUDEDIR_SERVER)\extension\$(EXTENSION)" clean: del /f $(SHLIB) $(EXTENSION).lib $(EXTENSION).exp del /f $(OBJS) del /f sql\$(EXTENSION)--$(EXTVERSION).sql del /f /s /q results regression.diffs regression.out tmp_check tmp_check_iso log output_iso pgvector-0.6.0/README.md000066400000000000000000000602421455577216400146510ustar00rootroot00000000000000# pgvector Open-source vector similarity search for Postgres Store your vectors with the rest of your data. Supports: - exact and approximate nearest neighbor search - L2 distance, inner product, and cosine distance - any [language](#languages) with a Postgres client Plus [ACID](https://en.wikipedia.org/wiki/ACID) compliance, point-in-time recovery, JOINs, and all of the other [great features](https://www.postgresql.org/about/) of Postgres [![Build Status](https://github.com/pgvector/pgvector/actions/workflows/build.yml/badge.svg)](https://github.com/pgvector/pgvector/actions) ## Installation ### Linux and Mac Compile and install the extension (supports Postgres 11+) ```sh cd /tmp git clone --branch v0.6.0 https://github.com/pgvector/pgvector.git cd pgvector make make install # may need sudo ``` See the [installation notes](#installation-notes) if you run into issues You can also install it with [Docker](#docker), [Homebrew](#homebrew), [PGXN](#pgxn), [APT](#apt), [Yum](#yum), or [conda-forge](#conda-forge), and it comes preinstalled with [Postgres.app](#postgresapp) and many [hosted providers](#hosted-postgres). There are also instructions for [GitHub Actions](https://github.com/pgvector/setup-pgvector). ### Windows Ensure [C++ support in Visual Studio](https://learn.microsoft.com/en-us/cpp/build/building-on-the-command-line?view=msvc-170#download-and-install-the-tools) is installed, and run: ```cmd call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" ``` Note: The exact path will vary depending on your Visual Studio version and edition Then use `nmake` to build: ```cmd set "PGROOT=C:\Program Files\PostgreSQL\16" git clone --branch v0.6.0 https://github.com/pgvector/pgvector.git cd pgvector nmake /F Makefile.win nmake /F Makefile.win install ``` You can also install it with [Docker](#docker) or [conda-forge](#conda-forge). ## Getting Started Enable the extension (do this once in each database where you want to use it) ```tsql CREATE EXTENSION vector; ``` Create a vector column with 3 dimensions ```sql CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3)); ``` Insert vectors ```sql INSERT INTO items (embedding) VALUES ('[1,2,3]'), ('[4,5,6]'); ``` Get the nearest neighbors by L2 distance ```sql SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5; ``` Also supports inner product (`<#>`) and cosine distance (`<=>`) Note: `<#>` returns the negative inner product since Postgres only supports `ASC` order index scans on operators ## Storing Create a new table with a vector column ```sql CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3)); ``` Or add a vector column to an existing table ```sql ALTER TABLE items ADD COLUMN embedding vector(3); ``` Insert vectors ```sql INSERT INTO items (embedding) VALUES ('[1,2,3]'), ('[4,5,6]'); ``` Upsert vectors ```sql INSERT INTO items (id, embedding) VALUES (1, '[1,2,3]'), (2, '[4,5,6]') ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding; ``` Update vectors ```sql UPDATE items SET embedding = '[1,2,3]' WHERE id = 1; ``` Delete vectors ```sql DELETE FROM items WHERE id = 1; ``` ## Querying Get the nearest neighbors to a vector ```sql SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5; ``` Get the nearest neighbors to a row ```sql SELECT * FROM items WHERE id != 1 ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = 1) LIMIT 5; ``` Get rows within a certain distance ```sql SELECT * FROM items WHERE embedding <-> '[3,1,2]' < 5; ``` Note: Combine with `ORDER BY` and `LIMIT` to use an index #### Distances Get the distance ```sql SELECT embedding <-> '[3,1,2]' AS distance FROM items; ``` For inner product, multiply by -1 (since `<#>` returns the negative inner product) ```tsql SELECT (embedding <#> '[3,1,2]') * -1 AS inner_product FROM items; ``` For cosine similarity, use 1 - cosine distance ```sql SELECT 1 - (embedding <=> '[3,1,2]') AS cosine_similarity FROM items; ``` #### Aggregates Average vectors ```sql SELECT AVG(embedding) FROM items; ``` Average groups of vectors ```sql SELECT category_id, AVG(embedding) FROM items GROUP BY category_id; ``` ## Indexing By default, pgvector performs exact nearest neighbor search, which provides perfect recall. You can add an index to use approximate nearest neighbor search, which trades some recall for speed. Unlike typical indexes, you will see different results for queries after adding an approximate index. Supported index types are: - [HNSW](#hnsw) - added in 0.5.0 - [IVFFlat](#ivfflat) ## HNSW An HNSW index creates a multilayer graph. It has better query performance than IVFFlat (in terms of speed-recall tradeoff), but has slower build times and uses more memory. Also, an index can be created without any data in the table since there isn’t a training step like IVFFlat. Add an index for each distance function you want to use. L2 distance ```sql CREATE INDEX ON items USING hnsw (embedding vector_l2_ops); ``` Inner product ```sql CREATE INDEX ON items USING hnsw (embedding vector_ip_ops); ``` Cosine distance ```sql CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops); ``` Vectors with up to 2,000 dimensions can be indexed. ### Index Options Specify HNSW parameters - `m` - the max number of connections per layer (16 by default) - `ef_construction` - the size of the dynamic candidate list for constructing the graph (64 by default) ```sql CREATE INDEX ON items USING hnsw (embedding vector_l2_ops) WITH (m = 16, ef_construction = 64); ``` A higher value of `ef_construction` provides better recall at the cost of index build time / insert speed. ### Query Options Specify the size of the dynamic candidate list for search (40 by default) ```sql SET hnsw.ef_search = 100; ``` A higher value provides better recall at the cost of speed. Use `SET LOCAL` inside a transaction to set it for a single query ```sql BEGIN; SET LOCAL hnsw.ef_search = 100; SELECT ... COMMIT; ``` ### Index Build Time Indexes build significantly faster when the graph fits into `maintenance_work_mem` ```sql SET maintenance_work_mem = '8GB'; ``` A notice is shown when the graph no longer fits ```text NOTICE: hnsw graph no longer fits into maintenance_work_mem after 100000 tuples DETAIL: Building will take significantly more time. HINT: Increase maintenance_work_mem to speed up builds. ``` Note: Do not set `maintenance_work_mem` so high that it exhausts the memory on the server Starting with 0.6.0, you can also speed up index creation by increasing the number of parallel workers (2 by default) ```sql SET max_parallel_maintenance_workers = 7; -- plus leader ``` For a large number of workers, you may also need to increase `max_parallel_workers` (8 by default) ### Indexing Progress Check [indexing progress](https://www.postgresql.org/docs/current/progress-reporting.html#CREATE-INDEX-PROGRESS-REPORTING) with Postgres 12+ ```sql SELECT phase, round(100.0 * blocks_done / nullif(blocks_total, 0), 1) AS "%" FROM pg_stat_progress_create_index; ``` The phases for HNSW are: 1. `initializing` 2. `loading tuples` ## IVFFlat An IVFFlat index divides vectors into lists, and then searches a subset of those lists that are closest to the query vector. It has faster build times and uses less memory than HNSW, but has lower query performance (in terms of speed-recall tradeoff). Three keys to achieving good recall are: 1. Create the index *after* the table has some data 2. Choose an appropriate number of lists - a good place to start is `rows / 1000` for up to 1M rows and `sqrt(rows)` for over 1M rows 3. When querying, specify an appropriate number of [probes](#query-options) (higher is better for recall, lower is better for speed) - a good place to start is `sqrt(lists)` Add an index for each distance function you want to use. L2 distance ```sql CREATE INDEX ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100); ``` Inner product ```sql CREATE INDEX ON items USING ivfflat (embedding vector_ip_ops) WITH (lists = 100); ``` Cosine distance ```sql CREATE INDEX ON items USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); ``` Vectors with up to 2,000 dimensions can be indexed. ### Query Options Specify the number of probes (1 by default) ```sql SET ivfflat.probes = 10; ``` A higher value provides better recall at the cost of speed, and it can be set to the number of lists for exact nearest neighbor search (at which point the planner won’t use the index) Use `SET LOCAL` inside a transaction to set it for a single query ```sql BEGIN; SET LOCAL ivfflat.probes = 10; SELECT ... COMMIT; ``` ### Index Build Time Speed up index creation on large tables by increasing the number of parallel workers (2 by default) ```sql SET max_parallel_maintenance_workers = 7; -- plus leader ``` For a large number of workers, you may also need to increase `max_parallel_workers` (8 by default) ### Indexing Progress Check [indexing progress](https://www.postgresql.org/docs/current/progress-reporting.html#CREATE-INDEX-PROGRESS-REPORTING) with Postgres 12+ ```sql SELECT phase, round(100.0 * tuples_done / nullif(tuples_total, 0), 1) AS "%" FROM pg_stat_progress_create_index; ``` The phases for IVFFlat are: 1. `initializing` 2. `performing k-means` 3. `assigning tuples` 4. `loading tuples` Note: `%` is only populated during the `loading tuples` phase ## Filtering There are a few ways to index nearest neighbor queries with a `WHERE` clause ```sql SELECT * FROM items WHERE category_id = 123 ORDER BY embedding <-> '[3,1,2]' LIMIT 5; ``` Create an index on one [or more](https://www.postgresql.org/docs/current/indexes-multicolumn.html) of the `WHERE` columns for exact search ```sql CREATE INDEX ON items (category_id); ``` Or a [partial index](https://www.postgresql.org/docs/current/indexes-partial.html) on the vector column for approximate search ```sql CREATE INDEX ON items USING hnsw (embedding vector_l2_ops) WHERE (category_id = 123); ``` Use [partitioning](https://www.postgresql.org/docs/current/ddl-partitioning.html) for approximate search on many different values of the `WHERE` columns ```sql CREATE TABLE items (embedding vector(3), category_id int) PARTITION BY LIST(category_id); ``` ## Hybrid Search Use together with Postgres [full-text search](https://www.postgresql.org/docs/current/textsearch-intro.html) for hybrid search. ```sql SELECT id, content FROM items, plainto_tsquery('hello search') query WHERE textsearch @@ query ORDER BY ts_rank_cd(textsearch, query) DESC LIMIT 5; ``` You can use [Reciprocal Rank Fusion](https://github.com/pgvector/pgvector-python/blob/master/examples/hybrid_search_rrf.py) or a [cross-encoder](https://github.com/pgvector/pgvector-python/blob/master/examples/hybrid_search.py) to combine results. ## Performance Use `EXPLAIN ANALYZE` to debug performance. ```sql EXPLAIN ANALYZE SELECT * FROM items ORDER BY embedding <-> '[3,1,2]' LIMIT 5; ``` ### Exact Search To speed up queries without an index, increase `max_parallel_workers_per_gather`. ```sql SET max_parallel_workers_per_gather = 4; ``` If vectors are normalized to length 1 (like [OpenAI embeddings](https://platform.openai.com/docs/guides/embeddings/which-distance-function-should-i-use)), use inner product for best performance. ```tsql SELECT * FROM items ORDER BY embedding <#> '[3,1,2]' LIMIT 5; ``` ### Approximate Search To speed up queries with an IVFFlat index, increase the number of inverted lists (at the expense of recall). ```sql CREATE INDEX ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 1000); ``` ## Languages Use pgvector from any language with a Postgres client. You can even generate and store vectors in one language and query them in another. Language | Libraries / Examples --- | --- C | [pgvector-c](https://github.com/pgvector/pgvector-c) C++ | [pgvector-cpp](https://github.com/pgvector/pgvector-cpp) C#, F#, Visual Basic | [pgvector-dotnet](https://github.com/pgvector/pgvector-dotnet) Crystal | [pgvector-crystal](https://github.com/pgvector/pgvector-crystal) Dart | [pgvector-dart](https://github.com/pgvector/pgvector-dart) Elixir | [pgvector-elixir](https://github.com/pgvector/pgvector-elixir) Go | [pgvector-go](https://github.com/pgvector/pgvector-go) Haskell | [pgvector-haskell](https://github.com/pgvector/pgvector-haskell) Java, Kotlin, Groovy, Scala | [pgvector-java](https://github.com/pgvector/pgvector-java) JavaScript, TypeScript | [pgvector-node](https://github.com/pgvector/pgvector-node) Julia | [pgvector-julia](https://github.com/pgvector/pgvector-julia) Lisp | [pgvector-lisp](https://github.com/pgvector/pgvector-lisp) Lua | [pgvector-lua](https://github.com/pgvector/pgvector-lua) Nim | [pgvector-nim](https://github.com/pgvector/pgvector-nim) OCaml | [pgvector-ocaml](https://github.com/pgvector/pgvector-ocaml) Perl | [pgvector-perl](https://github.com/pgvector/pgvector-perl) PHP | [pgvector-php](https://github.com/pgvector/pgvector-php) Python | [pgvector-python](https://github.com/pgvector/pgvector-python) R | [pgvector-r](https://github.com/pgvector/pgvector-r) Ruby | [pgvector-ruby](https://github.com/pgvector/pgvector-ruby), [Neighbor](https://github.com/ankane/neighbor) Rust | [pgvector-rust](https://github.com/pgvector/pgvector-rust) Swift | [pgvector-swift](https://github.com/pgvector/pgvector-swift) Zig | [pgvector-zig](https://github.com/pgvector/pgvector-zig) ## Frequently Asked Questions #### How many vectors can be stored in a single table? A non-partitioned table has a limit of 32 TB by default in Postgres. A partitioned table can have thousands of partitions of that size. #### Is replication supported? Yes, pgvector uses the write-ahead log (WAL), which allows for replication and point-in-time recovery. #### What if I want to index vectors with more than 2,000 dimensions? You’ll need to use [dimensionality reduction](https://en.wikipedia.org/wiki/Dimensionality_reduction) at the moment. #### Can I store vectors with different dimensions in the same column? You can use `vector` as the type (instead of `vector(3)`). ```sql CREATE TABLE embeddings (model_id bigint, item_id bigint, embedding vector, PRIMARY KEY (model_id, item_id)); ``` However, you can only create indexes on rows with the same number of dimensions (using [expression](https://www.postgresql.org/docs/current/indexes-expressional.html) and [partial](https://www.postgresql.org/docs/current/indexes-partial.html) indexing): ```sql CREATE INDEX ON embeddings USING hnsw ((embedding::vector(3)) vector_l2_ops) WHERE (model_id = 123); ``` and query with: ```sql SELECT * FROM embeddings WHERE model_id = 123 ORDER BY embedding::vector(3) <-> '[3,1,2]' LIMIT 5; ``` #### Can I store vectors with more precision? You can use the `double precision[]` or `numeric[]` type to store vectors with more precision. ```sql CREATE TABLE items (id bigserial PRIMARY KEY, embedding double precision[]); -- use {} instead of [] for Postgres arrays INSERT INTO items (embedding) VALUES ('{1,2,3}'), ('{4,5,6}'); ``` Optionally, add a [check constraint](https://www.postgresql.org/docs/current/ddl-constraints.html) to ensure data can be converted to the `vector` type and has the expected dimensions. ```sql ALTER TABLE items ADD CHECK (vector_dims(embedding::vector) = 3); ``` Use [expression indexing](https://www.postgresql.org/docs/current/indexes-expressional.html) to index (at a lower precision): ```sql CREATE INDEX ON items USING hnsw ((embedding::vector(3)) vector_l2_ops); ``` and query with: ```sql SELECT * FROM items ORDER BY embedding::vector(3) <-> '[3,1,2]' LIMIT 5; ``` #### Do indexes need to fit into memory? No, but like other index types, you’ll likely see better performance if they do. You can get the size of an index with: ```sql SELECT pg_size_pretty(pg_relation_size('index_name')); ``` ## Troubleshooting #### Why isn’t a query using an index? The cost estimation in pgvector < 0.4.3 does not always work well with the planner. You can encourage the planner to use an index for a query with: ```sql BEGIN; SET LOCAL enable_seqscan = off; SELECT ... COMMIT; ``` Also, if the table is small, a table scan may be faster. #### Why isn’t a query using a parallel table scan? The planner doesn’t consider [out-of-line storage](https://www.postgresql.org/docs/current/storage-toast.html) in cost estimates, which can make a serial scan look cheaper. You can reduce the cost of a parallel scan for a query with: ```sql BEGIN; SET LOCAL min_parallel_table_scan_size = 1; SET LOCAL parallel_setup_cost = 1; SELECT ... COMMIT; ``` or choose to store vectors inline: ```sql ALTER TABLE items ALTER COLUMN embedding SET STORAGE PLAIN; ``` #### Why are there less results for a query after adding an IVFFlat index? The index was likely created with too little data for the number of lists. Drop the index until the table has more data. ```sql DROP INDEX index_name; ``` ## Reference ### Vector Type Each vector takes `4 * dimensions + 8` bytes of storage. Each element is a single precision floating-point number (like the `real` type in Postgres), and all elements must be finite (no `NaN`, `Infinity` or `-Infinity`). Vectors can have up to 16,000 dimensions. ### Vector Operators Operator | Description | Added --- | --- | --- \+ | element-wise addition | \- | element-wise subtraction | \* | element-wise multiplication | 0.5.0 <-> | Euclidean distance | <#> | negative inner product | <=> | cosine distance | ### Vector Functions Function | Description | Added --- | --- | --- cosine_distance(vector, vector) → double precision | cosine distance | inner_product(vector, vector) → double precision | inner product | l2_distance(vector, vector) → double precision | Euclidean distance | l1_distance(vector, vector) → double precision | taxicab distance | 0.5.0 vector_dims(vector) → integer | number of dimensions | vector_norm(vector) → double precision | Euclidean norm | ### Aggregate Functions Function | Description | Added --- | --- | --- avg(vector) → vector | average | sum(vector) → vector | sum | 0.5.0 ## Installation Notes ### Postgres Location If your machine has multiple Postgres installations, specify the path to [pg_config](https://www.postgresql.org/docs/current/app-pgconfig.html) with: ```sh export PG_CONFIG=/Library/PostgreSQL/16/bin/pg_config ``` Then re-run the installation instructions (run `make clean` before `make` if needed). If `sudo` is needed for `make install`, use: ```sh sudo --preserve-env=PG_CONFIG make install ``` A few common paths on Mac are: - EDB installer - `/Library/PostgreSQL/16/bin/pg_config` - Homebrew (arm64) - `/opt/homebrew/opt/postgresql@16/bin/pg_config` - Homebrew (x86-64) - `/usr/local/opt/postgresql@16/bin/pg_config` Note: Replace `16` with your Postgres server version ### Missing Header If compilation fails with `fatal error: postgres.h: No such file or directory`, make sure Postgres development files are installed on the server. For Ubuntu and Debian, use: ```sh sudo apt install postgresql-server-dev-16 ``` Note: Replace `16` with your Postgres server version ### Missing SDK If compilation fails and the output includes `warning: no such sysroot directory` on Mac, reinstall Xcode Command Line Tools. ## Additional Installation Methods ### Docker Get the [Docker image](https://hub.docker.com/r/pgvector/pgvector) with: ```sh docker pull pgvector/pgvector:pg16 ``` This adds pgvector to the [Postgres image](https://hub.docker.com/_/postgres) (run it the same way). You can also build the image manually: ```sh git clone --branch v0.6.0 https://github.com/pgvector/pgvector.git cd pgvector docker build --build-arg PG_MAJOR=16 -t myuser/pgvector . ``` ### Homebrew With Homebrew Postgres, you can use: ```sh brew install pgvector ``` Note: This only adds it to the `postgresql@14` formula ### PGXN Install from the [PostgreSQL Extension Network](https://pgxn.org/dist/vector) with: ```sh pgxn install vector ``` ### APT Debian and Ubuntu packages are available from the [PostgreSQL APT Repository](https://wiki.postgresql.org/wiki/Apt). Follow the [setup instructions](https://wiki.postgresql.org/wiki/Apt#Quickstart) and run: ```sh sudo apt install postgresql-16-pgvector ``` Note: Replace `16` with your Postgres server version ### Yum RPM packages are available from the [PostgreSQL Yum Repository](https://yum.postgresql.org/). Follow the [setup instructions](https://www.postgresql.org/download/linux/redhat/) for your distribution and run: ```sh sudo yum install pgvector_16 # or sudo dnf install pgvector_16 ``` Note: Replace `16` with your Postgres server version ### conda-forge With Conda Postgres, install from [conda-forge](https://anaconda.org/conda-forge/pgvector) with: ```sh conda install -c conda-forge pgvector ``` This method is [community-maintained](https://github.com/conda-forge/pgvector-feedstock) by [@mmcauliffe](https://github.com/mmcauliffe) ### Postgres.app Download the [latest release](https://postgresapp.com/downloads.html) with Postgres 15+. ## Hosted Postgres pgvector is available on [these providers](https://github.com/pgvector/pgvector/issues/54). ## Upgrading [Install](#installation) the latest version (use the same method as the original installation). Then in each database you want to upgrade, run: ```sql ALTER EXTENSION vector UPDATE; ``` You can check the version in the current database with: ```sql SELECT extversion FROM pg_extension WHERE extname = 'vector'; ``` ## Upgrade Notes ### 0.6.0 #### Postgres 12 If upgrading with Postgres 12, remove this line from `sql/vector--0.5.1--0.6.0.sql`: ```sql ALTER TYPE vector SET (STORAGE = external); ``` Then run `make install` and `ALTER EXTENSION vector UPDATE;`. #### Docker The Docker image is now published in the `pgvector` org, and there are tags for each supported version of Postgres (rather than a `latest` tag). ```sh docker pull pgvector/pgvector:pg16 # or docker pull pgvector/pgvector:0.6.0-pg16 ``` Also, if you’ve increased `maintenance_work_mem`, make sure `--shm-size` is at least that size to avoid an error with parallel HNSW index builds. ```sh docker run --shm-size=1g ... ``` ## Thanks Thanks to: - [PASE: PostgreSQL Ultra-High-Dimensional Approximate Nearest Neighbor Search Extension](https://dl.acm.org/doi/pdf/10.1145/3318464.3386131) - [Faiss: A Library for Efficient Similarity Search and Clustering of Dense Vectors](https://github.com/facebookresearch/faiss) - [Using the Triangle Inequality to Accelerate k-means](https://cdn.aaai.org/ICML/2003/ICML03-022.pdf) - [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf) - [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf) - [Efficient and Robust Approximate Nearest Neighbor Search using Hierarchical Navigable Small World Graphs](https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf) ## History View the [changelog](https://github.com/pgvector/pgvector/blob/master/CHANGELOG.md) ## Contributing Everyone is encouraged to help improve this project. Here are a few ways you can help: - [Report bugs](https://github.com/pgvector/pgvector/issues) - Fix bugs and [submit pull requests](https://github.com/pgvector/pgvector/pulls) - Write, clarify, or fix documentation - Suggest or add new features To get started with development: ```sh git clone https://github.com/pgvector/pgvector.git cd pgvector make make install ``` To run all tests: ```sh make installcheck # regression tests make prove_installcheck # TAP tests ``` To run single tests: ```sh make installcheck REGRESS=functions # regression test make prove_installcheck PROVE_TESTS=test/t/001_ivfflat_wal.pl # TAP test ``` To enable benchmarking: ```sh make clean && PG_CFLAGS="-DIVFFLAT_BENCH" make && make install ``` To show memory usage: ```sh make clean && PG_CFLAGS="-DHNSW_MEMORY -DIVFFLAT_MEMORY" make && make install ``` To enable assertions: ```sh make clean && PG_CFLAGS="-DUSE_ASSERT_CHECKING" make && make install ``` To get k-means metrics: ```sh make clean && PG_CFLAGS="-DIVFFLAT_KMEANS_DEBUG" make && make install ``` Resources for contributors - [Extension Building Infrastructure](https://www.postgresql.org/docs/current/extend-pgxs.html) - [Index Access Method Interface Definition](https://www.postgresql.org/docs/current/indexam.html) - [Generic WAL Records](https://www.postgresql.org/docs/current/generic-wal.html) pgvector-0.6.0/sql/000077500000000000000000000000001455577216400141655ustar00rootroot00000000000000pgvector-0.6.0/sql/vector--0.1.0--0.1.1.sql000066400000000000000000000037771455577216400174450ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.1.1'" to load this file. \quit CREATE FUNCTION vector_recv(internal, oid, integer) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; CREATE FUNCTION vector_send(vector) RETURNS bytea AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT; ALTER TYPE vector SET ( RECEIVE = vector_recv, SEND = vector_send ); -- functions ALTER FUNCTION vector_in(cstring, oid, integer) PARALLEL SAFE; ALTER FUNCTION vector_out(vector) PARALLEL SAFE; ALTER FUNCTION vector_typmod_in(cstring[]) PARALLEL SAFE; ALTER FUNCTION vector_recv(internal, oid, integer) PARALLEL SAFE; ALTER FUNCTION vector_send(vector) PARALLEL SAFE; ALTER FUNCTION l2_distance(vector, vector) PARALLEL SAFE; ALTER FUNCTION inner_product(vector, vector) PARALLEL SAFE; ALTER FUNCTION cosine_distance(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_dims(vector) PARALLEL SAFE; ALTER FUNCTION vector_norm(vector) PARALLEL SAFE; ALTER FUNCTION vector_add(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_sub(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_lt(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_le(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_eq(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_ne(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_ge(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_gt(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_cmp(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_l2_squared_distance(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_negative_inner_product(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector_spherical_distance(vector, vector) PARALLEL SAFE; ALTER FUNCTION vector(vector, integer, boolean) PARALLEL SAFE; ALTER FUNCTION array_to_vector(integer[], integer, boolean) PARALLEL SAFE; ALTER FUNCTION array_to_vector(real[], integer, boolean) PARALLEL SAFE; ALTER FUNCTION array_to_vector(double precision[], integer, boolean) PARALLEL SAFE; pgvector-0.6.0/sql/vector--0.1.1--0.1.3.sql000066400000000000000000000002311455577216400174260ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.1.3'" to load this file. \quit pgvector-0.6.0/sql/vector--0.1.3--0.1.4.sql000066400000000000000000000002311455577216400174310ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.1.4'" to load this file. \quit pgvector-0.6.0/sql/vector--0.1.4--0.1.5.sql000066400000000000000000000002311455577216400174330ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.1.5'" to load this file. \quit pgvector-0.6.0/sql/vector--0.1.5--0.1.6.sql000066400000000000000000000002311455577216400174350ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.1.6'" to load this file. \quit pgvector-0.6.0/sql/vector--0.1.6--0.1.7.sql000066400000000000000000000006231455577216400174440ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.1.7'" to load this file. \quit CREATE FUNCTION array_to_vector(numeric[], integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE CAST (numeric[] AS vector) WITH FUNCTION array_to_vector(numeric[], integer, boolean) AS IMPLICIT; pgvector-0.6.0/sql/vector--0.1.7--0.1.8.sql000066400000000000000000000006141455577216400174460ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.1.8'" to load this file. \quit CREATE FUNCTION vector_to_float4(vector, integer, boolean) RETURNS real[] AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE CAST (vector AS real[]) WITH FUNCTION vector_to_float4(vector, integer, boolean) AS IMPLICIT; pgvector-0.6.0/sql/vector--0.1.8--0.2.0.sql000066400000000000000000000002311455577216400174330ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.2.0'" to load this file. \quit pgvector-0.6.0/sql/vector--0.2.0--0.2.1.sql000066400000000000000000000013501455577216400174300ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.2.1'" to load this file. \quit DROP CAST (integer[] AS vector); DROP CAST (real[] AS vector); DROP CAST (double precision[] AS vector); DROP CAST (numeric[] AS vector); CREATE CAST (integer[] AS vector) WITH FUNCTION array_to_vector(integer[], integer, boolean) AS ASSIGNMENT; CREATE CAST (real[] AS vector) WITH FUNCTION array_to_vector(real[], integer, boolean) AS ASSIGNMENT; CREATE CAST (double precision[] AS vector) WITH FUNCTION array_to_vector(double precision[], integer, boolean) AS ASSIGNMENT; CREATE CAST (numeric[] AS vector) WITH FUNCTION array_to_vector(numeric[], integer, boolean) AS ASSIGNMENT; pgvector-0.6.0/sql/vector--0.2.1--0.2.2.sql000066400000000000000000000002311455577216400174270ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.2.2'" to load this file. \quit pgvector-0.6.0/sql/vector--0.2.2--0.2.3.sql000066400000000000000000000002311455577216400174310ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.2.3'" to load this file. \quit pgvector-0.6.0/sql/vector--0.2.3--0.2.4.sql000066400000000000000000000002311455577216400174330ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.2.4'" to load this file. \quit pgvector-0.6.0/sql/vector--0.2.4--0.2.5.sql000066400000000000000000000002311455577216400174350ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.2.5'" to load this file. \quit pgvector-0.6.0/sql/vector--0.2.5--0.2.6.sql000066400000000000000000000002311455577216400174370ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.2.6'" to load this file. \quit pgvector-0.6.0/sql/vector--0.2.6--0.2.7.sql000066400000000000000000000002311455577216400174410ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.2.7'" to load this file. \quit pgvector-0.6.0/sql/vector--0.2.7--0.3.0.sql000066400000000000000000000002311455577216400174340ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.3.0'" to load this file. \quit pgvector-0.6.0/sql/vector--0.3.0--0.3.1.sql000066400000000000000000000002311455577216400174270ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.3.1'" to load this file. \quit pgvector-0.6.0/sql/vector--0.3.1--0.3.2.sql000066400000000000000000000002311455577216400174310ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.3.2'" to load this file. \quit pgvector-0.6.0/sql/vector--0.3.2--0.4.0.sql000066400000000000000000000015401455577216400174350ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.4.0'" to load this file. \quit -- remove this single line for Postgres < 13 ALTER TYPE vector SET (STORAGE = extended); CREATE FUNCTION vector_accum(double precision[], vector) RETURNS double precision[] AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_avg(double precision[]) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_combine(double precision[], double precision[]) RETURNS double precision[] AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE AGGREGATE avg(vector) ( SFUNC = vector_accum, STYPE = double precision[], FINALFUNC = vector_avg, COMBINEFUNC = vector_combine, INITCOND = '{0}', PARALLEL = SAFE ); pgvector-0.6.0/sql/vector--0.4.0--0.4.1.sql000066400000000000000000000002311455577216400174310ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.4.1'" to load this file. \quit pgvector-0.6.0/sql/vector--0.4.1--0.4.2.sql000066400000000000000000000002311455577216400174330ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.4.2'" to load this file. \quit pgvector-0.6.0/sql/vector--0.4.2--0.4.3.sql000066400000000000000000000002311455577216400174350ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.4.3'" to load this file. \quit pgvector-0.6.0/sql/vector--0.4.3--0.4.4.sql000066400000000000000000000002311455577216400174370ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.4.4'" to load this file. \quit pgvector-0.6.0/sql/vector--0.4.4--0.5.0.sql000066400000000000000000000026221455577216400174430ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.5.0'" to load this file. \quit CREATE FUNCTION l1_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_mul(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE OPERATOR * ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_mul, COMMUTATOR = * ); CREATE AGGREGATE sum(vector) ( SFUNC = vector_add, STYPE = vector, COMBINEFUNC = vector_add, PARALLEL = SAFE ); CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler AS 'MODULE_PATHNAME' LANGUAGE C; CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler; COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; CREATE OPERATOR CLASS vector_l2_ops FOR TYPE vector USING hnsw AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_l2_squared_distance(vector, vector); CREATE OPERATOR CLASS vector_ip_ops FOR TYPE vector USING hnsw AS OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector); CREATE OPERATOR CLASS vector_cosine_ops FOR TYPE vector USING hnsw AS OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); pgvector-0.6.0/sql/vector--0.5.0--0.5.1.sql000066400000000000000000000002311455577216400174330ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.5.1'" to load this file. \quit pgvector-0.6.0/sql/vector--0.5.1--0.6.0.sql000066400000000000000000000003631455577216400174420ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "ALTER EXTENSION vector UPDATE TO '0.6.0'" to load this file. \quit -- remove this single line for Postgres < 13 ALTER TYPE vector SET (STORAGE = external); pgvector-0.6.0/sql/vector.sql000066400000000000000000000225241455577216400162150ustar00rootroot00000000000000-- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "CREATE EXTENSION vector" to load this file. \quit -- type CREATE TYPE vector; CREATE FUNCTION vector_in(cstring, oid, integer) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_out(vector) RETURNS cstring AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_typmod_in(cstring[]) RETURNS integer AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_recv(internal, oid, integer) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_send(vector) RETURNS bytea AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE TYPE vector ( INPUT = vector_in, OUTPUT = vector_out, TYPMOD_IN = vector_typmod_in, RECEIVE = vector_recv, SEND = vector_send, STORAGE = external ); -- functions CREATE FUNCTION l2_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION inner_product(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION cosine_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION l1_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_dims(vector) RETURNS integer AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_norm(vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_add(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_sub(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_mul(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -- private functions CREATE FUNCTION vector_lt(vector, vector) RETURNS bool AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_le(vector, vector) RETURNS bool AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_eq(vector, vector) RETURNS bool AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_ne(vector, vector) RETURNS bool AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_ge(vector, vector) RETURNS bool AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_gt(vector, vector) RETURNS bool AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_cmp(vector, vector) RETURNS int4 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_l2_squared_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_negative_inner_product(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_spherical_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_accum(double precision[], vector) RETURNS double precision[] AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_avg(double precision[]) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_combine(double precision[], double precision[]) RETURNS double precision[] AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -- aggregates CREATE AGGREGATE avg(vector) ( SFUNC = vector_accum, STYPE = double precision[], FINALFUNC = vector_avg, COMBINEFUNC = vector_combine, INITCOND = '{0}', PARALLEL = SAFE ); CREATE AGGREGATE sum(vector) ( SFUNC = vector_add, STYPE = vector, COMBINEFUNC = vector_add, PARALLEL = SAFE ); -- cast functions CREATE FUNCTION vector(vector, integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION array_to_vector(integer[], integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION array_to_vector(real[], integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION array_to_vector(double precision[], integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION array_to_vector(numeric[], integer, boolean) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; CREATE FUNCTION vector_to_float4(vector, integer, boolean) RETURNS real[] AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -- casts CREATE CAST (vector AS vector) WITH FUNCTION vector(vector, integer, boolean) AS IMPLICIT; CREATE CAST (vector AS real[]) WITH FUNCTION vector_to_float4(vector, integer, boolean) AS IMPLICIT; CREATE CAST (integer[] AS vector) WITH FUNCTION array_to_vector(integer[], integer, boolean) AS ASSIGNMENT; CREATE CAST (real[] AS vector) WITH FUNCTION array_to_vector(real[], integer, boolean) AS ASSIGNMENT; CREATE CAST (double precision[] AS vector) WITH FUNCTION array_to_vector(double precision[], integer, boolean) AS ASSIGNMENT; CREATE CAST (numeric[] AS vector) WITH FUNCTION array_to_vector(numeric[], integer, boolean) AS ASSIGNMENT; -- operators CREATE OPERATOR <-> ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = l2_distance, COMMUTATOR = '<->' ); CREATE OPERATOR <#> ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_negative_inner_product, COMMUTATOR = '<#>' ); CREATE OPERATOR <=> ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = cosine_distance, COMMUTATOR = '<=>' ); CREATE OPERATOR + ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_add, COMMUTATOR = + ); CREATE OPERATOR - ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_sub, COMMUTATOR = - ); CREATE OPERATOR * ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_mul, COMMUTATOR = * ); CREATE OPERATOR < ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_lt, COMMUTATOR = > , NEGATOR = >= , RESTRICT = scalarltsel, JOIN = scalarltjoinsel ); -- should use scalarlesel and scalarlejoinsel, but not supported in Postgres < 11 CREATE OPERATOR <= ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_le, COMMUTATOR = >= , NEGATOR = > , RESTRICT = scalarltsel, JOIN = scalarltjoinsel ); CREATE OPERATOR = ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_eq, COMMUTATOR = = , NEGATOR = <> , RESTRICT = eqsel, JOIN = eqjoinsel ); CREATE OPERATOR <> ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_ne, COMMUTATOR = <> , NEGATOR = = , RESTRICT = eqsel, JOIN = eqjoinsel ); -- should use scalargesel and scalargejoinsel, but not supported in Postgres < 11 CREATE OPERATOR >= ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_ge, COMMUTATOR = <= , NEGATOR = < , RESTRICT = scalargtsel, JOIN = scalargtjoinsel ); CREATE OPERATOR > ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_gt, COMMUTATOR = < , NEGATOR = <= , RESTRICT = scalargtsel, JOIN = scalargtjoinsel ); -- access methods CREATE FUNCTION ivfflathandler(internal) RETURNS index_am_handler AS 'MODULE_PATHNAME' LANGUAGE C; CREATE ACCESS METHOD ivfflat TYPE INDEX HANDLER ivfflathandler; COMMENT ON ACCESS METHOD ivfflat IS 'ivfflat index access method'; CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler AS 'MODULE_PATHNAME' LANGUAGE C; CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler; COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; -- opclasses CREATE OPERATOR CLASS vector_ops DEFAULT FOR TYPE vector USING btree AS OPERATOR 1 < , OPERATOR 2 <= , OPERATOR 3 = , OPERATOR 4 >= , OPERATOR 5 > , FUNCTION 1 vector_cmp(vector, vector); CREATE OPERATOR CLASS vector_l2_ops DEFAULT FOR TYPE vector USING ivfflat AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_l2_squared_distance(vector, vector), FUNCTION 3 l2_distance(vector, vector); CREATE OPERATOR CLASS vector_ip_ops FOR TYPE vector USING ivfflat AS OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 3 vector_spherical_distance(vector, vector), FUNCTION 4 vector_norm(vector); CREATE OPERATOR CLASS vector_cosine_ops FOR TYPE vector USING ivfflat AS OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector), FUNCTION 3 vector_spherical_distance(vector, vector), FUNCTION 4 vector_norm(vector); CREATE OPERATOR CLASS vector_l2_ops FOR TYPE vector USING hnsw AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_l2_squared_distance(vector, vector); CREATE OPERATOR CLASS vector_ip_ops FOR TYPE vector USING hnsw AS OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector); CREATE OPERATOR CLASS vector_cosine_ops FOR TYPE vector USING hnsw AS OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 vector_negative_inner_product(vector, vector), FUNCTION 2 vector_norm(vector); pgvector-0.6.0/src/000077500000000000000000000000001455577216400141555ustar00rootroot00000000000000pgvector-0.6.0/src/hnsw.c000066400000000000000000000153601455577216400153050ustar00rootroot00000000000000#include "postgres.h" #include #include #include "access/amapi.h" #include "access/reloptions.h" #include "commands/progress.h" #include "commands/vacuum.h" #include "hnsw.h" #include "utils/guc.h" #include "utils/selfuncs.h" #if PG_VERSION_NUM < 150000 #define MarkGUCPrefixReserved(x) EmitWarningsOnPlaceholders(x) #endif int hnsw_ef_search; int hnsw_lock_tranche_id; static relopt_kind hnsw_relopt_kind; /* * Assign a tranche ID for our LWLocks. This only needs to be done by one * backend, as the tranche ID is remembered in shared memory. * * This shared memory area is very small, so we just allocate it from the * "slop" that PostgreSQL reserves for small allocations like this. If * this grows bigger, we should use a shmem_request_hook and * RequestAddinShmemSpace() to pre-reserve space for this. */ static void HnswInitLockTranche(void) { int *tranche_ids; bool found; LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE); tranche_ids = ShmemInitStruct("hnsw LWLock ids", sizeof(int) * 1, &found); if (!found) tranche_ids[0] = LWLockNewTrancheId(); hnsw_lock_tranche_id = tranche_ids[0]; LWLockRelease(AddinShmemInitLock); /* Per-backend registration of the tranche ID */ LWLockRegisterTranche(hnsw_lock_tranche_id, "HnswBuild"); } /* * Initialize index options and variables */ void HnswInit(void) { HnswInitLockTranche(); hnsw_relopt_kind = add_reloption_kind(); add_int_reloption(hnsw_relopt_kind, "m", "Max number of connections", HNSW_DEFAULT_M, HNSW_MIN_M, HNSW_MAX_M #if PG_VERSION_NUM >= 130000 ,AccessExclusiveLock #endif ); add_int_reloption(hnsw_relopt_kind, "ef_construction", "Size of the dynamic candidate list for construction", HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION #if PG_VERSION_NUM >= 130000 ,AccessExclusiveLock #endif ); DefineCustomIntVariable("hnsw.ef_search", "Sets the size of the dynamic candidate list for search", "Valid range is 1..1000.", &hnsw_ef_search, HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL); MarkGUCPrefixReserved("hnsw"); } /* * Get the name of index build phase */ static char * hnswbuildphasename(int64 phasenum) { switch (phasenum) { case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE: return "initializing"; case PROGRESS_HNSW_PHASE_LOAD: return "loading tuples"; default: return NULL; } } /* * Estimate the cost of an index scan */ static void hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, Cost *indexStartupCost, Cost *indexTotalCost, Selectivity *indexSelectivity, double *indexCorrelation, double *indexPages) { GenericCosts costs; int m; int entryLevel; Relation index; /* Never use index without order */ if (path->indexorderbys == NULL) { *indexStartupCost = DBL_MAX; *indexTotalCost = DBL_MAX; *indexSelectivity = 0; *indexCorrelation = 0; *indexPages = 0; return; } MemSet(&costs, 0, sizeof(costs)); index = index_open(path->indexinfo->indexoid, NoLock); HnswGetMetaPageInfo(index, &m, NULL); index_close(index, NoLock); /* Approximate entry level */ entryLevel = (int) -log(1.0 / path->indexinfo->tuples) * HnswGetMl(m); /* TODO Improve estimate of visited tuples (currently underestimates) */ /* Account for number of tuples (or entry level), m, and ef_search */ costs.numIndexTuples = (entryLevel + 2) * m; genericcostestimate(root, path, loop_count, &costs); /* Use total cost since most work happens before first tuple is returned */ *indexStartupCost = costs.indexTotalCost; *indexTotalCost = costs.indexTotalCost; *indexSelectivity = costs.indexSelectivity; *indexCorrelation = costs.indexCorrelation; *indexPages = costs.numIndexPages; } /* * Parse and validate the reloptions */ static bytea * hnswoptions(Datum reloptions, bool validate) { static const relopt_parse_elt tab[] = { {"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)}, {"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)}, }; #if PG_VERSION_NUM >= 130000 return (bytea *) build_reloptions(reloptions, validate, hnsw_relopt_kind, sizeof(HnswOptions), tab, lengthof(tab)); #else relopt_value *options; int numoptions; HnswOptions *rdopts; options = parseRelOptions(reloptions, validate, hnsw_relopt_kind, &numoptions); rdopts = allocateReloptStruct(sizeof(HnswOptions), options, numoptions); fillRelOptions((void *) rdopts, sizeof(HnswOptions), options, numoptions, validate, tab, lengthof(tab)); return (bytea *) rdopts; #endif } /* * Validate catalog entries for the specified operator class */ static bool hnswvalidate(Oid opclassoid) { return true; } /* * Define index handler * * See https://www.postgresql.org/docs/current/index-api.html */ PGDLLEXPORT PG_FUNCTION_INFO_V1(hnswhandler); Datum hnswhandler(PG_FUNCTION_ARGS) { IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; amroutine->amsupport = 2; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif amroutine->amcanorder = false; amroutine->amcanorderbyop = true; amroutine->amcanbackward = false; /* can change direction mid-scan */ amroutine->amcanunique = false; amroutine->amcanmulticol = false; amroutine->amoptionalkey = true; amroutine->amsearcharray = false; amroutine->amsearchnulls = false; amroutine->amstorage = false; amroutine->amclusterable = false; amroutine->ampredlocks = false; amroutine->amcanparallel = false; amroutine->amcaninclude = false; #if PG_VERSION_NUM >= 130000 amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */ amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL; #endif amroutine->amkeytype = InvalidOid; /* Interface functions */ amroutine->ambuild = hnswbuild; amroutine->ambuildempty = hnswbuildempty; amroutine->aminsert = hnswinsert; amroutine->ambulkdelete = hnswbulkdelete; amroutine->amvacuumcleanup = hnswvacuumcleanup; amroutine->amcanreturn = NULL; amroutine->amcostestimate = hnswcostestimate; amroutine->amoptions = hnswoptions; amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */ amroutine->ambuildphasename = hnswbuildphasename; amroutine->amvalidate = hnswvalidate; #if PG_VERSION_NUM >= 140000 amroutine->amadjustmembers = NULL; #endif amroutine->ambeginscan = hnswbeginscan; amroutine->amrescan = hnswrescan; amroutine->amgettuple = hnswgettuple; amroutine->amgetbitmap = NULL; amroutine->amendscan = hnswendscan; amroutine->ammarkpos = NULL; amroutine->amrestrpos = NULL; /* Interface functions to support parallel index scans */ amroutine->amestimateparallelscan = NULL; amroutine->aminitparallelscan = NULL; amroutine->amparallelrescan = NULL; PG_RETURN_POINTER(amroutine); } pgvector-0.6.0/src/hnsw.h000066400000000000000000000322161455577216400153110ustar00rootroot00000000000000#ifndef HNSW_H #define HNSW_H #include "postgres.h" #include "access/genam.h" #include "access/parallel.h" #include "lib/pairingheap.h" #include "nodes/execnodes.h" #include "port.h" /* for random() */ #include "utils/relptr.h" #include "utils/sampling.h" #include "vector.h" #if PG_VERSION_NUM < 120000 #error "Requires PostgreSQL 12+" #endif #define HNSW_MAX_DIM 2000 /* Support functions */ #define HNSW_DISTANCE_PROC 1 #define HNSW_NORM_PROC 2 #define HNSW_VERSION 1 #define HNSW_MAGIC_NUMBER 0xA953A953 #define HNSW_PAGE_ID 0xFF90 /* Preserved page numbers */ #define HNSW_METAPAGE_BLKNO 0 #define HNSW_HEAD_BLKNO 1 /* first element page */ /* Must correspond to page numbers since page lock is used */ #define HNSW_UPDATE_LOCK 0 #define HNSW_SCAN_LOCK 1 /* HNSW parameters */ #define HNSW_DEFAULT_M 16 #define HNSW_MIN_M 2 #define HNSW_MAX_M 100 #define HNSW_DEFAULT_EF_CONSTRUCTION 64 #define HNSW_MIN_EF_CONSTRUCTION 4 #define HNSW_MAX_EF_CONSTRUCTION 1000 #define HNSW_DEFAULT_EF_SEARCH 40 #define HNSW_MIN_EF_SEARCH 1 #define HNSW_MAX_EF_SEARCH 1000 /* Tuple types */ #define HNSW_ELEMENT_TUPLE_TYPE 1 #define HNSW_NEIGHBOR_TUPLE_TYPE 2 /* Make graph robust against non-HOT updates */ #define HNSW_HEAPTIDS 10 #define HNSW_UPDATE_ENTRY_GREATER 1 #define HNSW_UPDATE_ENTRY_ALWAYS 2 /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ #define PROGRESS_HNSW_PHASE_LOAD 2 #define HNSW_MAX_SIZE (BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)) - sizeof(ItemIdData)) #define HNSW_TUPLE_ALLOC_SIZE BLCKSZ #define HNSW_ELEMENT_TUPLE_SIZE(size) MAXALIGN(offsetof(HnswElementTupleData, data) + (size)) #define HNSW_NEIGHBOR_TUPLE_SIZE(level, m) MAXALIGN(offsetof(HnswNeighborTupleData, indextids) + ((level) + 2) * (m) * sizeof(ItemPointerData)) #define HNSW_NEIGHBOR_ARRAY_SIZE(lm) (offsetof(HnswNeighborArray, items) + sizeof(HnswCandidate) * (lm)) #define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page)) #define HnswPageGetMeta(page) ((HnswMetaPageData *) PageGetContents(page)) #if PG_VERSION_NUM >= 150000 #define RandomDouble() pg_prng_double(&pg_global_prng_state) #define SeedRandom(seed) pg_prng_seed(&pg_global_prng_state, seed) #else #define RandomDouble() (((double) random()) / MAX_RANDOM_VALUE) #define SeedRandom(seed) srandom(seed) #endif #if PG_VERSION_NUM < 130000 #define list_delete_last(list) list_truncate(list, list_length(list) - 1) #define list_sort(list, cmp) list_qsort(list, cmp) #endif #define HnswIsElementTuple(tup) ((tup)->type == HNSW_ELEMENT_TUPLE_TYPE) #define HnswIsNeighborTuple(tup) ((tup)->type == HNSW_NEIGHBOR_TUPLE_TYPE) /* 2 * M connections for ground layer */ #define HnswGetLayerM(m, layer) (layer == 0 ? (m) * 2 : (m)) /* Optimal ML from paper */ #define HnswGetMl(m) (1 / log(m)) /* Ensure fits on page and in uint8 */ #define HnswGetMaxLevel(m) Min(((BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)) - offsetof(HnswNeighborTupleData, indextids) - sizeof(ItemIdData)) / (sizeof(ItemPointerData)) / (m)) - 2, 255) #define HnswGetValue(base, element) PointerGetDatum(HnswPtrAccess(base, (element)->value)) #if PG_VERSION_NUM < 140005 #define relptr_offset(rp) ((rp).relptr_off - 1) #endif /* Pointer macros */ #define HnswPtrAccess(base, hp) ((base) == NULL ? (hp).ptr : relptr_access(base, (hp).relptr)) #define HnswPtrStore(base, hp, value) ((base) == NULL ? (void) ((hp).ptr = (value)) : (void) relptr_store(base, (hp).relptr, value)) #define HnswPtrIsNull(base, hp) ((base) == NULL ? (hp).ptr == NULL : relptr_is_null((hp).relptr)) #define HnswPtrEqual(base, hp1, hp2) ((base) == NULL ? (hp1).ptr == (hp2).ptr : relptr_offset((hp1).relptr) == relptr_offset((hp2).relptr)) /* For code paths dedicated to each type */ #define HnswPtrPointer(hp) (hp).ptr #define HnswPtrOffset(hp) relptr_offset((hp).relptr) /* Variables */ extern int hnsw_ef_search; extern int hnsw_lock_tranche_id; typedef struct HnswElementData HnswElementData; typedef struct HnswNeighborArray HnswNeighborArray; #define HnswPtrDeclare(type, relptrtype, ptrtype) \ relptr_declare(type, relptrtype); \ typedef union { type *ptr; relptrtype relptr; } ptrtype; /* Pointers that can be absolute or relative */ /* Use char for DatumPtr so works with Pointer */ HnswPtrDeclare(HnswElementData, HnswElementRelptr, HnswElementPtr); HnswPtrDeclare(HnswNeighborArray, HnswNeighborArrayRelptr, HnswNeighborArrayPtr); HnswPtrDeclare(HnswNeighborArrayPtr, HnswNeighborsRelptr, HnswNeighborsPtr); HnswPtrDeclare(char, DatumRelptr, DatumPtr); typedef struct HnswElementData { HnswElementPtr next; ItemPointerData heaptids[HNSW_HEAPTIDS]; uint8 heaptidsLength; uint8 level; uint8 deleted; uint32 hash; HnswNeighborsPtr neighbors; BlockNumber blkno; OffsetNumber offno; OffsetNumber neighborOffno; BlockNumber neighborPage; DatumPtr value; LWLock lock; } HnswElementData; typedef HnswElementData * HnswElement; typedef struct HnswCandidate { HnswElementPtr element; float distance; bool closer; } HnswCandidate; typedef struct HnswNeighborArray { int length; bool closerSet; HnswCandidate items[FLEXIBLE_ARRAY_MEMBER]; } HnswNeighborArray; typedef struct HnswPairingHeapNode { pairingheap_node ph_node; HnswCandidate *inner; } HnswPairingHeapNode; /* HNSW index options */ typedef struct HnswOptions { int32 vl_len_; /* varlena header (do not touch directly!) */ int m; /* number of connections */ int efConstruction; /* size of dynamic candidate list */ } HnswOptions; typedef struct HnswGraph { /* Graph state */ slock_t lock; HnswElementPtr head; double indtuples; /* Entry state */ LWLock entryLock; HnswElementPtr entryPoint; /* Allocations state */ LWLock allocatorLock; long memoryUsed; long memoryTotal; /* Flushed state */ LWLock flushLock; bool flushed; } HnswGraph; typedef struct HnswShared { /* Immutable state */ Oid heaprelid; Oid indexrelid; bool isconcurrent; /* Worker progress */ ConditionVariable workersdonecv; /* Mutex for mutable state */ slock_t mutex; /* Mutable state */ int nparticipantsdone; double reltuples; HnswGraph graphData; } HnswShared; #define ParallelTableScanFromHnswShared(shared) \ (ParallelTableScanDesc) ((char *) (shared) + BUFFERALIGN(sizeof(HnswShared))) typedef struct HnswLeader { ParallelContext *pcxt; int nparticipanttuplesorts; HnswShared *hnswshared; Snapshot snapshot; char *hnswarea; } HnswLeader; typedef struct HnswAllocator { void *(*alloc) (Size size, void *state); void *state; } HnswAllocator; typedef struct HnswBuildState { /* Info */ Relation heap; Relation index; IndexInfo *indexInfo; ForkNumber forkNum; /* Settings */ int dimensions; int m; int efConstruction; /* Statistics */ double indtuples; double reltuples; /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; Oid collation; /* Variables */ HnswGraph graphData; HnswGraph *graph; double ml; int maxLevel; Vector *normvec; /* Memory */ MemoryContext graphCtx; MemoryContext tmpCtx; HnswAllocator allocator; /* Parallel builds */ HnswLeader *hnswleader; HnswShared *hnswshared; char *hnswarea; } HnswBuildState; typedef struct HnswMetaPageData { uint32 magicNumber; uint32 version; uint32 dimensions; uint16 m; uint16 efConstruction; BlockNumber entryBlkno; OffsetNumber entryOffno; int16 entryLevel; BlockNumber insertPage; } HnswMetaPageData; typedef HnswMetaPageData * HnswMetaPage; typedef struct HnswPageOpaqueData { BlockNumber nextblkno; uint16 unused; uint16 page_id; /* for identification of HNSW indexes */ } HnswPageOpaqueData; typedef HnswPageOpaqueData * HnswPageOpaque; typedef struct HnswElementTupleData { uint8 type; uint8 level; uint8 deleted; uint8 unused; ItemPointerData heaptids[HNSW_HEAPTIDS]; ItemPointerData neighbortid; uint16 unused2; Vector data; } HnswElementTupleData; typedef HnswElementTupleData * HnswElementTuple; typedef struct HnswNeighborTupleData { uint8 type; uint8 unused; uint16 count; ItemPointerData indextids[FLEXIBLE_ARRAY_MEMBER]; } HnswNeighborTupleData; typedef HnswNeighborTupleData * HnswNeighborTuple; typedef struct HnswScanOpaqueData { bool first; List *w; MemoryContext tmpCtx; /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; Oid collation; } HnswScanOpaqueData; typedef HnswScanOpaqueData * HnswScanOpaque; typedef struct HnswVacuumState { /* Info */ Relation index; IndexBulkDeleteResult *stats; IndexBulkDeleteCallback callback; void *callback_state; /* Settings */ int m; int efConstruction; /* Support functions */ FmgrInfo *procinfo; Oid collation; /* Variables */ struct tidhash_hash *deleted; BufferAccessStrategy bas; HnswNeighborTuple ntup; HnswElementData highestPoint; /* Memory */ MemoryContext tmpCtx; } HnswVacuumState; /* Methods */ int HnswGetM(Relation index); int HnswGetEfConstruction(Relation index); FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); void HnswInitPage(Buffer buf, Page page); void HnswInit(void); List *HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); HnswElement HnswGetEntryPoint(Relation index); void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); void *HnswAlloc(HnswAllocator * allocator, Size size); HnswElement HnswInitElement(char *base, ItemPointer tid, int m, double ml, int maxLevel, HnswAllocator * alloc); HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing); HnswCandidate *HnswEntryCandidate(char *base, HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadVec); void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building); void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m); void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * alloc); bool HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building); void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building); void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec); void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element); void HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); void HnswLoadNeighbors(HnswElement element, Relation index, int m); PGDLLEXPORT void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc); /* Index access methods */ IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo); void hnswbuildempty(Relation index); bool hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique #if PG_VERSION_NUM >= 140000 ,bool indexUnchanged #endif ,IndexInfo *indexInfo ); IndexBulkDeleteResult *hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state); IndexBulkDeleteResult *hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats); IndexScanDesc hnswbeginscan(Relation index, int nkeys, int norderbys); void hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys); bool hnswgettuple(IndexScanDesc scan, ScanDirection dir); void hnswendscan(IndexScanDesc scan); static inline HnswNeighborArray * HnswGetNeighbors(char *base, HnswElement element, int lc) { HnswNeighborArrayPtr *neighborList = HnswPtrAccess(base, element->neighbors); Assert(element->level >= lc); return HnswPtrAccess(base, neighborList[lc]); } /* Hash tables */ typedef struct TidHashEntry { ItemPointerData tid; char status; } TidHashEntry; #define SH_PREFIX tidhash #define SH_ELEMENT_TYPE TidHashEntry #define SH_KEY_TYPE ItemPointerData #define SH_SCOPE extern #define SH_DECLARE #include "lib/simplehash.h" typedef struct PointerHashEntry { uintptr_t ptr; char status; } PointerHashEntry; #define SH_PREFIX pointerhash #define SH_ELEMENT_TYPE PointerHashEntry #define SH_KEY_TYPE uintptr_t #define SH_SCOPE extern #define SH_DECLARE #include "lib/simplehash.h" typedef struct OffsetHashEntry { Size offset; char status; } OffsetHashEntry; #define SH_PREFIX offsethash #define SH_ELEMENT_TYPE OffsetHashEntry #define SH_KEY_TYPE Size #define SH_SCOPE extern #define SH_DECLARE #include "lib/simplehash.h" #endif pgvector-0.6.0/src/hnswbuild.c000066400000000000000000000755751455577216400163430ustar00rootroot00000000000000/* * The HNSW build happens in two phases: * * 1. In-memory phase * * In this first phase, the graph is held completely in memory. When the graph * is fully built, or we run out of memory reserved for the build (determined * by maintenance_work_mem), we materialize the graph to disk (see * FlushPages()), and switch to the on-disk phase. * * In a parallel build, a large contiguous chunk of shared memory is allocated * to hold the graph. Each worker process has its own HnswBuildState struct in * private memory, which contains information that doesn't change throughout * the build, and pointers to the shared structs in shared memory. The shared * memory area is mapped to a different address in each worker process, and * 'HnswBuildState.hnswarea' points to the beginning of the shared area in the * worker process's address space. All pointers used in the graph are * "relative pointers", stored as an offset from 'hnswarea'. * * Each element is protected by an LWLock. It must be held when reading or * modifying the element's neighbors or 'heaptids'. * * In a non-parallel build, the graph is held in backend-private memory. All * the elements are allocated in a dedicated memory context, 'graphCtx', and * the pointers used in the graph are regular pointers. * * 2. On-disk phase * * In the on-disk phase, the index is built by inserting each vector to the * index one by one, just like on INSERT. The only difference is that we don't * WAL-log the individual inserts. If the graph fit completely in memory and * was fully built in the in-memory phase, the on-disk phase is skipped. * * After we have finished building the graph, we perform one more scan through * the index and write all the pages to the WAL. */ #include "postgres.h" #include #include "access/parallel.h" #include "access/table.h" #include "access/tableam.h" #include "access/xact.h" #include "access/xloginsert.h" #include "catalog/index.h" #include "commands/progress.h" #include "hnsw.h" #include "miscadmin.h" #include "optimizer/optimizer.h" #include "storage/bufmgr.h" #include "tcop/tcopprot.h" #include "utils/datum.h" #include "utils/memutils.h" #if PG_VERSION_NUM >= 140000 #include "utils/backend_progress.h" #else #include "pgstat.h" #endif #if PG_VERSION_NUM >= 130000 #define CALLBACK_ITEM_POINTER ItemPointer tid #else #define CALLBACK_ITEM_POINTER HeapTuple hup #endif #if PG_VERSION_NUM >= 140000 #include "utils/backend_status.h" #include "utils/wait_event.h" #endif #define PARALLEL_KEY_HNSW_SHARED UINT64CONST(0xA000000000000001) #define PARALLEL_KEY_HNSW_AREA UINT64CONST(0xA000000000000002) #define PARALLEL_KEY_QUERY_TEXT UINT64CONST(0xA000000000000003) #if PG_VERSION_NUM < 130000 #define GENERATIONCHUNK_RAWSIZE (SIZEOF_SIZE_T + SIZEOF_VOID_P * 2) #endif /* * Create the metapage */ static void CreateMetaPage(HnswBuildState * buildstate) { Relation index = buildstate->index; ForkNumber forkNum = buildstate->forkNum; Buffer buf; Page page; HnswMetaPage metap; buf = HnswNewBuffer(index, forkNum); page = BufferGetPage(buf); HnswInitPage(buf, page); /* Set metapage data */ metap = HnswPageGetMeta(page); metap->magicNumber = HNSW_MAGIC_NUMBER; metap->version = HNSW_VERSION; metap->dimensions = buildstate->dimensions; metap->m = buildstate->m; metap->efConstruction = buildstate->efConstruction; metap->entryBlkno = InvalidBlockNumber; metap->entryOffno = InvalidOffsetNumber; metap->entryLevel = -1; metap->insertPage = InvalidBlockNumber; ((PageHeader) page)->pd_lower = ((char *) metap + sizeof(HnswMetaPageData)) - (char *) page; MarkBufferDirty(buf); UnlockReleaseBuffer(buf); } /* * Add a new page */ static void HnswBuildAppendPage(Relation index, Buffer *buf, Page *page, ForkNumber forkNum) { /* Add a new page */ Buffer newbuf = HnswNewBuffer(index, forkNum); /* Update previous page */ HnswPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(newbuf); /* Commit */ MarkBufferDirty(*buf); UnlockReleaseBuffer(*buf); /* Can take a while, so ensure we can interrupt */ /* Needs to be called when no buffer locks are held */ LockBuffer(newbuf, BUFFER_LOCK_UNLOCK); CHECK_FOR_INTERRUPTS(); LockBuffer(newbuf, BUFFER_LOCK_EXCLUSIVE); /* Prepare new page */ *buf = newbuf; *page = BufferGetPage(*buf); HnswInitPage(*buf, *page); } /* * Create graph pages */ static void CreateGraphPages(HnswBuildState * buildstate) { Relation index = buildstate->index; ForkNumber forkNum = buildstate->forkNum; Size maxSize; HnswElementTuple etup; HnswNeighborTuple ntup; BlockNumber insertPage; HnswElement entryPoint; Buffer buf; Page page; HnswElementPtr iter = buildstate->graph->head; char *base = buildstate->hnswarea; /* Calculate sizes */ maxSize = HNSW_MAX_SIZE; /* Allocate once */ etup = palloc0(HNSW_TUPLE_ALLOC_SIZE); ntup = palloc0(HNSW_TUPLE_ALLOC_SIZE); /* Prepare first page */ buf = HnswNewBuffer(index, forkNum); page = BufferGetPage(buf); HnswInitPage(buf, page); while (!HnswPtrIsNull(base, iter)) { HnswElement element = HnswPtrAccess(base, iter); Size etupSize; Size ntupSize; Size combinedSize; void *valuePtr = HnswPtrAccess(base, element->value); /* Update iterator */ iter = element->next; /* Zero memory for each element */ MemSet(etup, 0, HNSW_TUPLE_ALLOC_SIZE); /* Calculate sizes */ etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(valuePtr)); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); /* Initial size check */ if (etupSize > HNSW_TUPLE_ALLOC_SIZE) elog(ERROR, "index tuple too large"); HnswSetElementTuple(base, etup, element); /* Keep element and neighbors on the same page if possible */ if (PageGetFreeSpace(page) < etupSize || (combinedSize <= maxSize && PageGetFreeSpace(page) < combinedSize)) HnswBuildAppendPage(index, &buf, &page, forkNum); /* Calculate offsets */ element->blkno = BufferGetBlockNumber(buf); element->offno = OffsetNumberNext(PageGetMaxOffsetNumber(page)); if (combinedSize <= maxSize) { element->neighborPage = element->blkno; element->neighborOffno = OffsetNumberNext(element->offno); } else { element->neighborPage = element->blkno + 1; element->neighborOffno = FirstOffsetNumber; } ItemPointerSet(&etup->neighbortid, element->neighborPage, element->neighborOffno); /* Add element */ if (PageAddItem(page, (Item) etup, etupSize, InvalidOffsetNumber, false, false) != element->offno) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); /* Add new page if needed */ if (PageGetFreeSpace(page) < ntupSize) HnswBuildAppendPage(index, &buf, &page, forkNum); /* Add placeholder for neighbors */ if (PageAddItem(page, (Item) ntup, ntupSize, InvalidOffsetNumber, false, false) != element->neighborOffno) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); } insertPage = BufferGetBlockNumber(buf); /* Commit */ MarkBufferDirty(buf); UnlockReleaseBuffer(buf); entryPoint = HnswPtrAccess(base, buildstate->graph->entryPoint); HnswUpdateMetaPage(index, HNSW_UPDATE_ENTRY_ALWAYS, entryPoint, insertPage, forkNum, true); pfree(etup); pfree(ntup); } /* * Write neighbor tuples */ static void WriteNeighborTuples(HnswBuildState * buildstate) { Relation index = buildstate->index; ForkNumber forkNum = buildstate->forkNum; int m = buildstate->m; HnswElementPtr iter = buildstate->graph->head; char *base = buildstate->hnswarea; HnswNeighborTuple ntup; /* Allocate once */ ntup = palloc0(HNSW_TUPLE_ALLOC_SIZE); while (!HnswPtrIsNull(base, iter)) { HnswElement element = HnswPtrAccess(base, iter); Buffer buf; Page page; Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m); /* Update iterator */ iter = element->next; /* Zero memory for each element */ MemSet(ntup, 0, HNSW_TUPLE_ALLOC_SIZE); /* Can take a while, so ensure we can interrupt */ /* Needs to be called when no buffer locks are held */ CHECK_FOR_INTERRUPTS(); buf = ReadBufferExtended(index, forkNum, element->neighborPage, RBM_NORMAL, NULL); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); page = BufferGetPage(buf); HnswSetNeighborTuple(base, ntup, element, m); if (!PageIndexTupleOverwrite(page, element->neighborOffno, (Item) ntup, ntupSize)) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); /* Commit */ MarkBufferDirty(buf); UnlockReleaseBuffer(buf); } pfree(ntup); } /* * Flush pages */ static void FlushPages(HnswBuildState * buildstate) { #ifdef HNSW_MEMORY elog(INFO, "memory: %zu MB", buildstate->graph->memoryUsed / (1024 * 1024)); #endif CreateMetaPage(buildstate); CreateGraphPages(buildstate); WriteNeighborTuples(buildstate); buildstate->graph->flushed = true; MemoryContextReset(buildstate->graphCtx); } /* * Add a heap TID to an existing element */ static bool AddDuplicateInMemory(HnswElement element, HnswElement dup) { LWLockAcquire(&dup->lock, LW_EXCLUSIVE); if (dup->heaptidsLength == HNSW_HEAPTIDS) { LWLockRelease(&dup->lock); return false; } HnswAddHeapTid(dup, &element->heaptids[0]); LWLockRelease(&dup->lock); return true; } /* * Find duplicate element */ static bool FindDuplicateInMemory(char *base, HnswElement element) { HnswNeighborArray *neighbors = HnswGetNeighbors(base, element, 0); Datum value = HnswGetValue(base, element); for (int i = 0; i < neighbors->length; i++) { HnswCandidate *neighbor = &neighbors->items[i]; HnswElement neighborElement = HnswPtrAccess(base, neighbor->element); Datum neighborValue = HnswGetValue(base, neighborElement); /* Exit early since ordered by distance */ if (!datumIsEqual(value, neighborValue, false, -1)) return false; /* Check for space */ if (AddDuplicateInMemory(element, neighborElement)) return true; } return false; } /* * Add to element list */ static void AddElementInMemory(char *base, HnswGraph * graph, HnswElement element) { SpinLockAcquire(&graph->lock); element->next = graph->head; HnswPtrStore(base, graph->head, element); SpinLockRelease(&graph->lock); } /* * Update neighbors */ static void UpdateNeighborsInMemory(char *base, FmgrInfo *procinfo, Oid collation, HnswElement e, int m) { for (int lc = e->level; lc >= 0; lc--) { int lm = HnswGetLayerM(m, lc); HnswNeighborArray *neighbors = HnswGetNeighbors(base, e, lc); for (int i = 0; i < neighbors->length; i++) { HnswCandidate *hc = &neighbors->items[i]; HnswElement neighborElement = HnswPtrAccess(base, hc->element); /* Keep scan-build happy on Mac x86-64 */ Assert(neighborElement); /* Use element for lock instead of hc since hc can be replaced */ LWLockAcquire(&neighborElement->lock, LW_EXCLUSIVE); HnswUpdateConnection(base, e, hc, lm, lc, NULL, NULL, procinfo, collation); LWLockRelease(&neighborElement->lock); } } } /* * Update graph in memory */ static void UpdateGraphInMemory(FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, HnswBuildState * buildstate) { HnswGraph *graph = buildstate->graph; char *base = buildstate->hnswarea; /* Look for duplicate */ if (FindDuplicateInMemory(base, element)) return; /* Add element */ AddElementInMemory(base, graph, element); /* Update neighbors */ UpdateNeighborsInMemory(base, procinfo, collation, element, m); /* Update entry point if needed (already have lock) */ if (entryPoint == NULL || element->level > entryPoint->level) HnswPtrStore(base, graph->entryPoint, element); } /* * Insert tuple in memory */ static void InsertTupleInMemory(HnswBuildState * buildstate, HnswElement element) { FmgrInfo *procinfo = buildstate->procinfo; Oid collation = buildstate->collation; HnswGraph *graph = buildstate->graph; HnswElement entryPoint; LWLock *entryLock = &graph->entryLock; int efConstruction = buildstate->efConstruction; int m = buildstate->m; char *base = buildstate->hnswarea; /* Get entry point */ LWLockAcquire(entryLock, LW_SHARED); entryPoint = HnswPtrAccess(base, graph->entryPoint); /* Prevent concurrent inserts when likely updating entry point */ if (entryPoint == NULL || element->level > entryPoint->level) { /* Release shared lock */ LWLockRelease(entryLock); /* Get exclusive lock */ LWLockAcquire(entryLock, LW_EXCLUSIVE); /* Get latest entry point after lock is acquired */ entryPoint = HnswPtrAccess(base, graph->entryPoint); } /* Find neighbors for element */ HnswFindElementNeighbors(base, element, entryPoint, NULL, procinfo, collation, m, efConstruction, false); /* Update graph in memory */ UpdateGraphInMemory(procinfo, collation, element, m, efConstruction, entryPoint, buildstate); /* Release entry lock */ LWLockRelease(entryLock); } /* * Insert tuple */ static bool InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heaptid, HnswBuildState * buildstate) { HnswGraph *graph = buildstate->graph; HnswElement element; HnswAllocator *allocator = &buildstate->allocator; Size valueSize; Pointer valuePtr; LWLock *flushLock = &graph->flushLock; char *base = buildstate->hnswarea; /* Detoast once for all calls */ Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { if (!HnswNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec)) return false; } /* Get datum size */ valueSize = VARSIZE_ANY(DatumGetPointer(value)); /* Ensure graph not flushed when inserting */ LWLockAcquire(flushLock, LW_SHARED); /* Are we in the on-disk phase? */ if (graph->flushed) { LWLockRelease(flushLock); return HnswInsertTupleOnDisk(index, value, values, isnull, heaptid, true); } /* * In a parallel build, the HnswElement is allocated from the shared * memory area, so we need to coordinate with other processes. */ LWLockAcquire(&graph->allocatorLock, LW_EXCLUSIVE); /* * Check that we have enough memory available for the new element now that * we have the allocator lock, and flush pages if needed. */ if (graph->memoryUsed >= graph->memoryTotal) { LWLockRelease(&graph->allocatorLock); LWLockRelease(flushLock); LWLockAcquire(flushLock, LW_EXCLUSIVE); if (!graph->flushed) { ereport(NOTICE, (errmsg("hnsw graph no longer fits into maintenance_work_mem after " INT64_FORMAT " tuples", (int64) graph->indtuples), errdetail("Building will take significantly more time."), errhint("Increase maintenance_work_mem to speed up builds."))); FlushPages(buildstate); } LWLockRelease(flushLock); return HnswInsertTupleOnDisk(index, value, values, isnull, heaptid, true); } /* Ok, we can proceed to allocate the element */ element = HnswInitElement(base, heaptid, buildstate->m, buildstate->ml, buildstate->maxLevel, allocator); valuePtr = HnswAlloc(allocator, valueSize); /* * We have now allocated the space needed for the element, so we don't * need the allocator lock anymore. Release it and initialize the rest of * the element. */ LWLockRelease(&graph->allocatorLock); /* Copy the datum */ memcpy(valuePtr, DatumGetPointer(value), valueSize); HnswPtrStore(base, element->value, valuePtr); /* Create a lock for the element */ LWLockInitialize(&element->lock, hnsw_lock_tranche_id); /* Insert tuple */ InsertTupleInMemory(buildstate, element); /* Release flush lock */ LWLockRelease(flushLock); return true; } /* * Callback for table_index_build_scan */ static void BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, bool *isnull, bool tupleIsAlive, void *state) { HnswBuildState *buildstate = (HnswBuildState *) state; HnswGraph *graph = buildstate->graph; MemoryContext oldCtx; #if PG_VERSION_NUM < 130000 ItemPointer tid = &hup->t_self; #endif /* Skip nulls */ if (isnull[0]) return; /* Use memory context */ oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); /* Insert tuple */ if (InsertTuple(index, values, isnull, tid, buildstate)) { /* Update progress */ SpinLockAcquire(&graph->lock); pgstat_progress_update_param(PROGRESS_CREATEIDX_TUPLES_DONE, ++graph->indtuples); SpinLockRelease(&graph->lock); } /* Reset memory context */ MemoryContextSwitchTo(oldCtx); MemoryContextReset(buildstate->tmpCtx); } /* * Initialize the graph */ static void InitGraph(HnswGraph * graph, char *base, long memoryTotal) { HnswPtrStore(base, graph->head, (HnswElement) NULL); HnswPtrStore(base, graph->entryPoint, (HnswElement) NULL); graph->memoryUsed = 0; graph->memoryTotal = memoryTotal; graph->flushed = false; graph->indtuples = 0; SpinLockInit(&graph->lock); LWLockInitialize(&graph->entryLock, hnsw_lock_tranche_id); LWLockInitialize(&graph->allocatorLock, hnsw_lock_tranche_id); LWLockInitialize(&graph->flushLock, hnsw_lock_tranche_id); } /* * Initialize an allocator */ static void InitAllocator(HnswAllocator * allocator, void *(*alloc) (Size size, void *state), void *state) { allocator->alloc = alloc; allocator->state = state; } /* * Memory context allocator */ static void * HnswMemoryContextAlloc(Size size, void *state) { HnswBuildState *buildstate = (HnswBuildState *) state; void *chunk = MemoryContextAlloc(buildstate->graphCtx, size); #if PG_VERSION_NUM >= 130000 buildstate->graphData.memoryUsed = MemoryContextMemAllocated(buildstate->graphCtx, false); #else buildstate->graphData.memoryUsed += MAXALIGN(size); #endif return chunk; } /* * Shared memory allocator */ static void * HnswSharedMemoryAlloc(Size size, void *state) { HnswBuildState *buildstate = (HnswBuildState *) state; void *chunk = buildstate->hnswarea + buildstate->graph->memoryUsed; buildstate->graph->memoryUsed += MAXALIGN(size); return chunk; } /* * Initialize the build state */ static void InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum) { buildstate->heap = heap; buildstate->index = index; buildstate->indexInfo = indexInfo; buildstate->forkNum = forkNum; buildstate->m = HnswGetM(index); buildstate->efConstruction = HnswGetEfConstruction(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); if (buildstate->dimensions > HNSW_MAX_DIM) elog(ERROR, "column cannot have more than %d dimensions for hnsw index", HNSW_MAX_DIM); if (buildstate->efConstruction < 2 * buildstate->m) elog(ERROR, "ef_construction must be greater than or equal to 2 * m"); buildstate->reltuples = 0; buildstate->indtuples = 0; /* Get support functions */ buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); buildstate->collation = index->rd_indcollation[0]; InitGraph(&buildstate->graphData, NULL, maintenance_work_mem * 1024L); buildstate->graph = &buildstate->graphData; buildstate->ml = HnswGetMl(buildstate->m); buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); /* Reuse for each tuple */ buildstate->normvec = InitVector(buildstate->dimensions); buildstate->graphCtx = GenerationContextCreate(CurrentMemoryContext, "Hnsw build graph context", #if PG_VERSION_NUM >= 150000 1024 * 1024, 1024 * 1024, #endif 1024 * 1024); buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw build temporary context", ALLOCSET_DEFAULT_SIZES); InitAllocator(&buildstate->allocator, &HnswMemoryContextAlloc, buildstate); buildstate->hnswleader = NULL; buildstate->hnswshared = NULL; buildstate->hnswarea = NULL; } /* * Free resources */ static void FreeBuildState(HnswBuildState * buildstate) { pfree(buildstate->normvec); MemoryContextDelete(buildstate->graphCtx); MemoryContextDelete(buildstate->tmpCtx); } /* * Within leader, wait for end of heap scan */ static double ParallelHeapScan(HnswBuildState * buildstate) { HnswShared *hnswshared = buildstate->hnswleader->hnswshared; int nparticipanttuplesorts; double reltuples; nparticipanttuplesorts = buildstate->hnswleader->nparticipanttuplesorts; for (;;) { SpinLockAcquire(&hnswshared->mutex); if (hnswshared->nparticipantsdone == nparticipanttuplesorts) { buildstate->graph = &hnswshared->graphData; buildstate->hnswarea = buildstate->hnswleader->hnswarea; reltuples = hnswshared->reltuples; SpinLockRelease(&hnswshared->mutex); break; } SpinLockRelease(&hnswshared->mutex); ConditionVariableSleep(&hnswshared->workersdonecv, WAIT_EVENT_PARALLEL_CREATE_INDEX_SCAN); } ConditionVariableCancelSleep(); return reltuples; } /* * Perform a worker's portion of a parallel insert */ static void HnswParallelScanAndInsert(Relation heapRel, Relation indexRel, HnswShared * hnswshared, char *hnswarea, bool progress) { HnswBuildState buildstate; TableScanDesc scan; double reltuples; IndexInfo *indexInfo; /* Join parallel scan */ indexInfo = BuildIndexInfo(indexRel); indexInfo->ii_Concurrent = hnswshared->isconcurrent; InitBuildState(&buildstate, heapRel, indexRel, indexInfo, MAIN_FORKNUM); buildstate.graph = &hnswshared->graphData; buildstate.hnswarea = hnswarea; InitAllocator(&buildstate.allocator, &HnswSharedMemoryAlloc, &buildstate); scan = table_beginscan_parallel(heapRel, ParallelTableScanFromHnswShared(hnswshared)); reltuples = table_index_build_scan(heapRel, indexRel, indexInfo, true, progress, BuildCallback, (void *) &buildstate, scan); /* Record statistics */ SpinLockAcquire(&hnswshared->mutex); hnswshared->nparticipantsdone++; hnswshared->reltuples += reltuples; SpinLockRelease(&hnswshared->mutex); /* Log statistics */ if (progress) ereport(DEBUG1, (errmsg("leader processed " INT64_FORMAT " tuples", (int64) reltuples))); else ereport(DEBUG1, (errmsg("worker processed " INT64_FORMAT " tuples", (int64) reltuples))); /* Notify leader */ ConditionVariableSignal(&hnswshared->workersdonecv); FreeBuildState(&buildstate); } /* * Perform work within a launched parallel process */ void HnswParallelBuildMain(dsm_segment *seg, shm_toc *toc) { char *sharedquery; HnswShared *hnswshared; char *hnswarea; Relation heapRel; Relation indexRel; LOCKMODE heapLockmode; LOCKMODE indexLockmode; /* Set debug_query_string for individual workers first */ sharedquery = shm_toc_lookup(toc, PARALLEL_KEY_QUERY_TEXT, true); debug_query_string = sharedquery; /* Report the query string from leader */ pgstat_report_activity(STATE_RUNNING, debug_query_string); /* Look up shared state */ hnswshared = shm_toc_lookup(toc, PARALLEL_KEY_HNSW_SHARED, false); /* Open relations using lock modes known to be obtained by index.c */ if (!hnswshared->isconcurrent) { heapLockmode = ShareLock; indexLockmode = AccessExclusiveLock; } else { heapLockmode = ShareUpdateExclusiveLock; indexLockmode = RowExclusiveLock; } /* Open relations within worker */ heapRel = table_open(hnswshared->heaprelid, heapLockmode); indexRel = index_open(hnswshared->indexrelid, indexLockmode); hnswarea = shm_toc_lookup(toc, PARALLEL_KEY_HNSW_AREA, false); /* Perform inserts */ HnswParallelScanAndInsert(heapRel, indexRel, hnswshared, hnswarea, false); /* Close relations within worker */ index_close(indexRel, indexLockmode); table_close(heapRel, heapLockmode); } /* * End parallel build */ static void HnswEndParallel(HnswLeader * hnswleader) { /* Shutdown worker processes */ WaitForParallelWorkersToFinish(hnswleader->pcxt); /* Free last reference to MVCC snapshot, if one was used */ if (IsMVCCSnapshot(hnswleader->snapshot)) UnregisterSnapshot(hnswleader->snapshot); DestroyParallelContext(hnswleader->pcxt); ExitParallelMode(); } /* * Return size of shared memory required for parallel index build */ static Size ParallelEstimateShared(Relation heap, Snapshot snapshot) { return add_size(BUFFERALIGN(sizeof(HnswShared)), table_parallelscan_estimate(heap, snapshot)); } /* * Within leader, participate as a parallel worker */ static void HnswLeaderParticipateAsWorker(HnswBuildState * buildstate) { HnswLeader *hnswleader = buildstate->hnswleader; /* Perform work common to all participants */ HnswParallelScanAndInsert(buildstate->heap, buildstate->index, hnswleader->hnswshared, hnswleader->hnswarea, true); } /* * Begin parallel build */ static void HnswBeginParallel(HnswBuildState * buildstate, bool isconcurrent, int request) { ParallelContext *pcxt; Snapshot snapshot; Size esthnswshared; Size esthnswarea; Size estother; HnswShared *hnswshared; char *hnswarea; HnswLeader *hnswleader = (HnswLeader *) palloc0(sizeof(HnswLeader)); bool leaderparticipates = true; int querylen; #ifdef DISABLE_LEADER_PARTICIPATION leaderparticipates = false; #endif /* Enter parallel mode and create context */ EnterParallelMode(); Assert(request > 0); pcxt = CreateParallelContext("vector", "HnswParallelBuildMain", request); /* Get snapshot for table scan */ if (!isconcurrent) snapshot = SnapshotAny; else snapshot = RegisterSnapshot(GetTransactionSnapshot()); /* Estimate size of workspaces */ esthnswshared = ParallelEstimateShared(buildstate->heap, snapshot); shm_toc_estimate_chunk(&pcxt->estimator, esthnswshared); /* Leave space for other objects in shared memory */ /* Docker has a default limit of 64 MB for shm_size */ /* which happens to be the default value of maintenance_work_mem */ esthnswarea = maintenance_work_mem * 1024L; estother = 3 * 1024 * 1024; if (esthnswarea > estother) esthnswarea -= estother; shm_toc_estimate_chunk(&pcxt->estimator, esthnswarea); shm_toc_estimate_keys(&pcxt->estimator, 2); /* Finally, estimate PARALLEL_KEY_QUERY_TEXT space */ if (debug_query_string) { querylen = strlen(debug_query_string); shm_toc_estimate_chunk(&pcxt->estimator, querylen + 1); shm_toc_estimate_keys(&pcxt->estimator, 1); } else querylen = 0; /* keep compiler quiet */ /* Everyone's had a chance to ask for space, so now create the DSM */ InitializeParallelDSM(pcxt); /* If no DSM segment was available, back out (do serial build) */ if (pcxt->seg == NULL) { if (IsMVCCSnapshot(snapshot)) UnregisterSnapshot(snapshot); DestroyParallelContext(pcxt); ExitParallelMode(); return; } /* Store shared build state, for which we reserved space */ hnswshared = (HnswShared *) shm_toc_allocate(pcxt->toc, esthnswshared); /* Initialize immutable state */ hnswshared->heaprelid = RelationGetRelid(buildstate->heap); hnswshared->indexrelid = RelationGetRelid(buildstate->index); hnswshared->isconcurrent = isconcurrent; ConditionVariableInit(&hnswshared->workersdonecv); SpinLockInit(&hnswshared->mutex); /* Initialize mutable state */ hnswshared->nparticipantsdone = 0; hnswshared->reltuples = 0; table_parallelscan_initialize(buildstate->heap, ParallelTableScanFromHnswShared(hnswshared), snapshot); hnswarea = (char *) shm_toc_allocate(pcxt->toc, esthnswarea); /* Report less than allocated so never fails */ InitGraph(&hnswshared->graphData, hnswarea, esthnswarea - 1024 * 1024); shm_toc_insert(pcxt->toc, PARALLEL_KEY_HNSW_SHARED, hnswshared); shm_toc_insert(pcxt->toc, PARALLEL_KEY_HNSW_AREA, hnswarea); /* Store query string for workers */ if (debug_query_string) { char *sharedquery; sharedquery = (char *) shm_toc_allocate(pcxt->toc, querylen + 1); memcpy(sharedquery, debug_query_string, querylen + 1); shm_toc_insert(pcxt->toc, PARALLEL_KEY_QUERY_TEXT, sharedquery); } /* Launch workers, saving status for leader/caller */ LaunchParallelWorkers(pcxt); hnswleader->pcxt = pcxt; hnswleader->nparticipanttuplesorts = pcxt->nworkers_launched; if (leaderparticipates) hnswleader->nparticipanttuplesorts++; hnswleader->hnswshared = hnswshared; hnswleader->snapshot = snapshot; hnswleader->hnswarea = hnswarea; /* If no workers were successfully launched, back out (do serial build) */ if (pcxt->nworkers_launched == 0) { HnswEndParallel(hnswleader); return; } /* Log participants */ ereport(DEBUG1, (errmsg("using %d parallel workers", pcxt->nworkers_launched))); /* Save leader state now that it's clear build will be parallel */ buildstate->hnswleader = hnswleader; /* Join heap scan ourselves */ if (leaderparticipates) HnswLeaderParticipateAsWorker(buildstate); /* Wait for all launched workers */ WaitForParallelWorkersToAttach(pcxt); } /* * Compute parallel workers */ static int ComputeParallelWorkers(Relation heap, Relation index) { int parallel_workers; /* Make sure it's safe to use parallel workers */ parallel_workers = plan_create_index_workers(RelationGetRelid(heap), RelationGetRelid(index)); if (parallel_workers == 0) return 0; /* Use parallel_workers storage parameter on table if set */ parallel_workers = RelationGetParallelWorkers(heap, -1); if (parallel_workers != -1) return Min(parallel_workers, max_parallel_maintenance_workers); return max_parallel_maintenance_workers; } /* * Build graph */ static void BuildGraph(HnswBuildState * buildstate, ForkNumber forkNum) { int parallel_workers = 0; pgstat_progress_update_param(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_HNSW_PHASE_LOAD); /* Calculate parallel workers */ if (buildstate->heap != NULL) parallel_workers = ComputeParallelWorkers(buildstate->heap, buildstate->index); /* Attempt to launch parallel worker scan when required */ if (parallel_workers > 0) HnswBeginParallel(buildstate, buildstate->indexInfo->ii_Concurrent, parallel_workers); /* Add tuples to graph */ if (buildstate->heap != NULL) { if (buildstate->hnswleader) buildstate->reltuples = ParallelHeapScan(buildstate); else buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, true, true, BuildCallback, (void *) buildstate, NULL); buildstate->indtuples = buildstate->graph->indtuples; } /* Flush pages */ if (!buildstate->graph->flushed) FlushPages(buildstate); /* End parallel build */ if (buildstate->hnswleader) HnswEndParallel(buildstate->hnswleader); } /* * Build the index */ static void BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo, HnswBuildState * buildstate, ForkNumber forkNum) { #ifdef HNSW_MEMORY SeedRandom(42); #endif InitBuildState(buildstate, heap, index, indexInfo, forkNum); BuildGraph(buildstate, forkNum); if (RelationNeedsWAL(index)) log_newpage_range(index, forkNum, 0, RelationGetNumberOfBlocks(index), true); FreeBuildState(buildstate); } /* * Build the index for a logged table */ IndexBuildResult * hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo) { IndexBuildResult *result; HnswBuildState buildstate; BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM); result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult)); result->heap_tuples = buildstate.reltuples; result->index_tuples = buildstate.indtuples; return result; } /* * Build the index for an unlogged table */ void hnswbuildempty(Relation index) { IndexInfo *indexInfo = BuildIndexInfo(index); HnswBuildState buildstate; BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM); } pgvector-0.6.0/src/hnswinsert.c000066400000000000000000000410261455577216400165300ustar00rootroot00000000000000#include "postgres.h" #include #include "access/generic_xlog.h" #include "hnsw.h" #include "storage/bufmgr.h" #include "storage/lmgr.h" #include "utils/datum.h" #include "utils/memutils.h" /* * Get the insert page */ static BlockNumber GetInsertPage(Relation index) { Buffer buf; Page page; HnswMetaPage metap; BlockNumber insertPage; buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); metap = HnswPageGetMeta(page); insertPage = metap->insertPage; UnlockReleaseBuffer(buf); return insertPage; } /* * Check for a free offset */ static bool HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage) { OffsetNumber offno; OffsetNumber maxoffno = PageGetMaxOffsetNumber(page); for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); /* Skip neighbor tuples */ if (!HnswIsElementTuple(etup)) continue; if (etup->deleted) { BlockNumber elementPage = BufferGetBlockNumber(buf); BlockNumber neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); OffsetNumber neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); ItemId itemid; if (!BlockNumberIsValid(*newInsertPage)) *newInsertPage = elementPage; if (neighborPage == elementPage) { *nbuf = buf; *npage = page; } else { *nbuf = ReadBuffer(index, neighborPage); LockBuffer(*nbuf, BUFFER_LOCK_EXCLUSIVE); /* Skip WAL for now */ *npage = BufferGetPage(*nbuf); } itemid = PageGetItemId(*npage, neighborOffno); /* Check for space on neighbor tuple page */ if (PageGetFreeSpace(*npage) + ItemIdGetLength(itemid) - sizeof(ItemIdData) >= ntupSize) { *freeOffno = offno; *freeNeighborOffno = neighborOffno; return true; } else if (*nbuf != buf) UnlockReleaseBuffer(*nbuf); } } return false; } /* * Add a new page */ static void HnswInsertAppendPage(Relation index, Buffer *nbuf, Page *npage, GenericXLogState *state, Page page, bool building) { /* Add a new page */ LockRelationForExtension(index, ExclusiveLock); *nbuf = HnswNewBuffer(index, MAIN_FORKNUM); UnlockRelationForExtension(index, ExclusiveLock); /* Init new page */ if (building) *npage = BufferGetPage(*nbuf); else *npage = GenericXLogRegisterBuffer(state, *nbuf, GENERIC_XLOG_FULL_IMAGE); HnswInitPage(*nbuf, *npage); /* Update previous buffer */ HnswPageGetOpaque(page)->nextblkno = BufferGetBlockNumber(*nbuf); } /* * Add to element and neighbor pages */ static void AddElementOnDisk(Relation index, HnswElement e, int m, BlockNumber insertPage, BlockNumber *updatedInsertPage, bool building) { Buffer buf; Page page; GenericXLogState *state; Size etupSize; Size ntupSize; Size combinedSize; Size maxSize; Size minCombinedSize; HnswElementTuple etup; BlockNumber currentPage = insertPage; HnswNeighborTuple ntup; Buffer nbuf; Page npage; OffsetNumber freeOffno = InvalidOffsetNumber; OffsetNumber freeNeighborOffno = InvalidOffsetNumber; BlockNumber newInsertPage = InvalidBlockNumber; char *base = NULL; /* Calculate sizes */ etupSize = HNSW_ELEMENT_TUPLE_SIZE(VARSIZE_ANY(HnswPtrAccess(base, e->value))); ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); combinedSize = etupSize + ntupSize + sizeof(ItemIdData); maxSize = HNSW_MAX_SIZE; minCombinedSize = etupSize + HNSW_NEIGHBOR_TUPLE_SIZE(0, m) + sizeof(ItemIdData); /* Prepare element tuple */ etup = palloc0(etupSize); HnswSetElementTuple(base, etup, e); /* Prepare neighbor tuple */ ntup = palloc0(ntupSize); HnswSetNeighborTuple(base, ntup, e, m); /* Find a page (or two if needed) to insert the tuples */ for (;;) { buf = ReadBuffer(index, currentPage); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); if (building) { state = NULL; page = BufferGetPage(buf); } else { state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); } /* Keep track of first page where element at level 0 can fit */ if (!BlockNumberIsValid(newInsertPage) && PageGetFreeSpace(page) >= minCombinedSize) newInsertPage = currentPage; /* First, try the fastest path */ /* Space for both tuples on the current page */ /* This can split existing tuples in rare cases */ if (PageGetFreeSpace(page) >= combinedSize) { nbuf = buf; npage = page; break; } /* Next, try space from a deleted element */ if (HnswFreeOffset(index, buf, page, e, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage)) { if (nbuf != buf) { if (building) npage = BufferGetPage(nbuf); else npage = GenericXLogRegisterBuffer(state, nbuf, 0); } break; } /* Finally, try space for element only if last page */ /* Skip if both tuples can fit on the same page */ if (combinedSize > maxSize && PageGetFreeSpace(page) >= etupSize && !BlockNumberIsValid(HnswPageGetOpaque(page)->nextblkno)) { HnswInsertAppendPage(index, &nbuf, &npage, state, page, building); break; } currentPage = HnswPageGetOpaque(page)->nextblkno; if (BlockNumberIsValid(currentPage)) { /* Move to next page */ if (!building) GenericXLogAbort(state); UnlockReleaseBuffer(buf); } else { Buffer newbuf; Page newpage; HnswInsertAppendPage(index, &newbuf, &newpage, state, page, building); /* Commit */ if (building) MarkBufferDirty(buf); else GenericXLogFinish(state); /* Unlock previous buffer */ UnlockReleaseBuffer(buf); /* Prepare new buffer */ buf = newbuf; if (building) { state = NULL; page = BufferGetPage(buf); } else { state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); } /* Create new page for neighbors if needed */ if (PageGetFreeSpace(page) < combinedSize) HnswInsertAppendPage(index, &nbuf, &npage, state, page, building); else { nbuf = buf; npage = page; } break; } } e->blkno = BufferGetBlockNumber(buf); e->neighborPage = BufferGetBlockNumber(nbuf); /* Added tuple to new page if newInsertPage is not set */ /* So can set to neighbor page instead of element page */ if (!BlockNumberIsValid(newInsertPage)) newInsertPage = e->neighborPage; if (OffsetNumberIsValid(freeOffno)) { e->offno = freeOffno; e->neighborOffno = freeNeighborOffno; } else { e->offno = OffsetNumberNext(PageGetMaxOffsetNumber(page)); if (nbuf == buf) e->neighborOffno = OffsetNumberNext(e->offno); else e->neighborOffno = FirstOffsetNumber; } ItemPointerSet(&etup->neighbortid, e->neighborPage, e->neighborOffno); /* Add element and neighbors */ if (OffsetNumberIsValid(freeOffno)) { if (!PageIndexTupleOverwrite(page, e->offno, (Item) etup, etupSize)) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); if (!PageIndexTupleOverwrite(npage, e->neighborOffno, (Item) ntup, ntupSize)) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); } else { if (PageAddItem(page, (Item) etup, etupSize, InvalidOffsetNumber, false, false) != e->offno) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); if (PageAddItem(npage, (Item) ntup, ntupSize, InvalidOffsetNumber, false, false) != e->neighborOffno) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); } /* Commit */ if (building) { MarkBufferDirty(buf); if (nbuf != buf) MarkBufferDirty(nbuf); } else GenericXLogFinish(state); UnlockReleaseBuffer(buf); if (nbuf != buf) UnlockReleaseBuffer(nbuf); /* Update the insert page */ if (BlockNumberIsValid(newInsertPage) && newInsertPage != insertPage) *updatedInsertPage = newInsertPage; } /* * Check if connection already exists */ static bool ConnectionExists(HnswElement e, HnswNeighborTuple ntup, int startIdx, int lm) { for (int i = 0; i < lm; i++) { ItemPointer indextid = &ntup->indextids[startIdx + i]; if (!ItemPointerIsValid(indextid)) break; if (ItemPointerGetBlockNumber(indextid) == e->blkno && ItemPointerGetOffsetNumber(indextid) == e->offno) return true; } return false; } /* * Update neighbors */ void HnswUpdateNeighborsOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting, bool building) { char *base = NULL; for (int lc = e->level; lc >= 0; lc--) { int lm = HnswGetLayerM(m, lc); HnswNeighborArray *neighbors = HnswGetNeighbors(base, e, lc); for (int i = 0; i < neighbors->length; i++) { HnswCandidate *hc = &neighbors->items[i]; Buffer buf; Page page; GenericXLogState *state; HnswNeighborTuple ntup; int idx = -1; int startIdx; HnswElement neighborElement = HnswPtrAccess(base, hc->element); OffsetNumber offno = neighborElement->neighborOffno; /* Get latest neighbors since they may have changed */ /* Do not lock yet since selecting neighbors can take time */ HnswLoadNeighbors(neighborElement, index, m); /* * Could improve performance for vacuuming by checking neighbors * against list of elements being deleted to find index. It's * important to exclude already deleted elements for this since * they can be replaced at any time. */ /* Select neighbors */ HnswUpdateConnection(NULL, e, hc, lm, lc, &idx, index, procinfo, collation); /* New element was not selected as a neighbor */ if (idx == -1) continue; /* Register page */ buf = ReadBuffer(index, neighborElement->neighborPage); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); if (building) { state = NULL; page = BufferGetPage(buf); } else { state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); } /* Get tuple */ ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, offno)); /* Calculate index for update */ startIdx = (neighborElement->level - lc) * m; /* Check for existing connection */ if (checkExisting && ConnectionExists(e, ntup, startIdx, lm)) idx = -1; else if (idx == -2) { /* Find free offset if still exists */ /* TODO Retry updating connections if not */ for (int j = 0; j < lm; j++) { if (!ItemPointerIsValid(&ntup->indextids[startIdx + j])) { idx = startIdx + j; break; } } } else idx += startIdx; /* Make robust to issues */ if (idx >= 0 && idx < ntup->count) { ItemPointer indextid = &ntup->indextids[idx]; /* Update neighbor on the buffer */ ItemPointerSet(indextid, e->blkno, e->offno); /* Commit */ if (building) MarkBufferDirty(buf); else GenericXLogFinish(state); } else if (!building) GenericXLogAbort(state); UnlockReleaseBuffer(buf); } } } /* * Add a heap TID to an existing element */ static bool AddDuplicateOnDisk(Relation index, HnswElement element, HnswElement dup, bool building) { Buffer buf; Page page; GenericXLogState *state; HnswElementTuple etup; int i; /* Read page */ buf = ReadBuffer(index, dup->blkno); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); if (building) { state = NULL; page = BufferGetPage(buf); } else { state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); } /* Find space */ etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, dup->offno)); for (i = 0; i < HNSW_HEAPTIDS; i++) { if (!ItemPointerIsValid(&etup->heaptids[i])) break; } /* Either being deleted or we lost our chance to another backend */ if (i == 0 || i == HNSW_HEAPTIDS) { if (!building) GenericXLogAbort(state); UnlockReleaseBuffer(buf); return false; } /* Add heap TID, modifying the tuple on the page directly */ etup->heaptids[i] = element->heaptids[0]; /* Commit */ if (building) MarkBufferDirty(buf); else GenericXLogFinish(state); UnlockReleaseBuffer(buf); return true; } /* * Find duplicate element */ static bool FindDuplicateOnDisk(Relation index, HnswElement element, bool building) { char *base = NULL; HnswNeighborArray *neighbors = HnswGetNeighbors(base, element, 0); Datum value = HnswGetValue(base, element); for (int i = 0; i < neighbors->length; i++) { HnswCandidate *neighbor = &neighbors->items[i]; HnswElement neighborElement = HnswPtrAccess(base, neighbor->element); Datum neighborValue = HnswGetValue(base, neighborElement); /* Exit early since ordered by distance */ if (!datumIsEqual(value, neighborValue, false, -1)) return false; if (AddDuplicateOnDisk(index, element, neighborElement, building)) return true; } return false; } /* * Update graph on disk */ static void UpdateGraphOnDisk(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement entryPoint, bool building) { BlockNumber newInsertPage = InvalidBlockNumber; /* Look for duplicate */ if (FindDuplicateOnDisk(index, element, building)) return; /* Add element */ AddElementOnDisk(index, element, m, GetInsertPage(index), &newInsertPage, building); /* Update insert page if needed */ if (BlockNumberIsValid(newInsertPage)) HnswUpdateMetaPage(index, 0, NULL, newInsertPage, MAIN_FORKNUM, building); /* Update neighbors */ HnswUpdateNeighborsOnDisk(index, procinfo, collation, element, m, false, building); /* Update entry point if needed */ if (entryPoint == NULL || element->level > entryPoint->level) HnswUpdateMetaPage(index, HNSW_UPDATE_ENTRY_GREATER, element, InvalidBlockNumber, MAIN_FORKNUM, building); } /* * Insert a tuple into the index */ bool HnswInsertTupleOnDisk(Relation index, Datum value, Datum *values, bool *isnull, ItemPointer heap_tid, bool building) { HnswElement entryPoint; HnswElement element; int m; int efConstruction = HnswGetEfConstruction(index); FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); Oid collation = index->rd_indcollation[0]; LOCKMODE lockmode = ShareLock; char *base = NULL; /* * Get a shared lock. This allows vacuum to ensure no in-flight inserts * before repairing graph. Use a page lock so it does not interfere with * buffer lock (or reads when vacuuming). */ LockPage(index, HNSW_UPDATE_LOCK, lockmode); /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); /* Create an element */ element = HnswInitElement(base, heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m), NULL); HnswPtrStore(base, element->value, DatumGetPointer(value)); /* Prevent concurrent inserts when likely updating entry point */ if (entryPoint == NULL || element->level > entryPoint->level) { /* Release shared lock */ UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); /* Get exclusive lock */ lockmode = ExclusiveLock; LockPage(index, HNSW_UPDATE_LOCK, lockmode); /* Get latest entry point after lock is acquired */ entryPoint = HnswGetEntryPoint(index); } /* Find neighbors for element */ HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, false); /* Update graph on disk */ UpdateGraphOnDisk(index, procinfo, collation, element, m, efConstruction, entryPoint, building); /* Release lock */ UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); return true; } /* * Insert a tuple into the index */ static void HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid) { Datum value; FmgrInfo *normprocinfo; Oid collation = index->rd_indcollation[0]; /* Detoast once for all calls */ value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Normalize if needed */ normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); if (normprocinfo != NULL) { if (!HnswNormValue(normprocinfo, collation, &value, NULL)) return; } HnswInsertTupleOnDisk(index, value, values, isnull, heap_tid, false); } /* * Insert a tuple into the index */ bool hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique #if PG_VERSION_NUM >= 140000 ,bool indexUnchanged #endif ,IndexInfo *indexInfo ) { MemoryContext oldCtx; MemoryContext insertCtx; /* Skip nulls */ if (isnull[0]) return false; /* Create memory context */ insertCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw insert temporary context", ALLOCSET_DEFAULT_SIZES); oldCtx = MemoryContextSwitchTo(insertCtx); /* Insert tuple */ HnswInsertTuple(index, values, isnull, heap_tid); /* Delete memory context */ MemoryContextSwitchTo(oldCtx); MemoryContextDelete(insertCtx); return false; } pgvector-0.6.0/src/hnswscan.c000066400000000000000000000120601455577216400161440ustar00rootroot00000000000000#include "postgres.h" #include "access/relscan.h" #include "hnsw.h" #include "pgstat.h" #include "storage/bufmgr.h" #include "storage/lmgr.h" #include "utils/memutils.h" /* * Algorithm 5 from paper */ static List * GetScanItems(IndexScanDesc scan, Datum q) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Relation index = scan->indexRelation; FmgrInfo *procinfo = so->procinfo; Oid collation = so->collation; List *ep; List *w; int m; HnswElement entryPoint; char *base = NULL; /* Get m and entry point */ HnswGetMetaPageInfo(index, &m, &entryPoint); if (entryPoint == NULL) return NIL; ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, false)); for (int lc = entryPoint->level; lc >= 1; lc--) { w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, false, NULL); ep = w; } return HnswSearchLayer(base, q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); } /* * Get dimensions from metapage */ static int GetDimensions(Relation index) { Buffer buf; Page page; HnswMetaPage metap; int dimensions; buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); metap = HnswPageGetMeta(page); dimensions = metap->dimensions; UnlockReleaseBuffer(buf); return dimensions; } /* * Get scan value */ static Datum GetScanValue(IndexScanDesc scan) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; Datum value; if (scan->orderByData->sk_flags & SK_ISNULL) value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation))); else { value = scan->orderByData->sk_argument; /* Value should not be compressed or toasted */ Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); /* Fine if normalization fails */ if (so->normprocinfo != NULL) HnswNormValue(so->normprocinfo, so->collation, &value, NULL); } return value; } /* * Prepare for an index scan */ IndexScanDesc hnswbeginscan(Relation index, int nkeys, int norderbys) { IndexScanDesc scan; HnswScanOpaque so; scan = RelationGetIndexScan(index, nkeys, norderbys); so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData)); so->first = true; so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw scan temporary context", ALLOCSET_DEFAULT_SIZES); /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); so->collation = index->rd_indcollation[0]; scan->opaque = so; return scan; } /* * Start or restart an index scan */ void hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; so->first = true; MemoryContextReset(so->tmpCtx); if (keys && scan->numberOfKeys > 0) memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData)); if (orderbys && scan->numberOfOrderBys > 0) memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData)); } /* * Fetch the next tuple in the given scan */ bool hnswgettuple(IndexScanDesc scan, ScanDirection dir) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; MemoryContext oldCtx = MemoryContextSwitchTo(so->tmpCtx); /* * Index can be used to scan backward, but Postgres doesn't support * backward scan on operators */ Assert(ScanDirectionIsForward(dir)); if (so->first) { Datum value; /* Count index scan for stats */ pgstat_count_index_scan(scan->indexRelation); /* Safety check */ if (scan->orderByData == NULL) elog(ERROR, "cannot scan hnsw index without order"); /* Requires MVCC-compliant snapshot as not able to maintain a pin */ /* https://www.postgresql.org/docs/current/index-locking.html */ if (!IsMVCCSnapshot(scan->xs_snapshot)) elog(ERROR, "non-MVCC snapshots are not supported with hnsw"); /* Get scan value */ value = GetScanValue(scan); /* * Get a shared lock. This allows vacuum to ensure no in-flight scans * before marking tuples as deleted. */ LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); so->w = GetScanItems(scan, value); /* Release shared lock */ UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); so->first = false; } while (list_length(so->w) > 0) { char *base = NULL; HnswCandidate *hc = llast(so->w); HnswElement element = HnswPtrAccess(base, hc->element); ItemPointer heaptid; /* Move to next element if no valid heap TIDs */ if (element->heaptidsLength == 0) { so->w = list_delete_last(so->w); continue; } heaptid = &element->heaptids[--element->heaptidsLength]; MemoryContextSwitchTo(oldCtx); scan->xs_heaptid = *heaptid; scan->xs_recheck = false; scan->xs_recheckorderby = false; return true; } MemoryContextSwitchTo(oldCtx); return false; } /* * End a scan and release resources */ void hnswendscan(IndexScanDesc scan) { HnswScanOpaque so = (HnswScanOpaque) scan->opaque; MemoryContextDelete(so->tmpCtx); pfree(so); scan->opaque = NULL; } pgvector-0.6.0/src/hnswutils.c000066400000000000000000000711111455577216400163620ustar00rootroot00000000000000#include "postgres.h" #include #include "access/generic_xlog.h" #include "hnsw.h" #include "lib/pairingheap.h" #include "storage/bufmgr.h" #include "utils/datum.h" #include "utils/memdebug.h" #include "utils/rel.h" #include "vector.h" #if PG_VERSION_NUM >= 130000 #include "common/hashfn.h" #else #include "utils/hashutils.h" #endif #if PG_VERSION_NUM < 170000 static inline uint64 murmurhash64(uint64 data) { uint64 h = data; h ^= h >> 33; h *= 0xff51afd7ed558ccd; h ^= h >> 33; h *= 0xc4ceb9fe1a85ec53; h ^= h >> 33; return h; } #endif /* TID hash table */ static uint32 hash_tid(ItemPointerData tid) { union { uint64 i; ItemPointerData tid; } x; /* Initialize unused bytes */ x.i = 0; x.tid = tid; return murmurhash64(x.i); } #define SH_PREFIX tidhash #define SH_ELEMENT_TYPE TidHashEntry #define SH_KEY_TYPE ItemPointerData #define SH_KEY tid #define SH_HASH_KEY(tb, key) hash_tid(key) #define SH_EQUAL(tb, a, b) ItemPointerEquals(&a, &b) #define SH_SCOPE extern #define SH_DEFINE #include "lib/simplehash.h" /* Pointer hash table */ static uint32 hash_pointer(uintptr_t ptr) { #if SIZEOF_VOID_P == 8 return murmurhash64((uint64) ptr); #else return murmurhash32((uint32) ptr); #endif } #define SH_PREFIX pointerhash #define SH_ELEMENT_TYPE PointerHashEntry #define SH_KEY_TYPE uintptr_t #define SH_KEY ptr #define SH_HASH_KEY(tb, key) hash_pointer(key) #define SH_EQUAL(tb, a, b) (a == b) #define SH_SCOPE extern #define SH_DEFINE #include "lib/simplehash.h" /* Offset hash table */ static uint32 hash_offset(Size offset) { #if SIZEOF_SIZE_T == 8 return murmurhash64((uint64) offset); #else return murmurhash32((uint32) offset); #endif } #define SH_PREFIX offsethash #define SH_ELEMENT_TYPE OffsetHashEntry #define SH_KEY_TYPE Size #define SH_KEY offset #define SH_HASH_KEY(tb, key) hash_offset(key) #define SH_EQUAL(tb, a, b) (a == b) #define SH_SCOPE extern #define SH_DEFINE #include "lib/simplehash.h" typedef union { pointerhash_hash *pointers; offsethash_hash *offsets; tidhash_hash *tids; } visited_hash; /* * Get the max number of connections in an upper layer for each element in the index */ int HnswGetM(Relation index) { HnswOptions *opts = (HnswOptions *) index->rd_options; if (opts) return opts->m; return HNSW_DEFAULT_M; } /* * Get the size of the dynamic candidate list in the index */ int HnswGetEfConstruction(Relation index) { HnswOptions *opts = (HnswOptions *) index->rd_options; if (opts) return opts->efConstruction; return HNSW_DEFAULT_EF_CONSTRUCTION; } /* * Get proc */ FmgrInfo * HnswOptionalProcInfo(Relation index, uint16 procnum) { if (!OidIsValid(index_getprocid(index, 1, procnum))) return NULL; return index_getprocinfo(index, 1, procnum); } /* * Divide by the norm * * Returns false if value should not be indexed * * The caller needs to free the pointer stored in value * if it's different than the original value */ bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) { double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); if (norm > 0) { Vector *v = DatumGetVector(*value); if (result == NULL) result = InitVector(v->dim); for (int i = 0; i < v->dim; i++) result->x[i] = v->x[i] / norm; *value = PointerGetDatum(result); return true; } return false; } /* * New buffer */ Buffer HnswNewBuffer(Relation index, ForkNumber forkNum) { Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); return buf; } /* * Init page */ void HnswInitPage(Buffer buf, Page page) { PageInit(page, BufferGetPageSize(buf), sizeof(HnswPageOpaqueData)); HnswPageGetOpaque(page)->nextblkno = InvalidBlockNumber; HnswPageGetOpaque(page)->page_id = HNSW_PAGE_ID; } /* * Allocate a neighbor array */ static HnswNeighborArray * HnswInitNeighborArray(int lm, HnswAllocator * allocator) { HnswNeighborArray *a = HnswAlloc(allocator, HNSW_NEIGHBOR_ARRAY_SIZE(lm)); a->length = 0; a->closerSet = false; return a; } /* * Allocate neighbors */ void HnswInitNeighbors(char *base, HnswElement element, int m, HnswAllocator * allocator) { int level = element->level; HnswNeighborArrayPtr *neighborList = (HnswNeighborArrayPtr *) HnswAlloc(allocator, sizeof(HnswNeighborArrayPtr) * (level + 1)); HnswPtrStore(base, element->neighbors, neighborList); for (int lc = 0; lc <= level; lc++) HnswPtrStore(base, neighborList[lc], HnswInitNeighborArray(HnswGetLayerM(m, lc), allocator)); } /* * Allocate memory from the allocator */ void * HnswAlloc(HnswAllocator * allocator, Size size) { if (allocator) return (*(allocator)->alloc) (size, (allocator)->state); return palloc(size); } /* * Allocate an element */ HnswElement HnswInitElement(char *base, ItemPointer heaptid, int m, double ml, int maxLevel, HnswAllocator * allocator) { HnswElement element = HnswAlloc(allocator, sizeof(HnswElementData)); int level = (int) (-log(RandomDouble()) * ml); /* Cap level */ if (level > maxLevel) level = maxLevel; element->heaptidsLength = 0; HnswAddHeapTid(element, heaptid); element->level = level; element->deleted = 0; HnswInitNeighbors(base, element, m, allocator); HnswPtrStore(base, element->value, (Pointer) NULL); return element; } /* * Add a heap TID to an element */ void HnswAddHeapTid(HnswElement element, ItemPointer heaptid) { element->heaptids[element->heaptidsLength++] = *heaptid; } /* * Allocate an element from block and offset numbers */ HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno) { HnswElement element = palloc(sizeof(HnswElementData)); char *base = NULL; element->blkno = blkno; element->offno = offno; HnswPtrStore(base, element->neighbors, (HnswNeighborArrayPtr *) NULL); HnswPtrStore(base, element->value, (Pointer) NULL); return element; } /* * Get the metapage info */ void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint) { Buffer buf; Page page; HnswMetaPage metap; buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); metap = HnswPageGetMeta(page); if (m != NULL) *m = metap->m; if (entryPoint != NULL) { if (BlockNumberIsValid(metap->entryBlkno)) { *entryPoint = HnswInitElementFromBlock(metap->entryBlkno, metap->entryOffno); (*entryPoint)->level = metap->entryLevel; } else *entryPoint = NULL; } UnlockReleaseBuffer(buf); } /* * Get the entry point */ HnswElement HnswGetEntryPoint(Relation index) { HnswElement entryPoint; HnswGetMetaPageInfo(index, NULL, &entryPoint); return entryPoint; } /* * Update the metapage info */ static void HnswUpdateMetaPageInfo(Page page, int updateEntry, HnswElement entryPoint, BlockNumber insertPage) { HnswMetaPage metap = HnswPageGetMeta(page); if (updateEntry) { if (entryPoint == NULL) { metap->entryBlkno = InvalidBlockNumber; metap->entryOffno = InvalidOffsetNumber; metap->entryLevel = -1; } else if (entryPoint->level > metap->entryLevel || updateEntry == HNSW_UPDATE_ENTRY_ALWAYS) { metap->entryBlkno = entryPoint->blkno; metap->entryOffno = entryPoint->offno; metap->entryLevel = entryPoint->level; } } if (BlockNumberIsValid(insertPage)) metap->insertPage = insertPage; } /* * Update the metapage */ void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum, bool building) { Buffer buf; Page page; GenericXLogState *state; buf = ReadBufferExtended(index, forkNum, HNSW_METAPAGE_BLKNO, RBM_NORMAL, NULL); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); if (building) { state = NULL; page = BufferGetPage(buf); } else { state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); } HnswUpdateMetaPageInfo(page, updateEntry, entryPoint, insertPage); if (building) MarkBufferDirty(buf); else GenericXLogFinish(state); UnlockReleaseBuffer(buf); } /* * Set element tuple, except for neighbor info */ void HnswSetElementTuple(char *base, HnswElementTuple etup, HnswElement element) { Pointer valuePtr = HnswPtrAccess(base, element->value); etup->type = HNSW_ELEMENT_TUPLE_TYPE; etup->level = element->level; etup->deleted = 0; for (int i = 0; i < HNSW_HEAPTIDS; i++) { if (i < element->heaptidsLength) etup->heaptids[i] = element->heaptids[i]; else ItemPointerSetInvalid(&etup->heaptids[i]); } memcpy(&etup->data, valuePtr, VARSIZE_ANY(valuePtr)); } /* * Set neighbor tuple */ void HnswSetNeighborTuple(char *base, HnswNeighborTuple ntup, HnswElement e, int m) { int idx = 0; ntup->type = HNSW_NEIGHBOR_TUPLE_TYPE; for (int lc = e->level; lc >= 0; lc--) { HnswNeighborArray *neighbors = HnswGetNeighbors(base, e, lc); int lm = HnswGetLayerM(m, lc); for (int i = 0; i < lm; i++) { ItemPointer indextid = &ntup->indextids[idx++]; if (i < neighbors->length) { HnswCandidate *hc = &neighbors->items[i]; HnswElement hce = HnswPtrAccess(base, hc->element); ItemPointerSet(indextid, hce->blkno, hce->offno); } else ItemPointerSetInvalid(indextid); } } ntup->count = idx; } /* * Load neighbors from page */ static void LoadNeighborsFromPage(HnswElement element, Relation index, Page page, int m) { char *base = NULL; HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); int neighborCount = (element->level + 2) * m; Assert(HnswIsNeighborTuple(ntup)); HnswInitNeighbors(base, element, m, NULL); /* Ensure expected neighbors */ if (ntup->count != neighborCount) return; for (int i = 0; i < neighborCount; i++) { HnswElement e; int level; HnswCandidate *hc; ItemPointer indextid; HnswNeighborArray *neighbors; indextid = &ntup->indextids[i]; if (!ItemPointerIsValid(indextid)) continue; e = HnswInitElementFromBlock(ItemPointerGetBlockNumber(indextid), ItemPointerGetOffsetNumber(indextid)); /* Calculate level based on offset */ level = element->level - i / m; if (level < 0) level = 0; neighbors = HnswGetNeighbors(base, element, level); hc = &neighbors->items[neighbors->length++]; HnswPtrStore(base, hc->element, e); } } /* * Load neighbors */ void HnswLoadNeighbors(HnswElement element, Relation index, int m) { Buffer buf; Page page; buf = ReadBuffer(index, element->neighborPage); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); LoadNeighborsFromPage(element, index, page, m); UnlockReleaseBuffer(buf); } /* * Load an element from a tuple */ void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec) { element->level = etup->level; element->deleted = etup->deleted; element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); element->heaptidsLength = 0; if (loadHeaptids) { for (int i = 0; i < HNSW_HEAPTIDS; i++) { /* Can stop at first invalid */ if (!ItemPointerIsValid(&etup->heaptids[i])) break; HnswAddHeapTid(element, &etup->heaptids[i]); } } if (loadVec) { char *base = NULL; Datum value = datumCopy(PointerGetDatum(&etup->data), false, -1); HnswPtrStore(base, element->value, DatumGetPointer(value)); } } /* * Load an element and optionally get its distance from q */ void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) { Buffer buf; Page page; HnswElementTuple etup; /* Read vector */ buf = ReadBuffer(index, element->blkno); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno)); Assert(HnswIsElementTuple(etup)); /* Load element */ HnswLoadElementFromTuple(element, etup, true, loadVec); /* Calculate distance */ if (distance != NULL) *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->data))); UnlockReleaseBuffer(buf); } /* * Get the distance for a candidate */ static float GetCandidateDistance(char *base, HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) { HnswElement hce = HnswPtrAccess(base, hc->element); Datum value = HnswGetValue(base, hce); return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, value)); } /* * Create a candidate for the entry point */ HnswCandidate * HnswEntryCandidate(char *base, HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) { HnswCandidate *hc = palloc(sizeof(HnswCandidate)); HnswPtrStore(base, hc->element, entryPoint); if (index == NULL) hc->distance = GetCandidateDistance(base, hc, q, procinfo, collation); else HnswLoadElement(entryPoint, &hc->distance, &q, index, procinfo, collation, loadVec); return hc; } /* * Compare candidate distances */ static int CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) { if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) return 1; if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) return -1; return 0; } /* * Compare candidate distances */ static int CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) { if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) return -1; if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) return 1; return 0; } /* * Create a pairing heap node for a candidate */ static HnswPairingHeapNode * CreatePairingHeapNode(HnswCandidate * c) { HnswPairingHeapNode *node = palloc(sizeof(HnswPairingHeapNode)); node->inner = c; return node; } /* * Init visited */ static inline void InitVisited(char *base, visited_hash * v, Relation index, int ef, int m) { if (index != NULL) v->tids = tidhash_create(CurrentMemoryContext, ef * m * 2, NULL); else if (base != NULL) v->offsets = offsethash_create(CurrentMemoryContext, ef * m * 2, NULL); else v->pointers = pointerhash_create(CurrentMemoryContext, ef * m * 2, NULL); } /* * Add to visited */ static inline void AddToVisited(char *base, visited_hash * v, HnswCandidate * hc, Relation index, bool *found) { if (index != NULL) { HnswElement element = HnswPtrAccess(base, hc->element); ItemPointerData indextid; ItemPointerSet(&indextid, element->blkno, element->offno); tidhash_insert(v->tids, indextid, found); } else if (base != NULL) { #if PG_VERSION_NUM >= 130000 HnswElement element = HnswPtrAccess(base, hc->element); offsethash_insert_hash(v->offsets, HnswPtrOffset(hc->element), element->hash, found); #else offsethash_insert(v->offsets, HnswPtrOffset(hc->element), found); #endif } else { #if PG_VERSION_NUM >= 130000 HnswElement element = HnswPtrAccess(base, hc->element); pointerhash_insert_hash(v->pointers, (uintptr_t) HnswPtrPointer(hc->element), element->hash, found); #else pointerhash_insert(v->pointers, (uintptr_t) HnswPtrPointer(hc->element), found); #endif } } /* * Count element towards ef */ static inline bool CountElement(char *base, HnswElement skipElement, HnswCandidate * hc) { HnswElement e; if (skipElement == NULL) return true; /* Ensure does not access heaptidsLength during in-memory build */ pg_memory_barrier(); e = HnswPtrAccess(base, hc->element); return e->heaptidsLength != 0; } /* * Algorithm 2 from paper */ List * HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) { List *w = NIL; pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); int wlen = 0; visited_hash v; ListCell *lc2; HnswNeighborArray *neighborhoodData = NULL; Size neighborhoodSize; InitVisited(base, &v, index, ef, m); /* Create local memory for neighborhood if needed */ if (index == NULL) { neighborhoodSize = HNSW_NEIGHBOR_ARRAY_SIZE(HnswGetLayerM(m, lc)); neighborhoodData = palloc(neighborhoodSize); } /* Add entry points to v, C, and W */ foreach(lc2, ep) { HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); bool found; AddToVisited(base, &v, hc, index, &found); pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node)); pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node)); /* * Do not count elements being deleted towards ef when vacuuming. It * would be ideal to do this for inserts as well, but this could * affect insert performance. */ if (CountElement(base, skipElement, hc)) wlen++; } while (!pairingheap_is_empty(C)) { HnswNeighborArray *neighborhood; HnswCandidate *c = ((HnswPairingHeapNode *) pairingheap_remove_first(C))->inner; HnswCandidate *f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; HnswElement cElement; if (c->distance > f->distance) break; cElement = HnswPtrAccess(base, c->element); if (HnswPtrIsNull(base, cElement->neighbors)) HnswLoadNeighbors(cElement, index, m); /* Get the neighborhood at layer lc */ neighborhood = HnswGetNeighbors(base, cElement, lc); /* Copy neighborhood to local memory if needed */ if (index == NULL) { LWLockAcquire(&cElement->lock, LW_SHARED); memcpy(neighborhoodData, neighborhood, neighborhoodSize); LWLockRelease(&cElement->lock); neighborhood = neighborhoodData; } for (int i = 0; i < neighborhood->length; i++) { HnswCandidate *e = &neighborhood->items[i]; bool visited; AddToVisited(base, &v, e, index, &visited); if (!visited) { float eDistance; HnswElement eElement = HnswPtrAccess(base, e->element); f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; if (index == NULL) eDistance = GetCandidateDistance(base, e, q, procinfo, collation); else HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting); Assert(!eElement->deleted); /* Make robust to issues */ if (eElement->level < lc) continue; if (eDistance < f->distance || wlen < ef) { /* Copy e */ HnswCandidate *ec = palloc(sizeof(HnswCandidate)); HnswPtrStore(base, ec->element, eElement); ec->distance = eDistance; pairingheap_add(C, &(CreatePairingHeapNode(ec)->ph_node)); pairingheap_add(W, &(CreatePairingHeapNode(ec)->ph_node)); /* * Do not count elements being deleted towards ef when * vacuuming. It would be ideal to do this for inserts as * well, but this could affect insert performance. */ if (CountElement(base, skipElement, e)) { wlen++; /* No need to decrement wlen */ if (wlen > ef) pairingheap_remove_first(W); } } } } } /* Add each element of W to w */ while (!pairingheap_is_empty(W)) { HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner; w = lappend(w, hc); } return w; } /* * Compare candidate distances with pointer tie-breaker */ static int #if PG_VERSION_NUM >= 130000 CompareCandidateDistances(const ListCell *a, const ListCell *b) #else CompareCandidateDistances(const void *a, const void *b) #endif { HnswCandidate *hca = lfirst((ListCell *) a); HnswCandidate *hcb = lfirst((ListCell *) b); if (hca->distance < hcb->distance) return 1; if (hca->distance > hcb->distance) return -1; if (HnswPtrPointer(hca->element) < HnswPtrPointer(hcb->element)) return 1; if (HnswPtrPointer(hca->element) > HnswPtrPointer(hcb->element)) return -1; return 0; } /* * Compare candidate distances with offset tie-breaker */ static int #if PG_VERSION_NUM >= 130000 CompareCandidateDistancesOffset(const ListCell *a, const ListCell *b) #else CompareCandidateDistancesOffset(const void *a, const void *b) #endif { HnswCandidate *hca = lfirst((ListCell *) a); HnswCandidate *hcb = lfirst((ListCell *) b); if (hca->distance < hcb->distance) return 1; if (hca->distance > hcb->distance) return -1; if (HnswPtrOffset(hca->element) < HnswPtrOffset(hcb->element)) return 1; if (HnswPtrOffset(hca->element) > HnswPtrOffset(hcb->element)) return -1; return 0; } /* * Calculate the distance between elements */ static float HnswGetDistance(char *base, HnswElement a, HnswElement b, FmgrInfo *procinfo, Oid collation) { Datum aValue = HnswGetValue(base, a); Datum bValue = HnswGetValue(base, b); return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, aValue, bValue)); } /* * Check if an element is closer to q than any element from R */ static bool CheckElementCloser(char *base, HnswCandidate * e, List *r, FmgrInfo *procinfo, Oid collation) { HnswElement eElement = HnswPtrAccess(base, e->element); ListCell *lc2; foreach(lc2, r) { HnswCandidate *ri = lfirst(lc2); HnswElement riElement = HnswPtrAccess(base, ri->element); float distance = HnswGetDistance(base, eElement, riElement, procinfo, collation); if (distance <= e->distance) return false; } return true; } /* * Algorithm 4 from paper */ static List * SelectNeighbors(char *base, List *c, int lm, int lc, FmgrInfo *procinfo, Oid collation, HnswElement e2, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) { List *r = NIL; List *w = list_copy(c); pairingheap *wd; HnswNeighborArray *neighbors = HnswGetNeighbors(base, e2, lc); bool mustCalculate = !neighbors->closerSet; List *added = NIL; bool removedAny = false; if (list_length(w) <= lm) return w; wd = pairingheap_allocate(CompareNearestCandidates, NULL); /* Ensure order of candidates is deterministic for closer caching */ if (sortCandidates) { if (base == NULL) list_sort(w, CompareCandidateDistances); else list_sort(w, CompareCandidateDistancesOffset); } while (list_length(w) > 0 && list_length(r) < lm) { /* Assumes w is already ordered desc */ HnswCandidate *e = llast(w); w = list_delete_last(w); /* Use previous state of r and wd to skip work when possible */ if (mustCalculate) e->closer = CheckElementCloser(base, e, r, procinfo, collation); else if (list_length(added) > 0) { /* Keep Valgrind happy for in-memory, parallel builds */ if (base != NULL) VALGRIND_MAKE_MEM_DEFINED(&e->closer, 1); /* * If the current candidate was closer, we only need to compare it * with the other candidates that we have added. */ if (e->closer) { e->closer = CheckElementCloser(base, e, added, procinfo, collation); if (!e->closer) removedAny = true; } else { /* * If we have removed any candidates from closer, a candidate * that was not closer earlier might now be. */ if (removedAny) { e->closer = CheckElementCloser(base, e, r, procinfo, collation); if (e->closer) added = lappend(added, e); } } } else if (e == newCandidate) { e->closer = CheckElementCloser(base, e, r, procinfo, collation); if (e->closer) added = lappend(added, e); } /* Keep Valgrind happy for in-memory, parallel builds */ if (base != NULL) VALGRIND_MAKE_MEM_DEFINED(&e->closer, 1); if (e->closer) r = lappend(r, e); else pairingheap_add(wd, &(CreatePairingHeapNode(e)->ph_node)); } /* Cached value can only be used in future if sorted deterministically */ neighbors->closerSet = sortCandidates; /* Keep pruned connections */ while (!pairingheap_is_empty(wd) && list_length(r) < lm) r = lappend(r, ((HnswPairingHeapNode *) pairingheap_remove_first(wd))->inner); /* Return pruned for update connections */ if (pruned != NULL) { if (!pairingheap_is_empty(wd)) *pruned = ((HnswPairingHeapNode *) pairingheap_first(wd))->inner; else *pruned = linitial(w); } return r; } /* * Add connections */ static void AddConnections(char *base, HnswElement element, List *neighbors, int lc) { ListCell *lc2; HnswNeighborArray *a = HnswGetNeighbors(base, element, lc); foreach(lc2, neighbors) a->items[a->length++] = *((HnswCandidate *) lfirst(lc2)); } /* * Update connections */ void HnswUpdateConnection(char *base, HnswElement element, HnswCandidate * hc, int lm, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation) { HnswElement hce = HnswPtrAccess(base, hc->element); HnswNeighborArray *currentNeighbors = HnswGetNeighbors(base, hce, lc); HnswCandidate hc2; HnswPtrStore(base, hc2.element, element); hc2.distance = hc->distance; if (currentNeighbors->length < lm) { currentNeighbors->items[currentNeighbors->length++] = hc2; /* Track update */ if (updateIdx != NULL) *updateIdx = -2; } else { /* Shrink connections */ HnswCandidate *pruned = NULL; /* Load elements on insert */ if (index != NULL) { Datum q = HnswGetValue(base, hce); for (int i = 0; i < currentNeighbors->length; i++) { HnswCandidate *hc3 = ¤tNeighbors->items[i]; HnswElement hc3Element = HnswPtrAccess(base, hc3->element); if (HnswPtrIsNull(base, hc3Element->value)) HnswLoadElement(hc3Element, &hc3->distance, &q, index, procinfo, collation, true); else hc3->distance = GetCandidateDistance(base, hc3, q, procinfo, collation); /* Prune element if being deleted */ if (hc3Element->heaptidsLength == 0) { pruned = ¤tNeighbors->items[i]; break; } } } if (pruned == NULL) { List *c = NIL; /* Add candidates */ for (int i = 0; i < currentNeighbors->length; i++) c = lappend(c, ¤tNeighbors->items[i]); c = lappend(c, &hc2); SelectNeighbors(base, c, lm, lc, procinfo, collation, hce, &hc2, &pruned, true); /* Should not happen */ if (pruned == NULL) return; } /* Find and replace the pruned element */ for (int i = 0; i < currentNeighbors->length; i++) { if (HnswPtrEqual(base, currentNeighbors->items[i].element, pruned->element)) { currentNeighbors->items[i] = hc2; /* Track update */ if (updateIdx != NULL) *updateIdx = i; break; } } } } /* * Remove elements being deleted or skipped */ static List * RemoveElements(char *base, List *w, HnswElement skipElement) { ListCell *lc2; List *w2 = NIL; /* Ensure does not access heaptidsLength during in-memory build */ pg_memory_barrier(); foreach(lc2, w) { HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); HnswElement hce = HnswPtrAccess(base, hc->element); /* Skip self for vacuuming update */ if (skipElement != NULL && hce->blkno == skipElement->blkno && hce->offno == skipElement->offno) continue; if (hce->heaptidsLength != 0) w2 = lappend(w2, hc); } return w2; } #if PG_VERSION_NUM >= 130000 /* * Precompute hash */ static void PrecomputeHash(char *base, HnswElement element) { HnswElementPtr ptr; HnswPtrStore(base, ptr, element); if (base == NULL) element->hash = hash_pointer((uintptr_t) HnswPtrPointer(ptr)); else element->hash = hash_offset(HnswPtrOffset(ptr)); } #endif /* * Algorithm 1 from paper */ void HnswFindElementNeighbors(char *base, HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing) { List *ep; List *w; int level = element->level; int entryLevel; Datum q = HnswGetValue(base, element); HnswElement skipElement = existing ? element : NULL; #if PG_VERSION_NUM >= 130000 /* Precompute hash */ if (index == NULL) PrecomputeHash(base, element); #endif /* No neighbors if no entry point */ if (entryPoint == NULL) return; /* Get entry point and level */ ep = list_make1(HnswEntryCandidate(base, entryPoint, q, index, procinfo, collation, true)); entryLevel = entryPoint->level; /* 1st phase: greedy search to insert level */ for (int lc = entryLevel; lc >= level + 1; lc--) { w = HnswSearchLayer(base, q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); ep = w; } if (level > entryLevel) level = entryLevel; /* Add one for existing element */ if (existing) efConstruction++; /* 2nd phase */ for (int lc = level; lc >= 0; lc--) { int lm = HnswGetLayerM(m, lc); List *neighbors; List *lw; w = HnswSearchLayer(base, q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); /* Elements being deleted or skipped can help with search */ /* but should be removed before selecting neighbors */ if (index != NULL) lw = RemoveElements(base, w, skipElement); else lw = w; /* * Candidates are sorted, but not deterministically. Could set * sortCandidates to true for in-memory builds to enable closer * caching, but there does not seem to be a difference in performance. */ neighbors = SelectNeighbors(base, lw, lm, lc, procinfo, collation, element, NULL, NULL, false); AddConnections(base, element, neighbors, lc); ep = w; } } pgvector-0.6.0/src/hnswvacuum.c000066400000000000000000000416711455577216400165320ustar00rootroot00000000000000#include "postgres.h" #include #include "access/generic_xlog.h" #include "commands/vacuum.h" #include "hnsw.h" #include "storage/bufmgr.h" #include "storage/lmgr.h" #include "utils/memutils.h" /* * Check if deleted list contains an index TID */ static bool DeletedContains(tidhash_hash * deleted, ItemPointer indextid) { return tidhash_lookup(deleted, *indextid) != NULL; } /* * Remove deleted heap TIDs * * OK to remove for entry point, since always considered for searches and inserts */ static void RemoveHeapTids(HnswVacuumState * vacuumstate) { BlockNumber blkno = HNSW_HEAD_BLKNO; HnswElement highestPoint = &vacuumstate->highestPoint; Relation index = vacuumstate->index; BufferAccessStrategy bas = vacuumstate->bas; HnswElement entryPoint = HnswGetEntryPoint(vacuumstate->index); IndexBulkDeleteResult *stats = vacuumstate->stats; /* Store separately since highestPoint.level is uint8 */ int highestLevel = -1; /* Initialize highest point */ highestPoint->blkno = InvalidBlockNumber; highestPoint->offno = InvalidOffsetNumber; while (BlockNumberIsValid(blkno)) { Buffer buf; Page page; GenericXLogState *state; OffsetNumber offno; OffsetNumber maxoffno; bool updated = false; vacuum_delay_point(); buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); maxoffno = PageGetMaxOffsetNumber(page); /* Iterate over nodes */ for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); int idx = 0; bool itemUpdated = false; /* Skip neighbor tuples */ if (!HnswIsElementTuple(etup)) continue; if (ItemPointerIsValid(&etup->heaptids[0])) { for (int i = 0; i < HNSW_HEAPTIDS; i++) { /* Stop at first unused */ if (!ItemPointerIsValid(&etup->heaptids[i])) break; if (vacuumstate->callback(&etup->heaptids[i], vacuumstate->callback_state)) { itemUpdated = true; stats->tuples_removed++; } else { /* Move to front of list */ etup->heaptids[idx++] = etup->heaptids[i]; stats->num_index_tuples++; } } if (itemUpdated) { /* Mark rest as invalid */ for (int i = idx; i < HNSW_HEAPTIDS; i++) ItemPointerSetInvalid(&etup->heaptids[i]); updated = true; } } if (!ItemPointerIsValid(&etup->heaptids[0])) { ItemPointerData ip; bool found; /* Add to deleted list */ ItemPointerSet(&ip, blkno, offno); tidhash_insert(vacuumstate->deleted, ip, &found); Assert(!found); } else if (etup->level > highestLevel && !(entryPoint != NULL && blkno == entryPoint->blkno && offno == entryPoint->offno)) { /* Keep track of highest non-entry point */ highestPoint->blkno = blkno; highestPoint->offno = offno; highestPoint->level = etup->level; highestLevel = etup->level; } } blkno = HnswPageGetOpaque(page)->nextblkno; if (updated) GenericXLogFinish(state); else GenericXLogAbort(state); UnlockReleaseBuffer(buf); } } /* * Check for deleted neighbors */ static bool NeedsUpdated(HnswVacuumState * vacuumstate, HnswElement element) { Relation index = vacuumstate->index; BufferAccessStrategy bas = vacuumstate->bas; Buffer buf; Page page; HnswNeighborTuple ntup; bool needsUpdated = false; buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); Assert(HnswIsNeighborTuple(ntup)); /* Check neighbors */ for (int i = 0; i < ntup->count; i++) { ItemPointer indextid = &ntup->indextids[i]; if (!ItemPointerIsValid(indextid)) continue; /* Check if in deleted list */ if (DeletedContains(vacuumstate->deleted, indextid)) { needsUpdated = true; break; } } /* Also update if layer 0 is not full */ /* This could indicate too many candidates being deleted during insert */ if (!needsUpdated) needsUpdated = !ItemPointerIsValid(&ntup->indextids[ntup->count - 1]); UnlockReleaseBuffer(buf); return needsUpdated; } /* * Repair graph for a single element */ static void RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswElement entryPoint) { Relation index = vacuumstate->index; Buffer buf; Page page; GenericXLogState *state; int m = vacuumstate->m; int efConstruction = vacuumstate->efConstruction; FmgrInfo *procinfo = vacuumstate->procinfo; Oid collation = vacuumstate->collation; BufferAccessStrategy bas = vacuumstate->bas; HnswNeighborTuple ntup = vacuumstate->ntup; Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m); char *base = NULL; /* Skip if element is entry point */ if (entryPoint != NULL && element->blkno == entryPoint->blkno && element->offno == entryPoint->offno) return; /* Init fields */ HnswInitNeighbors(base, element, m, NULL); element->heaptidsLength = 0; /* Find neighbors for element, skipping itself */ HnswFindElementNeighbors(base, element, entryPoint, index, procinfo, collation, m, efConstruction, true); /* Zero memory for each element */ MemSet(ntup, 0, HNSW_TUPLE_ALLOC_SIZE); /* Update neighbor tuple */ /* Do this before getting page to minimize locking */ HnswSetNeighborTuple(base, ntup, element, m); /* Get neighbor page */ buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); /* Overwrite tuple */ if (!PageIndexTupleOverwrite(page, element->neighborOffno, (Item) ntup, ntupSize)) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); /* Commit */ GenericXLogFinish(state); UnlockReleaseBuffer(buf); /* Update neighbors */ HnswUpdateNeighborsOnDisk(index, procinfo, collation, element, m, true, false); } /* * Repair graph entry point */ static void RepairGraphEntryPoint(HnswVacuumState * vacuumstate) { Relation index = vacuumstate->index; HnswElement highestPoint = &vacuumstate->highestPoint; HnswElement entryPoint; MemoryContext oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); if (!BlockNumberIsValid(highestPoint->blkno)) highestPoint = NULL; /* * Repair graph for highest non-entry point. Highest point may be outdated * due to inserts that happen during and after RemoveHeapTids. */ if (highestPoint != NULL) { /* Get a shared lock */ LockPage(index, HNSW_UPDATE_LOCK, ShareLock); /* Load element */ HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); /* Repair if needed */ if (NeedsUpdated(vacuumstate, highestPoint)) RepairGraphElement(vacuumstate, highestPoint, HnswGetEntryPoint(index)); /* Release lock */ UnlockPage(index, HNSW_UPDATE_LOCK, ShareLock); } /* Prevent concurrent inserts when possibly updating entry point */ LockPage(index, HNSW_UPDATE_LOCK, ExclusiveLock); /* Get latest entry point */ entryPoint = HnswGetEntryPoint(index); if (entryPoint != NULL) { ItemPointerData epData; ItemPointerSet(&epData, entryPoint->blkno, entryPoint->offno); if (DeletedContains(vacuumstate->deleted, &epData)) { /* * Replace the entry point with the highest point. If highest * point is outdated and empty, the entry point will be empty * until an element is repaired. */ HnswUpdateMetaPage(index, HNSW_UPDATE_ENTRY_ALWAYS, highestPoint, InvalidBlockNumber, MAIN_FORKNUM, false); } else { /* * Repair the entry point with the highest point. If highest point * is outdated, this can remove connections at higher levels in * the graph until they are repaired, but this should be fine. */ HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); if (NeedsUpdated(vacuumstate, entryPoint)) { /* Reset neighbors from previous update */ if (highestPoint != NULL) HnswPtrStore((char *) NULL, highestPoint->neighbors, (HnswNeighborArrayPtr *) NULL); RepairGraphElement(vacuumstate, entryPoint, highestPoint); } } } /* Release lock */ UnlockPage(index, HNSW_UPDATE_LOCK, ExclusiveLock); /* Reset memory context */ MemoryContextSwitchTo(oldCtx); MemoryContextReset(vacuumstate->tmpCtx); } /* * Repair graph for all elements */ static void RepairGraph(HnswVacuumState * vacuumstate) { Relation index = vacuumstate->index; BufferAccessStrategy bas = vacuumstate->bas; BlockNumber blkno = HNSW_HEAD_BLKNO; /* * Wait for inserts to complete. Inserts before this point may have * neighbors about to be deleted. Inserts after this point will not. */ LockPage(index, HNSW_UPDATE_LOCK, ExclusiveLock); UnlockPage(index, HNSW_UPDATE_LOCK, ExclusiveLock); /* Repair entry point first */ RepairGraphEntryPoint(vacuumstate); while (BlockNumberIsValid(blkno)) { Buffer buf; Page page; OffsetNumber offno; OffsetNumber maxoffno; List *elements = NIL; ListCell *lc2; MemoryContext oldCtx; vacuum_delay_point(); oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); maxoffno = PageGetMaxOffsetNumber(page); /* Load items into memory to minimize locking */ for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); HnswElement element; /* Skip neighbor tuples */ if (!HnswIsElementTuple(etup)) continue; /* Skip updating neighbors if being deleted */ if (!ItemPointerIsValid(&etup->heaptids[0])) continue; /* Create an element */ element = HnswInitElementFromBlock(blkno, offno); HnswLoadElementFromTuple(element, etup, false, true); elements = lappend(elements, element); } blkno = HnswPageGetOpaque(page)->nextblkno; UnlockReleaseBuffer(buf); /* Update neighbor pages */ foreach(lc2, elements) { HnswElement element = (HnswElement) lfirst(lc2); HnswElement entryPoint; LOCKMODE lockmode = ShareLock; /* Check if any neighbors point to deleted values */ if (!NeedsUpdated(vacuumstate, element)) continue; /* Get a shared lock */ LockPage(index, HNSW_UPDATE_LOCK, lockmode); /* Refresh entry point for each element */ entryPoint = HnswGetEntryPoint(index); /* Prevent concurrent inserts when likely updating entry point */ if (entryPoint == NULL || element->level > entryPoint->level) { /* Release shared lock */ UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); /* Get exclusive lock */ lockmode = ExclusiveLock; LockPage(index, HNSW_UPDATE_LOCK, lockmode); /* Get latest entry point after lock is acquired */ entryPoint = HnswGetEntryPoint(index); } /* Repair connections */ RepairGraphElement(vacuumstate, element, entryPoint); /* * Update metapage if needed. Should only happen if entry point * was replaced and highest point was outdated. */ if (entryPoint == NULL || element->level > entryPoint->level) HnswUpdateMetaPage(index, HNSW_UPDATE_ENTRY_GREATER, element, InvalidBlockNumber, MAIN_FORKNUM, false); /* Release lock */ UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); } /* Reset memory context */ MemoryContextSwitchTo(oldCtx); MemoryContextReset(vacuumstate->tmpCtx); } } /* * Mark items as deleted */ static void MarkDeleted(HnswVacuumState * vacuumstate) { BlockNumber blkno = HNSW_HEAD_BLKNO; BlockNumber insertPage = InvalidBlockNumber; Relation index = vacuumstate->index; BufferAccessStrategy bas = vacuumstate->bas; /* * Wait for index scans to complete. Scans before this point may contain * tuples about to be deleted. Scans after this point will not, since the * graph has been repaired. */ LockPage(index, HNSW_SCAN_LOCK, ExclusiveLock); UnlockPage(index, HNSW_SCAN_LOCK, ExclusiveLock); while (BlockNumberIsValid(blkno)) { Buffer buf; Page page; GenericXLogState *state; OffsetNumber offno; OffsetNumber maxoffno; vacuum_delay_point(); buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); /* * ambulkdelete cannot delete entries from pages that are pinned by * other backends * * https://www.postgresql.org/docs/current/index-locking.html */ LockBufferForCleanup(buf); state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); maxoffno = PageGetMaxOffsetNumber(page); /* Update element and neighbors together */ for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); HnswNeighborTuple ntup; Buffer nbuf; Page npage; BlockNumber neighborPage; OffsetNumber neighborOffno; /* Skip neighbor tuples */ if (!HnswIsElementTuple(etup)) continue; /* Skip deleted tuples */ if (etup->deleted) { /* Set to first free page */ if (!BlockNumberIsValid(insertPage)) insertPage = blkno; continue; } /* Skip live tuples */ if (ItemPointerIsValid(&etup->heaptids[0])) continue; /* Get neighbor page */ neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); if (neighborPage == blkno) { nbuf = buf; npage = page; } else { nbuf = ReadBufferExtended(index, MAIN_FORKNUM, neighborPage, RBM_NORMAL, bas); LockBuffer(nbuf, BUFFER_LOCK_EXCLUSIVE); npage = GenericXLogRegisterBuffer(state, nbuf, 0); } ntup = (HnswNeighborTuple) PageGetItem(npage, PageGetItemId(npage, neighborOffno)); /* Overwrite element */ etup->deleted = 1; MemSet(&etup->data, 0, VARSIZE_ANY(&etup->data)); /* Overwrite neighbors */ for (int i = 0; i < ntup->count; i++) ItemPointerSetInvalid(&ntup->indextids[i]); /* * We modified the tuples in place, no need to call * PageIndexTupleOverwrite */ /* Commit */ GenericXLogFinish(state); if (nbuf != buf) UnlockReleaseBuffer(nbuf); /* Set to first free page */ if (!BlockNumberIsValid(insertPage)) insertPage = blkno; /* Prepare new xlog */ state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); } blkno = HnswPageGetOpaque(page)->nextblkno; GenericXLogAbort(state); UnlockReleaseBuffer(buf); } /* Update insert page last, after everything has been marked as deleted */ HnswUpdateMetaPage(index, 0, NULL, insertPage, MAIN_FORKNUM, false); } /* * Initialize the vacuum state */ static void InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state) { Relation index = info->index; if (stats == NULL) stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult)); vacuumstate->index = index; vacuumstate->stats = stats; vacuumstate->callback = callback; vacuumstate->callback_state = callback_state; vacuumstate->efConstruction = HnswGetEfConstruction(index); vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD); vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); vacuumstate->collation = index->rd_indcollation[0]; vacuumstate->ntup = palloc0(HNSW_TUPLE_ALLOC_SIZE); vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Hnsw vacuum temporary context", ALLOCSET_DEFAULT_SIZES); /* Get m from metapage */ HnswGetMetaPageInfo(index, &vacuumstate->m, NULL); /* Create hash table */ vacuumstate->deleted = tidhash_create(CurrentMemoryContext, 256, NULL); } /* * Free resources */ static void FreeVacuumState(HnswVacuumState * vacuumstate) { tidhash_destroy(vacuumstate->deleted); FreeAccessStrategy(vacuumstate->bas); pfree(vacuumstate->ntup); MemoryContextDelete(vacuumstate->tmpCtx); } /* * Bulk delete tuples from the index */ IndexBulkDeleteResult * hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state) { HnswVacuumState vacuumstate; InitVacuumState(&vacuumstate, info, stats, callback, callback_state); /* Pass 1: Remove heap TIDs */ RemoveHeapTids(&vacuumstate); /* Pass 2: Repair graph */ RepairGraph(&vacuumstate); /* Pass 3: Mark as deleted */ MarkDeleted(&vacuumstate); FreeVacuumState(&vacuumstate); return vacuumstate.stats; } /* * Clean up after a VACUUM operation */ IndexBulkDeleteResult * hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats) { Relation rel = info->index; if (info->analyze_only) return stats; /* stats is NULL if ambulkdelete not called */ /* OK to return NULL if index not changed */ if (stats == NULL) return NULL; stats->num_pages = RelationGetNumberOfBlocks(rel); return stats; } pgvector-0.6.0/src/ivfbuild.c000066400000000000000000000714251455577216400161360ustar00rootroot00000000000000#include "postgres.h" #include #include "access/table.h" #include "access/tableam.h" #include "access/parallel.h" #include "access/xact.h" #include "catalog/index.h" #include "catalog/pg_operator_d.h" #include "catalog/pg_type_d.h" #include "commands/progress.h" #include "ivfflat.h" #include "miscadmin.h" #include "optimizer/optimizer.h" #include "storage/bufmgr.h" #include "tcop/tcopprot.h" #include "utils/memutils.h" #if PG_VERSION_NUM >= 140000 #include "utils/backend_progress.h" #else #include "pgstat.h" #endif #if PG_VERSION_NUM >= 130000 #define CALLBACK_ITEM_POINTER ItemPointer tid #else #define CALLBACK_ITEM_POINTER HeapTuple hup #endif #if PG_VERSION_NUM >= 140000 #include "utils/backend_status.h" #include "utils/wait_event.h" #endif #define PARALLEL_KEY_IVFFLAT_SHARED UINT64CONST(0xA000000000000001) #define PARALLEL_KEY_TUPLESORT UINT64CONST(0xA000000000000002) #define PARALLEL_KEY_IVFFLAT_CENTERS UINT64CONST(0xA000000000000003) #define PARALLEL_KEY_QUERY_TEXT UINT64CONST(0xA000000000000004) /* * Add sample */ static void AddSample(Datum *values, IvfflatBuildState * buildstate) { VectorArray samples = buildstate->samples; int targsamples = samples->maxlen; /* Detoast once for all calls */ Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* * Normalize with KMEANS_NORM_PROC since spherical distance function * expects unit vectors */ if (buildstate->kmeansnormprocinfo != NULL) { if (!IvfflatNormValue(buildstate->kmeansnormprocinfo, buildstate->collation, &value, buildstate->normvec)) return; } if (samples->length < targsamples) { VectorArraySet(samples, samples->length, DatumGetVector(value)); samples->length++; } else { if (buildstate->rowstoskip < 0) buildstate->rowstoskip = reservoir_get_next_S(&buildstate->rstate, samples->length, targsamples); if (buildstate->rowstoskip <= 0) { #if PG_VERSION_NUM >= 150000 int k = (int) (targsamples * sampler_random_fract(&buildstate->rstate.randstate)); #else int k = (int) (targsamples * sampler_random_fract(buildstate->rstate.randstate)); #endif Assert(k >= 0 && k < targsamples); VectorArraySet(samples, k, DatumGetVector(value)); } buildstate->rowstoskip -= 1; } } /* * Callback for sampling */ static void SampleCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, bool *isnull, bool tupleIsAlive, void *state) { IvfflatBuildState *buildstate = (IvfflatBuildState *) state; MemoryContext oldCtx; /* Skip nulls */ if (isnull[0]) return; /* Use memory context since detoast can allocate */ oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); /* Add sample */ AddSample(values, state); /* Reset memory context */ MemoryContextSwitchTo(oldCtx); MemoryContextReset(buildstate->tmpCtx); } /* * Sample rows with same logic as ANALYZE */ static void SampleRows(IvfflatBuildState * buildstate) { int targsamples = buildstate->samples->maxlen; BlockNumber totalblocks = RelationGetNumberOfBlocks(buildstate->heap); buildstate->rowstoskip = -1; BlockSampler_Init(&buildstate->bs, totalblocks, targsamples, RandomInt()); reservoir_init_selection_state(&buildstate->rstate, targsamples); while (BlockSampler_HasMore(&buildstate->bs)) { BlockNumber targblock = BlockSampler_Next(&buildstate->bs); table_index_build_range_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, false, true, false, targblock, 1, SampleCallback, (void *) buildstate, NULL); } } /* * Add tuple to sort */ static void AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState * buildstate) { double distance; double minDistance = DBL_MAX; int closestCenter = 0; VectorArray centers = buildstate->centers; TupleTableSlot *slot = buildstate->slot; /* Detoast once for all calls */ Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Normalize if needed */ if (buildstate->normprocinfo != NULL) { if (!IvfflatNormValue(buildstate->normprocinfo, buildstate->collation, &value, buildstate->normvec)) return; } /* Find the list that minimizes the distance */ for (int i = 0; i < centers->length; i++) { distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, value, PointerGetDatum(VectorArrayGet(centers, i)))); if (distance < minDistance) { minDistance = distance; closestCenter = i; } } #ifdef IVFFLAT_KMEANS_DEBUG buildstate->inertia += minDistance; buildstate->listSums[closestCenter] += minDistance; buildstate->listCounts[closestCenter]++; #endif /* Create a virtual tuple */ ExecClearTuple(slot); slot->tts_values[0] = Int32GetDatum(closestCenter); slot->tts_isnull[0] = false; slot->tts_values[1] = PointerGetDatum(tid); slot->tts_isnull[1] = false; slot->tts_values[2] = value; slot->tts_isnull[2] = false; ExecStoreVirtualTuple(slot); /* * Add tuple to sort * * tuplesort_puttupleslot comment: Input data is always copied; the caller * need not save it. */ tuplesort_puttupleslot(buildstate->sortstate, slot); buildstate->indtuples++; } /* * Callback for table_index_build_scan */ static void BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, bool *isnull, bool tupleIsAlive, void *state) { IvfflatBuildState *buildstate = (IvfflatBuildState *) state; MemoryContext oldCtx; #if PG_VERSION_NUM < 130000 ItemPointer tid = &hup->t_self; #endif /* Skip nulls */ if (isnull[0]) return; /* Use memory context since detoast can allocate */ oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); /* Add tuple to sort */ AddTupleToSort(index, tid, values, buildstate); /* Reset memory context */ MemoryContextSwitchTo(oldCtx); MemoryContextReset(buildstate->tmpCtx); } /* * Get index tuple from sort state */ static inline void GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot, IndexTuple *itup, int *list) { Datum value; bool isnull; if (tuplesort_gettupleslot(sortstate, true, false, slot, NULL)) { *list = DatumGetInt32(slot_getattr(slot, 1, &isnull)); value = slot_getattr(slot, 3, &isnull); /* Form the index tuple */ *itup = index_form_tuple(tupdesc, &value, &isnull); (*itup)->t_tid = *((ItemPointer) DatumGetPointer(slot_getattr(slot, 2, &isnull))); } else *list = -1; } /* * Create initial entry pages */ static void InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) { int list; IndexTuple itup = NULL; /* silence compiler warning */ int64 inserted = 0; TupleTableSlot *slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsMinimalTuple); TupleDesc tupdesc = RelationGetDescr(index); pgstat_progress_update_param(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_LOAD); pgstat_progress_update_param(PROGRESS_CREATEIDX_TUPLES_TOTAL, buildstate->indtuples); GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list); for (int i = 0; i < buildstate->centers->length; i++) { Buffer buf; Page page; GenericXLogState *state; BlockNumber startPage; BlockNumber insertPage; /* Can take a while, so ensure we can interrupt */ /* Needs to be called when no buffer locks are held */ CHECK_FOR_INTERRUPTS(); buf = IvfflatNewBuffer(index, forkNum); IvfflatInitRegisterPage(index, &buf, &page, &state); startPage = BufferGetBlockNumber(buf); /* Get all tuples for list */ while (list == i) { /* Check for free space */ Size itemsz = MAXALIGN(IndexTupleSize(itup)); if (PageGetFreeSpace(page) < itemsz) IvfflatAppendPage(index, &buf, &page, &state, forkNum); /* Add the item */ if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); pfree(itup); pgstat_progress_update_param(PROGRESS_CREATEIDX_TUPLES_DONE, ++inserted); GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list); } insertPage = BufferGetBlockNumber(buf); IvfflatCommitBuffer(buf, state); /* Set the start and insert pages */ IvfflatUpdateList(index, buildstate->listInfo[i], insertPage, InvalidBlockNumber, startPage, forkNum); } } /* * Initialize the build state */ static void InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo) { buildstate->heap = heap; buildstate->index = index; buildstate->indexInfo = indexInfo; buildstate->lists = IvfflatGetLists(index); buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; /* Require column to have dimensions to be indexed */ if (buildstate->dimensions < 0) elog(ERROR, "column does not have dimensions"); if (buildstate->dimensions > IVFFLAT_MAX_DIM) elog(ERROR, "column cannot have more than %d dimensions for ivfflat index", IVFFLAT_MAX_DIM); buildstate->reltuples = 0; buildstate->indtuples = 0; /* Get support functions */ buildstate->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); buildstate->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); buildstate->kmeansnormprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); buildstate->collation = index->rd_indcollation[0]; /* Require more than one dimension for spherical k-means */ if (buildstate->kmeansnormprocinfo != NULL && buildstate->dimensions == 1) elog(ERROR, "dimensions must be greater than one for this opclass"); /* Create tuple description for sorting */ buildstate->tupdesc = CreateTemplateTupleDesc(3); TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 1, "list", INT4OID, -1, 0); TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 2, "tid", TIDOID, -1, 0); TupleDescInitEntry(buildstate->tupdesc, (AttrNumber) 3, "vector", RelationGetDescr(index)->attrs[0].atttypid, -1, 0); buildstate->slot = MakeSingleTupleTableSlot(buildstate->tupdesc, &TTSOpsVirtual); buildstate->centers = VectorArrayInit(buildstate->lists, buildstate->dimensions); buildstate->listInfo = palloc(sizeof(ListInfo) * buildstate->lists); /* Reuse for each tuple */ buildstate->normvec = InitVector(buildstate->dimensions); buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, "Ivfflat build temporary context", ALLOCSET_DEFAULT_SIZES); #ifdef IVFFLAT_KMEANS_DEBUG buildstate->inertia = 0; buildstate->listSums = palloc0(sizeof(double) * buildstate->lists); buildstate->listCounts = palloc0(sizeof(int) * buildstate->lists); #endif buildstate->ivfleader = NULL; } /* * Free resources */ static void FreeBuildState(IvfflatBuildState * buildstate) { VectorArrayFree(buildstate->centers); pfree(buildstate->listInfo); pfree(buildstate->normvec); #ifdef IVFFLAT_KMEANS_DEBUG pfree(buildstate->listSums); pfree(buildstate->listCounts); #endif MemoryContextDelete(buildstate->tmpCtx); } /* * Compute centers */ static void ComputeCenters(IvfflatBuildState * buildstate) { int numSamples; pgstat_progress_update_param(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_KMEANS); /* Target 50 samples per list, with at least 10000 samples */ /* The number of samples has a large effect on index build time */ numSamples = buildstate->lists * 50; if (numSamples < 10000) numSamples = 10000; /* Skip samples for unlogged table */ if (buildstate->heap == NULL) numSamples = 1; /* Sample rows */ /* TODO Ensure within maintenance_work_mem */ buildstate->samples = VectorArrayInit(numSamples, buildstate->dimensions); if (buildstate->heap != NULL) { SampleRows(buildstate); if (buildstate->samples->length < buildstate->lists) { ereport(NOTICE, (errmsg("ivfflat index created with little data"), errdetail("This will cause low recall."), errhint("Drop the index until the table has more data."))); } } /* Calculate centers */ IvfflatBench("k-means", IvfflatKmeans(buildstate->index, buildstate->samples, buildstate->centers)); /* Free samples before we allocate more memory */ VectorArrayFree(buildstate->samples); } /* * Create the metapage */ static void CreateMetaPage(Relation index, int dimensions, int lists, ForkNumber forkNum) { Buffer buf; Page page; GenericXLogState *state; IvfflatMetaPage metap; buf = IvfflatNewBuffer(index, forkNum); IvfflatInitRegisterPage(index, &buf, &page, &state); /* Set metapage data */ metap = IvfflatPageGetMeta(page); metap->magicNumber = IVFFLAT_MAGIC_NUMBER; metap->version = IVFFLAT_VERSION; metap->dimensions = dimensions; metap->lists = lists; ((PageHeader) page)->pd_lower = ((char *) metap + sizeof(IvfflatMetaPageData)) - (char *) page; IvfflatCommitBuffer(buf, state); } /* * Create list pages */ static void CreateListPages(Relation index, VectorArray centers, int dimensions, int lists, ForkNumber forkNum, ListInfo * *listInfo) { Buffer buf; Page page; GenericXLogState *state; Size listSize; IvfflatList list; listSize = MAXALIGN(IVFFLAT_LIST_SIZE(dimensions)); list = palloc0(listSize); buf = IvfflatNewBuffer(index, forkNum); IvfflatInitRegisterPage(index, &buf, &page, &state); for (int i = 0; i < lists; i++) { OffsetNumber offno; /* Load list */ list->startPage = InvalidBlockNumber; list->insertPage = InvalidBlockNumber; memcpy(&list->center, VectorArrayGet(centers, i), VECTOR_SIZE(dimensions)); /* Ensure free space */ if (PageGetFreeSpace(page) < listSize) IvfflatAppendPage(index, &buf, &page, &state, forkNum); /* Add the item */ offno = PageAddItem(page, (Item) list, listSize, InvalidOffsetNumber, false, false); if (offno == InvalidOffsetNumber) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); /* Save location info */ (*listInfo)[i].blkno = BufferGetBlockNumber(buf); (*listInfo)[i].offno = offno; } IvfflatCommitBuffer(buf, state); pfree(list); } #ifdef IVFFLAT_KMEANS_DEBUG /* * Print k-means metrics */ static void PrintKmeansMetrics(IvfflatBuildState * buildstate) { elog(INFO, "inertia: %.3e", buildstate->inertia); /* Calculate Davies-Bouldin index */ if (buildstate->lists > 1 && !buildstate->ivfleader) { double db = 0.0; /* Calculate average distance */ for (int i = 0; i < buildstate->lists; i++) { if (buildstate->listCounts[i] > 0) buildstate->listSums[i] /= buildstate->listCounts[i]; } for (int i = 0; i < buildstate->lists; i++) { double max = 0.0; double distance; for (int j = 0; j < buildstate->lists; j++) { if (j == i) continue; distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, PointerGetDatum(VectorArrayGet(buildstate->centers, i)), PointerGetDatum(VectorArrayGet(buildstate->centers, j)))); distance = (buildstate->listSums[i] + buildstate->listSums[j]) / distance; if (distance > max) max = distance; } db += max; } db /= buildstate->lists; elog(INFO, "davies-bouldin: %.3f", db); } } #endif /* * Within leader, wait for end of heap scan */ static double ParallelHeapScan(IvfflatBuildState * buildstate) { IvfflatShared *ivfshared = buildstate->ivfleader->ivfshared; int nparticipanttuplesorts; double reltuples; nparticipanttuplesorts = buildstate->ivfleader->nparticipanttuplesorts; for (;;) { SpinLockAcquire(&ivfshared->mutex); if (ivfshared->nparticipantsdone == nparticipanttuplesorts) { buildstate->indtuples = ivfshared->indtuples; reltuples = ivfshared->reltuples; #ifdef IVFFLAT_KMEANS_DEBUG buildstate->inertia = ivfshared->inertia; #endif SpinLockRelease(&ivfshared->mutex); break; } SpinLockRelease(&ivfshared->mutex); ConditionVariableSleep(&ivfshared->workersdonecv, WAIT_EVENT_PARALLEL_CREATE_INDEX_SCAN); } ConditionVariableCancelSleep(); return reltuples; } /* * Perform a worker's portion of a parallel sort */ static void IvfflatParallelScanAndSort(IvfflatSpool * ivfspool, IvfflatShared * ivfshared, Sharedsort *sharedsort, Vector * ivfcenters, int sortmem, bool progress) { SortCoordinate coordinate; IvfflatBuildState buildstate; TableScanDesc scan; double reltuples; IndexInfo *indexInfo; /* Sort options, which must match AssignTuples */ AttrNumber attNums[] = {1}; Oid sortOperators[] = {Int4LessOperator}; Oid sortCollations[] = {InvalidOid}; bool nullsFirstFlags[] = {false}; /* Initialize local tuplesort coordination state */ coordinate = palloc0(sizeof(SortCoordinateData)); coordinate->isWorker = true; coordinate->nParticipants = -1; coordinate->sharedsort = sharedsort; /* Join parallel scan */ indexInfo = BuildIndexInfo(ivfspool->index); indexInfo->ii_Concurrent = ivfshared->isconcurrent; InitBuildState(&buildstate, ivfspool->heap, ivfspool->index, indexInfo); memcpy(buildstate.centers->items, ivfcenters, VECTOR_SIZE(buildstate.centers->dim) * buildstate.centers->maxlen); buildstate.centers->length = buildstate.centers->maxlen; ivfspool->sortstate = tuplesort_begin_heap(buildstate.tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, sortmem, coordinate, false); buildstate.sortstate = ivfspool->sortstate; scan = table_beginscan_parallel(ivfspool->heap, ParallelTableScanFromIvfflatShared(ivfshared)); reltuples = table_index_build_scan(ivfspool->heap, ivfspool->index, indexInfo, true, progress, BuildCallback, (void *) &buildstate, scan); /* Execute this worker's part of the sort */ tuplesort_performsort(ivfspool->sortstate); /* Record statistics */ SpinLockAcquire(&ivfshared->mutex); ivfshared->nparticipantsdone++; ivfshared->reltuples += reltuples; ivfshared->indtuples += buildstate.indtuples; #ifdef IVFFLAT_KMEANS_DEBUG ivfshared->inertia += buildstate.inertia; #endif SpinLockRelease(&ivfshared->mutex); /* Log statistics */ if (progress) ereport(DEBUG1, (errmsg("leader processed " INT64_FORMAT " tuples", (int64) reltuples))); else ereport(DEBUG1, (errmsg("worker processed " INT64_FORMAT " tuples", (int64) reltuples))); /* Notify leader */ ConditionVariableSignal(&ivfshared->workersdonecv); /* We can end tuplesorts immediately */ tuplesort_end(ivfspool->sortstate); FreeBuildState(&buildstate); } /* * Perform work within a launched parallel process */ void IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc) { char *sharedquery; IvfflatSpool *ivfspool; IvfflatShared *ivfshared; Sharedsort *sharedsort; Vector *ivfcenters; Relation heapRel; Relation indexRel; LOCKMODE heapLockmode; LOCKMODE indexLockmode; int sortmem; /* Set debug_query_string for individual workers first */ sharedquery = shm_toc_lookup(toc, PARALLEL_KEY_QUERY_TEXT, true); debug_query_string = sharedquery; /* Report the query string from leader */ pgstat_report_activity(STATE_RUNNING, debug_query_string); /* Look up shared state */ ivfshared = shm_toc_lookup(toc, PARALLEL_KEY_IVFFLAT_SHARED, false); /* Open relations using lock modes known to be obtained by index.c */ if (!ivfshared->isconcurrent) { heapLockmode = ShareLock; indexLockmode = AccessExclusiveLock; } else { heapLockmode = ShareUpdateExclusiveLock; indexLockmode = RowExclusiveLock; } /* Open relations within worker */ heapRel = table_open(ivfshared->heaprelid, heapLockmode); indexRel = index_open(ivfshared->indexrelid, indexLockmode); /* Initialize worker's own spool */ ivfspool = (IvfflatSpool *) palloc0(sizeof(IvfflatSpool)); ivfspool->heap = heapRel; ivfspool->index = indexRel; /* Look up shared state private to tuplesort.c */ sharedsort = shm_toc_lookup(toc, PARALLEL_KEY_TUPLESORT, false); tuplesort_attach_shared(sharedsort, seg); ivfcenters = shm_toc_lookup(toc, PARALLEL_KEY_IVFFLAT_CENTERS, false); /* Perform sorting */ sortmem = maintenance_work_mem / ivfshared->scantuplesortstates; IvfflatParallelScanAndSort(ivfspool, ivfshared, sharedsort, ivfcenters, sortmem, false); /* Close relations within worker */ index_close(indexRel, indexLockmode); table_close(heapRel, heapLockmode); } /* * End parallel build */ static void IvfflatEndParallel(IvfflatLeader * ivfleader) { /* Shutdown worker processes */ WaitForParallelWorkersToFinish(ivfleader->pcxt); /* Free last reference to MVCC snapshot, if one was used */ if (IsMVCCSnapshot(ivfleader->snapshot)) UnregisterSnapshot(ivfleader->snapshot); DestroyParallelContext(ivfleader->pcxt); ExitParallelMode(); } /* * Return size of shared memory required for parallel index build */ static Size ParallelEstimateShared(Relation heap, Snapshot snapshot) { return add_size(BUFFERALIGN(sizeof(IvfflatShared)), table_parallelscan_estimate(heap, snapshot)); } /* * Within leader, participate as a parallel worker */ static void IvfflatLeaderParticipateAsWorker(IvfflatBuildState * buildstate) { IvfflatLeader *ivfleader = buildstate->ivfleader; IvfflatSpool *leaderworker; int sortmem; /* Allocate memory and initialize private spool */ leaderworker = (IvfflatSpool *) palloc0(sizeof(IvfflatSpool)); leaderworker->heap = buildstate->heap; leaderworker->index = buildstate->index; /* Perform work common to all participants */ sortmem = maintenance_work_mem / ivfleader->nparticipanttuplesorts; IvfflatParallelScanAndSort(leaderworker, ivfleader->ivfshared, ivfleader->sharedsort, ivfleader->ivfcenters, sortmem, true); } /* * Begin parallel build */ static void IvfflatBeginParallel(IvfflatBuildState * buildstate, bool isconcurrent, int request) { ParallelContext *pcxt; int scantuplesortstates; Snapshot snapshot; Size estivfshared; Size estsort; Size estcenters; IvfflatShared *ivfshared; Sharedsort *sharedsort; Vector *ivfcenters; IvfflatLeader *ivfleader = (IvfflatLeader *) palloc0(sizeof(IvfflatLeader)); bool leaderparticipates = true; int querylen; #ifdef DISABLE_LEADER_PARTICIPATION leaderparticipates = false; #endif /* Enter parallel mode and create context */ EnterParallelMode(); Assert(request > 0); pcxt = CreateParallelContext("vector", "IvfflatParallelBuildMain", request); scantuplesortstates = leaderparticipates ? request + 1 : request; /* Get snapshot for table scan */ if (!isconcurrent) snapshot = SnapshotAny; else snapshot = RegisterSnapshot(GetTransactionSnapshot()); /* Estimate size of workspaces */ estivfshared = ParallelEstimateShared(buildstate->heap, snapshot); shm_toc_estimate_chunk(&pcxt->estimator, estivfshared); estsort = tuplesort_estimate_shared(scantuplesortstates); shm_toc_estimate_chunk(&pcxt->estimator, estsort); estcenters = VECTOR_SIZE(buildstate->dimensions) * buildstate->lists; shm_toc_estimate_chunk(&pcxt->estimator, estcenters); shm_toc_estimate_keys(&pcxt->estimator, 3); /* Finally, estimate PARALLEL_KEY_QUERY_TEXT space */ if (debug_query_string) { querylen = strlen(debug_query_string); shm_toc_estimate_chunk(&pcxt->estimator, querylen + 1); shm_toc_estimate_keys(&pcxt->estimator, 1); } else querylen = 0; /* keep compiler quiet */ /* Everyone's had a chance to ask for space, so now create the DSM */ InitializeParallelDSM(pcxt); /* If no DSM segment was available, back out (do serial build) */ if (pcxt->seg == NULL) { if (IsMVCCSnapshot(snapshot)) UnregisterSnapshot(snapshot); DestroyParallelContext(pcxt); ExitParallelMode(); return; } /* Store shared build state, for which we reserved space */ ivfshared = (IvfflatShared *) shm_toc_allocate(pcxt->toc, estivfshared); /* Initialize immutable state */ ivfshared->heaprelid = RelationGetRelid(buildstate->heap); ivfshared->indexrelid = RelationGetRelid(buildstate->index); ivfshared->isconcurrent = isconcurrent; ivfshared->scantuplesortstates = scantuplesortstates; ConditionVariableInit(&ivfshared->workersdonecv); SpinLockInit(&ivfshared->mutex); /* Initialize mutable state */ ivfshared->nparticipantsdone = 0; ivfshared->reltuples = 0; ivfshared->indtuples = 0; #ifdef IVFFLAT_KMEANS_DEBUG ivfshared->inertia = 0; #endif table_parallelscan_initialize(buildstate->heap, ParallelTableScanFromIvfflatShared(ivfshared), snapshot); /* Store shared tuplesort-private state, for which we reserved space */ sharedsort = (Sharedsort *) shm_toc_allocate(pcxt->toc, estsort); tuplesort_initialize_shared(sharedsort, scantuplesortstates, pcxt->seg); ivfcenters = (Vector *) shm_toc_allocate(pcxt->toc, estcenters); memcpy(ivfcenters, buildstate->centers->items, estcenters); shm_toc_insert(pcxt->toc, PARALLEL_KEY_IVFFLAT_SHARED, ivfshared); shm_toc_insert(pcxt->toc, PARALLEL_KEY_TUPLESORT, sharedsort); shm_toc_insert(pcxt->toc, PARALLEL_KEY_IVFFLAT_CENTERS, ivfcenters); /* Store query string for workers */ if (debug_query_string) { char *sharedquery; sharedquery = (char *) shm_toc_allocate(pcxt->toc, querylen + 1); memcpy(sharedquery, debug_query_string, querylen + 1); shm_toc_insert(pcxt->toc, PARALLEL_KEY_QUERY_TEXT, sharedquery); } /* Launch workers, saving status for leader/caller */ LaunchParallelWorkers(pcxt); ivfleader->pcxt = pcxt; ivfleader->nparticipanttuplesorts = pcxt->nworkers_launched; if (leaderparticipates) ivfleader->nparticipanttuplesorts++; ivfleader->ivfshared = ivfshared; ivfleader->sharedsort = sharedsort; ivfleader->snapshot = snapshot; ivfleader->ivfcenters = ivfcenters; /* If no workers were successfully launched, back out (do serial build) */ if (pcxt->nworkers_launched == 0) { IvfflatEndParallel(ivfleader); return; } /* Log participants */ ereport(DEBUG1, (errmsg("using %d parallel workers", pcxt->nworkers_launched))); /* Save leader state now that it's clear build will be parallel */ buildstate->ivfleader = ivfleader; /* Join heap scan ourselves */ if (leaderparticipates) IvfflatLeaderParticipateAsWorker(buildstate); /* Wait for all launched workers */ WaitForParallelWorkersToAttach(pcxt); } /* * Scan table for tuples to index */ static void AssignTuples(IvfflatBuildState * buildstate) { int parallel_workers = 0; SortCoordinate coordinate = NULL; /* Sort options, which must match IvfflatParallelScanAndSort */ AttrNumber attNums[] = {1}; Oid sortOperators[] = {Int4LessOperator}; Oid sortCollations[] = {InvalidOid}; bool nullsFirstFlags[] = {false}; pgstat_progress_update_param(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_ASSIGN); /* Calculate parallel workers */ if (buildstate->heap != NULL) parallel_workers = plan_create_index_workers(RelationGetRelid(buildstate->heap), RelationGetRelid(buildstate->index)); /* Attempt to launch parallel worker scan when required */ if (parallel_workers > 0) IvfflatBeginParallel(buildstate, buildstate->indexInfo->ii_Concurrent, parallel_workers); /* Set up coordination state if at least one worker launched */ if (buildstate->ivfleader) { coordinate = (SortCoordinate) palloc0(sizeof(SortCoordinateData)); coordinate->isWorker = false; coordinate->nParticipants = buildstate->ivfleader->nparticipanttuplesorts; coordinate->sharedsort = buildstate->ivfleader->sharedsort; } /* Begin serial/leader tuplesort */ buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, coordinate, false); /* Add tuples to sort */ if (buildstate->heap != NULL) { if (buildstate->ivfleader) buildstate->reltuples = ParallelHeapScan(buildstate); else buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, true, true, BuildCallback, (void *) buildstate, NULL); #ifdef IVFFLAT_KMEANS_DEBUG PrintKmeansMetrics(buildstate); #endif } } /* * Create entry pages */ static void CreateEntryPages(IvfflatBuildState * buildstate, ForkNumber forkNum) { /* Assign */ IvfflatBench("assign tuples", AssignTuples(buildstate)); /* Sort */ IvfflatBench("sort tuples", tuplesort_performsort(buildstate->sortstate)); /* Load */ IvfflatBench("load tuples", InsertTuples(buildstate->index, buildstate, forkNum)); /* End sort */ tuplesort_end(buildstate->sortstate); /* End parallel build */ if (buildstate->ivfleader) IvfflatEndParallel(buildstate->ivfleader); } /* * Build the index */ static void BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo, IvfflatBuildState * buildstate, ForkNumber forkNum) { InitBuildState(buildstate, heap, index, indexInfo); ComputeCenters(buildstate); /* Create pages */ CreateMetaPage(index, buildstate->dimensions, buildstate->lists, forkNum); CreateListPages(index, buildstate->centers, buildstate->dimensions, buildstate->lists, forkNum, &buildstate->listInfo); CreateEntryPages(buildstate, forkNum); FreeBuildState(buildstate); } /* * Build the index for a logged table */ IndexBuildResult * ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo) { IndexBuildResult *result; IvfflatBuildState buildstate; BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM); result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult)); result->heap_tuples = buildstate.reltuples; result->index_tuples = buildstate.indtuples; return result; } /* * Build the index for an unlogged table */ void ivfflatbuildempty(Relation index) { IndexInfo *indexInfo = BuildIndexInfo(index); IvfflatBuildState buildstate; BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM); } pgvector-0.6.0/src/ivfflat.c000066400000000000000000000153631455577216400157640ustar00rootroot00000000000000#include "postgres.h" #include #include "access/amapi.h" #include "access/reloptions.h" #include "commands/progress.h" #include "commands/vacuum.h" #include "ivfflat.h" #include "utils/guc.h" #include "utils/selfuncs.h" #include "utils/spccache.h" #if PG_VERSION_NUM < 150000 #define MarkGUCPrefixReserved(x) EmitWarningsOnPlaceholders(x) #endif int ivfflat_probes; static relopt_kind ivfflat_relopt_kind; /* * Initialize index options and variables */ void IvfflatInit(void) { ivfflat_relopt_kind = add_reloption_kind(); add_int_reloption(ivfflat_relopt_kind, "lists", "Number of inverted lists", IVFFLAT_DEFAULT_LISTS, IVFFLAT_MIN_LISTS, IVFFLAT_MAX_LISTS #if PG_VERSION_NUM >= 130000 ,AccessExclusiveLock #endif ); DefineCustomIntVariable("ivfflat.probes", "Sets the number of probes", "Valid range is 1..lists.", &ivfflat_probes, IVFFLAT_DEFAULT_PROBES, IVFFLAT_MIN_LISTS, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL); MarkGUCPrefixReserved("ivfflat"); } /* * Get the name of index build phase */ static char * ivfflatbuildphasename(int64 phasenum) { switch (phasenum) { case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE: return "initializing"; case PROGRESS_IVFFLAT_PHASE_KMEANS: return "performing k-means"; case PROGRESS_IVFFLAT_PHASE_ASSIGN: return "assigning tuples"; case PROGRESS_IVFFLAT_PHASE_LOAD: return "loading tuples"; default: return NULL; } } /* * Estimate the cost of an index scan */ static void ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, Cost *indexStartupCost, Cost *indexTotalCost, Selectivity *indexSelectivity, double *indexCorrelation, double *indexPages) { GenericCosts costs; int lists; double ratio; double spc_seq_page_cost; Relation index; /* Never use index without order */ if (path->indexorderbys == NULL) { *indexStartupCost = DBL_MAX; *indexTotalCost = DBL_MAX; *indexSelectivity = 0; *indexCorrelation = 0; *indexPages = 0; return; } MemSet(&costs, 0, sizeof(costs)); index = index_open(path->indexinfo->indexoid, NoLock); IvfflatGetMetaPageInfo(index, &lists, NULL); index_close(index, NoLock); /* Get the ratio of lists that we need to visit */ ratio = ((double) ivfflat_probes) / lists; if (ratio > 1.0) ratio = 1.0; /* * This gives us the subset of tuples to visit. This value is passed into * the generic cost estimator to determine the number of pages to visit * during the index scan. */ costs.numIndexTuples = path->indexinfo->tuples * ratio; genericcostestimate(root, path, loop_count, &costs); get_tablespace_page_costs(path->indexinfo->reltablespace, NULL, &spc_seq_page_cost); /* Adjust cost if needed since TOAST not included in seq scan cost */ if (costs.numIndexPages > path->indexinfo->rel->pages && ratio < 0.5) { /* Change all page cost from random to sequential */ costs.indexTotalCost -= costs.numIndexPages * (costs.spc_random_page_cost - spc_seq_page_cost); /* Remove cost of extra pages */ costs.indexTotalCost -= (costs.numIndexPages - path->indexinfo->rel->pages) * spc_seq_page_cost; } else { /* Change some page cost from random to sequential */ costs.indexTotalCost -= 0.5 * costs.numIndexPages * (costs.spc_random_page_cost - spc_seq_page_cost); } /* * If the list selectivity is lower than what is returned from the generic * cost estimator, use that. */ if (ratio < costs.indexSelectivity) costs.indexSelectivity = ratio; /* Use total cost since most work happens before first tuple is returned */ *indexStartupCost = costs.indexTotalCost; *indexTotalCost = costs.indexTotalCost; *indexSelectivity = costs.indexSelectivity; *indexCorrelation = costs.indexCorrelation; *indexPages = costs.numIndexPages; } /* * Parse and validate the reloptions */ static bytea * ivfflatoptions(Datum reloptions, bool validate) { static const relopt_parse_elt tab[] = { {"lists", RELOPT_TYPE_INT, offsetof(IvfflatOptions, lists)}, }; #if PG_VERSION_NUM >= 130000 return (bytea *) build_reloptions(reloptions, validate, ivfflat_relopt_kind, sizeof(IvfflatOptions), tab, lengthof(tab)); #else relopt_value *options; int numoptions; IvfflatOptions *rdopts; options = parseRelOptions(reloptions, validate, ivfflat_relopt_kind, &numoptions); rdopts = allocateReloptStruct(sizeof(IvfflatOptions), options, numoptions); fillRelOptions((void *) rdopts, sizeof(IvfflatOptions), options, numoptions, validate, tab, lengthof(tab)); return (bytea *) rdopts; #endif } /* * Validate catalog entries for the specified operator class */ static bool ivfflatvalidate(Oid opclassoid) { return true; } /* * Define index handler * * See https://www.postgresql.org/docs/current/index-api.html */ PGDLLEXPORT PG_FUNCTION_INFO_V1(ivfflathandler); Datum ivfflathandler(PG_FUNCTION_ARGS) { IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); amroutine->amstrategies = 0; amroutine->amsupport = 4; #if PG_VERSION_NUM >= 130000 amroutine->amoptsprocnum = 0; #endif amroutine->amcanorder = false; amroutine->amcanorderbyop = true; amroutine->amcanbackward = false; /* can change direction mid-scan */ amroutine->amcanunique = false; amroutine->amcanmulticol = false; amroutine->amoptionalkey = true; amroutine->amsearcharray = false; amroutine->amsearchnulls = false; amroutine->amstorage = false; amroutine->amclusterable = false; amroutine->ampredlocks = false; amroutine->amcanparallel = false; amroutine->amcaninclude = false; #if PG_VERSION_NUM >= 130000 amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */ amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL; #endif amroutine->amkeytype = InvalidOid; /* Interface functions */ amroutine->ambuild = ivfflatbuild; amroutine->ambuildempty = ivfflatbuildempty; amroutine->aminsert = ivfflatinsert; amroutine->ambulkdelete = ivfflatbulkdelete; amroutine->amvacuumcleanup = ivfflatvacuumcleanup; amroutine->amcanreturn = NULL; /* tuple not included in heapsort */ amroutine->amcostestimate = ivfflatcostestimate; amroutine->amoptions = ivfflatoptions; amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */ amroutine->ambuildphasename = ivfflatbuildphasename; amroutine->amvalidate = ivfflatvalidate; #if PG_VERSION_NUM >= 140000 amroutine->amadjustmembers = NULL; #endif amroutine->ambeginscan = ivfflatbeginscan; amroutine->amrescan = ivfflatrescan; amroutine->amgettuple = ivfflatgettuple; amroutine->amgetbitmap = NULL; amroutine->amendscan = ivfflatendscan; amroutine->ammarkpos = NULL; amroutine->amrestrpos = NULL; /* Interface functions to support parallel index scans */ amroutine->amestimateparallelscan = NULL; amroutine->aminitparallelscan = NULL; amroutine->amparallelrescan = NULL; PG_RETURN_POINTER(amroutine); } pgvector-0.6.0/src/ivfflat.h000066400000000000000000000173221455577216400157660ustar00rootroot00000000000000#ifndef IVFFLAT_H #define IVFFLAT_H #include "postgres.h" #include "access/genam.h" #include "access/generic_xlog.h" #include "access/parallel.h" #include "lib/pairingheap.h" #include "nodes/execnodes.h" #include "port.h" /* for random() */ #include "utils/sampling.h" #include "utils/tuplesort.h" #include "vector.h" #if PG_VERSION_NUM >= 150000 #include "common/pg_prng.h" #endif #ifdef IVFFLAT_BENCH #include "portability/instr_time.h" #endif #define IVFFLAT_MAX_DIM 2000 /* Support functions */ #define IVFFLAT_DISTANCE_PROC 1 #define IVFFLAT_NORM_PROC 2 #define IVFFLAT_KMEANS_DISTANCE_PROC 3 #define IVFFLAT_KMEANS_NORM_PROC 4 #define IVFFLAT_VERSION 1 #define IVFFLAT_MAGIC_NUMBER 0x14FF1A7 #define IVFFLAT_PAGE_ID 0xFF84 /* Preserved page numbers */ #define IVFFLAT_METAPAGE_BLKNO 0 #define IVFFLAT_HEAD_BLKNO 1 /* first list page */ /* IVFFlat parameters */ #define IVFFLAT_DEFAULT_LISTS 100 #define IVFFLAT_MIN_LISTS 1 #define IVFFLAT_MAX_LISTS 32768 #define IVFFLAT_DEFAULT_PROBES 1 /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ #define PROGRESS_IVFFLAT_PHASE_KMEANS 2 #define PROGRESS_IVFFLAT_PHASE_ASSIGN 3 #define PROGRESS_IVFFLAT_PHASE_LOAD 4 #define IVFFLAT_LIST_SIZE(_dim) (offsetof(IvfflatListData, center) + VECTOR_SIZE(_dim)) #define IvfflatPageGetOpaque(page) ((IvfflatPageOpaque) PageGetSpecialPointer(page)) #define IvfflatPageGetMeta(page) ((IvfflatMetaPageData *) PageGetContents(page)) #ifdef IVFFLAT_BENCH #define IvfflatBench(name, code) \ do { \ instr_time start; \ instr_time duration; \ INSTR_TIME_SET_CURRENT(start); \ (code); \ INSTR_TIME_SET_CURRENT(duration); \ INSTR_TIME_SUBTRACT(duration, start); \ elog(INFO, "%s: %.3f ms", name, INSTR_TIME_GET_MILLISEC(duration)); \ } while (0) #else #define IvfflatBench(name, code) (code) #endif #if PG_VERSION_NUM >= 150000 #define RandomDouble() pg_prng_double(&pg_global_prng_state) #define RandomInt() pg_prng_uint32(&pg_global_prng_state) #else #define RandomDouble() (((double) random()) / MAX_RANDOM_VALUE) #define RandomInt() random() #endif /* Variables */ extern int ivfflat_probes; typedef struct VectorArrayData { int length; int maxlen; int dim; Vector *items; } VectorArrayData; typedef VectorArrayData * VectorArray; typedef struct ListInfo { BlockNumber blkno; OffsetNumber offno; } ListInfo; /* IVFFlat index options */ typedef struct IvfflatOptions { int32 vl_len_; /* varlena header (do not touch directly!) */ int lists; /* number of lists */ } IvfflatOptions; typedef struct IvfflatSpool { Tuplesortstate *sortstate; Relation heap; Relation index; } IvfflatSpool; typedef struct IvfflatShared { /* Immutable state */ Oid heaprelid; Oid indexrelid; bool isconcurrent; int scantuplesortstates; /* Worker progress */ ConditionVariable workersdonecv; /* Mutex for mutable state */ slock_t mutex; /* Mutable state */ int nparticipantsdone; double reltuples; double indtuples; #ifdef IVFFLAT_KMEANS_DEBUG double inertia; #endif } IvfflatShared; #define ParallelTableScanFromIvfflatShared(shared) \ (ParallelTableScanDesc) ((char *) (shared) + BUFFERALIGN(sizeof(IvfflatShared))) typedef struct IvfflatLeader { ParallelContext *pcxt; int nparticipanttuplesorts; IvfflatShared *ivfshared; Sharedsort *sharedsort; Snapshot snapshot; Vector *ivfcenters; } IvfflatLeader; typedef struct IvfflatBuildState { /* Info */ Relation heap; Relation index; IndexInfo *indexInfo; /* Settings */ int dimensions; int lists; /* Statistics */ double indtuples; double reltuples; /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; FmgrInfo *kmeansnormprocinfo; Oid collation; /* Variables */ VectorArray samples; VectorArray centers; ListInfo *listInfo; Vector *normvec; #ifdef IVFFLAT_KMEANS_DEBUG double inertia; double *listSums; int *listCounts; #endif /* Sampling */ BlockSamplerData bs; ReservoirStateData rstate; int rowstoskip; /* Sorting */ Tuplesortstate *sortstate; TupleDesc tupdesc; TupleTableSlot *slot; /* Memory */ MemoryContext tmpCtx; /* Parallel builds */ IvfflatLeader *ivfleader; } IvfflatBuildState; typedef struct IvfflatMetaPageData { uint32 magicNumber; uint32 version; uint16 dimensions; uint16 lists; } IvfflatMetaPageData; typedef IvfflatMetaPageData * IvfflatMetaPage; typedef struct IvfflatPageOpaqueData { BlockNumber nextblkno; uint16 unused; uint16 page_id; /* for identification of IVFFlat indexes */ } IvfflatPageOpaqueData; typedef IvfflatPageOpaqueData * IvfflatPageOpaque; typedef struct IvfflatListData { BlockNumber startPage; BlockNumber insertPage; Vector center; } IvfflatListData; typedef IvfflatListData * IvfflatList; typedef struct IvfflatScanList { pairingheap_node ph_node; BlockNumber startPage; double distance; } IvfflatScanList; typedef struct IvfflatScanOpaqueData { int probes; int dimensions; bool first; /* Sorting */ Tuplesortstate *sortstate; TupleDesc tupdesc; TupleTableSlot *slot; bool isnull; /* Support functions */ FmgrInfo *procinfo; FmgrInfo *normprocinfo; Oid collation; /* Lists */ pairingheap *listQueue; IvfflatScanList lists[FLEXIBLE_ARRAY_MEMBER]; /* must come last */ } IvfflatScanOpaqueData; typedef IvfflatScanOpaqueData * IvfflatScanOpaque; #define VECTOR_ARRAY_SIZE(_length, _dim) (sizeof(VectorArrayData) + (_length) * VECTOR_SIZE(_dim)) #define VECTOR_ARRAY_OFFSET(_arr, _offset) ((char*) (_arr)->items + (_offset) * VECTOR_SIZE((_arr)->dim)) #define VectorArrayGet(_arr, _offset) ((Vector *) VECTOR_ARRAY_OFFSET(_arr, _offset)) #define VectorArraySet(_arr, _offset, _val) memcpy(VECTOR_ARRAY_OFFSET(_arr, _offset), _val, VECTOR_SIZE((_arr)->dim)) /* Methods */ VectorArray VectorArrayInit(int maxlen, int dimensions); void VectorArrayFree(VectorArray arr); void PrintVectorArray(char *msg, VectorArray arr); void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); int IvfflatGetLists(Relation index); void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions); void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum); void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state); void IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum); Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum); void IvfflatInitPage(Buffer buf, Page page); void IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); void IvfflatInit(void); PGDLLEXPORT void IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc); /* Index access methods */ IndexBuildResult *ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo); void ivfflatbuildempty(Relation index); bool ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique #if PG_VERSION_NUM >= 140000 ,bool indexUnchanged #endif ,IndexInfo *indexInfo ); IndexBulkDeleteResult *ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state); IndexBulkDeleteResult *ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats); IndexScanDesc ivfflatbeginscan(Relation index, int nkeys, int norderbys); void ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys); bool ivfflatgettuple(IndexScanDesc scan, ScanDirection dir); void ivfflatendscan(IndexScanDesc scan); #endif pgvector-0.6.0/src/ivfinsert.c000066400000000000000000000122361455577216400163360ustar00rootroot00000000000000#include "postgres.h" #include #include "access/generic_xlog.h" #include "ivfflat.h" #include "storage/bufmgr.h" #include "storage/lmgr.h" #include "utils/memutils.h" /* * Find the list that minimizes the distance function */ static void FindInsertPage(Relation index, Datum *values, BlockNumber *insertPage, ListInfo * listInfo) { double minDistance = DBL_MAX; BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO; FmgrInfo *procinfo; Oid collation; /* Avoid compiler warning */ listInfo->blkno = nextblkno; listInfo->offno = FirstOffsetNumber; procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); collation = index->rd_indcollation[0]; /* Search all list pages */ while (BlockNumberIsValid(nextblkno)) { Buffer cbuf; Page cpage; OffsetNumber maxoffno; cbuf = ReadBuffer(index, nextblkno); LockBuffer(cbuf, BUFFER_LOCK_SHARE); cpage = BufferGetPage(cbuf); maxoffno = PageGetMaxOffsetNumber(cpage); for (OffsetNumber offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { IvfflatList list; double distance; list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno)); distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, values[0], PointerGetDatum(&list->center))); if (distance < minDistance || !BlockNumberIsValid(*insertPage)) { *insertPage = list->insertPage; listInfo->blkno = nextblkno; listInfo->offno = offno; minDistance = distance; } } nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno; UnlockReleaseBuffer(cbuf); } } /* * Insert a tuple into the index */ static void InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel) { IndexTuple itup; Datum value; FmgrInfo *normprocinfo; Buffer buf; Page page; GenericXLogState *state; Size itemsz; BlockNumber insertPage = InvalidBlockNumber; ListInfo listInfo; BlockNumber originalInsertPage; /* Detoast once for all calls */ value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Normalize if needed */ normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); if (normprocinfo != NULL) { if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value, NULL)) return; } /* Find the insert page - sets the page and list info */ FindInsertPage(index, values, &insertPage, &listInfo); Assert(BlockNumberIsValid(insertPage)); originalInsertPage = insertPage; /* Form tuple */ itup = index_form_tuple(RelationGetDescr(index), &value, isnull); itup->t_tid = *heap_tid; /* Get tuple size */ itemsz = MAXALIGN(IndexTupleSize(itup)); Assert(itemsz <= BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(IvfflatPageOpaqueData)) - sizeof(ItemIdData)); /* Find a page to insert the item */ for (;;) { buf = ReadBuffer(index, insertPage); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); if (PageGetFreeSpace(page) >= itemsz) break; insertPage = IvfflatPageGetOpaque(page)->nextblkno; if (BlockNumberIsValid(insertPage)) { /* Move to next page */ GenericXLogAbort(state); UnlockReleaseBuffer(buf); } else { Buffer newbuf; Page newpage; /* Add a new page */ LockRelationForExtension(index, ExclusiveLock); newbuf = IvfflatNewBuffer(index, MAIN_FORKNUM); UnlockRelationForExtension(index, ExclusiveLock); /* Init new page */ newpage = GenericXLogRegisterBuffer(state, newbuf, GENERIC_XLOG_FULL_IMAGE); IvfflatInitPage(newbuf, newpage); /* Update insert page */ insertPage = BufferGetBlockNumber(newbuf); /* Update previous buffer */ IvfflatPageGetOpaque(page)->nextblkno = insertPage; /* Commit */ GenericXLogFinish(state); /* Unlock previous buffer */ UnlockReleaseBuffer(buf); /* Prepare new buffer */ state = GenericXLogStart(index); buf = newbuf; page = GenericXLogRegisterBuffer(state, buf, 0); break; } } /* Add to next offset */ if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); IvfflatCommitBuffer(buf, state); /* Update the insert page */ if (insertPage != originalInsertPage) IvfflatUpdateList(index, listInfo, insertPage, originalInsertPage, InvalidBlockNumber, MAIN_FORKNUM); } /* * Insert a tuple into the index */ bool ivfflatinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique #if PG_VERSION_NUM >= 140000 ,bool indexUnchanged #endif ,IndexInfo *indexInfo ) { MemoryContext oldCtx; MemoryContext insertCtx; /* Skip nulls */ if (isnull[0]) return false; /* * Use memory context since detoast, IvfflatNormValue, and * index_form_tuple can allocate */ insertCtx = AllocSetContextCreate(CurrentMemoryContext, "Ivfflat insert temporary context", ALLOCSET_DEFAULT_SIZES); oldCtx = MemoryContextSwitchTo(insertCtx); /* Insert tuple */ InsertTuple(index, values, isnull, heap_tid, heap); /* Delete memory context */ MemoryContextSwitchTo(oldCtx); MemoryContextDelete(insertCtx); return false; } pgvector-0.6.0/src/ivfkmeans.c000066400000000000000000000333551455577216400163150ustar00rootroot00000000000000#include "postgres.h" #include #include #include "ivfflat.h" #include "miscadmin.h" #ifdef IVFFLAT_MEMORY #include "utils/memutils.h" #endif /* * Initialize with kmeans++ * * https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf */ static void InitCenters(Relation index, VectorArray samples, VectorArray centers, float *lowerBound) { FmgrInfo *procinfo; Oid collation; int64 j; float *weight = palloc(samples->length * sizeof(float)); int numCenters = centers->maxlen; int numSamples = samples->length; procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); collation = index->rd_indcollation[0]; /* Choose an initial center uniformly at random */ VectorArraySet(centers, 0, VectorArrayGet(samples, RandomInt() % samples->length)); centers->length++; for (j = 0; j < numSamples; j++) weight[j] = FLT_MAX; for (int i = 0; i < numCenters; i++) { double sum; double choice; CHECK_FOR_INTERRUPTS(); sum = 0.0; for (j = 0; j < numSamples; j++) { Vector *vec = VectorArrayGet(samples, j); double distance; /* Only need to compute distance for new center */ /* TODO Use triangle inequality to reduce distance calculations */ distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, i)))); /* Set lower bound */ lowerBound[j * numCenters + i] = distance; /* Use distance squared for weighted probability distribution */ distance *= distance; if (distance < weight[j]) weight[j] = distance; sum += weight[j]; } /* Only compute lower bound on last iteration */ if (i + 1 == numCenters) break; /* Choose new center using weighted probability distribution. */ choice = sum * RandomDouble(); for (j = 0; j < numSamples - 1; j++) { choice -= weight[j]; if (choice <= 0) break; } VectorArraySet(centers, i + 1, VectorArrayGet(samples, j)); centers->length++; } pfree(weight); } /* * Apply norm to vector */ static inline void ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Vector * vec) { double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(vec))); /* TODO Handle zero norm */ if (norm > 0) { for (int i = 0; i < vec->dim; i++) vec->x[i] /= norm; } } /* * Compare vectors */ static int CompareVectors(const void *a, const void *b) { return vector_cmp_internal((Vector *) a, (Vector *) b); } /* * Quick approach if we have little data */ static void QuickCenters(Relation index, VectorArray samples, VectorArray centers) { int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); /* Copy existing vectors while avoiding duplicates */ if (samples->length > 0) { qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors); for (int i = 0; i < samples->length; i++) { Vector *vec = VectorArrayGet(samples, i); if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0) { VectorArraySet(centers, centers->length, vec); centers->length++; } } } /* Fill remaining with random data */ while (centers->length < centers->maxlen) { Vector *vec = VectorArrayGet(centers, centers->length); SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); vec->dim = dimensions; for (int j = 0; j < dimensions; j++) vec->x[j] = RandomDouble(); /* Normalize if needed (only needed for random centers) */ if (normprocinfo != NULL) ApplyNorm(normprocinfo, collation, vec); centers->length++; } } #ifdef IVFFLAT_MEMORY /* * Show memory usage */ static void ShowMemoryUsage(Size estimatedSize) { #if PG_VERSION_NUM >= 130000 elog(INFO, "total memory: %zu MB", MemoryContextMemAllocated(CurrentMemoryContext, true) / (1024 * 1024)); #else MemoryContextStats(CurrentMemoryContext); #endif elog(INFO, "estimated memory: %zu MB", estimatedSize / (1024 * 1024)); } #endif /* * Use Elkan for performance. This requires distance function to satisfy triangle inequality. * * We use L2 distance for L2 (not L2 squared like index scan) * and angular distance for inner product and cosine distance * * https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf */ static void ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) { FmgrInfo *procinfo; FmgrInfo *normprocinfo; Oid collation; Vector *vec; Vector *newCenter; int64 j; int64 k; int dimensions = centers->dim; int numCenters = centers->maxlen; int numSamples = samples->length; VectorArray newCenters; int *centerCounts; int *closestCenters; float *lowerBound; float *upperBound; float *s; float *halfcdist; float *newcdist; /* Calculate allocation sizes */ Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim); Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->dim); Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, dimensions); Size centerCountsSize = sizeof(int) * numCenters; Size closestCentersSize = sizeof(int) * numSamples; Size lowerBoundSize = sizeof(float) * numSamples * numCenters; Size upperBoundSize = sizeof(float) * numSamples; Size sSize = sizeof(float) * numCenters; Size halfcdistSize = sizeof(float) * numCenters * numCenters; Size newcdistSize = sizeof(float) * numCenters; /* Calculate total size */ Size totalSize = samplesSize + centersSize + newCentersSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize; /* Check memory requirements */ /* Add one to error message to ceil */ if (totalSize > (Size) maintenance_work_mem * 1024L) ereport(ERROR, (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), errmsg("memory required is %zu MB, maintenance_work_mem is %d MB", totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024))); /* Ensure indexing does not overflow */ if (numCenters * numCenters > INT_MAX) elog(ERROR, "Indexing overflow detected. Please report a bug."); /* Set support functions */ procinfo = index_getprocinfo(index, 1, IVFFLAT_KMEANS_DISTANCE_PROC); normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); collation = index->rd_indcollation[0]; /* Allocate space */ /* Use float instead of double to save memory */ centerCounts = palloc(centerCountsSize); closestCenters = palloc(closestCentersSize); lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE); upperBound = palloc(upperBoundSize); s = palloc(sSize); halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE); newcdist = palloc(newcdistSize); newCenters = VectorArrayInit(numCenters, dimensions); for (j = 0; j < numCenters; j++) { vec = VectorArrayGet(newCenters, j); SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); vec->dim = dimensions; } #ifdef IVFFLAT_MEMORY ShowMemoryUsage(totalSize); #endif /* Pick initial centers */ InitCenters(index, samples, centers, lowerBound); /* Assign each x to its closest initial center c(x) = argmin d(x,c) */ for (j = 0; j < numSamples; j++) { float minDistance = FLT_MAX; int closestCenter = 0; /* Find closest center */ for (k = 0; k < numCenters; k++) { /* TODO Use Lemma 1 in k-means++ initialization */ float distance = lowerBound[j * numCenters + k]; if (distance < minDistance) { minDistance = distance; closestCenter = k; } } upperBound[j] = minDistance; closestCenters[j] = closestCenter; } /* Give 500 iterations to converge */ for (int iteration = 0; iteration < 500; iteration++) { int changes = 0; bool rjreset; /* Can take a while, so ensure we can interrupt */ CHECK_FOR_INTERRUPTS(); /* Step 1: For all centers, compute distance */ for (j = 0; j < numCenters; j++) { vec = VectorArrayGet(centers, j); for (k = j + 1; k < numCenters; k++) { float distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); halfcdist[j * numCenters + k] = distance; halfcdist[k * numCenters + j] = distance; } } /* For all centers c, compute s(c) */ for (j = 0; j < numCenters; j++) { float minDistance = FLT_MAX; for (k = 0; k < numCenters; k++) { float distance; if (j == k) continue; distance = halfcdist[j * numCenters + k]; if (distance < minDistance) minDistance = distance; } s[j] = minDistance; } rjreset = iteration != 0; for (j = 0; j < numSamples; j++) { bool rj; /* Step 2: Identify all points x such that u(x) <= s(c(x)) */ if (upperBound[j] <= s[closestCenters[j]]) continue; rj = rjreset; for (k = 0; k < numCenters; k++) { float dxcx; /* Step 3: For all remaining points x and centers c */ if (k == closestCenters[j]) continue; if (upperBound[j] <= lowerBound[j * numCenters + k]) continue; if (upperBound[j] <= halfcdist[closestCenters[j] * numCenters + k]) continue; vec = VectorArrayGet(samples, j); /* Step 3a */ if (rj) { dxcx = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, closestCenters[j])))); /* d(x,c(x)) computed, which is a form of d(x,c) */ lowerBound[j * numCenters + closestCenters[j]] = dxcx; upperBound[j] = dxcx; rj = false; } else dxcx = upperBound[j]; /* Step 3b */ if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k]) { float dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); /* d(x,c) calculated */ lowerBound[j * numCenters + k] = dxc; if (dxc < dxcx) { closestCenters[j] = k; /* c(x) changed */ upperBound[j] = dxc; changes++; } } } } /* Step 4: For each center c, let m(c) be mean of all points assigned */ for (j = 0; j < numCenters; j++) { vec = VectorArrayGet(newCenters, j); for (k = 0; k < dimensions; k++) vec->x[k] = 0.0; centerCounts[j] = 0; } for (j = 0; j < numSamples; j++) { int closestCenter; vec = VectorArrayGet(samples, j); closestCenter = closestCenters[j]; /* Increment sum and count of closest center */ newCenter = VectorArrayGet(newCenters, closestCenter); for (k = 0; k < dimensions; k++) newCenter->x[k] += vec->x[k]; centerCounts[closestCenter] += 1; } for (j = 0; j < numCenters; j++) { vec = VectorArrayGet(newCenters, j); if (centerCounts[j] > 0) { /* Double avoids overflow, but requires more memory */ /* TODO Update bounds */ for (k = 0; k < dimensions; k++) { if (isinf(vec->x[k])) vec->x[k] = vec->x[k] > 0 ? FLT_MAX : -FLT_MAX; } for (k = 0; k < dimensions; k++) vec->x[k] /= centerCounts[j]; } else { /* TODO Handle empty centers properly */ for (k = 0; k < dimensions; k++) vec->x[k] = RandomDouble(); } /* Normalize if needed */ if (normprocinfo != NULL) ApplyNorm(normprocinfo, collation, vec); } /* Step 5 */ for (j = 0; j < numCenters; j++) newcdist[j] = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(VectorArrayGet(centers, j)), PointerGetDatum(VectorArrayGet(newCenters, j)))); for (j = 0; j < numSamples; j++) { for (k = 0; k < numCenters; k++) { float distance = lowerBound[j * numCenters + k] - newcdist[k]; if (distance < 0) distance = 0; lowerBound[j * numCenters + k] = distance; } } /* Step 6 */ /* We reset r(x) before Step 3 in the next iteration */ for (j = 0; j < numSamples; j++) upperBound[j] += newcdist[closestCenters[j]]; /* Step 7 */ for (j = 0; j < numCenters; j++) VectorArraySet(centers, j, VectorArrayGet(newCenters, j)); if (changes == 0 && iteration != 0) break; } VectorArrayFree(newCenters); pfree(centerCounts); pfree(closestCenters); pfree(lowerBound); pfree(upperBound); pfree(s); pfree(halfcdist); pfree(newcdist); } /* * Detect issues with centers */ static void CheckCenters(Relation index, VectorArray centers) { FmgrInfo *normprocinfo; if (centers->length != centers->maxlen) elog(ERROR, "Not enough centers. Please report a bug."); /* Ensure no NaN or infinite values */ for (int i = 0; i < centers->length; i++) { Vector *vec = VectorArrayGet(centers, i); for (int j = 0; j < vec->dim; j++) { if (isnan(vec->x[j])) elog(ERROR, "NaN detected. Please report a bug."); if (isinf(vec->x[j])) elog(ERROR, "Infinite value detected. Please report a bug."); } } /* Ensure no duplicate centers */ /* Fine to sort in-place */ qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), CompareVectors); for (int i = 1; i < centers->length; i++) { if (CompareVectors(VectorArrayGet(centers, i), VectorArrayGet(centers, i - 1)) == 0) elog(ERROR, "Duplicate centers detected. Please report a bug."); } /* Ensure no zero vectors for cosine distance */ /* Check NORM_PROC instead of KMEANS_NORM_PROC */ normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); if (normprocinfo != NULL) { Oid collation = index->rd_indcollation[0]; for (int i = 0; i < centers->length; i++) { double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i)))); if (norm == 0) elog(ERROR, "Zero norm detected. Please report a bug."); } } } /* * Perform naive k-means centering * We use spherical k-means for inner product and cosine */ void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers) { if (samples->length <= centers->maxlen) QuickCenters(index, samples, centers); else ElkanKmeans(index, samples, centers); CheckCenters(index, centers); } pgvector-0.6.0/src/ivfscan.c000066400000000000000000000213311455577216400157520ustar00rootroot00000000000000#include "postgres.h" #include #include "access/relscan.h" #include "catalog/pg_operator_d.h" #include "catalog/pg_type_d.h" #include "lib/pairingheap.h" #include "ivfflat.h" #include "miscadmin.h" #include "pgstat.h" #include "storage/bufmgr.h" /* * Compare list distances */ static int CompareLists(const pairingheap_node *a, const pairingheap_node *b, void *arg) { if (((const IvfflatScanList *) a)->distance > ((const IvfflatScanList *) b)->distance) return 1; if (((const IvfflatScanList *) a)->distance < ((const IvfflatScanList *) b)->distance) return -1; return 0; } /* * Get lists and sort by distance */ static void GetScanLists(IndexScanDesc scan, Datum value) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO; int listCount = 0; double maxDistance = DBL_MAX; /* Search all list pages */ while (BlockNumberIsValid(nextblkno)) { Buffer cbuf; Page cpage; OffsetNumber maxoffno; cbuf = ReadBuffer(scan->indexRelation, nextblkno); LockBuffer(cbuf, BUFFER_LOCK_SHARE); cpage = BufferGetPage(cbuf); maxoffno = PageGetMaxOffsetNumber(cpage); for (OffsetNumber offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { IvfflatList list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno)); double distance; /* Use procinfo from the index instead of scan key for performance */ distance = DatumGetFloat8(FunctionCall2Coll(so->procinfo, so->collation, PointerGetDatum(&list->center), value)); if (listCount < so->probes) { IvfflatScanList *scanlist; scanlist = &so->lists[listCount]; scanlist->startPage = list->startPage; scanlist->distance = distance; listCount++; /* Add to heap */ pairingheap_add(so->listQueue, &scanlist->ph_node); /* Calculate max distance */ if (listCount == so->probes) maxDistance = ((IvfflatScanList *) pairingheap_first(so->listQueue))->distance; } else if (distance < maxDistance) { IvfflatScanList *scanlist; /* Remove */ scanlist = (IvfflatScanList *) pairingheap_remove_first(so->listQueue); /* Reuse */ scanlist->startPage = list->startPage; scanlist->distance = distance; pairingheap_add(so->listQueue, &scanlist->ph_node); /* Update max distance */ maxDistance = ((IvfflatScanList *) pairingheap_first(so->listQueue))->distance; } } nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno; UnlockReleaseBuffer(cbuf); } } /* * Get items */ static void GetScanItems(IndexScanDesc scan, Datum value) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; TupleDesc tupdesc = RelationGetDescr(scan->indexRelation); double tuples = 0; TupleTableSlot *slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsVirtual); /* * Reuse same set of shared buffers for scan * * See postgres/src/backend/storage/buffer/README for description */ BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD); /* Search closest probes lists */ while (!pairingheap_is_empty(so->listQueue)) { BlockNumber searchPage = ((IvfflatScanList *) pairingheap_remove_first(so->listQueue))->startPage; /* Search all entry pages for list */ while (BlockNumberIsValid(searchPage)) { Buffer buf; Page page; OffsetNumber maxoffno; buf = ReadBufferExtended(scan->indexRelation, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); maxoffno = PageGetMaxOffsetNumber(page); for (OffsetNumber offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { IndexTuple itup; Datum datum; bool isnull; ItemId itemid = PageGetItemId(page, offno); itup = (IndexTuple) PageGetItem(page, itemid); datum = index_getattr(itup, 1, tupdesc, &isnull); /* * Add virtual tuple * * Use procinfo from the index instead of scan key for * performance */ ExecClearTuple(slot); slot->tts_values[0] = FunctionCall2Coll(so->procinfo, so->collation, datum, value); slot->tts_isnull[0] = false; slot->tts_values[1] = PointerGetDatum(&itup->t_tid); slot->tts_isnull[1] = false; ExecStoreVirtualTuple(slot); tuplesort_puttupleslot(so->sortstate, slot); tuples++; } searchPage = IvfflatPageGetOpaque(page)->nextblkno; UnlockReleaseBuffer(buf); } } FreeAccessStrategy(bas); if (tuples < 100) ereport(DEBUG1, (errmsg("index scan found few tuples"), errdetail("Index may have been created with little data."), errhint("Recreate the index and possibly decrease lists."))); tuplesort_performsort(so->sortstate); } /* * Prepare for an index scan */ IndexScanDesc ivfflatbeginscan(Relation index, int nkeys, int norderbys) { IndexScanDesc scan; IvfflatScanOpaque so; int lists; int dimensions; AttrNumber attNums[] = {1}; Oid sortOperators[] = {Float8LessOperator}; Oid sortCollations[] = {InvalidOid}; bool nullsFirstFlags[] = {false}; int probes = ivfflat_probes; scan = RelationGetIndexScan(index, nkeys, norderbys); /* Get lists and dimensions from metapage */ IvfflatGetMetaPageInfo(index, &lists, &dimensions); if (probes > lists) probes = lists; so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList)); so->first = true; so->probes = probes; so->dimensions = dimensions; /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); so->normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); so->collation = index->rd_indcollation[0]; /* Create tuple description for sorting */ so->tupdesc = CreateTemplateTupleDesc(2); TupleDescInitEntry(so->tupdesc, (AttrNumber) 1, "distance", FLOAT8OID, -1, 0); TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "heaptid", TIDOID, -1, 0); /* Prep sort */ so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, NULL, false); so->slot = MakeSingleTupleTableSlot(so->tupdesc, &TTSOpsMinimalTuple); so->listQueue = pairingheap_allocate(CompareLists, scan); scan->opaque = so; return scan; } /* * Start or restart an index scan */ void ivfflatrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; #if PG_VERSION_NUM >= 130000 if (!so->first) tuplesort_reset(so->sortstate); #endif so->first = true; pairingheap_reset(so->listQueue); if (keys && scan->numberOfKeys > 0) memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData)); if (orderbys && scan->numberOfOrderBys > 0) memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData)); } /* * Fetch the next tuple in the given scan */ bool ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; /* * Index can be used to scan backward, but Postgres doesn't support * backward scan on operators */ Assert(ScanDirectionIsForward(dir)); if (so->first) { Datum value; /* Count index scan for stats */ pgstat_count_index_scan(scan->indexRelation); /* Safety check */ if (scan->orderByData == NULL) elog(ERROR, "cannot scan ivfflat index without order"); /* Requires MVCC-compliant snapshot as not able to pin during sorting */ /* https://www.postgresql.org/docs/current/index-locking.html */ if (!IsMVCCSnapshot(scan->xs_snapshot)) elog(ERROR, "non-MVCC snapshots are not supported with ivfflat"); if (scan->orderByData->sk_flags & SK_ISNULL) value = PointerGetDatum(InitVector(so->dimensions)); else { value = scan->orderByData->sk_argument; /* Value should not be compressed or toasted */ Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); /* Fine if normalization fails */ if (so->normprocinfo != NULL) IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL); } IvfflatBench("GetScanLists", GetScanLists(scan, value)); IvfflatBench("GetScanItems", GetScanItems(scan, value)); so->first = false; /* Clean up if we allocated a new value */ if (value != scan->orderByData->sk_argument) pfree(DatumGetPointer(value)); } if (tuplesort_gettupleslot(so->sortstate, true, false, so->slot, NULL)) { ItemPointer heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->slot, 2, &so->isnull)); scan->xs_heaptid = *heaptid; scan->xs_recheck = false; scan->xs_recheckorderby = false; return true; } return false; } /* * End a scan and release resources */ void ivfflatendscan(IndexScanDesc scan) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; pairingheap_free(so->listQueue); tuplesort_end(so->sortstate); pfree(so); scan->opaque = NULL; } pgvector-0.6.0/src/ivfutils.c000066400000000000000000000116731455577216400161760ustar00rootroot00000000000000#include "postgres.h" #include "access/generic_xlog.h" #include "ivfflat.h" #include "storage/bufmgr.h" #include "vector.h" /* * Allocate a vector array */ VectorArray VectorArrayInit(int maxlen, int dimensions) { VectorArray res = palloc(sizeof(VectorArrayData)); res->length = 0; res->maxlen = maxlen; res->dim = dimensions; res->items = palloc_extended(maxlen * VECTOR_SIZE(dimensions), MCXT_ALLOC_ZERO | MCXT_ALLOC_HUGE); return res; } /* * Free a vector array */ void VectorArrayFree(VectorArray arr) { pfree(arr->items); pfree(arr); } /* * Print vector array - useful for debugging */ void PrintVectorArray(char *msg, VectorArray arr) { for (int i = 0; i < arr->length; i++) PrintVector(msg, VectorArrayGet(arr, i)); } /* * Get the number of lists in the index */ int IvfflatGetLists(Relation index) { IvfflatOptions *opts = (IvfflatOptions *) index->rd_options; if (opts) return opts->lists; return IVFFLAT_DEFAULT_LISTS; } /* * Get proc */ FmgrInfo * IvfflatOptionalProcInfo(Relation index, uint16 procnum) { if (!OidIsValid(index_getprocid(index, 1, procnum))) return NULL; return index_getprocinfo(index, 1, procnum); } /* * Divide by the norm * * Returns false if value should not be indexed * * The caller needs to free the pointer stored in value * if it's different than the original value */ bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) { double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); if (norm > 0) { Vector *v = DatumGetVector(*value); if (result == NULL) result = InitVector(v->dim); for (int i = 0; i < v->dim; i++) result->x[i] = v->x[i] / norm; *value = PointerGetDatum(result); return true; } return false; } /* * New buffer */ Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum) { Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); return buf; } /* * Init page */ void IvfflatInitPage(Buffer buf, Page page) { PageInit(page, BufferGetPageSize(buf), sizeof(IvfflatPageOpaqueData)); IvfflatPageGetOpaque(page)->nextblkno = InvalidBlockNumber; IvfflatPageGetOpaque(page)->page_id = IVFFLAT_PAGE_ID; } /* * Init and register page */ void IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state) { *state = GenericXLogStart(index); *page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE); IvfflatInitPage(*buf, *page); } /* * Commit buffer */ void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state) { GenericXLogFinish(state); UnlockReleaseBuffer(buf); } /* * Add a new page * * The order is very important!! */ void IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum) { /* Get new buffer */ Buffer newbuf = IvfflatNewBuffer(index, forkNum); Page newpage = GenericXLogRegisterBuffer(*state, newbuf, GENERIC_XLOG_FULL_IMAGE); /* Update the previous buffer */ IvfflatPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(newbuf); /* Init new page */ IvfflatInitPage(newbuf, newpage); /* Commit */ GenericXLogFinish(*state); /* Unlock */ UnlockReleaseBuffer(*buf); *state = GenericXLogStart(index); *page = GenericXLogRegisterBuffer(*state, newbuf, GENERIC_XLOG_FULL_IMAGE); *buf = newbuf; } /* * Get the metapage info */ void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions) { Buffer buf; Page page; IvfflatMetaPage metap; buf = ReadBuffer(index, IVFFLAT_METAPAGE_BLKNO); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); metap = IvfflatPageGetMeta(page); *lists = metap->lists; if (dimensions != NULL) *dimensions = metap->dimensions; UnlockReleaseBuffer(buf); } /* * Update the start or insert page of a list */ void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum) { Buffer buf; Page page; GenericXLogState *state; IvfflatList list; bool changed = false; buf = ReadBufferExtended(index, forkNum, listInfo.blkno, RBM_NORMAL, NULL); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); list = (IvfflatList) PageGetItem(page, PageGetItemId(page, listInfo.offno)); if (BlockNumberIsValid(insertPage) && insertPage != list->insertPage) { /* Skip update if insert page is lower than original insert page */ /* This is needed to prevent insert from overwriting vacuum */ if (!BlockNumberIsValid(originalInsertPage) || insertPage >= originalInsertPage) { list->insertPage = insertPage; changed = true; } } if (BlockNumberIsValid(startPage) && startPage != list->startPage) { list->startPage = startPage; changed = true; } /* Only commit if changed */ if (changed) IvfflatCommitBuffer(buf, state); else { GenericXLogAbort(state); UnlockReleaseBuffer(buf); } } pgvector-0.6.0/src/ivfvacuum.c000066400000000000000000000075461455577216400163420ustar00rootroot00000000000000#include "postgres.h" #include "access/generic_xlog.h" #include "commands/vacuum.h" #include "ivfflat.h" #include "storage/bufmgr.h" /* * Bulk delete tuples from the index */ IndexBulkDeleteResult * ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state) { Relation index = info->index; BlockNumber blkno = IVFFLAT_HEAD_BLKNO; BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD); if (stats == NULL) stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult)); /* Iterate over list pages */ while (BlockNumberIsValid(blkno)) { Buffer cbuf; Page cpage; OffsetNumber coffno; OffsetNumber cmaxoffno; BlockNumber startPages[MaxOffsetNumber]; ListInfo listInfo; cbuf = ReadBuffer(index, blkno); LockBuffer(cbuf, BUFFER_LOCK_SHARE); cpage = BufferGetPage(cbuf); cmaxoffno = PageGetMaxOffsetNumber(cpage); /* Iterate over lists */ for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno)) { IvfflatList list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno)); startPages[coffno - FirstOffsetNumber] = list->startPage; } listInfo.blkno = blkno; blkno = IvfflatPageGetOpaque(cpage)->nextblkno; UnlockReleaseBuffer(cbuf); for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno)) { BlockNumber searchPage = startPages[coffno - FirstOffsetNumber]; BlockNumber insertPage = InvalidBlockNumber; /* Iterate over entry pages */ while (BlockNumberIsValid(searchPage)) { Buffer buf; Page page; GenericXLogState *state; OffsetNumber offno; OffsetNumber maxoffno; OffsetNumber deletable[MaxOffsetNumber]; int ndeletable; vacuum_delay_point(); buf = ReadBufferExtended(index, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas); /* * ambulkdelete cannot delete entries from pages that are * pinned by other backends * * https://www.postgresql.org/docs/current/index-locking.html */ LockBufferForCleanup(buf); state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); maxoffno = PageGetMaxOffsetNumber(page); ndeletable = 0; /* Find deleted tuples */ for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { IndexTuple itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno)); ItemPointer htup = &(itup->t_tid); if (callback(htup, callback_state)) { deletable[ndeletable++] = offno; stats->tuples_removed++; } else stats->num_index_tuples++; } /* Set to first free page */ /* Must be set before searchPage is updated */ if (!BlockNumberIsValid(insertPage) && ndeletable > 0) insertPage = searchPage; searchPage = IvfflatPageGetOpaque(page)->nextblkno; if (ndeletable > 0) { /* Delete tuples */ PageIndexMultiDelete(page, deletable, ndeletable); GenericXLogFinish(state); } else GenericXLogAbort(state); UnlockReleaseBuffer(buf); } /* * Update after all tuples deleted. * * We don't add or delete items from lists pages, so offset won't * change. */ if (BlockNumberIsValid(insertPage)) { listInfo.offno = coffno; IvfflatUpdateList(index, listInfo, insertPage, InvalidBlockNumber, InvalidBlockNumber, MAIN_FORKNUM); } } } FreeAccessStrategy(bas); return stats; } /* * Clean up after a VACUUM operation */ IndexBulkDeleteResult * ivfflatvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats) { Relation rel = info->index; if (info->analyze_only) return stats; /* stats is NULL if ambulkdelete not called */ /* OK to return NULL if index not changed */ if (stats == NULL) return NULL; stats->num_pages = RelationGetNumberOfBlocks(rel); return stats; } pgvector-0.6.0/src/vector.c000066400000000000000000000553111455577216400156300ustar00rootroot00000000000000#include "postgres.h" #include #include "catalog/pg_type.h" #include "common/shortest_dec.h" #include "fmgr.h" #include "hnsw.h" #include "ivfflat.h" #include "lib/stringinfo.h" #include "libpq/pqformat.h" #include "port.h" /* for strtof() */ #include "utils/array.h" #include "utils/builtins.h" #include "utils/float.h" #include "utils/lsyscache.h" #include "utils/numeric.h" #include "vector.h" #if PG_VERSION_NUM >= 160000 #include "varatt.h" #endif #if PG_VERSION_NUM < 130000 #define TYPALIGN_DOUBLE 'd' #define TYPALIGN_INT 'i' #endif #define STATE_DIMS(x) (ARR_DIMS(x)[0] - 1) #define CreateStateDatums(dim) palloc(sizeof(Datum) * (dim + 1)) PG_MODULE_MAGIC; /* * Initialize index options and variables */ PGDLLEXPORT void _PG_init(void); void _PG_init(void) { HnswInit(); IvfflatInit(); } /* * Ensure same dimensions */ static inline void CheckDims(Vector * a, Vector * b) { if (a->dim != b->dim) ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("different vector dimensions %d and %d", a->dim, b->dim))); } /* * Ensure expected dimensions */ static inline void CheckExpectedDim(int32 typmod, int dim) { if (typmod != -1 && typmod != dim) ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("expected %d dimensions, not %d", typmod, dim))); } /* * Ensure valid dimensions */ static inline void CheckDim(int dim) { if (dim < 1) ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("vector must have at least 1 dimension"))); if (dim > VECTOR_MAX_DIM) ereport(ERROR, (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM))); } /* * Ensure finite element */ static inline void CheckElement(float value) { if (isnan(value)) ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("NaN not allowed in vector"))); if (isinf(value)) ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("infinite value not allowed in vector"))); } /* * Allocate and initialize a new vector */ Vector * InitVector(int dim) { Vector *result; int size; size = VECTOR_SIZE(dim); result = (Vector *) palloc0(size); SET_VARSIZE(result, size); result->dim = dim; return result; } /* * Check for whitespace, since array_isspace() is static */ static inline bool vector_isspace(char ch) { if (ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f') return true; return false; } /* * Check state array */ static float8 * CheckStateArray(ArrayType *statearray, const char *caller) { if (ARR_NDIM(statearray) != 1 || ARR_DIMS(statearray)[0] < 1 || ARR_HASNULL(statearray) || ARR_ELEMTYPE(statearray) != FLOAT8OID) elog(ERROR, "%s: expected state array", caller); return (float8 *) ARR_DATA_PTR(statearray); } #if PG_VERSION_NUM < 120003 static pg_noinline void float_overflow_error(void) { ereport(ERROR, (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), errmsg("value out of range: overflow"))); } static pg_noinline void float_underflow_error(void) { ereport(ERROR, (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), errmsg("value out of range: underflow"))); } #endif /* * Convert textual representation to internal representation */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_in); Datum vector_in(PG_FUNCTION_ARGS) { char *lit = PG_GETARG_CSTRING(0); int32 typmod = PG_GETARG_INT32(2); float x[VECTOR_MAX_DIM]; int dim = 0; char *pt; char *stringEnd; Vector *result; char *litcopy = pstrdup(lit); char *str = litcopy; while (vector_isspace(*str)) str++; if (*str != '[') ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), errmsg("malformed vector literal: \"%s\"", lit), errdetail("Vector contents must start with \"[\"."))); str++; pt = strtok(str, ","); stringEnd = pt; while (pt != NULL && *stringEnd != ']') { if (dim == VECTOR_MAX_DIM) ereport(ERROR, (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), errmsg("vector cannot have more than %d dimensions", VECTOR_MAX_DIM))); while (vector_isspace(*pt)) pt++; /* Check for empty string like float4in */ if (*pt == '\0') ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), errmsg("invalid input syntax for type vector: \"%s\"", lit))); /* Use strtof like float4in to avoid a double-rounding problem */ x[dim] = strtof(pt, &stringEnd); CheckElement(x[dim]); dim++; if (stringEnd == pt) ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), errmsg("invalid input syntax for type vector: \"%s\"", lit))); while (vector_isspace(*stringEnd)) stringEnd++; if (*stringEnd != '\0' && *stringEnd != ']') ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), errmsg("invalid input syntax for type vector: \"%s\"", lit))); pt = strtok(NULL, ","); } if (stringEnd == NULL || *stringEnd != ']') ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), errmsg("malformed vector literal: \"%s\"", lit), errdetail("Unexpected end of input."))); stringEnd++; /* Only whitespace is allowed after the closing brace */ while (vector_isspace(*stringEnd)) stringEnd++; if (*stringEnd != '\0') ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), errmsg("malformed vector literal: \"%s\"", lit), errdetail("Junk after closing right brace."))); /* Ensure no consecutive delimiters since strtok skips */ for (pt = lit + 1; *pt != '\0'; pt++) { if (pt[-1] == ',' && *pt == ',') ereport(ERROR, (errcode(ERRCODE_INVALID_TEXT_REPRESENTATION), errmsg("malformed vector literal: \"%s\"", lit))); } if (dim < 1) ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("vector must have at least 1 dimension"))); pfree(litcopy); CheckExpectedDim(typmod, dim); result = InitVector(dim); for (int i = 0; i < dim; i++) result->x[i] = x[i]; PG_RETURN_POINTER(result); } /* * Convert internal representation to textual representation */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_out); Datum vector_out(PG_FUNCTION_ARGS) { Vector *vector = PG_GETARG_VECTOR_P(0); int dim = vector->dim; char *buf; char *ptr; int n; /* * Need: * * dim * (FLOAT_SHORTEST_DECIMAL_LEN - 1) bytes for * float_to_shortest_decimal_bufn * * dim - 1 bytes for separator * * 3 bytes for [, ], and \0 */ buf = (char *) palloc(FLOAT_SHORTEST_DECIMAL_LEN * dim + 2); ptr = buf; *ptr = '['; ptr++; for (int i = 0; i < dim; i++) { if (i > 0) { *ptr = ','; ptr++; } n = float_to_shortest_decimal_bufn(vector->x[i], ptr); ptr += n; } *ptr = ']'; ptr++; *ptr = '\0'; PG_FREE_IF_COPY(vector, 0); PG_RETURN_CSTRING(buf); } /* * Print vector - useful for debugging */ void PrintVector(char *msg, Vector * vector) { char *out = DatumGetPointer(DirectFunctionCall1(vector_out, PointerGetDatum(vector))); elog(INFO, "%s = %s", msg, out); pfree(out); } /* * Convert type modifier */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_typmod_in); Datum vector_typmod_in(PG_FUNCTION_ARGS) { ArrayType *ta = PG_GETARG_ARRAYTYPE_P(0); int32 *tl; int n; tl = ArrayGetIntegerTypmods(ta, &n); if (n != 1) ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("invalid type modifier"))); if (*tl < 1) ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("dimensions for type vector must be at least 1"))); if (*tl > VECTOR_MAX_DIM) ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("dimensions for type vector cannot exceed %d", VECTOR_MAX_DIM))); PG_RETURN_INT32(*tl); } /* * Convert external binary representation to internal representation */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_recv); Datum vector_recv(PG_FUNCTION_ARGS) { StringInfo buf = (StringInfo) PG_GETARG_POINTER(0); int32 typmod = PG_GETARG_INT32(2); Vector *result; int16 dim; int16 unused; dim = pq_getmsgint(buf, sizeof(int16)); unused = pq_getmsgint(buf, sizeof(int16)); CheckDim(dim); CheckExpectedDim(typmod, dim); if (unused != 0) ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("expected unused to be 0, not %d", unused))); result = InitVector(dim); for (int i = 0; i < dim; i++) { result->x[i] = pq_getmsgfloat4(buf); CheckElement(result->x[i]); } PG_RETURN_POINTER(result); } /* * Convert internal representation to the external binary representation */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_send); Datum vector_send(PG_FUNCTION_ARGS) { Vector *vec = PG_GETARG_VECTOR_P(0); StringInfoData buf; pq_begintypsend(&buf); pq_sendint(&buf, vec->dim, sizeof(int16)); pq_sendint(&buf, vec->unused, sizeof(int16)); for (int i = 0; i < vec->dim; i++) pq_sendfloat4(&buf, vec->x[i]); PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); } /* * Convert vector to vector * This is needed to check the type modifier */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector); Datum vector(PG_FUNCTION_ARGS) { Vector *vec = PG_GETARG_VECTOR_P(0); int32 typmod = PG_GETARG_INT32(1); CheckExpectedDim(typmod, vec->dim); PG_RETURN_POINTER(vec); } /* * Convert array to vector */ PGDLLEXPORT PG_FUNCTION_INFO_V1(array_to_vector); Datum array_to_vector(PG_FUNCTION_ARGS) { ArrayType *array = PG_GETARG_ARRAYTYPE_P(0); int32 typmod = PG_GETARG_INT32(1); Vector *result; int16 typlen; bool typbyval; char typalign; Datum *elemsp; int nelemsp; if (ARR_NDIM(array) > 1) ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("array must be 1-D"))); if (ARR_HASNULL(array) && array_contains_nulls(array)) ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), errmsg("array must not contain nulls"))); get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign); deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, NULL, &nelemsp); CheckDim(nelemsp); CheckExpectedDim(typmod, nelemsp); result = InitVector(nelemsp); if (ARR_ELEMTYPE(array) == INT4OID) { for (int i = 0; i < nelemsp; i++) result->x[i] = DatumGetInt32(elemsp[i]); } else if (ARR_ELEMTYPE(array) == FLOAT8OID) { for (int i = 0; i < nelemsp; i++) result->x[i] = DatumGetFloat8(elemsp[i]); } else if (ARR_ELEMTYPE(array) == FLOAT4OID) { for (int i = 0; i < nelemsp; i++) result->x[i] = DatumGetFloat4(elemsp[i]); } else if (ARR_ELEMTYPE(array) == NUMERICOID) { for (int i = 0; i < nelemsp; i++) result->x[i] = DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])); } else { ereport(ERROR, (errcode(ERRCODE_DATA_EXCEPTION), errmsg("unsupported array type"))); } /* * Free allocation from deconstruct_array. Do not free individual elements * when pass-by-reference since they point to original array. */ pfree(elemsp); /* Check elements */ for (int i = 0; i < result->dim; i++) CheckElement(result->x[i]); PG_RETURN_POINTER(result); } /* * Convert vector to float4[] */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_to_float4); Datum vector_to_float4(PG_FUNCTION_ARGS) { Vector *vec = PG_GETARG_VECTOR_P(0); Datum *datums; ArrayType *result; datums = (Datum *) palloc(sizeof(Datum) * vec->dim); for (int i = 0; i < vec->dim; i++) datums[i] = Float4GetDatum(vec->x[i]); /* Use TYPALIGN_INT for float4 */ result = construct_array(datums, vec->dim, FLOAT4OID, sizeof(float4), true, TYPALIGN_INT); pfree(datums); PG_RETURN_POINTER(result); } /* * Get the L2 distance between vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(l2_distance); Datum l2_distance(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; float distance = 0.0; float diff; CheckDims(a, b); /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) { diff = ax[i] - bx[i]; distance += diff * diff; } PG_RETURN_FLOAT8(sqrt((double) distance)); } /* * Get the L2 squared distance between vectors * This saves a sqrt calculation */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_l2_squared_distance); Datum vector_l2_squared_distance(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; float distance = 0.0; float diff; CheckDims(a, b); /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) { diff = ax[i] - bx[i]; distance += diff * diff; } PG_RETURN_FLOAT8((double) distance); } /* * Get the inner product of two vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(inner_product); Datum inner_product(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; float distance = 0.0; CheckDims(a, b); /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) distance += ax[i] * bx[i]; PG_RETURN_FLOAT8((double) distance); } /* * Get the negative inner product of two vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_negative_inner_product); Datum vector_negative_inner_product(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; float distance = 0.0; CheckDims(a, b); /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) distance += ax[i] * bx[i]; PG_RETURN_FLOAT8((double) distance * -1); } /* * Get the cosine distance between two vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(cosine_distance); Datum cosine_distance(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; float distance = 0.0; float norma = 0.0; float normb = 0.0; double similarity; CheckDims(a, b); /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) { distance += ax[i] * bx[i]; norma += ax[i] * ax[i]; normb += bx[i] * bx[i]; } /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ similarity = (double) distance / sqrt((double) norma * (double) normb); #ifdef _MSC_VER /* /fp:fast may not propagate NaN */ if (isnan(similarity)) PG_RETURN_FLOAT8(NAN); #endif /* Keep in range */ if (similarity > 1) similarity = 1.0; else if (similarity < -1) similarity = -1.0; PG_RETURN_FLOAT8(1.0 - similarity); } /* * Get the distance for spherical k-means * Currently uses angular distance since needs to satisfy triangle inequality * Assumes inputs are unit vectors (skips norm) */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_spherical_distance); Datum vector_spherical_distance(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; float dp = 0.0; double distance; CheckDims(a, b); /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) dp += ax[i] * bx[i]; distance = (double) dp; /* Prevent NaN with acos with loss of precision */ if (distance > 1) distance = 1; else if (distance < -1) distance = -1; PG_RETURN_FLOAT8(acos(distance) / M_PI); } /* * Get the L1 distance between two vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(l1_distance); Datum l1_distance(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; float distance = 0.0; CheckDims(a, b); /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) distance += fabsf(ax[i] - bx[i]); PG_RETURN_FLOAT8((double) distance); } /* * Get the dimensions of a vector */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_dims); Datum vector_dims(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); PG_RETURN_INT32(a->dim); } /* * Get the L2 norm of a vector */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_norm); Datum vector_norm(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); float *ax = a->x; double norm = 0.0; /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) norm += (double) ax[i] * (double) ax[i]; PG_RETURN_FLOAT8(sqrt(norm)); } /* * Add vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_add); Datum vector_add(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; Vector *result; float *rx; CheckDims(a, b); result = InitVector(a->dim); rx = result->x; /* Auto-vectorized */ for (int i = 0, imax = a->dim; i < imax; i++) rx[i] = ax[i] + bx[i]; /* Check for overflow */ for (int i = 0, imax = a->dim; i < imax; i++) { if (isinf(rx[i])) float_overflow_error(); } PG_RETURN_POINTER(result); } /* * Subtract vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_sub); Datum vector_sub(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; Vector *result; float *rx; CheckDims(a, b); result = InitVector(a->dim); rx = result->x; /* Auto-vectorized */ for (int i = 0, imax = a->dim; i < imax; i++) rx[i] = ax[i] - bx[i]; /* Check for overflow */ for (int i = 0, imax = a->dim; i < imax; i++) { if (isinf(rx[i])) float_overflow_error(); } PG_RETURN_POINTER(result); } /* * Multiply vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_mul); Datum vector_mul(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; Vector *result; float *rx; CheckDims(a, b); result = InitVector(a->dim); rx = result->x; /* Auto-vectorized */ for (int i = 0, imax = a->dim; i < imax; i++) rx[i] = ax[i] * bx[i]; /* Check for overflow and underflow */ for (int i = 0, imax = a->dim; i < imax; i++) { if (isinf(rx[i])) float_overflow_error(); if (rx[i] == 0 && !(ax[i] == 0 || bx[i] == 0)) float_underflow_error(); } PG_RETURN_POINTER(result); } /* * Internal helper to compare vectors */ int vector_cmp_internal(Vector * a, Vector * b) { CheckDims(a, b); for (int i = 0; i < a->dim; i++) { if (a->x[i] < b->x[i]) return -1; if (a->x[i] > b->x[i]) return 1; } return 0; } /* * Less than */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_lt); Datum vector_lt(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); PG_RETURN_BOOL(vector_cmp_internal(a, b) < 0); } /* * Less than or equal */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_le); Datum vector_le(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); PG_RETURN_BOOL(vector_cmp_internal(a, b) <= 0); } /* * Equal */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_eq); Datum vector_eq(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); PG_RETURN_BOOL(vector_cmp_internal(a, b) == 0); } /* * Not equal */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_ne); Datum vector_ne(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); PG_RETURN_BOOL(vector_cmp_internal(a, b) != 0); } /* * Greater than or equal */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_ge); Datum vector_ge(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); PG_RETURN_BOOL(vector_cmp_internal(a, b) >= 0); } /* * Greater than */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_gt); Datum vector_gt(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); PG_RETURN_BOOL(vector_cmp_internal(a, b) > 0); } /* * Compare vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_cmp); Datum vector_cmp(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); PG_RETURN_INT32(vector_cmp_internal(a, b)); } /* * Accumulate vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_accum); Datum vector_accum(PG_FUNCTION_ARGS) { ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0); Vector *newval = PG_GETARG_VECTOR_P(1); float8 *statevalues; int16 dim; bool newarr; float8 n; Datum *statedatums; float *x = newval->x; ArrayType *result; /* Check array before using */ statevalues = CheckStateArray(statearray, "vector_accum"); dim = STATE_DIMS(statearray); newarr = dim == 0; if (newarr) dim = newval->dim; else CheckExpectedDim(dim, newval->dim); n = statevalues[0] + 1.0; statedatums = CreateStateDatums(dim); statedatums[0] = Float8GetDatum(n); if (newarr) { for (int i = 0; i < dim; i++) statedatums[i + 1] = Float8GetDatum((double) x[i]); } else { for (int i = 0; i < dim; i++) { double v = statevalues[i + 1] + x[i]; /* Check for overflow */ if (isinf(v)) float_overflow_error(); statedatums[i + 1] = Float8GetDatum(v); } } /* Use float8 array like float4_accum */ result = construct_array(statedatums, dim + 1, FLOAT8OID, sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE); pfree(statedatums); PG_RETURN_ARRAYTYPE_P(result); } /* * Combine vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_combine); Datum vector_combine(PG_FUNCTION_ARGS) { ArrayType *statearray1 = PG_GETARG_ARRAYTYPE_P(0); ArrayType *statearray2 = PG_GETARG_ARRAYTYPE_P(1); float8 *statevalues1; float8 *statevalues2; float8 n; float8 n1; float8 n2; int16 dim; Datum *statedatums; ArrayType *result; /* Check arrays before using */ statevalues1 = CheckStateArray(statearray1, "vector_combine"); statevalues2 = CheckStateArray(statearray2, "vector_combine"); n1 = statevalues1[0]; n2 = statevalues2[0]; if (n1 == 0.0) { n = n2; dim = STATE_DIMS(statearray2); statedatums = CreateStateDatums(dim); for (int i = 1; i <= dim; i++) statedatums[i] = Float8GetDatum(statevalues2[i]); } else if (n2 == 0.0) { n = n1; dim = STATE_DIMS(statearray1); statedatums = CreateStateDatums(dim); for (int i = 1; i <= dim; i++) statedatums[i] = Float8GetDatum(statevalues1[i]); } else { n = n1 + n2; dim = STATE_DIMS(statearray1); CheckExpectedDim(dim, STATE_DIMS(statearray2)); statedatums = CreateStateDatums(dim); for (int i = 1; i <= dim; i++) { double v = statevalues1[i] + statevalues2[i]; /* Check for overflow */ if (isinf(v)) float_overflow_error(); statedatums[i] = Float8GetDatum(v); } } statedatums[0] = Float8GetDatum(n); result = construct_array(statedatums, dim + 1, FLOAT8OID, sizeof(float8), FLOAT8PASSBYVAL, TYPALIGN_DOUBLE); pfree(statedatums); PG_RETURN_ARRAYTYPE_P(result); } /* * Average vectors */ PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_avg); Datum vector_avg(PG_FUNCTION_ARGS) { ArrayType *statearray = PG_GETARG_ARRAYTYPE_P(0); float8 *statevalues; float8 n; uint16 dim; Vector *result; /* Check array before using */ statevalues = CheckStateArray(statearray, "vector_avg"); n = statevalues[0]; /* SQL defines AVG of no values to be NULL */ if (n == 0.0) PG_RETURN_NULL(); /* Create vector */ dim = STATE_DIMS(statearray); CheckDim(dim); result = InitVector(dim); for (int i = 0; i < dim; i++) { result->x[i] = statevalues[i + 1] / n; CheckElement(result->x[i]); } PG_RETURN_POINTER(result); } pgvector-0.6.0/src/vector.h000066400000000000000000000012051455577216400156260ustar00rootroot00000000000000#ifndef VECTOR_H #define VECTOR_H #define VECTOR_MAX_DIM 16000 #define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim)) #define DatumGetVector(x) ((Vector *) PG_DETOAST_DATUM(x)) #define PG_GETARG_VECTOR_P(x) DatumGetVector(PG_GETARG_DATUM(x)) #define PG_RETURN_VECTOR_P(x) PG_RETURN_POINTER(x) typedef struct Vector { int32 vl_len_; /* varlena header (do not touch directly!) */ int16 dim; /* number of dimensions */ int16 unused; float x[FLEXIBLE_ARRAY_MEMBER]; } Vector; Vector *InitVector(int dim); void PrintVector(char *msg, Vector * vector); int vector_cmp_internal(Vector * a, Vector * b); #endif pgvector-0.6.0/test/000077500000000000000000000000001455577216400143455ustar00rootroot00000000000000pgvector-0.6.0/test/expected/000077500000000000000000000000001455577216400161465ustar00rootroot00000000000000pgvector-0.6.0/test/expected/btree.out000066400000000000000000000005041455577216400177770ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t (val); SELECT * FROM t WHERE val = '[1,2,3]'; val --------- [1,2,3] (1 row) SELECT * FROM t ORDER BY val LIMIT 1; val --------- [0,0,0] (1 row) DROP TABLE t; pgvector-0.6.0/test/expected/cast.out000066400000000000000000000023061455577216400176320ustar00rootroot00000000000000SELECT ARRAY[1,2,3]::vector; array --------- [1,2,3] (1 row) SELECT ARRAY[1.0,2.0,3.0]::vector; array --------- [1,2,3] (1 row) SELECT ARRAY[1,2,3]::float4[]::vector; array --------- [1,2,3] (1 row) SELECT ARRAY[1,2,3]::float8[]::vector; array --------- [1,2,3] (1 row) SELECT ARRAY[1,2,3]::numeric[]::vector; array --------- [1,2,3] (1 row) SELECT '{NULL}'::real[]::vector; ERROR: array must not contain nulls SELECT '{NaN}'::real[]::vector; ERROR: NaN not allowed in vector SELECT '{Infinity}'::real[]::vector; ERROR: infinite value not allowed in vector SELECT '{-Infinity}'::real[]::vector; ERROR: infinite value not allowed in vector SELECT '{}'::real[]::vector; ERROR: vector must have at least 1 dimension SELECT '{{1}}'::real[]::vector; ERROR: array must be 1-D SELECT '[1,2,3]'::vector::real[]; float4 --------- {1,2,3} (1 row) SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; ERROR: vector cannot have more than 16000 dimensions SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n; ERROR: vector cannot have more than 16000 dimensions -- ensure no error SELECT ARRAY[1,2,3] = ARRAY[1,2,3]; ?column? ---------- t (1 row) pgvector-0.6.0/test/expected/copy.out000066400000000000000000000005531455577216400176540ustar00rootroot00000000000000CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE TABLE t2 (val vector(3)); \copy t TO 'results/data.bin' WITH (FORMAT binary) \copy t2 FROM 'results/data.bin' WITH (FORMAT binary) SELECT * FROM t2 ORDER BY val; val --------- [0,0,0] [1,1,1] [1,2,3] (4 rows) DROP TABLE t; DROP TABLE t2; pgvector-0.6.0/test/expected/functions.out000066400000000000000000000076011455577216400207130ustar00rootroot00000000000000SELECT '[1,2,3]'::vector + '[4,5,6]'; ?column? ---------- [5,7,9] (1 row) SELECT '[3e38]'::vector + '[3e38]'; ERROR: value out of range: overflow SELECT '[1,2,3]'::vector - '[4,5,6]'; ?column? ------------ [-3,-3,-3] (1 row) SELECT '[-3e38]'::vector - '[3e38]'; ERROR: value out of range: overflow SELECT '[1,2,3]'::vector * '[4,5,6]'; ?column? ----------- [4,10,18] (1 row) SELECT '[1e37]'::vector * '[1e37]'; ERROR: value out of range: overflow SELECT '[1e-37]'::vector * '[1e-37]'; ERROR: value out of range: underflow SELECT vector_dims('[1,2,3]'); vector_dims ------------- 3 (1 row) SELECT round(vector_norm('[1,1]')::numeric, 5); round --------- 1.41421 (1 row) SELECT vector_norm('[3,4]'); vector_norm ------------- 5 (1 row) SELECT vector_norm('[0,1]'); vector_norm ------------- 1 (1 row) SELECT vector_norm('[3e37,4e37]')::real; vector_norm ------------- 5e+37 (1 row) SELECT l2_distance('[0,0]', '[3,4]'); l2_distance ------------- 5 (1 row) SELECT l2_distance('[0,0]', '[0,1]'); l2_distance ------------- 1 (1 row) SELECT l2_distance('[1,2]', '[3]'); ERROR: different vector dimensions 2 and 1 SELECT l2_distance('[3e38]', '[-3e38]'); l2_distance ------------- Infinity (1 row) SELECT inner_product('[1,2]', '[3,4]'); inner_product --------------- 11 (1 row) SELECT inner_product('[1,2]', '[3]'); ERROR: different vector dimensions 2 and 1 SELECT inner_product('[3e38]', '[3e38]'); inner_product --------------- Infinity (1 row) SELECT cosine_distance('[1,2]', '[2,4]'); cosine_distance ----------------- 0 (1 row) SELECT cosine_distance('[1,2]', '[0,0]'); cosine_distance ----------------- NaN (1 row) SELECT cosine_distance('[1,1]', '[1,1]'); cosine_distance ----------------- 0 (1 row) SELECT cosine_distance('[1,0]', '[0,2]'); cosine_distance ----------------- 1 (1 row) SELECT cosine_distance('[1,1]', '[-1,-1]'); cosine_distance ----------------- 2 (1 row) SELECT cosine_distance('[1,2]', '[3]'); ERROR: different vector dimensions 2 and 1 SELECT cosine_distance('[1,1]', '[1.1,1.1]'); cosine_distance ----------------- 0 (1 row) SELECT cosine_distance('[1,1]', '[-1.1,-1.1]'); cosine_distance ----------------- 2 (1 row) SELECT cosine_distance('[3e38]', '[3e38]'); cosine_distance ----------------- NaN (1 row) SELECT l1_distance('[0,0]', '[3,4]'); l1_distance ------------- 7 (1 row) SELECT l1_distance('[0,0]', '[0,1]'); l1_distance ------------- 1 (1 row) SELECT l1_distance('[1,2]', '[3]'); ERROR: different vector dimensions 2 and 1 SELECT l1_distance('[3e38]', '[-3e38]'); l1_distance ------------- Infinity (1 row) SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- [2,3.5,5] (1 row) SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; avg ----------- [2,3.5,5] (1 row) SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v; avg ----- (1 row) SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; ERROR: expected 2 dimensions, not 1 SELECT avg(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; avg --------- [3e+38] (1 row) SELECT vector_avg(array_agg(n)) FROM generate_series(1, 16002) n; ERROR: vector cannot have more than 16000 dimensions SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; sum ---------- [4,7,10] (1 row) SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; sum ---------- [4,7,10] (1 row) SELECT sum(v) FROM unnest(ARRAY[]::vector[]) v; sum ----- (1 row) SELECT sum(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; ERROR: different vector dimensions 2 and 1 SELECT sum(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; ERROR: value out of range: overflow pgvector-0.6.0/test/expected/hnsw_cosine.out000066400000000000000000000010551455577216400212170ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING hnsw (val vector_cosine_ops); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <=> '[3,3,3]'; val --------- [1,1,1] [1,2,3] [1,2,4] (3 rows) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; count ------- 3 (1 row) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; count ------- 3 (1 row) DROP TABLE t; pgvector-0.6.0/test/expected/hnsw_ip.out000066400000000000000000000007141455577216400203500ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING hnsw (val vector_ip_ops); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <#> '[3,3,3]'; val --------- [1,2,4] [1,2,3] [1,1,1] [0,0,0] (4 rows) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; count ------- 4 (1 row) DROP TABLE t; pgvector-0.6.0/test/expected/hnsw_l2.out000066400000000000000000000011321455577216400202500ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING hnsw (val vector_l2_ops); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; val --------- [1,2,3] [1,2,4] [1,1,1] [0,0,0] (4 rows) SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); val --------- [0,0,0] [1,1,1] [1,2,3] [1,2,4] (4 rows) SELECT COUNT(*) FROM t; count ------- 5 (1 row) TRUNCATE t; SELECT * FROM t ORDER BY val <-> '[3,3,3]'; val ----- (0 rows) DROP TABLE t; pgvector-0.6.0/test/expected/hnsw_options.out000066400000000000000000000022251455577216400214320ustar00rootroot00000000000000CREATE TABLE t (val vector(3)); CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 1); ERROR: value 1 out of bounds for option "m" DETAIL: Valid values are between "2" and "100". CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101); ERROR: value 101 out of bounds for option "m" DETAIL: Valid values are between "2" and "100". CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 3); ERROR: value 3 out of bounds for option "ef_construction" DETAIL: Valid values are between "4" and "1000". CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001); ERROR: value 1001 out of bounds for option "ef_construction" DETAIL: Valid values are between "4" and "1000". CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 16, ef_construction = 31); ERROR: ef_construction must be greater than or equal to 2 * m SHOW hnsw.ef_search; hnsw.ef_search ---------------- 40 (1 row) SET hnsw.ef_search = 0; ERROR: 0 is outside the valid range for parameter "hnsw.ef_search" (1 .. 1000) SET hnsw.ef_search = 1001; ERROR: 1001 is outside the valid range for parameter "hnsw.ef_search" (1 .. 1000) DROP TABLE t; pgvector-0.6.0/test/expected/hnsw_unlogged.out000066400000000000000000000004621455577216400215440ustar00rootroot00000000000000SET enable_seqscan = off; CREATE UNLOGGED TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING hnsw (val vector_l2_ops); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; val --------- [1,2,3] [1,1,1] [0,0,0] (3 rows) DROP TABLE t; pgvector-0.6.0/test/expected/input.out000066400000000000000000000057731455577216400200520ustar00rootroot00000000000000SELECT '[1,2,3]'::vector; vector --------- [1,2,3] (1 row) SELECT '[-1,-2,-3]'::vector; vector ------------ [-1,-2,-3] (1 row) SELECT '[1.,2.,3.]'::vector; vector --------- [1,2,3] (1 row) SELECT ' [ 1, 2 , 3 ] '::vector; vector --------- [1,2,3] (1 row) SELECT '[1.23456]'::vector; vector ----------- [1.23456] (1 row) SELECT '[hello,1]'::vector; ERROR: invalid input syntax for type vector: "[hello,1]" LINE 1: SELECT '[hello,1]'::vector; ^ SELECT '[NaN,1]'::vector; ERROR: NaN not allowed in vector LINE 1: SELECT '[NaN,1]'::vector; ^ SELECT '[Infinity,1]'::vector; ERROR: infinite value not allowed in vector LINE 1: SELECT '[Infinity,1]'::vector; ^ SELECT '[-Infinity,1]'::vector; ERROR: infinite value not allowed in vector LINE 1: SELECT '[-Infinity,1]'::vector; ^ SELECT '[1.5e38,-1.5e38]'::vector; vector -------------------- [1.5e+38,-1.5e+38] (1 row) SELECT '[1.5e+38,-1.5e+38]'::vector; vector -------------------- [1.5e+38,-1.5e+38] (1 row) SELECT '[1.5e-38,-1.5e-38]'::vector; vector -------------------- [1.5e-38,-1.5e-38] (1 row) SELECT '[4e38,1]'::vector; ERROR: infinite value not allowed in vector LINE 1: SELECT '[4e38,1]'::vector; ^ SELECT '[1,2,3'::vector; ERROR: malformed vector literal: "[1,2,3" LINE 1: SELECT '[1,2,3'::vector; ^ DETAIL: Unexpected end of input. SELECT '[1,2,3]9'::vector; ERROR: malformed vector literal: "[1,2,3]9" LINE 1: SELECT '[1,2,3]9'::vector; ^ DETAIL: Junk after closing right brace. SELECT '1,2,3'::vector; ERROR: malformed vector literal: "1,2,3" LINE 1: SELECT '1,2,3'::vector; ^ DETAIL: Vector contents must start with "[". SELECT ''::vector; ERROR: malformed vector literal: "" LINE 1: SELECT ''::vector; ^ DETAIL: Vector contents must start with "[". SELECT '['::vector; ERROR: malformed vector literal: "[" LINE 1: SELECT '['::vector; ^ DETAIL: Unexpected end of input. SELECT '[,'::vector; ERROR: malformed vector literal: "[," LINE 1: SELECT '[,'::vector; ^ DETAIL: Unexpected end of input. SELECT '[]'::vector; ERROR: vector must have at least 1 dimension LINE 1: SELECT '[]'::vector; ^ SELECT '[1,]'::vector; ERROR: invalid input syntax for type vector: "[1,]" LINE 1: SELECT '[1,]'::vector; ^ SELECT '[1a]'::vector; ERROR: invalid input syntax for type vector: "[1a]" LINE 1: SELECT '[1a]'::vector; ^ SELECT '[1,,3]'::vector; ERROR: malformed vector literal: "[1,,3]" LINE 1: SELECT '[1,,3]'::vector; ^ SELECT '[1, ,3]'::vector; ERROR: invalid input syntax for type vector: "[1, ,3]" LINE 1: SELECT '[1, ,3]'::vector; ^ SELECT '[1,2,3]'::vector(2); ERROR: expected 2 dimensions, not 3 SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::vector[]); unnest --------- [1,2,3] [4,5,6] (2 rows) SELECT '{"[1,2,3]"}'::vector(2)[]; ERROR: expected 2 dimensions, not 3 pgvector-0.6.0/test/expected/ivfflat_cosine.out000066400000000000000000000011011455577216400216630ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING ivfflat (val vector_cosine_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <=> '[3,3,3]'; val --------- [1,1,1] [1,2,3] [1,2,4] (3 rows) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; count ------- 3 (1 row) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; count ------- 3 (1 row) DROP TABLE t; pgvector-0.6.0/test/expected/ivfflat_ip.out000066400000000000000000000007401455577216400210230ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING ivfflat (val vector_ip_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <#> '[3,3,3]'; val --------- [1,2,4] [1,2,3] [1,1,1] [0,0,0] (4 rows) SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; count ------- 4 (1 row) DROP TABLE t; pgvector-0.6.0/test/expected/ivfflat_l2.out000066400000000000000000000013701455577216400207300ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; val --------- [1,2,3] [1,2,4] [1,1,1] [0,0,0] (4 rows) SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); val --------- [0,0,0] [1,1,1] [1,2,3] [1,2,4] (4 rows) SELECT COUNT(*) FROM t; count ------- 5 (1 row) TRUNCATE t; NOTICE: ivfflat index created with little data DETAIL: This will cause low recall. HINT: Drop the index until the table has more data. SELECT * FROM t ORDER BY val <-> '[3,3,3]'; val ----- (0 rows) DROP TABLE t; pgvector-0.6.0/test/expected/ivfflat_options.out000066400000000000000000000007151455577216400221100ustar00rootroot00000000000000CREATE TABLE t (val vector(3)); CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 0); ERROR: value 0 out of bounds for option "lists" DETAIL: Valid values are between "1" and "32768". CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 32769); ERROR: value 32769 out of bounds for option "lists" DETAIL: Valid values are between "1" and "32768". SHOW ivfflat.probes; ivfflat.probes ---------------- 1 (1 row) DROP TABLE t; pgvector-0.6.0/test/expected/ivfflat_unlogged.out000066400000000000000000000005061455577216400222170ustar00rootroot00000000000000SET enable_seqscan = off; CREATE UNLOGGED TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 1); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; val --------- [1,2,3] [1,1,1] [0,0,0] (3 rows) DROP TABLE t; pgvector-0.6.0/test/perl/000077500000000000000000000000001455577216400153075ustar00rootroot00000000000000pgvector-0.6.0/test/perl/PostgresNode.pm000066400000000000000000000001451455577216400202610ustar00rootroot00000000000000use PostgreSQL::Test::Cluster; sub get_new_node { return PostgreSQL::Test::Cluster->new(@_); } 1; pgvector-0.6.0/test/perl/TestLib.pm000066400000000000000000000000411455577216400172060ustar00rootroot00000000000000use PostgreSQL::Test::Utils; 1; pgvector-0.6.0/test/sql/000077500000000000000000000000001455577216400151445ustar00rootroot00000000000000pgvector-0.6.0/test/sql/btree.sql000066400000000000000000000003731455577216400167710ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t (val); SELECT * FROM t WHERE val = '[1,2,3]'; SELECT * FROM t ORDER BY val LIMIT 1; DROP TABLE t; pgvector-0.6.0/test/sql/cast.sql000066400000000000000000000011541455577216400166200ustar00rootroot00000000000000SELECT ARRAY[1,2,3]::vector; SELECT ARRAY[1.0,2.0,3.0]::vector; SELECT ARRAY[1,2,3]::float4[]::vector; SELECT ARRAY[1,2,3]::float8[]::vector; SELECT ARRAY[1,2,3]::numeric[]::vector; SELECT '{NULL}'::real[]::vector; SELECT '{NaN}'::real[]::vector; SELECT '{Infinity}'::real[]::vector; SELECT '{-Infinity}'::real[]::vector; SELECT '{}'::real[]::vector; SELECT '{{1}}'::real[]::vector; SELECT '[1,2,3]'::vector::real[]; SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n; -- ensure no error SELECT ARRAY[1,2,3] = ARRAY[1,2,3]; pgvector-0.6.0/test/sql/copy.sql000066400000000000000000000004641455577216400166430ustar00rootroot00000000000000CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE TABLE t2 (val vector(3)); \copy t TO 'results/data.bin' WITH (FORMAT binary) \copy t2 FROM 'results/data.bin' WITH (FORMAT binary) SELECT * FROM t2 ORDER BY val; DROP TABLE t; DROP TABLE t2; pgvector-0.6.0/test/sql/functions.sql000066400000000000000000000036351455577216400177040ustar00rootroot00000000000000SELECT '[1,2,3]'::vector + '[4,5,6]'; SELECT '[3e38]'::vector + '[3e38]'; SELECT '[1,2,3]'::vector - '[4,5,6]'; SELECT '[-3e38]'::vector - '[3e38]'; SELECT '[1,2,3]'::vector * '[4,5,6]'; SELECT '[1e37]'::vector * '[1e37]'; SELECT '[1e-37]'::vector * '[1e-37]'; SELECT vector_dims('[1,2,3]'); SELECT round(vector_norm('[1,1]')::numeric, 5); SELECT vector_norm('[3,4]'); SELECT vector_norm('[0,1]'); SELECT vector_norm('[3e37,4e37]')::real; SELECT l2_distance('[0,0]', '[3,4]'); SELECT l2_distance('[0,0]', '[0,1]'); SELECT l2_distance('[1,2]', '[3]'); SELECT l2_distance('[3e38]', '[-3e38]'); SELECT inner_product('[1,2]', '[3,4]'); SELECT inner_product('[1,2]', '[3]'); SELECT inner_product('[3e38]', '[3e38]'); SELECT cosine_distance('[1,2]', '[2,4]'); SELECT cosine_distance('[1,2]', '[0,0]'); SELECT cosine_distance('[1,1]', '[1,1]'); SELECT cosine_distance('[1,0]', '[0,2]'); SELECT cosine_distance('[1,1]', '[-1,-1]'); SELECT cosine_distance('[1,2]', '[3]'); SELECT cosine_distance('[1,1]', '[1.1,1.1]'); SELECT cosine_distance('[1,1]', '[-1.1,-1.1]'); SELECT cosine_distance('[3e38]', '[3e38]'); SELECT l1_distance('[0,0]', '[3,4]'); SELECT l1_distance('[0,0]', '[0,1]'); SELECT l1_distance('[1,2]', '[3]'); SELECT l1_distance('[3e38]', '[-3e38]'); SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v; SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; SELECT avg(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; SELECT vector_avg(array_agg(n)) FROM generate_series(1, 16002) n; SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT sum(v) FROM unnest(ARRAY[]::vector[]) v; SELECT sum(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; SELECT sum(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; pgvector-0.6.0/test/sql/hnsw_cosine.sql000066400000000000000000000006701455577216400202070ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING hnsw (val vector_cosine_ops); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <=> '[3,3,3]'; SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; DROP TABLE t; pgvector-0.6.0/test/sql/hnsw_ip.sql000066400000000000000000000005561455577216400173420ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING hnsw (val vector_ip_ops); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <#> '[3,3,3]'; SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; DROP TABLE t; pgvector-0.6.0/test/sql/hnsw_l2.sql000066400000000000000000000006451455577216400172460ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING hnsw (val vector_l2_ops); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); SELECT COUNT(*) FROM t; TRUNCATE t; SELECT * FROM t ORDER BY val <-> '[3,3,3]'; DROP TABLE t; pgvector-0.6.0/test/sql/hnsw_options.sql000066400000000000000000000007541455577216400204250ustar00rootroot00000000000000CREATE TABLE t (val vector(3)); CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 1); CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101); CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 3); CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001); CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 16, ef_construction = 31); SHOW hnsw.ef_search; SET hnsw.ef_search = 0; SET hnsw.ef_search = 1001; DROP TABLE t; pgvector-0.6.0/test/sql/hnsw_unlogged.sql000066400000000000000000000003741455577216400205340ustar00rootroot00000000000000SET enable_seqscan = off; CREATE UNLOGGED TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING hnsw (val vector_l2_ops); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; DROP TABLE t; pgvector-0.6.0/test/sql/input.sql000066400000000000000000000014051455577216400170240ustar00rootroot00000000000000SELECT '[1,2,3]'::vector; SELECT '[-1,-2,-3]'::vector; SELECT '[1.,2.,3.]'::vector; SELECT ' [ 1, 2 , 3 ] '::vector; SELECT '[1.23456]'::vector; SELECT '[hello,1]'::vector; SELECT '[NaN,1]'::vector; SELECT '[Infinity,1]'::vector; SELECT '[-Infinity,1]'::vector; SELECT '[1.5e38,-1.5e38]'::vector; SELECT '[1.5e+38,-1.5e+38]'::vector; SELECT '[1.5e-38,-1.5e-38]'::vector; SELECT '[4e38,1]'::vector; SELECT '[1,2,3'::vector; SELECT '[1,2,3]9'::vector; SELECT '1,2,3'::vector; SELECT ''::vector; SELECT '['::vector; SELECT '[,'::vector; SELECT '[]'::vector; SELECT '[1,]'::vector; SELECT '[1a]'::vector; SELECT '[1,,3]'::vector; SELECT '[1, ,3]'::vector; SELECT '[1,2,3]'::vector(2); SELECT unnest('{"[1,2,3]", "[4,5,6]"}'::vector[]); SELECT '{"[1,2,3]"}'::vector(2)[]; pgvector-0.6.0/test/sql/ivfflat_cosine.sql000066400000000000000000000007141455577216400206620ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING ivfflat (val vector_cosine_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <=> '[3,3,3]'; SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; DROP TABLE t; pgvector-0.6.0/test/sql/ivfflat_ip.sql000066400000000000000000000006021455577216400200060ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING ivfflat (val vector_ip_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <#> '[3,3,3]'; SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; DROP TABLE t; pgvector-0.6.0/test/sql/ivfflat_l2.sql000066400000000000000000000006711455577216400177210ustar00rootroot00000000000000SET enable_seqscan = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); SELECT COUNT(*) FROM t; TRUNCATE t; SELECT * FROM t ORDER BY val <-> '[3,3,3]'; DROP TABLE t; pgvector-0.6.0/test/sql/ivfflat_options.sql000066400000000000000000000003251455577216400210730ustar00rootroot00000000000000CREATE TABLE t (val vector(3)); CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 0); CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 32769); SHOW ivfflat.probes; DROP TABLE t; pgvector-0.6.0/test/sql/ivfflat_unlogged.sql000066400000000000000000000004201455577216400212000ustar00rootroot00000000000000SET enable_seqscan = off; CREATE UNLOGGED TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 1); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; DROP TABLE t; pgvector-0.6.0/test/t/000077500000000000000000000000001455577216400146105ustar00rootroot00000000000000pgvector-0.6.0/test/t/001_ivfflat_wal.pl000066400000000000000000000057261455577216400200350ustar00rootroot00000000000000# Based on postgres/contrib/bloom/t/001_wal.pl # Test generic xlog record work for ivfflat index replication. use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $dim = 32; my $node_primary; my $node_replica; # Run few queries on both primary and replica and check their results match. sub test_index_replay { my ($test_name) = @_; # Wait for replica to catch up my $applname = $node_replica->name; my $caughtup_query = "SELECT pg_current_wal_lsn() <= replay_lsn FROM pg_stat_replication WHERE application_name = '$applname';"; $node_primary->poll_query_until('postgres', $caughtup_query) or die "Timed out while waiting for replica 1 to catch up"; my @r = (); for (1 .. $dim) { push(@r, rand()); } my $sql = join(",", @r); my $queries = qq( SET enable_seqscan = off; SELECT * FROM tst ORDER BY v <-> '[$sql]' LIMIT 10; ); # Run test queries and compare their result my $primary_result = $node_primary->safe_psql("postgres", $queries); my $replica_result = $node_replica->safe_psql("postgres", $queries); is($primary_result, $replica_result, "$test_name: query result matches"); return; } # Use ARRAY[random(), random(), random(), ...] over # SELECT array_agg(random()) FROM generate_series(1, $dim) # to generate different values for each row my $array_sql = join(",", ('random()') x $dim); # Initialize primary node $node_primary = get_new_node('primary'); $node_primary->init(allows_streaming => 1); if ($dim > 32) { # TODO use wal_keep_segments for Postgres < 13 $node_primary->append_conf('postgresql.conf', qq(wal_keep_size = 1GB)); } if ($dim > 1500) { $node_primary->append_conf('postgresql.conf', qq(maintenance_work_mem = 128MB)); } $node_primary->start; my $backup_name = 'my_backup'; # Take backup $node_primary->backup($backup_name); # Create streaming replica linking to primary $node_replica = get_new_node('replica'); $node_replica->init_from_backup($node_primary, $backup_name, has_streaming => 1); $node_replica->start; # Create ivfflat index on primary $node_primary->safe_psql("postgres", "CREATE EXTENSION vector;"); $node_primary->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); $node_primary->safe_psql("postgres", "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" ); $node_primary->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); # Test that queries give same result test_index_replay('initial'); # Run 10 cycles of table modification. Run test queries after each modification. for my $i (1 .. 10) { $node_primary->safe_psql("postgres", "DELETE FROM tst WHERE i = $i;"); test_index_replay("delete $i"); $node_primary->safe_psql("postgres", "VACUUM tst;"); test_index_replay("vacuum $i"); my ($start, $end) = (100001 + ($i - 1) * 10000, 100000 + $i * 10000); $node_primary->safe_psql("postgres", "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series($start, $end) i;" ); test_index_replay("insert $i"); } done_testing(); pgvector-0.6.0/test/t/002_ivfflat_vacuum.pl000066400000000000000000000022261455577216400205430ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $dim = 3; my @r = (); for (1 .. $dim) { my $v = int(rand(1000)) + 1; push(@r, "i % $v"); } my $array_sql = join(", ", @r); # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table and index $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" ); $node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); # Get size my $size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); # Delete all, vacuum, and insert same data $node->safe_psql("postgres", "DELETE FROM tst;"); $node->safe_psql("postgres", "VACUUM tst;"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" ); # Check size my $new_size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); is($size, $new_size, "size does not change"); done_testing(); pgvector-0.6.0/test/t/003_ivfflat_build_recall.pl000066400000000000000000000057151455577216400216730ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $node; my @queries = (); my @expected; my $limit = 20; sub test_recall { my ($probes, $min, $operator) = @_; my $correct = 0; my $total = 0; my $explain = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET ivfflat.probes = $probes; EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; )); like($explain, qr/Index Scan using idx on tst/); for my $i (0 .. $#queries) { my $actual = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET ivfflat.probes = $probes; SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; )); my @actual_ids = split("\n", $actual); my %actual_set = map { $_ => 1 } @actual_ids; my @expected_ids = split("\n", $expected[$i]); foreach (@expected_ids) { if (exists($actual_set{$_})) { $correct++; } $total++; } } cmp_ok($correct / $total, ">=", $min, $operator); } # Initialize node $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" ); # Generate queries for (1 .. 20) { my $r1 = rand(); my $r2 = rand(); my $r3 = rand(); push(@queries, "[$r1,$r2,$r3]"); } # Check each index type my @operators = ("<->", "<#>", "<=>"); my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); for my $i (0 .. $#operators) { my $operator = $operators[$i]; my $opclass = $opclasses[$i]; # Get exact results @expected = (); foreach (@queries) { my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); push(@expected, $res); } # Build index serially $node->safe_psql("postgres", qq( SET max_parallel_maintenance_workers = 0; CREATE INDEX idx ON tst USING ivfflat (v $opclass); )); # Test approximate results if ($operator ne "<#>") { # TODO Fix test (uniform random vectors all have similar inner product) test_recall(1, 0.71, $operator); test_recall(10, 0.95, $operator); } # Account for equal distances test_recall(100, 0.9925, $operator); $node->safe_psql("postgres", "DROP INDEX idx;"); # Build index in parallel my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( SET client_min_messages = DEBUG; SET min_parallel_table_scan_size = 1; CREATE INDEX idx ON tst USING ivfflat (v $opclass); )); is($ret, 0, $stderr); like($stderr, qr/using \d+ parallel workers/); # Test approximate results if ($operator ne "<#>") { # TODO Fix test (uniform random vectors all have similar inner product) test_recall(1, 0.71, $operator); test_recall(10, 0.95, $operator); } # Account for equal distances test_recall(100, 0.9925, $operator); $node->safe_psql("postgres", "DROP INDEX idx;"); } done_testing(); pgvector-0.6.0/test/t/004_ivfflat_centers.pl000066400000000000000000000015311455577216400207060ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, '[1,2,3]' FROM generate_series(1, 10) i;" ); sub test_centers { my ($lists, $min) = @_; my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops) WITH (lists = $lists);"); is($ret, 0, $stderr); } # Test no error for duplicate centers test_centers(5); test_centers(10); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, '[4,5,6]' FROM generate_series(1, 10) i;" ); # Test no error for duplicate centers test_centers(10); done_testing(); pgvector-0.6.0/test/t/005_ivfflat_query_recall.pl000066400000000000000000000021131455577216400217300ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4 primary key, v vector(3));"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" ); # Check each index type my @operators = ("<->", "<#>", "<=>"); my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); for my $i (0 .. $#operators) { my $operator = $operators[$i]; my $opclass = $opclasses[$i]; # Add index $node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v $opclass);"); # Test 100% recall for (1 .. 20) { my $id = int(rand() * 100000); my $query = $node->safe_psql("postgres", "SELECT v FROM tst WHERE i = $id;"); my $res = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SELECT v FROM tst ORDER BY v <-> '$query' LIMIT 1; )); is($res, $query); } } done_testing(); pgvector-0.6.0/test/t/006_ivfflat_lists.pl000066400000000000000000000020621455577216400204030ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (v vector(3));"); $node->safe_psql("postgres", "INSERT INTO tst SELECT ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" ); $node->safe_psql("postgres", "CREATE INDEX lists50 ON tst USING ivfflat (v vector_l2_ops) WITH (lists = 50);"); $node->safe_psql("postgres", "CREATE INDEX lists100 ON tst USING ivfflat (v vector_l2_ops) WITH (lists = 100);"); # Test prefers more lists my $res = $node->safe_psql("postgres", "EXPLAIN SELECT v FROM tst ORDER BY v <-> '[0.5,0.5,0.5]' LIMIT 10;"); like($res, qr/lists100/); unlike($res, qr/lists50/); # Test errors with too much memory my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX lists10000 ON tst USING ivfflat (v vector_l2_ops) WITH (lists = 10000);" ); like($stderr, qr/memory required is/); done_testing(); pgvector-0.6.0/test/t/007_ivfflat_inserts.pl000066400000000000000000000027051455577216400207410ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $dim = 768; my $array_sql = join(",", ('random()') x $dim); # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table and index $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (v vector($dim));"); $node->safe_psql("postgres", "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10000) i;" ); $node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); $node->pgbench( "--no-vacuum --client=5 --transactions=100", 0, [qr{actually processed}], [qr{^$}], "concurrent INSERTs", { "007_ivfflat_inserts" => "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10) i;" } ); sub idx_scan { # Stats do not update instantaneously # https://www.postgresql.org/docs/current/monitoring-stats.html#MONITORING-STATS-VIEWS sleep(1); $node->safe_psql("postgres", "SELECT idx_scan FROM pg_stat_user_indexes WHERE indexrelid = 'tst_v_idx'::regclass;"); } my $expected = 10000 + 5 * 100 * 10; my $count = $node->safe_psql("postgres", "SELECT COUNT(*) FROM tst;"); is($count, $expected); is(idx_scan(), 0); $count = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET ivfflat.probes = 100; SELECT COUNT(*) FROM (SELECT v FROM tst ORDER BY v <-> (SELECT v FROM tst LIMIT 1)) t; )); is($count, $expected); is(idx_scan(), 1); done_testing(); pgvector-0.6.0/test/t/008_aggregates.pl000066400000000000000000000024351455577216400176510ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (r1 real, r2 real, r3 real, v vector(3));"); $node->safe_psql("postgres", qq( INSERT INTO tst SELECT r1, r2, r3, ARRAY[r1, r2, r3] FROM ( SELECT random() + 1.01 AS r1, random() + 2.01 AS r2, random() + 3.01 AS r3 FROM generate_series(1, 1000000) t ) i; )); sub test_aggregate { my ($agg) = @_; # Test value my $res = $node->safe_psql("postgres", "SELECT $agg(v) FROM tst;"); like($res, qr/\[1\.5/); like($res, qr/,2\.5/); like($res, qr/,3\.5/); # Test matches real for avg # Cannot test sum since sum(real) varies between calls if ($agg eq 'avg') { my $r1 = $node->safe_psql("postgres", "SELECT $agg(r1)::float4 FROM tst;"); my $r2 = $node->safe_psql("postgres", "SELECT $agg(r2)::float4 FROM tst;"); my $r3 = $node->safe_psql("postgres", "SELECT $agg(r3)::float4 FROM tst;"); is($res, "[$r1,$r2,$r3]"); } # Test explain my $explain = $node->safe_psql("postgres", "EXPLAIN SELECT $agg(v) FROM tst;"); like($explain, qr/Partial Aggregate/); } test_aggregate('avg'); test_aggregate('sum'); done_testing(); pgvector-0.6.0/test/t/009_storage.pl000066400000000000000000000017561455577216400172120ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $dim = 1024; # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (v1 vector(1024), v2 vector(1024), v3 vector(1024));"); # Test insert succeeds $node->safe_psql("postgres", "INSERT INTO tst SELECT array_agg(n), array_agg(n), array_agg(n) FROM generate_series(1, $dim) n" ); # Change storage to PLAIN $node->safe_psql("postgres", "ALTER TABLE tst ALTER COLUMN v1 SET STORAGE PLAIN"); $node->safe_psql("postgres", "ALTER TABLE tst ALTER COLUMN v2 SET STORAGE PLAIN"); $node->safe_psql("postgres", "ALTER TABLE tst ALTER COLUMN v3 SET STORAGE PLAIN"); # Test insert fails my ($ret, $stdout, $stderr) = $node->psql("postgres", "INSERT INTO tst SELECT array_agg(n), array_agg(n), array_agg(n) FROM generate_series(1, $dim) n" ); like($stderr, qr/row is too big/); done_testing(); pgvector-0.6.0/test/t/010_hnsw_wal.pl000066400000000000000000000057031455577216400173540ustar00rootroot00000000000000# Based on postgres/contrib/bloom/t/001_wal.pl # Test generic xlog record work for hnsw index replication. use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $dim = 32; my $node_primary; my $node_replica; # Run few queries on both primary and replica and check their results match. sub test_index_replay { my ($test_name) = @_; # Wait for replica to catch up my $applname = $node_replica->name; my $caughtup_query = "SELECT pg_current_wal_lsn() <= replay_lsn FROM pg_stat_replication WHERE application_name = '$applname';"; $node_primary->poll_query_until('postgres', $caughtup_query) or die "Timed out while waiting for replica 1 to catch up"; my @r = (); for (1 .. $dim) { push(@r, rand()); } my $sql = join(",", @r); my $queries = qq( SET enable_seqscan = off; SELECT * FROM tst ORDER BY v <-> '[$sql]' LIMIT 10; ); # Run test queries and compare their result my $primary_result = $node_primary->safe_psql("postgres", $queries); my $replica_result = $node_replica->safe_psql("postgres", $queries); is($primary_result, $replica_result, "$test_name: query result matches"); return; } # Use ARRAY[random(), random(), random(), ...] over # SELECT array_agg(random()) FROM generate_series(1, $dim) # to generate different values for each row my $array_sql = join(",", ('random()') x $dim); # Initialize primary node $node_primary = get_new_node('primary'); $node_primary->init(allows_streaming => 1); if ($dim > 32) { # TODO use wal_keep_segments for Postgres < 13 $node_primary->append_conf('postgresql.conf', qq(wal_keep_size = 1GB)); } if ($dim > 1500) { $node_primary->append_conf('postgresql.conf', qq(maintenance_work_mem = 128MB)); } $node_primary->start; my $backup_name = 'my_backup'; # Take backup $node_primary->backup($backup_name); # Create streaming replica linking to primary $node_replica = get_new_node('replica'); $node_replica->init_from_backup($node_primary, $backup_name, has_streaming => 1); $node_replica->start; # Create hnsw index on primary $node_primary->safe_psql("postgres", "CREATE EXTENSION vector;"); $node_primary->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); $node_primary->safe_psql("postgres", "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 1000) i;" ); $node_primary->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); # Test that queries give same result test_index_replay('initial'); # Run 10 cycles of table modification. Run test queries after each modification. for my $i (1 .. 10) { $node_primary->safe_psql("postgres", "DELETE FROM tst WHERE i = $i;"); test_index_replay("delete $i"); $node_primary->safe_psql("postgres", "VACUUM tst;"); test_index_replay("vacuum $i"); my ($start, $end) = (1001 + ($i - 1) * 100, 1000 + $i * 100); $node_primary->safe_psql("postgres", "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series($start, $end) i;" ); test_index_replay("insert $i"); } done_testing(); pgvector-0.6.0/test/t/011_hnsw_vacuum.pl000066400000000000000000000027421455577216400200720ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $dim = 3; my @r = (); for (1 .. $dim) { my $v = int(rand(1000)) + 1; push(@r, "i % $v"); } my $array_sql = join(", ", @r); # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table and index $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" ); $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); # Get size my $size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); # Delete all, vacuum, and insert same data $node->safe_psql("postgres", "DELETE FROM tst;"); $node->safe_psql("postgres", "VACUUM tst;"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" ); # Check size # May increase some due to different levels my $new_size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); cmp_ok($new_size, "<=", $size * 1.02, "size does not increase too much"); # Delete all but one $node->safe_psql("postgres", "DELETE FROM tst WHERE i != 123;"); $node->safe_psql("postgres", "VACUUM tst;"); my $res = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SELECT i FROM tst ORDER BY v <-> '[0,0,0]' LIMIT 10; )); is($res, 123); done_testing(); pgvector-0.6.0/test/t/012_hnsw_build_recall.pl000066400000000000000000000060741455577216400212160ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $node; my @queries = (); my @expected; my $limit = 20; sub test_recall { my ($min, $operator) = @_; my $correct = 0; my $total = 0; my $explain = $node->safe_psql("postgres", qq( SET enable_seqscan = off; EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; )); like($explain, qr/Index Scan/); for my $i (0 .. $#queries) { my $actual = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; )); my @actual_ids = split("\n", $actual); my %actual_set = map { $_ => 1 } @actual_ids; my @expected_ids = split("\n", $expected[$i]); foreach (@expected_ids) { if (exists($actual_set{$_})) { $correct++; } $total++; } } cmp_ok($correct / $total, ">=", $min, $operator); } # Initialize node $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" ); # Generate queries for (1 .. 20) { my $r1 = rand(); my $r2 = rand(); my $r3 = rand(); push(@queries, "[$r1,$r2,$r3]"); } # Check each index type my @operators = ("<->", "<#>", "<=>"); my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); for my $i (0 .. $#operators) { my $operator = $operators[$i]; my $opclass = $opclasses[$i]; # Get exact results @expected = (); foreach (@queries) { my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); push(@expected, $res); } # Build index serially $node->safe_psql("postgres", qq( SET max_parallel_maintenance_workers = 0; CREATE INDEX idx ON tst USING hnsw (v $opclass); )); # Test approximate results my $min = $operator eq "<#>" ? 0.80 : 0.99; test_recall($min, $operator); $node->safe_psql("postgres", "DROP INDEX idx;"); # Build index in parallel in memory my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( SET client_min_messages = DEBUG; SET min_parallel_table_scan_size = 1; CREATE INDEX idx ON tst USING hnsw (v $opclass); )); is($ret, 0, $stderr); like($stderr, qr/using \d+ parallel workers/); # Test approximate results test_recall($min, $operator); $node->safe_psql("postgres", "DROP INDEX idx;"); # Build index in parallel on disk # Set parallel_workers on table to use workers with low maintenance_work_mem ($ret, $stdout, $stderr) = $node->psql("postgres", qq( ALTER TABLE tst SET (parallel_workers = 2); SET client_min_messages = DEBUG; SET maintenance_work_mem = '4MB'; CREATE INDEX idx ON tst USING hnsw (v $opclass); ALTER TABLE tst RESET (parallel_workers); )); is($ret, 0, $stderr); like($stderr, qr/using \d+ parallel workers/); like($stderr, qr/hnsw graph no longer fits into maintenance_work_mem/); $node->safe_psql("postgres", "DROP INDEX idx;"); } done_testing(); pgvector-0.6.0/test/t/013_hnsw_insert_recall.pl000066400000000000000000000043631455577216400214230ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $node; my @queries = (); my @expected; my $limit = 20; sub test_recall { my ($min, $operator) = @_; my $correct = 0; my $total = 0; my $explain = $node->safe_psql("postgres", qq( SET enable_seqscan = off; EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; )); like($explain, qr/Index Scan/); for my $i (0 .. $#queries) { my $actual = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; )); my @actual_ids = split("\n", $actual); my %actual_set = map { $_ => 1 } @actual_ids; my @expected_ids = split("\n", $expected[$i]); foreach (@expected_ids) { if (exists($actual_set{$_})) { $correct++; } $total++; } } cmp_ok($correct / $total, ">=", $min, $operator); } # Initialize node $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i serial, v vector(3));"); # Generate queries for (1 .. 20) { my $r1 = rand(); my $r2 = rand(); my $r3 = rand(); push(@queries, "[$r1,$r2,$r3]"); } # Check each index type my @operators = ("<->", "<#>", "<=>"); my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); for my $i (0 .. $#operators) { my $operator = $operators[$i]; my $opclass = $opclasses[$i]; # Add index $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v $opclass);"); # Use concurrent inserts $node->pgbench( "--no-vacuum --client=10 --transactions=1000", 0, [qr{actually processed}], [qr{^$}], "concurrent INSERTs", { "013_hnsw_insert_recall_$opclass" => "INSERT INTO tst (v) VALUES (ARRAY[random(), random(), random()]);" } ); # Get exact results @expected = (); foreach (@queries) { my $res = $node->safe_psql("postgres", qq( SET enable_indexscan = off; SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit; )); push(@expected, $res); } my $min = $operator eq "<#>" ? 0.80 : 0.99; test_recall($min, $operator); $node->safe_psql("postgres", "DROP INDEX idx;"); $node->safe_psql("postgres", "TRUNCATE tst;"); } done_testing(); pgvector-0.6.0/test/t/014_hnsw_inserts.pl000066400000000000000000000035551455577216400202670ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; # Ensures elements and neighbors on both same and different pages my $dim = 1900; my $array_sql = join(",", ('random()') x $dim); # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table and index $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (v vector($dim));"); $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); sub idx_scan { # Stats do not update instantaneously # https://www.postgresql.org/docs/current/monitoring-stats.html#MONITORING-STATS-VIEWS sleep(1); $node->safe_psql("postgres", "SELECT idx_scan FROM pg_stat_user_indexes WHERE indexrelid = 'tst_v_idx'::regclass;"); } for my $i (1 .. 20) { $node->pgbench( "--no-vacuum --client=10 --transactions=1", 0, [qr{actually processed}], [qr{^$}], "concurrent INSERTs", { "014_hnsw_inserts_$i" => "INSERT INTO tst VALUES (ARRAY[$array_sql]);" } ); my $count = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SELECT COUNT(*) FROM (SELECT v FROM tst ORDER BY v <-> (SELECT v FROM tst LIMIT 1)) t; )); is($count, 10); $node->safe_psql("postgres", "TRUNCATE tst;"); } $node->pgbench( "--no-vacuum --client=20 --transactions=5", 0, [qr{actually processed}], [qr{^$}], "concurrent INSERTs", { "014_hnsw_inserts" => "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10) i;" } ); my $count = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET hnsw.ef_search = 1000; SELECT COUNT(*) FROM (SELECT v FROM tst ORDER BY v <-> (SELECT v FROM tst LIMIT 1)) t; )); # Elements may lose all incoming connections with the HNSW algorithm # Vacuuming can fix this if one of the elements neighbors is deleted cmp_ok($count, ">=", 997); is(idx_scan(), 21); done_testing(); pgvector-0.6.0/test/t/015_hnsw_duplicates.pl000066400000000000000000000022061455577216400207260ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (v vector(3));"); sub insert_vectors { for my $i (1 .. 20) { $node->safe_psql("postgres", "INSERT INTO tst VALUES ('[1,1,1]');"); } } sub test_duplicates { my $res = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET hnsw.ef_search = 1; SELECT COUNT(*) FROM (SELECT * FROM tst ORDER BY v <-> '[1,1,1]') t; )); is($res, 10); } # Test duplicates with build insert_vectors(); $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v vector_l2_ops);"); test_duplicates(); # Reset $node->safe_psql("postgres", "TRUNCATE tst;"); # Test duplicates with inserts insert_vectors(); test_duplicates(); # Test fallback path for inserts $node->pgbench( "--no-vacuum --client=5 --transactions=100", 0, [qr{actually processed}], [qr{^$}], "concurrent INSERTs", { "015_hnsw_duplicates" => "INSERT INTO tst VALUES ('[1,1,1]');" } ); done_testing(); pgvector-0.6.0/test/t/016_hnsw_vacuum_recall.pl000066400000000000000000000042301455577216400214130ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $node; my @queries = (); my @expected; my $limit = 20; sub test_recall { my ($min, $ef_search, $test_name) = @_; my $correct = 0; my $total = 0; my $explain = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET hnsw.ef_search = $ef_search; EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v <-> '$queries[0]' LIMIT $limit; )); like($explain, qr/Index Scan/); for my $i (0 .. $#queries) { my $actual = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET hnsw.ef_search = $ef_search; SELECT i FROM tst ORDER BY v <-> '$queries[$i]' LIMIT $limit; )); my @actual_ids = split("\n", $actual); my %actual_set = map { $_ => 1 } @actual_ids; my @expected_ids = split("\n", $expected[$i]); foreach (@expected_ids) { if (exists($actual_set{$_})) { $correct++; } $total++; } } cmp_ok($correct / $total, ">=", $min, $test_name); } # Initialize node $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); $node->safe_psql("postgres", "ALTER TABLE tst SET (autovacuum_enabled = false);"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" ); # Add index $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops) WITH (m = 4, ef_construction = 8);"); # Delete data $node->safe_psql("postgres", "DELETE FROM tst WHERE i > 2500;"); # Generate queries for (1 .. 20) { my $r1 = rand(); my $r2 = rand(); my $r3 = rand(); push(@queries, "[$r1,$r2,$r3]"); } # Get exact results @expected = (); foreach (@queries) { my $res = $node->safe_psql("postgres", qq( SET enable_indexscan = off; SELECT i FROM tst ORDER BY v <-> '$_' LIMIT $limit; )); push(@expected, $res); } test_recall(0.20, $limit, "before vacuum"); test_recall(0.95, 100, "before vacuum"); # TODO Test concurrent inserts with vacuum $node->safe_psql("postgres", "VACUUM tst;"); test_recall(0.95, $limit, "after vacuum"); done_testing(); pgvector-0.6.0/test/t/017_ivfflat_insert_recall.pl000066400000000000000000000050651455577216400221030ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $node; my @queries = (); my @expected; my $limit = 20; sub test_recall { my ($probes, $min, $operator) = @_; my $correct = 0; my $total = 0; my $explain = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET ivfflat.probes = $probes; EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; )); like($explain, qr/Index Scan using idx on tst/); for my $i (0 .. $#queries) { my $actual = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SET ivfflat.probes = $probes; SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; )); my @actual_ids = split("\n", $actual); my %actual_set = map { $_ => 1 } @actual_ids; my @expected_ids = split("\n", $expected[$i]); foreach (@expected_ids) { if (exists($actual_set{$_})) { $correct++; } $total++; } } cmp_ok($correct / $total, ">=", $min, $operator); } # Initialize node $node = get_new_node('node'); $node->init; $node->start; # Create table $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i serial, v vector(3));"); # Generate queries for (1 .. 20) { my $r1 = rand(); my $r2 = rand(); my $r3 = rand(); push(@queries, "[$r1,$r2,$r3]"); } # Check each index type my @operators = ("<->", "<#>", "<=>"); my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); for my $i (0 .. $#operators) { my $operator = $operators[$i]; my $opclass = $opclasses[$i]; # Add index $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING ivfflat (v $opclass);"); # Use concurrent inserts $node->pgbench( "--no-vacuum --client=10 --transactions=1000", 0, [qr{actually processed}], [qr{^$}], "concurrent INSERTs", { "017_ivfflat_insert_recall_$opclass" => "INSERT INTO tst (v) SELECT ARRAY[random(), random(), random()] FROM generate_series(1, 10) i;" } ); # Get exact results @expected = (); foreach (@queries) { my $res = $node->safe_psql("postgres", qq( SET enable_indexscan = off; SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit; )); push(@expected, $res); } # Test approximate results if ($operator ne "<#>") { # TODO Fix test (uniform random vectors all have similar inner product) test_recall(1, 0.71, $operator); test_recall(10, 0.95, $operator); } # Account for equal distances test_recall(100, 0.9925, $operator); $node->safe_psql("postgres", "DROP INDEX idx;"); $node->safe_psql("postgres", "TRUNCATE tst;"); } done_testing(); pgvector-0.6.0/test/t/018_hnsw_filtering.pl000066400000000000000000000072421455577216400205640ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $dim = 3; my $nc = 50; my $limit = 20; my $array_sql = join(",", ('random()') x $dim); # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table and index $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4, t text);"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc, 'test ' || i FROM generate_series(1, 10000) i;" ); $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v vector_l2_ops);"); $node->safe_psql("postgres", "ANALYZE tst;"); # Generate query my @r = (); for (1 .. $dim) { push(@r, rand()); } my $query = "[" . join(",", @r) . "]"; my $c = int(rand() * $nc); # Test attribute filtering my $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $c ORDER BY v <-> '$query' LIMIT $limit; )); # TODO Do not use index like($explain, qr/Index Scan using idx/); # Test attribute filtering with few rows removed $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c != $c ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Index Scan using idx/); # Test attribute filtering with few rows removed comparison $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c >= 1 ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Index Scan using idx/); # Test attribute filtering with many rows removed comparison $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c < 1 ORDER BY v <-> '$query' LIMIT $limit; )); # TODO Do not use index like($explain, qr/Index Scan using idx/); # Test attribute filtering with few rows removed like $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE t LIKE '%%test%%' ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Index Scan using idx/); # Test attribute filtering with many rows removed like $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE t LIKE '%%other%%' ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Seq Scan/); # Test distance filtering $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1 ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Index Scan using idx/); # Test distance filtering greater than distance $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' > 1 ORDER BY v <-> '$query' LIMIT $limit; )); # TODO Do not use index like($explain, qr/Index Scan using idx/); # Test distance filtering without order $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1; )); like($explain, qr/Seq Scan/); # Test distance filtering without limit $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1 ORDER BY v <-> '$query'; )); like($explain, qr/Seq Scan/); # Test attribute index $node->safe_psql("postgres", "CREATE INDEX attribute_idx ON tst (c);"); $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $c ORDER BY v <-> '$query' LIMIT $limit; )); # TODO Use attribute index like($explain, qr/Index Scan using idx/); # Test partial index $node->safe_psql("postgres", "CREATE INDEX partial_idx ON tst USING hnsw (v vector_l2_ops) WHERE (c = $c);"); $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $c ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Index Scan using partial_idx/); done_testing(); pgvector-0.6.0/test/t/019_ivfflat_filtering.pl000066400000000000000000000074011455577216400212360ustar00rootroot00000000000000use strict; use warnings; use PostgresNode; use TestLib; use Test::More; my $dim = 3; my $nc = 50; my $limit = 20; my $array_sql = join(",", ('random()') x $dim); # Initialize node my $node = get_new_node('node'); $node->init; $node->start; # Create table and index $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim), c int4, t text);"); $node->safe_psql("postgres", "INSERT INTO tst SELECT i, ARRAY[$array_sql], i % $nc, 'test ' || i FROM generate_series(1, 10000) i;" ); $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING ivfflat (v vector_l2_ops) WITH (lists = 100);"); $node->safe_psql("postgres", "ANALYZE tst;"); # Generate query my @r = (); for (1 .. $dim) { push(@r, rand()); } my $query = "[" . join(",", @r) . "]"; my $c = int(rand() * $nc); # Test attribute filtering my $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $c ORDER BY v <-> '$query' LIMIT $limit; )); # TODO Do not use index like($explain, qr/Index Scan using idx/); # Test attribute filtering with few rows removed $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c != $c ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Index Scan using idx/); # Test attribute filtering with few rows removed comparison $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c >= 1 ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Index Scan using idx/); # Test attribute filtering with many rows removed comparison $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c < 1 ORDER BY v <-> '$query' LIMIT $limit; )); # TODO Do not use index like($explain, qr/Index Scan using idx/); # Test attribute filtering with few rows removed like $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE t LIKE '%%test%%' ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Index Scan using idx/); # Test attribute filtering with many rows removed like $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE t LIKE '%%other%%' ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Seq Scan/); # Test distance filtering $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1 ORDER BY v <-> '$query' LIMIT $limit; )); like($explain, qr/Index Scan using idx/); # Test distance filtering greater than distance $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' > 1 ORDER BY v <-> '$query' LIMIT $limit; )); # TODO Do not use index like($explain, qr/Index Scan using idx/); # Test distance filtering without order $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1; )); like($explain, qr/Seq Scan/); # Test distance filtering without limit $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE v <-> '$query' < 1 ORDER BY v <-> '$query'; )); # TODO Do not use index like($explain, qr/Index Scan using idx/); # Test attribute index $node->safe_psql("postgres", "CREATE INDEX attribute_idx ON tst (c);"); $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $c ORDER BY v <-> '$query' LIMIT $limit; )); # TODO Use attribute index like($explain, qr/Index Scan using idx/); # Test partial index $node->safe_psql("postgres", "CREATE INDEX partial_idx ON tst USING ivfflat (v vector_l2_ops) WITH (lists = 5) WHERE (c = $c);"); $explain = $node->safe_psql("postgres", qq( EXPLAIN ANALYZE SELECT i FROM tst WHERE c = $c ORDER BY v <-> '$query' LIMIT $limit; )); # TODO Use partial index like($explain, qr/Index Scan using idx/); done_testing(); pgvector-0.6.0/vector.control000066400000000000000000000002211455577216400162650ustar00rootroot00000000000000comment = 'vector data type and ivfflat and hnsw access methods' default_version = '0.6.0' module_pathname = '$libdir/vector' relocatable = true