341 lines
7.6 KiB
Go
341 lines
7.6 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
)
|
|
|
|
type RowScan struct {
|
|
count int
|
|
}
|
|
|
|
func (RowScan) LastInsertId() (int64, error) {
|
|
return 0, errors.New("no insert in select")
|
|
}
|
|
|
|
func (v RowScan) RowsAffected() (int64, error) {
|
|
return int64(v.count), nil
|
|
}
|
|
|
|
// 入口
|
|
func ScanResult(rows *sql.Rows, dest any) (result RowScan) {
|
|
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
|
|
|
if reflect.TypeOf(dest).Kind() != reflect.Ptr {
|
|
panic("dest 请传入指针类型")
|
|
}
|
|
|
|
if realDest.Kind() == reflect.Slice {
|
|
slice := realDest.Type()
|
|
sliceItem := slice.Elem()
|
|
if sliceItem.Kind() == reflect.Map {
|
|
return sMapSlice(rows, dest)
|
|
} else if sliceItem.Kind() == reflect.Struct {
|
|
return sStructSlice(rows, dest)
|
|
} else if sliceItem.Kind() == reflect.Ptr {
|
|
if sliceItem.Elem().Kind() == reflect.Struct {
|
|
return sStructSlice(rows, dest)
|
|
}
|
|
}
|
|
return sValues(rows, dest)
|
|
} else if realDest.Kind() == reflect.Struct {
|
|
return sStruct(rows, dest)
|
|
} else if realDest.Kind() == reflect.Map {
|
|
return sMap(rows, dest)
|
|
} else {
|
|
for rows.Next() {
|
|
result.count++
|
|
err := rows.Scan(dest)
|
|
if err != nil {
|
|
panic(err.Error())
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// to struct slice
|
|
func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
|
|
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
|
slice := realDest.Type()
|
|
sliceItem := slice.Elem()
|
|
itemIsPtr := sliceItem.Kind() == reflect.Ptr
|
|
keys := map[string]int{}
|
|
columns, _ := rows.Columns()
|
|
|
|
lenField := 0
|
|
if itemIsPtr {
|
|
lenField = sliceItem.Elem().NumField()
|
|
} else {
|
|
lenField = sliceItem.NumField()
|
|
}
|
|
|
|
// 缓存 dest 结构
|
|
// key is field name, value is field index
|
|
for i := 0; i < lenField; i++ {
|
|
var tField reflect.StructField
|
|
if itemIsPtr {
|
|
tField = sliceItem.Elem().Field(i)
|
|
} else {
|
|
tField = sliceItem.Field(i)
|
|
}
|
|
kFiled := tField.Tag.Get("db")
|
|
if kFiled == "" {
|
|
kFiled = tField.Name
|
|
}
|
|
keys[kFiled] = i
|
|
}
|
|
|
|
// 提取 json 字段, json 字段, 不能直接赋值,
|
|
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
|
|
var jFields map[int]reflect.Value
|
|
|
|
for rows.Next() {
|
|
result.count++
|
|
var v, vp reflect.Value
|
|
if itemIsPtr {
|
|
vp = reflect.New(sliceItem.Elem())
|
|
v = reflect.Indirect(vp)
|
|
} else {
|
|
vp = reflect.New(sliceItem)
|
|
v = reflect.Indirect(vp)
|
|
}
|
|
|
|
scanArgs := make([]any, len(columns))
|
|
|
|
jFields = map[int]reflect.Value{}
|
|
|
|
// 初始化 map 字段
|
|
for i, k := range columns {
|
|
if f, ok := keys[k]; ok {
|
|
switch v.Field(f).Kind() {
|
|
case reflect.Slice, reflect.Struct, reflect.Map:
|
|
scanArgs[i] = &[]uint8{}
|
|
jFields[i] = v.Field(f)
|
|
case reflect.String:
|
|
scanArgs[i] = &sql.NullString{}
|
|
jFields[i] = v.Field(f)
|
|
default:
|
|
scanArgs[i] = v.Field(f).Addr().Interface()
|
|
}
|
|
} else {
|
|
scanArgs[i] = new(any)
|
|
}
|
|
}
|
|
|
|
if err := rows.Scan(scanArgs...); err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// json 字符串 反序列化
|
|
for i, rv := range jFields {
|
|
switch rv.Kind() {
|
|
case reflect.String:
|
|
rv.SetString(scanArgs[i].(*sql.NullString).String)
|
|
default:
|
|
if data := *scanArgs[i].(*[]uint8); len(data) > 0 {
|
|
v := rv.Addr().Interface()
|
|
if err := json.Unmarshal(data, &v); err != nil {
|
|
panic(fmt.Sprintf("[%s:%v] => %s", rv.Kind(), data, err))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if itemIsPtr {
|
|
realDest.Set(reflect.Append(realDest, vp))
|
|
} else {
|
|
realDest.Set(reflect.Append(realDest, v))
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// to struct
|
|
func sStruct(rows *sql.Rows, dest any) (result RowScan) {
|
|
columns, _ := rows.Columns()
|
|
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
|
vp := reflect.New(realDest.Type())
|
|
v := reflect.Indirect(vp)
|
|
|
|
keys := map[string]int{}
|
|
|
|
for i := 0; i < v.NumField(); i++ {
|
|
tField := v.Type().Field(i)
|
|
kFiled := tField.Tag.Get("db")
|
|
if kFiled == "" {
|
|
kFiled = tField.Name
|
|
}
|
|
keys[kFiled] = i
|
|
}
|
|
|
|
// 提取 json 字段, json 字段, 不能直接赋值,
|
|
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
|
|
var jFields map[int]reflect.Value
|
|
|
|
for rows.Next() {
|
|
if result.count == 0 {
|
|
scanArgs := make([]any, len(columns))
|
|
|
|
jFields = map[int]reflect.Value{}
|
|
|
|
// 初始化 map 字段
|
|
for i, k := range columns {
|
|
if f, ok := keys[k]; ok {
|
|
switch v.Field(f).Kind() {
|
|
case reflect.Slice, reflect.Struct, reflect.Map:
|
|
scanArgs[i] = &[]uint8{}
|
|
jFields[i] = v.Field(f)
|
|
case reflect.String:
|
|
scanArgs[i] = &sql.NullString{}
|
|
jFields[i] = v.Field(f)
|
|
default:
|
|
scanArgs[i] = v.Field(f).Addr().Interface()
|
|
}
|
|
} else {
|
|
scanArgs[i] = new(any)
|
|
}
|
|
}
|
|
|
|
if err := rows.Scan(scanArgs...); err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// json 字符串 反序列化
|
|
for i, rv := range jFields {
|
|
switch rv.Kind() {
|
|
case reflect.String:
|
|
rv.SetString(scanArgs[i].(*sql.NullString).String)
|
|
default:
|
|
if data := *scanArgs[i].(*[]uint8); len(data) > 0 {
|
|
v := rv.Addr().Interface()
|
|
if err := json.Unmarshal(data, &v); err != nil {
|
|
panic(fmt.Sprintf("[%s:%v] => %s", rv.Kind(), data, err))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if realDest.Kind() == reflect.Ptr {
|
|
realDest.Set(vp)
|
|
} else {
|
|
realDest.Set(v)
|
|
}
|
|
}
|
|
|
|
result.count++
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// to map slice
|
|
func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
|
|
columns, _ := rows.Columns()
|
|
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
|
|
|
for rows.Next() {
|
|
scanArgs := make([]any, len(columns))
|
|
element := make(map[string]any)
|
|
|
|
result.count++
|
|
|
|
// 初始化 map 字段
|
|
for i := 0; i < len(columns); i++ {
|
|
scanArgs[i] = new(any)
|
|
}
|
|
|
|
// 填充 map 字段
|
|
if err := rows.Scan(scanArgs...); err != nil {
|
|
panic(err.Error())
|
|
}
|
|
|
|
// 将 scan 后的 map 填充到结果集
|
|
for i, column := range columns {
|
|
switch v := (*scanArgs[i].(*any)).(type) {
|
|
case []uint8:
|
|
data := string(v)
|
|
if len(v) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
|
|
jData := new(any)
|
|
if err := json.Unmarshal(v, jData); err == nil {
|
|
element[column] = *jData
|
|
}
|
|
} else {
|
|
element[column] = data
|
|
}
|
|
default:
|
|
// int64
|
|
element[column] = v
|
|
}
|
|
}
|
|
|
|
realDest.Set(reflect.Append(realDest, reflect.ValueOf(element)))
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// to map
|
|
func sMap(rows *sql.Rows, dest any) (result RowScan) {
|
|
columns, _ := rows.Columns()
|
|
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
|
|
|
for rows.Next() {
|
|
if result.count == 0 {
|
|
scanArgs := make([]interface{}, len(columns))
|
|
// 初始化 map 字段
|
|
for i := 0; i < len(columns); i++ {
|
|
scanArgs[i] = new(any)
|
|
}
|
|
|
|
if err := rows.Scan(scanArgs...); err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
for i, column := range columns {
|
|
var rv reflect.Value
|
|
switch v := (*scanArgs[i].(*any)).(type) {
|
|
case []uint8:
|
|
data := string(v)
|
|
if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
|
|
jData := new(any)
|
|
if err := json.Unmarshal(v, jData); err == nil {
|
|
rv = reflect.ValueOf(jData).Elem()
|
|
}
|
|
} else {
|
|
rv = reflect.ValueOf(data)
|
|
}
|
|
default:
|
|
// int64
|
|
rv = reflect.ValueOf(v)
|
|
}
|
|
realDest.SetMapIndex(reflect.ValueOf(column), rv)
|
|
}
|
|
}
|
|
|
|
result.count++
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// pluck, 如 []int{}
|
|
func sValues(rows *sql.Rows, dest any) (result RowScan) {
|
|
columns, _ := rows.Columns()
|
|
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
|
scanArgs := make([]interface{}, len(columns))
|
|
for rows.Next() {
|
|
result.count++
|
|
scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface()
|
|
err := rows.Scan(scanArgs...)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
realDest.Set(reflect.Append(realDest, reflect.ValueOf(reflect.ValueOf(scanArgs[0]).Elem().Interface())))
|
|
}
|
|
return result
|
|
}
|