package db import ( "database/sql" "encoding/json" "errors" "fmt" "reflect" "github.com/samber/lo" ) type RowScan struct { count int } func (RowScan) LastInsertId() (int64, error) { return 0, errors.New("no insert in select") } func (v RowScan) RowsAffected() (int64, error) { return int64(v.count), nil } // 入口 func ScanResult(rows *sql.Rows, dest any) (result RowScan) { realDest := reflect.Indirect(reflect.ValueOf(dest)) if reflect.TypeOf(dest).Kind() != reflect.Ptr { panic("dest 请传入指针类型") } if realDest.Kind() == reflect.Slice { slice := realDest.Type() sliceItem := slice.Elem() if sliceItem.Kind() == reflect.Map { return sMapSlice(rows, dest) } else if sliceItem.Kind() == reflect.Struct { return sStructSlice(rows, dest) } else if sliceItem.Kind() == reflect.Ptr { if sliceItem.Elem().Kind() == reflect.Struct { return sStructSlice(rows, dest) } } return sValues(rows, dest) } else if realDest.Kind() == reflect.Struct { return sStruct(rows, dest) } else if realDest.Kind() == reflect.Map { return sMap(rows, dest) } else { for rows.Next() { result.count++ err := rows.Scan(dest) if err != nil { panic(err.Error()) } } } return result } // to struct slice func sStructSlice(rows *sql.Rows, dest any) (result RowScan) { realDest := reflect.Indirect(reflect.ValueOf(dest)) slice := realDest.Type() sliceItem := slice.Elem() itemIsPtr := sliceItem.Kind() == reflect.Ptr keys := map[string]int{} columns, _ := rows.Columns() lenField := 0 if itemIsPtr { lenField = sliceItem.Elem().NumField() } else { lenField = sliceItem.NumField() } // 缓存 dest 结构 // key is field name, value is field index for i := 0; i < lenField; i++ { var tField reflect.StructField if itemIsPtr { tField = sliceItem.Elem().Field(i) } else { tField = sliceItem.Field(i) } kFiled := tField.Tag.Get("db") if kFiled == "" { kFiled = tField.Name } keys[kFiled] = i } // 提取 json 字段, json 字段, 不能直接赋值, // 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段 var jFields map[int]reflect.Value for rows.Next() { result.count++ var v, vp reflect.Value if itemIsPtr { vp = reflect.New(sliceItem.Elem()) v = reflect.Indirect(vp) } else { vp = reflect.New(sliceItem) v = reflect.Indirect(vp) } scanArgs := make([]any, len(columns)) jFields = map[int]reflect.Value{} // 初始化 map 字段 for i, k := range columns { if f, ok := keys[k]; ok { switch v.Field(f).Kind() { case reflect.Slice, reflect.Struct, reflect.Map: scanArgs[i] = &[]uint8{} jFields[i] = v.Field(f) case reflect.String: scanArgs[i] = &sql.NullString{} jFields[i] = v.Field(f) default: scanArgs[i] = v.Field(f).Addr().Interface() } } else { scanArgs[i] = new(any) } } if err := rows.Scan(scanArgs...); err != nil { panic(err) } // json 字符串 反序列化 for i, rv := range jFields { switch rv.Kind() { case reflect.String: rv.SetString(scanArgs[i].(*sql.NullString).String) default: if data := *scanArgs[i].(*[]uint8); len(data) > 0 { v := rv.Addr().Interface() if err := json.Unmarshal(data, &v); err != nil { panic(fmt.Sprintf("[%s:%s:%s] => %s", lo.Must1(lo.FindKey(keys, i)), rv.Kind(), data, err)) } } } } if itemIsPtr { realDest.Set(reflect.Append(realDest, vp)) } else { realDest.Set(reflect.Append(realDest, v)) } } return result } // to struct func sStruct(rows *sql.Rows, dest any) (result RowScan) { columns, _ := rows.Columns() realDest := reflect.Indirect(reflect.ValueOf(dest)) vp := reflect.New(realDest.Type()) v := reflect.Indirect(vp) keys := map[string]int{} for i := 0; i < v.NumField(); i++ { tField := v.Type().Field(i) kFiled := tField.Tag.Get("db") if kFiled == "" { kFiled = tField.Name } keys[kFiled] = i } // 提取 json 字段, json 字段, 不能直接赋值, // 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段 var jFields map[int]reflect.Value for rows.Next() { if result.count == 0 { scanArgs := make([]any, len(columns)) jFields = map[int]reflect.Value{} // 初始化 map 字段 for i, k := range columns { if f, ok := keys[k]; ok { switch v.Field(f).Kind() { case reflect.Slice, reflect.Struct, reflect.Map: scanArgs[i] = &[]uint8{} jFields[i] = v.Field(f) case reflect.String: scanArgs[i] = &sql.NullString{} jFields[i] = v.Field(f) default: scanArgs[i] = v.Field(f).Addr().Interface() } } else { scanArgs[i] = new(any) } } if err := rows.Scan(scanArgs...); err != nil { panic(err) } // json 字符串 反序列化 for i, rv := range jFields { switch rv.Kind() { case reflect.String: rv.SetString(scanArgs[i].(*sql.NullString).String) default: if data := *scanArgs[i].(*[]uint8); len(data) > 0 { v := rv.Addr().Interface() if err := json.Unmarshal(data, &v); err != nil { panic(fmt.Sprintf("[%s:%s:%s] => %s", lo.Must1(lo.FindKey(keys, i)), rv.Kind(), data, err)) } } } } if realDest.Kind() == reflect.Ptr { realDest.Set(vp) } else { realDest.Set(v) } } result.count++ } return result } // to map slice func sMapSlice(rows *sql.Rows, dest any) (result RowScan) { columns, _ := rows.Columns() realDest := reflect.Indirect(reflect.ValueOf(dest)) for rows.Next() { scanArgs := make([]any, len(columns)) element := make(map[string]any) result.count++ // 初始化 map 字段 for i := 0; i < len(columns); i++ { scanArgs[i] = new(any) } // 填充 map 字段 if err := rows.Scan(scanArgs...); err != nil { panic(err.Error()) } // 将 scan 后的 map 填充到结果集 for i, column := range columns { switch v := (*scanArgs[i].(*any)).(type) { case []uint8: data := string(v) if len(v) > 2 && (data[0:1] == "{" || data[0:1] == "[") { jData := new(any) if err := json.Unmarshal(v, jData); err == nil { element[column] = *jData } } else { element[column] = data } default: // int64 element[column] = v } } realDest.Set(reflect.Append(realDest, reflect.ValueOf(element))) } return result } // to map func sMap(rows *sql.Rows, dest any) (result RowScan) { columns, _ := rows.Columns() realDest := reflect.Indirect(reflect.ValueOf(dest)) for rows.Next() { if result.count == 0 { scanArgs := make([]interface{}, len(columns)) // 初始化 map 字段 for i := 0; i < len(columns); i++ { scanArgs[i] = new(any) } if err := rows.Scan(scanArgs...); err != nil { panic(err) } for i, column := range columns { var rv reflect.Value switch v := (*scanArgs[i].(*any)).(type) { case []uint8: data := string(v) if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") { jData := new(any) if err := json.Unmarshal(v, jData); err == nil { rv = reflect.ValueOf(jData).Elem() } } else { rv = reflect.ValueOf(data) } default: // int64 rv = reflect.ValueOf(v) } realDest.SetMapIndex(reflect.ValueOf(column), rv) } } result.count++ } return result } // pluck, 如 []int{} func sValues(rows *sql.Rows, dest any) (result RowScan) { columns, _ := rows.Columns() realDest := reflect.Indirect(reflect.ValueOf(dest)) scanArgs := make([]interface{}, len(columns)) for rows.Next() { result.count++ scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface() err := rows.Scan(scanArgs...) if err != nil { panic(err) } realDest.Set(reflect.Append(realDest, reflect.ValueOf(reflect.ValueOf(scanArgs[0]).Elem().Interface()))) } return result }