condition/engine/engine.go

226 lines
5.3 KiB
Go
Raw Normal View History

2023-07-20 14:55:00 +08:00
package engine
import (
"encoding/json"
"fmt"
"reflect"
"strings"
"git.fsdpf.net/go/contracts"
"git.fsdpf.net/go/contracts/support"
"git.fsdpf.net/go/db"
"github.com/samber/lo"
"github.com/spf13/cast"
)
var conn *db.Connection
2023-07-24 13:39:12 +08:00
var defaultEngineOptions = engineOptions{
debug: false,
relations: []string{},
}
2023-07-20 14:55:00 +08:00
func init() {
database := db.Open(map[string]db.DBConfig{
"condition-engine-sqlite3": {
Driver: "sqlite3",
File: ":memory:",
},
})
conn = database.Connection("condition-engine-sqlite3")
}
type Engine[T any] struct {
code string
2023-07-24 13:39:12 +08:00
opts engineOptions
2023-07-20 14:55:00 +08:00
g contracts.GlobalParams
def func(data T, g contracts.GlobalParams) error
predicates []*EngineCase[T]
}
func (this Engine[T]) GetCode() string {
return this.code
}
// 公共参数
func (this *Engine[T]) SetGlobalParams(g contracts.GlobalParams) *Engine[T] {
this.g = g
return this
}
func (this *Engine[T]) Case(cond contracts.Condition, cb func(data T, g contracts.GlobalParams) error) *Engine[T] {
this.predicates = append(this.predicates, &EngineCase[T]{cond, cb})
return this
}
// 基础条件
func (this *Engine[T]) Default(cb func(data T, g contracts.GlobalParams) error) *Engine[T] {
this.def = cb
return this
}
func (this Engine[T]) toField(rv reflect.Value) (string, error) {
if rv.Kind() == reflect.Ptr {
rv = reflect.Indirect(rv)
}
if rv.IsZero() {
return "NULL", nil
}
v := rv.Interface()
switch reflect.TypeOf(v).Kind() {
case reflect.Map, reflect.Array, reflect.Slice, reflect.Struct:
b, err := json.Marshal(v)
if err != nil {
return "", err
}
return fmt.Sprintf("CAST('%s' AS JSON1)", b), nil
case reflect.String:
return fmt.Sprintf("CAST('%s' AS TEXT)", v), nil
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int8, reflect.Int16,
reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint8, reflect.Uint16:
return fmt.Sprintf("CAST(%d AS INTEGER)", v), nil
case reflect.Float64, reflect.Float32:
return fmt.Sprintf("CAST(%v AS REAL)", v), nil
case reflect.Bool:
return fmt.Sprintf("CAST(%v AS BOOL)", v), nil
}
return "NULL", nil
}
func (this Engine[T]) toTables(rv reflect.Value, table string) (tables map[string][]string, err error) {
rt := rv.Type()
if rv.Kind() == reflect.Ptr {
rv = reflect.Indirect(rv.Elem())
rt = rv.Type()
}
if tables == nil {
tables = map[string][]string{}
}
if _, ok := tables[table]; !ok {
tables[table] = []string{}
}
if rv.Kind() == reflect.Struct {
// 遍历结构体的字段
for i := 0; i < rt.NumField(); i++ {
2023-07-24 13:39:12 +08:00
if _, ok := tables[rt.Field(i).Name]; !ok && lo.Contains(this.opts.relations, rt.Field(i).Name) {
2023-07-20 14:55:00 +08:00
if items, err := this.toTables(rv.Field(i).Elem(), rt.Field(i).Name); err == nil {
tables = lo.Assign(tables, items)
} else {
return nil, err
}
} else if s, err := this.toField(rv.Field(i)); err != nil {
return nil, err
} else {
tables[table] = append(tables[table], fmt.Sprintf("%s as `%s`", s, lo.Ternary(rt.Field(i).Tag.Get("db") != "", rt.Field(i).Tag.Get("db"), rt.Field(i).Name)))
}
}
} else if rv.Kind() == reflect.Map {
iter := rv.MapRange()
for iter.Next() {
2023-07-24 13:39:12 +08:00
if _, ok := tables[iter.Key().String()]; !ok && lo.Contains(this.opts.relations, iter.Key().String()) {
2023-07-20 14:55:00 +08:00
if items, err := this.toTables(iter.Value().Elem(), iter.Key().String()); err == nil {
tables = lo.Assign(tables, items)
} else {
return nil, err
}
} else if s, err := this.toField(iter.Value()); err != nil {
return nil, err
} else {
tables[table] = append(tables[table], fmt.Sprintf("%s as `%s`", s, iter.Key().String()))
}
}
} else {
return nil, fmt.Errorf("data type not map/struct, %s", rv.Kind())
}
return tables, nil
}
func (this *Engine[T]) Execute(data T) error {
tables, err := this.toTables(reflect.ValueOf(data), this.GetCode())
if err != nil {
return err
}
columns, ok := tables[this.GetCode()]
if !ok {
return fmt.Errorf("data is not")
}
sql := conn.Query().FromSub(fmt.Sprintf("SELECT %s", strings.Join(append(columns, "1 as `_`"), ", ")), this.GetCode())
for table, columns := range lo.OmitByKeys(tables, []string{this.GetCode()}) {
sql.JoinSub(
fmt.Sprintf("SELECT %s", strings.Join(append(columns, "1 as `_`"), ", ")),
table,
fmt.Sprintf("%s._", this.GetCode()),
fmt.Sprintf("%s._", table),
)
}
if this.g == nil {
this.g = support.NewGlobalParam(`{}`, nil)
}
param := &EngineParam{this.g}
for i, p := range this.predicates {
sql.AddSelect(db.Raw(fmt.Sprintf("IFNULL(%s, 0) as `%d`", p.ToSql(param), i)))
}
result := map[string]int64{}
2023-07-24 13:39:12 +08:00
if this.opts.debug {
fmt.Println("sql", sql.ToSql())
}
2023-07-20 14:55:00 +08:00
if _, err := sql.First(&result); err != nil {
return fmt.Errorf("%s => %s", err, sql.ToSql())
}
isDefault := true
for k, v := range result {
if v == 0 {
continue
}
isDefault = false
if err := this.predicates[cast.ToInt64(k)].Execute(data, this.g); err != nil {
return fmt.Errorf("case %q error, %s", k, err)
}
}
if isDefault && this.def != nil {
if err := this.def(data, this.g); err != nil {
return fmt.Errorf("case %q error, %s", "default", err)
}
}
return nil
}
2023-07-24 13:39:12 +08:00
func New[T any](table string, opt ...EngineOption) *Engine[T] {
opts := defaultEngineOptions
for _, o := range opt {
o.apply(&opts)
}
return &Engine[T]{
code: table,
opts: opts,
g: support.NewGlobalParam("", nil),
}
2023-07-20 14:55:00 +08:00
}