209 lines
5.1 KiB
Go
209 lines
5.1 KiB
Go
package engine
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"git.fsdpf.net/go/contracts"
|
|
"git.fsdpf.net/go/contracts/support"
|
|
"git.fsdpf.net/go/db"
|
|
"github.com/samber/lo"
|
|
"github.com/spf13/cast"
|
|
)
|
|
|
|
var conn *db.Connection
|
|
|
|
func init() {
|
|
database := db.Open(map[string]db.DBConfig{
|
|
"condition-engine-sqlite3": {
|
|
Driver: "sqlite3",
|
|
File: ":memory:",
|
|
},
|
|
})
|
|
conn = database.Connection("condition-engine-sqlite3")
|
|
}
|
|
|
|
type Engine[T any] struct {
|
|
code string
|
|
g contracts.GlobalParams
|
|
def func(data T, g contracts.GlobalParams) error
|
|
predicates []*EngineCase[T]
|
|
relations []string
|
|
}
|
|
|
|
func (this Engine[T]) GetCode() string {
|
|
return this.code
|
|
}
|
|
|
|
// 公共参数
|
|
func (this *Engine[T]) SetGlobalParams(g contracts.GlobalParams) *Engine[T] {
|
|
this.g = g
|
|
return this
|
|
}
|
|
|
|
func (this *Engine[T]) Case(cond contracts.Condition, cb func(data T, g contracts.GlobalParams) error) *Engine[T] {
|
|
this.predicates = append(this.predicates, &EngineCase[T]{cond, cb})
|
|
return this
|
|
}
|
|
|
|
// 基础条件
|
|
func (this *Engine[T]) Default(cb func(data T, g contracts.GlobalParams) error) *Engine[T] {
|
|
this.def = cb
|
|
return this
|
|
}
|
|
|
|
func (this Engine[T]) toField(rv reflect.Value) (string, error) {
|
|
if rv.Kind() == reflect.Ptr {
|
|
rv = reflect.Indirect(rv)
|
|
}
|
|
|
|
if rv.IsZero() {
|
|
return "NULL", nil
|
|
}
|
|
|
|
v := rv.Interface()
|
|
|
|
switch reflect.TypeOf(v).Kind() {
|
|
case reflect.Map, reflect.Array, reflect.Slice, reflect.Struct:
|
|
b, err := json.Marshal(v)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return fmt.Sprintf("CAST('%s' AS JSON1)", b), nil
|
|
case reflect.String:
|
|
return fmt.Sprintf("CAST('%s' AS TEXT)", v), nil
|
|
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int8, reflect.Int16,
|
|
reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint8, reflect.Uint16:
|
|
return fmt.Sprintf("CAST(%d AS INTEGER)", v), nil
|
|
case reflect.Float64, reflect.Float32:
|
|
return fmt.Sprintf("CAST(%v AS REAL)", v), nil
|
|
case reflect.Bool:
|
|
return fmt.Sprintf("CAST(%v AS BOOL)", v), nil
|
|
}
|
|
return "NULL", nil
|
|
}
|
|
|
|
func (this Engine[T]) toTables(rv reflect.Value, table string) (tables map[string][]string, err error) {
|
|
rt := rv.Type()
|
|
|
|
if rv.Kind() == reflect.Ptr {
|
|
rv = reflect.Indirect(rv.Elem())
|
|
rt = rv.Type()
|
|
}
|
|
|
|
if tables == nil {
|
|
tables = map[string][]string{}
|
|
}
|
|
|
|
if _, ok := tables[table]; !ok {
|
|
tables[table] = []string{}
|
|
}
|
|
|
|
if rv.Kind() == reflect.Struct {
|
|
// 遍历结构体的字段
|
|
for i := 0; i < rt.NumField(); i++ {
|
|
if _, ok := tables[rt.Field(i).Name]; !ok && lo.Contains(this.relations, rt.Field(i).Name) {
|
|
if items, err := this.toTables(rv.Field(i).Elem(), rt.Field(i).Name); err == nil {
|
|
tables = lo.Assign(tables, items)
|
|
} else {
|
|
return nil, err
|
|
}
|
|
} else if s, err := this.toField(rv.Field(i)); err != nil {
|
|
return nil, err
|
|
} else {
|
|
tables[table] = append(tables[table], fmt.Sprintf("%s as `%s`", s, lo.Ternary(rt.Field(i).Tag.Get("db") != "", rt.Field(i).Tag.Get("db"), rt.Field(i).Name)))
|
|
}
|
|
}
|
|
} else if rv.Kind() == reflect.Map {
|
|
iter := rv.MapRange()
|
|
for iter.Next() {
|
|
if _, ok := tables[iter.Key().String()]; !ok && lo.Contains(this.relations, iter.Key().String()) {
|
|
if items, err := this.toTables(iter.Value().Elem(), iter.Key().String()); err == nil {
|
|
tables = lo.Assign(tables, items)
|
|
} else {
|
|
return nil, err
|
|
}
|
|
} else if s, err := this.toField(iter.Value()); err != nil {
|
|
return nil, err
|
|
} else {
|
|
tables[table] = append(tables[table], fmt.Sprintf("%s as `%s`", s, iter.Key().String()))
|
|
}
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("data type not map/struct, %s", rv.Kind())
|
|
}
|
|
|
|
return tables, nil
|
|
}
|
|
|
|
func (this *Engine[T]) Execute(data T) error {
|
|
|
|
tables, err := this.toTables(reflect.ValueOf(data), this.GetCode())
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
columns, ok := tables[this.GetCode()]
|
|
|
|
if !ok {
|
|
return fmt.Errorf("data is not")
|
|
}
|
|
|
|
sql := conn.Query().FromSub(fmt.Sprintf("SELECT %s", strings.Join(append(columns, "1 as `_`"), ", ")), this.GetCode())
|
|
|
|
for table, columns := range lo.OmitByKeys(tables, []string{this.GetCode()}) {
|
|
sql.JoinSub(
|
|
fmt.Sprintf("SELECT %s", strings.Join(append(columns, "1 as `_`"), ", ")),
|
|
table,
|
|
fmt.Sprintf("%s._", this.GetCode()),
|
|
fmt.Sprintf("%s._", table),
|
|
)
|
|
}
|
|
|
|
if this.g == nil {
|
|
this.g = support.NewGlobalParam(`{}`, nil)
|
|
}
|
|
|
|
param := &EngineParam{this.g}
|
|
|
|
for i, p := range this.predicates {
|
|
sql.AddSelect(db.Raw(fmt.Sprintf("IFNULL(%s, 0) as `%d`", p.ToSql(param), i)))
|
|
}
|
|
|
|
result := map[string]int64{}
|
|
|
|
// fmt.Println("sql", sql.ToSql())
|
|
|
|
if _, err := sql.First(&result); err != nil {
|
|
return fmt.Errorf("%s => %s", err, sql.ToSql())
|
|
}
|
|
|
|
isDefault := true
|
|
for k, v := range result {
|
|
if v == 0 {
|
|
continue
|
|
}
|
|
|
|
isDefault = false
|
|
|
|
if err := this.predicates[cast.ToInt64(k)].Execute(data, this.g); err != nil {
|
|
return fmt.Errorf("case %q error, %s", k, err)
|
|
}
|
|
}
|
|
|
|
if isDefault && this.def != nil {
|
|
if err := this.def(data, this.g); err != nil {
|
|
return fmt.Errorf("case %q error, %s", "default", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func New[T any](table string, g contracts.GlobalParams, relations ...string) *Engine[T] {
|
|
return &Engine[T]{code: table, g: g, relations: relations}
|
|
}
|