pax_global_header00006660000000000000000000000064143772577320014533gustar00rootroot0000000000000052 comment=13468eb321a0bacbb0cbdfddb1ea1d62e68e0652 pgx-4.18.1/000077500000000000000000000000001437725773200124245ustar00rootroot00000000000000pgx-4.18.1/.github/000077500000000000000000000000001437725773200137645ustar00rootroot00000000000000pgx-4.18.1/.github/ISSUE_TEMPLATE/000077500000000000000000000000001437725773200161475ustar00rootroot00000000000000pgx-4.18.1/.github/ISSUE_TEMPLATE/bug_report.md000066400000000000000000000021741437725773200206450ustar00rootroot00000000000000--- name: Bug report about: Create a report to help us improve title: '' labels: bug assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: If possible, please provide runnable example such as: ```go package main import ( "context" "log" "os" "github.com/jackc/pgx/v4" ) func main() { conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { log.Fatal(err) } defer conn.Close(context.Background()) // Your code here... } ``` **Expected behavior** A clear and concise description of what you expected to happen. **Actual behavior** A clear and concise description of what actually happened. **Version** - Go: `$ go version` -> [e.g. go version go1.18.3 darwin/amd64] - PostgreSQL: `$ psql --no-psqlrc --tuples-only -c 'select version()'` -> [e.g. PostgreSQL 14.4 on x86_64-apple-darwin21.5.0, compiled by Apple clang version 13.1.6 (clang-1316.0.21.2.5), 64-bit] - pgx: `$ grep 'github.com/jackc/pgx/v[0-9]' go.mod` -> [e.g. v4.16.1] **Additional context** Add any other context about the problem here. pgx-4.18.1/.github/ISSUE_TEMPLATE/feature_request.md000066400000000000000000000011231437725773200216710ustar00rootroot00000000000000--- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. pgx-4.18.1/.github/ISSUE_TEMPLATE/other-issues.md000066400000000000000000000003501437725773200211210ustar00rootroot00000000000000--- name: Other issues about: Any issue that is not a bug or a feature request title: '' labels: '' assignees: '' --- Please describe the issue in detail. If this is a question about how to use pgx please use discussions instead. pgx-4.18.1/.github/workflows/000077500000000000000000000000001437725773200160215ustar00rootroot00000000000000pgx-4.18.1/.github/workflows/ci.yml000066400000000000000000000026201437725773200171370ustar00rootroot00000000000000name: CI on: push: branches: [ master ] pull_request: branches: [ master ] jobs: test: name: Test runs-on: ubuntu-20.04 strategy: matrix: go-version: [1.16, 1.17] pg-version: [10, 11, 12, 13, 14, cockroachdb] include: - pg-version: 10 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pg-version: 11 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pg-version: 12 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pg-version: 13 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pg-version: 14 pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test - pg-version: cockroachdb pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" steps: - name: Set up Go 1.x uses: actions/setup-go@v2 with: go-version: ${{ matrix.go-version }} - name: Check out code into the Go module directory uses: actions/checkout@v2 - name: Setup database server for testing run: ci/setup_test.bash env: PGVERSION: ${{ matrix.pg-version }} - name: Test run: go test -race ./... env: PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} pgx-4.18.1/.gitignore000066400000000000000000000004041437725773200144120ustar00rootroot00000000000000# Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders _obj _test # Architecture specific extensions/prefixes *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go *.exe .envrc pgx-4.18.1/CHANGELOG.md000066400000000000000000000247611437725773200142470ustar00rootroot00000000000000# 4.18.1 (February 27, 2023) * Fix: Support pgx v4 and v5 stdlib in same program (Tomáš Procházka) # 4.18.0 (February 11, 2023) * Upgrade pgconn to v1.14.0 * Upgrade pgproto3 to v2.3.2 * Upgrade pgtype to v1.14.0 * Fix query sanitizer when query text contains Unicode replacement character * Fix context with value in BeforeConnect (David Harju) * Support pgx v4 and v5 stdlib in same program (Vitalii Solodilov) # 4.17.2 (September 3, 2022) * Fix panic when logging batch error (Tom Möller) # 4.17.1 (August 27, 2022) * Upgrade puddle to v1.3.0 - fixes context failing to cancel Acquire when acquire is creating resource which was introduced in v4.17.0 (James Hartig) * Fix atomic alignment on 32-bit platforms # 4.17.0 (August 6, 2022) * Upgrade pgconn to v1.13.0 * Upgrade pgproto3 to v2.3.1 * Upgrade pgtype to v1.12.0 * Allow background pool connections to continue even if cause is canceled (James Hartig) * Add LoggerFunc (Gabor Szabad) * pgxpool: health check should avoid going below minConns (James Hartig) * Add pgxpool.Conn.Hijack() * Logging improvements (Stepan Rabotkin) # 4.16.1 (May 7, 2022) * Upgrade pgconn to v1.12.1 * Fix explicitly prepared statements with describe statement cache mode # 4.16.0 (April 21, 2022) * Upgrade pgconn to v1.12.0 * Upgrade pgproto3 to v2.3.0 * Upgrade pgtype to v1.11.0 * Fix: Do not panic when context cancelled while getting statement from cache. * Fix: Less memory pinning from old Rows. * Fix: Support '\r' line ending when sanitizing SQL comment. * Add pluggable GSSAPI support (Oliver Tan) # 4.15.0 (February 7, 2022) * Upgrade to pgconn v1.11.0 * Upgrade to pgtype v1.10.0 * Upgrade puddle to v1.2.1 * Make BatchResults.Close safe to be called multiple times # 4.14.1 (November 28, 2021) * Upgrade pgtype to v1.9.1 (fixes unintentional change to timestamp binary decoding) * Start pgxpool background health check after initial connections # 4.14.0 (November 20, 2021) * Upgrade pgconn to v1.10.1 * Upgrade pgproto3 to v2.2.0 * Upgrade pgtype to v1.9.0 * Upgrade puddle to v1.2.0 * Add QueryFunc to BatchResults * Add context options to zerologadapter (Thomas Frössman) * Add zerologadapter.NewContextLogger (urso) * Eager initialize minpoolsize on connect (Daniel) * Unpin memory used by large queries immediately after use # 4.13.0 (July 24, 2021) * Trimmed pseudo-dependencies in Go modules from other packages tests * Upgrade pgconn -- context cancellation no longer will return a net.Error * Support time durations for simple protocol (Michael Darr) # 4.12.0 (July 10, 2021) * ResetSession hook is called before a connection is reused from pool for another query (Dmytro Haranzha) * stdlib: Add RandomizeHostOrderFunc (dkinder) * stdlib: add OptionBeforeConnect (dkinder) * stdlib: Do not reuse ConnConfig strings (Andrew Kimball) * stdlib: implement Conn.ResetSession (Jonathan Amsterdam) * Upgrade pgconn to v1.9.0 * Upgrade pgtype to v1.8.0 # 4.11.0 (March 25, 2021) * Add BeforeConnect callback to pgxpool.Config (Robert Froehlich) * Add Ping method to pgxpool.Conn (davidsbond) * Added a kitlog level log adapter (Fabrice Aneche) * Make ScanArgError public to allow identification of offending column (Pau Sanchez) * Add *pgxpool.AcquireFunc * Add BeginFunc and BeginTxFunc * Add prefer_simple_protocol to connection string * Add logging on CopyFrom (Patrick Hemmer) * Add comment support when sanitizing SQL queries (Rusakow Andrew) * Do not panic on double close of pgxpool.Pool (Matt Schultz) * Avoid panic on SendBatch on closed Tx (Matt Schultz) * Update pgconn to v1.8.1 * Update pgtype to v1.7.0 # 4.10.1 (December 19, 2020) * Fix panic on Query error with nil stmtcache. # 4.10.0 (December 3, 2020) * Add CopyFromSlice to simplify CopyFrom usage (Egon Elbre) * Remove broken prepared statements from stmtcache (Ethan Pailes) * stdlib: consider any Ping error as fatal * Update puddle to v1.1.3 - this fixes an issue where concurrent Acquires can hang when a connection cannot be established * Update pgtype to v1.6.2 # 4.9.2 (November 3, 2020) The underlying library updates fix an issue where appending to a scanned slice could corrupt other data. * Update pgconn to v1.7.2 * Update pgproto3 to v2.0.6 # 4.9.1 (October 31, 2020) * Update pgconn to v1.7.1 * Update pgtype to v1.6.1 * Fix SendBatch of all prepared statements with statement cache disabled # 4.9.0 (September 26, 2020) * pgxpool now waits for connection cleanup to finish before making room in pool for another connection. This prevents temporarily exceeding max pool size. * Fix when scanning a column to nil to skip it on the first row but scanning it to a real value on a subsequent row. * Fix prefer simple protocol with prepared statements. (Jinzhu) * Fix FieldDescriptions not being available on Rows before calling Next the first time. * Various minor fixes in updated versions of pgconn, pgtype, and puddle. # 4.8.1 (July 29, 2020) * Update pgconn to v1.6.4 * Fix deadlock on error after CommandComplete but before ReadyForQuery * Fix panic on parsing DSN with trailing '=' # 4.8.0 (July 22, 2020) * All argument types supported by native pgx should now also work through database/sql * Update pgconn to v1.6.3 * Update pgtype to v1.4.2 # 4.7.2 (July 14, 2020) * Improve performance of Columns() (zikaeroh) * Fix fatal Commit() failure not being considered fatal * Update pgconn to v1.6.2 * Update pgtype to v1.4.1 # 4.7.1 (June 29, 2020) * Fix stdlib decoding error with certain order and combination of fields # 4.7.0 (June 27, 2020) * Update pgtype to v1.4.0 * Update pgconn to v1.6.1 * Update puddle to v1.1.1 * Fix context propagation with Tx commit and Rollback (georgysavva) * Add lazy connect option to pgxpool (georgysavva) * Fix connection leak if pgxpool.BeginTx() fail (Jean-Baptiste Bronisz) * Add native Go slice support for strings and numbers to simple protocol * stdlib add default timeouts for Conn.Close() and Stmt.Close() (georgysavva) * Assorted performance improvements especially with large result sets * Fix close pool on not lazy connect failure (Yegor Myskin) * Add Config copy (georgysavva) * Support SendBatch with Simple Protocol (Jordan Lewis) * Better error logging on rows close (Igor V. Kozinov) * Expose stdlib.Conn.Conn() to enable database/sql.Conn.Raw() * Improve unknown type support for database/sql * Fix transaction commit failure closing connection # 4.6.0 (March 30, 2020) * stdlib: Bail early if preloading rows.Next() results in rows.Err() (Bas van Beek) * Sanitize time to microsecond accuracy (Andrew Nicoll) * Update pgtype to v1.3.0 * Update pgconn to v1.5.0 * Update golang.org/x/crypto for security fix * Implement "verify-ca" SSL mode # 4.5.0 (March 7, 2020) * Update to pgconn v1.4.0 * Fixes QueryRow with empty SQL * Adds PostgreSQL service file support * Add Len() to *pgx.Batch (WGH) * Better logging for individual batch items (Ben Bader) # 4.4.1 (February 14, 2020) * Update pgconn to v1.3.2 - better default read buffer size * Fix race in CopyFrom # 4.4.0 (February 5, 2020) * Update puddle to v1.1.0 - fixes possible deadlock when acquire is cancelled * Update pgconn to v1.3.1 - fixes CopyFrom deadlock when multiple NoticeResponse received during copy * Update pgtype to v1.2.0 * Add MaxConnIdleTime to pgxpool (Patrick Ellul) * Add MinConns to pgxpool (Patrick Ellul) * Fix: stdlib.ReleaseConn closes connections left in invalid state # 4.3.0 (January 23, 2020) * Fix Rows.Values panic when unable to decode * Add Rows.Values support for unknown types * Add DriverContext support for stdlib (Alex Gaynor) * Update pgproto3 to v2.0.1 to never return an io.EOF as it would be misinterpreted by database/sql. Instead return io.UnexpectedEOF. # 4.2.1 (January 13, 2020) * Update pgconn to v1.2.1 (fixes context cancellation data race introduced in v1.2.0)) # 4.2.0 (January 11, 2020) * Update pgconn to v1.2.0. * Update pgtype to v1.1.0. * Return error instead of panic when wrong number of arguments passed to Exec. (malstoun) * Fix large objects functionality when PreferSimpleProtocol = true. * Restore GetDefaultDriver which existed in v3. (Johan Brandhorst) * Add RegisterConnConfig to stdlib which replaces the removed RegisterDriverConfig from v3. # 4.1.2 (October 22, 2019) * Fix dbSavepoint.Begin recursive self call * Upgrade pgtype to v1.0.2 - fix scan pointer to pointer # 4.1.1 (October 21, 2019) * Fix pgxpool Rows.CommandTag() infinite loop / typo # 4.1.0 (October 12, 2019) ## Potentially Breaking Changes Technically, two changes are breaking changes, but in practice these are extremely unlikely to break existing code. * Conn.Begin and Conn.BeginTx return a Tx interface instead of the internal dbTx struct. This is necessary for the Conn.Begin method to signature as other methods that begin a transaction. * Add Conn() to Tx interface. This is necessary to allow code using a Tx to access the *Conn (and pgconn.PgConn) on which the Tx is executing. ## Fixes * Releasing a busy connection closes the connection instead of returning an unusable connection to the pool * Do not mutate config.Config.OnNotification in connect # 4.0.1 (September 19, 2019) * Fix statement cache cleanup. * Corrected daterange OID. * Fix Tx when committing or rolling back multiple times in certain cases. * Improve documentation. # 4.0.0 (September 14, 2019) v4 is a major release with many significant changes some of which are breaking changes. The most significant are included below. * Simplified establishing a connection with a connection string. * All potentially blocking operations now require a context.Context. The non-context aware functions have been removed. * OIDs are hard-coded for known types. This saves the query on connection. * Context cancellations while network activity is in progress is now always fatal. Previously, it was sometimes recoverable. This led to increased complexity in pgx itself and in application code. * Go modules are required. * Errors are now implemented in the Go 1.13 style. * `Rows` and `Tx` are now interfaces. * The connection pool as been decoupled from pgx and is now a separate, included package (github.com/jackc/pgx/v4/pgxpool). * pgtype has been spun off to a separate package (github.com/jackc/pgtype). * pgproto3 has been spun off to a separate package (github.com/jackc/pgproto3/v2). * Logical replication support has been spun off to a separate package (github.com/jackc/pglogrepl). * Lower level PostgreSQL functionality is now implemented in a separate package (github.com/jackc/pgconn). * Tests are now configured with environment variables. * Conn has an automatic statement cache by default. * Batch interface has been simplified. * QueryArgs has been removed. pgx-4.18.1/LICENSE000066400000000000000000000020661437725773200134350ustar00rootroot00000000000000Copyright (c) 2013-2021 Jack Christensen MIT License Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. pgx-4.18.1/README.md000066400000000000000000000207711437725773200137120ustar00rootroot00000000000000[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v4) [![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx) --- This is the previous stable `v4` release. `v5` been released. --- # pgx - PostgreSQL Driver and Toolkit pgx is a pure Go driver and toolkit for PostgreSQL. pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for. The driver component of pgx can be used alongside the standard `database/sql` package. The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers, proxies, load balancers, logical replication clients, etc. The current release of `pgx v4` requires Go modules. To use the previous version, checkout and vendor the `v3` branch. ## Example Usage ```go package main import ( "context" "fmt" "os" "github.com/jackc/pgx/v4" ) func main() { // urlExample := "postgres://username:password@localhost:5432/database_name" conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err) os.Exit(1) } defer conn.Close(context.Background()) var name string var weight int64 err = conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight) if err != nil { fmt.Fprintf(os.Stderr, "QueryRow failed: %v\n", err) os.Exit(1) } fmt.Println(name, weight) } ``` See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information. ## Choosing Between the pgx and database/sql Interfaces It is recommended to use the pgx interface if: 1. The application only targets PostgreSQL. 2. No other libraries that require `database/sql` are in use. The pgx interface is faster and exposes more features. The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`, `float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the `database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses. ## Features pgx supports many features beyond what is available through `database/sql`: * Support for approximately 70 different PostgreSQL types * Automatic statement preparation and caching * Batch queries * Single-round trip query mode * Full TLS connection control * Binary format support for custom types (allows for much quicker encoding/decoding) * COPY protocol support for faster bulk data loads * Extendable logging support including built-in support for `log15adapter`, [`logrus`](https://github.com/sirupsen/logrus), [`zap`](https://github.com/uber-go/zap), and [`zerolog`](https://github.com/rs/zerolog) * Connection pool with after-connect hook for arbitrary connection setup * Listen / notify * Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings * Hstore support * JSON and JSONB support * Maps `inet` and `cidr` PostgreSQL types to `net.IPNet` and `net.IP` * Large object support * NULL mapping to Null* struct or pointer to pointer * Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types * Notice response handling * Simulated nested transactions with savepoints ## Performance There are three areas in particular where pgx can provide a significant performance advantage over the standard `database/sql` interface and other drivers: 1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format. 2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can perform nearly 3x the number of queries per second. 3. Batched queries - Multiple queries can be batched together to minimize network round trips. ## Testing pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` environment variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or DSN. In addition, the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable handling. ### Example Test Environment Connect to your PostgreSQL server and run: ``` create database pgx_test; ``` Connect to the newly-created database and run: ``` create domain uint64 as numeric(20,0); ``` Now, you can run the tests: ``` PGX_TEST_DATABASE="host=/var/run/postgresql database=pgx_test" go test ./... ``` In addition, there are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set. ## Supported Go and PostgreSQL Versions pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.16 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version. ## PGX Family Libraries pgx is the head of a family of PostgreSQL libraries. Many of these can be used independently. Many can also be accessed from pgx for lower-level control. ### [github.com/jackc/pgconn](https://github.com/jackc/pgconn) `pgconn` is a lower-level PostgreSQL database driver that operates at nearly the same level as the C library `libpq`. ### [github.com/jackc/pgx/v4/pgxpool](https://github.com/jackc/pgx/tree/master/pgxpool) `pgxpool` is a connection pool for pgx. pgx is entirely decoupled from its default pool implementation. This means that pgx can be used with a different pool or without any pool at all. ### [github.com/jackc/pgx/v4/stdlib](https://github.com/jackc/pgx/tree/master/stdlib) This is a `database/sql` compatibility layer for pgx. pgx can be used as a normal `database/sql` driver, but at any time, the native interface can be acquired for more performance or PostgreSQL specific functionality. ### [github.com/jackc/pgtype](https://github.com/jackc/pgtype) Over 70 PostgreSQL types are supported including `uuid`, `hstore`, `json`, `bytea`, `numeric`, `interval`, `inet`, and arrays. These types support `database/sql` interfaces and are usable outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver. ### [github.com/jackc/pgproto3](https://github.com/jackc/pgproto3) pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling. ### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl) pglogrepl provides functionality to act as a client for PostgreSQL logical replication. ### [github.com/jackc/pgmock](https://github.com/jackc/pgmock) pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler). ### [github.com/jackc/tern](https://github.com/jackc/tern) tern is a stand-alone SQL migration system. ### [github.com/jackc/pgerrcode](https://github.com/jackc/pgerrcode) pgerrcode contains constants for the PostgreSQL error codes. ## 3rd Party Libraries with PGX Support ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) Library for scanning data from a database into Go structs and more. ### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) Adds GSSAPI / Kerberos authentication support. ### [https://github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) Adds support for [`github.com/google/uuid`](https://github.com/google/uuid). pgx-4.18.1/batch.go000066400000000000000000000131721437725773200140400ustar00rootroot00000000000000package pgx import ( "context" "errors" "fmt" "github.com/jackc/pgconn" ) type batchItem struct { query string arguments []interface{} } // Batch queries are a way of bundling multiple queries together to avoid // unnecessary network round trips. type Batch struct { items []*batchItem } // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. func (b *Batch) Queue(query string, arguments ...interface{}) { b.items = append(b.items, &batchItem{ query: query, arguments: arguments, }) } // Len returns number of queries that have been queued so far. func (b *Batch) Len() int { return len(b.items) } type BatchResults interface { // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Exec() (pgconn.CommandTag, error) // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Query() (Rows, error) // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. QueryRow() Row // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. // In this case the underlying connection will have been closed. Close is safe to call multiple times. Close() error } type batchResults struct { ctx context.Context conn *Conn mrr *pgconn.MultiResultReader err error b *Batch ix int closed bool } // Exec reads the results from the next query in the batch as if the query has been sent with Exec. func (br *batchResults) Exec() (pgconn.CommandTag, error) { if br.err != nil { return nil, br.err } if br.closed { return nil, fmt.Errorf("batch already closed") } query, arguments, _ := br.nextQueryAndArgs() if !br.mrr.NextResult() { err := br.mrr.Close() if err == nil { err = errors.New("no result") } if br.conn.shouldLog(LogLevelError) { br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ "sql": query, "args": logQueryArgs(arguments), "err": err, }) } return nil, err } commandTag, err := br.mrr.ResultReader().Close() if err != nil { if br.conn.shouldLog(LogLevelError) { br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ "sql": query, "args": logQueryArgs(arguments), "err": err, }) } } else if br.conn.shouldLog(LogLevelInfo) { br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{}{ "sql": query, "args": logQueryArgs(arguments), "commandTag": commandTag, }) } return commandTag, err } // Query reads the results from the next query in the batch as if the query has been sent with Query. func (br *batchResults) Query() (Rows, error) { query, arguments, ok := br.nextQueryAndArgs() if !ok { query = "batch query" } if br.err != nil { return &connRows{err: br.err, closed: true}, br.err } if br.closed { alreadyClosedErr := fmt.Errorf("batch already closed") return &connRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr } rows := br.conn.getRows(br.ctx, query, arguments) if !br.mrr.NextResult() { rows.err = br.mrr.Close() if rows.err == nil { rows.err = errors.New("no result") } rows.closed = true if br.conn.shouldLog(LogLevelError) { br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{}{ "sql": query, "args": logQueryArgs(arguments), "err": rows.err, }) } return rows, rows.err } rows.resultReader = br.mrr.ResultReader() return rows, nil } // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if br.closed { return nil, fmt.Errorf("batch already closed") } rows, err := br.Query() if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(scans...) if err != nil { return nil, err } err = f(rows) if err != nil { return nil, err } } if err := rows.Err(); err != nil { return nil, err } return rows.CommandTag(), nil } // QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. func (br *batchResults) QueryRow() Row { rows, _ := br.Query() return (*connRow)(rows.(*connRows)) } // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to // resyncronize the connection with the server. In this case the underlying connection will have been closed. func (br *batchResults) Close() error { if br.err != nil { return br.err } if br.closed { return nil } br.closed = true // log any queries that haven't yet been logged by Exec or Query for { query, args, ok := br.nextQueryAndArgs() if !ok { break } if br.conn.shouldLog(LogLevelInfo) { br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{}{ "sql": query, "args": logQueryArgs(args), }) } } return br.mrr.Close() } func (br *batchResults) nextQueryAndArgs() (query string, args []interface{}, ok bool) { if br.b != nil && br.ix < len(br.b.items) { bi := br.b.items[br.ix] query = bi.query args = bi.arguments ok = true br.ix++ } return } pgx-4.18.1/batch_test.go000066400000000000000000000461671437725773200151110ustar00rootroot00000000000000package pgx_test import ( "context" "errors" "os" "testing" "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConnSendBatch(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server serial type is incompatible with test") sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` mustExec(t, conn, sql) batch := &pgx.Batch{} batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) batch.Queue("select id, description, amount from ledger order by id") batch.Queue("select id, description, amount from ledger order by id") batch.Queue("select * from ledger where false") batch.Queue("select sum(amount) from ledger") br := conn.SendBatch(context.Background(), batch) ct, err := br.Exec() if err != nil { t.Error(err) } if ct.RowsAffected() != 1 { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) } ct, err = br.Exec() if err != nil { t.Error(err) } if ct.RowsAffected() != 1 { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) } ct, err = br.Exec() if err != nil { t.Error(err) } if ct.RowsAffected() != 1 { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) } selectFromLedgerExpectedRows := []struct { id int32 description string amount int32 }{ {1, "q1", 1}, {2, "q2", 2}, {3, "q3", 3}, } rows, err := br.Query() if err != nil { t.Error(err) } var id int32 var description string var amount int32 rowCount := 0 for rows.Next() { if rowCount >= len(selectFromLedgerExpectedRows) { t.Fatalf("got too many rows: %d", rowCount) } if err := rows.Scan(&id, &description, &amount); err != nil { t.Fatalf("row %d: %v", rowCount, err) } if id != selectFromLedgerExpectedRows[rowCount].id { t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) } if description != selectFromLedgerExpectedRows[rowCount].description { t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) } if amount != selectFromLedgerExpectedRows[rowCount].amount { t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) } rowCount++ } if rows.Err() != nil { t.Fatal(rows.Err()) } rowCount = 0 _, err = br.QueryFunc([]interface{}{&id, &description, &amount}, func(pgx.QueryFuncRow) error { if id != selectFromLedgerExpectedRows[rowCount].id { t.Errorf("id => %v, want %v", id, selectFromLedgerExpectedRows[rowCount].id) } if description != selectFromLedgerExpectedRows[rowCount].description { t.Errorf("description => %v, want %v", description, selectFromLedgerExpectedRows[rowCount].description) } if amount != selectFromLedgerExpectedRows[rowCount].amount { t.Errorf("amount => %v, want %v", amount, selectFromLedgerExpectedRows[rowCount].amount) } rowCount++ return nil }) if err != nil { t.Error(err) } err = br.QueryRow().Scan(&id, &description, &amount) if !errors.Is(err, pgx.ErrNoRows) { t.Errorf("expected pgx.ErrNoRows but got: %v", err) } err = br.QueryRow().Scan(&amount) if err != nil { t.Error(err) } if amount != 6 { t.Errorf("amount => %v, want %v", amount, 6) } err = br.Close() if err != nil { t.Fatal(err) } ensureConnValid(t, conn) } func TestConnSendBatchMany(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` mustExec(t, conn, sql) batch := &pgx.Batch{} numInserts := 1000 for i := 0; i < numInserts; i++ { batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) } batch.Queue("select count(*) from ledger") br := conn.SendBatch(context.Background(), batch) for i := 0; i < numInserts; i++ { ct, err := br.Exec() assert.NoError(t, err) assert.EqualValues(t, 1, ct.RowsAffected()) } var actualInserts int err := br.QueryRow().Scan(&actualInserts) assert.NoError(t, err) assert.EqualValues(t, numInserts, actualInserts) err = br.Close() require.NoError(t, err) ensureConnValid(t, conn) } func TestConnSendBatchWithPreparedStatement(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") _, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") if err != nil { t.Fatal(err) } batch := &pgx.Batch{} queryCount := 3 for i := 0; i < queryCount; i++ { batch.Queue("ps1", 5) } br := conn.SendBatch(context.Background(), batch) for i := 0; i < queryCount; i++ { rows, err := br.Query() if err != nil { t.Fatal(err) } for k := 0; rows.Next(); k++ { var n int if err := rows.Scan(&n); err != nil { t.Fatal(err) } if n != k { t.Fatalf("n => %v, want %v", n, k) } } if rows.Err() != nil { t.Fatal(rows.Err()) } } err = br.Close() if err != nil { t.Fatal(err) } ensureConnValid(t, conn) } // https://github.com/jackc/pgx/issues/856 func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) { t.Parallel() config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.BuildStatementCache = nil conn := mustConnect(t, config) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") _, err = conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") if err != nil { t.Fatal(err) } batch := &pgx.Batch{} queryCount := 3 for i := 0; i < queryCount; i++ { batch.Queue("ps1", 5) } br := conn.SendBatch(context.Background(), batch) for i := 0; i < queryCount; i++ { rows, err := br.Query() if err != nil { t.Fatal(err) } for k := 0; rows.Next(); k++ { var n int if err := rows.Scan(&n); err != nil { t.Fatal(err) } if n != k { t.Fatalf("n => %v, want %v", n, k) } } if rows.Err() != nil { t.Fatal(rows.Err()) } } err = br.Close() if err != nil { t.Fatal(err) } ensureConnValid(t, conn) } func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) batch := &pgx.Batch{} batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n") br := conn.SendBatch(context.Background(), batch) rows, err := br.Query() if err != nil { t.Error(err) } for i := 0; i < 3; i++ { if !rows.Next() { t.Error("expected a row to be available") } var n int if err := rows.Scan(&n); err != nil { t.Error(err) } if n != i { t.Errorf("n => %v, want %v", n, i) } } rows.Close() rows, err = br.Query() if err != nil { t.Error(err) } for i := 0; rows.Next(); i++ { var n int if err := rows.Scan(&n); err != nil { t.Error(err) } if n != i { t.Errorf("n => %v, want %v", n, i) } } if rows.Err() != nil { t.Error(rows.Err()) } err = br.Close() if err != nil { t.Fatal(err) } ensureConnValid(t, conn) } func TestConnSendBatchQueryError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) batch := &pgx.Batch{} batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") batch.Queue("select n from generate_series(0,5) n") br := conn.SendBatch(context.Background(), batch) rows, err := br.Query() if err != nil { t.Error(err) } for i := 0; rows.Next(); i++ { var n int if err := rows.Scan(&n); err != nil { t.Error(err) } if n != i { t.Errorf("n => %v, want %v", n, i) } } if pgErr, ok := rows.Err().(*pgconn.PgError); !(ok && pgErr.Code == "22012") { t.Errorf("rows.Err() => %v, want error code %v", rows.Err(), 22012) } err = br.Close() if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { t.Errorf("rows.Err() => %v, want error code %v", err, 22012) } ensureConnValid(t, conn) } func TestConnSendBatchQuerySyntaxError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) batch := &pgx.Batch{} batch.Queue("select 1 1") br := conn.SendBatch(context.Background(), batch) var n int32 err := br.QueryRow().Scan(&n) if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { t.Errorf("rows.Err() => %v, want error code %v", err, 42601) } err = br.Close() if err == nil { t.Error("Expected error") } ensureConnValid(t, conn) } func TestConnSendBatchQueryRowInsert(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` mustExec(t, conn, sql) batch := &pgx.Batch{} batch.Queue("select 1") batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) br := conn.SendBatch(context.Background(), batch) var value int err := br.QueryRow().Scan(&value) if err != nil { t.Error(err) } ct, err := br.Exec() if err != nil { t.Error(err) } if ct.RowsAffected() != 2 { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) } br.Close() ensureConnValid(t, conn) } func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) sql := `create temporary table ledger( id serial primary key, description varchar not null, amount int not null );` mustExec(t, conn, sql) batch := &pgx.Batch{} batch.Queue("select 1 union all select 2 union all select 3") batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) br := conn.SendBatch(context.Background(), batch) rows, err := br.Query() if err != nil { t.Error(err) } rows.Close() ct, err := br.Exec() if err != nil { t.Error(err) } if ct.RowsAffected() != 2 { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 2) } br.Close() ensureConnValid(t, conn) } func TestTxSendBatch(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) sql := `create temporary table ledger1( id serial primary key, description varchar not null );` mustExec(t, conn, sql) sql = `create temporary table ledger2( id int primary key, amount int not null );` mustExec(t, conn, sql) tx, _ := conn.Begin(context.Background()) batch := &pgx.Batch{} batch.Queue("insert into ledger1(description) values($1) returning id", "q1") br := tx.SendBatch(context.Background(), batch) var id int err := br.QueryRow().Scan(&id) if err != nil { t.Error(err) } br.Close() batch = &pgx.Batch{} batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) batch.Queue("select amount from ledger2 where id = $1", id) br = tx.SendBatch(context.Background(), batch) ct, err := br.Exec() if err != nil { t.Error(err) } if ct.RowsAffected() != 1 { t.Errorf("ct.RowsAffected() => %v, want %v", ct.RowsAffected(), 1) } var amount int err = br.QueryRow().Scan(&amount) if err != nil { t.Error(err) } br.Close() tx.Commit(context.Background()) var count int conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) if count != 1 { t.Errorf("count => %v, want %v", count, 1) } err = br.Close() if err != nil { t.Fatal(err) } ensureConnValid(t, conn) } func TestTxSendBatchRollback(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) sql := `create temporary table ledger1( id serial primary key, description varchar not null );` mustExec(t, conn, sql) tx, _ := conn.Begin(context.Background()) batch := &pgx.Batch{} batch.Queue("insert into ledger1(description) values($1) returning id", "q1") br := tx.SendBatch(context.Background(), batch) var id int err := br.QueryRow().Scan(&id) if err != nil { t.Error(err) } br.Close() tx.Rollback(context.Background()) row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) var count int row.Scan(&count) if count != 0 { t.Errorf("count => %v, want %v", count, 0) } ensureConnValid(t, conn) } func TestConnBeginBatchDeferredError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") mustExec(t, conn, `create temporary table t ( id text primary key, n int not null, unique (n) deferrable initially deferred ); insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) batch := &pgx.Batch{} batch.Queue(`update t set n=n+1 where id='b' returning *`) br := conn.SendBatch(context.Background(), batch) rows, err := br.Query() if err != nil { t.Error(err) } for rows.Next() { var id string var n int32 err = rows.Scan(&id, &n) if err != nil { t.Fatal(err) } } err = br.Close() if err == nil { t.Fatal("expected error 23505 but got none") } if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { t.Fatalf("expected error 23505, got %v", err) } ensureConnValid(t, conn) } func TestConnSendBatchNoStatementCache(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = nil conn := mustConnect(t, config) defer closeConn(t, conn) testConnSendBatch(t, conn, 3) } func TestConnSendBatchPrepareStatementCache(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModePrepare, 32) } conn := mustConnect(t, config) defer closeConn(t, conn) testConnSendBatch(t, conn, 3) } func TestConnSendBatchDescribeStatementCache(t *testing.T) { config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModeDescribe, 32) } conn := mustConnect(t, config) defer closeConn(t, conn) testConnSendBatch(t, conn, 3) } func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) { batch := &pgx.Batch{} for j := 0; j < queryCount; j++ { batch.Queue("select n from generate_series(0,5) n") } br := conn.SendBatch(context.Background(), batch) for j := 0; j < queryCount; j++ { rows, err := br.Query() require.NoError(t, err) for k := 0; rows.Next(); k++ { var n int err := rows.Scan(&n) require.NoError(t, err) require.Equal(t, k, n) } require.NoError(t, rows.Err()) } err := br.Close() require.NoError(t, err) } func TestLogBatchStatementsOnExec(t *testing.T) { l1 := &testLogger{} config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.Logger = l1 conn := mustConnect(t, config) defer closeConn(t, conn) l1.logs = l1.logs[0:0] // Clear logs written when establishing connection batch := &pgx.Batch{} batch.Queue("create table foo (id bigint)") batch.Queue("drop table foo") br := conn.SendBatch(context.Background(), batch) _, err := br.Exec() if err != nil { t.Fatalf("Unexpected error creating table: %v", err) } _, err = br.Exec() if err != nil { t.Fatalf("Unexpected error dropping table: %v", err) } if len(l1.logs) != 3 { t.Fatalf("Expected two log entries but got %d", len(l1.logs)) } if l1.logs[0].msg != "SendBatch" { t.Errorf("Expected first log message to be 'SendBatch' but was '%s'", l1.logs[0].msg) } if l1.logs[1].msg != "BatchResult.Exec" { t.Errorf("Expected first log message to be 'BatchResult.Exec' but was '%s'", l1.logs[0].msg) } if l1.logs[1].data["sql"] != "create table foo (id bigint)" { t.Errorf("Expected the first query to be 'create table foo (id bigint)' but was '%s'", l1.logs[0].data["sql"]) } if l1.logs[2].msg != "BatchResult.Exec" { t.Errorf("Expected second log message to be 'BatchResult.Exec' but was '%s", l1.logs[1].msg) } if l1.logs[2].data["sql"] != "drop table foo" { t.Errorf("Expected the second query to be 'drop table foo' but was '%s'", l1.logs[1].data["sql"]) } } func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { l1 := &testLogger{} config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.Logger = l1 conn := mustConnect(t, config) defer closeConn(t, conn) l1.logs = l1.logs[0:0] // Clear logs written when establishing connection batch := &pgx.Batch{} batch.Queue("select generate_series(1,$1)", 100) batch.Queue("select 1 = 1;") br := conn.SendBatch(context.Background(), batch) if err := br.Close(); err != nil { t.Fatalf("Unexpected batch error: %v", err) } if len(l1.logs) != 3 { t.Fatalf("Expected 2 log statements but found %d", len(l1.logs)) } if l1.logs[0].msg != "SendBatch" { t.Errorf("Expected first log message to be 'SendBatch' but was '%s'", l1.logs[0].msg) } if l1.logs[1].msg != "BatchResult.Close" { t.Errorf("Expected first log statement to be 'BatchResult.Close' but was '%s'", l1.logs[0].msg) } if l1.logs[1].data["sql"] != "select generate_series(1,$1)" { t.Errorf("Expected first query to be 'select generate_series(1,$1)' but was '%s'", l1.logs[0].data["sql"]) } if l1.logs[2].msg != "BatchResult.Close" { t.Errorf("Expected second log statement to be 'BatchResult.Close' but was %s", l1.logs[1].msg) } if l1.logs[2].data["sql"] != "select 1 = 1;" { t.Errorf("Expected second query to be 'select 1 = 1;' but was '%s'", l1.logs[1].data["sql"]) } } func TestSendBatchSimpleProtocol(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.PreferSimpleProtocol = true ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() conn := mustConnect(t, config) defer closeConn(t, conn) var batch pgx.Batch batch.Queue("SELECT 1::int") batch.Queue("SELECT 2::int; SELECT $1::int", 3) results := conn.SendBatch(ctx, &batch) rows, err := results.Query() assert.NoError(t, err) assert.True(t, rows.Next()) values, err := rows.Values() assert.NoError(t, err) assert.EqualValues(t, 1, values[0]) assert.False(t, rows.Next()) rows, err = results.Query() assert.NoError(t, err) assert.True(t, rows.Next()) values, err = rows.Values() assert.NoError(t, err) assert.EqualValues(t, 2, values[0]) assert.False(t, rows.Next()) rows, err = results.Query() assert.NoError(t, err) assert.True(t, rows.Next()) values, err = rows.Values() assert.NoError(t, err) assert.EqualValues(t, 3, values[0]) assert.False(t, rows.Next()) } pgx-4.18.1/bench_test.go000066400000000000000000001023661437725773200151010ustar00rootroot00000000000000package pgx_test import ( "bytes" "context" "fmt" "io" "net" "os" "strconv" "strings" "testing" "time" "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/require" ) func BenchmarkMinimalUnpreparedSelectWithoutStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = nil conn := mustConnect(b, config) defer closeConn(b, conn) var n int64 b.ResetTimer() for i := 0; i < b.N; i++ { err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n) if err != nil { b.Fatal(err) } if n != int64(i) { b.Fatalf("expected %d, got %d", i, n) } } } func BenchmarkMinimalUnpreparedSelectWithStatementCacheModeDescribe(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModeDescribe, 32) } conn := mustConnect(b, config) defer closeConn(b, conn) var n int64 b.ResetTimer() for i := 0; i < b.N; i++ { err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n) if err != nil { b.Fatal(err) } if n != int64(i) { b.Fatalf("expected %d, got %d", i, n) } } } func BenchmarkMinimalUnpreparedSelectWithStatementCacheModePrepare(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModePrepare, 32) } conn := mustConnect(b, config) defer closeConn(b, conn) var n int64 b.ResetTimer() for i := 0; i < b.N; i++ { err := conn.QueryRow(context.Background(), "select $1::int8", i).Scan(&n) if err != nil { b.Fatal(err) } if n != int64(i) { b.Fatalf("expected %d, got %d", i, n) } } } func BenchmarkMinimalPreparedSelect(b *testing.B) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) _, err := conn.Prepare(context.Background(), "ps1", "select $1::int8") if err != nil { b.Fatal(err) } var n int64 b.ResetTimer() for i := 0; i < b.N; i++ { err = conn.QueryRow(context.Background(), "ps1", i).Scan(&n) if err != nil { b.Fatal(err) } if n != int64(i) { b.Fatalf("expected %d, got %d", i, n) } } } func BenchmarkMinimalPgConnPreparedSelect(b *testing.B) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) pgConn := conn.PgConn() _, err := pgConn.Prepare(context.Background(), "ps1", "select $1::int8", nil) if err != nil { b.Fatal(err) } encodedBytes := make([]byte, 8) b.ResetTimer() for i := 0; i < b.N; i++ { rr := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{encodedBytes}, []int16{1}, []int16{1}) if err != nil { b.Fatal(err) } for rr.NextRow() { for i := range rr.Values() { if bytes.Compare(rr.Values()[0], encodedBytes) != 0 { b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes) } } } _, err = rr.Close() if err != nil { b.Fatal(err) } } } func BenchmarkPointerPointerWithNullValues(b *testing.B) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) _, err := conn.Prepare(context.Background(), "selectNulls", "select 1::int4, 'johnsmith', null::text, null::text, null::text, null::date, null::timestamptz") if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { var record struct { id int32 userName string email *string name *string sex *string birthDate *time.Time lastLoginTime *time.Time } err = conn.QueryRow(context.Background(), "selectNulls").Scan( &record.id, &record.userName, &record.email, &record.name, &record.sex, &record.birthDate, &record.lastLoginTime, ) if err != nil { b.Fatal(err) } // These checks both ensure that the correct data was returned // and provide a benchmark of accessing the returned values. if record.id != 1 { b.Fatalf("bad value for id: %v", record.id) } if record.userName != "johnsmith" { b.Fatalf("bad value for userName: %v", record.userName) } if record.email != nil { b.Fatalf("bad value for email: %v", record.email) } if record.name != nil { b.Fatalf("bad value for name: %v", record.name) } if record.sex != nil { b.Fatalf("bad value for sex: %v", record.sex) } if record.birthDate != nil { b.Fatalf("bad value for birthDate: %v", record.birthDate) } if record.lastLoginTime != nil { b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) } } } func BenchmarkPointerPointerWithPresentValues(b *testing.B) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) _, err := conn.Prepare(context.Background(), "selectNulls", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { var record struct { id int32 userName string email *string name *string sex *string birthDate *time.Time lastLoginTime *time.Time } err = conn.QueryRow(context.Background(), "selectNulls").Scan( &record.id, &record.userName, &record.email, &record.name, &record.sex, &record.birthDate, &record.lastLoginTime, ) if err != nil { b.Fatal(err) } // These checks both ensure that the correct data was returned // and provide a benchmark of accessing the returned values. if record.id != 1 { b.Fatalf("bad value for id: %v", record.id) } if record.userName != "johnsmith" { b.Fatalf("bad value for userName: %v", record.userName) } if record.email == nil || *record.email != "johnsmith@example.com" { b.Fatalf("bad value for email: %v", record.email) } if record.name == nil || *record.name != "John Smith" { b.Fatalf("bad value for name: %v", record.name) } if record.sex == nil || *record.sex != "male" { b.Fatalf("bad value for sex: %v", record.sex) } if record.birthDate == nil || *record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) { b.Fatalf("bad value for birthDate: %v", record.birthDate) } if record.lastLoginTime == nil || *record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) } } } func BenchmarkSelectWithoutLogging(b *testing.B) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) benchmarkSelectWithLog(b, conn) } type discardLogger struct{} func (dl discardLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { } func BenchmarkSelectWithLoggingTraceDiscard(b *testing.B) { var logger discardLogger config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.Logger = logger config.LogLevel = pgx.LogLevelTrace conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkSelectWithLog(b, conn) } func BenchmarkSelectWithLoggingDebugWithDiscard(b *testing.B) { var logger discardLogger config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.Logger = logger config.LogLevel = pgx.LogLevelDebug conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkSelectWithLog(b, conn) } func BenchmarkSelectWithLoggingInfoWithDiscard(b *testing.B) { var logger discardLogger config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.Logger = logger config.LogLevel = pgx.LogLevelInfo conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkSelectWithLog(b, conn) } func BenchmarkSelectWithLoggingErrorWithDiscard(b *testing.B) { var logger discardLogger config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.Logger = logger config.LogLevel = pgx.LogLevelError conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkSelectWithLog(b, conn) } func benchmarkSelectWithLog(b *testing.B, conn *pgx.Conn) { _, err := conn.Prepare(context.Background(), "test", "select 1::int4, 'johnsmith', 'johnsmith@example.com', 'John Smith', 'male', '1970-01-01'::date, '2015-01-01 00:00:00'::timestamptz") if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { var record struct { id int32 userName string email string name string sex string birthDate time.Time lastLoginTime time.Time } err = conn.QueryRow(context.Background(), "test").Scan( &record.id, &record.userName, &record.email, &record.name, &record.sex, &record.birthDate, &record.lastLoginTime, ) if err != nil { b.Fatal(err) } // These checks both ensure that the correct data was returned // and provide a benchmark of accessing the returned values. if record.id != 1 { b.Fatalf("bad value for id: %v", record.id) } if record.userName != "johnsmith" { b.Fatalf("bad value for userName: %v", record.userName) } if record.email != "johnsmith@example.com" { b.Fatalf("bad value for email: %v", record.email) } if record.name != "John Smith" { b.Fatalf("bad value for name: %v", record.name) } if record.sex != "male" { b.Fatalf("bad value for sex: %v", record.sex) } if record.birthDate != time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) { b.Fatalf("bad value for birthDate: %v", record.birthDate) } if record.lastLoginTime != time.Date(2015, 1, 1, 0, 0, 0, 0, time.Local) { b.Fatalf("bad value for lastLoginTime: %v", record.lastLoginTime) } } } const benchmarkWriteTableCreateSQL = `drop table if exists t; create table t( varchar_1 varchar not null, varchar_2 varchar not null, varchar_null_1 varchar, date_1 date not null, date_null_1 date, int4_1 int4 not null, int4_2 int4 not null, int4_null_1 int4, tstz_1 timestamptz not null, tstz_2 timestamptz, bool_1 bool not null, bool_2 bool not null, bool_3 bool not null ); ` const benchmarkWriteTableInsertSQL = `insert into t( varchar_1, varchar_2, varchar_null_1, date_1, date_null_1, int4_1, int4_2, int4_null_1, tstz_1, tstz_2, bool_1, bool_2, bool_3 ) values ( $1::varchar, $2::varchar, $3::varchar, $4::date, $5::date, $6::int4, $7::int4, $8::int4, $9::timestamptz, $10::timestamptz, $11::bool, $12::bool, $13::bool )` type benchmarkWriteTableCopyFromSrc struct { count int idx int row []interface{} } func (s *benchmarkWriteTableCopyFromSrc) Next() bool { s.idx++ return s.idx < s.count } func (s *benchmarkWriteTableCopyFromSrc) Values() ([]interface{}, error) { return s.row, nil } func (s *benchmarkWriteTableCopyFromSrc) Err() error { return nil } func newBenchmarkWriteTableCopyFromSrc(count int) pgx.CopyFromSource { return &benchmarkWriteTableCopyFromSrc{ count: count, row: []interface{}{ "varchar_1", "varchar_2", &pgtype.Text{Status: pgtype.Null}, time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), &pgtype.Date{Status: pgtype.Null}, 1, 2, &pgtype.Int4{Status: pgtype.Null}, time.Date(2001, 1, 1, 0, 0, 0, 0, time.Local), time.Date(2002, 1, 1, 0, 0, 0, 0, time.Local), true, false, true, }, } } func benchmarkWriteNRowsViaInsert(b *testing.B, n int) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) mustExec(b, conn, benchmarkWriteTableCreateSQL) _, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL) if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { src := newBenchmarkWriteTableCopyFromSrc(n) tx, err := conn.Begin(context.Background()) if err != nil { b.Fatal(err) } for src.Next() { values, _ := src.Values() if _, err = tx.Exec(context.Background(), "insert_t", values...); err != nil { b.Fatalf("Exec unexpectedly failed with: %v", err) } } err = tx.Commit(context.Background()) if err != nil { b.Fatal(err) } } } type queryArgs []interface{} func (qa *queryArgs) Append(v interface{}) string { *qa = append(*qa, v) return "$" + strconv.Itoa(len(*qa)) } // note this function is only used for benchmarks -- it doesn't escape tableName // or columnNames func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc pgx.CopyFromSource) (int, error) { maxRowsPerInsert := 65535 / len(columnNames) rowsThisInsert := 0 rowCount := 0 sqlBuf := &bytes.Buffer{} args := make(queryArgs, 0) resetQuery := func() { sqlBuf.Reset() fmt.Fprintf(sqlBuf, "insert into %s(%s) values", tableName, strings.Join(columnNames, ", ")) args = args[0:0] rowsThisInsert = 0 } resetQuery() tx, err := conn.Begin(context.Background()) if err != nil { return 0, err } defer tx.Rollback(context.Background()) for rowSrc.Next() { if rowsThisInsert > 0 { sqlBuf.WriteByte(',') } sqlBuf.WriteByte('(') values, err := rowSrc.Values() if err != nil { return 0, err } for i, val := range values { if i > 0 { sqlBuf.WriteByte(',') } sqlBuf.WriteString(args.Append(val)) } sqlBuf.WriteByte(')') rowsThisInsert++ if rowsThisInsert == maxRowsPerInsert { _, err := tx.Exec(context.Background(), sqlBuf.String(), args...) if err != nil { return 0, err } rowCount += rowsThisInsert resetQuery() } } if rowsThisInsert > 0 { _, err := tx.Exec(context.Background(), sqlBuf.String(), args...) if err != nil { return 0, err } rowCount += rowsThisInsert } if err := tx.Commit(context.Background()); err != nil { return 0, nil } return rowCount, nil } func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) mustExec(b, conn, benchmarkWriteTableCreateSQL) _, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL) if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { src := newBenchmarkWriteTableCopyFromSrc(n) _, err := multiInsert(conn, "t", []string{"varchar_1", "varchar_2", "varchar_null_1", "date_1", "date_null_1", "int4_1", "int4_2", "int4_null_1", "tstz_1", "tstz_2", "bool_1", "bool_2", "bool_3"}, src) if err != nil { b.Fatal(err) } } } func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE"))) defer closeConn(b, conn) mustExec(b, conn, benchmarkWriteTableCreateSQL) b.ResetTimer() for i := 0; i < b.N; i++ { src := newBenchmarkWriteTableCopyFromSrc(n) _, err := conn.CopyFrom(context.Background(), pgx.Identifier{"t"}, []string{"varchar_1", "varchar_2", "varchar_null_1", "date_1", "date_null_1", "int4_1", "int4_2", "int4_null_1", "tstz_1", "tstz_2", "bool_1", "bool_2", "bool_3"}, src) if err != nil { b.Fatal(err) } } } func BenchmarkWrite5RowsViaInsert(b *testing.B) { benchmarkWriteNRowsViaInsert(b, 5) } func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 5) } func BenchmarkWrite5RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 5) } func BenchmarkWrite10RowsViaInsert(b *testing.B) { benchmarkWriteNRowsViaInsert(b, 10) } func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 10) } func BenchmarkWrite10RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10) } func BenchmarkWrite100RowsViaInsert(b *testing.B) { benchmarkWriteNRowsViaInsert(b, 100) } func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 100) } func BenchmarkWrite100RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 100) } func BenchmarkWrite1000RowsViaInsert(b *testing.B) { benchmarkWriteNRowsViaInsert(b, 1000) } func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 1000) } func BenchmarkWrite1000RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 1000) } func BenchmarkWrite10000RowsViaInsert(b *testing.B) { benchmarkWriteNRowsViaInsert(b, 10000) } func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 10000) } func BenchmarkWrite10000RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10000) } func BenchmarkMultipleQueriesNonBatchNoStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = nil conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkMultipleQueriesNonBatch(b, conn, 3) } func BenchmarkMultipleQueriesNonBatchPrepareStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModePrepare, 32) } conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkMultipleQueriesNonBatch(b, conn, 3) } func BenchmarkMultipleQueriesNonBatchDescribeStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModeDescribe, 32) } conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkMultipleQueriesNonBatch(b, conn, 3) } func benchmarkMultipleQueriesNonBatch(b *testing.B, conn *pgx.Conn, queryCount int) { b.ResetTimer() for i := 0; i < b.N; i++ { for j := 0; j < queryCount; j++ { rows, err := conn.Query(context.Background(), "select n from generate_series(0, 5) n") if err != nil { b.Fatal(err) } for k := 0; rows.Next(); k++ { var n int if err := rows.Scan(&n); err != nil { b.Fatal(err) } if n != k { b.Fatalf("n => %v, want %v", n, k) } } if rows.Err() != nil { b.Fatal(rows.Err()) } } } } func BenchmarkMultipleQueriesBatchNoStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = nil conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkMultipleQueriesBatch(b, conn, 3) } func BenchmarkMultipleQueriesBatchPrepareStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModePrepare, 32) } conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkMultipleQueriesBatch(b, conn, 3) } func BenchmarkMultipleQueriesBatchDescribeStatementCache(b *testing.B) { config := mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModeDescribe, 32) } conn := mustConnect(b, config) defer closeConn(b, conn) benchmarkMultipleQueriesBatch(b, conn, 3) } func benchmarkMultipleQueriesBatch(b *testing.B, conn *pgx.Conn, queryCount int) { b.ResetTimer() for i := 0; i < b.N; i++ { batch := &pgx.Batch{} for j := 0; j < queryCount; j++ { batch.Queue("select n from generate_series(0,5) n") } br := conn.SendBatch(context.Background(), batch) for j := 0; j < queryCount; j++ { rows, err := br.Query() if err != nil { b.Fatal(err) } for k := 0; rows.Next(); k++ { var n int if err := rows.Scan(&n); err != nil { b.Fatal(err) } if n != k { b.Fatalf("n => %v, want %v", n, k) } } if rows.Err() != nil { b.Fatal(rows.Err()) } } err := br.Close() if err != nil { b.Fatal(err) } } } func BenchmarkSelectManyUnknownEnum(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) ctx := context.Background() tx, err := conn.Begin(ctx) require.NoError(b, err) defer tx.Rollback(ctx) _, err = tx.Exec(context.Background(), "drop type if exists color;") require.NoError(b, err) _, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`) require.NoError(b, err) b.ResetTimer() var x, y, z string for i := 0; i < b.N; i++ { rows, err := conn.Query(ctx, "select 'blue'::color, 'green'::color, 'orange'::color from generate_series(1,10)") if err != nil { b.Fatal(err) } for rows.Next() { err = rows.Scan(&x, &y, &z) if err != nil { b.Fatal(err) } if x != "blue" { b.Fatal("unexpected result") } if y != "green" { b.Fatal("unexpected result") } if z != "orange" { b.Fatal("unexpected result") } } if rows.Err() != nil { b.Fatal(rows.Err()) } } } func BenchmarkSelectManyRegisteredEnum(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) ctx := context.Background() tx, err := conn.Begin(ctx) require.NoError(b, err) defer tx.Rollback(ctx) _, err = tx.Exec(context.Background(), "drop type if exists color;") require.NoError(b, err) _, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`) require.NoError(b, err) var oid uint32 err = conn.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "color").Scan(&oid) require.NoError(b, err) et := pgtype.NewEnumType("color", []string{"blue", "green", "orange"}) conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: et, Name: "color", OID: oid}) b.ResetTimer() var x, y, z string for i := 0; i < b.N; i++ { rows, err := conn.Query(ctx, "select 'blue'::color, 'green'::color, 'orange'::color from generate_series(1,10)") if err != nil { b.Fatal(err) } for rows.Next() { err = rows.Scan(&x, &y, &z) if err != nil { b.Fatal(err) } if x != "blue" { b.Fatal("unexpected result") } if y != "green" { b.Fatal("unexpected result") } if z != "orange" { b.Fatal("unexpected result") } } if rows.Err() != nil { b.Fatal(rows.Err()) } } } func getSelectRowsCounts(b *testing.B) []int64 { var rowCounts []int64 { s := os.Getenv("PGX_BENCH_SELECT_ROWS_COUNTS") if s != "" { for _, p := range strings.Split(s, " ") { n, err := strconv.ParseInt(p, 10, 64) if err != nil { b.Fatalf("Bad PGX_BENCH_SELECT_ROWS_COUNTS value: %v", err) } rowCounts = append(rowCounts, n) } } } if len(rowCounts) == 0 { rowCounts = []int64{1, 10, 100, 1000} } return rowCounts } type BenchRowSimple struct { ID int32 FirstName string LastName string Sex string BirthDate time.Time Weight int32 Height int32 UpdateTime time.Time } func BenchmarkSelectRowsScanSimple(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) rowCounts := getSelectRowsCounts(b) for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { br := &BenchRowSimple{} for i := 0; i < b.N; i++ { rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) if err != nil { b.Fatal(err) } for rows.Next() { rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) } if rows.Err() != nil { b.Fatal(rows.Err()) } } }) } } type BenchRowStringBytes struct { ID int32 FirstName []byte LastName []byte Sex []byte BirthDate time.Time Weight int32 Height int32 UpdateTime time.Time } func BenchmarkSelectRowsScanStringBytes(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) rowCounts := getSelectRowsCounts(b) for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { br := &BenchRowStringBytes{} for i := 0; i < b.N; i++ { rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) if err != nil { b.Fatal(err) } for rows.Next() { rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) } if rows.Err() != nil { b.Fatal(rows.Err()) } } }) } } type BenchRowDecoder struct { ID pgtype.Int4 FirstName pgtype.Text LastName pgtype.Text Sex pgtype.Text BirthDate pgtype.Date Weight pgtype.Int4 Height pgtype.Int4 UpdateTime pgtype.Timestamptz } func BenchmarkSelectRowsScanDecoder(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) rowCounts := getSelectRowsCounts(b) for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { formats := []struct { name string code int16 }{ {"text", pgx.TextFormatCode}, {"binary", pgx.BinaryFormatCode}, } for _, format := range formats { b.Run(format.name, func(b *testing.B) { br := &BenchRowDecoder{} for i := 0; i < b.N; i++ { rows, err := conn.Query( context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", pgx.QueryResultFormats{format.code}, rowCount, ) if err != nil { b.Fatal(err) } for rows.Next() { rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) } if rows.Err() != nil { b.Fatal(rows.Err()) } } }) } }) } } func BenchmarkSelectRowsExplicitDecoding(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) rowCounts := getSelectRowsCounts(b) for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { br := &BenchRowDecoder{} for i := 0; i < b.N; i++ { rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) if err != nil { b.Fatal(err) } for rows.Next() { rawValues := rows.RawValues() err = br.ID.DecodeBinary(conn.ConnInfo(), rawValues[0]) if err != nil { b.Fatal(err) } err = br.FirstName.DecodeText(conn.ConnInfo(), rawValues[1]) if err != nil { b.Fatal(err) } err = br.LastName.DecodeText(conn.ConnInfo(), rawValues[2]) if err != nil { b.Fatal(err) } err = br.Sex.DecodeText(conn.ConnInfo(), rawValues[3]) if err != nil { b.Fatal(err) } err = br.BirthDate.DecodeBinary(conn.ConnInfo(), rawValues[4]) if err != nil { b.Fatal(err) } err = br.Weight.DecodeBinary(conn.ConnInfo(), rawValues[5]) if err != nil { b.Fatal(err) } err = br.Height.DecodeBinary(conn.ConnInfo(), rawValues[6]) if err != nil { b.Fatal(err) } err = br.UpdateTime.DecodeBinary(conn.ConnInfo(), rawValues[7]) if err != nil { b.Fatal(err) } } if rows.Err() != nil { b.Fatal(rows.Err()) } } }) } } func BenchmarkSelectRowsPgConnExecText(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) rowCounts := getSelectRowsCounts(b) for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { for i := 0; i < b.N; i++ { mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount)) for mrr.NextResult() { rr := mrr.ResultReader() for rr.NextRow() { rr.Values() } } err := mrr.Close() if err != nil { b.Fatal(err) } } }) } } func BenchmarkSelectRowsPgConnExecParams(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) rowCounts := getSelectRowsCounts(b) for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { formats := []struct { name string code int16 }{ {"text", pgx.TextFormatCode}, {"binary - mostly", pgx.BinaryFormatCode}, } for _, format := range formats { b.Run(format.name, func(b *testing.B) { for i := 0; i < b.N; i++ { rr := conn.PgConn().ExecParams( context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, nil, nil, []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code}, ) for rr.NextRow() { rr.Values() } _, err := rr.Close() if err != nil { b.Fatal(err) } } }) } }) } } func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) rowCounts := getSelectRowsCounts(b) _, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) if err != nil { b.Fatal(err) } for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { formats := []struct { name string code int16 }{ {"text", pgx.TextFormatCode}, {"binary - mostly", pgx.BinaryFormatCode}, } for _, format := range formats { b.Run(format.name, func(b *testing.B) { for i := 0; i < b.N; i++ { rr := conn.PgConn().ExecPrepared( context.Background(), "ps1", [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, nil, []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code}, ) for rr.NextRow() { rr.Values() } _, err := rr.Close() if err != nil { b.Fatal(err) } } }) } }) } } type queryRecorder struct { conn net.Conn writeBuf []byte readCount int } func (qr *queryRecorder) Read(b []byte) (n int, err error) { n, err = qr.conn.Read(b) qr.readCount += n return n, err } func (qr *queryRecorder) Write(b []byte) (n int, err error) { qr.writeBuf = append(qr.writeBuf, b...) return qr.conn.Write(b) } func (qr *queryRecorder) Close() error { return qr.conn.Close() } func (qr *queryRecorder) LocalAddr() net.Addr { return qr.conn.LocalAddr() } func (qr *queryRecorder) RemoteAddr() net.Addr { return qr.conn.RemoteAddr() } func (qr *queryRecorder) SetDeadline(t time.Time) error { return qr.conn.SetDeadline(t) } func (qr *queryRecorder) SetReadDeadline(t time.Time) error { return qr.conn.SetReadDeadline(t) } func (qr *queryRecorder) SetWriteDeadline(t time.Time) error { return qr.conn.SetWriteDeadline(t) } // BenchmarkSelectRowsRawPrepared hijacks a pgconn connection and inserts a queryRecorder. It then executes the query // once. The benchmark is simply sending the exact query bytes over the wire to the server and reading the expected // number of bytes back. It does nothing else. This should be the theoretical maximum performance a Go application // could achieve. func BenchmarkSelectRowsRawPrepared(b *testing.B) { rowCounts := getSelectRowsCounts(b) for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { formats := []struct { name string code int16 }{ {"text", pgx.TextFormatCode}, {"binary - mostly", pgx.BinaryFormatCode}, } for _, format := range formats { b.Run(format.name, func(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")).PgConn() defer conn.Close(context.Background()) _, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) if err != nil { b.Fatal(err) } hijackedConn, err := conn.Hijack() require.NoError(b, err) qr := &queryRecorder{ conn: hijackedConn.Conn, } hijackedConn.Conn = qr hijackedConn.Frontend = hijackedConn.Config.BuildFrontend(qr, qr) conn, err = pgconn.Construct(hijackedConn) require.NoError(b, err) { rr := conn.ExecPrepared( context.Background(), "ps1", [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, nil, []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code}, ) _, err := rr.Close() require.NoError(b, err) } buf := make([]byte, qr.readCount) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := qr.conn.Write(qr.writeBuf) if err != nil { b.Fatal(err) } _, err = io.ReadFull(qr.conn, buf) if err != nil { b.Fatal(err) } } }) } }) } } pgx-4.18.1/ci/000077500000000000000000000000001437725773200130175ustar00rootroot00000000000000pgx-4.18.1/ci/setup_test.bash000077500000000000000000000045001437725773200160570ustar00rootroot00000000000000#!/usr/bin/env bash set -eux if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]] then sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common sudo rm -rf /var/lib/postgresql wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" sudo apt-get update -qq sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf fi sudo /etc/init.d/postgresql restart psql -U postgres -c 'create database pgx_test' psql -U postgres pgx_test -c 'create extension hstore' psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user `whoami`" fi if [[ "${PGVERSION-}" =~ ^cockroach ]] then wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/ cockroach start-single-node --insecure --background --listen-addr=localhost cockroach sql --insecure -e 'create database pgx_test' fi if [ "${CRATEVERSION-}" != "" ] then docker run \ -p "6543:5432" \ -d \ crate:"$CRATEVERSION" \ crate \ -Cnetwork.host=0.0.0.0 \ -Ctransport.host=localhost \ -Clicense.enterprise=false fi pgx-4.18.1/conn.go000066400000000000000000000647031437725773200137220ustar00rootroot00000000000000package pgx import ( "context" "errors" "fmt" "strconv" "strings" "time" "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4/internal/sanitize" ) // ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and // then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic. type ConnConfig struct { pgconn.Config Logger Logger LogLevel LogLevel // Original connection string that was parsed into config. connString string // BuildStatementCache creates the stmtcache.Cache implementation for connections created with this config. Set // to nil to disable automatic prepared statements. BuildStatementCache BuildStatementCacheFunc // PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended // protocol. This can improve performance due to being able to use the binary format. It also does not rely on client // side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement) // and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be // used by default. The same functionality can be controlled on a per query basis by setting // QueryExOptions.SimpleProtocol. PreferSimpleProtocol bool createdByParseConfig bool // Used to enforce created by ParseConfig rule. } // Copy returns a deep copy of the config that is safe to use and modify. // The only exception is the tls.Config: // according to the tls.Config docs it must not be modified after creation. func (cc *ConnConfig) Copy() *ConnConfig { newConfig := new(ConnConfig) *newConfig = *cc newConfig.Config = *newConfig.Config.Copy() return newConfig } // ConnString returns the connection string as parsed by pgx.ParseConfig into pgx.ConnConfig. func (cc *ConnConfig) ConnString() string { return cc.connString } // BuildStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection. type BuildStatementCacheFunc func(conn *pgconn.PgConn) stmtcache.Cache // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access // to multiple database connections from multiple goroutines. type Conn struct { pgConn *pgconn.PgConn config *ConnConfig // config used when establishing this connection preparedStatements map[string]*pgconn.StatementDescription stmtcache stmtcache.Cache logger Logger logLevel LogLevel notifications []*pgconn.Notification doneChan chan struct{} closedChan chan error connInfo *pgtype.ConnInfo wbuf []byte eqb extendedQueryBuilder } // Identifier a PostgreSQL identifier or name. Identifiers can be composed of // multiple parts such as ["schema", "table"] or ["table", "column"]. type Identifier []string // Sanitize returns a sanitized string safe for SQL interpolation. func (ident Identifier) Sanitize() string { parts := make([]string, len(ident)) for i := range ident { s := strings.ReplaceAll(ident[i], string([]byte{0}), "") parts[i] = `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } return strings.Join(parts, ".") } // ErrNoRows occurs when rows are expected but none are returned. var ErrNoRows = errors.New("no rows in result set") // ErrInvalidLogLevel occurs on attempt to set an invalid log level. var ErrInvalidLogLevel = errors.New("invalid log level") // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. func Connect(ctx context.Context, connString string) (*Conn, error) { connConfig, err := ParseConfig(connString) if err != nil { return nil, err } return connect(ctx, connConfig) } // ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct. // connConfig must have been created by ParseConfig. func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { return connect(ctx, connConfig) } // ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig // does. In addition, it accepts the following options: // // statement_cache_capacity // The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512. // // statement_cache_mode // Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server. // "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the // server. "describe" is primarily useful when the environment does not allow prepared statements such as when // running a connection pooler like PgBouncer. Default: "prepare" // // prefer_simple_protocol // Possible values: "true" and "false". Use the simple protocol instead of extended protocol. Default: false func ParseConfig(connString string) (*ConnConfig, error) { config, err := pgconn.ParseConfig(connString) if err != nil { return nil, err } var buildStatementCache BuildStatementCacheFunc statementCacheCapacity := 512 statementCacheMode := stmtcache.ModePrepare if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok { delete(config.RuntimeParams, "statement_cache_capacity") n, err := strconv.ParseInt(s, 10, 32) if err != nil { return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err) } statementCacheCapacity = int(n) } if s, ok := config.RuntimeParams["statement_cache_mode"]; ok { delete(config.RuntimeParams, "statement_cache_mode") switch s { case "prepare": statementCacheMode = stmtcache.ModePrepare case "describe": statementCacheMode = stmtcache.ModeDescribe default: return nil, fmt.Errorf("invalid statement_cache_mod: %s", s) } } if statementCacheCapacity > 0 { buildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, statementCacheMode, statementCacheCapacity) } } preferSimpleProtocol := false if s, ok := config.RuntimeParams["prefer_simple_protocol"]; ok { delete(config.RuntimeParams, "prefer_simple_protocol") if b, err := strconv.ParseBool(s); err == nil { preferSimpleProtocol = b } else { return nil, fmt.Errorf("invalid prefer_simple_protocol: %v", err) } } connConfig := &ConnConfig{ Config: *config, createdByParseConfig: true, LogLevel: LogLevelInfo, BuildStatementCache: buildStatementCache, PreferSimpleProtocol: preferSimpleProtocol, connString: connString, } return connConfig, nil } func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { panic("config must be created by ParseConfig") } originalConfig := config // This isn't really a deep copy. But it is enough to avoid the config.Config.OnNotification mutation from affecting // other connections with the same config. See https://github.com/jackc/pgx/issues/618. { configCopy := *config config = &configCopy } c = &Conn{ config: originalConfig, connInfo: pgtype.NewConnInfo(), logLevel: config.LogLevel, logger: config.Logger, } // Only install pgx notification system if no other callback handler is present. if config.Config.OnNotification == nil { config.Config.OnNotification = c.bufferNotifications } else { if c.shouldLog(LogLevelDebug) { c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]interface{}{"host": config.Config.Host}) } } if c.shouldLog(LogLevelInfo) { c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) } c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) if err != nil { if c.shouldLog(LogLevelError) { c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err}) } return nil, err } c.preparedStatements = make(map[string]*pgconn.StatementDescription) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) if c.config.BuildStatementCache != nil { c.stmtcache = c.config.BuildStatementCache(c.pgConn) } // Replication connections can't execute the queries to // populate the c.PgTypes and c.pgsqlAfInet if _, ok := config.Config.RuntimeParams["replication"]; ok { return c, nil } return c, nil } // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close(ctx context.Context) error { if c.IsClosed() { return nil } err := c.pgConn.Close(ctx) if c.shouldLog(LogLevelInfo) { c.log(ctx, LogLevelInfo, "closed connection", nil) } return err } // Prepare creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. // // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same // name and sql arguments. This allows a code path to Prepare and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { if name != "" { var ok bool if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql { return sd, nil } } if c.shouldLog(LogLevelError) { defer func() { if err != nil { c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) } }() } sd, err = c.pgConn.Prepare(ctx, name, sql, nil) if err != nil { return nil, err } if name != "" { c.preparedStatements[name] = sd } return sd, nil } // Deallocate released a prepared statement func (c *Conn) Deallocate(ctx context.Context, name string) error { delete(c.preparedStatements, name) _, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll() return err } func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) { c.notifications = append(c.notifications, n) } // WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a // slightly more convenient form. func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) { var n *pgconn.Notification // Return already received notification immediately if len(c.notifications) > 0 { n = c.notifications[0] c.notifications = c.notifications[1:] return n, nil } err := c.pgConn.WaitForNotification(ctx) if len(c.notifications) > 0 { n = c.notifications[0] c.notifications = c.notifications[1:] } return n, err } // IsClosed reports if the connection has been closed. func (c *Conn) IsClosed() bool { return c.pgConn.IsClosed() } func (c *Conn) die(err error) { if c.IsClosed() { return } ctx, cancel := context.WithCancel(context.Background()) cancel() // force immediate hard cancel c.pgConn.Close(ctx) } func (c *Conn) shouldLog(lvl LogLevel) bool { return c.logger != nil && c.logLevel >= lvl } func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) { if data == nil { data = map[string]interface{}{} } if c.pgConn != nil && c.pgConn.PID() != 0 { data["pid"] = c.pgConn.PID() } c.logger.Log(ctx, lvl, msg, data) } func quoteIdentifier(s string) string { return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } // Ping executes an empty sql statement against the *Conn // If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned. func (c *Conn) Ping(ctx context.Context) error { _, err := c.Exec(ctx, ";") return err } // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the // PostgreSQL connection than pgx exposes. // // It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn // is used and the connection must be returned to the same state before any *pgx.Conn methods are again used. func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } // StatementCache returns the statement cache used for this connection. func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache } // ConnInfo returns the connection info used for this connection. func (c *Conn) ConnInfo() *pgtype.ConnInfo { return c.connInfo } // Config returns a copy of config that was used to establish this connection. func (c *Conn) Config() *ConnConfig { return c.config.Copy() } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { startTime := time.Now() commandTag, err := c.exec(ctx, sql, arguments...) if err != nil { if c.shouldLog(LogLevelError) { endTime := time.Now() c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err, "time": endTime.Sub(startTime)}) } return commandTag, err } if c.shouldLog(LogLevelInfo) { endTime := time.Now() c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) } return commandTag, err } func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { simpleProtocol := c.config.PreferSimpleProtocol optionLoop: for len(arguments) > 0 { switch arg := arguments[0].(type) { case QuerySimpleProtocol: simpleProtocol = bool(arg) arguments = arguments[1:] default: break optionLoop } } if sd, ok := c.preparedStatements[sql]; ok { return c.execPrepared(ctx, sd, arguments) } if simpleProtocol { return c.execSimpleProtocol(ctx, sql, arguments) } if len(arguments) == 0 { return c.execSimpleProtocol(ctx, sql, arguments) } if c.stmtcache != nil { sd, err := c.stmtcache.Get(ctx, sql) if err != nil { return nil, err } if c.stmtcache.Mode() == stmtcache.ModeDescribe { return c.execParams(ctx, sd, arguments) } return c.execPrepared(ctx, sd, arguments) } sd, err := c.Prepare(ctx, "", sql) if err != nil { return nil, err } return c.execPrepared(ctx, sd, arguments) } func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { if len(arguments) > 0 { sql, err = c.sanitizeForSimpleQuery(sql, arguments...) if err != nil { return nil, err } } mrr := c.pgConn.Exec(ctx, sql) for mrr.NextResult() { commandTag, err = mrr.ResultReader().Close() } err = mrr.Close() return commandTag, err } func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error { if len(sd.ParamOIDs) != len(arguments) { return fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(arguments)) } c.eqb.Reset() args, err := convertDriverValuers(arguments) if err != nil { return err } for i := range args { err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) if err != nil { return err } } for i := range sd.Fields { c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) } return nil } func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { err := c.execParamsAndPreparedPrefix(sd, arguments) if err != nil { return nil, err } result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) { err := c.execParamsAndPreparedPrefix(sd, arguments) if err != nil { return nil, err } result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return result.CommandTag, result.Err } func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { r := &connRows{} r.ctx = ctx r.logger = c r.connInfo = c.connInfo r.startTime = time.Now() r.sql = sql r.args = args r.conn = c return r } // QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query. type QuerySimpleProtocol bool // QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. type QueryResultFormats []int16 // QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. type QueryResultFormatsByOID map[uint32]int16 // Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query // and initializing Rows will be returned. Err() on the returned Rows must be checked after the Rows is closed to // determine if the query executed successfully. // // The returned Rows must be closed before the connection can be used again. It is safe to attempt to read from the // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It // is allowed to ignore the error returned from Query and handle it in Rows. // // Err() on the returned Rows must be checked after the Rows is closed to determine if the query executed successfully // as some errors can only be detected by reading the entire response. e.g. A divide by zero error on the last row. // // For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID simpleProtocol := c.config.PreferSimpleProtocol optionLoop: for len(args) > 0 { switch arg := args[0].(type) { case QueryResultFormats: resultFormats = arg args = args[1:] case QueryResultFormatsByOID: resultFormatsByOID = arg args = args[1:] case QuerySimpleProtocol: simpleProtocol = bool(arg) args = args[1:] default: break optionLoop } } rows := c.getRows(ctx, sql, args) var err error sd, ok := c.preparedStatements[sql] if simpleProtocol && !ok { sql, err = c.sanitizeForSimpleQuery(sql, args...) if err != nil { rows.fatal(err) return rows, err } mrr := c.pgConn.Exec(ctx, sql) if mrr.NextResult() { rows.resultReader = mrr.ResultReader() rows.multiResultReader = mrr } else { err = mrr.Close() rows.fatal(err) return rows, err } return rows, nil } c.eqb.Reset() if !ok { if c.stmtcache != nil { sd, err = c.stmtcache.Get(ctx, sql) if err != nil { rows.fatal(err) return rows, rows.err } } else { sd, err = c.pgConn.Prepare(ctx, "", sql, nil) if err != nil { rows.fatal(err) return rows, rows.err } } } if len(sd.ParamOIDs) != len(args) { rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) return rows, rows.err } rows.sql = sd.SQL args, err = convertDriverValuers(args) if err != nil { rows.fatal(err) return rows, rows.err } for i := range args { err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) if err != nil { rows.fatal(err) return rows, rows.err } } if resultFormatsByOID != nil { resultFormats = make([]int16, len(sd.Fields)) for i := range resultFormats { resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] } } if resultFormats == nil { for i := range sd.Fields { c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) } resultFormats = c.eqb.resultFormats } if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe && !ok { rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) } else { rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) } c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. return rows, rows.err } // QueryRow is a convenience wrapper over Query. Any error that occurs while // querying is deferred until calling Scan on the returned Row. That Row will // error with ErrNoRows if no rows are returned. func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { rows, _ := c.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } // QueryFuncRow is the argument to the QueryFunc callback function. // // QueryFuncRow is an interface instead of a struct to allow tests to mock QueryFunc. However, adding a method to an // interface is technically a breaking change. Because of this the QueryFuncRow interface is partially excluded from // semantic version requirements. Methods will not be removed or changed, but new methods may be added. type QueryFuncRow interface { FieldDescriptions() []pgproto3.FieldDescription // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid during the current // function call. However, the underlying byte data is safe to retain a reference to and mutate. RawValues() [][]byte } // QueryFunc executes sql with args. For each row returned by the query the values will scanned into the elements of // scans and f will be called. If any row fails to scan or f returns an error the query will be aborted and the error // will be returned. func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { rows, err := c.Query(ctx, sql, args...) if err != nil { return nil, err } defer rows.Close() for rows.Next() { err = rows.Scan(scans...) if err != nil { return nil, err } err = f(rows) if err != nil { return nil, err } } if err := rows.Err(); err != nil { return nil, err } return rows.CommandTag(), nil } // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { startTime := time.Now() simpleProtocol := c.config.PreferSimpleProtocol var sb strings.Builder if simpleProtocol { for i, bi := range b.items { if i > 0 { sb.WriteByte(';') } sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } sb.WriteString(sql) } mrr := c.pgConn.Exec(ctx, sb.String()) return &batchResults{ ctx: ctx, conn: c, mrr: mrr, b: b, ix: 0, } } distinctUnpreparedQueries := map[string]struct{}{} for _, bi := range b.items { if _, ok := c.preparedStatements[bi.query]; ok { continue } distinctUnpreparedQueries[bi.query] = struct{}{} } var stmtCache stmtcache.Cache if len(distinctUnpreparedQueries) > 0 { if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) { stmtCache = c.stmtcache } else { stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries)) } for sql, _ := range distinctUnpreparedQueries { _, err := stmtCache.Get(ctx, sql) if err != nil { return &batchResults{ctx: ctx, conn: c, err: err} } } } batch := &pgconn.Batch{} for _, bi := range b.items { c.eqb.Reset() sd := c.preparedStatements[bi.query] if sd == nil { var err error sd, err = stmtCache.Get(ctx, bi.query) if err != nil { return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err}) } } if len(sd.ParamOIDs) != len(bi.arguments) { return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("mismatched param and argument count")}) } args, err := convertDriverValuers(bi.arguments) if err != nil { return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err}) } for i := range args { err = c.eqb.AppendParam(c.connInfo, sd.ParamOIDs[i], args[i]) if err != nil { return c.logBatchResults(ctx, startTime, &batchResults{ctx: ctx, conn: c, err: err}) } } for i := range sd.Fields { c.eqb.AppendResultFormat(c.ConnInfo().ResultFormatCodeForOID(sd.Fields[i].DataTypeOID)) } if sd.Name == "" { batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats) } else { batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats) } } c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. mrr := c.pgConn.ExecBatch(ctx, batch) return c.logBatchResults(ctx, startTime, &batchResults{ ctx: ctx, conn: c, mrr: mrr, b: b, ix: 0, }) } func (c *Conn) logBatchResults(ctx context.Context, startTime time.Time, results *batchResults) BatchResults { if results.err != nil { if c.shouldLog(LogLevelError) { endTime := time.Now() c.log(ctx, LogLevelError, "SendBatch", map[string]interface{}{"err": results.err, "time": endTime.Sub(startTime)}) } return results } if c.shouldLog(LogLevelInfo) { endTime := time.Now() c.log(ctx, LogLevelInfo, "SendBatch", map[string]interface{}{"batchLen": results.b.Len(), "time": endTime.Sub(startTime)}) } return results } func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) { if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") } if c.pgConn.ParameterStatus("client_encoding") != "UTF8" { return "", errors.New("simple protocol queries must be run with client_encoding=UTF8") } var err error valueArgs := make([]interface{}, len(args)) for i, a := range args { valueArgs[i], err = convertSimpleArgument(c.connInfo, a) if err != nil { return "", err } } return sanitize.SanitizeSQL(sql, valueArgs...) } pgx-4.18.1/conn_test.go000066400000000000000000000777731437725773200147740ustar00rootroot00000000000000package pgx_test import ( "bytes" "context" "log" "os" "strings" "sync" "testing" "time" "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCrateDBConnect(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_CRATEDB_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_CRATEDB_CONN_STRING") } conn, err := pgx.Connect(context.Background(), connString) require.Nil(t, err) defer closeConn(t, conn) assert.Equal(t, connString, conn.Config().ConnString()) var result int err = conn.QueryRow(context.Background(), "select 1 +1").Scan(&result) if err != nil { t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) } if result != 2 { t.Errorf("bad result: %d", result) } } func TestConnect(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_DATABASE") config := mustParseConfig(t, connString) conn, err := pgx.ConnectConfig(context.Background(), config) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } assertConfigsEqual(t, config, conn.Config(), "Conn.Config() returns original config") var currentDB string err = conn.QueryRow(context.Background(), "select current_database()").Scan(¤tDB) if err != nil { t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) } if currentDB != config.Config.Database { t.Errorf("Did not connect to specified database (%v)", config.Config.Database) } var user string err = conn.QueryRow(context.Background(), "select current_user").Scan(&user) if err != nil { t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) } if user != config.Config.User { t.Errorf("Did not connect as specified user (%v)", config.Config.User) } err = conn.Close(context.Background()) if err != nil { t.Fatal("Unable to close connection") } } func TestConnectWithPreferSimpleProtocol(t *testing.T) { t.Parallel() connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) connConfig.PreferSimpleProtocol = true conn := mustConnect(t, connConfig) defer closeConn(t, conn) // If simple protocol is used we should be able to correctly scan the result // into a pgtype.Text as the integer will have been encoded in text. var s pgtype.Text err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s) if err != nil { t.Fatal(err) } if s.Get() != "42" { t.Fatalf(`expected "42", got %v`, s) } ensureConnValid(t, conn) } func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { config := &pgx.ConnConfig{} require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgx.ConnectConfig(context.Background(), config) }) } func TestConfigContainsConnStr(t *testing.T) { connStr := os.Getenv("PGX_TEST_DATABASE") config, err := pgx.ParseConfig(connStr) require.NoError(t, err) assert.Equal(t, connStr, config.ConnString()) } func TestConfigCopyReturnsEqualConfig(t *testing.T) { connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgx.ParseConfig(connString) require.NoError(t, err) copied := original.Copy() assertConfigsEqual(t, original, copied, t.Name()) } func TestConfigCopyCanBeUsedToConnect(t *testing.T) { connString := os.Getenv("PGX_TEST_DATABASE") original, err := pgx.ParseConfig(connString) require.NoError(t, err) copied := original.Copy() assert.NotPanics(t, func() { _, err = pgx.ConnectConfig(context.Background(), copied) }) assert.NoError(t, err) } func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { t.Parallel() config, err := pgx.ParseConfig("statement_cache_capacity=0") require.NoError(t, err) require.Nil(t, config.BuildStatementCache) config, err = pgx.ParseConfig("statement_cache_capacity=42") require.NoError(t, err) require.NotNil(t, config.BuildStatementCache) c := config.BuildStatementCache(nil) require.NotNil(t, c) require.Equal(t, 42, c.Cap()) require.Equal(t, stmtcache.ModePrepare, c.Mode()) config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=prepare") require.NoError(t, err) require.NotNil(t, config.BuildStatementCache) c = config.BuildStatementCache(nil) require.NotNil(t, c) require.Equal(t, 42, c.Cap()) require.Equal(t, stmtcache.ModePrepare, c.Mode()) config, err = pgx.ParseConfig("statement_cache_capacity=42 statement_cache_mode=describe") require.NoError(t, err) require.NotNil(t, config.BuildStatementCache) c = config.BuildStatementCache(nil) require.NotNil(t, c) require.Equal(t, 42, c.Cap()) require.Equal(t, stmtcache.ModeDescribe, c.Mode()) } func TestParseConfigExtractsPreferSimpleProtocol(t *testing.T) { t.Parallel() for _, tt := range []struct { connString string preferSimpleProtocol bool }{ {"", false}, {"prefer_simple_protocol=false", false}, {"prefer_simple_protocol=0", false}, {"prefer_simple_protocol=true", true}, {"prefer_simple_protocol=1", true}, } { config, err := pgx.ParseConfig(tt.connString) require.NoError(t, err) require.Equalf(t, tt.preferSimpleProtocol, config.PreferSimpleProtocol, "connString: `%s`", tt.connString) require.Empty(t, config.RuntimeParams["prefer_simple_protocol"]) } } func TestExec(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" { t.Error("Unexpected results from Exec") } // Multiple statements can be executed -- last command tag is returned if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" { t.Error("Unexpected results from Exec") } // Can execute longer SQL strings than sharedBufferSize if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" { t.Errorf("Unexpected results from Exec: %v", results) } // Exec no-op which does not return a command tag if results := mustExec(t, conn, "--;"); string(results) != "" { t.Errorf("Unexpected results from Exec: %v", results) } }) } func TestExecFailure(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { if _, err := conn.Exec(context.Background(), "selct;"); err == nil { t.Fatal("Expected SQL syntax error") } rows, _ := conn.Query(context.Background(), "select 1") rows.Close() if rows.Err() != nil { t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) } }) } func TestExecFailureWithArguments(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { _, err := conn.Exec(context.Background(), "selct $1;", 1) if err == nil { t.Fatal("Expected SQL syntax error") } assert.False(t, pgconn.SafeToRetry(err)) _, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2") require.Error(t, err) }) } func TestExecContextWithoutCancelation(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);") if err != nil { t.Fatal(err) } if string(commandTag) != "CREATE TABLE" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } assert.False(t, pgconn.SafeToRetry(err)) }) } func TestExecContextFailureWithoutCancelation(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() _, err := conn.Exec(ctx, "selct;") if err == nil { t.Fatal("Expected SQL syntax error") } assert.False(t, pgconn.SafeToRetry(err)) rows, _ := conn.Query(context.Background(), "select 1") rows.Close() if rows.Err() != nil { t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) } assert.False(t, pgconn.SafeToRetry(err)) }) } func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() _, err := conn.Exec(ctx, "selct $1;", 1) if err == nil { t.Fatal("Expected SQL syntax error") } assert.False(t, pgconn.SafeToRetry(err)) }) } func TestExecFailureCloseBefore(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) closeConn(t, conn) _, err := conn.Exec(context.Background(), "select 1") require.Error(t, err) assert.True(t, pgconn.SafeToRetry(err)) } func TestExecStatementCacheModes(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) tests := []struct { name string buildStatementCache pgx.BuildStatementCacheFunc }{ { name: "disabled", buildStatementCache: nil, }, { name: "prepare", buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModePrepare, 32) }, }, { name: "describe", buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModeDescribe, 32) }, }, } for _, tt := range tests { func() { config.BuildStatementCache = tt.buildStatementCache conn := mustConnect(t, config) defer closeConn(t, conn) commandTag, err := conn.Exec(context.Background(), "select 1") assert.NoError(t, err, tt.name) assert.Equal(t, "SELECT 1", string(commandTag), tt.name) commandTag, err = conn.Exec(context.Background(), "select 1 union all select 1") assert.NoError(t, err, tt.name) assert.Equal(t, "SELECT 2", string(commandTag), tt.name) commandTag, err = conn.Exec(context.Background(), "select 1") assert.NoError(t, err, tt.name) assert.Equal(t, "SELECT 1", string(commandTag), tt.name) ensureConnValid(t, conn) }() } } func TestExecPerQuerySimpleProtocol(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() commandTag, err := conn.Exec(ctx, "create temporary table foo(name varchar primary key);") if err != nil { t.Fatal(err) } if string(commandTag) != "CREATE TABLE" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } commandTag, err = conn.Exec(ctx, "insert into foo(name) values($1);", pgx.QuerySimpleProtocol(true), "bar'; drop table foo;--", ) if err != nil { t.Fatal(err) } if string(commandTag) != "INSERT 0 1" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } } func TestPrepare(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) _, err := conn.Prepare(context.Background(), "test", "select $1::varchar") if err != nil { t.Errorf("Unable to prepare statement: %v", err) return } var s string err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s) if err != nil { t.Errorf("Executing prepared statement failed: %v", err) } if s != "hello" { t.Errorf("Prepared statement did not return expected value: %v", s) } err = conn.Deallocate(context.Background(), "test") if err != nil { t.Errorf("conn.Deallocate failed: %v", err) } // Create another prepared statement to ensure Deallocate left the connection // in a working state and that we can reuse the prepared statement name. _, err = conn.Prepare(context.Background(), "test", "select $1::integer") if err != nil { t.Errorf("Unable to prepare statement: %v", err) return } var n int32 err = conn.QueryRow(context.Background(), "test", int32(1)).Scan(&n) if err != nil { t.Errorf("Executing prepared statement failed: %v", err) } if n != 1 { t.Errorf("Prepared statement did not return expected value: %v", s) } err = conn.Deallocate(context.Background(), "test") if err != nil { t.Errorf("conn.Deallocate failed: %v", err) } } func TestPrepareBadSQLFailure(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) if _, err := conn.Prepare(context.Background(), "badSQL", "select foo"); err == nil { t.Fatal("Prepare should have failed with syntax error") } ensureConnValid(t, conn) } func TestPrepareIdempotency(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) for i := 0; i < 2; i++ { _, err := conn.Prepare(context.Background(), "test", "select 42::integer") if err != nil { t.Fatalf("%d. Unable to prepare statement: %v", i, err) } var n int32 err = conn.QueryRow(context.Background(), "test").Scan(&n) if err != nil { t.Errorf("%d. Executing prepared statement failed: %v", i, err) } if n != int32(42) { t.Errorf("%d. Prepared statement did not return expected value: %v", i, n) } } _, err := conn.Prepare(context.Background(), "test", "select 'fail'::varchar") if err == nil { t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't") return } } func TestPrepareStatementCacheModes(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) tests := []struct { name string buildStatementCache pgx.BuildStatementCacheFunc }{ { name: "disabled", buildStatementCache: nil, }, { name: "prepare", buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModePrepare, 32) }, }, { name: "describe", buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModeDescribe, 32) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { config.BuildStatementCache = tt.buildStatementCache conn := mustConnect(t, config) defer closeConn(t, conn) _, err := conn.Prepare(context.Background(), "test", "select $1::text") require.NoError(t, err) var s string err = conn.QueryRow(context.Background(), "test", "hello").Scan(&s) require.NoError(t, err) require.Equal(t, "hello", s) }) } } func TestListenNotify(t *testing.T) { t.Parallel() listener := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, listener) if listener.PgConn().ParameterStatus("crdb_version") != "" { t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") } mustExec(t, listener, "listen chat") notifier := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, notifier) mustExec(t, notifier, "notify chat") // when notification is waiting on the socket to be read notification, err := listener.WaitForNotification(context.Background()) require.NoError(t, err) assert.Equal(t, "chat", notification.Channel) // when notification has already been read during previous query mustExec(t, notifier, "notify chat") rows, _ := listener.Query(context.Background(), "select 1") rows.Close() require.NoError(t, rows.Err()) ctx, cancelFn := context.WithCancel(context.Background()) cancelFn() notification, err = listener.WaitForNotification(ctx) require.NoError(t, err) assert.Equal(t, "chat", notification.Channel) // when timeout occurs ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() notification, err = listener.WaitForNotification(ctx) assert.True(t, pgconn.Timeout(err)) // listener can listen again after a timeout mustExec(t, notifier, "notify chat") notification, err = listener.WaitForNotification(context.Background()) require.NoError(t, err) assert.Equal(t, "chat", notification.Channel) } func TestListenNotifyWhileBusyIsSafe(t *testing.T) { t.Parallel() func() { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") }() listenerDone := make(chan bool) notifierDone := make(chan bool) go func() { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) defer func() { listenerDone <- true }() mustExec(t, conn, "listen busysafe") for i := 0; i < 5000; i++ { var sum int32 var rowCount int32 rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 100) if err != nil { t.Errorf("conn.Query failed: %v", err) return } for rows.Next() { var n int32 if err := rows.Scan(&n); err != nil { t.Errorf("Row scan failed: %v", err) return } sum += n rowCount++ } if rows.Err() != nil { t.Errorf("conn.Query failed: %v", err) return } if sum != 5050 { t.Errorf("Wrong rows sum: %v", sum) return } if rowCount != 100 { t.Errorf("Wrong number of rows: %v", rowCount) return } time.Sleep(1 * time.Microsecond) } }() go func() { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) defer func() { notifierDone <- true }() for i := 0; i < 100000; i++ { mustExec(t, conn, "notify busysafe, 'hello'") time.Sleep(1 * time.Microsecond) } }() <-listenerDone <-notifierDone } func TestListenNotifySelfNotification(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") mustExec(t, conn, "listen self") // Notify self and WaitForNotification immediately mustExec(t, conn, "notify self") ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() notification, err := conn.WaitForNotification(ctx) require.NoError(t, err) assert.Equal(t, "self", notification.Channel) // Notify self and do something else before WaitForNotification mustExec(t, conn, "notify self") rows, _ := conn.Query(context.Background(), "select 1") rows.Close() if rows.Err() != nil { t.Fatalf("Unexpected error on Query: %v", rows.Err()) } ctx, cncl := context.WithTimeout(context.Background(), time.Second) defer cncl() notification, err = conn.WaitForNotification(ctx) require.NoError(t, err) assert.Equal(t, "self", notification.Channel) } func TestFatalRxError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() var n int32 var s string err := conn.QueryRow(context.Background(), "select 1::int4, pg_sleep(10)::varchar").Scan(&n, &s) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Severity == "FATAL" { } else { t.Errorf("Expected QueryRow Scan to return fatal PgError, but instead received %v", err) return } }() otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer otherConn.Close(context.Background()) if _, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PgConn().PID()); err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } wg.Wait() if !conn.IsClosed() { t.Fatal("Connection should be closed") } } func TestFatalTxError(t *testing.T) { t.Parallel() // Run timing sensitive test many times for i := 0; i < 50; i++ { func() { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)") otherConn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer otherConn.Close(context.Background()) _, err := otherConn.Exec(context.Background(), "select pg_terminate_backend($1)", conn.PgConn().PID()) if err != nil { t.Fatalf("Unable to kill backend PostgreSQL process: %v", err) } err = conn.QueryRow(context.Background(), "select 1").Scan(nil) if err == nil { t.Fatal("Expected error but none occurred") } if !conn.IsClosed() { t.Fatalf("Connection should be closed but isn't. Previous Query err: %v", err) } }() } } func TestInsertBoolArray(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } }) } func TestInsertTimestampArray(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" { t.Error("Unexpected results from Exec") } // Accept parameters if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" { t.Errorf("Unexpected results from Exec: %v", results) } }) } type testLog struct { lvl pgx.LogLevel msg string data map[string]interface{} } type testLogger struct { logs []testLog } func (l *testLogger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { data["ctxdata"] = ctx.Value("ctxdata") l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) } func TestLogPassesContext(t *testing.T) { t.Parallel() l1 := &testLogger{} config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.Logger = l1 conn := mustConnect(t, config) defer closeConn(t, conn) l1.logs = l1.logs[0:0] // Clear logs written when establishing connection ctx := context.WithValue(context.Background(), "ctxdata", "foo") if _, err := conn.Exec(ctx, ";"); err != nil { t.Fatal(err) } if len(l1.logs) != 1 { t.Fatal("Expected logger to be called once, but it wasn't") } if l1.logs[0].data["ctxdata"] != "foo" { t.Fatal("Expected context data to be passed to logger, but it wasn't") } } func TestLoggerFunc(t *testing.T) { t.Parallel() const testMsg = "foo" buf := bytes.Buffer{} logger := log.New(&buf, "", 0) createAdapterFn := func(logger *log.Logger) pgx.LoggerFunc { return func(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { logger.Printf("%s", testMsg) } } config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.Logger = createAdapterFn(logger) conn := mustConnect(t, config) defer closeConn(t, conn) buf.Reset() // Clear logs written when establishing connection if _, err := conn.Exec(context.TODO(), ";"); err != nil { t.Fatal(err) } if strings.TrimSpace(buf.String()) != testMsg { t.Errorf("Expected logger function to return '%s', but it was '%s'", testMsg, buf.String()) } } func TestIdentifierSanitize(t *testing.T) { t.Parallel() tests := []struct { ident pgx.Identifier expected string }{ { ident: pgx.Identifier{`foo`}, expected: `"foo"`, }, { ident: pgx.Identifier{`select`}, expected: `"select"`, }, { ident: pgx.Identifier{`foo`, `bar`}, expected: `"foo"."bar"`, }, { ident: pgx.Identifier{`you should " not do this`}, expected: `"you should "" not do this"`, }, { ident: pgx.Identifier{`you should " not do this`, `please don't`}, expected: `"you should "" not do this"."please don't"`, }, { ident: pgx.Identifier{`you should ` + string([]byte{0}) + `not do this`}, expected: `"you should not do this"`, }, } for i, tt := range tests { qval := tt.ident.Sanitize() if qval != tt.expected { t.Errorf("%d. Expected Sanitize %v to return %v but it was %v", i, tt.ident, tt.expected, qval) } } } func TestConnInitConnInfo(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // spot check that the standard postgres type names aren't qualified nameOIDs := map[string]uint32{ "_int8": pgtype.Int8ArrayOID, "int8": pgtype.Int8OID, "json": pgtype.JSONOID, "text": pgtype.TextOID, } for name, oid := range nameOIDs { dtByName, ok := conn.ConnInfo().DataTypeForName(name) if !ok { t.Fatalf("Expected type named %v to be present", name) } dtByOID, ok := conn.ConnInfo().DataTypeForOID(oid) if !ok { t.Fatalf("Expected type OID %v to be present", oid) } if dtByName != dtByOID { t.Fatalf("Expected type named %v to be the same as type OID %v", name, oid) } } ensureConnValid(t, conn) } func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") var n uint64 err := conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) if err != nil { t.Fatal(err) } if n != 42 { t.Fatalf("Expected n to be 42, but was %v", n) } }) } func TestDomainType(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") var n uint64 // Domain type uint64 is a PostgreSQL domain of underlying type numeric. err := conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) require.NoError(t, err) // A string can be used. But a string cannot be the result because the describe result from the PostgreSQL server gives // the underlying type of numeric. err = conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) if err != nil { t.Fatal(err) } if n != 42 { t.Fatalf("Expected n to be 42, but was %v", n) } var uint64OID uint32 err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID) if err != nil { t.Fatalf("did not find uint64 OID, %v", err) } conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID}) // String is still an acceptable argument after registration err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) if err != nil { t.Fatal(err) } if n != 7 { t.Fatalf("Expected n to be 7, but was %v", n) } // But a uint64 is acceptable err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) if err != nil { t.Fatal(err) } if n != 24 { t.Fatalf("Expected n to be 24, but was %v", n) } }) } func TestStmtCacheInvalidationConn(t *testing.T) { ctx := context.Background() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // create a table and fill it with some data _, err := conn.Exec(ctx, ` DROP TABLE IF EXISTS drop_cols; CREATE TABLE drop_cols ( id SERIAL PRIMARY KEY NOT NULL, f1 int NOT NULL, f2 int NOT NULL ); `) require.NoError(t, err) _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") require.NoError(t, err) getSQL := "SELECT * FROM drop_cols WHERE id = $1" // This query will populate the statement cache. We don't care about the result. rows, err := conn.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Close() // Now, change the schema of the table out from under the statement, making it invalid. _, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") require.NoError(t, err) // We must get an error the first time we try to re-execute a bad statement. // It is up to the application to determine if it wants to try again. We punt to // the application because there is no clear recovery path in the case of failed transactions // or batch operations and because automatic retry is tricky and we don't want to get // it wrong at such an importaint layer of the stack. rows, err = conn.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Next() nextErr := rows.Err() rows.Close() for _, err := range []error{nextErr, rows.Err()} { if err == nil { t.Fatal("expected InvalidCachedStatementPlanError: no error") } if !strings.Contains(err.Error(), "cached plan must not change result type") { t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) } } // On retry, the statement should have been flushed from the cache. rows, err = conn.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Next() err = rows.Err() require.NoError(t, err) rows.Close() require.NoError(t, rows.Err()) ensureConnValid(t, conn) } func TestStmtCacheInvalidationTx(t *testing.T) { ctx := context.Background() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // create a table and fill it with some data _, err := conn.Exec(ctx, ` DROP TABLE IF EXISTS drop_cols; CREATE TABLE drop_cols ( id SERIAL PRIMARY KEY NOT NULL, f1 int NOT NULL, f2 int NOT NULL ); `) require.NoError(t, err) _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") require.NoError(t, err) tx, err := conn.Begin(ctx) require.NoError(t, err) getSQL := "SELECT * FROM drop_cols WHERE id = $1" // This query will populate the statement cache. We don't care about the result. rows, err := tx.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Close() // Now, change the schema of the table out from under the statement, making it invalid. _, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") require.NoError(t, err) // We must get an error the first time we try to re-execute a bad statement. // It is up to the application to determine if it wants to try again. We punt to // the application because there is no clear recovery path in the case of failed transactions // or batch operations and because automatic retry is tricky and we don't want to get // it wrong at such an importaint layer of the stack. rows, err = tx.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Next() nextErr := rows.Err() rows.Close() for _, err := range []error{nextErr, rows.Err()} { if err == nil { t.Fatal("expected InvalidCachedStatementPlanError: no error") } if !strings.Contains(err.Error(), "cached plan must not change result type") { t.Fatalf("expected InvalidCachedStatementPlanError, got: %s", err.Error()) } } rows, err = tx.Query(ctx, getSQL, 1) require.NoError(t, err) // error does not pop up immediately rows.Next() err = rows.Err() // Retries within the same transaction are errors (really anything except a rollbakc // will be an error in this transaction). require.Error(t, err) rows.Close() err = tx.Rollback(ctx) require.NoError(t, err) // once we've rolled back, retries will work rows, err = conn.Query(ctx, getSQL, 1) require.NoError(t, err) rows.Next() err = rows.Err() require.NoError(t, err) rows.Close() ensureConnValid(t, conn) } func TestInsertDurationInterval(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { _, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)") require.NoError(t, err) result, err := conn.Exec(context.Background(), "insert into t(duration) values($1)", time.Minute) require.NoError(t, err) n := result.RowsAffected() require.EqualValues(t, 1, n) }) } pgx-4.18.1/copy_from.go000066400000000000000000000121741437725773200147550ustar00rootroot00000000000000package pgx import ( "bytes" "context" "fmt" "io" "time" "github.com/jackc/pgconn" "github.com/jackc/pgio" ) // CopyFromRows returns a CopyFromSource interface over the provided rows slice // making it usable by *Conn.CopyFrom. func CopyFromRows(rows [][]interface{}) CopyFromSource { return ©FromRows{rows: rows, idx: -1} } type copyFromRows struct { rows [][]interface{} idx int } func (ctr *copyFromRows) Next() bool { ctr.idx++ return ctr.idx < len(ctr.rows) } func (ctr *copyFromRows) Values() ([]interface{}, error) { return ctr.rows[ctr.idx], nil } func (ctr *copyFromRows) Err() error { return nil } // CopyFromSlice returns a CopyFromSource interface over a dynamic func // making it usable by *Conn.CopyFrom. func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource { return ©FromSlice{next: next, idx: -1, len: length} } type copyFromSlice struct { next func(int) ([]interface{}, error) idx int len int err error } func (cts *copyFromSlice) Next() bool { cts.idx++ return cts.idx < cts.len } func (cts *copyFromSlice) Values() ([]interface{}, error) { values, err := cts.next(cts.idx) if err != nil { cts.err = err } return values, err } func (cts *copyFromSlice) Err() error { return cts.err } // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. type CopyFromSource interface { // Next returns true if there is another row and makes the next row data // available to Values(). When there are no more rows available or an error // has occurred it returns false. Next() bool // Values returns the values for the current row. Values() ([]interface{}, error) // Err returns any error that has been encountered by the CopyFromSource. If // this is not nil *Conn.CopyFrom will abort the copy. Err() error } type copyFrom struct { conn *Conn tableName Identifier columnNames []string rowSrc CopyFromSource readerErrChan chan error } func (ct *copyFrom) run(ctx context.Context) (int64, error) { quotedTableName := ct.tableName.Sanitize() cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { if i != 0 { cbuf.WriteString(", ") } cbuf.WriteString(quoteIdentifier(cn)) } quotedColumnNames := cbuf.String() sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) if err != nil { return 0, err } r, w := io.Pipe() doneChan := make(chan struct{}) go func() { defer close(doneChan) // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. buf := ct.conn.wbuf buf = append(buf, "PGCOPY\n\377\r\n\000"...) buf = pgio.AppendInt32(buf, 0) buf = pgio.AppendInt32(buf, 0) moreRows := true for moreRows { var err error moreRows, buf, err = ct.buildCopyBuf(buf, sd) if err != nil { w.CloseWithError(err) return } if ct.rowSrc.Err() != nil { w.CloseWithError(ct.rowSrc.Err()) return } if len(buf) > 0 { _, err = w.Write(buf) if err != nil { w.Close() return } } buf = buf[:0] } w.Close() }() startTime := time.Now() commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) r.Close() <-doneChan rowsAffected := commandTag.RowsAffected() endTime := time.Now() if err == nil { if ct.conn.shouldLog(LogLevelInfo) { ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected}) } } else if ct.conn.shouldLog(LogLevelError) { ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)}) } return rowsAffected, err } func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { for ct.rowSrc.Next() { values, err := ct.rowSrc.Values() if err != nil { return false, nil, err } if len(values) != len(ct.columnNames) { return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) } buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val) if err != nil { return false, nil, err } } if len(buf) > 65536 { return true, buf, nil } } return false, buf, nil } // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. // It returns the number of rows copied and an error. // // CopyFrom requires all values use the binary format. Almost all types // implemented by pgx use the binary format by default. Types implementing // Encoder can only be used if they encode to the binary format. func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { ct := ©From{ conn: c, tableName: tableName, columnNames: columnNames, rowSrc: rowSrc, readerErrChan: make(chan error), } return ct.run(ctx) } pgx-4.18.1/copy_from_test.go000066400000000000000000000353531437725773200160200ustar00rootroot00000000000000package pgx_test import ( "context" "fmt" "os" "reflect" "testing" "time" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/require" ) func TestConnCopyFromSmall(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( a int2, b int4, c int8, d varchar, e text, f date, g timestamptz )`) tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) inputRows := [][]interface{}{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } rows, err := conn.Query(context.Background(), "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } if rows.Err() != nil { t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) } if !reflect.DeepEqual(inputRows, outputRows) { t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) } ensureConnValid(t, conn) } func TestConnCopyFromSliceSmall(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( a int2, b int4, c int8, d varchar, e text, f date, g timestamptz )`) tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) inputRows := [][]interface{}{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromSlice(len(inputRows), func(i int) ([]interface{}, error) { return inputRows[i], nil })) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } rows, err := conn.Query(context.Background(), "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } if rows.Err() != nil { t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) } if !reflect.DeepEqual(inputRows, outputRows) { t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) } ensureConnValid(t, conn) } func TestConnCopyFromLarge(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)") mustExec(t, conn, `create temporary table foo( a int2, b int4, c int8, d varchar, e text, f date, g timestamptz, h bytea )`) tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) inputRows := [][]interface{}{} for i := 0; i < 10000; i++ { inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) } copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } rows, err := conn.Query(context.Background(), "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } if rows.Err() != nil { t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) } if !reflect.DeepEqual(inputRows, outputRows) { t.Errorf("Input rows and output rows do not equal") } ensureConnValid(t, conn) } func TestConnCopyFromEnum(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx := context.Background() tx, err := conn.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) _, err = tx.Exec(ctx, `drop type if exists color`) require.NoError(t, err) _, err = tx.Exec(ctx, `drop type if exists fruit`) require.NoError(t, err) _, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`) require.NoError(t, err) _, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`) require.NoError(t, err) _, err = tx.Exec(ctx, `create table foo( a text, b color, c fruit, d color, e fruit, f text )`) require.NoError(t, err) inputRows := [][]interface{}{ {"abc", "blue", "grape", "orange", "orange", "def"}, {nil, nil, nil, nil, nil, nil}, } copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows)) require.NoError(t, err) require.EqualValues(t, len(inputRows), copyCount) rows, err := conn.Query(ctx, "select * from foo") require.NoError(t, err) var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() require.NoError(t, err) outputRows = append(outputRows, row) } require.NoError(t, rows.Err()) if !reflect.DeepEqual(inputRows, outputRows) { t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) } ensureConnValid(t, conn) } func TestConnCopyFromJSON(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) for _, typeName := range []string{"json", "jsonb"} { if _, ok := conn.ConnInfo().DataTypeForName(typeName); !ok { return // No JSON/JSONB type -- must be running against old PostgreSQL } } mustExec(t, conn, `create temporary table foo( a json, b jsonb )`) inputRows := [][]interface{}{ {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, {nil, nil}, } copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } rows, err := conn.Query(context.Background(), "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } if rows.Err() != nil { t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) } if !reflect.DeepEqual(inputRows, outputRows) { t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) } ensureConnValid(t, conn) } type clientFailSource struct { count int err error } func (cfs *clientFailSource) Next() bool { cfs.count++ return cfs.count < 100 } func (cfs *clientFailSource) Values() ([]interface{}, error) { if cfs.count == 3 { cfs.err = fmt.Errorf("client error") return nil, cfs.err } return []interface{}{make([]byte, 100000)}, nil } func (cfs *clientFailSource) Err() error { return cfs.err } func TestConnCopyFromFailServerSideMidway(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( a int4, b varchar not null )`) inputRows := [][]interface{}{ {int32(1), "abc"}, {int32(2), nil}, // this row should trigger a failure {int32(3), "def"}, } copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } if _, ok := err.(*pgconn.PgError); !ok { t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) } if copyCount != 0 { t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) } rows, err := conn.Query(context.Background(), "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } if rows.Err() != nil { t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) } if len(outputRows) != 0 { t.Errorf("Expected 0 rows, but got %v", outputRows) } mustExec(t, conn, "truncate foo") ensureConnValid(t, conn) } type failSource struct { count int } func (fs *failSource) Next() bool { time.Sleep(time.Millisecond * 100) fs.count++ return fs.count < 100 } func (fs *failSource) Values() ([]interface{}, error) { if fs.count == 3 { return []interface{}{nil}, nil } return []interface{}{make([]byte, 100000)}, nil } func (fs *failSource) Err() error { return nil } func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( a bytea not null )`) startTime := time.Now() copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } if _, ok := err.(*pgconn.PgError); !ok { t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) } if copyCount != 0 { t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) } endTime := time.Now() copyTime := endTime.Sub(startTime) if copyTime > time.Second { t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime) } rows, err := conn.Query(context.Background(), "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } if rows.Err() != nil { t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) } if len(outputRows) != 0 { t.Errorf("Expected 0 rows, but got %v", outputRows) } ensureConnValid(t, conn) } type slowFailRaceSource struct { count int } func (fs *slowFailRaceSource) Next() bool { time.Sleep(time.Millisecond) fs.count++ return fs.count < 1000 } func (fs *slowFailRaceSource) Values() ([]interface{}, error) { if fs.count == 500 { return []interface{}{nil, nil}, nil } return []interface{}{1, make([]byte, 1000)}, nil } func (fs *slowFailRaceSource) Err() error { return nil } func TestConnCopyFromSlowFailRace(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( a int not null, b bytea not null )`) copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } if _, ok := err.(*pgconn.PgError); !ok { t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) } if copyCount != 0 { t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) } ensureConnValid(t, conn) } func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( a bytea not null )`) copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } if copyCount != 0 { t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) } rows, err := conn.Query(context.Background(), "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } if rows.Err() != nil { t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) } if len(outputRows) != 0 { t.Errorf("Expected 0 rows, but got %v", len(outputRows)) } ensureConnValid(t, conn) } type clientFinalErrSource struct { count int } func (cfs *clientFinalErrSource) Next() bool { cfs.count++ return cfs.count < 5 } func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { return []interface{}{make([]byte, 100000)}, nil } func (cfs *clientFinalErrSource) Err() error { return fmt.Errorf("final error") } func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, `create temporary table foo( a bytea not null )`) copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) if err == nil { t.Errorf("Expected CopyFrom return error, but it did not") } if copyCount != 0 { t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) } rows, err := conn.Query(context.Background(), "select * from foo") if err != nil { t.Errorf("Unexpected error for Query: %v", err) } var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } if rows.Err() != nil { t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) } if len(outputRows) != 0 { t.Errorf("Expected 0 rows, but got %v", outputRows) } ensureConnValid(t, conn) } pgx-4.18.1/doc.go000066400000000000000000000266241437725773200135320ustar00rootroot00000000000000// Package pgx is a PostgreSQL database driver. /* pgx provides lower level access to PostgreSQL than the standard database/sql. It remains as similar to the database/sql interface as possible while providing better speed and access to PostgreSQL specific features. Import github.com/jackc/pgx/v4/stdlib to use pgx as a database/sql compatible driver. Establishing a Connection The primary way of establishing a connection is with `pgx.Connect`. conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with `ConnectConfig`. config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL")) if err != nil { // ... } config.Logger = log15adapter.NewLogger(log.New("module", "pgx")) conn, err := pgx.ConnectConfig(context.Background(), config) Connection Pool `*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use sub-package pgxpool for a concurrency safe connection pool. Query Interface pgx implements Query and Scan in the familiar database/sql style. var sum int32 // Send the query to the server. The returned rows MUST be closed // before conn can be used again. rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { return err } // rows.Close is called by rows.Next when all rows are read // or an error occurs in Next or Scan. So it may optionally be // omitted if nothing in the rows.Next loop can panic. It is // safe to close rows multiple times. defer rows.Close() // Iterate through the result set for rows.Next() { var n int32 err = rows.Scan(&n) if err != nil { return err } sum += n } // Any errors encountered by rows.Next or rows.Scan will be returned here if rows.Err() != nil { return rows.Err() } // No errors found - do something with sum pgx also implements QueryRow in the same style as database/sql. var name string var weight int64 err := conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight) if err != nil { return err } Use Exec to execute a query that does not return a result set. commandTag, err := conn.Exec(context.Background(), "delete from widgets where id=$1", 42) if err != nil { return err } if commandTag.RowsAffected() != 1 { return errors.New("No row found to delete") } QueryFunc can be used to execute a callback function for every row. This is often easier to use than Query. var sum, n int32 _, err = conn.QueryFunc( context.Background(), "select generate_series(1,$1)", []interface{}{10}, []interface{}{&n}, func(pgx.QueryFuncRow) error { sum += n return nil }, ) if err != nil { return err } Base Type Mapping pgx maps between all common base types directly between Go and PostgreSQL. In particular: Go PostgreSQL ----------------------- string varchar text // Integers are automatically be converted to any other integer type if // it can be done without overflow or underflow. int8 int16 smallint int32 int int64 bigint int uint8 uint16 uint32 uint64 uint // Floats are strict and do not automatically convert like integers. float32 float4 float64 float8 time.Time date timestamp timestamptz []byte bytea Null Mapping pgx can map nulls in two ways. The first is package pgtype provides types that have a data field and a status field. They work in a similar fashion to database/sql. The second is to use a pointer to a pointer. var foo pgtype.Varchar var bar *string err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar) if err != nil { return err } Array Mapping pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type. Go slices of native types do not support nulls, so if a PostgreSQL array that contains a null is read into a native Go slice an error will occur. The pgtype package includes many more array types for PostgreSQL types that do not directly map to native Go types. JSON and JSONB Mapping pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB. Inet and CIDR Mapping pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In addition, as a convenience pgx will encode from a net.IP; it will assume a /32 netmask for IPv4 and a /128 for IPv6. Custom Type Support pgx includes support for the common data types like integers, floats, strings, dates, and times that have direct mappings between Go and SQL. In addition, pgx uses the github.com/jackc/pgtype library to support more types. See documention for that library for instructions on how to implement custom types. See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. pgx also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer interfaces. If pgx does cannot natively encode a type and that type is a renamed type (e.g. type MyTime time.Time) pgx will attempt to encode the underlying type. While this is usually desired behavior it can produce surprising behavior if one the underlying type and the renamed type each implement database/sql interfaces and the other implements pgx interfaces. It is recommended that this situation be avoided by implementing pgx interfaces on the renamed type. Composite types and row values Row values and composite types are represented as pgtype.Record (https://pkg.go.dev/github.com/jackc/pgtype?tab=doc#Record). It is possible to get values of your custom type by implementing DecodeBinary interface. Decoding into pgtype.Record first can simplify process by avoiding dealing with raw protocol directly. For example: type MyType struct { a int // NULL will cause decoding error b *string // there can be NULL in this position in SQL } func (t *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { r := pgtype.Record{ Fields: []pgtype.Value{&pgtype.Int4{}, &pgtype.Text{}}, } if err := r.DecodeBinary(ci, src); err != nil { return err } if r.Status != pgtype.Present { return errors.New("BUG: decoding should not be called on NULL value") } a := r.Fields[0].(*pgtype.Int4) b := r.Fields[1].(*pgtype.Text) // type compatibility is checked by AssignTo // only lossless assignments will succeed if err := a.AssignTo(&t.a); err != nil { return err } // AssignTo also deals with null value handling if err := b.AssignTo(&t.b); err != nil { return err } return nil } result := MyType{} err := conn.QueryRow(context.Background(), "select row(1, 'foo'::text)", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&r) Raw Bytes Mapping []byte passed as arguments to Query, QueryRow, and Exec are passed unmodified to PostgreSQL. Transactions Transactions are started by calling Begin. tx, err := conn.Begin(context.Background()) if err != nil { return err } // Rollback is safe to call even if the tx is already closed, so if // the tx commits successfully, this is a no-op defer tx.Rollback(context.Background()) _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") if err != nil { return err } err = tx.Commit(context.Background()) if err != nil { return err } The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions. These are internally implemented with savepoints. Use BeginTx to control the transaction mode. BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the transaction depending on the return value of the function. These can be simpler and less error prone to use. err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") return err }) if err != nil { return err } Prepared Statements Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are automatically prepared on first execution and the prepared statement is reused on subsequent executions. See ParseConfig for information on how to customize or disable the statement cache. Copy Protocol Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a CopyFromSource interface. If the data is already in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource interface. Or implement CopyFromSource to avoid buffering the entire data set in memory. rows := [][]interface{}{ {"John", "Smith", int32(36)}, {"Jane", "Doe", int32(29)}, } copyCount, err := conn.CopyFrom( context.Background(), pgx.Identifier{"people"}, []string{"first_name", "last_name", "age"}, pgx.CopyFromRows(rows), ) When you already have a typed array using CopyFromSlice can be more convenient. rows := []User{ {"John", "Smith", 36}, {"Jane", "Doe", 29}, } copyCount, err := conn.CopyFrom( context.Background(), pgx.Identifier{"people"}, []string{"first_name", "last_name", "age"}, pgx.CopyFromSlice(len(rows), func(i int) ([]interface{}, error) { return []interface{}{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil }), ) CopyFrom can be faster than an insert with as few as 5 rows. Listen and Notify pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a notification is received or the context is canceled. _, err := conn.Exec(context.Background(), "listen channelname") if err != nil { return nil } if notification, err := conn.WaitForNotification(context.Background()); err != nil { // do something with notification } Logging pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus, go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory. Lower Level PostgreSQL Functionality pgx is implemented on top of github.com/jackc/pgconn a lower level PostgreSQL driver. The Conn.PgConn() method can be used to access this lower layer. PgBouncer pgx is compatible with PgBouncer in two modes. One is when the connection has a statement cache in "describe" mode. The other is when the connection is using the simple protocol. This can be set with the PreferSimpleProtocol config option. */ package pgx pgx-4.18.1/example_custom_type_test.go000066400000000000000000000045741437725773200201120ustar00rootroot00000000000000package pgx_test import ( "context" "fmt" "os" "regexp" "strconv" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) // Point represents a point that may be null. type Point struct { X, Y float64 // Coordinates of point Status pgtype.Status } func (dst *Point) Set(src interface{}) error { return fmt.Errorf("cannot convert %v to Point", src) } func (dst *Point) Get() interface{} { switch dst.Status { case pgtype.Present: return dst case pgtype.Null: return nil default: return dst.Status } } func (src *Point) AssignTo(dst interface{}) error { return fmt.Errorf("cannot assign %v to %T", src, dst) } func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { if src == nil { *dst = Point{Status: pgtype.Null} return nil } s := string(src) match := pointRegexp.FindStringSubmatch(s) if match == nil { return fmt.Errorf("Received invalid point: %v", s) } x, err := strconv.ParseFloat(match[1], 64) if err != nil { return fmt.Errorf("Received invalid point: %v", s) } y, err := strconv.ParseFloat(match[2], 64) if err != nil { return fmt.Errorf("Received invalid point: %v", s) } *dst = Point{X: x, Y: y, Status: pgtype.Present} return nil } func (src *Point) String() string { if src.Status == pgtype.Null { return "null point" } return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) } func Example_CustomType() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { fmt.Printf("Unable to establish connection: %v", err) return } defer conn.Close(context.Background()) if conn.PgConn().ParameterStatus("crdb_version") != "" { // Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be // skipped fake success instead. fmt.Println("null point") fmt.Println("1.5, 2.5") return } // Override registered handler for point conn.ConnInfo().RegisterDataType(pgtype.DataType{ Value: &Point{}, Name: "point", OID: 600, }) p := &Point{} err = conn.QueryRow(context.Background(), "select null::point").Scan(p) if err != nil { fmt.Println(err) return } fmt.Println(p) err = conn.QueryRow(context.Background(), "select point(1.5,2.5)").Scan(p) if err != nil { fmt.Println(err) return } fmt.Println(p) // Output: // null point // 1.5, 2.5 } pgx-4.18.1/example_json_test.go000066400000000000000000000011511437725773200164740ustar00rootroot00000000000000package pgx_test import ( "context" "fmt" "os" "github.com/jackc/pgx/v4" ) func Example_JSON() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { fmt.Printf("Unable to establish connection: %v", err) return } type person struct { Name string `json:"name"` Age int `json:"age"` } input := person{ Name: "John", Age: 42, } var output person err = conn.QueryRow(context.Background(), "select $1::json", input).Scan(&output) if err != nil { fmt.Println(err) return } fmt.Println(output.Name, output.Age) // Output: // John 42 } pgx-4.18.1/examples/000077500000000000000000000000001437725773200142425ustar00rootroot00000000000000pgx-4.18.1/examples/README.md000066400000000000000000000005721437725773200155250ustar00rootroot00000000000000# Examples * chat is a command line chat program using listen/notify. * todo is a command line todo list that demonstrates basic CRUD actions. * url_shortener contains a simple example of using pgx in a web context. * [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx. * [The Pithy Reader](https://github.com/jackc/tpr) is a RSS aggregator that uses pgx. pgx-4.18.1/examples/chat/000077500000000000000000000000001437725773200151615ustar00rootroot00000000000000pgx-4.18.1/examples/chat/README.md000066400000000000000000000010471437725773200164420ustar00rootroot00000000000000# Description This is a sample chat program implemented using PostgreSQL's listen/notify functionality with pgx. Start multiple instances of this program connected to the same database to chat between them. ## Connection configuration The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.) You can either export them then run chat: export PGHOST=/private/tmp ./chat Or you can prefix the chat execution with the environment variables: PGHOST=/private/tmp ./chat pgx-4.18.1/examples/chat/main.go000066400000000000000000000030361437725773200164360ustar00rootroot00000000000000package main import ( "bufio" "context" "fmt" "os" "github.com/jackc/pgx/v4/pgxpool" ) var pool *pgxpool.Pool func main() { var err error pool, err = pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) os.Exit(1) } go listen() fmt.Println(`Type a message and press enter. This message should appear in any other chat instances connected to the same database. Type "exit" to quit.`) scanner := bufio.NewScanner(os.Stdin) for scanner.Scan() { msg := scanner.Text() if msg == "exit" { os.Exit(0) } _, err = pool.Exec(context.Background(), "select pg_notify('chat', $1)", msg) if err != nil { fmt.Fprintln(os.Stderr, "Error sending notification:", err) os.Exit(1) } } if err := scanner.Err(); err != nil { fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) os.Exit(1) } } func listen() { conn, err := pool.Acquire(context.Background()) if err != nil { fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) os.Exit(1) } defer conn.Release() _, err = conn.Exec(context.Background(), "listen chat") if err != nil { fmt.Fprintln(os.Stderr, "Error listening to chat channel:", err) os.Exit(1) } for { notification, err := conn.Conn().WaitForNotification(context.Background()) if err != nil { fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) os.Exit(1) } fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) } } pgx-4.18.1/examples/todo/000077500000000000000000000000001437725773200152075ustar00rootroot00000000000000pgx-4.18.1/examples/todo/README.md000066400000000000000000000035141437725773200164710ustar00rootroot00000000000000# Description This is a sample todo list implemented using pgx as the connector to a PostgreSQL data store. # Usage Create a PostgreSQL database and run structure.sql into it to create the necessary data schema. Example: createdb todo psql todo < structure.sql Build todo: go build ## Connection configuration The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.) You can either export them then run todo: export PGDATABASE=todo ./todo list Or you can prefix the todo execution with the environment variables: PGDATABASE=todo ./todo list ## Add a todo item ./todo add 'Learn go' ## List tasks ./todo list ## Update a task ./todo update 1 'Learn more go' ## Delete a task ./todo remove 1 # Example Setup and Execution jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ createdb todo jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ psql todo < structure.sql Expanded display is used automatically. Timing is on. CREATE TABLE Time: 6.363 ms jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ go build jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ export PGDATABASE=todo jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo add 'Learn Go' jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 1. Learn Go jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo update 1 'Learn more Go' jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 1. Learn more Go jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo remove 1 jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list pgx-4.18.1/examples/todo/main.go000066400000000000000000000044331437725773200164660ustar00rootroot00000000000000package main import ( "context" "fmt" "os" "strconv" "github.com/jackc/pgx/v4" ) var conn *pgx.Conn func main() { var err error conn, err = pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { fmt.Fprintf(os.Stderr, "Unable to connection to database: %v\n", err) os.Exit(1) } if len(os.Args) == 1 { printHelp() os.Exit(0) } switch os.Args[1] { case "list": err = listTasks() if err != nil { fmt.Fprintf(os.Stderr, "Unable to list tasks: %v\n", err) os.Exit(1) } case "add": err = addTask(os.Args[2]) if err != nil { fmt.Fprintf(os.Stderr, "Unable to add task: %v\n", err) os.Exit(1) } case "update": n, err := strconv.ParseInt(os.Args[2], 10, 32) if err != nil { fmt.Fprintf(os.Stderr, "Unable convert task_num into int32: %v\n", err) os.Exit(1) } err = updateTask(int32(n), os.Args[3]) if err != nil { fmt.Fprintf(os.Stderr, "Unable to update task: %v\n", err) os.Exit(1) } case "remove": n, err := strconv.ParseInt(os.Args[2], 10, 32) if err != nil { fmt.Fprintf(os.Stderr, "Unable convert task_num into int32: %v\n", err) os.Exit(1) } err = removeTask(int32(n)) if err != nil { fmt.Fprintf(os.Stderr, "Unable to remove task: %v\n", err) os.Exit(1) } default: fmt.Fprintln(os.Stderr, "Invalid command") printHelp() os.Exit(1) } } func listTasks() error { rows, _ := conn.Query(context.Background(), "select * from tasks") for rows.Next() { var id int32 var description string err := rows.Scan(&id, &description) if err != nil { return err } fmt.Printf("%d. %s\n", id, description) } return rows.Err() } func addTask(description string) error { _, err := conn.Exec(context.Background(), "insert into tasks(description) values($1)", description) return err } func updateTask(itemNum int32, description string) error { _, err := conn.Exec(context.Background(), "update tasks set description=$1 where id=$2", description, itemNum) return err } func removeTask(itemNum int32) error { _, err := conn.Exec(context.Background(), "delete from tasks where id=$1", itemNum) return err } func printHelp() { fmt.Print(`Todo pgx demo Usage: todo list todo add task todo update task_num item todo remove task_num Example: todo add 'Learn Go' todo list `) } pgx-4.18.1/examples/todo/structure.sql000066400000000000000000000001151437725773200177650ustar00rootroot00000000000000create table tasks ( id serial primary key, description text not null ); pgx-4.18.1/examples/url_shortener/000077500000000000000000000000001437725773200171355ustar00rootroot00000000000000pgx-4.18.1/examples/url_shortener/README.md000066400000000000000000000012021437725773200204070ustar00rootroot00000000000000# Description This is a sample REST URL shortener service implemented using pgx as the connector to a PostgreSQL data store. # Usage Create a PostgreSQL database and run structure.sql into it to create the necessary data schema. Configure the database connection with `DATABASE_URL` or standard PostgreSQL (`PG*`) environment variables or Run main.go: ``` go run main.go ``` ## Create or Update a Shortened URL ``` curl -X PUT -d 'http://www.google.com' http://localhost:8080/google ``` ## Get a Shortened URL ``` curl http://localhost:8080/google ``` ## Delete a Shortened URL ``` curl -X DELETE http://localhost:8080/google ``` pgx-4.18.1/examples/url_shortener/main.go000066400000000000000000000046521437725773200204170ustar00rootroot00000000000000package main import ( "context" "io/ioutil" "net/http" "os" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/log/log15adapter" "github.com/jackc/pgx/v4/pgxpool" log "gopkg.in/inconshreveable/log15.v2" ) var db *pgxpool.Pool func getUrlHandler(w http.ResponseWriter, req *http.Request) { var url string err := db.QueryRow(context.Background(), "select url from shortened_urls where id=$1", req.URL.Path).Scan(&url) switch err { case nil: http.Redirect(w, req, url, http.StatusSeeOther) case pgx.ErrNoRows: http.NotFound(w, req) default: http.Error(w, "Internal server error", http.StatusInternalServerError) } } func putUrlHandler(w http.ResponseWriter, req *http.Request) { id := req.URL.Path var url string if body, err := ioutil.ReadAll(req.Body); err == nil { url = string(body) } else { http.Error(w, "Internal server error", http.StatusInternalServerError) return } if _, err := db.Exec(context.Background(), `insert into shortened_urls(id, url) values ($1, $2) on conflict (id) do update set url=excluded.url`, id, url); err == nil { w.WriteHeader(http.StatusOK) } else { http.Error(w, "Internal server error", http.StatusInternalServerError) } } func deleteUrlHandler(w http.ResponseWriter, req *http.Request) { if _, err := db.Exec(context.Background(), "delete from shortened_urls where id=$1", req.URL.Path); err == nil { w.WriteHeader(http.StatusOK) } else { http.Error(w, "Internal server error", http.StatusInternalServerError) } } func urlHandler(w http.ResponseWriter, req *http.Request) { switch req.Method { case "GET": getUrlHandler(w, req) case "PUT": putUrlHandler(w, req) case "DELETE": deleteUrlHandler(w, req) default: w.Header().Add("Allow", "GET, PUT, DELETE") w.WriteHeader(http.StatusMethodNotAllowed) } } func main() { logger := log15adapter.NewLogger(log.New("module", "pgx")) poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) if err != nil { log.Crit("Unable to parse DATABASE_URL", "error", err) os.Exit(1) } poolConfig.ConnConfig.Logger = logger db, err = pgxpool.ConnectConfig(context.Background(), poolConfig) if err != nil { log.Crit("Unable to create connection pool", "error", err) os.Exit(1) } http.HandleFunc("/", urlHandler) log.Info("Starting URL shortener on localhost:8080") err = http.ListenAndServe("localhost:8080", nil) if err != nil { log.Crit("Unable to start web server", "error", err) os.Exit(1) } } pgx-4.18.1/examples/url_shortener/structure.sql000066400000000000000000000001131437725773200217110ustar00rootroot00000000000000create table shortened_urls ( id text primary key, url text not null );pgx-4.18.1/extended_query_builder.go000066400000000000000000000075171437725773200175200ustar00rootroot00000000000000package pgx import ( "database/sql/driver" "fmt" "reflect" "github.com/jackc/pgtype" ) type extendedQueryBuilder struct { paramValues [][]byte paramValueBytes []byte paramFormats []int16 resultFormats []int16 } func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error { f := chooseParameterFormatCode(ci, oid, arg) eqb.paramFormats = append(eqb.paramFormats, f) v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg) if err != nil { return err } eqb.paramValues = append(eqb.paramValues, v) return nil } func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { eqb.resultFormats = append(eqb.resultFormats, f) } // Reset readies eqb to build another query. func (eqb *extendedQueryBuilder) Reset() { eqb.paramValues = eqb.paramValues[0:0] eqb.paramValueBytes = eqb.paramValueBytes[0:0] eqb.paramFormats = eqb.paramFormats[0:0] eqb.resultFormats = eqb.resultFormats[0:0] if cap(eqb.paramValues) > 64 { eqb.paramValues = make([][]byte, 0, 64) } if cap(eqb.paramValueBytes) > 256 { eqb.paramValueBytes = make([]byte, 0, 256) } if cap(eqb.paramFormats) > 64 { eqb.paramFormats = make([]int16, 0, 64) } if cap(eqb.resultFormats) > 64 { eqb.resultFormats = make([]int16, 0, 64) } } func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { if arg == nil { return nil, nil } refVal := reflect.ValueOf(arg) argIsPtr := refVal.Kind() == reflect.Ptr if argIsPtr && refVal.IsNil() { return nil, nil } if eqb.paramValueBytes == nil { eqb.paramValueBytes = make([]byte, 0, 128) } var err error var buf []byte pos := len(eqb.paramValueBytes) if arg, ok := arg.(string); ok { return []byte(arg), nil } if formatCode == TextFormatCode { if arg, ok := arg.(pgtype.TextEncoder); ok { buf, err = arg.EncodeText(ci, eqb.paramValueBytes) if err != nil { return nil, err } if buf == nil { return nil, nil } eqb.paramValueBytes = buf return eqb.paramValueBytes[pos:], nil } } else if formatCode == BinaryFormatCode { if arg, ok := arg.(pgtype.BinaryEncoder); ok { buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) if err != nil { return nil, err } if buf == nil { return nil, nil } eqb.paramValueBytes = buf return eqb.paramValueBytes[pos:], nil } } if argIsPtr { // We have already checked that arg is not pointing to nil, // so it is safe to dereference here. arg = refVal.Elem().Interface() return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) } if dt, ok := ci.DataTypeForOID(oid); ok { value := dt.Value err := value.Set(arg) if err != nil { { if arg, ok := arg.(driver.Valuer); ok { v, err := callValuerValue(arg) if err != nil { return nil, err } return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) } } return nil, err } return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) } // There is no data type registered for the destination OID, but maybe there is data type registered for the arg // type. If so use it's text encoder (if available). if dt, ok := ci.DataTypeForValue(arg); ok { value := dt.Value if textEncoder, ok := value.(pgtype.TextEncoder); ok { err := value.Set(arg) if err != nil { return nil, err } buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes) if err != nil { return nil, err } if buf == nil { return nil, nil } eqb.paramValueBytes = buf return eqb.paramValueBytes[pos:], nil } } if strippedArg, ok := stripNamedType(&refVal); ok { return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) } return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } pgx-4.18.1/go.mod000066400000000000000000000011321437725773200135270ustar00rootroot00000000000000module github.com/jackc/pgx/v4 go 1.13 require ( github.com/Masterminds/semver/v3 v3.1.1 github.com/cockroachdb/apd v1.1.0 github.com/go-kit/log v0.1.0 github.com/gofrs/uuid v4.0.0+incompatible github.com/jackc/pgconn v1.14.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgproto3/v2 v2.3.2 github.com/jackc/pgtype v1.14.0 github.com/jackc/puddle v1.3.0 github.com/rs/zerolog v1.15.0 github.com/shopspring/decimal v1.2.0 github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.8.1 go.uber.org/zap v1.13.0 gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec ) pgx-4.18.1/go.sum000066400000000000000000000540361437725773200135670ustar00rootroot00000000000000github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-kit/log v0.1.0 h1:DGJh0Sm43HbOeYDNnVZFl8BvcYVvjD5bqYJvp0REbwQ= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgconn v1.13.0 h1:3L1XMNV2Zvca/8BYhzcRFS70Lr0WlDg16Di6SFGAbys= github.com/jackc/pgconn v1.13.0/go.mod h1:AnowpAqO4CMIIJNZl2VJp+KrkAZciAkhEl0W0JIobpI= github.com/jackc/pgconn v1.14.0 h1:vrbA9Ud87g6JdFWkHTJXppVce58qPIdP7N8y0Ml/A7Q= github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y4Y= github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0= github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= github.com/jackc/pgtype v1.12.0 h1:Dlq8Qvcch7kiehm8wPGIW0W3KsCCHJnRacKW0UM8n5w= github.com/jackc/pgtype v1.12.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1 h1:gI8os0wpRXFd4FiAY2dWiqRK037tjj3t7rKFeO4X5iw= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200103221440-774c71fcf114 h1:DnSr2mCsxyCE6ZgIkmcWUQY2R5cH/6wL7eIxEmQOMSE= golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec h1:RlWgLqCMMIYYEVcAR5MDsuHlVkaIPDAF+5Dehzg8L5A= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= pgx-4.18.1/go_stdlib.go000066400000000000000000000047461437725773200147340ustar00rootroot00000000000000package pgx import ( "database/sql/driver" "reflect" ) // This file contains code copied from the Go standard library due to the // required function not being public. // Copyright (c) 2009 The Go Authors. All rights reserved. // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are // met: // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above // copyright notice, this list of conditions and the following disclaimer // in the documentation and/or other materials provided with the // distribution. // * Neither the name of Google Inc. nor the names of its // contributors may be used to endorse or promote products derived from // this software without specific prior written permission. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // From database/sql/convert.go var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() // callValuerValue returns vr.Value(), with one exception: // If vr.Value is an auto-generated method on a pointer type and the // pointer is nil, it would panic at runtime in the panicwrap // method. Treat it like nil instead. // Issue 8415. // // This is so people can implement driver.Value on value types and // still use nil pointers to those types to mean nil/NULL, just like // string/*string. // // This function is mirrored in the database/sql/driver package. func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && rv.IsNil() && rv.Type().Elem().Implements(valuerReflectType) { return nil, nil } return vr.Value() } pgx-4.18.1/helper_test.go000066400000000000000000000132721437725773200152760ustar00rootroot00000000000000package pgx_test import ( "context" "os" "testing" "github.com/stretchr/testify/assert" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/require" ) func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { t.Run("SimpleProto", func(t *testing.T) { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.PreferSimpleProtocol = true conn, err := pgx.ConnectConfig(context.Background(), config) require.NoError(t, err) defer func() { err := conn.Close(context.Background()) require.NoError(t, err) }() f(t, conn) ensureConnValid(t, conn) }, ) t.Run("DefaultProto", func(t *testing.T) { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) conn, err := pgx.ConnectConfig(context.Background(), config) require.NoError(t, err) defer func() { err := conn.Close(context.Background()) require.NoError(t, err) }() f(t, conn) ensureConnValid(t, conn) }, ) } func mustConnectString(t testing.TB, connString string) *pgx.Conn { conn, err := pgx.Connect(context.Background(), connString) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } return conn } func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig { config, err := pgx.ParseConfig(connString) require.Nil(t, err) return config } func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn { conn, err := pgx.ConnectConfig(context.Background(), config) if err != nil { t.Fatalf("Unable to establish connection: %v", err) } return conn } func closeConn(t testing.TB, conn *pgx.Conn) { err := conn.Close(context.Background()) if err != nil { t.Fatalf("conn.Close unexpectedly failed: %v", err) } } func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) { var err error if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil { t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err) } return } // Do a simple query to ensure the connection is still usable func ensureConnValid(t *testing.T, conn *pgx.Conn) { var sum, rowCount int32 rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } defer rows.Close() for rows.Next() { var n int32 rows.Scan(&n) sum += n rowCount++ } if rows.Err() != nil { t.Fatalf("conn.Query failed: %v", err) } if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") } if sum != 55 { t.Error("Wrong values returned") } } func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) { if !assert.NotNil(t, expected) { return } if !assert.NotNil(t, actual) { return } assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) assert.Equalf(t, expected.User, actual.User, "%s - User", testName) assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) } } if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { for i := range expected.Fallbacks { assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { if expected.Fallbacks[i].TLSConfig != nil { assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) } } } } } func skipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) { if conn.PgConn().ParameterStatus("crdb_version") != "" { t.Skip(msg) } } pgx-4.18.1/internal/000077500000000000000000000000001437725773200142405ustar00rootroot00000000000000pgx-4.18.1/internal/sanitize/000077500000000000000000000000001437725773200160665ustar00rootroot00000000000000pgx-4.18.1/internal/sanitize/sanitize.go000066400000000000000000000152361437725773200202520ustar00rootroot00000000000000package sanitize import ( "bytes" "encoding/hex" "fmt" "strconv" "strings" "time" "unicode/utf8" ) // Part is either a string or an int. A string is raw SQL. An int is a // argument placeholder. type Part interface{} type Query struct { Parts []Part } // utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement // character. utf8.RuneError is not an error if it is also width 3. // // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 func (q *Query) Sanitize(args ...interface{}) (string, error) { argUse := make([]bool, len(args)) buf := &bytes.Buffer{} for _, part := range q.Parts { var str string switch part := part.(type) { case string: str = part case int: argIdx := part - 1 if argIdx >= len(args) { return "", fmt.Errorf("insufficient arguments") } arg := args[argIdx] switch arg := arg.(type) { case nil: str = "null" case int64: str = strconv.FormatInt(arg, 10) case float64: str = strconv.FormatFloat(arg, 'f', -1, 64) case bool: str = strconv.FormatBool(arg) case []byte: str = QuoteBytes(arg) case string: str = QuoteString(arg) case time.Time: str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") default: return "", fmt.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true default: return "", fmt.Errorf("invalid Part type: %T", part) } buf.WriteString(str) } for i, used := range argUse { if !used { return "", fmt.Errorf("unused argument: %d", i) } } return buf.String(), nil } func NewQuery(sql string) (*Query, error) { l := &sqlLexer{ src: sql, stateFn: rawState, } for l.stateFn != nil { l.stateFn = l.stateFn(l) } query := &Query{Parts: l.parts} return query, nil } func QuoteString(str string) string { return "'" + strings.ReplaceAll(str, "'", "''") + "'" } func QuoteBytes(buf []byte) string { return `'\x` + hex.EncodeToString(buf) + "'" } type sqlLexer struct { src string start int pos int nested int // multiline comment nesting level. stateFn stateFn parts []Part } type stateFn func(*sqlLexer) stateFn func rawState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) l.pos += width switch r { case 'e', 'E': nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) if nextRune == '\'' { l.pos += width return escapeStringState } case '\'': return singleQuoteState case '"': return doubleQuoteState case '$': nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) if '0' <= nextRune && nextRune <= '9' { if l.pos-l.start > 0 { l.parts = append(l.parts, l.src[l.start:l.pos-width]) } l.start = l.pos return placeholderState } case '-': nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) if nextRune == '-' { l.pos += width return oneLineCommentState } case '/': nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) if nextRune == '*' { l.pos += width return multilineCommentState } case utf8.RuneError: if width != replacementcharacterwidth { if l.pos-l.start > 0 { l.parts = append(l.parts, l.src[l.start:l.pos]) l.start = l.pos } return nil } } } } func singleQuoteState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) l.pos += width switch r { case '\'': nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) if nextRune != '\'' { return rawState } l.pos += width case utf8.RuneError: if width != replacementcharacterwidth { if l.pos-l.start > 0 { l.parts = append(l.parts, l.src[l.start:l.pos]) l.start = l.pos } return nil } } } } func doubleQuoteState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) l.pos += width switch r { case '"': nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) if nextRune != '"' { return rawState } l.pos += width case utf8.RuneError: if width != replacementcharacterwidth { if l.pos-l.start > 0 { l.parts = append(l.parts, l.src[l.start:l.pos]) l.start = l.pos } return nil } } } } // placeholderState consumes a placeholder value. The $ must have already has // already been consumed. The first rune must be a digit. func placeholderState(l *sqlLexer) stateFn { num := 0 for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) l.pos += width if '0' <= r && r <= '9' { num *= 10 num += int(r - '0') } else { l.parts = append(l.parts, num) l.pos -= width l.start = l.pos return rawState } } } func escapeStringState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) l.pos += width switch r { case '\\': _, width = utf8.DecodeRuneInString(l.src[l.pos:]) l.pos += width case '\'': nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) if nextRune != '\'' { return rawState } l.pos += width case utf8.RuneError: if width != replacementcharacterwidth { if l.pos-l.start > 0 { l.parts = append(l.parts, l.src[l.start:l.pos]) l.start = l.pos } return nil } } } } func oneLineCommentState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) l.pos += width switch r { case '\\': _, width = utf8.DecodeRuneInString(l.src[l.pos:]) l.pos += width case '\n', '\r': return rawState case utf8.RuneError: if width != replacementcharacterwidth { if l.pos-l.start > 0 { l.parts = append(l.parts, l.src[l.start:l.pos]) l.start = l.pos } return nil } } } } func multilineCommentState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) l.pos += width switch r { case '/': nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) if nextRune == '*' { l.pos += width l.nested++ } case '*': nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) if nextRune != '/' { continue } l.pos += width if l.nested == 0 { return rawState } l.nested-- case utf8.RuneError: if width != replacementcharacterwidth { if l.pos-l.start > 0 { l.parts = append(l.parts, l.src[l.start:l.pos]) l.start = l.pos } return nil } } } } // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. func SanitizeSQL(sql string, args ...interface{}) (string, error) { query, err := NewQuery(sql) if err != nil { return "", err } return query.Sanitize(args...) } pgx-4.18.1/internal/sanitize/sanitize_test.go000066400000000000000000000146701437725773200213120ustar00rootroot00000000000000package sanitize_test import ( "testing" "time" "github.com/jackc/pgx/v4/internal/sanitize" ) func TestNewQuery(t *testing.T) { successTests := []struct { sql string expected sanitize.Query }{ { sql: "select 42", expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, }, { sql: "select $1", expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, }, { sql: "select 'quoted $42', $1", expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}}, }, { sql: `select "doubled quoted $42", $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}}, }, { sql: "select 'foo''bar', $1", expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}}, }, { sql: `select "foo""bar", $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}}, }, { sql: "select '''', $1", expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}}, }, { sql: `select """", $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}}, }, { sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11", expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}}, }, { sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`, expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}}, }, { sql: `select E'escape string\' $42', $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}}, }, { sql: `select e'escape string\' $42', $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}}, }, { sql: `select /* a baby's toy */ 'barbie', $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}}, }, { sql: `select /* *_* */ $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}}, }, { sql: `select 42 /* /* /* 42 */ */ */, $1`, expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}}, }, { sql: "select -- a baby's toy\n'barbie', $1", expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}}, }, { sql: "select 42 -- is a Deep Thought's favorite number", expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Deep Thought's favorite number"}}, }, { sql: "select 42, -- \\nis a Deep Thought's favorite number\n$1", expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\n", 1}}, }, { sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1", expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}}, }, { // https://github.com/jackc/pgx/issues/1380 sql: "select 'hello w�rld'", expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello w�rld'"}}, }, { // Unterminated quoted string sql: "select 'hello world", expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}}, }, } for i, tt := range successTests { query, err := sanitize.NewQuery(tt.sql) if err != nil { t.Errorf("%d. %v", i, err) } if len(query.Parts) == len(tt.expected.Parts) { for j := range query.Parts { if query.Parts[j] != tt.expected.Parts[j] { t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j]) } } } else { t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts) } } } func TestQuerySanitize(t *testing.T) { successfulTests := []struct { query sanitize.Query args []interface{} expected string }{ { query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, args: []interface{}{}, expected: `select 42`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{int64(42)}, expected: `select 42`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{float64(1.23)}, expected: `select 1.23`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{true}, expected: `select true`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{[]byte{0, 1, 2, 3, 255}}, expected: `select '\x00010203ff'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{nil}, expected: `select null`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{"foobar"}, expected: `select 'foobar'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{"foo'bar"}, expected: `select 'foo''bar'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{`foo\'bar`}, expected: `select 'foo\''bar'`, }, { query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, args: []interface{}{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, expected: `insert '2020-03-01 23:59:59.999999Z'`, }, } for i, tt := range successfulTests { actual, err := tt.query.Sanitize(tt.args...) if err != nil { t.Errorf("%d. %v", i, err) continue } if tt.expected != actual { t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual) } } errorTests := []struct { query sanitize.Query args []interface{} expected string }{ { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, args: []interface{}{int64(42)}, expected: `insufficient arguments`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, args: []interface{}{int64(42)}, expected: `unused argument: 0`, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []interface{}{42}, expected: `invalid arg type: int`, }, } for i, tt := range errorTests { _, err := tt.query.Sanitize(tt.args...) if err == nil || err.Error() != tt.expected { t.Errorf("%d. expected error %v, got %v", i, tt.expected, err) } } } pgx-4.18.1/large_objects.go000066400000000000000000000064431437725773200155650ustar00rootroot00000000000000package pgx import ( "context" "errors" "io" ) // LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it // was created. // // For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html type LargeObjects struct { tx Tx } type LargeObjectMode int32 const ( LargeObjectModeWrite LargeObjectMode = 0x20000 LargeObjectModeRead LargeObjectMode = 0x40000 ) // Create creates a new large object. If oid is zero, the server assigns an unused OID. func (o *LargeObjects) Create(ctx context.Context, oid uint32) (uint32, error) { err := o.tx.QueryRow(ctx, "select lo_create($1)", oid).Scan(&oid) return oid, err } // Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large // object. func (o *LargeObjects) Open(ctx context.Context, oid uint32, mode LargeObjectMode) (*LargeObject, error) { var fd int32 err := o.tx.QueryRow(ctx, "select lo_open($1, $2)", oid, mode).Scan(&fd) if err != nil { return nil, err } return &LargeObject{fd: fd, tx: o.tx, ctx: ctx}, nil } // Unlink removes a large object from the database. func (o *LargeObjects) Unlink(ctx context.Context, oid uint32) error { var result int32 err := o.tx.QueryRow(ctx, "select lo_unlink($1)", oid).Scan(&result) if err != nil { return err } if result != 1 { return errors.New("failed to remove large object") } return nil } // A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized // in. It uses the context it was initialized with for all operations. It implements these interfaces: // // io.Writer // io.Reader // io.Seeker // io.Closer type LargeObject struct { ctx context.Context tx Tx fd int32 } // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. func (o *LargeObject) Write(p []byte) (int, error) { var n int err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n) if err != nil { return n, err } if n < 0 { return 0, errors.New("failed to write to large object") } return n, nil } // Read reads up to len(p) bytes into p returning the number of bytes read. func (o *LargeObject) Read(p []byte) (int, error) { var res []byte err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res) copy(p, res) if err != nil { return len(res), err } if len(res) < len(p) { err = io.EOF } return len(res), err } // Seek moves the current location pointer to the new location specified by offset. func (o *LargeObject) Seek(offset int64, whence int) (n int64, err error) { err = o.tx.QueryRow(o.ctx, "select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) return n, err } // Tell returns the current read or write location of the large object descriptor. func (o *LargeObject) Tell() (n int64, err error) { err = o.tx.QueryRow(o.ctx, "select lo_tell64($1)", o.fd).Scan(&n) return n, err } // Truncate the large object to size. func (o *LargeObject) Truncate(size int64) (err error) { _, err = o.tx.Exec(o.ctx, "select lo_truncate64($1, $2)", o.fd, size) return err } // Close the large object descriptor. func (o *LargeObject) Close() error { _, err := o.tx.Exec(o.ctx, "select lo_close($1)", o.fd) return err } pgx-4.18.1/large_objects_test.go000066400000000000000000000126701437725773200166230ustar00rootroot00000000000000package pgx_test import ( "context" "io" "os" "testing" "time" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" ) func TestLargeObjects(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } skipCockroachDB(t, conn, "Server does support large objects") tx, err := conn.Begin(ctx) if err != nil { t.Fatal(err) } testLargeObjects(t, ctx, tx) } func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } config.PreferSimpleProtocol = true conn, err := pgx.ConnectConfig(ctx, config) if err != nil { t.Fatal(err) } skipCockroachDB(t, conn, "Server does support large objects") tx, err := conn.Begin(ctx) if err != nil { t.Fatal(err) } testLargeObjects(t, ctx, tx) } func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) { lo := tx.LargeObjects() id, err := lo.Create(ctx, 0) if err != nil { t.Fatal(err) } obj, err := lo.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) if err != nil { t.Fatal(err) } n, err := obj.Write([]byte("testing")) if err != nil { t.Fatal(err) } if n != 7 { t.Errorf("Expected n to be 7, got %d", n) } pos, err := obj.Seek(1, 0) if err != nil { t.Fatal(err) } if pos != 1 { t.Errorf("Expected pos to be 1, got %d", pos) } res := make([]byte, 6) n, err = obj.Read(res) if err != nil { t.Fatal(err) } if string(res) != "esting" { t.Errorf(`Expected res to be "esting", got %q`, res) } if n != 6 { t.Errorf("Expected n to be 6, got %d", n) } n, err = obj.Read(res) if err != io.EOF { t.Error("Expected io.EOF, go nil") } if n != 0 { t.Errorf("Expected n to be 0, got %d", n) } pos, err = obj.Tell() if err != nil { t.Fatal(err) } if pos != 7 { t.Errorf("Expected pos to be 7, got %d", pos) } err = obj.Truncate(1) if err != nil { t.Fatal(err) } pos, err = obj.Seek(-1, 2) if err != nil { t.Fatal(err) } if pos != 0 { t.Errorf("Expected pos to be 0, got %d", pos) } res = make([]byte, 2) n, err = obj.Read(res) if err != io.EOF { t.Errorf("Expected err to be io.EOF, got %v", err) } if n != 1 { t.Errorf("Expected n to be 1, got %d", n) } if res[0] != 't' { t.Errorf("Expected res[0] to be 't', got %v", res[0]) } err = obj.Close() if err != nil { t.Fatal(err) } err = lo.Unlink(ctx, id) if err != nil { t.Fatal(err) } _, err = lo.Open(ctx, id, pgx.LargeObjectModeRead) if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { t.Errorf("Expected undefined_object error (42704), got %#v", err) } } func TestLargeObjectsMultipleTransactions(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) if err != nil { t.Fatal(err) } skipCockroachDB(t, conn, "Server does support large objects") tx, err := conn.Begin(ctx) if err != nil { t.Fatal(err) } lo := tx.LargeObjects() id, err := lo.Create(ctx, 0) if err != nil { t.Fatal(err) } obj, err := lo.Open(ctx, id, pgx.LargeObjectModeWrite) if err != nil { t.Fatal(err) } n, err := obj.Write([]byte("testing")) if err != nil { t.Fatal(err) } if n != 7 { t.Errorf("Expected n to be 7, got %d", n) } // Commit the first transaction err = tx.Commit(ctx) if err != nil { t.Fatal(err) } // IMPORTANT: Use the same connection for another query query := `select n from generate_series(1,10) n` rows, err := conn.Query(ctx, query) if err != nil { t.Fatal(err) } rows.Close() // Start a new transaction tx2, err := conn.Begin(ctx) if err != nil { t.Fatal(err) } lo2 := tx2.LargeObjects() // Reopen the large object in the new transaction obj2, err := lo2.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) if err != nil { t.Fatal(err) } pos, err := obj2.Seek(1, 0) if err != nil { t.Fatal(err) } if pos != 1 { t.Errorf("Expected pos to be 1, got %d", pos) } res := make([]byte, 6) n, err = obj2.Read(res) if err != nil { t.Fatal(err) } if string(res) != "esting" { t.Errorf(`Expected res to be "esting", got %q`, res) } if n != 6 { t.Errorf("Expected n to be 6, got %d", n) } n, err = obj2.Read(res) if err != io.EOF { t.Error("Expected io.EOF, go nil") } if n != 0 { t.Errorf("Expected n to be 0, got %d", n) } pos, err = obj2.Tell() if err != nil { t.Fatal(err) } if pos != 7 { t.Errorf("Expected pos to be 7, got %d", pos) } err = obj2.Truncate(1) if err != nil { t.Fatal(err) } pos, err = obj2.Seek(-1, 2) if err != nil { t.Fatal(err) } if pos != 0 { t.Errorf("Expected pos to be 0, got %d", pos) } res = make([]byte, 2) n, err = obj2.Read(res) if err != io.EOF { t.Errorf("Expected err to be io.EOF, got %v", err) } if n != 1 { t.Errorf("Expected n to be 1, got %d", n) } if res[0] != 't' { t.Errorf("Expected res[0] to be 't', got %v", res[0]) } err = obj2.Close() if err != nil { t.Fatal(err) } err = lo2.Unlink(ctx, id) if err != nil { t.Fatal(err) } _, err = lo2.Open(ctx, id, pgx.LargeObjectModeRead) if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { t.Errorf("Expected undefined_object error (42704), got %#v", err) } } pgx-4.18.1/log/000077500000000000000000000000001437725773200132055ustar00rootroot00000000000000pgx-4.18.1/log/kitlogadapter/000077500000000000000000000000001437725773200160375ustar00rootroot00000000000000pgx-4.18.1/log/kitlogadapter/adapter.go000066400000000000000000000015211437725773200200050ustar00rootroot00000000000000package kitlogadapter import ( "context" "github.com/go-kit/log" kitlevel "github.com/go-kit/log/level" "github.com/jackc/pgx/v4" ) type Logger struct { l log.Logger } func NewLogger(l log.Logger) *Logger { return &Logger{l: l} } func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { logger := l.l for k, v := range data { logger = log.With(logger, k, v) } switch level { case pgx.LogLevelTrace: logger.Log("PGX_LOG_LEVEL", level, "msg", msg) case pgx.LogLevelDebug: kitlevel.Debug(logger).Log("msg", msg) case pgx.LogLevelInfo: kitlevel.Info(logger).Log("msg", msg) case pgx.LogLevelWarn: kitlevel.Warn(logger).Log("msg", msg) case pgx.LogLevelError: kitlevel.Error(logger).Log("msg", msg) default: logger.Log("INVALID_PGX_LOG_LEVEL", level, "error", msg) } } pgx-4.18.1/log/log15adapter/000077500000000000000000000000001437725773200154755ustar00rootroot00000000000000pgx-4.18.1/log/log15adapter/adapter.go000066400000000000000000000023231437725773200174440ustar00rootroot00000000000000// Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger // log. package log15adapter import ( "context" "github.com/jackc/pgx/v4" ) // Log15Logger interface defines the subset of // github.com/inconshreveable/log15.Logger that this adapter uses. type Log15Logger interface { Debug(msg string, ctx ...interface{}) Info(msg string, ctx ...interface{}) Warn(msg string, ctx ...interface{}) Error(msg string, ctx ...interface{}) Crit(msg string, ctx ...interface{}) } type Logger struct { l Log15Logger } func NewLogger(l Log15Logger) *Logger { return &Logger{l: l} } func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { logArgs := make([]interface{}, 0, len(data)) for k, v := range data { logArgs = append(logArgs, k, v) } switch level { case pgx.LogLevelTrace: l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...) case pgx.LogLevelDebug: l.l.Debug(msg, logArgs...) case pgx.LogLevelInfo: l.l.Info(msg, logArgs...) case pgx.LogLevelWarn: l.l.Warn(msg, logArgs...) case pgx.LogLevelError: l.l.Error(msg, logArgs...) default: l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...) } } pgx-4.18.1/log/logrusadapter/000077500000000000000000000000001437725773200160615ustar00rootroot00000000000000pgx-4.18.1/log/logrusadapter/adapter.go000066400000000000000000000015671437725773200200410ustar00rootroot00000000000000// Package logrusadapter provides a logger that writes to a github.com/sirupsen/logrus.Logger // log. package logrusadapter import ( "context" "github.com/jackc/pgx/v4" "github.com/sirupsen/logrus" ) type Logger struct { l logrus.FieldLogger } func NewLogger(l logrus.FieldLogger) *Logger { return &Logger{l: l} } func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { var logger logrus.FieldLogger if data != nil { logger = l.l.WithFields(data) } else { logger = l.l } switch level { case pgx.LogLevelTrace: logger.WithField("PGX_LOG_LEVEL", level).Debug(msg) case pgx.LogLevelDebug: logger.Debug(msg) case pgx.LogLevelInfo: logger.Info(msg) case pgx.LogLevelWarn: logger.Warn(msg) case pgx.LogLevelError: logger.Error(msg) default: logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg) } } pgx-4.18.1/log/testingadapter/000077500000000000000000000000001437725773200162235ustar00rootroot00000000000000pgx-4.18.1/log/testingadapter/adapter.go000066400000000000000000000013321437725773200201710ustar00rootroot00000000000000// Package testingadapter provides a logger that writes to a test or benchmark // log. package testingadapter import ( "context" "fmt" "github.com/jackc/pgx/v4" ) // TestingLogger interface defines the subset of testing.TB methods used by this // adapter. type TestingLogger interface { Log(args ...interface{}) } type Logger struct { l TestingLogger } func NewLogger(l TestingLogger) *Logger { return &Logger{l: l} } func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { logArgs := make([]interface{}, 0, 2+len(data)) logArgs = append(logArgs, level, msg) for k, v := range data { logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v)) } l.l.Log(logArgs...) } pgx-4.18.1/log/zapadapter/000077500000000000000000000000001437725773200153405ustar00rootroot00000000000000pgx-4.18.1/log/zapadapter/adapter.go000066400000000000000000000020021437725773200173010ustar00rootroot00000000000000// Package zapadapter provides a logger that writes to a go.uber.org/zap.Logger. package zapadapter import ( "context" "github.com/jackc/pgx/v4" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) type Logger struct { logger *zap.Logger } func NewLogger(logger *zap.Logger) *Logger { return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))} } func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { fields := make([]zapcore.Field, len(data)) i := 0 for k, v := range data { fields[i] = zap.Any(k, v) i++ } switch level { case pgx.LogLevelTrace: pl.logger.Debug(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) case pgx.LogLevelDebug: pl.logger.Debug(msg, fields...) case pgx.LogLevelInfo: pl.logger.Info(msg, fields...) case pgx.LogLevelWarn: pl.logger.Warn(msg, fields...) case pgx.LogLevelError: pl.logger.Error(msg, fields...) default: pl.logger.Error(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) } } pgx-4.18.1/log/zerologadapter/000077500000000000000000000000001437725773200162275ustar00rootroot00000000000000pgx-4.18.1/log/zerologadapter/adapter.go000066400000000000000000000046721437725773200202070ustar00rootroot00000000000000// Package zerologadapter provides a logger that writes to a github.com/rs/zerolog. package zerologadapter import ( "context" "github.com/jackc/pgx/v4" "github.com/rs/zerolog" ) type Logger struct { logger zerolog.Logger withFunc func(context.Context, zerolog.Context) zerolog.Context fromContext bool skipModule bool } // option options for configuring the logger when creating a new logger. type option func(logger *Logger) // WithContextFunc adds possibility to get request scoped values from the // ctx.Context before logging lines. func WithContextFunc(withFunc func(context.Context, zerolog.Context) zerolog.Context) option { return func(logger *Logger) { logger.withFunc = withFunc } } // WithoutPGXModule disables adding module:pgx to the default logger context. func WithoutPGXModule() option { return func(logger *Logger) { logger.skipModule = true } } // NewLogger accepts a zerolog.Logger as input and returns a new custom pgx // logging facade as output. func NewLogger(logger zerolog.Logger, options ...option) *Logger { l := Logger{ logger: logger, } l.init(options) return &l } // NewContextLogger creates logger that extracts the zerolog.Logger from the // context.Context by using `zerolog.Ctx`. The zerolog.DefaultContextLogger will // be used if no logger is associated with the context. func NewContextLogger(options ...option) *Logger { l := Logger{ fromContext: true, } l.init(options) return &l } func (pl *Logger) init(options []option) { for _, opt := range options { opt(pl) } if !pl.skipModule { pl.logger = pl.logger.With().Str("module", "pgx").Logger() } } func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { var zlevel zerolog.Level switch level { case pgx.LogLevelNone: zlevel = zerolog.NoLevel case pgx.LogLevelError: zlevel = zerolog.ErrorLevel case pgx.LogLevelWarn: zlevel = zerolog.WarnLevel case pgx.LogLevelInfo: zlevel = zerolog.InfoLevel case pgx.LogLevelDebug: zlevel = zerolog.DebugLevel default: zlevel = zerolog.DebugLevel } var zctx zerolog.Context if pl.fromContext { logger := zerolog.Ctx(ctx) zctx = logger.With() } else { zctx = pl.logger.With() } if pl.withFunc != nil { zctx = pl.withFunc(ctx, zctx) } pgxlog := zctx.Logger() event := pgxlog.WithLevel(zlevel) if event.Enabled() { if pl.fromContext && !pl.skipModule { event.Str("module", "pgx") } event.Fields(data).Msg(msg) } } pgx-4.18.1/log/zerologadapter/adapter_test.go000066400000000000000000000051201437725773200212330ustar00rootroot00000000000000package zerologadapter_test import ( "bytes" "context" "testing" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/log/zerologadapter" "github.com/rs/zerolog" ) func TestLogger(t *testing.T) { t.Run("default", func(t *testing.T) { var buf bytes.Buffer zlogger := zerolog.New(&buf) logger := zerologadapter.NewLogger(zlogger) logger.Log(context.Background(), pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"}) const want = `{"level":"info","module":"pgx","one":"two","message":"hello"} ` got := buf.String() if got != want { t.Errorf("%s != %s", got, want) } }) t.Run("disable pgx module", func(t *testing.T) { var buf bytes.Buffer zlogger := zerolog.New(&buf) logger := zerologadapter.NewLogger(zlogger, zerologadapter.WithoutPGXModule()) logger.Log(context.Background(), pgx.LogLevelInfo, "hello", nil) const want = `{"level":"info","message":"hello"} ` got := buf.String() if got != want { t.Errorf("%s != %s", got, want) } }) t.Run("from context", func(t *testing.T) { var buf bytes.Buffer zlogger := zerolog.New(&buf) ctx := zlogger.WithContext(context.Background()) logger := zerologadapter.NewContextLogger() logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"}) const want = `{"level":"info","module":"pgx","one":"two","message":"hello"} ` got := buf.String() if got != want { t.Log(got) t.Log(want) t.Errorf("%s != %s", got, want) } }) var buf bytes.Buffer type key string var ck key zlogger := zerolog.New(&buf) logger := zerologadapter.NewLogger(zlogger, zerologadapter.WithContextFunc(func(ctx context.Context, logWith zerolog.Context) zerolog.Context { // You can use zerolog.hlog.IDFromCtx(ctx) or even // zerolog.log.Ctx(ctx) to fetch the whole logger instance from the // context if you want. id, ok := ctx.Value(ck).(string) if ok { logWith = logWith.Str("req_id", id) } return logWith }), ) t.Run("no request id", func(t *testing.T) { buf.Reset() ctx := context.Background() logger.Log(ctx, pgx.LogLevelInfo, "hello", nil) const want = `{"level":"info","module":"pgx","message":"hello"} ` got := buf.String() if got != want { t.Errorf("%s != %s", got, want) } }) t.Run("with request id", func(t *testing.T) { buf.Reset() ctx := context.WithValue(context.Background(), ck, "1") logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"two": "2"}) const want = `{"level":"info","module":"pgx","req_id":"1","two":"2","message":"hello"} ` got := buf.String() if got != want { t.Errorf("%s != %s", got, want) } }) } pgx-4.18.1/logger.go000066400000000000000000000044631437725773200142410ustar00rootroot00000000000000package pgx import ( "context" "encoding/hex" "errors" "fmt" ) // The values for log levels are chosen such that the zero value means that no // log level was specified. const ( LogLevelTrace = 6 LogLevelDebug = 5 LogLevelInfo = 4 LogLevelWarn = 3 LogLevelError = 2 LogLevelNone = 1 ) // LogLevel represents the pgx logging level. See LogLevel* constants for // possible values. type LogLevel int func (ll LogLevel) String() string { switch ll { case LogLevelTrace: return "trace" case LogLevelDebug: return "debug" case LogLevelInfo: return "info" case LogLevelWarn: return "warn" case LogLevelError: return "error" case LogLevelNone: return "none" default: return fmt.Sprintf("invalid level %d", ll) } } // Logger is the interface used to get logging from pgx internals. type Logger interface { // Log a message at the given level with data key/value pairs. data may be nil. Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) } // LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) // Log delegates the logging request to the wrapped function func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { f(ctx, level, msg, data) } // LogLevelFromString converts log level string to constant // // Valid levels: // // trace // debug // info // warn // error // none func LogLevelFromString(s string) (LogLevel, error) { switch s { case "trace": return LogLevelTrace, nil case "debug": return LogLevelDebug, nil case "info": return LogLevelInfo, nil case "warn": return LogLevelWarn, nil case "error": return LogLevelError, nil case "none": return LogLevelNone, nil default: return 0, errors.New("invalid log level") } } func logQueryArgs(args []interface{}) []interface{} { logArgs := make([]interface{}, 0, len(args)) for _, a := range args { switch v := a.(type) { case []byte: if len(v) < 64 { a = hex.EncodeToString(v) } else { a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) } case string: if len(v) > 64 { a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) } } logArgs = append(logArgs, a) } return logArgs } pgx-4.18.1/messages.go000066400000000000000000000006221437725773200145620ustar00rootroot00000000000000package pgx import ( "database/sql/driver" "github.com/jackc/pgtype" ) func convertDriverValuers(args []interface{}) ([]interface{}, error) { for i, arg := range args { switch arg := arg.(type) { case pgtype.BinaryEncoder: case pgtype.TextEncoder: case driver.Valuer: v, err := callValuerValue(arg) if err != nil { return nil, err } args[i] = v } } return args, nil } pgx-4.18.1/pgbouncer_test.go000066400000000000000000000042521437725773200160010ustar00rootroot00000000000000package pgx_test import ( "context" "os" "testing" "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestPgbouncerStatementCacheDescribe(t *testing.T) { connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING") } config := mustParseConfig(t, connString) config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModeDescribe, 1024) } testPgbouncer(t, config, 10, 100) } func TestPgbouncerSimpleProtocol(t *testing.T) { connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING") } config := mustParseConfig(t, connString) config.BuildStatementCache = nil config.PreferSimpleProtocol = true testPgbouncer(t, config, 10, 100) } func testPgbouncer(t *testing.T, config *pgx.ConnConfig, workers, iterations int) { doneChan := make(chan struct{}) for i := 0; i < workers; i++ { go func() { defer func() { doneChan <- struct{}{} }() conn, err := pgx.ConnectConfig(context.Background(), config) require.Nil(t, err) defer closeConn(t, conn) for i := 0; i < iterations; i++ { var i32 int32 var i64 int64 var f32 float32 var s string var s2 string err = conn.QueryRow(context.Background(), "select 1::int4, 2::int8, 3::float4, 'hi'::text").Scan(&i32, &i64, &f32, &s) require.NoError(t, err) assert.Equal(t, int32(1), i32) assert.Equal(t, int64(2), i64) assert.Equal(t, float32(3), f32) assert.Equal(t, "hi", s) err = conn.QueryRow(context.Background(), "select 1::int8, 2::float4, 'bye'::text, 4::int4, 'whatever'::text").Scan(&i64, &f32, &s, &i32, &s2) require.NoError(t, err) assert.Equal(t, int64(1), i64) assert.Equal(t, float32(2), f32) assert.Equal(t, "bye", s) assert.Equal(t, int32(4), i32) assert.Equal(t, "whatever", s2) } }() } for i := 0; i < workers; i++ { <-doneChan } } pgx-4.18.1/pgxpool/000077500000000000000000000000001437725773200141145ustar00rootroot00000000000000pgx-4.18.1/pgxpool/batch_results.go000066400000000000000000000022401437725773200173030ustar00rootroot00000000000000package pgxpool import ( "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" ) type errBatchResults struct { err error } func (br errBatchResults) Exec() (pgconn.CommandTag, error) { return nil, br.err } func (br errBatchResults) Query() (pgx.Rows, error) { return errRows{err: br.err}, br.err } func (br errBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { return nil, br.err } func (br errBatchResults) QueryRow() pgx.Row { return errRow{err: br.err} } func (br errBatchResults) Close() error { return br.err } type poolBatchResults struct { br pgx.BatchResults c *Conn } func (br *poolBatchResults) Exec() (pgconn.CommandTag, error) { return br.br.Exec() } func (br *poolBatchResults) Query() (pgx.Rows, error) { return br.br.Query() } func (br *poolBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { return br.br.QueryFunc(scans, f) } func (br *poolBatchResults) QueryRow() pgx.Row { return br.br.QueryRow() } func (br *poolBatchResults) Close() error { err := br.br.Close() if br.c != nil { br.c.Release() br.c = nil } return err } pgx-4.18.1/pgxpool/bench_test.go000066400000000000000000000034171437725773200165660ustar00rootroot00000000000000package pgxpool_test import ( "context" "os" "testing" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/stretchr/testify/require" ) func BenchmarkAcquireAndRelease(b *testing.B) { pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(b, err) defer pool.Close() b.ResetTimer() for i := 0; i < b.N; i++ { c, err := pool.Acquire(context.Background()) if err != nil { b.Fatal(err) } c.Release() } } func BenchmarkMinimalPreparedSelectBaseline(b *testing.B) { config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(b, err) config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { _, err := c.Prepare(ctx, "ps1", "select $1::int8") return err } db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(b, err) conn, err := db.Acquire(context.Background()) require.NoError(b, err) defer conn.Release() var n int64 b.ResetTimer() for i := 0; i < b.N; i++ { err = conn.QueryRow(context.Background(), "ps1", i).Scan(&n) if err != nil { b.Fatal(err) } if n != int64(i) { b.Fatalf("expected %d, got %d", i, n) } } } func BenchmarkMinimalPreparedSelect(b *testing.B) { config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(b, err) config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { _, err := c.Prepare(ctx, "ps1", "select $1::int8") return err } db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(b, err) var n int64 b.ResetTimer() for i := 0; i < b.N; i++ { err = db.QueryRow(context.Background(), "ps1", i).Scan(&n) if err != nil { b.Fatal(err) } if n != int64(i) { b.Fatalf("expected %d, got %d", i, n) } } } pgx-4.18.1/pgxpool/common_test.go000066400000000000000000000170351437725773200170000ustar00rootroot00000000000000package pgxpool_test import ( "context" "testing" "time" "github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Conn.Release is an asynchronous process that returns immediately. There is no signal when the actual work is // completed. To test something that relies on the actual work for Conn.Release being completed we must simply wait. // This function wraps the sleep so there is more meaning for the callers. func waitForReleaseToComplete() { time.Sleep(500 * time.Millisecond) } type execer interface { Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) } func testExec(t *testing.T, db execer) { results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'") require.NoError(t, err) assert.EqualValues(t, "SET", results) } type queryer interface { Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) } func testQuery(t *testing.T, db queryer) { var sum, rowCount int32 rows, err := db.Query(context.Background(), "select generate_series(1,$1)", 10) require.NoError(t, err) for rows.Next() { var n int32 rows.Scan(&n) sum += n rowCount++ } assert.NoError(t, rows.Err()) assert.Equal(t, int32(10), rowCount) assert.Equal(t, int32(55), sum) } type queryRower interface { QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row } func testQueryRow(t *testing.T, db queryRower) { var what, who string err := db.QueryRow(context.Background(), "select 'hello', $1::text", "world").Scan(&what, &who) assert.NoError(t, err) assert.Equal(t, "hello", what) assert.Equal(t, "world", who) } type sendBatcher interface { SendBatch(context.Context, *pgx.Batch) pgx.BatchResults } func testSendBatch(t *testing.T, db sendBatcher) { batch := &pgx.Batch{} batch.Queue("select 1") batch.Queue("select 2") br := db.SendBatch(context.Background(), batch) var err error var n int32 err = br.QueryRow().Scan(&n) assert.NoError(t, err) assert.EqualValues(t, 1, n) err = br.QueryRow().Scan(&n) assert.NoError(t, err) assert.EqualValues(t, 2, n) err = br.Close() assert.NoError(t, err) } type copyFromer interface { CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error) } func testCopyFrom(t *testing.T, db interface { execer queryer copyFromer }) { _, err := db.Exec(context.Background(), `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) require.NoError(t, err) tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) inputRows := [][]interface{}{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } copyCount, err := db.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) assert.NoError(t, err) assert.EqualValues(t, len(inputRows), copyCount) rows, err := db.Query(context.Background(), "select * from foo") assert.NoError(t, err) var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } assert.NoError(t, rows.Err()) assert.Equal(t, inputRows, outputRows) } func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName string) { if !assert.NotNil(t, expected) { return } if !assert.NotNil(t, actual) { return } assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName) assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName) assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName) assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName) assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName) assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName) assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName) assert.Equalf(t, expected.LazyConnect, actual.LazyConnect, "%s - LazyConnect", testName) assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName) } func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) { if !assert.NotNil(t, expected) { return } if !assert.NotNil(t, actual) { return } assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) assert.Equalf(t, expected.User, actual.User, "%s - User", testName) assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) } } if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { for i := range expected.Fallbacks { assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { if expected.Fallbacks[i].TLSConfig != nil { assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) } } } } } pgx-4.18.1/pgxpool/conn.go000066400000000000000000000100451437725773200154000ustar00rootroot00000000000000package pgxpool import ( "context" "sync/atomic" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/jackc/puddle" ) // Conn is an acquired *pgx.Conn from a Pool. type Conn struct { res *puddle.Resource p *Pool } // Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called. // However, it is safe to call Release multiple times. Subsequent calls after the first will be ignored. func (c *Conn) Release() { if c.res == nil { return } conn := c.Conn() res := c.res c.res = nil if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { res.Destroy() // Signal to the health check to run since we just destroyed a connections // and we might be below minConns now c.p.triggerHealthCheck() return } // If the pool is consistently being used, we might never get to check the // lifetime of a connection since we only check idle connections in checkConnsHealth // so we also check the lifetime here and force a health check if c.p.isExpired(res) { atomic.AddInt64(&c.p.lifetimeDestroyCount, 1) res.Destroy() // Signal to the health check to run since we just destroyed a connections // and we might be below minConns now c.p.triggerHealthCheck() return } if c.p.afterRelease == nil { res.Release() return } go func() { if c.p.afterRelease(conn) { res.Release() } else { res.Destroy() // Signal to the health check to run since we just destroyed a connections // and we might be below minConns now c.p.triggerHealthCheck() } }() } // Hijack assumes ownership of the connection from the pool. Caller is responsible for closing the connection. Hijack // will panic if called on an already released or hijacked connection. func (c *Conn) Hijack() *pgx.Conn { if c.res == nil { panic("cannot hijack already released or hijacked connection") } conn := c.Conn() res := c.res c.res = nil res.Hijack() return conn } func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { return c.Conn().Exec(ctx, sql, arguments...) } func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { return c.Conn().Query(ctx, sql, args...) } func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { return c.Conn().QueryRow(ctx, sql, args...) } func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { return c.Conn().QueryFunc(ctx, sql, args, scans, f) } func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return c.Conn().SendBatch(ctx, b) } func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) } // Begin starts a transaction block from the *Conn without explicitly setting a transaction mode (see BeginTx with TxOptions if transaction mode is required). func (c *Conn) Begin(ctx context.Context) (pgx.Tx, error) { return c.Conn().Begin(ctx) } // BeginTx starts a transaction block from the *Conn with txOptions determining the transaction mode. func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { return c.Conn().BeginTx(ctx, txOptions) } func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { return c.Conn().BeginFunc(ctx, f) } func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { return c.Conn().BeginTxFunc(ctx, txOptions, f) } func (c *Conn) Ping(ctx context.Context) error { return c.Conn().Ping(ctx) } func (c *Conn) Conn() *pgx.Conn { return c.connResource().conn } func (c *Conn) connResource() *connResource { return c.res.Value().(*connResource) } func (c *Conn) getPoolRow(r pgx.Row) *poolRow { return c.connResource().getPoolRow(c, r) } func (c *Conn) getPoolRows(r pgx.Rows) *poolRows { return c.connResource().getPoolRows(c, r) } pgx-4.18.1/pgxpool/conn_test.go000066400000000000000000000031111437725773200164330ustar00rootroot00000000000000package pgxpool_test import ( "context" "os" "testing" "github.com/jackc/pgx/v4/pgxpool" "github.com/stretchr/testify/require" ) func TestConnExec(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() c, err := pool.Acquire(context.Background()) require.NoError(t, err) defer c.Release() testExec(t, c) } func TestConnQuery(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() c, err := pool.Acquire(context.Background()) require.NoError(t, err) defer c.Release() testQuery(t, c) } func TestConnQueryRow(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() c, err := pool.Acquire(context.Background()) require.NoError(t, err) defer c.Release() testQueryRow(t, c) } func TestConnSendBatch(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() c, err := pool.Acquire(context.Background()) require.NoError(t, err) defer c.Release() testSendBatch(t, c) } func TestConnCopyFrom(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() c, err := pool.Acquire(context.Background()) require.NoError(t, err) defer c.Release() testCopyFrom(t, c) } pgx-4.18.1/pgxpool/doc.go000066400000000000000000000016251437725773200152140ustar00rootroot00000000000000// Package pgxpool is a concurrency-safe connection pool for pgx. /* pgxpool implements a nearly identical interface to pgx connections. Establishing a Connection The primary way of establishing a connection is with `pgxpool.Connect`. pool, err := pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) The database connection string can be in URL or DSN format. PostgreSQL settings, pgx settings, and pool settings can be specified here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with `ConnectConfig`. config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) if err != nil { // ... } config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { // do something with every new connection } pool, err := pgxpool.ConnectConfig(context.Background(), config) */ package pgxpool pgx-4.18.1/pgxpool/pool.go000066400000000000000000000554331437725773200154260ustar00rootroot00000000000000package pgxpool import ( "context" "fmt" "math/rand" "runtime" "strconv" "sync" "sync/atomic" "time" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/jackc/puddle" ) var defaultMaxConns = int32(4) var defaultMinConns = int32(0) var defaultMaxConnLifetime = time.Hour var defaultMaxConnIdleTime = time.Minute * 30 var defaultHealthCheckPeriod = time.Minute type connResource struct { conn *pgx.Conn conns []Conn poolRows []poolRow poolRowss []poolRows } func (cr *connResource) getConn(p *Pool, res *puddle.Resource) *Conn { if len(cr.conns) == 0 { cr.conns = make([]Conn, 128) } c := &cr.conns[len(cr.conns)-1] cr.conns = cr.conns[0 : len(cr.conns)-1] c.res = res c.p = p return c } func (cr *connResource) getPoolRow(c *Conn, r pgx.Row) *poolRow { if len(cr.poolRows) == 0 { cr.poolRows = make([]poolRow, 128) } pr := &cr.poolRows[len(cr.poolRows)-1] cr.poolRows = cr.poolRows[0 : len(cr.poolRows)-1] pr.c = c pr.r = r return pr } func (cr *connResource) getPoolRows(c *Conn, r pgx.Rows) *poolRows { if len(cr.poolRowss) == 0 { cr.poolRowss = make([]poolRows, 128) } pr := &cr.poolRowss[len(cr.poolRowss)-1] cr.poolRowss = cr.poolRowss[0 : len(cr.poolRowss)-1] pr.c = c pr.r = r return pr } // detachedCtx wraps a context and will never be canceled, regardless of if // the wrapped one is cancelled. The Err() method will never return any errors. type detachedCtx struct { context.Context } func (detachedCtx) Done() <-chan struct{} { return nil } func (detachedCtx) Deadline() (time.Time, bool) { return time.Time{}, false } func (detachedCtx) Err() error { return nil } // Pool allows for connection reuse. type Pool struct { // 64 bit fields accessed with atomics must be at beginning of struct to guarantee alignment for certain 32-bit // architectures. See BUGS section of https://pkg.go.dev/sync/atomic and https://github.com/jackc/pgx/issues/1288. newConnsCount int64 lifetimeDestroyCount int64 idleDestroyCount int64 p *puddle.Pool config *Config beforeConnect func(context.Context, *pgx.ConnConfig) error afterConnect func(context.Context, *pgx.Conn) error beforeAcquire func(context.Context, *pgx.Conn) bool afterRelease func(*pgx.Conn) bool minConns int32 maxConns int32 maxConnLifetime time.Duration maxConnLifetimeJitter time.Duration maxConnIdleTime time.Duration healthCheckPeriod time.Duration healthCheckChan chan struct{} closeOnce sync.Once closeChan chan struct{} } // Config is the configuration struct for creating a pool. It must be created by ParseConfig and then it can be // modified. A manually initialized ConnConfig will cause ConnectConfig to panic. type Config struct { ConnConfig *pgx.ConnConfig // BeforeConnect is called before a new connection is made. It is passed a copy of the underlying pgx.ConnConfig and // will not impact any existing open connections. BeforeConnect func(context.Context, *pgx.ConnConfig) error // AfterConnect is called after a connection is established, but before it is added to the pool. AfterConnect func(context.Context, *pgx.Conn) error // BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the // acquision or false to indicate that the connection should be destroyed and a different connection should be // acquired. BeforeAcquire func(context.Context, *pgx.Conn) bool // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to // return the connection to the pool or false to destroy the connection. AfterRelease func(*pgx.Conn) bool // MaxConnLifetime is the duration since creation after which a connection will be automatically closed. MaxConnLifetime time.Duration // MaxConnLifetimeJitter is the duration after MaxConnLifetime to randomly decide to close a connection. // This helps prevent all connections from being closed at the exact same time, starving the pool. MaxConnLifetimeJitter time.Duration // MaxConnIdleTime is the duration after which an idle connection will be automatically closed by the health check. MaxConnIdleTime time.Duration // MaxConns is the maximum size of the pool. The default is the greater of 4 or runtime.NumCPU(). MaxConns int32 // MinConns is the minimum size of the pool. After connection closes, the pool might dip below MinConns. A low // number of MinConns might mean the pool is empty after MaxConnLifetime until the health check has a chance // to create new connections. MinConns int32 // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration // If set to true, pool doesn't do any I/O operation on initialization. // And connects to the server only when the pool starts to be used. // The default is false. LazyConnect bool createdByParseConfig bool // Used to enforce created by ParseConfig rule. } // Copy returns a deep copy of the config that is safe to use and modify. // The only exception is the tls.Config: // according to the tls.Config docs it must not be modified after creation. func (c *Config) Copy() *Config { newConfig := new(Config) *newConfig = *c newConfig.ConnConfig = c.ConnConfig.Copy() return newConfig } // ConnString returns the connection string as parsed by pgxpool.ParseConfig into pgxpool.Config. func (c *Config) ConnString() string { return c.ConnConfig.ConnString() } // Connect creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial // connection. See ParseConfig for information on connString format. func Connect(ctx context.Context, connString string) (*Pool, error) { config, err := ParseConfig(connString) if err != nil { return nil, err } return ConnectConfig(ctx, config) } // ConnectConfig creates a new Pool and immediately establishes one connection. ctx can be used to cancel this initial // connection. config must have been created by ParseConfig. func ConnectConfig(ctx context.Context, config *Config) (*Pool, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { panic("config must be created by ParseConfig") } p := &Pool{ config: config, beforeConnect: config.BeforeConnect, afterConnect: config.AfterConnect, beforeAcquire: config.BeforeAcquire, afterRelease: config.AfterRelease, minConns: config.MinConns, maxConns: config.MaxConns, maxConnLifetime: config.MaxConnLifetime, maxConnLifetimeJitter: config.MaxConnLifetimeJitter, maxConnIdleTime: config.MaxConnIdleTime, healthCheckPeriod: config.HealthCheckPeriod, healthCheckChan: make(chan struct{}, 1), closeChan: make(chan struct{}), } p.p = puddle.NewPool( func(ctx context.Context) (interface{}, error) { // we ignore cancellation on the original context because its either from // the health check or its from a query and we don't want to cancel creating // a connection just because the original query was cancelled since that // could end up stampeding the server // this will keep any Values in the original context and will just ignore // cancellation // see https://github.com/jackc/pgx/issues/1259 ctx = detachedCtx{ctx} connConfig := p.config.ConnConfig.Copy() // But we do want to ensure that a connect won't hang forever. if connConfig.ConnectTimeout <= 0 { connConfig.ConnectTimeout = 2 * time.Minute } if p.beforeConnect != nil { if err := p.beforeConnect(ctx, connConfig); err != nil { return nil, err } } conn, err := pgx.ConnectConfig(ctx, connConfig) if err != nil { return nil, err } if p.afterConnect != nil { err = p.afterConnect(ctx, conn) if err != nil { conn.Close(ctx) return nil, err } } cr := &connResource{ conn: conn, conns: make([]Conn, 64), poolRows: make([]poolRow, 64), poolRowss: make([]poolRows, 64), } return cr, nil }, func(value interface{}) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) conn := value.(*connResource).conn conn.Close(ctx) select { case <-conn.PgConn().CleanupDone(): case <-ctx.Done(): } cancel() }, config.MaxConns, ) if !config.LazyConnect { if err := p.checkMinConnsWithContext(ctx); err != nil { // Couldn't create resources for minpool size. Close unhealthy pool. p.Close() return nil, err } // Initially establish one connection res, err := p.p.Acquire(ctx) if err != nil { p.Close() return nil, err } res.Release() } go p.backgroundHealthCheck() return p, nil } // ParseConfig builds a Config from connString. It parses connString with the same behavior as pgx.ParseConfig with the // addition of the following variables: // // pool_max_conns: integer greater than 0 // pool_min_conns: integer 0 or greater // pool_max_conn_lifetime: duration string // pool_max_conn_idle_time: duration string // pool_health_check_period: duration string // pool_max_conn_lifetime_jitter: duration string // // See Config for definitions of these arguments. // // # Example DSN // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 // // # Example URL // postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10 func ParseConfig(connString string) (*Config, error) { connConfig, err := pgx.ParseConfig(connString) if err != nil { return nil, err } config := &Config{ ConnConfig: connConfig, createdByParseConfig: true, } if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conns"]; ok { delete(connConfig.Config.RuntimeParams, "pool_max_conns") n, err := strconv.ParseInt(s, 10, 32) if err != nil { return nil, fmt.Errorf("cannot parse pool_max_conns: %w", err) } if n < 1 { return nil, fmt.Errorf("pool_max_conns too small: %d", n) } config.MaxConns = int32(n) } else { config.MaxConns = defaultMaxConns if numCPU := int32(runtime.NumCPU()); numCPU > config.MaxConns { config.MaxConns = numCPU } } if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_conns"]; ok { delete(connConfig.Config.RuntimeParams, "pool_min_conns") n, err := strconv.ParseInt(s, 10, 32) if err != nil { return nil, fmt.Errorf("cannot parse pool_min_conns: %w", err) } config.MinConns = int32(n) } else { config.MinConns = defaultMinConns } if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok { delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime") d, err := time.ParseDuration(s) if err != nil { return nil, fmt.Errorf("invalid pool_max_conn_lifetime: %w", err) } config.MaxConnLifetime = d } else { config.MaxConnLifetime = defaultMaxConnLifetime } if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_idle_time"]; ok { delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time") d, err := time.ParseDuration(s) if err != nil { return nil, fmt.Errorf("invalid pool_max_conn_idle_time: %w", err) } config.MaxConnIdleTime = d } else { config.MaxConnIdleTime = defaultMaxConnIdleTime } if s, ok := config.ConnConfig.Config.RuntimeParams["pool_health_check_period"]; ok { delete(connConfig.Config.RuntimeParams, "pool_health_check_period") d, err := time.ParseDuration(s) if err != nil { return nil, fmt.Errorf("invalid pool_health_check_period: %w", err) } config.HealthCheckPeriod = d } else { config.HealthCheckPeriod = defaultHealthCheckPeriod } if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime_jitter"]; ok { delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter") d, err := time.ParseDuration(s) if err != nil { return nil, fmt.Errorf("invalid pool_max_conn_lifetime_jitter: %w", err) } config.MaxConnLifetimeJitter = d } return config, nil } // Close closes all connections in the pool and rejects future Acquire calls. Blocks until all connections are returned // to pool and closed. func (p *Pool) Close() { p.closeOnce.Do(func() { close(p.closeChan) p.p.Close() }) } func (p *Pool) isExpired(res *puddle.Resource) bool { now := time.Now() // Small optimization to avoid rand. If it's over lifetime AND jitter, immediately // return true. if now.Sub(res.CreationTime()) > p.maxConnLifetime+p.maxConnLifetimeJitter { return true } if p.maxConnLifetimeJitter == 0 { return false } jitterSecs := rand.Float64() * p.maxConnLifetimeJitter.Seconds() return now.Sub(res.CreationTime()) > p.maxConnLifetime+(time.Duration(jitterSecs)*time.Second) } func (p *Pool) triggerHealthCheck() { go func() { // Destroy is asynchronous so we give it time to actually remove itself from // the pool otherwise we might try to check the pool size too soon time.Sleep(500 * time.Millisecond) select { case p.healthCheckChan <- struct{}{}: default: } }() } func (p *Pool) backgroundHealthCheck() { ticker := time.NewTicker(p.healthCheckPeriod) defer ticker.Stop() for { select { case <-p.closeChan: return case <-p.healthCheckChan: p.checkHealth() case <-ticker.C: p.checkHealth() } } } func (p *Pool) checkHealth() { for { // If checkMinConns failed we don't destroy any connections since we couldn't // even get to minConns if err := p.checkMinConns(); err != nil { // Should we log this error somewhere? break } if !p.checkConnsHealth() { // Since we didn't destroy any connections we can stop looping break } // Technically Destroy is asynchronous but 500ms should be enough for it to // remove it from the underlying pool select { case <-p.closeChan: return case <-time.After(500 * time.Millisecond): } } } // checkConnsHealth will check all idle connections, destroy a connection if // it's idle or too old, and returns true if any were destroyed func (p *Pool) checkConnsHealth() bool { var destroyed bool totalConns := p.Stat().TotalConns() resources := p.p.AcquireAllIdle() for _, res := range resources { // We're okay going under minConns if the lifetime is up if p.isExpired(res) && totalConns >= p.minConns { atomic.AddInt64(&p.lifetimeDestroyCount, 1) res.Destroy() destroyed = true // Since Destroy is async we manually decrement totalConns. totalConns-- } else if res.IdleDuration() > p.maxConnIdleTime && totalConns > p.minConns { atomic.AddInt64(&p.idleDestroyCount, 1) res.Destroy() destroyed = true // Since Destroy is async we manually decrement totalConns. totalConns-- } else { res.ReleaseUnused() } } return destroyed } func (p *Pool) checkMinConnsWithContext(ctx context.Context) error { // TotalConns can include ones that are being destroyed but we should have // sleep(500ms) around all of the destroys to help prevent that from throwing // off this check toCreate := p.minConns - p.Stat().TotalConns() if toCreate > 0 { return p.createIdleResources(ctx, int(toCreate)) } return nil } func (p *Pool) checkMinConns() error { return p.checkMinConnsWithContext(context.Background()) } func (p *Pool) createIdleResources(parentCtx context.Context, targetResources int) error { ctx, cancel := context.WithCancel(parentCtx) defer cancel() errs := make(chan error, targetResources) for i := 0; i < targetResources; i++ { go func() { atomic.AddInt64(&p.newConnsCount, 1) err := p.p.CreateResource(ctx) errs <- err }() } var firstError error for i := 0; i < targetResources; i++ { err := <-errs if err != nil && firstError == nil { cancel() firstError = err } } return firstError } // Acquire returns a connection (*Conn) from the Pool func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { for { res, err := p.p.Acquire(ctx) if err != nil { return nil, err } cr := res.Value().(*connResource) if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { return cr.getConn(p, res), nil } res.Destroy() } } // AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the // call of f. The return value is either an error acquiring the *Conn or the return value of f. The *Conn is // automatically released after the call of f. func (p *Pool) AcquireFunc(ctx context.Context, f func(*Conn) error) error { conn, err := p.Acquire(ctx) if err != nil { return err } defer conn.Release() return f(conn) } // AcquireAllIdle atomically acquires all currently idle connections. Its intended use is for health check and // keep-alive functionality. It does not update pool statistics. func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { resources := p.p.AcquireAllIdle() conns := make([]*Conn, 0, len(resources)) for _, res := range resources { cr := res.Value().(*connResource) if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { conns = append(conns, cr.getConn(p, res)) } else { res.Destroy() } } return conns } // Config returns a copy of config that was used to initialize this pool. func (p *Pool) Config() *Config { return p.config.Copy() } // Stat returns a pgxpool.Stat struct with a snapshot of Pool statistics. func (p *Pool) Stat() *Stat { return &Stat{ s: p.p.Stat(), newConnsCount: atomic.LoadInt64(&p.newConnsCount), lifetimeDestroyCount: atomic.LoadInt64(&p.lifetimeDestroyCount), idleDestroyCount: atomic.LoadInt64(&p.idleDestroyCount), } } // Exec acquires a connection from the Pool and executes the given SQL. // SQL can be either a prepared statement name or an SQL string. // Arguments should be referenced positionally from the SQL string as $1, $2, etc. // The acquired connection is returned to the pool when the Exec function returns. func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { return nil, err } defer c.Release() return c.Exec(ctx, sql, arguments...) } // Query acquires a connection and executes a query that returns pgx.Rows. // Arguments should be referenced positionally from the SQL string as $1, $2, etc. // See pgx.Rows documentation to close the returned Rows and return the acquired connection to the Pool. // // If there is an error, the returned pgx.Rows will be returned in an error state. // If preferred, ignore the error returned from Query and handle errors using the returned pgx.Rows. // // For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (p *Pool) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { c, err := p.Acquire(ctx) if err != nil { return errRows{err: err}, err } rows, err := c.Query(ctx, sql, args...) if err != nil { c.Release() return errRows{err: err}, err } return c.getPoolRows(rows), nil } // QueryRow acquires a connection and executes a query that is expected // to return at most one row (pgx.Row). Errors are deferred until pgx.Row's // Scan method is called. If the query selects no rows, pgx.Row's Scan will // return ErrNoRows. Otherwise, pgx.Row's Scan scans the first selected row // and discards the rest. The acquired connection is returned to the Pool when // pgx.Row's Scan method is called. // // Arguments should be referenced positionally from the SQL string as $1, $2, etc. // // For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { c, err := p.Acquire(ctx) if err != nil { return errRow{err: err} } row := c.QueryRow(ctx, sql, args...) return c.getPoolRow(row) } func (p *Pool) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { c, err := p.Acquire(ctx) if err != nil { return nil, err } defer c.Release() return c.QueryFunc(ctx, sql, args, scans, f) } func (p *Pool) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { c, err := p.Acquire(ctx) if err != nil { return errBatchResults{err: err} } br := c.SendBatch(ctx, b) return &poolBatchResults{br: br, c: c} } // Begin acquires a connection from the Pool and starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no // auto-rollback on context cancellation. Begin initiates a transaction block without explicitly setting a transaction mode for the block (see BeginTx with TxOptions if transaction mode is required). // *pgxpool.Tx is returned, which implements the pgx.Tx interface. // Commit or Rollback must be called on the returned transaction to finalize the transaction block. func (p *Pool) Begin(ctx context.Context) (pgx.Tx, error) { return p.BeginTx(ctx, pgx.TxOptions{}) } // BeginTx acquires a connection from the Pool and starts a transaction with pgx.TxOptions determining the transaction mode. // Unlike database/sql, the context only affects the begin command. i.e. there is no auto-rollback on context cancellation. // *pgxpool.Tx is returned, which implements the pgx.Tx interface. // Commit or Rollback must be called on the returned transaction to finalize the transaction block. func (p *Pool) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { c, err := p.Acquire(ctx) if err != nil { return nil, err } t, err := c.BeginTx(ctx, txOptions) if err != nil { c.Release() return nil, err } return &Tx{t: t, c: c}, nil } func (p *Pool) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { return p.BeginTxFunc(ctx, pgx.TxOptions{}, f) } func (p *Pool) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { c, err := p.Acquire(ctx) if err != nil { return err } defer c.Release() return c.BeginTxFunc(ctx, txOptions, f) } func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { c, err := p.Acquire(ctx) if err != nil { return 0, err } defer c.Release() return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) } // Ping acquires a connection from the Pool and executes an empty sql statement against it. // If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned. func (p *Pool) Ping(ctx context.Context) error { c, err := p.Acquire(ctx) if err != nil { return err } defer c.Release() return c.Ping(ctx) } pgx-4.18.1/pgxpool/pool_test.go000066400000000000000000000706071437725773200164650ustar00rootroot00000000000000package pgxpool_test import ( "context" "errors" "fmt" "os" "sync/atomic" "testing" "time" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pgxpool" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConnect(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_DATABASE") pool, err := pgxpool.Connect(context.Background(), connString) require.NoError(t, err) assert.Equal(t, connString, pool.Config().ConnString()) pool.Close() } func TestConnectConfig(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_DATABASE") config, err := pgxpool.ParseConfig(connString) require.NoError(t, err) pool, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) assertConfigsEqual(t, config, pool.Config(), "Pool.Config() returns original config") pool.Close() } func TestParseConfigExtractsPoolArguments(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1") assert.NoError(t, err) assert.EqualValues(t, 42, config.MaxConns) assert.EqualValues(t, 1, config.MinConns) assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_max_conns") assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_min_conns") } func TestConnectCancel(t *testing.T) { t.Parallel() ctx, cancel := context.WithCancel(context.Background()) cancel() pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) assert.Nil(t, pool) assert.Equal(t, context.Canceled, err) } func TestLazyConnect(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) assert.NoError(t, err) config.LazyConnect = true ctx, cancel := context.WithCancel(context.Background()) cancel() pool, err := pgxpool.ConnectConfig(ctx, config) assert.NoError(t, err) _, err = pool.Exec(ctx, "SELECT 1") assert.Equal(t, context.Canceled, err) } func TestBeforeConnectWithContextWithValueAndOneMinConn(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) assert.NoError(t, err) config.MinConns = 1 config.BeforeConnect = func(ctx context.Context, config *pgx.ConnConfig) error { val := ctx.Value("key") if val == nil { return errors.New("no value found with key 'key'") } return nil } ctx := context.WithValue(context.Background(), "key", "value") _, err = pgxpool.ConnectConfig(ctx, config) assert.NoError(t, err) } func TestConstructorIgnoresContext(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) assert.NoError(t, err) config.LazyConnect = true var cancel func() config.BeforeConnect = func(context.Context, *pgx.ConnConfig) error { // cancel the query's context before we actually Dial to ensure the Dial's // context isn't cancelled cancel() return nil } pool, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) assert.EqualValues(t, 0, pool.Stat().TotalConns()) var ctx context.Context ctx, cancel = context.WithCancel(context.Background()) defer cancel() _, err = pool.Exec(ctx, "SELECT 1") assert.ErrorIs(t, err, context.Canceled) assert.EqualValues(t, 1, pool.Stat().TotalConns()) } func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { t.Parallel() config := &pgxpool.Config{} require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgxpool.ConnectConfig(context.Background(), config) }) } func TestConfigCopyReturnsEqualConfig(t *testing.T) { connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgxpool.ParseConfig(connString) require.NoError(t, err) copied := original.Copy() assertConfigsEqual(t, original, copied, t.Name()) } func TestConfigCopyCanBeUsedToConnect(t *testing.T) { connString := os.Getenv("PGX_TEST_DATABASE") original, err := pgxpool.ParseConfig(connString) require.NoError(t, err) copied := original.Copy() assert.NotPanics(t, func() { _, err = pgxpool.ConnectConfig(context.Background(), copied) }) assert.NoError(t, err) } func TestPoolAcquireAndConnRelease(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() c, err := pool.Acquire(context.Background()) require.NoError(t, err) c.Release() } func TestPoolAcquireAndConnHijack(t *testing.T) { t.Parallel() ctx := context.Background() pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() c, err := pool.Acquire(ctx) require.NoError(t, err) connsBeforeHijack := pool.Stat().TotalConns() conn := c.Hijack() defer conn.Close(ctx) connsAfterHijack := pool.Stat().TotalConns() require.Equal(t, connsBeforeHijack-1, connsAfterHijack) var n int32 err = conn.QueryRow(ctx, `select 1`).Scan(&n) require.NoError(t, err) require.Equal(t, int32(1), n) } func TestPoolAcquireFunc(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() var n int32 err = pool.AcquireFunc(context.Background(), func(c *pgxpool.Conn) error { return c.QueryRow(context.Background(), "select 1").Scan(&n) }) require.NoError(t, err) require.EqualValues(t, 1, n) } func TestPoolAcquireFuncReturnsFnError(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() err = pool.AcquireFunc(context.Background(), func(c *pgxpool.Conn) error { return fmt.Errorf("some error") }) require.EqualError(t, err, "some error") } func TestPoolBeforeConnect(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { cfg.Config.RuntimeParams["application_name"] = "pgx" return nil } db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer db.Close() var str string err = db.QueryRow(context.Background(), "SHOW application_name").Scan(&str) require.NoError(t, err) assert.EqualValues(t, "pgx", str) } func TestPoolAfterConnect(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { _, err := c.Prepare(ctx, "ps1", "select 1") return err } db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer db.Close() var n int32 err = db.QueryRow(context.Background(), "ps1").Scan(&n) require.NoError(t, err) assert.EqualValues(t, 1, n) } func TestPoolBeforeAcquire(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) acquireAttempts := 0 config.BeforeAcquire = func(ctx context.Context, c *pgx.Conn) bool { acquireAttempts++ return acquireAttempts%2 == 0 } db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer db.Close() conns := make([]*pgxpool.Conn, 4) for i := range conns { conns[i], err = db.Acquire(context.Background()) assert.NoError(t, err) } for _, c := range conns { c.Release() } waitForReleaseToComplete() assert.EqualValues(t, 8, acquireAttempts) conns = db.AcquireAllIdle(context.Background()) assert.Len(t, conns, 2) for _, c := range conns { c.Release() } waitForReleaseToComplete() assert.EqualValues(t, 12, acquireAttempts) } func TestPoolAfterRelease(t *testing.T) { t.Parallel() func() { pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() err = pool.AcquireFunc(context.Background(), func(conn *pgxpool.Conn) error { if conn.Conn().PgConn().ParameterStatus("crdb_version") != "" { t.Skip("Server does not support backend PID") } return nil }) require.NoError(t, err) }() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) afterReleaseCount := 0 config.AfterRelease = func(c *pgx.Conn) bool { afterReleaseCount++ return afterReleaseCount%2 == 1 } db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer db.Close() connPIDs := map[uint32]struct{}{} for i := 0; i < 10; i++ { conn, err := db.Acquire(context.Background()) assert.NoError(t, err) connPIDs[conn.Conn().PgConn().PID()] = struct{}{} conn.Release() waitForReleaseToComplete() } assert.EqualValues(t, 5, len(connPIDs)) } func TestPoolAcquireAllIdle(t *testing.T) { t.Parallel() db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() conns := db.AcquireAllIdle(context.Background()) assert.Len(t, conns, 1) for _, c := range conns { c.Release() } waitForReleaseToComplete() conns = make([]*pgxpool.Conn, 3) for i := range conns { conns[i], err = db.Acquire(context.Background()) assert.NoError(t, err) } for _, c := range conns { if c != nil { c.Release() } } waitForReleaseToComplete() conns = db.AcquireAllIdle(context.Background()) assert.Len(t, conns, 3) for _, c := range conns { c.Release() } } func TestConnReleaseChecksMaxConnLifetime(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.MaxConnLifetime = 250 * time.Millisecond db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer db.Close() c, err := db.Acquire(context.Background()) require.NoError(t, err) time.Sleep(config.MaxConnLifetime) c.Release() waitForReleaseToComplete() stats := db.Stat() assert.EqualValues(t, 0, stats.TotalConns()) } func TestConnReleaseClosesBusyConn(t *testing.T) { t.Parallel() db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() c, err := db.Acquire(context.Background()) require.NoError(t, err) _, err = c.Query(context.Background(), "select generate_series(1,10)") require.NoError(t, err) c.Release() waitForReleaseToComplete() // wait for the connection to actually be destroyed for i := 0; i < 1000; i++ { if db.Stat().TotalConns() == 0 { break } time.Sleep(time.Millisecond) } stats := db.Stat() assert.EqualValues(t, 0, stats.TotalConns()) } func TestPoolBackgroundChecksMaxConnLifetime(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.MaxConnLifetime = 100 * time.Millisecond config.HealthCheckPeriod = 100 * time.Millisecond db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer db.Close() c, err := db.Acquire(context.Background()) require.NoError(t, err) c.Release() time.Sleep(config.MaxConnLifetime + 500*time.Millisecond) stats := db.Stat() assert.EqualValues(t, 0, stats.TotalConns()) assert.EqualValues(t, 0, stats.MaxIdleDestroyCount()) assert.EqualValues(t, 1, stats.MaxLifetimeDestroyCount()) } func TestPoolBackgroundChecksMaxConnIdleTime(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.MaxConnLifetime = 1 * time.Minute config.MaxConnIdleTime = 100 * time.Millisecond config.HealthCheckPeriod = 150 * time.Millisecond db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer db.Close() c, err := db.Acquire(context.Background()) require.NoError(t, err) c.Release() time.Sleep(config.HealthCheckPeriod) for i := 0; i < 1000; i++ { if db.Stat().TotalConns() == 0 { break } time.Sleep(time.Millisecond) } stats := db.Stat() assert.EqualValues(t, 0, stats.TotalConns()) assert.EqualValues(t, 1, stats.MaxIdleDestroyCount()) assert.EqualValues(t, 0, stats.MaxLifetimeDestroyCount()) } func TestPoolBackgroundChecksMinConns(t *testing.T) { config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.HealthCheckPeriod = 100 * time.Millisecond config.MinConns = 2 db, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer db.Close() time.Sleep(config.HealthCheckPeriod + 500*time.Millisecond) stats := db.Stat() assert.EqualValues(t, 2, stats.TotalConns()) assert.EqualValues(t, 0, stats.MaxLifetimeDestroyCount()) assert.EqualValues(t, 2, stats.NewConnsCount()) c, err := db.Acquire(context.Background()) require.NoError(t, err) err = c.Conn().Close(context.Background()) require.NoError(t, err) c.Release() time.Sleep(config.HealthCheckPeriod + 500*time.Millisecond) stats = db.Stat() assert.EqualValues(t, 2, stats.TotalConns()) assert.EqualValues(t, 0, stats.MaxIdleDestroyCount()) assert.EqualValues(t, 3, stats.NewConnsCount()) } func TestPoolExec(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() testExec(t, pool) } func TestPoolQuery(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() // Test common usage testQuery(t, pool) waitForReleaseToComplete() // Test expected pool behavior rows, err := pool.Query(context.Background(), "select generate_series(1,$1)", 10) require.NoError(t, err) stats := pool.Stat() assert.EqualValues(t, 1, stats.AcquiredConns()) assert.EqualValues(t, 1, stats.TotalConns()) rows.Close() assert.NoError(t, rows.Err()) waitForReleaseToComplete() stats = pool.Stat() assert.EqualValues(t, 0, stats.AcquiredConns()) assert.EqualValues(t, 1, stats.TotalConns()) } func TestPoolQueryRow(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() testQueryRow(t, pool) waitForReleaseToComplete() stats := pool.Stat() assert.EqualValues(t, 0, stats.AcquiredConns()) assert.EqualValues(t, 1, stats.TotalConns()) } // https://github.com/jackc/pgx/issues/677 func TestPoolQueryRowErrNoRows(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() err = pool.QueryRow(context.Background(), "select n from generate_series(1,10) n where n=0").Scan(nil) require.Equal(t, pgx.ErrNoRows, err) } func TestPoolSendBatch(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() testSendBatch(t, pool) waitForReleaseToComplete() stats := pool.Stat() assert.EqualValues(t, 0, stats.AcquiredConns()) assert.EqualValues(t, 1, stats.TotalConns()) } func TestPoolCopyFrom(t *testing.T) { // Not able to use testCopyFrom because it relies on temporary tables and the pool may run subsequent calls under // different connections. t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() _, err = pool.Exec(ctx, `drop table if exists poolcopyfromtest`) require.NoError(t, err) _, err = pool.Exec(ctx, `create table poolcopyfromtest(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) require.NoError(t, err) defer pool.Exec(ctx, `drop table poolcopyfromtest`) tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) inputRows := [][]interface{}{ {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, {nil, nil, nil, nil, nil, nil, nil}, } copyCount, err := pool.CopyFrom(ctx, pgx.Identifier{"poolcopyfromtest"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) assert.NoError(t, err) assert.EqualValues(t, len(inputRows), copyCount) rows, err := pool.Query(ctx, "select * from poolcopyfromtest") assert.NoError(t, err) var outputRows [][]interface{} for rows.Next() { row, err := rows.Values() if err != nil { t.Errorf("Unexpected error for rows.Values(): %v", err) } outputRows = append(outputRows, row) } assert.NoError(t, rows.Err()) assert.Equal(t, inputRows, outputRows) } func TestConnReleaseClosesConnInFailedTransaction(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() err = pool.AcquireFunc(ctx, func(conn *pgxpool.Conn) error { if conn.Conn().PgConn().ParameterStatus("crdb_version") != "" { t.Skip("Server does not support backend PID") } return nil }) require.NoError(t, err) c, err := pool.Acquire(ctx) require.NoError(t, err) pid := c.Conn().PgConn().PID() assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus()) _, err = c.Exec(ctx, "begin") assert.NoError(t, err) assert.Equal(t, byte('T'), c.Conn().PgConn().TxStatus()) _, err = c.Exec(ctx, "selct") assert.Error(t, err) assert.Equal(t, byte('E'), c.Conn().PgConn().TxStatus()) c.Release() waitForReleaseToComplete() c, err = pool.Acquire(ctx) require.NoError(t, err) assert.NotEqual(t, pid, c.Conn().PgConn().PID()) assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus()) c.Release() } func TestConnReleaseClosesConnInTransaction(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() err = pool.AcquireFunc(ctx, func(conn *pgxpool.Conn) error { if conn.Conn().PgConn().ParameterStatus("crdb_version") != "" { t.Skip("Server does not support backend PID") } return nil }) require.NoError(t, err) c, err := pool.Acquire(ctx) require.NoError(t, err) pid := c.Conn().PgConn().PID() assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus()) _, err = c.Exec(ctx, "begin") assert.NoError(t, err) assert.Equal(t, byte('T'), c.Conn().PgConn().TxStatus()) c.Release() waitForReleaseToComplete() c, err = pool.Acquire(ctx) require.NoError(t, err) assert.NotEqual(t, pid, c.Conn().PgConn().PID()) assert.Equal(t, byte('I'), c.Conn().PgConn().TxStatus()) c.Release() } func TestConnReleaseDestroysClosedConn(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() pool, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() c, err := pool.Acquire(ctx) require.NoError(t, err) err = c.Conn().Close(ctx) require.NoError(t, err) assert.EqualValues(t, 1, pool.Stat().TotalConns()) c.Release() waitForReleaseToComplete() // wait for the connection to actually be destroyed for i := 0; i < 1000; i++ { if pool.Stat().TotalConns() == 0 { break } time.Sleep(time.Millisecond) } assert.EqualValues(t, 0, pool.Stat().TotalConns()) } func TestConnPoolQueryConcurrentLoad(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() n := 100 done := make(chan bool) for i := 0; i < n; i++ { go func() { defer func() { done <- true }() testQuery(t, pool) testQueryRow(t, pool) }() } for i := 0; i < n; i++ { <-done } } func TestConnReleaseWhenBeginFail(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() db, err := pgxpool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() tx, err := db.BeginTx(ctx, pgx.TxOptions{ IsoLevel: pgx.TxIsoLevel("foo"), }) assert.Error(t, err) if !assert.Zero(t, tx) { err := tx.Rollback(ctx) assert.NoError(t, err) } for i := 0; i < 1000; i++ { if db.Stat().TotalConns() == 0 { break } time.Sleep(time.Millisecond) } assert.EqualValues(t, 0, db.Stat().TotalConns()) } func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() createSql := ` drop table if exists pgxpooltx; create temporary table pgxpooltx( id integer, unique (id) ); ` _, err = db.Exec(context.Background(), createSql) require.NoError(t, err) defer func() { db.Exec(context.Background(), "drop table pgxpooltx") }() err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") require.NoError(t, err) err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") require.NoError(t, err) err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)") require.NoError(t, err) return nil }) require.NoError(t, err) return nil }) require.NoError(t, err) return nil }) require.NoError(t, err) var n int64 err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n) require.NoError(t, err) require.EqualValues(t, 3, n) } func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { db, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer db.Close() createSql := ` drop table if exists pgxpooltx; create temporary table pgxpooltx( id integer, unique (id) ); ` _, err = db.Exec(context.Background(), createSql) require.NoError(t, err) defer func() { db.Exec(context.Background(), "drop table pgxpooltx") }() err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (1)") require.NoError(t, err) err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into pgxpooltx(id) values (2)") require.NoError(t, err) return errors.New("do a rollback") }) require.EqualError(t, err, "do a rollback") _, err = db.Exec(context.Background(), "insert into pgxpooltx(id) values (3)") require.NoError(t, err) return nil }) require.NoError(t, err) var n int64 err = db.QueryRow(context.Background(), "select count(*) from pgxpooltx").Scan(&n) require.NoError(t, err) require.EqualValues(t, 2, n) } func TestIdempotentPoolClose(t *testing.T) { pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) // Close the open pool. require.NotPanics(t, func() { pool.Close() }) // Close the already closed pool. require.NotPanics(t, func() { pool.Close() }) } func TestConnectCreatesMinPool(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.MinConns = int32(12) config.MaxConns = int32(15) config.LazyConnect = false acquireAttempts := int64(0) connectAttempts := int64(0) config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { atomic.AddInt64(&acquireAttempts, 1) return true } config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { atomic.AddInt64(&connectAttempts, 1) return nil } pool, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer pool.Close() stat := pool.Stat() require.Equal(t, int32(12), stat.IdleConns()) require.Equal(t, int64(1), stat.AcquireCount()) require.Equal(t, int32(12), stat.TotalConns()) require.Equal(t, int64(0), acquireAttempts) require.Equal(t, int64(12), connectAttempts) } func TestConnectSkipMinPoolWithLazy(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.MinConns = int32(12) config.MaxConns = int32(15) config.LazyConnect = true acquireAttempts := int64(0) connectAttempts := int64(0) config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { atomic.AddInt64(&acquireAttempts, 1) return true } config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { atomic.AddInt64(&connectAttempts, 1) return nil } pool, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer pool.Close() stat := pool.Stat() require.Equal(t, int32(0), stat.IdleConns()) require.Equal(t, int64(0), stat.AcquireCount()) require.Equal(t, int32(0), stat.TotalConns()) require.Equal(t, int64(0), acquireAttempts) require.Equal(t, int64(0), connectAttempts) } func TestConnectMinPoolZero(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.MinConns = int32(0) config.MaxConns = int32(15) config.LazyConnect = false acquireAttempts := int64(0) connectAttempts := int64(0) config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { atomic.AddInt64(&acquireAttempts, 1) return true } config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { atomic.AddInt64(&connectAttempts, 1) return nil } pool, err := pgxpool.ConnectConfig(context.Background(), config) require.NoError(t, err) defer pool.Close() stat := pool.Stat() require.Equal(t, int32(1), stat.IdleConns()) require.Equal(t, int64(1), stat.AcquireCount()) require.Equal(t, int32(1), stat.TotalConns()) require.Equal(t, int64(0), acquireAttempts) require.Equal(t, int64(1), connectAttempts) } func TestCreateMinPoolClosesConnectionsOnError(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.MinConns = int32(12) config.MaxConns = int32(15) config.LazyConnect = false acquireAttempts := int64(0) madeConnections := int64(0) conns := make(chan *pgx.Conn, 15) config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { atomic.AddInt64(&acquireAttempts, 1) return true } config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { conns <- conn atomic.AddInt64(&madeConnections, 1) mc := atomic.LoadInt64(&madeConnections) if mc == 10 { return errors.New("mock error") } return nil } pool, err := pgxpool.ConnectConfig(context.Background(), config) require.Error(t, err) require.Nil(t, pool) close(conns) for conn := range conns { require.True(t, conn.IsClosed()) } require.Equal(t, int64(0), acquireAttempts) require.True(t, madeConnections >= 10, "Expected %d got %d", 10, madeConnections) } func TestCreateMinPoolReturnsFirstError(t *testing.T) { t.Parallel() config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.MinConns = int32(12) config.MaxConns = int32(15) config.LazyConnect = false acquireAttempts := int64(0) connectAttempts := int64(0) mockErr := errors.New("mock connect error") config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { atomic.AddInt64(&acquireAttempts, 1) return true } config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { atomic.AddInt64(&connectAttempts, 1) ca := atomic.LoadInt64(&connectAttempts) if ca >= 5 { return mockErr } return nil } pool, err := pgxpool.ConnectConfig(context.Background(), config) require.Nil(t, pool) require.Error(t, err) require.True(t, connectAttempts >= 5, "Expected %d got %d", 5, connectAttempts) require.ErrorIs(t, err, mockErr) } func TestPoolSendBatchBatchCloseTwice(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() errChan := make(chan error) testCount := 5000 for i := 0; i < testCount; i++ { go func() { batch := &pgx.Batch{} batch.Queue("select 1") batch.Queue("select 2") br := pool.SendBatch(context.Background(), batch) defer br.Close() var err error var n int32 err = br.QueryRow().Scan(&n) if err != nil { errChan <- err return } if n != 1 { errChan <- fmt.Errorf("expected 1 got %v", n) return } err = br.QueryRow().Scan(&n) if err != nil { errChan <- err return } if n != 2 { errChan <- fmt.Errorf("expected 2 got %v", n) return } err = br.Close() errChan <- err }() } for i := 0; i < testCount; i++ { err := <-errChan assert.NoError(t, err) } } pgx-4.18.1/pgxpool/rows.go000066400000000000000000000040361437725773200154400ustar00rootroot00000000000000package pgxpool import ( "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v4" ) type errRows struct { err error } func (errRows) Close() {} func (e errRows) Err() error { return e.err } func (errRows) CommandTag() pgconn.CommandTag { return nil } func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } func (errRows) Next() bool { return false } func (e errRows) Scan(dest ...interface{}) error { return e.err } func (e errRows) Values() ([]interface{}, error) { return nil, e.err } func (e errRows) RawValues() [][]byte { return nil } type errRow struct { err error } func (e errRow) Scan(dest ...interface{}) error { return e.err } type poolRows struct { r pgx.Rows c *Conn err error } func (rows *poolRows) Close() { rows.r.Close() if rows.c != nil { rows.c.Release() rows.c = nil } } func (rows *poolRows) Err() error { if rows.err != nil { return rows.err } return rows.r.Err() } func (rows *poolRows) CommandTag() pgconn.CommandTag { return rows.r.CommandTag() } func (rows *poolRows) FieldDescriptions() []pgproto3.FieldDescription { return rows.r.FieldDescriptions() } func (rows *poolRows) Next() bool { if rows.err != nil { return false } n := rows.r.Next() if !n { rows.Close() } return n } func (rows *poolRows) Scan(dest ...interface{}) error { err := rows.r.Scan(dest...) if err != nil { rows.Close() } return err } func (rows *poolRows) Values() ([]interface{}, error) { values, err := rows.r.Values() if err != nil { rows.Close() } return values, err } func (rows *poolRows) RawValues() [][]byte { return rows.r.RawValues() } type poolRow struct { r pgx.Row c *Conn err error } func (row *poolRow) Scan(dest ...interface{}) error { if row.err != nil { return row.err } err := row.r.Scan(dest...) if row.c != nil { row.c.Release() } return err } pgx-4.18.1/pgxpool/stat.go000066400000000000000000000045021437725773200154170ustar00rootroot00000000000000package pgxpool import ( "time" "github.com/jackc/puddle" ) // Stat is a snapshot of Pool statistics. type Stat struct { s *puddle.Stat newConnsCount int64 lifetimeDestroyCount int64 idleDestroyCount int64 } // AcquireCount returns the cumulative count of successful acquires from the pool. func (s *Stat) AcquireCount() int64 { return s.s.AcquireCount() } // AcquireDuration returns the total duration of all successful acquires from // the pool. func (s *Stat) AcquireDuration() time.Duration { return s.s.AcquireDuration() } // AcquiredConns returns the number of currently acquired connections in the pool. func (s *Stat) AcquiredConns() int32 { return s.s.AcquiredResources() } // CanceledAcquireCount returns the cumulative count of acquires from the pool // that were canceled by a context. func (s *Stat) CanceledAcquireCount() int64 { return s.s.CanceledAcquireCount() } // ConstructingConns returns the number of conns with construction in progress in // the pool. func (s *Stat) ConstructingConns() int32 { return s.s.ConstructingResources() } // EmptyAcquireCount returns the cumulative count of successful acquires from the pool // that waited for a resource to be released or constructed because the pool was // empty. func (s *Stat) EmptyAcquireCount() int64 { return s.s.EmptyAcquireCount() } // IdleConns returns the number of currently idle conns in the pool. func (s *Stat) IdleConns() int32 { return s.s.IdleResources() } // MaxConns returns the maximum size of the pool. func (s *Stat) MaxConns() int32 { return s.s.MaxResources() } // TotalConns returns the total number of resources currently in the pool. // The value is the sum of ConstructingConns, AcquiredConns, and // IdleConns. func (s *Stat) TotalConns() int32 { return s.s.TotalResources() } // NewConnsCount returns the cumulative count of new connections opened. func (s *Stat) NewConnsCount() int64 { return s.newConnsCount } // MaxLifetimeDestroyCount returns the cumulative count of connections destroyed // because they exceeded MaxConnLifetime. func (s *Stat) MaxLifetimeDestroyCount() int64 { return s.lifetimeDestroyCount } // MaxIdleDestroyCount returns the cumulative count of connections destroyed because // they exceeded MaxConnIdleTime. func (s *Stat) MaxIdleDestroyCount() int64 { return s.idleDestroyCount } pgx-4.18.1/pgxpool/tx.go000066400000000000000000000060521437725773200151010ustar00rootroot00000000000000package pgxpool import ( "context" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" ) // Tx represents a database transaction acquired from a Pool. type Tx struct { t pgx.Tx c *Conn } // Begin starts a pseudo nested transaction implemented with a savepoint. func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { return tx.t.Begin(ctx) } func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { return tx.t.BeginFunc(ctx, f) } // Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed // if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status // (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned. func (tx *Tx) Commit(ctx context.Context) error { err := tx.t.Commit(ctx) if tx.c != nil { tx.c.Release() tx.c = nil } return err } // Rollback rolls back the transaction and returns the associated connection back to the Pool. Rollback will return ErrTxClosed // if the Tx is already closed, but is otherwise safe to call multiple times. Hence, defer tx.Rollback() is safe even if // tx.Commit() will be called first in a non-error condition. func (tx *Tx) Rollback(ctx context.Context) error { err := tx.t.Rollback(ctx) if tx.c != nil { tx.c.Release() tx.c = nil } return err } func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { return tx.t.CopyFrom(ctx, tableName, columnNames, rowSrc) } func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return tx.t.SendBatch(ctx, b) } func (tx *Tx) LargeObjects() pgx.LargeObjects { return tx.t.LargeObjects() } // Prepare creates a prepared statement with name and sql. If the name is empty, // an anonymous prepared statement will be used. sql can contain placeholders // for bound parameters. These placeholders are referenced positionally as $1, $2, etc. // // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same // name and sql arguments. This allows a code path to Prepare and Query/Exec without // needing to first check whether the statement has already been prepared. func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { return tx.t.Prepare(ctx, name, sql) } func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { return tx.t.Exec(ctx, sql, arguments...) } func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { return tx.t.Query(ctx, sql, args...) } func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { return tx.t.QueryRow(ctx, sql, args...) } func (tx *Tx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { return tx.t.QueryFunc(ctx, sql, args, scans, f) } func (tx *Tx) Conn() *pgx.Conn { return tx.t.Conn() } pgx-4.18.1/pgxpool/tx_test.go000066400000000000000000000032551437725773200161420ustar00rootroot00000000000000package pgxpool_test import ( "context" "os" "testing" "github.com/jackc/pgx/v4/pgxpool" "github.com/stretchr/testify/require" ) func TestTxExec(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) testExec(t, tx) } func TestTxQuery(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) testQuery(t, tx) } func TestTxQueryRow(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) testQueryRow(t, tx) } func TestTxSendBatch(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) testSendBatch(t, tx) } func TestTxCopyFrom(t *testing.T) { t.Parallel() pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) defer pool.Close() tx, err := pool.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) testCopyFrom(t, tx) } pgx-4.18.1/query_test.go000066400000000000000000001536001437725773200151640ustar00rootroot00000000000000package pgx_test import ( "bytes" "context" "database/sql" "errors" "fmt" "os" "reflect" "strconv" "strings" "testing" "time" "github.com/cockroachdb/apd" "github.com/gofrs/uuid" "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConnQueryScan(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var sum, rowCount int32 rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } defer rows.Close() for rows.Next() { var n int32 rows.Scan(&n) sum += n rowCount++ } if rows.Err() != nil { t.Fatalf("conn.Query failed: %v", err) } assert.Equal(t, "SELECT 10", string(rows.CommandTag())) if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") } if sum != 55 { t.Error("Wrong values returned") } } func TestConnQueryRowsFieldDescriptionsBeforeNext(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) rows, err := conn.Query(context.Background(), "select 'hello' as msg") require.NoError(t, err) defer rows.Close() require.Len(t, rows.FieldDescriptions(), 1) assert.Equal(t, []byte("msg"), rows.FieldDescriptions()[0].Name) } func TestConnQueryWithoutResultSetCommandTag(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) rows, err := conn.Query(context.Background(), "create temporary table t (id serial);") assert.NoError(t, err) rows.Close() assert.NoError(t, rows.Err()) assert.Equal(t, "CREATE TABLE", string(rows.CommandTag())) } func TestConnQueryScanWithManyColumns(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) columnCount := 1000 sql := "select " for i := 0; i < columnCount; i++ { if i > 0 { sql += "," } sql += fmt.Sprintf(" %d", i) } sql += " from generate_series(1,5)" dest := make([]int, columnCount) var rowCount int rows, err := conn.Query(context.Background(), sql) if err != nil { t.Fatalf("conn.Query failed: %v", err) } defer rows.Close() for rows.Next() { destPtrs := make([]interface{}, columnCount) for i := range destPtrs { destPtrs[i] = &dest[i] } if err := rows.Scan(destPtrs...); err != nil { t.Fatalf("rows.Scan failed: %v", err) } rowCount++ for i := range dest { if dest[i] != i { t.Errorf("dest[%d] => %d, want %d", i, dest[i], i) } } } if rows.Err() != nil { t.Fatalf("conn.Query failed: %v", err) } if rowCount != 5 { t.Errorf("rowCount => %d, want %d", rowCount, 5) } } func TestConnQueryValues(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var rowCount int32 rows, err := conn.Query(context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } defer rows.Close() for rows.Next() { rowCount++ values, err := rows.Values() require.NoError(t, err) require.Len(t, values, 5) assert.Equal(t, "foo", values[0]) assert.Equal(t, "bar", values[1]) assert.EqualValues(t, rowCount, values[2]) assert.Nil(t, values[3]) assert.EqualValues(t, rowCount, values[4]) } if rows.Err() != nil { t.Fatalf("conn.Query failed: %v", err) } if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") } } // https://github.com/jackc/pgx/issues/666 func TestConnQueryValuesWhenUnableToDecode(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Note that this relies on pgtype.Record not supporting the text protocol. This seems safe as it is impossible to // decode the text protocol because unlike the binary protocol there is no way to determine the OIDs of the elements. rows, err := conn.Query(context.Background(), "select (array[1::oid], null)", pgx.QueryResultFormats{pgx.TextFormatCode}) require.NoError(t, err) defer rows.Close() require.True(t, rows.Next()) values, err := rows.Values() require.NoError(t, err) require.Equal(t, "({1},)", values[0]) } func TestConnQueryValuesWithUnknownOID(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) tx, err := conn.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) _, err = tx.Exec(ctx, "create type fruit as enum('orange', 'apple', 'pear')") require.NoError(t, err) rows, err := conn.Query(context.Background(), "select 'orange'::fruit") require.NoError(t, err) defer rows.Close() require.True(t, rows.Next()) values, err := rows.Values() require.NoError(t, err) require.Equal(t, "orange", values[0]) } // https://github.com/jackc/pgx/issues/478 func TestConnQueryReadRowMultipleTimes(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var rowCount int32 rows, err := conn.Query(context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", 10) require.NoError(t, err) defer rows.Close() for rows.Next() { rowCount++ for i := 0; i < 2; i++ { values, err := rows.Values() require.NoError(t, err) require.Len(t, values, 5) require.Equal(t, "foo", values[0]) require.Equal(t, "bar", values[1]) require.EqualValues(t, rowCount, values[2]) require.Nil(t, values[3]) require.EqualValues(t, rowCount, values[4]) var a, b string var c int32 var d pgtype.Unknown var e int32 err = rows.Scan(&a, &b, &c, &d, &e) require.NoError(t, err) require.Equal(t, "foo", a) require.Equal(t, "bar", b) require.Equal(t, rowCount, c) require.Equal(t, pgtype.Null, d.Status) require.Equal(t, rowCount, e) } } require.NoError(t, rows.Err()) require.Equal(t, int32(10), rowCount) } // https://github.com/jackc/pgx/issues/386 func TestConnQueryValuesWithMultipleComplexColumnsOfSameType(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) expected0 := &pgtype.Int8Array{ Elements: []pgtype.Int8{ {Int: 1, Status: pgtype.Present}, {Int: 2, Status: pgtype.Present}, {Int: 3, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Status: pgtype.Present, } expected1 := &pgtype.Int8Array{ Elements: []pgtype.Int8{ {Int: 4, Status: pgtype.Present}, {Int: 5, Status: pgtype.Present}, {Int: 6, Status: pgtype.Present}, }, Dimensions: []pgtype.ArrayDimension{{Length: 3, LowerBound: 1}}, Status: pgtype.Present, } var rowCount int32 rows, err := conn.Query(context.Background(), "select '{1,2,3}'::bigint[], '{4,5,6}'::bigint[] from generate_series(1,$1) n", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } defer rows.Close() for rows.Next() { rowCount++ values, err := rows.Values() if err != nil { t.Fatalf("rows.Values failed: %v", err) } if len(values) != 2 { t.Errorf("Expected rows.Values to return 2 values, but it returned %d", len(values)) } if !reflect.DeepEqual(values[0], *expected0) { t.Errorf(`Expected values[0] to be %v, but it was %v`, *expected0, values[0]) } if !reflect.DeepEqual(values[1], *expected1) { t.Errorf(`Expected values[1] to be %v, but it was %v`, *expected1, values[1]) } } if rows.Err() != nil { t.Fatalf("conn.Query failed: %v", err) } if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") } } // https://github.com/jackc/pgx/issues/228 func TestRowsScanDoesNotAllowScanningBinaryFormatValuesIntoString(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var s string err := conn.QueryRow(context.Background(), "select 1").Scan(&s) if err == nil || !(strings.Contains(err.Error(), "cannot decode binary value into string") || strings.Contains(err.Error(), "cannot assign")) { t.Fatalf("Expected Scan to fail to encode binary value into string but: %v", err) } ensureConnValid(t, conn) } func TestConnQueryRawValues(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var rowCount int32 rows, err := conn.Query( context.Background(), "select 'foo'::text, 'bar'::varchar, n, null, n from generate_series(1,$1) n", pgx.QuerySimpleProtocol(true), 10, ) require.NoError(t, err) defer rows.Close() for rows.Next() { rowCount++ rawValues := rows.RawValues() assert.Len(t, rawValues, 5) assert.Equal(t, "foo", string(rawValues[0])) assert.Equal(t, "bar", string(rawValues[1])) assert.Equal(t, strconv.FormatInt(int64(rowCount), 10), string(rawValues[2])) assert.Nil(t, rawValues[3]) assert.Equal(t, strconv.FormatInt(int64(rowCount), 10), string(rawValues[4])) } require.NoError(t, rows.Err()) assert.EqualValues(t, 10, rowCount) } // Test that a connection stays valid when query results are closed early func TestConnQueryCloseEarly(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Immediately close query without reading any rows rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } rows.Close() ensureConnValid(t, conn) // Read partial response then close rows, err = conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } ok := rows.Next() if !ok { t.Fatal("rows.Next terminated early") } var n int32 rows.Scan(&n) if n != 1 { t.Fatalf("Expected 1 from first row, but got %v", n) } rows.Close() ensureConnValid(t, conn) } func TestConnQueryCloseEarlyWithErrorOnWire(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) rows, err := conn.Query(context.Background(), "select 1/(10-n) from generate_series(1,10) n") if err != nil { t.Fatalf("conn.Query failed: %v", err) } assert.False(t, pgconn.SafeToRetry(err)) rows.Close() ensureConnValid(t, conn) } // Test that a connection stays valid when query results read incorrectly func TestConnQueryReadWrongTypeError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Read a single value incorrectly rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } rowsRead := 0 for rows.Next() { var t time.Time rows.Scan(&t) rowsRead++ } if rowsRead != 1 { t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) } if rows.Err() == nil { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } if rows.Err().Error() != "can't scan into dest[0]: Can't convert OID 23 to time.Time" && !strings.Contains(rows.Err().Error(), "cannot assign") { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } ensureConnValid(t, conn) } // Test that a connection stays valid when query results read incorrectly func TestConnQueryReadTooManyValues(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Read too many values rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) if err != nil { t.Fatalf("conn.Query failed: %v", err) } rowsRead := 0 for rows.Next() { var n, m int32 rows.Scan(&n, &m) rowsRead++ } if rowsRead != 1 { t.Fatalf("Expected error to cause only 1 row to be read, but %d were read", rowsRead) } if rows.Err() == nil { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } ensureConnValid(t, conn) } func TestConnQueryScanIgnoreColumn(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) rows, err := conn.Query(context.Background(), "select 1::int8, 2::int8, 3::int8") if err != nil { t.Fatalf("conn.Query failed: %v", err) } ok := rows.Next() if !ok { t.Fatal("rows.Next terminated early") } var n, m int64 err = rows.Scan(&n, nil, &m) if err != nil { t.Fatalf("rows.Scan failed: %v", err) } rows.Close() if n != 1 { t.Errorf("Expected n to equal 1, but it was %d", n) } if m != 3 { t.Errorf("Expected n to equal 3, but it was %d", m) } ensureConnValid(t, conn) } // https://github.com/jackc/pgx/issues/570 func TestConnQueryDeferredError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") mustExec(t, conn, `create temporary table t ( id text primary key, n int not null, unique (n) deferrable initially deferred ); insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) rows, err := conn.Query(context.Background(), `update t set n=n+1 where id='b' returning *`) if err != nil { t.Fatal(err) } defer rows.Close() for rows.Next() { var id string var n int32 err = rows.Scan(&id, &n) if err != nil { t.Fatal(err) } } if rows.Err() == nil { t.Fatal("expected error 23505 but got none") } if err, ok := rows.Err().(*pgconn.PgError); !ok || err.Code != "23505" { t.Fatalf("expected error 23505, got %v", err) } ensureConnValid(t, conn) } func TestConnQueryErrorWhileReturningRows(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server uses numeric instead of int") for i := 0; i < 100; i++ { func() { sql := `select 42 / (random() * 20)::integer from generate_series(1,100000)` rows, err := conn.Query(context.Background(), sql) if err != nil { t.Fatal(err) } defer rows.Close() for rows.Next() { var n int32 if err := rows.Scan(&n); err != nil { t.Fatalf("Row scan failed: %v", err) } } if _, ok := rows.Err().(*pgconn.PgError); !ok { t.Fatalf("Expected pgx.PgError, got %v", rows.Err()) } ensureConnValid(t, conn) }() } } func TestQueryEncodeError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) rows, err := conn.Query(context.Background(), "select $1::integer", "wrong") if err != nil { t.Errorf("conn.Query failure: %v", err) } assert.False(t, pgconn.SafeToRetry(err)) defer rows.Close() rows.Next() if rows.Err() == nil { t.Error("Expected rows.Err() to return error, but it didn't") } if conn.PgConn().ParameterStatus("crdb_version") != "" { if !strings.Contains(rows.Err().Error(), "SQLSTATE 08P01") { // CockroachDB returns protocol_violation instead of invalid_text_representation t.Error("Expected rows.Err() to return different error:", rows.Err()) } } else { if !strings.Contains(rows.Err().Error(), "SQLSTATE 22P02") { t.Error("Expected rows.Err() to return different error:", rows.Err()) } } } func TestQueryRowCoreTypes(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) type allTypes struct { s string f32 float32 f64 float64 b bool t time.Time oid uint32 } var actual, zero allTypes tests := []struct { sql string queryArgs []interface{} scanArgs []interface{} expected allTypes }{ {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.s}, allTypes{s: "Jack"}}, {"select $1::float4", []interface{}{float32(1.23)}, []interface{}{&actual.f32}, allTypes{f32: 1.23}}, {"select $1::float8", []interface{}{float64(1.23)}, []interface{}{&actual.f64}, allTypes{f64: 1.23}}, {"select $1::bool", []interface{}{true}, []interface{}{&actual.b}, allTypes{b: true}}, {"select $1::timestamptz", []interface{}{time.Unix(123, 5000)}, []interface{}{&actual.t}, allTypes{t: time.Unix(123, 5000)}}, {"select $1::timestamp", []interface{}{time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(2010, 1, 2, 3, 4, 5, 0, time.UTC)}}, {"select $1::date", []interface{}{time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}, []interface{}{&actual.t}, allTypes{t: time.Date(1987, 1, 2, 0, 0, 0, 0, time.UTC)}}, {"select $1::oid", []interface{}{uint32(42)}, []interface{}{&actual.oid}, allTypes{oid: 42}}, } for i, tt := range tests { actual = zero err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) } if actual.s != tt.expected.s || actual.f32 != tt.expected.f32 || actual.b != tt.expected.b || !actual.t.Equal(tt.expected.t) || actual.oid != tt.expected.oid { t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) } ensureConnValid(t, conn) // Check that Scan errors when a core type is null err = conn.QueryRow(context.Background(), tt.sql, nil).Scan(tt.scanArgs...) if err == nil { t.Errorf("%d. Expected null to cause error, but it didn't (sql -> %v)", i, tt.sql) } ensureConnValid(t, conn) } } func TestQueryRowCoreIntegerEncoding(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) type allTypes struct { ui uint ui8 uint8 ui16 uint16 ui32 uint32 ui64 uint64 i int i8 int8 i16 int16 i32 int32 i64 int64 } var actual, zero allTypes successfulEncodeTests := []struct { sql string queryArg interface{} scanArg interface{} expected allTypes }{ // Check any integer type where value is within int2 range can be encoded {"select $1::int2", int(42), &actual.i16, allTypes{i16: 42}}, {"select $1::int2", int8(42), &actual.i16, allTypes{i16: 42}}, {"select $1::int2", int16(42), &actual.i16, allTypes{i16: 42}}, {"select $1::int2", int32(42), &actual.i16, allTypes{i16: 42}}, {"select $1::int2", int64(42), &actual.i16, allTypes{i16: 42}}, {"select $1::int2", uint(42), &actual.i16, allTypes{i16: 42}}, {"select $1::int2", uint8(42), &actual.i16, allTypes{i16: 42}}, {"select $1::int2", uint16(42), &actual.i16, allTypes{i16: 42}}, {"select $1::int2", uint32(42), &actual.i16, allTypes{i16: 42}}, {"select $1::int2", uint64(42), &actual.i16, allTypes{i16: 42}}, // Check any integer type where value is within int4 range can be encoded {"select $1::int4", int(42), &actual.i32, allTypes{i32: 42}}, {"select $1::int4", int8(42), &actual.i32, allTypes{i32: 42}}, {"select $1::int4", int16(42), &actual.i32, allTypes{i32: 42}}, {"select $1::int4", int32(42), &actual.i32, allTypes{i32: 42}}, {"select $1::int4", int64(42), &actual.i32, allTypes{i32: 42}}, {"select $1::int4", uint(42), &actual.i32, allTypes{i32: 42}}, {"select $1::int4", uint8(42), &actual.i32, allTypes{i32: 42}}, {"select $1::int4", uint16(42), &actual.i32, allTypes{i32: 42}}, {"select $1::int4", uint32(42), &actual.i32, allTypes{i32: 42}}, {"select $1::int4", uint64(42), &actual.i32, allTypes{i32: 42}}, // Check any integer type where value is within int8 range can be encoded {"select $1::int8", int(42), &actual.i64, allTypes{i64: 42}}, {"select $1::int8", int8(42), &actual.i64, allTypes{i64: 42}}, {"select $1::int8", int16(42), &actual.i64, allTypes{i64: 42}}, {"select $1::int8", int32(42), &actual.i64, allTypes{i64: 42}}, {"select $1::int8", int64(42), &actual.i64, allTypes{i64: 42}}, {"select $1::int8", uint(42), &actual.i64, allTypes{i64: 42}}, {"select $1::int8", uint8(42), &actual.i64, allTypes{i64: 42}}, {"select $1::int8", uint16(42), &actual.i64, allTypes{i64: 42}}, {"select $1::int8", uint32(42), &actual.i64, allTypes{i64: 42}}, {"select $1::int8", uint64(42), &actual.i64, allTypes{i64: 42}}, } for i, tt := range successfulEncodeTests { actual = zero err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(tt.scanArg) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg) continue } if actual != tt.expected { t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArg -> %v)", i, tt.expected, actual, tt.sql, tt.queryArg) } ensureConnValid(t, conn) } failedEncodeTests := []struct { sql string queryArg interface{} }{ // Check any integer type where value is outside pg:int2 range cannot be encoded {"select $1::int2", int(32769)}, {"select $1::int2", int32(32769)}, {"select $1::int2", int32(32769)}, {"select $1::int2", int64(32769)}, {"select $1::int2", uint(32769)}, {"select $1::int2", uint16(32769)}, {"select $1::int2", uint32(32769)}, {"select $1::int2", uint64(32769)}, // Check any integer type where value is outside pg:int4 range cannot be encoded {"select $1::int4", int64(2147483649)}, {"select $1::int4", uint32(2147483649)}, {"select $1::int4", uint64(2147483649)}, // Check any integer type where value is outside pg:int8 range cannot be encoded {"select $1::int8", uint64(9223372036854775809)}, } for i, tt := range failedEncodeTests { err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(nil) if err == nil { t.Errorf("%d. Expected failure to encode, but unexpectedly succeeded: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg) } else if !strings.Contains(err.Error(), "is greater than") { t.Errorf("%d. Expected failure to encode, but got: %v (sql -> %v, queryArg -> %v)", i, err, tt.sql, tt.queryArg) } ensureConnValid(t, conn) } } func TestQueryRowCoreIntegerDecoding(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) type allTypes struct { ui uint ui8 uint8 ui16 uint16 ui32 uint32 ui64 uint64 i int i8 int8 i16 int16 i32 int32 i64 int64 } var actual, zero allTypes successfulDecodeTests := []struct { sql string scanArg interface{} expected allTypes }{ // Check any integer type where value is within Go:int range can be decoded {"select 42::int2", &actual.i, allTypes{i: 42}}, {"select 42::int4", &actual.i, allTypes{i: 42}}, {"select 42::int8", &actual.i, allTypes{i: 42}}, {"select -42::int2", &actual.i, allTypes{i: -42}}, {"select -42::int4", &actual.i, allTypes{i: -42}}, {"select -42::int8", &actual.i, allTypes{i: -42}}, // Check any integer type where value is within Go:int8 range can be decoded {"select 42::int2", &actual.i8, allTypes{i8: 42}}, {"select 42::int4", &actual.i8, allTypes{i8: 42}}, {"select 42::int8", &actual.i8, allTypes{i8: 42}}, {"select -42::int2", &actual.i8, allTypes{i8: -42}}, {"select -42::int4", &actual.i8, allTypes{i8: -42}}, {"select -42::int8", &actual.i8, allTypes{i8: -42}}, // Check any integer type where value is within Go:int16 range can be decoded {"select 42::int2", &actual.i16, allTypes{i16: 42}}, {"select 42::int4", &actual.i16, allTypes{i16: 42}}, {"select 42::int8", &actual.i16, allTypes{i16: 42}}, {"select -42::int2", &actual.i16, allTypes{i16: -42}}, {"select -42::int4", &actual.i16, allTypes{i16: -42}}, {"select -42::int8", &actual.i16, allTypes{i16: -42}}, // Check any integer type where value is within Go:int32 range can be decoded {"select 42::int2", &actual.i32, allTypes{i32: 42}}, {"select 42::int4", &actual.i32, allTypes{i32: 42}}, {"select 42::int8", &actual.i32, allTypes{i32: 42}}, {"select -42::int2", &actual.i32, allTypes{i32: -42}}, {"select -42::int4", &actual.i32, allTypes{i32: -42}}, {"select -42::int8", &actual.i32, allTypes{i32: -42}}, // Check any integer type where value is within Go:int64 range can be decoded {"select 42::int2", &actual.i64, allTypes{i64: 42}}, {"select 42::int4", &actual.i64, allTypes{i64: 42}}, {"select 42::int8", &actual.i64, allTypes{i64: 42}}, {"select -42::int2", &actual.i64, allTypes{i64: -42}}, {"select -42::int4", &actual.i64, allTypes{i64: -42}}, {"select -42::int8", &actual.i64, allTypes{i64: -42}}, // Check any integer type where value is within Go:uint range can be decoded {"select 128::int2", &actual.ui, allTypes{ui: 128}}, {"select 128::int4", &actual.ui, allTypes{ui: 128}}, {"select 128::int8", &actual.ui, allTypes{ui: 128}}, // Check any integer type where value is within Go:uint8 range can be decoded {"select 128::int2", &actual.ui8, allTypes{ui8: 128}}, {"select 128::int4", &actual.ui8, allTypes{ui8: 128}}, {"select 128::int8", &actual.ui8, allTypes{ui8: 128}}, // Check any integer type where value is within Go:uint16 range can be decoded {"select 42::int2", &actual.ui16, allTypes{ui16: 42}}, {"select 32768::int4", &actual.ui16, allTypes{ui16: 32768}}, {"select 32768::int8", &actual.ui16, allTypes{ui16: 32768}}, // Check any integer type where value is within Go:uint32 range can be decoded {"select 42::int2", &actual.ui32, allTypes{ui32: 42}}, {"select 42::int4", &actual.ui32, allTypes{ui32: 42}}, {"select 2147483648::int8", &actual.ui32, allTypes{ui32: 2147483648}}, // Check any integer type where value is within Go:uint64 range can be decoded {"select 42::int2", &actual.ui64, allTypes{ui64: 42}}, {"select 42::int4", &actual.ui64, allTypes{ui64: 42}}, {"select 42::int8", &actual.ui64, allTypes{ui64: 42}}, } for i, tt := range successfulDecodeTests { actual = zero err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) continue } if actual != tt.expected { t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) } ensureConnValid(t, conn) } failedDecodeTests := []struct { sql string scanArg interface{} expectedErr string }{ // Check any integer type where value is outside Go:int8 range cannot be decoded {"select 128::int2", &actual.i8, "is greater than"}, {"select 128::int4", &actual.i8, "is greater than"}, {"select 128::int8", &actual.i8, "is greater than"}, {"select -129::int2", &actual.i8, "is less than"}, {"select -129::int4", &actual.i8, "is less than"}, {"select -129::int8", &actual.i8, "is less than"}, // Check any integer type where value is outside Go:int16 range cannot be decoded {"select 32768::int4", &actual.i16, "is greater than"}, {"select 32768::int8", &actual.i16, "is greater than"}, {"select -32769::int4", &actual.i16, "is less than"}, {"select -32769::int8", &actual.i16, "is less than"}, // Check any integer type where value is outside Go:int32 range cannot be decoded {"select 2147483648::int8", &actual.i32, "is greater than"}, {"select -2147483649::int8", &actual.i32, "is less than"}, // Check any integer type where value is outside Go:uint range cannot be decoded {"select -1::int2", &actual.ui, "is less than"}, {"select -1::int4", &actual.ui, "is less than"}, {"select -1::int8", &actual.ui, "is less than"}, // Check any integer type where value is outside Go:uint8 range cannot be decoded {"select 256::int2", &actual.ui8, "is greater than"}, {"select 256::int4", &actual.ui8, "is greater than"}, {"select 256::int8", &actual.ui8, "is greater than"}, {"select -1::int2", &actual.ui8, "is less than"}, {"select -1::int4", &actual.ui8, "is less than"}, {"select -1::int8", &actual.ui8, "is less than"}, // Check any integer type where value is outside Go:uint16 cannot be decoded {"select 65536::int4", &actual.ui16, "is greater than"}, {"select 65536::int8", &actual.ui16, "is greater than"}, {"select -1::int2", &actual.ui16, "is less than"}, {"select -1::int4", &actual.ui16, "is less than"}, {"select -1::int8", &actual.ui16, "is less than"}, // Check any integer type where value is outside Go:uint32 range cannot be decoded {"select 4294967296::int8", &actual.ui32, "is greater than"}, {"select -1::int2", &actual.ui32, "is less than"}, {"select -1::int4", &actual.ui32, "is less than"}, {"select -1::int8", &actual.ui32, "is less than"}, // Check any integer type where value is outside Go:uint64 range cannot be decoded {"select -1::int2", &actual.ui64, "is less than"}, {"select -1::int4", &actual.ui64, "is less than"}, {"select -1::int8", &actual.ui64, "is less than"}, } for i, tt := range failedDecodeTests { err := conn.QueryRow(context.Background(), tt.sql).Scan(tt.scanArg) if err == nil { t.Errorf("%d. Expected failure to decode, but unexpectedly succeeded: %v (sql -> %v)", i, err, tt.sql) } else if !strings.Contains(err.Error(), tt.expectedErr) { t.Errorf("%d. Expected failure to decode, but got: %v (sql -> %v)", i, err, tt.sql) } ensureConnValid(t, conn) } } func TestQueryRowCoreByteSlice(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) tests := []struct { sql string queryArg interface{} expected []byte }{ {"select $1::text", "Jack", []byte("Jack")}, {"select $1::text", []byte("Jack"), []byte("Jack")}, {"select $1::varchar", []byte("Jack"), []byte("Jack")}, {"select $1::bytea", []byte{0, 15, 255, 17}, []byte{0, 15, 255, 17}}, } for i, tt := range tests { var actual []byte err := conn.QueryRow(context.Background(), tt.sql, tt.queryArg).Scan(&actual) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) } if !bytes.Equal(actual, tt.expected) { t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.expected, actual, tt.sql) } ensureConnValid(t, conn) } } func TestQueryRowErrors(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) type allTypes struct { i16 int16 i int s string } var actual, zero allTypes tests := []struct { sql string queryArgs []interface{} scanArgs []interface{} err string }{ // {"select $1::badtype", []interface{}{"Jack"}, []interface{}{&actual.i16}, `type "badtype" does not exist`}, // {"SYNTAX ERROR", []interface{}{}, []interface{}{&actual.i16}, "SQLSTATE 42601"}, {"select $1::text", []interface{}{"Jack"}, []interface{}{&actual.i16}, "unable to assign"}, // {"select $1::point", []interface{}{int(705)}, []interface{}{&actual.s}, "cannot convert 705 to Point"}, } for i, tt := range tests { actual = zero err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) if err == nil { t.Errorf("%d. Unexpected success (sql -> %v, queryArgs -> %v)", i, tt.sql, tt.queryArgs) } if err != nil && !strings.Contains(err.Error(), tt.err) { t.Errorf("%d. Expected error to contain %s, but got %v (sql -> %v, queryArgs -> %v)", i, tt.err, err, tt.sql, tt.queryArgs) } ensureConnValid(t, conn) } } func TestQueryRowNoResults(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var n int32 err := conn.QueryRow(context.Background(), "select 1 where 1=0").Scan(&n) if err != pgx.ErrNoRows { t.Errorf("Expected pgx.ErrNoRows, got %v", err) } ensureConnValid(t, conn) } func TestQueryRowEmptyQuery(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() var n int32 err := conn.QueryRow(ctx, "").Scan(&n) require.Error(t, err) require.False(t, pgconn.Timeout(err)) ensureConnValid(t, conn) } func TestReadingValueAfterEmptyArray(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var a []string var b int32 err := conn.QueryRow(context.Background(), "select '{}'::text[], 42::integer").Scan(&a, &b) if err != nil { t.Fatalf("conn.QueryRow failed: %v", err) } if len(a) != 0 { t.Errorf("Expected 'a' to have length 0, but it was: %d", len(a)) } if b != 42 { t.Errorf("Expected 'b' to 42, but it was: %d", b) } } func TestReadingNullByteArray(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var a []byte err := conn.QueryRow(context.Background(), "select null::text").Scan(&a) if err != nil { t.Fatalf("conn.QueryRow failed: %v", err) } if a != nil { t.Errorf("Expected 'a' to be nil, but it was: %v", a) } } func TestReadingNullByteArrays(t *testing.T) { conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) rows, err := conn.Query(context.Background(), "select null::text union all select null::text") if err != nil { t.Fatalf("conn.Query failed: %v", err) } count := 0 for rows.Next() { count++ var a []byte if err := rows.Scan(&a); err != nil { t.Fatalf("failed to scan row: %v", err) } if a != nil { t.Errorf("Expected 'a' to be nil, but it was: %v", a) } } if count != 2 { t.Errorf("Expected to read 2 rows, read: %d", count) } } // Use github.com/shopspring/decimal as real-world database/sql custom type // to test against. func TestConnQueryDatabaseSQLScanner(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) var num decimal.Decimal err := conn.QueryRow(context.Background(), "select '1234.567'::decimal").Scan(&num) if err != nil { t.Fatalf("Scan failed: %v", err) } expected, err := decimal.NewFromString("1234.567") if err != nil { t.Fatal(err) } if !num.Equals(expected) { t.Errorf("Expected num to be %v, but it was %v", expected, num) } ensureConnValid(t, conn) } // Use github.com/shopspring/decimal as real-world database/sql custom type // to test against. func TestConnQueryDatabaseSQLDriverValuer(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) expected, err := decimal.NewFromString("1234.567") if err != nil { t.Fatal(err) } var num decimal.Decimal err = conn.QueryRow(context.Background(), "select $1::decimal", &expected).Scan(&num) if err != nil { t.Fatalf("Scan failed: %v", err) } if !num.Equals(expected) { t.Errorf("Expected num to be %v, but it was %v", expected, num) } ensureConnValid(t, conn) } // https://github.com/jackc/pgx/issues/339 func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) mustExec(t, conn, "create temporary table t(n numeric)") var d *apd.Decimal commandTag, err := conn.Exec(context.Background(), `insert into t(n) values($1)`, d) if err != nil { t.Fatal(err) } if string(commandTag) != "INSERT 0 1" { t.Fatalf("want %s, got %s", "INSERT 0 1", commandTag) } ensureConnValid(t, conn) } func TestConnQueryDatabaseSQLDriverValuerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) expected, err := uuid.FromString("6ba7b810-9dad-11d1-80b4-00c04fd430c8") if err != nil { t.Fatal(err) } var u2 uuid.UUID err = conn.QueryRow(context.Background(), "select $1::uuid", expected).Scan(&u2) if err != nil { t.Fatalf("Scan failed: %v", err) } if expected != u2 { t.Errorf("Expected u2 to be %v, but it was %v", expected, u2) } ensureConnValid(t, conn) } func TestConnQueryDatabaseSQLNullX(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) type row struct { boolValid sql.NullBool boolNull sql.NullBool int64Valid sql.NullInt64 int64Null sql.NullInt64 float64Valid sql.NullFloat64 float64Null sql.NullFloat64 stringValid sql.NullString stringNull sql.NullString } expected := row{ boolValid: sql.NullBool{Bool: true, Valid: true}, int64Valid: sql.NullInt64{Int64: 123, Valid: true}, float64Valid: sql.NullFloat64{Float64: 3.14, Valid: true}, stringValid: sql.NullString{String: "pgx", Valid: true}, } var actual row err := conn.QueryRow( context.Background(), "select $1::bool, $2::bool, $3::int8, $4::int8, $5::float8, $6::float8, $7::text, $8::text", expected.boolValid, expected.boolNull, expected.int64Valid, expected.int64Null, expected.float64Valid, expected.float64Null, expected.stringValid, expected.stringNull, ).Scan( &actual.boolValid, &actual.boolNull, &actual.int64Valid, &actual.int64Null, &actual.float64Valid, &actual.float64Null, &actual.stringValid, &actual.stringNull, ) if err != nil { t.Fatalf("Scan failed: %v", err) } if expected != actual { t.Errorf("Expected %v, but got %v", expected, actual) } ensureConnValid(t, conn) } func TestQueryContextSuccess(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() rows, err := conn.Query(ctx, "select 42::integer") if err != nil { t.Fatal(err) } var result, rowCount int for rows.Next() { err = rows.Scan(&result) if err != nil { t.Fatal(err) } rowCount++ } if rows.Err() != nil { t.Fatal(rows.Err()) } if rowCount != 1 { t.Fatalf("Expected 1 row, got %d", rowCount) } if result != 42 { t.Fatalf("Expected result 42, got %d", result) } ensureConnValid(t, conn) } func TestQueryContextErrorWhileReceivingRows(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server uses numeric instead of int") ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() rows, err := conn.Query(ctx, "select 10/(10-n) from generate_series(1, 100) n") if err != nil { t.Fatal(err) } var result, rowCount int for rows.Next() { err = rows.Scan(&result) if err != nil { t.Fatal(err) } rowCount++ } if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" { t.Fatalf("Expected division by zero error, but got %v", rows.Err()) } if rowCount != 9 { t.Fatalf("Expected 9 rows, got %d", rowCount) } if result != 10 { t.Fatalf("Expected result 10, got %d", result) } ensureConnValid(t, conn) } func TestQueryRowContextSuccess(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() var result int err := conn.QueryRow(ctx, "select 42::integer").Scan(&result) if err != nil { t.Fatal(err) } if result != 42 { t.Fatalf("Expected result 42, got %d", result) } ensureConnValid(t, conn) } func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() var result int err := conn.QueryRow(ctx, "select 10/0").Scan(&result) if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" { t.Fatalf("Expected division by zero error, but got %v", err) } ensureConnValid(t, conn) } func TestQueryCloseBefore(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) closeConn(t, conn) _, err := conn.Query(context.Background(), "select 1") require.Error(t, err) assert.True(t, pgconn.SafeToRetry(err)) } func TestScanRow(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) resultReader := conn.PgConn().ExecParams(context.Background(), "select generate_series(1,$1)", [][]byte{[]byte("10")}, nil, nil, nil) var sum, rowCount int32 for resultReader.NextRow() { var n int32 err := pgx.ScanRow(conn.ConnInfo(), resultReader.FieldDescriptions(), resultReader.Values(), &n) assert.NoError(t, err) sum += n rowCount++ } _, err := resultReader.Close() require.NoError(t, err) assert.EqualValues(t, 10, rowCount) assert.EqualValues(t, 55, sum) } func TestConnSimpleProtocol(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Test all supported low-level types { expected := int64(42) var actual int64 err := conn.QueryRow( context.Background(), "select $1::int8", pgx.QuerySimpleProtocol(true), expected, ).Scan(&actual) if err != nil { t.Error(err) } if expected != actual { t.Errorf("expected %v got %v", expected, actual) } } { expected := float64(1.23) var actual float64 err := conn.QueryRow( context.Background(), "select $1::float8", pgx.QuerySimpleProtocol(true), expected, ).Scan(&actual) if err != nil { t.Error(err) } if expected != actual { t.Errorf("expected %v got %v", expected, actual) } } { expected := true var actual bool err := conn.QueryRow( context.Background(), "select $1", pgx.QuerySimpleProtocol(true), expected, ).Scan(&actual) if err != nil { t.Error(err) } if expected != actual { t.Errorf("expected %v got %v", expected, actual) } } { expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95} var actual []byte err := conn.QueryRow( context.Background(), "select $1::bytea", pgx.QuerySimpleProtocol(true), expected, ).Scan(&actual) if err != nil { t.Error(err) } if bytes.Compare(actual, expected) != 0 { t.Errorf("expected %v got %v", expected, actual) } } { expected := "test" var actual string err := conn.QueryRow( context.Background(), "select $1::text", pgx.QuerySimpleProtocol(true), expected, ).Scan(&actual) if err != nil { t.Error(err) } if expected != actual { t.Errorf("expected %v got %v", expected, actual) } } { tests := []struct { expected []string }{ {[]string(nil)}, {[]string{}}, {[]string{"test", "foo", "bar"}}, {[]string{`foo'bar"\baz;quz`, `foo'bar"\baz;quz`}}, } for i, tt := range tests { var actual []string err := conn.QueryRow( context.Background(), "select $1::text[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []int16 }{ {[]int16(nil)}, {[]int16{}}, {[]int16{1, 2, 3}}, } for i, tt := range tests { var actual []int16 err := conn.QueryRow( context.Background(), "select $1::smallint[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []int32 }{ {[]int32(nil)}, {[]int32{}}, {[]int32{1, 2, 3}}, } for i, tt := range tests { var actual []int32 err := conn.QueryRow( context.Background(), "select $1::int[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []int64 }{ {[]int64(nil)}, {[]int64{}}, {[]int64{1, 2, 3}}, } for i, tt := range tests { var actual []int64 err := conn.QueryRow( context.Background(), "select $1::bigint[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []int }{ {[]int(nil)}, {[]int{}}, {[]int{1, 2, 3}}, } for i, tt := range tests { var actual []int err := conn.QueryRow( context.Background(), "select $1::bigint[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []uint16 }{ {[]uint16(nil)}, {[]uint16{}}, {[]uint16{1, 2, 3}}, } for i, tt := range tests { var actual []uint16 err := conn.QueryRow( context.Background(), "select $1::smallint[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []uint32 }{ {[]uint32(nil)}, {[]uint32{}}, {[]uint32{1, 2, 3}}, } for i, tt := range tests { var actual []uint32 err := conn.QueryRow( context.Background(), "select $1::bigint[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []uint64 }{ {[]uint64(nil)}, {[]uint64{}}, {[]uint64{1, 2, 3}}, } for i, tt := range tests { var actual []uint64 err := conn.QueryRow( context.Background(), "select $1::bigint[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []uint }{ {[]uint(nil)}, {[]uint{}}, {[]uint{1, 2, 3}}, } for i, tt := range tests { var actual []uint err := conn.QueryRow( context.Background(), "select $1::bigint[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []float32 }{ {[]float32(nil)}, {[]float32{}}, {[]float32{1, 2, 3}}, } for i, tt := range tests { var actual []float32 err := conn.QueryRow( context.Background(), "select $1::float4[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } { tests := []struct { expected []float64 }{ {[]float64(nil)}, {[]float64{}}, {[]float64{1, 2, 3}}, } for i, tt := range tests { var actual []float64 err := conn.QueryRow( context.Background(), "select $1::float8[]", pgx.QuerySimpleProtocol(true), tt.expected, ).Scan(&actual) assert.NoErrorf(t, err, "%d", i) assert.Equalf(t, tt.expected, actual, "%d", i) } } // Test high-level type { if conn.PgConn().ParameterStatus("crdb_version") == "" { // CockroachDB doesn't support circle type. expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present} actual := expected err := conn.QueryRow( context.Background(), "select $1::circle", pgx.QuerySimpleProtocol(true), &expected, ).Scan(&actual) if err != nil { t.Error(err) } if expected != actual { t.Errorf("expected %v got %v", expected, actual) } } } // Test multiple args in single query { expectedInt64 := int64(234423) expectedFloat64 := float64(-0.2312) expectedBool := true expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223} expectedString := "test" var actualInt64 int64 var actualFloat64 float64 var actualBool bool var actualBytes []byte var actualString string err := conn.QueryRow( context.Background(), "select $1::int8, $2::float8, $3, $4::bytea, $5::text", pgx.QuerySimpleProtocol(true), expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) if err != nil { t.Error(err) } if expectedInt64 != actualInt64 { t.Errorf("expected %v got %v", expectedInt64, actualInt64) } if expectedFloat64 != actualFloat64 { t.Errorf("expected %v got %v", expectedFloat64, actualFloat64) } if expectedBool != actualBool { t.Errorf("expected %v got %v", expectedBool, actualBool) } if bytes.Compare(expectedBytes, actualBytes) != 0 { t.Errorf("expected %v got %v", expectedBytes, actualBytes) } if expectedString != actualString { t.Errorf("expected %v got %v", expectedString, actualString) } } // Test dangerous cases { expected := "foo';drop table users;" var actual string err := conn.QueryRow( context.Background(), "select $1", pgx.QuerySimpleProtocol(true), expected, ).Scan(&actual) if err != nil { t.Error(err) } if expected != actual { t.Errorf("expected %v got %v", expected, actual) } } ensureConnValid(t, conn) } func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server does not support changing client_encoding (https://www.cockroachlabs.com/docs/stable/set-vars.html)") mustExec(t, conn, "set client_encoding to 'SQL_ASCII'") var expected string err := conn.QueryRow( context.Background(), "select $1", pgx.QuerySimpleProtocol(true), "test", ).Scan(&expected) if err == nil { t.Error("expected error when client_encoding not UTF8, but no error occurred") } ensureConnValid(t, conn) } func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server does not support standard_conforming_strings = off (https://github.com/cockroachdb/cockroach/issues/36215)") mustExec(t, conn, "set standard_conforming_strings to off") var expected string err := conn.QueryRow( context.Background(), "select $1", pgx.QuerySimpleProtocol(true), `\'; drop table users; --`, ).Scan(&expected) if err == nil { t.Error("expected error when standard_conforming_strings is off, but no error occurred") } ensureConnValid(t, conn) } func TestQueryStatementCacheModes(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) tests := []struct { name string buildStatementCache pgx.BuildStatementCacheFunc }{ { name: "disabled", buildStatementCache: nil, }, { name: "prepare", buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModePrepare, 32) }, }, { name: "describe", buildStatementCache: func(conn *pgconn.PgConn) stmtcache.Cache { return stmtcache.New(conn, stmtcache.ModeDescribe, 32) }, }, } for _, tt := range tests { func() { config.BuildStatementCache = tt.buildStatementCache conn := mustConnect(t, config) defer closeConn(t, conn) var n int err := conn.QueryRow(context.Background(), "select 1").Scan(&n) assert.NoError(t, err, tt.name) assert.Equal(t, 1, n, tt.name) err = conn.QueryRow(context.Background(), "select 2").Scan(&n) assert.NoError(t, err, tt.name) assert.Equal(t, 2, n, tt.name) err = conn.QueryRow(context.Background(), "select 1").Scan(&n) assert.NoError(t, err, tt.name) assert.Equal(t, 1, n, tt.name) ensureConnValid(t, conn) }() } } // https://github.com/jackc/pgx/issues/895 func TestQueryErrorWithNilStatementCacheMode(t *testing.T) { t.Parallel() config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config.BuildStatementCache = nil conn := mustConnect(t, config) defer closeConn(t, conn) _, err := conn.Exec(context.Background(), "create temporary table t_unq(id text primary key);") require.NoError(t, err) _, err = conn.Exec(context.Background(), "insert into t_unq (id) values ($1)", "abc") require.NoError(t, err) rows, err := conn.Query(context.Background(), "insert into t_unq (id) values ($1)", "abc") require.NoError(t, err) rows.Close() err = rows.Err() require.Error(t, err) var pgErr *pgconn.PgError if errors.As(err, &pgErr) { assert.Equal(t, "23505", pgErr.Code) } else { t.Errorf("err is not a *pgconn.PgError: %T", err) } ensureConnValid(t, conn) } func TestConnQueryFunc(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { var actualResults []interface{} var a, b int ct, err := conn.QueryFunc( context.Background(), "select n, n * 2 from generate_series(1, $1) n", []interface{}{3}, []interface{}{&a, &b}, func(pgx.QueryFuncRow) error { actualResults = append(actualResults, []interface{}{a, b}) return nil }, ) require.NoError(t, err) expectedResults := []interface{}{ []interface{}{1, 2}, []interface{}{2, 4}, []interface{}{3, 6}, } require.Equal(t, expectedResults, actualResults) require.EqualValues(t, 3, ct.RowsAffected()) }) } func TestConnQueryFuncScanError(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { var actualResults []interface{} var a, b int ct, err := conn.QueryFunc( context.Background(), "select 'foo', 'bar' from generate_series(1, $1) n", []interface{}{3}, []interface{}{&a, &b}, func(pgx.QueryFuncRow) error { actualResults = append(actualResults, []interface{}{a, b}) return nil }, ) require.EqualError(t, err, "can't scan into dest[0]: unable to assign to *int") require.Nil(t, ct) }) } func TestConnQueryFuncAbort(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { var a, b int ct, err := conn.QueryFunc( context.Background(), "select n, n * 2 from generate_series(1, $1) n", []interface{}{3}, []interface{}{&a, &b}, func(pgx.QueryFuncRow) error { return errors.New("abort") }, ) require.EqualError(t, err, "abort") require.Nil(t, ct) }) } func ExampleConn_QueryFunc() { conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) if err != nil { fmt.Printf("Unable to establish connection: %v", err) return } var a, b int _, err = conn.QueryFunc( context.Background(), "select n, n * 2 from generate_series(1, $1) n", []interface{}{3}, []interface{}{&a, &b}, func(pgx.QueryFuncRow) error { fmt.Printf("%v, %v\n", a, b) return nil }, ) if err != nil { fmt.Printf("QueryFunc error: %v", err) return } // Output: // 1, 2 // 2, 4 // 3, 6 } pgx-4.18.1/rows.go000066400000000000000000000230101437725773200137410ustar00rootroot00000000000000package pgx import ( "context" "errors" "fmt" "time" "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgtype" ) // Rows is the result set returned from *Conn.Query. Rows must be closed before // the *Conn can be used again. Rows are closed by explicitly calling Close(), // calling Next() until it returns false, or when a fatal error occurs. // // Once a Rows is closed the only methods that may be called are Close(), Err(), and CommandTag(). // // Rows is an interface instead of a struct to allow tests to mock Query. However, // adding a method to an interface is technically a breaking change. Because of this // the Rows interface is partially excluded from semantic version requirements. // Methods will not be removed or changed, but new methods may be added. type Rows interface { // Close closes the rows, making the connection ready for use again. It is safe // to call Close after rows is already closed. Close() // Err returns any error that occurred while reading. Err() error // CommandTag returns the command tag from this query. It is only available after Rows is closed. CommandTag() pgconn.CommandTag FieldDescriptions() []pgproto3.FieldDescription // Next prepares the next row for reading. It returns true if there is another // row and false if no more rows are available. It automatically closes rows // when all rows are read. Next() bool // Scan reads the values from the current row into dest values positionally. // dest can include pointers to core types, values implementing the Scanner // interface, and nil. nil will skip the value entirely. It is an error to // call Scan without first calling Next() and checking that it returned true. Scan(dest ...interface{}) error // Values returns the decoded row values. As with Scan(), it is an error to // call Values without first calling Next() and checking that it returned // true. Values() ([]interface{}, error) // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid until the next Next // call or the Rows is closed. However, the underlying byte data is safe to retain a reference to and mutate. RawValues() [][]byte } // Row is a convenience wrapper over Rows that is returned by QueryRow. // // Row is an interface instead of a struct to allow tests to mock QueryRow. However, // adding a method to an interface is technically a breaking change. Because of this // the Row interface is partially excluded from semantic version requirements. // Methods will not be removed or changed, but new methods may be added. type Row interface { // Scan works the same as Rows. with the following exceptions. If no // rows were found it returns ErrNoRows. If multiple rows are returned it // ignores all but the first. Scan(dest ...interface{}) error } // connRow implements the Row interface for Conn.QueryRow. type connRow connRows func (r *connRow) Scan(dest ...interface{}) (err error) { rows := (*connRows)(r) if rows.Err() != nil { return rows.Err() } if !rows.Next() { if rows.Err() == nil { return ErrNoRows } return rows.Err() } rows.Scan(dest...) rows.Close() return rows.Err() } type rowLog interface { shouldLog(lvl LogLevel) bool log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) } // connRows implements the Rows interface for Conn.Query. type connRows struct { ctx context.Context logger rowLog connInfo *pgtype.ConnInfo values [][]byte rowCount int err error commandTag pgconn.CommandTag startTime time.Time sql string args []interface{} closed bool conn *Conn resultReader *pgconn.ResultReader multiResultReader *pgconn.MultiResultReader scanPlans []pgtype.ScanPlan } func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription { return rows.resultReader.FieldDescriptions() } func (rows *connRows) Close() { if rows.closed { return } rows.closed = true if rows.resultReader != nil { var closeErr error rows.commandTag, closeErr = rows.resultReader.Close() if rows.err == nil { rows.err = closeErr } } if rows.multiResultReader != nil { closeErr := rows.multiResultReader.Close() if rows.err == nil { rows.err = closeErr } } if rows.logger != nil { endTime := time.Now() if rows.err == nil { if rows.logger.shouldLog(LogLevelInfo) { rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) } } else { if rows.logger.shouldLog(LogLevelError) { rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)}) } if rows.err != nil && rows.conn.stmtcache != nil { rows.conn.stmtcache.StatementErrored(rows.sql, rows.err) } } } } func (rows *connRows) CommandTag() pgconn.CommandTag { return rows.commandTag } func (rows *connRows) Err() error { return rows.err } // fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. func (rows *connRows) fatal(err error) { if rows.err != nil { return } rows.err = err rows.Close() } func (rows *connRows) Next() bool { if rows.closed { return false } if rows.resultReader.NextRow() { rows.rowCount++ rows.values = rows.resultReader.Values() return true } else { rows.Close() return false } } func (rows *connRows) Scan(dest ...interface{}) error { ci := rows.connInfo fieldDescriptions := rows.FieldDescriptions() values := rows.values if len(fieldDescriptions) != len(values) { err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) rows.fatal(err) return err } if len(fieldDescriptions) != len(dest) { err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) rows.fatal(err) return err } if rows.scanPlans == nil { rows.scanPlans = make([]pgtype.ScanPlan, len(values)) for i := range dest { rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) } } for i, dst := range dest { if dst == nil { continue } err := rows.scanPlans[i].Scan(ci, fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], dst) if err != nil { err = ScanArgError{ColumnIndex: i, Err: err} rows.fatal(err) return err } } return nil } func (rows *connRows) Values() ([]interface{}, error) { if rows.closed { return nil, errors.New("rows is closed") } values := make([]interface{}, 0, len(rows.FieldDescriptions())) for i := range rows.FieldDescriptions() { buf := rows.values[i] fd := &rows.FieldDescriptions()[i] if buf == nil { values = append(values, nil) continue } if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { value := dt.Value switch fd.Format { case TextFormatCode: decoder, ok := value.(pgtype.TextDecoder) if !ok { decoder = &pgtype.GenericText{} } err := decoder.DecodeText(rows.connInfo, buf) if err != nil { rows.fatal(err) } values = append(values, decoder.(pgtype.Value).Get()) case BinaryFormatCode: decoder, ok := value.(pgtype.BinaryDecoder) if !ok { decoder = &pgtype.GenericBinary{} } err := decoder.DecodeBinary(rows.connInfo, buf) if err != nil { rows.fatal(err) } values = append(values, value.Get()) default: rows.fatal(errors.New("Unknown format code")) } } else { switch fd.Format { case TextFormatCode: decoder := &pgtype.GenericText{} err := decoder.DecodeText(rows.connInfo, buf) if err != nil { rows.fatal(err) } values = append(values, decoder.Get()) case BinaryFormatCode: decoder := &pgtype.GenericBinary{} err := decoder.DecodeBinary(rows.connInfo, buf) if err != nil { rows.fatal(err) } values = append(values, decoder.Get()) default: rows.fatal(errors.New("Unknown format code")) } } if rows.Err() != nil { return nil, rows.Err() } } return values, rows.Err() } func (rows *connRows) RawValues() [][]byte { return rows.values } type ScanArgError struct { ColumnIndex int Err error } func (e ScanArgError) Error() string { return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) } func (e ScanArgError) Unwrap() error { return e.Err } // ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. // // connInfo - OID to Go type mapping. // fieldDescriptions - OID and format of values // values - the raw data as returned from the PostgreSQL server // dest - the destination that values will be decoded into func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error { if len(fieldDescriptions) != len(values) { return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) } if len(fieldDescriptions) != len(dest) { return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) } for i, d := range dest { if d == nil { continue } err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) if err != nil { return ScanArgError{ColumnIndex: i, Err: err} } } return nil } pgx-4.18.1/stdlib/000077500000000000000000000000001437725773200137055ustar00rootroot00000000000000pgx-4.18.1/stdlib/bench_test.go000066400000000000000000000046271437725773200163630ustar00rootroot00000000000000package stdlib_test import ( "database/sql" "fmt" "os" "strconv" "strings" "testing" "time" ) func getSelectRowsCounts(b *testing.B) []int64 { var rowCounts []int64 { s := os.Getenv("PGX_BENCH_SELECT_ROWS_COUNTS") if s != "" { for _, p := range strings.Split(s, " ") { n, err := strconv.ParseInt(p, 10, 64) if err != nil { b.Fatalf("Bad PGX_BENCH_SELECT_ROWS_COUNTS value: %v", err) } rowCounts = append(rowCounts, n) } } } if len(rowCounts) == 0 { rowCounts = []int64{1, 10, 100, 1000} } return rowCounts } type BenchRowSimple struct { ID int32 FirstName string LastName string Sex string BirthDate time.Time Weight int32 Height int32 UpdateTime time.Time } func BenchmarkSelectRowsScanSimple(b *testing.B) { db := openDB(b) defer closeDB(b, db) rowCounts := getSelectRowsCounts(b) for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { br := &BenchRowSimple{} for i := 0; i < b.N; i++ { rows, err := db.Query("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(1, $1) n", rowCount) if err != nil { b.Fatal(err) } for rows.Next() { rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) } if rows.Err() != nil { b.Fatal(rows.Err()) } } }) } } type BenchRowNull struct { ID sql.NullInt32 FirstName sql.NullString LastName sql.NullString Sex sql.NullString BirthDate sql.NullTime Weight sql.NullInt32 Height sql.NullInt32 UpdateTime sql.NullTime } func BenchmarkSelectRowsScanNull(b *testing.B) { db := openDB(b) defer closeDB(b, db) rowCounts := getSelectRowsCounts(b) for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { br := &BenchRowSimple{} for i := 0; i < b.N; i++ { rows, err := db.Query("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100000, 100000 + $1) n", rowCount) if err != nil { b.Fatal(err) } for rows.Next() { rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) } if rows.Err() != nil { b.Fatal(rows.Err()) } } }) } } pgx-4.18.1/stdlib/sql.go000066400000000000000000000574641437725773200150530ustar00rootroot00000000000000// Package stdlib is the compatibility layer from pgx to database/sql. // // A database/sql connection can be established through sql.Open. // // db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") // if err != nil { // return err // } // // Or from a DSN string. // // db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") // if err != nil { // return err // } // // Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the // pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used // with sql.Open. // // connConfig, _ := pgx.ParseConfig(os.Getenv("DATABASE_URL")) // connConfig.Logger = myLogger // connStr := stdlib.RegisterConnConfig(connConfig) // db, _ := sql.Open("pgx", connStr) // // pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. // It does not support named parameters. // // db.QueryRow("select * from users where id=$1", userID) // // In Go 1.13 and above (*sql.Conn) Raw() can be used to get a *pgx.Conn from the standard // database/sql.DB connection pool. This allows operations that use pgx specific functionality. // // // Given db is a *sql.DB // conn, err := db.Conn(context.Background()) // if err != nil { // // handle error from acquiring connection from DB pool // } // // err = conn.Raw(func(driverConn interface{}) error { // conn := driverConn.(*stdlib.Conn).Conn() // conn is a *pgx.Conn // // Do pgx specific stuff with conn // conn.CopyFrom(...) // return nil // }) // if err != nil { // // handle error that occurred while using *pgx.Conn // } package stdlib import ( "context" "database/sql" "database/sql/driver" "errors" "fmt" "io" "math" "math/rand" "reflect" "strconv" "strings" "sync" "time" "github.com/jackc/pgconn" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) // Only intrinsic types should be binary format with database/sql. var databaseSQLResultFormats pgx.QueryResultFormatsByOID var pgxDriver *Driver type ctxKey int var ctxKeyFakeTx ctxKey = 0 var ErrNotPgx = errors.New("not pgx *sql.DB") func init() { pgxDriver = &Driver{ configs: make(map[string]*pgx.ConnConfig), } fakeTxConns = make(map[*pgx.Conn]*sql.Tx) // if pgx driver was already registered by different pgx major version then we // skip registration under the default name. if !contains(sql.Drivers(), "pgx") { sql.Register("pgx", pgxDriver) } sql.Register("pgx/v4", pgxDriver) databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ pgtype.BoolOID: 1, pgtype.ByteaOID: 1, pgtype.CIDOID: 1, pgtype.DateOID: 1, pgtype.Float4OID: 1, pgtype.Float8OID: 1, pgtype.Int2OID: 1, pgtype.Int4OID: 1, pgtype.Int8OID: 1, pgtype.OIDOID: 1, pgtype.TimestampOID: 1, pgtype.TimestamptzOID: 1, pgtype.XIDOID: 1, } } // TODO replace by slices.Contains when experimental package will be merged to stdlib // https://pkg.go.dev/golang.org/x/exp/slices#Contains func contains(list []string, y string) bool { for _, x := range list { if x == y { return true } } return false } var ( fakeTxMutex sync.Mutex fakeTxConns map[*pgx.Conn]*sql.Tx ) // OptionOpenDB options for configuring the driver when opening a new db pool. type OptionOpenDB func(*connector) // OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will // be used to connect, so only its immediate members should be modified. func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB { return func(dc *connector) { dc.BeforeConnect = bc } } // OptionAfterConnect provides a callback for after connect. func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB { return func(dc *connector) { dc.AfterConnect = ac } } // OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the // connection if the connection has been used before. // If ResetSessionFunc returns ErrBadConn error the connection will be discarded. func OptionResetSession(rs func(context.Context, *pgx.Conn) error) OptionOpenDB { return func(dc *connector) { dc.ResetSession = rs } } // RandomizeHostOrderFunc is a BeforeConnect hook that randomizes the host order in the provided connConfig, so that a // new host becomes primary each time. This is useful to distribute connections for multi-master databases like // CockroachDB. If you use this you likely should set https://golang.org/pkg/database/sql/#DB.SetConnMaxLifetime as well // to ensure that connections are periodically rebalanced across your nodes. func RandomizeHostOrderFunc(ctx context.Context, connConfig *pgx.ConnConfig) error { if len(connConfig.Fallbacks) == 0 { return nil } newFallbacks := append([]*pgconn.FallbackConfig{&pgconn.FallbackConfig{ Host: connConfig.Host, Port: connConfig.Port, TLSConfig: connConfig.TLSConfig, }}, connConfig.Fallbacks...) rand.Shuffle(len(newFallbacks), func(i, j int) { newFallbacks[i], newFallbacks[j] = newFallbacks[j], newFallbacks[i] }) // Use the one that sorted last as the primary and keep the rest as the fallbacks newPrimary := newFallbacks[len(newFallbacks)-1] connConfig.Host = newPrimary.Host connConfig.Port = newPrimary.Port connConfig.TLSConfig = newPrimary.TLSConfig connConfig.Fallbacks = newFallbacks[:len(newFallbacks)-1] return nil } func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector { c := connector{ ConnConfig: config, BeforeConnect: func(context.Context, *pgx.ConnConfig) error { return nil }, // noop before connect by default AfterConnect: func(context.Context, *pgx.Conn) error { return nil }, // noop after connect by default ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default driver: pgxDriver, } for _, opt := range opts { opt(&c) } return c } func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { c := GetConnector(config, opts...) return sql.OpenDB(c) } type connector struct { pgx.ConnConfig BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused driver *Driver } // Connect implement driver.Connector interface func (c connector) Connect(ctx context.Context) (driver.Conn, error) { var ( err error conn *pgx.Conn ) // Create a shallow copy of the config, so that BeforeConnect can safely modify it connConfig := c.ConnConfig if err = c.BeforeConnect(ctx, &connConfig); err != nil { return nil, err } if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil { return nil, err } if err = c.AfterConnect(ctx, conn); err != nil { return nil, err } return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil } // Driver implement driver.Connector interface func (c connector) Driver() driver.Driver { return c.driver } // GetDefaultDriver returns the driver initialized in the init function // and used when the pgx driver is registered. func GetDefaultDriver() driver.Driver { return pgxDriver } type Driver struct { configMutex sync.Mutex configs map[string]*pgx.ConnConfig sequence int } func (d *Driver) Open(name string) (driver.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout defer cancel() connector, err := d.OpenConnector(name) if err != nil { return nil, err } return connector.Connect(ctx) } func (d *Driver) OpenConnector(name string) (driver.Connector, error) { return &driverConnector{driver: d, name: name}, nil } func (d *Driver) registerConnConfig(c *pgx.ConnConfig) string { d.configMutex.Lock() connStr := fmt.Sprintf("registeredConnConfig%d", d.sequence) d.sequence++ d.configs[connStr] = c d.configMutex.Unlock() return connStr } func (d *Driver) unregisterConnConfig(connStr string) { d.configMutex.Lock() delete(d.configs, connStr) d.configMutex.Unlock() } type driverConnector struct { driver *Driver name string } func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { var connConfig *pgx.ConnConfig dc.driver.configMutex.Lock() connConfig = dc.driver.configs[dc.name] dc.driver.configMutex.Unlock() if connConfig == nil { var err error connConfig, err = pgx.ParseConfig(dc.name) if err != nil { return nil, err } } conn, err := pgx.ConnectConfig(ctx, connConfig) if err != nil { return nil, err } c := &Conn{ conn: conn, driver: dc.driver, connConfig: *connConfig, resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil }, } return c, nil } func (dc *driverConnector) Driver() driver.Driver { return dc.driver } // RegisterConnConfig registers a ConnConfig and returns the connection string to use with Open. func RegisterConnConfig(c *pgx.ConnConfig) string { return pgxDriver.registerConnConfig(c) } // UnregisterConnConfig removes the ConnConfig registration for connStr. func UnregisterConnConfig(connStr string) { pgxDriver.unregisterConnConfig(connStr) } type Conn struct { conn *pgx.Conn psCount int64 // Counter used for creating unique prepared statement names driver *Driver connConfig pgx.ConnConfig resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused } // Conn returns the underlying *pgx.Conn func (c *Conn) Conn() *pgx.Conn { return c.conn } func (c *Conn) Prepare(query string) (driver.Stmt, error) { return c.PrepareContext(context.Background(), query) } func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if c.conn.IsClosed() { return nil, driver.ErrBadConn } name := fmt.Sprintf("pgx_%d", c.psCount) c.psCount++ sd, err := c.conn.Prepare(ctx, name, query) if err != nil { return nil, err } return &Stmt{sd: sd, conn: c}, nil } func (c *Conn) Close() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() return c.conn.Close(ctx) } func (c *Conn) Begin() (driver.Tx, error) { return c.BeginTx(context.Background(), driver.TxOptions{}) } func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if c.conn.IsClosed() { return nil, driver.ErrBadConn } if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { *pconn = c.conn return fakeTx{}, nil } var pgxOpts pgx.TxOptions switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: case sql.LevelReadUncommitted: pgxOpts.IsoLevel = pgx.ReadUncommitted case sql.LevelReadCommitted: pgxOpts.IsoLevel = pgx.ReadCommitted case sql.LevelRepeatableRead, sql.LevelSnapshot: pgxOpts.IsoLevel = pgx.RepeatableRead case sql.LevelSerializable: pgxOpts.IsoLevel = pgx.Serializable default: return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation) } if opts.ReadOnly { pgxOpts.AccessMode = pgx.ReadOnly } tx, err := c.conn.BeginTx(ctx, pgxOpts) if err != nil { return nil, err } return wrapTx{ctx: ctx, tx: tx}, nil } func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { if c.conn.IsClosed() { return nil, driver.ErrBadConn } args := namedValueToInterface(argsV) commandTag, err := c.conn.Exec(ctx, query, args...) // if we got a network error before we had a chance to send the query, retry if err != nil { if pgconn.SafeToRetry(err) { return nil, driver.ErrBadConn } } return driver.RowsAffected(commandTag.RowsAffected()), err } func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { if c.conn.IsClosed() { return nil, driver.ErrBadConn } args := []interface{}{databaseSQLResultFormats} args = append(args, namedValueToInterface(argsV)...) rows, err := c.conn.Query(ctx, query, args...) if err != nil { if pgconn.SafeToRetry(err) { return nil, driver.ErrBadConn } return nil, err } // Preload first row because otherwise we won't know what columns are available when database/sql asks. more := rows.Next() if err = rows.Err(); err != nil { rows.Close() return nil, err } return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil } func (c *Conn) Ping(ctx context.Context) error { if c.conn.IsClosed() { return driver.ErrBadConn } err := c.conn.Ping(ctx) if err != nil { // A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the // failure, but manually close it just to be sure. c.Close() return driver.ErrBadConn } return nil } func (c *Conn) CheckNamedValue(*driver.NamedValue) error { // Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly. return nil } func (c *Conn) ResetSession(ctx context.Context) error { if c.conn.IsClosed() { return driver.ErrBadConn } return c.resetSessionFunc(ctx, c.conn) } type Stmt struct { sd *pgconn.StatementDescription conn *Conn } func (s *Stmt) Close() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() return s.conn.conn.Deallocate(ctx, s.sd.Name) } func (s *Stmt) NumInput() int { return len(s.sd.ParamOIDs) } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { return nil, errors.New("Stmt.Exec deprecated and not implemented") } func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { return s.conn.ExecContext(ctx, s.sd.Name, argsV) } func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return nil, errors.New("Stmt.Query deprecated and not implemented") } func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { return s.conn.QueryContext(ctx, s.sd.Name, argsV) } type rowValueFunc func(src []byte) (driver.Value, error) type Rows struct { conn *Conn rows pgx.Rows valueFuncs []rowValueFunc skipNext bool skipNextMore bool columnNames []string } func (r *Rows) Columns() []string { if r.columnNames == nil { fields := r.rows.FieldDescriptions() r.columnNames = make([]string, len(fields)) for i, fd := range fields { r.columnNames[i] = string(fd.Name) } } return r.columnNames } // ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { if dt, ok := r.conn.conn.ConnInfo().DataTypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { return strings.ToUpper(dt.Name) } return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10) } const varHeaderSize = 4 // ColumnTypeLength returns the length of the column type if the column is a // variable length type. If the column is not a variable length type ok // should return false. func (r *Rows) ColumnTypeLength(index int) (int64, bool) { fd := r.rows.FieldDescriptions()[index] switch fd.DataTypeOID { case pgtype.TextOID, pgtype.ByteaOID: return math.MaxInt64, true case pgtype.VarcharOID, pgtype.BPCharArrayOID: return int64(fd.TypeModifier - varHeaderSize), true default: return 0, false } } // ColumnTypePrecisionScale should return the precision and scale for decimal // types. If not applicable, ok should be false. func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { fd := r.rows.FieldDescriptions()[index] switch fd.DataTypeOID { case pgtype.NumericOID: mod := fd.TypeModifier - varHeaderSize precision = int64((mod >> 16) & 0xffff) scale = int64(mod & 0xffff) return precision, scale, true default: return 0, 0, false } } // ColumnTypeScanType returns the value type that can be used to scan types into. func (r *Rows) ColumnTypeScanType(index int) reflect.Type { fd := r.rows.FieldDescriptions()[index] switch fd.DataTypeOID { case pgtype.Float8OID: return reflect.TypeOf(float64(0)) case pgtype.Float4OID: return reflect.TypeOf(float32(0)) case pgtype.Int8OID: return reflect.TypeOf(int64(0)) case pgtype.Int4OID: return reflect.TypeOf(int32(0)) case pgtype.Int2OID: return reflect.TypeOf(int16(0)) case pgtype.BoolOID: return reflect.TypeOf(false) case pgtype.NumericOID: return reflect.TypeOf(float64(0)) case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: return reflect.TypeOf(time.Time{}) case pgtype.ByteaOID: return reflect.TypeOf([]byte(nil)) default: return reflect.TypeOf("") } } func (r *Rows) Close() error { r.rows.Close() return r.rows.Err() } func (r *Rows) Next(dest []driver.Value) error { ci := r.conn.conn.ConnInfo() fieldDescriptions := r.rows.FieldDescriptions() if r.valueFuncs == nil { r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions)) for i, fd := range fieldDescriptions { dataTypeOID := fd.DataTypeOID format := fd.Format switch fd.DataTypeOID { case pgtype.BoolOID: var d bool scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } case pgtype.ByteaOID: var d []byte scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } case pgtype.CIDOID: var d pgtype.CID scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } return d.Value() } case pgtype.DateOID: var d pgtype.Date scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } return d.Value() } case pgtype.Float4OID: var d float32 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return float64(d), err } case pgtype.Float8OID: var d float64 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } case pgtype.Int2OID: var d int16 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return int64(d), err } case pgtype.Int4OID: var d int32 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return int64(d), err } case pgtype.Int8OID: var d int64 scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } case pgtype.JSONOID: var d pgtype.JSON scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } return d.Value() } case pgtype.JSONBOID: var d pgtype.JSONB scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } return d.Value() } case pgtype.OIDOID: var d pgtype.OIDValue scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } return d.Value() } case pgtype.TimestampOID: var d pgtype.Timestamp scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } return d.Value() } case pgtype.TimestamptzOID: var d pgtype.Timestamptz scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } return d.Value() } case pgtype.XIDOID: var d pgtype.XID scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) if err != nil { return nil, err } return d.Value() } default: var d string scanPlan := ci.PlanScan(dataTypeOID, format, &d) r.valueFuncs[i] = func(src []byte) (driver.Value, error) { err := scanPlan.Scan(ci, dataTypeOID, format, src, &d) return d, err } } } } var more bool if r.skipNext { more = r.skipNextMore r.skipNext = false } else { more = r.rows.Next() } if !more { if r.rows.Err() == nil { return io.EOF } else { return r.rows.Err() } } for i, rv := range r.rows.RawValues() { if rv != nil { var err error dest[i], err = r.valueFuncs[i](rv) if err != nil { return fmt.Errorf("convert field %d failed: %v", i, err) } } else { dest[i] = nil } } return nil } func valueToInterface(argsV []driver.Value) []interface{} { args := make([]interface{}, 0, len(argsV)) for _, v := range argsV { if v != nil { args = append(args, v.(interface{})) } else { args = append(args, nil) } } return args } func namedValueToInterface(argsV []driver.NamedValue) []interface{} { args := make([]interface{}, 0, len(argsV)) for _, v := range argsV { if v.Value != nil { args = append(args, v.Value.(interface{})) } else { args = append(args, nil) } } return args } type wrapTx struct { ctx context.Context tx pgx.Tx } func (wtx wrapTx) Commit() error { return wtx.tx.Commit(wtx.ctx) } func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(wtx.ctx) } type fakeTx struct{} func (fakeTx) Commit() error { return nil } func (fakeTx) Rollback() error { return nil } // AcquireConn acquires a *pgx.Conn from database/sql connection pool. It must be released with ReleaseConn. // // In Go 1.13 this functionality has been incorporated into the standard library in the db.Conn.Raw() method. func AcquireConn(db *sql.DB) (*pgx.Conn, error) { var conn *pgx.Conn ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, err } if conn == nil { tx.Rollback() return nil, ErrNotPgx } fakeTxMutex.Lock() fakeTxConns[conn] = tx fakeTxMutex.Unlock() return conn, nil } // ReleaseConn releases a *pgx.Conn acquired with AcquireConn. func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { var tx *sql.Tx var ok bool if conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() conn.Close(ctx) } fakeTxMutex.Lock() tx, ok = fakeTxConns[conn] if ok { delete(fakeTxConns, conn) fakeTxMutex.Unlock() } else { fakeTxMutex.Unlock() return fmt.Errorf("can't release conn that is not acquired") } return tx.Rollback() } pgx-4.18.1/stdlib/sql_test.go000066400000000000000000000750171437725773200161040ustar00rootroot00000000000000package stdlib_test import ( "bytes" "context" "database/sql" "encoding/json" "math" "os" "reflect" "regexp" "testing" "time" "github.com/Masterminds/semver/v3" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func openDB(t testing.TB) *sql.DB { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) return stdlib.OpenDB(*config) } func closeDB(t testing.TB, db *sql.DB) { err := db.Close() require.NoError(t, err) } func skipCockroachDB(t testing.TB, db *sql.DB, msg string) { conn, err := db.Conn(context.Background()) require.NoError(t, err) defer conn.Close() err = conn.Raw(func(driverConn interface{}) error { conn := driverConn.(*stdlib.Conn).Conn() if conn.PgConn().ParameterStatus("crdb_version") != "" { t.Skip(msg) } return nil }) require.NoError(t, err) } func skipPostgreSQLVersion(t testing.TB, db *sql.DB, constraintStr, msg string) { conn, err := db.Conn(context.Background()) require.NoError(t, err) defer conn.Close() err = conn.Raw(func(driverConn interface{}) error { conn := driverConn.(*stdlib.Conn).Conn() serverVersionStr := conn.PgConn().ParameterStatus("server_version") serverVersionStr = regexp.MustCompile(`^[0-9.]+`).FindString(serverVersionStr) // if not PostgreSQL do nothing if serverVersionStr == "" { return nil } serverVersion, err := semver.NewVersion(serverVersionStr) if err != nil { return err } c, err := semver.NewConstraint(constraintStr) if err != nil { return err } if c.Check(serverVersion) { t.Skip(msg) } return nil }) require.NoError(t, err) } func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, db *sql.DB)) { t.Run("SimpleProto", func(t *testing.T) { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) config.PreferSimpleProtocol = true db := stdlib.OpenDB(*config) defer func() { err := db.Close() require.NoError(t, err) }() f(t, db) ensureDBValid(t, db) }, ) t.Run("DefaultProto", func(t *testing.T) { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) db := stdlib.OpenDB(*config) defer func() { err := db.Close() require.NoError(t, err) }() f(t, db) ensureDBValid(t, db) }, ) } // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should // cover an broken connections. func ensureDBValid(t testing.TB, db *sql.DB) { var sum, rowCount int32 rows, err := db.Query("select generate_series(1,$1)", 10) require.NoError(t, err) defer rows.Close() for rows.Next() { var n int32 rows.Scan(&n) sum += n rowCount++ } require.NoError(t, rows.Err()) if rowCount != 10 { t.Error("Select called onDataRow wrong number of times") } if sum != 55 { t.Error("Wrong values returned") } } type preparer interface { Prepare(query string) (*sql.Stmt, error) } func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt { stmt, err := p.Prepare(sql) require.NoError(t, err) return stmt } func closeStmt(t *testing.T, stmt *sql.Stmt) { err := stmt.Close() require.NoError(t, err) } func TestSQLOpen(t *testing.T) { tests := []struct { driverName string }{ {driverName: "pgx"}, {driverName: "pgx/v4"}, } for _, tt := range tests { tt := tt t.Run(tt.driverName, func(t *testing.T) { db, err := sql.Open(tt.driverName, os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) closeDB(t, db) }) } } func TestNormalLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db) skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") defer closeStmt(t, stmt) rows, err := stmt.Query(int32(1), int32(10)) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var s string var n int64 err := rows.Scan(&s, &n) require.NoError(t, err) if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } if n != rowCount { t.Errorf("Expected %d, received %d", rowCount, n) } } require.NoError(t, rows.Err()) require.EqualValues(t, 10, rowCount) err = rows.Close() require.NoError(t, err) ensureDBValid(t, db) } func TestStmtExec(t *testing.T) { db := openDB(t) defer closeDB(t, db) tx, err := db.Begin() require.NoError(t, err) createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)") _, err = createStmt.Exec() require.NoError(t, err) closeStmt(t, createStmt) insertStmt := prepareStmt(t, tx, "insert into t values($1::text)") result, err := insertStmt.Exec("foo") require.NoError(t, err) n, err := result.RowsAffected() require.NoError(t, err) require.EqualValues(t, 1, n) closeStmt(t, insertStmt) ensureDBValid(t, db) } func TestQueryCloseRowsEarly(t *testing.T) { db := openDB(t) defer closeDB(t, db) skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") defer closeStmt(t, stmt) rows, err := stmt.Query(int32(1), int32(10)) require.NoError(t, err) // Close rows immediately without having read them err = rows.Close() require.NoError(t, err) // Run the query again to ensure the connection and statement are still ok rows, err = stmt.Query(int32(1), int32(10)) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var s string var n int64 err := rows.Scan(&s, &n) require.NoError(t, err) if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } if n != rowCount { t.Errorf("Expected %d, received %d", rowCount, n) } } require.NoError(t, rows.Err()) require.EqualValues(t, 10, rowCount) err = rows.Close() require.NoError(t, err) ensureDBValid(t, db) } func TestConnExec(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("create temporary table t(a varchar not null)") require.NoError(t, err) result, err := db.Exec("insert into t values('hey')") require.NoError(t, err) n, err := result.RowsAffected() require.NoError(t, err) require.EqualValues(t, 1, n) }) } func TestConnQuery(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var s string var n int64 err := rows.Scan(&s, &n) require.NoError(t, err) if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } if n != rowCount { t.Errorf("Expected %d, received %d", rowCount, n) } } require.NoError(t, rows.Err()) require.EqualValues(t, 10, rowCount) err = rows.Close() require.NoError(t, err) }) } // https://github.com/jackc/pgx/issues/781 func TestConnQueryDifferentScanPlansIssue781(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { var s string var b bool rows, err := db.Query("select true, 'foo'") require.NoError(t, err) require.True(t, rows.Next()) require.NoError(t, rows.Scan(&b, &s)) assert.Equal(t, true, b) assert.Equal(t, "foo", s) }) } func TestConnQueryNull(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select $1::int", nil) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var n sql.NullInt64 err := rows.Scan(&n) require.NoError(t, err) if n.Valid != false { t.Errorf("Expected n to be null, but it was %v", n) } } require.NoError(t, rows.Err()) require.EqualValues(t, 1, rowCount) err = rows.Close() require.NoError(t, err) }) } func TestConnQueryRowByteSlice(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { expected := []byte{222, 173, 190, 239} var actual []byte err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) require.NoError(t, err) require.EqualValues(t, expected, actual) }) } func TestConnQueryFailure(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { _, err := db.Query("select 'foo") require.Error(t, err) require.IsType(t, new(pgconn.PgError), err) }) } func TestConnSimpleSlicePassThrough(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support cardinality function") var n int64 err := db.QueryRow("select cardinality($1::text[])", []string{"a", "b", "c"}).Scan(&n) require.NoError(t, err) assert.EqualValues(t, 3, n) }) } // Test type that pgx would handle natively in binary, but since it is not a // database/sql native type should be passed through as a string func TestConnQueryRowPgxBinary(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { sql := "select $1::int4[]" expected := "{1,2,3}" var actual string err := db.QueryRow(sql, expected).Scan(&actual) require.NoError(t, err) require.EqualValues(t, expected, actual) }) } func TestConnQueryRowUnknownType(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support point type") sql := "select $1::point" expected := "(1,2)" var actual string err := db.QueryRow(sql, expected).Scan(&actual) require.NoError(t, err) require.EqualValues(t, expected, actual) }) } func TestConnQueryJSONIntoByteSlice(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec(` create temporary table docs( body json not null ); insert into docs(body) values('{"foo": "bar"}'); `) require.NoError(t, err) sql := `select * from docs` expected := []byte(`{"foo": "bar"}`) var actual []byte err = db.QueryRow(sql).Scan(&actual) if err != nil { t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) } if bytes.Compare(actual, expected) != 0 { t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql) } _, err = db.Exec(`drop table docs`) require.NoError(t, err) }) } func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { // Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data // that needs to escape. No way to know whether the destination is really a text compatible or a bytea. db := openDB(t) defer closeDB(t, db) _, err := db.Exec(` create temporary table docs( body json not null ); `) require.NoError(t, err) expected := []byte(`{"foo": "bar"}`) _, err = db.Exec(`insert into docs(body) values($1)`, expected) require.NoError(t, err) var actual []byte err = db.QueryRow(`select body from docs`).Scan(&actual) require.NoError(t, err) if bytes.Compare(actual, expected) != 0 { t.Errorf(`Expected "%v", got "%v"`, string(expected), string(actual)) } _, err = db.Exec(`drop table docs`) require.NoError(t, err) } func TestTransactionLifeCycle(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("create temporary table t(a varchar not null)") require.NoError(t, err) tx, err := db.Begin() require.NoError(t, err) _, err = tx.Exec("insert into t values('hi')") require.NoError(t, err) err = tx.Rollback() require.NoError(t, err) var n int64 err = db.QueryRow("select count(*) from t").Scan(&n) require.NoError(t, err) require.EqualValues(t, 0, n) tx, err = db.Begin() require.NoError(t, err) _, err = tx.Exec("insert into t values('hi')") require.NoError(t, err) err = tx.Commit() require.NoError(t, err) err = db.QueryRow("select count(*) from t").Scan(&n) require.NoError(t, err) require.EqualValues(t, 1, n) }) } func TestConnBeginTxIsolation(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server always uses serializable isolation level") var defaultIsoLevel string err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) require.NoError(t, err) supportedTests := []struct { sqlIso sql.IsolationLevel pgIso string }{ {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, } for i, tt := range supportedTests { func() { tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) if err != nil { t.Errorf("%d. BeginTx failed: %v", i, err) return } defer tx.Rollback() var pgIso string err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) if err != nil { t.Errorf("%d. QueryRow failed: %v", i, err) } if pgIso != tt.pgIso { t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) } }() } unsupportedTests := []struct { sqlIso sql.IsolationLevel }{ {sqlIso: sql.LevelWriteCommitted}, {sqlIso: sql.LevelLinearizable}, } for i, tt := range unsupportedTests { tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) if err == nil { t.Errorf("%d. BeginTx should have failed", i) tx.Rollback() } } }) } func TestConnBeginTxReadOnly(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) require.NoError(t, err) defer tx.Rollback() var pgReadOnly string err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) if err != nil { t.Errorf("QueryRow failed: %v", err) } if pgReadOnly != "on" { t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") } }) } func TestBeginTxContextCancel(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { _, err := db.Exec("drop table if exists t") require.NoError(t, err) ctx, cancelFn := context.WithCancel(context.Background()) tx, err := db.BeginTx(ctx, nil) require.NoError(t, err) _, err = tx.Exec("create table t(id serial)") require.NoError(t, err) cancelFn() err = tx.Commit() if err != context.Canceled && err != sql.ErrTxDone { t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) } var n int err = db.QueryRow("select count(*) from t").Scan(&n) if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" { t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) } }) } func TestAcquireConn(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { var conns []*pgx.Conn for i := 1; i < 6; i++ { conn, err := stdlib.AcquireConn(db) if err != nil { t.Errorf("%d. AcquireConn failed: %v", i, err) continue } var n int32 err = conn.QueryRow(context.Background(), "select 1").Scan(&n) if err != nil { t.Errorf("%d. QueryRow failed: %v", i, err) } if n != 1 { t.Errorf("%d. n => %d, want %d", i, n, 1) } stats := db.Stats() if stats.OpenConnections != i { t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) } conns = append(conns, conn) } for i, conn := range conns { if err := stdlib.ReleaseConn(db, conn); err != nil { t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) } } }) } func TestConnRaw(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { conn, err := db.Conn(context.Background()) require.NoError(t, err) var n int err = conn.Raw(func(driverConn interface{}) error { conn := driverConn.(*stdlib.Conn).Conn() return conn.QueryRow(context.Background(), "select 42").Scan(&n) }) require.NoError(t, err) assert.EqualValues(t, 42, n) }) } // https://github.com/jackc/pgx/issues/673 func TestReleaseConnWithTxInProgress(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server does not support backend PID") c1, err := stdlib.AcquireConn(db) require.NoError(t, err) _, err = c1.Exec(context.Background(), "begin") require.NoError(t, err) c1PID := c1.PgConn().PID() err = stdlib.ReleaseConn(db, c1) require.NoError(t, err) c2, err := stdlib.AcquireConn(db) require.NoError(t, err) c2PID := c2.PgConn().PID() err = stdlib.ReleaseConn(db, c2) require.NoError(t, err) require.NotEqual(t, c1PID, c2PID) // Releasing a conn with a tx in progress should close the connection stats := db.Stats() require.Equal(t, 1, stats.OpenConnections) }) } func TestConnPingContextSuccess(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { err := db.PingContext(context.Background()) require.NoError(t, err) }) } func TestConnPrepareContextSuccess(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { stmt, err := db.PrepareContext(context.Background(), "select now()") require.NoError(t, err) err = stmt.Close() require.NoError(t, err) }) } func TestConnExecContextSuccess(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") require.NoError(t, err) }) } func TestConnExecContextFailureRetry(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { // We get a connection, immediately close it, and then get it back; // DB.Conn along with Conn.ResetSession does the retry for us. { conn, err := stdlib.AcquireConn(db) require.NoError(t, err) conn.Close(context.Background()) stdlib.ReleaseConn(db, conn) } conn, err := db.Conn(context.Background()) require.NoError(t, err) _, err = conn.ExecContext(context.Background(), "select 1") require.NoError(t, err) }) } func TestConnQueryContextSuccess(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") require.NoError(t, err) for rows.Next() { var n int64 err := rows.Scan(&n) require.NoError(t, err) } require.NoError(t, rows.Err()) }) } func TestConnQueryContextFailureRetry(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { // We get a connection, immediately close it, and then get it back; // DB.Conn along with Conn.ResetSession does the retry for us. { conn, err := stdlib.AcquireConn(db) require.NoError(t, err) conn.Close(context.Background()) stdlib.ReleaseConn(db, conn) } conn, err := db.Conn(context.Background()) require.NoError(t, err) _, err = conn.QueryContext(context.Background(), "select 1") require.NoError(t, err) }) } func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { rows, err := db.Query("select 42::bigint") require.NoError(t, err) columnTypes, err := rows.ColumnTypes() require.NoError(t, err) require.Len(t, columnTypes, 1) if columnTypes[0].DatabaseTypeName() != "INT8" { t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT8") } err = rows.Close() require.NoError(t, err) }) } func TestStmtExecContextSuccess(t *testing.T) { db := openDB(t) defer closeDB(t, db) _, err := db.Exec("create temporary table t(id int primary key)") require.NoError(t, err) stmt, err := db.Prepare("insert into t(id) values ($1::int4)") require.NoError(t, err) defer stmt.Close() _, err = stmt.ExecContext(context.Background(), 42) require.NoError(t, err) ensureDBValid(t, db) } func TestStmtExecContextCancel(t *testing.T) { db := openDB(t) defer closeDB(t, db) _, err := db.Exec("create temporary table t(id int primary key)") require.NoError(t, err) stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") require.NoError(t, err) defer stmt.Close() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() _, err = stmt.ExecContext(ctx, 42) if !pgconn.Timeout(err) { t.Errorf("expected timeout error, got %v", err) } ensureDBValid(t, db) } func TestStmtQueryContextSuccess(t *testing.T) { db := openDB(t) defer closeDB(t, db) skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") require.NoError(t, err) defer stmt.Close() rows, err := stmt.QueryContext(context.Background(), 5) require.NoError(t, err) for rows.Next() { var n int64 if err := rows.Scan(&n); err != nil { t.Error(err) } } if rows.Err() != nil { t.Error(rows.Err()) } ensureDBValid(t, db) } func TestRowsColumnTypes(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { columnTypesTests := []struct { Name string TypeName string Length struct { Len int64 OK bool } DecimalSize struct { Precision int64 Scale int64 OK bool } ScanType reflect.Type }{ { Name: "a", TypeName: "INT8", Length: struct { Len int64 OK bool }{ Len: 0, OK: false, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 0, Scale: 0, OK: false, }, ScanType: reflect.TypeOf(int64(0)), }, { Name: "bar", TypeName: "TEXT", Length: struct { Len int64 OK bool }{ Len: math.MaxInt64, OK: true, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 0, Scale: 0, OK: false, }, ScanType: reflect.TypeOf(""), }, { Name: "dec", TypeName: "NUMERIC", Length: struct { Len int64 OK bool }{ Len: 0, OK: false, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 9, Scale: 2, OK: true, }, ScanType: reflect.TypeOf(float64(0)), }, { Name: "d", TypeName: "1266", Length: struct { Len int64 OK bool }{ Len: 0, OK: false, }, DecimalSize: struct { Precision int64 Scale int64 OK bool }{ Precision: 0, Scale: 0, OK: false, }, ScanType: reflect.TypeOf(""), }, } rows, err := db.Query("SELECT 1::bigint AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec, '12:00:00'::timetz as d") require.NoError(t, err) columns, err := rows.ColumnTypes() require.NoError(t, err) assert.Len(t, columns, 4) for i, tt := range columnTypesTests { c := columns[i] if c.Name() != tt.Name { t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) } if c.DatabaseTypeName() != tt.TypeName { t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) } l, ok := c.Length() if l != tt.Length.Len { t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) } if ok != tt.Length.OK { t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) } p, s, ok := c.DecimalSize() if p != tt.DecimalSize.Precision { t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) } if s != tt.DecimalSize.Scale { t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) } if ok != tt.DecimalSize.OK { t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) } if c.ScanType() != tt.ScanType { t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) } } }) } func TestQueryLifeCycle(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { skipCockroachDB(t, db, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) require.NoError(t, err) rowCount := int64(0) for rows.Next() { rowCount++ var ( s string n int64 ) err := rows.Scan(&s, &n) require.NoError(t, err) if s != "foo" { t.Errorf(`Expected "foo", received "%v"`, s) } if n != rowCount { t.Errorf("Expected %d, received %d", rowCount, n) } } require.NoError(t, rows.Err()) err = rows.Close() require.NoError(t, err) rows, err = db.Query("select 1 where false") require.NoError(t, err) rowCount = int64(0) for rows.Next() { rowCount++ } require.NoError(t, rows.Err()) require.EqualValues(t, 0, rowCount) err = rows.Close() require.NoError(t, err) }) } // https://github.com/jackc/pgx/issues/409 func TestScanJSONIntoJSONRawMessage(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { var msg json.RawMessage err := db.QueryRow("select '{}'::json").Scan(&msg) require.NoError(t, err) require.EqualValues(t, []byte("{}"), []byte(msg)) }) } type testLog struct { lvl pgx.LogLevel msg string data map[string]interface{} } type testLogger struct { logs []testLog } func (l *testLogger) Log(ctx context.Context, lvl pgx.LogLevel, msg string, data map[string]interface{}) { l.logs = append(l.logs, testLog{lvl: lvl, msg: msg, data: data}) } func TestRegisterConnConfig(t *testing.T) { connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) logger := &testLogger{} connConfig.Logger = logger // Issue 947: Register and unregister a ConnConfig and ensure that the // returned connection string is not reused. connStr := stdlib.RegisterConnConfig(connConfig) require.Equal(t, "registeredConnConfig0", connStr) stdlib.UnregisterConnConfig(connStr) connStr = stdlib.RegisterConnConfig(connConfig) defer stdlib.UnregisterConnConfig(connStr) require.Equal(t, "registeredConnConfig1", connStr) db, err := sql.Open("pgx", connStr) require.NoError(t, err) defer closeDB(t, db) var n int64 err = db.QueryRow("select 1").Scan(&n) require.NoError(t, err) l := logger.logs[len(logger.logs)-1] assert.Equal(t, "Query", l.msg) assert.Equal(t, "select 1", l.data["sql"]) } // https://github.com/jackc/pgx/issues/958 func TestConnQueryRowConstraintErrors(t *testing.T) { testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { skipPostgreSQLVersion(t, db, "< 11", "Test requires PG 11+") skipCockroachDB(t, db, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") _, err := db.Exec(`create temporary table defer_test ( id text primary key, n int not null, unique (n), unique (n) deferrable initially deferred )`) require.NoError(t, err) _, err = db.Exec(`drop function if exists test_trigger cascade`) require.NoError(t, err) _, err = db.Exec(`create function test_trigger() returns trigger language plpgsql as $$ begin if new.n = 4 then raise exception 'n cant be 4!'; end if; return new; end$$`) require.NoError(t, err) _, err = db.Exec(`create constraint trigger test after insert or update on defer_test deferrable initially deferred for each row execute function test_trigger()`) require.NoError(t, err) _, err = db.Exec(`insert into defer_test (id, n) values ('a', 1), ('b', 2), ('c', 3)`) require.NoError(t, err) var id string err = db.QueryRow(`insert into defer_test (id, n) values ('e', 4) returning id`).Scan(&id) assert.Error(t, err) }) } func TestOptionBeforeAfterConnect(t *testing.T) { config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) var beforeConnConfigs []*pgx.ConnConfig var afterConns []*pgx.Conn db := stdlib.OpenDB(*config, stdlib.OptionBeforeConnect(func(ctx context.Context, connConfig *pgx.ConnConfig) error { beforeConnConfigs = append(beforeConnConfigs, connConfig) return nil }), stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { afterConns = append(afterConns, conn) return nil })) defer closeDB(t, db) // Force it to close and reopen a new connection after each query db.SetMaxIdleConns(0) _, err = db.Exec("select 1") require.NoError(t, err) _, err = db.Exec("select 1") require.NoError(t, err) require.Len(t, beforeConnConfigs, 2) require.Len(t, afterConns, 2) // Note: BeforeConnect creates a shallow copy, so the config contents will be the same but we wean to ensure they // are different objects, so can't use require.NotEqual require.False(t, config == beforeConnConfigs[0]) require.False(t, beforeConnConfigs[0] == beforeConnConfigs[1]) } func TestRandomizeHostOrderFunc(t *testing.T) { config, err := pgx.ParseConfig("postgres://host1,host2,host3") require.NoError(t, err) // Test that at some point we connect to all 3 hosts hostsNotSeenYet := map[string]struct{}{ "host1": struct{}{}, "host2": struct{}{}, "host3": struct{}{}, } // If we don't succeed within this many iterations, something is certainly wrong for i := 0; i < 100000; i++ { connCopy := *config stdlib.RandomizeHostOrderFunc(context.Background(), &connCopy) delete(hostsNotSeenYet, connCopy.Host) if len(hostsNotSeenYet) == 0 { return } hostCheckLoop: for _, h := range []string{"host1", "host2", "host3"} { if connCopy.Host == h { continue } for _, f := range connCopy.Fallbacks { if f.Host == h { continue hostCheckLoop } } require.Failf(t, "got configuration from RandomizeHostOrderFunc that did not have all the hosts", "%+v", connCopy) } } require.Fail(t, "did not get all hosts as primaries after many randomizations") } func TestResetSessionHookCalled(t *testing.T) { var mockCalled bool connConfig, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) db := stdlib.OpenDB(*connConfig, stdlib.OptionResetSession(func(ctx context.Context, conn *pgx.Conn) error { mockCalled = true return nil })) defer closeDB(t, db) err = db.Ping() require.NoError(t, err) err = db.Ping() require.NoError(t, err) require.True(t, mockCalled) } pgx-4.18.1/tx.go000066400000000000000000000342401437725773200134110ustar00rootroot00000000000000package pgx import ( "bytes" "context" "errors" "fmt" "strconv" "github.com/jackc/pgconn" ) // TxIsoLevel is the transaction isolation level (serializable, repeatable read, read committed or read uncommitted) type TxIsoLevel string // Transaction isolation levels const ( Serializable TxIsoLevel = "serializable" RepeatableRead TxIsoLevel = "repeatable read" ReadCommitted TxIsoLevel = "read committed" ReadUncommitted TxIsoLevel = "read uncommitted" ) // TxAccessMode is the transaction access mode (read write or read only) type TxAccessMode string // Transaction access modes const ( ReadWrite TxAccessMode = "read write" ReadOnly TxAccessMode = "read only" ) // TxDeferrableMode is the transaction deferrable mode (deferrable or not deferrable) type TxDeferrableMode string // Transaction deferrable modes const ( Deferrable TxDeferrableMode = "deferrable" NotDeferrable TxDeferrableMode = "not deferrable" ) // TxOptions are transaction modes within a transaction block type TxOptions struct { IsoLevel TxIsoLevel AccessMode TxAccessMode DeferrableMode TxDeferrableMode } var emptyTxOptions TxOptions func (txOptions TxOptions) beginSQL() string { if txOptions == emptyTxOptions { return "begin" } buf := &bytes.Buffer{} buf.WriteString("begin") if txOptions.IsoLevel != "" { fmt.Fprintf(buf, " isolation level %s", txOptions.IsoLevel) } if txOptions.AccessMode != "" { fmt.Fprintf(buf, " %s", txOptions.AccessMode) } if txOptions.DeferrableMode != "" { fmt.Fprintf(buf, " %s", txOptions.DeferrableMode) } return buf.String() } var ErrTxClosed = errors.New("tx is closed") // ErrTxCommitRollback occurs when an error has occurred in a transaction and // Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but // it is treated as ROLLBACK. var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") // Begin starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no // auto-rollback on context cancellation. func (c *Conn) Begin(ctx context.Context) (Tx, error) { return c.BeginTx(ctx, TxOptions{}) } // BeginTx starts a transaction with txOptions determining the transaction mode. Unlike database/sql, the context only // affects the begin command. i.e. there is no auto-rollback on context cancellation. func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { _, err := c.Exec(ctx, txOptions.beginSQL()) if err != nil { // begin should never fail unless there is an underlying connection issue or // a context timeout. In either case, the connection is possibly broken. c.die(errors.New("failed to begin transaction")) return nil, err } return &dbTx{conn: c}, nil } // BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns // an error the transaction is rolled back. The context will be used when executing the transaction control statements // (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f. func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { return c.BeginTxFunc(ctx, TxOptions{}, f) } // BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return // an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be // used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect // the execution of f. func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) { var tx Tx tx, err = c.BeginTx(ctx, txOptions) if err != nil { return err } defer func() { rollbackErr := tx.Rollback(ctx) if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { err = rollbackErr } }() fErr := f(tx) if fErr != nil { _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return return fErr } return tx.Commit(ctx) } // Tx represents a database transaction. // // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx // state, to support pseudo-nested transactions with savepoints, and to allow tests to mock transactions. However, // adding a method to an interface is technically a breaking change. If new methods are added to Conn it may be // desirable to add them to Tx as well. Because of this the Tx interface is partially excluded from semantic version // requirements. Methods will not be removed or changed, but new methods may be added. type Tx interface { // Begin starts a pseudo nested transaction. Begin(ctx context.Context) (Tx, error) // BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested // transaction will be committed. If it does then it will be rolled back. BeginFunc(ctx context.Context, f func(Tx) error) (err error) // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then // ErrTxCommitRollback will be returned. Commit(ctx context.Context) error // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a // pseudo nested transaction. Rollback will return ErrTxClosed if the Tx is already closed, but is otherwise safe to // call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error // condition. Any other failure of a real transaction will result in the connection being closed. Rollback(ctx context.Context) error CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) SendBatch(ctx context.Context, b *Batch) BatchResults LargeObjects() LargeObjects Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) QueryRow(ctx context.Context, sql string, args ...interface{}) Row QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) // Conn returns the underlying *Conn that on which this transaction is executing. Conn() *Conn } // dbTx represents a database transaction. // // All dbTx methods return ErrTxClosed if Commit or Rollback has already been // called on the dbTx. type dbTx struct { conn *Conn err error savepointNum int64 closed bool } // Begin starts a pseudo nested transaction implemented with a savepoint. func (tx *dbTx) Begin(ctx context.Context) (Tx, error) { if tx.closed { return nil, ErrTxClosed } tx.savepointNum++ _, err := tx.conn.Exec(ctx, "savepoint sp_"+strconv.FormatInt(tx.savepointNum, 10)) if err != nil { return nil, err } return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil } func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { if tx.closed { return ErrTxClosed } var savepoint Tx savepoint, err = tx.Begin(ctx) if err != nil { return err } defer func() { rollbackErr := savepoint.Rollback(ctx) if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { err = rollbackErr } }() fErr := f(savepoint) if fErr != nil { _ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return return fErr } return savepoint.Commit(ctx) } // Commit commits the transaction. func (tx *dbTx) Commit(ctx context.Context) error { if tx.closed { return ErrTxClosed } commandTag, err := tx.conn.Exec(ctx, "commit") tx.closed = true if err != nil { if tx.conn.PgConn().TxStatus() != 'I' { _ = tx.conn.Close(ctx) // already have error to return } return err } if string(commandTag) == "ROLLBACK" { return ErrTxCommitRollback } return nil } // Rollback rolls back the transaction. Rollback will return ErrTxClosed if the // Tx is already closed, but is otherwise safe to call multiple times. Hence, a // defer tx.Rollback() is safe even if tx.Commit() will be called first in a // non-error condition. func (tx *dbTx) Rollback(ctx context.Context) error { if tx.closed { return ErrTxClosed } _, err := tx.conn.Exec(ctx, "rollback") tx.closed = true if err != nil { // A rollback failure leaves the connection in an undefined state tx.conn.die(fmt.Errorf("rollback failed: %w", err)) return err } return nil } // Exec delegates to the underlying *Conn func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { return tx.conn.Exec(ctx, sql, arguments...) } // Prepare delegates to the underlying *Conn func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { if tx.closed { return nil, ErrTxClosed } return tx.conn.Prepare(ctx, name, sql) } // Query delegates to the underlying *Conn func (tx *dbTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { if tx.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed return &connRows{closed: true, err: err}, err } return tx.conn.Query(ctx, sql, args...) } // QueryRow delegates to the underlying *Conn func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { rows, _ := tx.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } // QueryFunc delegates to the underlying *Conn. func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if tx.closed { return nil, ErrTxClosed } return tx.conn.QueryFunc(ctx, sql, args, scans, f) } // CopyFrom delegates to the underlying *Conn func (tx *dbTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { if tx.closed { return 0, ErrTxClosed } return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc) } // SendBatch delegates to the underlying *Conn func (tx *dbTx) SendBatch(ctx context.Context, b *Batch) BatchResults { if tx.closed { return &batchResults{err: ErrTxClosed} } return tx.conn.SendBatch(ctx, b) } // LargeObjects returns a LargeObjects instance for the transaction. func (tx *dbTx) LargeObjects() LargeObjects { return LargeObjects{tx: tx} } func (tx *dbTx) Conn() *Conn { return tx.conn } // dbSimulatedNestedTx represents a simulated nested transaction implemented by a savepoint. type dbSimulatedNestedTx struct { tx Tx savepointNum int64 closed bool } // Begin starts a pseudo nested transaction implemented with a savepoint. func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) { if sp.closed { return nil, ErrTxClosed } return sp.tx.Begin(ctx) } func (sp *dbSimulatedNestedTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { if sp.closed { return ErrTxClosed } return sp.tx.BeginFunc(ctx, f) } // Commit releases the savepoint essentially committing the pseudo nested transaction. func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error { if sp.closed { return ErrTxClosed } _, err := sp.Exec(ctx, "release savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) sp.closed = true return err } // Rollback rolls back to the savepoint essentially rolling back the pseudo nested transaction. Rollback will return // ErrTxClosed if the dbSavepoint is already closed, but is otherwise safe to call multiple times. Hence, a defer sp.Rollback() // is safe even if sp.Commit() will be called first in a non-error condition. func (sp *dbSimulatedNestedTx) Rollback(ctx context.Context) error { if sp.closed { return ErrTxClosed } _, err := sp.Exec(ctx, "rollback to savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) sp.closed = true return err } // Exec delegates to the underlying Tx func (sp *dbSimulatedNestedTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { if sp.closed { return nil, ErrTxClosed } return sp.tx.Exec(ctx, sql, arguments...) } // Prepare delegates to the underlying Tx func (sp *dbSimulatedNestedTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { if sp.closed { return nil, ErrTxClosed } return sp.tx.Prepare(ctx, name, sql) } // Query delegates to the underlying Tx func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { if sp.closed { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed return &connRows{closed: true, err: err}, err } return sp.tx.Query(ctx, sql, args...) } // QueryRow delegates to the underlying Tx func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { rows, _ := sp.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } // QueryFunc delegates to the underlying Tx. func (sp *dbSimulatedNestedTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { if sp.closed { return nil, ErrTxClosed } return sp.tx.QueryFunc(ctx, sql, args, scans, f) } // CopyFrom delegates to the underlying *Conn func (sp *dbSimulatedNestedTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { if sp.closed { return 0, ErrTxClosed } return sp.tx.CopyFrom(ctx, tableName, columnNames, rowSrc) } // SendBatch delegates to the underlying *Conn func (sp *dbSimulatedNestedTx) SendBatch(ctx context.Context, b *Batch) BatchResults { if sp.closed { return &batchResults{err: ErrTxClosed} } return sp.tx.SendBatch(ctx, b) } func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects { return LargeObjects{tx: sp} } func (sp *dbSimulatedNestedTx) Conn() *Conn { return sp.tx.Conn() } pgx-4.18.1/tx_test.go000066400000000000000000000357331437725773200144600ustar00rootroot00000000000000package pgx_test import ( "context" "errors" "os" "testing" "time" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/require" ) func TestTransactionSuccessfulCommit(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, unique (id) ); ` if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") if err != nil { t.Fatalf("tx.Exec failed: %v", err) } err = tx.Commit(context.Background()) if err != nil { t.Fatalf("tx.Commit failed: %v", err) } var n int64 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } if n != 1 { t.Fatalf("Did not receive correct number of rows: %v", n) } } func TestTxCommitWhenTxBroken(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, unique (id) ); ` if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { t.Fatalf("tx.Exec failed: %v", err) } // Purposely break transaction if _, err := tx.Exec(context.Background(), "syntax error"); err == nil { t.Fatal("Unexpected success") } err = tx.Commit(context.Background()) if err != pgx.ErrTxCommitRollback { t.Fatalf("Expected error %v, got %v", pgx.ErrTxCommitRollback, err) } var n int64 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } if n != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } } func TestTxCommitWhenDeferredConstraintFailure(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") createSql := ` create temporary table foo( id integer, unique (id) initially deferred ); ` if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { t.Fatalf("tx.Exec failed: %v", err) } if _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)"); err != nil { t.Fatalf("tx.Exec failed: %v", err) } err = tx.Commit(context.Background()) if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "23505" { t.Fatalf("Expected unique constraint violation 23505, got %#v", err) } var n int64 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } if n != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } } func TestTxCommitSerializationFailure(t *testing.T) { t.Parallel() c1 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, c1) if c1.PgConn().ParameterStatus("crdb_version") != "" { t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/60754)") } c2 := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, c2) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() c1.Exec(ctx, `drop table if exists tx_serializable_sums`) _, err := c1.Exec(ctx, `create table tx_serializable_sums(num integer);`) if err != nil { t.Fatalf("Unable to create temporary table: %v", err) } defer c1.Exec(ctx, `drop table tx_serializable_sums`) tx1, err := c1.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("Begin failed: %v", err) } defer tx1.Rollback(ctx) tx2, err := c2.BeginTx(ctx, pgx.TxOptions{IsoLevel: pgx.Serializable}) if err != nil { t.Fatalf("Begin failed: %v", err) } defer tx2.Rollback(ctx) _, err = tx1.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) if err != nil { t.Fatalf("Exec failed: %v", err) } _, err = tx2.Exec(ctx, `insert into tx_serializable_sums(num) select sum(num)::int from tx_serializable_sums`) if err != nil { t.Fatalf("Exec failed: %v", err) } err = tx1.Commit(ctx) if err != nil { t.Fatalf("Commit failed: %v", err) } err = tx2.Commit(ctx) if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "40001" { t.Fatalf("Expected serialization error 40001, got %#v", err) } ensureConnValid(t, c1) ensureConnValid(t, c2) } func TestTransactionSuccessfulRollback(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, unique (id) ); ` if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } tx, err := conn.Begin(context.Background()) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") if err != nil { t.Fatalf("tx.Exec failed: %v", err) } err = tx.Rollback(context.Background()) if err != nil { t.Fatalf("tx.Rollback failed: %v", err) } var n int64 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } if n != 0 { t.Fatalf("Did not receive correct number of rows: %v", n) } } func TestTransactionRollbackFailsClosesConnection(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) ctx, cancel := context.WithCancel(context.Background()) tx, err := conn.Begin(ctx) require.NoError(t, err) cancel() err = tx.Rollback(ctx) require.Error(t, err) require.True(t, conn.IsClosed()) } func TestBeginIsoLevels(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) skipCockroachDB(t, conn, "Server always uses SERIALIZABLE isolation (https://www.cockroachlabs.com/docs/stable/demo-serializable.html)") isoLevels := []pgx.TxIsoLevel{pgx.Serializable, pgx.RepeatableRead, pgx.ReadCommitted, pgx.ReadUncommitted} for _, iso := range isoLevels { tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{IsoLevel: iso}) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } var level pgx.TxIsoLevel conn.QueryRow(context.Background(), "select current_setting('transaction_isolation')").Scan(&level) if level != iso { t.Errorf("Expected to be in isolation level %v but was %v", iso, level) } err = tx.Rollback(context.Background()) if err != nil { t.Fatalf("tx.Rollback failed: %v", err) } } } func TestBeginFunc(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, unique (id) ); ` _, err := conn.Exec(context.Background(), createSql) require.NoError(t, err) err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) return nil }) require.NoError(t, err) var n int64 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) require.NoError(t, err) require.EqualValues(t, 1, n) } func TestBeginFuncRollbackOnError(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, unique (id) ); ` _, err := conn.Exec(context.Background(), createSql) require.NoError(t, err) err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) return errors.New("some error") }) require.EqualError(t, err, "some error") var n int64 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) require.NoError(t, err) require.EqualValues(t, 0, n) } func TestBeginReadOnly(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) tx, err := conn.BeginTx(context.Background(), pgx.TxOptions{AccessMode: pgx.ReadOnly}) if err != nil { t.Fatalf("conn.Begin failed: %v", err) } defer tx.Rollback(context.Background()) _, err = conn.Exec(context.Background(), "create table foo(id serial primary key)") if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "25006" { t.Errorf("Expected error SQLSTATE 25006, but got %#v", err) } } func TestTxNestedTransactionCommit(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, unique (id) ); ` if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } tx, err := conn.Begin(context.Background()) if err != nil { t.Fatal(err) } _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") if err != nil { t.Fatalf("tx.Exec failed: %v", err) } nestedTx, err := tx.Begin(context.Background()) if err != nil { t.Fatal(err) } _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)") if err != nil { t.Fatalf("nestedTx.Exec failed: %v", err) } doubleNestedTx, err := nestedTx.Begin(context.Background()) if err != nil { t.Fatal(err) } _, err = doubleNestedTx.Exec(context.Background(), "insert into foo(id) values (3)") if err != nil { t.Fatalf("doubleNestedTx.Exec failed: %v", err) } err = doubleNestedTx.Commit(context.Background()) if err != nil { t.Fatalf("doubleNestedTx.Commit failed: %v", err) } err = nestedTx.Commit(context.Background()) if err != nil { t.Fatalf("nestedTx.Commit failed: %v", err) } err = tx.Commit(context.Background()) if err != nil { t.Fatalf("tx.Commit failed: %v", err) } var n int64 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } if n != 3 { t.Fatalf("Did not receive correct number of rows: %v", n) } } func TestTxNestedTransactionRollback(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) createSql := ` create temporary table foo( id integer, unique (id) ); ` if _, err := conn.Exec(context.Background(), createSql); err != nil { t.Fatalf("Failed to create table: %v", err) } tx, err := conn.Begin(context.Background()) if err != nil { t.Fatal(err) } _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") if err != nil { t.Fatalf("tx.Exec failed: %v", err) } nestedTx, err := tx.Begin(context.Background()) if err != nil { t.Fatal(err) } _, err = nestedTx.Exec(context.Background(), "insert into foo(id) values (2)") if err != nil { t.Fatalf("nestedTx.Exec failed: %v", err) } err = nestedTx.Rollback(context.Background()) if err != nil { t.Fatalf("nestedTx.Rollback failed: %v", err) } _, err = tx.Exec(context.Background(), "insert into foo(id) values (3)") if err != nil { t.Fatalf("tx.Exec failed: %v", err) } err = tx.Commit(context.Background()) if err != nil { t.Fatalf("tx.Commit failed: %v", err) } var n int64 err = conn.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } if n != 2 { t.Fatalf("Did not receive correct number of rows: %v", n) } } func TestTxBeginFuncNestedTransactionCommit(t *testing.T) { t.Parallel() db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, db) createSql := ` create temporary table foo( id integer, unique (id) ); ` _, err := db.Exec(context.Background(), createSql) require.NoError(t, err) err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") require.NoError(t, err) err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (3)") require.NoError(t, err) return nil }) return nil }) require.NoError(t, err) return nil }) require.NoError(t, err) var n int64 err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) require.NoError(t, err) require.EqualValues(t, 3, n) } func TestTxBeginFuncNestedTransactionRollback(t *testing.T) { t.Parallel() db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, db) createSql := ` create temporary table foo( id integer, unique (id) ); ` _, err := db.Exec(context.Background(), createSql) require.NoError(t, err) err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (1)") require.NoError(t, err) err = db.BeginFunc(context.Background(), func(db pgx.Tx) error { _, err := db.Exec(context.Background(), "insert into foo(id) values (2)") require.NoError(t, err) return errors.New("do a rollback") }) require.EqualError(t, err, "do a rollback") _, err = db.Exec(context.Background(), "insert into foo(id) values (3)") require.NoError(t, err) return nil }) var n int64 err = db.QueryRow(context.Background(), "select count(*) from foo").Scan(&n) require.NoError(t, err) require.EqualValues(t, 2, n) } func TestTxSendBatchClosed(t *testing.T) { t.Parallel() db := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, db) tx, err := db.Begin(context.Background()) require.NoError(t, err) defer tx.Rollback(context.Background()) err = tx.Commit(context.Background()) require.NoError(t, err) batch := &pgx.Batch{} batch.Queue("select 1") batch.Queue("select 2") batch.Queue("select 3") br := tx.SendBatch(context.Background(), batch) defer br.Close() var n int _, err = br.Exec() require.Error(t, err) err = br.QueryRow().Scan(&n) require.Error(t, err) _, err = br.Query() require.Error(t, err) } pgx-4.18.1/values.go000066400000000000000000000154451437725773200142630ustar00rootroot00000000000000package pgx import ( "database/sql/driver" "fmt" "math" "reflect" "time" "github.com/jackc/pgio" "github.com/jackc/pgtype" ) // PostgreSQL format codes const ( TextFormatCode = 0 BinaryFormatCode = 1 ) // SerializationError occurs on failure to encode or decode a value type SerializationError string func (e SerializationError) Error() string { return string(e) } func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) { if arg == nil { return nil, nil } refVal := reflect.ValueOf(arg) if refVal.Kind() == reflect.Ptr && refVal.IsNil() { return nil, nil } switch arg := arg.(type) { // https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface // []byte to database/sql instead of string. But that caused problems with the // simple protocol because the driver.Valuer case got taken before the // pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual // case because of https://github.com/jackc/pgx/issues/339. So instead we // special case JSON and JSONB. case *pgtype.JSON: buf, err := arg.EncodeText(ci, nil) if err != nil { return nil, err } if buf == nil { return nil, nil } return string(buf), nil case *pgtype.JSONB: buf, err := arg.EncodeText(ci, nil) if err != nil { return nil, err } if buf == nil { return nil, nil } return string(buf), nil case driver.Valuer: return callValuerValue(arg) case pgtype.TextEncoder: buf, err := arg.EncodeText(ci, nil) if err != nil { return nil, err } if buf == nil { return nil, nil } return string(buf), nil case float32: return float64(arg), nil case float64: return arg, nil case bool: return arg, nil case time.Duration: return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil case time.Time: return arg, nil case string: return arg, nil case []byte: return arg, nil case int8: return int64(arg), nil case int16: return int64(arg), nil case int32: return int64(arg), nil case int64: return arg, nil case int: return int64(arg), nil case uint8: return int64(arg), nil case uint16: return int64(arg), nil case uint32: return int64(arg), nil case uint64: if arg > math.MaxInt64 { return nil, fmt.Errorf("arg too big for int64: %v", arg) } return int64(arg), nil case uint: if uint64(arg) > math.MaxInt64 { return nil, fmt.Errorf("arg too big for int64: %v", arg) } return int64(arg), nil } if dt, found := ci.DataTypeForValue(arg); found { v := dt.Value err := v.Set(arg) if err != nil { return nil, err } buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) if err != nil { return nil, err } if buf == nil { return nil, nil } return string(buf), nil } if refVal.Kind() == reflect.Ptr { arg = refVal.Elem().Interface() return convertSimpleArgument(ci, arg) } if strippedArg, ok := stripNamedType(&refVal); ok { return convertSimpleArgument(ci, strippedArg) } return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) } func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) { if arg == nil { return pgio.AppendInt32(buf, -1), nil } switch arg := arg.(type) { case pgtype.BinaryEncoder: sp := len(buf) buf = pgio.AppendInt32(buf, -1) argBuf, err := arg.EncodeBinary(ci, buf) if err != nil { return nil, err } if argBuf != nil { buf = argBuf pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } return buf, nil case pgtype.TextEncoder: sp := len(buf) buf = pgio.AppendInt32(buf, -1) argBuf, err := arg.EncodeText(ci, buf) if err != nil { return nil, err } if argBuf != nil { buf = argBuf pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } return buf, nil case string: buf = pgio.AppendInt32(buf, int32(len(arg))) buf = append(buf, arg...) return buf, nil } refVal := reflect.ValueOf(arg) if refVal.Kind() == reflect.Ptr { if refVal.IsNil() { return pgio.AppendInt32(buf, -1), nil } arg = refVal.Elem().Interface() return encodePreparedStatementArgument(ci, buf, oid, arg) } if dt, ok := ci.DataTypeForOID(oid); ok { value := dt.Value err := value.Set(arg) if err != nil { { if arg, ok := arg.(driver.Valuer); ok { v, err := callValuerValue(arg) if err != nil { return nil, err } return encodePreparedStatementArgument(ci, buf, oid, v) } } return nil, err } sp := len(buf) buf = pgio.AppendInt32(buf, -1) argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) if err != nil { return nil, err } if argBuf != nil { buf = argBuf pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } return buf, nil } if strippedArg, ok := stripNamedType(&refVal); ok { return encodePreparedStatementArgument(ci, buf, oid, strippedArg) } return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } // chooseParameterFormatCode determines the correct format code for an // argument to a prepared statement. It defaults to TextFormatCode if no // determination can be made. func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { switch arg := arg.(type) { case pgtype.ParamFormatPreferrer: return arg.PreferredParamFormat() case pgtype.BinaryEncoder: return BinaryFormatCode case string, *string, pgtype.TextEncoder: return TextFormatCode } return ci.ParamFormatCodeForOID(oid) } func stripNamedType(val *reflect.Value) (interface{}, bool) { switch val.Kind() { case reflect.Int: convVal := int(val.Int()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int8: convVal := int8(val.Int()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int16: convVal := int16(val.Int()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int32: convVal := int32(val.Int()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Int64: convVal := int64(val.Int()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint: convVal := uint(val.Uint()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint8: convVal := uint8(val.Uint()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint16: convVal := uint16(val.Uint()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint32: convVal := uint32(val.Uint()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.Uint64: convVal := uint64(val.Uint()) return convVal, reflect.TypeOf(convVal) != val.Type() case reflect.String: convVal := val.String() return convVal, reflect.TypeOf(convVal) != val.Type() } return nil, false } pgx-4.18.1/values_test.go000066400000000000000000000747001437725773200153210ustar00rootroot00000000000000package pgx_test import ( "bytes" "context" "net" "os" "reflect" "strings" "testing" "time" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDateTranscode(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { dates := []time.Time{ time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC), time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC), time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC), time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC), time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC), time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC), } for _, actualDate := range dates { var d time.Time err := conn.QueryRow(context.Background(), "select $1::date", actualDate).Scan(&d) if err != nil { t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) } if !actualDate.Equal(d) { t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate) } } }) } func TestTimestampTzTranscode(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) var outputTime time.Time err := conn.QueryRow(context.Background(), "select $1::timestamptz", inputTime).Scan(&outputTime) if err != nil { t.Fatalf("QueryRow Scan failed: %v", err) } if !inputTime.Equal(outputTime) { t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) } }) } // TODO - move these tests to pgtype func TestJSONAndJSONBTranscode(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { for _, typename := range []string{"json", "jsonb"} { if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } testJSONString(t, conn, typename) testJSONStringPointer(t, conn, typename) } }) } func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) for _, typename := range []string{"json", "jsonb"} { if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } testJSONSingleLevelStringMap(t, conn, typename) testJSONNestedMap(t, conn, typename) testJSONStringArray(t, conn, typename) testJSONInt64Array(t, conn, typename) testJSONInt16ArrayFailureDueToOverflow(t, conn, typename) testJSONStruct(t, conn, typename) } } func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return } if !reflect.DeepEqual(expectedOutput, output) { t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) return } } func testJSONStringPointer(t *testing.T, conn *pgx.Conn, typename string) { input := `{"key": "value"}` expectedOutput := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow(context.Background(), "select $1::"+typename, &input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return } if !reflect.DeepEqual(expectedOutput, output) { t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, expectedOutput, output) return } } func testJSONSingleLevelStringMap(t *testing.T, conn *pgx.Conn, typename string) { input := map[string]string{"key": "value"} var output map[string]string err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return } if !reflect.DeepEqual(input, output) { t.Errorf("%s: Did not transcode map[string]string successfully: %v is not %v", typename, input, output) return } } func testJSONNestedMap(t *testing.T, conn *pgx.Conn, typename string) { input := map[string]interface{}{ "name": "Uncanny", "stats": map[string]interface{}{"hp": float64(107), "maxhp": float64(150)}, "inventory": []interface{}{"phone", "key"}, } var output map[string]interface{} err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) return } if !reflect.DeepEqual(input, output) { t.Errorf("%s: Did not transcode map[string]interface{} successfully: %v is not %v", typename, input, output) return } } func testJSONStringArray(t *testing.T, conn *pgx.Conn, typename string) { input := []string{"foo", "bar", "baz"} var output []string err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) } if !reflect.DeepEqual(input, output) { t.Errorf("%s: Did not transcode []string successfully: %v is not %v", typename, input, output) } } func testJSONInt64Array(t *testing.T, conn *pgx.Conn, typename string) { input := []int64{1, 2, 234432} var output []int64 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) } if !reflect.DeepEqual(input, output) { t.Errorf("%s: Did not transcode []int64 successfully: %v is not %v", typename, input, output) } } func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typename string) { input := []int{1, 2, 234432} var output []int16 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err) } } func testJSONStruct(t *testing.T, conn *pgx.Conn, typename string) { type person struct { Name string `json:"name"` Age int `json:"age"` } input := person{ Name: "John", Age: 42, } var output person err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) if err != nil { t.Errorf("%s: QueryRow Scan failed: %v", typename, err) } if !reflect.DeepEqual(input, output) { t.Errorf("%s: Did not transcode struct successfully: %v is not %v", typename, input, output) } } func mustParseCIDR(t *testing.T, s string) *net.IPNet { _, ipnet, err := net.ParseCIDR(s) if err != nil { t.Fatal(err) } return ipnet } func TestStringToNotTextTypeTranscode(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { input := "01086ee0-4963-4e35-9116-30c173a8d0bd" var output string err := conn.QueryRow(context.Background(), "select $1::uuid", input).Scan(&output) if err != nil { t.Fatal(err) } if input != output { t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output) } err = conn.QueryRow(context.Background(), "select $1::uuid", &input).Scan(&output) if err != nil { t.Fatal(err) } if input != output { t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output) } }) } func TestInetCIDRTranscodeIPNet(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value *net.IPNet }{ {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, {"select $1::inet", mustParseCIDR(t, "::/128")}, {"select $1::inet", mustParseCIDR(t, "::/0")}, {"select $1::inet", mustParseCIDR(t, "::1/128")}, {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, {"select $1::cidr", mustParseCIDR(t, "::/128")}, {"select $1::cidr", mustParseCIDR(t, "::/0")}, {"select $1::cidr", mustParseCIDR(t, "::1/128")}, {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, } for i, tt := range tests { if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") continue } var actual net.IPNet err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) continue } if actual.String() != tt.value.String() { t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) } } }) } func TestInetCIDRTranscodeIP(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value net.IP }{ {"select $1::inet", net.ParseIP("0.0.0.0")}, {"select $1::inet", net.ParseIP("127.0.0.1")}, {"select $1::inet", net.ParseIP("12.34.56.0")}, {"select $1::inet", net.ParseIP("255.255.255.255")}, {"select $1::inet", net.ParseIP("::1")}, {"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")}, {"select $1::cidr", net.ParseIP("0.0.0.0")}, {"select $1::cidr", net.ParseIP("127.0.0.1")}, {"select $1::cidr", net.ParseIP("12.34.56.0")}, {"select $1::cidr", net.ParseIP("255.255.255.255")}, {"select $1::cidr", net.ParseIP("::1")}, {"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")}, } for i, tt := range tests { if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") continue } var actual net.IP err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) continue } if !actual.Equal(tt.value) { t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) } ensureConnValid(t, conn) } failTests := []struct { sql string value *net.IPNet }{ {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, } for i, tt := range failTests { var actual net.IP err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) if err == nil { t.Errorf("%d. Expected failure but got none", i) continue } ensureConnValid(t, conn) } }) } func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value []*net.IPNet }{ { "select $1::inet[]", []*net.IPNet{ mustParseCIDR(t, "0.0.0.0/32"), mustParseCIDR(t, "127.0.0.1/32"), mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), mustParseCIDR(t, "255.0.0.0/8"), mustParseCIDR(t, "255.255.255.255/32"), mustParseCIDR(t, "::/128"), mustParseCIDR(t, "::/0"), mustParseCIDR(t, "::1/128"), mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), }, }, { "select $1::cidr[]", []*net.IPNet{ mustParseCIDR(t, "0.0.0.0/32"), mustParseCIDR(t, "127.0.0.1/32"), mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), mustParseCIDR(t, "255.0.0.0/8"), mustParseCIDR(t, "255.255.255.255/32"), mustParseCIDR(t, "::/128"), mustParseCIDR(t, "::/0"), mustParseCIDR(t, "::1/128"), mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), }, }, } for i, tt := range tests { if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") continue } var actual []*net.IPNet err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) continue } if !reflect.DeepEqual(actual, tt.value) { t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) } ensureConnValid(t, conn) } }) } func TestInetCIDRArrayTranscodeIP(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value []net.IP }{ { "select $1::inet[]", []net.IP{ net.ParseIP("0.0.0.0"), net.ParseIP("127.0.0.1"), net.ParseIP("12.34.56.0"), net.ParseIP("255.255.255.255"), net.ParseIP("2607:f8b0:4009:80b::200e"), }, }, { "select $1::cidr[]", []net.IP{ net.ParseIP("0.0.0.0"), net.ParseIP("127.0.0.1"), net.ParseIP("12.34.56.0"), net.ParseIP("255.255.255.255"), net.ParseIP("2607:f8b0:4009:80b::200e"), }, }, } for i, tt := range tests { if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") continue } var actual []net.IP err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) continue } assert.Equal(t, len(tt.value), len(actual), "%d", i) for j := range actual { assert.True(t, actual[j].Equal(tt.value[j]), "%d", i) } ensureConnValid(t, conn) } failTests := []struct { sql string value []*net.IPNet }{ { "select $1::inet[]", []*net.IPNet{ mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), }, }, { "select $1::cidr[]", []*net.IPNet{ mustParseCIDR(t, "12.34.56.0/32"), mustParseCIDR(t, "192.168.1.0/24"), }, }, } for i, tt := range failTests { var actual []net.IP err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) if err == nil { t.Errorf("%d. Expected failure but got none", i) continue } ensureConnValid(t, conn) } }) } func TestInetCIDRTranscodeWithJustIP(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string value string }{ {"select $1::inet", "0.0.0.0/32"}, {"select $1::inet", "127.0.0.1/32"}, {"select $1::inet", "12.34.56.0/32"}, {"select $1::inet", "255.255.255.255/32"}, {"select $1::inet", "::/128"}, {"select $1::inet", "2607:f8b0:4009:80b::200e/128"}, {"select $1::cidr", "0.0.0.0/32"}, {"select $1::cidr", "127.0.0.1/32"}, {"select $1::cidr", "12.34.56.0/32"}, {"select $1::cidr", "255.255.255.255/32"}, {"select $1::cidr", "::/128"}, {"select $1::cidr", "2607:f8b0:4009:80b::200e/128"}, } for i, tt := range tests { if conn.PgConn().ParameterStatus("crdb_version") != "" && strings.Contains(tt.sql, "cidr") { t.Log("Server does not support cidr type (https://github.com/cockroachdb/cockroach/issues/18846)") continue } expected := mustParseCIDR(t, tt.value) var actual net.IPNet err := conn.QueryRow(context.Background(), tt.sql, expected.IP).Scan(&actual) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) continue } if actual.String() != expected.String() { t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) } ensureConnValid(t, conn) } }) } func TestArrayDecoding(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { tests := []struct { sql string query interface{} scan interface{} assert func(*testing.T, interface{}, interface{}) }{ { "select $1::bool[]", []bool{true, false, true}, &[]bool{}, func(t *testing.T, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]bool))) { t.Errorf("failed to encode bool[]") } }, }, { "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, func(t *testing.T, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]int16))) { t.Errorf("failed to encode smallint[]") } }, }, { "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, func(t *testing.T, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { t.Errorf("failed to encode smallint[]") } }, }, { "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, func(t *testing.T, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]int32))) { t.Errorf("failed to encode int[]") } }, }, { "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, func(t *testing.T, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { t.Errorf("failed to encode int[]") } }, }, { "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, func(t *testing.T, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]int64))) { t.Errorf("failed to encode bigint[]") } }, }, { "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, func(t *testing.T, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { t.Errorf("failed to encode bigint[]") } }, }, { "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, func(t *testing.T, query, scan interface{}) { if !reflect.DeepEqual(query, *(scan.(*[]string))) { t.Errorf("failed to encode text[]") } }, }, { "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, func(t *testing.T, query, scan interface{}) { queryTimeSlice := query.([]time.Time) scanTimeSlice := *(scan.(*[]time.Time)) require.Equal(t, len(queryTimeSlice), len(scanTimeSlice)) for i := range queryTimeSlice { assert.Truef(t, queryTimeSlice[i].Equal(scanTimeSlice[i]), "%d", i) } }, }, { "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, func(t *testing.T, query, scan interface{}) { queryBytesSliceSlice := query.([][]byte) scanBytesSliceSlice := *(scan.(*[][]byte)) if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice)) } for i := range queryBytesSliceSlice { qb := queryBytesSliceSlice[i] sb := scanBytesSliceSlice[i] if !bytes.Equal(qb, sb) { t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) } } }, }, } for i, tt := range tests { err := conn.QueryRow(context.Background(), tt.sql, tt.query).Scan(tt.scan) if err != nil { t.Errorf(`%d. error reading array: %v`, i, err) continue } tt.assert(t, tt.query, tt.scan) ensureConnValid(t, conn) } }) } func TestEmptyArrayDecoding(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { var val []string err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val) if err != nil { t.Errorf(`error reading array: %v`, err) } if len(val) != 0 { t.Errorf("Expected 0 values, got %d", len(val)) } var n, m int32 err = conn.QueryRow(context.Background(), "select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m) if err != nil { t.Errorf(`error reading array: %v`, err) } if len(val) != 0 { t.Errorf("Expected 0 values, got %d", len(val)) } if n != 1 { t.Errorf("Expected n to be 1, but it was %d", n) } if m != 42 { t.Errorf("Expected n to be 42, but it was %d", n) } rows, err := conn.Query(context.Background(), "select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]") if err != nil { t.Errorf(`error retrieving rows with array: %v`, err) } defer rows.Close() for rows.Next() { err = rows.Scan(&n, &val) if err != nil { t.Errorf(`error reading array: %v`, err) } } }) } func TestPointerPointer(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { skipCockroachDB(t, conn, "Server auto converts ints to bigint and test relies on exact types") type allTypes struct { s *string i16 *int16 i32 *int32 i64 *int64 f32 *float32 f64 *float64 b *bool t *time.Time } var actual, zero, expected allTypes { s := "foo" expected.s = &s i16 := int16(1) expected.i16 = &i16 i32 := int32(1) expected.i32 = &i32 i64 := int64(1) expected.i64 = &i64 f32 := float32(1.23) expected.f32 = &f32 f64 := float64(1.23) expected.f64 = &f64 b := true expected.b = &b t := time.Unix(123, 5000) expected.t = &t } tests := []struct { sql string queryArgs []interface{} scanArgs []interface{} expected allTypes }{ {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}}, {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}}, {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}}, {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}}, {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}}, {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}}, {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}}, {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}}, {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}}, {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}}, {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}}, {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}}, {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}}, {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, } for i, tt := range tests { actual = zero err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) } assert.Equal(t, tt.expected.s, actual.s) assert.Equal(t, tt.expected.i16, actual.i16) assert.Equal(t, tt.expected.i32, actual.i32) assert.Equal(t, tt.expected.i64, actual.i64) assert.Equal(t, tt.expected.f32, actual.f32) assert.Equal(t, tt.expected.f64, actual.f64) assert.Equal(t, tt.expected.b, actual.b) if tt.expected.t != nil || actual.t != nil { assert.True(t, tt.expected.t.Equal(*actual.t)) } ensureConnValid(t, conn) } }) } func TestPointerPointerNonZero(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { f := "foo" dest := &f err := conn.QueryRow(context.Background(), "select $1::text", nil).Scan(&dest) if err != nil { t.Errorf("Unexpected failure scanning: %v", err) } if dest != nil { t.Errorf("Expected dest to be nil, got %#v", dest) } }) } func TestEncodeTypeRename(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { type _int int inInt := _int(1) var outInt _int type _int8 int8 inInt8 := _int8(2) var outInt8 _int8 type _int16 int16 inInt16 := _int16(3) var outInt16 _int16 type _int32 int32 inInt32 := _int32(4) var outInt32 _int32 type _int64 int64 inInt64 := _int64(5) var outInt64 _int64 type _uint uint inUint := _uint(6) var outUint _uint type _uint8 uint8 inUint8 := _uint8(7) var outUint8 _uint8 type _uint16 uint16 inUint16 := _uint16(8) var outUint16 _uint16 type _uint32 uint32 inUint32 := _uint32(9) var outUint32 _uint32 type _uint64 uint64 inUint64 := _uint64(10) var outUint64 _uint64 type _string string inString := _string("foo") var outString _string err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString) if err != nil { t.Fatalf("Failed with type rename: %v", err) } if inInt != outInt { t.Errorf("int rename: expected %v, got %v", inInt, outInt) } if inInt8 != outInt8 { t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8) } if inInt16 != outInt16 { t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16) } if inInt32 != outInt32 { t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32) } if inInt64 != outInt64 { t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64) } if inUint != outUint { t.Errorf("uint rename: expected %v, got %v", inUint, outUint) } if inUint8 != outUint8 { t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8) } if inUint16 != outUint16 { t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16) } if inUint32 != outUint32 { t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32) } if inUint64 != outUint64 { t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64) } if inString != outString { t.Errorf("string rename: expected %v, got %v", inString, outString) } }) } func TestRowDecodeBinary(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) tests := []struct { sql string expected []interface{} }{ { "select row(1, 'cat', '2015-01-01 08:12:42-00'::timestamptz)", []interface{}{ int32(1), "cat", time.Date(2015, 1, 1, 8, 12, 42, 0, time.UTC).Local(), }, }, { "select row(100.0::float, 1.09::float)", []interface{}{ float64(100), float64(1.09), }, }, } for i, tt := range tests { var actual []interface{} err := conn.QueryRow(context.Background(), tt.sql).Scan(&actual) if err != nil { t.Errorf("%d. Unexpected failure: %v (sql -> %v)", i, err, tt.sql) continue } for j := range tt.expected { assert.EqualValuesf(t, tt.expected[j], actual[j], "%d. [%d]", i, j) } ensureConnValid(t, conn) } } // https://github.com/jackc/pgx/issues/810 func TestRowsScanNilThenScanValue(t *testing.T) { t.Parallel() testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { sql := `select null as a, null as b union select 1, 2 order by a nulls first ` rows, err := conn.Query(context.Background(), sql) require.NoError(t, err) require.True(t, rows.Next()) err = rows.Scan(nil, nil) require.NoError(t, err) require.True(t, rows.Next()) var a int var b int err = rows.Scan(&a, &b) require.NoError(t, err) require.EqualValues(t, 1, a) require.EqualValues(t, 2, b) rows.Close() require.NoError(t, rows.Err()) }) } func TestScanIntoByteSlice(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) // Success cases for _, tt := range []struct { name string sql string resultFormatCode int16 output []byte }{ {"int - text", "select 42", pgx.TextFormatCode, []byte("42")}, {"text - text", "select 'hi'", pgx.TextFormatCode, []byte("hi")}, {"text - binary", "select 'hi'", pgx.BinaryFormatCode, []byte("hi")}, {"json - text", "select '{}'::json", pgx.TextFormatCode, []byte("{}")}, {"json - binary", "select '{}'::json", pgx.BinaryFormatCode, []byte("{}")}, {"jsonb - text", "select '{}'::jsonb", pgx.TextFormatCode, []byte("{}")}, {"jsonb - binary", "select '{}'::jsonb", pgx.BinaryFormatCode, []byte("{}")}, } { t.Run(tt.name, func(t *testing.T) { var buf []byte err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{tt.resultFormatCode}).Scan(&buf) require.NoError(t, err) require.Equal(t, tt.output, buf) }) } // Failure cases for _, tt := range []struct { name string sql string err string }{ {"int binary", "select 42", "can't scan into dest[0]: cannot assign 42 into *[]uint8"}, } { t.Run(tt.name, func(t *testing.T) { var buf []byte err := conn.QueryRow(context.Background(), tt.sql, pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&buf) require.EqualError(t, err, tt.err) }) } }