[fix] 处理查询结果字段类型为 slice,struct时报错

This commit is contained in:
what 2023-06-21 10:03:51 +08:00
parent cd33e09d69
commit bf3032103e
2 changed files with 105 additions and 22 deletions

78
scan.go
View File

@ -8,7 +8,7 @@ import (
) )
type RowScan struct { type RowScan struct {
Count int count int
} }
func (RowScan) LastInsertId() (int64, error) { func (RowScan) LastInsertId() (int64, error) {
@ -16,7 +16,7 @@ func (RowScan) LastInsertId() (int64, error) {
} }
func (v RowScan) RowsAffected() (int64, error) { func (v RowScan) RowsAffected() (int64, error) {
return int64(v.Count), nil return int64(v.count), nil
} }
// 入口 // 入口
@ -46,7 +46,7 @@ func ScanResult(rows *sql.Rows, dest any) (result RowScan) {
return sMap(rows, dest) return sMap(rows, dest)
} else { } else {
for rows.Next() { for rows.Next() {
result.Count++ result.count++
err := rows.Scan(dest) err := rows.Scan(dest)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
@ -88,8 +88,12 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
keys[kFiled] = i keys[kFiled] = i
} }
// 提取 json 字段, json 字段, 不能直接赋值,
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
var jFields map[int]reflect.Value
for rows.Next() { for rows.Next() {
result.Count++ result.count++
var v, vp reflect.Value var v, vp reflect.Value
if itemIsPtr { if itemIsPtr {
vp = reflect.New(sliceItem.Elem()) vp = reflect.New(sliceItem.Elem())
@ -101,10 +105,18 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
scanArgs := make([]any, len(columns)) scanArgs := make([]any, len(columns))
jFields = map[int]reflect.Value{}
// 初始化 map 字段 // 初始化 map 字段
for i, k := range columns { for i, k := range columns {
if f, ok := keys[k]; ok { if f, ok := keys[k]; ok {
switch v.Field(f).Kind() {
case reflect.Slice, reflect.Struct, reflect.Map:
scanArgs[i] = &[]uint8{}
jFields[i] = v.Field(f)
default:
scanArgs[i] = v.Field(f).Addr().Interface() scanArgs[i] = v.Field(f).Addr().Interface()
}
} else { } else {
scanArgs[i] = new(any) scanArgs[i] = new(any)
} }
@ -114,6 +126,22 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
panic(err) panic(err)
} }
// json 字符串 反序列化
for i, rv := range jFields {
v := rv.Interface()
data := *scanArgs[i].(*[]uint8)
if len(data) == 0 {
continue
}
if err := json.Unmarshal(data, &v); err != nil {
panic(err)
}
rv.Set(reflect.ValueOf(v))
}
if itemIsPtr { if itemIsPtr {
realDest.Set(reflect.Append(realDest, vp)) realDest.Set(reflect.Append(realDest, vp))
} else { } else {
@ -141,14 +169,26 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
keys[kFiled] = i keys[kFiled] = i
} }
// 提取 json 字段, json 字段, 不能直接赋值,
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
var jFields map[int]reflect.Value
for rows.Next() { for rows.Next() {
if result.Count == 0 { if result.count == 0 {
scanArgs := make([]any, len(columns)) scanArgs := make([]any, len(columns))
jFields = map[int]reflect.Value{}
// 初始化 map 字段 // 初始化 map 字段
for i, k := range columns { for i, k := range columns {
if f, ok := keys[k]; ok { if f, ok := keys[k]; ok {
switch v.Field(f).Kind() {
case reflect.Slice, reflect.Struct, reflect.Map:
scanArgs[i] = &[]uint8{}
jFields[i] = v.Field(f)
default:
scanArgs[i] = v.Field(f).Addr().Interface() scanArgs[i] = v.Field(f).Addr().Interface()
}
} else { } else {
scanArgs[i] = new(any) scanArgs[i] = new(any)
} }
@ -158,6 +198,22 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
panic(err) panic(err)
} }
// json 字符串 反序列化
for i, rv := range jFields {
v := rv.Interface()
data := *scanArgs[i].(*[]uint8)
if len(data) == 0 {
continue
}
if err := json.Unmarshal(data, &v); err != nil {
panic(err)
}
rv.Set(reflect.ValueOf(v))
}
if realDest.Kind() == reflect.Ptr { if realDest.Kind() == reflect.Ptr {
realDest.Set(vp) realDest.Set(vp)
} else { } else {
@ -165,7 +221,7 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
} }
} }
result.Count++ result.count++
} }
return result return result
@ -180,7 +236,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
scanArgs := make([]any, len(columns)) scanArgs := make([]any, len(columns))
element := make(map[string]any) element := make(map[string]any)
result.Count++ result.count++
// 初始化 map 字段 // 初始化 map 字段
for i := 0; i < len(columns); i++ { for i := 0; i < len(columns); i++ {
@ -197,7 +253,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
switch v := (*scanArgs[i].(*any)).(type) { switch v := (*scanArgs[i].(*any)).(type) {
case []uint8: case []uint8:
data := string(v) data := string(v)
if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") { if len(v) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
jData := new(any) jData := new(any)
if err := json.Unmarshal(v, jData); err == nil { if err := json.Unmarshal(v, jData); err == nil {
element[column] = *jData element[column] = *jData
@ -223,7 +279,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) {
realDest := reflect.Indirect(reflect.ValueOf(dest)) realDest := reflect.Indirect(reflect.ValueOf(dest))
for rows.Next() { for rows.Next() {
if result.Count == 0 { if result.count == 0 {
scanArgs := make([]interface{}, len(columns)) scanArgs := make([]interface{}, len(columns))
// 初始化 map 字段 // 初始化 map 字段
for i := 0; i < len(columns); i++ { for i := 0; i < len(columns); i++ {
@ -255,7 +311,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) {
} }
} }
result.Count++ result.count++
} }
return result return result
@ -267,7 +323,7 @@ func sValues(rows *sql.Rows, dest any) (result RowScan) {
realDest := reflect.Indirect(reflect.ValueOf(dest)) realDest := reflect.Indirect(reflect.ValueOf(dest))
scanArgs := make([]interface{}, len(columns)) scanArgs := make([]interface{}, len(columns))
for rows.Next() { for rows.Next() {
result.Count++ result.count++
scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface() scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface()
err := rows.Scan(scanArgs...) err := rows.Scan(scanArgs...)
if err != nil { if err != nil {

View File

@ -4,11 +4,11 @@ import (
"testing" "testing"
) )
var oDB *DB var myConn *Connection
func init() { func init() {
oDB = Open(map[string]DBConfig{ database := Open(map[string]DBConfig{
"default": { "mysql": {
Driver: "mysql", Driver: "mysql",
Host: "localhost", Host: "localhost",
Port: "3366", Port: "3366",
@ -20,12 +20,14 @@ func init() {
// ParseTime: true, // ParseTime: true,
}, },
}) })
myConn = database.Connection("mysql")
} }
func TestScanMapSlice(t *testing.T) { func TestScanMapSlice(t *testing.T) {
dest := []map[string]any{} dest := []map[string]any{}
oDB.Select("select * from fsm", []any{}, &dest) myConn.Select("select * from fsm", []any{}, &dest)
t.Log(dest) t.Log(dest)
} }
@ -34,7 +36,7 @@ func TestScanValues(t *testing.T) {
dest := []int{} dest := []int{}
oDB.Select("select id from fsm", []any{}, &dest) myConn.Select("select 1 as id", []any{}, &dest)
t.Log(dest) t.Log(dest)
} }
@ -43,19 +45,31 @@ func TestScanMap(t *testing.T) {
dest := map[string]any{} dest := map[string]any{}
oDB.Select("select * from fsm where id = 3", []any{}, &dest) myConn.Select("select 0 as id, 'test' as code", []any{}, &dest)
t.Log(dest) t.Log(dest)
} }
func TestScanStruct(t *testing.T) { func TestScanStruct(t *testing.T) {
dest := struct { dest := struct {
Id int `db:"id"` Id int `db:"id"`
Code string `db:"code"` Code string `db:"code"`
}{} }{}
oDB.Select("select * from fsm", []any{}, &dest) myConn.Select(`select 0 as id, 'test' as code`, []any{}, &dest)
t.Log(dest)
}
func TestScanJsonFieldStruct(t *testing.T) {
dest := struct {
Id int `db:"id"`
Code string `db:"code"`
Map map[string]any `db:"map"`
}{}
myConn.Select(`select 0 as id, 'test' as code, '{"a":1}' as map`, []any{}, &dest)
t.Log(dest) t.Log(dest)
} }
@ -67,7 +81,20 @@ func TestScanStructSlice(t *testing.T) {
Code string `db:"code"` Code string `db:"code"`
}{} }{}
oDB.Select("select * from fsm", []any{}, &dest) myConn.Select("select * from ((select 0 as id, 'test' as code) union (select 1 as id, 'test1' as code)) as t", []any{}, &dest)
t.Log(dest)
}
func TestScanJsonFieldStructSlice(t *testing.T) {
dest := []struct {
Id int `db:"id"`
Code string `db:"code"`
Map map[string]any `db:"map"`
}{}
myConn.Select(`select * from ((select 0 as id, 'test' as code, '{"a":1}' as map) union (select 1 as id, 'test1' as code, '{"a":2}' as map)) as t`, []any{}, &dest)
t.Log(dest) t.Log(dest)
} }