pax_global_header00006660000000000000000000000064130565030710014512gustar00rootroot0000000000000052 comment=2246060a2a43f8282ad53295d56d780dbc930b7f pg-5.3.3/000077500000000000000000000000001305650307100121305ustar00rootroot00000000000000pg-5.3.3/.travis.yml000066400000000000000000000023051305650307100142410ustar00rootroot00000000000000sudo: required language: go go: - 1.7 - 1.8 - tip matrix: allow_failures: - go: tip before_install: - sudo /etc/init.d/postgresql stop - sudo apt-get -y remove --purge postgresql-9.1 - sudo apt-get -y remove --purge postgresql-9.2 - sudo apt-get -y remove --purge postgresql-9.3 - sudo apt-get -y remove --purge postgresql-9.4 - sudo apt-get -y remove --purge postgresql-9.5 - sudo apt-get -y autoremove - sudo apt-key adv --keyserver keys.gnupg.net --recv-keys 7FCC7D46ACCC4CF8 - sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt/ precise-pgdg main 9.6" >> /etc/apt/sources.list.d/postgresql.list' - sudo apt-get update - sudo apt-get -y install postgresql-9.6 - sudo sh -c 'echo "local all postgres trust" > /etc/postgresql/9.6/main/pg_hba.conf' - sudo sh -c 'echo -n "host all all 0.0.0.0/0 trust" >> /etc/postgresql/9.6/main/pg_hba.conf' - sudo /etc/init.d/postgresql restart - sudo -u postgres psql -c "CREATE EXTENSION hstore" install: - go get github.com/jinzhu/inflection - go get gopkg.in/check.v1 - go get github.com/onsi/ginkgo - go get github.com/onsi/gomega - mkdir -p $HOME/gopath/src/gopkg.in - mv `pwd` $HOME/gopath/src/gopkg.in/pg.v5 pg-5.3.3/CHANGELOG.md000066400000000000000000000023031305650307100137370ustar00rootroot00000000000000# Changelog ## v5 - All fields are nullable by default. `,null` tag is replaced with `,notnull`. - `Result.Affected` renamed to `Result.RowsAffected`. - Added `Result.RowsReturned`. - `Create` renamed to `Insert`, `BeforeCreate` to `BeforeInsert`, `AfterCreate` to `AfterInsert`. - Indexed placeholders support, e.g. `db.Exec("SELECT ?0 + ?0", 1)`. - Named placeholders are evaluated when query is executed. - Added Update and Delete hooks. - Order reworked to quote column names. OrderExpr added to bypass Order quoting restrictions. - Group reworked to quote column names. GroupExpr added to bypass Group quoting restrictions. ## v4 - `Options.Host` and `Options.Port` merged into `Options.Addr`. - Added `Options.MaxRetries`. Now queries are not retried by default. - `LoadInto` renamed to `Scan`, `ColumnLoader` renamed to `ColumnScanner`, LoadColumn renamed to ScanColumn, `NewRecord() interface{}` changed to `NewModel() ColumnScanner`, `AppendQuery(dst []byte) []byte` changed to `AppendValue(dst []byte, quote bool) ([]byte, error)`. - Structs, maps and slices are marshalled to JSON by default. - Added support for scanning slices, .e.g. scanning `[]int`. - Added object relational mapping. pg-5.3.3/LICENSE000066400000000000000000000024341305650307100131400ustar00rootroot00000000000000Copyright (c) 2013 github.com/go-pg/pg Authors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. pg-5.3.3/Makefile000066400000000000000000000000701305650307100135650ustar00rootroot00000000000000all: go test ./... go test ./... -short -race go vet pg-5.3.3/README.md000066400000000000000000000121341305650307100134100ustar00rootroot00000000000000# PostgreSQL client for Golang [![Build Status](https://travis-ci.org/go-pg/pg.svg)](https://travis-ci.org/go-pg/pg) ## Features: - Basic types: integers, floats, string, bool, time.Time. - sql.NullBool, sql.NullString, sql.NullInt64, sql.NullFloat64 and [pg.NullTime](http://godoc.org/gopkg.in/pg.v5#NullTime). - [sql.Scanner](http://golang.org/pkg/database/sql/#Scanner) and [sql/driver.Valuer](http://golang.org/pkg/database/sql/driver/#Valuer) interfaces. - Structs, maps and arrays are marshalled as JSON by default. - PostgreSQL multidimensional Arrays using [array tag](https://godoc.org/gopkg.in/pg.v5#example-DB-Model-PostgresArrayStructTag) and [Array wrapper](https://godoc.org/gopkg.in/pg.v5#example-Array). - Hstore using [hstore tag](https://godoc.org/gopkg.in/pg.v5#example-DB-Model-HstoreStructTag) and [Hstore wrapper](https://godoc.org/gopkg.in/pg.v5#example-Hstore). - All struct fields are nullable by default and zero values (empty string, 0, zero time) are marshalled as SQL `NULL`. ```sql:",notnull"` is used to reverse this behaviour. - [Transactions](http://godoc.org/gopkg.in/pg.v5#example-DB-Begin). - [Prepared statements](http://godoc.org/gopkg.in/pg.v5#example-DB-Prepare). - [Notifications](http://godoc.org/gopkg.in/pg.v5#example-Listener) using `LISTEN` and `NOTIFY`. - [Copying data](http://godoc.org/gopkg.in/pg.v5#example-DB-CopyFrom) using `COPY FROM` and `COPY TO`. - [Timeouts](http://godoc.org/gopkg.in/pg.v5#Options). - Automatic connection pooling. - Queries retries on network errors. - Working with models using [ORM](https://godoc.org/gopkg.in/pg.v5#example-DB-Model) and [SQL](https://godoc.org/gopkg.in/pg.v5#example-DB-Query). - Scanning variables using [ORM](https://godoc.org/gopkg.in/pg.v5#example-DB-Select-SomeColumnsIntoVars) and [SQL](https://godoc.org/gopkg.in/pg.v5#example-Scan). - [SelectOrInsert](https://godoc.org/gopkg.in/pg.v5#example-DB-Insert-SelectOrInsert) using on-conflict. - [INSERT ... ON CONFLICT DO UPDATE](https://godoc.org/gopkg.in/pg.v5#example-DB-Insert-OnConflictDoUpdate) using ORM. - Common table expressions using [WITH](https://godoc.org/gopkg.in/pg.v5#example-DB-Select-With) and [WrapWith](https://godoc.org/gopkg.in/pg.v5#example-DB-Select-WrapWith). - [CountEstimate](https://godoc.org/gopkg.in/pg.v5#example-DB-Model-CountEstimate) using `EXPLAIN` to get [estimated number of matching rows](https://wiki.postgresql.org/wiki/Count_estimate). - [HasOne](https://godoc.org/gopkg.in/pg.v5#example-DB-Model-HasOne), [BelongsTo](https://godoc.org/gopkg.in/pg.v5#example-DB-Model-BelongsTo), [HasMany](https://godoc.org/gopkg.in/pg.v5#example-DB-Model-HasMany) and [ManyToMany](https://godoc.org/gopkg.in/pg.v5#example-DB-Model-ManyToMany). - [Creating tables from structs](https://godoc.org/gopkg.in/pg.v5#example-DB-CreateTable). - [Migrations](https://github.com/go-pg/migrations). - [Sharding](https://github.com/go-pg/sharding). ## Get Started - [Wiki](https://github.com/go-pg/pg/wiki) - [API docs](http://godoc.org/gopkg.in/pg.v5) - [Examples](http://godoc.org/gopkg.in/pg.v5#pkg-examples) ## Look & Feel ```go package pg_test import ( "fmt" "gopkg.in/pg.v5" "gopkg.in/pg.v5/orm" ) type User struct { Id int64 Name string Emails []string } func (u User) String() string { return fmt.Sprintf("User<%d %s %v>", u.Id, u.Name, u.Emails) } type Story struct { Id int64 Title string AuthorId int64 Author *User } func (s Story) String() string { return fmt.Sprintf("Story<%d %s %s>", s.Id, s.Title, s.Author) } func ExampleDB_Model() { db := pg.Connect(&pg.Options{ User: "postgres", }) err := createSchema(db) if err != nil { panic(err) } user1 := &User{ Name: "admin", Emails: []string{"admin1@admin", "admin2@admin"}, } err = db.Insert(user1) if err != nil { panic(err) } err = db.Insert(&User{ Name: "root", Emails: []string{"root1@root", "root2@root"}, }) if err != nil { panic(err) } story1 := &Story{ Title: "Cool story", AuthorId: user1.Id, } err = db.Insert(story1) if err != nil { panic(err) } // Select user by primary key. user := User{Id: user1.Id} err = db.Select(&user) if err != nil { panic(err) } // Select all users. var users []User err = db.Model(&users).Select() if err != nil { panic(err) } // Select story and associated author in one query. var story Story err = db.Model(&story). Column("story.*", "Author"). Where("story.id = ?", story1.Id). Select() if err != nil { panic(err) } fmt.Println(user) fmt.Println(users) fmt.Println(story) // Output: User<1 admin [admin1@admin admin2@admin]> // [User<1 admin [admin1@admin admin2@admin]> User<2 root [root1@root root2@root]>] // Story<1 Cool story User<1 admin [admin1@admin admin2@admin]>> } func createSchema(db *pg.DB) error { for _, model := range []interface{}{&User{}, &Story{}} { err := db.CreateTable(model, &orm.CreateTableOptions{ Temp: true, }) if err != nil { return err } } return nil } ``` pg-5.3.3/bench_test.go000066400000000000000000000225751305650307100146100ustar00rootroot00000000000000package pg_test import ( "fmt" "math/rand" "strconv" "sync" "testing" "time" "gopkg.in/pg.v5" "gopkg.in/pg.v5/orm" ) func benchmarkDB() *pg.DB { return pg.Connect(&pg.Options{ User: "postgres", Database: "postgres", DialTimeout: 30 * time.Second, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, PoolSize: 10, PoolTimeout: 30 * time.Second, }) } func BenchmarkQueryRowsGopgDiscard(b *testing.B) { seedDB() db := benchmarkDB() defer db.Close() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _, err := db.Query(pg.Discard, `SELECT * FROM records LIMIT 100`) if err != nil { b.Fatal(err) } } }) } func BenchmarkQueryRowsGopgOptimized(b *testing.B) { seedDB() db := benchmarkDB() defer db.Close() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { var rs OptRecords _, err := db.Query(&rs, `SELECT * FROM records LIMIT 100`) if err != nil { b.Fatal(err) } if len(rs.C) != 100 { b.Fatalf("got %d, wanted 100", len(rs.C)) } } }) } func BenchmarkQueryRowsGopgReflect(b *testing.B) { seedDB() db := benchmarkDB() defer db.Close() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { var rs []Record _, err := db.Query(&rs, `SELECT * FROM records LIMIT 100`) if err != nil { b.Fatal(err) } if len(rs) != 100 { b.Fatalf("got %d, wanted 100", len(rs)) } } }) } func BenchmarkQueryRowsGopgORM(b *testing.B) { seedDB() db := benchmarkDB() defer db.Close() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { var rs []Record err := db.Model(&rs).Limit(100).Select() if err != nil { b.Fatal(err) } if len(rs) != 100 { b.Fatalf("got %d, wanted 100", len(rs)) } } }) } func BenchmarkModelHasOneGopg(b *testing.B) { seedDB() db := benchmarkDB() defer db.Close() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { var books []Book err := db.Model(&books).Column("book.*", "Author").Limit(100).Select() if err != nil { b.Fatal(err) } if len(books) != 100 { b.Fatalf("got %d, wanted 100", len(books)) } } }) } func BenchmarkModelHasManyGopg(b *testing.B) { seedDB() db := benchmarkDB() defer db.Close() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { var books []Book err := db.Model(&books).Column("book.*", "Translations").Limit(100).Select() if err != nil { b.Fatal(err) } if len(books) != 100 { b.Fatalf("got %d, wanted 100", len(books)) } for _, book := range books { if len(book.Translations) != 10 { b.Fatalf("got %d, wanted 10", len(book.Translations)) } } } }) } func BenchmarkModelHasMany2ManyGopg(b *testing.B) { seedDB() db := benchmarkDB() defer db.Close() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { var books []Book err := db.Model(&books). Column("book.*", "Genres"). Limit(100). Select() if err != nil { b.Fatal(err) } if len(books) != 100 { b.Fatalf("got %d, wanted 100", len(books)) } for _, book := range books { if len(book.Genres) != 10 { b.Fatalf("got %d, wanted 10", len(book.Genres)) } } } }) } func BenchmarkQueryRow(b *testing.B) { db := benchmarkDB() defer db.Close() b.ResetTimer() for i := 0; i < b.N; i++ { var dst numLoader _, err := db.QueryOne(&dst, `SELECT ?::bigint AS num`, 1) if err != nil { b.Fatal(err) } if dst.Num != 1 { b.Fatalf("got %d, wanted 1", dst.Num) } } } func BenchmarkQueryRowStmt(b *testing.B) { db := benchmarkDB() defer db.Close() stmt, err := db.Prepare(`SELECT $1::bigint AS num`) if err != nil { b.Fatal(err) } defer stmt.Close() b.ResetTimer() for i := 0; i < b.N; i++ { var dst numLoader _, err := stmt.QueryOne(&dst, 1) if err != nil { b.Fatal(err) } if dst.Num != 1 { b.Fatalf("got %d, wanted 1", dst.Num) } } } func BenchmarkQueryRowScan(b *testing.B) { db := benchmarkDB() defer db.Close() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { var n int64 _, err := db.QueryOne(pg.Scan(&n), `SELECT ? AS num`, 1) if err != nil { b.Fatal(err) } if n != 1 { b.Fatalf("got %d, wanted 1", n) } } }) } func BenchmarkQueryRowStmtScan(b *testing.B) { db := benchmarkDB() defer db.Close() stmt, err := db.Prepare(`SELECT $1::bigint AS num`) if err != nil { b.Fatal(err) } defer stmt.Close() b.ResetTimer() for i := 0; i < b.N; i++ { var n int64 _, err := stmt.QueryOne(pg.Scan(&n), 1) if err != nil { b.Fatal(err) } if n != 1 { b.Fatalf("got %d, wanted 1", n) } } } func BenchmarkExec(b *testing.B) { db := benchmarkDB() defer db.Close() qs := []string{ `DROP TABLE IF EXISTS exec_test`, `CREATE TABLE exec_test(id bigint, name varchar(500))`, } for _, q := range qs { _, err := db.Exec(q) if err != nil { b.Fatal(err) } } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _, err := db.Exec(`INSERT INTO exec_test (id, name) VALUES (?, ?)`, 1, "hello world") if err != nil { b.Fatal(err) } } }) } func BenchmarkExecWithError(b *testing.B) { db := benchmarkDB() defer db.Close() qs := []string{ `DROP TABLE IF EXISTS exec_with_error_test`, `CREATE TABLE exec_with_error_test(id bigint PRIMARY KEY, name varchar(500))`, } for _, q := range qs { _, err := db.Exec(q) if err != nil { b.Fatal(err) } } _, err := db.Exec(` INSERT INTO exec_with_error_test(id, name) VALUES(?, ?) `, 1, "hello world") if err != nil { b.Fatal(err) } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _, err := db.Exec(`INSERT INTO exec_with_error_test(id) VALUES(?)`, 1) if err == nil { b.Fatalf("got nil error, expected integrity violation") } else if pgErr, ok := err.(pg.Error); !ok || !pgErr.IntegrityViolation() { b.Fatalf("got %s, expected integrity violation", err) } } }) } func BenchmarkExecStmt(b *testing.B) { db := benchmarkDB() defer db.Close() _, err := db.Exec(`CREATE TEMP TABLE statement_exec(id bigint, name varchar(500))`) if err != nil { b.Fatal(err) } stmt, err := db.Prepare(`INSERT INTO statement_exec (id, name) VALUES ($1, $2)`) if err != nil { b.Fatal(err) } defer stmt.Close() b.ResetTimer() for i := 0; i < b.N; i++ { _, err := stmt.Exec(1, "hello world") if err != nil { b.Fatal(err) } } } var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") func randSeq(n int) string { b := make([]rune, n) for i := range b { b[i] = letters[rand.Intn(len(letters))] } return string(b) } type Record struct { Num1, Num2, Num3 int64 Str1, Str2, Str3 string } func (r *Record) GetNum1() int64 { return r.Num1 } func (r *Record) GetNum2() int64 { return r.Num2 } func (r *Record) GetNum3() int64 { return r.Num3 } func (r *Record) GetStr1() string { return r.Str1 } func (r *Record) GetStr2() string { return r.Str2 } func (r *Record) GetStr3() string { return r.Str3 } type OptRecord struct { Num1, Num2, Num3 int64 Str1, Str2, Str3 string } var _ orm.ColumnScanner = (*OptRecord)(nil) func (r *OptRecord) ScanColumn(colIdx int, colName string, b []byte) error { var err error switch colName { case "num1": r.Num1, err = strconv.ParseInt(string(b), 10, 64) case "num2": r.Num2, err = strconv.ParseInt(string(b), 10, 64) case "num3": r.Num3, err = strconv.ParseInt(string(b), 10, 64) case "str1": r.Str1 = string(b) case "str2": r.Str2 = string(b) case "str3": r.Str3 = string(b) default: return fmt.Errorf("unknown column: %q", colName) } return err } type OptRecords struct { C []OptRecord } func (rs *OptRecords) NewModel() orm.ColumnScanner { rs.C = append(rs.C, OptRecord{}) return &rs.C[len(rs.C)-1] } func (OptRecords) AddModel(_ orm.ColumnScanner) error { return nil } func (OptRecords) AfterSelect(_ orm.DB) error { return nil } var seedDBOnce sync.Once func seedDB() { seedDBOnce.Do(func() { if err := _seedDB(); err != nil { panic(err) } }) } func _seedDB() error { db := benchmarkDB() defer db.Close() _, err := db.Exec(`DROP TABLE IF EXISTS records`) if err != nil { return err } _, err = db.Exec(` CREATE TABLE records( num1 serial, num2 serial, num3 serial, str1 text, str2 text, str3 text ) `) if err != nil { return err } for i := 0; i < 1000; i++ { _, err := db.Exec(` INSERT INTO records (str1, str2, str3) VALUES (?, ?, ?) `, randSeq(100), randSeq(200), randSeq(300)) if err != nil { return err } } err = createTestSchema(db) if err != nil { return err } for i := 1; i < 100; i++ { genre := Genre{ Id: i, Name: fmt.Sprintf("genre %d", i), } err = db.Insert(&genre) if err != nil { return err } author := Author{ ID: i, Name: fmt.Sprintf("author %d", i), } err = db.Insert(&author) if err != nil { return err } } for i := 1; i <= 1000; i++ { err = db.Insert(&Book{ Id: i, Title: fmt.Sprintf("book %d", i), AuthorID: rand.Intn(99) + 1, CreatedAt: time.Now(), }) if err != nil { return err } for j := 1; j <= 10; j++ { err = db.Insert(&BookGenre{ BookId: i, GenreId: rand.Intn(99) + 1, }) if err != nil { return err } err = db.Insert(&Translation{ BookId: i, Lang: fmt.Sprintf("%d", j), }) if err != nil { return err } } } return nil } pg-5.3.3/conv_test.go000066400000000000000000000416301305650307100144670ustar00rootroot00000000000000package pg_test import ( "database/sql" "database/sql/driver" "encoding/json" "fmt" "math" "reflect" "testing" "time" "gopkg.in/pg.v5" "gopkg.in/pg.v5/orm" "gopkg.in/pg.v5/types" ) type JSONMap map[string]interface{} func (m *JSONMap) Scan(b interface{}) error { if b == nil { *m = nil return nil } return json.Unmarshal(b.([]byte), m) } func (m JSONMap) Value() (driver.Value, error) { b, err := json.Marshal(m) if err != nil { return nil, err } return string(b), nil } type ( StringSlice []string IntSlice []int Int64Slice []int64 Float64Slice []float64 ) type Struct struct { Foo string } type conversionTest struct { i int src, dst, wanted interface{} pgtype string wanterr string wantnil bool wantzero bool } func unwrap(v interface{}) interface{} { if arr, ok := v.(*types.Array); ok { return arr.Value() } if hstore, ok := v.(*types.Hstore); ok { return hstore.Value() } return v } func deref(vi interface{}) interface{} { v := reflect.ValueOf(vi) for v.Kind() == reflect.Ptr { v = v.Elem() } if v.IsValid() { return v.Interface() } return nil } func zero(v interface{}) interface{} { return reflect.Zero(reflect.ValueOf(v).Elem().Type()).Interface() } func (test *conversionTest) String() string { return fmt.Sprintf("#%d src=%#v dst=%#v", test.i, test.src, test.dst) } func (test *conversionTest) Assert(t *testing.T, err error) { if test.wanterr != "" { if err == nil || err.Error() != test.wanterr { t.Fatalf("got error %q, wanted %q (%s)", err, test.wanterr, test) } return } if err != nil { t.Fatalf("got error %q, wanted nil (%s)", err, test) } dst := reflect.Indirect(reflect.ValueOf(unwrap(test.dst))).Interface() if test.wantnil { dstValue := reflect.ValueOf(dst) if !dstValue.IsValid() { return } if dstValue.IsNil() { return } t.Fatalf("got %#v, wanted nil (%s)", dst, test) return } // Remove any intermediate pointers to compare values. dst = deref(unwrap(dst)) src := deref(unwrap(test.src)) if test.wantzero { dstValue := reflect.ValueOf(dst) switch dstValue.Kind() { case reflect.Slice, reflect.Map: if dstValue.IsNil() { t.Fatalf("got nil, wanted zero value") } if dstValue.Len() != 0 { t.Fatalf("got %d items, wanted 0", dstValue.Len()) } default: zero := zero(test.dst) if dst != zero { t.Fatalf("%#v != %#v (%s)", dst, zero, test) } } return } if dstTime, ok := dst.(time.Time); ok { srcTime := src.(time.Time) if dstTime.Unix() != srcTime.Unix() { t.Fatalf("%#v != %#v", dstTime, srcTime) } return } if dstTimes, ok := dst.([]time.Time); ok { srcTimes := src.([]time.Time) for i, dstTime := range dstTimes { srcTime := srcTimes[i] if dstTime.Unix() != srcTime.Unix() { t.Fatalf("%#v != %#v", dstTime, srcTime) } } return } wanted := test.wanted if wanted == nil { wanted = src } if !reflect.DeepEqual(dst, wanted) { t.Fatalf("%#v != %#v (%s)", dst, wanted, test) } } func conversionTests() []conversionTest { return []conversionTest{ {src: nil, dst: nil, wanterr: "pg: Scan(nil)"}, {src: nil, dst: new(uintptr), wanterr: "pg: Scan(unsupported uintptr)"}, {src: nil, dst: true, pgtype: "bool", wanterr: "pg: Scan(non-pointer bool)"}, {src: nil, dst: new(*bool), pgtype: "bool", wantnil: true}, {src: nil, dst: new(bool), pgtype: "bool", wantzero: true}, {src: true, dst: new(bool), pgtype: "bool"}, {src: true, dst: new(*bool), pgtype: "bool"}, {src: 1, dst: new(bool), wanted: true}, {src: nil, dst: "", pgtype: "text", wanterr: "pg: Scan(non-pointer string)"}, {src: nil, dst: new(string), pgtype: "text", wantzero: true}, {src: nil, dst: new(*string), pgtype: "text", wantnil: true}, {src: "hello world", dst: new(string), pgtype: "text"}, {src: "hello world", dst: new(*string), pgtype: "text"}, {src: "'\"\000", dst: new(string), wanted: `'"`, pgtype: "text"}, {src: nil, dst: []byte(nil), pgtype: "bytea", wanterr: "pg: Scan(non-pointer []uint8)"}, {src: nil, dst: new([]byte), pgtype: "bytea", wantnil: true}, {src: []byte("hello world\000"), dst: new([]byte), pgtype: "bytea"}, {src: []byte{}, dst: new([]byte), pgtype: "bytea", wantzero: true}, {src: nil, dst: int8(0), pgtype: "smallint", wanterr: "pg: Scan(non-pointer int8)"}, {src: nil, dst: new(int8), pgtype: "smallint", wantzero: true}, {src: int8(math.MaxInt8), dst: new(int8), pgtype: "smallint"}, {src: int8(math.MaxInt8), dst: new(*int8), pgtype: "smallint"}, {src: int8(math.MinInt8), dst: new(int8), pgtype: "smallint"}, {src: nil, dst: int16(0), pgtype: "smallint", wanterr: "pg: Scan(non-pointer int16)"}, {src: nil, dst: new(int16), pgtype: "smallint", wantzero: true}, {src: int16(math.MaxInt16), dst: new(int16), pgtype: "smallint"}, {src: int16(math.MaxInt16), dst: new(*int16), pgtype: "smallint"}, {src: int16(math.MinInt16), dst: new(int16), pgtype: "smallint"}, {src: nil, dst: int32(0), pgtype: "int", wanterr: "pg: Scan(non-pointer int32)"}, {src: nil, dst: new(int32), pgtype: "int", wantzero: true}, {src: int32(math.MaxInt32), dst: new(int32), pgtype: "int"}, {src: int32(math.MaxInt32), dst: new(*int32), pgtype: "int"}, {src: int32(math.MinInt32), dst: new(int32), pgtype: "int"}, {src: nil, dst: int64(0), pgtype: "bigint", wanterr: "pg: Scan(non-pointer int64)"}, {src: nil, dst: new(int64), pgtype: "bigint", wantzero: true}, {src: int64(math.MaxInt64), dst: new(int64), pgtype: "bigint"}, {src: int64(math.MaxInt64), dst: new(*int64), pgtype: "bigint"}, {src: int64(math.MinInt64), dst: new(int64), pgtype: "bigint"}, {src: nil, dst: int(0), pgtype: "bigint", wanterr: "pg: Scan(non-pointer int)"}, {src: nil, dst: new(int), pgtype: "bigint", wantzero: true}, {src: int(math.MaxInt64), dst: new(int), pgtype: "bigint"}, {src: int(math.MaxInt64), dst: new(*int), pgtype: "bigint"}, {src: int(math.MinInt32), dst: new(int), pgtype: "bigint"}, {src: nil, dst: uint8(0), pgtype: "smallint", wanterr: "pg: Scan(non-pointer uint8)"}, {src: nil, dst: new(uint8), pgtype: "smallint", wantzero: true}, {src: uint8(math.MaxUint8), dst: new(uint8), pgtype: "smallint"}, {src: uint8(math.MaxUint8), dst: new(*uint8), pgtype: "smallint"}, {src: nil, dst: uint16(0), pgtype: "smallint", wanterr: "pg: Scan(non-pointer uint16)"}, {src: nil, dst: new(uint16), pgtype: "smallint", wantzero: true}, {src: uint16(math.MaxUint16), dst: new(uint16), pgtype: "int"}, {src: uint16(math.MaxUint16), dst: new(*uint16), pgtype: "int"}, {src: nil, dst: uint32(0), pgtype: "bigint", wanterr: "pg: Scan(non-pointer uint32)"}, {src: nil, dst: new(uint32), pgtype: "bigint", wantzero: true}, {src: uint32(math.MaxUint32), dst: new(uint32), pgtype: "bigint"}, {src: uint32(math.MaxUint32), dst: new(*uint32), pgtype: "bigint"}, {src: nil, dst: uint64(0), pgtype: "bigint", wanterr: "pg: Scan(non-pointer uint64)"}, {src: nil, dst: new(uint64), pgtype: "bigint", wantzero: true}, {src: uint64(math.MaxUint64), dst: new(uint64)}, {src: uint64(math.MaxUint64), dst: new(*uint64)}, {src: uint64(math.MaxUint32), dst: new(uint64), pgtype: "bigint"}, {src: nil, dst: uint(0), pgtype: "smallint", wanterr: "pg: Scan(non-pointer uint)"}, {src: nil, dst: new(uint), pgtype: "bigint", wantzero: true}, {src: uint(math.MaxUint64), dst: new(uint)}, {src: uint(math.MaxUint64), dst: new(*uint)}, {src: uint(math.MaxUint32), dst: new(uint), pgtype: "bigint"}, {src: nil, dst: float32(0), pgtype: "decimal", wanterr: "pg: Scan(non-pointer float32)"}, {src: nil, dst: new(float32), pgtype: "decimal", wantzero: true}, {src: float32(math.MaxFloat32), dst: new(float32), pgtype: "decimal"}, {src: float32(math.MaxFloat32), dst: new(*float32), pgtype: "decimal"}, {src: float32(math.SmallestNonzeroFloat32), dst: new(float32), pgtype: "decimal"}, {src: nil, dst: float64(0), pgtype: "decimal", wanterr: "pg: Scan(non-pointer float64)"}, {src: nil, dst: new(float64), pgtype: "decimal", wantzero: true}, {src: float64(math.MaxFloat64), dst: new(float64), pgtype: "decimal"}, {src: float64(math.MaxFloat64), dst: new(*float64), pgtype: "decimal"}, {src: float64(math.SmallestNonzeroFloat64), dst: new(float64), pgtype: "decimal"}, {src: nil, dst: []int(nil), pgtype: "jsonb", wanterr: "pg: Scan(non-pointer []int)"}, {src: nil, dst: new([]int), pgtype: "jsonb", wantnil: true}, {src: []int(nil), dst: new([]int), pgtype: "jsonb", wantnil: true}, {src: []int{}, dst: new([]int), pgtype: "jsonb", wantzero: true}, {src: []int{1, 2, 3}, dst: new([]int), pgtype: "jsonb"}, {src: IntSlice{1, 2, 3}, dst: new(IntSlice), pgtype: "jsonb"}, {src: nil, dst: pg.Array([]int(nil)), pgtype: "int[]", wanterr: "pg: Scan(non-pointer []int)"}, {src: pg.Array([]int(nil)), dst: pg.Array(new([]int)), pgtype: "int[]", wantnil: true}, {src: pg.Array([]int{}), dst: pg.Array(new([]int)), pgtype: "int[]"}, {src: pg.Array([]int{1, 2, 3}), dst: pg.Array(new([]int)), pgtype: "int[]"}, {src: nil, dst: pg.Array([]int64(nil)), pgtype: "bigint[]", wanterr: "pg: Scan(non-pointer []int64)"}, {src: nil, dst: pg.Array(new([]int64)), pgtype: "bigint[]", wantnil: true}, {src: pg.Array([]int64(nil)), dst: pg.Array(new([]int64)), pgtype: "bigint[]", wantnil: true}, {src: pg.Array([]int64{}), dst: pg.Array(new([]int64)), pgtype: "bigint[]"}, {src: pg.Array([]int64{1, 2, 3}), dst: pg.Array(new([]int64)), pgtype: "bigint[]"}, {src: nil, dst: pg.Array([]float64(nil)), pgtype: "decimal[]", wanterr: "pg: Scan(non-pointer []float64)"}, {src: nil, dst: pg.Array(new([]float64)), pgtype: "decimal[]", wantnil: true}, {src: pg.Array([]float64(nil)), dst: pg.Array(new([]float64)), pgtype: "decimal[]", wantnil: true}, {src: pg.Array([]float64{}), dst: pg.Array(new([]float64)), pgtype: "decimal[]"}, {src: pg.Array([]float64{1.1, 2.22, 3.333}), dst: pg.Array(new([]float64)), pgtype: "decimal[]"}, {src: nil, dst: pg.Array([]string(nil)), pgtype: "text[]", wanterr: "pg: Scan(non-pointer []string)"}, {src: nil, dst: pg.Array(new([]string)), pgtype: "text[]", wantnil: true}, {src: pg.Array([]string(nil)), dst: pg.Array(new([]string)), pgtype: "text[]", wantnil: true}, {src: pg.Array([]string{}), dst: pg.Array(new([]string)), pgtype: "text[]"}, {src: pg.Array([]string{"one", "two", "three"}), dst: pg.Array(new([]string)), pgtype: "text[]"}, {src: pg.Array([]string{`'"{}`}), dst: pg.Array(new([]string)), pgtype: "text[]"}, {src: nil, dst: pg.Array([][]string(nil)), pgtype: "text[][]", wanterr: "pg: Scan(non-pointer [][]string)"}, {src: nil, dst: pg.Array(new([][]string)), pgtype: "text[][]", wantnil: true}, {src: pg.Array([][]string(nil)), dst: pg.Array(new([]string)), pgtype: "text[][]", wantnil: true}, {src: pg.Array([][]string{}), dst: pg.Array(new([][]string)), pgtype: "text[][]"}, {src: pg.Array([][]string{{"one", "two"}, {"three", "four"}}), dst: pg.Array(new([][]string)), pgtype: "text[][]"}, {src: pg.Array([][]string{{`'"\{}`}}), dst: pg.Array(new([][]string)), pgtype: "text[][]"}, {src: nil, dst: pg.Hstore(map[string]string(nil)), pgtype: "hstore", wanterr: "pg: Scan(non-pointer map[string]string)"}, {src: nil, dst: pg.Hstore(new(map[string]string)), pgtype: "hstore", wantnil: true}, {src: pg.Hstore(map[string]string(nil)), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore", wantnil: true}, {src: pg.Hstore(map[string]string{}), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore"}, {src: pg.Hstore(map[string]string{"foo": "bar"}), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore"}, {src: pg.Hstore(map[string]string{`'"\{}=>`: `'"\{}=>`}), dst: pg.Hstore(new(map[string]string)), pgtype: "hstore"}, {src: nil, dst: sql.NullBool{}, pgtype: "bool", wanterr: "pg: Scan(non-pointer sql.NullBool)"}, {src: nil, dst: new(*sql.NullBool), pgtype: "bool", wantnil: true}, {src: nil, dst: new(sql.NullBool), pgtype: "bool", wanted: sql.NullBool{}}, {src: &sql.NullBool{}, dst: new(sql.NullBool), pgtype: "bool"}, {src: &sql.NullBool{Valid: true}, dst: new(sql.NullBool), pgtype: "bool"}, {src: &sql.NullBool{Valid: true, Bool: true}, dst: new(sql.NullBool), pgtype: "bool"}, {src: &sql.NullString{}, dst: new(sql.NullString), pgtype: "text"}, {src: &sql.NullString{Valid: true}, dst: new(sql.NullString), pgtype: "text"}, {src: &sql.NullString{Valid: true, String: "foo"}, dst: new(sql.NullString), pgtype: "text"}, {src: &sql.NullInt64{}, dst: new(sql.NullInt64), pgtype: "bigint"}, {src: &sql.NullInt64{Valid: true}, dst: new(sql.NullInt64), pgtype: "bigint"}, {src: &sql.NullInt64{Valid: true, Int64: math.MaxInt64}, dst: new(sql.NullInt64), pgtype: "bigint"}, {src: &sql.NullFloat64{}, dst: new(sql.NullFloat64), pgtype: "decimal"}, {src: &sql.NullFloat64{Valid: true}, dst: new(sql.NullFloat64), pgtype: "decimal"}, {src: &sql.NullFloat64{Valid: true, Float64: math.MaxFloat64}, dst: new(sql.NullFloat64), pgtype: "decimal"}, {src: nil, dst: customStrSlice{}, wanterr: "pg: Scan(non-pointer pg_test.customStrSlice)"}, {src: nil, dst: new(customStrSlice), wantnil: true}, {src: nil, dst: new(*customStrSlice), wantnil: true}, {src: customStrSlice{}, dst: new(customStrSlice), wantzero: true}, {src: customStrSlice{"one", "two"}, dst: new(customStrSlice)}, {src: nil, dst: time.Time{}, pgtype: "timestamp", wanterr: "pg: Scan(non-pointer time.Time)"}, {src: nil, dst: new(time.Time), pgtype: "timestamp", wantzero: true}, {src: nil, dst: new(*time.Time), pgtype: "timestamp", wantnil: true}, {src: time.Now(), dst: new(time.Time), pgtype: "timestamp"}, {src: time.Now(), dst: new(*time.Time), pgtype: "timestamp"}, {src: time.Now().UTC(), dst: new(time.Time), pgtype: "timestamp"}, {src: time.Time{}, dst: new(time.Time), pgtype: "timestamp"}, {src: nil, dst: new(time.Time), pgtype: "timestamptz", wantzero: true}, {src: nil, dst: new(*time.Time), pgtype: "timestamptz", wantnil: true}, {src: time.Now(), dst: new(time.Time), pgtype: "timestamptz"}, {src: time.Now(), dst: new(*time.Time), pgtype: "timestamptz"}, {src: time.Now().UTC(), dst: new(time.Time), pgtype: "timestamptz"}, {src: time.Time{}, dst: new(time.Time), pgtype: "timestamptz"}, {src: nil, dst: pg.Array([]time.Time(nil)), pgtype: "timestamptz[]", wanterr: "pg: Scan(non-pointer []time.Time)"}, {src: nil, dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]", wantnil: true}, {src: pg.Array([]time.Time(nil)), dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]", wantnil: true}, {src: pg.Array([]time.Time{}), dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]"}, {src: pg.Array([]time.Time{time.Now(), time.Now(), time.Now()}), dst: pg.Array(new([]time.Time)), pgtype: "timestamptz[]"}, {src: nil, dst: pg.Ints{}, wanterr: "pg: Scan(non-pointer pg.Ints)"}, {src: 1, dst: new(pg.Ints), wanted: pg.Ints{1}}, {src: nil, dst: pg.Strings{}, wanterr: "pg: Scan(non-pointer pg.Strings)"}, {src: "hello", dst: new(pg.Strings), wanted: pg.Strings{"hello"}}, {src: nil, dst: pg.IntSet{}, wanterr: "pg: Scan(non-pointer pg.IntSet)"}, {src: 1, dst: new(pg.IntSet), wanted: pg.IntSet{1: struct{}{}}}, {src: nil, dst: JSONMap{}, pgtype: "json", wanterr: "pg: Scan(non-pointer pg_test.JSONMap)"}, {src: nil, dst: new(JSONMap), pgtype: "json", wantnil: true}, {src: nil, dst: new(*JSONMap), pgtype: "json", wantnil: true}, {src: JSONMap{}, dst: new(JSONMap), pgtype: "json"}, {src: JSONMap{}, dst: new(*JSONMap), pgtype: "json"}, {src: JSONMap{"foo": "bar"}, dst: new(JSONMap), pgtype: "json"}, {src: `{"foo": "bar"}`, dst: new(JSONMap), pgtype: "json", wanted: JSONMap{"foo": "bar"}}, {src: nil, dst: Struct{}, pgtype: "json", wanterr: "pg: Scan(non-pointer pg_test.Struct)"}, {src: nil, dst: new(*Struct), pgtype: "json", wantnil: true}, {src: nil, dst: new(Struct), pgtype: "json", wantzero: true}, {src: Struct{}, dst: new(Struct), pgtype: "json"}, {src: Struct{Foo: "bar"}, dst: new(Struct), pgtype: "json"}, {src: `{"foo": "bar"}`, dst: new(Struct), wanted: Struct{Foo: "bar"}}, } } func TestConversion(t *testing.T) { db := pg.Connect(pgOptions()) for i, test := range conversionTests() { test.i = i var scanner orm.ColumnScanner if v, ok := test.dst.(orm.ColumnScanner); ok { scanner = v } else { scanner = pg.Scan(test.dst) } _, err := db.QueryOne(scanner, "SELECT (?) AS dst", test.src) test.Assert(t, err) } for i, test := range conversionTests() { test.i = i var scanner orm.ColumnScanner if v, ok := test.dst.(orm.ColumnScanner); ok { scanner = v } else { scanner = pg.Scan(test.dst) } err := db.Model().ColumnExpr("(?) AS dst", test.src).Select(scanner) test.Assert(t, err) } for i, test := range conversionTests() { test.i = i if test.pgtype == "" { continue } stmt, err := db.Prepare(fmt.Sprintf("SELECT ($1::%s) AS dst", test.pgtype)) if err != nil { t.Fatal(err) } var scanner orm.ColumnScanner if v, ok := test.dst.(orm.ColumnScanner); ok { scanner = v } else { scanner = pg.Scan(test.dst) } _, err = stmt.QueryOne(scanner, test.src) test.Assert(t, err) if err := stmt.Close(); err != nil { t.Fatal(err) } } } pg-5.3.3/db.go000066400000000000000000000215401305650307100130460ustar00rootroot00000000000000package pg import ( "fmt" "io" "time" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/internal/pool" "gopkg.in/pg.v5/orm" "gopkg.in/pg.v5/types" ) // Connect connects to a database using provided options. // // The returned DB is safe for concurrent use by multiple goroutines // and maintains its own connection pool. func Connect(opt *Options) *DB { opt.init() return &DB{ opt: opt, pool: newConnPool(opt), } } // DB is a database handle representing a pool of zero or more // underlying connections. It's safe for concurrent use by multiple // goroutines. type DB struct { opt *Options pool *pool.ConnPool fmter orm.Formatter } var _ orm.DB = (*DB)(nil) func (db *DB) String() string { return fmt.Sprintf("DB", db.opt.Addr, db.fmter) } // Options returns read-only Options that were used to connect to the DB. func (db *DB) Options() *Options { return db.opt } // WithTimeout returns a DB that uses d as the read/write timeout. func (db *DB) WithTimeout(d time.Duration) *DB { newopt := *db.opt newopt.ReadTimeout = d newopt.WriteTimeout = d return &DB{ opt: &newopt, pool: db.pool, fmter: db.fmter, } } // WithParam returns a DB that replaces the param with the value in queries. func (db *DB) WithParam(param string, value interface{}) *DB { return &DB{ opt: db.opt, pool: db.pool, fmter: db.fmter.WithParam(param, value), } } func (db *DB) conn() (*pool.Conn, error) { cn, _, err := db.pool.Get() if err != nil { return nil, err } cn.SetReadWriteTimeout(db.opt.ReadTimeout, db.opt.WriteTimeout) if cn.InitedAt.IsZero() { if err := db.initConn(cn); err != nil { _ = db.pool.Remove(cn, err) return nil, err } cn.InitedAt = time.Now() } return cn, nil } func (db *DB) initConn(cn *pool.Conn) error { if db.opt.TLSConfig != nil { if err := enableSSL(cn, db.opt.TLSConfig); err != nil { return err } } err := startup(cn, db.opt.User, db.opt.Password, db.opt.Database) if err != nil { return err } return nil } func (db *DB) freeConn(cn *pool.Conn, err error) error { if !isBadConn(err, false) { return db.pool.Put(cn) } return db.pool.Remove(cn, err) } func (db *DB) shouldRetry(err error) bool { if err == nil { return false } if pgerr, ok := err.(Error); ok { switch pgerr.Field('C') { case "40001": // serialization_failure return true case "55000": // attempted to delete invisible tuple return true case "57014": // statement_timeout return db.opt.RetryStatementTimeout default: return false } } return isNetworkError(err) } // Close closes the database client, releasing any open resources. // // It is rare to Close a DB, as the DB handle is meant to be // long-lived and shared between many goroutines. func (db *DB) Close() error { st := db.pool.Stats() if st.TotalConns != st.FreeConns { internal.Logf( "connection leaking detected: total_conns=%d free_conns=%d", st.TotalConns, st.FreeConns, ) } return db.pool.Close() } // Exec executes a query ignoring returned rows. The params are for any // placeholders in the query. func (db *DB) Exec(query interface{}, params ...interface{}) (res *types.Result, err error) { for i := 0; ; i++ { var cn *pool.Conn cn, err = db.conn() if err != nil { return nil, err } res, err = db.simpleQuery(cn, query, params...) db.freeConn(cn, err) if i >= db.opt.MaxRetries { break } if !db.shouldRetry(err) { break } time.Sleep(internal.RetryBackoff << uint(i)) } return res, err } // ExecOne acts like Exec, but query must affect only one row. It // returns ErrNoRows error when query returns zero rows or // ErrMultiRows when query returns multiple rows. func (db *DB) ExecOne(query interface{}, params ...interface{}) (*types.Result, error) { res, err := db.Exec(query, params...) if err != nil { return nil, err } if err := internal.AssertOneRow(res.RowsAffected()); err != nil { return nil, err } return res, nil } // Query executes a query that returns rows, typically a SELECT. // The params are for any placeholders in the query. func (db *DB) Query(model, query interface{}, params ...interface{}) (res *types.Result, err error) { var mod orm.Model for i := 0; i < 3; i++ { var cn *pool.Conn cn, err = db.conn() if err != nil { return nil, err } res, mod, err = db.simpleQueryData(cn, model, query, params...) db.freeConn(cn, err) if i >= db.opt.MaxRetries { break } if !db.shouldRetry(err) { break } time.Sleep(internal.RetryBackoff << uint(i)) } if err != nil { return nil, err } if res.RowsReturned() > 0 && mod != nil { if err = mod.AfterQuery(db); err != nil { return res, err } } return res, nil } // QueryOne acts like Query, but query must return only one row. It // returns ErrNoRows error when query returns zero rows or // ErrMultiRows when query returns multiple rows. func (db *DB) QueryOne(model, query interface{}, params ...interface{}) (*types.Result, error) { mod, err := orm.NewModel(model) if err != nil { return nil, err } res, err := db.Query(mod, query, params...) if err != nil { return nil, err } if err := internal.AssertOneRow(res.RowsAffected()); err != nil { return nil, err } return res, nil } // Listen listens for notifications sent with NOTIFY command. func (db *DB) Listen(channels ...string) *Listener { ln := &Listener{ db: db, } _ = ln.Listen(channels...) return ln } // CopyFrom copies data from reader to a table. func (db *DB) CopyFrom(reader io.Reader, query interface{}, params ...interface{}) (*types.Result, error) { cn, err := db.conn() if err != nil { return nil, err } res, err := db.copyFrom(cn, reader, query, params...) db.freeConn(cn, err) return res, err } // CopyTo copies data from a table to writer. func (db *DB) CopyTo(writer io.Writer, query interface{}, params ...interface{}) (*types.Result, error) { cn, err := db.conn() if err != nil { return nil, err } if err := writeQueryMsg(cn.Wr, db, query, params...); err != nil { db.pool.Put(cn) return nil, err } if err := cn.FlushWriter(); err != nil { db.freeConn(cn, err) return nil, err } if err := readCopyOutResponse(cn); err != nil { db.freeConn(cn, err) return nil, err } res, err := readCopyData(cn, writer) if err != nil { db.freeConn(cn, err) return nil, err } db.pool.Put(cn) return res, nil } // Model returns new query for the model. func (db *DB) Model(model ...interface{}) *orm.Query { return orm.NewQuery(db, model...) } // Select selects the model by primary key. func (db *DB) Select(model interface{}) error { return orm.Select(db, model) } // Insert inserts the model updating primary keys if they are empty. func (db *DB) Insert(model ...interface{}) error { return orm.Insert(db, model...) } // Update updates the model by primary key. func (db *DB) Update(model interface{}) error { return orm.Update(db, model) } // Delete deletes the model by primary key. func (db *DB) Delete(model interface{}) error { return orm.Delete(db, model) } // CreateTable creates table for the model. It recognizes following field tags: // - notnull - sets NOT NULL constraint. // - unique - sets UNIQUE constraint. func (db *DB) CreateTable(model interface{}, opt *orm.CreateTableOptions) error { _, err := orm.CreateTable(db, model, opt) return err } func (db *DB) FormatQuery(dst []byte, query string, params ...interface{}) []byte { return db.fmter.Append(dst, query, params...) } func (db *DB) cancelRequest(processId, secretKey int32) error { cn, err := db.pool.NewConn() if err != nil { return err } writeCancelRequestMsg(cn.Wr, processId, secretKey) if err = cn.FlushWriter(); err != nil { return err } cn.Close() return nil } func (db *DB) simpleQuery( cn *pool.Conn, query interface{}, params ...interface{}, ) (*types.Result, error) { if err := writeQueryMsg(cn.Wr, db, query, params...); err != nil { return nil, err } if err := cn.FlushWriter(); err != nil { return nil, err } return readSimpleQuery(cn) } func (db *DB) simpleQueryData( cn *pool.Conn, model, query interface{}, params ...interface{}, ) (*types.Result, orm.Model, error) { if err := writeQueryMsg(cn.Wr, db, query, params...); err != nil { return nil, nil, err } if err := cn.FlushWriter(); err != nil { return nil, nil, err } return readSimpleQueryData(cn, model) } func (db *DB) copyFrom(cn *pool.Conn, r io.Reader, query interface{}, params ...interface{}) (*types.Result, error) { if err := writeQueryMsg(cn.Wr, db, query, params...); err != nil { return nil, err } if err := cn.FlushWriter(); err != nil { return nil, err } if err := readCopyInResponse(cn); err != nil { return nil, err } for { if _, err := writeCopyData(cn.Wr, r); err != nil { if err == io.EOF { break } return nil, err } if err := cn.FlushWriter(); err != nil { return nil, err } } writeCopyDone(cn.Wr) if err := cn.FlushWriter(); err != nil { return nil, err } return readReadyForQuery(cn) } pg-5.3.3/db_test.go000066400000000000000000001007651305650307100141140ustar00rootroot00000000000000package pg_test import ( "bytes" "crypto/tls" "database/sql" "fmt" "net" "testing" "time" "gopkg.in/pg.v5" "gopkg.in/pg.v5/orm" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) func init() { //pg.SetLogger(log.New(os.Stderr, "pg: ", log.LstdFlags)) } func TestGinkgo(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "pg") } func pgOptions() *pg.Options { return &pg.Options{ User: "postgres", Database: "postgres", TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, DialTimeout: 30 * time.Second, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, PoolSize: 10, PoolTimeout: 30 * time.Second, IdleTimeout: 10 * time.Second, MaxAge: 10 * time.Second, IdleCheckFrequency: 100 * time.Millisecond, } } func TestDBString(t *testing.T) { db := pg.Connect(pgOptions()) wanted := `DB` if db.String() != wanted { t.Fatalf("got %q, wanted %q", db.String(), wanted) } db = db.WithParam("param1", "value1").WithParam("param2", 2) wanted = `DB` if db.String() != wanted { t.Fatalf("got %q, wanted %q", db.String(), wanted) } } var _ = Describe("Time", func() { var tests = []struct { str string wanted time.Time }{ {"0001-01-01 00:00:00+00", time.Time{}}, {"0000-01-01 00:00:00+00", time.Date(0, time.January, 1, 0, 0, 0, 0, time.UTC)}, {"2001-02-03", time.Date(2001, time.February, 3, 0, 0, 0, 0, time.UTC)}, {"2001-02-03 04:05:06", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.Local)}, {"2001-02-03 04:05:06.000001", time.Date(2001, time.February, 3, 4, 5, 6, 1000, time.Local)}, {"2001-02-03 04:05:06.00001", time.Date(2001, time.February, 3, 4, 5, 6, 10000, time.Local)}, {"2001-02-03 04:05:06.0001", time.Date(2001, time.February, 3, 4, 5, 6, 100000, time.Local)}, {"2001-02-03 04:05:06.001", time.Date(2001, time.February, 3, 4, 5, 6, 1000000, time.Local)}, {"2001-02-03 04:05:06.01", time.Date(2001, time.February, 3, 4, 5, 6, 10000000, time.Local)}, {"2001-02-03 04:05:06.1", time.Date(2001, time.February, 3, 4, 5, 6, 100000000, time.Local)}, {"2001-02-03 04:05:06.12", time.Date(2001, time.February, 3, 4, 5, 6, 120000000, time.Local)}, {"2001-02-03 04:05:06.123", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, time.Local)}, {"2001-02-03 04:05:06.1234", time.Date(2001, time.February, 3, 4, 5, 6, 123400000, time.Local)}, {"2001-02-03 04:05:06.12345", time.Date(2001, time.February, 3, 4, 5, 6, 123450000, time.Local)}, {"2001-02-03 04:05:06.123456", time.Date(2001, time.February, 3, 4, 5, 6, 123456000, time.Local)}, {"2001-02-03 04:05:06.123-07", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", -7*60*60))}, {"2001-02-03 04:05:06-07", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -7*60*60))}, {"2001-02-03 04:05:06-07:42", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+42*60)))}, {"2001-02-03 04:05:06-07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9)))}, {"2001-02-03 04:05:06+07", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 7*60*60))}, } var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) }) AfterEach(func() { Expect(db.Close()).NotTo(HaveOccurred()) }) It("is formatted correctly", func() { for i, test := range tests { var tm time.Time _, err := db.QueryOne(pg.Scan(&tm), "SELECT ?", test.wanted) Expect(err).NotTo(HaveOccurred()) Expect(tm.Unix()).To(Equal(test.wanted.Unix()), "#%d str=%q wanted=%q", i, test.str, test.wanted) } }) It("is parsed correctly", func() { for i, test := range tests { var tm time.Time _, err := db.QueryOne(pg.Scan(&tm), "SELECT ?", test.str) Expect(err).NotTo(HaveOccurred()) Expect(tm.Unix()).To(Equal(test.wanted.Unix()), "#%d str=%q wanted=%q", i, test.str, test.wanted) } }) }) var _ = Describe("slice model", func() { type value struct { Id int } var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) }) AfterEach(func() { Expect(db.Close()).NotTo(HaveOccurred()) }) It("does not error when there are no rows", func() { var ints []int _, err := db.Query(&ints, "SELECT generate_series(1, 0)") Expect(err).NotTo(HaveOccurred()) Expect(ints).To(BeZero()) }) It("does not error when there are no rows", func() { var slice []value _, err := db.Query(&slice, "SELECT generate_series(1, 0)") Expect(err).NotTo(HaveOccurred()) Expect(slice).To(BeZero()) }) It("does not error when there are no rows", func() { var slice []*value _, err := db.Query(&slice, "SELECT generate_series(1, 0)") Expect(err).NotTo(HaveOccurred()) Expect(slice).To(BeZero()) }) It("supports slice of structs", func() { var slice []value _, err := db.Query(&slice, `SELECT generate_series(1, 3) AS id`) Expect(err).NotTo(HaveOccurred()) Expect(slice).To(Equal([]value{{1}, {2}, {3}})) }) It("supports slice of pointers", func() { var slice []*value _, err := db.Query(&slice, `SELECT generate_series(1, 3) AS id`) Expect(err).NotTo(HaveOccurred()) Expect(slice).To(Equal([]*value{{1}, {2}, {3}})) }) It("supports Ints", func() { var ints pg.Ints _, err := db.Query(&ints, `SELECT generate_series(1, 3)`) Expect(err).NotTo(HaveOccurred()) Expect(ints).To(Equal(pg.Ints{1, 2, 3})) }) It("supports slice of ints", func() { var ints []int _, err := db.Query(&ints, `SELECT generate_series(1, 3)`) Expect(err).NotTo(HaveOccurred()) Expect(ints).To(Equal([]int{1, 2, 3})) }) It("supports slice of time.Time", func() { var times []time.Time _, err := db.Query(×, ` WITH data (time) AS (VALUES (clock_timestamp()), (clock_timestamp())) SELECT time FROM data `) Expect(err).NotTo(HaveOccurred()) Expect(times).To(HaveLen(2)) }) It("resets slice", func() { ints := []int{1, 2, 3} _, err := db.Query(&ints, `SELECT 1`) Expect(err).NotTo(HaveOccurred()) Expect(ints).To(Equal([]int{1})) }) It("resets slice when there are no results", func() { ints := []int{1, 2, 3} _, err := db.Query(&ints, `SELECT 1 WHERE FALSE`) Expect(err).NotTo(HaveOccurred()) Expect(ints).To(BeEmpty()) }) }) var _ = Describe("read/write timeout", func() { var db *pg.DB BeforeEach(func() { opt := pgOptions() opt.ReadTimeout = time.Millisecond db = pg.Connect(opt) }) AfterEach(func() { Expect(db.Close()).NotTo(HaveOccurred()) }) It("slow query timeouts", func() { _, err := db.Exec(`SELECT pg_sleep(1)`) Expect(err.(net.Error).Timeout()).To(BeTrue()) }) Context("WithTimeout", func() { It("slow query passes", func() { _, err := db.WithTimeout(time.Minute).Exec(`SELECT pg_sleep(1)`) Expect(err).NotTo(HaveOccurred()) }) }) }) var _ = Describe("CopyFrom/CopyTo", func() { const n = 1000000 var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) qs := []string{ "CREATE TEMP TABLE copy_from(n int)", "CREATE TEMP TABLE copy_to(n int)", fmt.Sprintf("INSERT INTO copy_from SELECT generate_series(1, %d)", n), } for _, q := range qs { _, err := db.Exec(q) Expect(err).NotTo(HaveOccurred()) } }) AfterEach(func() { err := db.Close() Expect(err).NotTo(HaveOccurred()) }) It("copies data from a table and to a table", func() { var buf bytes.Buffer res, err := db.CopyTo(&buf, "COPY copy_from TO STDOUT") Expect(err).NotTo(HaveOccurred()) Expect(res.RowsAffected()).To(Equal(n)) res, err = db.CopyFrom(&buf, "COPY copy_to FROM STDIN") Expect(err).NotTo(HaveOccurred()) Expect(res.RowsAffected()).To(Equal(n)) var count int _, err = db.QueryOne(pg.Scan(&count), "SELECT count(*) FROM copy_to") Expect(err).NotTo(HaveOccurred()) Expect(count).To(Equal(n)) st := db.Pool().Stats() Expect(st.Requests).To(Equal(uint32(6))) Expect(st.Hits).To(Equal(uint32(5))) Expect(st.Timeouts).To(Equal(uint32(0))) Expect(st.TotalConns).To(Equal(uint32(1))) Expect(st.FreeConns).To(Equal(uint32(1))) }) It("copies corrupted data to a table", func() { buf := bytes.NewBufferString("corrupted data") res, err := db.CopyFrom(buf, "COPY copy_to FROM STDIN") Expect(err).To(MatchError(`ERROR #22P02 invalid input syntax for integer: "corrupted data" (addr="127.0.0.1:5432")`)) Expect(res).To(BeNil()) st := db.Pool().Stats() Expect(st.Requests).To(Equal(uint32(4))) Expect(st.Hits).To(Equal(uint32(3))) Expect(st.Timeouts).To(Equal(uint32(0))) Expect(st.TotalConns).To(Equal(uint32(1))) Expect(st.FreeConns).To(Equal(uint32(1))) }) }) var _ = Describe("CountEstimate", func() { var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) }) It("works", func() { count, err := db.Model(). TableExpr("generate_series(1, 10)"). CountEstimate(1000) Expect(err).NotTo(HaveOccurred()) Expect(count).To(Equal(10)) }) It("works when there are no results", func() { count, err := db.Model(). TableExpr("generate_series(1, 0)"). CountEstimate(1000) Expect(err).NotTo(HaveOccurred()) Expect(count).To(Equal(0)) }) It("works with GROUP", func() { count, err := db.Model(). TableExpr("generate_series(1, 10)"). Group("generate_series"). CountEstimate(1000) Expect(err).NotTo(HaveOccurred()) Expect(count).To(Equal(10)) }) It("works with GROUP when there are no results", func() { count, err := db.Model(). TableExpr("generate_series(1, 0)"). Group("generate_series"). CountEstimate(1000) Expect(err).NotTo(HaveOccurred()) Expect(count).To(Equal(0)) }) }) var _ = Describe("DB nulls", func() { var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) _, err := db.Exec("CREATE TEMP TABLE tests (id int, value int)") Expect(err).To(BeNil()) }) AfterEach(func() { err := db.Close() Expect(err).NotTo(HaveOccurred()) }) Describe("sql.NullInt64", func() { type Test struct { Id int Value sql.NullInt64 } It("inserts null value", func() { ins := Test{ Id: 1, } err := db.Insert(&ins) Expect(err).NotTo(HaveOccurred()) sel := Test{ Id: 1, } err = db.Select(&sel) Expect(err).NotTo(HaveOccurred()) Expect(sel.Value.Valid).To(BeFalse()) }) It("inserts non-null value", func() { ins := Test{ Id: 1, Value: sql.NullInt64{ Int64: 2, Valid: true, }, } err := db.Insert(&ins) Expect(err).NotTo(HaveOccurred()) sel := Test{ Id: 1, } err = db.Select(&sel) Expect(err).NotTo(HaveOccurred()) Expect(sel.Value.Valid).To(BeTrue()) Expect(sel.Value.Int64).To(Equal(int64(2))) }) }) Context("nil ptr", func() { type Test struct { Id int Value *int } It("inserts null value", func() { ins := Test{ Id: 1, } err := db.Insert(&ins) Expect(err).NotTo(HaveOccurred()) sel := Test{ Id: 1, } err = db.Select(&sel) Expect(err).NotTo(HaveOccurred()) Expect(sel.Value).To(BeNil()) }) It("inserts non-null value", func() { value := 2 ins := Test{ Id: 1, Value: &value, } err := db.Insert(&ins) Expect(err).NotTo(HaveOccurred()) sel := Test{ Id: 1, } err = db.Select(&sel) Expect(err).NotTo(HaveOccurred()) Expect(sel.Value).NotTo(BeNil()) Expect(*sel.Value).To(Equal(2)) }) }) }) var _ = Describe("DB.Select", func() { var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) qs := []string{ `CREATE TEMP TABLE tests (col bytea)`, fmt.Sprintf(`INSERT INTO tests VALUES ('\x%x')`, []byte("bytes")), } for _, q := range qs { _, err := db.Exec(q) Expect(err).NotTo(HaveOccurred()) } }) AfterEach(func() { err := db.Close() Expect(err).NotTo(HaveOccurred()) }) It("selects bytea", func() { var col []byte err := db.Model().Table("tests").Column("col").Select(pg.Scan(&col)) Expect(err).NotTo(HaveOccurred()) }) }) var _ = Describe("DB.Insert", func() { var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) }) AfterEach(func() { err := db.Close() Expect(err).NotTo(HaveOccurred()) }) It("returns an error on nil", func() { err := db.Insert(nil) Expect(err).To(MatchError("pg: Model(nil)")) }) It("returns an errors if value is not settable", func() { err := db.Insert(1) Expect(err).To(MatchError("pg: Model(non-pointer int)")) }) It("returns an errors if value is not supported", func() { var v int err := db.Insert(&v) Expect(err).To(MatchError("pg: Model(unsupported int)")) }) }) var _ = Describe("DB.Update", func() { var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) }) AfterEach(func() { err := db.Close() Expect(err).NotTo(HaveOccurred()) }) It("returns an error on nil", func() { err := db.Update(nil) Expect(err).To(MatchError("pg: Model(nil)")) }) It("returns an error if there are no pks", func() { type Test struct{} var test Test err := db.Update(&test) Expect(err).To(MatchError(`model=Test does not have primary keys`)) }) }) var _ = Describe("DB.Delete", func() { var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) }) AfterEach(func() { err := db.Close() Expect(err).NotTo(HaveOccurred()) }) It("returns an error on nil", func() { err := db.Delete(nil) Expect(err).To(MatchError("pg: Model(nil)")) }) It("returns an error if there are no pks", func() { type Test struct{} var test Test err := db.Delete(&test) Expect(err).To(MatchError(`model=Test does not have primary keys`)) }) }) var _ = Describe("errors", func() { var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) }) AfterEach(func() { err := db.Close() Expect(err).NotTo(HaveOccurred()) }) It("unknown column error", func() { type Test struct { Col1 int } var test Test _, err := db.QueryOne(&test, "SELECT 1 AS col1, 2 AS col2") Expect(err).To(MatchError("pg: can't find column=col2 in model=Test")) Expect(test.Col1).To(Equal(1)) }) It("Scan error", func() { var n1 int _, err := db.QueryOne(pg.Scan(&n1), "SELECT 1, 2") Expect(err).To(MatchError("pg: no Scan value for column index=1 name=?column?")) Expect(n1).To(Equal(1)) }) }) type Genre struct { // tableName is an optional field that specifies custom table name and alias. // By default go-pg generates table name and alias from struct name. tableName struct{} `sql:"genres,alias:genre"` // default values are the same Id int // Id is automatically detected as primary key Name string Rating int `sql:"-"` // - is used to ignore field Books []Book `pg:",many2many:book_genres"` // many to many relation ParentId int Subgenres []Genre `pg:",fk:Parent"` // fk specifies prefix for foreign key (ParentId) } func (g Genre) String() string { return fmt.Sprintf("Genre", g.Id, g.Name) } type Author struct { ID int // both "Id" and "ID" are detected as primary key Name string Books []*Book // has many relation } func (a Author) String() string { return fmt.Sprintf("Author", a.ID, a.Name) } type BookGenre struct { tableName struct{} `sql:",alias:bg"` // custom table alias BookId int `sql:",pk"` // pk tag is used to mark field as primary key GenreId int `sql:",pk"` Genre_Rating int // belongs to and is copied to Genre model } type Book struct { Id int Title string AuthorID int Author *Author // has one relation EditorID int Editor *Author // has one relation CreatedAt time.Time Genres []Genre `pg:",many2many:book_genres" gorm:"many2many:book_genres;"` // many to many relation Translations []Translation // has many relation Comments []Comment `pg:",polymorphic:Trackable"` // has many polymorphic relation } func (b Book) String() string { return fmt.Sprintf("Book", b.Id, b.Title) } func (b *Book) BeforeInsert(db orm.DB) error { if b.CreatedAt.IsZero() { b.CreatedAt = time.Now() } return nil } // BookWithCommentCount is like Book model, but has additional CommentCount // field that is used to select data into it. The use of `pg:",override"` tag // is essential here and it overrides internal model properties such as table name. type BookWithCommentCount struct { Book `pg:",override"` CommentCount int } type Translation struct { tableName struct{} `sql:",alias:tr"` // custom table alias Id int BookId int Book *Book // belongs to relation Lang string Comments []Comment `pg:",polymorphic:Trackable"` // has many polymorphic relation } type Comment struct { TrackableId int // Book.Id or Translation.Id TrackableType string // "Book" or "Translation" Text string } func createTestSchema(db *pg.DB) error { sql := []string{ `DROP TABLE IF EXISTS comments`, `DROP TABLE IF EXISTS translations`, `DROP TABLE IF EXISTS authors`, `DROP TABLE IF EXISTS books`, `DROP TABLE IF EXISTS genres`, `DROP TABLE IF EXISTS book_genres`, } for _, q := range sql { _, err := db.Exec(q) if err != nil { return err } } tables := []interface{}{ &Author{}, &Book{}, &Genre{}, &BookGenre{}, &Translation{}, &Comment{}, } for _, table := range tables { err := db.CreateTable(table, nil) if err != nil { return err } } _, err := db.Exec(`CREATE UNIQUE INDEX authors_name ON authors (name)`) return err } var _ = Describe("ORM", func() { var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) err := createTestSchema(db) Expect(err).NotTo(HaveOccurred()) genres := []Genre{{ Id: 1, Name: "genre 1", }, { Id: 2, Name: "genre 2", }, { Id: 3, Name: "subgenre 1", ParentId: 1, }, { Id: 4, Name: "subgenre 2", ParentId: 1, }} err = db.Insert(&genres) Expect(err).NotTo(HaveOccurred()) Expect(genres).To(HaveLen(4)) authors := []Author{{ ID: 10, Name: "author 1", }, { ID: 11, Name: "author 2", }, Author{ ID: 12, Name: "author 3", }} err = db.Insert(&authors) Expect(err).NotTo(HaveOccurred()) Expect(authors).To(HaveLen(3)) books := []Book{{ Id: 100, Title: "book 1", AuthorID: 10, EditorID: 11, }, { Id: 101, Title: "book 2", AuthorID: 10, EditorID: 12, }, Book{ Id: 102, Title: "book 3", AuthorID: 11, EditorID: 11, }} err = db.Insert(&books) Expect(err).NotTo(HaveOccurred()) Expect(books).To(HaveLen(3)) for _, book := range books { Expect(book.CreatedAt).To(BeTemporally("~", time.Now(), time.Second)) } bookGenres := []BookGenre{{ BookId: 100, GenreId: 1, Genre_Rating: 999, }, { BookId: 100, GenreId: 2, Genre_Rating: 9999, }, { BookId: 101, GenreId: 1, Genre_Rating: 99999, }} err = db.Insert(&bookGenres) Expect(err).NotTo(HaveOccurred()) Expect(bookGenres).To(HaveLen(3)) translations := []Translation{{ Id: 1000, BookId: 100, Lang: "ru", }, { Id: 1001, BookId: 100, Lang: "md", }, { Id: 1002, BookId: 101, Lang: "ua", }} err = db.Insert(&translations) Expect(err).NotTo(HaveOccurred()) Expect(translations).To(HaveLen(3)) comments := []Comment{{ TrackableId: 100, TrackableType: "Book", Text: "comment1", }, { TrackableId: 100, TrackableType: "Book", Text: "comment2", }, { TrackableId: 1000, TrackableType: "Translation", Text: "comment3", }} err = db.Insert(&comments) Expect(err).NotTo(HaveOccurred()) Expect(comments).To(HaveLen(3)) }) Describe("struct model", func() { It("supports HasOne, HasMany, HasMany2Many, Polymorphic, HasMany -> Polymorphic", func() { var book Book err := db.Model(&book). Column("book.id", "Author", "Editor", "Genres", "Comments", "Translations", "Translations.Comments"). First() Expect(err).NotTo(HaveOccurred()) Expect(book).To(Equal(Book{ Id: 100, Title: "", Author: &Author{ID: 10, Name: "author 1", Books: nil}, Editor: &Author{ID: 11, Name: "author 2", Books: nil}, CreatedAt: time.Time{}, Genres: []Genre{ {Id: 1, Name: "genre 1", Rating: 999}, {Id: 2, Name: "genre 2", Rating: 9999}, }, Translations: []Translation{{ Id: 1000, BookId: 100, Lang: "ru", Comments: []Comment{ {TrackableId: 1000, TrackableType: "Translation", Text: "comment3"}, }, }, { Id: 1001, BookId: 100, Lang: "md", Comments: nil, }}, Comments: []Comment{ {TrackableId: 100, TrackableType: "Book", Text: "comment1"}, {TrackableId: 100, TrackableType: "Book", Text: "comment2"}, }, })) }) It("supports HasMany -> HasOne, HasMany -> HasMany", func() { var author Author err := db.Model(&author). Column( "author.*", "Books.id", "Books.author_id", "Books.editor_id", "Books.Author", "Books.Editor", "Books.Translations", ). First() Expect(err).NotTo(HaveOccurred()) Expect(author).To(Equal(Author{ ID: 10, Name: "author 1", Books: []*Book{{ Id: 100, Title: "", AuthorID: 10, Author: &Author{ID: 10, Name: "author 1", Books: nil}, EditorID: 11, Editor: &Author{ID: 11, Name: "author 2", Books: nil}, CreatedAt: time.Time{}, Genres: nil, Translations: []Translation{ {Id: 1000, BookId: 100, Book: nil, Lang: "ru", Comments: nil}, {Id: 1001, BookId: 100, Book: nil, Lang: "md", Comments: nil}, }, }, { Id: 101, Title: "", AuthorID: 10, Author: &Author{ID: 10, Name: "author 1", Books: nil}, EditorID: 12, Editor: &Author{ID: 12, Name: "author 3", Books: nil}, CreatedAt: time.Time{}, Genres: nil, Translations: []Translation{ {Id: 1002, BookId: 101, Book: nil, Lang: "ua", Comments: nil}, }, }}, })) }) It("supports HasMany -> HasMany -> HasMany", func() { var genre Genre err := db.Model(&genre). Column("genre.*", "Books.id", "Books.Translations"). First() Expect(err).NotTo(HaveOccurred()) Expect(genre).To(Equal(Genre{ Id: 1, Name: "genre 1", Rating: 0, Books: []Book{{ Id: 100, Title: "", AuthorID: 0, Author: nil, EditorID: 0, Editor: nil, CreatedAt: time.Time{}, Genres: nil, Translations: []Translation{ {Id: 1000, BookId: 100, Book: nil, Lang: "ru", Comments: nil}, {Id: 1001, BookId: 100, Book: nil, Lang: "md", Comments: nil}, }, Comments: nil, }, { Id: 101, Title: "", AuthorID: 0, Author: nil, EditorID: 0, Editor: nil, CreatedAt: time.Time{}, Genres: nil, Translations: []Translation{ {Id: 1002, BookId: 101, Book: nil, Lang: "ua", Comments: nil}, }, Comments: nil, }}, ParentId: 0, Subgenres: nil, })) }) It("supports HasOne -> HasOne", func() { var translation Translation err := db.Model(&translation). Column("tr.*", "Book.id", "Book.Author", "Book.Editor"). First() Expect(err).NotTo(HaveOccurred()) Expect(translation).To(Equal(Translation{ Id: 1000, BookId: 100, Book: &Book{ Id: 100, Author: &Author{ID: 10, Name: "author 1"}, Editor: &Author{ID: 11, Name: "author 2"}, }, Lang: "ru", })) }) It("works when there are no results", func() { var book Book err := db.Model(&book). Column("book.*", "Author", "Genres", "Comments"). Where("1 = 2"). Select() Expect(err).To(Equal(pg.ErrNoRows)) }) It("supports overriding", func() { var book BookWithCommentCount err := db.Model(&book). Column("book.id", "Author", "Genres"). ColumnExpr(`(SELECT COUNT(*) FROM comments WHERE trackable_type = 'Book' AND trackable_id = book.id) AS comment_count`). First() Expect(err).NotTo(HaveOccurred()) Expect(book).To(Equal(BookWithCommentCount{ Book: Book{ Id: 100, Author: &Author{ID: 10, Name: "author 1"}, Genres: []Genre{ {Id: 1, Name: "genre 1", Rating: 999}, {Id: 2, Name: "genre 2", Rating: 9999}, }, }, CommentCount: 2, })) }) }) Describe("slice model", func() { It("supports HasOne, HasMany, HasMany2Many", func() { var books []Book err := db.Model(&books). Column("book.id", "Author", "Editor", "Translations", "Genres"). OrderExpr("book.id ASC"). Select() Expect(err).NotTo(HaveOccurred()) Expect(books).To(Equal([]Book{{ Id: 100, Title: "", AuthorID: 0, Author: &Author{ID: 10, Name: "author 1", Books: nil}, EditorID: 0, Editor: &Author{ID: 11, Name: "author 2", Books: nil}, CreatedAt: time.Time{}, Genres: []Genre{ {Id: 1, Name: "genre 1", Rating: 999, Books: nil, ParentId: 0, Subgenres: nil}, {Id: 2, Name: "genre 2", Rating: 9999, Books: nil, ParentId: 0, Subgenres: nil}, }, Translations: []Translation{ {Id: 1000, BookId: 100, Book: nil, Lang: "ru", Comments: nil}, {Id: 1001, BookId: 100, Book: nil, Lang: "md", Comments: nil}, }, Comments: nil, }, { Id: 101, Title: "", AuthorID: 0, Author: &Author{ID: 10, Name: "author 1", Books: nil}, EditorID: 0, Editor: &Author{ID: 12, Name: "author 3", Books: nil}, CreatedAt: time.Time{}, Genres: []Genre{ {Id: 1, Name: "genre 1", Rating: 99999, Books: nil, ParentId: 0, Subgenres: nil}, }, Translations: []Translation{ {Id: 1002, BookId: 101, Book: nil, Lang: "ua", Comments: nil}, }, Comments: nil, }, { Id: 102, Title: "", AuthorID: 0, Author: &Author{ID: 11, Name: "author 2", Books: nil}, EditorID: 0, Editor: &Author{ID: 11, Name: "author 2", Books: nil}, CreatedAt: time.Time{}, Genres: nil, Translations: nil, Comments: nil, }})) }) It("supports HasMany2Many, HasMany2Many -> HasMany", func() { var genres []Genre err := db.Model(&genres). Column("genre.*", "Subgenres", "Books.id", "Books.Translations"). Where("genre.parent_id IS NULL"). OrderExpr("genre.id"). Select() Expect(err).NotTo(HaveOccurred()) Expect(genres).To(Equal([]Genre{{ Id: 1, Name: "genre 1", Rating: 0, Books: []Book{{ Id: 100, Title: "", AuthorID: 0, Author: nil, EditorID: 0, Editor: nil, CreatedAt: time.Time{}, Genres: nil, Translations: []Translation{ {Id: 1000, BookId: 100, Book: nil, Lang: "ru", Comments: nil}, {Id: 1001, BookId: 100, Book: nil, Lang: "md", Comments: nil}, }, Comments: nil, }, { Id: 101, Title: "", AuthorID: 0, Author: nil, EditorID: 0, Editor: nil, CreatedAt: time.Time{}, Genres: nil, Translations: []Translation{ {Id: 1002, BookId: 101, Book: nil, Lang: "ua", Comments: nil}, }, Comments: nil, }}, ParentId: 0, Subgenres: []Genre{ {Id: 3, Name: "subgenre 1", Rating: 0, Books: nil, ParentId: 1, Subgenres: nil}, {Id: 4, Name: "subgenre 2", Rating: 0, Books: nil, ParentId: 1, Subgenres: nil}, }, }, { Id: 2, Name: "genre 2", Rating: 0, Books: []Book{{ Id: 100, Title: "", AuthorID: 0, Author: nil, EditorID: 0, Editor: nil, CreatedAt: time.Time{}, Genres: nil, Translations: []Translation{ {Id: 1000, BookId: 100, Book: nil, Lang: "ru", Comments: nil}, {Id: 1001, BookId: 100, Book: nil, Lang: "md", Comments: nil}, }, Comments: nil, }}, ParentId: 0, Subgenres: nil, }, })) }) It("supports HasOne -> HasOne", func() { var translations []Translation err := db.Model(&translations). Column("tr.*", "Book.id", "Book.Author", "Book.Editor"). Select() Expect(err).NotTo(HaveOccurred()) Expect(translations).To(Equal([]Translation{{ Id: 1000, BookId: 100, Book: &Book{ Id: 100, Author: &Author{ID: 10, Name: "author 1"}, Editor: &Author{ID: 11, Name: "author 2"}, }, Lang: "ru", }, { Id: 1001, BookId: 100, Book: &Book{ Id: 100, Author: &Author{ID: 10, Name: "author 1"}, Editor: &Author{ID: 11, Name: "author 2"}, }, Lang: "md", }, { Id: 1002, BookId: 101, Book: &Book{ Id: 101, Author: &Author{ID: 10, Name: "author 1", Books: nil}, Editor: &Author{ID: 12, Name: "author 3", Books: nil}, }, Lang: "ua", }})) }) It("works when there are no results", func() { var books []Book err := db.Model(&books). Column("book.*", "Author", "Genres", "Comments"). Where("1 = 2"). Select() Expect(err).NotTo(HaveOccurred()) Expect(books).To(BeNil()) }) It("supports overriding", func() { var books []BookWithCommentCount err := db.Model(&books). Column("book.id", "Author", "Genres"). ColumnExpr(`(SELECT COUNT(*) FROM comments WHERE trackable_type = 'Book' AND trackable_id = book.id) AS comment_count`). OrderExpr("id ASC"). Select() Expect(err).NotTo(HaveOccurred()) Expect(books).To(Equal([]BookWithCommentCount{{ Book: Book{ Id: 100, Author: &Author{ID: 10, Name: "author 1", Books: nil}, Genres: []Genre{ {Id: 1, Name: "genre 1", Rating: 999}, {Id: 2, Name: "genre 2", Rating: 9999}, }, }, CommentCount: 2, }, { Book: Book{ Id: 101, Author: &Author{ID: 10, Name: "author 1", Books: nil}, Genres: []Genre{ {Id: 1, Name: "genre 1", Rating: 99999}, }, }, CommentCount: 0, }, { Book: Book{ Id: 102, Author: &Author{ID: 11, Name: "author 2", Books: nil}, }, CommentCount: 0, }})) }) }) Describe("slice of ptrs model", func() { It("supports HasOne, HasMany, HasMany2Many", func() { var books []*Book err := db.Model(&books). Column("book.id", "Author", "Editor", "Translations", "Genres"). OrderExpr("book.id ASC"). Select() Expect(err).NotTo(HaveOccurred()) Expect(books).To(HaveLen(3)) }) It("supports HasMany2Many, HasMany2Many -> HasMany", func() { var genres []*Genre err := db.Model(&genres). Column("genre.*", "Subgenres", "Books.id", "Books.Translations"). Where("genre.parent_id IS NULL"). OrderExpr("genre.id"). Select() Expect(err).NotTo(HaveOccurred()) Expect(genres).To(HaveLen(2)) }) It("supports HasOne -> HasOne", func() { var translations []*Translation err := db.Model(&translations). Column("tr.*", "Book.id", "Book.Author", "Book.Editor"). Select() Expect(err).NotTo(HaveOccurred()) Expect(translations).To(HaveLen(3)) }) It("works when there are no results", func() { var books []*Book err := db.Model(&books). Column("book.*", "Author", "Genres", "Comments"). Where("1 = 2"). Select() Expect(err).NotTo(HaveOccurred()) Expect(books).To(BeNil()) }) It("supports overriding", func() { var books []*BookWithCommentCount err := db.Model(&books). Column("book.id", "Author"). ColumnExpr(`(SELECT COUNT(*) FROM comments WHERE trackable_type = 'Book' AND trackable_id = book.id) AS comment_count`). OrderExpr("id ASC"). Select() Expect(err).NotTo(HaveOccurred()) Expect(books).To(HaveLen(3)) }) }) It("filters by HasOne", func() { var books []Book err := db.Model(&books). Column("book.id", "Author._"). Where("author.id = 10"). OrderExpr("book.id ASC"). Select() Expect(err).NotTo(HaveOccurred()) Expect(books).To(Equal([]Book{{ Id: 100, Title: "", AuthorID: 0, Author: nil, EditorID: 0, Editor: nil, CreatedAt: time.Time{}, Genres: nil, Translations: nil, Comments: nil, }, { Id: 101, Title: "", AuthorID: 0, Author: nil, EditorID: 0, Editor: nil, CreatedAt: time.Time{}, Genres: nil, Translations: nil, Comments: nil, }})) }) It("supports filtering HasMany", func() { var book Book err := db.Model(&book). Column("book.id", "Translations"). Relation("Translations", func(q *orm.Query) (*orm.Query, error) { return q.Where("lang = 'ru'"), nil }). First() Expect(err).NotTo(HaveOccurred()) Expect(book).To(Equal(Book{ Id: 100, Translations: []Translation{ {Id: 1000, BookId: 100, Lang: "ru"}, }, })) }) It("supports filtering HasMany2Many", func() { var book Book err := db.Model(&book). Column("book.id", "Genres"). Relation("Genres", func(q *orm.Query) (*orm.Query, error) { return q.Where("genre__rating > 999"), nil }). First() Expect(err).NotTo(HaveOccurred()) Expect(book).To(Equal(Book{ Id: 100, Genres: []Genre{ {Id: 2, Name: "genre 2", Rating: 9999}, }, })) }) }) pg-5.3.3/doc.go000066400000000000000000000001101305650307100132140ustar00rootroot00000000000000/* Package gopkg.in/pg.v5 implements a PostgreSQL client. */ package pg pg-5.3.3/error.go000066400000000000000000000021021305650307100136030ustar00rootroot00000000000000package pg import ( "io" "net" "gopkg.in/pg.v5/internal" ) var ( ErrNoRows = internal.ErrNoRows ErrMultiRows = internal.ErrMultiRows errSSLNotSupported = internal.Errorf("pg: SSL is not enabled on the server") errClosed = internal.Errorf("pg: database is closed") errTxDone = internal.Errorf("pg: transaction has already been committed or rolled back") errStmtClosed = internal.Errorf("pg: statement is closed") errListenerClosed = internal.Errorf("pg: listener is closed") ) type Error interface { Field(byte) string IntegrityViolation() bool } var _ Error = (*internal.PGError)(nil) func isBadConn(err error, allowTimeout bool) bool { if err == nil { return false } if _, ok := err.(internal.Error); ok { return false } if pgErr, ok := err.(Error); ok && pgErr.Field('S') != "FATAL" { return false } if allowTimeout { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { return false } } return true } func isNetworkError(err error) bool { if err == io.EOF { return true } _, ok := err.(net.Error) return ok } pg-5.3.3/example_array_test.go000066400000000000000000000020661305650307100163530ustar00rootroot00000000000000package pg_test import ( "fmt" "gopkg.in/pg.v5" ) func ExampleDB_Model_postgresArrayStructTag() { type Item struct { Id int64 Emails []string `pg:",array"` // marshalled as PostgreSQL array Numbers [][]int `pg:",array"` // marshalled as PostgreSQL array } _, err := db.Exec(`CREATE TEMP TABLE items (id serial, emails text[], numbers int[][])`) if err != nil { panic(err) } defer db.Exec("DROP TABLE items") item1 := Item{ Id: 1, Emails: []string{"one@example.com", "two@example.com"}, Numbers: [][]int{{1, 2}, {3, 4}}, } if err := db.Insert(&item1); err != nil { panic(err) } var item Item err = db.Model(&item).Where("id = ?", 1).Select() if err != nil { panic(err) } fmt.Println(item) // Output: {1 [one@example.com two@example.com] [[1 2] [3 4]]} } func ExampleArray() { src := []string{"one@example.com", "two@example.com"} var dst []string _, err := db.QueryOne(pg.Scan(pg.Array(&dst)), `SELECT ?`, pg.Array(src)) if err != nil { panic(err) } fmt.Println(dst) // Output: [one@example.com two@example.com] } pg-5.3.3/example_hstore_test.go000066400000000000000000000015771305650307100165470ustar00rootroot00000000000000package pg_test import ( "fmt" "gopkg.in/pg.v5" ) func ExampleDB_Model_hstoreStructTag() { type Item struct { Id int64 Attrs map[string]string `pg:",hstore"` // marshalled as PostgreSQL hstore } _, err := db.Exec(`CREATE TEMP TABLE items (id serial, attrs hstore)`) if err != nil { panic(err) } defer db.Exec("DROP TABLE items") item1 := Item{ Id: 1, Attrs: map[string]string{"hello": "world"}, } if err := db.Insert(&item1); err != nil { panic(err) } var item Item err = db.Model(&item).Where("id = ?", 1).Select() if err != nil { panic(err) } fmt.Println(item) // Output: {1 map[hello:world]} } func ExampleHstore() { src := map[string]string{"hello": "world"} var dst map[string]string _, err := db.QueryOne(pg.Scan(pg.Hstore(&dst)), `SELECT ?`, pg.Hstore(src)) if err != nil { panic(err) } fmt.Println(dst) // Output: map[hello:world] } pg-5.3.3/example_model_test.go000066400000000000000000000372121305650307100163360ustar00rootroot00000000000000package pg_test import ( "database/sql" "fmt" "time" "gopkg.in/pg.v5" "gopkg.in/pg.v5/orm" ) func modelDB() *pg.DB { db := pg.Connect(&pg.Options{ User: "postgres", }) err := createTestSchema(db) if err != nil { panic(err) } err = db.Insert(&Author{ Name: "author 1", }) books := []Book{{ Title: "book 1", AuthorID: 1, EditorID: 11, }, { Title: "book 2", AuthorID: 1, EditorID: 12, }, { Title: "book 3", AuthorID: 11, EditorID: 11, CreatedAt: time.Now(), }} err = db.Insert(&books) if err != nil { panic(err) } for i := 0; i < 2; i++ { genre := Genre{ Name: fmt.Sprintf("genre %d", i+1), } err = db.Insert(&genre) if err != nil { panic(err) } err = db.Insert(&BookGenre{ BookId: 1, GenreId: genre.Id, }) if err != nil { panic(err) } } // For CountEstimate. _, err = db.Exec("VACUUM") if err != nil { panic(err) } return db } func ExampleDB_Insert() { db := modelDB() book := Book{ Title: "new book", AuthorID: 1, } err := db.Insert(&book) if err != nil { panic(err) } fmt.Println(book) // Output: Book err = db.Delete(&book) if err != nil { panic(err) } } func ExampleDB_Insert_bulkInsert() { db := modelDB() book1 := Book{ Title: "new book 1", } book2 := Book{ Title: "new book 2", } err := db.Insert(&book1, &book2) if err != nil { panic(err) } fmt.Println(book1, book2) // Output: Book Book for _, book := range []*Book{&book1, &book2} { err := db.Delete(book) if err != nil { panic(err) } } } func ExampleDB_Insert_bulkInsert2() { db := modelDB() books := []Book{{ Title: "new book 1", }, { Title: "new book 2", }} err := db.Insert(&books) if err != nil { panic(err) } fmt.Println(books) // Output: [Book Book] for i := range books { err := db.Delete(&books[i]) if err != nil { panic(err) } } } func ExampleDB_Insert_onConflictDoNothing() { db := modelDB() book := Book{ Id: 100, Title: "book 100", } for i := 0; i < 2; i++ { res, err := db.Model(&book).OnConflict("DO NOTHING").Insert() if err != nil { panic(err) } if res.RowsAffected() > 0 { fmt.Println("created") } else { fmt.Println("did nothing") } } err := db.Delete(&book) if err != nil { panic(err) } // Output: created // did nothing } func ExampleDB_Insert_onConflictDoUpdate() { db := modelDB() var book *Book for i := 0; i < 2; i++ { book = &Book{ Id: 100, Title: fmt.Sprintf("title version #%d", i), } _, err := db.Model(book). OnConflict("(id) DO UPDATE"). Set("title = ?title"). Insert() if err != nil { panic(err) } err = db.Select(book) if err != nil { panic(err) } fmt.Println(book) } err := db.Delete(book) if err != nil { panic(err) } // Output: Book // Book } func ExampleDB_Insert_selectOrInsert() { db := modelDB() author := Author{ Name: "R. Scott Bakker", } created, err := db.Model(&author). Column("id"). Where("name = ?name"). OnConflict("DO NOTHING"). // OnConflict is optional Returning("id"). SelectOrInsert() if err != nil { panic(err) } fmt.Println(created, author) // Output: true Author } func ExampleDB_Select() { db := modelDB() book := Book{ Id: 1, } err := db.Select(&book) if err != nil { panic(err) } fmt.Println(book) // Output: Book } func ExampleDB_Select_firstRow() { db := modelDB() var firstBook Book err := db.Model(&firstBook).First() if err != nil { panic(err) } fmt.Println(firstBook) // Output: Book } func ExampleDB_Select_lastRow() { db := modelDB() var lastBook Book err := db.Model(&lastBook).Last() if err != nil { panic(err) } fmt.Println(lastBook) // Output: Book } func ExampleDB_Select_allColumns() { db := modelDB() var book Book err := db.Model(&book).Column("book.*").First() if err != nil { panic(err) } fmt.Println(book, book.AuthorID) // Output: Book 1 } func ExampleDB_Select_someColumns() { db := modelDB() var book Book err := db.Model(&book). Column("book.id", "book.title"). OrderExpr("book.id ASC"). Limit(1). Select() if err != nil { panic(err) } fmt.Println(book) // Output: Book } func ExampleDB_Select_someColumnsIntoVars() { db := modelDB() var id int var title string err := db.Model(&Book{}). Column("book.id", "book.title"). OrderExpr("book.id ASC"). Limit(1). Select(&id, &title) if err != nil { panic(err) } fmt.Println(id, title) // Output: 1 book 1 } func ExampleDB_Select_whereIn() { db := modelDB() var books []Book err := db.Model(&books).WhereIn("id IN (?)", 1, 2).Select() if err != nil { panic(err) } fmt.Println(books) // Output: [Book Book] } func ExampleDB_Select_sqlExpression() { db := modelDB() var ids []int err := db.Model(&Book{}). ColumnExpr("array_agg(book.id)"). Select(pg.Array(&ids)) if err != nil { panic(err) } fmt.Println(ids) // Output: [1 2 3] } func ExampleDB_Select_groupBy() { db := modelDB() var res []struct { AuthorId int BookCount int } err := db.Model(&Book{}). Column("author_id"). ColumnExpr("count(*) AS book_count"). Group("author_id"). OrderExpr("book_count DESC"). Select(&res) if err != nil { panic(err) } fmt.Println("len", len(res)) fmt.Printf("author %d has %d books\n", res[0].AuthorId, res[0].BookCount) fmt.Printf("author %d has %d books\n", res[1].AuthorId, res[1].BookCount) // Output: len 2 // author 1 has 2 books // author 11 has 1 books } func ExampleDB_Select_with() { authorBooks := db.Model(&Book{}).Where("author_id = ?", 1) var books []Book err := db.Model(). With("author_books", authorBooks). Table("author_books"). Select(&books) if err != nil { panic(err) } fmt.Println(books) // Output: [Book Book] } func ExampleDB_Select_wrapWith() { // WITH author_books AS ( // SELECT * books WHERE author_id = 1 // ) // SELECT * FROM author_books var books []Book err := db.Model(&books). Where("author_id = ?", 1). WrapWith("author_books"). Table("author_books"). Select(&books) if err != nil { panic(err) } fmt.Println(books) // Output: [Book Book] } func ExampleDB_Select_applyFunc() { db := modelDB() var authorId int var editorId int filter := func(q *orm.Query) (*orm.Query, error) { if authorId != 0 { q = q.Where("author_id = ?", authorId) } if editorId != 0 { q = q.Where("editor_id = ?", editorId) } return q, nil } var books []Book authorId = 1 err := db.Model(&books). Apply(filter). Select() if err != nil { panic(err) } fmt.Println(books) // Output: [Book Book] } func ExampleDB_Model_count() { db := modelDB() count, err := db.Model(&Book{}).Count() if err != nil { panic(err) } fmt.Println(count) // Output: 3 } func ExampleDB_Model_countEstimate() { db := modelDB() count, err := db.Model(&Book{}).CountEstimate(0) if err != nil { panic(err) } fmt.Println(count) // Output: 3 } func ExampleDB_Model_selectAndCount() { db := modelDB() var books []Book count, err := db.Model(&books).OrderExpr("id ASC").Limit(2).SelectAndCount() if err != nil { panic(err) } fmt.Println(count) fmt.Println(books) // Output: 3 // [Book Book] } func ExampleDB_Model_nullEmptyValue() { type Example struct { Hello string } var str sql.NullString _, err := db.QueryOne(pg.Scan(&str), "SELECT ?hello", &Example{Hello: ""}) if err != nil { panic(err) } fmt.Println(str.Valid) // Output: false } func ExampleDB_Model_hasOne() { type Profile struct { Id int Lang string } // User has one profile. type User struct { Id int Name string ProfileId int Profile *Profile } db := connect() defer db.Close() qs := []string{ "CREATE TEMP TABLE users (id int, name text, profile_id int)", "CREATE TEMP TABLE profiles (id int, lang text)", "INSERT INTO users VALUES (1, 'user 1', 1), (2, 'user 2', 2)", "INSERT INTO profiles VALUES (1, 'en'), (2, 'ru')", } for _, q := range qs { _, err := db.Exec(q) if err != nil { panic(err) } } // Select users joining their profiles with following query: // // SELECT // "user".*, // "profile"."id" AS "profile__id", // "profile"."lang" AS "profile__lang", // "profile"."user_id" AS "profile__user_id" // FROM "users" AS "user" // LEFT JOIN "profiles" AS "profile" ON "profile"."user_id" = "user"."id" var users []User err := db.Model(&users). Column("user.*", "Profile"). Select() if err != nil { panic(err) } fmt.Println(len(users), "results") fmt.Println(users[0].Id, users[0].Name, users[0].Profile) fmt.Println(users[1].Id, users[1].Name, users[1].Profile) // Output: 2 results // 1 user 1 &{1 en} // 2 user 2 &{2 ru} } func ExampleDB_Model_belongsTo() { // Profile belongs to User. type Profile struct { Id int Lang string UserId int } type User struct { Id int Name string Profile *Profile } db := connect() defer db.Close() qs := []string{ "CREATE TEMP TABLE users (id int, name text)", "CREATE TEMP TABLE profiles (id int, lang text, user_id int)", "INSERT INTO users VALUES (1, 'user 1'), (2, 'user 2')", "INSERT INTO profiles VALUES (1, 'en', 1), (2, 'ru', 2)", } for _, q := range qs { _, err := db.Exec(q) if err != nil { panic(err) } } // Select users joining their profiles with following query: // // SELECT // "user".*, // "profile"."id" AS "profile__id", // "profile"."lang" AS "profile__lang" // FROM "users" AS "user" // LEFT JOIN "profiles" AS "profile" ON "profile"."id" = "user"."profile_id" var users []User err := db.Model(&users). Column("user.*", "Profile"). Select() if err != nil { panic(err) } fmt.Println(len(users), "results") fmt.Println(users[0].Id, users[0].Name, users[0].Profile) fmt.Println(users[1].Id, users[1].Name, users[1].Profile) // Output: 2 results // 1 user 1 &{1 en 1} // 2 user 2 &{2 ru 2} } func ExampleDB_Model_hasMany() { type Profile struct { Id int Lang string Active bool UserId int } // User has many profiles. type User struct { Id int Name string Profiles []*Profile } db := connect() defer db.Close() qs := []string{ "CREATE TEMP TABLE users (id int, name text)", "CREATE TEMP TABLE profiles (id int, lang text, active bool, user_id int)", "INSERT INTO users VALUES (1, 'user 1')", "INSERT INTO profiles VALUES (1, 'en', TRUE, 1), (2, 'ru', TRUE, 1), (3, 'md', FALSE, 1)", } for _, q := range qs { _, err := db.Exec(q) if err != nil { panic(err) } } // Select user and all his active profiles with following queries: // // SELECT "user".* FROM "users" AS "user" ORDER BY "user"."id" LIMIT 1 // // SELECT "profile".* FROM "profiles" AS "profile" // WHERE (active IS TRUE) AND (("profile"."user_id") IN ((1))) var user User err := db.Model(&user). Column("user.*", "Profiles"). Relation("Profiles", func(q *orm.Query) (*orm.Query, error) { return q.Where("active IS TRUE"), nil }). First() if err != nil { panic(err) } fmt.Println(user.Id, user.Name, user.Profiles[0], user.Profiles[1]) // Output: 1 user 1 &{1 en true 1} &{2 ru true 1} } func ExampleDB_Model_hasManySelf() { type Item struct { Id int Items []Item `pg:",fk:Parent"` ParentId int } db := connect() defer db.Close() qs := []string{ "CREATE TEMP TABLE items (id int, parent_id int)", "INSERT INTO items VALUES (1, NULL), (2, 1), (3, 1)", } for _, q := range qs { _, err := db.Exec(q) if err != nil { panic(err) } } // Select item and all subitems with following queries: // // SELECT "item".* FROM "items" AS "item" ORDER BY "item"."id" LIMIT 1 // // SELECT "item".* FROM "items" AS "item" WHERE (("item"."parent_id") IN ((1))) var item Item err := db.Model(&item).Column("item.*", "Items").First() if err != nil { panic(err) } fmt.Println("Item", item.Id) fmt.Println("Subitems", item.Items[0].Id, item.Items[1].Id) // Output: Item 1 // Subitems 2 3 } func ExampleDB_Model_manyToMany() { type Item struct { Id int Items []Item `pg:",many2many:item_to_items,joinFK:Sub"` } db := connect() defer db.Close() qs := []string{ "CREATE TEMP TABLE items (id int)", "CREATE TEMP TABLE item_to_items (item_id int, sub_id int)", "INSERT INTO items VALUES (1), (2), (3)", "INSERT INTO item_to_items VALUES (1, 2), (1, 3)", } for _, q := range qs { _, err := db.Exec(q) if err != nil { panic(err) } } // Select item and all subitems with following queries: // // SELECT "item".* FROM "items" AS "item" ORDER BY "item"."id" LIMIT 1 // // SELECT * FROM "items" AS "item" // JOIN "item_to_items" ON ("item_to_items"."item_id") IN ((1)) // WHERE ("item"."id" = "item_to_items"."sub_id") var item Item err := db.Model(&item).Column("item.*", "Items").First() if err != nil { panic(err) } fmt.Println("Item", item.Id) fmt.Println("Subitems", item.Items[0].Id, item.Items[1].Id) // Output: Item 1 // Subitems 2 3 } func ExampleDB_Update() { db := modelDB() err := db.Update(&Book{ Id: 1, Title: "updated book 1", }) if err != nil { panic(err) } var book Book err = db.Model(&book).Where("id = ?", 1).Select() if err != nil { panic(err) } fmt.Println(book) // Output: Book } func ExampleDB_Update_someColumns() { db := modelDB() book := Book{ Id: 1, Title: "updated book 1", // only this column is going to be updated AuthorID: 2, } _, err := db.Model(&book).Column("title").Returning("*").Update() if err != nil { panic(err) } fmt.Println(book, book.AuthorID) // Output: Book 1 } func ExampleDB_Update_someColumns2() { db := modelDB() book := Book{ Id: 1, Title: "updated book 1", AuthorID: 2, // this column will not be updated } _, err := db.Model(&book).Set("title = ?title").Returning("*").Update() if err != nil { panic(err) } fmt.Println(book, book.AuthorID) // Output: Book 1 } func ExampleDB_Update_setValues() { db := modelDB() var book Book _, err := db.Model(&book). Set("title = concat(?, title, ?)", "prefix ", " suffix"). Where("id = ?", 1). Returning("*"). Update() if err != nil { panic(err) } fmt.Println(book) // Output: Book } func ExampleDB_Delete() { db := modelDB() book := Book{ Title: "title 1", AuthorID: 1, } err := db.Insert(&book) if err != nil { panic(err) } err = db.Delete(&book) if err != nil { panic(err) } err = db.Select(&book) fmt.Println(err) // Output: pg: no rows in result set } func ExampleDB_Delete_multipleRows() { db := modelDB() ids := pg.In([]int{1, 2, 3}) res, err := db.Model(&Book{}).Where("id IN (?)", ids).Delete() if err != nil { panic(err) } fmt.Println("deleted", res.RowsAffected()) count, err := db.Model(&Book{}).Count() if err != nil { panic(err) } fmt.Println("left", count) // Output: deleted 3 // left 0 } func ExampleQ() { db := modelDB() cond := fmt.Sprintf("id = %d", 1) var book Book err := db.Model(&book).Where("?", pg.Q(cond)).Select() if err != nil { panic(err) } fmt.Println(book) // Output: Book } func ExampleF() { db := modelDB() var book Book err := db.Model(&book).Where("? = 1", pg.F("id")).Select() if err != nil { panic(err) } fmt.Println(book) // Output: Book } pg-5.3.3/example_placeholders_test.go000066400000000000000000000023141305650307100176760ustar00rootroot00000000000000package pg_test import ( "fmt" "gopkg.in/pg.v5" ) type Params struct { X int Y int } func (p *Params) Sum() int { return p.X + p.Y } // go-pg recognizes placeholders (`?`) in queries and replaces them // with parameters when queries are executed. Parameters are escaped // before replacing according to PostgreSQL rules. Specifically: // - all parameters are properly quoted against SQL injections; // - null byte is removed; // - JSON/JSONB gets `\u0000` escaped as `\\u0000`. func Example_placeholders() { var num int // Simple params. _, err := db.Query(pg.Scan(&num), "SELECT ?", 42) if err != nil { panic(err) } fmt.Println("simple:", num) // Indexed params. _, err = db.Query(pg.Scan(&num), "SELECT ?0 + ?0", 1) if err != nil { panic(err) } fmt.Println("indexed:", num) // Named params. params := &Params{ X: 1, Y: 1, } _, err = db.Query(pg.Scan(&num), "SELECT ?x + ?y + ?Sum", params) if err != nil { panic(err) } fmt.Println("named:", num) // Global params. _, err = db.WithParam("z", 1).Query(pg.Scan(&num), "SELECT ?x + ?y + ?z", params) if err != nil { panic(err) } fmt.Println("global:", num) // Output: simple: 42 // indexed: 2 // named: 4 // global: 3 } pg-5.3.3/example_test.go000066400000000000000000000113621305650307100151540ustar00rootroot00000000000000package pg_test import ( "bytes" "fmt" "strings" "time" "gopkg.in/pg.v5" "gopkg.in/pg.v5/orm" ) var db *pg.DB func init() { db = connect() } func connect() *pg.DB { return pg.Connect(pgOptions()) } func ExampleConnect() { db := pg.Connect(&pg.Options{ User: "postgres", Password: "", Database: "postgres", }) var n int _, err := db.QueryOne(pg.Scan(&n), "SELECT 1") if err != nil { panic(err) } fmt.Println(n) err = db.Close() if err != nil { panic(err) } // Output: 1 } func ExampleDB_QueryOne() { var user struct { Name string } res, err := db.QueryOne(&user, ` WITH users (name) AS (VALUES (?)) SELECT * FROM users `, "admin") if err != nil { panic(err) } fmt.Println(res.RowsAffected()) fmt.Println(user) // Output: 1 // {admin} } func ExampleDB_QueryOne_returning_id() { _, err := db.Exec(`CREATE TEMP TABLE users(id serial, name varchar(500))`) if err != nil { panic(err) } var user struct { Id int32 Name string } user.Name = "admin" _, err = db.QueryOne(&user, ` INSERT INTO users (name) VALUES (?name) RETURNING id `, &user) if err != nil { panic(err) } fmt.Println(user) // Output: {1 admin} } func ExampleDB_Exec() { res, err := db.Exec(`CREATE TEMP TABLE test()`) fmt.Println(res.RowsAffected(), err) // Output: -1 } func ExampleListener() { ln := db.Listen("mychan") defer ln.Close() ch := ln.Channel() go func() { _, err := db.Exec("NOTIFY mychan, ?", "hello world") if err != nil { panic(err) } }() notif := <-ch fmt.Println(notif) // Output: &{mychan hello world} } func txExample() *pg.DB { db := pg.Connect(&pg.Options{ User: "postgres", }) queries := []string{ `DROP TABLE IF EXISTS tx_test`, `CREATE TABLE tx_test(counter int)`, `INSERT INTO tx_test (counter) VALUES (0)`, } for _, q := range queries { _, err := db.Exec(q) if err != nil { panic(err) } } return db } func ExampleDB_Begin() { db := txExample() tx, err := db.Begin() if err != nil { panic(err) } var counter int _, err = tx.QueryOne(pg.Scan(&counter), `SELECT counter FROM tx_test`) if err != nil { tx.Rollback() panic(err) } counter++ _, err = tx.Exec(`UPDATE tx_test SET counter = ?`, counter) if err != nil { tx.Rollback() panic(err) } err = tx.Commit() if err != nil { panic(err) } fmt.Println(counter) // Output: 1 } func ExampleDB_RunInTransaction() { db := txExample() var counter int // Transaction is automatically rollbacked on error. err := db.RunInTransaction(func(tx *pg.Tx) error { _, err := tx.QueryOne(pg.Scan(&counter), `SELECT counter FROM tx_test`) if err != nil { return err } counter++ _, err = tx.Exec(`UPDATE tx_test SET counter = ?`, counter) return err }) if err != nil { panic(err) } fmt.Println(counter) // Output: 1 } func ExampleDB_Prepare() { stmt, err := db.Prepare(`SELECT $1::text, $2::text`) if err != nil { panic(err) } var s1, s2 string _, err = stmt.QueryOne(pg.Scan(&s1, &s2), "foo", "bar") if err != nil { panic(err) } fmt.Println(s1, s2) // Output: foo bar } func ExampleDB_CreateTable() { type Model struct { Id int Name string } err := db.CreateTable(&Model{}, &orm.CreateTableOptions{ Temp: true, // create temp table }) if err != nil { panic(err) } var info []struct { ColumnName string DataType string } _, err = db.Query(&info, ` SELECT column_name, data_type FROM information_schema.columns WHERE table_name = 'models' `) if err != nil { panic(err) } fmt.Println(info) // Output: [{id bigint} {name text}] } func ExampleInts() { var nums pg.Ints _, err := db.Query(&nums, `SELECT generate_series(0, 10)`) fmt.Println(nums, err) // Output: [0 1 2 3 4 5 6 7 8 9 10] } func ExampleStrings() { var strs pg.Strings _, err := db.Query(&strs, ` WITH users AS (VALUES ('foo'), ('bar')) SELECT * FROM users `) fmt.Println(strs, err) // Output: [foo bar] } func ExampleDB_CopyFrom() { _, err := db.Exec(`CREATE TEMP TABLE words(word text, len int)`) if err != nil { panic(err) } r := strings.NewReader("hello,5\nfoo,3\n") _, err = db.CopyFrom(r, `COPY words FROM STDIN WITH CSV`) if err != nil { panic(err) } var buf bytes.Buffer _, err = db.CopyTo(&buf, `COPY words TO STDOUT WITH CSV`) if err != nil { panic(err) } fmt.Println(buf.String()) // Output: hello,5 // foo,3 } func ExampleDB_WithTimeout() { var count int // Use bigger timeout since this query is known to be slow. _, err := db.WithTimeout(time.Minute).QueryOne(pg.Scan(&count), ` SELECT count(*) FROM big_table `) if err != nil { panic(err) } } func ExampleScan() { var s1, s2 string _, err := db.QueryOne(pg.Scan(&s1, &s2), `SELECT ?, ?`, "foo", "bar") fmt.Println(s1, s2, err) // Output: foo bar } pg-5.3.3/exampledb_model_test.go000066400000000000000000000035531305650307100166450ustar00rootroot00000000000000package pg_test import ( "fmt" "gopkg.in/pg.v5" "gopkg.in/pg.v5/orm" ) type User struct { Id int64 Name string Emails []string } func (u User) String() string { return fmt.Sprintf("User<%d %s %v>", u.Id, u.Name, u.Emails) } type Story struct { Id int64 Title string AuthorId int64 Author *User } func (s Story) String() string { return fmt.Sprintf("Story<%d %s %s>", s.Id, s.Title, s.Author) } func ExampleDB_Model() { db := pg.Connect(&pg.Options{ User: "postgres", }) err := createSchema(db) if err != nil { panic(err) } user1 := &User{ Name: "admin", Emails: []string{"admin1@admin", "admin2@admin"}, } err = db.Insert(user1) if err != nil { panic(err) } err = db.Insert(&User{ Name: "root", Emails: []string{"root1@root", "root2@root"}, }) if err != nil { panic(err) } story1 := &Story{ Title: "Cool story", AuthorId: user1.Id, } err = db.Insert(story1) if err != nil { panic(err) } // Select user by primary key. user := User{Id: user1.Id} err = db.Select(&user) if err != nil { panic(err) } // Select all users. var users []User err = db.Model(&users).Select() if err != nil { panic(err) } // Select story and associated author in one query. var story Story err = db.Model(&story). Column("story.*", "Author"). Where("story.id = ?", story1.Id). Select() if err != nil { panic(err) } fmt.Println(user) fmt.Println(users) fmt.Println(story) // Output: User<1 admin [admin1@admin admin2@admin]> // [User<1 admin [admin1@admin admin2@admin]> User<2 root [root1@root root2@root]>] // Story<1 Cool story User<1 admin [admin1@admin admin2@admin]>> } func createSchema(db *pg.DB) error { for _, model := range []interface{}{&User{}, &Story{}} { err := db.CreateTable(model, &orm.CreateTableOptions{ Temp: true, }) if err != nil { return err } } return nil } pg-5.3.3/exampledb_query_test.go000066400000000000000000000043001305650307100167010ustar00rootroot00000000000000package pg_test import ( "fmt" "gopkg.in/pg.v5" ) func CreateUser(db *pg.DB, user *User) error { _, err := db.QueryOne(user, ` INSERT INTO users (name, emails) VALUES (?name, ?emails) RETURNING id `, user) return err } func GetUser(db *pg.DB, id int64) (*User, error) { var user User _, err := db.QueryOne(&user, `SELECT * FROM users WHERE id = ?`, id) return &user, err } func GetUsers(db *pg.DB) ([]User, error) { var users []User _, err := db.Query(&users, `SELECT * FROM users`) return users, err } func GetUsersByIds(db *pg.DB, ids []int64) ([]User, error) { var users []User _, err := db.Query(&users, `SELECT * FROM users WHERE id IN (?)`, pg.In(ids)) return users, err } func CreateStory(db *pg.DB, story *Story) error { _, err := db.QueryOne(story, ` INSERT INTO stories (title, author_id) VALUES (?title, ?author_id) RETURNING id `, story) return err } // GetStory returns story with associated author. func GetStory(db *pg.DB, id int64) (*Story, error) { var story Story _, err := db.QueryOne(&story, ` SELECT s.*, u.id AS author__id, u.name AS author__name, u.emails AS author__emails FROM stories AS s, users AS u WHERE s.id = ? AND u.id = s.author_id `, id) return &story, err } func ExampleDB_Query() { db := pg.Connect(&pg.Options{ User: "postgres", }) err := createSchema(db) if err != nil { panic(err) } user1 := &User{ Name: "admin", Emails: []string{"admin1@admin", "admin2@admin"}, } err = CreateUser(db, user1) if err != nil { panic(err) } err = CreateUser(db, &User{ Name: "root", Emails: []string{"root1@root", "root2@root"}, }) if err != nil { panic(err) } story1 := &Story{ Title: "Cool story", AuthorId: user1.Id, } err = CreateStory(db, story1) user, err := GetUser(db, user1.Id) if err != nil { panic(err) } users, err := GetUsers(db) if err != nil { panic(err) } story, err := GetStory(db, story1.Id) if err != nil { panic(err) } fmt.Println(user) fmt.Println(users) fmt.Println(story) // Output: User<1 admin [admin1@admin admin2@admin]> // [User<1 admin [admin1@admin admin2@admin]> User<2 root [root1@root root2@root]>] // Story<1 Cool story User<1 admin [admin1@admin admin2@admin]>> } pg-5.3.3/export_test.go000066400000000000000000000002541305650307100150400ustar00rootroot00000000000000package pg import "gopkg.in/pg.v5/internal/pool" func (db *DB) Pool() *pool.ConnPool { return db.pool } func (ln *Listener) CurrentConn() *pool.Conn { return ln._cn } pg-5.3.3/hook_test.go000066400000000000000000000062551305650307100144660ustar00rootroot00000000000000package pg_test import ( "gopkg.in/pg.v5" "gopkg.in/pg.v5/orm" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type HookTest struct { Id int Value string afterQuery int afterSelect int beforeInsert int afterInsert int beforeUpdate int afterUpdate int beforeDelete int afterDelete int } func (t *HookTest) AfterQuery(db orm.DB) error { t.afterQuery++ return nil } func (t *HookTest) AfterSelect(db orm.DB) error { t.afterSelect++ return nil } func (t *HookTest) BeforeInsert(db orm.DB) error { t.beforeInsert++ return nil } func (t *HookTest) AfterInsert(db orm.DB) error { t.afterInsert++ return nil } func (t *HookTest) BeforeUpdate(db orm.DB) error { t.beforeUpdate++ return nil } func (t *HookTest) AfterUpdate(db orm.DB) error { t.afterUpdate++ return nil } func (t *HookTest) BeforeDelete(db orm.DB) error { t.beforeDelete++ return nil } func (t *HookTest) AfterDelete(db orm.DB) error { t.afterDelete++ return nil } var _ = Describe("HookTest", func() { var db *pg.DB BeforeEach(func() { db = pg.Connect(pgOptions()) qs := []string{ "CREATE TEMP TABLE hook_tests (id int, value text)", "INSERT INTO hook_tests VALUES (1, '')", } for _, q := range qs { _, err := db.Exec(q) Expect(err).NotTo(HaveOccurred()) } }) AfterEach(func() { Expect(db.Close()).NotTo(HaveOccurred()) }) It("calls AfterQuery for struct", func() { var hook HookTest _, err := db.QueryOne(&hook, "SELECT 1 AS id") Expect(err).NotTo(HaveOccurred()) Expect(hook.afterQuery).To(Equal(1)) Expect(hook.afterSelect).To(Equal(0)) }) It("calls AfterQuery and AfterSelect for struct model", func() { var hook HookTest err := db.Model(&hook).Select() Expect(err).NotTo(HaveOccurred()) Expect(hook.afterQuery).To(Equal(1)) Expect(hook.afterSelect).To(Equal(1)) }) It("calls AfterQuery for slice", func() { var hooks []HookTest _, err := db.Query(&hooks, "SELECT 1 AS id") Expect(err).NotTo(HaveOccurred()) Expect(hooks).To(HaveLen(1)) Expect(hooks[0].afterQuery).To(Equal(1)) Expect(hooks[0].afterSelect).To(Equal(0)) }) It("calls AfterQuery and AfterSelect for slice model", func() { var hooks []HookTest err := db.Model(&hooks).Select() Expect(err).NotTo(HaveOccurred()) Expect(hooks).To(HaveLen(1)) Expect(hooks[0].afterQuery).To(Equal(1)) Expect(hooks[0].afterSelect).To(Equal(1)) }) It("calls BeforeInsert and AfterInsert", func() { hook := &HookTest{ Id: 1, Value: "value", } err := db.Insert(&hook) Expect(err).NotTo(HaveOccurred()) Expect(hook.afterQuery).To(Equal(0)) Expect(hook.beforeInsert).To(Equal(1)) Expect(hook.afterInsert).To(Equal(1)) }) It("calls BeforeUpdate and AfterUpdate", func() { hook := &HookTest{ Id: 1, } err := db.Update(&hook) Expect(err).NotTo(HaveOccurred()) Expect(hook.afterQuery).To(Equal(0)) Expect(hook.beforeUpdate).To(Equal(1)) Expect(hook.afterUpdate).To(Equal(1)) }) It("calls BeforeDelete and AfterDelete", func() { hook := &HookTest{ Id: 1, } err := db.Delete(&hook) Expect(err).NotTo(HaveOccurred()) Expect(hook.afterQuery).To(Equal(0)) Expect(hook.beforeDelete).To(Equal(1)) Expect(hook.afterDelete).To(Equal(1)) }) }) pg-5.3.3/internal/000077500000000000000000000000001305650307100137445ustar00rootroot00000000000000pg-5.3.3/internal/error.go000066400000000000000000000017401305650307100154260ustar00rootroot00000000000000package internal import "fmt" var ( ErrNoRows = Errorf("pg: no rows in result set") ErrMultiRows = Errorf("pg: multiple rows in result set") ) type Error struct { s string } func Errorf(s string, args ...interface{}) Error { return Error{s: fmt.Sprintf(s, args...)} } func (err Error) Error() string { return err.s } type PGError struct { m map[byte]string } func NewPGError(m map[byte]string) PGError { return PGError{ m: m, } } func (err PGError) Field(k byte) string { return err.m[k] } func (err PGError) IntegrityViolation() bool { switch err.Field('C') { case "23000", "23001", "23502", "23503", "23505", "23514", "23P01": return true default: return false } } func (err PGError) Error() string { return fmt.Sprintf( "%s #%s %s (addr=%q)", err.Field('S'), err.Field('C'), err.Field('M'), err.Field('a'), ) } func AssertOneRow(l int) error { switch { case l == 0: return ErrNoRows case l > 1: return ErrMultiRows default: return nil } } pg-5.3.3/internal/internal.go000066400000000000000000000001151305650307100161040ustar00rootroot00000000000000package internal import "time" const RetryBackoff = 250 * time.Millisecond pg-5.3.3/internal/log.go000066400000000000000000000013211305650307100150510ustar00rootroot00000000000000package internal import ( "fmt" "log" "path/filepath" "runtime" "strings" ) var ( Logger *log.Logger QueryLogger *log.Logger ) func Logf(s string, args ...interface{}) { if Logger == nil { return } Logger.Output(2, fmt.Sprintf(s, args...)) } func LogQuery(query string) { if QueryLogger == nil { return } file, line := fileLine(2) QueryLogger.Printf("%s:%d: %s", file, line, strings.TrimRight(query, "\t\n")) } const packageName = "gopkg.in/pg.v5" func fileLine(depth int) (string, int) { for i := depth; ; i++ { _, file, line, ok := runtime.Caller(i) if !ok { break } if strings.Contains(file, packageName) { continue } return filepath.Base(file), line } return "", 0 } pg-5.3.3/internal/parser/000077500000000000000000000000001305650307100152405ustar00rootroot00000000000000pg-5.3.3/internal/parser/array_parser.go000066400000000000000000000022441305650307100202630ustar00rootroot00000000000000package parser import ( "bytes" "fmt" ) type ArrayParser struct { *Parser stickyErr error } func NewArrayParser(b []byte) *ArrayParser { var err error if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { err = fmt.Errorf("pg: can't parse array: %s", string(b)) } else { b = b[1 : len(b)-1] } return &ArrayParser{ Parser: New(b), stickyErr: err, } } func (p *ArrayParser) NextElem() ([]byte, error) { if p.stickyErr != nil { return nil, p.stickyErr } switch c := p.Peek(); c { case '"': p.Advance() b := p.readSubstring() p.Skip(',') return b, nil case '{': b := p.readElem() if b != nil { b = append(b, '}') } p.Skip(',') return b, nil default: b, _ := p.ReadSep(',') if bytes.Equal(b, pgNull) { b = nil } return b, nil } } func (p *ArrayParser) readElem() []byte { var b []byte for p.Valid() { c := p.Read() switch c { case '"': b = append(b, '"') for { bb, ok := p.ReadSep('"') b = append(b, bb...) stop := len(b) > 0 && b[len(b)-1] != '\\' if ok { b = append(b, '"') } if stop { break } } case '}': return b default: b = append(b, c) } } return b } pg-5.3.3/internal/parser/array_parser_test.go000066400000000000000000000021051305650307100213160ustar00rootroot00000000000000package parser_test import ( "testing" "gopkg.in/pg.v5/internal/parser" ) var arrayTests = []struct { s string els []string }{ {`{"\\"}`, []string{`\`}}, {`{"''"}`, []string{`'`}}, {`{{"''\"{}"}}`, []string{`{"''\"{}"}`}}, {`{"''\"{}"}`, []string{`'"{}`}}, {"{1,2}", []string{"1", "2"}}, {"{1,NULL}", []string{"1", ""}}, {`{"1","2"}`, []string{"1", "2"}}, {`{"{1}","{2}"}`, []string{"{1}", "{2}"}}, {"{{1,2},{3}}", []string{"{1,2}", "{3}"}}, } func TestArrayParser(t *testing.T) { for testi, test := range arrayTests { p := parser.NewArrayParser([]byte(test.s)) var got []string for p.Valid() { b, err := p.NextElem() if err != nil { t.Fatal(err) } got = append(got, string(b)) } if len(got) != len(test.els) { t.Fatalf( "#%d got %d elements, wanted %d (got=%#v wanted=%#v)", testi, len(got), len(test.els), got, test.els, ) } for i, el := range got { if el != test.els[i] { t.Fatalf( "#%d el #%d does not match: %q != %q (got=%#v wanted=%#v)", testi, i, el, test.els[i], got, test.els, ) } } } } pg-5.3.3/internal/parser/hstore_parser.go000066400000000000000000000013231305650307100204460ustar00rootroot00000000000000package parser import "fmt" type HstoreParser struct { *Parser } func NewHstoreParser(b []byte) *HstoreParser { return &HstoreParser{ Parser: New(b), } } func (p *HstoreParser) NextKey() ([]byte, error) { if p.Skip(',') { p.Skip(' ') } if !p.Skip('"') { return nil, fmt.Errorf("pg: can't parse hstore key: %q", p.Bytes()) } key := p.readSubstring() if !(p.Skip('=') && p.Skip('>')) { return nil, fmt.Errorf("pg: can't parse hstore key: %q", p.Bytes()) } return key, nil } func (p *HstoreParser) NextValue() ([]byte, error) { if !p.Skip('"') { return nil, fmt.Errorf("pg: can't parse hstore value: %q", p.Bytes()) } value := p.readSubstring() p.SkipBytes([]byte(", ")) return value, nil } pg-5.3.3/internal/parser/hstore_parser_test.go000066400000000000000000000022311305650307100215040ustar00rootroot00000000000000package parser_test import ( "testing" "gopkg.in/pg.v5/internal/parser" ) var hstoreTests = []struct { s string m map[string]string }{ {`""=>""`, map[string]string{"": ""}}, {`"k''k"=>"k''k"`, map[string]string{"k'k": "k'k"}}, {`"k\"k"=>"k\"k"`, map[string]string{`k"k`: `k"k`}}, {`"k\k"=>"k\k"`, map[string]string{`k\k`: `k\k`}}, {`"foo"=>"bar"`, map[string]string{"foo": "bar"}}, {`"foo"=>"bar","k"=>"v"`, map[string]string{"foo": "bar", "k": "v"}}, } func TestHstoreParser(t *testing.T) { for testi, test := range hstoreTests { p := parser.NewHstoreParser([]byte(test.s)) got := make(map[string]string) for p.Valid() { key, err := p.NextKey() if err != nil { t.Fatal(err) } value, err := p.NextValue() if err != nil { t.Fatal(err) } got[string(key)] = string(value) } if len(got) != len(test.m) { t.Fatalf( "#%d got %d elements, wanted %d (got=%#v wanted=%#v)", testi, len(got), len(test.m), got, test.m, ) } for k, v := range got { if v != test.m[k] { t.Fatalf( "#%d el %q does not match: %q != %q (got=%#v wanted=%#v)", testi, k, v, test.m[k], got, test.m, ) } } } } pg-5.3.3/internal/parser/parser.go000066400000000000000000000040571305650307100170710ustar00rootroot00000000000000package parser import ( "bytes" "strconv" "gopkg.in/pg.v5/internal" ) type Parser struct { b []byte } func New(b []byte) *Parser { return &Parser{b: b} } func NewString(s string) *Parser { return &Parser{b: internal.StringToBytes(s)} } func (p *Parser) Bytes() []byte { return p.b } func (p *Parser) Valid() bool { return len(p.b) > 0 } func (p *Parser) Read() byte { c := p.b[0] p.Skip(c) return c } func (p *Parser) Peek() byte { if p.Valid() { return p.b[0] } return 0 } func (p *Parser) Advance() { p.b = p.b[1:] } func (p *Parser) Skip(c byte) bool { if p.Peek() == c { p.Advance() return true } return false } func (p *Parser) SkipBytes(b []byte) bool { if len(b) > len(p.b) { return false } if !bytes.Equal(p.b[:len(b)], b) { return false } p.b = p.b[len(b):] return true } func (p *Parser) ReadSep(c byte) ([]byte, bool) { ind := bytes.IndexByte(p.b, c) if ind == -1 { b := p.b p.b = p.b[len(p.b):] return b, false } b := p.b[:ind] p.b = p.b[ind+1:] return b, true } func (p *Parser) ReadIdentifier() (s string, numeric bool) { pos := len(p.b) numeric = true for i, ch := range p.b { if isNum(ch) { continue } if isAlpha(ch) || ch == '_' { numeric = false continue } pos = i break } if pos <= 0 { return "", false } b := p.b[:pos] p.b = p.b[pos:] return internal.BytesToString(b), numeric } func (p *Parser) ReadNumber() int { end := len(p.b) for i, ch := range p.b { if !isNum(ch) { end = i break } } if end <= 0 { return 0 } n, _ := strconv.Atoi(string(p.b[:end])) p.b = p.b[end:] return n } func (p *Parser) readSubstring() []byte { var b []byte for p.Valid() { c := p.Read() switch c { case '\\': switch p.Peek() { case '\\': b = append(b, '\\') p.Advance() case '"': b = append(b, '"') p.Advance() default: b = append(b, c) } case '\'': switch p.Peek() { case '\'': b = append(b, '\'') p.Skip(c) default: b = append(b, c) } case '"': return b default: b = append(b, c) } } return b } pg-5.3.3/internal/parser/util.go000066400000000000000000000003741305650307100165500ustar00rootroot00000000000000package parser var pgNull = []byte("NULL") func isNum(c byte) bool { return c >= '0' && c <= '9' } func isAlpha(c byte) bool { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') } func isAlnum(c byte) bool { return isAlpha(c) || isNum(c) } pg-5.3.3/internal/pool/000077500000000000000000000000001305650307100147155ustar00rootroot00000000000000pg-5.3.3/internal/pool/bench_test.go000066400000000000000000000031601305650307100173620ustar00rootroot00000000000000package pool_test import ( "errors" "testing" "time" "gopkg.in/pg.v5/internal/pool" ) func benchmarkPoolGetPut(b *testing.B, poolSize int) { connPool := pool.NewConnPool(&pool.Options{ Dial: dummyDialer, PoolSize: poolSize, PoolTimeout: time.Second, IdleTimeout: time.Hour, IdleCheckFrequency: time.Hour, }) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { cn, _, err := connPool.Get() if err != nil { b.Fatal(err) } if err = connPool.Put(cn); err != nil { b.Fatal(err) } } }) } func BenchmarkPoolGetPut10Conns(b *testing.B) { benchmarkPoolGetPut(b, 10) } func BenchmarkPoolGetPut100Conns(b *testing.B) { benchmarkPoolGetPut(b, 100) } func BenchmarkPoolGetPut1000Conns(b *testing.B) { benchmarkPoolGetPut(b, 1000) } func benchmarkPoolGetRemove(b *testing.B, poolSize int) { connPool := pool.NewConnPool(&pool.Options{ Dial: dummyDialer, PoolSize: poolSize, PoolTimeout: time.Second, IdleTimeout: time.Hour, IdleCheckFrequency: time.Hour, }) removeReason := errors.New("benchmark") b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { cn, _, err := connPool.Get() if err != nil { b.Fatal(err) } if err := connPool.Remove(cn, removeReason); err != nil { b.Fatal(err) } } }) } func BenchmarkPoolGetRemove10Conns(b *testing.B) { benchmarkPoolGetRemove(b, 10) } func BenchmarkPoolGetRemove100Conns(b *testing.B) { benchmarkPoolGetRemove(b, 100) } func BenchmarkPoolGetRemove1000Conns(b *testing.B) { benchmarkPoolGetRemove(b, 1000) } pg-5.3.3/internal/pool/conn.go000066400000000000000000000034401305650307100162020ustar00rootroot00000000000000package pool import ( "bufio" "encoding/hex" "fmt" "io" "net" "strconv" "time" ) var noDeadline = time.Time{} type Conn struct { netConn net.Conn buf []byte // read buffer Rd *bufio.Reader Columns [][]byte Wr *WriteBuffer InitedAt time.Time UsedAt time.Time ProcessId int32 SecretKey int32 _lastId int64 } func NewConn(netConn net.Conn) *Conn { cn := &Conn{ buf: make([]byte, 0, 512), Rd: bufio.NewReader(netConn), Wr: NewWriteBuffer(), UsedAt: time.Now(), } cn.SetNetConn(netConn) return cn } func (cn *Conn) RemoteAddr() net.Addr { return cn.netConn.RemoteAddr() } func (cn *Conn) SetNetConn(netConn net.Conn) { cn.netConn = netConn cn.Rd.Reset(netConn) } func (cn *Conn) NetConn() net.Conn { return cn.netConn } func (cn *Conn) NextId() string { cn._lastId++ return strconv.FormatInt(cn._lastId, 10) } func (cn *Conn) SetReadWriteTimeout(rt, wt time.Duration) { cn.UsedAt = time.Now() if rt > 0 { cn.netConn.SetReadDeadline(cn.UsedAt.Add(rt)) } else { cn.netConn.SetReadDeadline(noDeadline) } if wt > 0 { cn.netConn.SetWriteDeadline(cn.UsedAt.Add(wt)) } else { cn.netConn.SetWriteDeadline(noDeadline) } } func (cn *Conn) ReadN(n int) ([]byte, error) { if d := n - cap(cn.buf); d > 0 { cn.buf = cn.buf[:cap(cn.buf)] cn.buf = append(cn.buf, make([]byte, d)...) } else { cn.buf = cn.buf[:n] } _, err := io.ReadFull(cn.Rd, cn.buf) return cn.buf, err } func (cn *Conn) FlushWriter() error { _, err := cn.netConn.Write(cn.Wr.Bytes) cn.Wr.Reset() return err } func (cn *Conn) Close() error { return cn.netConn.Close() } func (cn *Conn) CheckHealth() error { if cn.Rd.Buffered() != 0 { b, _ := cn.Rd.Peek(cn.Rd.Buffered()) err := fmt.Errorf("connection has unread data:\n%s", hex.Dump(b)) return err } return nil } pg-5.3.3/internal/pool/main_test.go000066400000000000000000000010141305650307100172230ustar00rootroot00000000000000package pool_test import ( "net" "sync" "testing" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) func TestGinkgoSuite(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "pool") } func perform(n int, cbs ...func(int)) { var wg sync.WaitGroup for _, cb := range cbs { for i := 0; i < n; i++ { wg.Add(1) go func(cb func(int), i int) { defer GinkgoRecover() defer wg.Done() cb(i) }(cb, i) } } wg.Wait() } func dummyDialer() (net.Conn, error) { return &net.TCPConn{}, nil } pg-5.3.3/internal/pool/pool.go000066400000000000000000000156351305650307100162270ustar00rootroot00000000000000package pool import ( "errors" "net" "sync" "sync/atomic" "time" "gopkg.in/pg.v5/internal" ) var ( ErrClosed = errors.New("pg: database is closed") ErrPoolTimeout = errors.New("pg: connection pool timeout") errConnStale = errors.New("pg: connection is stale") ) var timers = sync.Pool{ New: func() interface{} { t := time.NewTimer(time.Hour) t.Stop() return t }, } // Stats contains pool state information and accumulated stats. type Stats struct { Requests uint32 // number of times a connection was requested by the pool Hits uint32 // number of times free connection was found in the pool Timeouts uint32 // number of times a wait timeout occurred TotalConns uint32 // the number of total connections in the pool FreeConns uint32 // the number of free connections in the pool } type Pooler interface { Get() (*Conn, bool, error) Put(*Conn) error Remove(*Conn, error) error Len() int FreeLen() int Stats() *Stats Close() error Closed() bool } type Options struct { Dial func() (net.Conn, error) OnClose func(*Conn) error PoolSize int PoolTimeout time.Duration IdleTimeout time.Duration IdleCheckFrequency time.Duration MaxAge time.Duration } type ConnPool struct { opt *Options queue chan struct{} connsMu sync.Mutex conns []*Conn freeConnsMu sync.Mutex freeConns []*Conn stats Stats _closed int32 // atomic } var _ Pooler = (*ConnPool)(nil) func NewConnPool(opt *Options) *ConnPool { p := &ConnPool{ opt: opt, queue: make(chan struct{}, opt.PoolSize), conns: make([]*Conn, 0, opt.PoolSize), freeConns: make([]*Conn, 0, opt.PoolSize), } if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { go p.reaper(opt.IdleCheckFrequency) } return p } func (p *ConnPool) dial() (net.Conn, error) { cn, err := p.opt.Dial() if err != nil { return nil, err } return cn, nil } func (p *ConnPool) NewConn() (*Conn, error) { netConn, err := p.dial() if err != nil { return nil, err } return NewConn(netConn), nil } func (p *ConnPool) isStaleConn(cn *Conn) bool { if p.opt.IdleTimeout == 0 && p.opt.MaxAge == 0 { return false } now := time.Now() if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt) >= p.opt.IdleTimeout { return true } if p.opt.MaxAge > 0 && now.Sub(cn.InitedAt) >= p.opt.MaxAge { return true } return false } func (p *ConnPool) PopFree() *Conn { timer := timers.Get().(*time.Timer) timer.Reset(p.opt.PoolTimeout) select { case p.queue <- struct{}{}: if !timer.Stop() { <-timer.C } timers.Put(timer) case <-timer.C: timers.Put(timer) atomic.AddUint32(&p.stats.Timeouts, 1) return nil } p.freeConnsMu.Lock() cn := p.popFree() p.freeConnsMu.Unlock() if cn == nil { <-p.queue } return cn } func (p *ConnPool) popFree() *Conn { if len(p.freeConns) == 0 { return nil } idx := len(p.freeConns) - 1 cn := p.freeConns[idx] p.freeConns = p.freeConns[:idx] return cn } // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get() (*Conn, bool, error) { if p.Closed() { return nil, false, ErrClosed } atomic.AddUint32(&p.stats.Requests, 1) timer := timers.Get().(*time.Timer) timer.Reset(p.opt.PoolTimeout) select { case p.queue <- struct{}{}: if !timer.Stop() { <-timer.C } timers.Put(timer) case <-timer.C: timers.Put(timer) atomic.AddUint32(&p.stats.Timeouts, 1) return nil, false, ErrPoolTimeout } for { p.freeConnsMu.Lock() cn := p.popFree() p.freeConnsMu.Unlock() if cn == nil { break } if p.isStaleConn(cn) { p.remove(cn, errConnStale) continue } atomic.AddUint32(&p.stats.Hits, 1) return cn, false, nil } newcn, err := p.NewConn() if err != nil { <-p.queue return nil, false, err } p.connsMu.Lock() p.conns = append(p.conns, newcn) p.connsMu.Unlock() return newcn, true, nil } func (p *ConnPool) Put(cn *Conn) error { if e := cn.CheckHealth(); e != nil { internal.Logf(e.Error()) return p.Remove(cn, e) } p.freeConnsMu.Lock() p.freeConns = append(p.freeConns, cn) p.freeConnsMu.Unlock() <-p.queue return nil } func (p *ConnPool) Remove(cn *Conn, reason error) error { p.remove(cn, reason) <-p.queue return nil } func (p *ConnPool) remove(cn *Conn, reason error) { _ = p.closeConn(cn, reason) p.connsMu.Lock() for i, c := range p.conns { if c == cn { p.conns = append(p.conns[:i], p.conns[i+1:]...) break } } p.connsMu.Unlock() } // Len returns total number of connections. func (p *ConnPool) Len() int { p.connsMu.Lock() l := len(p.conns) p.connsMu.Unlock() return l } // FreeLen returns number of free connections. func (p *ConnPool) FreeLen() int { p.freeConnsMu.Lock() l := len(p.freeConns) p.freeConnsMu.Unlock() return l } func (p *ConnPool) Stats() *Stats { return &Stats{ Requests: atomic.LoadUint32(&p.stats.Requests), Hits: atomic.LoadUint32(&p.stats.Hits), Timeouts: atomic.LoadUint32(&p.stats.Timeouts), TotalConns: uint32(p.Len()), FreeConns: uint32(p.FreeLen()), } } func (p *ConnPool) Closed() bool { return atomic.LoadInt32(&p._closed) == 1 } func (p *ConnPool) Close() (retErr error) { if !atomic.CompareAndSwapInt32(&p._closed, 0, 1) { return ErrClosed } p.connsMu.Lock() // Close all connections. for _, cn := range p.conns { if cn == nil { continue } if err := p.closeConn(cn, ErrClosed); err != nil && retErr == nil { retErr = err } } p.conns = nil p.connsMu.Unlock() p.freeConnsMu.Lock() p.freeConns = nil p.freeConnsMu.Unlock() return retErr } func (p *ConnPool) closeConn(cn *Conn, reason error) error { if p.opt.OnClose != nil { _ = p.opt.OnClose(cn) } return cn.Close() } func (p *ConnPool) reapStaleConn() bool { if len(p.freeConns) == 0 { return false } cn := p.freeConns[0] if !p.isStaleConn(cn) { return false } p.remove(cn, errConnStale) p.freeConns = append(p.freeConns[:0], p.freeConns[1:]...) return true } func (p *ConnPool) ReapStaleConns() (int, error) { var n int for { p.queue <- struct{}{} p.freeConnsMu.Lock() reaped := p.reapStaleConn() p.freeConnsMu.Unlock() <-p.queue if reaped { n++ } else { break } } return n, nil } func (p *ConnPool) reaper(frequency time.Duration) { ticker := time.NewTicker(frequency) defer ticker.Stop() for _ = range ticker.C { if p.Closed() { break } n, err := p.ReapStaleConns() if err != nil { internal.Logf("ReapStaleConns failed: %s", err) continue } s := p.Stats() internal.Logf( "reaper: removed %d stale conns (TotalConns=%d FreeConns=%d Requests=%d Hits=%d Timeouts=%d)", n, s.TotalConns, s.FreeConns, s.Requests, s.Hits, s.Timeouts, ) } } //------------------------------------------------------------------------------ var idleCheckFrequency atomic.Value func SetIdleCheckFrequency(d time.Duration) { idleCheckFrequency.Store(d) } func getIdleCheckFrequency() time.Duration { v := idleCheckFrequency.Load() if v == nil { return time.Minute } return v.(time.Duration) } pg-5.3.3/internal/pool/pool_test.go000066400000000000000000000142511305650307100172570ustar00rootroot00000000000000package pool_test import ( "errors" "testing" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "gopkg.in/pg.v5/internal/pool" ) var _ = Describe("ConnPool", func() { var connPool *pool.ConnPool BeforeEach(func() { connPool = pool.NewConnPool(&pool.Options{ Dial: dummyDialer, PoolSize: 10, PoolTimeout: time.Hour, IdleTimeout: time.Millisecond, IdleCheckFrequency: time.Millisecond, }) }) AfterEach(func() { connPool.Close() }) It("should unblock client when conn is removed", func() { // Reserve one connection. cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) // Reserve all other connections. var cns []*pool.Conn for i := 0; i < 9; i++ { cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) cns = append(cns, cn) } started := make(chan bool, 1) done := make(chan bool, 1) go func() { defer GinkgoRecover() started <- true _, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) done <- true err = connPool.Put(cn) Expect(err).NotTo(HaveOccurred()) }() <-started // Check that Get is blocked. select { case <-done: Fail("Get is not blocked") default: // ok } err = connPool.Remove(cn, errors.New("test")) Expect(err).NotTo(HaveOccurred()) // Check that Ping is unblocked. select { case <-done: // ok case <-time.After(time.Second): Fail("Get is not unblocked") } for _, cn := range cns { err = connPool.Put(cn) Expect(err).NotTo(HaveOccurred()) } }) }) var _ = Describe("conns reaper", func() { const idleTimeout = time.Minute const maxAge = time.Hour var connPool *pool.ConnPool var conns, staleConns, closedConns []*pool.Conn assert := func(typ string) { BeforeEach(func() { closedConns = nil connPool = pool.NewConnPool(&pool.Options{ Dial: dummyDialer, PoolSize: 10, PoolTimeout: time.Second, IdleTimeout: idleTimeout, MaxAge: maxAge, IdleCheckFrequency: time.Hour, OnClose: func(cn *pool.Conn) error { closedConns = append(closedConns, cn) return nil }, }) conns = nil // add stale connections staleConns = nil for i := 0; i < 3; i++ { cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) switch typ { case "idle": cn.UsedAt = time.Now().Add(-2 * idleTimeout) case "aged": cn.InitedAt = time.Now().Add(-2 * maxAge) } conns = append(conns, cn) staleConns = append(staleConns, cn) } // add fresh connections for i := 0; i < 3; i++ { cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) conns = append(conns, cn) } for _, cn := range conns { if cn.InitedAt.IsZero() { cn.InitedAt = time.Now() } Expect(connPool.Put(cn)).NotTo(HaveOccurred()) } Expect(connPool.Len()).To(Equal(6)) Expect(connPool.FreeLen()).To(Equal(6)) n, err := connPool.ReapStaleConns() Expect(err).NotTo(HaveOccurred()) Expect(n).To(Equal(3)) }) AfterEach(func() { _ = connPool.Close() Expect(connPool.Len()).To(Equal(0)) Expect(connPool.FreeLen()).To(Equal(0)) Expect(len(closedConns)).To(Equal(len(conns))) Expect(closedConns).To(ConsistOf(conns)) }) It("reaps stale connections", func() { Expect(connPool.Len()).To(Equal(3)) Expect(connPool.FreeLen()).To(Equal(3)) }) It("does not reap fresh connections", func() { n, err := connPool.ReapStaleConns() Expect(err).NotTo(HaveOccurred()) Expect(n).To(Equal(0)) }) It("stale connections are closed", func() { Expect(len(closedConns)).To(Equal(len(staleConns))) Expect(closedConns).To(ConsistOf(staleConns)) }) It("pool is functional", func() { for j := 0; j < 3; j++ { var freeCns []*pool.Conn for i := 0; i < 3; i++ { cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) freeCns = append(freeCns, cn) } Expect(connPool.Len()).To(Equal(3)) Expect(connPool.FreeLen()).To(Equal(0)) cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) conns = append(conns, cn) Expect(connPool.Len()).To(Equal(4)) Expect(connPool.FreeLen()).To(Equal(0)) err = connPool.Remove(cn, errors.New("test")) Expect(err).NotTo(HaveOccurred()) Expect(connPool.Len()).To(Equal(3)) Expect(connPool.FreeLen()).To(Equal(0)) for _, cn := range freeCns { err := connPool.Put(cn) Expect(err).NotTo(HaveOccurred()) } Expect(connPool.Len()).To(Equal(3)) Expect(connPool.FreeLen()).To(Equal(3)) } }) } assert("idle") assert("aged") }) var _ = Describe("race", func() { var connPool *pool.ConnPool var C, N int BeforeEach(func() { C, N = 10, 1000 if testing.Short() { C = 4 N = 100 } }) AfterEach(func() { connPool.Close() }) It("does not happen on Get, Put, and Remove", func() { connPool = pool.NewConnPool(&pool.Options{ Dial: dummyDialer, PoolSize: 10, PoolTimeout: time.Minute, IdleTimeout: time.Millisecond, IdleCheckFrequency: time.Millisecond, }) perform(C, func(id int) { for i := 0; i < N; i++ { cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) if err == nil { Expect(connPool.Put(cn)).NotTo(HaveOccurred()) } } }, func(id int) { for i := 0; i < N; i++ { cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) if err == nil { Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred()) } } }) }) It("does not happen on Get and PopFree", func() { connPool = pool.NewConnPool(&pool.Options{ Dial: dummyDialer, PoolSize: 10, PoolTimeout: time.Minute, IdleTimeout: time.Second, IdleCheckFrequency: time.Millisecond, }) perform(C, func(id int) { for i := 0; i < N; i++ { cn, _, err := connPool.Get() Expect(err).NotTo(HaveOccurred()) if err == nil { Expect(connPool.Put(cn)).NotTo(HaveOccurred()) } cn = connPool.PopFree() if cn != nil { Expect(connPool.Put(cn)).NotTo(HaveOccurred()) } } }) }) }) pg-5.3.3/internal/pool/write_buffer.go000066400000000000000000000037661305650307100177430ustar00rootroot00000000000000package pool import ( "encoding/binary" "io" ) type WriteBuffer struct { Bytes []byte msgStart, paramStart int } func NewWriteBuffer() *WriteBuffer { return &WriteBuffer{ Bytes: make([]byte, 0, 4096), } } func (buf *WriteBuffer) StartMessage(c byte) { if c == 0 { buf.msgStart = len(buf.Bytes) buf.Bytes = append(buf.Bytes, 0, 0, 0, 0) } else { buf.msgStart = len(buf.Bytes) + 1 buf.Bytes = append(buf.Bytes, c, 0, 0, 0, 0) } } func (buf *WriteBuffer) FinishMessage() { binary.BigEndian.PutUint32( buf.Bytes[buf.msgStart:], uint32(len(buf.Bytes)-buf.msgStart)) } func (buf *WriteBuffer) StartParam() { buf.paramStart = len(buf.Bytes) buf.Bytes = append(buf.Bytes, 0, 0, 0, 0) } func (buf *WriteBuffer) FinishParam() { binary.BigEndian.PutUint32( buf.Bytes[buf.paramStart:], uint32(len(buf.Bytes)-buf.paramStart-4)) } var nullParamLength = int32(-1) func (buf *WriteBuffer) FinishNullParam() { binary.BigEndian.PutUint32( buf.Bytes[buf.paramStart:], uint32(nullParamLength)) } func (buf *WriteBuffer) Write(b []byte) (int, error) { buf.Bytes = append(buf.Bytes, b...) return len(b), nil } func (buf *WriteBuffer) WriteInt16(num int16) { buf.Bytes = append(buf.Bytes, 0, 0) binary.BigEndian.PutUint16(buf.Bytes[len(buf.Bytes)-2:], uint16(num)) } func (buf *WriteBuffer) WriteInt32(num int32) { buf.Bytes = append(buf.Bytes, 0, 0, 0, 0) binary.BigEndian.PutUint32(buf.Bytes[len(buf.Bytes)-4:], uint32(num)) } func (buf *WriteBuffer) WriteString(s string) { buf.Bytes = append(buf.Bytes, s...) buf.Bytes = append(buf.Bytes, 0) } func (buf *WriteBuffer) WriteBytes(b []byte) { buf.Bytes = append(buf.Bytes, b...) buf.Bytes = append(buf.Bytes, 0) } func (buf *WriteBuffer) WriteByte(c byte) { buf.Bytes = append(buf.Bytes, c) } func (buf *WriteBuffer) Reset() { buf.Bytes = buf.Bytes[:0] } func (buf *WriteBuffer) ReadFrom(r io.Reader) (int64, error) { n, err := r.Read(buf.Bytes[len(buf.Bytes):cap(buf.Bytes)]) buf.Bytes = buf.Bytes[:len(buf.Bytes)+int(n)] return int64(n), err } pg-5.3.3/internal/safe.go000066400000000000000000000002341305650307100152100ustar00rootroot00000000000000// +build appengine package internal func BytesToString(b []byte) string { return string(b) } func StringToBytes(s string) []byte { return []byte(s) } pg-5.3.3/internal/underscore.go000066400000000000000000000021771305650307100164530ustar00rootroot00000000000000package internal func isUpper(c byte) bool { return c >= 'A' && c <= 'Z' } func isLower(c byte) bool { return !isUpper(c) } func toUpper(c byte) byte { return c - 32 } func toLower(c byte) byte { return c + 32 } // Underscore converts "CamelCasedString" to "camel_cased_string". func Underscore(s string) string { r := make([]byte, 0, len(s)) for i := 0; i < len(s); i++ { c := s[i] if isUpper(c) { if i > 0 && i+1 < len(s) && (isLower(s[i-1]) || isLower(s[i+1])) { r = append(r, '_', toLower(c)) } else { r = append(r, toLower(c)) } } else { r = append(r, c) } } return string(r) } func ToUpper(s string) string { if isUpperString(s) { return s } b := make([]byte, len(s)) for i := range b { c := s[i] if c >= 'a' && c <= 'z' { c -= 'a' - 'A' } b[i] = c } return string(b) } func isUpperString(s string) bool { for i := 0; i < len(s); i++ { c := s[i] if c >= 'a' && c <= 'z' { return false } } return true } func ToExported(s string) string { if len(s) == 0 { return s } if c := s[0]; isLower(c) { b := []byte(s) b[0] = toUpper(c) return string(b) } return s } pg-5.3.3/internal/underscore_test.go000066400000000000000000000006711305650307100175070ustar00rootroot00000000000000package internal_test import ( "testing" "gopkg.in/pg.v5/internal" ) func TestUnderscore(t *testing.T) { tests := []struct { s, wanted string }{ {"Megacolumn", "megacolumn"}, {"MegaColumn", "mega_column"}, {"MegaColumn_Id", "mega_column__id"}, {"MegaColumn_id", "mega_column_id"}, } for _, v := range tests { if got := internal.Underscore(v.s); got != v.wanted { t.Errorf("got %q, wanted %q", got, v.wanted) } } } pg-5.3.3/internal/unsafe.go000066400000000000000000000006601305650307100155560ustar00rootroot00000000000000// +build !appengine package internal import ( "reflect" "unsafe" ) func BytesToString(b []byte) string { bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) sh := reflect.StringHeader{bh.Data, bh.Len} return *(*string)(unsafe.Pointer(&sh)) } func StringToBytes(s string) []byte { sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) bh := reflect.SliceHeader{sh.Data, sh.Len, sh.Len} return *(*[]byte)(unsafe.Pointer(&bh)) } pg-5.3.3/internal/util.go000066400000000000000000000006461305650307100152560ustar00rootroot00000000000000package internal import "reflect" func SliceNextElem(v reflect.Value) reflect.Value { if v.Len() < v.Cap() { v.Set(v.Slice(0, v.Len()+1)) return v.Index(v.Len() - 1) } elemType := v.Type().Elem() if elemType.Kind() == reflect.Ptr { elem := reflect.New(elemType.Elem()) v.Set(reflect.Append(v, elem)) return elem.Elem() } v.Set(reflect.Append(v, reflect.Zero(elemType))) return v.Index(v.Len() - 1) } pg-5.3.3/listener.go000066400000000000000000000067121305650307100143120ustar00rootroot00000000000000package pg import ( "sync" "time" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/internal/pool" ) // A notification received with LISTEN command. type Notification struct { Channel string Payload string } // Listener listens for notifications sent with NOTIFY command. // It's NOT safe for concurrent use by multiple goroutines // except the Channel API. type Listener struct { db *DB channels []string mu sync.Mutex _cn *pool.Conn closed bool } func (ln *Listener) conn(readTimeout time.Duration) (*pool.Conn, error) { ln.mu.Lock() defer ln.mu.Unlock() if ln.closed { return nil, errListenerClosed } if ln._cn == nil { cn, err := ln.db.conn() if err != nil { return nil, err } ln._cn = cn if len(ln.channels) > 0 { if err := ln.listen(cn, ln.channels...); err != nil { return nil, err } } } ln._cn.SetReadWriteTimeout(readTimeout, ln.db.opt.WriteTimeout) return ln._cn, nil } // Channel returns a channel for concurrently receiving notifications. // The channel is closed with Listener. func (ln *Listener) Channel() <-chan *Notification { ch := make(chan *Notification, 100) go func() { for { channel, payload, err := ln.ReceiveTimeout(5 * time.Second) if err != nil { if err == errListenerClosed { break } continue } ch <- &Notification{channel, payload} } close(ch) }() return ch } // Listen starts listening for notifications on channels. func (ln *Listener) Listen(channels ...string) error { cn, err := ln.conn(ln.db.opt.ReadTimeout) if err != nil { return err } if err := ln.listen(cn, channels...); err != nil { if err != nil { ln.freeConn(err) } return err } ln.channels = appendIfNotExists(ln.channels, channels...) return nil } func (ln *Listener) listen(cn *pool.Conn, channels ...string) error { for _, channel := range channels { if err := writeQueryMsg(cn.Wr, ln.db, "LISTEN ?", F(channel)); err != nil { return err } } return cn.FlushWriter() } // Receive indefinitely waits for a notification. func (ln *Listener) Receive() (channel string, payload string, err error) { return ln.ReceiveTimeout(0) } // ReceiveTimeout waits for a notification until timeout is reached. func (ln *Listener) ReceiveTimeout(timeout time.Duration) (channel, payload string, err error) { channel, payload, err = ln.receiveTimeout(timeout) if err != nil { ln.freeConn(err) } return channel, payload, err } func (ln *Listener) receiveTimeout(readTimeout time.Duration) (channel, payload string, err error) { cn, err := ln.conn(readTimeout) if err != nil { return "", "", err } return readNotification(cn) } func (ln *Listener) freeConn(err error) (retErr error) { if !isBadConn(err, true) { return nil } return ln.closeConn(err) } func (ln *Listener) closeConn(reason error) error { var firstErr error ln.mu.Lock() if ln._cn != nil { if !ln.closed { internal.Logf("pg: discarding bad listener connection: %s", reason) } firstErr = ln.db.pool.Remove(ln._cn, reason) ln._cn = nil } ln.mu.Unlock() return firstErr } // Close closes the listener, releasing any open resources. func (ln *Listener) Close() error { ln.mu.Lock() closed := ln.closed ln.closed = true ln.mu.Unlock() if closed { return errListenerClosed } return ln.closeConn(errListenerClosed) } func appendIfNotExists(ss []string, es ...string) []string { loop: for _, e := range es { for _, s := range ss { if s == e { continue loop } } ss = append(ss, e) } return ss } pg-5.3.3/listener_test.go000066400000000000000000000070001305650307100153400ustar00rootroot00000000000000package pg_test import ( "net" "time" "gopkg.in/pg.v5" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Context("Listener", func() { var db *pg.DB var ln *pg.Listener BeforeEach(func() { opt := pgOptions() opt.PoolSize = 2 opt.PoolTimeout = time.Second db = pg.Connect(opt) ln = db.Listen("test_channel") }) var _ = AfterEach(func() { _ = ln.Close() err := db.Close() Expect(err).NotTo(HaveOccurred()) }) It("reuses connection", func() { for i := 0; i < 100; i++ { _, _, err := ln.ReceiveTimeout(time.Nanosecond) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(MatchRegexp(".+ i/o timeout")) } st := db.Pool().Stats() Expect(st.Requests).To(Equal(uint32(1))) Expect(st.Hits).To(Equal(uint32(0))) Expect(st.Timeouts).To(Equal(uint32(0))) Expect(st.TotalConns).To(Equal(uint32(1))) Expect(st.FreeConns).To(Equal(uint32(0))) }) It("listens for notifications", func() { wait := make(chan struct{}, 2) go func() { defer GinkgoRecover() wait <- struct{}{} channel, payload, err := ln.Receive() Expect(err).NotTo(HaveOccurred()) Expect(channel).To(Equal("test_channel")) Expect(payload).To(Equal("")) wait <- struct{}{} }() select { case <-wait: // ok case <-time.After(3 * time.Second): Fail("timeout") } _, err := db.Exec("NOTIFY test_channel") Expect(err).NotTo(HaveOccurred()) select { case <-wait: // ok case <-time.After(3 * time.Second): Fail("timeout") } }) It("is aborted when DB is closed", func() { wait := make(chan struct{}, 2) go func() { defer GinkgoRecover() wait <- struct{}{} _, _, err := ln.Receive() Expect(err.Error()).To(SatisfyAny( Equal("EOF"), MatchRegexp(`use of closed (file or )?network connection$`), )) wait <- struct{}{} }() select { case <-wait: // ok case <-time.After(3 * time.Second): Fail("timeout") } select { case <-wait: Fail("Receive is not blocked") case <-time.After(time.Second): // ok } Expect(ln.Close()).To(BeNil()) select { case <-wait: // ok case <-time.After(3 * time.Second): Fail("timeout") } }) It("returns an error on timeout", func() { channel, payload, err := ln.ReceiveTimeout(time.Second) Expect(err.(net.Error).Timeout()).To(BeTrue()) Expect(channel).To(Equal("")) Expect(payload).To(Equal("")) }) It("reconnects on listen error", func() { cn := ln.CurrentConn() Expect(cn).NotTo(BeNil()) cn.SetNetConn(&badConn{}) err := ln.Listen("test_channel2") Expect(err).Should(MatchError("bad connection")) err = ln.Listen("test_channel2") Expect(err).NotTo(HaveOccurred()) }) It("reconnects on receive error", func() { cn := ln.CurrentConn() Expect(cn).NotTo(BeNil()) cn.SetNetConn(&badConn{}) _, _, err := ln.ReceiveTimeout(time.Second) Expect(err).Should(MatchError("bad connection")) _, _, err = ln.ReceiveTimeout(time.Second) Expect(err.(net.Error).Timeout()).To(BeTrue()) wait := make(chan struct{}, 2) go func() { defer GinkgoRecover() wait <- struct{}{} _, _, err := ln.Receive() Expect(err).NotTo(HaveOccurred()) wait <- struct{}{} }() select { case <-wait: // ok case <-time.After(3 * time.Second): Fail("timeout") } select { case <-wait: Fail("Receive is not blocked") case <-time.After(time.Second): // ok } _, err = db.Exec("NOTIFY test_channel") Expect(err).NotTo(HaveOccurred()) select { case <-wait: // ok case <-time.After(3 * time.Second): Fail("timeout") } }) }) pg-5.3.3/loader_test.go000066400000000000000000000066761305650307100150030ustar00rootroot00000000000000package pg_test import ( "errors" "gopkg.in/pg.v5" "gopkg.in/pg.v5/orm" . "gopkg.in/check.v1" ) type LoaderTest struct { db *pg.DB } var _ = Suite(&LoaderTest{}) func (t *LoaderTest) SetUpTest(c *C) { t.db = pg.Connect(pgOptions()) } func (t *LoaderTest) TearDownTest(c *C) { c.Assert(t.db.Close(), IsNil) } type numLoader struct { Num int } type embeddedLoader struct { *numLoader Num2 int } type multipleLoader struct { One struct { Num int } Num int } func (t *LoaderTest) TestQuery(c *C) { var dst numLoader _, err := t.db.Query(&dst, "SELECT 1 AS num") c.Assert(err, IsNil) c.Assert(dst.Num, Equals, 1) } func (t *LoaderTest) TestQueryNull(c *C) { var dst numLoader _, err := t.db.Query(&dst, "SELECT NULL AS num") c.Assert(err, IsNil) c.Assert(dst.Num, Equals, 0) } func (t *LoaderTest) TestQueryEmbeddedStruct(c *C) { src := &embeddedLoader{ numLoader: &numLoader{ Num: 1, }, Num2: 2, } dst := &embeddedLoader{ numLoader: &numLoader{}, } _, err := t.db.QueryOne(dst, "SELECT ?num AS num, ?num2 as num2", src) c.Assert(err, IsNil) c.Assert(dst, DeepEquals, src) } func (t *LoaderTest) TestQueryNestedStructs(c *C) { src := &multipleLoader{} src.One.Num = 1 src.Num = 2 dst := &multipleLoader{} _, err := t.db.QueryOne(dst, `SELECT ?one__num AS one__num, ?num as num`, src) c.Assert(err, IsNil) c.Assert(dst, DeepEquals, src) } func (t *LoaderTest) TestQueryStmt(c *C) { stmt, err := t.db.Prepare("SELECT 1 AS num") c.Assert(err, IsNil) defer stmt.Close() dst := &numLoader{} _, err = stmt.Query(dst) c.Assert(err, IsNil) c.Assert(dst.Num, Equals, 1) } func (t *LoaderTest) TestQueryInts(c *C) { var ids pg.Ints _, err := t.db.Query(&ids, "SELECT s.num AS num FROM generate_series(0, 10) AS s(num)") c.Assert(err, IsNil) c.Assert(ids, DeepEquals, pg.Ints{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) } func (t *LoaderTest) TestQueryInts2(c *C) { var ints pg.Ints _, err := t.db.Query(&ints, "SELECT * FROM generate_series(1, 1000000)") c.Assert(err, IsNil) c.Assert(ints, HasLen, 1000000) } func (t *LoaderTest) TestQueryStrings(c *C) { var strings pg.Strings _, err := t.db.Query(&strings, "SELECT 'hello'") c.Assert(err, IsNil) c.Assert(strings, DeepEquals, pg.Strings{"hello"}) } type errLoader string var _ orm.Model = errLoader("") func (errLoader) Reset() error { return nil } func (m errLoader) NewModel() orm.ColumnScanner { return m } func (errLoader) AddModel(_ orm.ColumnScanner) error { return nil } func (errLoader) AfterQuery(_ orm.DB) error { return nil } func (errLoader) AfterSelect(_ orm.DB) error { return nil } func (errLoader) BeforeInsert(_ orm.DB) error { return nil } func (errLoader) AfterInsert(_ orm.DB) error { return nil } func (errLoader) BeforeUpdate(_ orm.DB) error { return nil } func (errLoader) AfterUpdate(_ orm.DB) error { return nil } func (errLoader) BeforeDelete(_ orm.DB) error { return nil } func (errLoader) AfterDelete(_ orm.DB) error { return nil } func (m errLoader) ScanColumn(int, string, []byte) error { return errors.New(string(m)) } func (t *LoaderTest) TestLoaderError(c *C) { tx, err := t.db.Begin() c.Assert(err, IsNil) defer tx.Rollback() loader := errLoader("my error") _, err = tx.QueryOne(loader, "SELECT 1, 2") c.Assert(err.Error(), Equals, "my error") // Verify that client is still functional. var n1, n2 int _, err = tx.QueryOne(pg.Scan(&n1, &n2), "SELECT 1, 2") c.Assert(err, IsNil) c.Assert(n1, Equals, 1) c.Assert(n2, Equals, 2) } pg-5.3.3/main_test.go000066400000000000000000000114311305650307100144420ustar00rootroot00000000000000package pg_test import ( "bytes" "database/sql/driver" "net" "strings" "sync" "sync/atomic" "testing" "time" "gopkg.in/pg.v5" . "github.com/onsi/ginkgo" . "gopkg.in/check.v1" ) func TestUnixSocket(t *testing.T) { opt := pgOptions() opt.Network = "unix" opt.Addr = "/var/run/postgresql/.s.PGSQL.5432" opt.TLSConfig = nil db := pg.Connect(opt) defer db.Close() _, err := db.Exec("SELECT 'test_unix_socket'") if err != nil { t.Fatal(err) } } func TestGocheck(t *testing.T) { TestingT(t) } var _ = Suite(&DBTest{}) type DBTest struct { db *pg.DB } func (t *DBTest) SetUpTest(c *C) { t.db = pg.Connect(pgOptions()) } func (t *DBTest) TearDownTest(c *C) { c.Assert(t.db.Close(), IsNil) } func (t *DBTest) TestQueryZeroRows(c *C) { res, err := t.db.Query(pg.Discard, "SELECT 1 WHERE 1 != 1") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 0) } func (t *DBTest) TestQueryOneErrNoRows(c *C) { _, err := t.db.QueryOne(pg.Discard, "SELECT 1 WHERE 1 != 1") c.Assert(err, Equals, pg.ErrNoRows) } func (t *DBTest) TestQueryOneErrMultiRows(c *C) { _, err := t.db.QueryOne(pg.Discard, "SELECT generate_series(0, 1)") c.Assert(err, Equals, pg.ErrMultiRows) } func (t *DBTest) TestExecOne(c *C) { res, err := t.db.ExecOne("SELECT 'test_exec_one'") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 1) } func (t *DBTest) TestExecOneErrNoRows(c *C) { _, err := t.db.ExecOne("SELECT 1 WHERE 1 != 1") c.Assert(err, Equals, pg.ErrNoRows) } func (t *DBTest) TestExecOneErrMultiRows(c *C) { _, err := t.db.ExecOne("SELECT generate_series(0, 1)") c.Assert(err, Equals, pg.ErrMultiRows) } func (t *DBTest) TestScan(c *C) { var dst int _, err := t.db.QueryOne(pg.Scan(&dst), "SELECT 1") c.Assert(err, IsNil) c.Assert(dst, Equals, 1) } func (t *DBTest) TestExec(c *C) { res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, -1) res, err = t.db.Exec("INSERT INTO test VALUES (1)") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 1) } func (t *DBTest) TestStatementExec(c *C) { res, err := t.db.Exec("CREATE TEMP TABLE test(id serial PRIMARY KEY)") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, -1) stmt, err := t.db.Prepare("INSERT INTO test VALUES($1)") c.Assert(err, IsNil) defer stmt.Close() res, err = stmt.Exec(1) c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 1) } func (t *DBTest) TestLargeWriteRead(c *C) { src := bytes.Repeat([]byte{0x1}, 1e6) var dst []byte _, err := t.db.QueryOne(pg.Scan(&dst), "SELECT ?", src) c.Assert(err, IsNil) c.Assert(dst, DeepEquals, src) } func (t *DBTest) TestIntegrityError(c *C) { _, err := t.db.Exec("DO $$BEGIN RAISE unique_violation USING MESSAGE='foo'; END$$;") c.Assert(err.(pg.Error).IntegrityViolation(), Equals, true) } type customStrSlice []string func (s customStrSlice) Value() (driver.Value, error) { return strings.Join(s, "\n"), nil } func (s *customStrSlice) Scan(v interface{}) error { if v == nil { *s = nil return nil } b := v.([]byte) if len(b) == 0 { *s = []string{} return nil } *s = strings.Split(string(b), "\n") return nil } func (t *DBTest) TestScannerValueOnStruct(c *C) { src := customStrSlice{"foo", "bar"} dst := struct{ Dst customStrSlice }{} _, err := t.db.QueryOne(&dst, "SELECT ? AS dst", src) c.Assert(err, IsNil) c.Assert(dst.Dst, DeepEquals, src) } //------------------------------------------------------------------------------ type badConnError string func (e badConnError) Error() string { return string(e) } func (e badConnError) Timeout() bool { return false } func (e badConnError) Temporary() bool { return false } type badConn struct { net.TCPConn readDelay, writeDelay time.Duration readErr, writeErr error } var _ net.Conn = &badConn{} func (cn *badConn) Read([]byte) (int, error) { if cn.readDelay != 0 { time.Sleep(cn.readDelay) } if cn.readErr != nil { return 0, cn.readErr } return 0, badConnError("bad connection") } func (cn *badConn) Write([]byte) (int, error) { if cn.writeDelay != 0 { time.Sleep(cn.writeDelay) } if cn.writeErr != nil { return 0, cn.writeErr } return 0, badConnError("bad connection") } func perform(n int, cbs ...func(int)) { var wg sync.WaitGroup for _, cb := range cbs { for i := 0; i < n; i++ { wg.Add(1) go func(cb func(int), i int) { defer GinkgoRecover() defer wg.Done() cb(i) }(cb, i) } } wg.Wait() } func eventually(fn func() error, timeout time.Duration) (err error) { done := make(chan struct{}) var exit int32 go func() { for atomic.LoadInt32(&exit) == 0 { err = fn() if err == nil { close(done) return } time.Sleep(timeout / 100) } }() select { case <-done: return nil case <-time.After(timeout): atomic.StoreInt32(&exit, 1) return err } } pg-5.3.3/messages.go000066400000000000000000000533601305650307100142750ustar00rootroot00000000000000package pg import ( "bufio" "crypto/md5" "crypto/tls" "encoding/binary" "encoding/hex" "fmt" "io" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/internal/pool" "gopkg.in/pg.v5/orm" "gopkg.in/pg.v5/types" ) const ( commandCompleteMsg = 'C' errorResponseMsg = 'E' noticeResponseMsg = 'N' parameterStatusMsg = 'S' authenticationOKMsg = 'R' backendKeyDataMsg = 'K' noDataMsg = 'n' passwordMessageMsg = 'p' terminateMsg = 'X' notificationResponseMsg = 'A' describeMsg = 'D' parameterDescriptionMsg = 't' queryMsg = 'Q' readyForQueryMsg = 'Z' emptyQueryResponseMsg = 'I' rowDescriptionMsg = 'T' dataRowMsg = 'D' parseMsg = 'P' parseCompleteMsg = '1' bindMsg = 'B' bindCompleteMsg = '2' executeMsg = 'E' syncMsg = 'S' flushMsg = 'H' closeMsg = 'C' closeCompleteMsg = '3' copyInResponseMsg = 'G' copyOutResponseMsg = 'H' copyDataMsg = 'd' copyDoneMsg = 'c' ) func startup(cn *pool.Conn, user, password, database string) error { writeStartupMsg(cn.Wr, user, database) if err := cn.FlushWriter(); err != nil { return err } for { c, msgLen, err := readMessageType(cn) if err != nil { return err } switch c { case backendKeyDataMsg: processId, err := readInt32(cn) if err != nil { return err } secretKey, err := readInt32(cn) if err != nil { return err } cn.ProcessId = processId cn.SecretKey = secretKey case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return err } case authenticationOKMsg: if err := authenticate(cn, user, password); err != nil { return err } case readyForQueryMsg: _, err := cn.ReadN(msgLen) return err case errorResponseMsg: e, err := readError(cn) if err != nil { return err } return e default: return fmt.Errorf("pg: unknown startup message response: %q", c) } } } func enableSSL(cn *pool.Conn, tlsConf *tls.Config) error { writeSSLMsg(cn.Wr) if err := cn.FlushWriter(); err != nil { return err } c, err := cn.Rd.ReadByte() if err != nil { return err } if c != 'S' { return errSSLNotSupported } cn.SetNetConn(tls.Client(cn.NetConn(), tlsConf)) return nil } func authenticate(cn *pool.Conn, user, password string) error { num, err := readInt32(cn) if err != nil { return err } switch num { case 0: return nil case 3: writePasswordMsg(cn.Wr, password) if err := cn.FlushWriter(); err != nil { return err } c, _, err := readMessageType(cn) if err != nil { return err } switch c { case authenticationOKMsg: num, err := readInt32(cn) if err != nil { return err } if num != 0 { return fmt.Errorf("pg: unexpected authentication code: %d", num) } return nil case errorResponseMsg: e, err := readError(cn) if err != nil { return err } return e default: return fmt.Errorf("pg: unknown password message response: %q", c) } case 5: b, err := cn.ReadN(4) if err != nil { return err } secret := "md5" + md5s(md5s(password+user)+string(b)) writePasswordMsg(cn.Wr, secret) if err := cn.FlushWriter(); err != nil { return err } c, _, err := readMessageType(cn) if err != nil { return err } switch c { case authenticationOKMsg: num, err := readInt32(cn) if err != nil { return err } if num != 0 { return fmt.Errorf("pg: unexpected authentication code: %d", num) } return nil case errorResponseMsg: e, err := readError(cn) if err != nil { return err } return e default: return fmt.Errorf("pg: unknown password message response: %q", c) } default: return fmt.Errorf("pg: unknown authentication message response: %d", num) } } func md5s(s string) string { h := md5.New() h.Write([]byte(s)) return hex.EncodeToString(h.Sum(nil)) } func writeStartupMsg(buf *pool.WriteBuffer, user, database string) { buf.StartMessage(0) buf.WriteInt32(196608) buf.WriteString("user") buf.WriteString(user) buf.WriteString("database") buf.WriteString(database) buf.WriteString("") buf.FinishMessage() } func writeSSLMsg(buf *pool.WriteBuffer) { buf.StartMessage(0) buf.WriteInt32(80877103) buf.FinishMessage() } func writePasswordMsg(buf *pool.WriteBuffer, password string) { buf.StartMessage(passwordMessageMsg) buf.WriteString(password) buf.FinishMessage() } func writeFlushMsg(buf *pool.WriteBuffer) { buf.StartMessage(flushMsg) buf.FinishMessage() } func writeCancelRequestMsg(buf *pool.WriteBuffer, processId, secretKey int32) { buf.StartMessage(0) buf.WriteInt32(80877102) buf.WriteInt32(processId) buf.WriteInt32(secretKey) buf.FinishMessage() } func writeQueryMsg(buf *pool.WriteBuffer, fmter orm.QueryFormatter, query interface{}, params ...interface{}) error { buf.StartMessage(queryMsg) bytes, err := appendQuery(buf.Bytes, fmter, query, params...) if err != nil { buf.Reset() return err } if internal.QueryLogger != nil { internal.LogQuery(string(bytes[5:])) } buf.Bytes = bytes buf.WriteByte(0x0) buf.FinishMessage() return nil } func appendQuery(dst []byte, fmter orm.QueryFormatter, query interface{}, params ...interface{}) ([]byte, error) { switch query := query.(type) { case orm.QueryAppender: return query.AppendQuery(dst, params...) case string: return fmter.FormatQuery(dst, query, params...), nil default: return nil, fmt.Errorf("pg: can't append %T", query) } } func writeSyncMsg(buf *pool.WriteBuffer) { buf.StartMessage(syncMsg) buf.FinishMessage() } func writeParseDescribeSyncMsg(buf *pool.WriteBuffer, name, q string) { buf.StartMessage(parseMsg) buf.WriteString(name) buf.WriteString(q) buf.WriteInt16(0) buf.FinishMessage() buf.StartMessage(describeMsg) buf.WriteByte('S') buf.WriteString(name) buf.FinishMessage() writeSyncMsg(buf) } func readParseDescribeSync(cn *pool.Conn) ([][]byte, error) { var columns [][]byte for { c, msgLen, err := readMessageType(cn) if err != nil { return nil, err } switch c { case parseCompleteMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, err } case rowDescriptionMsg: // Response to the DESCRIBE message. columns, err = readRowDescription(cn, nil) if err != nil { return nil, err } case parameterDescriptionMsg: // Response to the DESCRIBE message. _, err := cn.ReadN(msgLen) if err != nil { return nil, err } case noDataMsg: // Response to the DESCRIBE message. _, err := cn.ReadN(msgLen) if err != nil { return nil, err } case readyForQueryMsg: _, err := cn.ReadN(msgLen) return columns, err case errorResponseMsg: e, err := readError(cn) if err != nil { return nil, err } return nil, e case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return nil, err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return nil, err } default: return nil, fmt.Errorf("pg: readParseDescribeSync: unexpected message %#x", c) } } } // Writes BIND, EXECUTE and SYNC messages. func writeBindExecuteMsg(buf *pool.WriteBuffer, name string, params ...interface{}) error { const paramLenWidth = 4 buf.StartMessage(bindMsg) buf.WriteString("") buf.WriteString(name) buf.WriteInt16(0) buf.WriteInt16(int16(len(params))) for _, param := range params { buf.StartParam() bytes := types.Append(buf.Bytes, param, 0) if bytes != nil { buf.Bytes = bytes buf.FinishParam() } else { buf.FinishNullParam() } } buf.WriteInt16(0) buf.FinishMessage() buf.StartMessage(executeMsg) buf.WriteString("") buf.WriteInt32(0) buf.FinishMessage() writeSyncMsg(buf) return nil } func readBindMsg(cn *pool.Conn) error { for { c, msgLen, err := readMessageType(cn) if err != nil { return err } switch c { case bindCompleteMsg: _, err := cn.ReadN(msgLen) if err != nil { return err } case readyForQueryMsg: // This is response to the SYNC message. _, err := cn.ReadN(msgLen) return err case errorResponseMsg: e, err := readError(cn) if err != nil { return err } return e case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return err } default: return fmt.Errorf("pg: readBindMsg: unexpected message %#x", c) } } } func writeCloseMsg(buf *pool.WriteBuffer, name string) { buf.StartMessage(closeMsg) buf.WriteByte('S') buf.WriteString(name) buf.FinishMessage() } func readCloseCompleteMsg(cn *pool.Conn) error { for { c, msgLen, err := readMessageType(cn) if err != nil { return err } switch c { case closeCompleteMsg: _, err := cn.ReadN(msgLen) return err case errorResponseMsg: e, err := readError(cn) if err != nil { return err } return e case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return err } default: return fmt.Errorf("pg: readCloseCompleteMsg: unexpected message %#x", c) } } } func readSimpleQuery(cn *pool.Conn) (res *types.Result, retErr error) { setErr := func(err error) { if retErr == nil { retErr = err } } var rows int for { c, msgLen, err := readMessageType(cn) if err != nil { return nil, err } switch c { case commandCompleteMsg: b, err := cn.ReadN(msgLen) if err != nil { return nil, err } res = types.NewResult(b, rows) case readyForQueryMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, err } return res, retErr case rowDescriptionMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, err } case dataRowMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, err } rows++ case errorResponseMsg: e, err := readError(cn) if err != nil { return nil, err } setErr(e) case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return nil, err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return nil, err } default: return nil, fmt.Errorf("pg: readSimpleQuery: unexpected message %#x", c) } } } func readExtQuery(cn *pool.Conn) (res *types.Result, retErr error) { setErr := func(err error) { if retErr == nil { retErr = err } } var rows int for { c, msgLen, err := readMessageType(cn) if err != nil { return nil, err } switch c { case bindCompleteMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, err } case dataRowMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, err } rows++ case commandCompleteMsg: // Response to the EXECUTE message. b, err := cn.ReadN(msgLen) if err != nil { return nil, err } res = types.NewResult(b, rows) case readyForQueryMsg: // Response to the SYNC message. _, err := cn.ReadN(msgLen) if err != nil { return nil, err } return res, retErr case errorResponseMsg: e, err := readError(cn) if err != nil { return nil, err } setErr(e) case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return nil, err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return nil, err } default: return nil, fmt.Errorf("pg: readExtQuery: unexpected message %#x", c) } } } func readRowDescription(cn *pool.Conn, columns [][]byte) ([][]byte, error) { colNum, err := readInt16(cn) if err != nil { return nil, err } columns = setByteSliceLen(columns, int(colNum)) for i := 0; i < int(colNum); i++ { columns[i], err = readBytes(cn, columns[i][:0]) if err != nil { return nil, err } if _, err := cn.ReadN(18); err != nil { return nil, err } } return columns, nil } func setByteSliceLen(b [][]byte, n int) [][]byte { if n <= cap(b) { return b[:n] } b = b[:cap(b)] b = append(b, make([][]byte, n-cap(b))...) return b } func readDataRow(cn *pool.Conn, scanner orm.ColumnScanner, columns [][]byte) (retErr error) { setErr := func(err error) { if retErr == nil { retErr = err } } colNum, err := readInt16(cn) if err != nil { return err } for colIdx := int16(0); colIdx < colNum; colIdx++ { l, err := readInt32(cn) if err != nil { return err } var b []byte if l != -1 { // NULL b, err = cn.ReadN(int(l)) if err != nil { return err } } column := internal.BytesToString(columns[colIdx]) if err := scanner.ScanColumn(int(colIdx), column, b); err != nil { setErr(err) } } return retErr } func newModel(mod interface{}) (orm.Model, error) { m, ok := mod.(orm.Model) if ok { return m, m.Reset() } m, err := orm.NewModel(mod) if err != nil { return nil, err } return m, m.Reset() } func readSimpleQueryData( cn *pool.Conn, mod interface{}, ) (res *types.Result, model orm.Model, retErr error) { setErr := func(err error) { if retErr == nil { retErr = err } } var rows int for { c, msgLen, err := readMessageType(cn) if err != nil { return nil, nil, err } switch c { case rowDescriptionMsg: cn.Columns, err = readRowDescription(cn, cn.Columns[:0]) if err != nil { return nil, nil, err } if model == nil { var err error model, err = newModel(mod) if err != nil { setErr(err) model = Discard } } case dataRowMsg: m := model.NewModel() if err := readDataRow(cn, m, cn.Columns); err != nil { setErr(err) } else { if err := model.AddModel(m); err != nil { setErr(err) } } rows++ case commandCompleteMsg: b, err := cn.ReadN(msgLen) if err != nil { return nil, nil, err } res = types.NewResult(b, rows) case readyForQueryMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, nil, err } return res, model, retErr case errorResponseMsg: e, err := readError(cn) if err != nil { return nil, nil, err } setErr(e) case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return nil, nil, err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return nil, nil, err } default: return nil, nil, fmt.Errorf("pg: readSimpleQueryData: unexpected message %#x", c) } } } func readExtQueryData( cn *pool.Conn, mod interface{}, columns [][]byte, ) (res *types.Result, model orm.Model, retErr error) { setErr := func(err error) { if retErr == nil { retErr = err } } var rows int for { c, msgLen, err := readMessageType(cn) if err != nil { return nil, nil, err } switch c { case bindCompleteMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, nil, err } case dataRowMsg: if model == nil { var err error model, err = newModel(mod) if err != nil { setErr(err) model = Discard } } m := model.NewModel() if err := readDataRow(cn, m, columns); err != nil { setErr(err) } else { if err := model.AddModel(m); err != nil { setErr(err) } } rows++ case commandCompleteMsg: // Response to the EXECUTE message. b, err := cn.ReadN(msgLen) if err != nil { return nil, nil, err } res = types.NewResult(b, rows) case readyForQueryMsg: // Response to the SYNC message. _, err := cn.ReadN(msgLen) if err != nil { return nil, nil, err } return res, model, retErr case errorResponseMsg: e, err := readError(cn) if err != nil { return nil, nil, err } setErr(e) case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return nil, nil, err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return nil, nil, err } default: return nil, nil, fmt.Errorf("pg: readExtQueryData: unexpected message %#x", c) } } } func readCopyInResponse(cn *pool.Conn) error { for { c, msgLen, err := readMessageType(cn) if err != nil { return err } switch c { case copyInResponseMsg: _, err := cn.ReadN(msgLen) return err case errorResponseMsg: e, err := readError(cn) if err != nil { return err } return e case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return err } default: return fmt.Errorf("pg: readCopyInResponse: unexpected message %#x", c) } } } func readCopyOutResponse(cn *pool.Conn) error { for { c, msgLen, err := readMessageType(cn) if err != nil { return err } switch c { case copyOutResponseMsg: _, err := cn.ReadN(msgLen) return err case errorResponseMsg: e, err := readError(cn) if err != nil { return err } return e case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return err } default: return fmt.Errorf("pg: readCopyOutResponse: unexpected message %#x", c) } } } func readCopyData(cn *pool.Conn, w io.Writer) (*types.Result, error) { var res *types.Result for { c, msgLen, err := readMessageType(cn) if err != nil { return nil, err } switch c { case copyDataMsg: b, err := cn.ReadN(msgLen) if err != nil { return nil, err } _, err = w.Write(b) if err != nil { return nil, err } case copyDoneMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, err } case commandCompleteMsg: b, err := cn.ReadN(msgLen) if err != nil { return nil, err } res = types.NewResult(b, 0) case readyForQueryMsg: _, err := cn.ReadN(msgLen) return res, err case errorResponseMsg: e, err := readError(cn) if err != nil { return nil, err } return nil, e case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return nil, err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return nil, err } default: return nil, fmt.Errorf("pg: readCopyData: unexpected message %#x", c) } } } func writeCopyData(buf *pool.WriteBuffer, r io.Reader) (int64, error) { buf.StartMessage(copyDataMsg) n, err := buf.ReadFrom(r) buf.FinishMessage() return n, err } func writeCopyDone(buf *pool.WriteBuffer) { buf.StartMessage(copyDoneMsg) buf.FinishMessage() } func readReadyForQuery(cn *pool.Conn) (res *types.Result, retErr error) { for { c, msgLen, err := readMessageType(cn) if err != nil { return nil, err } switch c { case commandCompleteMsg: b, err := cn.ReadN(msgLen) if err != nil { return nil, err } res = types.NewResult(b, 0) case readyForQueryMsg: _, err := cn.ReadN(msgLen) if err != nil { return nil, err } return res, retErr case errorResponseMsg: e, err := readError(cn) if err != nil { return nil, err } retErr = e case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return nil, err } case parameterStatusMsg: if err := logParameterStatus(cn, msgLen); err != nil { return nil, err } default: return nil, fmt.Errorf("pg: readReadyForQueryOrError: unexpected message %#x", c) } } } func readNotification(cn *pool.Conn) (channel, payload string, err error) { for { c, msgLen, err := readMessageType(cn) if err != nil { return "", "", err } switch c { case commandCompleteMsg: _, err := cn.ReadN(msgLen) if err != nil { return "", "", err } case readyForQueryMsg: _, err := cn.ReadN(msgLen) if err != nil { return "", "", err } case errorResponseMsg: e, err := readError(cn) if err != nil { return "", "", err } return "", "", e case noticeResponseMsg: if err := logNotice(cn, msgLen); err != nil { return "", "", err } case notificationResponseMsg: _, err := readInt32(cn) if err != nil { return "", "", err } channel, err = readString(cn) if err != nil { return "", "", err } payload, err = readString(cn) if err != nil { return "", "", err } return channel, payload, nil default: return "", "", fmt.Errorf("pg: unexpected message %q", c) } } } var terminateMessage = []byte{terminateMsg, 0, 0, 0, 4} func terminateConn(cn *pool.Conn) error { // Don't use cn.Buf because it is racy with user code. _, err := cn.NetConn().Write(terminateMessage) return err } //------------------------------------------------------------------------------ func readInt16(cn *pool.Conn) (int16, error) { b, err := cn.ReadN(2) if err != nil { return 0, err } return int16(binary.BigEndian.Uint16(b)), nil } func readInt32(cn *pool.Conn) (int32, error) { b, err := cn.ReadN(4) if err != nil { return 0, err } return int32(binary.BigEndian.Uint32(b)), nil } func readString(cn *pool.Conn) (string, error) { s, err := cn.Rd.ReadString(0) if err != nil { return "", err } return s[:len(s)-1], nil } func readBytes(cn *pool.Conn, b []byte) ([]byte, error) { for { line, err := cn.Rd.ReadSlice(0) if err != nil && err != bufio.ErrBufferFull { return nil, err } b = append(b, line...) if err == nil { break } } return b[:len(b)-1], nil } func readError(cn *pool.Conn) (error, error) { m := map[byte]string{ 'a': cn.RemoteAddr().String(), } for { c, err := cn.Rd.ReadByte() if err != nil { return nil, err } if c == 0 { break } s, err := readString(cn) if err != nil { return nil, err } m[c] = s } return internal.NewPGError(m), nil } func readMessageType(cn *pool.Conn) (byte, int, error) { c, err := cn.Rd.ReadByte() if err != nil { return 0, 0, err } l, err := readInt32(cn) if err != nil { return 0, 0, err } return c, int(l) - 4, nil } func logNotice(cn *pool.Conn, msgLen int) error { _, err := cn.ReadN(msgLen) return err } func logParameterStatus(cn *pool.Conn, msgLen int) error { _, err := cn.ReadN(msgLen) return err } pg-5.3.3/options.go000066400000000000000000000114751305650307100141620ustar00rootroot00000000000000package pg import ( "crypto/tls" "errors" "fmt" "net" "net/url" "strings" "time" "gopkg.in/pg.v5/internal/pool" ) // Database connection options. type Options struct { // Network type, either tcp or unix. // Default is tcp. Network string // TCP host:port or Unix socket depending on Network. Addr string // Dialer creates new network connection and has priority over // Network and Addr options. Dialer func(network, addr string) (net.Conn, error) User string Password string Database string // TLS config for secure connections. TLSConfig *tls.Config // Maximum number of retries before giving up. // Default is to not retry failed queries. MaxRetries int // Whether to retry queries cancelled because of statement_timeout. RetryStatementTimeout bool // Dial timeout for establishing new connections. // Default is 5 seconds. DialTimeout time.Duration // Timeout for socket reads. If reached, commands will fail // with a timeout instead of blocking. ReadTimeout time.Duration // Timeout for socket writes. If reached, commands will fail // with a timeout instead of blocking. WriteTimeout time.Duration // Maximum number of socket connections. // Default is 20 connections. PoolSize int // Time for which client waits for free connection if all // connections are busy before returning an error. // Default is 5 seconds. PoolTimeout time.Duration // Time after which client closes idle connections. // Default is to not close idle connections. IdleTimeout time.Duration // Connection age at which client retires (closes) the connection. // Primarily useful with proxies like HAProxy. // Default is to not close aged connections. MaxAge time.Duration // Frequency of idle checks. // Default is 1 minute. IdleCheckFrequency time.Duration // When true Tx does not issue BEGIN, COMMIT, or ROLLBACK. // Also underlying database connection is immediately returned to the pool. // This is primarily useful for running your database tests in one big // transaction, because PostgreSQL does not support nested transactions. DisableTransaction bool } func (opt *Options) init() { if opt.Network == "" { opt.Network = "tcp" } if opt.Addr == "" { switch opt.Network { case "tcp": opt.Addr = "localhost:5432" case "unix": opt.Addr = "/var/run/postgresql/.s.PGSQL.5432" } } if opt.PoolSize == 0 { opt.PoolSize = 20 } if opt.PoolTimeout == 0 { if opt.ReadTimeout != 0 { opt.PoolTimeout = opt.ReadTimeout + time.Second } else { opt.PoolTimeout = 30 * time.Second } } if opt.DialTimeout == 0 { opt.DialTimeout = 5 * time.Second } if opt.IdleCheckFrequency == 0 { opt.IdleCheckFrequency = time.Minute } } // ParseURL parses an URL into options that can be used to connect to PostgreSQL. func ParseURL(sURL string) (*Options, error) { parsedUrl, err := url.Parse(sURL) if err != nil { return nil, err } // scheme if parsedUrl.Scheme != "postgres" { return nil, errors.New("pg: invalid scheme: " + parsedUrl.Scheme) } // host and port options := &Options{ Addr: parsedUrl.Host, } if !strings.Contains(options.Addr, ":") { options.Addr = options.Addr + ":5432" } // username and password if parsedUrl.User != nil { options.User = parsedUrl.User.Username() if password, ok := parsedUrl.User.Password(); ok { options.Password = password } } if options.User == "" { options.User = "postgres" } // database if len(strings.Trim(parsedUrl.Path, "/")) > 0 { options.Database = parsedUrl.Path[1:] } else { return nil, errors.New("pg: database name not provided") } // ssl mode query, err := url.ParseQuery(parsedUrl.RawQuery) if err != nil { return nil, err } if sslMode, ok := query["sslmode"]; ok && len(sslMode) > 0 { switch sslMode[0] { case "allow": fallthrough case "prefer": options.TLSConfig = &tls.Config{InsecureSkipVerify: true} case "disable": options.TLSConfig = nil default: return nil, errors.New(fmt.Sprintf("pg: sslmode '%v' is not supported", sslMode[0])) } } else { options.TLSConfig = &tls.Config{InsecureSkipVerify: true} } delete(query, "sslmode") if len(query) > 0 { return nil, errors.New("pg: options other than 'sslmode' are not supported") } return options, nil } func (opt *Options) getDialer() func() (net.Conn, error) { if opt.Dialer != nil { return func() (net.Conn, error) { return opt.Dialer(opt.Network, opt.Addr) } } return func() (net.Conn, error) { return net.DialTimeout(opt.Network, opt.Addr, opt.DialTimeout) } } func newConnPool(opt *Options) *pool.ConnPool { return pool.NewConnPool(&pool.Options{ Dial: opt.getDialer(), PoolSize: opt.PoolSize, PoolTimeout: opt.PoolTimeout, IdleTimeout: opt.IdleTimeout, IdleCheckFrequency: opt.IdleCheckFrequency, OnClose: func(cn *pool.Conn) error { return terminateConn(cn) }, }) } pg-5.3.3/options_test.go000066400000000000000000000074221305650307100152160ustar00rootroot00000000000000// +build go1.7 package pg import ( "errors" "testing" ) func TestParseURL(t *testing.T) { cases := []struct { url string addr string user string password string database string tls bool err error }{ { "postgres://vasya:pupkin@somewhere.at.amazonaws.com:5432/postgres", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "postgres", true, nil, }, { "postgres://vasya:pupkin@somewhere.at.amazonaws.com:5432/postgres?sslmode=allow", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "postgres", true, nil, }, { "postgres://vasya:pupkin@somewhere.at.amazonaws.com:5432/postgres?sslmode=prefer", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "postgres", true, nil, }, { "postgres://vasya:pupkin@somewhere.at.amazonaws.com:5432/postgres?sslmode=require", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "postgres", true, errors.New("pg: sslmode 'require' is not supported"), }, { "postgres://vasya:pupkin@somewhere.at.amazonaws.com:5432/postgres?sslmode=verify-ca", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "postgres", true, errors.New("pg: sslmode 'verify-ca' is not supported"), }, { "postgres://vasya:pupkin@somewhere.at.amazonaws.com:5432/postgres?sslmode=verify-full", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "postgres", true, errors.New("pg: sslmode 'verify-full' is not supported"), }, { "postgres://vasya:pupkin@somewhere.at.amazonaws.com:5432/postgres?sslmode=disable", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "postgres", false, nil, }, { "postgres://vasya:pupkin@somewhere.at.amazonaws.com:5432/", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "", true, errors.New("pg: database name not provided"), }, { "postgres://vasya:pupkin@somewhere.at.amazonaws.com/postgres", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "postgres", true, nil, }, { "postgres://vasya:pupkin@somewhere.at.amazonaws.com:5432/postgres?abc=123", "somewhere.at.amazonaws.com:5432", "vasya", "pupkin", "postgres", true, errors.New("pg: options other than 'sslmode' are not supported"), }, { "postgres://vasya@somewhere.at.amazonaws.com:5432/postgres", "somewhere.at.amazonaws.com:5432", "vasya", "", "postgres", true, nil, }, { "postgres://somewhere.at.amazonaws.com:5432/postgres", "somewhere.at.amazonaws.com:5432", "postgres", "", "postgres", true, nil, }, { "http://google.com/test", "google.com:5432", "postgres", "", "test", true, errors.New("pg: invalid scheme: http"), }, } for _, c := range cases { t.Run(c.url, func(t *testing.T) { o, err := ParseURL(c.url) if c.err == nil && err != nil { t.Fatalf("unexpected error: '%q'", err) return } if c.err != nil && err != nil { if c.err.Error() != err.Error() { t.Fatalf("expected error %q, want %q", err, c.err) } return } if c.err != nil && err == nil { t.Errorf("expected error %q, got nothing", c.err) } if o.Addr != c.addr { t.Errorf("addr: got %q, want %q", o.Addr, c.addr) } if o.User != c.user { t.Errorf("user: got %q, want %q", o.User, c.user) } if o.Password != c.password { t.Errorf("password: got %q, want %q", o.Password, c.password) } if o.Database != c.database { t.Errorf("database: got %q, want %q", o.Database, c.database) } if c.tls { if o.TLSConfig == nil { t.Error("got nil TLSConfig, expected a TLSConfig") } else if !o.TLSConfig.InsecureSkipVerify { t.Error("must set InsecureSkipVerify to true in TLSConfig, got false") } } }) } } pg-5.3.3/orm/000077500000000000000000000000001305650307100127255ustar00rootroot00000000000000pg-5.3.3/orm/count_estimate.go000066400000000000000000000047531305650307100163100ustar00rootroot00000000000000package orm import ( "fmt" "sync" "gopkg.in/pg.v5/internal" ) // Placeholder that is replaced with count(*). const placeholder = `'_go_pg_placeholder'` // https://wiki.postgresql.org/wiki/Count_estimate var pgCountEstimateFunc = fmt.Sprintf(` CREATE OR REPLACE FUNCTION _go_pg_count_estimate_v2(query text, threshold int) RETURNS int AS $$ DECLARE rec record; nrows int; BEGIN FOR rec IN EXECUTE 'EXPLAIN ' || query LOOP nrows := substring(rec."QUERY PLAN" FROM ' rows=(\d+)'); EXIT WHEN nrows IS NOT NULL; END LOOP; -- Return the estimation if there are too many rows. IF nrows > threshold THEN RETURN nrows; END IF; -- Otherwise execute real count query. query := replace(query, 'SELECT '%s'', 'SELECT count(*)'); EXECUTE query INTO nrows; IF nrows IS NULL THEN nrows := 0; END IF; RETURN nrows; END; $$ LANGUAGE plpgsql; `, placeholder) // CountEstimate uses EXPLAIN to get estimated number of rows matching the query. // If that number is bigger than the threshold it returns the estimation. // Otherwise it executes another query using count aggregate function and // returns the result. // // Based on https://wiki.postgresql.org/wiki/Count_estimate func (q *Query) CountEstimate(threshold int) (int, error) { if q.stickyErr != nil { return 0, q.stickyErr } query, err := q.countQuery().countSelectQuery(placeholder).AppendQuery(nil) if err != nil { return 0, err } for i := 0; i < 3; i++ { var count int _, err = q.db.QueryOne( Scan(&count), "SELECT _go_pg_count_estimate_v2(?, ?)", string(query), threshold, ) if err != nil { if pgerr, ok := err.(internal.PGError); ok && pgerr.Field('C') == "42883" { // undefined_function if err := q.createCountEstimateFunc(); err != nil { return 0, err } continue } } return count, err } return 0, err } func (q *Query) createCountEstimateFunc() error { _, err := q.db.Exec(pgCountEstimateFunc) return err } // SelectAndCountEstimate runs Select and CountEstimate in two goroutines, // waits for them to finish and returns the result. func (q *Query) SelectAndCountEstimate(threshold int, values ...interface{}) (count int, err error) { if q.stickyErr != nil { return 0, q.stickyErr } var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() if e := q.Select(values...); e != nil { err = e } }() go func() { defer wg.Done() var e error count, e = q.CountEstimate(threshold) if e != nil { err = e } }() wg.Wait() return count, err } pg-5.3.3/orm/create_table.go000066400000000000000000000030201305650307100156610ustar00rootroot00000000000000package orm import ( "fmt" "reflect" "gopkg.in/pg.v5/types" ) type CreateTableOptions struct { Temp bool } func CreateTable(db DB, model interface{}, opt *CreateTableOptions) (*types.Result, error) { return db.Exec(createTableQuery{model: model, opt: opt}) } type createTableQuery struct { model interface{} opt *CreateTableOptions } func (c createTableQuery) AppendQuery(b []byte, params ...interface{}) ([]byte, error) { typ := reflect.TypeOf(c.model) switch typ.Kind() { case reflect.Ptr: typ = typ.Elem() } if typ.Kind() != reflect.Struct { return nil, fmt.Errorf("pg: Model(unsupported %s)", typ) } table := Tables.Get(typ) b = append(b, "CREATE "...) if c.opt != nil && c.opt.Temp { b = append(b, "TEMP "...) } b = append(b, "TABLE "...) b = append(b, table.Name...) b = append(b, " ("...) for i, field := range table.Fields { b = append(b, field.SQLName...) b = append(b, " "...) b = append(b, field.SQLType...) if field.Has(NotNullFlag) { b = append(b, " NOT NULL"...) } if field.Has(UniqueFlag) { b = append(b, " UNIQUE"...) } if i != len(table.Fields)-1 { b = append(b, ", "...) } } b = appendPKConstraint(b, table.PKs) b = append(b, ")"...) return b, nil } func appendPKConstraint(b []byte, primaryKeys []*Field) []byte { if len(primaryKeys) == 0 { return b } b = append(b, ", PRIMARY KEY ("...) for i, pk := range primaryKeys { b = append(b, pk.SQLName...) if i != len(primaryKeys)-1 { b = append(b, ", "...) } } b = append(b, ")"...) return b } pg-5.3.3/orm/create_table_test.go000066400000000000000000000032371305650307100167320ustar00rootroot00000000000000package orm import ( "database/sql" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type CreateTableModel struct { Id int Int8 int8 Uint8 uint8 Int16 int16 Uint16 uint16 Int32 int32 Uint32 uint32 Int64 int64 Uint64 uint64 Float32 float32 Float64 float64 String string Varchar string `sql:",type:varchar(500)"` Time time.Time NotNull int `sql:",notnull"` Unique int `sql:",unique"` NullBool sql.NullBool NullFloat64 sql.NullFloat64 NullInt64 sql.NullInt64 NullString sql.NullString Slice []int Map map[int]int Struct struct{} } type CreateTableWithoutPKModel struct { String string } var _ = Describe("CreateTable", func() { It("creates new table", func() { b, err := createTableQuery{model: CreateTableModel{}}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`CREATE TABLE "create_table_models" (id bigserial, int8 smallint, uint8 smallint, int16 smallint, uint16 integer, int32 integer, uint32 bigint, int64 bigint, uint64 decimal, float32 real, float64 double precision, string text, varchar varchar(500), time timestamptz, not_null bigint NOT NULL, unique bigint UNIQUE, null_bool boolean, null_float64 double precision, null_int64 bigint, null_string text, slice jsonb, map jsonb, struct jsonb, PRIMARY KEY (id))`)) }) It("creates new table without primary key", func() { b, err := createTableQuery{model: CreateTableWithoutPKModel{}}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`CREATE TABLE "create_table_without_pk_models" (string text)`)) }) }) pg-5.3.3/orm/delete.go000066400000000000000000000014321305650307100145160ustar00rootroot00000000000000package orm import "gopkg.in/pg.v5/internal" func Delete(db DB, model interface{}) error { res, err := NewQuery(db, model).Delete() if err != nil { return err } return internal.AssertOneRow(res.RowsAffected()) } type deleteQuery struct { *Query } var _ QueryAppender = (*deleteQuery)(nil) func (q deleteQuery) AppendQuery(b []byte, params ...interface{}) ([]byte, error) { var err error if len(q.with) > 0 { b, err = q.appendWith(b, "") if err != nil { return nil, err } } b = append(b, "DELETE FROM "...) b = q.appendFirstTable(b) if q.hasOtherTables() { b = append(b, " USING "...) b = q.appendOtherTables(b) } b, err = q.mustAppendWhere(b) if err != nil { return nil, err } if len(q.returning) > 0 { b = q.appendReturning(b) } return b, nil } pg-5.3.3/orm/delete_test.go000066400000000000000000000011321305650307100155520ustar00rootroot00000000000000package orm import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type DeleteTest struct{} var _ = Describe("Delete", func() { It("supports WITH", func() { q := NewQuery(nil, &DeleteTest{}). WrapWith("wrapper"). Model(&DeleteTest{}). Table("wrapper"). Where("delete_test.id = wrapper.id") b, err := deleteQuery{q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`WITH "wrapper" AS (SELECT FROM "delete_tests" AS "delete_test") DELETE FROM "delete_tests" AS "delete_test" USING "wrapper" WHERE (delete_test.id = wrapper.id)`)) }) }) pg-5.3.3/orm/field.go000066400000000000000000000032051305650307100143370ustar00rootroot00000000000000package orm import ( "reflect" "gopkg.in/pg.v5/types" ) const ( PrimaryKeyFlag = 1 << iota ForeignKeyFlag NotNullFlag UniqueFlag ) type Field struct { Type reflect.Type GoName string // struct field name, e.g. Id ColName types.Q SQLName string // SQL name, .e.g. id SQLType string Index []int flags uint8 append types.AppenderFunc scan types.ScannerFunc isEmpty func(reflect.Value) bool } func (f *Field) Copy() *Field { copy := *f copy.Index = copy.Index[:len(f.Index):len(f.Index)] return © } func (f *Field) Has(flag uint8) bool { return f.flags&flag != 0 } func (f *Field) Value(strct reflect.Value) reflect.Value { return strct.FieldByIndex(f.Index) } func (f *Field) IsEmpty(strct reflect.Value) bool { fv := f.Value(strct) return f.isEmpty(fv) } func (f *Field) OmitEmpty(strct reflect.Value) bool { return !f.Has(NotNullFlag) && f.isEmpty(f.Value(strct)) } func (f *Field) AppendValue(b []byte, strct reflect.Value, quote int) []byte { fv := f.Value(strct) if !f.Has(NotNullFlag) && f.isEmpty(fv) { return types.AppendNull(b, quote) } return f.append(b, fv, quote) } func (f *Field) ScanValue(strct reflect.Value, b []byte) error { fv := fieldByIndex(strct, f.Index) return f.scan(fv, b) } type Method struct { Index int flags int8 appender func([]byte, reflect.Value, int) []byte } func (m *Method) Has(flag int8) bool { return m.flags&flag != 0 } func (m *Method) Value(strct reflect.Value) reflect.Value { return strct.Method(m.Index).Call(nil)[0] } func (m *Method) AppendValue(dst []byte, strct reflect.Value, quote int) []byte { mv := m.Value(strct) return m.appender(dst, mv, quote) } pg-5.3.3/orm/format.go000066400000000000000000000115501305650307100145460ustar00rootroot00000000000000package orm import ( "bytes" "fmt" "sort" "strconv" "strings" "gopkg.in/pg.v5/internal/parser" "gopkg.in/pg.v5/types" ) type FormatAppender interface { AppendFormat([]byte, QueryFormatter) []byte } type sepFormatAppender interface { FormatAppender AppendSep([]byte) []byte } //------------------------------------------------------------------------------ type queryParamsAppender struct { query string params []interface{} } var _ FormatAppender = (*queryParamsAppender)(nil) func Q(query string, params ...interface{}) FormatAppender { return queryParamsAppender{query, params} } func (q queryParamsAppender) AppendFormat(b []byte, f QueryFormatter) []byte { return f.FormatQuery(b, q.query, q.params...) } //------------------------------------------------------------------------------ type whereAppender struct { conj string query string params []interface{} } var _ FormatAppender = (*whereAppender)(nil) func (q whereAppender) AppendSep(b []byte) []byte { return append(b, q.conj...) } func (q whereAppender) AppendFormat(b []byte, f QueryFormatter) []byte { b = append(b, '(') b = f.FormatQuery(b, q.query, q.params...) b = append(b, ')') return b } //------------------------------------------------------------------------------ type fieldAppender struct { field string } var _ FormatAppender = (*fieldAppender)(nil) func (a fieldAppender) AppendFormat(b []byte, f QueryFormatter) []byte { return types.AppendField(b, a.field, 1) } //------------------------------------------------------------------------------ type Formatter struct { namedParams map[string]interface{} } func (f Formatter) String() string { if len(f.namedParams) == 0 { return "" } var keys []string for k, _ := range f.namedParams { keys = append(keys, k) } sort.Strings(keys) var ss []string for _, k := range keys { ss = append(ss, fmt.Sprintf("%s=%v", k, f.namedParams[k])) } return " " + strings.Join(ss, " ") } func (f Formatter) Copy() Formatter { var cp Formatter for param, value := range f.namedParams { cp.SetParam(param, value) } return cp } func (f *Formatter) SetParam(param string, value interface{}) { if f.namedParams == nil { f.namedParams = make(map[string]interface{}) } f.namedParams[param] = value } func (f *Formatter) WithParam(param string, value interface{}) Formatter { cp := f.Copy() cp.SetParam(param, value) return cp } func (f Formatter) Append(dst []byte, src string, params ...interface{}) []byte { if (params == nil && f.namedParams == nil) || strings.IndexByte(src, '?') == -1 { return append(dst, src...) } return f.append(dst, parser.NewString(src), params) } func (f Formatter) AppendBytes(dst, src []byte, params ...interface{}) []byte { if (params == nil && f.namedParams == nil) || bytes.IndexByte(src, '?') == -1 { return append(dst, src...) } return f.append(dst, parser.New(src), params) } func (f Formatter) FormatQuery(dst []byte, query string, params ...interface{}) []byte { return f.Append(dst, query, params...) } func (f Formatter) append(dst []byte, p *parser.Parser, params []interface{}) []byte { var paramsIndex int var namedParams *tableParams var namedParamsInit bool var model tableModel if len(params) > 0 { var ok bool model, ok = params[len(params)-1].(tableModel) if ok { params = params[:len(params)-1] } } for p.Valid() { b, ok := p.ReadSep('?') if !ok { dst = append(dst, b...) continue } if len(b) > 0 && b[len(b)-1] == '\\' { dst = append(dst, b[:len(b)-1]...) dst = append(dst, '?') continue } dst = append(dst, b...) if id, numeric := p.ReadIdentifier(); id != "" { if numeric { idx, err := strconv.Atoi(id) if err != nil { goto restore_param } if idx >= len(params) { goto restore_param } dst = f.appendParam(dst, params[idx]) continue } if f.namedParams != nil { if param, ok := f.namedParams[id]; ok { dst = f.appendParam(dst, param) continue } } if !namedParamsInit && len(params) > 0 { namedParams, ok = newTableParams(params[len(params)-1]) if ok { params = params[:len(params)-1] } namedParamsInit = true } if namedParams != nil { dst, ok = namedParams.AppendParam(dst, id) if ok { continue } } if model != nil { dst, ok = model.AppendParam(dst, id) if ok { continue } } restore_param: dst = append(dst, '?') dst = append(dst, id...) continue } if paramsIndex >= len(params) { dst = append(dst, '?') continue } param := params[paramsIndex] paramsIndex++ if fa, ok := param.(FormatAppender); ok { dst = fa.AppendFormat(dst, f) } else { dst = types.Append(dst, param, 1) } } return dst } func (f Formatter) appendParam(b []byte, param interface{}) []byte { if fa, ok := param.(FormatAppender); ok { return fa.AppendFormat(b, f) } return types.Append(b, param, 1) } pg-5.3.3/orm/format_test.go000066400000000000000000000115571305650307100156140ustar00rootroot00000000000000package orm_test import ( "database/sql/driver" "errors" "fmt" "math" "testing" "gopkg.in/pg.v5/orm" "gopkg.in/pg.v5/types" ) type ValuerError string func (e ValuerError) Value() (driver.Value, error) { return nil, errors.New(string(e)) } type StructFormatter struct { tableName struct{} `sql:"my_name,alias:my_alias" json:"-"` String string NotNull string `sql:",notnull"` Iface interface{} } func (StructFormatter) Method() string { return "method_value" } func (StructFormatter) MethodParam() types.Q { return types.Q("?string") } func (StructFormatter) MethodWithArgs(string) string { return "method_value" } func (StructFormatter) MethodWithCompositeReturn() (string, string) { return "method_value1", "method_value2" } type EmbeddedStructFormatter struct { *StructFormatter } func (EmbeddedStructFormatter) Method2() string { return "method_value2" } type params []interface{} type paramsMap map[string]interface{} type formatTest struct { q string params params paramsMap paramsMap wanted string wanterr string } var ( structv = &StructFormatter{ String: "string_value", Iface: "iface_value", } embeddedStructv = &EmbeddedStructFormatter{structv} ) var formatTests = []formatTest{ {q: "?", params: params{ValuerError("error")}, wanted: "?!(error)"}, {q: "?", wanted: "?"}, {q: "? ? ?", params: params{"foo", "bar"}, wanted: "'foo' 'bar' ?"}, {q: "?0 ?1", params: params{"foo", "bar"}, wanted: "'foo' 'bar'"}, {q: "?0 ?1 ?2", params: params{"foo", "bar"}, wanted: "'foo' 'bar' ?2"}, {q: "?0 ?1 ?0", params: params{"foo", "bar"}, wanted: "'foo' 'bar' 'foo'"}, {q: "one ?foo two", wanted: "one ?foo two"}, {q: "one ?foo two", params: params{structv}, wanted: "one ?foo two"}, {q: "one ?MethodWithArgs two", params: params{structv}, wanted: "one ?MethodWithArgs two"}, {q: "one ?MethodWithCompositeReturn two", params: params{structv}, wanted: "one ?MethodWithCompositeReturn two"}, {q: "?", params: params{uint64(math.MaxUint64)}, wanted: "18446744073709551615"}, {q: "?", params: params{orm.Q("query")}, wanted: "query"}, {q: "?", params: params{types.F("field")}, wanted: `"field"`}, {q: "?", params: params{structv}, wanted: `'{"String":"string_value","NotNull":"","Iface":"iface_value"}'`}, {q: `\? ?`, params: params{1}, wanted: "? 1"}, {q: `?`, params: params{types.Q(`\?`)}, wanted: `\?`}, {q: `?`, params: params{types.Q(`\\?`)}, wanted: `\\?`}, {q: `?`, params: params{types.Q(`\?param`)}, wanted: `\?param`}, {q: "?string", params: params{structv}, wanted: `'string_value'`}, {q: "?iface", params: params{structv}, wanted: `'iface_value'`}, {q: "?string", params: params{&StructFormatter{}}, wanted: `NULL`}, {q: "? ?string ?", params: params{"one", "two", structv}, wanted: "'one' 'string_value' 'two'"}, {q: "?string ?Method", params: params{structv}, wanted: "'string_value' 'method_value'"}, {q: "?string ?Method ?Method2", params: params{embeddedStructv}, wanted: "'string_value' 'method_value' 'method_value2'"}, {q: "?string", params: params{structv}, paramsMap: paramsMap{"string": "my_value"}, wanted: "'my_value'"}, {q: "?", params: params{types.Q("?string")}, paramsMap: paramsMap{"string": "my_value"}, wanted: "?string"}, {q: "?", params: params{types.F("?string")}, paramsMap: paramsMap{"string": types.Q("my_value")}, wanted: `"?string"`}, {q: "?", params: params{orm.Q("?string")}, paramsMap: paramsMap{"string": "my_value"}, wanted: "'my_value'"}, {q: "?MethodParam", params: params{structv}, paramsMap: paramsMap{"string": "my_value"}, wanted: "?string"}, } func TestFormatQuery(t *testing.T) { for _, test := range formatTests { var f orm.Formatter for k, v := range test.paramsMap { f.SetParam(k, v) } got := f.Append(nil, test.q, test.params...) if string(got) != test.wanted { t.Fatalf( "got %q, wanted %q (q=%q params=%v paramsMap=%v)", got, test.wanted, test.q, test.params, test.paramsMap, ) } } } func BenchmarkFormatQueryWithoutParams(b *testing.B) { for i := 0; i < b.N; i++ { _ = orm.Q("SELECT * FROM my_table WHERE id = 1") } } func BenchmarkFormatQuery1Param(b *testing.B) { for i := 0; i < b.N; i++ { _ = orm.Q("SELECT * FROM my_table WHERE id = ?", 1) } } func BenchmarkFormatQuery10Params(b *testing.B) { for i := 0; i < b.N; i++ { _ = orm.Q( "SELECT * FROM my_table WHERE id IN (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ) } } func BenchmarkFormatQuerySprintf(b *testing.B) { for i := 0; i < b.N; i++ { _ = fmt.Sprintf("SELECT * FROM my_table WHERE id = %d", 1) } } func BenchmarkFormatQueryStructParam(b *testing.B) { param := StructFormatter{ String: "1", } for i := 0; i < b.N; i++ { _ = orm.Q("SELECT * FROM my_table WHERE id = ?string", param) } } func BenchmarkFormatQueryStructMethod(b *testing.B) { param := StructFormatter{} for i := 0; i < b.N; i++ { _ = orm.Q("SELECT * FROM my_table WHERE id = ?Method", ¶m) } } pg-5.3.3/orm/hook.go000066400000000000000000000101601305650307100142120ustar00rootroot00000000000000package orm import "reflect" const ( AfterQueryHookFlag = 1 << iota AfterSelectHookFlag BeforeInsertHookFlag AfterInsertHookFlag BeforeUpdateHookFlag AfterUpdateHookFlag BeforeDeleteHookFlag AfterDeleteHookFlag ) type hookStubs struct{} func (hookStubs) AfterQuery(_ DB) error { return nil } func (hookStubs) AfterSelect(_ DB) error { return nil } func (hookStubs) BeforeInsert(_ DB) error { return nil } func (hookStubs) AfterInsert(_ DB) error { return nil } func (hookStubs) BeforeUpdate(_ DB) error { return nil } func (hookStubs) AfterUpdate(_ DB) error { return nil } func (hookStubs) BeforeDelete(_ DB) error { return nil } func (hookStubs) AfterDelete(_ DB) error { return nil } func callHookSlice(slice reflect.Value, ptr bool, db DB, hook func(reflect.Value, DB) error) error { var firstErr error for i := 0; i < slice.Len(); i++ { var err error if ptr { err = hook(slice.Index(i), db) } else { err = hook(slice.Index(i).Addr(), db) } if err != nil && firstErr == nil { firstErr = err } } return firstErr } type afterQueryHook interface { AfterQuery(db DB) error } var afterQueryHookType = reflect.TypeOf((*afterQueryHook)(nil)).Elem() func callAfterQueryHook(v reflect.Value, db DB) error { return v.Interface().(afterQueryHook).AfterQuery(db) } func callAfterQueryHookSlice(slice reflect.Value, ptr bool, db DB) error { return callHookSlice(slice, ptr, db, callAfterQueryHook) } type afterSelectHook interface { AfterSelect(db DB) error } var afterSelectHookType = reflect.TypeOf((*afterSelectHook)(nil)).Elem() func callAfterSelectHook(v reflect.Value, db DB) error { return v.Interface().(afterSelectHook).AfterSelect(db) } func callAfterSelectHookSlice(slice reflect.Value, ptr bool, db DB) error { return callHookSlice(slice, ptr, db, callAfterSelectHook) } type beforeInsertHook interface { BeforeInsert(db DB) error } var beforeInsertHookType = reflect.TypeOf((*beforeInsertHook)(nil)).Elem() func callBeforeInsertHook(v reflect.Value, db DB) error { return v.Interface().(beforeInsertHook).BeforeInsert(db) } func callBeforeInsertHookSlice(slice reflect.Value, ptr bool, db DB) error { return callHookSlice(slice, ptr, db, callBeforeInsertHook) } type afterInsertHook interface { AfterInsert(db DB) error } var afterInsertHookType = reflect.TypeOf((*afterInsertHook)(nil)).Elem() func callAfterInsertHook(v reflect.Value, db DB) error { return v.Interface().(afterInsertHook).AfterInsert(db) } func callAfterInsertHookSlice(slice reflect.Value, ptr bool, db DB) error { return callHookSlice(slice, ptr, db, callAfterInsertHook) } type beforeUpdateHook interface { BeforeUpdate(db DB) error } var beforeUpdateHookType = reflect.TypeOf((*beforeUpdateHook)(nil)).Elem() func callBeforeUpdateHook(v reflect.Value, db DB) error { return v.Interface().(beforeUpdateHook).BeforeUpdate(db) } func callBeforeUpdateHookSlice(slice reflect.Value, ptr bool, db DB) error { return callHookSlice(slice, ptr, db, callBeforeUpdateHook) } type afterUpdateHook interface { AfterUpdate(db DB) error } var afterUpdateHookType = reflect.TypeOf((*afterUpdateHook)(nil)).Elem() func callAfterUpdateHook(v reflect.Value, db DB) error { return v.Interface().(afterUpdateHook).AfterUpdate(db) } func callAfterUpdateHookSlice(slice reflect.Value, ptr bool, db DB) error { return callHookSlice(slice, ptr, db, callAfterUpdateHook) } type beforeDeleteHook interface { BeforeDelete(db DB) error } var beforeDeleteHookType = reflect.TypeOf((*beforeDeleteHook)(nil)).Elem() func callBeforeDeleteHook(v reflect.Value, db DB) error { return v.Interface().(beforeDeleteHook).BeforeDelete(db) } func callBeforeDeleteHookSlice(slice reflect.Value, ptr bool, db DB) error { return callHookSlice(slice, ptr, db, callBeforeDeleteHook) } type afterDeleteHook interface { AfterDelete(db DB) error } var afterDeleteHookType = reflect.TypeOf((*afterDeleteHook)(nil)).Elem() func callAfterDeleteHook(v reflect.Value, db DB) error { return v.Interface().(afterDeleteHook).AfterDelete(db) } func callAfterDeleteHookSlice(slice reflect.Value, ptr bool, db DB) error { return callHookSlice(slice, ptr, db, callAfterDeleteHook) } pg-5.3.3/orm/insert.go000066400000000000000000000045431305650307100145660ustar00rootroot00000000000000package orm import ( "bytes" "errors" "reflect" ) func Insert(db DB, v ...interface{}) error { _, err := NewQuery(db, v...).Insert() return err } type insertQuery struct { *Query returningFields []*Field } var _ QueryAppender = (*insertQuery)(nil) func (q insertQuery) AppendQuery(b []byte, params ...interface{}) ([]byte, error) { if q.model == nil { return nil, errors.New("pg: Model(nil)") } table := q.model.Table() value := q.model.Value() b = append(b, "INSERT INTO "...) if q.onConflict != nil { b = q.appendTableNameWithAlias(b) } else { b = q.appendTableName(b) } b = append(b, " ("...) start := len(b) for _, f := range table.Fields { b = append(b, f.ColName...) b = append(b, ", "...) } if len(b) > start { b = b[:len(b)-2] } b = append(b, ") VALUES ("...) if value.Kind() == reflect.Struct { b = q.appendValues(b, table.Fields, value) } else { for i := 0; i < value.Len(); i++ { el := value.Index(i) if el.Kind() == reflect.Interface { el = el.Elem() } b = q.appendValues(b, table.Fields, reflect.Indirect(el)) if i != value.Len()-1 { b = append(b, "), ("...) } } } b = append(b, ')') if q.onConflict != nil { b = append(b, " ON CONFLICT "...) b = q.onConflict.AppendFormat(b, q) if onConflictDoUpdate(b) { if len(q.set) > 0 { b = q.appendSet(b) } if len(q.where) > 0 { b = q.appendWhere(b) } } } if len(q.returning) > 0 { b = q.appendReturning(b) } else if len(q.returningFields) > 0 { b = q.appendReturningFields(b, q.returningFields) } return b, nil } func onConflictDoUpdate(b []byte) bool { return bytes.HasSuffix(b, []byte(" DO UPDATE")) } func (q *insertQuery) appendValues(b []byte, fields []*Field, v reflect.Value) []byte { for i, f := range fields { if i > 0 { b = append(b, ", "...) } if f.OmitEmpty(v) { b = append(b, "DEFAULT"...) q.addReturningField(f) } else { b = f.AppendValue(b, v, 1) } } return b } func (ins *insertQuery) addReturningField(field *Field) { for _, f := range ins.returningFields { if f == field { return } } ins.returningFields = append(ins.returningFields, field) } func (insertQuery) appendReturningFields(b []byte, fields []*Field) []byte { b = append(b, " RETURNING "...) for i, f := range fields { if i > 0 { b = append(b, ", "...) } b = append(b, f.ColName...) } return b } pg-5.3.3/orm/insert_test.go000066400000000000000000000055451305650307100156300ustar00rootroot00000000000000package orm import ( "gopkg.in/pg.v5/types" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type InsertTest struct{} type EmbeddingTest struct { tableName struct{} `sql:"name"` Id int Field int } type EmbeddedInsertTest struct { tableName struct{} `sql:"my_name"` EmbeddingTest Field2 int } type OverrideInsertTest struct { EmbeddingTest `pg:",override"` Field2 int } type InsertNullTest struct { F1 int F2 int `sql:",notnull"` F3 int `sql:",pk"` F4 int `sql:",pk,notnull"` } type InsertQTest struct { Geo types.Q } var _ = Describe("Insert", func() { It("supports ON CONFLICT DO UPDATE", func() { q := NewQuery(nil, &InsertTest{}). OnConflict("(unq1) DO UPDATE"). Set("count1 = count1 + 1"). Where("cond1 IS TRUE") b, err := insertQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`INSERT INTO "insert_tests" AS "insert_test" () VALUES () ON CONFLICT (unq1) DO UPDATE SET count1 = count1 + 1 WHERE (cond1 IS TRUE)`)) }) It("supports ON CONFLICT DO NOTHING", func() { q := NewQuery(nil, &InsertTest{}). OnConflict("(unq1) DO NOTHING"). Set("count1 = count1 + 1"). Where("cond1 IS TRUE") b, err := insertQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`INSERT INTO "insert_tests" AS "insert_test" () VALUES () ON CONFLICT (unq1) DO NOTHING`)) }) It("supports custom table name on embedded struct", func() { q := NewQuery(nil, &EmbeddedInsertTest{}) b, err := insertQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`INSERT INTO my_name ("id", "field", "field2") VALUES (DEFAULT, DEFAULT, DEFAULT) RETURNING "id", "field", "field2"`)) }) It("supports override table name with embedded struct", func() { q := NewQuery(nil, &OverrideInsertTest{}) b, err := insertQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`INSERT INTO name ("id", "field", "field2") VALUES (DEFAULT, DEFAULT, DEFAULT) RETURNING "id", "field", "field2"`)) }) It("supports notnull", func() { q := NewQuery(nil, &InsertNullTest{}) b, err := insertQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`INSERT INTO "insert_null_tests" ("f1", "f2", "f3", "f4") VALUES (DEFAULT, 0, DEFAULT, 0) RETURNING "f1", "f3"`)) }) It("inserts types.Q", func() { q := NewQuery(nil, &InsertQTest{ Geo: types.Q("ST_GeomFromText('POLYGON((75.150000 29.530000, 77.000000 29.000000, 77.600000 29.500000, 75.150000 29.530000))')"), }) b, err := insertQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`INSERT INTO "insert_q_tests" ("geo") VALUES (ST_GeomFromText('POLYGON((75.150000 29.530000, 77.000000 29.000000, 77.600000 29.500000, 75.150000 29.530000))'))`)) }) }) pg-5.3.3/orm/join.go000066400000000000000000000123711305650307100142170ustar00rootroot00000000000000package orm import "gopkg.in/pg.v5/types" type join struct { Parent *join BaseModel tableModel JoinModel tableModel Rel *Relation ApplyQuery func(*Query) (*Query, error) Columns []string } func (j *join) Select(db DB) error { switch j.Rel.Type { case HasManyRelation: return j.selectMany(db) case Many2ManyRelation: return j.selectM2M(db) } panic("not reached") } func (j *join) selectMany(db DB) error { q, err := j.manyQuery(db) if err != nil { return err } err = q.Select() if err != nil { return err } return nil } func (j *join) manyQuery(db DB) (*Query, error) { root := j.JoinModel.Root() index := j.JoinModel.ParentIndex() manyModel := newManyModel(j) q := NewQuery(db, manyModel) if j.ApplyQuery != nil { var err error q, err = j.ApplyQuery(q) if err != nil { return nil, err } } q.columns = append(q.columns, hasManyColumnsAppender{j}) baseTable := j.BaseModel.Table() cols := columns(j.JoinModel.Table().Alias, "", j.Rel.FKs) vals := values(root, index, baseTable.PKs) q = q.Where(`(?) IN (?)`, types.Q(cols), types.Q(vals)) if j.Rel.Polymorphic { q = q.Where( `? IN (?, ?)`, types.F(j.Rel.BasePrefix+"type"), baseTable.ModelName, baseTable.TypeName, ) } return q, nil } func (j *join) selectM2M(db DB) error { q, err := j.m2mQuery(db) if err != nil { return err } err = q.Select() if err != nil { return err } return nil } func (j *join) m2mQuery(db DB) (*Query, error) { index := j.JoinModel.ParentIndex() baseTable := j.BaseModel.Table() m2mCols := columns(j.Rel.M2MTableName, j.Rel.BasePrefix, baseTable.PKs) m2mVals := values(j.BaseModel.Root(), index, baseTable.PKs) m2mModel := newM2MModel(j) q := NewQuery(db, m2mModel) if j.ApplyQuery != nil { var err error q, err = j.ApplyQuery(q) if err != nil { return nil, err } } q.columns = append(q.columns, hasManyColumnsAppender{j}) q = q.Join( "JOIN ? ON (?) IN (?)", j.Rel.M2MTableName, types.Q(m2mCols), types.Q(m2mVals), ) joinAlias := j.JoinModel.Table().Alias for _, pk := range j.JoinModel.Table().PKs { q = q.Where( "?.? = ?.?", joinAlias, pk.ColName, j.Rel.M2MTableName, types.F(j.Rel.JoinPrefix+pk.SQLName), ) } return q, nil } func (j *join) hasParent() bool { if j.Parent != nil { switch j.Parent.Rel.Type { case HasOneRelation, BelongsToRelation: return true } } return false } func (j *join) appendAlias(b []byte) []byte { b = append(b, '"') b = appendAlias(b, j, true) b = append(b, '"') return b } func (j *join) appendAliasColumn(b []byte, column string) []byte { b = append(b, '"') b = appendAlias(b, j, true) b = append(b, "__"...) b = types.AppendField(b, column, 2) b = append(b, '"') return b } func (j *join) appendBaseAlias(b []byte) []byte { if j.hasParent() { b = append(b, '"') b = appendAlias(b, j.Parent, true) b = append(b, '"') return b } return append(b, j.BaseModel.Table().Alias...) } func appendAlias(b []byte, j *join, topLevel bool) []byte { if j.hasParent() { b = appendAlias(b, j.Parent, topLevel) topLevel = false } if !topLevel { b = append(b, "__"...) } b = append(b, j.Rel.Field.SQLName...) return b } func (j *join) appendHasOneColumns(b []byte) []byte { if j.Columns == nil { for i, f := range j.JoinModel.Table().Fields { if i > 0 { b = append(b, ", "...) } b = j.appendAlias(b) b = append(b, '.') b = append(b, f.ColName...) b = append(b, " AS "...) b = j.appendAliasColumn(b, f.SQLName) } return b } for i, column := range j.Columns { if i > 0 { b = append(b, ", "...) } b = j.appendAlias(b) b = append(b, '.') b = types.AppendField(b, column, 1) b = append(b, " AS "...) b = j.appendAliasColumn(b, column) } return b } func (j *join) appendHasOneJoin(b []byte) []byte { b = append(b, "LEFT JOIN "...) b = append(b, j.JoinModel.Table().Name...) b = append(b, " AS "...) b = j.appendAlias(b) b = append(b, " ON "...) if j.Rel.Type == HasOneRelation { joinTable := j.Rel.JoinTable for i, fk := range j.Rel.FKs { if i > 0 { b = append(b, " AND "...) } b = j.appendAlias(b) b = append(b, '.') b = append(b, joinTable.PKs[i].ColName...) b = append(b, " = "...) b = j.appendBaseAlias(b) b = append(b, '.') b = append(b, fk.ColName...) } } else { baseTable := j.BaseModel.Table() for i, fk := range j.Rel.FKs { if i > 0 { b = append(b, " AND "...) } b = j.appendAlias(b) b = append(b, '.') b = append(b, fk.ColName...) b = append(b, " = "...) b = j.appendBaseAlias(b) b = append(b, '.') b = append(b, baseTable.PKs[i].ColName...) } } return b } type hasManyColumnsAppender struct { *join } func (q hasManyColumnsAppender) AppendFormat(b []byte, f QueryFormatter) []byte { if q.Rel.M2MTableName != "" { b = append(b, q.Rel.M2MTableName...) b = append(b, ".*, "...) } joinTable := q.JoinModel.Table() if q.Columns == nil { for i, f := range joinTable.Fields { if i > 0 { b = append(b, ", "...) } b = append(b, joinTable.Alias...) b = append(b, '.') b = append(b, f.ColName...) } return b } for i, column := range q.Columns { if i > 0 { b = append(b, ", "...) } b = append(b, joinTable.Alias...) b = append(b, '.') b = types.AppendField(b, column, 1) } return b } pg-5.3.3/orm/join_test.go000066400000000000000000000031251305650307100152530ustar00rootroot00000000000000package orm import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type JoinTest struct { tableName struct{} `sql:"JoinTest,alias:JoinTest"` Id int HasOne *HasOne HasOneId int BelongsTo *BelongsTo } type HasOne struct { tableName struct{} `sql:"HasOne,alias:HasOne"` Id int HasOne *HasOne HasOneId int } type BelongsTo struct { tableName struct{} `sql:"BelongsTo,alias:BelongsTo"` Id int JoinTestId int } var _ = Describe("Select", func() { It("supports has one", func() { q := NewQuery(nil, &JoinTest{}).Relation("HasOne.HasOne", nil) b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT JoinTest."id", JoinTest."has_one_id", "has_one"."id" AS "has_one__id", "has_one"."has_one_id" AS "has_one__has_one_id", "has_one__has_one"."id" AS "has_one__has_one__id", "has_one__has_one"."has_one_id" AS "has_one__has_one__has_one_id" FROM JoinTest AS JoinTest LEFT JOIN HasOne AS "has_one" ON "has_one"."id" = JoinTest."has_one_id" LEFT JOIN HasOne AS "has_one__has_one" ON "has_one__has_one"."id" = "has_one"."has_one_id"`)) }) It("supports belongs to", func() { q := NewQuery(nil, &JoinTest{}).Relation("BelongsTo", nil) b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT JoinTest."id", JoinTest."has_one_id", "belongs_to"."id" AS "belongs_to__id", "belongs_to"."join_test_id" AS "belongs_to__join_test_id" FROM JoinTest AS JoinTest LEFT JOIN BelongsTo AS "belongs_to" ON "belongs_to"."join_test_id" = JoinTest."id"`)) }) }) pg-5.3.3/orm/kinds.go000066400000000000000000000025371305650307100143730ustar00rootroot00000000000000package orm import "reflect" var isZeroerType = reflect.TypeOf((*isZeroer)(nil)).Elem() type isZeroer interface { IsZero() bool } func isEmptyFunc(typ reflect.Type) func(reflect.Value) bool { if typ.Implements(isZeroerType) { return isEmptyZero } switch typ.Kind() { case reflect.Array, reflect.Map, reflect.Slice, reflect.String: return isEmptyLen case reflect.Bool: return isEmptyBool case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return isEmptyInt case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return isEmptyUint case reflect.Float32, reflect.Float64: return isEmptyFloat case reflect.Interface, reflect.Ptr: return isEmptyNil } return isEmptyFalse } func isEmptyLen(v reflect.Value) bool { return v.Len() == 0 } func isEmptyNil(v reflect.Value) bool { return v.IsNil() } func isEmptyBool(v reflect.Value) bool { return !v.Bool() } func isEmptyInt(v reflect.Value) bool { return v.Int() == 0 } func isEmptyUint(v reflect.Value) bool { return v.Uint() == 0 } func isEmptyFloat(v reflect.Value) bool { return v.Float() == 0 } func isEmptyZero(v reflect.Value) bool { switch v.Kind() { case reflect.Ptr: if v.IsNil() { return true } } return v.Interface().(isZeroer).IsZero() } func isEmptyFalse(v reflect.Value) bool { return false } pg-5.3.3/orm/main_test.go000066400000000000000000000002631305650307100152400ustar00rootroot00000000000000package orm_test import ( "testing" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) func TestGinkgo(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "ORM") } pg-5.3.3/orm/model.go000066400000000000000000000031471305650307100143610ustar00rootroot00000000000000package orm import ( "database/sql" "errors" "fmt" "reflect" "time" "gopkg.in/pg.v5/types" ) var timeType = reflect.TypeOf((*time.Time)(nil)).Elem() type useQueryOne interface { useQueryOne() bool } type Model interface { ColumnScanner // Reset resets model state. Reset() error // NewModel returns ColumnScanner that is used to scan columns // from the current row. NewModel() ColumnScanner // AddModel adds ColumnScanner to the Collection. AddModel(ColumnScanner) error AfterQuery(DB) error AfterSelect(DB) error BeforeInsert(DB) error AfterInsert(DB) error BeforeUpdate(DB) error AfterUpdate(DB) error BeforeDelete(DB) error AfterDelete(DB) error } func NewModel(values ...interface{}) (Model, error) { if len(values) > 1 { return Scan(values...), nil } v0 := values[0] switch v0 := v0.(type) { case Model: return v0, nil case sql.Scanner: return Scan(v0), nil } v := reflect.ValueOf(v0) if !v.IsValid() { return nil, errors.New("pg: Model(nil)") } if v.Kind() != reflect.Ptr { return nil, fmt.Errorf("pg: Model(non-pointer %T)", v0) } v = v.Elem() switch v.Kind() { case reflect.Struct: return newStructTableModel(v) case reflect.Slice: typ := v.Type() structType := indirectType(typ.Elem()) if structType.Kind() == reflect.Struct && structType != timeType { m := sliceTableModel{ structTableModel: structTableModel{ table: Tables.Get(structType), root: v, }, slice: v, } m.init(typ) return &m, nil } else { return &sliceModel{ slice: v, scan: types.Scanner(structType), }, nil } } return Scan(v0), nil } pg-5.3.3/orm/model_discard.go000066400000000000000000000005151305650307100160460ustar00rootroot00000000000000package orm type Discard struct { hookStubs } var _ Model = (*Discard)(nil) func (Discard) Reset() error { return nil } func (d Discard) NewModel() ColumnScanner { return d } func (Discard) AddModel(_ ColumnScanner) error { return nil } func (Discard) ScanColumn(colIdx int, colName string, b []byte) error { return nil } pg-5.3.3/orm/model_scan.go000066400000000000000000000013121305650307100153550ustar00rootroot00000000000000package orm import ( "fmt" "gopkg.in/pg.v5/types" ) type valuesModel struct { hookStubs values []interface{} } var _ Model = valuesModel{} func Scan(values ...interface{}) valuesModel { return valuesModel{ values: values, } } func (valuesModel) useQueryOne() bool { return true } func (valuesModel) Reset() error { return nil } func (m valuesModel) NewModel() ColumnScanner { return m } func (valuesModel) AddModel(_ ColumnScanner) error { return nil } func (m valuesModel) ScanColumn(colIdx int, colName string, b []byte) error { if colIdx >= len(m.values) { return fmt.Errorf("pg: no Scan value for column index=%d name=%s", colIdx, colName) } return types.Scan(m.values[colIdx], b) } pg-5.3.3/orm/model_slice.go000066400000000000000000000011231305650307100155300ustar00rootroot00000000000000package orm import ( "reflect" "gopkg.in/pg.v5/internal" ) type sliceModel struct { hookStubs slice reflect.Value scan func(reflect.Value, []byte) error } var _ Model = (*sliceModel)(nil) func (m *sliceModel) Reset() error { if m.slice.IsValid() && m.slice.Len() > 0 { m.slice.Set(m.slice.Slice(0, 0)) } return nil } func (m *sliceModel) NewModel() ColumnScanner { return m } func (sliceModel) AddModel(_ ColumnScanner) error { return nil } func (m *sliceModel) ScanColumn(colIdx int, _ string, b []byte) error { v := internal.SliceNextElem(m.slice) return m.scan(v, b) } pg-5.3.3/orm/model_table.go000066400000000000000000000040651305650307100155300ustar00rootroot00000000000000package orm import ( "errors" "fmt" "reflect" ) type tableModel interface { Model Table() *Table Relation() *Relation AppendParam([]byte, string) ([]byte, bool) Join(string, func(*Query) (*Query, error)) (bool, *join) GetJoin(string) *join GetJoins() []join AddJoin(join) *join Root() reflect.Value Index() []int ParentIndex() []int Bind(reflect.Value) Value() reflect.Value scanColumn(int, string, []byte) (bool, error) } func newTableModel(v interface{}) (tableModel, error) { switch v := v.(type) { case tableModel: return v, nil case reflect.Value: return newTableModelValue(v) default: vv := reflect.ValueOf(v) if !vv.IsValid() { return nil, errors.New("pg: Model(nil)") } if vv.Kind() != reflect.Ptr { return nil, fmt.Errorf("pg: Model(non-pointer %T)", v) } return newTableModelValue(vv.Elem()) } } func newTableModelValue(v reflect.Value) (tableModel, error) { if !v.IsValid() { return nil, errors.New("pg: Model(nil)") } v = reflect.Indirect(v) switch v.Kind() { case reflect.Struct: return newStructTableModel(v) case reflect.Slice: structType := sliceElemType(v) if structType.Kind() == reflect.Struct { m := sliceTableModel{ structTableModel: structTableModel{ table: Tables.Get(structType), root: v, }, slice: v, } m.init(v.Type()) return &m, nil } } return nil, fmt.Errorf("pg: Model(unsupported %s)", v.Type()) } func newTableModelIndex(root reflect.Value, index []int, rel *Relation) (tableModel, error) { typ := typeByIndex(root.Type(), index) if typ.Kind() == reflect.Struct { return &structTableModel{ table: Tables.Get(typ), rel: rel, root: root, index: index, }, nil } if typ.Kind() == reflect.Slice { structType := indirectType(typ.Elem()) if structType.Kind() == reflect.Struct { m := sliceTableModel{ structTableModel: structTableModel{ table: Tables.Get(structType), rel: rel, root: root, index: index, }, } m.init(typ) return &m, nil } } return nil, fmt.Errorf("pg: NewModel(%s)", typ) } pg-5.3.3/orm/model_table_m2m.go000066400000000000000000000043711305650307100163030ustar00rootroot00000000000000package orm import ( "fmt" "reflect" ) type m2mModel struct { *sliceTableModel baseTable *Table rel *Relation buf []byte dstValues map[string][]reflect.Value columns map[string]string } var _ tableModel = (*m2mModel)(nil) func newM2MModel(join *join) *m2mModel { baseTable := join.BaseModel.Table() joinModel := join.JoinModel.(*sliceTableModel) dstValues := dstValues(joinModel, baseTable.PKs) m := &m2mModel{ sliceTableModel: joinModel, baseTable: baseTable, rel: join.Rel, dstValues: dstValues, columns: make(map[string]string), } if !m.sliceOfPtr { m.strct = reflect.New(m.table.Type).Elem() } return m } func (m *m2mModel) NewModel() ColumnScanner { if m.sliceOfPtr { m.strct = reflect.New(m.table.Type).Elem() } else { m.strct.Set(m.table.zeroStruct) } m.structTableModel.NewModel() return m } func (m *m2mModel) AddModel(model ColumnScanner) error { m.buf = modelIdMap(m.buf[:0], m.columns, m.baseTable.ModelName+"_", m.baseTable.PKs) dstValues, ok := m.dstValues[string(m.buf)] if !ok { return fmt.Errorf("pg: can't find dst value for model id=%q", m.buf) } for _, v := range dstValues { if m.sliceOfPtr { v.Set(reflect.Append(v, m.strct.Addr())) } else { v.Set(reflect.Append(v, m.strct)) } } return nil } func (m *m2mModel) AfterQuery(db DB) error { if !m.rel.JoinTable.Has(AfterQueryHookFlag) { return nil } var retErr error for _, slices := range m.dstValues { for _, slice := range slices { err := callAfterQueryHookSlice(slice, m.sliceOfPtr, db) if err != nil && retErr == nil { retErr = err } } } return retErr } func (m *m2mModel) AfterSelect(db DB) error { return nil } func (m *m2mModel) BeforeInsert(db DB) error { return nil } func (m *m2mModel) AfterInsert(db DB) error { return nil } func (m *m2mModel) BeforeUpdate(db DB) error { return nil } func (m *m2mModel) AfterUpdate(db DB) error { return nil } func (m *m2mModel) BeforeDelete(db DB) error { return nil } func (m *m2mModel) AfterDelete(db DB) error { return nil } func (m *m2mModel) ScanColumn(colIdx int, colName string, b []byte) error { ok, err := m.sliceTableModel.scanColumn(colIdx, colName, b) if ok { return err } m.columns[colName] = string(b) return nil } pg-5.3.3/orm/model_table_many.go000066400000000000000000000035631305650307100165560ustar00rootroot00000000000000package orm import ( "fmt" "reflect" ) type manyModel struct { *sliceTableModel rel *Relation buf []byte dstValues map[string][]reflect.Value } var _ tableModel = (*manyModel)(nil) func newManyModel(j *join) *manyModel { joinModel := j.JoinModel.(*sliceTableModel) dstValues := dstValues(joinModel, j.BaseModel.Table().PKs) m := manyModel{ sliceTableModel: joinModel, rel: j.Rel, dstValues: dstValues, } if !m.sliceOfPtr { m.strct = reflect.New(m.table.Type).Elem() } return &m } func (m *manyModel) NewModel() ColumnScanner { if m.sliceOfPtr { m.strct = reflect.New(m.table.Type).Elem() } else { m.strct.Set(m.table.zeroStruct) } m.structTableModel.NewModel() return m } func (m *manyModel) AddModel(model ColumnScanner) error { m.buf = modelId(m.buf[:0], m.strct, m.rel.FKs) dstValues, ok := m.dstValues[string(m.buf)] if !ok { return fmt.Errorf("pg: can't find dst value for model id=%q", m.buf) } for _, v := range dstValues { if m.sliceOfPtr { v.Set(reflect.Append(v, m.strct.Addr())) } else { v.Set(reflect.Append(v, m.strct)) } } return nil } func (m *manyModel) AfterQuery(db DB) error { if !m.rel.JoinTable.Has(AfterQueryHookFlag) { return nil } var retErr error for _, slices := range m.dstValues { for _, slice := range slices { err := callAfterQueryHookSlice(slice, m.sliceOfPtr, db) if err != nil && retErr == nil { retErr = err } } } return retErr } func (m *manyModel) AfterSelect(db DB) error { return nil } func (m *manyModel) BeforeInsert(db DB) error { return nil } func (m *manyModel) AfterInsert(db DB) error { return nil } func (m *manyModel) BeforeUpdate(db DB) error { return nil } func (m *manyModel) AfterUpdate(db DB) error { return nil } func (m *manyModel) BeforeDelete(db DB) error { return nil } func (m *manyModel) AfterDelete(db DB) error { return nil } pg-5.3.3/orm/model_table_slice.go000066400000000000000000000054731305650307100167130ustar00rootroot00000000000000package orm import "reflect" type sliceTableModel struct { structTableModel slice reflect.Value sliceOfPtr bool } var _ tableModel = (*sliceTableModel)(nil) func (m *sliceTableModel) init(sliceType reflect.Type) { switch sliceType.Elem().Kind() { case reflect.Ptr, reflect.Interface: m.sliceOfPtr = true } } func (sliceTableModel) useQueryOne() {} func (m *sliceTableModel) Join(name string, apply func(*Query) (*Query, error)) (bool, *join) { return m.join(m.Value(), name, apply) } func (m *sliceTableModel) Bind(bind reflect.Value) { m.slice = bind.Field(m.index[len(m.index)-1]) } func (m *sliceTableModel) Value() reflect.Value { return m.slice } func (m *sliceTableModel) Reset() error { if m.slice.IsValid() && m.slice.Len() > 0 { m.slice.Set(m.slice.Slice(0, 0)) } return nil } func (m *sliceTableModel) NewModel() ColumnScanner { m.strct = m.nextElem() m.bindChildren() return m } func (m *sliceTableModel) AfterQuery(db DB) error { if !m.table.Has(AfterQueryHookFlag) { return nil } return callAfterQueryHookSlice(m.slice, m.sliceOfPtr, db) } func (m *sliceTableModel) AfterSelect(db DB) error { if !m.table.Has(AfterSelectHookFlag) { return nil } return callAfterSelectHookSlice(m.slice, m.sliceOfPtr, db) } func (m *sliceTableModel) BeforeInsert(db DB) error { if !m.table.Has(BeforeInsertHookFlag) { return nil } return callBeforeInsertHookSlice(m.slice, m.sliceOfPtr, db) } func (m *sliceTableModel) AfterInsert(db DB) error { if !m.table.Has(AfterInsertHookFlag) { return nil } return callAfterInsertHookSlice(m.slice, m.sliceOfPtr, db) } func (m *sliceTableModel) BeforeUpdate(db DB) error { if !m.table.Has(BeforeUpdateHookFlag) { return nil } return callBeforeUpdateHookSlice(m.slice, m.sliceOfPtr, db) } func (m *sliceTableModel) AfterUpdate(db DB) error { if !m.table.Has(AfterUpdateHookFlag) { return nil } return callAfterUpdateHookSlice(m.slice, m.sliceOfPtr, db) } func (m *sliceTableModel) BeforeDelete(db DB) error { if !m.table.Has(BeforeDeleteHookFlag) { return nil } return callBeforeDeleteHookSlice(m.slice, m.sliceOfPtr, db) } func (m *sliceTableModel) AfterDelete(db DB) error { if !m.table.Has(AfterDeleteHookFlag) { return nil } return callAfterDeleteHookSlice(m.slice, m.sliceOfPtr, db) } func (m *sliceTableModel) nextElem() reflect.Value { if m.slice.Len() < m.slice.Cap() { m.slice.Set(m.slice.Slice(0, m.slice.Len()+1)) elem := m.slice.Index(m.slice.Len() - 1) if m.sliceOfPtr { if elem.IsNil() { elem.Set(reflect.New(elem.Type().Elem())) } return elem.Elem() } else { return elem } } if m.sliceOfPtr { elem := reflect.New(m.table.Type) m.slice.Set(reflect.Append(m.slice, elem)) return elem.Elem() } else { m.slice.Set(reflect.Append(m.slice, m.table.zeroStruct)) return m.slice.Index(m.slice.Len() - 1) } } pg-5.3.3/orm/model_table_struct.go000066400000000000000000000146751305650307100171440ustar00rootroot00000000000000package orm import ( "errors" "fmt" "reflect" "strings" ) type structTableModel struct { table *Table rel *Relation joins []join root reflect.Value index []int strct reflect.Value } var _ tableModel = (*structTableModel)(nil) func newStructTableModel(v interface{}) (*structTableModel, error) { switch v := v.(type) { case *structTableModel: return v, nil case reflect.Value: return newStructTableModelValue(v) default: return newStructTableModelValue(reflect.ValueOf(v)) } } func newStructTableModelValue(v reflect.Value) (*structTableModel, error) { if !v.IsValid() { return nil, errors.New("pg: Model(nil)") } v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return nil, fmt.Errorf("pg: Model(unsupported %s)", v.Type()) } return &structTableModel{ table: Tables.Get(v.Type()), root: v, strct: v, }, nil } func (structTableModel) useQueryOne() bool { return true } func (m *structTableModel) Table() *Table { return m.table } func (m *structTableModel) Relation() *Relation { return m.rel } func (m *structTableModel) AppendParam(dst []byte, name string) ([]byte, bool) { dst, ok := m.table.AppendParam(dst, m.strct, name) if ok { return dst, true } switch name { case "TableAlias": dst = append(dst, m.table.Alias...) return dst, true } return dst, false } func (m *structTableModel) Root() reflect.Value { return m.root } func (m *structTableModel) Index() []int { return m.index } func (m *structTableModel) ParentIndex() []int { return m.index[:len(m.index)-len(m.rel.Field.Index)] } func (m *structTableModel) Value() reflect.Value { return m.strct } func (m *structTableModel) Bind(bind reflect.Value) { m.strct = bind.FieldByIndex(m.rel.Field.Index) } func (m *structTableModel) initStruct(bindChildren bool) { if m.strct.Kind() == reflect.Interface { m.strct = m.strct.Elem() } if m.strct.Kind() == reflect.Ptr { if m.strct.IsNil() { m.strct.Set(reflect.New(m.strct.Type().Elem())) m.strct = m.strct.Elem() bindChildren = true } else { m.strct = m.strct.Elem() } } if bindChildren { m.bindChildren() } } func (m *structTableModel) bindChildren() { for i := range m.joins { j := &m.joins[i] switch j.Rel.Type { case HasOneRelation, BelongsToRelation: j.JoinModel.Bind(m.strct) } } } func (structTableModel) Reset() error { return nil } func (m *structTableModel) NewModel() ColumnScanner { m.initStruct(true) return m } func (m *structTableModel) AddModel(_ ColumnScanner) error { return nil } func (m *structTableModel) AfterQuery(db DB) error { if !m.table.Has(AfterQueryHookFlag) { return nil } return callAfterQueryHook(m.strct.Addr(), db) } func (m *structTableModel) AfterSelect(db DB) error { if !m.table.Has(AfterSelectHookFlag) { return nil } return callAfterSelectHook(m.strct.Addr(), db) } func (m *structTableModel) BeforeInsert(db DB) error { if !m.table.Has(BeforeInsertHookFlag) { return nil } return callBeforeInsertHook(m.strct.Addr(), db) } func (m *structTableModel) AfterInsert(db DB) error { if !m.table.Has(AfterInsertHookFlag) { return nil } return callAfterInsertHook(m.strct.Addr(), db) } func (m *structTableModel) BeforeUpdate(db DB) error { if !m.table.Has(BeforeUpdateHookFlag) { return nil } return callBeforeUpdateHook(m.strct.Addr(), db) } func (m *structTableModel) AfterUpdate(db DB) error { if !m.table.Has(AfterUpdateHookFlag) { return nil } return callAfterUpdateHook(m.strct.Addr(), db) } func (m *structTableModel) BeforeDelete(db DB) error { if !m.table.Has(BeforeDeleteHookFlag) { return nil } return callBeforeDeleteHook(m.strct.Addr(), db) } func (m *structTableModel) AfterDelete(db DB) error { if !m.table.Has(AfterDeleteHookFlag) { return nil } return callAfterDeleteHook(m.strct.Addr(), db) } func (m *structTableModel) ScanColumn(colIdx int, colName string, b []byte) error { ok, err := m.scanColumn(colIdx, colName, b) if ok { return err } return fmt.Errorf("pg: can't find column=%s in model=%s", colName, m.table.Type.Name()) } func (m *structTableModel) scanColumn(colIdx int, colName string, b []byte) (bool, error) { joinName, fieldName := splitColumn(colName) if joinName != "" { if join := m.GetJoin(joinName); join != nil { return join.JoinModel.scanColumn(colIdx, fieldName, b) } if m.table.ModelName == joinName { return m.scanColumn(colIdx, fieldName, b) } } field, ok := m.table.FieldsMap[colName] if !ok { return false, nil } m.initStruct(false) return true, field.ScanValue(m.strct, b) } func (m *structTableModel) GetJoin(name string) *join { for i := range m.joins { j := &m.joins[i] if j.Rel.Field.GoName == name || j.Rel.Field.SQLName == name { return j } } return nil } func (m *structTableModel) GetJoins() []join { return m.joins } func (m *structTableModel) AddJoin(j join) *join { m.joins = append(m.joins, j) return &m.joins[len(m.joins)-1] } func (m *structTableModel) Join(name string, apply func(*Query) (*Query, error)) (bool, *join) { return m.join(m.Value(), name, apply) } func (m *structTableModel) join( bind reflect.Value, name string, apply func(*Query) (*Query, error), ) (bool, *join) { path := strings.Split(name, ".") index := make([]int, 0, len(path)) currJoin := join{ BaseModel: m, JoinModel: m, } var created bool var lastJoin *join var hasColumnName bool for _, name := range path { rel, ok := currJoin.JoinModel.Table().Relations[name] if !ok { hasColumnName = true break } currJoin.Rel = rel index = append(index, rel.Field.Index...) if j := currJoin.JoinModel.GetJoin(name); j != nil { currJoin.BaseModel = j.BaseModel currJoin.JoinModel = j.JoinModel created = false lastJoin = j } else { model, err := newTableModelIndex(bind, index, rel) if err != nil { return false, nil } currJoin.Parent = lastJoin currJoin.BaseModel = currJoin.JoinModel currJoin.JoinModel = model created = true lastJoin = currJoin.BaseModel.AddJoin(currJoin) } } // No joins with such name. if lastJoin == nil { return false, nil } if apply != nil { lastJoin.ApplyQuery = apply } if hasColumnName { column := path[len(path)-1] if column == "_" { if lastJoin.Columns == nil { lastJoin.Columns = make([]string, 0) } } else { lastJoin.Columns = append(lastJoin.Columns, column) } } return created, lastJoin } func splitColumn(s string) (string, string) { ind := strings.Index(s, "__") if ind == -1 { return "", s } return s[:ind], s[ind+2:] } pg-5.3.3/orm/orm.go000066400000000000000000000021131305650307100140460ustar00rootroot00000000000000package orm import "gopkg.in/pg.v5/types" // ColumnScanner is used to scan column values. type ColumnScanner interface { // Scan assigns a column value from a row. // // An error should be returned if the value can not be stored // without loss of information. ScanColumn(colIdx int, colName string, b []byte) error } type QueryAppender interface { AppendQuery(dst []byte, params ...interface{}) ([]byte, error) } type QueryFormatter interface { FormatQuery(dst []byte, query string, params ...interface{}) []byte } // DB is a common interface for pg.DB and pg.Tx types. type DB interface { Model(model ...interface{}) *Query Select(model interface{}) error Insert(model ...interface{}) error Update(model interface{}) error Delete(model interface{}) error Exec(query interface{}, params ...interface{}) (*types.Result, error) ExecOne(query interface{}, params ...interface{}) (*types.Result, error) Query(coll, query interface{}, params ...interface{}) (*types.Result, error) QueryOne(model, query interface{}, params ...interface{}) (*types.Result, error) QueryFormatter } pg-5.3.3/orm/query.go000066400000000000000000000361721305650307100144320ustar00rootroot00000000000000package orm import ( "errors" "fmt" "strings" "sync" "time" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/types" ) type withQuery struct { name string query *Query } type Query struct { db DB stickyErr error model tableModel ignoreModel bool with []withQuery tables []FormatAppender columns []FormatAppender set []FormatAppender where []sepFormatAppender joins []FormatAppender group []FormatAppender having []queryParamsAppender order []FormatAppender onConflict FormatAppender returning []queryParamsAppender limit int offset int } func NewQuery(db DB, model ...interface{}) *Query { return (&Query{}).DB(db).Model(model...) } // New returns new zero Query binded to the current db and model. func (q *Query) New() *Query { return &Query{ db: q.db, model: q.model, ignoreModel: true, } } // Copy returns copy of the Query. func (q *Query) Copy() *Query { copy := &Query{ db: q.db, stickyErr: q.stickyErr, model: q.model, ignoreModel: q.ignoreModel, tables: q.tables[:], columns: q.columns[:], set: q.set[:], where: q.where[:], joins: q.joins[:], group: q.group[:], having: q.having[:], order: q.order[:], onConflict: q.onConflict, returning: q.returning[:], limit: q.limit, offset: q.offset, } for _, with := range q.with { copy = copy.With(with.name, with.query.Copy()) } return copy } func (q *Query) err(err error) *Query { if q.stickyErr == nil { q.stickyErr = err } return q } func (q *Query) DB(db DB) *Query { q.db = db for _, with := range q.with { with.query.db = db } return q } func (q *Query) Model(model ...interface{}) *Query { var err error switch l := len(model); { case l == 0: q.model = nil case l == 1: q.model, err = newTableModel(model[0]) case l > 1: q.model, err = newTableModel(&model) } if err != nil { q = q.err(err) } if q.ignoreModel { q.ignoreModel = false } return q } // With adds subq as common table expression with the given name. func (q *Query) With(name string, subq *Query) *Query { q.with = append(q.with, withQuery{name, subq}) return q } // WrapWith creates new Query and adds to it current query as // common table expression with the given name. func (q *Query) WrapWith(name string) *Query { wrapper := q.New() wrapper.with = q.with q.with = nil wrapper = wrapper.With(name, q) return wrapper } func (q *Query) Table(tables ...string) *Query { for _, table := range tables { q.tables = append(q.tables, fieldAppender{table}) } return q } func (q *Query) TableExpr(expr string, params ...interface{}) *Query { q.tables = append(q.tables, queryParamsAppender{expr, params}) return q } // Column adds column to the Query quoting it according to PostgreSQL rules. // ColumnExpr can be used to bypass quoting restriction. func (q *Query) Column(columns ...string) *Query { for _, column := range columns { if column == "_" { if q.columns == nil { q.columns = make([]FormatAppender, 0) } continue } if q.model != nil { if _, j := q.model.Join(column, nil); j != nil { continue } } q.columns = append(q.columns, fieldAppender{column}) } return q } // ColumnExpr adds column expression to the Query. func (q *Query) ColumnExpr(expr string, params ...interface{}) *Query { q.columns = append(q.columns, queryParamsAppender{expr, params}) return q } func (q *Query) getFields() []string { var fields []string for _, col := range q.columns { if f, ok := col.(fieldAppender); ok { fields = append(fields, f.field) } } return fields } func (q *Query) Relation(name string, apply func(*Query) (*Query, error)) *Query { if _, j := q.model.Join(name, apply); j == nil { return q.err(fmt.Errorf( "model=%s does not have relation=%s", q.model.Table().Type.Name(), name, )) } return q } func (q *Query) Set(set string, params ...interface{}) *Query { q.set = append(q.set, queryParamsAppender{set, params}) return q } func (q *Query) Where(where string, params ...interface{}) *Query { q.where = append(q.where, &whereAppender{"AND", where, params}) return q } func (q *Query) WhereOr(where string, params ...interface{}) *Query { q.where = append(q.where, &whereAppender{"OR", where, params}) return q } // WhereIn is a shortcut for Where and pg.In to work with IN operator: // // WhereIn("id IN (?)", 1, 2, 3) func (q *Query) WhereIn(where string, params ...interface{}) *Query { return q.Where(where, types.In(params)) } func (q *Query) Join(join string, params ...interface{}) *Query { q.joins = append(q.joins, queryParamsAppender{join, params}) return q } func (q *Query) Group(columns ...string) *Query { for _, column := range columns { q.group = append(q.group, fieldAppender{column}) } return q } func (q *Query) GroupExpr(group string, params ...interface{}) *Query { q.group = append(q.group, queryParamsAppender{group, params}) return q } func (q *Query) Having(having string, params ...interface{}) *Query { q.having = append(q.having, queryParamsAppender{having, params}) return q } // Order adds sort order to the Query quoting column name. // OrderExpr can be used to bypass quoting restriction. func (q *Query) Order(orders ...string) *Query { loop: for _, order := range orders { ind := strings.LastIndex(order, " ") if ind != -1 { field := order[:ind] sort := order[ind+1:] switch internal.ToUpper(sort) { case "ASC", "DESC": q.order = append(q.order, queryParamsAppender{ query: "? ?", params: []interface{}{types.F(field), types.Q(sort)}, }) continue loop } } q.order = append(q.order, fieldAppender{order}) continue } return q } // Order adds sort order to the Query. func (q *Query) OrderExpr(order string, params ...interface{}) *Query { q.order = append(q.order, queryParamsAppender{order, params}) return q } func (q *Query) Limit(n int) *Query { q.limit = n return q } func (q *Query) Offset(n int) *Query { q.offset = n return q } func (q *Query) OnConflict(s string, params ...interface{}) *Query { q.onConflict = queryParamsAppender{s, params} return q } func (q *Query) Returning(s string, params ...interface{}) *Query { q.returning = append(q.returning, queryParamsAppender{s, params}) return q } // Apply calls the fn passing the Query as an argument. func (q *Query) Apply(fn func(*Query) (*Query, error)) *Query { qq, err := fn(q) if err != nil { q.err(err) return q } return qq } // Count returns number of rows matching the query using count aggregate function. func (q *Query) Count() (int, error) { if q.stickyErr != nil { return 0, q.stickyErr } var count int _, err := q.db.QueryOne( Scan(&count), q.countQuery().countSelectQuery("count(*)"), q.model, ) return count, err } func (q *Query) countQuery() *Query { if len(q.group) > 0 { return q.Copy().WrapWith("wrapper").Table("wrapper") } return q } func (q *Query) countSelectQuery(column string) selectQuery { return selectQuery{ Query: q, count: column, } } // First selects the first row. func (q *Query) First() error { b := columns(q.model.Table().Alias, "", q.model.Table().PKs) return q.OrderExpr(string(b)).Limit(1).Select() } // Last selects the last row. func (q *Query) Last() error { b := columns(q.model.Table().Alias, "", q.model.Table().PKs) b = append(b, " DESC"...) return q.OrderExpr(string(b)).Limit(1).Select() } // Select selects the model. func (q *Query) Select(values ...interface{}) error { if q.stickyErr != nil { return q.stickyErr } model, err := q.newModel(values...) if err != nil { return err } res, err := q.query(model, selectQuery{Query: q}) if err != nil { return err } if res.RowsReturned() > 0 { if q.model != nil { if err := selectJoins(q.db, q.model.GetJoins()); err != nil { return err } } if err := model.AfterSelect(q.db); err != nil { return err } } return nil } func (q *Query) newModel(values ...interface{}) (Model, error) { if len(values) > 0 { return NewModel(values...) } return q.model, nil } func (q *Query) query(model Model, query interface{}) (*types.Result, error) { if _, ok := model.(useQueryOne); ok { return q.db.QueryOne(model, query, q.model) } return q.db.Query(model, query, q.model) } // SelectAndCount runs Select and Count in two goroutines, // waits for them to finish and returns the result. func (q *Query) SelectAndCount(values ...interface{}) (count int, err error) { if q.stickyErr != nil { return 0, q.stickyErr } var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() if e := q.Select(values...); e != nil { err = e } }() go func() { defer wg.Done() var e error count, e = q.Count() if e != nil { err = e } }() wg.Wait() return count, err } func (q *Query) forEachHasOneJoin(fn func(*join)) { if q.model == nil { return } q._forEachHasOneJoin(fn, q.model.GetJoins()) } func (q *Query) _forEachHasOneJoin(fn func(*join), joins []join) { for i := range joins { j := &joins[i] switch j.Rel.Type { case HasOneRelation, BelongsToRelation: fn(j) q._forEachHasOneJoin(fn, j.JoinModel.GetJoins()) } } } func selectJoins(db DB, joins []join) error { var err error for i := range joins { j := &joins[i] if j.Rel.Type == HasOneRelation || j.Rel.Type == BelongsToRelation { err = selectJoins(db, j.JoinModel.GetJoins()) } else { err = j.Select(db) } if err != nil { return err } } return nil } // Insert inserts the model. func (q *Query) Insert(values ...interface{}) (*types.Result, error) { if q.stickyErr != nil { return nil, q.stickyErr } model, err := q.newModel(values...) if err != nil { return nil, err } if q.model != nil { if err := q.model.BeforeInsert(q.db); err != nil { return nil, err } } res, err := q.db.Query(model, insertQuery{Query: q}, q.model) if err != nil { return nil, err } if q.model != nil { if err := q.model.AfterInsert(q.db); err != nil { return nil, err } } return res, nil } // SelectOrInsert selects the model inserting one if it does not exist. func (q *Query) SelectOrInsert(values ...interface{}) (inserted bool, err error) { if q.stickyErr != nil { return false, q.stickyErr } var insertErr error for i := 0; i < 5; i++ { if i >= 2 { time.Sleep(internal.RetryBackoff << uint(i-2)) } err := q.Select(values...) if err == nil { return false, nil } if err != internal.ErrNoRows { return false, err } res, err := q.Insert(values...) if err != nil { insertErr = err if pgErr, ok := err.(internal.PGError); ok { if pgErr.IntegrityViolation() { continue } if pgErr.Field('C') == "55000" { // Retry on "#55000 attempted to delete invisible tuple". continue } } return false, err } if res.RowsAffected() == 1 { return true, nil } } err = fmt.Errorf( "pg: SelectOrInsert: select returns no rows (insert fails with err=%q)", insertErr, ) return false, err } // Update updates the model. func (q *Query) Update(values ...interface{}) (*types.Result, error) { if q.stickyErr != nil { return nil, q.stickyErr } model, err := q.newModel(values...) if err != nil { return nil, err } if q.model != nil { if err := q.model.BeforeUpdate(q.db); err != nil { return nil, err } } res, err := q.db.Query(model, updateQuery{q}, q.model) if err != nil { return nil, err } if q.model != nil { if err := q.model.AfterUpdate(q.db); err != nil { return nil, err } } return res, nil } // Delete deletes the model. func (q *Query) Delete() (*types.Result, error) { if q.stickyErr != nil { return nil, q.stickyErr } if q.model != nil { if err := q.model.BeforeDelete(q.db); err != nil { return nil, err } } res, err := q.db.Query(q.model, deleteQuery{q}, q.model) if err != nil { return nil, err } if q.model != nil { if err := q.model.AfterDelete(q.db); err != nil { return nil, err } } return res, nil } func (q *Query) FormatQuery(dst []byte, query string, params ...interface{}) []byte { params = append(params, q.model) if q.db != nil { return q.db.FormatQuery(dst, query, params...) } return Formatter{}.Append(dst, query, params...) } func (q *Query) hasModel() bool { return !q.ignoreModel && q.model != nil } func (q *Query) hasTables() bool { return q.hasModel() || len(q.tables) > 0 } func (q *Query) appendTableName(b []byte) []byte { return q.FormatQuery(b, string(q.model.Table().Name)) } func (q *Query) appendTableNameWithAlias(b []byte) []byte { b = q.appendTableName(b) b = append(b, " AS "...) b = append(b, q.model.Table().Alias...) return b } func (q *Query) appendTables(b []byte) []byte { if q.hasModel() { b = q.appendTableNameWithAlias(b) if len(q.tables) > 0 { b = append(b, ", "...) } } for i, f := range q.tables { if i > 0 { b = append(b, ", "...) } b = f.AppendFormat(b, q) } return b } func (q *Query) appendFirstTable(b []byte) []byte { if q.hasModel() { return q.appendTableNameWithAlias(b) } if len(q.tables) > 0 { b = q.tables[0].AppendFormat(b, q) } return b } func (q *Query) hasOtherTables() bool { if q.hasModel() { return len(q.tables) > 0 } return len(q.tables) > 1 } func (q *Query) appendOtherTables(b []byte) []byte { tables := q.tables if !q.hasModel() { tables = tables[1:] } for i, f := range tables { if i > 0 { b = append(b, ", "...) } b = f.AppendFormat(b, q) } return b } func (q *Query) mustAppendWhere(b []byte) ([]byte, error) { if len(q.where) > 0 { b = q.appendWhere(b) return b, nil } if q.model == nil { return nil, errors.New("pg: Model(nil)") } if err := q.model.Table().checkPKs(); err != nil { return nil, err } b = append(b, " WHERE "...) return wherePKQuery{q}.AppendFormat(b, nil), nil } func (q *Query) appendWhere(b []byte) []byte { b = append(b, " WHERE "...) for i, f := range q.where { if i > 0 { b = append(b, ' ') b = f.AppendSep(b) b = append(b, ' ') } b = f.AppendFormat(b, q) } return b } func (q *Query) appendSet(b []byte) []byte { b = append(b, " SET "...) for i, f := range q.set { if i > 0 { b = append(b, ", "...) } b = f.AppendFormat(b, q) } return b } func (q *Query) appendReturning(b []byte) []byte { b = append(b, " RETURNING "...) for i, f := range q.returning { if i > 0 { b = append(b, ", "...) } b = f.AppendFormat(b, q) } return b } func (q *Query) appendWith(b []byte, count string) ([]byte, error) { var err error b = append(b, "WITH "...) for i, with := range q.with { if i > 0 { b = append(b, ", "...) } b = types.AppendField(b, with.name, 1) b = append(b, " AS ("...) if count != "" { b, err = with.query.countSelectQuery("*").AppendQuery(b) } else { b, err = selectQuery{Query: with.query}.AppendQuery(b) } if err != nil { return nil, err } b = append(b, ')') } b = append(b, ' ') return b, nil } //------------------------------------------------------------------------------ type wherePKQuery struct { *Query } func (wherePKQuery) AppendSep(b []byte) []byte { return append(b, "AND"...) } func (q wherePKQuery) AppendFormat(b []byte, f QueryFormatter) []byte { table := q.model.Table() return appendColumnAndValue(b, q.model.Value(), table, table.PKs) } pg-5.3.3/orm/query_test.go000066400000000000000000000011231305650307100154550ustar00rootroot00000000000000package orm_test import ( "testing" "unsafe" "gopkg.in/pg.v5/orm" ) func TestQuerySize(t *testing.T) { size := int(unsafe.Sizeof(orm.Query{})) wanted := 328 if size != wanted { t.Fatalf("got %d, wanted %d", size, wanted) } } type FormatModel struct { Foo string } func TestQueryFormatQuery(t *testing.T) { q := orm.NewQuery(nil, &FormatModel{"bar"}) params := &struct { Foo string }{ "not_bar", } b := q.FormatQuery(nil, "?foo ?TableAlias", params) wanted := `'not_bar' "format_model"` if string(b) != wanted { t.Fatalf("got %q, wanted %q", string(b), wanted) } } pg-5.3.3/orm/relation.go000066400000000000000000000005051305650307100150710ustar00rootroot00000000000000package orm import "gopkg.in/pg.v5/types" const ( HasOneRelation = 1 << iota BelongsToRelation HasManyRelation Many2ManyRelation ) type Relation struct { Type int Polymorphic bool Field *Field JoinTable *Table FKs []*Field M2MTableName types.Q BasePrefix string JoinPrefix string } pg-5.3.3/orm/select.go000066400000000000000000000051441305650307100145370ustar00rootroot00000000000000package orm import "strconv" func Select(db DB, model interface{}) error { q := NewQuery(db, model) if err := q.model.Table().checkPKs(); err != nil { return err } q.where = append(q.where, wherePKQuery{q}) return q.Select() } type selectQuery struct { *Query count string } var _ QueryAppender = (*selectQuery)(nil) func (q selectQuery) AppendQuery(b []byte, params ...interface{}) ([]byte, error) { var err error if len(q.with) > 0 { b, err = q.appendWith(b, q.count) if err != nil { return nil, err } } b = append(b, "SELECT "...) if q.count != "" && q.count != "*" { b = append(b, q.count...) } else { b = q.appendColumns(b) } if q.hasTables() { b = append(b, " FROM "...) b = q.appendTables(b) } q.forEachHasOneJoin(func(j *join) { b = append(b, ' ') b = j.appendHasOneJoin(b) }) if len(q.joins) > 0 { for _, f := range q.joins { b = append(b, ' ') b = f.AppendFormat(b, q) } } if len(q.where) > 0 { b = q.appendWhere(b) } if len(q.group) > 0 { b = append(b, " GROUP BY "...) for i, f := range q.group { if i > 0 { b = append(b, ", "...) } b = f.AppendFormat(b, q) } } if len(q.having) > 0 { b = append(b, " HAVING "...) for i, f := range q.having { if i > 0 { b = append(b, " AND "...) } b = append(b, '(') b = f.AppendFormat(b, q) b = append(b, ')') } } if q.count == "" { if len(q.order) > 0 { b = append(b, " ORDER BY "...) for i, f := range q.order { if i > 0 { b = append(b, ", "...) } b = f.AppendFormat(b, q) } } if q.limit != 0 { b = append(b, " LIMIT "...) b = strconv.AppendInt(b, int64(q.limit), 10) } if q.offset != 0 { b = append(b, " OFFSET "...) b = strconv.AppendInt(b, int64(q.offset), 10) } } return b, nil } func (q selectQuery) appendColumns(b []byte) []byte { start := len(b) if q.columns != nil { b = q.appendQueryColumns(b) } else if q.hasModel() { b = q.appendModelColumns(b) } else { b = append(b, '*') } q.forEachHasOneJoin(func(j *join) { if len(b) != start { b = append(b, ", "...) start = len(b) } b = j.appendHasOneColumns(b) if len(b) == start { b = b[:len(b)-2] } }) return b } func (q selectQuery) appendQueryColumns(b []byte) []byte { for i, f := range q.columns { if i > 0 { b = append(b, ", "...) } b = f.AppendFormat(b, q) } return b } func (q selectQuery) appendModelColumns(b []byte) []byte { for i, f := range q.model.Table().Fields { if i > 0 { b = append(b, ", "...) } b = append(b, q.model.Table().Alias...) b = append(b, '.') b = append(b, f.ColName...) } return b } pg-5.3.3/orm/select_test.go000066400000000000000000000136561305650307100156050ustar00rootroot00000000000000package orm import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type SelectModel struct { Id int Name string HasOne *HasOneModel HasOneId int HasMany []HasManyModel } type HasOneModel struct { Id int } type HasManyModel struct { Id int SelectModelId int } var _ = Describe("Select", func() { It("works without db", func() { q := NewQuery(nil).Where("hello = ?", "world") b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal("SELECT * WHERE (hello = 'world')")) }) It("specifies all columns", func() { q := NewQuery(nil, &SelectModel{}) b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT "select_model"."id", "select_model"."name", "select_model"."has_one_id" FROM "select_models" AS "select_model"`)) }) It("omits columns in main query", func() { q := NewQuery(nil, &SelectModel{}).Column("_", "HasOne") b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT "has_one"."id" AS "has_one__id" FROM "select_models" AS "select_model" LEFT JOIN "has_one_models" AS "has_one" ON "has_one"."id" = "select_model"."has_one_id"`)) }) It("omits columns in join query", func() { q := NewQuery(nil, &SelectModel{}).Column("HasOne._") b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT "select_model"."id", "select_model"."name", "select_model"."has_one_id" FROM "select_models" AS "select_model" LEFT JOIN "has_one_models" AS "has_one" ON "has_one"."id" = "select_model"."has_one_id"`)) }) It("specifies all columns for has one", func() { q := NewQuery(nil, &SelectModel{Id: 1}).Column("HasOne") b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT "select_model"."id", "select_model"."name", "select_model"."has_one_id", "has_one"."id" AS "has_one__id" FROM "select_models" AS "select_model" LEFT JOIN "has_one_models" AS "has_one" ON "has_one"."id" = "select_model"."has_one_id"`)) }) It("specifies all columns for has many", func() { q := NewQuery(nil, &SelectModel{Id: 1}).Column("HasMany") q, err := q.model.GetJoin("HasMany").manyQuery(nil) Expect(err).NotTo(HaveOccurred()) b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT "has_many_model"."id", "has_many_model"."select_model_id" FROM "has_many_models" AS "has_many_model" WHERE (("has_many_model"."select_model_id") IN ((1)))`)) }) It("supports multiple groups", func() { q := NewQuery(nil).Group("one").Group("two") b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT * GROUP BY "one", "two"`)) }) It("WhereOr", func() { q := NewQuery(nil).Where("1 = 1").WhereOr("1 = 2") b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT * WHERE (1 = 1) OR (1 = 2)`)) }) }) var _ = Describe("Count", func() { It("removes LIMIT, OFFSET, and ORDER", func() { q := NewQuery(nil).Order("order").Limit(1).Offset(2) b, err := q.countSelectQuery("count(*)").AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT count(*)`)) }) It("removes LIMIT, OFFSET, and ORDER from CTE", func() { q := NewQuery(nil). Column("col1", "col2"). Order("order"). Limit(1). Offset(2). WrapWith("wrapper"). Table("wrapper") b, err := q.countSelectQuery("count(*)").AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`WITH "wrapper" AS (SELECT "col1", "col2") SELECT count(*) FROM "wrapper"`)) }) It("uses CTE when query contains GROUP BY", func() { q := NewQuery(nil).Group("one") b, err := q.countQuery().countSelectQuery("count(*)").AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`WITH "wrapper" AS (SELECT * GROUP BY "one") SELECT count(*) FROM "wrapper"`)) }) It("includes has one joins", func() { q := NewQuery(nil, &SelectModel{Id: 1}).Column("HasOne") b, err := q.countSelectQuery("count(*)").AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT count(*) FROM "select_models" AS "select_model" LEFT JOIN "has_one_models" AS "has_one" ON "has_one"."id" = "select_model"."has_one_id"`)) }) }) var _ = Describe("With", func() { It("WrapWith wraps query in CTE", func() { q := NewQuery(nil, &SelectModel{}). Where("cond1"). WrapWith("wrapper"). Table("wrapper"). Where("cond2") b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`WITH "wrapper" AS (SELECT "select_model"."id", "select_model"."name", "select_model"."has_one_id" FROM "select_models" AS "select_model" WHERE (cond1)) SELECT * FROM "wrapper" WHERE (cond2)`)) }) It("generates nested CTE", func() { q1 := NewQuery(nil).Table("q1") q2 := NewQuery(nil).With("q1", q1).Table("q2", "q1") q3 := NewQuery(nil).With("q2", q2).Table("q3", "q2") b, err := selectQuery{Query: q3}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`WITH "q2" AS (WITH "q1" AS (SELECT * FROM "q1") SELECT * FROM "q2", "q1") SELECT * FROM "q3", "q2"`)) }) }) type orderTest struct { order string query string } var _ = Describe("Select Order", func() { orderTests := []orderTest{ {"id", `"id"`}, {"id asc", `"id" asc`}, {"id desc", `"id" desc`}, {"id ASC", `"id" ASC`}, {"id DESC", `"id" DESC`}, {"id ASC NULLS FIRST", `"id ASC NULLS FIRST"`}, } It("sets order", func() { for _, test := range orderTests { q := NewQuery(nil).Order(test.order) b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`SELECT * ORDER BY ` + test.query)) } }) }) pg-5.3.3/orm/table.go000066400000000000000000000233601305650307100143470ustar00rootroot00000000000000package orm import ( "database/sql" "fmt" "reflect" "strings" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/types" "github.com/jinzhu/inflection" ) var nullBool = reflect.TypeOf((*sql.NullBool)(nil)).Elem() var nullFloat = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() var nullInt = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() var nullString = reflect.TypeOf((*sql.NullString)(nil)).Elem() type Table struct { Type reflect.Type zeroStruct reflect.Value TypeName string Name types.Q Alias types.Q ModelName string PKs []*Field Fields []*Field FieldsMap map[string]*Field Methods map[string]*Method Relations map[string]*Relation flags int16 } func (t *Table) Has(flag int16) bool { if t == nil { return false } return t.flags&flag != 0 } func (t *Table) HasField(field string) bool { _, err := t.GetField(field) return err == nil } func (t *Table) checkPKs() error { if len(t.PKs) == 0 { return fmt.Errorf("model=%s does not have primary keys", t.Type.Name()) } return nil } func (t *Table) AddField(field *Field) { t.Fields = append(t.Fields, field) t.FieldsMap[field.SQLName] = field } func (t *Table) GetField(fieldName string) (*Field, error) { field, ok := t.FieldsMap[fieldName] if !ok { return nil, fmt.Errorf("can't find column=%s in table=%s", fieldName, t.Name) } return field, nil } func (t *Table) AppendParam(dst []byte, strct reflect.Value, name string) ([]byte, bool) { if field, ok := t.FieldsMap[name]; ok { dst = field.AppendValue(dst, strct, 1) return dst, true } if method, ok := t.Methods[name]; ok { dst = method.AppendValue(dst, strct.Addr(), 1) return dst, true } return dst, false } func (t *Table) addRelation(rel *Relation) { if t.Relations == nil { t.Relations = make(map[string]*Relation) } t.Relations[rel.Field.GoName] = rel } func newTable(typ reflect.Type) *Table { table, ok := Tables.tables[typ] if ok { return table } table, ok = Tables.inFlight[typ] if ok { return table } modelName := internal.Underscore(typ.Name()) table = &Table{ Type: typ, zeroStruct: reflect.Zero(typ), TypeName: internal.ToExported(typ.Name()), Name: types.Q(types.AppendField(nil, inflection.Plural(modelName), 1)), Alias: types.Q(types.AppendField(nil, modelName, 1)), ModelName: modelName, Fields: make([]*Field, 0, typ.NumField()), FieldsMap: make(map[string]*Field, typ.NumField()), } Tables.inFlight[typ] = table table.addFields(typ, nil) typ = reflect.PtrTo(typ) if typ.Implements(afterQueryHookType) { table.flags |= AfterQueryHookFlag } if typ.Implements(afterSelectHookType) { table.flags |= AfterSelectHookFlag } if typ.Implements(beforeInsertHookType) { table.flags |= BeforeInsertHookFlag } if typ.Implements(afterInsertHookType) { table.flags |= AfterInsertHookFlag } if typ.Implements(beforeUpdateHookType) { table.flags |= BeforeUpdateHookFlag } if typ.Implements(afterUpdateHookType) { table.flags |= AfterUpdateHookFlag } if typ.Implements(beforeDeleteHookType) { table.flags |= BeforeDeleteHookFlag } if typ.Implements(afterDeleteHookType) { table.flags |= AfterDeleteHookFlag } if table.Methods == nil { table.Methods = make(map[string]*Method) } for i := 0; i < typ.NumMethod(); i++ { m := typ.Method(i) if m.PkgPath != "" { continue } if m.Type.NumIn() > 1 { continue } if m.Type.NumOut() != 1 { continue } retType := m.Type.Out(0) method := Method{ Index: m.Index, appender: types.Appender(retType), } table.Methods[m.Name] = &method } Tables.tables[typ] = table delete(Tables.inFlight, typ) return table } func (t *Table) addFields(typ reflect.Type, index []int) { for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) if f.Anonymous { embeddedTable := newTable(indirectType(f.Type)) _, pgOpt := parseTag(f.Tag.Get("pg")) if _, ok := pgOpt.Get("override"); ok { t.TypeName = embeddedTable.TypeName t.Name = embeddedTable.Name t.Alias = embeddedTable.Alias t.ModelName = embeddedTable.ModelName } t.addFields(embeddedTable.Type, append(index, f.Index...)) continue } field := t.newField(f, index) if field != nil { t.AddField(field) } } } func (t *Table) getField(name string) *Field { for _, f := range t.Fields { if f.GoName == name { return f } } f, ok := t.Type.FieldByName(name) if !ok { return nil } return t.newField(f, nil) } func (t *Table) newField(f reflect.StructField, index []int) *Field { sqlName, sqlOpt := parseTag(f.Tag.Get("sql")) switch f.Name { case "tableName", "TableName": if index != nil { return nil } if sqlName != "" { t.Name = types.Q(sqlName) } if alias, ok := sqlOpt.Get("alias:"); ok { t.Alias = types.Q(alias) } return nil } if f.PkgPath != "" { return nil } skip := sqlName == "-" if skip || sqlName == "" { sqlName = internal.Underscore(f.Name) } if field, ok := t.FieldsMap[sqlName]; ok { return field } _, pgOpt := parseTag(f.Tag.Get("pg")) var appender types.AppenderFunc var scanner types.ScannerFunc if _, ok := pgOpt.Get("array"); ok { appender = types.ArrayAppender(f.Type) scanner = types.ArrayScanner(f.Type) } else if _, ok := pgOpt.Get("hstore"); ok { appender = types.HstoreAppender(f.Type) scanner = types.HstoreScanner(f.Type) } else { appender = types.Appender(f.Type) scanner = types.Scanner(f.Type) } field := Field{ Type: indirectType(f.Type), GoName: f.Name, SQLName: sqlName, ColName: types.Q(types.AppendField(nil, sqlName, 1)), Index: append(index, f.Index...), append: appender, scan: scanner, isEmpty: isEmptyFunc(f.Type), } if _, ok := sqlOpt.Get("notnull"); ok { field.flags |= NotNullFlag } if _, ok := sqlOpt.Get("unique"); ok { field.flags |= UniqueFlag } if len(t.PKs) == 0 && (field.SQLName == "id" || field.SQLName == "uuid") { field.flags |= PrimaryKeyFlag t.PKs = append(t.PKs, &field) } else if _, ok := sqlOpt.Get("pk"); ok { field.flags |= PrimaryKeyFlag t.PKs = append(t.PKs, &field) } else if strings.HasSuffix(string(field.SQLName), "_id") { field.flags |= ForeignKeyFlag } field.SQLType = sqlType(&field, sqlOpt) if !skip && types.IsSQLScanner(f.Type) { return &field } switch field.Type.Kind() { case reflect.Slice: elemType := indirectType(field.Type.Elem()) if elemType.Kind() != reflect.Struct { break } joinTable := newTable(elemType) basePrefix := t.TypeName if s, ok := pgOpt.Get("fk:"); ok { basePrefix = s } if m2mTable, _ := pgOpt.Get("many2many:"); m2mTable != "" { joinPrefix := joinTable.TypeName if s, ok := pgOpt.Get("joinFK:"); ok { joinPrefix = s } t.addRelation(&Relation{ Type: Many2ManyRelation, Field: &field, JoinTable: joinTable, M2MTableName: types.Q(m2mTable), BasePrefix: internal.Underscore(basePrefix + "_"), JoinPrefix: internal.Underscore(joinPrefix + "_"), }) return nil } var polymorphic bool if s, ok := pgOpt.Get("polymorphic:"); ok { polymorphic = true basePrefix = s } fks := foreignKeys(t, joinTable, basePrefix) if len(fks) > 0 { t.addRelation(&Relation{ Type: HasManyRelation, Polymorphic: polymorphic, Field: &field, FKs: fks, JoinTable: joinTable, BasePrefix: internal.Underscore(basePrefix + "_"), }) return nil } case reflect.Struct: joinTable := newTable(field.Type) if len(joinTable.Fields) == 0 { break } for _, ff := range joinTable.FieldsMap { ff = ff.Copy() ff.SQLName = field.SQLName + "__" + ff.SQLName ff.ColName = types.Q(types.AppendField(nil, ff.SQLName, 1)) ff.Index = append(field.Index, ff.Index...) t.FieldsMap[ff.SQLName] = ff } if t.detectHasOne(&field, joinTable) || t.detectBelongsToOne(&field, joinTable) { t.FieldsMap[field.SQLName] = &field return nil } } if skip { t.FieldsMap[field.SQLName] = &field return nil } return &field } func sqlType(field *Field, sqlOpt tagOptions) string { if v, ok := sqlOpt.Get("type:"); ok { return v } switch field.Type { case timeType: return "timestamptz" case nullBool: return "boolean" case nullFloat: return "double precision" case nullInt: return "bigint" case nullString: return "text" } switch field.Type.Kind() { case reflect.Int8, reflect.Uint8, reflect.Int16: if field.Has(PrimaryKeyFlag) { return "smallserial" } return "smallint" case reflect.Uint16, reflect.Int32: if field.Has(PrimaryKeyFlag) { return "serial" } return "integer" case reflect.Uint32, reflect.Int64, reflect.Int: if field.Has(PrimaryKeyFlag) { return "bigserial" } return "bigint" case reflect.Uint, reflect.Uint64: return "decimal" case reflect.Float32: return "real" case reflect.Float64: return "double precision" case reflect.Bool: return "boolean" case reflect.String: return "text" case reflect.Map, reflect.Slice, reflect.Struct: return "jsonb" default: return field.Type.Kind().String() } } func foreignKeys(base, join *Table, prefix string) []*Field { var fks []*Field for _, pk := range base.PKs { fkName := prefix + pk.GoName if fk := join.getField(fkName); fk != nil { fks = append(fks, fk) } } return fks } func (t *Table) detectHasOne(field *Field, joinTable *Table) bool { fks := foreignKeys(joinTable, t, field.GoName) if len(fks) > 0 { t.addRelation(&Relation{ Type: HasOneRelation, Field: field, FKs: fks, JoinTable: joinTable, }) return true } return false } func (t *Table) detectBelongsToOne(field *Field, joinTable *Table) bool { fks := foreignKeys(t, joinTable, t.TypeName) if len(fks) > 0 { t.addRelation(&Relation{ Type: BelongsToRelation, Field: field, FKs: fks, JoinTable: joinTable, }) return true } return false } pg-5.3.3/orm/table_params.go000066400000000000000000000010011305650307100156760ustar00rootroot00000000000000package orm import "reflect" type tableParams struct { table *Table strct reflect.Value } func newTableParams(strct interface{}) (*tableParams, bool) { v := reflect.ValueOf(strct) if !v.IsValid() { return nil, false } v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return nil, false } return &tableParams{ table: Tables.Get(v.Type()), strct: v, }, true } func (m tableParams) AppendParam(dst []byte, name string) ([]byte, bool) { return m.table.AppendParam(dst, m.strct, name) } pg-5.3.3/orm/table_test.go000066400000000000000000000053731305650307100154120ustar00rootroot00000000000000package orm_test import ( "reflect" "gopkg.in/pg.v5/orm" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type A struct { Id int } func (A) Method() int { return 10 } type B struct { A } var _ = Describe("embedded Model", func() { var strct reflect.Value var table *orm.Table BeforeEach(func() { strct = reflect.ValueOf(B{A: A{Id: 1}}) table = orm.Tables.Get(strct.Type()) }) It("has fields", func() { Expect(table.Fields).To(HaveLen(1)) Expect(table.FieldsMap).To(HaveLen(1)) id, ok := table.FieldsMap["id"] Expect(ok).To(BeTrue()) Expect(id.GoName).To(Equal("Id")) Expect(id.SQLName).To(Equal("id")) Expect(string(id.ColName)).To(Equal(`"id"`)) Expect(id.Has(orm.PrimaryKeyFlag)).To(BeTrue()) Expect(string(id.AppendValue(nil, strct, 1))).To(Equal("1")) Expect(table.PKs).To(HaveLen(1)) Expect(table.PKs[0]).To(Equal(id)) }) It("has methods", func() { Expect(table.Methods).To(HaveLen(1)) m, ok := table.Methods["Method"] Expect(ok).To(BeTrue()) Expect(m.Index).To(Equal(0)) Expect(string(m.AppendValue(nil, strct, 1))).To(Equal("10")) }) }) type C struct { Name int `sql:",pk"` Id int UUID int } var _ = Describe("primary key annotation", func() { var table *orm.Table BeforeEach(func() { strct := reflect.ValueOf(C{}) table = orm.Tables.Get(strct.Type()) }) It("has precedence over auto-detection", func() { Expect(table.PKs).To(HaveLen(1)) Expect(table.PKs[0].GoName).To(Equal("Name")) }) }) type D struct { UUID int } var _ = Describe("uuid field", func() { var table *orm.Table BeforeEach(func() { strct := reflect.ValueOf(D{}) table = orm.Tables.Get(strct.Type()) }) It("is detected as primary key", func() { Expect(table.PKs).To(HaveLen(1)) Expect(table.PKs[0].GoName).To(Equal("UUID")) }) }) type E struct { Id int StructField struct { Foo string Bar string } } var _ = Describe("struct field", func() { var table *orm.Table BeforeEach(func() { strct := reflect.ValueOf(E{}) table = orm.Tables.Get(strct.Type()) }) It("is present in the list", func() { Expect(table.Fields).To(HaveLen(2)) _, ok := table.FieldsMap["struct_field"] Expect(ok).To(BeTrue()) }) }) type f struct { Id int G *g } type g struct { Id int FId int F *f } var _ = Describe("unexported types", func() { It("work with belongs to relation", func() { strct := reflect.ValueOf(f{}) table := orm.Tables.Get(strct.Type()) rel, ok := table.Relations["G"] Expect(ok).To(BeTrue()) Expect(rel.Type).To(Equal(orm.BelongsToRelation)) }) It("work with has one relation", func() { strct := reflect.ValueOf(g{}) table := orm.Tables.Get(strct.Type()) rel, ok := table.Relations["F"] Expect(ok).To(BeTrue()) Expect(rel.Type).To(Equal(orm.HasOneRelation)) }) }) pg-5.3.3/orm/tables.go000066400000000000000000000011741305650307100145310ustar00rootroot00000000000000package orm import ( "fmt" "reflect" "sync" ) var Tables = newTables() type tables struct { inFlight map[reflect.Type]*Table tables map[reflect.Type]*Table mu sync.RWMutex } func newTables() *tables { return &tables{ inFlight: make(map[reflect.Type]*Table), tables: make(map[reflect.Type]*Table), } } func (t *tables) Get(typ reflect.Type) *Table { if typ.Kind() != reflect.Struct { panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) } t.mu.RLock() table, ok := t.tables[typ] t.mu.RUnlock() if ok { return table } t.mu.Lock() table = newTable(typ) t.mu.Unlock() return table } pg-5.3.3/orm/tag.go000066400000000000000000000010661305650307100140320ustar00rootroot00000000000000package orm import ( "bytes" "strings" ) type tagOptions string func (o tagOptions) Get(name string) (string, bool) { s := string(o) for len(s) > 0 { var next string idx := strings.IndexByte(s, ',') if idx >= 0 { s, next = s[:idx], s[idx+1:] } if strings.HasPrefix(s, name) { return s[len(name):], true } s = next } return "", false } func parseTag(tagStr string) (string, tagOptions) { tag := []byte(tagStr) if idx := bytes.IndexByte(tag, ','); idx != -1 { return string(tag[:idx]), tagOptions(tag[idx+1:]) } return tagStr, "" } pg-5.3.3/orm/update.go000066400000000000000000000033641305650307100145440ustar00rootroot00000000000000package orm import ( "errors" "gopkg.in/pg.v5/internal" ) func Update(db DB, model interface{}) error { res, err := NewQuery(db, model).Update() if err != nil { return err } return internal.AssertOneRow(res.RowsAffected()) } type updateQuery struct { *Query } var _ QueryAppender = (*updateQuery)(nil) func (q updateQuery) AppendQuery(b []byte, params ...interface{}) ([]byte, error) { var err error if len(q.with) > 0 { b, err = q.appendWith(b, "") if err != nil { return nil, err } } b = append(b, "UPDATE "...) b = q.appendFirstTable(b) b, err = q.mustAppendSet(b) if err != nil { return nil, err } if q.hasOtherTables() { b = append(b, " FROM "...) b = q.appendOtherTables(b) } b, err = q.mustAppendWhere(b) if err != nil { return nil, err } if len(q.returning) > 0 { b = q.appendReturning(b) } return b, nil } func (q updateQuery) mustAppendSet(b []byte) ([]byte, error) { if len(q.set) > 0 { b = q.appendSet(b) return b, nil } if q.model == nil { return nil, errors.New("pg: Model(nil)") } b = append(b, " SET "...) table := q.model.Table() strct := q.model.Value() if fields := q.getFields(); len(fields) > 0 { for i, fieldName := range fields { field, err := table.GetField(fieldName) if err != nil { return nil, err } if i > 0 { b = append(b, ", "...) } b = append(b, field.ColName...) b = append(b, " = "...) b = field.AppendValue(b, strct, 1) } return b, nil } start := len(b) for _, field := range table.Fields { if field.Has(PrimaryKeyFlag) { continue } b = append(b, field.ColName...) b = append(b, " = "...) b = field.AppendValue(b, strct, 1) b = append(b, ", "...) } if len(b) > start { b = b[:len(b)-2] } return b, nil } pg-5.3.3/orm/update_test.go000066400000000000000000000011311305650307100155710ustar00rootroot00000000000000package orm import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type UpdateTest struct{} var _ = Describe("Update", func() { It("supports WITH", func() { q := NewQuery(nil, &UpdateTest{}). WrapWith("wrapper"). Model(&UpdateTest{}). Table("wrapper"). Where("update_test.id = wrapper.id") b, err := updateQuery{q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(`WITH "wrapper" AS (SELECT FROM "update_tests" AS "update_test") UPDATE "update_tests" AS "update_test" SET FROM "wrapper" WHERE (update_test.id = wrapper.id)`)) }) }) pg-5.3.3/orm/url_values.go000066400000000000000000000060261305650307100154410ustar00rootroot00000000000000package orm import ( "fmt" "net/url" "strconv" "strings" "gopkg.in/pg.v5/types" ) func URLValues(urlValues url.Values) func(*Query) (*Query, error) { return func(q *Query) (*Query, error) { for fieldName, values := range urlValues { var operation string if i := strings.Index(fieldName, "__"); i != -1 { fieldName, operation = fieldName[:i], fieldName[i+2:] } if q.model.Table().HasField(fieldName) { q = addOperator(q, fieldName, operation, values) } } return setOrder(q, urlValues), nil } } func addOperator(q *Query, fieldName, operator string, values []string) *Query { switch operator { case "gt": q = forEachValue(q, fieldName, values, "? > ?") case "gte": q = forEachValue(q, fieldName, values, "? >= ?") case "lt": q = forEachValue(q, fieldName, values, "? < ?") case "lte": q = forEachValue(q, fieldName, values, "? <= ?") case "ieq": q = forEachValue(q, fieldName, values, "? ILIKE ?") case "match": q = forEachValue(q, fieldName, values, "? SIMILAR TO ?") case "exclude": q = forAllValues(q, fieldName, values, "? != ?", "? NOT IN (?)") case "", "include": q = forAllValues(q, fieldName, values, "? = ?", "? IN (?)") } return q } func forEachValue(q *Query, fieldName string, values []string, queryTemplate string) *Query { for _, value := range values { q = q.Where(queryTemplate, types.F(fieldName), value) } return q } func forAllValues(q *Query, fieldName string, values []string, queryTemplate, queryArrayTemplate string) *Query { if len(values) > 1 { q = q.Where(queryArrayTemplate, types.F(fieldName), types.In(values)) } else { q = q.Where(queryTemplate, types.F(fieldName), values[0]) } return q } func setOrder(q *Query, urlValues url.Values) *Query { for _, order := range urlValues["order"] { if order != "" { q = q.Order(order) } } return q } // Pager sets LIMIT and OFFSET from the URL values: // - ?limit=10 - sets q.Limit(10), max limit is 1000. // - ?page=5 - sets q.Offset((page - 1) * limit), max offset is 1000000. func Pager(urlValues url.Values, defaultLimit int) func(*Query) (*Query, error) { return func(q *Query) (*Query, error) { const maxLimit = 1000 const maxOffset = 1e6 limit, err := intParam(urlValues, "limit") if err != nil { return nil, err } if limit < 1 { limit = defaultLimit } else if limit > maxLimit { return nil, fmt.Errorf("limit=%d is bigger than %d", limit, maxLimit) } if limit > 0 { q = q.Limit(limit) } page, err := intParam(urlValues, "page") if err != nil { return nil, err } if page > 0 { offset := (page - 1) * limit if offset > maxOffset { return nil, fmt.Errorf("offset=%d can't bigger than %d", offset, maxOffset) } q = q.Offset(offset) } return q, nil } } func intParam(urlValues url.Values, paramName string) (int, error) { values, ok := urlValues[paramName] if !ok { return 0, nil } value, err := strconv.Atoi(values[0]) if err != nil { return 0, fmt.Errorf("param=%s value=%s is invalid: %s", paramName, values[0], err) } return value, nil } pg-5.3.3/orm/url_values_test.go000066400000000000000000000063661305650307100165070ustar00rootroot00000000000000package orm import ( "net/http" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) type URLValuesModel struct { Id int Name string } type urlValuesTest struct { url string query string } var _ = Describe("URLValues", func() { query := `SELECT "url_values_model"."id", "url_values_model"."name" FROM "url_values_models" AS "url_values_model"` urlValuesTests := []urlValuesTest{ { url: "http://localhost:8000/test?id__gt=1", query: query + ` WHERE ("id" > '1')`, }, { url: "http://localhost:8000/test?name__gte=Michael", query: query + ` WHERE ("name" >= 'Michael')`, }, { url: "http://localhost:8000/test?id__lt=10", query: query + ` WHERE ("id" < '10')`, }, { url: "http://localhost:8000/test?name__lte=Peter", query: query + ` WHERE ("name" <= 'Peter')`, }, { url: "http://localhost:8000/test?name__exclude=Peter", query: query + ` WHERE ("name" != 'Peter')`, }, { url: "http://localhost:8000/test?name__exclude=Mike&name__exclude=Peter", query: query + ` WHERE ("name" NOT IN ('Mike','Peter'))`, }, { url: "http://localhost:8000/test?name=Mike", query: query + ` WHERE ("name" = 'Mike')`, }, { url: "http://localhost:8000/test?name__ieq=mik_", query: query + ` WHERE ("name" ILIKE 'mik_')`, }, { url: "http://localhost:8000/test?name__match=(m|p).*", query: query + ` WHERE ("name" SIMILAR TO '(m|p).*')`, }, { url: "http://localhost:8000/test?name__include=Peter&name__include=Mike", query: query + ` WHERE ("name" IN ('Peter','Mike'))`, }, { url: "http://localhost:8000/test?name=Mike&name=Peter", query: query + ` WHERE ("name" IN ('Mike','Peter'))`, }, { url: "http://localhost:8000/test?order=name DESC", query: query + ` ORDER BY "name" DESC`, }, { url: "http://localhost:8000/test?order=id ASC&order=name DESC", query: query + ` ORDER BY "id" ASC, "name" DESC`, }, { url: "http://localhost:8000/test?invalid_field=1", query: query, }, } It("adds conditions to the query", func() { for _, urlValuesTest := range urlValuesTests { req, _ := http.NewRequest("GET", urlValuesTest.url, nil) q := NewQuery(nil, &URLValuesModel{}) q = q.Apply(URLValues(req.URL.Query())) b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(urlValuesTest.query)) } }) }) var _ = Describe("Pager", func() { query := `SELECT "url_values_model"."id", "url_values_model"."name" FROM "url_values_models" AS "url_values_model"` urlValuesTests := []urlValuesTest{ { url: "http://localhost:8000/test?limit=10", query: query + " LIMIT 10", }, { url: "http://localhost:8000/test?page=5", query: query + ` LIMIT 100 OFFSET 400`, }, { url: "http://localhost:8000/test?page=5&limit=20", query: query + ` LIMIT 20 OFFSET 80`, }, } It("adds limit and offset to the query", func() { for _, urlValuesTest := range urlValuesTests { req, _ := http.NewRequest("GET", urlValuesTest.url, nil) q := NewQuery(nil, &URLValuesModel{}) q = q.Apply(Pager(req.URL.Query(), 100)) b, err := selectQuery{Query: q}.AppendQuery(nil) Expect(err).NotTo(HaveOccurred()) Expect(string(b)).To(Equal(urlValuesTest.query)) } }) }) pg-5.3.3/orm/util.go000066400000000000000000000060271305650307100142360ustar00rootroot00000000000000package orm import ( "reflect" "gopkg.in/pg.v5/types" ) func indirectType(t reflect.Type) reflect.Type { if t.Kind() == reflect.Ptr { t = t.Elem() } return t } func sliceElemType(v reflect.Value) reflect.Type { elemType := v.Type().Elem() if elemType.Kind() == reflect.Interface && v.Len() > 0 { return reflect.Indirect(v.Index(0).Elem()).Type() } else { return indirectType(elemType) } } func typeByIndex(t reflect.Type, index []int) reflect.Type { for _, x := range index { switch t.Kind() { case reflect.Ptr: t = t.Elem() case reflect.Slice: t = indirectType(t.Elem()) } t = t.Field(x).Type } return indirectType(t) } func fieldByIndex(v reflect.Value, index []int) reflect.Value { for i, x := range index { if i > 0 { v = indirectNew(v) } v = v.Field(x) } return v } func indirectNew(v reflect.Value) reflect.Value { if v.Kind() == reflect.Ptr { if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } v = v.Elem() } return v } func columns(table types.Q, prefix string, fields []*Field) []byte { var b []byte for i, f := range fields { if i > 0 { b = append(b, ", "...) } if len(table) > 0 { b = append(b, table...) b = append(b, '.') } b = types.AppendField(b, prefix+f.SQLName, 1) } return b } func walk(v reflect.Value, index []int, fn func(reflect.Value)) { v = reflect.Indirect(v) switch v.Kind() { case reflect.Slice: for i := 0; i < v.Len(); i++ { visitField(v.Index(i), index, fn) } default: visitField(v, index, fn) } } func visitField(v reflect.Value, index []int, fn func(reflect.Value)) { v = reflect.Indirect(v) if len(index) > 0 { v = v.Field(index[0]) walk(v, index[1:], fn) } else { fn(v) } } func values(v reflect.Value, index []int, fields []*Field) []byte { var b []byte walk(v, index, func(v reflect.Value) { b = append(b, '(') for i, field := range fields { if i > 0 { b = append(b, ", "...) } b = field.AppendValue(b, v, 1) } b = append(b, "), "...) }) if len(b) > 0 { b = b[:len(b)-2] // trim ", " } return b } func dstValues(model tableModel, fields []*Field) map[string][]reflect.Value { mp := make(map[string][]reflect.Value) var id []byte walk(model.Root(), model.ParentIndex(), func(v reflect.Value) { id = modelId(id[:0], v, fields) mp[string(id)] = append(mp[string(id)], v.FieldByIndex(model.Relation().Field.Index)) }) return mp } func appendColumnAndValue(b []byte, v reflect.Value, table *Table, fields []*Field) []byte { for i, f := range fields { if i > 0 { b = append(b, " AND "...) } b = append(b, table.Alias...) b = append(b, '.') b = append(b, f.ColName...) b = append(b, " = "...) b = f.AppendValue(b, v, 1) } return b } func modelId(b []byte, v reflect.Value, fields []*Field) []byte { for _, f := range fields { b = f.AppendValue(b, v, 0) b = append(b, ',') } return b } func modelIdMap(b []byte, m map[string]string, prefix string, fields []*Field) []byte { for _, f := range fields { b = append(b, m[prefix+f.SQLName]...) b = append(b, ',') } return b } pg-5.3.3/pg.go000066400000000000000000000130011305650307100130600ustar00rootroot00000000000000package pg // import "gopkg.in/pg.v5" import ( "log" "strconv" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/orm" "gopkg.in/pg.v5/types" ) // Discard is used with Query and QueryOne to discard rows. var Discard orm.Discard // Model returns new query for the optional model. func Model(model ...interface{}) *orm.Query { return orm.NewQuery(nil, model...) } // Scan returns ColumnScanner that copies the columns in the // row into the values. func Scan(values ...interface{}) orm.ColumnScanner { return orm.Scan(values...) } // Q replaces any placeholders found in the query. func Q(query string, params ...interface{}) orm.FormatAppender { return orm.Q(query, params...) } // F quotes a SQL identifier such as a table or column name replacing any // placeholders found in the field. func F(field string) types.ValueAppender { return types.F(field) } // In accepts a slice and returns a wrapper that can be used with PostgreSQL // IN operator: // // Where("id IN (?)", pg.In([]int{1, 2, 3})) func In(slice interface{}) types.ValueAppender { return types.In(slice) } // Array accepts a slice and returns a wrapper for working with PostgreSQL // array data type. // // For struct fields you can use array tag: // // Emails []string `pg:",array"` func Array(v interface{}) *types.Array { return types.NewArray(v) } // Hstore accepts a map and returns a wrapper for working with hstore data type. // Supported map types are: // - map[string]string // // For struct fields you can use hstore tag: // // Attrs map[string]string `pg:",hstore"` func Hstore(v interface{}) *types.Hstore { return types.NewHstore(v) } func SetLogger(logger *log.Logger) { internal.Logger = logger } // SetQueryLogger sets a logger that will be used to log generated queries. func SetQueryLogger(logger *log.Logger) { internal.QueryLogger = logger } //------------------------------------------------------------------------------ type Strings []string var _ orm.Model = (*Strings)(nil) var _ types.ValueAppender = (*Strings)(nil) func (strings *Strings) Reset() error { if s := *strings; len(s) > 0 { *strings = s[:0] } return nil } func (strings *Strings) NewModel() orm.ColumnScanner { return strings } func (Strings) AddModel(_ orm.ColumnScanner) error { return nil } func (Strings) AfterQuery(_ orm.DB) error { return nil } func (Strings) AfterSelect(_ orm.DB) error { return nil } func (Strings) BeforeInsert(_ orm.DB) error { return nil } func (Strings) AfterInsert(_ orm.DB) error { return nil } func (Strings) BeforeUpdate(_ orm.DB) error { return nil } func (Strings) AfterUpdate(_ orm.DB) error { return nil } func (Strings) BeforeDelete(_ orm.DB) error { return nil } func (Strings) AfterDelete(_ orm.DB) error { return nil } func (strings *Strings) ScanColumn(colIdx int, _ string, b []byte) error { *strings = append(*strings, string(b)) return nil } func (strings Strings) AppendValue(dst []byte, quote int) ([]byte, error) { if len(strings) <= 0 { return dst, nil } for _, s := range strings { dst = types.AppendString(dst, s, 1) dst = append(dst, ',') } dst = dst[:len(dst)-1] return dst, nil } //------------------------------------------------------------------------------ type Ints []int64 var _ orm.Model = (*Ints)(nil) var _ types.ValueAppender = (*Ints)(nil) func (ints *Ints) Reset() error { if s := *ints; len(s) > 0 { *ints = s[:0] } return nil } func (ints *Ints) NewModel() orm.ColumnScanner { return ints } func (Ints) AddModel(_ orm.ColumnScanner) error { return nil } func (Ints) AfterQuery(_ orm.DB) error { return nil } func (Ints) AfterSelect(_ orm.DB) error { return nil } func (Ints) BeforeInsert(_ orm.DB) error { return nil } func (Ints) AfterInsert(_ orm.DB) error { return nil } func (Ints) BeforeUpdate(_ orm.DB) error { return nil } func (Ints) AfterUpdate(_ orm.DB) error { return nil } func (Ints) BeforeDelete(_ orm.DB) error { return nil } func (Ints) AfterDelete(_ orm.DB) error { return nil } func (ints *Ints) ScanColumn(colIdx int, colName string, b []byte) error { n, err := strconv.ParseInt(internal.BytesToString(b), 10, 64) if err != nil { return err } *ints = append(*ints, n) return nil } func (ints Ints) AppendValue(dst []byte, quote int) ([]byte, error) { if len(ints) <= 0 { return dst, nil } for _, v := range ints { dst = strconv.AppendInt(dst, v, 10) dst = append(dst, ',') } dst = dst[:len(dst)-1] return dst, nil } //------------------------------------------------------------------------------ type IntSet map[int64]struct{} var _ orm.Model = (*IntSet)(nil) func (set *IntSet) Reset() error { if len(*set) > 0 { *set = make(map[int64]struct{}) } return nil } func (set *IntSet) NewModel() orm.ColumnScanner { return set } func (IntSet) AddModel(_ orm.ColumnScanner) error { return nil } func (IntSet) AfterQuery(_ orm.DB) error { return nil } func (IntSet) AfterSelect(_ orm.DB) error { return nil } func (IntSet) BeforeInsert(_ orm.DB) error { return nil } func (IntSet) AfterInsert(_ orm.DB) error { return nil } func (IntSet) BeforeUpdate(_ orm.DB) error { return nil } func (IntSet) AfterUpdate(_ orm.DB) error { return nil } func (IntSet) BeforeDelete(_ orm.DB) error { return nil } func (IntSet) AfterDelete(_ orm.DB) error { return nil } func (setptr *IntSet) ScanColumn(colIdx int, colName string, b []byte) error { set := *setptr if set == nil { *setptr = make(IntSet) set = *setptr } n, err := strconv.ParseInt(internal.BytesToString(b), 10, 64) if err != nil { return err } set[n] = struct{}{} return nil } pg-5.3.3/pool_test.go000066400000000000000000000065701305650307100144770ustar00rootroot00000000000000package pg_test import ( "time" "gopkg.in/pg.v5" . "gopkg.in/check.v1" ) var _ = Suite(&PoolTest{}) type PoolTest struct { db *pg.DB } func (t *PoolTest) SetUpTest(c *C) { opt := pgOptions() opt.IdleTimeout = time.Second t.db = pg.Connect(opt) } func (t *PoolTest) TearDownTest(c *C) { _ = t.db.Close() } func (t *PoolTest) TestPoolReusesConnection(c *C) { for i := 0; i < 100; i++ { _, err := t.db.Exec("SELECT 'test_pool_reuses_connection'") c.Assert(err, IsNil) } c.Assert(t.db.Pool().Len(), Equals, 1) c.Assert(t.db.Pool().FreeLen(), Equals, 1) } func (t *PoolTest) TestPoolMaxSize(c *C) { N := 1000 perform(N, func(int) { _, err := t.db.Exec("SELECT 'test_pool_max_size'") c.Assert(err, IsNil) }) c.Assert(t.db.Pool().Len(), Equals, 10) c.Assert(t.db.Pool().FreeLen(), Equals, 10) } func (t *PoolTest) TestCloseClosesAllConnections(c *C) { ln := t.db.Listen("test_channel") wait := make(chan struct{}, 2) go func() { wait <- struct{}{} _, _, err := ln.Receive() c.Assert(err, ErrorMatches, `^(.*use of closed (file or )?network connection|EOF)$`) wait <- struct{}{} }() select { case <-wait: // ok case <-time.After(3 * time.Second): c.Fatal("timeout") } c.Assert(t.db.Close(), IsNil) select { case <-wait: // ok case <-time.After(3 * time.Second): c.Fatal("timeout") } c.Assert(t.db.Pool().Len(), Equals, 0) c.Assert(t.db.Pool().FreeLen(), Equals, 0) } func (t *PoolTest) TestClosedDB(c *C) { c.Assert(t.db.Close(), IsNil) c.Assert(t.db.Pool().Len(), Equals, 0) c.Assert(t.db.Pool().FreeLen(), Equals, 0) err := t.db.Close() c.Assert(err, Not(IsNil)) c.Assert(err.Error(), Equals, "pg: database is closed") _, err = t.db.Exec("SELECT 'test_closed_db'") c.Assert(err, Not(IsNil)) c.Assert(err.Error(), Equals, "pg: database is closed") } func (t *PoolTest) TestClosedListener(c *C) { ln := t.db.Listen("test_channel") c.Assert(t.db.Pool().Len(), Equals, 1) c.Assert(t.db.Pool().FreeLen(), Equals, 0) c.Assert(ln.Close(), IsNil) c.Assert(t.db.Pool().Len(), Equals, 0) c.Assert(t.db.Pool().FreeLen(), Equals, 0) err := ln.Close() c.Assert(err, Not(IsNil)) c.Assert(err.Error(), Equals, "pg: listener is closed") _, _, err = ln.ReceiveTimeout(time.Second) c.Assert(err, Not(IsNil)) c.Assert(err.Error(), Equals, "pg: listener is closed") } func (t *PoolTest) TestClosedTx(c *C) { tx, err := t.db.Begin() c.Assert(err, IsNil) c.Assert(t.db.Pool().Len(), Equals, 1) c.Assert(t.db.Pool().FreeLen(), Equals, 0) c.Assert(tx.Rollback(), IsNil) c.Assert(t.db.Pool().Len(), Equals, 1) c.Assert(t.db.Pool().FreeLen(), Equals, 1) err = tx.Rollback() c.Assert(err, Not(IsNil)) c.Assert(err.Error(), Equals, "pg: transaction has already been committed or rolled back") _, err = tx.Exec("SELECT 'test_closed_tx'") c.Assert(err, Not(IsNil)) c.Assert(err.Error(), Equals, "pg: transaction has already been committed or rolled back") } func (t *PoolTest) TestClosedStmt(c *C) { stmt, err := t.db.Prepare("SELECT $1::int") c.Assert(err, IsNil) c.Assert(t.db.Pool().Len(), Equals, 1) c.Assert(t.db.Pool().FreeLen(), Equals, 0) c.Assert(stmt.Close(), IsNil) c.Assert(t.db.Pool().Len(), Equals, 1) c.Assert(t.db.Pool().FreeLen(), Equals, 1) err = stmt.Close() c.Assert(err, Not(IsNil)) c.Assert(err.Error(), Equals, "pg: statement is closed") _, err = stmt.Exec(1) c.Assert(err.Error(), Equals, "pg: statement is closed") } pg-5.3.3/race_test.go000066400000000000000000000033251305650307100144330ustar00rootroot00000000000000package pg_test import ( "testing" "gopkg.in/pg.v5" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("DB race", func() { var db *pg.DB var C, N int BeforeEach(func() { db = pg.Connect(pgOptions()) err := createTestSchema(db) Expect(err).NotTo(HaveOccurred()) C, N = 10, 1000 if testing.Short() { C = 4 N = 100 } }) AfterEach(func() { err := db.Close() Expect(err).NotTo(HaveOccurred()) }) It("SelectOrInsert with OnConflict is race free", func() { perform(C, func(id int) { a := &Author{ Name: "R. Scott Bakker", } for i := 0; i < N; i++ { a.ID = 0 _, err := db.Model(a). Column("id"). Where("name = ?name"). OnConflict("DO NOTHING"). Returning("id"). SelectOrInsert(&a.ID) Expect(err).NotTo(HaveOccurred()) Expect(a.ID).NotTo(BeZero()) if i%(N/C) == 0 { err := db.Delete(a) if err != pg.ErrNoRows { Expect(err).NotTo(HaveOccurred()) } } } }) count, err := db.Model(&Author{}).Count() Expect(err).NotTo(HaveOccurred()) Expect(count).To(Equal(1)) }) It("SelectOrInsert without OnConflict is race free", func() { perform(C, func(id int) { a := &Author{ Name: "R. Scott Bakker", } for i := 0; i < N; i++ { a.ID = 0 _, err := db.Model(a). Column("id"). Where("name = ?name"). Returning("id"). SelectOrInsert(&a.ID) Expect(err).NotTo(HaveOccurred()) Expect(a.ID).NotTo(BeZero()) if i%(N/C) == 0 { err := db.Delete(a) if err != pg.ErrNoRows { Expect(err).NotTo(HaveOccurred()) } } } }) count, err := db.Model(&Author{}).Count() Expect(err).NotTo(HaveOccurred()) Expect(count).To(Equal(1)) }) }) pg-5.3.3/stmt.go000066400000000000000000000117011305650307100134460ustar00rootroot00000000000000package pg import ( "sync" "time" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/internal/pool" "gopkg.in/pg.v5/orm" "gopkg.in/pg.v5/types" ) // Stmt is a prepared statement. Stmt is safe for concurrent use by // multiple goroutines. type Stmt struct { db *DB mu sync.Mutex _cn *pool.Conn inTx bool q string name string columns [][]byte stickyErr error } // Prepare creates a prepared statement for later queries or // executions. Multiple queries or executions may be run concurrently // from the returned statement. func (db *DB) Prepare(q string) (*Stmt, error) { cn, err := db.conn() if err != nil { return nil, err } return prepare(db, cn, q) } func (stmt *Stmt) conn() (*pool.Conn, error) { if stmt._cn == nil { if stmt.stickyErr != nil { return nil, stmt.stickyErr } return nil, errStmtClosed } stmt._cn.SetReadWriteTimeout(stmt.db.opt.ReadTimeout, stmt.db.opt.WriteTimeout) return stmt._cn, nil } func (stmt *Stmt) exec(params ...interface{}) (*types.Result, error) { stmt.mu.Lock() defer stmt.mu.Unlock() cn, err := stmt.conn() if err != nil { return nil, err } return extQuery(cn, stmt.name, params...) } // Exec executes a prepared statement with the given parameters. func (stmt *Stmt) Exec(params ...interface{}) (res *types.Result, err error) { for i := 0; i < 3; i++ { res, err = stmt.exec(params...) if i >= stmt.db.opt.MaxRetries { break } if !stmt.db.shouldRetry(err) { break } time.Sleep(internal.RetryBackoff << uint(i)) } if err != nil { stmt.setErr(err) } return } // ExecOne acts like Exec, but query must affect only one row. It // returns ErrNoRows error when query returns zero rows or // ErrMultiRows when query returns multiple rows. func (stmt *Stmt) ExecOne(params ...interface{}) (*types.Result, error) { res, err := stmt.Exec(params...) if err != nil { return nil, err } if err := internal.AssertOneRow(res.RowsAffected()); err != nil { return nil, err } return res, nil } func (stmt *Stmt) query(model interface{}, params ...interface{}) (*types.Result, error) { stmt.mu.Lock() defer stmt.mu.Unlock() cn, err := stmt.conn() if err != nil { return nil, err } res, mod, err := extQueryData(cn, stmt.name, model, stmt.columns, params...) if err != nil { return nil, err } if res.RowsReturned() > 0 && mod != nil { if err = mod.AfterQuery(stmt.db); err != nil { return res, err } } return res, nil } // Query executes a prepared query statement with the given parameters. func (stmt *Stmt) Query(model interface{}, params ...interface{}) (res *types.Result, err error) { for i := 0; i < 3; i++ { res, err = stmt.query(model, params...) if i >= stmt.db.opt.MaxRetries { break } if !stmt.db.shouldRetry(err) { break } time.Sleep(internal.RetryBackoff << uint(i)) } if err != nil { stmt.setErr(err) } return } // QueryOne acts like Query, but query must return only one row. It // returns ErrNoRows error when query returns zero rows or // ErrMultiRows when query returns multiple rows. func (stmt *Stmt) QueryOne(model interface{}, params ...interface{}) (*types.Result, error) { mod, err := orm.NewModel(model) if err != nil { return nil, err } res, err := stmt.Query(mod, params...) if err != nil { return nil, err } if err := internal.AssertOneRow(res.RowsAffected()); err != nil { return nil, err } return res, nil } func (stmt *Stmt) setErr(e error) { if stmt.stickyErr == nil { stmt.stickyErr = e } } // Close closes the statement. func (stmt *Stmt) Close() error { stmt.mu.Lock() defer stmt.mu.Unlock() if stmt._cn == nil { return errStmtClosed } err := closeStmt(stmt._cn, stmt.name) if !stmt.inTx { _ = stmt.db.freeConn(stmt._cn, err) } stmt._cn = nil return err } func prepare(db *DB, cn *pool.Conn, q string) (*Stmt, error) { name := cn.NextId() writeParseDescribeSyncMsg(cn.Wr, name, q) if err := cn.FlushWriter(); err != nil { db.freeConn(cn, err) return nil, err } columns, err := readParseDescribeSync(cn) if err != nil { db.freeConn(cn, err) return nil, err } stmt := &Stmt{ db: db, _cn: cn, q: q, name: name, columns: columns, } return stmt, nil } func extQuery(cn *pool.Conn, name string, params ...interface{}) (*types.Result, error) { if err := writeBindExecuteMsg(cn.Wr, name, params...); err != nil { return nil, err } if err := cn.FlushWriter(); err != nil { return nil, err } return readExtQuery(cn) } func extQueryData( cn *pool.Conn, name string, model interface{}, columns [][]byte, params ...interface{}, ) (*types.Result, orm.Model, error) { if err := writeBindExecuteMsg(cn.Wr, name, params...); err != nil { return nil, nil, err } if err := cn.FlushWriter(); err != nil { return nil, nil, err } return readExtQueryData(cn, model, columns) } func closeStmt(cn *pool.Conn, name string) error { writeCloseMsg(cn.Wr, name) writeFlushMsg(cn.Wr) if err := cn.FlushWriter(); err != nil { return err } return readCloseCompleteMsg(cn) } pg-5.3.3/time.go000066400000000000000000000021251305650307100134150ustar00rootroot00000000000000package pg import ( "bytes" "database/sql" "encoding/json" "time" "gopkg.in/pg.v5/types" ) var jsonNull = []byte("null") // NullTime is a time.Time wrapper that marshals zero time as JSON null and // PostgreSQL NULL. type NullTime struct { time.Time } var _ json.Marshaler = (*NullTime)(nil) var _ json.Unmarshaler = (*NullTime)(nil) var _ sql.Scanner = (*NullTime)(nil) var _ types.ValueAppender = (*NullTime)(nil) func (tm NullTime) MarshalJSON() ([]byte, error) { if tm.IsZero() { return jsonNull, nil } return tm.Time.MarshalJSON() } func (tm *NullTime) UnmarshalJSON(b []byte) error { if bytes.Equal(b, jsonNull) { tm.Time = time.Time{} return nil } return tm.Time.UnmarshalJSON(b) } func (tm NullTime) AppendValue(b []byte, quote int) ([]byte, error) { if tm.IsZero() { return types.AppendNull(b, quote), nil } return types.AppendTime(b, tm.Time, quote), nil } func (tm *NullTime) Scan(b interface{}) error { if b == nil { tm.Time = time.Time{} return nil } newtm, err := types.ParseTime(b.([]byte)) if err != nil { return err } tm.Time = newtm return nil } pg-5.3.3/tx.go000066400000000000000000000152211305650307100131130ustar00rootroot00000000000000package pg import ( "io" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/internal/pool" "gopkg.in/pg.v5/orm" "gopkg.in/pg.v5/types" ) // Tx is an in-progress database transaction. // // A transaction must end with a call to Commit or Rollback. // // After a call to Commit or Rollback, all operations on the transaction fail // with ErrTxDone. // // The statements prepared for a transaction by calling the transaction's // Prepare or Stmt methods are closed by the call to Commit or Rollback. type Tx struct { db *DB cn *pool.Conn stmts []*Stmt } var _ orm.DB = (*Tx)(nil) // Begin starts a transaction. Most callers should use RunInTransaction instead. func (db *DB) Begin() (*Tx, error) { tx := &Tx{ db: db, } if !db.opt.DisableTransaction { cn, err := db.conn() if err != nil { return nil, err } tx.cn = cn } if err := tx.begin(); err != nil { return nil, err } return tx, nil } // RunInTransaction runs a function in a transaction. If function // returns an error transaction is rollbacked, otherwise transaction // is committed. func (db *DB) RunInTransaction(fn func(*Tx) error) error { tx, err := db.Begin() if err != nil { return err } return tx.RunInTransaction(fn) } // Begin returns the transaction. func (tx *Tx) Begin() (*Tx, error) { return tx, nil } // RunInTransaction runs a function in the transaction. If function // returns an error transaction is rollbacked, otherwise transaction // is committed. func (tx *Tx) RunInTransaction(fn func(*Tx) error) error { if err := fn(tx); err != nil { tx.Rollback() return err } return tx.Commit() } func (tx *Tx) conn() (*pool.Conn, error) { var cn *pool.Conn if tx.db.opt.DisableTransaction { var err error cn, err = tx.db.conn() if err != nil { return nil, err } } else { cn = tx.cn if cn == nil { return nil, errTxDone } } cn.SetReadWriteTimeout(tx.db.opt.ReadTimeout, tx.db.opt.WriteTimeout) return cn, nil } func (tx *Tx) freeConn(cn *pool.Conn, err error) { if tx.db.opt.DisableTransaction { _ = tx.db.freeConn(cn, err) } } // Stmt returns a transaction-specific prepared statement from an existing statement. func (tx *Tx) Stmt(stmt *Stmt) *Stmt { stmt, err := tx.Prepare(stmt.q) if err != nil { return &Stmt{stickyErr: err} } return stmt } // Prepare creates a prepared statement for use within a transaction. // // The returned statement operates within the transaction and can no longer // be used once the transaction has been committed or rolled back. // // To use an existing prepared statement on this transaction, see Tx.Stmt. func (tx *Tx) Prepare(q string) (*Stmt, error) { cn, err := tx.conn() if err != nil { return nil, err } stmt, err := prepare(tx.db, cn, q) tx.freeConn(cn, err) if err != nil { return nil, err } stmt.inTx = true tx.stmts = append(tx.stmts, stmt) return stmt, nil } // Exec executes a query with the given parameters in a transaction. func (tx *Tx) Exec(query interface{}, params ...interface{}) (*types.Result, error) { cn, err := tx.conn() if err != nil { return nil, err } res, err := tx.db.simpleQuery(cn, query, params...) tx.freeConn(cn, err) return res, err } // ExecOne acts like Exec, but query must affect only one row. It // returns ErrNoRows error when query returns zero rows or // ErrMultiRows when query returns multiple rows. func (tx *Tx) ExecOne(query interface{}, params ...interface{}) (*types.Result, error) { res, err := tx.Exec(query, params...) if err != nil { return nil, err } if err := internal.AssertOneRow(res.RowsAffected()); err != nil { return nil, err } return res, nil } // Query executes a query with the given parameters in a transaction. func (tx *Tx) Query(model interface{}, query interface{}, params ...interface{}) (*types.Result, error) { cn, err := tx.conn() if err != nil { return nil, err } res, mod, err := tx.db.simpleQueryData(cn, model, query, params...) tx.freeConn(cn, err) if err != nil { return nil, err } if res.RowsReturned() > 0 && mod != nil { if err = mod.AfterQuery(tx); err != nil { return res, err } } return res, err } // QueryOne acts like Query, but query must return only one row. It // returns ErrNoRows error when query returns zero rows or // ErrMultiRows when query returns multiple rows. func (tx *Tx) QueryOne(model interface{}, query interface{}, params ...interface{}) (*types.Result, error) { mod, err := orm.NewModel(model) if err != nil { return nil, err } res, err := tx.Query(mod, query, params...) if err != nil { return nil, err } if err := internal.AssertOneRow(res.RowsAffected()); err != nil { return nil, err } return res, nil } // Model returns new query for the model. func (tx *Tx) Model(model ...interface{}) *orm.Query { return orm.NewQuery(tx, model...) } // Select selects the model by primary key. func (tx *Tx) Select(model interface{}) error { return orm.Select(tx, model) } // Insert inserts the model updating primary keys if they are empty. func (tx *Tx) Insert(model ...interface{}) error { return orm.Insert(tx, model...) } // Update updates the model by primary key. func (tx *Tx) Update(model interface{}) error { return orm.Update(tx, model) } // Delete deletes the model by primary key. func (tx *Tx) Delete(model interface{}) error { return orm.Delete(tx, model) } // CreateTable creates table for the model. It recognizes following field tags: // - notnull - sets NOT NULL constraint. // - unique - sets UNIQUE constraint. func (tx *Tx) CreateTable(model interface{}, opt *orm.CreateTableOptions) error { _, err := orm.CreateTable(tx, model, opt) return err } func (tx *Tx) FormatQuery(dst []byte, query string, params ...interface{}) []byte { return tx.db.FormatQuery(dst, query, params...) } func (tx *Tx) begin() error { if tx.db.opt.DisableTransaction { return nil } _, err := tx.Exec("BEGIN") return err } // Commit commits the transaction. func (tx *Tx) Commit() error { if tx.db.opt.DisableTransaction { return nil } _, err := tx.Exec("COMMIT") tx.close(err) return err } // Rollback aborts the transaction. func (tx *Tx) Rollback() error { if tx.db.opt.DisableTransaction { return nil } _, err := tx.Exec("ROLLBACK") tx.close(err) return err } func (tx *Tx) close(lastErr error) error { if tx.cn == nil { return errTxDone } for _, stmt := range tx.stmts { _ = stmt.Close() } tx.stmts = nil err := tx.db.freeConn(tx.cn, lastErr) tx.cn = nil return err } // CopyFrom copies data from reader to a table. func (tx *Tx) CopyFrom(r io.Reader, query string, params ...interface{}) (*types.Result, error) { cn, err := tx.conn() if err != nil { return nil, err } res, err := tx.db.copyFrom(cn, r, query, params...) tx.freeConn(cn, err) return res, err } pg-5.3.3/tx_test.go000066400000000000000000000064101305650307100141520ustar00rootroot00000000000000package pg_test import ( "strings" "gopkg.in/pg.v5" . "gopkg.in/check.v1" ) var _ = Suite(&TxTest{}) type TxTest struct { db *pg.DB } func (t *TxTest) SetUpTest(c *C) { t.db = pg.Connect(pgOptions()) } func (t *TxTest) TearDownTest(c *C) { c.Assert(t.db.Close(), IsNil) } func (t *TxTest) TestMultiPrepare(c *C) { tx, err := t.db.Begin() c.Assert(err, IsNil) stmt1, err := tx.Prepare(`SELECT 'test_multi_prepare_tx1'`) c.Assert(err, IsNil) stmt2, err := tx.Prepare(`SELECT 'test_multi_prepare_tx2'`) c.Assert(err, IsNil) var s1 string _, err = stmt1.QueryOne(pg.Scan(&s1)) c.Assert(err, IsNil) c.Assert(s1, Equals, "test_multi_prepare_tx1") var s2 string _, err = stmt2.QueryOne(pg.Scan(&s2)) c.Assert(err, IsNil) c.Assert(s2, Equals, "test_multi_prepare_tx2") c.Assert(tx.Rollback(), IsNil) } func (t *TxTest) TestCopyFromInTransaction(c *C) { data := "hello\t5\nworld\t5\nfoo\t3\nbar\t3\n" _, err := t.db.Exec("DROP TABLE IF EXISTS test_copy_from") c.Assert(err, IsNil) _, err = t.db.Exec("CREATE TABLE test_copy_from(word text, len int)") c.Assert(err, IsNil) tx1, err := t.db.Begin() c.Assert(err, IsNil) tx2, err := t.db.Begin() c.Assert(err, IsNil) r := strings.NewReader(data) res, err := tx1.CopyFrom(r, "COPY test_copy_from FROM STDIN") c.Assert(err, IsNil) c.Assert(res.RowsAffected(), Equals, 4) var count int _, err = tx1.QueryOne(pg.Scan(&count), "SELECT COUNT(*) FROM test_copy_from") c.Assert(err, IsNil) c.Assert(count, Equals, 4) _, err = tx2.QueryOne(pg.Scan(&count), "SELECT COUNT(*) FROM test_copy_from") c.Assert(err, IsNil) c.Assert(count, Equals, 0) c.Assert(tx1.Commit(), IsNil) _, err = tx2.QueryOne(pg.Scan(&count), "SELECT COUNT(*) FROM test_copy_from") c.Assert(err, IsNil) c.Assert(count, Equals, 4) // assuming READ COMMITTED c.Assert(tx2.Rollback(), IsNil) _, err = t.db.Exec("DROP TABLE IF EXISTS test_copy_from") c.Assert(err, IsNil) } func (t *TxTest) TestCopyFromInTransactionWithErrors(c *C) { // too many fields on second line data := "hello\t5\nworld\t5\t6\t8\t9\nfoo\t3\nbar\t3\n" _, err := t.db.Exec("DROP TABLE IF EXISTS test_copy_from") c.Assert(err, IsNil) _, err = t.db.Exec("CREATE TABLE test_copy_from(word text, len int)") c.Assert(err, IsNil) _, err = t.db.Exec("INSERT INTO test_copy_from VALUES ('xxx', 3)") c.Assert(err, IsNil) tx1, err := t.db.Begin() c.Assert(err, IsNil) tx2, err := t.db.Begin() c.Assert(err, IsNil) _, err = tx1.Exec("INSERT INTO test_copy_from VALUES ('yyy', 3)") c.Assert(err, IsNil) r := strings.NewReader(data) _, err = tx1.CopyFrom(r, "COPY test_copy_from FROM STDIN") c.Assert(err, Not(IsNil)) var count int _, err = tx1.QueryOne(pg.Scan(&count), "SELECT COUNT(*) FROM test_copy_from") c.Assert(err, Not(IsNil)) // transaction has errors, cannot proceed _, err = tx2.QueryOne(pg.Scan(&count), "SELECT COUNT(*) FROM test_copy_from") c.Assert(err, IsNil) c.Assert(count, Equals, 1) c.Assert(tx1.Commit(), IsNil) // actually ROLLBACK happens here _, err = tx2.QueryOne(pg.Scan(&count), "SELECT COUNT(*) FROM test_copy_from") c.Assert(err, IsNil) c.Assert(count, Equals, 1) // other transaction was rolled back so it's not 2 and not 6 c.Assert(tx2.Rollback(), IsNil) _, err = t.db.Exec("DROP TABLE IF EXISTS test_copy_from") c.Assert(err, IsNil) } pg-5.3.3/types/000077500000000000000000000000001305650307100132745ustar00rootroot00000000000000pg-5.3.3/types/append.go000066400000000000000000000065601305650307100151010ustar00rootroot00000000000000package types import ( "database/sql/driver" "encoding/hex" "reflect" "strconv" "time" ) func Append(b []byte, v interface{}, quote int) []byte { switch v := v.(type) { case nil: return AppendNull(b, quote) case bool: return appendBool(b, v) case int8: return strconv.AppendInt(b, int64(v), 10) case int16: return strconv.AppendInt(b, int64(v), 10) case int32: return strconv.AppendInt(b, int64(v), 10) case int64: return strconv.AppendInt(b, int64(v), 10) case int: return strconv.AppendInt(b, int64(v), 10) case uint8: return strconv.AppendUint(b, uint64(v), 10) case uint16: return strconv.AppendUint(b, uint64(v), 10) case uint32: return strconv.AppendUint(b, uint64(v), 10) case uint64: return strconv.AppendUint(b, v, 10) case uint: return strconv.AppendUint(b, uint64(v), 10) case float32: return appendFloat(b, float64(v)) case float64: return appendFloat(b, v) case string: return AppendString(b, v, quote) case time.Time: return AppendTime(b, v, quote) case []byte: return appendBytes(b, v, quote) case ValueAppender: return appendAppender(b, v, quote) case driver.Valuer: return appendDriverValuer(b, v, quote) default: return appendValue(b, reflect.ValueOf(v), quote) } } func AppendError(b []byte, err error) []byte { b = append(b, "?!("...) b = append(b, err.Error()...) b = append(b, ')') return b } func AppendNull(b []byte, quote int) []byte { if quote == 1 { return append(b, "NULL"...) } else { return nil } } func appendBool(dst []byte, v bool) []byte { if v { return append(dst, "TRUE"...) } return append(dst, "FALSE"...) } func appendFloat(dst []byte, v float64) []byte { return strconv.AppendFloat(dst, v, 'f', -1, 64) } func AppendString(b []byte, s string, quote int) []byte { if quote == 2 { b = append(b, '"') } else if quote == 1 { b = append(b, '\'') } for i := 0; i < len(s); i++ { c := s[i] if c == '\000' { continue } if quote >= 1 { if c == '\'' { b = append(b, '\'', '\'') continue } } if quote == 2 { if c == '"' { b = append(b, '\\', '"') continue } if c == '\\' { b = append(b, '\\', '\\') continue } } b = append(b, c) } if quote >= 2 { b = append(b, '"') } else if quote == 1 { b = append(b, '\'') } return b } func appendBytes(b []byte, bytes []byte, quote int) []byte { if bytes == nil { return AppendNull(b, quote) } if quote == 1 { b = append(b, '\'') } tmp := make([]byte, hex.EncodedLen(len(bytes))) hex.Encode(tmp, bytes) b = append(b, "\\x"...) b = append(b, tmp...) if quote == 1 { b = append(b, '\'') } return b } func AppendStringStringMap(b []byte, m map[string]string, quote int) []byte { if m == nil { return AppendNull(b, quote) } if quote == 1 { b = append(b, '\'') } for key, value := range m { b = AppendString(b, key, 2) b = append(b, '=', '>') b = AppendString(b, value, 2) b = append(b, ',') } if len(m) > 0 { b = b[:len(b)-1] // Strip trailing comma. } if quote == 1 { b = append(b, '\'') } return b } func appendDriverValuer(b []byte, v driver.Valuer, quote int) []byte { value, err := v.Value() if err != nil { return AppendError(b, err) } return Append(b, value, quote) } func appendAppender(b []byte, v ValueAppender, quote int) []byte { bb, err := v.AppendValue(b, quote) if err != nil { return AppendError(b, err) } return bb } pg-5.3.3/types/append_array.go000066400000000000000000000071021305650307100162700ustar00rootroot00000000000000package types import ( "reflect" "strconv" ) var stringType = reflect.TypeOf((*string)(nil)).Elem() var sliceStringType = reflect.TypeOf([]string(nil)) var intType = reflect.TypeOf((*int)(nil)).Elem() var sliceIntType = reflect.TypeOf([]int(nil)) var int64Type = reflect.TypeOf((*int64)(nil)).Elem() var sliceInt64Type = reflect.TypeOf([]int64(nil)) var float64Type = reflect.TypeOf((*float64)(nil)).Elem() var sliceFloat64Type = reflect.TypeOf([]float64(nil)) func ArrayAppender(typ reflect.Type) AppenderFunc { elemType := typ.Elem() switch elemType { case stringType: return appendSliceStringValue case intType: return appendSliceIntValue case int64Type: return appendSliceInt64Value case float64Type: return appendSliceFloat64Value } appendElem := appender(elemType, true) return func(b []byte, v reflect.Value, quote int) []byte { if v.IsNil() { return AppendNull(b, quote) } if quote == 1 { b = append(b, '\'') } b = append(b, '{') for i := 0; i < v.Len(); i++ { elem := v.Index(i) b = appendElem(b, elem, 2) b = append(b, ',') } if v.Len() > 0 { b[len(b)-1] = '}' // Replace trailing comma. } else { b = append(b, '}') } if quote == 1 { b = append(b, '\'') } return b } } func appendSliceStringValue(b []byte, v reflect.Value, quote int) []byte { ss := v.Convert(sliceStringType).Interface().([]string) return appendSliceString(b, ss, quote) } func appendSliceString(b []byte, ss []string, quote int) []byte { if ss == nil { return AppendNull(b, quote) } if quote == 1 { b = append(b, '\'') } b = append(b, '{') for _, s := range ss { b = AppendString(b, s, 2) b = append(b, ',') } if len(ss) > 0 { b[len(b)-1] = '}' // Replace trailing comma. } else { b = append(b, '}') } if quote == 1 { b = append(b, '\'') } return b } func appendSliceIntValue(b []byte, v reflect.Value, quote int) []byte { ints := v.Convert(sliceIntType).Interface().([]int) return appendSliceInt(b, ints, quote) } func appendSliceInt(b []byte, ints []int, quote int) []byte { if ints == nil { return AppendNull(b, quote) } if quote == 1 { b = append(b, '\'') } b = append(b, '{') for _, n := range ints { b = strconv.AppendInt(b, int64(n), 10) b = append(b, ',') } if len(ints) > 0 { b[len(b)-1] = '}' // Replace trailing comma. } else { b = append(b, '}') } if quote == 1 { b = append(b, '\'') } return b } func appendSliceInt64Value(b []byte, v reflect.Value, quote int) []byte { ints := v.Convert(sliceInt64Type).Interface().([]int64) return appendSliceInt64(b, ints, quote) } func appendSliceInt64(b []byte, ints []int64, quote int) []byte { if ints == nil { return AppendNull(b, quote) } if quote == 1 { b = append(b, '\'') } b = append(b, '{') for _, n := range ints { b = strconv.AppendInt(b, n, 10) b = append(b, ',') } if len(ints) > 0 { b[len(b)-1] = '}' // Replace trailing comma. } else { b = append(b, '}') } if quote == 1 { b = append(b, '\'') } return b } func appendSliceFloat64Value(b []byte, v reflect.Value, quote int) []byte { floats := v.Convert(sliceFloat64Type).Interface().([]float64) return appendSliceFloat64(b, floats, quote) } func appendSliceFloat64(b []byte, floats []float64, quote int) []byte { if floats == nil { return AppendNull(b, quote) } if quote == 1 { b = append(b, '\'') } b = append(b, '{') for _, n := range floats { b = appendFloat(b, n) b = append(b, ',') } if len(floats) > 0 { b[len(b)-1] = '}' // Replace trailing comma. } else { b = append(b, '}') } if quote == 1 { b = append(b, '\'') } return b } pg-5.3.3/types/append_field.go000066400000000000000000000016661305650307100162460ustar00rootroot00000000000000package types import "gopkg.in/pg.v5/internal/parser" func AppendField(b []byte, field string, quote int) []byte { return appendField(b, parser.NewString(field), quote) } func AppendFieldBytes(b []byte, field []byte, quote int) []byte { return appendField(b, parser.New(field), quote) } func appendField(b []byte, p *parser.Parser, quote int) []byte { var quoted bool for p.Valid() { c := p.Read() switch c { case '*': if !quoted { b = append(b, '*') continue } case '.': if quoted && quote == 1 { b = append(b, '"') quoted = false } b = append(b, '.') if p.Skip('*') { b = append(b, '*') } else if quote == 1 { b = append(b, '"') quoted = true } continue } if !quoted && quote == 1 { b = append(b, '"') quoted = true } if c == '"' { b = append(b, '"', '"') } else { b = append(b, c) } } if quote == 1 && quoted { b = append(b, '"') } return b } pg-5.3.3/types/append_field_test.go000066400000000000000000000012221305650307100172710ustar00rootroot00000000000000package types_test import ( "testing" "gopkg.in/pg.v5/types" ) var appendFieldTests = []struct { field string wanted string }{ {"", ""}, {"id", `"id"`}, {"table.id", `"table"."id"`}, {"*", "*"}, {"table.*", `"table".*`}, {"id AS pk", `"id AS pk"`}, {"table.id AS table__id", `"table"."id AS table__id"`}, {"?shard", `"?shard"`}, {"?shard.id", `"?shard"."id"`}, {`"`, `""""`}, {`'`, `"'"`}, } func TestAppendField(t *testing.T) { for _, test := range appendFieldTests { got := types.AppendField(nil, test.field, 1) if string(got) != test.wanted { t.Errorf("got %q, wanted %q (field=%q)", got, test.wanted, test.field) } } } pg-5.3.3/types/append_hstore.go000066400000000000000000000020201305650307100164500ustar00rootroot00000000000000package types import ( "fmt" "reflect" ) var mapStringStringType = reflect.TypeOf(map[string]string(nil)) func HstoreAppender(typ reflect.Type) AppenderFunc { if typ.Key() == stringType && typ.Elem() == stringType { return appendMapStringStringValue } return func(b []byte, v reflect.Value, quote int) []byte { err := fmt.Errorf("pg.Hstore(unsupported %s)", v.Type()) return AppendError(b, err) } } func appendMapStringString(b []byte, m map[string]string, quote int) []byte { if m == nil { return AppendNull(b, quote) } if quote == 1 { b = append(b, '\'') } for key, value := range m { b = AppendString(b, key, 2) b = append(b, '=', '>') b = AppendString(b, value, 2) b = append(b, ',') } if len(m) > 0 { b = b[:len(b)-1] // Strip trailing comma. } if quote == 1 { b = append(b, '\'') } return b } func appendMapStringStringValue(b []byte, v reflect.Value, quote int) []byte { m := v.Convert(mapStringStringType).Interface().(map[string]string) return appendMapStringString(b, m, quote) } pg-5.3.3/types/append_jsonb.go000066400000000000000000000011621305650307100162650ustar00rootroot00000000000000package types import "gopkg.in/pg.v5/internal/parser" func AppendJSONB(b, jsonb []byte, quote int) []byte { if quote == 1 { b = append(b, '\'') } p := parser.New(jsonb) for p.Valid() { c := p.Read() switch c { case '\'': if quote == 1 { b = append(b, '\'', '\'') } else { b = append(b, '\'') } case '\000': continue case '\\': if p.SkipBytes([]byte("u0000")) { b = append(b, "\\\\u0000"...) } else { b = append(b, '\\') if p.Valid() { b = append(b, p.Read()) } } default: b = append(b, c) } } if quote == 1 { b = append(b, '\'') } return b } pg-5.3.3/types/append_jsonb_test.go000066400000000000000000000014131305650307100173230ustar00rootroot00000000000000package types_test import ( "bytes" "encoding/json" "testing" "gopkg.in/pg.v5/types" ) var jsonbTests = []struct { s, wanted string }{ {`\u0000`, `\\u0000`}, {`\\u0000`, `\\u0000`}, {`\\\u0000`, `\\\\u0000`}, {`foo \u0000 bar`, `foo \\u0000 bar`}, {`\u0001`, `\u0001`}, {`\\u0001`, `\\u0001`}, } func TestAppendJSONB(t *testing.T) { for _, test := range jsonbTests { got := types.AppendJSONB(nil, []byte(test.s), 0) if !bytes.Equal(got, []byte(test.wanted)) { t.Errorf("got %q, wanted %q", got, test.wanted) } } } func BenchmarkAppendJSONB(b *testing.B) { bytes, err := json.Marshal(jsonbTests) if err != nil { b.Fatal(err) } buf := make([]byte, 1024) b.ResetTimer() for i := 0; i < b.N; i++ { _ = types.AppendJSONB(buf[:0], bytes, 1) } } pg-5.3.3/types/append_value.go000066400000000000000000000075751305650307100163040ustar00rootroot00000000000000package types import ( "database/sql/driver" "encoding/json" "reflect" "strconv" "time" ) var appenderType = reflect.TypeOf((*ValueAppender)(nil)).Elem() type AppenderFunc func([]byte, reflect.Value, int) []byte var valueAppenders []AppenderFunc func init() { valueAppenders = []AppenderFunc{ reflect.Bool: appendBoolValue, reflect.Int: appendIntValue, reflect.Int8: appendIntValue, reflect.Int16: appendIntValue, reflect.Int32: appendIntValue, reflect.Int64: appendIntValue, reflect.Uint: appendUintValue, reflect.Uint8: appendUintValue, reflect.Uint16: appendUintValue, reflect.Uint32: appendUintValue, reflect.Uint64: appendUintValue, reflect.Uintptr: nil, reflect.Float32: appendFloatValue, reflect.Float64: appendFloatValue, reflect.Complex64: nil, reflect.Complex128: nil, reflect.Array: nil, reflect.Chan: nil, reflect.Func: nil, reflect.Interface: appendIfaceValue, reflect.Map: appendJSONValue, reflect.Ptr: nil, reflect.Slice: appendJSONValue, reflect.String: appendStringValue, reflect.Struct: appendStructValue, reflect.UnsafePointer: nil, } } func Appender(typ reflect.Type) AppenderFunc { return appender(typ, false) } func appender(typ reflect.Type, pgArray bool) AppenderFunc { if typ == timeType { return appendTimeValue } if typ.Implements(appenderType) { return appendAppenderValue } if typ.Implements(driverValuerType) { return appendDriverValuerValue } kind := typ.Kind() switch kind { case reflect.Ptr: return ptrAppenderFunc(typ) case reflect.Slice: if typ.Elem().Kind() == reflect.Uint8 { return appendBytesValue } if pgArray { return ArrayAppender(typ) } } return valueAppenders[kind] } func ptrAppenderFunc(typ reflect.Type) AppenderFunc { appender := Appender(typ.Elem()) return func(b []byte, v reflect.Value, quote int) []byte { if v.IsNil() { return AppendNull(b, quote) } return appender(b, v.Elem(), quote) } } func appendValue(b []byte, v reflect.Value, quote int) []byte { if v.Kind() == reflect.Ptr { if v.IsNil() { return AppendNull(b, quote) } return appendValue(b, v.Elem(), quote) } appender := Appender(v.Type()) return appender(b, v, quote) } func appendIfaceValue(b []byte, v reflect.Value, quote int) []byte { return Append(b, v.Interface(), quote) } func appendBoolValue(b []byte, v reflect.Value, _ int) []byte { return appendBool(b, v.Bool()) } func appendIntValue(b []byte, v reflect.Value, _ int) []byte { return strconv.AppendInt(b, v.Int(), 10) } func appendUintValue(b []byte, v reflect.Value, _ int) []byte { return strconv.AppendUint(b, v.Uint(), 10) } func appendFloatValue(b []byte, v reflect.Value, _ int) []byte { return appendFloat(b, v.Float()) } func appendBytesValue(b []byte, v reflect.Value, quote int) []byte { return appendBytes(b, v.Bytes(), quote) } func appendStringValue(b []byte, v reflect.Value, quote int) []byte { return AppendString(b, v.String(), quote) } func appendStructValue(b []byte, v reflect.Value, quote int) []byte { if v.Type() == timeType { return appendTimeValue(b, v, quote) } return appendJSONValue(b, v, quote) } func appendJSONValue(b []byte, v reflect.Value, quote int) []byte { bytes, err := json.Marshal(v.Interface()) if err != nil { return AppendError(b, err) } return AppendJSONB(b, bytes, quote) } func appendTimeValue(b []byte, v reflect.Value, quote int) []byte { tm := v.Interface().(time.Time) return AppendTime(b, tm, quote) } func appendAppenderValue(b []byte, v reflect.Value, quote int) []byte { return appendAppender(b, v.Interface().(ValueAppender), quote) } func appendDriverValuerValue(b []byte, v reflect.Value, quote int) []byte { return appendDriverValuer(b, v.Interface().(driver.Valuer), quote) } pg-5.3.3/types/array.go000066400000000000000000000015751305650307100147510ustar00rootroot00000000000000package types import ( "database/sql" "fmt" "reflect" ) type Array struct { v reflect.Value append AppenderFunc scan ScannerFunc } var _ ValueAppender = (*Array)(nil) var _ sql.Scanner = (*Array)(nil) func NewArray(vi interface{}) *Array { v := reflect.ValueOf(vi) if !v.IsValid() { panic(fmt.Errorf("pg.Array(nil)")) } v = reflect.Indirect(v) if v.Kind() != reflect.Slice { panic(fmt.Errorf("pg.Array(unsupported %s)", v.Type())) } return &Array{ v: v, append: ArrayAppender(v.Type()), scan: ArrayScanner(v.Type()), } } func (a *Array) Value() interface{} { if a.v.IsValid() { return a.v.Interface() } return nil } func (a *Array) AppendValue(b []byte, quote int) ([]byte, error) { b = a.append(b, a.v, quote) return b, nil } func (a *Array) Scan(b interface{}) error { if b == nil { return a.scan(a.v, nil) } return a.scan(a.v, b.([]byte)) } pg-5.3.3/types/hstore.go000066400000000000000000000016101305650307100151250ustar00rootroot00000000000000package types import ( "database/sql" "fmt" "reflect" ) type Hstore struct { v reflect.Value append AppenderFunc scan ScannerFunc } var _ ValueAppender = (*Hstore)(nil) var _ sql.Scanner = (*Hstore)(nil) func NewHstore(vi interface{}) *Hstore { v := reflect.ValueOf(vi) if !v.IsValid() { panic(fmt.Errorf("pg.Hstore(nil)")) } v = reflect.Indirect(v) if v.Kind() != reflect.Map { panic(fmt.Errorf("pg.Hstore(unsupported %s)", v.Type())) } return &Hstore{ v: v, append: HstoreAppender(v.Type()), scan: HstoreScanner(v.Type()), } } func (h *Hstore) Value() interface{} { if h.v.IsValid() { return h.v.Interface() } return nil } func (h *Hstore) AppendValue(b []byte, quote int) ([]byte, error) { b = h.append(b, h.v, quote) return b, nil } func (h *Hstore) Scan(b interface{}) error { if b == nil { return h.scan(h.v, nil) } return h.scan(h.v, b.([]byte)) } pg-5.3.3/types/in_op.go000066400000000000000000000012321305650307100147250ustar00rootroot00000000000000package types import ( "fmt" "reflect" ) type InOp struct { slice reflect.Value append AppenderFunc } var _ ValueAppender = (*InOp)(nil) func In(slice interface{}) *InOp { v := reflect.ValueOf(slice) if !v.IsValid() { panic(fmt.Errorf("pg.In(nil)")) } if v.Kind() != reflect.Slice { panic(fmt.Errorf("pg.In(unsupported %s)", v.Type())) } return &InOp{ slice: v, append: Appender(v.Type().Elem()), } } func (in *InOp) AppendValue(b []byte, quote int) ([]byte, error) { for i := 0; i < in.slice.Len(); i++ { b = in.append(b, in.slice.Index(i), quote) b = append(b, ',') } if in.slice.Len() > 0 { b = b[:len(b)-1] } return b, nil } pg-5.3.3/types/interface.go000066400000000000000000000012131305650307100155600ustar00rootroot00000000000000package types type ValueAppender interface { AppendValue(b []byte, quote int) ([]byte, error) } //------------------------------------------------------------------------------ // Q represents safe SQL query. type Q string var _ ValueAppender = Q("") func (q Q) AppendValue(dst []byte, quote int) ([]byte, error) { return append(dst, q...), nil } //------------------------------------------------------------------------------ // F represents a SQL field, e.g. table or column name. type F string var _ ValueAppender = F("") func (f F) AppendValue(dst []byte, quote int) ([]byte, error) { return AppendField(dst, string(f), quote), nil } pg-5.3.3/types/result.go000066400000000000000000000015601305650307100151430ustar00rootroot00000000000000package types import ( "bytes" "strconv" "gopkg.in/pg.v5/internal" ) // A Result summarizes an executed SQL command. type Result struct { affected int returned int } func NewResult(b []byte, returned int) *Result { res := Result{ affected: -1, returned: returned, } ind := bytes.LastIndexByte(b, ' ') if ind == -1 { return &res } s := internal.BytesToString(b[ind+1 : len(b)-1]) affected, err := strconv.Atoi(s) if err == nil { res.affected = affected } return &res } // RowsAffected returns the number of rows affected by SELECT, INSERT, UPDATE, // or DELETE queries. It returns -1 when query can't possibly affect any rows, // e.g. in case of CREATE or SHOW queries. func (r Result) RowsAffected() int { return r.affected } // RowsReturned returns the number of rows returned by the query. func (r Result) RowsReturned() int { return r.returned } pg-5.3.3/types/scan.go000066400000000000000000000026471305650307100145600ustar00rootroot00000000000000package types import ( "database/sql" "encoding/hex" "reflect" "strconv" "time" "gopkg.in/pg.v5/internal" ) func Scan(v interface{}, b []byte) error { switch v := v.(type) { case *string: *v = string(b) return nil case *[]byte: if b == nil { *v = nil return nil } var err error *v, err = scanBytes(b) return err case *int: if b == nil { *v = 0 return nil } var err error *v, err = strconv.Atoi(internal.BytesToString(b)) return err case *int64: if b == nil { *v = 0 return nil } var err error *v, err = strconv.ParseInt(internal.BytesToString(b), 10, 64) return err case *time.Time: if b == nil { *v = time.Time{} return nil } var err error *v, err = ParseTime(b) return err } vv := reflect.ValueOf(v) if !vv.IsValid() { return internal.Errorf("pg: Scan(nil)") } if vv.Kind() != reflect.Ptr { return internal.Errorf("pg: Scan(non-pointer %T)", v) } vv = vv.Elem() if !vv.IsValid() { return internal.Errorf("pg: Scan(non-pointer %T)", v) } return ScanValue(vv, b) } func scanSQLScanner(scanner sql.Scanner, b []byte) error { if b == nil { return scanner.Scan(nil) } return scanner.Scan(b) } func scanBytes(b []byte) ([]byte, error) { if len(b) < 2 { return nil, internal.Errorf("pg: can't parse bytes: %q", b) } b = b[2:] // Trim off "\\x". tmp := make([]byte, hex.DecodedLen(len(b))) _, err := hex.Decode(tmp, b) return tmp, err } pg-5.3.3/types/scan_array.go000066400000000000000000000070531305650307100157520ustar00rootroot00000000000000package types import ( "reflect" "strconv" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/internal/parser" ) func ArrayScanner(typ reflect.Type) ScannerFunc { elemType := typ.Elem() switch elemType { case stringType: return scanSliceStringValue case intType: return scanSliceIntValue case int64Type: return scanSliceInt64Value case float64Type: return scanSliceFloat64Value } scanElem := scanner(elemType, true) return func(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } if b == nil { if !v.IsNil() { v.Set(reflect.New(v.Type())) } return nil } if v.IsNil() { v.Set(reflect.MakeSlice(v.Type(), 0, 0)) } p := parser.NewArrayParser(b) for p.Valid() { elem, err := p.NextElem() if err != nil { return err } elemValue := internal.SliceNextElem(v) if err := scanElem(elemValue, elem); err != nil { return err } } return nil } } func scanSliceStringValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } strings, err := decodeSliceString(b) if err != nil { return err } v.Set(reflect.ValueOf(strings)) return nil } func decodeSliceString(b []byte) ([]string, error) { if b == nil { return nil, nil } p := parser.NewArrayParser(b) s := make([]string, 0) for p.Valid() { elem, err := p.NextElem() if err != nil { return nil, err } s = append(s, string(elem)) } return s, nil } func scanSliceIntValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } ints, err := decodeSliceInt(b) if err != nil { return err } v.Set(reflect.ValueOf(ints)) return nil } func decodeSliceInt(b []byte) ([]int, error) { if b == nil { return nil, nil } p := parser.NewArrayParser(b) slice := make([]int, 0) for p.Valid() { elem, err := p.NextElem() if err != nil { return nil, err } if elem == nil { slice = append(slice, 0) continue } n, err := strconv.Atoi(string(elem)) if err != nil { return nil, err } slice = append(slice, n) } return slice, nil } func scanSliceInt64Value(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } ints, err := decodeSliceInt64(b) if err != nil { return err } v.Set(reflect.ValueOf(ints)) return nil } func decodeSliceInt64(b []byte) ([]int64, error) { if b == nil { return nil, nil } p := parser.NewArrayParser(b) slice := make([]int64, 0) for p.Valid() { elem, err := p.NextElem() if err != nil { return nil, err } if elem == nil { slice = append(slice, 0) continue } n, err := strconv.ParseInt(internal.BytesToString(elem), 10, 64) if err != nil { return nil, err } slice = append(slice, n) } return slice, nil } func scanSliceFloat64Value(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } floats, err := decodeSliceFloat64(b) if err != nil { return err } v.Set(reflect.ValueOf(floats)) return nil } func decodeSliceFloat64(b []byte) ([]float64, error) { if b == nil { return nil, nil } p := parser.NewArrayParser(b) slice := make([]float64, 0) for p.Valid() { elem, err := p.NextElem() if err != nil { return nil, err } if elem == nil { slice = append(slice, 0) continue } n, err := strconv.ParseFloat(internal.BytesToString(elem), 64) if err != nil { return nil, err } slice = append(slice, n) } return slice, nil } pg-5.3.3/types/scan_hstore.go000066400000000000000000000022031305650307100161300ustar00rootroot00000000000000package types import ( "fmt" "reflect" "gopkg.in/pg.v5/internal" "gopkg.in/pg.v5/internal/parser" ) func HstoreScanner(typ reflect.Type) ScannerFunc { if typ.Key() == stringType && typ.Elem() == stringType { return scanMapStringStringValue } return func(v reflect.Value, b []byte) error { return fmt.Errorf("pg.Hstore(unsupported %s)", v.Type()) } } func scanMapStringString(b []byte) (map[string]string, error) { if b == nil { return nil, nil } p := parser.NewHstoreParser(b) m := make(map[string]string) for p.Valid() { key, err := p.NextKey() if err != nil { return nil, err } if key == nil { return nil, fmt.Errorf("pg: unexpected NULL: %q", b) } value, err := p.NextValue() if err != nil { return nil, err } if value == nil { return nil, fmt.Errorf("pg: unexpected NULL: %q", b) } m[string(key)] = string(value) } return m, nil } func scanMapStringStringValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } m, err := scanMapStringString(b) if err != nil { return err } v.Set(reflect.ValueOf(m)) return nil } pg-5.3.3/types/scan_value.go000066400000000000000000000136041305650307100157470ustar00rootroot00000000000000package types import ( "database/sql" "database/sql/driver" "encoding/json" "reflect" "strconv" "time" "gopkg.in/pg.v5/internal" ) var ( scannerType = reflect.TypeOf(new(sql.Scanner)).Elem() driverValuerType = reflect.TypeOf(new(driver.Valuer)).Elem() ) var ( timeType = reflect.TypeOf((*time.Time)(nil)).Elem() ) type ScannerFunc func(reflect.Value, []byte) error var valueScanners []ScannerFunc func init() { valueScanners = []ScannerFunc{ reflect.Bool: scanBoolValue, reflect.Int: scanIntValue, reflect.Int8: scanIntValue, reflect.Int16: scanIntValue, reflect.Int32: scanIntValue, reflect.Int64: scanIntValue, reflect.Uint: scanUintValue, reflect.Uint8: scanUintValue, reflect.Uint16: scanUintValue, reflect.Uint32: scanUintValue, reflect.Uint64: scanUintValue, reflect.Uintptr: nil, reflect.Float32: scanFloatValue, reflect.Float64: scanFloatValue, reflect.Complex64: nil, reflect.Complex128: nil, reflect.Array: nil, reflect.Chan: nil, reflect.Func: nil, reflect.Interface: scanIfaceValue, reflect.Map: scanJSONValue, reflect.Ptr: nil, reflect.Slice: scanJSONValue, reflect.String: scanStringValue, reflect.Struct: scanJSONValue, reflect.UnsafePointer: nil, } } func Scanner(typ reflect.Type) ScannerFunc { return scanner(typ, false) } func scanner(typ reflect.Type, pgArray bool) ScannerFunc { if typ == timeType { return scanTimeValue } if typ.Implements(scannerType) { return scanSQLScannerValue } if reflect.PtrTo(typ).Implements(scannerType) { return scanSQLScannerAddrValue } kind := typ.Kind() switch kind { case reflect.Ptr: return ptrScannerFunc(typ) case reflect.Slice: if typ.Elem().Kind() == reflect.Uint8 { return scanBytesValue } if pgArray { return ArrayScanner(typ) } } return valueScanners[kind] } func ptrScannerFunc(typ reflect.Type) ScannerFunc { scanner := Scanner(typ.Elem()) return func(v reflect.Value, b []byte) error { if scanner == nil { return internal.Errorf("pg: Scan(unsupported %s)", v.Type()) } if b == nil { if v.IsNil() { return nil } if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } v.Set(reflect.Zero(v.Type())) return nil } if v.IsNil() { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } v.Set(reflect.New(v.Type().Elem())) } return scanner(v.Elem(), b) } } func scanIfaceValue(v reflect.Value, b []byte) error { if v.IsNil() { return scanJSONValue(v, b) } return ScanValue(v.Elem(), b) } func IsSQLScanner(typ reflect.Type) bool { if typ.Implements(scannerType) { return true } if reflect.PtrTo(typ).Implements(scannerType) { return true } return false } func ScanValue(v reflect.Value, b []byte) error { if !v.IsValid() { return internal.Errorf("pg: Scan(nil)") } scanner := Scanner(v.Type()) if scanner != nil { return scanner(v, b) } if v.Kind() == reflect.Interface { return internal.Errorf("pg: Scan(nil)") } return internal.Errorf("pg: Scan(unsupported %s)", v.Type()) } func scanBoolValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } if b == nil { v.SetBool(false) return nil } v.SetBool(len(b) == 1 && (b[0] == 't' || b[0] == '1')) return nil } func scanIntValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } if b == nil { v.SetInt(0) return nil } n, err := strconv.ParseInt(internal.BytesToString(b), 10, 64) if err != nil { return err } v.SetInt(n) return nil } func scanUintValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } if b == nil { v.SetUint(0) return nil } n, err := strconv.ParseUint(internal.BytesToString(b), 10, 64) if err != nil { return err } v.SetUint(n) return nil } func scanFloatValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } if b == nil { v.SetFloat(0) return nil } n, err := strconv.ParseFloat(internal.BytesToString(b), 64) if err != nil { return err } v.SetFloat(n) return nil } func scanStringValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } v.SetString(string(b)) return nil } func scanJSONValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } if b == nil { v.Set(reflect.New(v.Type()).Elem()) return nil } return json.Unmarshal(b, v.Addr().Interface()) } var zeroTimeValue = reflect.ValueOf(time.Time{}) func scanTimeValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } if b == nil { v.Set(zeroTimeValue) return nil } tm, err := ParseTime(b) if err != nil { return err } v.Set(reflect.ValueOf(tm)) return nil } func scanBytesValue(v reflect.Value, b []byte) error { if !v.CanSet() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } if b == nil { v.SetBytes(nil) return nil } bs, err := scanBytes(b) if err != nil { return err } v.SetBytes(bs) return nil } func scanSQLScannerValue(v reflect.Value, b []byte) error { if b == nil { if v.IsNil() { return nil } return scanSQLScanner(v.Interface().(sql.Scanner), nil) } if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } return scanSQLScanner(v.Interface().(sql.Scanner), b) } func scanSQLScannerAddrValue(v reflect.Value, b []byte) error { if !v.CanAddr() { return internal.Errorf("pg: Scan(non-pointer %s)", v.Type()) } return scanSQLScanner(v.Addr().Interface().(sql.Scanner), b) } pg-5.3.3/types/time.go000066400000000000000000000022021305650307100145550ustar00rootroot00000000000000package types import "time" const ( dateFormat = "2006-01-02" timeFormat = "15:04:05.999999999" timestampFormat = "2006-01-02 15:04:05.999999999" timestamptzFormat = "2006-01-02 15:04:05.999999999-07:00:00" timestamptzFormat2 = "2006-01-02 15:04:05.999999999-07:00" timestamptzFormat3 = "2006-01-02 15:04:05.999999999-07" ) func ParseTime(b []byte) (time.Time, error) { switch l := len(b); { case l <= len(dateFormat): return time.Parse(dateFormat, string(b)) case l <= len(timeFormat): return time.Parse(timeFormat, string(b)) default: if c := b[len(b)-9]; c == '+' || c == '-' { return time.Parse(timestamptzFormat, string(b)) } if c := b[len(b)-6]; c == '+' || c == '-' { return time.Parse(timestamptzFormat2, string(b)) } if c := b[len(b)-3]; c == '+' || c == '-' { return time.Parse(timestamptzFormat3, string(b)) } return time.ParseInLocation(timestampFormat, string(b), time.Local) } } func AppendTime(b []byte, tm time.Time, quote int) []byte { if quote == 1 { b = append(b, '\'') } b = tm.AppendFormat(b, timestamptzFormat) if quote == 1 { b = append(b, '\'') } return b } pg-5.3.3/types/time_test.go000066400000000000000000000003011305650307100156120ustar00rootroot00000000000000package types_test import ( "testing" "gopkg.in/pg.v5/types" ) func BenchmarkParseTime(b *testing.B) { for i := 0; i < b.N; i++ { types.ParseTime([]byte("2001-02-03 04:05:06+07")) } }