diff --git a/scan.go b/scan.go index 407f6b8..0169913 100644 --- a/scan.go +++ b/scan.go @@ -8,7 +8,7 @@ import ( ) type RowScan struct { - Count int + count int } func (RowScan) LastInsertId() (int64, error) { @@ -16,7 +16,7 @@ func (RowScan) LastInsertId() (int64, error) { } func (v RowScan) RowsAffected() (int64, error) { - return int64(v.Count), nil + return int64(v.count), nil } // 入口 @@ -46,7 +46,7 @@ func ScanResult(rows *sql.Rows, dest any) (result RowScan) { return sMap(rows, dest) } else { for rows.Next() { - result.Count++ + result.count++ err := rows.Scan(dest) if err != nil { panic(err.Error()) @@ -88,8 +88,12 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) { keys[kFiled] = i } + // 提取 json 字段, json 字段, 不能直接赋值, + // 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段 + var jFields map[int]reflect.Value + for rows.Next() { - result.Count++ + result.count++ var v, vp reflect.Value if itemIsPtr { vp = reflect.New(sliceItem.Elem()) @@ -101,10 +105,18 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) { scanArgs := make([]any, len(columns)) + jFields = map[int]reflect.Value{} + // 初始化 map 字段 for i, k := range columns { if f, ok := keys[k]; ok { - scanArgs[i] = v.Field(f).Addr().Interface() + switch v.Field(f).Kind() { + case reflect.Slice, reflect.Struct, reflect.Map: + scanArgs[i] = &[]uint8{} + jFields[i] = v.Field(f) + default: + scanArgs[i] = v.Field(f).Addr().Interface() + } } else { scanArgs[i] = new(any) } @@ -114,6 +126,22 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) { panic(err) } + // json 字符串 反序列化 + for i, rv := range jFields { + v := rv.Interface() + data := *scanArgs[i].(*[]uint8) + + if len(data) == 0 { + continue + } + + if err := json.Unmarshal(data, &v); err != nil { + panic(err) + } + + rv.Set(reflect.ValueOf(v)) + } + if itemIsPtr { realDest.Set(reflect.Append(realDest, vp)) } else { @@ -141,14 +169,26 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) { keys[kFiled] = i } + // 提取 json 字段, json 字段, 不能直接赋值, + // 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段 + var jFields map[int]reflect.Value + for rows.Next() { - if result.Count == 0 { + if result.count == 0 { scanArgs := make([]any, len(columns)) + jFields = map[int]reflect.Value{} + // 初始化 map 字段 for i, k := range columns { if f, ok := keys[k]; ok { - scanArgs[i] = v.Field(f).Addr().Interface() + switch v.Field(f).Kind() { + case reflect.Slice, reflect.Struct, reflect.Map: + scanArgs[i] = &[]uint8{} + jFields[i] = v.Field(f) + default: + scanArgs[i] = v.Field(f).Addr().Interface() + } } else { scanArgs[i] = new(any) } @@ -158,6 +198,22 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) { panic(err) } + // json 字符串 反序列化 + for i, rv := range jFields { + v := rv.Interface() + data := *scanArgs[i].(*[]uint8) + + if len(data) == 0 { + continue + } + + if err := json.Unmarshal(data, &v); err != nil { + panic(err) + } + + rv.Set(reflect.ValueOf(v)) + } + if realDest.Kind() == reflect.Ptr { realDest.Set(vp) } else { @@ -165,7 +221,7 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) { } } - result.Count++ + result.count++ } return result @@ -180,7 +236,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) { scanArgs := make([]any, len(columns)) element := make(map[string]any) - result.Count++ + result.count++ // 初始化 map 字段 for i := 0; i < len(columns); i++ { @@ -197,7 +253,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) { switch v := (*scanArgs[i].(*any)).(type) { case []uint8: data := string(v) - if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") { + if len(v) > 2 && (data[0:1] == "{" || data[0:1] == "[") { jData := new(any) if err := json.Unmarshal(v, jData); err == nil { element[column] = *jData @@ -223,7 +279,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) { realDest := reflect.Indirect(reflect.ValueOf(dest)) for rows.Next() { - if result.Count == 0 { + if result.count == 0 { scanArgs := make([]interface{}, len(columns)) // 初始化 map 字段 for i := 0; i < len(columns); i++ { @@ -255,7 +311,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) { } } - result.Count++ + result.count++ } return result @@ -267,7 +323,7 @@ func sValues(rows *sql.Rows, dest any) (result RowScan) { realDest := reflect.Indirect(reflect.ValueOf(dest)) scanArgs := make([]interface{}, len(columns)) for rows.Next() { - result.Count++ + result.count++ scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface() err := rows.Scan(scanArgs...) if err != nil { diff --git a/scan_test.go b/scan_test.go index 902f726..ea3d0ab 100644 --- a/scan_test.go +++ b/scan_test.go @@ -4,11 +4,11 @@ import ( "testing" ) -var oDB *DB +var myConn *Connection func init() { - oDB = Open(map[string]DBConfig{ - "default": { + database := Open(map[string]DBConfig{ + "mysql": { Driver: "mysql", Host: "localhost", Port: "3366", @@ -20,12 +20,14 @@ func init() { // ParseTime: true, }, }) + + myConn = database.Connection("mysql") } func TestScanMapSlice(t *testing.T) { dest := []map[string]any{} - oDB.Select("select * from fsm", []any{}, &dest) + myConn.Select("select * from fsm", []any{}, &dest) t.Log(dest) } @@ -34,7 +36,7 @@ func TestScanValues(t *testing.T) { dest := []int{} - oDB.Select("select id from fsm", []any{}, &dest) + myConn.Select("select 1 as id", []any{}, &dest) t.Log(dest) } @@ -43,19 +45,31 @@ func TestScanMap(t *testing.T) { dest := map[string]any{} - oDB.Select("select * from fsm where id = 3", []any{}, &dest) + myConn.Select("select 0 as id, 'test' as code", []any{}, &dest) t.Log(dest) } func TestScanStruct(t *testing.T) { - dest := struct { Id int `db:"id"` Code string `db:"code"` }{} - oDB.Select("select * from fsm", []any{}, &dest) + myConn.Select(`select 0 as id, 'test' as code`, []any{}, &dest) + + t.Log(dest) +} + +func TestScanJsonFieldStruct(t *testing.T) { + + dest := struct { + Id int `db:"id"` + Code string `db:"code"` + Map map[string]any `db:"map"` + }{} + + myConn.Select(`select 0 as id, 'test' as code, '{"a":1}' as map`, []any{}, &dest) t.Log(dest) } @@ -67,7 +81,20 @@ func TestScanStructSlice(t *testing.T) { Code string `db:"code"` }{} - oDB.Select("select * from fsm", []any{}, &dest) + myConn.Select("select * from ((select 0 as id, 'test' as code) union (select 1 as id, 'test1' as code)) as t", []any{}, &dest) + + t.Log(dest) +} + +func TestScanJsonFieldStructSlice(t *testing.T) { + + dest := []struct { + Id int `db:"id"` + Code string `db:"code"` + Map map[string]any `db:"map"` + }{} + + myConn.Select(`select * from ((select 0 as id, 'test' as code, '{"a":1}' as map) union (select 1 as id, 'test1' as code, '{"a":2}' as map)) as t`, []any{}, &dest) t.Log(dest) }