golang-pq-dev/0000755000014500017510000000000012163564140013006 5ustar michaelstaffgolang-pq-dev/user_windows.go0000644000014500017510000000152112163564140016064 0ustar michaelstaff// Package pq is a pure Go Postgres driver for the database/sql package. package pq import ( "path/filepath" "syscall" ) // Perform Windows user name lookup identically to libpq. // // The PostgreSQL code makes use of the legacy Win32 function // GetUserName, and that function has not been imported into stock Go. // GetUserNameEx is available though, the difference being that a // wider range of names are available. To get the output to be the // same as GetUserName, only the base (or last) component of the // result is returned. func userCurrent() (string, error) { pw_name := make([]uint16, 128) pwname_size := uint32(len(pw_name)) - 1 err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size) if err != nil { return "", err } s := syscall.UTF16ToString(pw_name) u := filepath.Base(s) return u, nil } golang-pq-dev/user_posix.go0000644000014500017510000000042512163564140015536 0ustar michaelstaff// Package pq is a pure Go Postgres driver for the database/sql package. // +build darwin freebsd linux netbsd openbsd package pq import "os/user" func userCurrent() (string, error) { u, err := user.Current() if err != nil { return "", err } return u.Username, nil } golang-pq-dev/url_test.go0000644000014500017510000000224212163564140015176 0ustar michaelstaffpackage pq import ( "testing" ) func TestSimpleParseURL(t *testing.T) { expected := "host=hostname.remote" str, err := ParseURL("postgres://hostname.remote") if err != nil { t.Fatal(err) } if str != expected { t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) } } func TestFullParseURL(t *testing.T) { expected := "dbname=database host=hostname.remote password=secret port=1234 user=username" str, err := ParseURL("postgres://username:secret@hostname.remote:1234/database") if err != nil { t.Fatal(err) } if str != expected { t.Fatalf("unexpected result from ParseURL:\n+ %s\n- %s", str, expected) } } func TestInvalidProtocolParseURL(t *testing.T) { _, err := ParseURL("http://hostname.remote") switch err { case nil: t.Fatal("Expected an error from parsing invalid protocol") default: msg := "invalid connection protocol: http" if err.Error() != msg { t.Fatalf("Unexpected error message:\n+ %s\n- %s", err.Error(), msg) } } } func TestMinimalURL(t *testing.T) { cs, err := ParseURL("postgres://") if err != nil { t.Fatal(err) } if cs != "" { t.Fatalf("expected blank connection string, got: %q", cs) } } golang-pq-dev/url.go0000644000014500017510000000233212163564140014137 0ustar michaelstaffpackage pq import ( "fmt" nurl "net/url" "sort" "strings" ) // ParseURL converts url to a connection string for driver.Open. // Example: // // "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" // // converts to: // // "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" // // A minimal example: // // "postgres://" // // This will be blank, causing driver.Open to use all of the defaults func ParseURL(url string) (string, error) { u, err := nurl.Parse(url) if err != nil { return "", err } if u.Scheme != "postgres" { return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) } var kvs []string accrue := func(k, v string) { if v != "" { kvs = append(kvs, k+"="+v) } } if u.User != nil { v := u.User.Username() accrue("user", v) v, _ = u.User.Password() accrue("password", v) } i := strings.Index(u.Host, ":") if i < 0 { accrue("host", u.Host) } else { accrue("host", u.Host[:i]) accrue("port", u.Host[i+1:]) } if u.Path != "" { accrue("dbname", u.Path[1:]) } q := u.Query() for k, _ := range q { accrue(k, q.Get(k)) } sort.Strings(kvs) // Makes testing easier (not a performance concern) return strings.Join(kvs, " "), nil } golang-pq-dev/oid/0000755000014500017510000000000012163564140013561 5ustar michaelstaffgolang-pq-dev/oid/types.go0000644000014500017510000001166612163564140015266 0ustar michaelstaffpackage oid // Generated via massaging this catalog query: // // SELECT 'T_' || typname || ' = ' || oid // FROM pg_type WHERE oid < 10000 // ORDER BY oid; // // This should probably be done one per release. Postgres does not // re-appropriate the system OID space below 10000 as a general rule. type Oid uint32 const ( T_bool Oid = 16 T_bytea = 17 T_char = 18 T_name = 19 T_int8 = 20 T_int2 = 21 T_int2vector = 22 T_int4 = 23 T_regproc = 24 T_text = 25 T_oid = 26 T_tid = 27 T_xid = 28 T_cid = 29 T_oidvector = 30 T_pg_type = 71 T_pg_attribute = 75 T_pg_proc = 81 T_pg_class = 83 T_json = 114 T_xml = 142 T__xml = 143 T_pg_node_tree = 194 T__json = 199 T_smgr = 210 T_point = 600 T_lseg = 601 T_path = 602 T_box = 603 T_polygon = 604 T_line = 628 T__line = 629 T_cidr = 650 T__cidr = 651 T_float4 = 700 T_float8 = 701 T_abstime = 702 T_reltime = 703 T_tinterval = 704 T_unknown = 705 T_circle = 718 T__circle = 719 T_money = 790 T__money = 791 T_macaddr = 829 T_inet = 869 T__bool = 1000 T__bytea = 1001 T__char = 1002 T__name = 1003 T__int2 = 1005 T__int2vector = 1006 T__int4 = 1007 T__regproc = 1008 T__text = 1009 T__tid = 1010 T__xid = 1011 T__cid = 1012 T__oidvector = 1013 T__bpchar = 1014 T__varchar = 1015 T__int8 = 1016 T__point = 1017 T__lseg = 1018 T__path = 1019 T__box = 1020 T__float4 = 1021 T__float8 = 1022 T__abstime = 1023 T__reltime = 1024 T__tinterval = 1025 T__polygon = 1027 T__oid = 1028 T_aclitem = 1033 T__aclitem = 1034 T__macaddr = 1040 T__inet = 1041 T_bpchar = 1042 T_varchar = 1043 T_date = 1082 T_time = 1083 T_timestamp = 1114 T__timestamp = 1115 T__date = 1182 T__time = 1183 T_timestamptz = 1184 T__timestamptz = 1185 T_interval = 1186 T__interval = 1187 T__numeric = 1231 T_pg_database = 1248 T__cstring = 1263 T_timetz = 1266 T__timetz = 1270 T_bit = 1560 T__bit = 1561 T_varbit = 1562 T__varbit = 1563 T_numeric = 1700 T_refcursor = 1790 T__refcursor = 2201 T_regprocedure = 2202 T_regoper = 2203 T_regoperator = 2204 T_regclass = 2205 T_regtype = 2206 T__regprocedure = 2207 T__regoper = 2208 T__regoperator = 2209 T__regclass = 2210 T__regtype = 2211 T_record = 2249 T_cstring = 2275 T_any = 2276 T_anyarray = 2277 T_void = 2278 T_trigger = 2279 T_language_handler = 2280 T_internal = 2281 T_opaque = 2282 T_anyelement = 2283 T__record = 2287 T_anynonarray = 2776 T_pg_authid = 2842 T_pg_auth_members = 2843 T__txid_snapshot = 2949 T_uuid = 2950 T__uuid = 2951 T_txid_snapshot = 2970 T_fdw_handler = 3115 T_anyenum = 3500 T_tsvector = 3614 T_tsquery = 3615 T_gtsvector = 3642 T__tsvector = 3643 T__gtsvector = 3644 T__tsquery = 3645 T_regconfig = 3734 T__regconfig = 3735 T_regdictionary = 3769 T__regdictionary = 3770 T_anyrange = 3831 T_int4range = 3904 T__int4range = 3905 T_numrange = 3906 T__numrange = 3907 T_tsrange = 3908 T__tsrange = 3909 T_tstzrange = 3910 T__tstzrange = 3911 T_daterange = 3912 T__daterange = 3913 T_int8range = 3926 T__int8range = 3927 ) golang-pq-dev/error.go0000644000014500017510000000333512163564140014472 0ustar michaelstaffpackage pq import ( "database/sql/driver" "fmt" "io" "net" "runtime" ) const ( Efatal = "FATAL" Epanic = "PANIC" Ewarning = "WARNING" Enotice = "NOTICE" Edebug = "DEBUG" Einfo = "INFO" Elog = "LOG" ) type Error error type PGError interface { Error() string Fatal() bool Get(k byte) (v string) } type pgError struct { c map[byte]string } func parseError(r *readBuf) *pgError { err := &pgError{make(map[byte]string)} for t := r.byte(); t != 0; t = r.byte() { err.c[t] = r.string() } return err } func (err *pgError) Get(k byte) (v string) { v, _ = err.c[k] return } func (err *pgError) Fatal() bool { return err.Get('S') == Efatal } func (err *pgError) Error() string { var s string for k, v := range err.c { s += fmt.Sprintf(" %c:%q", k, v) } return "pq: " + s[1:] } func errorf(s string, args ...interface{}) { panic(Error(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))) } type SimplePGError struct { pgError } func (err *SimplePGError) Error() string { return "pq: " + err.Get('M') } func errRecoverWithPGReason(err *error) { e := recover() switch v := e.(type) { case nil: // Do nothing case *pgError: // Return a SimplePGError in place *err = &SimplePGError{*v} default: // Otherwise re-panic panic(e) } } func errRecover(err *error) { e := recover() switch v := e.(type) { case nil: // Do nothing case runtime.Error: panic(v) case *pgError: if v.Fatal() { *err = driver.ErrBadConn } else { *err = v } case *net.OpError: *err = driver.ErrBadConn case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { *err = driver.ErrBadConn } else { *err = v } default: panic(fmt.Sprintf("unknown error: %#v", e)) } } golang-pq-dev/encode_test.go0000644000014500017510000000653312163564140015640 0ustar michaelstaffpackage pq import ( "fmt" "testing" "time" ) func TestScanTimestamp(t *testing.T) { var nt NullTime tn := time.Now() (&nt).Scan(tn) if !nt.Valid { t.Errorf("Expected Valid=false") } if nt.Time != tn { t.Errorf("Time value mismatch") } } func TestScanNilTimestamp(t *testing.T) { var nt NullTime (&nt).Scan(nil) if nt.Valid { t.Errorf("Expected Valid=false") } } func TestTimestampWithTimeZone(t *testing.T) { db := openTestConn(t) defer db.Close() tx, err := db.Begin() if err != nil { t.Fatal(err) } defer tx.Rollback() _, err = tx.Exec("create temp table test (t timestamp with time zone)") if err != nil { t.Fatal(err) } // try several different locations, all included in Go's zoneinfo.zip for _, locName := range []string{ "UTC", "America/Chicago", "America/New_York", "Australia/Darwin", "Australia/Perth", } { loc, err := time.LoadLocation(locName) if err != nil { t.Logf("Could not load time zone %s - skipping", locName) continue } // Postgres timestamps have a resolution of 1 microsecond, so don't // use the full range of the Nanosecond argument refTime := time.Date(2012, 11, 6, 10, 23, 42, 123456000, loc) _, err = tx.Exec("insert into test(t) values($1)", refTime) if err != nil { t.Fatal(err) } for _, pgTimeZone := range []string{"US/Eastern", "Australia/Darwin"} { // Switch Postgres's timezone to test different output timestamp formats _, err = tx.Exec(fmt.Sprintf("set time zone '%s'", pgTimeZone)) if err != nil { t.Fatal(err) } var gotTime time.Time row := tx.QueryRow("select t from test") err = row.Scan(&gotTime) if err != nil { t.Fatal(err) } if !refTime.Equal(gotTime) { t.Errorf("timestamps not equal: %s != %s", refTime, gotTime) } } _, err = tx.Exec("delete from test") if err != nil { t.Fatal(err) } } } func TestTimestampWithOutTimezone(t *testing.T) { db := openTestConn(t) defer db.Close() test := func(ts, pgts string) { r, err := db.Query("SELECT $1::timestamp", pgts) if err != nil { t.Fatalf("Could not run query: %v", err) } n := r.Next() if n != true { t.Fatal("Expected at least one row") } var result time.Time err = r.Scan(&result) if err != nil { t.Fatalf("Did not expect error scanning row: %v", err) } expected, err := time.Parse(time.RFC3339, ts) if err != nil { t.Fatalf("Could not parse test time literal: %v", err) } if !result.Equal(expected) { t.Fatalf("Expected time to match %v: got mismatch %v", expected, result) } n = r.Next() if n != false { t.Fatal("Expected only one row") } } test("2000-01-01T00:00:00Z", "2000-01-01T00:00:00") // Test higher precision time test("2013-01-04T20:14:58.80033Z", "2013-01-04 20:14:58.80033") } func TestStringWithNul(t *testing.T) { db := openTestConn(t) defer db.Close() hello0world := string("hello\x00world") _, err := db.Query("SELECT $1::text", &hello0world) if err == nil { t.Fatal("Postgres accepts a string with nul in it; " + "injection attacks may be plausible") } } func TestByteToText(t *testing.T) { db := openTestConn(t) defer db.Close() b := []byte("hello world") row := db.QueryRow("SELECT $1::text", b) var result []byte err := row.Scan(&result) if err != nil { t.Fatal(err) } if string(result) != string(b) { t.Fatalf("expected %v but got %v", b, result) } } golang-pq-dev/encode.go0000644000014500017510000000474512163564140014604 0ustar michaelstaffpackage pq import ( "database/sql/driver" "encoding/hex" "fmt" "github.com/lib/pq/oid" "strconv" "time" ) func encode(x interface{}, pgtypOid oid.Oid) []byte { switch v := x.(type) { case int64: return []byte(fmt.Sprintf("%d", v)) case float32, float64: return []byte(fmt.Sprintf("%f", v)) case []byte: if pgtypOid == oid.T_bytea { return []byte(fmt.Sprintf("\\x%x", v)) } return v case string: if pgtypOid == oid.T_bytea { return []byte(fmt.Sprintf("\\x%x", v)) } return []byte(v) case bool: return []byte(fmt.Sprintf("%t", v)) case time.Time: return []byte(v.Format(time.RFC3339Nano)) default: errorf("encode: unknown type for %T", v) } panic("not reached") } func decode(s []byte, typ oid.Oid) interface{} { switch typ { case oid.T_bytea: s = s[2:] // trim off "\\x" d := make([]byte, hex.DecodedLen(len(s))) _, err := hex.Decode(d, s) if err != nil { errorf("%s", err) } return d case oid.T_timestamptz: return mustParse("2006-01-02 15:04:05-07", typ, s) case oid.T_timestamp: return mustParse("2006-01-02 15:04:05", typ, s) case oid.T_time: return mustParse("15:04:05", typ, s) case oid.T_timetz: return mustParse("15:04:05-07", typ, s) case oid.T_date: return mustParse("2006-01-02", typ, s) case oid.T_bool: return s[0] == 't' case oid.T_int8, oid.T_int2, oid.T_int4: i, err := strconv.ParseInt(string(s), 10, 64) if err != nil { errorf("%s", err) } return i case oid.T_float4, oid.T_float8: bits := 64 if typ == oid.T_float4 { bits = 32 } f, err := strconv.ParseFloat(string(s), bits) if err != nil { errorf("%s", err) } return f } return s } func mustParse(f string, typ oid.Oid, s []byte) time.Time { str := string(s) // Special case until time.Parse bug is fixed: // http://code.google.com/p/go/issues/detail?id=3487 if str[len(str)-2] == '.' { str += "0" } // check for a 30-minute-offset timezone if (typ == oid.T_timestamptz || typ == oid.T_timetz) && str[len(str)-3] == ':' { f += ":00" } t, err := time.Parse(f, str) if err != nil { errorf("decode: %s", err) } return t } type NullTime struct { Time time.Time Valid bool // Valid is true if Time is not NULL } // Scan implements the Scanner interface. func (nt *NullTime) Scan(value interface{}) error { nt.Time, nt.Valid = value.(time.Time) return nil } // Value implements the driver Valuer interface. func (nt NullTime) Value() (driver.Value, error) { if !nt.Valid { return nil, nil } return nt.Time, nil } golang-pq-dev/conn_test.go0000644000014500017510000002220512163564140015332 0ustar michaelstaffpackage pq import ( "database/sql" "database/sql/driver" "io" "os" "reflect" "testing" "time" ) type Fatalistic interface { Fatal(args ...interface{}) } func openTestConn(t Fatalistic) *sql.DB { datname := os.Getenv("PGDATABASE") sslmode := os.Getenv("PGSSLMODE") if datname == "" { os.Setenv("PGDATABASE", "pqgotest") } if sslmode == "" { os.Setenv("PGSSLMODE", "disable") } conn, err := sql.Open("postgres", "") if err != nil { t.Fatal(err) } return conn } func TestExec(t *testing.T) { db := openTestConn(t) defer db.Close() _, err := db.Exec("CREATE TEMP TABLE temp (a int)") if err != nil { t.Fatal(err) } r, err := db.Exec("INSERT INTO temp VALUES (1)") if err != nil { t.Fatal(err) } if n, _ := r.RowsAffected(); n != 1 { t.Fatalf("expected 1 row affected, not %d", n) } r, err = db.Exec("INSERT INTO temp VALUES ($1), ($2), ($3)", 1, 2, 3) if err != nil { t.Fatal(err) } if n, _ := r.RowsAffected(); n != 3 { t.Fatalf("expected 3 rows affected, not %d", n) } r, err = db.Exec("SELECT g FROM generate_series(1, 2) g") if err != nil { t.Fatal(err) } if n, _ := r.RowsAffected(); n != 2 { t.Fatalf("expected 2 rows affected, not %d", n) } r, err = db.Exec("SELECT g FROM generate_series(1, $1) g", 3) if err != nil { t.Fatal(err) } if n, _ := r.RowsAffected(); n != 3 { t.Fatalf("expected 3 rows affected, not %d", n) } } func TestStatment(t *testing.T) { db := openTestConn(t) defer db.Close() st, err := db.Prepare("SELECT 1") if err != nil { t.Fatal(err) } st1, err := db.Prepare("SELECT 2") if err != nil { t.Fatal(err) } r, err := st.Query() if err != nil { t.Fatal(err) } defer r.Close() if !r.Next() { t.Fatal("expected row") } var i int err = r.Scan(&i) if err != nil { t.Fatal(err) } if i != 1 { t.Fatalf("expected 1, got %d", i) } // st1 r1, err := st1.Query() if err != nil { t.Fatal(err) } defer r1.Close() if !r1.Next() { if r.Err() != nil { t.Fatal(r1.Err()) } t.Fatal("expected row") } err = r1.Scan(&i) if err != nil { t.Fatal(err) } if i != 2 { t.Fatalf("expected 2, got %d", i) } } func TestRowsCloseBeforeDone(t *testing.T) { db := openTestConn(t) defer db.Close() r, err := db.Query("SELECT 1") if err != nil { t.Fatal(err) } err = r.Close() if err != nil { t.Fatal(err) } if r.Next() { t.Fatal("unexpected row") } if r.Err() != nil { t.Fatal(r.Err()) } } func TestEncodeDecode(t *testing.T) { db := openTestConn(t) defer db.Close() q := ` SELECT '\x000102'::bytea, 'foobar'::text, NULL::integer, '2000-1-1 01:02:03.04-7'::timestamptz, 0::boolean, 123, 3.14::float8 WHERE '\x000102'::bytea = $1 AND 'foobar'::text = $2 AND $3::integer is NULL ` // AND '2000-1-1 12:00:00.000000-7'::timestamp = $3 exp1 := []byte{0, 1, 2} exp2 := "foobar" r, err := db.Query(q, exp1, exp2, nil) if err != nil { t.Fatal(err) } defer r.Close() if !r.Next() { if r.Err() != nil { t.Fatal(r.Err()) } t.Fatal("expected row") } var got1 []byte var got2 string var got3 = sql.NullInt64{Valid: true} var got4 time.Time var got5, got6, got7 interface{} err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(exp1, got1) { t.Errorf("expected %q byte: %q", exp1, got1) } if !reflect.DeepEqual(exp2, got2) { t.Errorf("expected %q byte: %q", exp2, got2) } if got3.Valid { t.Fatal("expected invalid") } if got4.Year() != 2000 { t.Fatal("wrong year") } if got5 != false { t.Fatalf("expected false, got %q", got5) } if got6 != int64(123) { t.Fatalf("expected 123, got %d", got6) } if got7 != float64(3.14) { t.Fatalf("expected 3.14, got %f", got7) } } func TestNoData(t *testing.T) { db := openTestConn(t) defer db.Close() st, err := db.Prepare("SELECT 1 WHERE true = false") if err != nil { t.Fatal(err) } defer st.Close() r, err := st.Query() if err != nil { t.Fatal(err) } defer r.Close() if r.Next() { if r.Err() != nil { t.Fatal(r.Err()) } t.Fatal("unexpected row") } } func TestPGError(t *testing.T) { // Don't use the normal connection setup, this is intended to // blow up in the startup packet from a non-existent user. db, err := sql.Open("postgres", "user=thisuserreallydoesntexist") if err != nil { t.Fatal(err) } defer db.Close() _, err = db.Begin() if err == nil { t.Fatal("expected error") } if err, ok := err.(PGError); !ok { t.Fatalf("expected a PGError, got: %v", err) } } func TestBadConn(t *testing.T) { var err error func() { defer errRecover(&err) panic(io.EOF) }() if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } func() { defer errRecover(&err) e := &pgError{c: make(map[byte]string)} e.c['S'] = Efatal panic(e) }() if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } } func TestErrorOnExec(t *testing.T) { db := openTestConn(t) defer db.Close() sql := "DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END; $$;" _, err := db.Exec(sql) _, ok := err.(PGError) if !ok { t.Fatalf("expected PGError, was: %#v", err) } _, err = db.Exec("SELECT 1 WHERE true = false") // returns no rows if err != nil { t.Fatal(err) } } func TestErrorOnQuery(t *testing.T) { db := openTestConn(t) defer db.Close() sql := "DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END; $$;" r, err := db.Query(sql) if err != nil { t.Fatal(err) } defer r.Close() if r.Next() { t.Fatal("unexpected row, want error") } _, ok := r.Err().(PGError) if !ok { t.Fatalf("expected PGError, was: %#v", r.Err()) } r, err = db.Query("SELECT 1 WHERE true = false") // returns no rows if err != nil { t.Fatal(err) } if r.Next() { t.Fatal("unexpected row") } } func TestBindError(t *testing.T) { db := openTestConn(t) defer db.Close() _, err := db.Exec("create temp table test (i integer)") if err != nil { t.Fatal(err) } _, err = db.Query("select * from test where i=$1", "hhh") if err == nil { t.Fatal("expected an error") } // Should not get error here r, err := db.Query("select * from test where i=$1", 1) if err != nil { t.Fatal(err) } defer r.Close() } func TestParseEnviron(t *testing.T) { expected := map[string]string{"dbname": "hello", "user": "goodbye"} results := parseEnviron([]string{"PGDATABASE=hello", "PGUSER=goodbye"}) if !reflect.DeepEqual(expected, results) { t.Fatalf("Expected: %#v Got: %#v", expected, results) } } func TestExecerInterface(t *testing.T) { // Gin up a straw man private struct just for the type check cn := &conn{c: nil} var cni interface{} = cn _, ok := cni.(driver.Execer) if !ok { t.Fatal("Driver doesn't implement Execer") } } func TestNullAfterNonNull(t *testing.T) { db := openTestConn(t) defer db.Close() r, err := db.Query("SELECT 9::integer UNION SELECT NULL::integer") if err != nil { t.Fatal(err) } var n sql.NullInt64 if !r.Next() { if r.Err() != nil { t.Fatal(err) } t.Fatal("expected row") } if err := r.Scan(&n); err != nil { t.Fatal(err) } if n.Int64 != 9 { t.Fatalf("expected 2, not %d", n.Int64) } if !r.Next() { if r.Err() != nil { t.Fatal(err) } t.Fatal("expected row") } if err := r.Scan(&n); err != nil { t.Fatal(err) } if n.Valid { t.Fatal("expected n to be invalid") } if n.Int64 != 0 { t.Fatalf("expected n to 2, not %d", n.Int64) } } // Stress test the performance of parsing results from the wire. func BenchmarkResultParsing(b *testing.B) { b.StopTimer() db := openTestConn(b) defer db.Close() _, err := db.Exec("BEGIN") if err != nil { b.Fatal(err) } b.StartTimer() for i := 0; i < b.N; i++ { res, err := db.Query("SELECT generate_series(1, 50000)") if err != nil { b.Fatal(err) } res.Close() } } func Test64BitErrorChecking(t *testing.T) { defer func() { if err := recover(); err != nil { t.Fatal("panic due to 0xFFFFFFFF != -1 " + "when int is 64 bits") } }() db := openTestConn(t) defer db.Close() r, err := db.Query(`SELECT * FROM (VALUES (0::integer, NULL::text), (1, 'test string')) AS t;`) if err != nil { t.Fatal(err) } defer r.Close() for r.Next() { } } // Open transaction, issue INSERT query inside transaction, rollback // transaction, issue SELECT query to same db used to create the tx. No rows // should be returned. func TestRollback(t *testing.T) { db := openTestConn(t) defer db.Close() _, err := db.Exec("CREATE TEMP TABLE temp (a int)") if err != nil { t.Fatal(err) } sqlInsert := "INSERT INTO temp VALUES (1)" sqlSelect := "SELECT * FROM temp" tx, err := db.Begin() if err != nil { t.Fatal(err) } _, err = tx.Query(sqlInsert) if err != nil { t.Fatal(err) } err = tx.Rollback() if err != nil { t.Fatal(err) } r, err := db.Query(sqlSelect) if err != nil { t.Fatal(err) } // Next() returns false if query returned no rows. if r.Next() { t.Fatal("Transaction rollback failed") } } func TestConnTrailingSpace(t *testing.T) { o := make(Values) expected := Values{"dbname": "hello", "user": "goodbye"} parseOpts("dbname=hello user=goodbye ", o) if !reflect.DeepEqual(expected, o) { t.Fatalf("Expected: %#v Got: %#v", expected, o) } } golang-pq-dev/conn.go0000644000014500017510000002770612163564140014306 0ustar michaelstaff// Package pq is a pure Go Postgres driver for the database/sql package. package pq import ( "bufio" "crypto/md5" "crypto/tls" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" "github.com/lib/pq/oid" "io" "net" "os" "path" "strconv" "strings" ) var ( ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") ErrNotSupported = errors.New("pq: invalid command") ) type drv struct{} func (d *drv) Open(name string) (driver.Conn, error) { return Open(name) } func init() { sql.Register("postgres", &drv{}) } type conn struct { c net.Conn buf *bufio.Reader namei int } func Open(name string) (_ driver.Conn, err error) { defer errRecover(&err) defer errRecoverWithPGReason(&err) o := make(Values) // A number of defaults are applied here, in this order: // // * Very low precedence defaults applied in every situation // * Environment variables // * Explicitly passed connection information o.Set("host", "localhost") o.Set("port", "5432") for k, v := range parseEnviron(os.Environ()) { o.Set(k, v) } parseOpts(name, o) // If a user is not provided by any other means, the last // resort is to use the current operating system provided user // name. if o.Get("user") == "" { u, err := userCurrent() if err != nil { return nil, err } else { o.Set("user", u) } } c, err := net.Dial(network(o)) if err != nil { return nil, err } cn := &conn{c: c} cn.ssl(o) cn.buf = bufio.NewReader(cn.c) cn.startup(o) return cn, nil } func network(o Values) (string, string) { host := o.Get("host") if strings.HasPrefix(host, "/") { sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) return "unix", sockPath } return "tcp", host + ":" + o.Get("port") } type Values map[string]string func (vs Values) Set(k, v string) { vs[k] = v } func (vs Values) Get(k string) (v string) { v, _ = vs[k] return } func parseOpts(name string, o Values) { if len(name) == 0 { return } name = strings.TrimSpace(name) ps := strings.Split(name, " ") for _, p := range ps { kv := strings.Split(p, "=") if len(kv) < 2 { errorf("invalid option: %q", p) } o.Set(kv[0], kv[1]) } } func (cn *conn) Begin() (driver.Tx, error) { _, err := cn.Exec("BEGIN", nil) if err != nil { return nil, err } return cn, err } func (cn *conn) Commit() error { _, err := cn.Exec("COMMIT", nil) return err } func (cn *conn) Rollback() error { _, err := cn.Exec("ROLLBACK", nil) return err } func (cn *conn) gname() string { cn.namei++ return strconv.FormatInt(int64(cn.namei), 10) } func (cn *conn) simpleQuery(q string) (res driver.Result, err error) { defer errRecover(&err) b := newWriteBuf('Q') b.string(q) cn.send(b) for { t, r := cn.recv1() switch t { case 'C': res = parseComplete(r.string()) case 'Z': // done return case 'E': err = parseError(r) case 'T', 'N', 'S', 'D': // ignore default: errorf("unknown response for simple query: %q", t) } } panic("not reached") } func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) { defer errRecover(&err) st := &stmt{cn: cn, name: stmtName, query: q} b := newWriteBuf('P') b.string(st.name) b.string(q) b.int16(0) cn.send(b) b = newWriteBuf('D') b.byte('S') b.string(st.name) cn.send(b) cn.send(newWriteBuf('S')) for { t, r := cn.recv1() switch t { case '1', '2', 'N': case 't': st.nparams = int(r.int16()) st.paramTyps = make([]oid.Oid, st.nparams, st.nparams) for i := 0; i < st.nparams; i += 1 { st.paramTyps[i] = r.oid() } case 'T': n := r.int16() st.cols = make([]string, n) st.rowTyps = make([]oid.Oid, n) for i := range st.cols { st.cols[i] = r.string() r.next(6) st.rowTyps[i] = r.oid() r.next(8) } case 'n': // no data case 'Z': return st, err case 'E': err = parseError(r) case 'C': // command complete return st, err default: errorf("unexpected describe rows response: %q", t) } } panic("not reached") } func (cn *conn) Prepare(q string) (driver.Stmt, error) { return cn.prepareTo(q, cn.gname()) } func (cn *conn) Close() (err error) { defer errRecover(&err) cn.send(newWriteBuf('X')) return cn.c.Close() } // Implement the optional "Execer" interface for one-shot queries func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) { defer errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is // *much* faster than going through prepare/exec if len(args) == 0 { return cn.simpleQuery(query) } // Use the unnamed statement to defer planning until bind // time, or else value-based selectivity estimates cannot be // used. st, err := cn.prepareTo(query, "") if err != nil { panic(err) } r, err := st.Exec(args) if err != nil { panic(err) } return r, err } // Assumes len(*m) is > 5 func (cn *conn) send(m *writeBuf) { b := (*m)[1:] binary.BigEndian.PutUint32(b, uint32(len(b))) if (*m)[0] == 0 { *m = b } _, err := cn.c.Write(*m) if err != nil { panic(err) } } func (cn *conn) recv() (t byte, r *readBuf) { for { t, r = cn.recv1() switch t { case 'E': panic(parseError(r)) case 'N': // ignore default: return } } panic("not reached") } func (cn *conn) recv1() (byte, *readBuf) { x := make([]byte, 5) _, err := io.ReadFull(cn.buf, x) if err != nil { panic(err) } b := readBuf(x[1:]) y := make([]byte, b.int32()-4) _, err = io.ReadFull(cn.buf, y) if err != nil { panic(err) } return x[0], (*readBuf)(&y) } func (cn *conn) ssl(o Values) { tlsConf := tls.Config{} switch mode := o.Get("sslmode"); mode { case "require", "": tlsConf.InsecureSkipVerify = true case "verify-full": // fall out case "disable": return default: errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode) } w := newWriteBuf(0) w.int32(80877103) cn.send(w) b := make([]byte, 1) _, err := io.ReadFull(cn.c, b) if err != nil { panic(err) } if b[0] != 'S' { panic(ErrSSLNotSupported) } cn.c = tls.Client(cn.c, &tlsConf) } func (cn *conn) startup(o Values) { w := newWriteBuf(0) w.int32(196608) w.string("user") w.string(o.Get("user")) w.string("database") w.string(o.Get("dbname")) w.string("") cn.send(w) for { t, r := cn.recv() switch t { case 'K', 'S': case 'R': cn.auth(r, o) case 'Z': return default: errorf("unknown response for startup: %q", t) } } } func (cn *conn) auth(r *readBuf, o Values) { switch code := r.int32(); code { case 0: // OK case 3: w := newWriteBuf('p') w.string(o.Get("password")) cn.send(w) t, r := cn.recv() if t != 'R' { errorf("unexpected password response: %q", t) } if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } case 5: s := string(r.next(4)) w := newWriteBuf('p') w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) cn.send(w) t, r := cn.recv() if t != 'R' { errorf("unexpected password response: %q", t) } if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } default: errorf("unknown authentication response: %d", code) } } type stmt struct { cn *conn name string query string cols []string nparams int rowTyps []oid.Oid paramTyps []oid.Oid closed bool } func (st *stmt) Close() (err error) { if st.closed { return nil } defer errRecover(&err) w := newWriteBuf('C') w.byte('S') w.string(st.name) st.cn.send(w) st.cn.send(newWriteBuf('S')) t, _ := st.cn.recv() if t != '3' { errorf("unexpected close response: %q", t) } st.closed = true t, _ = st.cn.recv() if t != 'Z' { errorf("expected ready for query, but got: %q", t) } return nil } func (st *stmt) Query(v []driver.Value) (_ driver.Rows, err error) { defer errRecover(&err) st.exec(v) return &rows{st: st}, nil } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { defer errRecover(&err) if len(v) == 0 { return st.cn.simpleQuery(st.query) } st.exec(v) for { t, r := st.cn.recv1() switch t { case 'E': err = parseError(r) case 'C': res = parseComplete(r.string()) case 'Z': // done return case 'T', 'N', 'S', 'D': // Ignore default: errorf("unknown exec response: %q", t) } } panic("not reached") } func (st *stmt) exec(v []driver.Value) { w := newWriteBuf('B') w.string("") w.string(st.name) w.int16(0) w.int16(len(v)) for i, x := range v { if x == nil { w.int32(-1) } else { b := encode(x, st.paramTyps[i]) w.int32(len(b)) w.bytes(b) } } w.int16(0) st.cn.send(w) w = newWriteBuf('E') w.string("") w.int32(0) st.cn.send(w) st.cn.send(newWriteBuf('S')) var err error for { t, r := st.cn.recv1() switch t { case 'E': err = parseError(r) case '2': if err != nil { panic(err) } return case 'Z': if err != nil { panic(err) } return case 'N': // ignore default: errorf("unexpected bind response: %q", t) } } } func (st *stmt) NumInput() int { return st.nparams } type result int64 func (i result) RowsAffected() (int64, error) { return int64(i), nil } func (i result) LastInsertId() (int64, error) { return 0, ErrNotSupported } func parseComplete(s string) driver.Result { parts := strings.Split(s, " ") n, _ := strconv.ParseInt(parts[len(parts)-1], 10, 64) return result(n) } type rows struct { st *stmt done bool } func (rs *rows) Close() error { for { err := rs.Next(nil) switch err { case nil: case io.EOF: return nil default: return err } } panic("not reached") } func (rs *rows) Columns() []string { return rs.st.cols } func (rs *rows) Next(dest []driver.Value) (err error) { if rs.done { return io.EOF } defer errRecover(&err) for { t, r := rs.st.cn.recv1() switch t { case 'E': err = parseError(r) case 'C', 'S', 'N': continue case 'Z': rs.done = true if err != nil { return err } return io.EOF case 'D': n := r.int16() for i := 0; i < len(dest) && i < n; i++ { l := r.int32() if l == -1 { dest[i] = nil continue } dest[i] = decode(r.next(l), rs.st.rowTyps[i]) } return default: errorf("unexpected message after execute: %q", t) } } panic("not reached") } func md5s(s string) string { h := md5.New() h.Write([]byte(s)) return fmt.Sprintf("%x", h.Sum(nil)) } // parseEnviron tries to mimic some of libpq's environment handling // // To ease testing, it does not directly reference os.Environ, but is // designed to accept its output. // // Environment-set connection information is intended to have a higher // precedence than a library default but lower than any explicitly // passed information (such as in the URL or connection string). func parseEnviron(env []string) (out map[string]string) { out = make(map[string]string) for _, v := range env { parts := strings.SplitN(v, "=", 2) accrue := func(keyname string) { out[keyname] = parts[1] } // The order of these is the same as is seen in the // PostgreSQL 9.1 manual, with omissions briefly // noted. switch parts[0] { case "PGHOST": accrue("host") case "PGHOSTADDR": accrue("hostaddr") case "PGPORT": accrue("port") case "PGDATABASE": accrue("dbname") case "PGUSER": accrue("user") case "PGPASSWORD": accrue("password") // skip PGPASSFILE, PGSERVICE, PGSERVICEFILE, // PGREALM case "PGOPTIONS": accrue("options") case "PGAPPNAME": accrue("application_name") case "PGSSLMODE": accrue("sslmode") case "PGREQUIRESSL": accrue("requiressl") case "PGSSLCERT": accrue("sslcert") case "PGSSLKEY": accrue("sslkey") case "PGSSLROOTCERT": accrue("sslrootcert") case "PGSSLCRL": accrue("sslcrl") case "PGREQUIREPEER": accrue("requirepeer") case "PGKRBSRVNAME": accrue("krbsrvname") case "PGGSSLIB": accrue("gsslib") case "PGCONNECT_TIMEOUT": accrue("connect_timeout") case "PGCLIENTENCODING": accrue("client_encoding") // skip PGDATESTYLE, PGTZ, PGGEQO, PGSYSCONFDIR, // PGLOCALEDIR } } return out } golang-pq-dev/buf.go0000644000014500017510000000246012163564140014113 0ustar michaelstaffpackage pq import ( "bytes" "encoding/binary" "github.com/lib/pq/oid" ) type readBuf []byte func (b *readBuf) int32() (n int) { n = int(int32(binary.BigEndian.Uint32(*b))) *b = (*b)[4:] return } func (b *readBuf) oid() (n oid.Oid) { n = oid.Oid(binary.BigEndian.Uint32(*b)) *b = (*b)[4:] return } func (b *readBuf) int16() (n int) { n = int(binary.BigEndian.Uint16(*b)) *b = (*b)[2:] return } var stringTerm = []byte{0} func (b *readBuf) string() string { i := bytes.Index(*b, stringTerm) if i < 0 { errorf("invalid message format; expected string terminator") } s := (*b)[:i] *b = (*b)[i+1:] return string(s) } func (b *readBuf) next(n int) (v []byte) { v = (*b)[:n] *b = (*b)[n:] return } func (b *readBuf) byte() byte { return b.next(1)[0] } type writeBuf []byte func newWriteBuf(c byte) *writeBuf { b := make(writeBuf, 5) b[0] = c return &b } func (b *writeBuf) int32(n int) { x := make([]byte, 4) binary.BigEndian.PutUint32(x, uint32(n)) *b = append(*b, x...) } func (b *writeBuf) int16(n int) { x := make([]byte, 2) binary.BigEndian.PutUint16(x, uint16(n)) *b = append(*b, x...) } func (b *writeBuf) string(s string) { *b = append(*b, (s + "\000")...) } func (b *writeBuf) byte(c byte) { *b = append(*b, c) } func (b *writeBuf) bytes(v []byte) { *b = append(*b, v...) } golang-pq-dev/README.md0000644000014500017510000000551112163564140014267 0ustar michaelstaff# pq - A pure Go postgres driver for Go's database/sql package ## Install go get github.com/lib/pq ## Docs ## Use package main import ( _ "github.com/lib/pq" "database/sql" ) func main() { db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full") // ... } **Connection String Parameters** These are a subset of the libpq connection parameters. In addition, a number of the [environment variables](http://www.postgresql.org/docs/9.1/static/libpq-envars.html) supported by libpq are also supported. Just like libpq, these have lower precedence than explicitly provided connection parameters. See http://www.postgresql.org/docs/9.1/static/libpq-connect.html. * `dbname` - The name of the database to connect to * `user` - The user to sign in as * `password` - The user's password * `host` - The host to connect to. Values that start with `/` are for unix domain sockets. (default is `localhost`) * `port` - The port to bind to. (default is `5432`) * `sslmode` - Whether or not to use SSL (default is `require`, this is not the default for libpq) Valid values are: * `disable` - No SSL * `require` - Always SSL (skip verification) * `verify-full` - Always SSL (require verification) See http://golang.org/pkg/database/sql to learn how to use with `pq` through the `database/sql` package. ## Tests `go test` is used for testing. A running PostgreSQL server is required, with the ability to log in. The default database to connect to test with is "pqgotest," but it can be overridden using environment variables. Example: PGHOST=/var/run/postgresql go test pq ## Features * SSL * Handles bad connections for `database/sql` * Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`) * Scan binary blobs correctly (i.e. `bytea`) * pq.ParseURL for converting urls to connection strings for sql.Open. * Many libpq compatible environment variables * Unix socket support ## Future / Things you can help with * Notifications: `LISTEN`/`NOTIFY` * `hstore` sugar (i.e. handling hstore in `rows.Scan`) ## Thank you (alphabetical) Some of these contributors are from the original library `bmizerany/pq.go` whose code still exists in here. * Andy Balholm (andybalholm) * Ben Berkert (benburkert) * Bill Mill (llimllib) * Bjørn Madsen (aeons) * Blake Gentry (bgentry) * Brad Fitzpatrick (bradfitz) * Chris Walsh (cwds) * Daniel Farina (fdr) * Everyone at The Go Team * Ewan Chou (coocood) * Federico Romero (federomero) * Gary Burd (garyburd) * Heroku (heroku) * Jason McVetta (jmcvetta) * Joakim Sernbrant (serbaut) * John Gallagher (jgallagher) * Kamil Kisiel (kisielk) * Keith Rarick (kr) * Maciek Sakrejda (deafbybeheading) * Marc Brinkmann (mbr) * Martin Olsen (martinolsen) * Mike Lewis (mikelikespie) * Ryan Smith (ryandotsmith) * Samuel Stauffer (samuel) * notedit (notedit) golang-pq-dev/LICENSE.md0000644000014500017510000000212612163564140014413 0ustar michaelstaffCopyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.