pax_global_header00006660000000000000000000000064143717234520014522gustar00rootroot0000000000000052 comment=754c5737f13c951bb6aca7c029bc0ccbee6656c5 pgconn-1.14.0/000077500000000000000000000000001437172345200130715ustar00rootroot00000000000000pgconn-1.14.0/.github/000077500000000000000000000000001437172345200144315ustar00rootroot00000000000000pgconn-1.14.0/.github/workflows/000077500000000000000000000000001437172345200164665ustar00rootroot00000000000000pgconn-1.14.0/.github/workflows/ci.yml000066400000000000000000000100721437172345200176040ustar00rootroot00000000000000name: CI on: push: branches: [ master ] pull_request: branches: [ master ] jobs: test: name: Test runs-on: ubuntu-18.04 strategy: matrix: go-version: [1.15, 1.16] pg-version: [9.6, 10, 11, 12, 13, cockroachdb] include: - pg-version: 9.6 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 10 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 11 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 12 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: 13 pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test - pg-version: cockroachdb pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" steps: - name: Set up Go 1.x uses: actions/setup-go@v2 with: go-version: ${{ matrix.go-version }} - name: Check out code into the Go module directory uses: actions/checkout@v2 - name: Setup database server for testing run: ci/setup_test.bash env: PGVERSION: ${{ matrix.pg-version }} - name: Test run: go test -v -race ./... env: PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }} PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} pgconn-1.14.0/.gitignore000066400000000000000000000000271437172345200150600ustar00rootroot00000000000000.envrc vendor/ .vscode pgconn-1.14.0/CHANGELOG.md000066400000000000000000000120601437172345200147010ustar00rootroot00000000000000# 1.14.0 (February 11, 2023) * Fix: each connection attempt to new node gets own timeout (Nathan Giardina) * Set SNI for SSL connections (Stas Kelvich) * Fix: CopyFrom I/O race (Tommy Reilly) * Minor dependency upgrades # 1.13.0 (August 6, 2022) * Add sslpassword support (Eric McCormack and yun.xu) * Add prefer-standby target_session_attrs support (sergey.bashilov) * Fix GSS ErrorResponse handling (Oliver Tan) # 1.12.1 (May 7, 2022) * Fix: setting krbspn and krbsrvname in connection string (sireax) * Add support for Unix sockets on Windows (Eno Compton) * Stop ignoring ErrorResponse during SCRAM auth (Rafi Shamim) # 1.12.0 (April 21, 2022) * Add pluggable GSSAPI support (Oliver Tan) * Fix: Consider any "0A000" error a possible cached plan changed error due to locale * Better match psql fallback behavior with multiple hosts # 1.11.0 (February 7, 2022) * Support port in ip from LookupFunc to override config (James Hartig) * Fix TLS connection timeout (Blake Embrey) * Add support for read-only, primary, standby, prefer-standby target_session_attributes (Oscar) * Fix connect when receiving NoticeResponse # 1.10.1 (November 20, 2021) * Close without waiting for response (Kei Kamikawa) * Save waiting for network round-trip in CopyFrom (Rueian) * Fix concurrency issue with ContextWatcher * LRU.Get always checks context for cancellation / expiration (Georges Varouchas) # 1.10.0 (July 24, 2021) * net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned. # 1.9.0 (July 10, 2021) * pgconn.Timeout only is true for errors originating in pgconn (Michael Darr) * Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle) * Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard) * Fix default host when parsing URL without host but with port * Allow dbname query parameter in URL conn string * Update underlying dependencies # 1.8.1 (March 25, 2021) * Better connection string sanitization (ip.novikov) * Use proper pgpass location on Windows (Moshe Katz) * Use errors instead of golang.org/x/xerrors * Resume fallback on server error in Connect (Andrey Borodin) # 1.8.0 (December 3, 2020) * Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes) # 1.7.2 (November 3, 2020) * Fix data value slices into work buffer with capacities larger than length. # 1.7.1 (October 31, 2020) * Do not asyncClose after receiving FATAL error from PostgreSQL server # 1.7.0 (September 26, 2020) * Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded * Add ReceiveResults (Sebastiaan Mannem) * Fix parsing DSN connection with bad backslash * Add PgConn.CleanupDone so connection pools can determine when async close is complete # 1.6.4 (July 29, 2020) * Fix deadlock on error after CommandComplete but before ReadyForQuery * Fix panic on parsing DSN with trailing '=' # 1.6.3 (July 22, 2020) * Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo) # 1.6.2 (July 14, 2020) * Update pgservicefile library # 1.6.1 (June 27, 2020) * Update golang.org/x/crypto to latest * Update golang.org/x/text to 0.3.3 * Fix error handling for bad PGSERVICE definition * Redact passwords in ParseConfig errors (Lukas Vogel) # 1.6.0 (June 6, 2020) * Fix panic when closing conn during cancellable query * Fix behavior of sslmode=require with sslrootcert present (Petr Jediný) * Fix field descriptions available after command concluded (Tobias Salzmann) * Support connect_timeout (georgysavva) * Handle IPv6 in connection URLs (Lukas Vogel) * Fix ValidateConnect with cancelable context * Improve CopyFrom performance * Add Config.Copy (georgysavva) # 1.5.0 (March 30, 2020) * Update golang.org/x/crypto for security fix * Implement "verify-ca" SSL mode (Greg Curtis) # 1.4.0 (March 7, 2020) * Fix ExecParams and ExecPrepared handling of empty query. * Support reading config from PostgreSQL service files. # 1.3.2 (February 14, 2020) * Update chunkreader to v2.0.1 for optimized default buffer size. # 1.3.1 (February 5, 2020) * Fix CopyFrom deadlock when multiple NoticeResponse received during copy # 1.3.0 (January 23, 2020) * Add Hijack and Construct. * Update pgproto3 to v2.0.1. # 1.2.1 (January 13, 2020) * Fix data race in context cancellation introduced in v1.2.0. # 1.2.0 (January 11, 2020) ## Features * Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag. * Add PgError.SQLState method. This could be used for compatibility with other drivers and databases. ## Performance * Improve performance when context.Background() is used. (bakape) * CommandTag.RowsAffected is faster and does not allocate. ## Fixes * Try to cancel any in-progress query when a conn is closed by ctx cancel. * Handle NoticeResponse during CopyFrom. * Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish. # 1.1.0 (October 12, 2019) * Add PgConn.IsBusy() method. # 1.0.1 (September 19, 2019) * Fix statement cache not properly cleaning discarded statements. pgconn-1.14.0/LICENSE000066400000000000000000000020661437172345200141020ustar00rootroot00000000000000Copyright (c) 2019-2021 Jack Christensen MIT License Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. pgconn-1.14.0/README.md000066400000000000000000000043141437172345200143520ustar00rootroot00000000000000[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn) ![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg) --- This version is used with pgx `v4`. In pgx `v5` it is part of the https://github.com/jackc/pgx repository. --- # pgconn Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq. It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx. Applications should handle normal queries with a higher level library and only use pgconn directly when required for low-level access to PostgreSQL functionality. ## Example Usage ```go pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { log.Fatalln("pgconn failed to connect:", err) } defer pgConn.Close(context.Background()) result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil) for result.NextRow() { fmt.Println("User 123 has email:", string(result.Values()[0])) } _, err = result.Close() if err != nil { log.Fatalln("failed reading result:", err) } ``` ## Testing The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable handling. ### Example Test Environment Connect to your PostgreSQL server and run: ``` create database pgx_test; ``` Now you can run the tests: ```bash PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./... ``` ### Connection and Authentication Tests Pgconn supports multiple connection types and means of authentication. These tests are optional. They will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change authentication code. pgconn-1.14.0/auth_scram.go000066400000000000000000000170401437172345200155500ustar00rootroot00000000000000// SCRAM-SHA-256 authentication // // Resources: // https://tools.ietf.org/html/rfc5802 // https://tools.ietf.org/html/rfc8265 // https://www.postgresql.org/docs/current/sasl-authentication.html // // Inspiration drawn from other implementations: // https://github.com/lib/pq/pull/608 // https://github.com/lib/pq/pull/788 // https://github.com/lib/pq/pull/833 package pgconn import ( "bytes" "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/base64" "errors" "fmt" "strconv" "github.com/jackc/pgproto3/v2" "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" ) const clientNonceLen = 18 // Perform SCRAM authentication. func (c *PgConn) scramAuth(serverAuthMechanisms []string) error { sc, err := newScramClient(serverAuthMechanisms, c.config.Password) if err != nil { return err } // Send client-first-message in a SASLInitialResponse saslInitialResponse := &pgproto3.SASLInitialResponse{ AuthMechanism: "SCRAM-SHA-256", Data: sc.clientFirstMessage(), } _, err = c.conn.Write(saslInitialResponse.Encode(nil)) if err != nil { return err } // Receive server-first-message payload in a AuthenticationSASLContinue. saslContinue, err := c.rxSASLContinue() if err != nil { return err } err = sc.recvServerFirstMessage(saslContinue.Data) if err != nil { return err } // Send client-final-message in a SASLResponse saslResponse := &pgproto3.SASLResponse{ Data: []byte(sc.clientFinalMessage()), } _, err = c.conn.Write(saslResponse.Encode(nil)) if err != nil { return err } // Receive server-final-message payload in a AuthenticationSASLFinal. saslFinal, err := c.rxSASLFinal() if err != nil { return err } return sc.recvServerFinalMessage(saslFinal.Data) } func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) { msg, err := c.receiveMessage() if err != nil { return nil, err } switch m := msg.(type) { case *pgproto3.AuthenticationSASLContinue: return m, nil case *pgproto3.ErrorResponse: return nil, ErrorResponseToPgError(m) } return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg) } func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { msg, err := c.receiveMessage() if err != nil { return nil, err } switch m := msg.(type) { case *pgproto3.AuthenticationSASLFinal: return m, nil case *pgproto3.ErrorResponse: return nil, ErrorResponseToPgError(m) } return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg) } type scramClient struct { serverAuthMechanisms []string password []byte clientNonce []byte clientFirstMessageBare []byte serverFirstMessage []byte clientAndServerNonce []byte salt []byte iterations int saltedPassword []byte authMessage []byte } func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) { sc := &scramClient{ serverAuthMechanisms: serverAuthMechanisms, } // Ensure server supports SCRAM-SHA-256 hasScramSHA256 := false for _, mech := range sc.serverAuthMechanisms { if mech == "SCRAM-SHA-256" { hasScramSHA256 = true break } } if !hasScramSHA256 { return nil, errors.New("server does not support SCRAM-SHA-256") } // precis.OpaqueString is equivalent to SASLprep for password. var err error sc.password, err = precis.OpaqueString.Bytes([]byte(password)) if err != nil { // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. sc.password = []byte(password) } buf := make([]byte, clientNonceLen) _, err = rand.Read(buf) if err != nil { return nil, err } sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf))) base64.RawStdEncoding.Encode(sc.clientNonce, buf) return sc, nil } func (sc *scramClient) clientFirstMessage() []byte { sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) } func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { sc.serverFirstMessage = serverFirstMessage buf := serverFirstMessage if !bytes.HasPrefix(buf, []byte("r=")) { return errors.New("invalid SCRAM server-first-message received from server: did not include r=") } buf = buf[2:] idx := bytes.IndexByte(buf, ',') if idx == -1 { return errors.New("invalid SCRAM server-first-message received from server: did not include s=") } sc.clientAndServerNonce = buf[:idx] buf = buf[idx+1:] if !bytes.HasPrefix(buf, []byte("s=")) { return errors.New("invalid SCRAM server-first-message received from server: did not include s=") } buf = buf[2:] idx = bytes.IndexByte(buf, ',') if idx == -1 { return errors.New("invalid SCRAM server-first-message received from server: did not include i=") } saltStr := buf[:idx] buf = buf[idx+1:] if !bytes.HasPrefix(buf, []byte("i=")) { return errors.New("invalid SCRAM server-first-message received from server: did not include i=") } buf = buf[2:] iterationsStr := buf var err error sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr)) if err != nil { return fmt.Errorf("invalid SCRAM salt received from server: %w", err) } sc.iterations, err = strconv.Atoi(string(iterationsStr)) if err != nil || sc.iterations <= 0 { return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err) } if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) { return errors.New("invalid SCRAM nonce: did not start with client nonce") } if len(sc.clientAndServerNonce) <= len(sc.clientNonce) { return errors.New("invalid SCRAM nonce: did not include server nonce") } return nil } func (sc *scramClient) clientFinalMessage() string { clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof) } func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error { if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) { return errors.New("invalid SCRAM server-final-message received from server") } serverSignature := serverFinalMessage[2:] if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) { return errors.New("invalid SCRAM ServerSignature received from server") } return nil } func computeHMAC(key, msg []byte) []byte { mac := hmac.New(sha256.New, key) mac.Write(msg) return mac.Sum(nil) } func computeClientProof(saltedPassword, authMessage []byte) []byte { clientKey := computeHMAC(saltedPassword, []byte("Client Key")) storedKey := sha256.Sum256(clientKey) clientSignature := computeHMAC(storedKey[:], authMessage) clientProof := make([]byte, len(clientSignature)) for i := 0; i < len(clientSignature); i++ { clientProof[i] = clientKey[i] ^ clientSignature[i] } buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) base64.StdEncoding.Encode(buf, clientProof) return buf } func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { serverKey := computeHMAC(saltedPassword, []byte("Server Key")) serverSignature := computeHMAC(serverKey, authMessage) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) base64.StdEncoding.Encode(buf, serverSignature) return buf } pgconn-1.14.0/benchmark_test.go000066400000000000000000000165751437172345200164270ustar00rootroot00000000000000package pgconn_test import ( "bytes" "context" "os" "strings" "testing" "github.com/jackc/pgconn" "github.com/stretchr/testify/require" ) func BenchmarkConnect(b *testing.B) { benchmarks := []struct { name string env string }{ {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, {"TCP", "PGX_TEST_TCP_CONN_STRING"}, } for _, bm := range benchmarks { bm := bm b.Run(bm.name, func(b *testing.B) { connString := os.Getenv(bm.env) if connString == "" { b.Skipf("Skipping due to missing environment variable %v", bm.env) } for i := 0; i < b.N; i++ { conn, err := pgconn.Connect(context.Background(), connString) require.Nil(b, err) err = conn.Close(context.Background()) require.Nil(b, err) } }) } } func BenchmarkExec(b *testing.B) { expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} benchmarks := []struct { name string ctx context.Context }{ // Using an empty context other than context.Background() to compare // performance {"background context", context.Background()}, {"empty context", context.TODO()}, } for _, bm := range benchmarks { bm := bm b.Run(bm.name, func(b *testing.B) { conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) b.ResetTimer() for i := 0; i < b.N; i++ { mrr := conn.Exec(bm.ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") for mrr.NextResult() { rr := mrr.ResultReader() rowCount := 0 for rr.NextRow() { rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } } _, err = rr.Close() if err != nil { b.Fatal(err) } if rowCount != 1 { b.Fatalf("unexpected rowCount: %d", rowCount) } } err := mrr.Close() if err != nil { b.Fatal(err) } } }) } } func BenchmarkExecPossibleToCancel(b *testing.B) { conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} b.ResetTimer() ctx, cancel := context.WithCancel(context.Background()) defer cancel() for i := 0; i < b.N; i++ { mrr := conn.Exec(ctx, "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date") for mrr.NextResult() { rr := mrr.ResultReader() rowCount := 0 for rr.NextRow() { rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } } _, err = rr.Close() if err != nil { b.Fatal(err) } if rowCount != 1 { b.Fatalf("unexpected rowCount: %d", rowCount) } } err := mrr.Close() if err != nil { b.Fatal(err) } } } func BenchmarkExecPrepared(b *testing.B) { expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} benchmarks := []struct { name string ctx context.Context }{ // Using an empty context other than context.Background() to compare // performance {"background context", context.Background()}, {"empty context", context.TODO()}, } for _, bm := range benchmarks { bm := bm b.Run(bm.name, func(b *testing.B) { conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) _, err = conn.Prepare(bm.ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) require.Nil(b, err) b.ResetTimer() for i := 0; i < b.N; i++ { rr := conn.ExecPrepared(bm.ctx, "ps1", nil, nil, nil) rowCount := 0 for rr.NextRow() { rowCount++ if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } } _, err = rr.Close() if err != nil { b.Fatal(err) } if rowCount != 1 { b.Fatalf("unexpected rowCount: %d", rowCount) } } }) } } func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.Nil(b, err) defer closeConn(b, conn) ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, err = conn.Prepare(ctx, "ps1", "select 'hello'::text as a, 42::int4 as b, '2019-01-01'::date", nil) require.Nil(b, err) expectedValues := [][]byte{[]byte("hello"), []byte("42"), []byte("2019-01-01")} b.ResetTimer() for i := 0; i < b.N; i++ { rr := conn.ExecPrepared(ctx, "ps1", nil, nil, nil) rowCount := 0 for rr.NextRow() { rowCount += 1 if len(rr.Values()) != len(expectedValues) { b.Fatalf("unexpected number of values: %d", len(rr.Values())) } for i := range rr.Values() { if !bytes.Equal(rr.Values()[i], expectedValues[i]) { b.Fatalf("unexpected values: %s %s", rr.Values()[i], expectedValues[i]) } } } _, err = rr.Close() if err != nil { b.Fatal(err) } if rowCount != 1 { b.Fatalf("unexpected rowCount: %d", rowCount) } } } // func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { // conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) // require.Nil(b, err) // defer closeConn(b, conn) // ctx, cancel := context.WithCancel(context.Background()) // defer cancel() // b.ResetTimer() // for i := 0; i < b.N; i++ { // conn.ChanToSetDeadline().Watch(ctx) // conn.ChanToSetDeadline().Ignore() // } // } func BenchmarkCommandTagRowsAffected(b *testing.B) { benchmarks := []struct { commandTag string rowsAffected int64 }{ {"UPDATE 1", 1}, {"UPDATE 123456789", 123456789}, {"INSERT 0 1", 1}, {"INSERT 0 123456789", 123456789}, } for _, bm := range benchmarks { ct := pgconn.CommandTag(bm.commandTag) b.Run(bm.commandTag, func(b *testing.B) { var n int64 for i := 0; i < b.N; i++ { n = ct.RowsAffected() } if n != bm.rowsAffected { b.Errorf("expected %d got %d", bm.rowsAffected, n) } }) } } func BenchmarkCommandTagTypeFromString(b *testing.B) { ct := pgconn.CommandTag("UPDATE 1") var update bool for i := 0; i < b.N; i++ { update = strings.HasPrefix(ct.String(), "UPDATE") } if !update { b.Error("expected update") } } func BenchmarkCommandTagInsert(b *testing.B) { benchmarks := []struct { commandTag string is bool }{ {"INSERT 1", true}, {"INSERT 1234567890", true}, {"UPDATE 1", false}, {"UPDATE 1234567890", false}, {"DELETE 1", false}, {"DELETE 1234567890", false}, {"SELECT 1", false}, {"SELECT 1234567890", false}, {"UNKNOWN 1234567890", false}, } for _, bm := range benchmarks { ct := pgconn.CommandTag(bm.commandTag) b.Run(bm.commandTag, func(b *testing.B) { var is bool for i := 0; i < b.N; i++ { is = ct.Insert() } if is != bm.is { b.Errorf("expected %v got %v", bm.is, is) } }) } } pgconn-1.14.0/ci/000077500000000000000000000000001437172345200134645ustar00rootroot00000000000000pgconn-1.14.0/ci/script.bash000077500000000000000000000002601437172345200156300ustar00rootroot00000000000000#!/usr/bin/env bash set -eux if [ "${PGVERSION-}" != "" ] then go test -v -race ./... elif [ "${CRATEVERSION-}" != "" ] then go test -v -race -run 'TestCrateDBConnect' fi pgconn-1.14.0/ci/setup_test.bash000077500000000000000000000064361437172345200165360ustar00rootroot00000000000000#!/usr/bin/env bash set -eux if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]] then sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common sudo rm -rf /var/lib/postgresql wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" sudo apt-get update -qq sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf fi sudo /etc/init.d/postgresql restart # The tricky test user, below, has to actually exist so that it can be used in a test # of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles. psql -U postgres -c 'create database pgx_test' psql -U postgres pgx_test -c 'create extension hstore' psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'" psql -U postgres -c "create user `whoami`" psql -U postgres -c "create user pgx_replication with replication password 'secret'" psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'" fi if [[ "${PGVERSION-}" =~ ^cockroach ]] then wget -qO- https://binaries.cockroachdb.com/cockroach-v22.1.8.linux-amd64.tgz | tar xvz sudo mv cockroach-v22.1.8.linux-amd64/cockroach /usr/local/bin/ cockroach start-single-node --insecure --background --listen-addr=localhost cockroach sql --insecure -e 'create database pgx_test' fi if [ "${CRATEVERSION-}" != "" ] then docker run \ -p "6543:5432" \ -d \ crate:"$CRATEVERSION" \ crate \ -Cnetwork.host=0.0.0.0 \ -Ctransport.host=localhost \ -Clicense.enterprise=false fi pgconn-1.14.0/config.go000066400000000000000000000673551437172345200147050ustar00rootroot00000000000000package pgconn import ( "context" "crypto/tls" "crypto/x509" "encoding/pem" "errors" "fmt" "io" "io/ioutil" "math" "net" "net/url" "os" "path/filepath" "strconv" "strings" "time" "github.com/jackc/chunkreader/v2" "github.com/jackc/pgpassfile" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgservicefile" ) type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error type GetSSLPasswordFunc func(ctx context.Context) string // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A // manually initialized Config will cause ConnectConfig to panic. type Config struct { Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) Port uint16 Database string User string Password string TLSConfig *tls.Config // nil disables TLS ConnectTimeout time.Duration DialFunc DialFunc // e.g. net.Dialer.DialContext LookupFunc LookupFunc // e.g. net.Resolver.LookupHost BuildFrontend BuildFrontendFunc RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) KerberosSrvName string KerberosSpn string Fallbacks []*FallbackConfig // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. ValidateConnect ValidateConnectFunc // AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables // or prepare statements). If this returns an error the connection attempt fails. AfterConnect AfterConnectFunc // OnNotice is a callback function called when a notice response is received. OnNotice NoticeHandler // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler createdByParseConfig bool // Used to enforce created by ParseConfig rule. } // ParseConfigOptions contains options that control how a config is built such as getsslpassword. type ParseConfigOptions struct { // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function // PQsetSSLKeyPassHook_OpenSSL. GetSSLPassword GetSSLPasswordFunc } // Copy returns a deep copy of the config that is safe to use and modify. // The only exception is the TLSConfig field: // according to the tls.Config docs it must not be modified after creation. func (c *Config) Copy() *Config { newConf := new(Config) *newConf = *c if newConf.TLSConfig != nil { newConf.TLSConfig = c.TLSConfig.Clone() } if newConf.RuntimeParams != nil { newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) for k, v := range c.RuntimeParams { newConf.RuntimeParams[k] = v } } if newConf.Fallbacks != nil { newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) for i, fallback := range c.Fallbacks { newFallback := new(FallbackConfig) *newFallback = *fallback if newFallback.TLSConfig != nil { newFallback.TLSConfig = fallback.TLSConfig.Clone() } newConf.Fallbacks[i] = newFallback } } return newConf } // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. type FallbackConfig struct { Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) Port uint16 TLSConfig *tls.Config // nil disables TLS } // isAbsolutePath checks if the provided value is an absolute path either // beginning with a forward slash (as on Linux-based systems) or with a capital // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). func isAbsolutePath(path string) bool { isWindowsPath := func(p string) bool { if len(p) < 3 { return false } drive := p[0] colon := p[1] backslash := p[2] if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' { return true } return false } return strings.HasPrefix(path, "/") || isWindowsPath(path) } // NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with // net.Dial. func NetworkAddress(host string, port uint16) (network, address string) { if isAbsolutePath(host) { network = "unix" address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) } else { network = "tcp" address = net.JoinHostPort(host, strconv.Itoa(int(port))) } return network, address } // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely // matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). // See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be // empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // // # Example DSN // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca // // # Example URL // postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca // // The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done // through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be // interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should // not be modified individually. They should all be modified or all left unchanged. // // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated // values that will be tried in order. This can be used as part of a high availability system. See // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // // # Example URL // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed // via database URL or DSN: // // PGHOST // PGPORT // PGDATABASE // PGUSER // PGPASSWORD // PGPASSFILE // PGSERVICE // PGSERVICEFILE // PGSSLMODE // PGSSLCERT // PGSSLKEY // PGSSLROOTCERT // PGSSLPASSWORD // PGAPPNAME // PGCONNECT_TIMEOUT // PGTARGETSESSIONATTRS // // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are // usually but not always the environment variable name downcased and without the "PG" prefix. // // Important Security Notes: // // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if // not set. // // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of // security each sslmode provides. // // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of // the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of // sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback // which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually // changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting // TLSConfig. // // Other known differences with libpq: // // When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn // does not. // // In addition, ParseConfig accepts the following options: // // min_read_buffer_size // The minimum size of the internal read buffer. Default 8192. // servicefile // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a // part of the connection string. func ParseConfig(connString string) (*Config, error) { var parseConfigOptions ParseConfigOptions return ParseConfigWithOptions(connString, parseConfigOptions) } // ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard // C library libpq. options contains settings that cannot be specified in a connString such as providing a function to // get the SSL password. func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) { defaultSettings := defaultSettings() envSettings := parseEnvSettings() connStringSettings := make(map[string]string) if connString != "" { var err error // connString may be a database URL or a DSN if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { connStringSettings, err = parseURLSettings(connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} } } else { connStringSettings, err = parseDSNSettings(connString) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} } } } settings := mergeSettings(defaultSettings, envSettings, connStringSettings) if service, present := settings["service"]; present { serviceSettings, err := parseServiceSettings(settings["servicefile"], service) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} } settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) } minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32) if err != nil { return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err} } config := &Config{ createdByParseConfig: true, Database: settings["database"], User: settings["user"], Password: settings["password"], RuntimeParams: make(map[string]string), BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)), } if connectTimeoutSetting, present := settings["connect_timeout"]; present { connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) if err != nil { return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} } config.ConnectTimeout = connectTimeout config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) } else { defaultDialer := makeDefaultDialer() config.DialFunc = defaultDialer.DialContext } config.LookupFunc = makeDefaultResolver().LookupHost notRuntimeParams := map[string]struct{}{ "host": {}, "port": {}, "database": {}, "user": {}, "password": {}, "passfile": {}, "connect_timeout": {}, "sslmode": {}, "sslkey": {}, "sslcert": {}, "sslrootcert": {}, "sslpassword": {}, "sslsni": {}, "krbspn": {}, "krbsrvname": {}, "target_session_attrs": {}, "min_read_buffer_size": {}, "service": {}, "servicefile": {}, } // Adding kerberos configuration if _, present := settings["krbsrvname"]; present { config.KerberosSrvName = settings["krbsrvname"] } if _, present := settings["krbspn"]; present { config.KerberosSpn = settings["krbspn"] } for k, v := range settings { if _, present := notRuntimeParams[k]; present { continue } config.RuntimeParams[k] = v } fallbacks := []*FallbackConfig{} hosts := strings.Split(settings["host"], ",") ports := strings.Split(settings["port"], ",") for i, host := range hosts { var portStr string if i < len(ports) { portStr = ports[i] } else { portStr = ports[0] } port, err := parsePort(portStr) if err != nil { return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} } var tlsConfigs []*tls.Config // Ignore TLS settings if Unix domain socket like libpq if network, _ := NetworkAddress(host, port); network == "unix" { tlsConfigs = append(tlsConfigs, nil) } else { var err error tlsConfigs, err = configTLS(settings, host, options) if err != nil { return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} } } for _, tlsConfig := range tlsConfigs { fallbacks = append(fallbacks, &FallbackConfig{ Host: host, Port: port, TLSConfig: tlsConfig, }) } } config.Host = fallbacks[0].Host config.Port = fallbacks[0].Port config.TLSConfig = fallbacks[0].TLSConfig config.Fallbacks = fallbacks[1:] passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) if err == nil { if config.Password == "" { host := config.Host if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" { host = "localhost" } config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User) } } switch tsa := settings["target_session_attrs"]; tsa { case "read-write": config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite case "read-only": config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly case "primary": config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary case "standby": config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby case "prefer-standby": config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby case "any": // do nothing default: return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} } return config, nil } func mergeSettings(settingSets ...map[string]string) map[string]string { settings := make(map[string]string) for _, s2 := range settingSets { for k, v := range s2 { settings[k] = v } } return settings } func parseEnvSettings() map[string]string { settings := make(map[string]string) nameMap := map[string]string{ "PGHOST": "host", "PGPORT": "port", "PGDATABASE": "database", "PGUSER": "user", "PGPASSWORD": "password", "PGPASSFILE": "passfile", "PGAPPNAME": "application_name", "PGCONNECT_TIMEOUT": "connect_timeout", "PGSSLMODE": "sslmode", "PGSSLKEY": "sslkey", "PGSSLCERT": "sslcert", "PGSSLSNI": "sslsni", "PGSSLROOTCERT": "sslrootcert", "PGSSLPASSWORD": "sslpassword", "PGTARGETSESSIONATTRS": "target_session_attrs", "PGSERVICE": "service", "PGSERVICEFILE": "servicefile", } for envname, realname := range nameMap { value := os.Getenv(envname) if value != "" { settings[realname] = value } } return settings } func parseURLSettings(connString string) (map[string]string, error) { settings := make(map[string]string) url, err := url.Parse(connString) if err != nil { return nil, err } if url.User != nil { settings["user"] = url.User.Username() if password, present := url.User.Password(); present { settings["password"] = password } } // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. var hosts []string var ports []string for _, host := range strings.Split(url.Host, ",") { if host == "" { continue } if isIPOnly(host) { hosts = append(hosts, strings.Trim(host, "[]")) continue } h, p, err := net.SplitHostPort(host) if err != nil { return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err) } if h != "" { hosts = append(hosts, h) } if p != "" { ports = append(ports, p) } } if len(hosts) > 0 { settings["host"] = strings.Join(hosts, ",") } if len(ports) > 0 { settings["port"] = strings.Join(ports, ",") } database := strings.TrimLeft(url.Path, "/") if database != "" { settings["database"] = database } nameMap := map[string]string{ "dbname": "database", } for k, v := range url.Query() { if k2, present := nameMap[k]; present { k = k2 } settings[k] = v[0] } return settings, nil } func isIPOnly(host string) bool { return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":") } var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} func parseDSNSettings(s string) (map[string]string, error) { settings := make(map[string]string) nameMap := map[string]string{ "dbname": "database", } for len(s) > 0 { var key, val string eqIdx := strings.IndexRune(s, '=') if eqIdx < 0 { return nil, errors.New("invalid dsn") } key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f") if len(s) == 0 { } else if s[0] != '\'' { end := 0 for ; end < len(s); end++ { if asciiSpace[s[end]] == 1 { break } if s[end] == '\\' { end++ if end == len(s) { return nil, errors.New("invalid backslash") } } } val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) if end == len(s) { s = "" } else { s = s[end+1:] } } else { // quoted string s = s[1:] end := 0 for ; end < len(s); end++ { if s[end] == '\'' { break } if s[end] == '\\' { end++ } } if end == len(s) { return nil, errors.New("unterminated quoted string in connection info string") } val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1) if end == len(s) { s = "" } else { s = s[end+1:] } } if k, ok := nameMap[key]; ok { key = k } if key == "" { return nil, errors.New("invalid dsn") } settings[key] = val } return settings, nil } func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) { servicefile, err := pgservicefile.ReadServicefile(servicefilePath) if err != nil { return nil, fmt.Errorf("failed to read service file: %v", servicefilePath) } service, err := servicefile.GetService(serviceName) if err != nil { return nil, fmt.Errorf("unable to find service: %v", serviceName) } nameMap := map[string]string{ "dbname": "database", } settings := make(map[string]string, len(service.Settings)) for k, v := range service.Settings { if k2, present := nameMap[k]; present { k = k2 } settings[k] = v } return settings, nil } // configTLS uses libpq's TLS parameters to construct []*tls.Config. It is // necessary to allow returning multiple TLS configs as sslmode "allow" and // "prefer" allow fallback. func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) { host := thisHost sslmode := settings["sslmode"] sslrootcert := settings["sslrootcert"] sslcert := settings["sslcert"] sslkey := settings["sslkey"] sslpassword := settings["sslpassword"] sslsni := settings["sslsni"] // Match libpq default behavior if sslmode == "" { sslmode = "prefer" } if sslsni == "" { sslsni = "1" } tlsConfig := &tls.Config{} switch sslmode { case "disable": return []*tls.Config{nil}, nil case "allow", "prefer": tlsConfig.InsecureSkipVerify = true case "require": // According to PostgreSQL documentation, if a root CA file exists, // the behavior of sslmode=require should be the same as that of verify-ca // // See https://www.postgresql.org/docs/12/libpq-ssl.html if sslrootcert != "" { goto nextCase } tlsConfig.InsecureSkipVerify = true break nextCase: fallthrough case "verify-ca": // Don't perform the default certificate verification because it // will verify the hostname. Instead, verify the server's // certificate chain ourselves in VerifyPeerCertificate and // ignore the server name. This emulates libpq's verify-ca // behavior. // // See https://github.com/golang/go/issues/21971#issuecomment-332693931 // and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate // for more info. tlsConfig.InsecureSkipVerify = true tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error { certs := make([]*x509.Certificate, len(certificates)) for i, asn1Data := range certificates { cert, err := x509.ParseCertificate(asn1Data) if err != nil { return errors.New("failed to parse certificate from server: " + err.Error()) } certs[i] = cert } // Leave DNSName empty to skip hostname verification. opts := x509.VerifyOptions{ Roots: tlsConfig.RootCAs, Intermediates: x509.NewCertPool(), } // Skip the first cert because it's the leaf. All others // are intermediates. for _, cert := range certs[1:] { opts.Intermediates.AddCert(cert) } _, err := certs[0].Verify(opts) return err } case "verify-full": tlsConfig.ServerName = host default: return nil, errors.New("sslmode is invalid") } if sslrootcert != "" { caCertPool := x509.NewCertPool() caPath := sslrootcert caCert, err := ioutil.ReadFile(caPath) if err != nil { return nil, fmt.Errorf("unable to read CA file: %w", err) } if !caCertPool.AppendCertsFromPEM(caCert) { return nil, errors.New("unable to add CA to cert pool") } tlsConfig.RootCAs = caCertPool tlsConfig.ClientCAs = caCertPool } if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { return nil, errors.New(`both "sslcert" and "sslkey" are required`) } if sslcert != "" && sslkey != "" { buf, err := ioutil.ReadFile(sslkey) if err != nil { return nil, fmt.Errorf("unable to read sslkey: %w", err) } block, _ := pem.Decode(buf) var pemKey []byte var decryptedKey []byte var decryptedError error // If PEM is encrypted, attempt to decrypt using pass phrase if x509.IsEncryptedPEMBlock(block) { // Attempt decryption with pass phrase // NOTE: only supports RSA (PKCS#1) if sslpassword != "" { decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) } //if sslpassword not provided or has decryption error when use it //try to find sslpassword with callback function if sslpassword == "" || decryptedError != nil { if parseConfigOptions.GetSSLPassword != nil { sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) } if sslpassword == "" { return nil, fmt.Errorf("unable to find sslpassword") } } decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) // Should we also provide warning for PKCS#1 needed? if decryptedError != nil { return nil, fmt.Errorf("unable to decrypt key: %w", err) } pemBytes := pem.Block{ Type: "RSA PRIVATE KEY", Bytes: decryptedKey, } pemKey = pem.EncodeToMemory(&pemBytes) } else { pemKey = pem.EncodeToMemory(block) } certfile, err := ioutil.ReadFile(sslcert) if err != nil { return nil, fmt.Errorf("unable to read cert: %w", err) } cert, err := tls.X509KeyPair(certfile, pemKey) if err != nil { return nil, fmt.Errorf("unable to load cert: %w", err) } tlsConfig.Certificates = []tls.Certificate{cert} } // Set Server Name Indication (SNI), if enabled by connection parameters. // Per RFC 6066, do not set it if the host is a literal IP address (IPv4 // or IPv6). if sslsni == "1" && net.ParseIP(host) == nil { tlsConfig.ServerName = host } switch sslmode { case "allow": return []*tls.Config{nil, tlsConfig}, nil case "prefer": return []*tls.Config{tlsConfig, nil}, nil case "require", "verify-ca", "verify-full": return []*tls.Config{tlsConfig}, nil default: panic("BUG: bad sslmode should already have been caught") } } func parsePort(s string) (uint16, error) { port, err := strconv.ParseUint(s, 10, 16) if err != nil { return 0, err } if port < 1 || port > math.MaxUint16 { return 0, errors.New("outside range") } return uint16(port), nil } func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } func makeDefaultResolver() *net.Resolver { return net.DefaultResolver } func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { return func(r io.Reader, w io.Writer) Frontend { cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) if err != nil { panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err)) } frontend := pgproto3.NewFrontend(cr, w) return frontend } } func parseConnectTimeoutSetting(s string) (time.Duration, error) { timeout, err := strconv.ParseInt(s, 10, 64) if err != nil { return 0, err } if timeout < 0 { return 0, errors.New("negative timeout") } return time.Duration(timeout) * time.Second, nil } func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { d := makeDefaultDialer() d.Timeout = timeout return d.DialContext } // ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-write. func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err } if string(result.Rows[0][0]) == "on" { return errors.New("read only connection") } return nil } // ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-only. func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err } if string(result.Rows[0][0]) != "on" { return errors.New("connection is not read only") } return nil } // ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=standby. func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err } if string(result.Rows[0][0]) != "t" { return errors.New("server is not in hot standby mode") } return nil } // ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=primary. func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err } if string(result.Rows[0][0]) == "t" { return errors.New("server is in standby mode") } return nil } // ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible // target_session_attrs=prefer-standby. func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() if result.Err != nil { return result.Err } if string(result.Rows[0][0]) != "t" { return &NotPreferredError{err: errors.New("server is not in hot standby mode")} } return nil } pgconn-1.14.0/config_test.go000066400000000000000000000767541437172345200157470ustar00rootroot00000000000000package pgconn_test import ( "context" "crypto/tls" "fmt" "io/ioutil" "os" "os/user" "runtime" "strings" "testing" "time" "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestParseConfig(t *testing.T) { t.Parallel() var osUserName string osUser, err := user.Current() if err == nil { // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, // but the libpq default is just the `user` portion, so we strip off the first part. if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] } else { osUserName = osUser.Username } } config, err := pgconn.ParseConfig("") require.NoError(t, err) defaultHost := config.Host tests := []struct { name string connString string config *pgconn.Config }{ // Test all sslmodes { name: "sslmode not set (prefer)", connString: "postgres://jack:secret@localhost:5432/mydb", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", }, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "localhost", Port: 5432, TLSConfig: nil, }, }, }, }, { name: "sslmode disable", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "sslmode allow", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=allow", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "localhost", Port: 5432, TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", }, }, }, }, }, { name: "sslmode prefer", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", }, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "localhost", Port: 5432, TLSConfig: nil, }, }, }, }, { name: "sslmode require", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=require", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", }, RuntimeParams: map[string]string{}, }, }, { name: "sslmode verify-ca", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-ca", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", }, RuntimeParams: map[string]string{}, }, }, { name: "sslmode verify-full", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=verify-full", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ServerName: "localhost"}, RuntimeParams: map[string]string{}, }, }, { name: "database url everything", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&application_name=pgxtest&search_path=myschema&connect_timeout=5", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, ConnectTimeout: 5 * time.Second, RuntimeParams: map[string]string{ "application_name": "pgxtest", "search_path": "myschema", }, }, }, { name: "database url missing password", connString: "postgres://jack@localhost:5432/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "database url missing user and password", connString: "postgres://localhost:5432/mydb?sslmode=disable", config: &pgconn.Config{ User: osUserName, Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "database url missing port", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "database url unix domain socket host", connString: "postgres:///foo?host=/tmp", config: &pgconn.Config{ User: osUserName, Host: "/tmp", Port: 5432, Database: "foo", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "database url unix domain socket host on windows", connString: "postgres:///foo?host=C:\\tmp", config: &pgconn.Config{ User: osUserName, Host: "C:\\tmp", Port: 5432, Database: "foo", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "database url dbname", connString: "postgres://localhost/?dbname=foo&sslmode=disable", config: &pgconn.Config{ User: osUserName, Host: "localhost", Port: 5432, Database: "foo", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "database url postgresql protocol", connString: "postgresql://jack@localhost:5432/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "database url IPv4 with port", connString: "postgresql://jack@127.0.0.1:5433/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Host: "127.0.0.1", Port: 5433, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "database url IPv6 with port", connString: "postgresql://jack@[2001:db8::1]:5433/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Host: "2001:db8::1", Port: 5433, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "database url IPv6 no port", connString: "postgresql://jack@[2001:db8::1]/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Host: "2001:db8::1", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "DSN everything", connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, ConnectTimeout: 5 * time.Second, RuntimeParams: map[string]string{ "application_name": "pgxtest", "search_path": "myschema", }, }, }, { name: "DSN with escaped single quote", connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack's", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "DSN with escaped backslash", connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "sooper\\secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "DSN with single quoted values", connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'", config: &pgconn.Config{ User: "jack", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "DSN with single quoted value with escaped single quote", connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'", config: &pgconn.Config{ User: "jack's", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "DSN with empty single quoted value", connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'", config: &pgconn.Config{ User: "jack", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "DSN with space between key and value", connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'", config: &pgconn.Config{ User: "jack", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "URL multiple hosts", connString: "postgres://jack:secret@foo,bar,baz/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "foo", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "bar", Port: 5432, TLSConfig: nil, }, &pgconn.FallbackConfig{ Host: "baz", Port: 5432, TLSConfig: nil, }, }, }, }, { name: "URL multiple hosts and ports", connString: "postgres://jack:secret@foo:1,bar:2,baz:3/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "foo", Port: 1, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "bar", Port: 2, TLSConfig: nil, }, &pgconn.FallbackConfig{ Host: "baz", Port: 3, TLSConfig: nil, }, }, }, }, // https://github.com/jackc/pgconn/issues/72 { name: "URL without host but with port still uses default host", connString: "postgres://jack:secret@:1/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", Host: defaultHost, Port: 1, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "DSN multiple hosts one port", connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "foo", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "bar", Port: 5432, TLSConfig: nil, }, &pgconn.FallbackConfig{ Host: "baz", Port: 5432, TLSConfig: nil, }, }, }, }, { name: "DSN multiple hosts multiple ports", connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "foo", Port: 1, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "bar", Port: 2, TLSConfig: nil, }, &pgconn.FallbackConfig{ Host: "baz", Port: 3, TLSConfig: nil, }, }, }, }, { name: "multiple hosts and fallback tsl", connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "foo", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "foo", }, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "foo", Port: 5432, TLSConfig: nil, }, &pgconn.FallbackConfig{ Host: "bar", Port: 5432, TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "bar", }}, &pgconn.FallbackConfig{ Host: "bar", Port: 5432, TLSConfig: nil, }, &pgconn.FallbackConfig{ Host: "baz", Port: 5432, TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "baz", }}, &pgconn.FallbackConfig{ Host: "baz", Port: 5432, TLSConfig: nil, }, }, }, }, { name: "target_session_attrs read-write", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite, }, }, { name: "target_session_attrs read-only", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-only", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadOnly, }, }, { name: "target_session_attrs primary", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=primary", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPrimary, }, }, { name: "target_session_attrs standby", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=standby", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsStandby, }, }, { name: "target_session_attrs prefer-standby", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPreferStandby, }, }, { name: "target_session_attrs any", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=any", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "target_session_attrs not set (any)", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "localhost", Port: 5432, Database: "mydb", TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, { name: "SNI is set by default", connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "sni.test", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "sni.test", }, RuntimeParams: map[string]string{}, }, }, { name: "SNI is not set for IPv4", connString: "postgres://jack:secret@1.1.1.1:5432/mydb?sslmode=require", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "1.1.1.1", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, RuntimeParams: map[string]string{}, }, }, { name: "SNI is not set for IPv6", connString: "postgres://jack:secret@[::1]:5432/mydb?sslmode=require", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "::1", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, RuntimeParams: map[string]string{}, }, }, { name: "SNI is not set when disabled (URL-style)", connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require&sslsni=0", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "sni.test", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, RuntimeParams: map[string]string{}, }, }, { name: "SNI is not set when disabled (key/value style)", connString: "user=jack password=secret host=sni.test dbname=mydb sslmode=require sslsni=0", config: &pgconn.Config{ User: "jack", Password: "secret", Host: "sni.test", Port: 5432, Database: "mydb", TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, RuntimeParams: map[string]string{}, }, }, } for i, tt := range tests { config, err := pgconn.ParseConfig(tt.connString) if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { continue } assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) } } // https://github.com/jackc/pgconn/issues/47 func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) { _, err := pgconn.ParseConfig("host= user= password= port= database=") require.NoError(t, err) } func TestParseConfigDSNLeadingEqual(t *testing.T) { _, err := pgconn.ParseConfig("= user=jack") require.Error(t, err) } // https://github.com/jackc/pgconn/issues/49 func TestParseConfigDSNTrailingBackslash(t *testing.T) { _, err := pgconn.ParseConfig(`x=x\`) require.Error(t, err) assert.Contains(t, err.Error(), "invalid backslash") } func TestConfigCopyReturnsEqualConfig(t *testing.T) { connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) require.NoError(t, err) copied := original.Copy() assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") } func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5&sslmode=prefer" original, err := pgconn.ParseConfig(connString) require.NoError(t, err) copied := original.Copy() assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") copied.Port = uint16(5433) copied.RuntimeParams["foo"] = "bar" copied.Fallbacks[0].Port = uint16(5433) assert.Equal(t, uint16(5432), original.Port) assert.Equal(t, "", original.RuntimeParams["foo"]) assert.Equal(t, uint16(5432), original.Fallbacks[0].Port) } func TestConfigCopyCanBeUsedToConnect(t *testing.T) { connString := os.Getenv("PGX_TEST_CONN_STRING") original, err := pgconn.ParseConfig(connString) require.NoError(t, err) copied := original.Copy() assert.NotPanics(t, func() { _, err = pgconn.ConnectConfig(context.Background(), copied) }) assert.NoError(t, err) } func TestNetworkAddress(t *testing.T) { tests := []struct { name string host string wantNet string }{ { name: "Default Unix socket address", host: "/var/run/postgresql", wantNet: "unix", }, { name: "Windows Unix socket address (standard drive name)", host: "C:\\tmp", wantNet: "unix", }, { name: "Windows Unix socket address (first drive name)", host: "A:\\tmp", wantNet: "unix", }, { name: "Windows Unix socket address (last drive name)", host: "Z:\\tmp", wantNet: "unix", }, { name: "Assume TCP for unknown formats", host: "a/tmp", wantNet: "tcp", }, { name: "loopback interface", host: "localhost", wantNet: "tcp", }, { name: "IP address", host: "127.0.0.1", wantNet: "tcp", }, } for i, tt := range tests { gotNet, _ := pgconn.NetworkAddress(tt.host, 5432) assert.Equalf(t, tt.wantNet, gotNet, "Test %d (%s)", i, tt.name) } } func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { if !assert.NotNil(t, expected) { return } if !assert.NotNil(t, actual) { return } assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) assert.Equalf(t, expected.User, actual.User, "%s - User", testName) assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { if expected.TLSConfig != nil { assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) } } if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { for i := range expected.Fallbacks { assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { if expected.Fallbacks[i].TLSConfig != nil { assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) } } } } } func TestParseConfigEnvLibpq(t *testing.T) { var osUserName string osUser, err := user.Current() if err == nil { // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, // but the libpq default is just the `user` portion, so we strip off the first part. if runtime.GOOS == "windows" && strings.Contains(osUser.Username, "\\") { osUserName = osUser.Username[strings.LastIndex(osUser.Username, "\\")+1:] } else { osUserName = osUser.Username } } pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI"} savedEnv := make(map[string]string) for _, n := range pgEnvvars { savedEnv[n] = os.Getenv(n) } defer func() { for k, v := range savedEnv { err := os.Setenv(k, v) if err != nil { t.Fatalf("Unable to restore environment: %v", err) } } }() tests := []struct { name string envvars map[string]string config *pgconn.Config }{ { // not testing no environment at all as that would use default host and that can vary. name: "PGHOST only", envvars: map[string]string{"PGHOST": "123.123.123.123"}, config: &pgconn.Config{ User: osUserName, Host: "123.123.123.123", Port: 5432, TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "123.123.123.123", Port: 5432, TLSConfig: nil, }, }, }, }, { name: "All non-TLS environment", envvars: map[string]string{ "PGHOST": "123.123.123.123", "PGPORT": "7777", "PGDATABASE": "foo", "PGUSER": "bar", "PGPASSWORD": "baz", "PGCONNECT_TIMEOUT": "10", "PGSSLMODE": "disable", "PGAPPNAME": "pgxtest", }, config: &pgconn.Config{ Host: "123.123.123.123", Port: 7777, Database: "foo", User: "bar", Password: "baz", ConnectTimeout: 10 * time.Second, TLSConfig: nil, RuntimeParams: map[string]string{"application_name": "pgxtest"}, }, }, { name: "SNI can be disabled via environment variable", envvars: map[string]string{ "PGHOST": "test.foo", "PGSSLMODE": "require", "PGSSLSNI": "0", }, config: &pgconn.Config{ User: osUserName, Host: "test.foo", Port: 5432, TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, RuntimeParams: map[string]string{}, }, }, } for i, tt := range tests { for _, n := range pgEnvvars { err := os.Unsetenv(n) require.NoError(t, err) } for k, v := range tt.envvars { err := os.Setenv(k, v) require.NoError(t, err) } config, err := pgconn.ParseConfig("") if !assert.Nilf(t, err, "Test %d (%s)", i, tt.name) { continue } assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) } } func TestParseConfigReadsPgPassfile(t *testing.T) { t.Parallel() tf, err := ioutil.TempFile("", "") require.NoError(t, err) defer tf.Close() defer os.Remove(tf.Name()) _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) require.NoError(t, err) connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) expected := &pgconn.Config{ User: "curly", Password: "nyuknyuknyuk", Host: "test1", Port: 5432, Database: "curlydb", TLSConfig: nil, RuntimeParams: map[string]string{}, } actual, err := pgconn.ParseConfig(connString) assert.NoError(t, err) assertConfigsEqual(t, expected, actual, "passfile") } func TestParseConfigReadsPgServiceFile(t *testing.T) { t.Parallel() tf, err := ioutil.TempFile("", "") require.NoError(t, err) defer tf.Close() defer os.Remove(tf.Name()) _, err = tf.Write([]byte(` [abc] host=abc.example.com port=9999 dbname=abcdb user=abcuser [def] host = def.example.com dbname = defdb user = defuser application_name = spaced string `)) require.NoError(t, err) tests := []struct { name string connString string config *pgconn.Config }{ { name: "abc", connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"), config: &pgconn.Config{ Host: "abc.example.com", Database: "abcdb", User: "abcuser", Port: 9999, TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "abc.example.com", }, RuntimeParams: map[string]string{}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "abc.example.com", Port: 9999, TLSConfig: nil, }, }, }, }, { name: "def", connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"), config: &pgconn.Config{ Host: "def.example.com", Port: 5432, Database: "defdb", User: "defuser", TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "def.example.com", }, RuntimeParams: map[string]string{"application_name": "spaced string"}, Fallbacks: []*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "def.example.com", Port: 5432, TLSConfig: nil, }, }, }, }, { name: "conn string has precedence", connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"), config: &pgconn.Config{ Host: "other.example.com", Database: "abcdb", User: "abcuser", Port: 7777, TLSConfig: nil, RuntimeParams: map[string]string{}, }, }, } for i, tt := range tests { config, err := pgconn.ParseConfig(tt.connString) if !assert.NoErrorf(t, err, "Test %d (%s)", i, tt.name) { continue } assertConfigsEqual(t, tt.config, config, fmt.Sprintf("Test %d (%s)", i, tt.name)) } } func TestParseConfigExtractsMinReadBufferSize(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig("min_read_buffer_size=0") require.NoError(t, err) _, present := config.RuntimeParams["min_read_buffer_size"] require.False(t, present) // The buffer size is internal so there isn't much that can be done to test it other than see that the runtime param // was removed. } pgconn-1.14.0/defaults.go000066400000000000000000000036271437172345200152370ustar00rootroot00000000000000//go:build !windows // +build !windows package pgconn import ( "os" "os/user" "path/filepath" ) func defaultSettings() map[string]string { settings := make(map[string]string) settings["host"] = defaultHost() settings["port"] = "5432" // Default to the OS user name. Purposely ignoring err getting user name from // OS. The client application will simply have to specify the user in that // case (which they typically will be doing anyway). user, err := user.Current() if err == nil { settings["user"] = user.Username settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass") settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") if _, err := os.Stat(sslcert); err == nil { if _, err := os.Stat(sslkey); err == nil { // Both the cert and key must be present to use them, or do not use either settings["sslcert"] = sslcert settings["sslkey"] = sslkey } } sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt") if _, err := os.Stat(sslrootcert); err == nil { settings["sslrootcert"] = sslrootcert } } settings["target_session_attrs"] = "any" settings["min_read_buffer_size"] = "8192" return settings } // defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost // on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it // checks the existence of common locations. func defaultHost() string { candidatePaths := []string{ "/var/run/postgresql", // Debian "/private/tmp", // OSX - homebrew "/tmp", // standard PostgreSQL } for _, path := range candidatePaths { if _, err := os.Stat(path); err == nil { return path } } return "localhost" } pgconn-1.14.0/defaults_windows.go000066400000000000000000000036661437172345200170140ustar00rootroot00000000000000package pgconn import ( "os" "os/user" "path/filepath" "strings" ) func defaultSettings() map[string]string { settings := make(map[string]string) settings["host"] = defaultHost() settings["port"] = "5432" // Default to the OS user name. Purposely ignoring err getting user name from // OS. The client application will simply have to specify the user in that // case (which they typically will be doing anyway). user, err := user.Current() appData := os.Getenv("APPDATA") if err == nil { // Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`, // but the libpq default is just the `user` portion, so we strip off the first part. username := user.Username if strings.Contains(username, "\\") { username = username[strings.LastIndex(username, "\\")+1:] } settings["user"] = username settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf") settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf") sslcert := filepath.Join(appData, "postgresql", "postgresql.crt") sslkey := filepath.Join(appData, "postgresql", "postgresql.key") if _, err := os.Stat(sslcert); err == nil { if _, err := os.Stat(sslkey); err == nil { // Both the cert and key must be present to use them, or do not use either settings["sslcert"] = sslcert settings["sslkey"] = sslkey } } sslrootcert := filepath.Join(appData, "postgresql", "root.crt") if _, err := os.Stat(sslrootcert); err == nil { settings["sslrootcert"] = sslrootcert } } settings["target_session_attrs"] = "any" settings["min_read_buffer_size"] = "8192" return settings } // defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost // on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it // checks the existence of common locations. func defaultHost() string { return "localhost" } pgconn-1.14.0/doc.go000066400000000000000000000022531437172345200141670ustar00rootroot00000000000000// Package pgconn is a low-level PostgreSQL database driver. /* pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at nearly the same level is the C library libpq. Establishing a Connection Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for libpq style environment variables. Executing a Query ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method reads all rows into memory. Executing Multiple Queries in a Single Round Trip Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query result. The ReadAll method reads all query results into memory. Context Support All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the method immediately returns. In most circumstances, this will close the underlying connection. The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. */ package pgconn pgconn-1.14.0/errors.go000066400000000000000000000131311437172345200147330ustar00rootroot00000000000000package pgconn import ( "context" "errors" "fmt" "net" "net/url" "regexp" "strings" ) // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. func SafeToRetry(err error) bool { if e, ok := err.(interface{ SafeToRetry() bool }); ok { return e.SafeToRetry() } return false } // Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { var timeoutErr *errTimeout return errors.As(err, &timeoutErr) } // PgError represents an error reported by the PostgreSQL server. See // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for // detailed field description. type PgError struct { Severity string Code string Message string Detail string Hint string Position int32 InternalPosition int32 InternalQuery string Where string SchemaName string TableName string ColumnName string DataTypeName string ConstraintName string File string Line int32 Routine string } func (pe *PgError) Error() string { return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")" } // SQLState returns the SQLState of the error. func (pe *PgError) SQLState() string { return pe.Code } type connectError struct { config *Config msg string err error } func (e *connectError) Error() string { sb := &strings.Builder{} fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) if e.err != nil { fmt.Fprintf(sb, " (%s)", e.err.Error()) } return sb.String() } func (e *connectError) Unwrap() error { return e.err } type connLockError struct { status string } func (e *connLockError) SafeToRetry() bool { return true // a lock failure by definition happens before the connection is used. } func (e *connLockError) Error() string { return e.status } type parseConfigError struct { connString string msg string err error } func (e *parseConfigError) Error() string { connString := redactPW(e.connString) if e.err == nil { return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) } return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) } func (e *parseConfigError) Unwrap() error { return e.err } // preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == // true. Otherwise returns err. func preferContextOverNetTimeoutError(ctx context.Context, err error) error { if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { return &errTimeout{err: ctx.Err()} } return err } type pgconnError struct { msg string err error safeToRetry bool } func (e *pgconnError) Error() string { if e.msg == "" { return e.err.Error() } if e.err == nil { return e.msg } return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) } func (e *pgconnError) SafeToRetry() bool { return e.safeToRetry } func (e *pgconnError) Unwrap() error { return e.err } // errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is // context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true. type errTimeout struct { err error } func (e *errTimeout) Error() string { return fmt.Sprintf("timeout: %s", e.err.Error()) } func (e *errTimeout) SafeToRetry() bool { return SafeToRetry(e.err) } func (e *errTimeout) Unwrap() error { return e.err } type contextAlreadyDoneError struct { err error } func (e *contextAlreadyDoneError) Error() string { return fmt.Sprintf("context already done: %s", e.err.Error()) } func (e *contextAlreadyDoneError) SafeToRetry() bool { return true } func (e *contextAlreadyDoneError) Unwrap() error { return e.err } // newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`. func newContextAlreadyDoneError(ctx context.Context) (err error) { return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}} } type writeError struct { err error safeToRetry bool } func (e *writeError) Error() string { return fmt.Sprintf("write failed: %s", e.err.Error()) } func (e *writeError) SafeToRetry() bool { return e.safeToRetry } func (e *writeError) Unwrap() error { return e.err } func redactPW(connString string) string { if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { if u, err := url.Parse(connString); err == nil { return redactURL(u) } } quotedDSN := regexp.MustCompile(`password='[^']*'`) connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") plainDSN := regexp.MustCompile(`password=[^ ]*`) connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") brokenURL := regexp.MustCompile(`:[^:@]+?@`) connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") return connString } func redactURL(u *url.URL) string { if u == nil { return "" } if _, pwSet := u.User.Password(); pwSet { u.User = url.UserPassword(u.User.Username(), "xxxxx") } return u.String() } type NotPreferredError struct { err error safeToRetry bool } func (e *NotPreferredError) Error() string { return fmt.Sprintf("standby server not found: %s", e.err.Error()) } func (e *NotPreferredError) SafeToRetry() bool { return e.safeToRetry } func (e *NotPreferredError) Unwrap() error { return e.err } pgconn-1.14.0/errors_test.go000066400000000000000000000032321437172345200157730ustar00rootroot00000000000000package pgconn_test import ( "testing" "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" ) func TestConfigError(t *testing.T) { tests := []struct { name string err error expectedMsg string }{ { name: "url with password", err: pgconn.NewParseConfigError("postgresql://foo:password@host", "msg", nil), expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg", }, { name: "dsn with password unquoted", err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil), expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", }, { name: "dsn with password quoted", err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil), expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", }, { name: "weird url", err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", }, { name: "weird url with slash in password", err: pgconn.NewParseConfigError("postgres://user:pass/word@host:5432/db_name", "msg", nil), expectedMsg: "cannot parse `postgres://user:xxxxxx@host:5432/db_name`: msg", }, { name: "url without password", err: pgconn.NewParseConfigError("postgresql://other@host/db", "msg", nil), expectedMsg: "cannot parse `postgresql://other@host/db`: msg", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() assert.EqualError(t, tt.err, tt.expectedMsg) }) } } pgconn-1.14.0/export_test.go000066400000000000000000000003461437172345200160030ustar00rootroot00000000000000// File export_test exports some methods for better testing. package pgconn func NewParseConfigError(conn, msg string, err error) error { return &parseConfigError{ connString: conn, msg: msg, err: err, } } pgconn-1.14.0/frontend_test.go000066400000000000000000000031451437172345200163010ustar00rootroot00000000000000package pgconn_test import ( "context" "io" "os" "testing" "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // frontendWrapper allows to hijack a regular frontend, and inject a specific response type frontendWrapper struct { front pgconn.Frontend msg pgproto3.BackendMessage } // frontendWrapper implements the pgconn.Frontend interface var _ pgconn.Frontend = (*frontendWrapper)(nil) func (f *frontendWrapper) Receive() (pgproto3.BackendMessage, error) { if f.msg != nil { return f.msg, nil } return f.front.Receive() } func TestFrontendFatalErrExec(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) buildFrontend := config.BuildFrontend var front *frontendWrapper config.BuildFrontend = func(r io.Reader, w io.Writer) pgconn.Frontend { wrapped := buildFrontend(r, w) front = &frontendWrapper{wrapped, nil} return front } conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) require.NotNil(t, conn) require.NotNil(t, front) // set frontend to return a "FATAL" message on next call front.msg = &pgproto3.ErrorResponse{Severity: "FATAL", Message: "unit testing fatal error"} _, err = conn.Exec(context.Background(), "SELECT 1").ReadAll() assert.Error(t, err) err = conn.Close(context.Background()) assert.NoError(t, err) select { case <-conn.CleanupDone(): t.Log("ok, CleanupDone() is not blocking") default: assert.Fail(t, "connection closed but CleanupDone() still blocking") } } pgconn-1.14.0/go.mod000066400000000000000000000006351437172345200142030ustar00rootroot00000000000000module github.com/jackc/pgconn go 1.12 require ( github.com/jackc/chunkreader/v2 v2.0.1 github.com/jackc/pgio v1.0.0 github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgproto3/v2 v2.3.2 github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.6.0 golang.org/x/text v0.7.0 ) pgconn-1.14.0/go.sum000066400000000000000000000406531437172345200142340ustar00rootroot00000000000000github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8 h1:KxsCQec+1iwJXtxnbbS/dY0EJ6rJEUlFsrJUnL5A2XI= github.com/jackc/pgproto3/v2 v2.2.1-0.20220412121321-175856ffd3c8/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.0 h1:brH0pCGBDkBW07HWlN/oSBXrmo3WB0UvZd1pIuDcL8Y= github.com/jackc/pgproto3/v2 v2.3.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.1 h1:nwj7qwf0S+Q7ISFfBndqeLwSwxs+4DPsbRFjECT1Y4Y= github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0= github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= pgconn-1.14.0/helper_test.go000066400000000000000000000017641437172345200157460ustar00rootroot00000000000000package pgconn_test import ( "context" "testing" "time" "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func closeConn(t testing.TB, conn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() require.NoError(t, conn.Close(ctx)) select { case <-conn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } } // Do a simple query to ensure the connection is still usable func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read() cancel() require.Nil(t, result.Err) assert.Equal(t, 3, len(result.Rows)) assert.Equal(t, "1", string(result.Rows[0][0])) assert.Equal(t, "2", string(result.Rows[1][0])) assert.Equal(t, "3", string(result.Rows[2][0])) } pgconn-1.14.0/internal/000077500000000000000000000000001437172345200147055ustar00rootroot00000000000000pgconn-1.14.0/internal/ctxwatch/000077500000000000000000000000001437172345200165325ustar00rootroot00000000000000pgconn-1.14.0/internal/ctxwatch/context_watcher.go000066400000000000000000000034671437172345200222740ustar00rootroot00000000000000package ctxwatch import ( "context" "sync" ) // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a // time. type ContextWatcher struct { onCancel func() onUnwatchAfterCancel func() unwatchChan chan struct{} lock sync.Mutex watchInProgress bool onCancelWasCalled bool } // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. // OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and // onCancel called. func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { cw := &ContextWatcher{ onCancel: onCancel, onUnwatchAfterCancel: onUnwatchAfterCancel, unwatchChan: make(chan struct{}), } return cw } // Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. func (cw *ContextWatcher) Watch(ctx context.Context) { cw.lock.Lock() defer cw.lock.Unlock() if cw.watchInProgress { panic("Watch already in progress") } cw.onCancelWasCalled = false if ctx.Done() != nil { cw.watchInProgress = true go func() { select { case <-ctx.Done(): cw.onCancel() cw.onCancelWasCalled = true <-cw.unwatchChan case <-cw.unwatchChan: } }() } else { cw.watchInProgress = false } } // Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was // called then onUnwatchAfterCancel will also be called. func (cw *ContextWatcher) Unwatch() { cw.lock.Lock() defer cw.lock.Unlock() if cw.watchInProgress { cw.unwatchChan <- struct{}{} if cw.onCancelWasCalled { cw.onUnwatchAfterCancel() } cw.watchInProgress = false } } pgconn-1.14.0/internal/ctxwatch/context_watcher_test.go000066400000000000000000000076031437172345200233270ustar00rootroot00000000000000package ctxwatch_test import ( "context" "sync/atomic" "testing" "time" "github.com/jackc/pgconn/internal/ctxwatch" "github.com/stretchr/testify/require" ) func TestContextWatcherContextCancelled(t *testing.T) { canceledChan := make(chan struct{}) cleanupCalled := false cw := ctxwatch.NewContextWatcher(func() { canceledChan <- struct{}{} }, func() { cleanupCalled = true }) ctx, cancel := context.WithCancel(context.Background()) cw.Watch(ctx) cancel() select { case <-canceledChan: case <-time.NewTimer(time.Second).C: t.Fatal("Timed out waiting for cancel func to be called") } cw.Unwatch() require.True(t, cleanupCalled, "Cleanup func was not called") } func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) { cw := ctxwatch.NewContextWatcher(func() { t.Error("cancel func should not have been called") }, func() { t.Error("cleanup func should not have been called") }) ctx, cancel := context.WithCancel(context.Background()) cw.Watch(ctx) cw.Unwatch() cancel() } func TestContextWatcherMultipleWatchPanics(t *testing.T) { cw := ctxwatch.NewContextWatcher(func() {}, func() {}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cw.Watch(ctx) ctx2, cancel2 := context.WithCancel(context.Background()) defer cancel2() require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") } func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw.Unwatch() // unwatch when not / never watching ctx, cancel := context.WithCancel(context.Background()) defer cancel() cw.Watch(ctx) cw.Unwatch() cw.Unwatch() // double unwatch } func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { cw := ctxwatch.NewContextWatcher(func() {}, func() {}) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() cw.Watch(ctx) go cw.Unwatch() go cw.Unwatch() <-ctx.Done() } func TestContextWatcherStress(t *testing.T) { var cancelFuncCalls int64 var cleanupFuncCalls int64 cw := ctxwatch.NewContextWatcher(func() { atomic.AddInt64(&cancelFuncCalls, 1) }, func() { atomic.AddInt64(&cleanupFuncCalls, 1) }) cycleCount := 100000 for i := 0; i < cycleCount; i++ { ctx, cancel := context.WithCancel(context.Background()) cw.Watch(ctx) if i%2 == 0 { cancel() } // Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix. if i%3 == 0 { time.Sleep(time.Nanosecond) } cw.Unwatch() if i%2 == 1 { cancel() } } actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls) if actualCancelFuncCalls == 0 { t.Fatal("actualCancelFuncCalls == 0") } maxCancelFuncCalls := int64(cycleCount) / 2 if actualCancelFuncCalls > maxCancelFuncCalls { t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls) } if actualCancelFuncCalls != actualCleanupFuncCalls { t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls) } } func BenchmarkContextWatcherUncancellable(b *testing.B) { cw := ctxwatch.NewContextWatcher(func() {}, func() {}) for i := 0; i < b.N; i++ { cw.Watch(context.Background()) cw.Unwatch() } } func BenchmarkContextWatcherCancelled(b *testing.B) { cw := ctxwatch.NewContextWatcher(func() {}, func() {}) for i := 0; i < b.N; i++ { ctx, cancel := context.WithCancel(context.Background()) cw.Watch(ctx) cancel() cw.Unwatch() } } func BenchmarkContextWatcherCancellable(b *testing.B) { cw := ctxwatch.NewContextWatcher(func() {}, func() {}) ctx, cancel := context.WithCancel(context.Background()) defer cancel() for i := 0; i < b.N; i++ { cw.Watch(ctx) cw.Unwatch() } } pgconn-1.14.0/krb5.go000066400000000000000000000044131437172345200142650ustar00rootroot00000000000000package pgconn import ( "errors" "fmt" "github.com/jackc/pgproto3/v2" ) // NewGSSFunc creates a GSS authentication provider, for use with // RegisterGSSProvider. type NewGSSFunc func() (GSS, error) var newGSS NewGSSFunc // RegisterGSSProvider registers a GSS authentication provider. For example, if // you need to use Kerberos to authenticate with your server, add this to your // main package: // // import "github.com/otan/gopgkrb5" // // func init() { // pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() }) // } func RegisterGSSProvider(newGSSArg NewGSSFunc) { newGSS = newGSSArg } // GSS provides GSSAPI authentication (e.g., Kerberos). type GSS interface { GetInitToken(host string, service string) ([]byte, error) GetInitTokenFromSPN(spn string) ([]byte, error) Continue(inToken []byte) (done bool, outToken []byte, err error) } func (c *PgConn) gssAuth() error { if newGSS == nil { return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5") } cli, err := newGSS() if err != nil { return err } var nextData []byte if c.config.KerberosSpn != "" { // Use the supplied SPN if provided. nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn) } else { // Allow the kerberos service name to be overridden service := "postgres" if c.config.KerberosSrvName != "" { service = c.config.KerberosSrvName } nextData, err = cli.GetInitToken(c.config.Host, service) } if err != nil { return err } for { gssResponse := &pgproto3.GSSResponse{ Data: nextData, } _, err = c.conn.Write(gssResponse.Encode(nil)) if err != nil { return err } resp, err := c.rxGSSContinue() if err != nil { return err } var done bool done, nextData, err = cli.Continue(resp.Data) if err != nil { return err } if done { break } } return nil } func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) { msg, err := c.receiveMessage() if err != nil { return nil, err } switch m := msg.(type) { case *pgproto3.AuthenticationGSSContinue: return m, nil case *pgproto3.ErrorResponse: return nil, ErrorResponseToPgError(m) } return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg) } pgconn-1.14.0/pgconn.go000066400000000000000000001511601437172345200147100ustar00rootroot00000000000000package pgconn import ( "context" "crypto/md5" "crypto/tls" "encoding/binary" "encoding/hex" "errors" "fmt" "io" "math" "net" "strconv" "strings" "sync" "time" "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" ) const ( connStatusUninitialized = iota connStatusConnecting connStatusClosed connStatusIdle connStatusBusy ) const wbufLen = 1024 // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system type Notification struct { PID uint32 // backend pid that sent the notification Channel string // channel from which notification was received Payload string } // DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // LookupFunc is a function that can be used to lookup IPs addrs from host. Optionally an ip:port combination can be // returned in order to override the connection string's port. type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY // notification. type NoticeHandler func(*PgConn, *Notice) // NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications // can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is // aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a // notice event. type NotificationHandler func(*PgConn, *Notification) // Frontend used to receive messages from backend. type Frontend interface { Receive() (pgproto3.BackendMessage, error) } // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend Frontend config *Config status byte // One of connStatus* constants bufferingReceive bool bufferingReceiveMux sync.Mutex bufferingReceiveMsg pgproto3.BackendMessage bufferingReceiveErr error peekedMsg pgproto3.BackendMessage // Reusable / preallocated resources wbuf []byte // write buffer resultReader ResultReader multiResultReader MultiResultReader contextWatcher *ctxwatch.ContextWatcher cleanupDone chan struct{} } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) // to provide configuration. See documentation for ParseConfig for details. ctx can be used to cancel a connect attempt. func Connect(ctx context.Context, connString string) (*PgConn, error) { config, err := ParseConfig(connString) if err != nil { return nil, err } return ConnectConfig(ctx, config) } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) // and ParseConfigOptions to provide additional configuration. See documentation for ParseConfig for details. ctx can be // used to cancel a connect attempt. func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { config, err := ParseConfigWithOptions(connString, parseConfigOptions) if err != nil { return nil, err } return ConnectConfig(ctx, config) } // Connect establishes a connection to a PostgreSQL server using config. config must have been constructed with // ParseConfig. ctx can be used to cancel a connect attempt. // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, // if all attempts fail the last error is returned. func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { panic("config must be created by ParseConfig") } // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { Host: config.Host, Port: config.Port, TLSConfig: config.TLSConfig, }, } fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) ctx := octx fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) if err != nil { return nil, &connectError{config: config, msg: "hostname resolving error", err: err} } if len(fallbackConfigs) == 0 { return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} } foundBestServer := false var fallbackConfig *FallbackConfig for _, fc := range fallbackConfigs { // ConnectTimeout restricts the whole connection process. if config.ConnectTimeout != 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) defer cancel() } else { ctx = octx } pgConn, err = connect(ctx, config, fc, false) if err == nil { foundBestServer = true break } else if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege if pgerr.Code == ERRCODE_INVALID_PASSWORD || pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION || pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { break } } else if cerr, ok := err.(*connectError); ok { if _, ok := cerr.err.(*NotPreferredError); ok { fallbackConfig = fc } } } if !foundBestServer && fallbackConfig != nil { pgConn, err = connect(ctx, config, fallbackConfig, true) if pgerr, ok := err.(*PgError); ok { err = &connectError{config: config, msg: "server error", err: pgerr} } } if err != nil { return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "AfterConnect error", err: err} } } return pgConn, nil } func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { var configs []*FallbackConfig for _, fb := range fallbacks { // skip resolve for unix sockets if isAbsolutePath(fb.Host) { configs = append(configs, &FallbackConfig{ Host: fb.Host, Port: fb.Port, TLSConfig: fb.TLSConfig, }) continue } ips, err := lookupFn(ctx, fb.Host) if err != nil { return nil, err } for _, ip := range ips { splitIP, splitPort, err := net.SplitHostPort(ip) if err == nil { port, err := strconv.ParseUint(splitPort, 10, 16) if err != nil { return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) } configs = append(configs, &FallbackConfig{ Host: splitIP, Port: uint16(port), TLSConfig: fb.TLSConfig, }) } else { configs = append(configs, &FallbackConfig{ Host: ip, Port: fb.Port, TLSConfig: fb.TLSConfig, }) } } } return configs, nil } func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, ignoreNotPreferredErr bool) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config pgConn.wbuf = make([]byte, 0, wbufLen) pgConn.cleanupDone = make(chan struct{}) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) netConn, err := config.DialFunc(ctx, network, address) if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { err = &errTimeout{err: err} } return nil, &connectError{config: config, msg: "dial error", err: err} } pgConn.conn = netConn pgConn.contextWatcher = newContextWatcher(netConn) pgConn.contextWatcher.Watch(ctx) if fallbackConfig.TLSConfig != nil { tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig) pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { netConn.Close() return nil, &connectError{config: config, msg: "tls error", err: err} } pgConn.conn = tlsConn pgConn.contextWatcher = newContextWatcher(tlsConn) pgConn.contextWatcher.Watch(ctx) } defer pgConn.contextWatcher.Unwatch() pgConn.parameterStatuses = make(map[string]string) pgConn.status = connStatusConnecting pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn) startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: make(map[string]string), } // Copy default run-time params for k, v := range config.RuntimeParams { startupMsg.Parameters[k] = v } startupMsg.Parameters["user"] = config.User if config.Database != "" { startupMsg.Parameters["database"] = config.Database } if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed to write startup message", err: err} } for { msg, err := pgConn.receiveMessage() if err != nil { pgConn.conn.Close() if err, ok := err.(*PgError); ok { return nil, err } return nil, &connectError{config: config, msg: "failed to receive message", err: preferContextOverNetTimeoutError(ctx, err)} } switch msg := msg.(type) { case *pgproto3.BackendKeyData: pgConn.pid = msg.ProcessID pgConn.secretKey = msg.SecretKey case *pgproto3.AuthenticationOk: case *pgproto3.AuthenticationCleartextPassword: err = pgConn.txPasswordMessage(pgConn.config.Password) if err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed to write password message", err: err} } case *pgproto3.AuthenticationMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) if err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed to write password message", err: err} } case *pgproto3.AuthenticationSASL: err = pgConn.scramAuth(msg.AuthMechanisms) if err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed SASL auth", err: err} } case *pgproto3.AuthenticationGSS: err = pgConn.gssAuth() if err != nil { pgConn.conn.Close() return nil, &connectError{config: config, msg: "failed GSS auth", err: err} } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { // ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid // the watch already in progress panic. This is that last thing done by this method so there is no need to // restart the watch after ValidateConnect returns. // // See https://github.com/jackc/pgconn/issues/40. pgConn.contextWatcher.Unwatch() err := config.ValidateConnect(ctx, pgConn) if err != nil { if _, ok := err.(*NotPreferredError); ignoreNotPreferredErr && ok { return pgConn, nil } pgConn.conn.Close() return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err} } } return pgConn, nil case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() return nil, ErrorResponseToPgError(msg) default: pgConn.conn.Close() return nil, &connectError{config: config, msg: "received unexpected message", err: err} } } } func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { return ctxwatch.NewContextWatcher( func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { conn.SetDeadline(time.Time{}) }, ) } func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return nil, err } response := make([]byte, 1) if _, err = io.ReadFull(conn, response); err != nil { return nil, err } if response[0] != 'S' { return nil, errors.New("server refused TLS connection") } return tls.Client(conn, tlsConfig), nil } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) return err } func hexMD5(s string) string { hash := md5.New() io.WriteString(hash, s) return hex.EncodeToString(hash.Sum(nil)) } func (pgConn *PgConn) signalMessage() chan struct{} { if pgConn.bufferingReceive { panic("BUG: signalMessage when already in progress") } pgConn.bufferingReceive = true pgConn.bufferingReceiveMux.Lock() ch := make(chan struct{}) go func() { pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.frontend.Receive() pgConn.bufferingReceiveMux.Unlock() close(ch) }() return ch } // SendBytes sends buf to the PostgreSQL server. It must only be used when the connection is not busy. e.g. It is as // error to call SendBytes while reading the result of a query. // // This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error { if err := pgConn.lock(); err != nil { return err } defer pgConn.unlock() if ctx != context.Background() { select { case <-ctx.Done(): return newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() return &writeError{err: err, safeToRetry: n == 0} } return nil } // ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the // connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages // are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger // the OnNotification callback. // // This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessage, error) { if err := pgConn.lock(); err != nil { return nil, err } defer pgConn.unlock() if ctx != context.Background() { select { case <-ctx.Done(): return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } msg, err := pgConn.receiveMessage() if err != nil { err = &pgconnError{ msg: "receive message failed", err: preferContextOverNetTimeoutError(ctx, err), safeToRetry: true} } return msg, err } // peekMessage peeks at the next message without setting up context cancellation. func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { if pgConn.peekedMsg != nil { return pgConn.peekedMsg, nil } var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { pgConn.bufferingReceiveMux.Lock() msg = pgConn.bufferingReceiveMsg err = pgConn.bufferingReceiveErr pgConn.bufferingReceiveMux.Unlock() pgConn.bufferingReceive = false // If a timeout error happened in the background try the read again. var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { msg, err = pgConn.frontend.Receive() } } else { msg, err = pgConn.frontend.Receive() } if err != nil { // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() } return nil, err } pgConn.peekedMsg = msg return msg, nil } // receiveMessage receives a message without setting up context cancellation func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := pgConn.peekMessage() if err != nil { // Close on anything other than timeout error - everything else is fatal var netErr net.Error isNetErr := errors.As(err, &netErr) if !(isNetErr && netErr.Timeout()) { pgConn.asyncClose() } return nil, err } pgConn.peekedMsg = nil switch msg := msg.(type) { case *pgproto3.ReadyForQuery: pgConn.txStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. close(pgConn.cleanupDone) return nil, ErrorResponseToPgError(msg) } case *pgproto3.NoticeResponse: if pgConn.config.OnNotice != nil { pgConn.config.OnNotice(pgConn, noticeResponseToNotice(msg)) } case *pgproto3.NotificationResponse: if pgConn.config.OnNotification != nil { pgConn.config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) } } return msg, nil } // Conn returns the underlying net.Conn. func (pgConn *PgConn) Conn() net.Conn { return pgConn.conn } // PID returns the backend PID. func (pgConn *PgConn) PID() uint32 { return pgConn.pid } // TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. // // Possible return values: // 'I' - idle / not in transaction // 'T' - in a transaction // 'E' - in a failed transaction // // See https://www.postgresql.org/docs/current/protocol-message-formats.html. func (pgConn *PgConn) TxStatus() byte { return pgConn.txStatus } // SecretKey returns the backend secret key used to send a cancel query message to the server. func (pgConn *PgConn) SecretKey() uint32 { return pgConn.secretKey } // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. func (pgConn *PgConn) Close(ctx context.Context) error { if pgConn.status == connStatusClosed { return nil } pgConn.status = connStatusClosed defer close(pgConn.cleanupDone) defer pgConn.conn.Close() if ctx != context.Background() { // Close may be called while a cancellable query is in progress. This will most often be triggered by panic when // a defer closes the connection (possibly indirectly via a transaction or a connection pool). Unwatch to end any // previous watch. It is safe to Unwatch regardless of whether a watch is already is progress. // // See https://github.com/jackc/pgconn/issues/29 pgConn.contextWatcher.Unwatch() pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } // Ignore any errors sending Terminate message and waiting for server to close connection. // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully // ignores errors. // // See https://github.com/jackc/pgx/issues/637 pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) return pgConn.conn.Close() } // asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying // connection. func (pgConn *PgConn) asyncClose() { if pgConn.status == connStatusClosed { return } pgConn.status = connStatusClosed go func() { defer close(pgConn.cleanupDone) defer pgConn.conn.Close() deadline := time.Now().Add(time.Second * 15) ctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() pgConn.CancelRequest(ctx) pgConn.conn.SetDeadline(deadline) pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) }() } // CleanupDone returns a channel that will be closed after all underlying resources have been cleaned up. A closed // connection is no longer usable, but underlying resources, in particular the net.Conn, may not have finished closing // yet. This is because certain errors such as a context cancellation require that the interrupted function call return // immediately, but the error may also cause the connection to be closed. In these cases the underlying resources are // closed asynchronously. // // This is only likely to be useful to connection pools. It gives them a way avoid establishing a new connection while // an old connection is still being cleaned up and thereby exceeding the maximum pool size. func (pgConn *PgConn) CleanupDone() chan (struct{}) { return pgConn.cleanupDone } // IsClosed reports if the connection has been closed. // // CleanupDone() can be used to determine if all cleanup has been completed. func (pgConn *PgConn) IsClosed() bool { return pgConn.status < connStatusIdle } // IsBusy reports if the connection is busy. func (pgConn *PgConn) IsBusy() bool { return pgConn.status == connStatusBusy } // lock locks the connection. func (pgConn *PgConn) lock() error { switch pgConn.status { case connStatusBusy: return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug. case connStatusClosed: return &connLockError{status: "conn closed"} case connStatusUninitialized: return &connLockError{status: "conn uninitialized"} } pgConn.status = connStatusBusy return nil } func (pgConn *PgConn) unlock() { switch pgConn.status { case connStatusBusy: pgConn.status = connStatusIdle case connStatusClosed: default: panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. } } // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { return pgConn.parameterStatuses[key] } // CommandTag is the result of an Exec function type CommandTag []byte // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { // Find last non-digit idx := -1 for i := len(ct) - 1; i >= 0; i-- { if ct[i] >= '0' && ct[i] <= '9' { idx = i } else { break } } if idx == -1 { return 0 } var n int64 for _, b := range ct[idx:] { n = n*10 + int64(b-'0') } return n } func (ct CommandTag) String() string { return string(ct) } // Insert is true if the command tag starts with "INSERT". func (ct CommandTag) Insert() bool { return len(ct) >= 6 && ct[0] == 'I' && ct[1] == 'N' && ct[2] == 'S' && ct[3] == 'E' && ct[4] == 'R' && ct[5] == 'T' } // Update is true if the command tag starts with "UPDATE". func (ct CommandTag) Update() bool { return len(ct) >= 6 && ct[0] == 'U' && ct[1] == 'P' && ct[2] == 'D' && ct[3] == 'A' && ct[4] == 'T' && ct[5] == 'E' } // Delete is true if the command tag starts with "DELETE". func (ct CommandTag) Delete() bool { return len(ct) >= 6 && ct[0] == 'D' && ct[1] == 'E' && ct[2] == 'L' && ct[3] == 'E' && ct[4] == 'T' && ct[5] == 'E' } // Select is true if the command tag starts with "SELECT". func (ct CommandTag) Select() bool { return len(ct) >= 6 && ct[0] == 'S' && ct[1] == 'E' && ct[2] == 'L' && ct[3] == 'E' && ct[4] == 'C' && ct[5] == 'T' } type StatementDescription struct { Name string SQL string ParamOIDs []uint32 Fields []pgproto3.FieldDescription } // Prepare creates a prepared statement. If the name is empty, the anonymous prepared statement will be used. This // allows Prepare to also to describe statements without creating a server-side prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { return nil, err } defer pgConn.unlock() if ctx != context.Background() { select { case <-ctx.Done(): return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() return nil, &writeError{err: err, safeToRetry: n == 0} } psd := &StatementDescription{Name: name, SQL: sql} var parseErr error readloop: for { msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { case *pgproto3.ParameterDescription: psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: parseErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } } if parseErr != nil { return nil, parseErr } return psd, nil } // ErrorResponseToPgError converts a wire protocol error message to a *PgError. func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, Code: string(msg.Code), Message: string(msg.Message), Detail: string(msg.Detail), Hint: msg.Hint, Position: msg.Position, InternalPosition: msg.InternalPosition, InternalQuery: string(msg.InternalQuery), Where: string(msg.Where), SchemaName: string(msg.SchemaName), TableName: string(msg.TableName), ColumnName: string(msg.ColumnName), DataTypeName: string(msg.DataTypeName), ConstraintName: msg.ConstraintName, File: string(msg.File), Line: msg.Line, Routine: string(msg.Routine), } } func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { pgerr := ErrorResponseToPgError((*pgproto3.ErrorResponse)(msg)) return (*Notice)(pgerr) } // CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() if ctx != context.Background() { contextWatcher := ctxwatch.NewContextWatcher( func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Time{}) }, ) contextWatcher.Watch(ctx) defer contextWatcher.Unwatch() } buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { return err } _, err = cancelConn.Read(buf) if err != io.EOF { return err } return nil } // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { if err := pgConn.lock(); err != nil { return err } defer pgConn.unlock() if ctx != context.Background() { select { case <-ctx.Done(): return newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } for { msg, err := pgConn.receiveMessage() if err != nil { return preferContextOverNetTimeoutError(ctx, err) } switch msg.(type) { case *pgproto3.NotificationResponse: return nil } } } // Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control // statements. // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, err: err, } } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, ctx: ctx, } multiResult := &pgConn.multiResultReader if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: } pgConn.contextWatcher.Watch(ctx) } buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true multiResult.err = &writeError{err: err, safeToRetry: n == 0} pgConn.unlock() return multiResult } return multiResult } // ReceiveResults reads the result that might be returned by Postgres after a SendBytes // (e.a. after sending a CopyDone in a copy-both situation). // // This is a very low level method that requires deep understanding of the PostgreSQL wire protocol to use correctly. // See https://www.postgresql.org/docs/current/protocol.html. func (pgConn *PgConn) ReceiveResults(ctx context.Context) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, err: err, } } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, ctx: ctx, } multiResult := &pgConn.multiResultReader if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: } pgConn.contextWatcher.Watch(ctx) } return multiResult } // ExecParams executes a command via the PostgreSQL extended query protocol. // // sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, // etc. // // paramValues are the parameter values. It must be encoded in the format given by paramFormats. // // paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for // all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. // ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). // // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or // binary format. If paramFormats is nil all params are text format. ExecParams will panic if // len(paramFormats) is not 0, 1, or len(paramValues). // // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result } buf := pgConn.wbuf buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) pgConn.execExtendedSuffix(buf, result) return result } // ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. // // paramValues are the parameter values. It must be encoded in the format given by paramFormats. // // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or // binary format. If paramFormats is nil all params are text format. ExecPrepared will panic if // len(paramFormats) is not 0, 1, or len(paramValues). // // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result } buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) pgConn.execExtendedSuffix(buf, result) return result } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, } result := &pgConn.resultReader if err := pgConn.lock(); err != nil { result.concludeCommand(nil, err) result.closed = true return result } if len(paramValues) > math.MaxUint16 { result.concludeCommand(nil, fmt.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result } if ctx != context.Background() { select { case <-ctx.Done(): result.concludeCommand(nil, newContextAlreadyDoneError(ctx)) result.closed = true pgConn.unlock() return result default: } pgConn.contextWatcher.Watch(ctx) } return result } func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() result.concludeCommand(nil, &writeError{err: err, safeToRetry: n == 0}) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() return } result.readUntilRowDescription() } // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { return nil, err } if ctx != context.Background() { select { case <-ctx.Done(): pgConn.unlock() return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } // Send copy to command buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() pgConn.unlock() return nil, &writeError{err: err, safeToRetry: n == 0} } // Read results var commandTag CommandTag var pgErr error for { msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { case *pgproto3.CopyDone: case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { pgConn.asyncClose() return nil, err } case *pgproto3.ReadyForQuery: pgConn.unlock() return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) } } } // CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. // // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { return nil, err } defer pgConn.unlock() if ctx != context.Background() { select { case <-ctx.Done(): return nil, newContextAlreadyDoneError(ctx) default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() } // Send copy to command buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() return nil, &writeError{err: err, safeToRetry: n == 0} } // Send copy data abortCopyChan := make(chan struct{}) copyErrChan := make(chan error, 1) signalMessageChan := pgConn.signalMessage() var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() buf := make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) for { n, readErr := r.Read(buf[5:cap(buf)]) if n > 0 { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) _, writeErr := pgConn.conn.Write(buf) if writeErr != nil { // Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. pgConn.conn.Close() copyErrChan <- writeErr return } } if readErr != nil { copyErrChan <- readErr return } select { case <-abortCopyChan: return default: } } }() var pgErr error var copyErr error for copyErr == nil && pgErr == nil { select { case copyErr = <-copyErrChan: case <-signalMessageChan: msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) default: signalMessageChan = pgConn.signalMessage() } } } close(abortCopyChan) // Make sure io goroutine finishes before writing. wg.Wait() buf = buf[:0] if copyErr == io.EOF || pgErr != nil { copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { copyFail := &pgproto3.CopyFail{Message: copyErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) if err != nil { pgConn.asyncClose() return nil, err } // Read results var commandTag CommandTag for { msg, err := pgConn.receiveMessage() if err != nil { pgConn.asyncClose() return nil, preferContextOverNetTimeoutError(ctx, err) } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: pgErr = ErrorResponseToPgError(msg) } } } // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn ctx context.Context rr *ResultReader closed bool err error } // ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { var results []*Result for mrr.NextResult() { results = append(results, mrr.ResultReader().Read()) } err := mrr.Close() return results, err } func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := mrr.pgConn.receiveMessage() if err != nil { mrr.pgConn.contextWatcher.Unwatch() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.asyncClose() return nil, mrr.err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: mrr.pgConn.contextWatcher.Unwatch() mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = ErrorResponseToPgError(msg) } return msg, nil } // NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. func (mrr *MultiResultReader) NextResult() bool { for !mrr.closed && mrr.err == nil { msg, err := mrr.receiveMessage() if err != nil { return false } switch msg := msg.(type) { case *pgproto3.RowDescription: mrr.pgConn.resultReader = ResultReader{ pgConn: mrr.pgConn, multiResultReader: mrr, ctx: mrr.ctx, fieldDescriptions: msg.Fields, } mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.CommandComplete: mrr.pgConn.resultReader = ResultReader{ commandTag: CommandTag(msg.CommandTag), commandConcluded: true, closed: true, } mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.EmptyQueryResponse: return false } } return false } // ResultReader returns the current ResultReader. func (mrr *MultiResultReader) ResultReader() *ResultReader { return mrr.rr } // Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. func (mrr *MultiResultReader) Close() error { for !mrr.closed { _, err := mrr.receiveMessage() if err != nil { return mrr.err } } return mrr.err } // ResultReader is a reader for the result of a single query. type ResultReader struct { pgConn *PgConn multiResultReader *MultiResultReader ctx context.Context fieldDescriptions []pgproto3.FieldDescription rowValues [][]byte commandTag CommandTag commandConcluded bool closed bool err error } // Result is the saved query response that is returned by calling Read on a ResultReader. type Result struct { FieldDescriptions []pgproto3.FieldDescription Rows [][][]byte CommandTag CommandTag Err error } // Read saves the query response to a Result. func (rr *ResultReader) Read() *Result { br := &Result{} for rr.NextRow() { if br.FieldDescriptions == nil { br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) copy(br.FieldDescriptions, rr.FieldDescriptions()) } row := make([][]byte, len(rr.Values())) copy(row, rr.Values()) br.Rows = append(br.Rows, row) } br.CommandTag, br.Err = rr.Close() return br } // NextRow advances the ResultReader to the next row and returns true if a row is available. func (rr *ResultReader) NextRow() bool { for !rr.commandConcluded { msg, err := rr.receiveMessage() if err != nil { return false } switch msg := msg.(type) { case *pgproto3.DataRow: rr.rowValues = msg.Values return true } } return false } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until // the ResultReader is closed. func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { return rr.fieldDescriptions } // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only // valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to // retain a reference to and mutate. func (rr *ResultReader) Values() [][]byte { return rr.rowValues } // Close consumes any remaining result data and returns the command tag or // error. func (rr *ResultReader) Close() (CommandTag, error) { if rr.closed { return rr.commandTag, rr.err } rr.closed = true for !rr.commandConcluded { _, err := rr.receiveMessage() if err != nil { return nil, rr.err } } if rr.multiResultReader == nil { for { msg, err := rr.receiveMessage() if err != nil { return nil, rr.err } switch msg := msg.(type) { // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. case *pgproto3.ErrorResponse: rr.err = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() return rr.commandTag, rr.err } } } return rr.commandTag, rr.err } // readUntilRowDescription ensures the ResultReader's fieldDescriptions are loaded. It does not return an error as any // error will be stored in the ResultReader. func (rr *ResultReader) readUntilRowDescription() { for !rr.commandConcluded { // Peek before receive to avoid consuming a DataRow if the result set does not include a RowDescription method. // This should never happen under normal pgconn usage, but it is possible if SendBytes and ReceiveResults are // manually used to construct a query that does not issue a describe statement. msg, _ := rr.pgConn.peekMessage() if _, ok := msg.(*pgproto3.DataRow); ok { return } // Consume the message msg, _ = rr.receiveMessage() if _, ok := msg.(*pgproto3.RowDescription); ok { return } } } func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { msg, err = rr.pgConn.receiveMessage() } else { msg, err = rr.multiResultReader.receiveMessage() } if err != nil { err = preferContextOverNetTimeoutError(rr.ctx, err) rr.concludeCommand(nil, err) rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { rr.pgConn.asyncClose() } return nil, rr.err } switch msg := msg.(type) { case *pgproto3.RowDescription: rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.EmptyQueryResponse: rr.concludeCommand(nil, nil) case *pgproto3.ErrorResponse: rr.concludeCommand(nil, ErrorResponseToPgError(msg)) } return msg, nil } func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { // Keep the first error that is recorded. Store the error before checking if the command is already concluded to // allow for receiving an error after CommandComplete but before ReadyForQuery. if err != nil && rr.err == nil { rr.err = err } if rr.commandConcluded { return } rr.commandTag = commandTag rr.rowValues = nil rr.commandConcluded = true } // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) batch.ExecPrepared("", paramValues, paramFormats, resultFormats) } // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) } // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, err: err, } } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, ctx: ctx, } multiResult := &pgConn.multiResultReader if ctx != context.Background() { select { case <-ctx.Done(): multiResult.closed = true multiResult.err = newContextAlreadyDoneError(ctx) pgConn.unlock() return multiResult default: } pgConn.contextWatcher.Watch(ctx) } batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. // The error the code reading the batch results receives will be a closed connection error. // // See https://github.com/jackc/pgx/issues/374. go func() { _, err := pgConn.conn.Write(batch.buf) if err != nil { pgConn.conn.Close() } }() return multiResult } // EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include // the surrounding single quotes. // // The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these // conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. func (pgConn *PgConn) EscapeString(s string) (string, error) { if pgConn.ParameterStatus("standard_conforming_strings") != "on" { return "", errors.New("EscapeString must be run with standard_conforming_strings=on") } if pgConn.ParameterStatus("client_encoding") != "UTF8" { return "", errors.New("EscapeString must be run with client_encoding=UTF8") } return strings.Replace(s, "'", "''", -1), nil } // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. type HijackedConn struct { Conn net.Conn // the underlying TCP or unix domain socket connection PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server ParameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte Frontend Frontend Config *Config } // Hijack extracts the internal connection data. pgConn must be in an idle state. pgConn is unusable after hijacking. // Hijacking is typically only useful when using pgconn to establish a connection, but taking complete control of the // raw connection after that (e.g. a load balancer or proxy). // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. func (pgConn *PgConn) Hijack() (*HijackedConn, error) { if err := pgConn.lock(); err != nil { return nil, err } pgConn.status = connStatusClosed return &HijackedConn{ Conn: pgConn.conn, PID: pgConn.pid, SecretKey: pgConn.secretKey, ParameterStatuses: pgConn.parameterStatuses, TxStatus: pgConn.txStatus, Frontend: pgConn.frontend, Config: pgConn.config, }, nil } // Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of // PgConn.Hijack. The connection must be in an idle state. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning // compatibility. func Construct(hc *HijackedConn) (*PgConn, error) { pgConn := &PgConn{ conn: hc.Conn, pid: hc.PID, secretKey: hc.SecretKey, parameterStatuses: hc.ParameterStatuses, txStatus: hc.TxStatus, frontend: hc.Frontend, config: hc.Config, status: connStatusIdle, wbuf: make([]byte, 0, wbufLen), cleanupDone: make(chan struct{}), } pgConn.contextWatcher = newContextWatcher(pgConn.conn) return pgConn, nil } pgconn-1.14.0/pgconn_stress_test.go000066400000000000000000000050021437172345200173430ustar00rootroot00000000000000package pgconn_test import ( "context" "math/rand" "os" "runtime" "strconv" "testing" "github.com/jackc/pgconn" "github.com/stretchr/testify/require" ) func TestConnStress(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) actionCount := 10000 if s := os.Getenv("PGX_TEST_STRESS_FACTOR"); s != "" { stressFactor, err := strconv.ParseInt(s, 10, 64) require.Nil(t, err, "Failed to parse PGX_TEST_STRESS_FACTOR") actionCount *= int(stressFactor) } setupStressDB(t, pgConn) actions := []struct { name string fn func(*pgconn.PgConn) error }{ {"Exec Select", stressExecSelect}, {"ExecParams Select", stressExecParamsSelect}, {"Batch", stressBatch}, } for i := 0; i < actionCount; i++ { action := actions[rand.Intn(len(actions))] err := action.fn(pgConn) require.Nilf(t, err, "%d: %s", i, action.name) } // Each call with a context starts a goroutine. Ensure they are cleaned up when context is not canceled. numGoroutine := runtime.NumGoroutine() require.Truef(t, numGoroutine < 1000, "goroutines appear to be orphaned: %d in process", numGoroutine) } func setupStressDB(t *testing.T, pgConn *pgconn.PgConn) { _, err := pgConn.Exec(context.Background(), ` create temporary table widgets( id serial primary key, name varchar not null, description text, creation_time timestamptz default now() ); insert into widgets(name, description) values ('Foo', 'bar'), ('baz', 'Something really long Something really long Something really long Something really long Something really long'), ('a', 'b')`).ReadAll() require.NoError(t, err) } func stressExecSelect(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, err := pgConn.Exec(ctx, "select * from widgets").ReadAll() return err } func stressExecParamsSelect(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() result := pgConn.ExecParams(ctx, "select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil).Read() return result.Err } func stressBatch(pgConn *pgconn.PgConn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() batch := &pgconn.Batch{} batch.ExecParams("select * from widgets", nil, nil, nil, nil) batch.ExecParams("select * from widgets where id < $1", [][]byte{[]byte("10")}, nil, nil, nil) _, err := pgConn.ExecBatch(ctx, batch).ReadAll() return err } pgconn-1.14.0/pgconn_test.go000066400000000000000000002066461437172345200157610ustar00rootroot00000000000000package pgconn_test import ( "bytes" "compress/gzip" "context" "crypto/tls" "errors" "fmt" "io" "io/ioutil" "log" "math" "net" "os" "strconv" "strings" "testing" "time" "github.com/jackc/pgconn" "github.com/jackc/pgmock" "github.com/jackc/pgproto3/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConnect(t *testing.T) { tests := []struct { name string env string }{ {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, {"TCP", "PGX_TEST_TCP_CONN_STRING"}, {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { connString := os.Getenv(tt.env) if connString == "" { t.Skipf("Skipping due to missing environment variable %v", tt.env) } conn, err := pgconn.Connect(context.Background(), connString) require.NoError(t, err) closeConn(t, conn) }) } } func TestConnectWithOptions(t *testing.T) { tests := []struct { name string env string }{ {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, {"TCP", "PGX_TEST_TCP_CONN_STRING"}, {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, {"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { connString := os.Getenv(tt.env) if connString == "" { t.Skipf("Skipping due to missing environment variable %v", tt.env) } var sslOptions pgconn.ParseConfigOptions sslOptions.GetSSLPassword = GetSSLPassword conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions) require.NoError(t, err) closeConn(t, conn) }) } } // TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure // connection. func TestConnectTLS(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") } var conn *pgconn.PgConn var err error var sslOptions pgconn.ParseConfigOptions sslOptions.GetSSLPassword = GetSSLPassword config, err := pgconn.ParseConfigWithOptions(connString, sslOptions) require.Nil(t, err) conn, err = pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) if _, ok := conn.Conn().(*tls.Conn); !ok { t.Error("not a TLS connection") } closeConn(t, conn) } type pgmockWaitStep time.Duration func (s pgmockWaitStep) Step(*pgproto3.Backend) error { time.Sleep(time.Duration(s)) return nil } func TestConnectTimeout(t *testing.T) { t.Parallel() tests := []struct { name string connect func(connStr string) error }{ { name: "via context that times out", connect: func(connStr string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) defer cancel() _, err := pgconn.Connect(ctx, connStr) return err }, }, { name: "via config ConnectTimeout", connect: func(connStr string) error { conf, err := pgconn.ParseConfig(connStr) require.NoError(t, err) conf.ConnectTimeout = time.Microsecond * 50 _, err = pgconn.ConnectConfig(context.Background(), conf) return err }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() script := &pgmock.Script{ Steps: []pgmock.Step{ pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), pgmock.SendMessage(&pgproto3.AuthenticationOk{}), pgmockWaitStep(time.Millisecond * 500), pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), }, } ln, err := net.Listen("tcp", "127.0.0.1:") require.NoError(t, err) defer ln.Close() serverErrChan := make(chan error, 1) go func() { defer close(serverErrChan) conn, err := ln.Accept() if err != nil { serverErrChan <- err return } defer conn.Close() err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450)) if err != nil { serverErrChan <- err return } err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) if err != nil { serverErrChan <- err return } }() parts := strings.Split(ln.Addr().String(), ":") host := parts[0] port := parts[1] connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) tooLate := time.Now().Add(time.Millisecond * 500) err = tt.connect(connStr) require.True(t, pgconn.Timeout(err), err) require.True(t, time.Now().Before(tooLate)) }) } } func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) { t.Parallel() tests := []struct { name string connect func(connStr string) error }{ { name: "via context that times out", connect: func(connStr string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) defer cancel() _, err := pgconn.Connect(ctx, connStr) return err }, }, { name: "via config ConnectTimeout", connect: func(connStr string) error { conf, err := pgconn.ParseConfig(connStr) require.NoError(t, err) conf.ConnectTimeout = time.Millisecond * 10 _, err = pgconn.ConnectConfig(context.Background(), conf) return err }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() ln, err := net.Listen("tcp", "127.0.0.1:") require.NoError(t, err) defer ln.Close() serverErrChan := make(chan error) defer close(serverErrChan) go func() { conn, err := ln.Accept() if err != nil { serverErrChan <- err return } defer conn.Close() var buf []byte _, err = conn.Read(buf) if err != nil { serverErrChan <- err return } // Sleeping to hang the TLS handshake. time.Sleep(time.Minute) }() parts := strings.Split(ln.Addr().String(), ":") host := parts[0] port := parts[1] connStr := fmt.Sprintf("host=%s port=%s", host, port) errChan := make(chan error) go func() { err := tt.connect(connStr) errChan <- err }() select { case err = <-errChan: require.True(t, pgconn.Timeout(err), err) case err = <-serverErrChan: t.Fatalf("server failed with error: %s", err) case <-time.After(time.Millisecond * 100): t.Fatal("exceeded connection timeout without erroring out") } }) } } func TestConnectInvalidUser(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") } config, err := pgconn.ParseConfig(connString) require.NoError(t, err) config.User = "pgxinvalidusertest" _, err = pgconn.ConnectConfig(context.Background(), config) require.Error(t, err) pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) if !ok { t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) } if pgErr.Code != "28000" && pgErr.Code != "28P01" { t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) } } func TestConnectWithConnectionRefused(t *testing.T) { t.Parallel() // Presumably nothing is listening on 127.0.0.1:1 conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") if err == nil { conn.Close(context.Background()) t.Fatal("Expected error establishing connection to bad port") } } func TestConnectCustomDialer(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) dialed := false config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { dialed = true return net.Dial(network, address) } conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) require.True(t, dialed) closeConn(t, conn) } func TestConnectCustomLookup(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") } config, err := pgconn.ParseConfig(connString) require.NoError(t, err) looked := false config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { looked = true return net.LookupHost(host) } conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) require.True(t, looked) closeConn(t, conn) } func TestConnectCustomLookupWithPort(t *testing.T) { t.Parallel() connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") } config, err := pgconn.ParseConfig(connString) require.NoError(t, err) origPort := config.Port // Chnage the config an invalid port so it will fail if used config.Port = 0 looked := false config.LookupFunc = func(ctx context.Context, host string) ([]string, error) { looked = true addrs, err := net.LookupHost(host) if err != nil { return nil, err } for i := range addrs { addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10)) } return addrs, nil } conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) require.True(t, looked) closeConn(t, conn) } func TestConnectWithRuntimeParams(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) config.RuntimeParams = map[string]string{ "application_name": "pgxtest", "search_path": "myschema", } conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, conn) result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read() require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "pgxtest", string(result.Rows[0][0])) result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read() require.Nil(t, result.Err) assert.Equal(t, 1, len(result.Rows)) assert.Equal(t, "myschema", string(result.Rows[0][0])) } func TestConnectWithFallback(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) // Prepend current primary config to fallbacks config.Fallbacks = append([]*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: config.Host, Port: config.Port, TLSConfig: config.TLSConfig, }, }, config.Fallbacks...) // Make primary config bad config.Host = "localhost" config.Port = 1 // presumably nothing listening here // Prepend bad first fallback config.Fallbacks = append([]*pgconn.FallbackConfig{ &pgconn.FallbackConfig{ Host: "localhost", Port: 1, TLSConfig: config.TLSConfig, }, }, config.Fallbacks...) conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) closeConn(t, conn) } func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) dialCount := 0 config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { dialCount++ return net.Dial(network, address) } acceptConnCount := 0 config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error { acceptConnCount++ if acceptConnCount < 2 { return errors.New("reject first conn") } return nil } // Append current primary config to fallbacks config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{ Host: config.Host, Port: config.Port, TLSConfig: config.TLSConfig, }) // Repeat fallbacks config.Fallbacks = append(config.Fallbacks, config.Fallbacks...) conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) closeConn(t, conn) assert.True(t, dialCount > 1) assert.True(t, acceptConnCount > 1) } func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite config.RuntimeParams["default_transaction_read_only"] = "on" ctx, cancel := context.WithCancel(context.Background()) defer cancel() conn, err := pgconn.ConnectConfig(ctx, config) if !assert.NotNil(t, err) { conn.Close(ctx) } } func TestConnectWithAfterConnect(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error { _, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll() return err } conn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) results, err := conn.Exec(context.Background(), "show search_path;").ReadAll() require.NoError(t, err) defer closeConn(t, conn) assert.Equal(t, []byte("foobar"), results[0].Rows[0][0]) } func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { t.Parallel() config := &pgconn.Config{} require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) }) } func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil) require.Nil(t, psd) require.NotNil(t, err) ensureConnValid(t, pgConn) } func TestConnPrepareContextPrecanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithCancel(context.Background()) cancel() psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil) assert.Nil(t, psd) assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } func TestConnExec(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) ensureConnValid(t, pgConn) } func TestConnExecEmpty(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) multiResult := pgConn.Exec(context.Background(), ";") resultCount := 0 for multiResult.NextResult() { resultCount++ multiResult.ResultReader().Close() } assert.Equal(t, 0, resultCount) err = multiResult.Close() assert.NoError(t, err) ensureConnValid(t, pgConn) } func TestConnExecMultipleQueries(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll() assert.NoError(t, err) assert.Len(t, results, 2) assert.Nil(t, results[0].Err) assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) assert.Nil(t, results[1].Err) assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) assert.Len(t, results[1].Rows, 1) assert.Equal(t, "1", string(results[1].Rows[0][0])) ensureConnValid(t, pgConn) } func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num") require.True(t, mrr.NextResult()) require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) assert.Equal(t, []byte("msg"), mrr.ResultReader().FieldDescriptions()[0].Name) _, err = mrr.ResultReader().Close() require.NoError(t, err) require.True(t, mrr.NextResult()) require.Len(t, mrr.ResultReader().FieldDescriptions(), 1) assert.Equal(t, []byte("num"), mrr.ResultReader().FieldDescriptions()[0].Name) _, err = mrr.ResultReader().Close() require.NoError(t, err) require.False(t, mrr.NextResult()) require.NoError(t, mrr.Close()) ensureConnValid(t, pgConn) } func TestConnExecMultipleQueriesError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll() require.NotNil(t, err) if pgErr, ok := err.(*pgconn.PgError); ok { assert.Equal(t, "22012", pgErr.Code) } else { t.Errorf("unexpected error: %v", err) } if pgConn.ParameterStatus("crdb_version") != "" { // CockroachDB starts the second query result set and then sends the divide by zero error. require.Len(t, results, 2) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "1", string(results[0].Rows[0][0])) assert.Len(t, results[1].Rows, 0) } else { // PostgreSQL sends the divide by zero and never sends the second query result set. require.Len(t, results, 1) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "1", string(results[0].Rows[0][0])) } ensureConnValid(t, pgConn) } func TestConnExecDeferredError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") } setupSQL := `create temporary table t ( id text primary key, n int not null, unique (n) deferrable initially deferred ); insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() assert.NoError(t, err) _, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll() require.NotNil(t, err) var pgErr *pgconn.PgError require.True(t, errors.As(err, &pgErr)) require.Equal(t, "23505", pgErr.Code) ensureConnValid(t, pgConn) } func TestConnExecContextCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)") for multiResult.NextResult() { } err = multiResult.Close() assert.True(t, pgconn.Timeout(err)) assert.ErrorIs(t, err, context.DeadlineExceeded) assert.True(t, pgConn.IsClosed()) select { case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } } func TestConnExecContextPrecanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } func TestConnExecParams(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) require.Len(t, result.FieldDescriptions(), 1) assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) rowCount := 0 for result.NextRow() { rowCount += 1 assert.Equal(t, "Hello, world", string(result.Values()[0])) } assert.Equal(t, 1, rowCount) commandTag, err := result.Close() assert.Equal(t, "SELECT 1", string(commandTag)) assert.NoError(t, err) ensureConnValid(t, pgConn) } func TestConnExecParamsDeferredError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") } setupSQL := `create temporary table t ( id text primary key, n int not null, unique (n) deferrable initially deferred ); insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() assert.NoError(t, err) result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read() require.NotNil(t, result.Err) var pgErr *pgconn.PgError require.True(t, errors.As(result.Err, &pgErr)) require.Equal(t, "23505", pgErr.Code) ensureConnValid(t, pgConn) } func TestConnExecParamsMaxNumberOfParams(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) paramCount := math.MaxUint16 params := make([]string, 0, paramCount) args := make([][]byte, 0, paramCount) for i := 0; i < paramCount; i++ { params = append(params, fmt.Sprintf("($%d::text)", i+1)) args = append(args, []byte(strconv.Itoa(i))) } sql := "values" + strings.Join(params, ", ") result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() require.NoError(t, result.Err) require.Len(t, result.Rows, paramCount) ensureConnValid(t, pgConn) } func TestConnExecParamsTooManyParams(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) paramCount := math.MaxUint16 + 1 params := make([]string, 0, paramCount) args := make([][]byte, 0, paramCount) for i := 0; i < paramCount; i++ { params = append(params, fmt.Sprintf("($%d::text)", i+1)) args = append(args, []byte(strconv.Itoa(i))) } sql := "values" + strings.Join(params, ", ") result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read() require.Error(t, result.Err) require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error()) ensureConnValid(t, pgConn) } func TestConnExecParamsCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil) rowCount := 0 for result.NextRow() { rowCount += 1 } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.True(t, pgconn.Timeout(err)) assert.ErrorIs(t, err, context.DeadlineExceeded) assert.True(t, pgConn.IsClosed()) select { case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } } func TestConnExecParamsPrecanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithCancel(context.Background()) cancel() result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } func TestConnExecParamsEmptySQL(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read() assert.Nil(t, result.CommandTag) assert.Len(t, result.Rows, 0) assert.NoError(t, result.Err) ensureConnValid(t, pgConn) } // https://github.com/jackc/pgx/issues/859 func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil) require.Len(t, result.FieldDescriptions(), 1) assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) rowCount := 0 for result.NextRow() { rowCount += 1 assert.Equal(t, "Hello, world", string(result.Values()[0])) assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0])) } assert.Equal(t, 1, rowCount) commandTag, err := result.Close() assert.Equal(t, "SELECT 1", string(commandTag)) assert.NoError(t, err) ensureConnValid(t, pgConn) } func TestConnExecPrepared(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil) require.NoError(t, err) require.NotNil(t, psd) assert.Len(t, psd.ParamOIDs, 1) assert.Len(t, psd.Fields, 1) result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil) require.Len(t, result.FieldDescriptions(), 1) assert.Equal(t, []byte("msg"), result.FieldDescriptions()[0].Name) rowCount := 0 for result.NextRow() { rowCount += 1 assert.Equal(t, "Hello, world", string(result.Values()[0])) } assert.Equal(t, 1, rowCount) commandTag, err := result.Close() assert.Equal(t, "SELECT 1", string(commandTag)) assert.NoError(t, err) ensureConnValid(t, pgConn) } func TestConnExecPreparedMaxNumberOfParams(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) paramCount := math.MaxUint16 params := make([]string, 0, paramCount) args := make([][]byte, 0, paramCount) for i := 0; i < paramCount; i++ { params = append(params, fmt.Sprintf("($%d::text)", i+1)) args = append(args, []byte(strconv.Itoa(i))) } sql := "values" + strings.Join(params, ", ") psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) require.NoError(t, err) require.NotNil(t, psd) assert.Len(t, psd.ParamOIDs, paramCount) assert.Len(t, psd.Fields, 1) result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() require.NoError(t, result.Err) require.Len(t, result.Rows, paramCount) ensureConnValid(t, pgConn) } func TestConnExecPreparedTooManyParams(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) paramCount := math.MaxUint16 + 1 params := make([]string, 0, paramCount) args := make([][]byte, 0, paramCount) for i := 0; i < paramCount; i++ { params = append(params, fmt.Sprintf("($%d::text)", i+1)) args = append(args, []byte(strconv.Itoa(i))) } sql := "values" + strings.Join(params, ", ") psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil) if pgConn.ParameterStatus("crdb_version") != "" { // CockroachDB rejects preparing a statement with more than 65535 parameters. require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)") } else { // PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol. require.NoError(t, err) require.NotNil(t, psd) assert.Len(t, psd.ParamOIDs, paramCount) assert.Len(t, psd.Fields, 1) result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read() require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters") } ensureConnValid(t, pgConn) } func TestConnExecPreparedCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil) rowCount := 0 for result.NextRow() { rowCount += 1 } assert.Equal(t, 0, rowCount) commandTag, err := result.Close() assert.Equal(t, pgconn.CommandTag(nil), commandTag) assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } } func TestConnExecPreparedPrecanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) cancel() result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() require.Error(t, result.Err) assert.True(t, errors.Is(result.Err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(result.Err)) ensureConnValid(t, pgConn) } func TestConnExecPreparedEmptySQL(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Prepare(ctx, "ps1", "", nil) require.NoError(t, err) result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read() assert.Nil(t, result.CommandTag) assert.Len(t, result.Rows, 0) assert.NoError(t, result.Err) ensureConnValid(t, pgConn) } func TestConnExecBatch(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.NoError(t, err) batch := &pgconn.Batch{} batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() require.NoError(t, err) require.Len(t, results, 3) require.Len(t, results[0].Rows, 1) require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0])) assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) require.Len(t, results[1].Rows, 1) require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0])) assert.Equal(t, "SELECT 1", string(results[1].CommandTag)) require.Len(t, results[2].Rows, 1) require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0])) assert.Equal(t, "SELECT 1", string(results[2].CommandTag)) } func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") } setupSQL := `create temporary table t ( id text primary key, n int not null, unique (n) deferrable initially deferred ); insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);` _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() require.NoError(t, err) batch := &pgconn.Batch{} batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil) _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() require.NotNil(t, err) var pgErr *pgconn.PgError require.True(t, errors.As(err, &pgErr)) require.Equal(t, "23505", pgErr.Code) ensureConnValid(t, pgConn) } func TestConnExecBatchPrecanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil) require.NoError(t, err) batch := &pgconn.Batch{} batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil) batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = pgConn.ExecBatch(ctx, batch).ReadAll() require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(err)) ensureConnValid(t, pgConn) } // Without concurrent reading and writing large batches can deadlock. // // See https://github.com/jackc/pgx/issues/374. func TestConnExecBatchHuge(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") } t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) batch := &pgconn.Batch{} queryCount := 100000 args := make([]string, queryCount) for i := range args { args[i] = strconv.Itoa(i) batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil) } results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll() require.NoError(t, err) require.Len(t, results, queryCount) for i := range args { require.Len(t, results[i].Rows, 1) require.Equal(t, args[i], string(results[i].Rows[0][0])) assert.Equal(t, "SELECT 1", string(results[i].CommandTag)) } } func TestConnExecBatchImplicitTransaction(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)") } _, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll() require.NoError(t, err) batch := &pgconn.Batch{} batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil) batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil) batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil) batch.ExecParams("select 1/0", nil, nil, nil, nil) _, err = pgConn.ExecBatch(context.Background(), batch).ReadAll() require.Error(t, err) result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read() require.Equal(t, "0", string(result.Rows[0][0])) } func TestConnLocking(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) mrr := pgConn.Exec(context.Background(), "select 'Hello, world'") _, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.Error(t, err) assert.Equal(t, "conn busy", err.Error()) assert.True(t, pgconn.SafeToRetry(err)) results, err := mrr.ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) ensureConnValid(t, pgConn) } func TestCommandTag(t *testing.T) { t.Parallel() var tests = []struct { commandTag pgconn.CommandTag rowsAffected int64 isInsert bool isUpdate bool isDelete bool isSelect bool }{ {commandTag: pgconn.CommandTag("INSERT 0 5"), rowsAffected: 5, isInsert: true}, {commandTag: pgconn.CommandTag("UPDATE 0"), rowsAffected: 0, isUpdate: true}, {commandTag: pgconn.CommandTag("UPDATE 1"), rowsAffected: 1, isUpdate: true}, {commandTag: pgconn.CommandTag("DELETE 0"), rowsAffected: 0, isDelete: true}, {commandTag: pgconn.CommandTag("DELETE 1"), rowsAffected: 1, isDelete: true}, {commandTag: pgconn.CommandTag("DELETE 1234567890"), rowsAffected: 1234567890, isDelete: true}, {commandTag: pgconn.CommandTag("SELECT 1"), rowsAffected: 1, isSelect: true}, {commandTag: pgconn.CommandTag("SELECT 99999999999"), rowsAffected: 99999999999, isSelect: true}, {commandTag: pgconn.CommandTag("CREATE TABLE"), rowsAffected: 0}, {commandTag: pgconn.CommandTag("ALTER TABLE"), rowsAffected: 0}, {commandTag: pgconn.CommandTag("DROP TABLE"), rowsAffected: 0}, } for i, tt := range tests { ct := tt.commandTag assert.Equalf(t, tt.rowsAffected, ct.RowsAffected(), "%d. %v", i, tt.commandTag) assert.Equalf(t, tt.isInsert, ct.Insert(), "%d. %v", i, tt.commandTag) assert.Equalf(t, tt.isUpdate, ct.Update(), "%d. %v", i, tt.commandTag) assert.Equalf(t, tt.isDelete, ct.Delete(), "%d. %v", i, tt.commandTag) assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) } } func TestConnOnNotice(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) var msg string config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { msg = notice.Message } config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect. pgConn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)") } multiResult := pgConn.Exec(context.Background(), `do $$ begin raise notice 'hello, world'; end$$;`) err = multiResult.Close() require.NoError(t, err) assert.Equal(t, "hello, world", msg) ensureConnValid(t, pgConn) } func TestConnOnNotification(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) var msg string config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { msg = n.Payload } pgConn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") } _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() require.NoError(t, err) notifier, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, notifier) _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() require.NoError(t, err) _, err = pgConn.Exec(context.Background(), "select 1").ReadAll() require.NoError(t, err) assert.Equal(t, "bar", msg) ensureConnValid(t, pgConn) } func TestConnWaitForNotification(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) var msg string config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) { msg = n.Payload } pgConn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") } _, err = pgConn.Exec(context.Background(), "listen foo").ReadAll() require.NoError(t, err) notifier, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, notifier) _, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll() require.NoError(t, err) err = pgConn.WaitForNotification(context.Background()) require.NoError(t, err) assert.Equal(t, "bar", msg) ensureConnValid(t, pgConn) } func TestConnWaitForNotificationPrecanceled(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) pgConn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithCancel(context.Background()) cancel() err = pgConn.WaitForNotification(ctx) require.ErrorIs(t, err, context.Canceled) ensureConnValid(t, pgConn) } func TestConnWaitForNotificationTimeout(t *testing.T) { t.Parallel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) pgConn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, pgConn) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) err = pgConn.WaitForNotification(ctx) cancel() assert.True(t, pgconn.Timeout(err)) assert.ErrorIs(t, err, context.DeadlineExceeded) ensureConnValid(t, pgConn) } func TestConnCopyToSmall(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does support COPY TO") } _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int2, b int4, c int8, d varchar, e text, f date, g json )`).ReadAll() require.NoError(t, err) _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll() require.NoError(t, err) _, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll() require.NoError(t, err) inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" + "\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n") outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") require.NoError(t, err) assert.Equal(t, int64(2), res.RowsAffected()) assert.Equal(t, inputBytes, outputWriter.Bytes()) ensureConnValid(t, pgConn) } func TestConnCopyToLarge(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does support COPY TO") } _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int2, b int4, c int8, d varchar, e text, f date, g json, h bytea )`).ReadAll() require.NoError(t, err) inputBytes := make([]byte, 0) for i := 0; i < 1000; i++ { _, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll() require.NoError(t, err) inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...) } outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes))) res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout") require.NoError(t, err) assert.Equal(t, int64(1000), res.RowsAffected()) assert.Equal(t, inputBytes, outputWriter.Bytes()) ensureConnValid(t, pgConn) } func TestConnCopyToQueryError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) outputWriter := bytes.NewBuffer(make([]byte, 0)) res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout") require.Error(t, err) assert.IsType(t, &pgconn.PgError{}, err) assert.Equal(t, int64(0), res.RowsAffected()) ensureConnValid(t, pgConn) } func TestConnCopyToCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") } outputWriter := &bytes.Buffer{} ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout") assert.Error(t, err) assert.Equal(t, pgconn.CommandTag(nil), res) assert.True(t, pgConn.IsClosed()) select { case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } } func TestConnCopyToPrecanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) outputWriter := &bytes.Buffer{} ctx, cancel := context.WithCancel(context.Background()) cancel() res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), res) ensureConnValid(t, pgConn) } func TestConnCopyFrom(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") } _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar )`).ReadAll() require.NoError(t, err) srcBuf := &bytes.Buffer{} inputRows := [][][]byte{} for i := 0; i < 1000; i++ { a := strconv.Itoa(i) b := "foo " + a + " bar" inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) _, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) } ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo FROM STDIN WITH (FORMAT csv)") require.NoError(t, err) assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() require.NoError(t, result.Err) assert.Equal(t, inputRows, result.Rows) ensureConnValid(t, pgConn) } func TestConnCopyFromCanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") } _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar )`).ReadAll() require.NoError(t, err) r, w := io.Pipe() go func() { for i := 0; i < 1000000; i++ { a := strconv.Itoa(i) b := "foo " + a + " bar" _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) if err != nil { return } time.Sleep(time.Microsecond) } }() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") cancel() assert.Equal(t, int64(0), ct.RowsAffected()) assert.Error(t, err) assert.True(t, pgConn.IsClosed()) select { case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } } func TestConnCopyFromPrecanceled(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar )`).ReadAll() require.NoError(t, err) r, w := io.Pipe() go func() { for i := 0; i < 1000000; i++ { a := strconv.Itoa(i) b := "foo " + a + " bar" _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) if err != nil { return } time.Sleep(time.Microsecond) } }() ctx, cancel := context.WithCancel(context.Background()) cancel() ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)") require.Error(t, err) assert.True(t, errors.Is(err, context.Canceled)) assert.True(t, pgconn.SafeToRetry(err)) assert.Equal(t, pgconn.CommandTag(nil), ct) ensureConnValid(t, pgConn) } func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)") } _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar )`).ReadAll() require.NoError(t, err) f, err := ioutil.TempFile("", "*") require.NoError(t, err) gw := gzip.NewWriter(f) inputRows := [][][]byte{} for i := 0; i < 1000; i++ { a := strconv.Itoa(i) b := "foo " + a + " bar" inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)}) _, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) require.NoError(t, err) } err = gw.Close() require.NoError(t, err) _, err = f.Seek(0, 0) require.NoError(t, err) gr, err := gzip.NewReader(f) require.NoError(t, err) ct, err := pgConn.CopyFrom(context.Background(), gr, "COPY foo FROM STDIN WITH (FORMAT csv)") require.NoError(t, err) assert.Equal(t, int64(len(inputRows)), ct.RowsAffected()) err = gr.Close() require.NoError(t, err) err = f.Close() require.NoError(t, err) err = os.Remove(f.Name()) require.NoError(t, err) result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read() require.NoError(t, result.Err) assert.Equal(t, inputRows, result.Rows) ensureConnValid(t, pgConn) } func TestConnCopyFromQuerySyntaxError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) _, err = pgConn.Exec(context.Background(), `create temporary table foo( a int4, b varchar )`).ReadAll() require.NoError(t, err) srcBuf := &bytes.Buffer{} res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo to stdout") require.Error(t, err) assert.IsType(t, &pgconn.PgError{}, err) assert.Equal(t, int64(0), res.RowsAffected()) ensureConnValid(t, pgConn) } func TestConnCopyFromQueryNoTableError(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) srcBuf := &bytes.Buffer{} res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout") require.Error(t, err) assert.IsType(t, &pgconn.PgError{}, err) assert.Equal(t, int64(0), res.RowsAffected()) ensureConnValid(t, pgConn) } // https://github.com/jackc/pgconn/issues/21 func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) { t.Parallel() ctx := context.Background() pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)") } _, err = pgConn.Exec(ctx, `create temporary table sentences( t text, ts tsvector )`).ReadAll() require.NoError(t, err) _, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$ begin new.ts := to_tsvector(new.t); return new; end $$ language plpgsql;`).ReadAll() require.NoError(t, err) _, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll() require.NoError(t, err) longString := make([]byte, 10001) for i := range longString { longString[i] = 'x' } buf := &bytes.Buffer{} for i := 0; i < 1000; i++ { buf.Write([]byte(fmt.Sprintf("%s\n", string(longString)))) } _, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)") require.NoError(t, err) } func TestConnEscapeString(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) tests := []struct { in string out string }{ {in: "", out: ""}, {in: "42", out: "42"}, {in: "'", out: "''"}, {in: "hi'there", out: "hi''there"}, {in: "'hi there'", out: "''hi there''"}, } for i, tt := range tests { value, err := pgConn.EscapeString(tt.in) if assert.NoErrorf(t, err, "%d.", i) { assert.Equalf(t, tt.out, value, "%d.", i) } } ensureConnValid(t, pgConn) } func TestConnCancelRequest(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)") } multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)") // This test flickers without the Sleep. It appears that since Exec only sends the query and returns without awaiting a // response that the CancelRequest can race it and be received before the query is running and cancellable. So wait a // few milliseconds. time.Sleep(50 * time.Millisecond) err = pgConn.CancelRequest(context.Background()) require.NoError(t, err) for multiResult.NextResult() { } err = multiResult.Close() require.IsType(t, &pgconn.PgError{}, err) require.Equal(t, "57014", err.(*pgconn.PgError).Code) ensureConnValid(t, pgConn) } // https://github.com/jackc/pgx/issues/659 func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, pgConn) pid := pgConn.PID() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(30)") for multiResult.NextResult() { } err = multiResult.Close() assert.True(t, pgconn.Timeout(err)) assert.True(t, pgConn.IsClosed()) select { case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer closeConn(t, otherConn) ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) defer cancel() for { result := otherConn.ExecParams(ctx, `select 1 from pg_stat_activity where pid=$1`, [][]byte{[]byte(strconv.FormatInt(int64(pid), 10))}, nil, nil, nil, ).Read() require.NoError(t, result.Err) if len(result.Rows) == 0 { break } } } func TestConnSendBytesAndReceiveMessage(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the messages we expect. pgConn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) defer closeConn(t, pgConn) queryMsg := pgproto3.Query{String: "select 42"} buf := queryMsg.Encode(nil) err = pgConn.SendBytes(ctx, buf) require.NoError(t, err) msg, err := pgConn.ReceiveMessage(ctx) require.NoError(t, err) _, ok := msg.(*pgproto3.RowDescription) require.True(t, ok) msg, err = pgConn.ReceiveMessage(ctx) require.NoError(t, err) _, ok = msg.(*pgproto3.DataRow) require.True(t, ok) msg, err = pgConn.ReceiveMessage(ctx) require.NoError(t, err) _, ok = msg.(*pgproto3.CommandComplete) require.True(t, ok) msg, err = pgConn.ReceiveMessage(ctx) require.NoError(t, err) _, ok = msg.(*pgproto3.ReadyForQuery) require.True(t, ok) ensureConnValid(t, pgConn) } func TestHijackAndConstruct(t *testing.T) { t.Parallel() origConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) hc, err := origConn.Hijack() require.NoError(t, err) _, err = origConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() require.Error(t, err) newConn, err := pgconn.Construct(hc) require.NoError(t, err) defer closeConn(t, newConn) results, err := newConn.Exec(context.Background(), "select 'Hello, world'").ReadAll() assert.NoError(t, err) assert.Len(t, results, 1) assert.Nil(t, results[0].Err) assert.Equal(t, "SELECT 1", string(results[0].CommandTag)) assert.Len(t, results[0].Rows, 1) assert.Equal(t, "Hello, world", string(results[0].Rows[0][0])) ensureConnValid(t, newConn) } func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) { t.Parallel() pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) ctx, _ := context.WithCancel(context.Background()) pgConn.Exec(ctx, "select n from generate_series(1,10) n") closeCtx, _ := context.WithCancel(context.Background()) pgConn.Close(closeCtx) select { case <-pgConn.CleanupDone(): case <-time.After(5 * time.Second): t.Fatal("Connection cleanup exceeded maximum time") } } // https://github.com/jackc/pgx/issues/800 func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) { t.Parallel() steps := pgmock.AcceptUnauthenticatedConnRequestSteps() steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{})) steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{})) steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{})) steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ {Name: []byte("mock")}, }})) steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")})) steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) script := &pgmock.Script{Steps: steps} ln, err := net.Listen("tcp", "127.0.0.1:") require.NoError(t, err) defer ln.Close() serverErrChan := make(chan error, 1) go func() { defer close(serverErrChan) conn, err := ln.Accept() if err != nil { serverErrChan <- err return } defer conn.Close() err = conn.SetDeadline(time.Now().Add(5 * time.Second)) if err != nil { serverErrChan <- err return } err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) if err != nil { serverErrChan <- err return } }() parts := strings.Split(ln.Addr().String(), ":") host := parts[0] port := parts[1] connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() conn, err := pgconn.Connect(ctx, connStr) require.NoError(t, err) rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil) for rr.NextRow() { } _, err = rr.Close() require.Error(t, err) } func Example() { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) if err != nil { log.Fatalln(err) } defer pgConn.Close(context.Background()) result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read() if result.Err != nil { log.Fatalln(result.Err) } for _, row := range result.Rows { fmt.Println(string(row[0])) } fmt.Println(result.CommandTag) // Output: // 1 // 2 // 3 // SELECT 3 } func GetSSLPassword(ctx context.Context) string { connString := os.Getenv("PGX_SSL_PASSWORD") return connString } var rsaCertPEM = `-----BEGIN CERTIFICATE----- MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39 tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d 9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp 0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o 6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2 gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I 81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27 MlascjupnaptKX/wMA== -----END CERTIFICATE----- ` var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY----- MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+ QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3 1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2 SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6 MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4 TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr 5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T z3w+CgS20UrbLIR1YXfqUXge1g== -----END TESTING KEY----- `) func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } func TestSNISupport(t *testing.T) { t.Parallel() tests := []struct { name string sni_param string sni_set bool }{ { name: "SNI is passed by default", sni_param: "", sni_set: true, }, { name: "SNI is passed when asked for", sni_param: "sslsni=1", sni_set: true, }, { name: "SNI is not passed when disabled", sni_param: "sslsni=0", sni_set: false, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() ln, err := net.Listen("tcp", "127.0.0.1:") require.NoError(t, err) defer ln.Close() serverErrChan := make(chan error, 1) serverSNINameChan := make(chan string, 1) defer close(serverErrChan) defer close(serverSNINameChan) go func() { var sniHost string conn, err := ln.Accept() if err != nil { serverErrChan <- err return } defer conn.Close() err = conn.SetDeadline(time.Now().Add(5 * time.Second)) if err != nil { serverErrChan <- err return } backend := pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn) startupMessage, err := backend.ReceiveStartupMessage() if err != nil { serverErrChan <- err return } switch startupMessage.(type) { case *pgproto3.SSLRequest: _, err = conn.Write([]byte("S")) if err != nil { serverErrChan <- err return } default: serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) return } cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) if err != nil { serverErrChan <- err return } srv := tls.Server(conn, &tls.Config{ Certificates: []tls.Certificate{cert}, GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { sniHost = argHello.ServerName return nil, nil }, }) defer srv.Close() if err := srv.Handshake(); err != nil { serverErrChan <- fmt.Errorf("handshake: %v", err) return } srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil)) srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)) srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) serverSNINameChan <- sniHost }() port := strings.Split(ln.Addr().String(), ":")[1] connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param) _, err = pgconn.Connect(context.Background(), connStr) select { case sniHost := <-serverSNINameChan: if tt.sni_set { require.Equal(t, sniHost, "localhost") } else { require.Equal(t, sniHost, "") } case err = <-serverErrChan: t.Fatalf("server failed with error: %+v", err) case <-time.After(time.Millisecond * 100): t.Fatal("exceeded connection timeout without erroring out") } }) } } type delayedReader struct { r io.Reader } func (d delayedReader) Read(p []byte) (int, error) { // W/o sleep test passes, with sleep it fails. time.Sleep(time.Millisecond) return d.r.Read(p) } func TestCopyFrom(t *testing.T) { connString := os.Getenv("PGX_TEST_CONN_STRING") if connString == "" { t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_CONN_STRING") } config, err := pgconn.ParseConfig(connString) require.NoError(t, err) pgConn, err := pgconn.ConnectConfig(context.Background(), config) require.NoError(t, err) if pgConn.ParameterStatus("crdb_version") != "" { t.Skip("Server does support COPY FROM") } setupSQL := `create temporary table t ( id text primary key, n int not null );` _, err = pgConn.Exec(context.Background(), setupSQL).ReadAll() assert.NoError(t, err) r1 := delayedReader{r: strings.NewReader(`id 0\n`)} // Generate an error with a bogus COPY command _, err = pgConn.CopyFrom(context.Background(), r1, "COPY nosuchtable FROM STDIN ") assert.Error(t, err) r2 := delayedReader{r: strings.NewReader(`id 0\n`)} _, err = pgConn.CopyFrom(context.Background(), r2, "COPY t FROM STDIN") assert.NoError(t, err) } pgconn-1.14.0/stmtcache/000077500000000000000000000000001437172345200150445ustar00rootroot00000000000000pgconn-1.14.0/stmtcache/lru.go000066400000000000000000000100631437172345200161750ustar00rootroot00000000000000package stmtcache import ( "container/list" "context" "fmt" "sync/atomic" "github.com/jackc/pgconn" ) var lruCount uint64 // LRU implements Cache with a Least Recently Used (LRU) cache. type LRU struct { conn *pgconn.PgConn mode int cap int prepareCount int m map[string]*list.Element l *list.List psNamePrefix string stmtsToClear []string } // NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache. func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU { mustBeValidMode(mode) mustBeValidCap(cap) n := atomic.AddUint64(&lruCount, 1) return &LRU{ conn: conn, mode: mode, cap: cap, m: make(map[string]*list.Element), l: list.New(), psNamePrefix: fmt.Sprintf("lrupsc_%d", n), } } // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { if ctx != context.Background() { select { case <-ctx.Done(): return nil, ctx.Err() default: } } // flush an outstanding bad statements txStatus := c.conn.TxStatus() if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 { for _, stmt := range c.stmtsToClear { err := c.clearStmt(ctx, stmt) if err != nil { return nil, err } } } if el, ok := c.m[sql]; ok { c.l.MoveToFront(el) return el.Value.(*pgconn.StatementDescription), nil } if c.l.Len() == c.cap { err := c.removeOldest(ctx) if err != nil { return nil, err } } psd, err := c.prepare(ctx, sql) if err != nil { return nil, err } el := c.l.PushFront(psd) c.m[sql] = el return psd, nil } // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. func (c *LRU) Clear(ctx context.Context) error { for c.l.Len() > 0 { err := c.removeOldest(ctx) if err != nil { return err } } return nil } func (c *LRU) StatementErrored(sql string, err error) { pgErr, ok := err.(*pgconn.PgError) if !ok { return } // https://github.com/jackc/pgx/issues/1162 // // We used to look for the message "cached plan must not change result type". However, that message can be localized. // Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to // tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't // have so it should be safe. possibleInvalidCachedPlanError := pgErr.Code == "0A000" if possibleInvalidCachedPlanError { c.stmtsToClear = append(c.stmtsToClear, sql) } } func (c *LRU) clearStmt(ctx context.Context, sql string) error { elem, inMap := c.m[sql] if !inMap { // The statement probably fell off the back of the list. In that case, we've // ensured that it isn't in the cache, so we can declare victory. return nil } c.l.Remove(elem) psd := elem.Value.(*pgconn.StatementDescription) delete(c.m, psd.SQL) if c.mode == ModePrepare { return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() } return nil } // Len returns the number of cached prepared statement descriptions. func (c *LRU) Len() int { return c.l.Len() } // Cap returns the maximum number of cached prepared statement descriptions. func (c *LRU) Cap() int { return c.cap } // Mode returns the mode of the cache (ModePrepare or ModeDescribe) func (c *LRU) Mode() int { return c.mode } func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) { var name string if c.mode == ModePrepare { name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount) c.prepareCount += 1 } return c.conn.Prepare(ctx, name, sql, nil) } func (c *LRU) removeOldest(ctx context.Context) error { oldest := c.l.Back() c.l.Remove(oldest) psd := oldest.Value.(*pgconn.StatementDescription) delete(c.m, psd.SQL) if c.mode == ModePrepare { return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close() } return nil } pgconn-1.14.0/stmtcache/lru_test.go000066400000000000000000000213661437172345200172440ustar00rootroot00000000000000package stmtcache_test import ( "context" "fmt" "math/rand" "os" "regexp" "testing" "time" "github.com/jackc/pgconn" "github.com/jackc/pgconn/stmtcache" "github.com/stretchr/testify/require" ) func TestLRUModePrepare(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 1") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 2") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 2, cache.Len()) require.ElementsMatch(t, []string{"select 1", "select 2"}, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 3") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 2, cache.Len()) require.ElementsMatch(t, []string{"select 2", "select 3"}, fetchServerStatements(t, ctx, conn)) err = cache.Clear(ctx) require.NoError(t, err) require.EqualValues(t, 0, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) } func TestLRUStmtInvalidation(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) // we construct a fake error because its not super straightforward to actually call // a prepared statement from the LRU cache without the helper routines which live // in pgx proper. fakeInvalidCachePlanError := &pgconn.PgError{ Severity: "ERROR", Code: "0A000", Message: "cached plan must not change result type", } cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) // // outside of a transaction, we eagerly flush the statement // _, err = cache.Get(ctx, "select 1") require.NoError(t, err) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) cache.StatementErrored("select 1", fakeInvalidCachePlanError) _, err = cache.Get(ctx, "select 2") require.NoError(t, err) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) err = cache.Clear(ctx) require.NoError(t, err) // // within an errored transaction, we defer the flush to after the first get // that happens after the transaction is rolled back // _, err = cache.Get(ctx, "select 1") require.NoError(t, err) require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 1"}, fetchServerStatements(t, ctx, conn)) res := conn.Exec(ctx, "begin") require.NoError(t, res.Close()) require.Equal(t, byte('T'), conn.TxStatus()) res = conn.Exec(ctx, "selec") require.Error(t, res.Close()) require.Equal(t, byte('E'), conn.TxStatus()) cache.StatementErrored("select 1", fakeInvalidCachePlanError) require.EqualValues(t, 1, cache.Len()) res = conn.Exec(ctx, "rollback") require.NoError(t, res.Close()) _, err = cache.Get(ctx, "select 2") require.EqualValues(t, 1, cache.Len()) require.ElementsMatch(t, []string{"select 2"}, fetchServerStatements(t, ctx, conn)) } func TestLRUStmtInvalidationIntegration(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 2) result := conn.ExecParams(ctx, "create temporary table stmtcache_table (a text)", nil, nil, nil, nil).Read() require.NoError(t, result.Err) sql := "select * from stmtcache_table" sd1, err := cache.Get(ctx, sql) require.NoError(t, err) result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() require.NoError(t, result.Err) result = conn.ExecParams(ctx, "alter table stmtcache_table add column b text", nil, nil, nil, nil).Read() require.NoError(t, result.Err) result = conn.ExecPrepared(ctx, sd1.Name, nil, nil, nil).Read() require.EqualError(t, result.Err, "ERROR: cached plan must not change result type (SQLSTATE 0A000)") cache.StatementErrored(sql, result.Err) sd2, err := cache.Get(ctx, sql) require.NoError(t, err) require.NotEqual(t, sd1.Name, sd2.Name) result = conn.ExecPrepared(ctx, sd2.Name, nil, nil, nil).Read() require.NoError(t, result.Err) } func TestLRUModePrepareStress(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModePrepare, 8) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 8, cache.Cap()) require.EqualValues(t, stmtcache.ModePrepare, cache.Mode()) for i := 0; i < 1000; i++ { psd, err := cache.Get(ctx, fmt.Sprintf("select %d", rand.Intn(50))) require.NoError(t, err) require.NotNil(t, psd) result := conn.ExecPrepared(ctx, psd.Name, nil, nil, nil).Read() require.NoError(t, result.Err) } } func TestLRUModeDescribe(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) require.EqualValues(t, 0, cache.Len()) require.EqualValues(t, 2, cache.Cap()) require.EqualValues(t, stmtcache.ModeDescribe, cache.Mode()) psd, err := cache.Get(ctx, "select 1") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 1, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 1") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 1, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 2") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 2, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) psd, err = cache.Get(ctx, "select 3") require.NoError(t, err) require.NotNil(t, psd) require.EqualValues(t, 2, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) err = cache.Clear(ctx) require.NoError(t, err) require.EqualValues(t, 0, cache.Len()) require.Empty(t, fetchServerStatements(t, ctx, conn)) } func TestLRUContext(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() conn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING")) require.NoError(t, err) defer conn.Close(ctx) cache := stmtcache.NewLRU(conn, stmtcache.ModeDescribe, 2) // test 1 : getting a value for the first time with a cancelled context returns an error ctx1, cancel1 := context.WithCancel(ctx) cancel1() desc, err := cache.Get(ctx1, "SELECT 1") require.Error(t, err) require.Nil(t, desc) // test 2 : when querying for the 2nd time a cached value, if the context is canceled return an error ctx2, cancel2 := context.WithCancel(ctx) desc, err = cache.Get(ctx2, "SELECT 2") require.NoError(t, err) require.NotNil(t, desc) cancel2() desc, err = cache.Get(ctx2, "SELECT 2") require.Error(t, err) require.Nil(t, desc) } func fetchServerStatements(t testing.TB, ctx context.Context, conn *pgconn.PgConn) []string { result := conn.ExecParams(ctx, `select statement from pg_prepared_statements`, nil, nil, nil, nil).Read() require.NoError(t, result.Err) var statements []string for _, r := range result.Rows { statement := string(r[0]) if conn.ParameterStatus("crdb_version") != "" { if statement == "PREPARE AS select statement from pg_prepared_statements" { // CockroachDB includes the currently running unnamed prepared statement while PostgreSQL does not. Ignore it. continue } // CockroachDB includes the "PREPARE ... AS" text in the statement even if it was prepared through the extended // protocol will PostgreSQL does not. Normalize the statement. re := regexp.MustCompile(`^PREPARE lrupsc[0-9_]+ AS `) statement = re.ReplaceAllString(statement, "") } statements = append(statements, statement) } return statements } pgconn-1.14.0/stmtcache/stmtcache.go000066400000000000000000000036161437172345200173540ustar00rootroot00000000000000// Package stmtcache is a cache that can be used to implement lazy prepared statements. package stmtcache import ( "context" "github.com/jackc/pgconn" ) const ( ModePrepare = iota // Cache should prepare named statements. ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement. ) // Cache prepares and caches prepared statement descriptions. type Cache interface { // Get returns the prepared statement description for sql preparing or describing the sql on the server as needed. Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) // Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session. Clear(ctx context.Context) error // StatementErrored informs the cache that the given statement resulted in an error when it // was last used against the database. In some cases, this will cause the cache to maer that // statement as bad. The bad statement will instead be flushed during the next call to Get // that occurs outside of a failed transaction. StatementErrored(sql string, err error) // Len returns the number of cached prepared statement descriptions. Len() int // Cap returns the maximum number of cached prepared statement descriptions. Cap() int // Mode returns the mode of the cache (ModePrepare or ModeDescribe) Mode() int } // New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is // the maximum size of the cache. func New(conn *pgconn.PgConn, mode int, cap int) Cache { mustBeValidMode(mode) mustBeValidCap(cap) return NewLRU(conn, mode, cap) } func mustBeValidMode(mode int) { if mode != ModePrepare && mode != ModeDescribe { panic("mode must be ModePrepare or ModeDescribe") } } func mustBeValidCap(cap int) { if cap < 1 { panic("cache must have cap of >= 1") } }