Compare commits

..

No commits in common. "master" and "v1.0" have entirely different histories.
master ... v1.0

8 changed files with 84 additions and 239 deletions

View File

@ -323,8 +323,10 @@ func Clone(original *Builder) *Builder {
} }
} else if original.Connection.Config.Driver == DriverSqlite3 { } else if original.Connection.Config.Driver == DriverSqlite3 {
newBuilder.Grammar = &Sqlite3Grammar{ newBuilder.Grammar = &Sqlite3Grammar{
MysqlGrammar: &MysqlGrammar{
Prefix: original.Grammar.GetTablePrefix(), Prefix: original.Grammar.GetTablePrefix(),
Builder: &newBuilder, Builder: &newBuilder,
},
} }
} else { } else {
panic("不支持的数据库类型") panic("不支持的数据库类型")
@ -336,8 +338,7 @@ func (b *Builder) Clone() *Builder {
return Clone(b) return Clone(b)
} }
/* /*CloneWithout
CloneWithout
CloneWithoutClone the query without the given properties. CloneWithoutClone the query without the given properties.
*/ */
func CloneWithout(original *Builder, without ...string) *Builder { func CloneWithout(original *Builder, without ...string) *Builder {
@ -434,31 +435,29 @@ func (b *Builder) SelectRaw(expression string, bindings ...[]interface{}) *Build
//CreateSub Creates a subquery and parse it. //CreateSub Creates a subquery and parse it.
func (b *Builder) CreateSub(query interface{}) (string, []interface{}) { func (b *Builder) CreateSub(query interface{}) (string, []interface{}) {
switch v := query.(type) { var builder *Builder
case func(builder *Builder): if bT, ok := query.(*Builder); ok {
builder := CloneBuilderWithTable(b) builder = bT
v(builder) } else if function, ok := query.(func(builder *Builder)); ok {
return b.ParseSub(builder) builder = CloneBuilderWithTable(b)
case string, Expression, *Builder: function(builder)
return b.ParseSub(v) } else if str, ok := query.(string); ok {
} return b.ParseSub(str)
} else {
panic("can not create sub") panic("can not create sub")
} }
return b.ParseSub(builder)
}
/* /*
ParseSub Parse the subquery into SQL and bindings. ParseSub Parse the subquery into SQL and bindings.
*/ */
func (b *Builder) ParseSub(query interface{}) (string, []interface{}) { func (b *Builder) ParseSub(query interface{}) (string, []interface{}) {
switch v := query.(type) { if s, ok := query.(string); ok {
case string: return s, []interface{}{}
return v, []interface{}{} } else if builder, ok := query.(*Builder); ok {
case Expression: return builder.ToSql(), builder.GetBindings()
return string(v), []interface{}{}
case *Builder:
return v.ToSql(), v.GetBindings()
} }
panic("A subquery must be a query builder instance, a Closure, or a string.") panic("A subquery must be a query builder instance, a Closure, or a string.")
} }
@ -1262,7 +1261,6 @@ func (b *Builder) OrWhereNotIn(params ...interface{}) *Builder {
/* /*
WhereNull Add a "where null" clause to the query. WhereNull Add a "where null" clause to the query.
params takes in below order: params takes in below order:
1. column string 1. column string
2. boolean string in [2]string{"and","or"} 2. boolean string in [2]string{"and","or"}
@ -1353,6 +1351,7 @@ WhereBetween Add a where between statement to the query.
1. WhereBetween(column string,values []interface{"min","max"}) 1. WhereBetween(column string,values []interface{"min","max"})
2. WhereBetween(column string,values []interface{"min","max"},"and/or") 2. WhereBetween(column string,values []interface{"min","max"},"and/or")
3. WhereBetween(column string,values []interface{"min","max","and/or",true/false}) 3. WhereBetween(column string,values []interface{"min","max","and/or",true/false})
*/ */
func (b *Builder) WhereBetween(params ...interface{}) *Builder { func (b *Builder) WhereBetween(params ...interface{}) *Builder {
paramsLength := len(params) paramsLength := len(params)
@ -1995,6 +1994,7 @@ func (b *Builder) DoesntExist() (notExists bool, err error) {
/* /*
Aggregate Execute an aggregate function on the database. Aggregate Execute an aggregate function on the database.
*/ */
func (b *Builder) Aggregate(dest interface{}, fn string, column ...string) (result sql.Result, err error) { func (b *Builder) Aggregate(dest interface{}, fn string, column ...string) (result sql.Result, err error) {
b.Dest = dest b.Dest = dest
@ -2308,6 +2308,7 @@ When Apply the callback if the given "value" is truthy.
1. When(true,func(builder *Builder)) 1. When(true,func(builder *Builder))
2. When(true,func(builder *Builder),func(builder *Builder)) //with default callback 2. When(true,func(builder *Builder),func(builder *Builder)) //with default callback
*/ */
func (b *Builder) When(boolean bool, cb ...func(builder *Builder)) *Builder { func (b *Builder) When(boolean bool, cb ...func(builder *Builder)) *Builder {
if boolean { if boolean {
@ -2322,8 +2323,7 @@ func (b *Builder) Value(dest interface{}, column string) (sql.Result, error) {
return b.First(dest, column) return b.First(dest, column)
} }
/* /*Reset
Reset
reset bindings and components reset bindings and components
*/ */
func (b *Builder) Reset(targets ...string) *Builder { func (b *Builder) Reset(targets ...string) *Builder {
@ -2353,10 +2353,6 @@ func (b *Builder) Reset(targets ...string) *Builder {
delete(b.Bindings, TYPE_COLUMN) delete(b.Bindings, TYPE_COLUMN)
delete(b.Components, TYPE_COLUMN) delete(b.Components, TYPE_COLUMN)
b.Columns = nil b.Columns = nil
case TYPE_JOIN:
delete(b.Bindings, TYPE_JOIN)
delete(b.Components, TYPE_JOIN)
b.Joins = nil
} }
} }
return b return b
@ -2386,8 +2382,7 @@ func (b *Builder) After(callback func(builder *Builder, t string, res sql.Result
return b return b
} }
/* /*ApplyBeforeQueryCallBacks
ApplyBeforeQueryCallBacks
Invoke the "before query" modification callbacks. Invoke the "before query" modification callbacks.
*/ */
func (b *Builder) ApplyBeforeQueryCallBacks(t string, items ...map[string]any) { func (b *Builder) ApplyBeforeQueryCallBacks(t string, items ...map[string]any) {

View File

@ -1,10 +0,0 @@
package db
import "testing"
func TestJoinSub(t *testing.T) {
b := NewBuilder(conn)
b.Grammar = &Sqlite3Grammar{}
t.Log(b.On("Select 1 as a, 2 as b", "t1", "true").ToSql())
}

22
db.go
View File

@ -17,20 +17,24 @@ func (d *DB) SetLogger(f func(log Log)) *DB {
} }
func Open(config map[string]DBConfig) *DB { func Open(config map[string]DBConfig) *DB {
if Engine == nil { var configP = make(map[string]*DBConfig)
Engine = &DB{
DatabaseManager: DatabaseManager{
Configs: make(map[string]*DBConfig),
Connections: make(map[string]*Connection),
},
}
}
for name := range config { for name := range config {
c := config[name] c := config[name]
Engine.Configs[name] = &c configP[name] = &c
} }
db := DB{
DatabaseManager: DatabaseManager{
Configs: configP,
Connections: make(map[string]*Connection),
},
}
db.Connection("default")
Engine = &db
return Engine return Engine
} }

85
scan.go
View File

@ -4,14 +4,11 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"reflect" "reflect"
"github.com/samber/lo"
) )
type RowScan struct { type RowScan struct {
count int Count int
} }
func (RowScan) LastInsertId() (int64, error) { func (RowScan) LastInsertId() (int64, error) {
@ -19,7 +16,7 @@ func (RowScan) LastInsertId() (int64, error) {
} }
func (v RowScan) RowsAffected() (int64, error) { func (v RowScan) RowsAffected() (int64, error) {
return int64(v.count), nil return int64(v.Count), nil
} }
// 入口 // 入口
@ -49,7 +46,7 @@ func ScanResult(rows *sql.Rows, dest any) (result RowScan) {
return sMap(rows, dest) return sMap(rows, dest)
} else { } else {
for rows.Next() { for rows.Next() {
result.count++ result.Count++
err := rows.Scan(dest) err := rows.Scan(dest)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
@ -91,12 +88,8 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
keys[kFiled] = i keys[kFiled] = i
} }
// 提取 json 字段, json 字段, 不能直接赋值,
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
var jFields map[int]reflect.Value
for rows.Next() { for rows.Next() {
result.count++ result.Count++
var v, vp reflect.Value var v, vp reflect.Value
if itemIsPtr { if itemIsPtr {
vp = reflect.New(sliceItem.Elem()) vp = reflect.New(sliceItem.Elem())
@ -108,21 +101,10 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
scanArgs := make([]any, len(columns)) scanArgs := make([]any, len(columns))
jFields = map[int]reflect.Value{}
// 初始化 map 字段 // 初始化 map 字段
for i, k := range columns { for i, k := range columns {
if f, ok := keys[k]; ok { if f, ok := keys[k]; ok {
switch v.Field(f).Kind() {
case reflect.Slice, reflect.Struct, reflect.Map:
scanArgs[i] = &[]uint8{}
jFields[i] = v.Field(f)
case reflect.String:
scanArgs[i] = &sql.NullString{}
jFields[i] = v.Field(f)
default:
scanArgs[i] = v.Field(f).Addr().Interface() scanArgs[i] = v.Field(f).Addr().Interface()
}
} else { } else {
scanArgs[i] = new(any) scanArgs[i] = new(any)
} }
@ -132,21 +114,6 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
panic(err) panic(err)
} }
// json 字符串 反序列化
for i, rv := range jFields {
switch rv.Kind() {
case reflect.String:
rv.SetString(scanArgs[i].(*sql.NullString).String)
default:
if data := *scanArgs[i].(*[]uint8); len(data) > 0 {
v := rv.Addr().Interface()
if err := json.Unmarshal(data, &v); err != nil {
panic(fmt.Sprintf("[%s:%s:%s] => %s", lo.Must1(lo.FindKey(keys, i)), rv.Kind(), scanArgs[i], err))
}
}
}
}
if itemIsPtr { if itemIsPtr {
realDest.Set(reflect.Append(realDest, vp)) realDest.Set(reflect.Append(realDest, vp))
} else { } else {
@ -174,29 +141,14 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
keys[kFiled] = i keys[kFiled] = i
} }
// 提取 json 字段, json 字段, 不能直接赋值,
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
var jFields map[int]reflect.Value
for rows.Next() { for rows.Next() {
if result.count == 0 { if result.Count == 0 {
scanArgs := make([]any, len(columns)) scanArgs := make([]any, len(columns))
jFields = map[int]reflect.Value{}
// 初始化 map 字段 // 初始化 map 字段
for i, k := range columns { for i, k := range columns {
if f, ok := keys[k]; ok { if f, ok := keys[k]; ok {
switch v.Field(f).Kind() {
case reflect.Slice, reflect.Struct, reflect.Map:
scanArgs[i] = &[]uint8{}
jFields[i] = v.Field(f)
case reflect.String:
scanArgs[i] = &sql.NullString{}
jFields[i] = v.Field(f)
default:
scanArgs[i] = v.Field(f).Addr().Interface() scanArgs[i] = v.Field(f).Addr().Interface()
}
} else { } else {
scanArgs[i] = new(any) scanArgs[i] = new(any)
} }
@ -206,21 +158,6 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
panic(err) panic(err)
} }
// json 字符串 反序列化
for i, rv := range jFields {
switch rv.Kind() {
case reflect.String:
rv.SetString(scanArgs[i].(*sql.NullString).String)
default:
if data := *scanArgs[i].(*[]uint8); len(data) > 0 {
v := rv.Addr().Interface()
if err := json.Unmarshal(data, &v); err != nil {
panic(fmt.Sprintf("[%s:%s:%s] => %s", lo.Must1(lo.FindKey(keys, i)), rv.Kind(), scanArgs[i], err))
}
}
}
}
if realDest.Kind() == reflect.Ptr { if realDest.Kind() == reflect.Ptr {
realDest.Set(vp) realDest.Set(vp)
} else { } else {
@ -228,7 +165,7 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
} }
} }
result.count++ result.Count++
} }
return result return result
@ -243,7 +180,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
scanArgs := make([]any, len(columns)) scanArgs := make([]any, len(columns))
element := make(map[string]any) element := make(map[string]any)
result.count++ result.Count++
// 初始化 map 字段 // 初始化 map 字段
for i := 0; i < len(columns); i++ { for i := 0; i < len(columns); i++ {
@ -260,7 +197,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
switch v := (*scanArgs[i].(*any)).(type) { switch v := (*scanArgs[i].(*any)).(type) {
case []uint8: case []uint8:
data := string(v) data := string(v)
if len(v) > 2 && (data[0:1] == "{" || data[0:1] == "[") { if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
jData := new(any) jData := new(any)
if err := json.Unmarshal(v, jData); err == nil { if err := json.Unmarshal(v, jData); err == nil {
element[column] = *jData element[column] = *jData
@ -286,7 +223,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) {
realDest := reflect.Indirect(reflect.ValueOf(dest)) realDest := reflect.Indirect(reflect.ValueOf(dest))
for rows.Next() { for rows.Next() {
if result.count == 0 { if result.Count == 0 {
scanArgs := make([]interface{}, len(columns)) scanArgs := make([]interface{}, len(columns))
// 初始化 map 字段 // 初始化 map 字段
for i := 0; i < len(columns); i++ { for i := 0; i < len(columns); i++ {
@ -318,7 +255,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) {
} }
} }
result.count++ result.Count++
} }
return result return result
@ -330,7 +267,7 @@ func sValues(rows *sql.Rows, dest any) (result RowScan) {
realDest := reflect.Indirect(reflect.ValueOf(dest)) realDest := reflect.Indirect(reflect.ValueOf(dest))
scanArgs := make([]interface{}, len(columns)) scanArgs := make([]interface{}, len(columns))
for rows.Next() { for rows.Next() {
result.count++ result.Count++
scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface() scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface()
err := rows.Scan(scanArgs...) err := rows.Scan(scanArgs...)
if err != nil { if err != nil {

View File

@ -4,11 +4,11 @@ import (
"testing" "testing"
) )
var myConn *Connection var oDB *DB
func init() { func init() {
database := Open(map[string]DBConfig{ oDB = Open(map[string]DBConfig{
"mysql": { "default": {
Driver: "mysql", Driver: "mysql",
Host: "localhost", Host: "localhost",
Port: "3366", Port: "3366",
@ -20,14 +20,12 @@ func init() {
// ParseTime: true, // ParseTime: true,
}, },
}) })
myConn = database.Connection("mysql")
} }
func TestScanMapSlice(t *testing.T) { func TestScanMapSlice(t *testing.T) {
dest := []map[string]any{} dest := []map[string]any{}
myConn.Select("select * from fsm", []any{}, &dest) oDB.Select("select * from fsm", []any{}, &dest)
t.Log(dest) t.Log(dest)
} }
@ -36,7 +34,7 @@ func TestScanValues(t *testing.T) {
dest := []int{} dest := []int{}
myConn.Select("select 1 as id", []any{}, &dest) oDB.Select("select id from fsm", []any{}, &dest)
t.Log(dest) t.Log(dest)
} }
@ -45,32 +43,19 @@ func TestScanMap(t *testing.T) {
dest := map[string]any{} dest := map[string]any{}
myConn.Select("select 0 as id, 'test' as code", []any{}, &dest) oDB.Select("select * from fsm where id = 3", []any{}, &dest)
t.Log(dest) t.Log(dest)
} }
func TestScanStruct(t *testing.T) { func TestScanStruct(t *testing.T) {
dest := struct { dest := struct {
Id int `db:"id"` Id int `db:"id"`
Code string `db:"code"` Code string `db:"code"`
}{} }{}
myConn.Select(`select 0 as id, 'test' as code`, []any{}, &dest) oDB.Select("select * from fsm", []any{}, &dest)
t.Log(dest)
}
func TestScanJsonFieldStruct(t *testing.T) {
dest := struct {
Id int `db:"id"`
Code string `db:"code"`
Map map[string]any `db:"map"`
Arr []int64 `db:"arr"`
}{}
myConn.Select(`select 0 as id, NULL as code, '{"a":1}' as map, '[3, 4]' as arr`, []any{}, &dest)
t.Log(dest) t.Log(dest)
} }
@ -82,21 +67,7 @@ func TestScanStructSlice(t *testing.T) {
Code string `db:"code"` Code string `db:"code"`
}{} }{}
myConn.Select("select * from ((select 0 as id, 'test' as code) union (select 1 as id, 'test1' as code)) as t", []any{}, &dest) oDB.Select("select * from fsm", []any{}, &dest)
t.Log(dest)
}
func TestScanJsonFieldStructSlice(t *testing.T) {
dest := []struct {
Id int `db:"id"`
Code string `db:"code"`
Map map[string]any `db:"map"`
Arr []any `db:"arr"`
}{}
myConn.Select(`select * from ((select 0 as id, 'test' as code, '{"a":1}' as map, '[1, 2]' as arr) union (select 1 as id, 'test1' as code, '{"a":2}' as map, '[1, 2]' as arr)) as t`, []any{}, &dest)
t.Log(dest) t.Log(dest)
} }

View File

@ -1,3 +1,5 @@
package db package db
type Sqlite3Grammar = MysqlGrammar type Sqlite3Grammar struct {
*MysqlGrammar
}

View File

@ -1,36 +0,0 @@
package db
import (
"os"
"testing"
)
var conn *Connection
func TestMain(m *testing.M) {
database := Open(map[string]DBConfig{
"sqlite3": {
Driver: "sqlite3",
File: ":memory:",
},
})
conn = database.Connection("sqlite3")
// 执行测试
code := m.Run()
// 退出测试
os.Exit(code)
}
func TestSqliteQuery(t *testing.T) {
dest := []map[string]any{}
t.Log(conn.Select("select 1 as a, 2 as b", nil, &dest))
t.Log("result: ", dest)
}
func TestSqliteBuilder(t *testing.T) {
t.Log(conn.Query().FromSub("select 1 as a, 2 as b", "tb").ToSql())
}

18
util.go
View File

@ -75,21 +75,3 @@ func Stripslashes(str string) string {
} }
return string(dstRune) return string(dstRune)
} }
func MysqlRealEscapeString(value string) string {
var sb strings.Builder
for i := 0; i < len(value); i++ {
c := value[i]
switch c {
case '\\', 0, '\n', '\r', '\'', '"':
sb.WriteByte('\\')
sb.WriteByte(c)
case '\032':
sb.WriteByte('\\')
sb.WriteByte('Z')
default:
sb.WriteByte(c)
}
}
return sb.String()
}