[fix] 处理查询结果字段类型为 slice,struct时报错
This commit is contained in:
		
							parent
							
								
									cd33e09d69
								
							
						
					
					
						commit
						bf3032103e
					
				
							
								
								
									
										78
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										78
									
								
								scan.go
									
									
									
									
									
								
							@ -8,7 +8,7 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type RowScan struct {
 | 
					type RowScan struct {
 | 
				
			||||||
	Count int
 | 
						count int
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (RowScan) LastInsertId() (int64, error) {
 | 
					func (RowScan) LastInsertId() (int64, error) {
 | 
				
			||||||
@ -16,7 +16,7 @@ func (RowScan) LastInsertId() (int64, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (v RowScan) RowsAffected() (int64, error) {
 | 
					func (v RowScan) RowsAffected() (int64, error) {
 | 
				
			||||||
	return int64(v.Count), nil
 | 
						return int64(v.count), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 入口
 | 
					// 入口
 | 
				
			||||||
@ -46,7 +46,7 @@ func ScanResult(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
		return sMap(rows, dest)
 | 
							return sMap(rows, dest)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		for rows.Next() {
 | 
							for rows.Next() {
 | 
				
			||||||
			result.Count++
 | 
								result.count++
 | 
				
			||||||
			err := rows.Scan(dest)
 | 
								err := rows.Scan(dest)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				panic(err.Error())
 | 
									panic(err.Error())
 | 
				
			||||||
@ -88,8 +88,12 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
		keys[kFiled] = i
 | 
							keys[kFiled] = i
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 提取 json 字段, json 字段, 不能直接赋值,
 | 
				
			||||||
 | 
						// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
 | 
				
			||||||
 | 
						var jFields map[int]reflect.Value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for rows.Next() {
 | 
						for rows.Next() {
 | 
				
			||||||
		result.Count++
 | 
							result.count++
 | 
				
			||||||
		var v, vp reflect.Value
 | 
							var v, vp reflect.Value
 | 
				
			||||||
		if itemIsPtr {
 | 
							if itemIsPtr {
 | 
				
			||||||
			vp = reflect.New(sliceItem.Elem())
 | 
								vp = reflect.New(sliceItem.Elem())
 | 
				
			||||||
@ -101,10 +105,18 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		scanArgs := make([]any, len(columns))
 | 
							scanArgs := make([]any, len(columns))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							jFields = map[int]reflect.Value{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 初始化 map 字段
 | 
							// 初始化 map 字段
 | 
				
			||||||
		for i, k := range columns {
 | 
							for i, k := range columns {
 | 
				
			||||||
			if f, ok := keys[k]; ok {
 | 
								if f, ok := keys[k]; ok {
 | 
				
			||||||
 | 
									switch v.Field(f).Kind() {
 | 
				
			||||||
 | 
									case reflect.Slice, reflect.Struct, reflect.Map:
 | 
				
			||||||
 | 
										scanArgs[i] = &[]uint8{}
 | 
				
			||||||
 | 
										jFields[i] = v.Field(f)
 | 
				
			||||||
 | 
									default:
 | 
				
			||||||
					scanArgs[i] = v.Field(f).Addr().Interface()
 | 
										scanArgs[i] = v.Field(f).Addr().Interface()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				scanArgs[i] = new(any)
 | 
									scanArgs[i] = new(any)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -114,6 +126,22 @@ func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
			panic(err)
 | 
								panic(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// json 字符串 反序列化
 | 
				
			||||||
 | 
							for i, rv := range jFields {
 | 
				
			||||||
 | 
								v := rv.Interface()
 | 
				
			||||||
 | 
								data := *scanArgs[i].(*[]uint8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if len(data) == 0 {
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if err := json.Unmarshal(data, &v); err != nil {
 | 
				
			||||||
 | 
									panic(err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								rv.Set(reflect.ValueOf(v))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if itemIsPtr {
 | 
							if itemIsPtr {
 | 
				
			||||||
			realDest.Set(reflect.Append(realDest, vp))
 | 
								realDest.Set(reflect.Append(realDest, vp))
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
@ -141,14 +169,26 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
		keys[kFiled] = i
 | 
							keys[kFiled] = i
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 提取 json 字段, json 字段, 不能直接赋值,
 | 
				
			||||||
 | 
						// 先赋值给 uint8, 再通过 json.Unmarshal 赋值给字段
 | 
				
			||||||
 | 
						var jFields map[int]reflect.Value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for rows.Next() {
 | 
						for rows.Next() {
 | 
				
			||||||
		if result.Count == 0 {
 | 
							if result.count == 0 {
 | 
				
			||||||
			scanArgs := make([]any, len(columns))
 | 
								scanArgs := make([]any, len(columns))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								jFields = map[int]reflect.Value{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// 初始化 map 字段
 | 
								// 初始化 map 字段
 | 
				
			||||||
			for i, k := range columns {
 | 
								for i, k := range columns {
 | 
				
			||||||
				if f, ok := keys[k]; ok {
 | 
									if f, ok := keys[k]; ok {
 | 
				
			||||||
 | 
										switch v.Field(f).Kind() {
 | 
				
			||||||
 | 
										case reflect.Slice, reflect.Struct, reflect.Map:
 | 
				
			||||||
 | 
											scanArgs[i] = &[]uint8{}
 | 
				
			||||||
 | 
											jFields[i] = v.Field(f)
 | 
				
			||||||
 | 
										default:
 | 
				
			||||||
						scanArgs[i] = v.Field(f).Addr().Interface()
 | 
											scanArgs[i] = v.Field(f).Addr().Interface()
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					scanArgs[i] = new(any)
 | 
										scanArgs[i] = new(any)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -158,6 +198,22 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
				panic(err)
 | 
									panic(err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// json 字符串 反序列化
 | 
				
			||||||
 | 
								for i, rv := range jFields {
 | 
				
			||||||
 | 
									v := rv.Interface()
 | 
				
			||||||
 | 
									data := *scanArgs[i].(*[]uint8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if len(data) == 0 {
 | 
				
			||||||
 | 
										continue
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if err := json.Unmarshal(data, &v); err != nil {
 | 
				
			||||||
 | 
										panic(err)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									rv.Set(reflect.ValueOf(v))
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if realDest.Kind() == reflect.Ptr {
 | 
								if realDest.Kind() == reflect.Ptr {
 | 
				
			||||||
				realDest.Set(vp)
 | 
									realDest.Set(vp)
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
@ -165,7 +221,7 @@ func sStruct(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		result.Count++
 | 
							result.count++
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return result
 | 
						return result
 | 
				
			||||||
@ -180,7 +236,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
		scanArgs := make([]any, len(columns))
 | 
							scanArgs := make([]any, len(columns))
 | 
				
			||||||
		element := make(map[string]any)
 | 
							element := make(map[string]any)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		result.Count++
 | 
							result.count++
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 初始化 map 字段
 | 
							// 初始化 map 字段
 | 
				
			||||||
		for i := 0; i < len(columns); i++ {
 | 
							for i := 0; i < len(columns); i++ {
 | 
				
			||||||
@ -197,7 +253,7 @@ func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
			switch v := (*scanArgs[i].(*any)).(type) {
 | 
								switch v := (*scanArgs[i].(*any)).(type) {
 | 
				
			||||||
			case []uint8:
 | 
								case []uint8:
 | 
				
			||||||
				data := string(v)
 | 
									data := string(v)
 | 
				
			||||||
				if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
 | 
									if len(v) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
 | 
				
			||||||
					jData := new(any)
 | 
										jData := new(any)
 | 
				
			||||||
					if err := json.Unmarshal(v, jData); err == nil {
 | 
										if err := json.Unmarshal(v, jData); err == nil {
 | 
				
			||||||
						element[column] = *jData
 | 
											element[column] = *jData
 | 
				
			||||||
@ -223,7 +279,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
	realDest := reflect.Indirect(reflect.ValueOf(dest))
 | 
						realDest := reflect.Indirect(reflect.ValueOf(dest))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for rows.Next() {
 | 
						for rows.Next() {
 | 
				
			||||||
		if result.Count == 0 {
 | 
							if result.count == 0 {
 | 
				
			||||||
			scanArgs := make([]interface{}, len(columns))
 | 
								scanArgs := make([]interface{}, len(columns))
 | 
				
			||||||
			// 初始化 map 字段
 | 
								// 初始化 map 字段
 | 
				
			||||||
			for i := 0; i < len(columns); i++ {
 | 
								for i := 0; i < len(columns); i++ {
 | 
				
			||||||
@ -255,7 +311,7 @@ func sMap(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		result.Count++
 | 
							result.count++
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return result
 | 
						return result
 | 
				
			||||||
@ -267,7 +323,7 @@ func sValues(rows *sql.Rows, dest any) (result RowScan) {
 | 
				
			|||||||
	realDest := reflect.Indirect(reflect.ValueOf(dest))
 | 
						realDest := reflect.Indirect(reflect.ValueOf(dest))
 | 
				
			||||||
	scanArgs := make([]interface{}, len(columns))
 | 
						scanArgs := make([]interface{}, len(columns))
 | 
				
			||||||
	for rows.Next() {
 | 
						for rows.Next() {
 | 
				
			||||||
		result.Count++
 | 
							result.count++
 | 
				
			||||||
		scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface()
 | 
							scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface()
 | 
				
			||||||
		err := rows.Scan(scanArgs...)
 | 
							err := rows.Scan(scanArgs...)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										45
									
								
								scan_test.go
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								scan_test.go
									
									
									
									
									
								
							@ -4,11 +4,11 @@ import (
 | 
				
			|||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var oDB *DB
 | 
					var myConn *Connection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	oDB = Open(map[string]DBConfig{
 | 
						database := Open(map[string]DBConfig{
 | 
				
			||||||
		"default": {
 | 
							"mysql": {
 | 
				
			||||||
			Driver:          "mysql",
 | 
								Driver:          "mysql",
 | 
				
			||||||
			Host:            "localhost",
 | 
								Host:            "localhost",
 | 
				
			||||||
			Port:            "3366",
 | 
								Port:            "3366",
 | 
				
			||||||
@ -20,12 +20,14 @@ func init() {
 | 
				
			|||||||
			// ParseTime:       true,
 | 
								// ParseTime:       true,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						myConn = database.Connection("mysql")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestScanMapSlice(t *testing.T) {
 | 
					func TestScanMapSlice(t *testing.T) {
 | 
				
			||||||
	dest := []map[string]any{}
 | 
						dest := []map[string]any{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	oDB.Select("select * from fsm", []any{}, &dest)
 | 
						myConn.Select("select * from fsm", []any{}, &dest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Log(dest)
 | 
						t.Log(dest)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -34,7 +36,7 @@ func TestScanValues(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	dest := []int{}
 | 
						dest := []int{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	oDB.Select("select id from fsm", []any{}, &dest)
 | 
						myConn.Select("select 1 as id", []any{}, &dest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Log(dest)
 | 
						t.Log(dest)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -43,19 +45,31 @@ func TestScanMap(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	dest := map[string]any{}
 | 
						dest := map[string]any{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	oDB.Select("select * from fsm where id = 3", []any{}, &dest)
 | 
						myConn.Select("select 0 as id, 'test' as code", []any{}, &dest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Log(dest)
 | 
						t.Log(dest)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestScanStruct(t *testing.T) {
 | 
					func TestScanStruct(t *testing.T) {
 | 
				
			||||||
 | 
					 | 
				
			||||||
	dest := struct {
 | 
						dest := struct {
 | 
				
			||||||
		Id   int    `db:"id"`
 | 
							Id   int    `db:"id"`
 | 
				
			||||||
		Code string `db:"code"`
 | 
							Code string `db:"code"`
 | 
				
			||||||
	}{}
 | 
						}{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	oDB.Select("select * from fsm", []any{}, &dest)
 | 
						myConn.Select(`select 0 as id, 'test' as code`, []any{}, &dest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(dest)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestScanJsonFieldStruct(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dest := struct {
 | 
				
			||||||
 | 
							Id   int            `db:"id"`
 | 
				
			||||||
 | 
							Code string         `db:"code"`
 | 
				
			||||||
 | 
							Map  map[string]any `db:"map"`
 | 
				
			||||||
 | 
						}{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						myConn.Select(`select 0 as id, 'test' as code, '{"a":1}' as map`, []any{}, &dest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Log(dest)
 | 
						t.Log(dest)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -67,7 +81,20 @@ func TestScanStructSlice(t *testing.T) {
 | 
				
			|||||||
		Code string `db:"code"`
 | 
							Code string `db:"code"`
 | 
				
			||||||
	}{}
 | 
						}{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	oDB.Select("select * from fsm", []any{}, &dest)
 | 
						myConn.Select("select * from ((select 0 as id, 'test' as code) union (select 1 as id, 'test1' as code)) as t", []any{}, &dest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(dest)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestScanJsonFieldStructSlice(t *testing.T) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dest := []struct {
 | 
				
			||||||
 | 
							Id   int            `db:"id"`
 | 
				
			||||||
 | 
							Code string         `db:"code"`
 | 
				
			||||||
 | 
							Map  map[string]any `db:"map"`
 | 
				
			||||||
 | 
						}{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						myConn.Select(`select * from ((select 0 as id, 'test' as code, '{"a":1}' as map) union (select 1 as id, 'test1' as code, '{"a":2}' as map)) as t`, []any{}, &dest)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	t.Log(dest)
 | 
						t.Log(dest)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user