db/scan.go

343 lines
7.7 KiB
Go

package db
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"reflect"
"github.com/samber/lo"
)
type RowScan struct {
count int
}
func (RowScan) LastInsertId() (int64, error) {
return 0, errors.New("no insert in select")
}
func (v RowScan) RowsAffected() (int64, error) {
return int64(v.count), nil
}
// 入口
func ScanResult(rows *sql.Rows, dest any) (result RowScan) {
realDest := reflect.Indirect(reflect.ValueOf(dest))
if reflect.TypeOf(dest).Kind() != reflect.Ptr {
panic("dest 请传入指针类型")
}
if realDest.Kind() == reflect.Slice {
slice := realDest.Type()
sliceItem := slice.Elem()
if sliceItem.Kind() == reflect.Map {
return sMapSlice(rows, dest)
} else if sliceItem.Kind() == reflect.Struct {
return sStructSlice(rows, dest)
} else if sliceItem.Kind() == reflect.Ptr {
if sliceItem.Elem().Kind() == reflect.Struct {
return sStructSlice(rows, dest)
}
}
return sValues(rows, dest)
} else if realDest.Kind() == reflect.Struct {
return sStruct(rows, dest)
} else if realDest.Kind() == reflect.Map {
return sMap(rows, dest)
} else {
for rows.Next() {
result.count++
err := rows.Scan(dest)
if err != nil {
panic(err.Error())
}
}
}
return result
}
// to struct slice
func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
realDest := reflect.Indirect(reflect.ValueOf(dest))
slice := realDest.Type()
sliceItem := slice.Elem()
itemIsPtr := sliceItem.Kind() == reflect.Ptr
keys := map[string]int{}
columns, _ := rows.Columns()
lenField := 0
if itemIsPtr {
lenField = sliceItem.Elem().NumField()
} else {
lenField = sliceItem.NumField()
}
// 缓存 dest 结构
// key is field name, value is field index
for i := 0; i < lenField; i++ {
var tField reflect.StructField
if itemIsPtr {
tField = sliceItem.Elem().Field(i)
} else {
tField = sliceItem.Field(i)
}
kFiled := tField.Tag.Get("db")
if kFiled == "" {
kFiled = tField.Name
}
keys[kFiled] = i
}
// 提取 json 字段, json 字段, 不能直接赋值,
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
var jFields map[int]reflect.Value
for rows.Next() {
result.count++
var v, vp reflect.Value
if itemIsPtr {
vp = reflect.New(sliceItem.Elem())
v = reflect.Indirect(vp)
} else {
vp = reflect.New(sliceItem)
v = reflect.Indirect(vp)
}
scanArgs := make([]any, len(columns))
jFields = map[int]reflect.Value{}
// 初始化 map 字段
for i, k := range columns {
if f, ok := keys[k]; ok {
switch v.Field(f).Kind() {
case reflect.Slice, reflect.Struct, reflect.Map:
scanArgs[i] = &[]uint8{}
jFields[i] = v.Field(f)
case reflect.String:
scanArgs[i] = &sql.NullString{}
jFields[i] = v.Field(f)
default:
scanArgs[i] = v.Field(f).Addr().Interface()
}
} else {
scanArgs[i] = new(any)
}
}
if err := rows.Scan(scanArgs...); err != nil {
panic(err)
}
// json 字符串 反序列化
for i, rv := range jFields {
switch rv.Kind() {
case reflect.String:
rv.SetString(scanArgs[i].(*sql.NullString).String)
default:
if data := *scanArgs[i].(*[]uint8); len(data) > 0 {
v := rv.Addr().Interface()
if err := json.Unmarshal(data, &v); err != nil {
panic(fmt.Sprintf("[%s:%s:%s] => %s", lo.Must1(lo.FindKey(keys, i)), rv.Kind(), scanArgs[i], err))
}
}
}
}
if itemIsPtr {
realDest.Set(reflect.Append(realDest, vp))
} else {
realDest.Set(reflect.Append(realDest, v))
}
}
return result
}
// to struct
func sStruct(rows *sql.Rows, dest any) (result RowScan) {
columns, _ := rows.Columns()
realDest := reflect.Indirect(reflect.ValueOf(dest))
vp := reflect.New(realDest.Type())
v := reflect.Indirect(vp)
keys := map[string]int{}
for i := 0; i < v.NumField(); i++ {
tField := v.Type().Field(i)
kFiled := tField.Tag.Get("db")
if kFiled == "" {
kFiled = tField.Name
}
keys[kFiled] = i
}
// 提取 json 字段, json 字段, 不能直接赋值,
// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
var jFields map[int]reflect.Value
for rows.Next() {
if result.count == 0 {
scanArgs := make([]any, len(columns))
jFields = map[int]reflect.Value{}
// 初始化 map 字段
for i, k := range columns {
if f, ok := keys[k]; ok {
switch v.Field(f).Kind() {
case reflect.Slice, reflect.Struct, reflect.Map:
scanArgs[i] = &[]uint8{}
jFields[i] = v.Field(f)
case reflect.String:
scanArgs[i] = &sql.NullString{}
jFields[i] = v.Field(f)
default:
scanArgs[i] = v.Field(f).Addr().Interface()
}
} else {
scanArgs[i] = new(any)
}
}
if err := rows.Scan(scanArgs...); err != nil {
panic(err)
}
// json 字符串 反序列化
for i, rv := range jFields {
switch rv.Kind() {
case reflect.String:
rv.SetString(scanArgs[i].(*sql.NullString).String)
default:
if data := *scanArgs[i].(*[]uint8); len(data) > 0 {
v := rv.Addr().Interface()
if err := json.Unmarshal(data, &v); err != nil {
panic(fmt.Sprintf("[%s:%s:%s] => %s", lo.Must1(lo.FindKey(keys, i)), rv.Kind(), scanArgs[i], err))
}
}
}
}
if realDest.Kind() == reflect.Ptr {
realDest.Set(vp)
} else {
realDest.Set(v)
}
}
result.count++
}
return result
}
// to map slice
func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
columns, _ := rows.Columns()
realDest := reflect.Indirect(reflect.ValueOf(dest))
for rows.Next() {
scanArgs := make([]any, len(columns))
element := make(map[string]any)
result.count++
// 初始化 map 字段
for i := 0; i < len(columns); i++ {
scanArgs[i] = new(any)
}
// 填充 map 字段
if err := rows.Scan(scanArgs...); err != nil {
panic(err.Error())
}
// 将 scan 后的 map 填充到结果集
for i, column := range columns {
switch v := (*scanArgs[i].(*any)).(type) {
case []uint8:
data := string(v)
if len(v) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
jData := new(any)
if err := json.Unmarshal(v, jData); err == nil {
element[column] = *jData
}
} else {
element[column] = data
}
default:
// int64
element[column] = v
}
}
realDest.Set(reflect.Append(realDest, reflect.ValueOf(element)))
}
return result
}
// to map
func sMap(rows *sql.Rows, dest any) (result RowScan) {
columns, _ := rows.Columns()
realDest := reflect.Indirect(reflect.ValueOf(dest))
for rows.Next() {
if result.count == 0 {
scanArgs := make([]interface{}, len(columns))
// 初始化 map 字段
for i := 0; i < len(columns); i++ {
scanArgs[i] = new(any)
}
if err := rows.Scan(scanArgs...); err != nil {
panic(err)
}
for i, column := range columns {
var rv reflect.Value
switch v := (*scanArgs[i].(*any)).(type) {
case []uint8:
data := string(v)
if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
jData := new(any)
if err := json.Unmarshal(v, jData); err == nil {
rv = reflect.ValueOf(jData).Elem()
}
} else {
rv = reflect.ValueOf(data)
}
default:
// int64
rv = reflect.ValueOf(v)
}
realDest.SetMapIndex(reflect.ValueOf(column), rv)
}
}
result.count++
}
return result
}
// pluck, 如 []int{}
func sValues(rows *sql.Rows, dest any) (result RowScan) {
columns, _ := rows.Columns()
realDest := reflect.Indirect(reflect.ValueOf(dest))
scanArgs := make([]interface{}, len(columns))
for rows.Next() {
result.count++
scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface()
err := rows.Scan(scanArgs...)
if err != nil {
panic(err)
}
realDest.Set(reflect.Append(realDest, reflect.ValueOf(reflect.ValueOf(scanArgs[0]).Elem().Interface())))
}
return result
}