diff --git a/base/res_relation.go b/base/res_relation.go index 71f0e4a..dfa31f5 100644 --- a/base/res_relation.go +++ b/base/res_relation.go @@ -22,6 +22,9 @@ type GetResRelations func(categoryUuid string) []ResRelation type GetOrmJoinsByResRelations func(root string, items []ResRelation) []contracts.Join +// 获取被关联对象的资源 +type GetResRelationResource func(r ResRelation) Resource + // 资源关联依赖 func GetJoinResDependencies(root string, items []ResRelation) (dependencies []string) { dependencies = []string{root} diff --git a/base/resource.go b/base/resource.go index ca0181f..ea27b8f 100644 --- a/base/resource.go +++ b/base/resource.go @@ -2,6 +2,7 @@ package base import ( "database/sql" + "fmt" "log" "reflect" "unicode" @@ -10,6 +11,7 @@ import ( "github.com/samber/lo" "git.fsdpf.net/go/contracts" + "git.fsdpf.net/go/contracts/helper" "git.fsdpf.net/go/contracts/res_type" "git.fsdpf.net/go/db" ) @@ -142,7 +144,7 @@ func (this Resource) GetDBDriver() string { } func (this Resource) GetAuthDBTable(u contracts.User, params ...any) *db.Builder { - this.GetRolesCondition(u) + fmt.Println(this.GetRolesCondition(u).ToSql()) return this.GetDBTable(append(params, u)...) } @@ -205,30 +207,80 @@ func (this Resource) WithRolesCondition(b *db.Builder, roles ...string) { } // 获取鉴权条件 -func (this Resource) GetRolesCondition(u contracts.User) { - // isFullRight := false - // isFullNot := false +func (this Resource) GetRolesCondition(u contracts.User) *db.Builder { + isFullRight := false + isFullNot := false - // roles := do.MustInvoke[GetResRoles](this.container)(this.GetUuid()) - // GetResRelations := do.MustInvoke[GetResRelations](this.container) + NewOrmConditionByRes := do.MustInvoke[helper.NewOrmConditionByRes](this.container) + NewOrmJoin := do.MustInvoke[helper.NewOrmJoin](this.container) + GetResRelations := do.MustInvoke[GetResRelations](this.container) + GetResource := do.MustInvoke[contracts.GetResource](this.container) + GetResConditions := do.MustInvoke[GetResConditions](this.container) - // subTables := lo.Reduce(roles, func(carry *db.Builder, item ResRole, _ int) *db.Builder { - // db := this.GetDB().Table(string(this.GetTable()), this.GetCode()).Select("`" + this.GetCode() + "`.*") + roles := do.MustInvoke[GetResRoles](this.container)(this.GetUuid()) - // joins := lo.Filter(GetResRelations(item.Uuid), func(item ResRelation, _ int) bool { - // return item.Type == "inner" || item.Type == "left" || item.Type == "right" - // }) + subTables := lo.Reduce(roles, func(carry string, item ResRole, _ int) string { + db := this.GetDB().Table(string(this.GetTable()), this.GetCode()).Select(db.Raw("`" + this.GetCode() + "`.*")) - // join := orm.NewJoin(contracts.RelationType(item.Type), item.Code, oResource, item.RelationResource, item.RelationField, item.RelationForeignKey) + joins := lo.Filter(GetResRelations(item.Uuid), func(item ResRelation, _ int) bool { + return item.Type == "inner" || item.Type == "left" || item.Type == "right" + }) - // // 关联扩展条件 - // join.SetCondition(orm.NewConditionByRes(GetResConditions(item.Uuid))) + conditions := NewOrmConditionByRes(GetResConditions(item.Uuid)) - // return carry - // }, nil) + for i := 0; i < len(joins); i++ { + oResource, ok := GetResource(joins[i].ResourceCode) + if !ok { + continue + } + join := NewOrmJoin(contracts.RelationType(joins[i].Type), oResource, joins[i].Code, joins[i].RelationResource, joins[i].RelationField, joins[i].RelationForeignKey) + // 关联扩展条件 + join.SetCondition(NewOrmConditionByRes(GetResConditions(joins[i].Uuid))) + join.Inject(db, nil) + } - // fmt.Println(subTables.ToSql()) + if len(joins) == 0 && conditions.IsEmpty() { + // 无权限, 直接跳过这个 unoin 语句 + if carry != "" { + return carry + } + // 第一个无权限除外, 避免所有用户所属角色都是无权限 + db.WhereRaw("false") + isFullNot = true + } else if len(joins) == 0 && conditions.IsNotEmpty() && conditions.IsRight() /* 1=1 的这种条件*/ { + // 只要有1个满权限, 直接返回单条语句 + isFullRight = true + return db.ToSql() + } else if conditions.IsNotEmpty() { + db.WhereRaw(string(conditions.ToSql(nil))) + // 如果前面是无权限的sql查看, 这直接返回本次查询 + if isFullNot { + isFullNot = false + return db.ToSql() + } + } + + if carry != "" { + carry += " union " + } + + carry += db.ToSql() + + return carry + }, "") + + // @todo this.GetCode 要换成 alias + + if isFullRight { + this.GetDB().Table(string(this.GetTable()), this.GetCode()) + } else if isFullNot { + this.GetDB().Table(string(this.GetTable()), this.GetCode()).WhereRaw("false") + } else if subTables != "" { + this.GetDB().Table(subTables, this.GetCode()) + } + + return this.GetDB().Table(string(this.GetTable()), this.GetCode()) } // 格式化保存数据