pax_global_header00006660000000000000000000000064142265372130014517gustar00rootroot0000000000000052 comment=28212d434cdd87418ecd0cb81173690ce7ac6ab6 sqlx-1.3.5/000077500000000000000000000000001422653721300125145ustar00rootroot00000000000000sqlx-1.3.5/.gitignore000066400000000000000000000004171422653721300145060ustar00rootroot00000000000000# Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders _obj _test .idea # Architecture specific extensions/prefixes *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go *.exe tags environ sqlx-1.3.5/.travis.yml000066400000000000000000000011621422653721300146250ustar00rootroot00000000000000# vim: ft=yaml sw=2 ts=2 language: go # enable database services services: - mysql - postgresql # create test database before_install: - mysql -e 'CREATE DATABASE IF NOT EXISTS sqlxtest;' - psql -c 'create database sqlxtest;' -U postgres - go get github.com/mattn/goveralls - export SQLX_MYSQL_DSN="travis:@/sqlxtest?parseTime=true" - export SQLX_POSTGRES_DSN="postgres://postgres:@localhost/sqlxtest?sslmode=disable" - export SQLX_SQLITE_DSN="$HOME/sqlxtest.db" # go versions to test go: - "1.15.x" - "1.16.x" # run tests w/ coverage script: - travis_retry $GOPATH/bin/goveralls -service=travis-ci sqlx-1.3.5/LICENSE000066400000000000000000000020651422653721300135240ustar00rootroot00000000000000 Copyright (c) 2013, Jason Moiron Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. sqlx-1.3.5/README.md000066400000000000000000000205701422653721300137770ustar00rootroot00000000000000# sqlx [![Build Status](https://travis-ci.org/jmoiron/sqlx.svg?branch=master)](https://travis-ci.org/jmoiron/sqlx) [![Coverage Status](https://coveralls.io/repos/github/jmoiron/sqlx/badge.svg?branch=master)](https://coveralls.io/github/jmoiron/sqlx?branch=master) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/sqlx) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/jmoiron/sqlx/master/LICENSE) sqlx is a library which provides a set of extensions on go's standard `database/sql` library. The sqlx versions of `sql.DB`, `sql.TX`, `sql.Stmt`, et al. all leave the underlying interfaces untouched, so that their interfaces are a superset on the standard ones. This makes it relatively painless to integrate existing codebases using database/sql with sqlx. Major additional concepts are: * Marshal rows into structs (with embedded struct support), maps, and slices * Named parameter support including prepared statements * `Get` and `Select` to go quickly from query to struct/slice In addition to the [godoc API documentation](http://godoc.org/github.com/jmoiron/sqlx), there is also some [user documentation](http://jmoiron.github.io/sqlx/) that explains how to use `database/sql` along with sqlx. ## Recent Changes 1.3.0: * `sqlx.DB.Connx(context.Context) *sqlx.Conn` * `sqlx.BindDriver(driverName, bindType)` * support for `[]map[string]interface{}` to do "batch" insertions * allocation & perf improvements for `sqlx.In` DB.Connx returns an `sqlx.Conn`, which is an `sql.Conn`-alike consistent with sqlx's wrapping of other types. `BindDriver` allows users to control the bindvars that sqlx will use for drivers, and add new drivers at runtime. This results in a very slight performance hit when resolving the driver into a bind type (~40ns per call), but it allows users to specify what bindtype their driver uses even when sqlx has not been updated to know about it by default. ### Backwards Compatibility Compatibility with the most recent two versions of Go is a requirement for any new changes. Compatibility beyond that is not guaranteed. Versioning is done with Go modules. Breaking changes (eg. removing deprecated API) will get major version number bumps. ## install go get github.com/jmoiron/sqlx ## issues Row headers can be ambiguous (`SELECT 1 AS a, 2 AS a`), and the result of `Columns()` does not fully qualify column names in queries like: ```sql SELECT a.id, a.name, b.id, b.name FROM foos AS a JOIN foos AS b ON a.parent = b.id; ``` making a struct or map destination ambiguous. Use `AS` in your queries to give columns distinct names, `rows.Scan` to scan them manually, or `SliceScan` to get a slice of results. ## usage Below is an example which shows some common use cases for sqlx. Check [sqlx_test.go](https://github.com/jmoiron/sqlx/blob/master/sqlx_test.go) for more usage. ```go package main import ( "database/sql" "fmt" "log" _ "github.com/lib/pq" "github.com/jmoiron/sqlx" ) var schema = ` CREATE TABLE person ( first_name text, last_name text, email text ); CREATE TABLE place ( country text, city text NULL, telcode integer )` type Person struct { FirstName string `db:"first_name"` LastName string `db:"last_name"` Email string } type Place struct { Country string City sql.NullString TelCode int } func main() { // this Pings the database trying to connect // use sqlx.Open() for sql.Open() semantics db, err := sqlx.Connect("postgres", "user=foo dbname=bar sslmode=disable") if err != nil { log.Fatalln(err) } // exec the schema or fail; multi-statement Exec behavior varies between // database drivers; pq will exec them all, sqlite3 won't, ymmv db.MustExec(schema) tx := db.MustBegin() tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "Jason", "Moiron", "jmoiron@jmoiron.net") tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "John", "Doe", "johndoeDNE@gmail.net") tx.MustExec("INSERT INTO place (country, city, telcode) VALUES ($1, $2, $3)", "United States", "New York", "1") tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Hong Kong", "852") tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Singapore", "65") // Named queries can use structs, so if you have an existing struct (i.e. person := &Person{}) that you have populated, you can pass it in as &person tx.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", &Person{"Jane", "Citizen", "jane.citzen@example.com"}) tx.Commit() // Query the database, storing results in a []Person (wrapped in []interface{}) people := []Person{} db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC") jason, john := people[0], people[1] fmt.Printf("%#v\n%#v", jason, john) // Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"} // Person{FirstName:"John", LastName:"Doe", Email:"johndoeDNE@gmail.net"} // You can also get a single result, a la QueryRow jason = Person{} err = db.Get(&jason, "SELECT * FROM person WHERE first_name=$1", "Jason") fmt.Printf("%#v\n", jason) // Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"} // if you have null fields and use SELECT *, you must use sql.Null* in your struct places := []Place{} err = db.Select(&places, "SELECT * FROM place ORDER BY telcode ASC") if err != nil { fmt.Println(err) return } usa, singsing, honkers := places[0], places[1], places[2] fmt.Printf("%#v\n%#v\n%#v\n", usa, singsing, honkers) // Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1} // Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65} // Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852} // Loop through rows using only one struct place := Place{} rows, err := db.Queryx("SELECT * FROM place") for rows.Next() { err := rows.StructScan(&place) if err != nil { log.Fatalln(err) } fmt.Printf("%#v\n", place) } // Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1} // Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852} // Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65} // Named queries, using `:name` as the bindvar. Automatic bindvar support // which takes into account the dbtype based on the driverName on sqlx.Open/Connect _, err = db.NamedExec(`INSERT INTO person (first_name,last_name,email) VALUES (:first,:last,:email)`, map[string]interface{}{ "first": "Bin", "last": "Smuth", "email": "bensmith@allblacks.nz", }) // Selects Mr. Smith from the database rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:fn`, map[string]interface{}{"fn": "Bin"}) // Named queries can also use structs. Their bind names follow the same rules // as the name -> db mapping, so struct fields are lowercased and the `db` tag // is taken into consideration. rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:first_name`, jason) // batch insert // batch insert with structs personStructs := []Person{ {FirstName: "Ardie", LastName: "Savea", Email: "asavea@ab.co.nz"}, {FirstName: "Sonny Bill", LastName: "Williams", Email: "sbw@ab.co.nz"}, {FirstName: "Ngani", LastName: "Laumape", Email: "nlaumape@ab.co.nz"}, } _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`, personStructs) // batch insert with maps personMaps := []map[string]interface{}{ {"first_name": "Ardie", "last_name": "Savea", "email": "asavea@ab.co.nz"}, {"first_name": "Sonny Bill", "last_name": "Williams", "email": "sbw@ab.co.nz"}, {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"}, } _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`, personMaps) } ``` sqlx-1.3.5/bind.go000066400000000000000000000141021422653721300137550ustar00rootroot00000000000000package sqlx import ( "bytes" "database/sql/driver" "errors" "reflect" "strconv" "strings" "sync" "github.com/jmoiron/sqlx/reflectx" ) // Bindvar types supported by Rebind, BindMap and BindStruct. const ( UNKNOWN = iota QUESTION DOLLAR NAMED AT ) var defaultBinds = map[int][]string{ DOLLAR: []string{"postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql", "nrpostgres", "cockroach"}, QUESTION: []string{"mysql", "sqlite3", "nrmysql", "nrsqlite3"}, NAMED: []string{"oci8", "ora", "goracle", "godror"}, AT: []string{"sqlserver"}, } var binds sync.Map func init() { for bind, drivers := range defaultBinds { for _, driver := range drivers { BindDriver(driver, bind) } } } // BindType returns the bindtype for a given database given a drivername. func BindType(driverName string) int { itype, ok := binds.Load(driverName) if !ok { return UNKNOWN } return itype.(int) } // BindDriver sets the BindType for driverName to bindType. func BindDriver(driverName string, bindType int) { binds.Store(driverName, bindType) } // FIXME: this should be able to be tolerant of escaped ?'s in queries without // losing much speed, and should be to avoid confusion. // Rebind a query from the default bindtype (QUESTION) to the target bindtype. func Rebind(bindType int, query string) string { switch bindType { case QUESTION, UNKNOWN: return query } // Add space enough for 10 params before we have to allocate rqb := make([]byte, 0, len(query)+10) var i, j int for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") { rqb = append(rqb, query[:i]...) switch bindType { case DOLLAR: rqb = append(rqb, '$') case NAMED: rqb = append(rqb, ':', 'a', 'r', 'g') case AT: rqb = append(rqb, '@', 'p') } j++ rqb = strconv.AppendInt(rqb, int64(j), 10) query = query[i+1:] } return string(append(rqb, query...)) } // Experimental implementation of Rebind which uses a bytes.Buffer. The code is // much simpler and should be more resistant to odd unicode, but it is twice as // slow. Kept here for benchmarking purposes and to possibly replace Rebind if // problems arise with its somewhat naive handling of unicode. func rebindBuff(bindType int, query string) string { if bindType != DOLLAR { return query } b := make([]byte, 0, len(query)) rqb := bytes.NewBuffer(b) j := 1 for _, r := range query { if r == '?' { rqb.WriteRune('$') rqb.WriteString(strconv.Itoa(j)) j++ } else { rqb.WriteRune(r) } } return rqb.String() } func asSliceForIn(i interface{}) (v reflect.Value, ok bool) { if i == nil { return reflect.Value{}, false } v = reflect.ValueOf(i) t := reflectx.Deref(v.Type()) // Only expand slices if t.Kind() != reflect.Slice { return reflect.Value{}, false } // []byte is a driver.Value type so it should not be expanded if t == reflect.TypeOf([]byte{}) { return reflect.Value{}, false } return v, true } // In expands slice values in args, returning the modified query string // and a new arg list that can be executed by a database. The `query` should // use the `?` bindVar. The return value uses the `?` bindVar. func In(query string, args ...interface{}) (string, []interface{}, error) { // argMeta stores reflect.Value and length for slices and // the value itself for non-slice arguments type argMeta struct { v reflect.Value i interface{} length int } var flatArgsCount int var anySlices bool var stackMeta [32]argMeta var meta []argMeta if len(args) <= len(stackMeta) { meta = stackMeta[:len(args)] } else { meta = make([]argMeta, len(args)) } for i, arg := range args { if a, ok := arg.(driver.Valuer); ok { var err error arg, err = a.Value() if err != nil { return "", nil, err } } if v, ok := asSliceForIn(arg); ok { meta[i].length = v.Len() meta[i].v = v anySlices = true flatArgsCount += meta[i].length if meta[i].length == 0 { return "", nil, errors.New("empty slice passed to 'in' query") } } else { meta[i].i = arg flatArgsCount++ } } // don't do any parsing if there aren't any slices; note that this means // some errors that we might have caught below will not be returned. if !anySlices { return query, args, nil } newArgs := make([]interface{}, 0, flatArgsCount) var buf strings.Builder buf.Grow(len(query) + len(", ?")*flatArgsCount) var arg, offset int for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') { if arg >= len(meta) { // if an argument wasn't passed, lets return an error; this is // not actually how database/sql Exec/Query works, but since we are // creating an argument list programmatically, we want to be able // to catch these programmer errors earlier. return "", nil, errors.New("number of bindVars exceeds arguments") } argMeta := meta[arg] arg++ // not a slice, continue. // our questionmark will either be written before the next expansion // of a slice or after the loop when writing the rest of the query if argMeta.length == 0 { offset = offset + i + 1 newArgs = append(newArgs, argMeta.i) continue } // write everything up to and including our ? character buf.WriteString(query[:offset+i+1]) for si := 1; si < argMeta.length; si++ { buf.WriteString(", ?") } newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length) // slice the query and reset the offset. this avoids some bookkeeping for // the write after the loop query = query[offset+i+1:] offset = 0 } buf.WriteString(query) if arg < len(meta) { return "", nil, errors.New("number of bindVars less than number arguments") } return buf.String(), newArgs, nil } func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} { switch val := v.Interface().(type) { case []interface{}: args = append(args, val...) case []int: for i := range val { args = append(args, val[i]) } case []string: for i := range val { args = append(args, val[i]) } default: for si := 0; si < vlen; si++ { args = append(args, v.Index(si).Interface()) } } return args } sqlx-1.3.5/bind_test.go000066400000000000000000000030001422653721300150070ustar00rootroot00000000000000package sqlx import ( "math/rand" "testing" ) func oldBindType(driverName string) int { switch driverName { case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql": return DOLLAR case "mysql": return QUESTION case "sqlite3": return QUESTION case "oci8", "ora", "goracle", "godror": return NAMED case "sqlserver": return AT } return UNKNOWN } /* sync.Map implementation: goos: linux goarch: amd64 pkg: github.com/jmoiron/sqlx BenchmarkBindSpeed/old-4 100000000 11.0 ns/op BenchmarkBindSpeed/new-4 24575726 50.8 ns/op async.Value map implementation: goos: linux goarch: amd64 pkg: github.com/jmoiron/sqlx BenchmarkBindSpeed/old-4 100000000 11.0 ns/op BenchmarkBindSpeed/new-4 42535839 27.5 ns/op */ func BenchmarkBindSpeed(b *testing.B) { testDrivers := []string{ "postgres", "pgx", "mysql", "sqlite3", "ora", "sqlserver", } b.Run("old", func(b *testing.B) { b.StopTimer() var seq []int for i := 0; i < b.N; i++ { seq = append(seq, rand.Intn(len(testDrivers))) } b.StartTimer() for i := 0; i < b.N; i++ { s := oldBindType(testDrivers[seq[i]]) if s == UNKNOWN { b.Error("unknown driver") } } }) b.Run("new", func(b *testing.B) { b.StopTimer() var seq []int for i := 0; i < b.N; i++ { seq = append(seq, rand.Intn(len(testDrivers))) } b.StartTimer() for i := 0; i < b.N; i++ { s := BindType(testDrivers[seq[i]]) if s == UNKNOWN { b.Error("unknown driver") } } }) } sqlx-1.3.5/doc.go000066400000000000000000000010721422653721300136100ustar00rootroot00000000000000// Package sqlx provides general purpose extensions to database/sql. // // It is intended to seamlessly wrap database/sql and provide convenience // methods which are useful in the development of database driven applications. // None of the underlying database/sql methods are changed. Instead all extended // behavior is implemented through new methods defined on wrapper types. // // Additions include scanning into structs, named query support, rebinding // queries for different drivers, convenient shorthands for common error handling // and more. // package sqlx sqlx-1.3.5/go.mod000066400000000000000000000002331422653721300136200ustar00rootroot00000000000000module github.com/jmoiron/sqlx go 1.10 require ( github.com/go-sql-driver/mysql v1.6.0 github.com/lib/pq v1.2.0 github.com/mattn/go-sqlite3 v1.14.6 ) sqlx-1.3.5/go.sum000066400000000000000000000007731422653721300136560ustar00rootroot00000000000000github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= sqlx-1.3.5/named.go000066400000000000000000000342231422653721300141330ustar00rootroot00000000000000package sqlx // Named Query Support // // * BindMap - bind query bindvars to map/struct args // * NamedExec, NamedQuery - named query w/ struct or map // * NamedStmt - a pre-compiled named query which is a prepared statement // // Internal Interfaces: // // * compileNamedQuery - rebind a named query, returning a query and list of names // * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist // import ( "bytes" "database/sql" "errors" "fmt" "reflect" "regexp" "strconv" "unicode" "github.com/jmoiron/sqlx/reflectx" ) // NamedStmt is a prepared statement that executes named queries. Prepare it // how you would execute a NamedQuery, but pass in a struct or map when executing. type NamedStmt struct { Params []string QueryString string Stmt *Stmt } // Close closes the named statement. func (n *NamedStmt) Close() error { return n.Stmt.Close() } // Exec executes a named statement using the struct passed. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return *new(sql.Result), err } return n.Stmt.Exec(args...) } // Query executes a named statement using the struct argument, returning rows. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return nil, err } return n.Stmt.Query(args...) } // QueryRow executes a named statement against the database. Because sqlx cannot // create a *sql.Row with an error condition pre-set for binding errors, sqlx // returns a *sqlx.Row instead. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRow(arg interface{}) *Row { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return &Row{err: err} } return n.Stmt.QueryRowx(args...) } // MustExec execs a NamedStmt, panicing on error // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) MustExec(arg interface{}) sql.Result { res, err := n.Exec(arg) if err != nil { panic(err) } return res } // Queryx using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) { r, err := n.Query(arg) if err != nil { return nil, err } return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err } // QueryRowx this NamedStmt. Because of limitations with QueryRow, this is // an alias for QueryRow. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRowx(arg interface{}) *Row { return n.QueryRow(arg) } // Select using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Select(dest interface{}, arg interface{}) error { rows, err := n.Queryx(arg) if err != nil { return err } // if something happens here, we want to make sure the rows are Closed defer rows.Close() return scanAll(rows, dest, false) } // Get using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) Get(dest interface{}, arg interface{}) error { r := n.QueryRowx(arg) return r.scanAny(dest, false) } // Unsafe creates an unsafe version of the NamedStmt func (n *NamedStmt) Unsafe() *NamedStmt { r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString} r.Stmt.unsafe = true return r } // A union interface of preparer and binder, required to be able to prepare // named statements (as the bindtype must be determined). type namedPreparer interface { Preparer binder } func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) { bindType := BindType(p.DriverName()) q, args, err := compileNamedQuery([]byte(query), bindType) if err != nil { return nil, err } stmt, err := Preparex(p, q) if err != nil { return nil, err } return &NamedStmt{ QueryString: q, Params: args, Stmt: stmt, }, nil } // convertMapStringInterface attempts to convert v to map[string]interface{}. // Unlike v.(map[string]interface{}), this function works on named types that // are convertible to map[string]interface{} as well. func convertMapStringInterface(v interface{}) (map[string]interface{}, bool) { var m map[string]interface{} mtype := reflect.TypeOf(m) t := reflect.TypeOf(v) if !t.ConvertibleTo(mtype) { return nil, false } return reflect.ValueOf(v).Convert(mtype).Interface().(map[string]interface{}), true } func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { if maparg, ok := convertMapStringInterface(arg); ok { return bindMapArgs(names, maparg) } return bindArgs(names, arg, m) } // private interface to generate a list of interfaces from a given struct // type, given a list of names to pull out of the struct. Used by public // BindStruct interface. func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) { arglist := make([]interface{}, 0, len(names)) // grab the indirected value of arg v := reflect.ValueOf(arg) for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; { v = v.Elem() } err := m.TraversalsByNameFunc(v.Type(), names, func(i int, t []int) error { if len(t) == 0 { return fmt.Errorf("could not find name %s in %#v", names[i], arg) } val := reflectx.FieldByIndexesReadOnly(v, t) arglist = append(arglist, val.Interface()) return nil }) return arglist, err } // like bindArgs, but for maps. func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) { arglist := make([]interface{}, 0, len(names)) for _, name := range names { val, ok := arg[name] if !ok { return arglist, fmt.Errorf("could not find name %s in %#v", name, arg) } arglist = append(arglist, val) } return arglist, nil } // bindStruct binds a named parameter query with fields from a struct argument. // The rules for binding field names to parameter names follow the same // conventions as for StructScan, including obeying the `db` struct tags. func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { bound, names, err := compileNamedQuery([]byte(query), bindType) if err != nil { return "", []interface{}{}, err } arglist, err := bindAnyArgs(names, arg, m) if err != nil { return "", []interface{}{}, err } return bound, arglist, nil } var valuesReg = regexp.MustCompile(`\)\s*(?i)VALUES\s*\(`) func findMatchingClosingBracketIndex(s string) int { count := 0 for i, ch := range s { if ch == '(' { count++ } if ch == ')' { count-- if count == 0 { return i } } } return 0 } func fixBound(bound string, loop int) string { loc := valuesReg.FindStringIndex(bound) // defensive guard when "VALUES (...)" not found if len(loc) < 2 { return bound } openingBracketIndex := loc[1] - 1 index := findMatchingClosingBracketIndex(bound[openingBracketIndex:]) // defensive guard. must have closing bracket if index == 0 { return bound } closingBracketIndex := openingBracketIndex + index + 1 var buffer bytes.Buffer buffer.WriteString(bound[0:closingBracketIndex]) for i := 0; i < loop-1; i++ { buffer.WriteString(",") buffer.WriteString(bound[openingBracketIndex:closingBracketIndex]) } buffer.WriteString(bound[closingBracketIndex:]) return buffer.String() } // bindArray binds a named parameter query with fields from an array or slice of // structs argument. func bindArray(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { // do the initial binding with QUESTION; if bindType is not question, // we can rebind it at the end. bound, names, err := compileNamedQuery([]byte(query), QUESTION) if err != nil { return "", []interface{}{}, err } arrayValue := reflect.ValueOf(arg) arrayLen := arrayValue.Len() if arrayLen == 0 { return "", []interface{}{}, fmt.Errorf("length of array is 0: %#v", arg) } var arglist = make([]interface{}, 0, len(names)*arrayLen) for i := 0; i < arrayLen; i++ { elemArglist, err := bindAnyArgs(names, arrayValue.Index(i).Interface(), m) if err != nil { return "", []interface{}{}, err } arglist = append(arglist, elemArglist...) } if arrayLen > 1 { bound = fixBound(bound, arrayLen) } // adjust binding type if we weren't on question if bindType != QUESTION { bound = Rebind(bindType, bound) } return bound, arglist, nil } // bindMap binds a named parameter query with a map of arguments. func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) { bound, names, err := compileNamedQuery([]byte(query), bindType) if err != nil { return "", []interface{}{}, err } arglist, err := bindMapArgs(names, args) return bound, arglist, err } // -- Compilation of Named Queries // Allow digits and letters in bind params; additionally runes are // checked against underscores, meaning that bind params can have be // alphanumeric with underscores. Mind the difference between unicode // digits and numbers, where '5' is a digit but '五' is not. var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit} // FIXME: this function isn't safe for unicode named params, as a failing test // can testify. This is not a regression but a failure of the original code // as well. It should be modified to range over runes in a string rather than // bytes, even though this is less convenient and slower. Hopefully the // addition of the prepared NamedStmt (which will only do this once) will make // up for the slightly slower ad-hoc NamedExec/NamedQuery. // compile a NamedQuery into an unbound query (using the '?' bindvar) and // a list of names. func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) { names = make([]string, 0, 10) rebound := make([]byte, 0, len(qs)) inName := false last := len(qs) - 1 currentVar := 1 name := make([]byte, 0, 10) for i, b := range qs { // a ':' while we're in a name is an error if b == ':' { // if this is the second ':' in a '::' escape sequence, append a ':' if inName && i > 0 && qs[i-1] == ':' { rebound = append(rebound, ':') inName = false continue } else if inName { err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i)) return query, names, err } inName = true name = []byte{} } else if inName && i > 0 && b == '=' && len(name) == 0 { rebound = append(rebound, ':', '=') inName = false continue // if we're in a name, and this is an allowed character, continue } else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last { // append the byte to the name if we are in a name and not on the last byte name = append(name, b) // if we're in a name and it's not an allowed character, the name is done } else if inName { inName = false // if this is the final byte of the string and it is part of the name, then // make sure to add it to the name if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) { name = append(name, b) } // add the string representation to the names list names = append(names, string(name)) // add a proper bindvar for the bindType switch bindType { // oracle only supports named type bind vars even for positional case NAMED: rebound = append(rebound, ':') rebound = append(rebound, name...) case QUESTION, UNKNOWN: rebound = append(rebound, '?') case DOLLAR: rebound = append(rebound, '$') for _, b := range strconv.Itoa(currentVar) { rebound = append(rebound, byte(b)) } currentVar++ case AT: rebound = append(rebound, '@', 'p') for _, b := range strconv.Itoa(currentVar) { rebound = append(rebound, byte(b)) } currentVar++ } // add this byte to string unless it was not part of the name if i != last { rebound = append(rebound, b) } else if !unicode.IsOneOf(allowedBindRunes, rune(b)) { rebound = append(rebound, b) } } else { // this is a normal byte and should just go onto the rebound query rebound = append(rebound, b) } } return string(rebound), names, err } // BindNamed binds a struct or a map to a query with named parameters. // DEPRECATED: use sqlx.Named` instead of this, it may be removed in future. func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) { return bindNamedMapper(bindType, query, arg, mapper()) } // Named takes a query using named parameters and an argument and // returns a new query with a list of args that can be executed by // a database. The return value uses the `?` bindvar. func Named(query string, arg interface{}) (string, []interface{}, error) { return bindNamedMapper(QUESTION, query, arg, mapper()) } func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) { t := reflect.TypeOf(arg) k := t.Kind() switch { case k == reflect.Map && t.Key().Kind() == reflect.String: m, ok := convertMapStringInterface(arg) if !ok { return "", nil, fmt.Errorf("sqlx.bindNamedMapper: unsupported map type: %T", arg) } return bindMap(bindType, query, m) case k == reflect.Array || k == reflect.Slice: return bindArray(bindType, query, arg, m) default: return bindStruct(bindType, query, arg, m) } } // NamedQuery binds a named query and then runs Query on the result using the // provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with // map[string]interface{} types. func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err } return e.Queryx(q, args...) } // NamedExec uses BindStruct to get a query executable by the driver and // then runs Exec on the result. Returns an error from the binding // or the query execution itself. func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err } return e.Exec(q, args...) } sqlx-1.3.5/named_context.go000066400000000000000000000106461422653721300157020ustar00rootroot00000000000000// +build go1.8 package sqlx import ( "context" "database/sql" ) // A union interface of contextPreparer and binder, required to be able to // prepare named statements with context (as the bindtype must be determined). type namedPreparerContext interface { PreparerContext binder } func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) { bindType := BindType(p.DriverName()) q, args, err := compileNamedQuery([]byte(query), bindType) if err != nil { return nil, err } stmt, err := PreparexContext(ctx, p, q) if err != nil { return nil, err } return &NamedStmt{ QueryString: q, Params: args, Stmt: stmt, }, nil } // ExecContext executes a named statement using the struct passed. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return *new(sql.Result), err } return n.Stmt.ExecContext(ctx, args...) } // QueryContext executes a named statement using the struct argument, returning rows. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return nil, err } return n.Stmt.QueryContext(ctx, args...) } // QueryRowContext executes a named statement against the database. Because sqlx cannot // create a *sql.Row with an error condition pre-set for binding errors, sqlx // returns a *sqlx.Row instead. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row { args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper) if err != nil { return &Row{err: err} } return n.Stmt.QueryRowxContext(ctx, args...) } // MustExecContext execs a NamedStmt, panicing on error // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result { res, err := n.ExecContext(ctx, arg) if err != nil { panic(err) } return res } // QueryxContext using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) { r, err := n.QueryContext(ctx, arg) if err != nil { return nil, err } return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err } // QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is // an alias for QueryRow. // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row { return n.QueryRowContext(ctx, arg) } // SelectContext using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error { rows, err := n.QueryxContext(ctx, arg) if err != nil { return err } // if something happens here, we want to make sure the rows are Closed defer rows.Close() return scanAll(rows, dest, false) } // GetContext using this NamedStmt // Any named placeholder parameters are replaced with fields from arg. func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error { r := n.QueryRowxContext(ctx, arg) return r.scanAny(dest, false) } // NamedQueryContext binds a named query and then runs Query on the result using the // provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with // map[string]interface{} types. func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err } return e.QueryxContext(ctx, q, args...) } // NamedExecContext uses BindStruct to get a query executable by the driver and // then runs Exec on the result. Returns an error from the binding // or the query execution itself. func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) { q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e)) if err != nil { return nil, err } return e.ExecContext(ctx, q, args...) } sqlx-1.3.5/named_context_test.go000066400000000000000000000070731422653721300167410ustar00rootroot00000000000000// +build go1.8 package sqlx import ( "context" "database/sql" "testing" ) func TestNamedContextQueries(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) test := Test{t} var ns *NamedStmt var err error ctx := context.Background() // Check that invalid preparations fail ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name") if err == nil { t.Error("Expected an error with invalid prepared statement.") } ns, err = db.PrepareNamedContext(ctx, "invalid sql") if err == nil { t.Error("Expected an error with invalid prepared statement.") } // Check closing works as anticipated ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name") test.Error(err) err = ns.Close() test.Error(err) ns, err = db.PrepareNamedContext(ctx, ` SELECT first_name, last_name, email FROM person WHERE first_name=:first_name AND email=:email`) test.Error(err) // test Queryx w/ uses Query p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} rows, err := ns.QueryxContext(ctx, p) test.Error(err) for rows.Next() { var p2 Person rows.StructScan(&p2) if p.FirstName != p2.FirstName { t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) } if p.LastName != p2.LastName { t.Errorf("got %s, expected %s", p.LastName, p2.LastName) } if p.Email != p2.Email { t.Errorf("got %s, expected %s", p.Email, p2.Email) } } // test Select people := make([]Person, 0, 5) err = ns.SelectContext(ctx, &people, p) test.Error(err) if len(people) != 1 { t.Errorf("got %d results, expected %d", len(people), 1) } if p.FirstName != people[0].FirstName { t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) } if p.LastName != people[0].LastName { t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) } if p.Email != people[0].Email { t.Errorf("got %s, expected %s", p.Email, people[0].Email) } // test Exec ns, err = db.PrepareNamedContext(ctx, ` INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`) test.Error(err) js := Person{ FirstName: "Julien", LastName: "Savea", Email: "jsavea@ab.co.nz", } _, err = ns.ExecContext(ctx, js) test.Error(err) // Make sure we can pull him out again p2 := Person{} db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) if p2.Email != js.Email { t.Errorf("expected %s, got %s", js.Email, p2.Email) } // test Txn NamedStmts tx := db.MustBeginTx(ctx, nil) txns := tx.NamedStmtContext(ctx, ns) // We're going to add Steven in this txn sl := Person{ FirstName: "Steven", LastName: "Luatua", Email: "sluatua@ab.co.nz", } _, err = txns.ExecContext(ctx, sl) test.Error(err) // then rollback... tx.Rollback() // looking for Steven after a rollback should fail err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) if err != sql.ErrNoRows { t.Errorf("expected no rows error, got %v", err) } // now do the same, but commit tx = db.MustBeginTx(ctx, nil) txns = tx.NamedStmtContext(ctx, ns) _, err = txns.ExecContext(ctx, sl) test.Error(err) tx.Commit() // looking for Steven after a Commit should succeed err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) test.Error(err) if p2.Email != sl.Email { t.Errorf("expected %s, got %s", sl.Email, p2.Email) } }) } sqlx-1.3.5/named_test.go000066400000000000000000000311321422653721300151660ustar00rootroot00000000000000package sqlx import ( "database/sql" "fmt" "testing" ) func TestCompileQuery(t *testing.T) { table := []struct { Q, R, D, T, N string V []string }{ // basic test for named parameters, invalid char ',' terminating { Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`, N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, V: []string{"name", "age", "first", "last"}, }, // This query tests a named parameter ending the string as well as numbers { Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, R: `SELECT * FROM a WHERE first_name=? AND last_name=?`, D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`, T: `SELECT * FROM a WHERE first_name=@p1 AND last_name=@p2`, N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, V: []string{"name1", "name2"}, }, { Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`, D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`, T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`, N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, V: []string{"name1", "name2"}, }, { Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`, R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`, D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`, T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`, N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`, V: []string{"first_name", "last_name"}, }, { Q: `SELECT @name := "name", :age, :first, :last`, R: `SELECT @name := "name", ?, ?, ?`, D: `SELECT @name := "name", $1, $2, $3`, N: `SELECT @name := "name", :age, :first, :last`, T: `SELECT @name := "name", @p1, @p2, @p3`, V: []string{"age", "first", "last"}, }, /* This unicode awareness test sadly fails, because of our byte-wise worldview. * We could certainly iterate by Rune instead, though it's a great deal slower, * it's probably the RightWay(tm) { Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`, R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, N: []string{"name", "age", "first", "last"}, }, */ } for _, test := range table { qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION) if err != nil { t.Error(err) } if qr != test.R { t.Errorf("expected %s, got %s", test.R, qr) } if len(names) != len(test.V) { t.Errorf("expected %#v, got %#v", test.V, names) } else { for i, name := range names { if name != test.V[i] { t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name) } } } qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR) if qd != test.D { t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd) } qt, _, _ := compileNamedQuery([]byte(test.Q), AT) if qt != test.T { t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt) } qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED) if qq != test.N { t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq)) } } } type Test struct { t *testing.T } func (t Test) Error(err error, msg ...interface{}) { t.t.Helper() if err != nil { if len(msg) == 0 { t.t.Error(err) } else { t.t.Error(msg...) } } } func (t Test) Errorf(err error, format string, args ...interface{}) { t.t.Helper() if err != nil { t.t.Errorf(format, args...) } } func TestEscapedColons(t *testing.T) { t.Skip("not sure it is possible to support this in general case without an SQL parser") var qs = `SELECT * FROM testtable WHERE timeposted BETWEEN (now() AT TIME ZONE 'utc') AND (now() AT TIME ZONE 'utc') - interval '01:30:00') AND name = '\'this is a test\'' and id = :id` _, _, err := compileNamedQuery([]byte(qs), DOLLAR) if err != nil { t.Error("Didn't handle colons correctly when inside a string") } } func TestNamedQueries(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) test := Test{t} var ns *NamedStmt var err error // Check that invalid preparations fail ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first:name") if err == nil { t.Error("Expected an error with invalid prepared statement.") } ns, err = db.PrepareNamed("invalid sql") if err == nil { t.Error("Expected an error with invalid prepared statement.") } // Check closing works as anticipated ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first_name") test.Error(err) err = ns.Close() test.Error(err) ns, err = db.PrepareNamed(` SELECT first_name, last_name, email FROM person WHERE first_name=:first_name AND email=:email`) test.Error(err) // test Queryx w/ uses Query p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"} rows, err := ns.Queryx(p) test.Error(err) for rows.Next() { var p2 Person rows.StructScan(&p2) if p.FirstName != p2.FirstName { t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName) } if p.LastName != p2.LastName { t.Errorf("got %s, expected %s", p.LastName, p2.LastName) } if p.Email != p2.Email { t.Errorf("got %s, expected %s", p.Email, p2.Email) } } // test Select people := make([]Person, 0, 5) err = ns.Select(&people, p) test.Error(err) if len(people) != 1 { t.Errorf("got %d results, expected %d", len(people), 1) } if p.FirstName != people[0].FirstName { t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName) } if p.LastName != people[0].LastName { t.Errorf("got %s, expected %s", p.LastName, people[0].LastName) } if p.Email != people[0].Email { t.Errorf("got %s, expected %s", p.Email, people[0].Email) } // test struct batch inserts sls := []Person{ {FirstName: "Ardie", LastName: "Savea", Email: "asavea@ab.co.nz"}, {FirstName: "Sonny Bill", LastName: "Williams", Email: "sbw@ab.co.nz"}, {FirstName: "Ngani", LastName: "Laumape", Email: "nlaumape@ab.co.nz"}, } insert := fmt.Sprintf( "INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)\n", now, ) _, err = db.NamedExec(insert, sls) test.Error(err) // test map batch inserts slsMap := []map[string]interface{}{ {"first_name": "Ardie", "last_name": "Savea", "email": "asavea@ab.co.nz"}, {"first_name": "Sonny Bill", "last_name": "Williams", "email": "sbw@ab.co.nz"}, {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"}, } _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email) ;--`, slsMap) test.Error(err) type A map[string]interface{} typedMap := []A{ {"first_name": "Ardie", "last_name": "Savea", "email": "asavea@ab.co.nz"}, {"first_name": "Sonny Bill", "last_name": "Williams", "email": "sbw@ab.co.nz"}, {"first_name": "Ngani", "last_name": "Laumape", "email": "nlaumape@ab.co.nz"}, } _, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email) ;--`, typedMap) test.Error(err) for _, p := range sls { dest := Person{} err = db.Get(&dest, db.Rebind("SELECT * FROM person WHERE email=?"), p.Email) test.Error(err) if dest.Email != p.Email { t.Errorf("expected %s, got %s", p.Email, dest.Email) } } // test Exec ns, err = db.PrepareNamed(` INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`) test.Error(err) js := Person{ FirstName: "Julien", LastName: "Savea", Email: "jsavea@ab.co.nz", } _, err = ns.Exec(js) test.Error(err) // Make sure we can pull him out again p2 := Person{} db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email) if p2.Email != js.Email { t.Errorf("expected %s, got %s", js.Email, p2.Email) } // test Txn NamedStmts tx := db.MustBegin() txns := tx.NamedStmt(ns) // We're going to add Steven in this txn sl := Person{ FirstName: "Steven", LastName: "Luatua", Email: "sluatua@ab.co.nz", } _, err = txns.Exec(sl) test.Error(err) // then rollback... tx.Rollback() // looking for Steven after a rollback should fail err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) if err != sql.ErrNoRows { t.Errorf("expected no rows error, got %v", err) } // now do the same, but commit tx = db.MustBegin() txns = tx.NamedStmt(ns) _, err = txns.Exec(sl) test.Error(err) tx.Commit() // looking for Steven after a Commit should succeed err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email) test.Error(err) if p2.Email != sl.Email { t.Errorf("expected %s, got %s", sl.Email, p2.Email) } }) } func TestFixBounds(t *testing.T) { table := []struct { name, query, expect string loop int }{ { name: `named syntax`, query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last)`, loop: 2, }, { name: `mysql syntax`, query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?)`, loop: 2, }, { name: `named syntax w/ trailer`, query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) ;--`, expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) ;--`, loop: 2, }, { name: `mysql syntax w/ trailer`, query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?) ;--`, expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?) ;--`, loop: 2, }, { name: `not found test`, query: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`, expect: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`, loop: 2, }, { name: `found twice test`, query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`, expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`, loop: 2, }, { name: `nospace`, query: `INSERT INTO foo (a,b) VALUES(:a, :b)`, expect: `INSERT INTO foo (a,b) VALUES(:a, :b),(:a, :b)`, loop: 2, }, { name: `lowercase`, query: `INSERT INTO foo (a,b) values(:a, :b)`, expect: `INSERT INTO foo (a,b) values(:a, :b),(:a, :b)`, loop: 2, }, { name: `on duplicate key using VALUES`, query: `INSERT INTO foo (a,b) VALUES (:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`, expect: `INSERT INTO foo (a,b) VALUES (:a, :b),(:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`, loop: 2, }, { name: `single column`, query: `INSERT INTO foo (a) VALUES (:a)`, expect: `INSERT INTO foo (a) VALUES (:a),(:a)`, loop: 2, }, { name: `call now`, query: `INSERT INTO foo (a, b) VALUES (:a, NOW())`, expect: `INSERT INTO foo (a, b) VALUES (:a, NOW()),(:a, NOW())`, loop: 2, }, { name: `two level depth function call`, query: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW()))`, expect: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())),(:a, YEAR(NOW()))`, loop: 2, }, { name: `missing closing bracket`, query: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`, expect: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`, loop: 2, }, { name: `table with "values" at the end`, query: `INSERT INTO table_values (a, b) VALUES (:a, :b)`, expect: `INSERT INTO table_values (a, b) VALUES (:a, :b),(:a, :b)`, loop: 2, }, { name: `multiline indented query`, query: `INSERT INTO foo ( a, b, c, d ) VALUES ( :name, :age, :first, :last )`, expect: `INSERT INTO foo ( a, b, c, d ) VALUES ( :name, :age, :first, :last ),( :name, :age, :first, :last )`, loop: 2, }, } for _, tc := range table { t.Run(tc.name, func(t *testing.T) { res := fixBound(tc.query, tc.loop) if res != tc.expect { t.Errorf("mismatched results") } }) } } sqlx-1.3.5/reflectx/000077500000000000000000000000001422653721300143305ustar00rootroot00000000000000sqlx-1.3.5/reflectx/README.md000066400000000000000000000012571422653721300156140ustar00rootroot00000000000000# reflectx The sqlx package has special reflect needs. In particular, it needs to: * be able to map a name to a field * understand embedded structs * understand mapping names to fields by a particular tag * user specified name -> field mapping functions These behaviors mimic the behaviors by the standard library marshallers and also the behavior of standard Go accessors. The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct tags in the ways that are vital to most marshallers, and they are slow. This reflectx package extends reflect to achieve these goals. sqlx-1.3.5/reflectx/reflect.go000066400000000000000000000304621422653721300163100ustar00rootroot00000000000000// Package reflectx implements extensions to the standard reflect lib suitable // for implementing marshalling and unmarshalling packages. The main Mapper type // allows for Go-compatible named attribute access, including accessing embedded // struct attributes and the ability to use functions and struct tags to // customize field names. // package reflectx import ( "reflect" "runtime" "strings" "sync" ) // A FieldInfo is metadata for a struct field. type FieldInfo struct { Index []int Path string Field reflect.StructField Zero reflect.Value Name string Options map[string]string Embedded bool Children []*FieldInfo Parent *FieldInfo } // A StructMap is an index of field metadata for a struct. type StructMap struct { Tree *FieldInfo Index []*FieldInfo Paths map[string]*FieldInfo Names map[string]*FieldInfo } // GetByPath returns a *FieldInfo for a given string path. func (f StructMap) GetByPath(path string) *FieldInfo { return f.Paths[path] } // GetByTraversal returns a *FieldInfo for a given integer path. It is // analogous to reflect.FieldByIndex, but using the cached traversal // rather than re-executing the reflect machinery each time. func (f StructMap) GetByTraversal(index []int) *FieldInfo { if len(index) == 0 { return nil } tree := f.Tree for _, i := range index { if i >= len(tree.Children) || tree.Children[i] == nil { return nil } tree = tree.Children[i] } return tree } // Mapper is a general purpose mapper of names to struct fields. A Mapper // behaves like most marshallers in the standard library, obeying a field tag // for name mapping but also providing a basic transform function. type Mapper struct { cache map[reflect.Type]*StructMap tagName string tagMapFunc func(string) string mapFunc func(string) string mutex sync.Mutex } // NewMapper returns a new mapper using the tagName as its struct field tag. // If tagName is the empty string, it is ignored. func NewMapper(tagName string) *Mapper { return &Mapper{ cache: make(map[reflect.Type]*StructMap), tagName: tagName, } } // NewMapperTagFunc returns a new mapper which contains a mapper for field names // AND a mapper for tag values. This is useful for tags like json which can // have values like "name,omitempty". func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { return &Mapper{ cache: make(map[reflect.Type]*StructMap), tagName: tagName, mapFunc: mapFunc, tagMapFunc: tagMapFunc, } } // NewMapperFunc returns a new mapper which optionally obeys a field tag and // a struct field name mapper func given by f. Tags will take precedence, but // for any other field, the mapped name will be f(field.Name) func NewMapperFunc(tagName string, f func(string) string) *Mapper { return &Mapper{ cache: make(map[reflect.Type]*StructMap), tagName: tagName, mapFunc: f, } } // TypeMap returns a mapping of field strings to int slices representing // the traversal down the struct to reach the field. func (m *Mapper) TypeMap(t reflect.Type) *StructMap { m.mutex.Lock() mapping, ok := m.cache[t] if !ok { mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) m.cache[t] = mapping } m.mutex.Unlock() return mapping } // FieldMap returns the mapper's mapping of field names to reflect values. Panics // if v's Kind is not Struct, or v is not Indirectable to a struct kind. func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { v = reflect.Indirect(v) mustBe(v, reflect.Struct) r := map[string]reflect.Value{} tm := m.TypeMap(v.Type()) for tagName, fi := range tm.Names { r[tagName] = FieldByIndexes(v, fi.Index) } return r } // FieldByName returns a field by its mapped name as a reflect.Value. // Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. // Returns zero Value if the name is not found. func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { v = reflect.Indirect(v) mustBe(v, reflect.Struct) tm := m.TypeMap(v.Type()) fi, ok := tm.Names[name] if !ok { return v } return FieldByIndexes(v, fi.Index) } // FieldsByName returns a slice of values corresponding to the slice of names // for the value. Panics if v's Kind is not Struct or v is not Indirectable // to a struct Kind. Returns zero Value for each name not found. func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { v = reflect.Indirect(v) mustBe(v, reflect.Struct) tm := m.TypeMap(v.Type()) vals := make([]reflect.Value, 0, len(names)) for _, name := range names { fi, ok := tm.Names[name] if !ok { vals = append(vals, *new(reflect.Value)) } else { vals = append(vals, FieldByIndexes(v, fi.Index)) } } return vals } // TraversalsByName returns a slice of int slices which represent the struct // traversals for each mapped name. Panics if t is not a struct or Indirectable // to a struct. Returns empty int slice for each name not found. func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { r := make([][]int, 0, len(names)) m.TraversalsByNameFunc(t, names, func(_ int, i []int) error { if i == nil { r = append(r, []int{}) } else { r = append(r, i) } return nil }) return r } // TraversalsByNameFunc traverses the mapped names and calls fn with the index of // each name and the struct traversal represented by that name. Panics if t is not // a struct or Indirectable to a struct. Returns the first error returned by fn or nil. func (m *Mapper) TraversalsByNameFunc(t reflect.Type, names []string, fn func(int, []int) error) error { t = Deref(t) mustBe(t, reflect.Struct) tm := m.TypeMap(t) for i, name := range names { fi, ok := tm.Names[name] if !ok { if err := fn(i, nil); err != nil { return err } } else { if err := fn(i, fi.Index); err != nil { return err } } } return nil } // FieldByIndexes returns a value for the field given by the struct traversal // for the given value. func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { for _, i := range indexes { v = reflect.Indirect(v).Field(i) // if this is a pointer and it's nil, allocate a new value and set it if v.Kind() == reflect.Ptr && v.IsNil() { alloc := reflect.New(Deref(v.Type())) v.Set(alloc) } if v.Kind() == reflect.Map && v.IsNil() { v.Set(reflect.MakeMap(v.Type())) } } return v } // FieldByIndexesReadOnly returns a value for a particular struct traversal, // but is not concerned with allocating nil pointers because the value is // going to be used for reading and not setting. func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { for _, i := range indexes { v = reflect.Indirect(v).Field(i) } return v } // Deref is Indirect for reflect.Types func Deref(t reflect.Type) reflect.Type { if t.Kind() == reflect.Ptr { t = t.Elem() } return t } // -- helpers & utilities -- type kinder interface { Kind() reflect.Kind } // mustBe checks a value against a kind, panicing with a reflect.ValueError // if the kind isn't that which is required. func mustBe(v kinder, expected reflect.Kind) { if k := v.Kind(); k != expected { panic(&reflect.ValueError{Method: methodName(), Kind: k}) } } // methodName returns the caller of the function calling methodName func methodName() string { pc, _, _, _ := runtime.Caller(2) f := runtime.FuncForPC(pc) if f == nil { return "unknown method" } return f.Name() } type typeQueue struct { t reflect.Type fi *FieldInfo pp string // Parent path } // A copying append that creates a new slice each time. func apnd(is []int, i int) []int { x := make([]int, len(is)+1) copy(x, is) x[len(x)-1] = i return x } type mapf func(string) string // parseName parses the tag and the target name for the given field using // the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the // field's name to a target name, and tagMapFunc for mapping the tag to // a target name. func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) { // first, set the fieldName to the field's name fieldName = field.Name // if a mapFunc is set, use that to override the fieldName if mapFunc != nil { fieldName = mapFunc(fieldName) } // if there's no tag to look for, return the field name if tagName == "" { return "", fieldName } // if this tag is not set using the normal convention in the tag, // then return the fieldname.. this check is done because according // to the reflect documentation: // If the tag does not have the conventional format, // the value returned by Get is unspecified. // which doesn't sound great. if !strings.Contains(string(field.Tag), tagName+":") { return "", fieldName } // at this point we're fairly sure that we have a tag, so lets pull it out tag = field.Tag.Get(tagName) // if we have a mapper function, call it on the whole tag // XXX: this is a change from the old version, which pulled out the name // before the tagMapFunc could be run, but I think this is the right way if tagMapFunc != nil { tag = tagMapFunc(tag) } // finally, split the options from the name parts := strings.Split(tag, ",") fieldName = parts[0] return tag, fieldName } // parseOptions parses options out of a tag string, skipping the name func parseOptions(tag string) map[string]string { parts := strings.Split(tag, ",") options := make(map[string]string, len(parts)) if len(parts) > 1 { for _, opt := range parts[1:] { // short circuit potentially expensive split op if strings.Contains(opt, "=") { kv := strings.Split(opt, "=") options[kv[0]] = kv[1] continue } options[opt] = "" } } return options } // getMapping returns a mapping for the t type, using the tagName, mapFunc and // tagMapFunc to determine the canonical names of fields. func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap { m := []*FieldInfo{} root := &FieldInfo{} queue := []typeQueue{} queue = append(queue, typeQueue{Deref(t), root, ""}) QueueLoop: for len(queue) != 0 { // pop the first item off of the queue tq := queue[0] queue = queue[1:] // ignore recursive field for p := tq.fi.Parent; p != nil; p = p.Parent { if tq.fi.Field.Type == p.Field.Type { continue QueueLoop } } nChildren := 0 if tq.t.Kind() == reflect.Struct { nChildren = tq.t.NumField() } tq.fi.Children = make([]*FieldInfo, nChildren) // iterate through all of its fields for fieldPos := 0; fieldPos < nChildren; fieldPos++ { f := tq.t.Field(fieldPos) // parse the tag and the target name using the mapping options for this field tag, name := parseName(f, tagName, mapFunc, tagMapFunc) // if the name is "-", disabled via a tag, skip it if name == "-" { continue } fi := FieldInfo{ Field: f, Name: name, Zero: reflect.New(f.Type).Elem(), Options: parseOptions(tag), } // if the path is empty this path is just the name if tq.pp == "" { fi.Path = fi.Name } else { fi.Path = tq.pp + "." + fi.Name } // skip unexported fields if len(f.PkgPath) != 0 && !f.Anonymous { continue } // bfs search of anonymous embedded structs if f.Anonymous { pp := tq.pp if tag != "" { pp = fi.Path } fi.Embedded = true fi.Index = apnd(tq.fi.Index, fieldPos) nChildren := 0 ft := Deref(f.Type) if ft.Kind() == reflect.Struct { nChildren = ft.NumField() } fi.Children = make([]*FieldInfo, nChildren) queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { fi.Index = apnd(tq.fi.Index, fieldPos) fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) } fi.Index = apnd(tq.fi.Index, fieldPos) fi.Parent = tq.fi tq.fi.Children[fieldPos] = &fi m = append(m, &fi) } } flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} for _, fi := range flds.Index { // check if nothing has already been pushed with the same path // sometimes you can choose to override a type using embedded struct fld, ok := flds.Paths[fi.Path] if !ok || fld.Embedded { flds.Paths[fi.Path] = fi if fi.Name != "" && !fi.Embedded { flds.Names[fi.Path] = fi } } } return flds } sqlx-1.3.5/reflectx/reflect_test.go000066400000000000000000000516411422653721300173510ustar00rootroot00000000000000package reflectx import ( "reflect" "strings" "testing" ) func ival(v reflect.Value) int { return v.Interface().(int) } func TestBasic(t *testing.T) { type Foo struct { A int B int C int } f := Foo{1, 2, 3} fv := reflect.ValueOf(f) m := NewMapperFunc("", func(s string) string { return s }) v := m.FieldByName(fv, "A") if ival(v) != f.A { t.Errorf("Expecting %d, got %d", ival(v), f.A) } v = m.FieldByName(fv, "B") if ival(v) != f.B { t.Errorf("Expecting %d, got %d", f.B, ival(v)) } v = m.FieldByName(fv, "C") if ival(v) != f.C { t.Errorf("Expecting %d, got %d", f.C, ival(v)) } } func TestBasicEmbedded(t *testing.T) { type Foo struct { A int } type Bar struct { Foo // `db:""` is implied for an embedded struct B int C int `db:"-"` } type Baz struct { A int Bar `db:"Bar"` } m := NewMapperFunc("db", func(s string) string { return s }) z := Baz{} z.A = 1 z.B = 2 z.C = 4 z.Bar.Foo.A = 3 zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) if len(fields.Index) != 5 { t.Errorf("Expecting 5 fields") } // for _, fi := range fields.Index { // log.Println(fi) // } v := m.FieldByName(zv, "A") if ival(v) != z.A { t.Errorf("Expecting %d, got %d", z.A, ival(v)) } v = m.FieldByName(zv, "Bar.B") if ival(v) != z.Bar.B { t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v)) } v = m.FieldByName(zv, "Bar.A") if ival(v) != z.Bar.Foo.A { t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) } v = m.FieldByName(zv, "Bar.C") if _, ok := v.Interface().(int); ok { t.Errorf("Expecting Bar.C to not exist") } fi := fields.GetByPath("Bar.C") if fi != nil { t.Errorf("Bar.C should not exist") } } func TestEmbeddedSimple(t *testing.T) { type UUID [16]byte type MyID struct { UUID } type Item struct { ID MyID } z := Item{} m := NewMapper("db") m.TypeMap(reflect.TypeOf(z)) } func TestBasicEmbeddedWithTags(t *testing.T) { type Foo struct { A int `db:"a"` } type Bar struct { Foo // `db:""` is implied for an embedded struct B int `db:"b"` } type Baz struct { A int `db:"a"` Bar // `db:""` is implied for an embedded struct } m := NewMapper("db") z := Baz{} z.A = 1 z.B = 2 z.Bar.Foo.A = 3 zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) if len(fields.Index) != 5 { t.Errorf("Expecting 5 fields") } // for _, fi := range fields.index { // log.Println(fi) // } v := m.FieldByName(zv, "a") if ival(v) != z.A { // the dominant field t.Errorf("Expecting %d, got %d", z.A, ival(v)) } v = m.FieldByName(zv, "b") if ival(v) != z.B { t.Errorf("Expecting %d, got %d", z.B, ival(v)) } } func TestBasicEmbeddedWithSameName(t *testing.T) { type Foo struct { A int `db:"a"` Foo int `db:"Foo"` // Same name as the embedded struct } type FooExt struct { Foo B int `db:"b"` } m := NewMapper("db") z := FooExt{} z.A = 1 z.B = 2 z.Foo.Foo = 3 zv := reflect.ValueOf(z) fields := m.TypeMap(reflect.TypeOf(z)) if len(fields.Index) != 4 { t.Errorf("Expecting 3 fields, found %d", len(fields.Index)) } v := m.FieldByName(zv, "a") if ival(v) != z.A { // the dominant field t.Errorf("Expecting %d, got %d", z.A, ival(v)) } v = m.FieldByName(zv, "b") if ival(v) != z.B { t.Errorf("Expecting %d, got %d", z.B, ival(v)) } v = m.FieldByName(zv, "Foo") if ival(v) != z.Foo.Foo { t.Errorf("Expecting %d, got %d", z.Foo.Foo, ival(v)) } } func TestFlatTags(t *testing.T) { m := NewMapper("db") type Asset struct { Title string `db:"title"` } type Post struct { Author string `db:"author,required"` Asset Asset `db:""` } // Post columns: (author title) post := Post{Author: "Joe", Asset: Asset{Title: "Hello"}} pv := reflect.ValueOf(post) v := m.FieldByName(pv, "author") if v.Interface().(string) != post.Author { t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) } v = m.FieldByName(pv, "title") if v.Interface().(string) != post.Asset.Title { t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) } } func TestNestedStruct(t *testing.T) { m := NewMapper("db") type Details struct { Active bool `db:"active"` } type Asset struct { Title string `db:"title"` Details Details `db:"details"` } type Post struct { Author string `db:"author,required"` Asset `db:"asset"` } // Post columns: (author asset.title asset.details.active) post := Post{ Author: "Joe", Asset: Asset{Title: "Hello", Details: Details{Active: true}}, } pv := reflect.ValueOf(post) v := m.FieldByName(pv, "author") if v.Interface().(string) != post.Author { t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) } v = m.FieldByName(pv, "title") if _, ok := v.Interface().(string); ok { t.Errorf("Expecting field to not exist") } v = m.FieldByName(pv, "asset.title") if v.Interface().(string) != post.Asset.Title { t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) } v = m.FieldByName(pv, "asset.details.active") if v.Interface().(bool) != post.Asset.Details.Active { t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool)) } } func TestInlineStruct(t *testing.T) { m := NewMapperTagFunc("db", strings.ToLower, nil) type Employee struct { Name string ID int } type Boss Employee type person struct { Employee `db:"employee"` Boss `db:"boss"` } // employees columns: (employee.name employee.id boss.name boss.id) em := person{Employee: Employee{Name: "Joe", ID: 2}, Boss: Boss{Name: "Dick", ID: 1}} ev := reflect.ValueOf(em) fields := m.TypeMap(reflect.TypeOf(em)) if len(fields.Index) != 6 { t.Errorf("Expecting 6 fields") } v := m.FieldByName(ev, "employee.name") if v.Interface().(string) != em.Employee.Name { t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string)) } v = m.FieldByName(ev, "boss.id") if ival(v) != em.Boss.ID { t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v)) } } func TestRecursiveStruct(t *testing.T) { type Person struct { Parent *Person } m := NewMapperFunc("db", strings.ToLower) var p *Person m.TypeMap(reflect.TypeOf(p)) } func TestFieldsEmbedded(t *testing.T) { m := NewMapper("db") type Person struct { Name string `db:"name,size=64"` } type Place struct { Name string `db:"name"` } type Article struct { Title string `db:"title"` } type PP struct { Person `db:"person,required"` Place `db:",someflag"` Article `db:",required"` } // PP columns: (person.name name title) pp := PP{} pp.Person.Name = "Peter" pp.Place.Name = "Toronto" pp.Article.Title = "Best city ever" fields := m.TypeMap(reflect.TypeOf(pp)) // for i, f := range fields { // log.Println(i, f) // } ppv := reflect.ValueOf(pp) v := m.FieldByName(ppv, "person.name") if v.Interface().(string) != pp.Person.Name { t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string)) } v = m.FieldByName(ppv, "name") if v.Interface().(string) != pp.Place.Name { t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string)) } v = m.FieldByName(ppv, "title") if v.Interface().(string) != pp.Article.Title { t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string)) } fi := fields.GetByPath("person") if _, ok := fi.Options["required"]; !ok { t.Errorf("Expecting required option to be set") } if !fi.Embedded { t.Errorf("Expecting field to be embedded") } if len(fi.Index) != 1 || fi.Index[0] != 0 { t.Errorf("Expecting index to be [0]") } fi = fields.GetByPath("person.name") if fi == nil { t.Errorf("Expecting person.name to exist") } if fi.Path != "person.name" { t.Errorf("Expecting %s, got %s", "person.name", fi.Path) } if fi.Options["size"] != "64" { t.Errorf("Expecting %s, got %s", "64", fi.Options["size"]) } fi = fields.GetByTraversal([]int{1, 0}) if fi == nil { t.Errorf("Expecting traveral to exist") } if fi.Path != "name" { t.Errorf("Expecting %s, got %s", "name", fi.Path) } fi = fields.GetByTraversal([]int{2}) if fi == nil { t.Errorf("Expecting traversal to exist") } if _, ok := fi.Options["required"]; !ok { t.Errorf("Expecting required option to be set") } trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"}) if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) { t.Errorf("Expecting traversal: %v", trs) } } func TestPtrFields(t *testing.T) { m := NewMapperTagFunc("db", strings.ToLower, nil) type Asset struct { Title string } type Post struct { *Asset `db:"asset"` Author string } post := &Post{Author: "Joe", Asset: &Asset{Title: "Hiyo"}} pv := reflect.ValueOf(post) fields := m.TypeMap(reflect.TypeOf(post)) if len(fields.Index) != 3 { t.Errorf("Expecting 3 fields") } v := m.FieldByName(pv, "asset.title") if v.Interface().(string) != post.Asset.Title { t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) } v = m.FieldByName(pv, "author") if v.Interface().(string) != post.Author { t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) } } func TestNamedPtrFields(t *testing.T) { m := NewMapperTagFunc("db", strings.ToLower, nil) type User struct { Name string } type Asset struct { Title string Owner *User `db:"owner"` } type Post struct { Author string Asset1 *Asset `db:"asset1"` Asset2 *Asset `db:"asset2"` } post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil pv := reflect.ValueOf(post) fields := m.TypeMap(reflect.TypeOf(post)) if len(fields.Index) != 9 { t.Errorf("Expecting 9 fields") } v := m.FieldByName(pv, "asset1.title") if v.Interface().(string) != post.Asset1.Title { t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string)) } v = m.FieldByName(pv, "asset1.owner.name") if v.Interface().(string) != post.Asset1.Owner.Name { t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string)) } v = m.FieldByName(pv, "asset2.title") if v.Interface().(string) != post.Asset2.Title { t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string)) } v = m.FieldByName(pv, "asset2.owner.name") if v.Interface().(string) != post.Asset2.Owner.Name { t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string)) } v = m.FieldByName(pv, "author") if v.Interface().(string) != post.Author { t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) } } func TestFieldMap(t *testing.T) { type Foo struct { A int B int C int } f := Foo{1, 2, 3} m := NewMapperFunc("db", strings.ToLower) fm := m.FieldMap(reflect.ValueOf(f)) if len(fm) != 3 { t.Errorf("Expecting %d keys, got %d", 3, len(fm)) } if fm["a"].Interface().(int) != 1 { t.Errorf("Expecting %d, got %d", 1, ival(fm["a"])) } if fm["b"].Interface().(int) != 2 { t.Errorf("Expecting %d, got %d", 2, ival(fm["b"])) } if fm["c"].Interface().(int) != 3 { t.Errorf("Expecting %d, got %d", 3, ival(fm["c"])) } } func TestTagNameMapping(t *testing.T) { type Strategy struct { StrategyID string `protobuf:"bytes,1,opt,name=strategy_id" json:"strategy_id,omitempty"` StrategyName string } m := NewMapperTagFunc("json", strings.ToUpper, func(value string) string { if strings.Contains(value, ",") { return strings.Split(value, ",")[0] } return value }) strategy := Strategy{"1", "Alpah"} mapping := m.TypeMap(reflect.TypeOf(strategy)) for _, key := range []string{"strategy_id", "STRATEGYNAME"} { if fi := mapping.GetByPath(key); fi == nil { t.Errorf("Expecting to find key %s in mapping but did not.", key) } } } func TestMapping(t *testing.T) { type Person struct { ID int Name string WearsGlasses bool `db:"wears_glasses"` } m := NewMapperFunc("db", strings.ToLower) p := Person{1, "Jason", true} mapping := m.TypeMap(reflect.TypeOf(p)) for _, key := range []string{"id", "name", "wears_glasses"} { if fi := mapping.GetByPath(key); fi == nil { t.Errorf("Expecting to find key %s in mapping but did not.", key) } } type SportsPerson struct { Weight int Age int Person } s := SportsPerson{Weight: 100, Age: 30, Person: p} mapping = m.TypeMap(reflect.TypeOf(s)) for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} { if fi := mapping.GetByPath(key); fi == nil { t.Errorf("Expecting to find key %s in mapping but did not.", key) } } type RugbyPlayer struct { Position int IsIntense bool `db:"is_intense"` IsAllBlack bool `db:"-"` SportsPerson } r := RugbyPlayer{12, true, false, s} mapping = m.TypeMap(reflect.TypeOf(r)) for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} { if fi := mapping.GetByPath(key); fi == nil { t.Errorf("Expecting to find key %s in mapping but did not.", key) } } if fi := mapping.GetByPath("isallblack"); fi != nil { t.Errorf("Expecting to ignore `IsAllBlack` field") } } func TestGetByTraversal(t *testing.T) { type C struct { C0 int C1 int } type B struct { B0 string B1 *C } type A struct { A0 int A1 B } testCases := []struct { Index []int ExpectedName string ExpectNil bool }{ { Index: []int{0}, ExpectedName: "A0", }, { Index: []int{1, 0}, ExpectedName: "B0", }, { Index: []int{1, 1, 1}, ExpectedName: "C1", }, { Index: []int{3, 4, 5}, ExpectNil: true, }, { Index: []int{}, ExpectNil: true, }, { Index: nil, ExpectNil: true, }, } m := NewMapperFunc("db", func(n string) string { return n }) tm := m.TypeMap(reflect.TypeOf(A{})) for i, tc := range testCases { fi := tm.GetByTraversal(tc.Index) if tc.ExpectNil { if fi != nil { t.Errorf("%d: expected nil, got %v", i, fi) } continue } if fi == nil { t.Errorf("%d: expected %s, got nil", i, tc.ExpectedName) continue } if fi.Name != tc.ExpectedName { t.Errorf("%d: expected %s, got %s", i, tc.ExpectedName, fi.Name) } } } // TestMapperMethodsByName tests Mapper methods FieldByName and TraversalsByName func TestMapperMethodsByName(t *testing.T) { type C struct { C0 string C1 int } type B struct { B0 *C `db:"B0"` B1 C `db:"B1"` B2 string `db:"B2"` } type A struct { A0 *B `db:"A0"` B `db:"A1"` A2 int a3 int } val := &A{ A0: &B{ B0: &C{C0: "0", C1: 1}, B1: C{C0: "2", C1: 3}, B2: "4", }, B: B{ B0: nil, B1: C{C0: "5", C1: 6}, B2: "7", }, A2: 8, } testCases := []struct { Name string ExpectInvalid bool ExpectedValue interface{} ExpectedIndexes []int }{ { Name: "A0.B0.C0", ExpectedValue: "0", ExpectedIndexes: []int{0, 0, 0}, }, { Name: "A0.B0.C1", ExpectedValue: 1, ExpectedIndexes: []int{0, 0, 1}, }, { Name: "A0.B1.C0", ExpectedValue: "2", ExpectedIndexes: []int{0, 1, 0}, }, { Name: "A0.B1.C1", ExpectedValue: 3, ExpectedIndexes: []int{0, 1, 1}, }, { Name: "A0.B2", ExpectedValue: "4", ExpectedIndexes: []int{0, 2}, }, { Name: "A1.B0.C0", ExpectedValue: "", ExpectedIndexes: []int{1, 0, 0}, }, { Name: "A1.B0.C1", ExpectedValue: 0, ExpectedIndexes: []int{1, 0, 1}, }, { Name: "A1.B1.C0", ExpectedValue: "5", ExpectedIndexes: []int{1, 1, 0}, }, { Name: "A1.B1.C1", ExpectedValue: 6, ExpectedIndexes: []int{1, 1, 1}, }, { Name: "A1.B2", ExpectedValue: "7", ExpectedIndexes: []int{1, 2}, }, { Name: "A2", ExpectedValue: 8, ExpectedIndexes: []int{2}, }, { Name: "XYZ", ExpectInvalid: true, ExpectedIndexes: []int{}, }, { Name: "a3", ExpectInvalid: true, ExpectedIndexes: []int{}, }, } // build the names array from the test cases names := make([]string, len(testCases)) for i, tc := range testCases { names[i] = tc.Name } m := NewMapperFunc("db", func(n string) string { return n }) v := reflect.ValueOf(val) values := m.FieldsByName(v, names) if len(values) != len(testCases) { t.Errorf("expected %d values, got %d", len(testCases), len(values)) t.FailNow() } indexes := m.TraversalsByName(v.Type(), names) if len(indexes) != len(testCases) { t.Errorf("expected %d traversals, got %d", len(testCases), len(indexes)) t.FailNow() } for i, val := range values { tc := testCases[i] traversal := indexes[i] if !reflect.DeepEqual(tc.ExpectedIndexes, traversal) { t.Errorf("expected %v, got %v", tc.ExpectedIndexes, traversal) t.FailNow() } val = reflect.Indirect(val) if tc.ExpectInvalid { if val.IsValid() { t.Errorf("%d: expected zero value, got %v", i, val) } continue } if !val.IsValid() { t.Errorf("%d: expected valid value, got %v", i, val) continue } actualValue := reflect.Indirect(val).Interface() if !reflect.DeepEqual(tc.ExpectedValue, actualValue) { t.Errorf("%d: expected %v, got %v", i, tc.ExpectedValue, actualValue) } } } func TestFieldByIndexes(t *testing.T) { type C struct { C0 bool C1 string C2 int C3 map[string]int } type B struct { B1 C B2 *C } type A struct { A1 B A2 *B } testCases := []struct { value interface{} indexes []int expectedValue interface{} readOnly bool }{ { value: A{ A1: B{B1: C{C0: true}}, }, indexes: []int{0, 0, 0}, expectedValue: true, readOnly: true, }, { value: A{ A2: &B{B2: &C{C1: "answer"}}, }, indexes: []int{1, 1, 1}, expectedValue: "answer", readOnly: true, }, { value: &A{}, indexes: []int{1, 1, 3}, expectedValue: map[string]int{}, }, } for i, tc := range testCases { checkResults := func(v reflect.Value) { if tc.expectedValue == nil { if !v.IsNil() { t.Errorf("%d: expected nil, actual %v", i, v.Interface()) } } else { if !reflect.DeepEqual(tc.expectedValue, v.Interface()) { t.Errorf("%d: expected %v, actual %v", i, tc.expectedValue, v.Interface()) } } } checkResults(FieldByIndexes(reflect.ValueOf(tc.value), tc.indexes)) if tc.readOnly { checkResults(FieldByIndexesReadOnly(reflect.ValueOf(tc.value), tc.indexes)) } } } func TestMustBe(t *testing.T) { typ := reflect.TypeOf(E1{}) mustBe(typ, reflect.Struct) defer func() { if r := recover(); r != nil { valueErr, ok := r.(*reflect.ValueError) if !ok { t.Errorf("unexpected Method: %s", valueErr.Method) t.Error("expected panic with *reflect.ValueError") return } if valueErr.Method != "github.com/jmoiron/sqlx/reflectx.TestMustBe" { } if valueErr.Kind != reflect.String { t.Errorf("unexpected Kind: %s", valueErr.Kind) } } else { t.Error("expected panic") } }() typ = reflect.TypeOf("string") mustBe(typ, reflect.Struct) t.Error("got here, didn't expect to") } type E1 struct { A int } type E2 struct { E1 B int } type E3 struct { E2 C int } type E4 struct { E3 D int } func BenchmarkFieldNameL1(b *testing.B) { e4 := E4{D: 1} for i := 0; i < b.N; i++ { v := reflect.ValueOf(e4) f := v.FieldByName("D") if f.Interface().(int) != 1 { b.Fatal("Wrong value.") } } } func BenchmarkFieldNameL4(b *testing.B) { e4 := E4{} e4.A = 1 for i := 0; i < b.N; i++ { v := reflect.ValueOf(e4) f := v.FieldByName("A") if f.Interface().(int) != 1 { b.Fatal("Wrong value.") } } } func BenchmarkFieldPosL1(b *testing.B) { e4 := E4{D: 1} for i := 0; i < b.N; i++ { v := reflect.ValueOf(e4) f := v.Field(1) if f.Interface().(int) != 1 { b.Fatal("Wrong value.") } } } func BenchmarkFieldPosL4(b *testing.B) { e4 := E4{} e4.A = 1 for i := 0; i < b.N; i++ { v := reflect.ValueOf(e4) f := v.Field(0) f = f.Field(0) f = f.Field(0) f = f.Field(0) if f.Interface().(int) != 1 { b.Fatal("Wrong value.") } } } func BenchmarkFieldByIndexL4(b *testing.B) { e4 := E4{} e4.A = 1 idx := []int{0, 0, 0, 0} for i := 0; i < b.N; i++ { v := reflect.ValueOf(e4) f := FieldByIndexes(v, idx) if f.Interface().(int) != 1 { b.Fatal("Wrong value.") } } } func BenchmarkTraversalsByName(b *testing.B) { type A struct { Value int } type B struct { A A } type C struct { B B } type D struct { C C } m := NewMapper("") t := reflect.TypeOf(D{}) names := []string{"C", "B", "A", "Value"} b.ResetTimer() for i := 0; i < b.N; i++ { if l := len(m.TraversalsByName(t, names)); l != len(names) { b.Errorf("expected %d values, got %d", len(names), l) } } } func BenchmarkTraversalsByNameFunc(b *testing.B) { type A struct { Z int } type B struct { A A } type C struct { B B } type D struct { C C } m := NewMapper("") t := reflect.TypeOf(D{}) names := []string{"C", "B", "A", "Z", "Y"} b.ResetTimer() for i := 0; i < b.N; i++ { var l int if err := m.TraversalsByNameFunc(t, names, func(_ int, _ []int) error { l++ return nil }); err != nil { b.Errorf("unexpected error %s", err) } if l != len(names) { b.Errorf("expected %d values, got %d", len(names), l) } } } sqlx-1.3.5/sqlx.go000066400000000000000000000752741422653721300140510ustar00rootroot00000000000000package sqlx import ( "database/sql" "database/sql/driver" "errors" "fmt" "io/ioutil" "path/filepath" "reflect" "strings" "sync" "github.com/jmoiron/sqlx/reflectx" ) // Although the NameMapper is convenient, in practice it should not // be relied on except for application code. If you are writing a library // that uses sqlx, you should be aware that the name mappings you expect // can be overridden by your user's application. // NameMapper is used to map column names to struct field names. By default, // it uses strings.ToLower to lowercase struct field names. It can be set // to whatever you want, but it is encouraged to be set before sqlx is used // as name-to-field mappings are cached after first use on a type. var NameMapper = strings.ToLower var origMapper = reflect.ValueOf(NameMapper) // Rather than creating on init, this is created when necessary so that // importers have time to customize the NameMapper. var mpr *reflectx.Mapper // mprMu protects mpr. var mprMu sync.Mutex // mapper returns a valid mapper using the configured NameMapper func. func mapper() *reflectx.Mapper { mprMu.Lock() defer mprMu.Unlock() if mpr == nil { mpr = reflectx.NewMapperFunc("db", NameMapper) } else if origMapper != reflect.ValueOf(NameMapper) { // if NameMapper has changed, create a new mapper mpr = reflectx.NewMapperFunc("db", NameMapper) origMapper = reflect.ValueOf(NameMapper) } return mpr } // isScannable takes the reflect.Type and the actual dest value and returns // whether or not it's Scannable. Something is scannable if: // * it is not a struct // * it implements sql.Scanner // * it has no exported fields func isScannable(t reflect.Type) bool { if reflect.PtrTo(t).Implements(_scannerInterface) { return true } if t.Kind() != reflect.Struct { return true } // it's not important that we use the right mapper for this particular object, // we're only concerned on how many exported fields this struct has return len(mapper().TypeMap(t).Index) == 0 } // ColScanner is an interface used by MapScan and SliceScan type ColScanner interface { Columns() ([]string, error) Scan(dest ...interface{}) error Err() error } // Queryer is an interface used by Get and Select type Queryer interface { Query(query string, args ...interface{}) (*sql.Rows, error) Queryx(query string, args ...interface{}) (*Rows, error) QueryRowx(query string, args ...interface{}) *Row } // Execer is an interface used by MustExec and LoadFile type Execer interface { Exec(query string, args ...interface{}) (sql.Result, error) } // Binder is an interface for something which can bind queries (Tx, DB) type binder interface { DriverName() string Rebind(string) string BindNamed(string, interface{}) (string, []interface{}, error) } // Ext is a union interface which can bind, query, and exec, used by // NamedQuery and NamedExec. type Ext interface { binder Queryer Execer } // Preparer is an interface used by Preparex. type Preparer interface { Prepare(query string) (*sql.Stmt, error) } // determine if any of our extensions are unsafe func isUnsafe(i interface{}) bool { switch v := i.(type) { case Row: return v.unsafe case *Row: return v.unsafe case Rows: return v.unsafe case *Rows: return v.unsafe case NamedStmt: return v.Stmt.unsafe case *NamedStmt: return v.Stmt.unsafe case Stmt: return v.unsafe case *Stmt: return v.unsafe case qStmt: return v.unsafe case *qStmt: return v.unsafe case DB: return v.unsafe case *DB: return v.unsafe case Tx: return v.unsafe case *Tx: return v.unsafe case sql.Rows, *sql.Rows: return false default: return false } } func mapperFor(i interface{}) *reflectx.Mapper { switch i := i.(type) { case DB: return i.Mapper case *DB: return i.Mapper case Tx: return i.Mapper case *Tx: return i.Mapper default: return mapper() } } var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem() var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem() // Row is a reimplementation of sql.Row in order to gain access to the underlying // sql.Rows.Columns() data, necessary for StructScan. type Row struct { err error unsafe bool rows *sql.Rows Mapper *reflectx.Mapper } // Scan is a fixed implementation of sql.Row.Scan, which does not discard the // underlying error from the internal rows object if it exists. func (r *Row) Scan(dest ...interface{}) error { if r.err != nil { return r.err } // TODO(bradfitz): for now we need to defensively clone all // []byte that the driver returned (not permitting // *RawBytes in Rows.Scan), since we're about to close // the Rows in our defer, when we return from this function. // the contract with the driver.Next(...) interface is that it // can return slices into read-only temporary memory that's // only valid until the next Scan/Close. But the TODO is that // for a lot of drivers, this copy will be unnecessary. We // should provide an optional interface for drivers to // implement to say, "don't worry, the []bytes that I return // from Next will not be modified again." (for instance, if // they were obtained from the network anyway) But for now we // don't care. defer r.rows.Close() for _, dp := range dest { if _, ok := dp.(*sql.RawBytes); ok { return errors.New("sql: RawBytes isn't allowed on Row.Scan") } } if !r.rows.Next() { if err := r.rows.Err(); err != nil { return err } return sql.ErrNoRows } err := r.rows.Scan(dest...) if err != nil { return err } // Make sure the query can be processed to completion with no errors. if err := r.rows.Close(); err != nil { return err } return nil } // Columns returns the underlying sql.Rows.Columns(), or the deferred error usually // returned by Row.Scan() func (r *Row) Columns() ([]string, error) { if r.err != nil { return []string{}, r.err } return r.rows.Columns() } // ColumnTypes returns the underlying sql.Rows.ColumnTypes(), or the deferred error func (r *Row) ColumnTypes() ([]*sql.ColumnType, error) { if r.err != nil { return []*sql.ColumnType{}, r.err } return r.rows.ColumnTypes() } // Err returns the error encountered while scanning. func (r *Row) Err() error { return r.err } // DB is a wrapper around sql.DB which keeps track of the driverName upon Open, // used mostly to automatically bind named queries using the right bindvars. type DB struct { *sql.DB driverName string unsafe bool Mapper *reflectx.Mapper } // NewDb returns a new sqlx DB wrapper for a pre-existing *sql.DB. The // driverName of the original database is required for named query support. func NewDb(db *sql.DB, driverName string) *DB { return &DB{DB: db, driverName: driverName, Mapper: mapper()} } // DriverName returns the driverName passed to the Open function for this DB. func (db *DB) DriverName() string { return db.driverName } // Open is the same as sql.Open, but returns an *sqlx.DB instead. func Open(driverName, dataSourceName string) (*DB, error) { db, err := sql.Open(driverName, dataSourceName) if err != nil { return nil, err } return &DB{DB: db, driverName: driverName, Mapper: mapper()}, err } // MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error. func MustOpen(driverName, dataSourceName string) *DB { db, err := Open(driverName, dataSourceName) if err != nil { panic(err) } return db } // MapperFunc sets a new mapper for this db using the default sqlx struct tag // and the provided mapper function. func (db *DB) MapperFunc(mf func(string) string) { db.Mapper = reflectx.NewMapperFunc("db", mf) } // Rebind transforms a query from QUESTION to the DB driver's bindvar type. func (db *DB) Rebind(query string) string { return Rebind(BindType(db.driverName), query) } // Unsafe returns a version of DB which will silently succeed to scan when // columns in the SQL result have no fields in the destination struct. // sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its // safety behavior. func (db *DB) Unsafe() *DB { return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper} } // BindNamed binds a query using the DB driver's bindvar type. func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) { return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper) } // NamedQuery using this DB. // Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(db, query, arg) } // NamedExec using this DB. // Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(db, query, arg) } // Select using this DB. // Any placeholder parameters are replaced with supplied args. func (db *DB) Select(dest interface{}, query string, args ...interface{}) error { return Select(db, dest, query, args...) } // Get using this DB. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { return Get(db, dest, query, args...) } // MustBegin starts a transaction, and panics on error. Returns an *sqlx.Tx instead // of an *sql.Tx. func (db *DB) MustBegin() *Tx { tx, err := db.Beginx() if err != nil { panic(err) } return tx } // Beginx begins a transaction and returns an *sqlx.Tx instead of an *sql.Tx. func (db *DB) Beginx() (*Tx, error) { tx, err := db.DB.Begin() if err != nil { return nil, err } return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err } // Queryx queries the database and returns an *sqlx.Rows. // Any placeholder parameters are replaced with supplied args. func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := db.DB.Query(query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err } // QueryRowx queries the database and returns an *sqlx.Row. // Any placeholder parameters are replaced with supplied args. func (db *DB) QueryRowx(query string, args ...interface{}) *Row { rows, err := db.DB.Query(query, args...) return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} } // MustExec (panic) runs MustExec using this database. // Any placeholder parameters are replaced with supplied args. func (db *DB) MustExec(query string, args ...interface{}) sql.Result { return MustExec(db, query, args...) } // Preparex returns an sqlx.Stmt instead of a sql.Stmt func (db *DB) Preparex(query string) (*Stmt, error) { return Preparex(db, query) } // PrepareNamed returns an sqlx.NamedStmt func (db *DB) PrepareNamed(query string) (*NamedStmt, error) { return prepareNamed(db, query) } // Conn is a wrapper around sql.Conn with extra functionality type Conn struct { *sql.Conn driverName string unsafe bool Mapper *reflectx.Mapper } // Tx is an sqlx wrapper around sql.Tx with extra functionality type Tx struct { *sql.Tx driverName string unsafe bool Mapper *reflectx.Mapper } // DriverName returns the driverName used by the DB which began this transaction. func (tx *Tx) DriverName() string { return tx.driverName } // Rebind a query within a transaction's bindvar type. func (tx *Tx) Rebind(query string) string { return Rebind(BindType(tx.driverName), query) } // Unsafe returns a version of Tx which will silently succeed to scan when // columns in the SQL result have no fields in the destination struct. func (tx *Tx) Unsafe() *Tx { return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper} } // BindNamed binds a query within a transaction's bindvar type. func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) { return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper) } // NamedQuery within a transaction. // Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(tx, query, arg) } // NamedExec a named query within a transaction. // Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(tx, query, arg) } // Select within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error { return Select(tx, dest, query, args...) } // Queryx within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := tx.Tx.Query(query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err } // QueryRowx within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { rows, err := tx.Tx.Query(query, args...) return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} } // Get within a transaction. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { return Get(tx, dest, query, args...) } // MustExec runs MustExec within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result { return MustExec(tx, query, args...) } // Preparex a statement within a transaction. func (tx *Tx) Preparex(query string) (*Stmt, error) { return Preparex(tx, query) } // Stmtx returns a version of the prepared statement which runs within a transaction. Provided // stmt can be either *sql.Stmt or *sqlx.Stmt. func (tx *Tx) Stmtx(stmt interface{}) *Stmt { var s *sql.Stmt switch v := stmt.(type) { case Stmt: s = v.Stmt case *Stmt: s = v.Stmt case *sql.Stmt: s = v default: panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) } return &Stmt{Stmt: tx.Stmt(s), Mapper: tx.Mapper} } // NamedStmt returns a version of the prepared statement which runs within a transaction. func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt { return &NamedStmt{ QueryString: stmt.QueryString, Params: stmt.Params, Stmt: tx.Stmtx(stmt.Stmt), } } // PrepareNamed returns an sqlx.NamedStmt func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) { return prepareNamed(tx, query) } // Stmt is an sqlx wrapper around sql.Stmt with extra functionality type Stmt struct { *sql.Stmt unsafe bool Mapper *reflectx.Mapper } // Unsafe returns a version of Stmt which will silently succeed to scan when // columns in the SQL result have no fields in the destination struct. func (s *Stmt) Unsafe() *Stmt { return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper} } // Select using the prepared statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) Select(dest interface{}, args ...interface{}) error { return Select(&qStmt{s}, dest, "", args...) } // Get using the prepared statement. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (s *Stmt) Get(dest interface{}, args ...interface{}) error { return Get(&qStmt{s}, dest, "", args...) } // MustExec (panic) using this statement. Note that the query portion of the error // output will be blank, as Stmt does not expose its query. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) MustExec(args ...interface{}) sql.Result { return MustExec(&qStmt{s}, "", args...) } // QueryRowx using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryRowx(args ...interface{}) *Row { qs := &qStmt{s} return qs.QueryRowx("", args...) } // Queryx using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) { qs := &qStmt{s} return qs.Queryx("", args...) } // qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by // implementing those interfaces and ignoring the `query` argument. type qStmt struct{ *Stmt } func (q *qStmt) Query(query string, args ...interface{}) (*sql.Rows, error) { return q.Stmt.Query(args...) } func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := q.Stmt.Query(args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err } func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row { rows, err := q.Stmt.Query(args...) return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} } func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) { return q.Stmt.Exec(args...) } // Rows is a wrapper around sql.Rows which caches costly reflect operations // during a looped StructScan type Rows struct { *sql.Rows unsafe bool Mapper *reflectx.Mapper // these fields cache memory use for a rows during iteration w/ structScan started bool fields [][]int values []interface{} } // SliceScan using this Rows. func (r *Rows) SliceScan() ([]interface{}, error) { return SliceScan(r) } // MapScan using this Rows. func (r *Rows) MapScan(dest map[string]interface{}) error { return MapScan(r, dest) } // StructScan is like sql.Rows.Scan, but scans a single Row into a single Struct. // Use this and iterate over Rows manually when the memory load of Select() might be // prohibitive. *Rows.StructScan caches the reflect work of matching up column // positions to fields to avoid that overhead per scan, which means it is not safe // to run StructScan on the same Rows instance with different struct types. func (r *Rows) StructScan(dest interface{}) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { return errors.New("must pass a pointer, not a value, to StructScan destination") } v = v.Elem() if !r.started { columns, err := r.Columns() if err != nil { return err } m := r.Mapper r.fields = m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(r.fields); err != nil && !r.unsafe { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } r.values = make([]interface{}, len(columns)) r.started = true } err := fieldsByTraversal(v, r.fields, r.values, true) if err != nil { return err } // scan into the struct field pointers and append to our results err = r.Scan(r.values...) if err != nil { return err } return r.Err() } // Connect to a database and verify with a ping. func Connect(driverName, dataSourceName string) (*DB, error) { db, err := Open(driverName, dataSourceName) if err != nil { return nil, err } err = db.Ping() if err != nil { db.Close() return nil, err } return db, nil } // MustConnect connects to a database and panics on error. func MustConnect(driverName, dataSourceName string) *DB { db, err := Connect(driverName, dataSourceName) if err != nil { panic(err) } return db } // Preparex prepares a statement. func Preparex(p Preparer, query string) (*Stmt, error) { s, err := p.Prepare(query) if err != nil { return nil, err } return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err } // Select executes a query using the provided Queryer, and StructScans each row // into dest, which must be a slice. If the slice elements are scannable, then // the result set must have only one column. Otherwise, StructScan is used. // The *sql.Rows are closed automatically. // Any placeholder parameters are replaced with supplied args. func Select(q Queryer, dest interface{}, query string, args ...interface{}) error { rows, err := q.Queryx(query, args...) if err != nil { return err } // if something happens here, we want to make sure the rows are Closed defer rows.Close() return scanAll(rows, dest, false) } // Get does a QueryRow using the provided Queryer, and scans the resulting row // to dest. If dest is scannable, the result must only have one column. Otherwise, // StructScan is used. Get will return sql.ErrNoRows like row.Scan would. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { r := q.QueryRowx(query, args...) return r.scanAny(dest, false) } // LoadFile exec's every statement in a file (as a single call to Exec). // LoadFile may return a nil *sql.Result if errors are encountered locating or // reading the file at path. LoadFile reads the entire file into memory, so it // is not suitable for loading large data dumps, but can be useful for initializing // schemas or loading indexes. // // FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 // or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting // this by requiring something with DriverName() and then attempting to split the // queries will be difficult to get right, and its current driver-specific behavior // is deemed at least not complex in its incorrectness. func LoadFile(e Execer, path string) (*sql.Result, error) { realpath, err := filepath.Abs(path) if err != nil { return nil, err } contents, err := ioutil.ReadFile(realpath) if err != nil { return nil, err } res, err := e.Exec(string(contents)) return &res, err } // MustExec execs the query using e and panics if there was an error. // Any placeholder parameters are replaced with supplied args. func MustExec(e Execer, query string, args ...interface{}) sql.Result { res, err := e.Exec(query, args...) if err != nil { panic(err) } return res } // SliceScan using this Rows. func (r *Row) SliceScan() ([]interface{}, error) { return SliceScan(r) } // MapScan using this Rows. func (r *Row) MapScan(dest map[string]interface{}) error { return MapScan(r, dest) } func (r *Row) scanAny(dest interface{}, structOnly bool) error { if r.err != nil { return r.err } if r.rows == nil { r.err = sql.ErrNoRows return r.err } defer r.rows.Close() v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { return errors.New("must pass a pointer, not a value, to StructScan destination") } if v.IsNil() { return errors.New("nil pointer passed to StructScan destination") } base := reflectx.Deref(v.Type()) scannable := isScannable(base) if structOnly && scannable { return structOnlyError(base) } columns, err := r.Columns() if err != nil { return err } if scannable && len(columns) > 1 { return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns)) } if scannable { return r.Scan(dest) } m := r.Mapper fields := m.TraversalsByName(v.Type(), columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(fields); err != nil && !r.unsafe { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values := make([]interface{}, len(columns)) err = fieldsByTraversal(v, fields, values, true) if err != nil { return err } // scan into the struct field pointers and append to our results return r.Scan(values...) } // StructScan a single Row into dest. func (r *Row) StructScan(dest interface{}) error { return r.scanAny(dest, true) } // SliceScan a row, returning a []interface{} with values similar to MapScan. // This function is primarily intended for use where the number of columns // is not known. Because you can pass an []interface{} directly to Scan, // it's recommended that you do that as it will not have to allocate new // slices per row. func SliceScan(r ColScanner) ([]interface{}, error) { // ignore r.started, since we needn't use reflect for anything. columns, err := r.Columns() if err != nil { return []interface{}{}, err } values := make([]interface{}, len(columns)) for i := range values { values[i] = new(interface{}) } err = r.Scan(values...) if err != nil { return values, err } for i := range columns { values[i] = *(values[i].(*interface{})) } return values, r.Err() } // MapScan scans a single Row into the dest map[string]interface{}. // Use this to get results for SQL that might not be under your control // (for instance, if you're building an interface for an SQL server that // executes SQL from input). Please do not use this as a primary interface! // This will modify the map sent to it in place, so reuse the same map with // care. Columns which occur more than once in the result will overwrite // each other! func MapScan(r ColScanner, dest map[string]interface{}) error { // ignore r.started, since we needn't use reflect for anything. columns, err := r.Columns() if err != nil { return err } values := make([]interface{}, len(columns)) for i := range values { values[i] = new(interface{}) } err = r.Scan(values...) if err != nil { return err } for i, column := range columns { dest[column] = *(values[i].(*interface{})) } return r.Err() } type rowsi interface { Close() error Columns() ([]string, error) Err() error Next() bool Scan(...interface{}) error } // structOnlyError returns an error appropriate for type when a non-scannable // struct is expected but something else is given func structOnlyError(t reflect.Type) error { isStruct := t.Kind() == reflect.Struct isScanner := reflect.PtrTo(t).Implements(_scannerInterface) if !isStruct { return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind()) } if isScanner { return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name()) } return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name()) } // scanAll scans all rows into a destination, which must be a slice of any // type. It resets the slice length to zero before appending each element to // the slice. If the destination slice type is a Struct, then StructScan will // be used on each row. If the destination is some other kind of base type, // then each row must only have one column which can scan into that type. This // allows you to do something like: // // rows, _ := db.Query("select id from people;") // var ids []int // scanAll(rows, &ids, false) // // and ids will be a list of the id results. I realize that this is a desirable // interface to expose to users, but for now it will only be exposed via changes // to `Get` and `Select`. The reason that this has been implemented like this is // this is the only way to not duplicate reflect work in the new API while // maintaining backwards compatibility. func scanAll(rows rowsi, dest interface{}, structOnly bool) error { var v, vp reflect.Value value := reflect.ValueOf(dest) // json.Unmarshal returns errors for these if value.Kind() != reflect.Ptr { return errors.New("must pass a pointer, not a value, to StructScan destination") } if value.IsNil() { return errors.New("nil pointer passed to StructScan destination") } direct := reflect.Indirect(value) slice, err := baseType(value.Type(), reflect.Slice) if err != nil { return err } direct.SetLen(0) isPtr := slice.Elem().Kind() == reflect.Ptr base := reflectx.Deref(slice.Elem()) scannable := isScannable(base) if structOnly && scannable { return structOnlyError(base) } columns, err := rows.Columns() if err != nil { return err } // if it's a base type make sure it only has 1 column; if not return an error if scannable && len(columns) > 1 { return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns)) } if !scannable { var values []interface{} var m *reflectx.Mapper switch rows.(type) { case *Rows: m = rows.(*Rows).Mapper default: m = mapper() } fields := m.TraversalsByName(base, columns) // if we are not unsafe and are missing fields, return an error if f, err := missingFields(fields); err != nil && !isUnsafe(rows) { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } values = make([]interface{}, len(columns)) for rows.Next() { // create a new struct type (which returns PtrTo) and indirect it vp = reflect.New(base) v = reflect.Indirect(vp) err = fieldsByTraversal(v, fields, values, true) if err != nil { return err } // scan into the struct field pointers and append to our results err = rows.Scan(values...) if err != nil { return err } if isPtr { direct.Set(reflect.Append(direct, vp)) } else { direct.Set(reflect.Append(direct, v)) } } } else { for rows.Next() { vp = reflect.New(base) err = rows.Scan(vp.Interface()) if err != nil { return err } // append if isPtr { direct.Set(reflect.Append(direct, vp)) } else { direct.Set(reflect.Append(direct, reflect.Indirect(vp))) } } } return rows.Err() } // FIXME: StructScan was the very first bit of API in sqlx, and now unfortunately // it doesn't really feel like it's named properly. There is an incongruency // between this and the way that StructScan (which might better be ScanStruct // anyway) works on a rows object. // StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice. // StructScan will scan in the entire rows result, so if you do not want to // allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan. // If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default. func StructScan(rows rowsi, dest interface{}) error { return scanAll(rows, dest, true) } // reflect helpers func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { t = reflectx.Deref(t) if t.Kind() != expected { return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind()) } return t, nil } // fieldsByName fills a values interface with fields from the passed value based // on the traversals in int. If ptrs is true, return addresses instead of values. // We write this instead of using FieldsByName to save allocations and map lookups // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") } for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) continue } f := reflectx.FieldByIndexes(v, traversal) if ptrs { values[i] = f.Addr().Interface() } else { values[i] = f.Interface() } } return nil } func missingFields(transversals [][]int) (field int, err error) { for i, t := range transversals { if len(t) == 0 { return i, errors.New("missing field") } } return 0, nil } sqlx-1.3.5/sqlx_context.go000066400000000000000000000373561422653721300156140ustar00rootroot00000000000000// +build go1.8 package sqlx import ( "context" "database/sql" "fmt" "io/ioutil" "path/filepath" "reflect" ) // ConnectContext to a database and verify with a ping. func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) { db, err := Open(driverName, dataSourceName) if err != nil { return db, err } err = db.PingContext(ctx) return db, err } // QueryerContext is an interface used by GetContext and SelectContext type QueryerContext interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row } // PreparerContext is an interface used by PreparexContext. type PreparerContext interface { PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) } // ExecerContext is an interface used by MustExecContext and LoadFileContext type ExecerContext interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) } // ExtContext is a union interface which can bind, query, and exec, with Context // used by NamedQueryContext and NamedExecContext. type ExtContext interface { binder QueryerContext ExecerContext } // SelectContext executes a query using the provided Queryer, and StructScans // each row into dest, which must be a slice. If the slice elements are // scannable, then the result set must have only one column. Otherwise, // StructScan is used. The *sql.Rows are closed automatically. // Any placeholder parameters are replaced with supplied args. func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { rows, err := q.QueryxContext(ctx, query, args...) if err != nil { return err } // if something happens here, we want to make sure the rows are Closed defer rows.Close() return scanAll(rows, dest, false) } // PreparexContext prepares a statement. // // The provided context is used for the preparation of the statement, not for // the execution of the statement. func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) { s, err := p.PrepareContext(ctx, query) if err != nil { return nil, err } return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err } // GetContext does a QueryRow using the provided Queryer, and scans the // resulting row to dest. If dest is scannable, the result must only have one // column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like // row.Scan would. Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error { r := q.QueryRowxContext(ctx, query, args...) return r.scanAny(dest, false) } // LoadFileContext exec's every statement in a file (as a single call to Exec). // LoadFileContext may return a nil *sql.Result if errors are encountered // locating or reading the file at path. LoadFile reads the entire file into // memory, so it is not suitable for loading large data dumps, but can be useful // for initializing schemas or loading indexes. // // FIXME: this does not really work with multi-statement files for mattn/go-sqlite3 // or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting // this by requiring something with DriverName() and then attempting to split the // queries will be difficult to get right, and its current driver-specific behavior // is deemed at least not complex in its incorrectness. func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) { realpath, err := filepath.Abs(path) if err != nil { return nil, err } contents, err := ioutil.ReadFile(realpath) if err != nil { return nil, err } res, err := e.ExecContext(ctx, string(contents)) return &res, err } // MustExecContext execs the query using e and panics if there was an error. // Any placeholder parameters are replaced with supplied args. func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result { res, err := e.ExecContext(ctx, query, args...) if err != nil { panic(err) } return res } // PrepareNamedContext returns an sqlx.NamedStmt func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { return prepareNamedContext(ctx, db, query) } // NamedQueryContext using this DB. // Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) { return NamedQueryContext(ctx, db, query, arg) } // NamedExecContext using this DB. // Any named placeholder parameters are replaced with fields from arg. func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { return NamedExecContext(ctx, db, query, arg) } // SelectContext using this DB. // Any placeholder parameters are replaced with supplied args. func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return SelectContext(ctx, db, dest, query, args...) } // GetContext using this DB. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return GetContext(ctx, db, dest, query, args...) } // PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. // // The provided context is used for the preparation of the statement, not for // the execution of the statement. func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) { return PreparexContext(ctx, db, query) } // QueryxContext queries the database and returns an *sqlx.Rows. // Any placeholder parameters are replaced with supplied args. func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { r, err := db.DB.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err } // QueryRowxContext queries the database and returns an *sqlx.Row. // Any placeholder parameters are replaced with supplied args. func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := db.DB.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} } // MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead // of an *sql.Tx. // // The provided context is used until the transaction is committed or rolled // back. If the context is canceled, the sql package will roll back the // transaction. Tx.Commit will return an error if the context provided to // MustBeginContext is canceled. func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx { tx, err := db.BeginTxx(ctx, opts) if err != nil { panic(err) } return tx } // MustExecContext (panic) runs MustExec using this database. // Any placeholder parameters are replaced with supplied args. func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { return MustExecContext(ctx, db, query, args...) } // BeginTxx begins a transaction and returns an *sqlx.Tx instead of an // *sql.Tx. // // The provided context is used until the transaction is committed or rolled // back. If the context is canceled, the sql package will roll back the // transaction. Tx.Commit will return an error if the context provided to // BeginxContext is canceled. func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { tx, err := db.DB.BeginTx(ctx, opts) if err != nil { return nil, err } return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err } // Connx returns an *sqlx.Conn instead of an *sql.Conn. func (db *DB) Connx(ctx context.Context) (*Conn, error) { conn, err := db.DB.Conn(ctx) if err != nil { return nil, err } return &Conn{Conn: conn, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, nil } // BeginTxx begins a transaction and returns an *sqlx.Tx instead of an // *sql.Tx. // // The provided context is used until the transaction is committed or rolled // back. If the context is canceled, the sql package will roll back the // transaction. Tx.Commit will return an error if the context provided to // BeginxContext is canceled. func (c *Conn) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { tx, err := c.Conn.BeginTx(ctx, opts) if err != nil { return nil, err } return &Tx{Tx: tx, driverName: c.driverName, unsafe: c.unsafe, Mapper: c.Mapper}, err } // SelectContext using this Conn. // Any placeholder parameters are replaced with supplied args. func (c *Conn) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return SelectContext(ctx, c, dest, query, args...) } // GetContext using this Conn. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (c *Conn) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return GetContext(ctx, c, dest, query, args...) } // PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. // // The provided context is used for the preparation of the statement, not for // the execution of the statement. func (c *Conn) PreparexContext(ctx context.Context, query string) (*Stmt, error) { return PreparexContext(ctx, c, query) } // QueryxContext queries the database and returns an *sqlx.Rows. // Any placeholder parameters are replaced with supplied args. func (c *Conn) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { r, err := c.Conn.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: c.unsafe, Mapper: c.Mapper}, err } // QueryRowxContext queries the database and returns an *sqlx.Row. // Any placeholder parameters are replaced with supplied args. func (c *Conn) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := c.Conn.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err, unsafe: c.unsafe, Mapper: c.Mapper} } // Rebind a query within a Conn's bindvar type. func (c *Conn) Rebind(query string) string { return Rebind(BindType(c.driverName), query) } // StmtxContext returns a version of the prepared statement which runs within a // transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt. func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt { var s *sql.Stmt switch v := stmt.(type) { case Stmt: s = v.Stmt case *Stmt: s = v.Stmt case *sql.Stmt: s = v default: panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type())) } return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper} } // NamedStmtContext returns a version of the prepared statement which runs // within a transaction. func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt { return &NamedStmt{ QueryString: stmt.QueryString, Params: stmt.Params, Stmt: tx.StmtxContext(ctx, stmt.Stmt), } } // PreparexContext returns an sqlx.Stmt instead of a sql.Stmt. // // The provided context is used for the preparation of the statement, not for // the execution of the statement. func (tx *Tx) PreparexContext(ctx context.Context, query string) (*Stmt, error) { return PreparexContext(ctx, tx, query) } // PrepareNamedContext returns an sqlx.NamedStmt func (tx *Tx) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) { return prepareNamedContext(ctx, tx, query) } // MustExecContext runs MustExecContext within a transaction. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result { return MustExecContext(ctx, tx, query, args...) } // QueryxContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { r, err := tx.Tx.QueryContext(ctx, query, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err } // SelectContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return SelectContext(ctx, tx, dest, query, args...) } // GetContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return GetContext(ctx, tx, dest, query, args...) } // QueryRowxContext within a transaction and context. // Any placeholder parameters are replaced with supplied args. func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := tx.Tx.QueryContext(ctx, query, args...) return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} } // NamedExecContext using this Tx. // Any named placeholder parameters are replaced with fields from arg. func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { return NamedExecContext(ctx, tx, query, arg) } // SelectContext using the prepared statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error { return SelectContext(ctx, &qStmt{s}, dest, "", args...) } // GetContext using the prepared statement. // Any placeholder parameters are replaced with supplied args. // An error is returned if the result set is empty. func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error { return GetContext(ctx, &qStmt{s}, dest, "", args...) } // MustExecContext (panic) using this statement. Note that the query portion of // the error output will be blank, as Stmt does not expose its query. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result { return MustExecContext(ctx, &qStmt{s}, "", args...) } // QueryRowxContext using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row { qs := &qStmt{s} return qs.QueryRowxContext(ctx, "", args...) } // QueryxContext using this statement. // Any placeholder parameters are replaced with supplied args. func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) { qs := &qStmt{s} return qs.QueryxContext(ctx, "", args...) } func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { return q.Stmt.QueryContext(ctx, args...) } func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { r, err := q.Stmt.QueryContext(ctx, args...) if err != nil { return nil, err } return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err } func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := q.Stmt.QueryContext(ctx, args...) return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper} } func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { return q.Stmt.ExecContext(ctx, args...) } sqlx-1.3.5/sqlx_context_test.go000066400000000000000000001142511422653721300166410ustar00rootroot00000000000000// +build go1.8 // The following environment variables, if set, will be used: // // * SQLX_SQLITE_DSN // * SQLX_POSTGRES_DSN // * SQLX_MYSQL_DSN // // Set any of these variables to 'skip' to skip them. Note that for MySQL, // the string '?parseTime=True' will be appended to the DSN if it's not there // already. // package sqlx import ( "context" "database/sql" "encoding/json" "fmt" "log" "strings" "testing" "time" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx/reflectx" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) func MultiExecContext(ctx context.Context, e ExecerContext, query string) { stmts := strings.Split(query, ";\n") if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { stmts = stmts[:len(stmts)-1] } for _, s := range stmts { _, err := e.ExecContext(ctx, s) if err != nil { fmt.Println(err, s) } } } func RunWithSchemaContext(ctx context.Context, schema Schema, t *testing.T, test func(ctx context.Context, db *DB, t *testing.T)) { runner := func(ctx context.Context, db *DB, t *testing.T, create, drop, now string) { defer func() { MultiExecContext(ctx, db, drop) }() MultiExecContext(ctx, db, create) test(ctx, db, t) } if TestPostgres { create, drop, now := schema.Postgres() runner(ctx, pgdb, t, create, drop, now) } if TestSqlite { create, drop, now := schema.Sqlite3() runner(ctx, sldb, t, create, drop, now) } if TestMysql { create, drop, now := schema.MySQL() runner(ctx, mysqldb, t, create, drop, now) } } func loadDefaultFixtureContext(ctx context.Context, db *DB, t *testing.T) { tx := db.MustBeginTx(ctx, nil) tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") tx.MustExecContext(ctx, tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") if db.DriverName() == "mysql" { tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") } else { tx.MustExecContext(ctx, tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") } tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") tx.MustExecContext(ctx, tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") tx.Commit() } // Test a new backwards compatible feature, that missing scan destinations // will silently scan into sql.RawText rather than failing/panicing func TestMissingNamesContextContext(t *testing.T) { RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) type PersonPlus struct { FirstName string `db:"first_name"` LastName string `db:"last_name"` Email string //AddedAt time.Time `db:"added_at"` } // test Select first pps := []PersonPlus{} // pps lacks added_at destination err := db.SelectContext(ctx, &pps, "SELECT * FROM person") if err == nil { t.Error("Expected missing name from Select to fail, but it did not.") } // test Get pp := PersonPlus{} err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") if err == nil { t.Error("Expected missing name Get to fail, but it did not.") } // test naked StructScan pps = []PersonPlus{} rows, err := db.QueryContext(ctx, "SELECT * FROM person LIMIT 1") if err != nil { t.Fatal(err) } rows.Next() err = StructScan(rows, &pps) if err == nil { t.Error("Expected missing name in StructScan to fail, but it did not.") } rows.Close() // now try various things with unsafe set. db = db.Unsafe() pps = []PersonPlus{} err = db.SelectContext(ctx, &pps, "SELECT * FROM person") if err != nil { t.Error(err) } // test Get pp = PersonPlus{} err = db.GetContext(ctx, &pp, "SELECT * FROM person LIMIT 1") if err != nil { t.Error(err) } // test naked StructScan pps = []PersonPlus{} rowsx, err := db.QueryxContext(ctx, "SELECT * FROM person LIMIT 1") if err != nil { t.Fatal(err) } rowsx.Next() err = StructScan(rowsx, &pps) if err != nil { t.Error(err) } rowsx.Close() // test Named stmt if !isUnsafe(db) { t.Error("Expected db to be unsafe, but it isn't") } nstmt, err := db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) if err != nil { t.Fatal(err) } // its internal stmt should be marked unsafe if !nstmt.Stmt.unsafe { t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety") } pps = []PersonPlus{} err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) if err != nil { t.Fatal(err) } if len(pps) != 1 { t.Errorf("Expected 1 person back, got %d", len(pps)) } // test it with a safe db db.unsafe = false if isUnsafe(db) { t.Error("expected db to be safe but it isn't") } nstmt, err = db.PrepareNamedContext(ctx, `SELECT * FROM person WHERE first_name != :name`) if err != nil { t.Fatal(err) } // it should be safe if isUnsafe(nstmt) { t.Error("NamedStmt did not inherit safety") } nstmt.Unsafe() if !isUnsafe(nstmt) { t.Error("expected newly unsafed NamedStmt to be unsafe") } pps = []PersonPlus{} err = nstmt.SelectContext(ctx, &pps, map[string]interface{}{"name": "Jason"}) if err != nil { t.Fatal(err) } if len(pps) != 1 { t.Errorf("Expected 1 person back, got %d", len(pps)) } }) } func TestEmbeddedStructsContextContext(t *testing.T) { type Loop1 struct{ Person } type Loop2 struct{ Loop1 } type Loop3 struct{ Loop2 } RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) peopleAndPlaces := []PersonPlace{} err := db.SelectContext( ctx, &peopleAndPlaces, `SELECT person.*, place.* FROM person natural join place`) if err != nil { t.Fatal(err) } for _, pp := range peopleAndPlaces { if len(pp.Person.FirstName) == 0 { t.Errorf("Expected non zero lengthed first name.") } if len(pp.Place.Country) == 0 { t.Errorf("Expected non zero lengthed country.") } } // test embedded structs with StructScan rows, err := db.QueryxContext( ctx, `SELECT person.*, place.* FROM person natural join place`) if err != nil { t.Error(err) } perp := PersonPlace{} rows.Next() err = rows.StructScan(&perp) if err != nil { t.Error(err) } if len(perp.Person.FirstName) == 0 { t.Errorf("Expected non zero lengthed first name.") } if len(perp.Place.Country) == 0 { t.Errorf("Expected non zero lengthed country.") } rows.Close() // test the same for embedded pointer structs peopleAndPlacesPtrs := []PersonPlacePtr{} err = db.SelectContext( ctx, &peopleAndPlacesPtrs, `SELECT person.*, place.* FROM person natural join place`) if err != nil { t.Fatal(err) } for _, pp := range peopleAndPlacesPtrs { if len(pp.Person.FirstName) == 0 { t.Errorf("Expected non zero lengthed first name.") } if len(pp.Place.Country) == 0 { t.Errorf("Expected non zero lengthed country.") } } // test "deep nesting" l3s := []Loop3{} err = db.SelectContext(ctx, &l3s, `select * from person`) if err != nil { t.Fatal(err) } for _, l3 := range l3s { if len(l3.Loop2.Loop1.Person.FirstName) == 0 { t.Errorf("Expected non zero lengthed first name.") } } // test "embed conflicts" ec := []EmbedConflict{} err = db.SelectContext(ctx, &ec, `select * from person`) // I'm torn between erroring here or having some kind of working behavior // in order to allow for more flexibility in destination structs if err != nil { t.Errorf("Was not expecting an error on embed conflicts.") } }) } func TestJoinQueryContext(t *testing.T) { type Employee struct { Name string ID int64 // BossID is an id into the employee table BossID sql.NullInt64 `db:"boss_id"` } type Boss Employee RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) var employees []struct { Employee Boss `db:"boss"` } err := db.SelectContext(ctx, &employees, `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees JOIN employees AS boss ON employees.boss_id = boss.id`) if err != nil { t.Fatal(err) } for _, em := range employees { if len(em.Employee.Name) == 0 { t.Errorf("Expected non zero lengthed name.") } if em.Employee.BossID.Int64 != em.Boss.ID { t.Errorf("Expected boss ids to match") } } }) } func TestJoinQueryNamedPointerStructsContext(t *testing.T) { type Employee struct { Name string ID int64 // BossID is an id into the employee table BossID sql.NullInt64 `db:"boss_id"` } type Boss Employee RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) var employees []struct { Emp1 *Employee `db:"emp1"` Emp2 *Employee `db:"emp2"` *Boss `db:"boss"` } err := db.SelectContext(ctx, &employees, `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", boss.id "boss.id", boss.name "boss.name" FROM employees AS emp JOIN employees AS boss ON emp.boss_id = boss.id `) if err != nil { t.Fatal(err) } for _, em := range employees { if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 { t.Errorf("Expected non zero lengthed name.") } if em.Emp1.BossID.Int64 != em.Boss.ID || em.Emp2.BossID.Int64 != em.Boss.ID { t.Errorf("Expected boss ids to match") } } }) } func TestSelectSliceMapTimeContext(t *testing.T) { RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) rows, err := db.QueryxContext(ctx, "SELECT * FROM person") if err != nil { t.Fatal(err) } for rows.Next() { _, err := rows.SliceScan() if err != nil { t.Error(err) } } rows, err = db.QueryxContext(ctx, "SELECT * FROM person") if err != nil { t.Fatal(err) } for rows.Next() { m := map[string]interface{}{} err := rows.MapScan(m) if err != nil { t.Error(err) } } }) } func TestNilReceiverContext(t *testing.T) { RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) var p *Person err := db.GetContext(ctx, p, "SELECT * FROM person LIMIT 1") if err == nil { t.Error("Expected error when getting into nil struct ptr.") } var pp *[]Person err = db.SelectContext(ctx, pp, "SELECT * FROM person") if err == nil { t.Error("Expected an error when selecting into nil slice ptr.") } }) } func TestNamedQueryContext(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE place ( id integer PRIMARY KEY, name text NULL ); CREATE TABLE person ( first_name text NULL, last_name text NULL, email text NULL ); CREATE TABLE placeperson ( first_name text NULL, last_name text NULL, email text NULL, place_id integer NULL ); CREATE TABLE jsperson ( "FIRST" text NULL, last_name text NULL, "EMAIL" text NULL );`, drop: ` drop table person; drop table jsperson; drop table place; drop table placeperson; `, } RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { type Person struct { FirstName sql.NullString `db:"first_name"` LastName sql.NullString `db:"last_name"` Email sql.NullString } p := Person{ FirstName: sql.NullString{String: "ben", Valid: true}, LastName: sql.NullString{String: "doe", Valid: true}, Email: sql.NullString{String: "ben@doe.com", Valid: true}, } q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` _, err := db.NamedExecContext(ctx, q1, p) if err != nil { log.Fatal(err) } p2 := &Person{} rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) if err != nil { log.Fatal(err) } for rows.Next() { err = rows.StructScan(p2) if err != nil { t.Error(err) } if p2.FirstName.String != "ben" { t.Error("Expected first name of `ben`, got " + p2.FirstName.String) } if p2.LastName.String != "doe" { t.Error("Expected first name of `doe`, got " + p2.LastName.String) } } // these are tests for #73; they verify that named queries work if you've // changed the db mapper. This code checks both NamedQuery "ad-hoc" style // queries and NamedStmt queries, which use different code paths internally. old := *db.Mapper type JSONPerson struct { FirstName sql.NullString `json:"FIRST"` LastName sql.NullString `json:"last_name"` Email sql.NullString } jp := JSONPerson{ FirstName: sql.NullString{String: "ben", Valid: true}, LastName: sql.NullString{String: "smith", Valid: true}, Email: sql.NullString{String: "ben@smith.com", Valid: true}, } db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper) // prepare queries for case sensitivity to test our ToUpper function. // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line // strings are `` we use "" by default and swap out for MySQL pdb := func(s string, db *DB) string { if db.DriverName() == "mysql" { return strings.Replace(s, `"`, "`", -1) } return s } q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` _, err = db.NamedExecContext(ctx, pdb(q1, db), jp) if err != nil { t.Fatal(err, db.DriverName()) } // Checks that a person pulled out of the db matches the one we put in check := func(t *testing.T, rows *Rows) { jp = JSONPerson{} for rows.Next() { err = rows.StructScan(&jp) if err != nil { t.Error(err) } if jp.FirstName.String != "ben" { t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName()) } if jp.LastName.String != "smith" { t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName()) } if jp.Email.String != "ben@smith.com" { t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName()) } } } ns, err := db.PrepareNamed(pdb(` SELECT * FROM jsperson WHERE "FIRST"=:FIRST AND last_name=:last_name AND "EMAIL"=:EMAIL `, db)) if err != nil { t.Fatal(err) } rows, err = ns.QueryxContext(ctx, jp) if err != nil { t.Fatal(err) } check(t, rows) // Check exactly the same thing, but with db.NamedQuery, which does not go // through the PrepareNamed/NamedStmt path. rows, err = db.NamedQueryContext(ctx, pdb(` SELECT * FROM jsperson WHERE "FIRST"=:FIRST AND last_name=:last_name AND "EMAIL"=:EMAIL `, db), jp) if err != nil { t.Fatal(err) } check(t, rows) db.Mapper = &old // Test nested structs type Place struct { ID int `db:"id"` Name sql.NullString `db:"name"` } type PlacePerson struct { FirstName sql.NullString `db:"first_name"` LastName sql.NullString `db:"last_name"` Email sql.NullString Place Place `db:"place"` } pl := Place{ Name: sql.NullString{String: "myplace", Valid: true}, } pp := PlacePerson{ FirstName: sql.NullString{String: "ben", Valid: true}, LastName: sql.NullString{String: "doe", Valid: true}, Email: sql.NullString{String: "ben@doe.com", Valid: true}, } q2 := `INSERT INTO place (id, name) VALUES (1, :name)` _, err = db.NamedExecContext(ctx, q2, pl) if err != nil { log.Fatal(err) } id := 1 pp.Place.ID = id q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` _, err = db.NamedExecContext(ctx, q3, pp) if err != nil { log.Fatal(err) } pp2 := &PlacePerson{} rows, err = db.NamedQueryContext(ctx, ` SELECT first_name, last_name, email, place.id AS "place.id", place.name AS "place.name" FROM placeperson INNER JOIN place ON place.id = placeperson.place_id WHERE place.id=:place.id`, pp) if err != nil { log.Fatal(err) } for rows.Next() { err = rows.StructScan(pp2) if err != nil { t.Error(err) } if pp2.FirstName.String != "ben" { t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) } if pp2.LastName.String != "doe" { t.Error("Expected first name of `doe`, got " + pp2.LastName.String) } if pp2.Place.Name.String != "myplace" { t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) } if pp2.Place.ID != pp.Place.ID { t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) } } }) } func TestNilInsertsContext(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE tt ( id integer, value text NULL DEFAULT NULL );`, drop: "drop table tt;", } RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { type TT struct { ID int Value *string } var v, v2 TT r := db.Rebind db.MustExecContext(ctx, r(`INSERT INTO tt (id) VALUES (1)`)) db.GetContext(ctx, &v, r(`SELECT * FROM tt`)) if v.ID != 1 { t.Errorf("Expecting id of 1, got %v", v.ID) } if v.Value != nil { t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) } v.ID = 2 // NOTE: this incidentally uncovered a bug which was that named queries with // pointer destinations would not work if the passed value here was not addressable, // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly // function. This next line is important as it provides the only coverage for this. db.NamedExecContext(ctx, `INSERT INTO tt (id, value) VALUES (:id, :value)`, v) db.GetContext(ctx, &v2, r(`SELECT * FROM tt WHERE id=2`)) if v.ID != v2.ID { t.Errorf("%v != %v", v.ID, v2.ID) } if v2.Value != nil { t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) } }) } func TestScanErrorContext(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE kv ( k text, v integer );`, drop: `drop table kv;`, } RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { type WrongTypes struct { K int V string } _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) if err != nil { t.Error(err) } rows, err := db.QueryxContext(ctx, "SELECT * FROM kv") if err != nil { t.Error(err) } for rows.Next() { var wt WrongTypes err := rows.StructScan(&wt) if err == nil { t.Errorf("%s: Scanning wrong types into keys should have errored.", db.DriverName()) } } }) } // FIXME: this function is kinda big but it slows things down to be constantly // loading and reloading the schema.. func TestUsageContext(t *testing.T) { RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) slicemembers := []SliceMember{} err := db.SelectContext(ctx, &slicemembers, "SELECT * FROM place ORDER BY telcode ASC") if err != nil { t.Fatal(err) } people := []Person{} err = db.SelectContext(ctx, &people, "SELECT * FROM person ORDER BY first_name ASC") if err != nil { t.Fatal(err) } jason, john := people[0], people[1] if jason.FirstName != "Jason" { t.Errorf("Expecting FirstName of Jason, got %s", jason.FirstName) } if jason.LastName != "Moiron" { t.Errorf("Expecting LastName of Moiron, got %s", jason.LastName) } if jason.Email != "jmoiron@jmoiron.net" { t.Errorf("Expecting Email of jmoiron@jmoiron.net, got %s", jason.Email) } if john.FirstName != "John" || john.LastName != "Doe" || john.Email != "johndoeDNE@gmail.net" { t.Errorf("John Doe's person record not what expected: Got %v\n", john) } jason = Person{} err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") if err != nil { t.Fatal(err) } if jason.FirstName != "Jason" { t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) } err = db.GetContext(ctx, &jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") if err == nil { t.Errorf("Expecting an error, got nil\n") } if err != sql.ErrNoRows { t.Errorf("Expected sql.ErrNoRows, got %v\n", err) } // The following tests check statement reuse, which was actually a problem // due to copying being done when creating Stmt's which was eventually removed stmt1, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { t.Fatal(err) } jason = Person{} row := stmt1.QueryRowx("DoesNotExist") row.Scan(&jason) row = stmt1.QueryRowx("DoesNotExist") row.Scan(&jason) err = stmt1.GetContext(ctx, &jason, "DoesNotExist User") if err == nil { t.Error("Expected an error") } err = stmt1.GetContext(ctx, &jason, "DoesNotExist User 2") if err == nil { t.Fatal(err) } stmt2, err := db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { t.Fatal(err) } jason = Person{} tx, err := db.Beginx() if err != nil { t.Fatal(err) } tstmt2 := tx.Stmtx(stmt2) row2 := tstmt2.QueryRowx("Jason") err = row2.StructScan(&jason) if err != nil { t.Error(err) } tx.Commit() places := []*Place{} err = db.SelectContext(ctx, &places, "SELECT telcode FROM place ORDER BY telcode ASC") if err != nil { t.Fatal(err) } usa, singsing, honkers := places[0], places[1], places[2] if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { t.Errorf("Expected integer telcodes to work, got %#v", places) } placesptr := []PlacePtr{} err = db.SelectContext(ctx, &placesptr, "SELECT * FROM place ORDER BY telcode ASC") if err != nil { t.Error(err) } //fmt.Printf("%#v\n%#v\n%#v\n", placesptr[0], placesptr[1], placesptr[2]) // if you have null fields and use SELECT *, you must use sql.Null* in your struct // this test also verifies that you can use either a []Struct{} or a []*Struct{} places2 := []Place{} err = db.SelectContext(ctx, &places2, "SELECT * FROM place ORDER BY telcode ASC") if err != nil { t.Fatal(err) } usa, singsing, honkers = &places2[0], &places2[1], &places2[2] // this should return a type error that &p is not a pointer to a struct slice p := Place{} err = db.SelectContext(ctx, &p, "SELECT * FROM place ORDER BY telcode ASC") if err == nil { t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") } // this should be an error pl := []Place{} err = db.SelectContext(ctx, pl, "SELECT * FROM place ORDER BY telcode ASC") if err == nil { t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") } if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { t.Errorf("Expected integer telcodes to work, got %#v", places) } stmt, err := db.PreparexContext(ctx, db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) if err != nil { t.Error(err) } places = []*Place{} err = stmt.SelectContext(ctx, &places, 10) if len(places) != 2 { t.Error("Expected 2 places, got 0.") } if err != nil { t.Fatal(err) } singsing, honkers = places[0], places[1] if singsing.TelCode != 65 || honkers.TelCode != 852 { t.Errorf("Expected the right telcodes, got %#v", places) } rows, err := db.QueryxContext(ctx, "SELECT * FROM place") if err != nil { t.Fatal(err) } place := Place{} for rows.Next() { err = rows.StructScan(&place) if err != nil { t.Fatal(err) } } rows, err = db.QueryxContext(ctx, "SELECT * FROM place") if err != nil { t.Fatal(err) } m := map[string]interface{}{} for rows.Next() { err = rows.MapScan(m) if err != nil { t.Fatal(err) } _, ok := m["country"] if !ok { t.Errorf("Expected key `country` in map but could not find it (%#v)\n", m) } } rows, err = db.QueryxContext(ctx, "SELECT * FROM place") if err != nil { t.Fatal(err) } for rows.Next() { s, err := rows.SliceScan() if err != nil { t.Error(err) } if len(s) != 3 { t.Errorf("Expected 3 columns in result, got %d\n", len(s)) } } // test advanced querying // test that NamedExec works with a map as well as a struct _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ "first": "Bin", "last": "Smuth", "email": "bensmith@allblacks.nz", }) if err != nil { t.Fatal(err) } // ensure that if the named param happens right at the end it still works // ensure that NamedQuery works with a map[string]interface{} rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) if err != nil { t.Fatal(err) } ben := &Person{} for rows.Next() { err = rows.StructScan(ben) if err != nil { t.Fatal(err) } if ben.FirstName != "Bin" { t.Fatal("Expected first name of `Bin`, got " + ben.FirstName) } if ben.LastName != "Smuth" { t.Fatal("Expected first name of `Smuth`, got " + ben.LastName) } } ben.FirstName = "Ben" ben.LastName = "Smith" ben.Email = "binsmuth@allblacks.nz" // Insert via a named query using the struct _, err = db.NamedExecContext(ctx, "INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) if err != nil { t.Fatal(err) } rows, err = db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", ben) if err != nil { t.Fatal(err) } for rows.Next() { err = rows.StructScan(ben) if err != nil { t.Fatal(err) } if ben.FirstName != "Ben" { t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) } if ben.LastName != "Smith" { t.Fatal("Expected first name of `Smith`, got " + ben.LastName) } } // ensure that Get does not panic on emppty result set person := &Person{} err = db.GetContext(ctx, person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") if err == nil { t.Fatal("Should have got an error for Get on non-existent row.") } // lets test prepared statements some more stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { t.Fatal(err) } rows, err = stmt.QueryxContext(ctx, "Ben") if err != nil { t.Fatal(err) } for rows.Next() { err = rows.StructScan(ben) if err != nil { t.Fatal(err) } if ben.FirstName != "Ben" { t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) } if ben.LastName != "Smith" { t.Fatal("Expected first name of `Smith`, got " + ben.LastName) } } john = Person{} stmt, err = db.PreparexContext(ctx, db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { t.Error(err) } err = stmt.GetContext(ctx, &john, "John") if err != nil { t.Error(err) } // test name mapping // THIS USED TO WORK BUT WILL NO LONGER WORK. db.MapperFunc(strings.ToUpper) rsa := CPlace{} err = db.GetContext(ctx, &rsa, "SELECT * FROM capplace;") if err != nil { t.Error(err, "in db:", db.DriverName()) } db.MapperFunc(strings.ToLower) // create a copy and change the mapper, then verify the copy behaves // differently from the original. dbCopy := NewDb(db.DB, db.DriverName()) dbCopy.MapperFunc(strings.ToUpper) err = dbCopy.GetContext(ctx, &rsa, "SELECT * FROM capplace;") if err != nil { fmt.Println(db.DriverName()) t.Error(err) } err = db.GetContext(ctx, &rsa, "SELECT * FROM cappplace;") if err == nil { t.Error("Expected no error, got ", err) } // test base type slices var sdest []string rows, err = db.QueryxContext(ctx, "SELECT email FROM person ORDER BY email ASC;") if err != nil { t.Error(err) } err = scanAll(rows, &sdest, false) if err != nil { t.Error(err) } // test Get with base types var count int err = db.GetContext(ctx, &count, "SELECT count(*) FROM person;") if err != nil { t.Error(err) } if count != len(sdest) { t.Errorf("Expected %d == %d (count(*) vs len(SELECT ..)", count, len(sdest)) } // test Get and Select with time.Time, #84 var addedAt time.Time err = db.GetContext(ctx, &addedAt, "SELECT added_at FROM person LIMIT 1;") if err != nil { t.Error(err) } var addedAts []time.Time err = db.SelectContext(ctx, &addedAts, "SELECT added_at FROM person;") if err != nil { t.Error(err) } // test it on a double pointer var pcount *int err = db.GetContext(ctx, &pcount, "SELECT count(*) FROM person;") if err != nil { t.Error(err) } if *pcount != count { t.Errorf("expected %d = %d", *pcount, count) } // test Select... sdest = []string{} err = db.SelectContext(ctx, &sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") if err != nil { t.Error(err) } expected := []string{"Ben", "Bin", "Jason", "John"} for i, got := range sdest { if got != expected[i] { t.Errorf("Expected %d result to be %s, but got %s", i, expected[i], got) } } var nsdest []sql.NullString err = db.SelectContext(ctx, &nsdest, "SELECT city FROM place ORDER BY city ASC") if err != nil { t.Error(err) } for _, val := range nsdest { if val.Valid && val.String != "New York" { t.Errorf("expected single valid result to be `New York`, but got %s", val.String) } } }) } // tests that sqlx will not panic when the wrong driver is passed because // of an automatic nil dereference in sqlx.Open(), which was fixed. func TestDoNotPanicOnConnectContext(t *testing.T) { _, err := ConnectContext(context.Background(), "bogus", "hehe") if err == nil { t.Errorf("Should return error when using bogus driverName") } } func TestEmbeddedMapsContext(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE message ( string text, properties text );`, drop: `drop table message;`, } RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { messages := []Message{ {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, {"Thanks, Joy", PropertyMap{"pull": "request"}}, } q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` for _, m := range messages { _, err := db.NamedExecContext(ctx, q1, m) if err != nil { t.Fatal(err) } } var count int err := db.GetContext(ctx, &count, "SELECT count(*) FROM message") if err != nil { t.Fatal(err) } if count != len(messages) { t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) } var m Message err = db.GetContext(ctx, &m, "SELECT * FROM message LIMIT 1;") if err != nil { t.Fatal(err) } if m.Properties == nil { t.Fatal("Expected m.Properties to not be nil, but it was.") } }) } func TestIssue197Context(t *testing.T) { // this test actually tests for a bug in database/sql: // https://github.com/golang/go/issues/13905 // this potentially makes _any_ named type that is an alias for []byte // unsafe to use in a lot of different ways (basically, unsafe to hold // onto after loading from the database). t.Skip() type mybyte []byte type Var struct{ Raw json.RawMessage } type Var2 struct{ Raw []byte } type Var3 struct{ Raw mybyte } RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { var err error var v, q Var if err = db.GetContext(ctx, &v, `SELECT '{"a": "b"}' AS raw`); err != nil { t.Fatal(err) } if err = db.GetContext(ctx, &q, `SELECT 'null' AS raw`); err != nil { t.Fatal(err) } var v2, q2 Var2 if err = db.GetContext(ctx, &v2, `SELECT '{"a": "b"}' AS raw`); err != nil { t.Fatal(err) } if err = db.GetContext(ctx, &q2, `SELECT 'null' AS raw`); err != nil { t.Fatal(err) } var v3, q3 Var3 if err = db.QueryRowContext(ctx, `SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { t.Fatal(err) } if err = db.QueryRowContext(ctx, `SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { t.Fatal(err) } t.Fail() }) } func TestInContext(t *testing.T) { // some quite normal situations type tr struct { q string args []interface{} c int } tests := []tr{ {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, 7}, {"SELECT * FROM foo WHERE x in (?)", []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, 8}, } for _, test := range tests { q, a, err := In(test.q, test.args...) if err != nil { t.Error(err) } if len(a) != test.c { t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) } if strings.Count(q, "?") != test.c { t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) } } // too many bindVars, but no slices, so short circuits parsing // i'm not sure if this is the right behavior; this query/arg combo // might not work, but we shouldn't parse if we don't need to { orig := "SELECT * FROM foo WHERE x = ? AND y = ?" q, a, err := In(orig, "foo", "bar", "baz") if err != nil { t.Error(err) } if len(a) != 3 { t.Errorf("Expected 3 args, but got %d (%+v)", len(a), a) } if q != orig { t.Error("Expected unchanged query.") } } tests = []tr{ // too many bindvars; slice present so should return error during parse {"SELECT * FROM foo WHERE x = ? and y = ?", []interface{}{"foo", []int{1, 2, 3}, "bar"}, 0}, // empty slice, should return error before parse {"SELECT * FROM foo WHERE x = ?", []interface{}{[]int{}}, 0}, // too *few* bindvars, should return an error {"SELECT * FROM foo WHERE x = ? AND y in (?)", []interface{}{[]int{1, 2, 3}}, 0}, } for _, test := range tests { _, _, err := In(test.q, test.args...) if err == nil { t.Error("Expected an error, but got nil.") } } RunWithSchemaContext(context.Background(), defaultSchema, t, func(ctx context.Context, db *DB, t *testing.T) { loadDefaultFixtureContext(ctx, db, t) //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") //tx.MustExecContext(ctx, tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") telcodes := []int{852, 65} q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" query, args, err := In(q, telcodes) if err != nil { t.Error(err) } query = db.Rebind(query) places := []Place{} err = db.SelectContext(ctx, &places, query, args...) if err != nil { t.Error(err) } if len(places) != 2 { t.Fatalf("Expecting 2 results, got %d", len(places)) } if places[0].TelCode != 65 { t.Errorf("Expecting singapore first, but got %#v", places[0]) } if places[1].TelCode != 852 { t.Errorf("Expecting hong kong second, but got %#v", places[1]) } }) } func TestEmbeddedLiteralsContext(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE x ( k text );`, drop: `drop table x;`, } RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { type t1 struct { K *string } type t2 struct { Inline struct { F string } K *string } db.MustExecContext(ctx, db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") target := t1{} err := db.GetContext(ctx, &target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") if err != nil { t.Error(err) } if *target.K != "one" { t.Error("Expected target.K to be `one`, got ", target.K) } target2 := t2{} err = db.GetContext(ctx, &target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") if err != nil { t.Error(err) } if *target2.K != "one" { t.Errorf("Expected target2.K to be `one`, got `%v`", target2.K) } }) } func TestConn(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE tt_conn ( id integer, value text NULL DEFAULT NULL );`, drop: "drop table tt_conn;", } RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) { conn, err := db.Connx(ctx) defer conn.Close() if err != nil { t.Fatal(err) } _, err = conn.ExecContext(ctx, conn.Rebind(`INSERT INTO tt_conn (id, value) VALUES (?, ?), (?, ?)`), 1, "a", 2, "b") if err != nil { t.Fatal(err) } type s struct { ID int `db:"id"` Value string `db:"value"` } v := []s{} err = conn.SelectContext(ctx, &v, "SELECT * FROM tt_conn ORDER BY id ASC") if err != nil { t.Fatal(err) } if v[0].ID != 1 { t.Errorf("Expecting ID of 1, got %d", v[0].ID) } v1 := s{} err = conn.GetContext(ctx, &v1, conn.Rebind("SELECT * FROM tt_conn WHERE id=?"), 1) if err != nil { t.Fatal(err) } if v1.ID != 1 { t.Errorf("Expecting to get back 1, but got %v\n", v1.ID) } stmt, err := conn.PreparexContext(ctx, conn.Rebind("SELECT * FROM tt_conn WHERE id=?")) if err != nil { t.Fatal(err) } v1 = s{} tx, err := conn.BeginTxx(ctx, nil) if err != nil { t.Fatal(err) } tstmt := tx.Stmtx(stmt) row := tstmt.QueryRowx(1) err = row.StructScan(&v1) if err != nil { t.Error(err) } tx.Commit() if v1.ID != 1 { t.Errorf("Expecting to get back 1, but got %v\n", v1.ID) } rows, err := conn.QueryxContext(ctx, "SELECT * FROM tt_conn") if err != nil { t.Fatal(err) } for rows.Next() { err = rows.StructScan(&v1) if err != nil { t.Fatal(err) } } }) } sqlx-1.3.5/sqlx_test.go000066400000000000000000001367171422653721300151100ustar00rootroot00000000000000// The following environment variables, if set, will be used: // // * SQLX_SQLITE_DSN // * SQLX_POSTGRES_DSN // * SQLX_MYSQL_DSN // // Set any of these variables to 'skip' to skip them. Note that for MySQL, // the string '?parseTime=True' will be appended to the DSN if it's not there // already. // package sqlx import ( "database/sql" "database/sql/driver" "encoding/json" "fmt" "log" "os" "reflect" "strings" "testing" "time" _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx/reflectx" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) /* compile time checks that Db, Tx, Stmt (qStmt) implement expected interfaces */ var _, _ Ext = &DB{}, &Tx{} var _, _ ColScanner = &Row{}, &Rows{} var _ Queryer = &qStmt{} var _ Execer = &qStmt{} var TestPostgres = true var TestSqlite = true var TestMysql = true var sldb *DB var pgdb *DB var mysqldb *DB var active = []*DB{} func init() { ConnectAll() } func ConnectAll() { var err error pgdsn := os.Getenv("SQLX_POSTGRES_DSN") mydsn := os.Getenv("SQLX_MYSQL_DSN") sqdsn := os.Getenv("SQLX_SQLITE_DSN") TestPostgres = pgdsn != "skip" TestMysql = mydsn != "skip" TestSqlite = sqdsn != "skip" if !strings.Contains(mydsn, "parseTime=true") { mydsn += "?parseTime=true" } if TestPostgres { pgdb, err = Connect("postgres", pgdsn) if err != nil { fmt.Printf("Disabling PG tests:\n %v\n", err) TestPostgres = false } } else { fmt.Println("Disabling Postgres tests.") } if TestMysql { mysqldb, err = Connect("mysql", mydsn) if err != nil { fmt.Printf("Disabling MySQL tests:\n %v", err) TestMysql = false } } else { fmt.Println("Disabling MySQL tests.") } if TestSqlite { sldb, err = Connect("sqlite3", sqdsn) if err != nil { fmt.Printf("Disabling SQLite:\n %v", err) TestSqlite = false } } else { fmt.Println("Disabling SQLite tests.") } } type Schema struct { create string drop string } func (s Schema) Postgres() (string, string, string) { return s.create, s.drop, `now()` } func (s Schema) MySQL() (string, string, string) { return strings.Replace(s.create, `"`, "`", -1), s.drop, `now()` } func (s Schema) Sqlite3() (string, string, string) { return strings.Replace(s.create, `now()`, `CURRENT_TIMESTAMP`, -1), s.drop, `CURRENT_TIMESTAMP` } var defaultSchema = Schema{ create: ` CREATE TABLE person ( first_name text, last_name text, email text, added_at timestamp default now() ); CREATE TABLE place ( country text, city text NULL, telcode integer ); CREATE TABLE capplace ( "COUNTRY" text, "CITY" text NULL, "TELCODE" integer ); CREATE TABLE nullperson ( first_name text NULL, last_name text NULL, email text NULL ); CREATE TABLE employees ( name text, id integer, boss_id integer ); `, drop: ` drop table person; drop table place; drop table capplace; drop table nullperson; drop table employees; `, } type Person struct { FirstName string `db:"first_name"` LastName string `db:"last_name"` Email string AddedAt time.Time `db:"added_at"` } type Person2 struct { FirstName sql.NullString `db:"first_name"` LastName sql.NullString `db:"last_name"` Email sql.NullString } type Place struct { Country string City sql.NullString TelCode int } type PlacePtr struct { Country string City *string TelCode int } type PersonPlace struct { Person Place } type PersonPlacePtr struct { *Person *Place } type EmbedConflict struct { FirstName string `db:"first_name"` Person } type SliceMember struct { Country string City sql.NullString TelCode int People []Person `db:"-"` Addresses []Place `db:"-"` } // Note that because of field map caching, we need a new type here // if we've used Place already somewhere in sqlx type CPlace Place func MultiExec(e Execer, query string) { stmts := strings.Split(query, ";\n") if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { stmts = stmts[:len(stmts)-1] } for _, s := range stmts { _, err := e.Exec(s) if err != nil { fmt.Println(err, s) } } } func RunWithSchema(schema Schema, t *testing.T, test func(db *DB, t *testing.T, now string)) { runner := func(db *DB, t *testing.T, create, drop, now string) { defer func() { MultiExec(db, drop) }() MultiExec(db, create) test(db, t, now) } if TestPostgres { create, drop, now := schema.Postgres() runner(pgdb, t, create, drop, now) } if TestSqlite { create, drop, now := schema.Sqlite3() runner(sldb, t, create, drop, now) } if TestMysql { create, drop, now := schema.MySQL() runner(mysqldb, t, create, drop, now) } } func loadDefaultFixture(db *DB, t *testing.T) { tx := db.MustBegin() tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "Jason", "Moiron", "jmoiron@jmoiron.net") tx.MustExec(tx.Rebind("INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"), "John", "Doe", "johndoeDNE@gmail.net") tx.MustExec(tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") if db.DriverName() == "mysql" { tx.MustExec(tx.Rebind("INSERT INTO capplace (`COUNTRY`, `TELCODE`) VALUES (?, ?)"), "Sarf Efrica", "27") } else { tx.MustExec(tx.Rebind("INSERT INTO capplace (\"COUNTRY\", \"TELCODE\") VALUES (?, ?)"), "Sarf Efrica", "27") } tx.MustExec(tx.Rebind("INSERT INTO employees (name, id) VALUES (?, ?)"), "Peter", "4444") tx.MustExec(tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Joe", "1", "4444") tx.MustExec(tx.Rebind("INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"), "Martin", "2", "4444") tx.Commit() } // Test a new backwards compatible feature, that missing scan destinations // will silently scan into sql.RawText rather than failing/panicing func TestMissingNames(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) type PersonPlus struct { FirstName string `db:"first_name"` LastName string `db:"last_name"` Email string //AddedAt time.Time `db:"added_at"` } // test Select first pps := []PersonPlus{} // pps lacks added_at destination err := db.Select(&pps, "SELECT * FROM person") if err == nil { t.Error("Expected missing name from Select to fail, but it did not.") } // test Get pp := PersonPlus{} err = db.Get(&pp, "SELECT * FROM person LIMIT 1") if err == nil { t.Error("Expected missing name Get to fail, but it did not.") } // test naked StructScan pps = []PersonPlus{} rows, err := db.Query("SELECT * FROM person LIMIT 1") if err != nil { t.Fatal(err) } rows.Next() err = StructScan(rows, &pps) if err == nil { t.Error("Expected missing name in StructScan to fail, but it did not.") } rows.Close() // now try various things with unsafe set. db = db.Unsafe() pps = []PersonPlus{} err = db.Select(&pps, "SELECT * FROM person") if err != nil { t.Error(err) } // test Get pp = PersonPlus{} err = db.Get(&pp, "SELECT * FROM person LIMIT 1") if err != nil { t.Error(err) } // test naked StructScan pps = []PersonPlus{} rowsx, err := db.Queryx("SELECT * FROM person LIMIT 1") if err != nil { t.Fatal(err) } rowsx.Next() err = StructScan(rowsx, &pps) if err != nil { t.Error(err) } rowsx.Close() // test Named stmt if !isUnsafe(db) { t.Error("Expected db to be unsafe, but it isn't") } nstmt, err := db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`) if err != nil { t.Fatal(err) } // its internal stmt should be marked unsafe if !nstmt.Stmt.unsafe { t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety") } pps = []PersonPlus{} err = nstmt.Select(&pps, map[string]interface{}{"name": "Jason"}) if err != nil { t.Fatal(err) } if len(pps) != 1 { t.Errorf("Expected 1 person back, got %d", len(pps)) } // test it with a safe db db.unsafe = false if isUnsafe(db) { t.Error("expected db to be safe but it isn't") } nstmt, err = db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`) if err != nil { t.Fatal(err) } // it should be safe if isUnsafe(nstmt) { t.Error("NamedStmt did not inherit safety") } nstmt.Unsafe() if !isUnsafe(nstmt) { t.Error("expected newly unsafed NamedStmt to be unsafe") } pps = []PersonPlus{} err = nstmt.Select(&pps, map[string]interface{}{"name": "Jason"}) if err != nil { t.Fatal(err) } if len(pps) != 1 { t.Errorf("Expected 1 person back, got %d", len(pps)) } }) } func TestEmbeddedStructs(t *testing.T) { type Loop1 struct{ Person } type Loop2 struct{ Loop1 } type Loop3 struct{ Loop2 } RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) peopleAndPlaces := []PersonPlace{} err := db.Select( &peopleAndPlaces, `SELECT person.*, place.* FROM person natural join place`) if err != nil { t.Fatal(err) } for _, pp := range peopleAndPlaces { if len(pp.Person.FirstName) == 0 { t.Errorf("Expected non zero lengthed first name.") } if len(pp.Place.Country) == 0 { t.Errorf("Expected non zero lengthed country.") } } // test embedded structs with StructScan rows, err := db.Queryx( `SELECT person.*, place.* FROM person natural join place`) if err != nil { t.Error(err) } perp := PersonPlace{} rows.Next() err = rows.StructScan(&perp) if err != nil { t.Error(err) } if len(perp.Person.FirstName) == 0 { t.Errorf("Expected non zero lengthed first name.") } if len(perp.Place.Country) == 0 { t.Errorf("Expected non zero lengthed country.") } rows.Close() // test the same for embedded pointer structs peopleAndPlacesPtrs := []PersonPlacePtr{} err = db.Select( &peopleAndPlacesPtrs, `SELECT person.*, place.* FROM person natural join place`) if err != nil { t.Fatal(err) } for _, pp := range peopleAndPlacesPtrs { if len(pp.Person.FirstName) == 0 { t.Errorf("Expected non zero lengthed first name.") } if len(pp.Place.Country) == 0 { t.Errorf("Expected non zero lengthed country.") } } // test "deep nesting" l3s := []Loop3{} err = db.Select(&l3s, `select * from person`) if err != nil { t.Fatal(err) } for _, l3 := range l3s { if len(l3.Loop2.Loop1.Person.FirstName) == 0 { t.Errorf("Expected non zero lengthed first name.") } } // test "embed conflicts" ec := []EmbedConflict{} err = db.Select(&ec, `select * from person`) // I'm torn between erroring here or having some kind of working behavior // in order to allow for more flexibility in destination structs if err != nil { t.Errorf("Was not expecting an error on embed conflicts.") } }) } func TestJoinQuery(t *testing.T) { type Employee struct { Name string ID int64 // BossID is an id into the employee table BossID sql.NullInt64 `db:"boss_id"` } type Boss Employee RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) var employees []struct { Employee Boss `db:"boss"` } err := db.Select( &employees, `SELECT employees.*, boss.id "boss.id", boss.name "boss.name" FROM employees JOIN employees AS boss ON employees.boss_id = boss.id`) if err != nil { t.Fatal(err) } for _, em := range employees { if len(em.Employee.Name) == 0 { t.Errorf("Expected non zero lengthed name.") } if em.Employee.BossID.Int64 != em.Boss.ID { t.Errorf("Expected boss ids to match") } } }) } func TestJoinQueryNamedPointerStructs(t *testing.T) { type Employee struct { Name string ID int64 // BossID is an id into the employee table BossID sql.NullInt64 `db:"boss_id"` } type Boss Employee RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) var employees []struct { Emp1 *Employee `db:"emp1"` Emp2 *Employee `db:"emp2"` *Boss `db:"boss"` } err := db.Select( &employees, `SELECT emp.name "emp1.name", emp.id "emp1.id", emp.boss_id "emp1.boss_id", emp.name "emp2.name", emp.id "emp2.id", emp.boss_id "emp2.boss_id", boss.id "boss.id", boss.name "boss.name" FROM employees AS emp JOIN employees AS boss ON emp.boss_id = boss.id `) if err != nil { t.Fatal(err) } for _, em := range employees { if len(em.Emp1.Name) == 0 || len(em.Emp2.Name) == 0 { t.Errorf("Expected non zero lengthed name.") } if em.Emp1.BossID.Int64 != em.Boss.ID || em.Emp2.BossID.Int64 != em.Boss.ID { t.Errorf("Expected boss ids to match") } } }) } func TestSelectSliceMapTime(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) rows, err := db.Queryx("SELECT * FROM person") if err != nil { t.Fatal(err) } for rows.Next() { _, err := rows.SliceScan() if err != nil { t.Error(err) } } rows, err = db.Queryx("SELECT * FROM person") if err != nil { t.Fatal(err) } for rows.Next() { m := map[string]interface{}{} err := rows.MapScan(m) if err != nil { t.Error(err) } } }) } func TestNilReceiver(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) var p *Person err := db.Get(p, "SELECT * FROM person LIMIT 1") if err == nil { t.Error("Expected error when getting into nil struct ptr.") } var pp *[]Person err = db.Select(pp, "SELECT * FROM person") if err == nil { t.Error("Expected an error when selecting into nil slice ptr.") } }) } func TestNamedQuery(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE place ( id integer PRIMARY KEY, name text NULL ); CREATE TABLE person ( first_name text NULL, last_name text NULL, email text NULL ); CREATE TABLE placeperson ( first_name text NULL, last_name text NULL, email text NULL, place_id integer NULL ); CREATE TABLE jsperson ( "FIRST" text NULL, last_name text NULL, "EMAIL" text NULL );`, drop: ` drop table person; drop table jsperson; drop table place; drop table placeperson; `, } RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) { type Person struct { FirstName sql.NullString `db:"first_name"` LastName sql.NullString `db:"last_name"` Email sql.NullString } p := Person{ FirstName: sql.NullString{String: "ben", Valid: true}, LastName: sql.NullString{String: "doe", Valid: true}, Email: sql.NullString{String: "ben@doe.com", Valid: true}, } q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` _, err := db.NamedExec(q1, p) if err != nil { log.Fatal(err) } p2 := &Person{} rows, err := db.NamedQuery("SELECT * FROM person WHERE first_name=:first_name", p) if err != nil { log.Fatal(err) } for rows.Next() { err = rows.StructScan(p2) if err != nil { t.Error(err) } if p2.FirstName.String != "ben" { t.Error("Expected first name of `ben`, got " + p2.FirstName.String) } if p2.LastName.String != "doe" { t.Error("Expected first name of `doe`, got " + p2.LastName.String) } } // these are tests for #73; they verify that named queries work if you've // changed the db mapper. This code checks both NamedQuery "ad-hoc" style // queries and NamedStmt queries, which use different code paths internally. old := *db.Mapper type JSONPerson struct { FirstName sql.NullString `json:"FIRST"` LastName sql.NullString `json:"last_name"` Email sql.NullString } jp := JSONPerson{ FirstName: sql.NullString{String: "ben", Valid: true}, LastName: sql.NullString{String: "smith", Valid: true}, Email: sql.NullString{String: "ben@smith.com", Valid: true}, } db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper) // prepare queries for case sensitivity to test our ToUpper function. // postgres and sqlite accept "", but mysql uses ``; since Go's multi-line // strings are `` we use "" by default and swap out for MySQL pdb := func(s string, db *DB) string { if db.DriverName() == "mysql" { return strings.Replace(s, `"`, "`", -1) } return s } q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` _, err = db.NamedExec(pdb(q1, db), jp) if err != nil { t.Fatal(err, db.DriverName()) } // Checks that a person pulled out of the db matches the one we put in check := func(t *testing.T, rows *Rows) { jp = JSONPerson{} for rows.Next() { err = rows.StructScan(&jp) if err != nil { t.Error(err) } if jp.FirstName.String != "ben" { t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName()) } if jp.LastName.String != "smith" { t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName()) } if jp.Email.String != "ben@smith.com" { t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName()) } } } ns, err := db.PrepareNamed(pdb(` SELECT * FROM jsperson WHERE "FIRST"=:FIRST AND last_name=:last_name AND "EMAIL"=:EMAIL `, db)) if err != nil { t.Fatal(err) } rows, err = ns.Queryx(jp) if err != nil { t.Fatal(err) } check(t, rows) // Check exactly the same thing, but with db.NamedQuery, which does not go // through the PrepareNamed/NamedStmt path. rows, err = db.NamedQuery(pdb(` SELECT * FROM jsperson WHERE "FIRST"=:FIRST AND last_name=:last_name AND "EMAIL"=:EMAIL `, db), jp) if err != nil { t.Fatal(err) } check(t, rows) db.Mapper = &old // Test nested structs type Place struct { ID int `db:"id"` Name sql.NullString `db:"name"` } type PlacePerson struct { FirstName sql.NullString `db:"first_name"` LastName sql.NullString `db:"last_name"` Email sql.NullString Place Place `db:"place"` } pl := Place{ Name: sql.NullString{String: "myplace", Valid: true}, } pp := PlacePerson{ FirstName: sql.NullString{String: "ben", Valid: true}, LastName: sql.NullString{String: "doe", Valid: true}, Email: sql.NullString{String: "ben@doe.com", Valid: true}, } q2 := `INSERT INTO place (id, name) VALUES (1, :name)` _, err = db.NamedExec(q2, pl) if err != nil { log.Fatal(err) } id := 1 pp.Place.ID = id q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` _, err = db.NamedExec(q3, pp) if err != nil { log.Fatal(err) } pp2 := &PlacePerson{} rows, err = db.NamedQuery(` SELECT first_name, last_name, email, place.id AS "place.id", place.name AS "place.name" FROM placeperson INNER JOIN place ON place.id = placeperson.place_id WHERE place.id=:place.id`, pp) if err != nil { log.Fatal(err) } for rows.Next() { err = rows.StructScan(pp2) if err != nil { t.Error(err) } if pp2.FirstName.String != "ben" { t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) } if pp2.LastName.String != "doe" { t.Error("Expected first name of `doe`, got " + pp2.LastName.String) } if pp2.Place.Name.String != "myplace" { t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) } if pp2.Place.ID != pp.Place.ID { t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) } } }) } func TestNilInserts(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE tt ( id integer, value text NULL DEFAULT NULL );`, drop: "drop table tt;", } RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) { type TT struct { ID int Value *string } var v, v2 TT r := db.Rebind db.MustExec(r(`INSERT INTO tt (id) VALUES (1)`)) db.Get(&v, r(`SELECT * FROM tt`)) if v.ID != 1 { t.Errorf("Expecting id of 1, got %v", v.ID) } if v.Value != nil { t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) } v.ID = 2 // NOTE: this incidentally uncovered a bug which was that named queries with // pointer destinations would not work if the passed value here was not addressable, // as reflectx.FieldByIndexes attempts to allocate nil pointer receivers for // writing. This was fixed by creating & using the reflectx.FieldByIndexesReadOnly // function. This next line is important as it provides the only coverage for this. db.NamedExec(`INSERT INTO tt (id, value) VALUES (:id, :value)`, v) db.Get(&v2, r(`SELECT * FROM tt WHERE id=2`)) if v.ID != v2.ID { t.Errorf("%v != %v", v.ID, v2.ID) } if v2.Value != nil { t.Errorf("Expecting NULL to map to nil, got %s", *v.Value) } }) } func TestScanError(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE kv ( k text, v integer );`, drop: `drop table kv;`, } RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) { type WrongTypes struct { K int V string } _, err := db.Exec(db.Rebind("INSERT INTO kv (k, v) VALUES (?, ?)"), "hi", 1) if err != nil { t.Error(err) } rows, err := db.Queryx("SELECT * FROM kv") if err != nil { t.Error(err) } for rows.Next() { var wt WrongTypes err := rows.StructScan(&wt) if err == nil { t.Errorf("%s: Scanning wrong types into keys should have errored.", db.DriverName()) } } }) } func TestMultiInsert(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) q := db.Rebind(`INSERT INTO employees (name, id) VALUES (?, ?), (?, ?);`) db.MustExec(q, "Name1", 400, "name2", 500, ) }) } // FIXME: this function is kinda big but it slows things down to be constantly // loading and reloading the schema.. func TestUsage(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) slicemembers := []SliceMember{} err := db.Select(&slicemembers, "SELECT * FROM place ORDER BY telcode ASC") if err != nil { t.Fatal(err) } people := []Person{} err = db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC") if err != nil { t.Fatal(err) } jason, john := people[0], people[1] if jason.FirstName != "Jason" { t.Errorf("Expecting FirstName of Jason, got %s", jason.FirstName) } if jason.LastName != "Moiron" { t.Errorf("Expecting LastName of Moiron, got %s", jason.LastName) } if jason.Email != "jmoiron@jmoiron.net" { t.Errorf("Expecting Email of jmoiron@jmoiron.net, got %s", jason.Email) } if john.FirstName != "John" || john.LastName != "Doe" || john.Email != "johndoeDNE@gmail.net" { t.Errorf("John Doe's person record not what expected: Got %v\n", john) } jason = Person{} err = db.Get(&jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Jason") if err != nil { t.Fatal(err) } if jason.FirstName != "Jason" { t.Errorf("Expecting to get back Jason, but got %v\n", jason.FirstName) } err = db.Get(&jason, db.Rebind("SELECT * FROM person WHERE first_name=?"), "Foobar") if err == nil { t.Errorf("Expecting an error, got nil\n") } if err != sql.ErrNoRows { t.Errorf("Expected sql.ErrNoRows, got %v\n", err) } // The following tests check statement reuse, which was actually a problem // due to copying being done when creating Stmt's which was eventually removed stmt1, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { t.Fatal(err) } jason = Person{} row := stmt1.QueryRowx("DoesNotExist") row.Scan(&jason) row = stmt1.QueryRowx("DoesNotExist") row.Scan(&jason) err = stmt1.Get(&jason, "DoesNotExist User") if err == nil { t.Error("Expected an error") } err = stmt1.Get(&jason, "DoesNotExist User 2") if err == nil { t.Fatal(err) } stmt2, err := db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { t.Fatal(err) } jason = Person{} tx, err := db.Beginx() if err != nil { t.Fatal(err) } tstmt2 := tx.Stmtx(stmt2) row2 := tstmt2.QueryRowx("Jason") err = row2.StructScan(&jason) if err != nil { t.Error(err) } tx.Commit() places := []*Place{} err = db.Select(&places, "SELECT telcode FROM place ORDER BY telcode ASC") if err != nil { t.Fatal(err) } usa, singsing, honkers := places[0], places[1], places[2] if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { t.Errorf("Expected integer telcodes to work, got %#v", places) } placesptr := []PlacePtr{} err = db.Select(&placesptr, "SELECT * FROM place ORDER BY telcode ASC") if err != nil { t.Error(err) } //fmt.Printf("%#v\n%#v\n%#v\n", placesptr[0], placesptr[1], placesptr[2]) // if you have null fields and use SELECT *, you must use sql.Null* in your struct // this test also verifies that you can use either a []Struct{} or a []*Struct{} places2 := []Place{} err = db.Select(&places2, "SELECT * FROM place ORDER BY telcode ASC") if err != nil { t.Fatal(err) } usa, singsing, honkers = &places2[0], &places2[1], &places2[2] // this should return a type error that &p is not a pointer to a struct slice p := Place{} err = db.Select(&p, "SELECT * FROM place ORDER BY telcode ASC") if err == nil { t.Errorf("Expected an error, argument to select should be a pointer to a struct slice") } // this should be an error pl := []Place{} err = db.Select(pl, "SELECT * FROM place ORDER BY telcode ASC") if err == nil { t.Errorf("Expected an error, argument to select should be a pointer to a struct slice, not a slice.") } if usa.TelCode != 1 || honkers.TelCode != 852 || singsing.TelCode != 65 { t.Errorf("Expected integer telcodes to work, got %#v", places) } stmt, err := db.Preparex(db.Rebind("SELECT country, telcode FROM place WHERE telcode > ? ORDER BY telcode ASC")) if err != nil { t.Error(err) } places = []*Place{} err = stmt.Select(&places, 10) if len(places) != 2 { t.Error("Expected 2 places, got 0.") } if err != nil { t.Fatal(err) } singsing, honkers = places[0], places[1] if singsing.TelCode != 65 || honkers.TelCode != 852 { t.Errorf("Expected the right telcodes, got %#v", places) } rows, err := db.Queryx("SELECT * FROM place") if err != nil { t.Fatal(err) } place := Place{} for rows.Next() { err = rows.StructScan(&place) if err != nil { t.Fatal(err) } } rows, err = db.Queryx("SELECT * FROM place") if err != nil { t.Fatal(err) } m := map[string]interface{}{} for rows.Next() { err = rows.MapScan(m) if err != nil { t.Fatal(err) } _, ok := m["country"] if !ok { t.Errorf("Expected key `country` in map but could not find it (%#v)\n", m) } } rows, err = db.Queryx("SELECT * FROM place") if err != nil { t.Fatal(err) } for rows.Next() { s, err := rows.SliceScan() if err != nil { t.Error(err) } if len(s) != 3 { t.Errorf("Expected 3 columns in result, got %d\n", len(s)) } } // test advanced querying // test that NamedExec works with a map as well as a struct _, err = db.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first, :last, :email)", map[string]interface{}{ "first": "Bin", "last": "Smuth", "email": "bensmith@allblacks.nz", }) if err != nil { t.Fatal(err) } // ensure that if the named param happens right at the end it still works // ensure that NamedQuery works with a map[string]interface{} rows, err = db.NamedQuery("SELECT * FROM person WHERE first_name=:first", map[string]interface{}{"first": "Bin"}) if err != nil { t.Fatal(err) } ben := &Person{} for rows.Next() { err = rows.StructScan(ben) if err != nil { t.Fatal(err) } if ben.FirstName != "Bin" { t.Fatal("Expected first name of `Bin`, got " + ben.FirstName) } if ben.LastName != "Smuth" { t.Fatal("Expected first name of `Smuth`, got " + ben.LastName) } } ben.FirstName = "Ben" ben.LastName = "Smith" ben.Email = "binsmuth@allblacks.nz" // Insert via a named query using the struct _, err = db.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", ben) if err != nil { t.Fatal(err) } rows, err = db.NamedQuery("SELECT * FROM person WHERE first_name=:first_name", ben) if err != nil { t.Fatal(err) } for rows.Next() { err = rows.StructScan(ben) if err != nil { t.Fatal(err) } if ben.FirstName != "Ben" { t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) } if ben.LastName != "Smith" { t.Fatal("Expected first name of `Smith`, got " + ben.LastName) } } // ensure that Get does not panic on emppty result set person := &Person{} err = db.Get(person, "SELECT * FROM person WHERE first_name=$1", "does-not-exist") if err == nil { t.Fatal("Should have got an error for Get on non-existent row.") } // lets test prepared statements some more stmt, err = db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { t.Fatal(err) } rows, err = stmt.Queryx("Ben") if err != nil { t.Fatal(err) } for rows.Next() { err = rows.StructScan(ben) if err != nil { t.Fatal(err) } if ben.FirstName != "Ben" { t.Fatal("Expected first name of `Ben`, got " + ben.FirstName) } if ben.LastName != "Smith" { t.Fatal("Expected first name of `Smith`, got " + ben.LastName) } } john = Person{} stmt, err = db.Preparex(db.Rebind("SELECT * FROM person WHERE first_name=?")) if err != nil { t.Error(err) } err = stmt.Get(&john, "John") if err != nil { t.Error(err) } // test name mapping // THIS USED TO WORK BUT WILL NO LONGER WORK. db.MapperFunc(strings.ToUpper) rsa := CPlace{} err = db.Get(&rsa, "SELECT * FROM capplace;") if err != nil { t.Error(err, "in db:", db.DriverName()) } db.MapperFunc(strings.ToLower) // create a copy and change the mapper, then verify the copy behaves // differently from the original. dbCopy := NewDb(db.DB, db.DriverName()) dbCopy.MapperFunc(strings.ToUpper) err = dbCopy.Get(&rsa, "SELECT * FROM capplace;") if err != nil { fmt.Println(db.DriverName()) t.Error(err) } err = db.Get(&rsa, "SELECT * FROM cappplace;") if err == nil { t.Error("Expected no error, got ", err) } // test base type slices var sdest []string rows, err = db.Queryx("SELECT email FROM person ORDER BY email ASC;") if err != nil { t.Error(err) } err = scanAll(rows, &sdest, false) if err != nil { t.Error(err) } // test Get with base types var count int err = db.Get(&count, "SELECT count(*) FROM person;") if err != nil { t.Error(err) } if count != len(sdest) { t.Errorf("Expected %d == %d (count(*) vs len(SELECT ..)", count, len(sdest)) } // test Get and Select with time.Time, #84 var addedAt time.Time err = db.Get(&addedAt, "SELECT added_at FROM person LIMIT 1;") if err != nil { t.Error(err) } var addedAts []time.Time err = db.Select(&addedAts, "SELECT added_at FROM person;") if err != nil { t.Error(err) } // test it on a double pointer var pcount *int err = db.Get(&pcount, "SELECT count(*) FROM person;") if err != nil { t.Error(err) } if *pcount != count { t.Errorf("expected %d = %d", *pcount, count) } // test Select... sdest = []string{} err = db.Select(&sdest, "SELECT first_name FROM person ORDER BY first_name ASC;") if err != nil { t.Error(err) } expected := []string{"Ben", "Bin", "Jason", "John"} for i, got := range sdest { if got != expected[i] { t.Errorf("Expected %d result to be %s, but got %s", i, expected[i], got) } } var nsdest []sql.NullString err = db.Select(&nsdest, "SELECT city FROM place ORDER BY city ASC") if err != nil { t.Error(err) } for _, val := range nsdest { if val.Valid && val.String != "New York" { t.Errorf("expected single valid result to be `New York`, but got %s", val.String) } } }) } type Product struct { ProductID int } // tests that sqlx will not panic when the wrong driver is passed because // of an automatic nil dereference in sqlx.Open(), which was fixed. func TestDoNotPanicOnConnect(t *testing.T) { db, err := Connect("bogus", "hehe") if err == nil { t.Errorf("Should return error when using bogus driverName") } if db != nil { t.Errorf("Should not return the db on a connect failure") } } func TestRebind(t *testing.T) { q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` s1 := Rebind(DOLLAR, q1) s2 := Rebind(DOLLAR, q2) if s1 != `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)` { t.Errorf("q1 failed") } if s2 != `INSERT INTO foo (a, b, c) VALUES ($1, $2, "foo"), ("Hi", $3, $4)` { t.Errorf("q2 failed") } s1 = Rebind(AT, q1) s2 = Rebind(AT, q2) if s1 != `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)` { t.Errorf("q1 failed") } if s2 != `INSERT INTO foo (a, b, c) VALUES (@p1, @p2, "foo"), ("Hi", @p3, @p4)` { t.Errorf("q2 failed") } s1 = Rebind(NAMED, q1) s2 = Rebind(NAMED, q2) ex1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES ` + `(:arg1, :arg2, :arg3, :arg4, :arg5, :arg6, :arg7, :arg8, :arg9, :arg10)` if s1 != ex1 { t.Error("q1 failed on Named params") } ex2 := `INSERT INTO foo (a, b, c) VALUES (:arg1, :arg2, "foo"), ("Hi", :arg3, :arg4)` if s2 != ex2 { t.Error("q2 failed on Named params") } } func TestBindMap(t *testing.T) { // Test that it works.. q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` am := map[string]interface{}{ "name": "Jason Moiron", "age": 30, "first": "Jason", "last": "Moiron", } bq, args, _ := bindMap(QUESTION, q1, am) expect := `INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)` if bq != expect { t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) } if args[0].(string) != "Jason Moiron" { t.Errorf("Expected `Jason Moiron`, got %v\n", args[0]) } if args[1].(int) != 30 { t.Errorf("Expected 30, got %v\n", args[1]) } if args[2].(string) != "Jason" { t.Errorf("Expected Jason, got %v\n", args[2]) } if args[3].(string) != "Moiron" { t.Errorf("Expected Moiron, got %v\n", args[3]) } } // Test for #117, embedded nil maps type Message struct { Text string `db:"string"` Properties PropertyMap `db:"properties"` // Stored as JSON in the database } type PropertyMap map[string]string // Implement driver.Valuer and sql.Scanner interfaces on PropertyMap func (p PropertyMap) Value() (driver.Value, error) { if len(p) == 0 { return nil, nil } return json.Marshal(p) } func (p PropertyMap) Scan(src interface{}) error { v := reflect.ValueOf(src) if !v.IsValid() || v.CanAddr() && v.IsNil() { return nil } switch ts := src.(type) { case []byte: return json.Unmarshal(ts, &p) case string: return json.Unmarshal([]byte(ts), &p) default: return fmt.Errorf("Could not not decode type %T -> %T", src, p) } } func TestEmbeddedMaps(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE message ( string text, properties text );`, drop: `drop table message;`, } RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) { messages := []Message{ {"Hello, World", PropertyMap{"one": "1", "two": "2"}}, {"Thanks, Joy", PropertyMap{"pull": "request"}}, } q1 := `INSERT INTO message (string, properties) VALUES (:string, :properties);` for _, m := range messages { _, err := db.NamedExec(q1, m) if err != nil { t.Fatal(err) } } var count int err := db.Get(&count, "SELECT count(*) FROM message") if err != nil { t.Fatal(err) } if count != len(messages) { t.Fatalf("Expected %d messages in DB, found %d", len(messages), count) } var m Message err = db.Get(&m, "SELECT * FROM message LIMIT 1;") if err != nil { t.Fatal(err) } if m.Properties == nil { t.Fatal("Expected m.Properties to not be nil, but it was.") } }) } func TestIssue197(t *testing.T) { // this test actually tests for a bug in database/sql: // https://github.com/golang/go/issues/13905 // this potentially makes _any_ named type that is an alias for []byte // unsafe to use in a lot of different ways (basically, unsafe to hold // onto after loading from the database). t.Skip() type mybyte []byte type Var struct{ Raw json.RawMessage } type Var2 struct{ Raw []byte } type Var3 struct{ Raw mybyte } RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { var err error var v, q Var if err = db.Get(&v, `SELECT '{"a": "b"}' AS raw`); err != nil { t.Fatal(err) } if err = db.Get(&q, `SELECT 'null' AS raw`); err != nil { t.Fatal(err) } var v2, q2 Var2 if err = db.Get(&v2, `SELECT '{"a": "b"}' AS raw`); err != nil { t.Fatal(err) } if err = db.Get(&q2, `SELECT 'null' AS raw`); err != nil { t.Fatal(err) } var v3, q3 Var3 if err = db.QueryRow(`SELECT '{"a": "b"}' AS raw`).Scan(&v3.Raw); err != nil { t.Fatal(err) } if err = db.QueryRow(`SELECT '{"c": "d"}' AS raw`).Scan(&q3.Raw); err != nil { t.Fatal(err) } t.Fail() }) } func TestIn(t *testing.T) { // some quite normal situations type tr struct { q string args []interface{} c int } tests := []tr{ {"SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?", []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}, 7}, {"SELECT * FROM foo WHERE x in (?)", []interface{}{[]int{1, 2, 3, 4, 5, 6, 7, 8}}, 8}, {"SELECT * FROM foo WHERE x = ? AND y in (?)", []interface{}{[]byte("foo"), []int{0, 5, 3}}, 4}, {"SELECT * FROM foo WHERE x = ? AND y IN (?)", []interface{}{sql.NullString{Valid: false}, []string{"a", "b"}}, 3}, } for _, test := range tests { q, a, err := In(test.q, test.args...) if err != nil { t.Error(err) } if len(a) != test.c { t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a) } if strings.Count(q, "?") != test.c { t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?")) } } // too many bindVars, but no slices, so short circuits parsing // i'm not sure if this is the right behavior; this query/arg combo // might not work, but we shouldn't parse if we don't need to { orig := "SELECT * FROM foo WHERE x = ? AND y = ?" q, a, err := In(orig, "foo", "bar", "baz") if err != nil { t.Error(err) } if len(a) != 3 { t.Errorf("Expected 3 args, but got %d (%+v)", len(a), a) } if q != orig { t.Error("Expected unchanged query.") } } tests = []tr{ // too many bindvars; slice present so should return error during parse {"SELECT * FROM foo WHERE x = ? and y = ?", []interface{}{"foo", []int{1, 2, 3}, "bar"}, 0}, // empty slice, should return error before parse {"SELECT * FROM foo WHERE x = ?", []interface{}{[]int{}}, 0}, // too *few* bindvars, should return an error {"SELECT * FROM foo WHERE x = ? AND y in (?)", []interface{}{[]int{1, 2, 3}}, 0}, } for _, test := range tests { _, _, err := In(test.q, test.args...) if err == nil { t.Error("Expected an error, but got nil.") } } RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) //tx.MustExec(tx.Rebind("INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"), "United States", "New York", "1") //tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Hong Kong", "852") //tx.MustExec(tx.Rebind("INSERT INTO place (country, telcode) VALUES (?, ?)"), "Singapore", "65") telcodes := []int{852, 65} q := "SELECT * FROM place WHERE telcode IN(?) ORDER BY telcode" query, args, err := In(q, telcodes) if err != nil { t.Error(err) } query = db.Rebind(query) places := []Place{} err = db.Select(&places, query, args...) if err != nil { t.Error(err) } if len(places) != 2 { t.Fatalf("Expecting 2 results, got %d", len(places)) } if places[0].TelCode != 65 { t.Errorf("Expecting singapore first, but got %#v", places[0]) } if places[1].TelCode != 852 { t.Errorf("Expecting hong kong second, but got %#v", places[1]) } }) } func TestBindStruct(t *testing.T) { var err error q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` type tt struct { Name string Age int First string Last string } type tt2 struct { Field1 string `db:"field_1"` Field2 string `db:"field_2"` } type tt3 struct { tt2 Name string } am := tt{"Jason Moiron", 30, "Jason", "Moiron"} bq, args, _ := bindStruct(QUESTION, q1, am, mapper()) expect := `INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)` if bq != expect { t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) } if args[0].(string) != "Jason Moiron" { t.Errorf("Expected `Jason Moiron`, got %v\n", args[0]) } if args[1].(int) != 30 { t.Errorf("Expected 30, got %v\n", args[1]) } if args[2].(string) != "Jason" { t.Errorf("Expected Jason, got %v\n", args[2]) } if args[3].(string) != "Moiron" { t.Errorf("Expected Moiron, got %v\n", args[3]) } am2 := tt2{"Hello", "World"} bq, args, _ = bindStruct(QUESTION, "INSERT INTO foo (a, b) VALUES (:field_2, :field_1)", am2, mapper()) expect = `INSERT INTO foo (a, b) VALUES (?, ?)` if bq != expect { t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) } if args[0].(string) != "World" { t.Errorf("Expected 'World', got %s\n", args[0].(string)) } if args[1].(string) != "Hello" { t.Errorf("Expected 'Hello', got %s\n", args[1].(string)) } am3 := tt3{Name: "Hello!"} am3.Field1 = "Hello" am3.Field2 = "World" bq, args, err = bindStruct(QUESTION, "INSERT INTO foo (a, b, c) VALUES (:name, :field_1, :field_2)", am3, mapper()) if err != nil { t.Fatal(err) } expect = `INSERT INTO foo (a, b, c) VALUES (?, ?, ?)` if bq != expect { t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect) } if args[0].(string) != "Hello!" { t.Errorf("Expected 'Hello!', got %s\n", args[0].(string)) } if args[1].(string) != "Hello" { t.Errorf("Expected 'Hello', got %s\n", args[1].(string)) } if args[2].(string) != "World" { t.Errorf("Expected 'World', got %s\n", args[0].(string)) } } func TestEmbeddedLiterals(t *testing.T) { var schema = Schema{ create: ` CREATE TABLE x ( k text );`, drop: `drop table x;`, } RunWithSchema(schema, t, func(db *DB, t *testing.T, now string) { type t1 struct { K *string } type t2 struct { Inline struct { F string } K *string } db.MustExec(db.Rebind("INSERT INTO x (k) VALUES (?), (?), (?);"), "one", "two", "three") target := t1{} err := db.Get(&target, db.Rebind("SELECT * FROM x WHERE k=?"), "one") if err != nil { t.Error(err) } if *target.K != "one" { t.Error("Expected target.K to be `one`, got ", target.K) } target2 := t2{} err = db.Get(&target2, db.Rebind("SELECT * FROM x WHERE k=?"), "one") if err != nil { t.Error(err) } if *target2.K != "one" { t.Errorf("Expected target2.K to be `one`, got `%v`", target2.K) } }) } func BenchmarkBindStruct(b *testing.B) { b.StopTimer() q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` type t struct { Name string Age int First string Last string } am := t{"Jason Moiron", 30, "Jason", "Moiron"} b.StartTimer() for i := 0; i < b.N; i++ { bindStruct(DOLLAR, q1, am, mapper()) } } func TestBindNamedMapper(t *testing.T) { type A map[string]interface{} m := reflectx.NewMapperFunc("db", NameMapper) query, args, err := bindNamedMapper(DOLLAR, `select :x`, A{ "x": "X!", }, m) if err != nil { t.Fatal(err) } got := fmt.Sprintf("%s %s", query, args) want := `select $1 [X!]` if got != want { t.Errorf("\ngot: %q\nwant: %q", got, want) } _, _, err = bindNamedMapper(DOLLAR, `select :x`, map[string]string{ "x": "X!", }, m) if err == nil { t.Fatal("err is nil") } if !strings.Contains(err.Error(), "unsupported map type") { t.Errorf("wrong error: %s", err) } } func BenchmarkBindMap(b *testing.B) { b.StopTimer() q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)` am := map[string]interface{}{ "name": "Jason Moiron", "age": 30, "first": "Jason", "last": "Moiron", } b.StartTimer() for i := 0; i < b.N; i++ { bindMap(DOLLAR, q1, am) } } func BenchmarkIn(b *testing.B) { q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` for i := 0; i < b.N; i++ { _, _, _ = In(q, []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}...) } } func BenchmarkIn1k(b *testing.B) { q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` var vals [1000]interface{} for i := 0; i < b.N; i++ { _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) } } func BenchmarkIn1kInt(b *testing.B) { q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` var vals [1000]int for i := 0; i < b.N; i++ { _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) } } func BenchmarkIn1kString(b *testing.B) { q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?` var vals [1000]string for i := 0; i < b.N; i++ { _, _, _ = In(q, []interface{}{"foo", vals[:], "bar"}...) } } func BenchmarkRebind(b *testing.B) { b.StopTimer() q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` b.StartTimer() for i := 0; i < b.N; i++ { Rebind(DOLLAR, q1) Rebind(DOLLAR, q2) } } func BenchmarkRebindBuffer(b *testing.B) { b.StopTimer() q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` q2 := `INSERT INTO foo (a, b, c) VALUES (?, ?, "foo"), ("Hi", ?, ?)` b.StartTimer() for i := 0; i < b.N; i++ { rebindBuff(DOLLAR, q1) rebindBuff(DOLLAR, q2) } } func TestIn130Regression(t *testing.T) { t.Run("[]interface{}{}", func(t *testing.T) { q, args, err := In("SELECT * FROM people WHERE name IN (?)", []interface{}{[]string{"gopher"}}...) if err != nil { t.Fatal(err) } if q != "SELECT * FROM people WHERE name IN (?)" { t.Errorf("got=%v", q) } t.Log(args) for _, a := range args { switch a.(type) { case string: t.Log("ok: string", a) case *string: t.Error("ng: string pointer", a, *a.(*string)) } } }) t.Run("[]string{}", func(t *testing.T) { q, args, err := In("SELECT * FROM people WHERE name IN (?)", []string{"gopher"}) if err != nil { t.Fatal(err) } if q != "SELECT * FROM people WHERE name IN (?)" { t.Errorf("got=%v", q) } t.Log(args) for _, a := range args { switch a.(type) { case string: t.Log("ok: string", a) case *string: t.Error("ng: string pointer", a, *a.(*string)) } } }) } func TestSelectReset(t *testing.T) { RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T, now string) { loadDefaultFixture(db, t) filledDest := []string{"a", "b", "c"} err := db.Select(&filledDest, "SELECT first_name FROM person ORDER BY first_name ASC;") if err != nil { t.Fatal(err) } if len(filledDest) != 2 { t.Errorf("Expected 2 first names, got %d.", len(filledDest)) } expected := []string{"Jason", "John"} for i, got := range filledDest { if got != expected[i] { t.Errorf("Expected %d result to be %s, but got %s.", i, expected[i], got) } } var emptyDest []string err = db.Select(&emptyDest, "SELECT first_name FROM person WHERE first_name = 'Jack';") if err != nil { t.Fatal(err) } // Verify that selecting 0 rows into a nil target didn't create a // non-nil slice. if emptyDest != nil { t.Error("Expected emptyDest to be nil") } }) } sqlx-1.3.5/types/000077500000000000000000000000001422653721300136605ustar00rootroot00000000000000sqlx-1.3.5/types/README.md000066400000000000000000000002661422653721300151430ustar00rootroot00000000000000# types The types package provides some useful types which implement the `sql.Scanner` and `driver.Valuer` interfaces, suitable for use as scan and value targets with database/sql. sqlx-1.3.5/types/types.go000066400000000000000000000102501422653721300153510ustar00rootroot00000000000000package types import ( "bytes" "compress/gzip" "database/sql/driver" "encoding/json" "errors" "io/ioutil" ) // GzippedText is a []byte which transparently gzips data being submitted to // a database and ungzips data being Scanned from a database. type GzippedText []byte // Value implements the driver.Valuer interface, gzipping the raw value of // this GzippedText. func (g GzippedText) Value() (driver.Value, error) { b := make([]byte, 0, len(g)) buf := bytes.NewBuffer(b) w := gzip.NewWriter(buf) w.Write(g) w.Close() return buf.Bytes(), nil } // Scan implements the sql.Scanner interface, ungzipping the value coming off // the wire and storing the raw result in the GzippedText. func (g *GzippedText) Scan(src interface{}) error { var source []byte switch src := src.(type) { case string: source = []byte(src) case []byte: source = src default: return errors.New("Incompatible type for GzippedText") } reader, err := gzip.NewReader(bytes.NewReader(source)) if err != nil { return err } defer reader.Close() b, err := ioutil.ReadAll(reader) if err != nil { return err } *g = GzippedText(b) return nil } // JSONText is a json.RawMessage, which is a []byte underneath. // Value() validates the json format in the source, and returns an error if // the json is not valid. Scan does no validation. JSONText additionally // implements `Unmarshal`, which unmarshals the json within to an interface{} type JSONText json.RawMessage var emptyJSON = JSONText("{}") // MarshalJSON returns the *j as the JSON encoding of j. func (j JSONText) MarshalJSON() ([]byte, error) { if len(j) == 0 { return emptyJSON, nil } return j, nil } // UnmarshalJSON sets *j to a copy of data func (j *JSONText) UnmarshalJSON(data []byte) error { if j == nil { return errors.New("JSONText: UnmarshalJSON on nil pointer") } *j = append((*j)[0:0], data...) return nil } // Value returns j as a value. This does a validating unmarshal into another // RawMessage. If j is invalid json, it returns an error. func (j JSONText) Value() (driver.Value, error) { var m json.RawMessage var err = j.Unmarshal(&m) if err != nil { return []byte{}, err } return []byte(j), nil } // Scan stores the src in *j. No validation is done. func (j *JSONText) Scan(src interface{}) error { var source []byte switch t := src.(type) { case string: source = []byte(t) case []byte: if len(t) == 0 { source = emptyJSON } else { source = t } case nil: *j = emptyJSON default: return errors.New("Incompatible type for JSONText") } *j = append((*j)[0:0], source...) return nil } // Unmarshal unmarshal's the json in j to v, as in json.Unmarshal. func (j *JSONText) Unmarshal(v interface{}) error { if len(*j) == 0 { *j = emptyJSON } return json.Unmarshal([]byte(*j), v) } // String supports pretty printing for JSONText types. func (j JSONText) String() string { return string(j) } // NullJSONText represents a JSONText that may be null. // NullJSONText implements the scanner interface so // it can be used as a scan destination, similar to NullString. type NullJSONText struct { JSONText Valid bool // Valid is true if JSONText is not NULL } // Scan implements the Scanner interface. func (n *NullJSONText) Scan(value interface{}) error { if value == nil { n.JSONText, n.Valid = emptyJSON, false return nil } n.Valid = true return n.JSONText.Scan(value) } // Value implements the driver Valuer interface. func (n NullJSONText) Value() (driver.Value, error) { if !n.Valid { return nil, nil } return n.JSONText.Value() } // BitBool is an implementation of a bool for the MySQL type BIT(1). // This type allows you to avoid wasting an entire byte for MySQL's boolean type TINYINT. type BitBool bool // Value implements the driver.Valuer interface, // and turns the BitBool into a bitfield (BIT(1)) for MySQL storage. func (b BitBool) Value() (driver.Value, error) { if b { return []byte{1}, nil } return []byte{0}, nil } // Scan implements the sql.Scanner interface, // and turns the bitfield incoming from MySQL into a BitBool func (b *BitBool) Scan(src interface{}) error { v, ok := src.([]byte) if !ok { return errors.New("bad []byte type assertion") } *b = v[0] == 1 return nil } sqlx-1.3.5/types/types_test.go000066400000000000000000000050141422653721300164120ustar00rootroot00000000000000package types import "testing" func TestGzipText(t *testing.T) { g := GzippedText("Hello, world") v, err := g.Value() if err != nil { t.Errorf("Was not expecting an error") } err = (&g).Scan(v) if err != nil { t.Errorf("Was not expecting an error") } if string(g) != "Hello, world" { t.Errorf("Was expecting the string we sent in (Hello World), got %s", string(g)) } } func TestJSONText(t *testing.T) { j := JSONText(`{"foo": 1, "bar": 2}`) v, err := j.Value() if err != nil { t.Errorf("Was not expecting an error") } err = (&j).Scan(v) if err != nil { t.Errorf("Was not expecting an error") } m := map[string]interface{}{} j.Unmarshal(&m) if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 { t.Errorf("Expected valid json but got some garbage instead? %#v", m) } j = JSONText(`{"foo": 1, invalid, false}`) v, err = j.Value() if err == nil { t.Errorf("Was expecting invalid json to fail!") } j = JSONText("") v, err = j.Value() if err != nil { t.Errorf("Was not expecting an error") } err = (&j).Scan(v) if err != nil { t.Errorf("Was not expecting an error") } j = JSONText(nil) v, err = j.Value() if err != nil { t.Errorf("Was not expecting an error") } err = (&j).Scan(v) if err != nil { t.Errorf("Was not expecting an error") } } func TestNullJSONText(t *testing.T) { j := NullJSONText{} err := j.Scan(`{"foo": 1, "bar": 2}`) if err != nil { t.Errorf("Was not expecting an error") } v, err := j.Value() if err != nil { t.Errorf("Was not expecting an error") } err = (&j).Scan(v) if err != nil { t.Errorf("Was not expecting an error") } m := map[string]interface{}{} j.Unmarshal(&m) if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 { t.Errorf("Expected valid json but got some garbage instead? %#v", m) } j = NullJSONText{} err = j.Scan(nil) if err != nil { t.Errorf("Was not expecting an error") } if j.Valid != false { t.Errorf("Expected valid to be false, but got true") } } func TestBitBool(t *testing.T) { // Test true value var b BitBool = true v, err := b.Value() if err != nil { t.Errorf("Cannot return error") } err = (&b).Scan(v) if err != nil { t.Errorf("Was not expecting an error") } if !b { t.Errorf("Was expecting the bool we sent in (true), got %v", b) } // Test false value b = false v, err = b.Value() if err != nil { t.Errorf("Cannot return error") } err = (&b).Scan(v) if err != nil { t.Errorf("Was not expecting an error") } if b { t.Errorf("Was expecting the bool we sent in (false), got %v", b) } }