[fix] 处理查询结果字段类型为 slice,struct时报错
This commit is contained in:
parent
cd33e09d69
commit
bf3032103e
78
scan.go
78
scan.go
@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RowScan struct {
|
type RowScan struct {
|
||||||
Count int
|
count int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (RowScan) LastInsertId() (int64, error) {
|
func (RowScan) LastInsertId() (int64, error) {
|
||||||
@ -16,7 +16,7 @@ func (RowScan) LastInsertId() (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (v RowScan) RowsAffected() (int64, error) {
|
func (v RowScan) RowsAffected() (int64, error) {
|
||||||
return int64(v.Count), nil
|
return int64(v.count), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 入口
|
// 入口
|
||||||
@ -46,7 +46,7 @@ func ScanResult(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
return sMap(rows, dest)
|
return sMap(rows, dest)
|
||||||
} else {
|
} else {
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
result.Count++
|
result.count++
|
||||||
err := rows.Scan(dest)
|
err := rows.Scan(dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err.Error())
|
panic(err.Error())
|
||||||
@ -88,8 +88,12 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
keys[kFiled] = i
|
keys[kFiled] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 提取 json 字段, json 字段, 不能直接赋值,
|
||||||
|
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
|
||||||
|
var jFields map[int]reflect.Value
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
result.Count++
|
result.count++
|
||||||
var v, vp reflect.Value
|
var v, vp reflect.Value
|
||||||
if itemIsPtr {
|
if itemIsPtr {
|
||||||
vp = reflect.New(sliceItem.Elem())
|
vp = reflect.New(sliceItem.Elem())
|
||||||
@ -101,10 +105,18 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
|
|
||||||
scanArgs := make([]any, len(columns))
|
scanArgs := make([]any, len(columns))
|
||||||
|
|
||||||
|
jFields = map[int]reflect.Value{}
|
||||||
|
|
||||||
// 初始化 map 字段
|
// 初始化 map 字段
|
||||||
for i, k := range columns {
|
for i, k := range columns {
|
||||||
if f, ok := keys[k]; ok {
|
if f, ok := keys[k]; ok {
|
||||||
|
switch v.Field(f).Kind() {
|
||||||
|
case reflect.Slice, reflect.Struct, reflect.Map:
|
||||||
|
scanArgs[i] = &[]uint8{}
|
||||||
|
jFields[i] = v.Field(f)
|
||||||
|
default:
|
||||||
scanArgs[i] = v.Field(f).Addr().Interface()
|
scanArgs[i] = v.Field(f).Addr().Interface()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
scanArgs[i] = new(any)
|
scanArgs[i] = new(any)
|
||||||
}
|
}
|
||||||
@ -114,6 +126,22 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// json 字符串 反序列化
|
||||||
|
for i, rv := range jFields {
|
||||||
|
v := rv.Interface()
|
||||||
|
data := *scanArgs[i].(*[]uint8)
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, &v); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rv.Set(reflect.ValueOf(v))
|
||||||
|
}
|
||||||
|
|
||||||
if itemIsPtr {
|
if itemIsPtr {
|
||||||
realDest.Set(reflect.Append(realDest, vp))
|
realDest.Set(reflect.Append(realDest, vp))
|
||||||
} else {
|
} else {
|
||||||
@ -141,14 +169,26 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
keys[kFiled] = i
|
keys[kFiled] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 提取 json 字段, json 字段, 不能直接赋值,
|
||||||
|
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
|
||||||
|
var jFields map[int]reflect.Value
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if result.Count == 0 {
|
if result.count == 0 {
|
||||||
scanArgs := make([]any, len(columns))
|
scanArgs := make([]any, len(columns))
|
||||||
|
|
||||||
|
jFields = map[int]reflect.Value{}
|
||||||
|
|
||||||
// 初始化 map 字段
|
// 初始化 map 字段
|
||||||
for i, k := range columns {
|
for i, k := range columns {
|
||||||
if f, ok := keys[k]; ok {
|
if f, ok := keys[k]; ok {
|
||||||
|
switch v.Field(f).Kind() {
|
||||||
|
case reflect.Slice, reflect.Struct, reflect.Map:
|
||||||
|
scanArgs[i] = &[]uint8{}
|
||||||
|
jFields[i] = v.Field(f)
|
||||||
|
default:
|
||||||
scanArgs[i] = v.Field(f).Addr().Interface()
|
scanArgs[i] = v.Field(f).Addr().Interface()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
scanArgs[i] = new(any)
|
scanArgs[i] = new(any)
|
||||||
}
|
}
|
||||||
@ -158,6 +198,22 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// json 字符串 反序列化
|
||||||
|
for i, rv := range jFields {
|
||||||
|
v := rv.Interface()
|
||||||
|
data := *scanArgs[i].(*[]uint8)
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, &v); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rv.Set(reflect.ValueOf(v))
|
||||||
|
}
|
||||||
|
|
||||||
if realDest.Kind() == reflect.Ptr {
|
if realDest.Kind() == reflect.Ptr {
|
||||||
realDest.Set(vp)
|
realDest.Set(vp)
|
||||||
} else {
|
} else {
|
||||||
@ -165,7 +221,7 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result.Count++
|
result.count++
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -180,7 +236,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
scanArgs := make([]any, len(columns))
|
scanArgs := make([]any, len(columns))
|
||||||
element := make(map[string]any)
|
element := make(map[string]any)
|
||||||
|
|
||||||
result.Count++
|
result.count++
|
||||||
|
|
||||||
// 初始化 map 字段
|
// 初始化 map 字段
|
||||||
for i := 0; i < len(columns); i++ {
|
for i := 0; i < len(columns); i++ {
|
||||||
@ -197,7 +253,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
switch v := (*scanArgs[i].(*any)).(type) {
|
switch v := (*scanArgs[i].(*any)).(type) {
|
||||||
case []uint8:
|
case []uint8:
|
||||||
data := string(v)
|
data := string(v)
|
||||||
if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
|
if len(v) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
|
||||||
jData := new(any)
|
jData := new(any)
|
||||||
if err := json.Unmarshal(v, jData); err == nil {
|
if err := json.Unmarshal(v, jData); err == nil {
|
||||||
element[column] = *jData
|
element[column] = *jData
|
||||||
@ -223,7 +279,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if result.Count == 0 {
|
if result.count == 0 {
|
||||||
scanArgs := make([]interface{}, len(columns))
|
scanArgs := make([]interface{}, len(columns))
|
||||||
// 初始化 map 字段
|
// 初始化 map 字段
|
||||||
for i := 0; i < len(columns); i++ {
|
for i := 0; i < len(columns); i++ {
|
||||||
@ -255,7 +311,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result.Count++
|
result.count++
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -267,7 +323,7 @@ func sValues(rows *sql.Rows, dest any) (result RowScan) {
|
|||||||
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
||||||
scanArgs := make([]interface{}, len(columns))
|
scanArgs := make([]interface{}, len(columns))
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
result.Count++
|
result.count++
|
||||||
scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface()
|
scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface()
|
||||||
err := rows.Scan(scanArgs...)
|
err := rows.Scan(scanArgs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
45
scan_test.go
45
scan_test.go
@ -4,11 +4,11 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
var oDB *DB
|
var myConn *Connection
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
oDB = Open(map[string]DBConfig{
|
database := Open(map[string]DBConfig{
|
||||||
"default": {
|
"mysql": {
|
||||||
Driver: "mysql",
|
Driver: "mysql",
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: "3366",
|
Port: "3366",
|
||||||
@ -20,12 +20,14 @@ func init() {
|
|||||||
// ParseTime: true,
|
// ParseTime: true,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
myConn = database.Connection("mysql")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanMapSlice(t *testing.T) {
|
func TestScanMapSlice(t *testing.T) {
|
||||||
dest := []map[string]any{}
|
dest := []map[string]any{}
|
||||||
|
|
||||||
oDB.Select("select * from fsm", []any{}, &dest)
|
myConn.Select("select * from fsm", []any{}, &dest)
|
||||||
|
|
||||||
t.Log(dest)
|
t.Log(dest)
|
||||||
}
|
}
|
||||||
@ -34,7 +36,7 @@ func TestScanValues(t *testing.T) {
|
|||||||
|
|
||||||
dest := []int{}
|
dest := []int{}
|
||||||
|
|
||||||
oDB.Select("select id from fsm", []any{}, &dest)
|
myConn.Select("select 1 as id", []any{}, &dest)
|
||||||
|
|
||||||
t.Log(dest)
|
t.Log(dest)
|
||||||
}
|
}
|
||||||
@ -43,19 +45,31 @@ func TestScanMap(t *testing.T) {
|
|||||||
|
|
||||||
dest := map[string]any{}
|
dest := map[string]any{}
|
||||||
|
|
||||||
oDB.Select("select * from fsm where id = 3", []any{}, &dest)
|
myConn.Select("select 0 as id, 'test' as code", []any{}, &dest)
|
||||||
|
|
||||||
t.Log(dest)
|
t.Log(dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanStruct(t *testing.T) {
|
func TestScanStruct(t *testing.T) {
|
||||||
|
|
||||||
dest := struct {
|
dest := struct {
|
||||||
Id int `db:"id"`
|
Id int `db:"id"`
|
||||||
Code string `db:"code"`
|
Code string `db:"code"`
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
oDB.Select("select * from fsm", []any{}, &dest)
|
myConn.Select(`select 0 as id, 'test' as code`, []any{}, &dest)
|
||||||
|
|
||||||
|
t.Log(dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanJsonFieldStruct(t *testing.T) {
|
||||||
|
|
||||||
|
dest := struct {
|
||||||
|
Id int `db:"id"`
|
||||||
|
Code string `db:"code"`
|
||||||
|
Map map[string]any `db:"map"`
|
||||||
|
}{}
|
||||||
|
|
||||||
|
myConn.Select(`select 0 as id, 'test' as code, '{"a":1}' as map`, []any{}, &dest)
|
||||||
|
|
||||||
t.Log(dest)
|
t.Log(dest)
|
||||||
}
|
}
|
||||||
@ -67,7 +81,20 @@ func TestScanStructSlice(t *testing.T) {
|
|||||||
Code string `db:"code"`
|
Code string `db:"code"`
|
||||||
}{}
|
}{}
|
||||||
|
|
||||||
oDB.Select("select * from fsm", []any{}, &dest)
|
myConn.Select("select * from ((select 0 as id, 'test' as code) union (select 1 as id, 'test1' as code)) as t", []any{}, &dest)
|
||||||
|
|
||||||
|
t.Log(dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanJsonFieldStructSlice(t *testing.T) {
|
||||||
|
|
||||||
|
dest := []struct {
|
||||||
|
Id int `db:"id"`
|
||||||
|
Code string `db:"code"`
|
||||||
|
Map map[string]any `db:"map"`
|
||||||
|
}{}
|
||||||
|
|
||||||
|
myConn.Select(`select * from ((select 0 as id, 'test' as code, '{"a":1}' as map) union (select 1 as id, 'test1' as code, '{"a":2}' as map)) as t`, []any{}, &dest)
|
||||||
|
|
||||||
t.Log(dest)
|
t.Log(dest)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user