package db import ( "database/sql" "encoding/json" "errors" "reflect" ) type RowScan struct { Count int } func (RowScan) LastInsertId() (int64, error) { return 0, errors.New("no insert in select") } func (v RowScan) RowsAffected() (int64, error) { return int64(v.Count), nil } // 入口 func ScanResult(rows *sql.Rows, dest any) (result RowScan) { realDest := reflect.Indirect(reflect.ValueOf(dest)) if reflect.TypeOf(dest).Kind() != reflect.Ptr { panic("dest 请传入指针类型") } if realDest.Kind() == reflect.Slice { slice := realDest.Type() sliceItem := slice.Elem() if sliceItem.Kind() == reflect.Map { return sMapSlice(rows, dest) } else if sliceItem.Kind() == reflect.Struct { return sStructSlice(rows, dest) } else if sliceItem.Kind() == reflect.Ptr { if sliceItem.Elem().Kind() == reflect.Struct { return sStructSlice(rows, dest) } } return sValues(rows, dest) } else if realDest.Kind() == reflect.Struct { return sStruct(rows, dest) } else if realDest.Kind() == reflect.Map { return sMap(rows, dest) } else { for rows.Next() { result.Count++ err := rows.Scan(dest) if err != nil { panic(err.Error()) } } } return result } // to struct slice func sStructSlice(rows *sql.Rows, dest any) (result RowScan) { realDest := reflect.Indirect(reflect.ValueOf(dest)) slice := realDest.Type() sliceItem := slice.Elem() itemIsPtr := sliceItem.Kind() == reflect.Ptr keys := map[string]int{} columns, _ := rows.Columns() lenField := 0 if itemIsPtr { lenField = sliceItem.Elem().NumField() } else { lenField = sliceItem.NumField() } // 缓存 dest 结构 // key is field name, value is field index for i := 0; i < lenField; i++ { var tField reflect.StructField if itemIsPtr { tField = sliceItem.Elem().Field(i) } else { tField = sliceItem.Field(i) } kFiled := tField.Tag.Get("db") if kFiled == "" { kFiled = tField.Name } keys[kFiled] = i } for rows.Next() { result.Count++ var v, vp reflect.Value if itemIsPtr { vp = reflect.New(sliceItem.Elem()) v = reflect.Indirect(vp) } else { vp = reflect.New(sliceItem) v = reflect.Indirect(vp) } scanArgs := make([]any, len(columns)) // 初始化 map 字段 for i, k := range columns { if f, ok := keys[k]; ok { scanArgs[i] = v.Field(f).Addr().Interface() } else { scanArgs[i] = new(any) } } if err := rows.Scan(scanArgs...); err != nil { panic(err) } if itemIsPtr { realDest.Set(reflect.Append(realDest, vp)) } else { realDest.Set(reflect.Append(realDest, v)) } } return result } // to struct func sStruct(rows *sql.Rows, dest any) (result RowScan) { columns, _ := rows.Columns() realDest := reflect.Indirect(reflect.ValueOf(dest)) vp := reflect.New(realDest.Type()) v := reflect.Indirect(vp) keys := map[string]int{} for i := 0; i < v.NumField(); i++ { tField := v.Type().Field(i) kFiled := tField.Tag.Get("db") if kFiled == "" { kFiled = tField.Name } keys[kFiled] = i } for rows.Next() { if result.Count == 0 { scanArgs := make([]any, len(columns)) // 初始化 map 字段 for i, k := range columns { if f, ok := keys[k]; ok { scanArgs[i] = v.Field(f).Addr().Interface() } else { scanArgs[i] = new(any) } } if err := rows.Scan(scanArgs...); err != nil { panic(err) } if realDest.Kind() == reflect.Ptr { realDest.Set(vp) } else { realDest.Set(v) } } result.Count++ } return result } // to map slice func sMapSlice(rows *sql.Rows, dest any) (result RowScan) { columns, _ := rows.Columns() realDest := reflect.Indirect(reflect.ValueOf(dest)) for rows.Next() { scanArgs := make([]any, len(columns)) element := make(map[string]any) result.Count++ // 初始化 map 字段 for i := 0; i < len(columns); i++ { scanArgs[i] = new(any) } // 填充 map 字段 if err := rows.Scan(scanArgs...); err != nil { panic(err.Error()) } // 将 scan 后的 map 填充到结果集 for i, column := range columns { switch v := (*scanArgs[i].(*any)).(type) { case []uint8: data := string(v) if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") { jData := new(any) if err := json.Unmarshal(v, jData); err == nil { element[column] = *jData } } else { element[column] = data } default: // int64 element[column] = v } } realDest.Set(reflect.Append(realDest, reflect.ValueOf(element))) } return result } // to map func sMap(rows *sql.Rows, dest any) (result RowScan) { columns, _ := rows.Columns() realDest := reflect.Indirect(reflect.ValueOf(dest)) for rows.Next() { if result.Count == 0 { scanArgs := make([]interface{}, len(columns)) // 初始化 map 字段 for i := 0; i < len(columns); i++ { scanArgs[i] = new(any) } if err := rows.Scan(scanArgs...); err != nil { panic(err) } for i, column := range columns { var rv reflect.Value switch v := (*scanArgs[i].(*any)).(type) { case []uint8: data := string(v) if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") { jData := new(any) if err := json.Unmarshal(v, jData); err == nil { rv = reflect.ValueOf(jData).Elem() } } else { rv = reflect.ValueOf(data) } default: // int64 rv = reflect.ValueOf(v) } realDest.SetMapIndex(reflect.ValueOf(column), rv) } } result.Count++ } return result } // pluck, 如 []int{} func sValues(rows *sql.Rows, dest any) (result RowScan) { columns, _ := rows.Columns() realDest := reflect.Indirect(reflect.ValueOf(dest)) scanArgs := make([]interface{}, len(columns)) for rows.Next() { result.Count++ scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface() err := rows.Scan(scanArgs...) if err != nil { panic(err) } realDest.Set(reflect.Append(realDest, reflect.ValueOf(reflect.ValueOf(scanArgs[0]).Elem().Interface()))) } return result }