package engine import ( "encoding/json" "fmt" "reflect" "strings" "git.fsdpf.net/go/contracts" "git.fsdpf.net/go/contracts/support" "git.fsdpf.net/go/db" "github.com/samber/lo" "github.com/spf13/cast" ) var conn *db.Connection func init() { database := db.Open(map[string]db.DBConfig{ "condition-engine-sqlite3": { Driver: "sqlite3", File: ":memory:", }, }) conn = database.Connection("condition-engine-sqlite3") } type Engine[T any] struct { code string g contracts.GlobalParams def func(data T, g contracts.GlobalParams) error predicates []*EngineCase[T] relations []string } func (this Engine[T]) GetCode() string { return this.code } // 公共参数 func (this *Engine[T]) SetGlobalParams(g contracts.GlobalParams) *Engine[T] { this.g = g return this } func (this *Engine[T]) Case(cond contracts.Condition, cb func(data T, g contracts.GlobalParams) error) *Engine[T] { this.predicates = append(this.predicates, &EngineCase[T]{cond, cb}) return this } // 基础条件 func (this *Engine[T]) Default(cb func(data T, g contracts.GlobalParams) error) *Engine[T] { this.def = cb return this } func (this Engine[T]) toField(rv reflect.Value) (string, error) { if rv.Kind() == reflect.Ptr { rv = reflect.Indirect(rv) } if rv.IsZero() { return "NULL", nil } v := rv.Interface() switch reflect.TypeOf(v).Kind() { case reflect.Map, reflect.Array, reflect.Slice, reflect.Struct: b, err := json.Marshal(v) if err != nil { return "", err } return fmt.Sprintf("CAST('%s' AS JSON1)", b), nil case reflect.String: return fmt.Sprintf("CAST('%s' AS TEXT)", v), nil case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int8, reflect.Int16, reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint8, reflect.Uint16: return fmt.Sprintf("CAST(%d AS INTEGER)", v), nil case reflect.Float64, reflect.Float32: return fmt.Sprintf("CAST(%v AS REAL)", v), nil case reflect.Bool: return fmt.Sprintf("CAST(%v AS BOOL)", v), nil } return "NULL", nil } func (this Engine[T]) toTables(rv reflect.Value, table string) (tables map[string][]string, err error) { rt := rv.Type() if rv.Kind() == reflect.Ptr { rv = reflect.Indirect(rv.Elem()) rt = rv.Type() } if tables == nil { tables = map[string][]string{} } if _, ok := tables[table]; !ok { tables[table] = []string{} } if rv.Kind() == reflect.Struct { // 遍历结构体的字段 for i := 0; i < rt.NumField(); i++ { if _, ok := tables[rt.Field(i).Name]; !ok && lo.Contains(this.relations, rt.Field(i).Name) { if items, err := this.toTables(rv.Field(i).Elem(), rt.Field(i).Name); err == nil { tables = lo.Assign(tables, items) } else { return nil, err } } else if s, err := this.toField(rv.Field(i)); err != nil { return nil, err } else { tables[table] = append(tables[table], fmt.Sprintf("%s as `%s`", s, lo.Ternary(rt.Field(i).Tag.Get("db") != "", rt.Field(i).Tag.Get("db"), rt.Field(i).Name))) } } } else if rv.Kind() == reflect.Map { iter := rv.MapRange() for iter.Next() { if _, ok := tables[iter.Key().String()]; !ok && lo.Contains(this.relations, iter.Key().String()) { if items, err := this.toTables(iter.Value().Elem(), iter.Key().String()); err == nil { tables = lo.Assign(tables, items) } else { return nil, err } } else if s, err := this.toField(iter.Value()); err != nil { return nil, err } else { tables[table] = append(tables[table], fmt.Sprintf("%s as `%s`", s, iter.Key().String())) } } } else { return nil, fmt.Errorf("data type not map/struct, %s", rv.Kind()) } return tables, nil } func (this *Engine[T]) Execute(data T) error { tables, err := this.toTables(reflect.ValueOf(data), this.GetCode()) if err != nil { return err } columns, ok := tables[this.GetCode()] if !ok { return fmt.Errorf("data is not") } sql := conn.Query().FromSub(fmt.Sprintf("SELECT %s", strings.Join(append(columns, "1 as `_`"), ", ")), this.GetCode()) for table, columns := range lo.OmitByKeys(tables, []string{this.GetCode()}) { sql.JoinSub( fmt.Sprintf("SELECT %s", strings.Join(append(columns, "1 as `_`"), ", ")), table, fmt.Sprintf("%s._", this.GetCode()), fmt.Sprintf("%s._", table), ) } if this.g == nil { this.g = support.NewGlobalParam(`{}`, nil) } param := &EngineParam{this.g} for i, p := range this.predicates { sql.AddSelect(db.Raw(fmt.Sprintf("IFNULL(%s, 0) as `%d`", p.ToSql(param), i))) } result := map[string]int64{} // fmt.Println("sql", sql.ToSql()) if _, err := sql.First(&result); err != nil { return fmt.Errorf("%s => %s", err, sql.ToSql()) } isDefault := true for k, v := range result { if v == 0 { continue } isDefault = false if err := this.predicates[cast.ToInt64(k)].Execute(data, this.g); err != nil { return fmt.Errorf("case %q error, %s", k, err) } } if isDefault && this.def != nil { if err := this.def(data, this.g); err != nil { return fmt.Errorf("case %q error, %s", "default", err) } } return nil } func New[T any](table string, g contracts.GlobalParams, relations ...string) *Engine[T] { return &Engine[T]{code: table, g: g, relations: relations} }