db/schema/blueprint.go

457 lines
13 KiB
Go
Raw Permalink Normal View History

2023-04-12 15:58:25 +08:00
package schema
import (
"strings"
"github.com/samber/lo"
"git.fsdpf.net/go/db"
)
type Blueprint struct {
table string // the table the blueprint describes.
columns []*ColumnDefinition // columns that should be added to the table
commands []*Command //
temporary bool // Whether to make the table temporary.
charset string // The default character set that should be used for the table.
collation string // The collation that should be used for the table.
engine string // The engine that should be used for the table.
}
type Command struct {
Type string
CommandOptions
}
type CommandOptions struct {
Index string
Columns []string
Algorithm string
To string
From string
}
func NewBlueprint(table string) *Blueprint {
return &Blueprint{table: table, charset: "utf8mb4", collation: "utf8mb4_general_ci"}
}
// 字符串
func (this *Blueprint) Char(column string, length int) *ColumnDefinition {
if length == 0 {
length = 255
}
return this.addColumn("char", column, &ColumnOptions{Length: length})
}
// 可变长度字符串
func (this *Blueprint) String(column string, length int) *ColumnDefinition {
if length == 0 {
length = 255
}
return this.addColumn("string", column, &ColumnOptions{Length: length})
}
// 文本
func (this *Blueprint) Text(column string) *ColumnDefinition {
return this.addColumn("text", column, nil)
}
// 整型
func (this *Blueprint) Integer(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
unsigned := false
if len(params) > 0 {
autoIncrement = params[0]
}
if len(params) > 1 {
unsigned = params[1]
}
return this.addColumn("integer", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
}
// 迷你整型 1 byte
func (this *Blueprint) TinyInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
unsigned := false
if len(params) > 0 {
autoIncrement = params[0]
}
if len(params) > 1 {
unsigned = params[1]
}
return this.addColumn("tinyInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
}
// 小整型 2 byte
func (this *Blueprint) SmallInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
unsigned := false
if len(params) > 0 {
autoIncrement = params[0]
}
if len(params) > 1 {
unsigned = params[1]
}
return this.addColumn("smallInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
}
// 大整型 2 byte
func (this *Blueprint) BigInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
unsigned := false
if len(params) > 0 {
autoIncrement = params[0]
}
if len(params) > 1 {
unsigned = params[1]
}
return this.addColumn("bigInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
}
// 无符号整型
func (this *Blueprint) UnsignedInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
if len(params) > 0 {
autoIncrement = params[0]
}
return this.Integer(column, autoIncrement, true)
}
// 无符号迷你整型 1 byte
func (this *Blueprint) UnsignedTinyInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
if len(params) > 0 {
autoIncrement = params[0]
}
return this.TinyInteger(column, autoIncrement, true)
}
// 无符号小整型 2 byte
func (this *Blueprint) UnsignedSmallInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
if len(params) > 0 {
autoIncrement = params[0]
}
return this.SmallInteger(column, autoIncrement, true)
}
// 无符号大整型
func (this *Blueprint) UnsignedBigInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
if len(params) > 0 {
autoIncrement = params[0]
}
return this.BigInteger(column, autoIncrement, true)
}
// 精确小数
func (this *Blueprint) Decimal(column string, total, places int) *ColumnDefinition {
if total == 0 {
total = 8
}
if places == 0 {
places = 2
}
return this.addColumn("decimal", column, &ColumnOptions{Total: total, Places: places})
}
// 无符号精确销售
func (this *Blueprint) UnsignedDecimal(column string, total, places int) *ColumnDefinition {
if total == 0 {
total = 8
}
if places == 0 {
places = 2
}
return this.addColumn("decimal", column, &ColumnOptions{Total: total, Places: places, unsigned: true})
}
// 布尔值
func (this *Blueprint) Boolean(column string) *ColumnDefinition {
return this.addColumn("boolean", column, nil)
}
// 枚举类型
func (this *Blueprint) Enum(column string, allowed []string) *ColumnDefinition {
return this.addColumn("enum", column, &ColumnOptions{Allowed: allowed})
}
// JSON
func (this *Blueprint) Json(column string) *ColumnDefinition {
return this.addColumn("json", column, nil)
}
// 日期类型
func (this *Blueprint) Date(column string) *ColumnDefinition {
return this.addColumn("date", column, nil)
}
// 日期时间类型
func (this *Blueprint) DateTime(column string, precision ...int) *ColumnDefinition {
if len(precision) > 0 {
return this.addColumn("datetime", column, &ColumnOptions{Precision: precision[0]})
}
return this.addColumn("datetime", column, nil)
}
// 时间类型
func (this *Blueprint) Time(column string, precision ...int) *ColumnDefinition {
if len(precision) > 0 {
return this.addColumn("time", column, &ColumnOptions{Precision: precision[0]})
}
return this.addColumn("time", column, nil)
}
// 时间戳
func (this *Blueprint) Timestamp(column string, precision ...int) *ColumnDefinition {
if len(precision) > 0 {
return this.addColumn("timestamp", column, &ColumnOptions{Precision: precision[0]})
}
return this.addColumn("timestamp", column, nil)
}
// 年
func (this *Blueprint) Year(column string) *ColumnDefinition {
return this.addColumn("year", column, nil)
}
// 二进制数据
func (this *Blueprint) Binary(column string) *ColumnDefinition {
return this.addColumn("binary", column, nil)
}
// UUID
func (this *Blueprint) Uuid(column string) *ColumnDefinition {
return this.addColumn("uuid", column, nil)
}
// 自增字段
func (this *Blueprint) Increments(column string) *ColumnDefinition {
return this.UnsignedInteger(column, true)
}
// 自增Big字段
func (this *Blueprint) BigIncrements(column string) *ColumnDefinition {
return this.UnsignedBigInteger(column, true)
}
// 添加主键
func (this *Blueprint) Primary(columns ...string) *Command {
return this.addCommand("primary", CommandOptions{Index: this.generateIndexName("pk", columns), Columns: columns})
}
// 唯一键
func (this *Blueprint) Unique(columns ...string) *Command {
return this.addCommand("unique", CommandOptions{Index: this.generateIndexName("unique", columns), Columns: columns})
}
// 普通索引
func (this *Blueprint) Index(columns ...string) *Command {
return this.addCommand("index", CommandOptions{Index: this.generateIndexName("index", columns), Columns: columns})
}
// 空间索引
func (this *Blueprint) SpatialIndex(columns ...string) *Command {
return this.addCommand("spatialIndex", CommandOptions{Index: this.generateIndexName("spatial_index", columns), Columns: columns})
}
// 删除列
func (this *Blueprint) DropColumn(columns ...string) *Command {
return this.addCommand("dropColumn", CommandOptions{Columns: columns})
}
// 创建表
func (this *Blueprint) Create() *Command {
return this.addCommand("create", CommandOptions{})
}
// 设置临时表标记
func (this *Blueprint) Temporary() {
this.temporary = true
}
// 设置表字符集
func (this *Blueprint) Charset(charset string) {
this.charset = charset
}
// 修改表名
func (this *Blueprint) Rename(to string) *Command {
return this.addCommand("rename", CommandOptions{To: to})
}
// 删除表
func (this *Blueprint) Drop() *Command {
return this.addCommand("drop", CommandOptions{})
}
// 删除表, 先判断再删除
func (this *Blueprint) DropIfExists() *Command {
return this.addCommand("dropIfExists", CommandOptions{})
}
// 执行SQL查询
func (this *Blueprint) Build(conn *db.Connection, grammar IGrammar) error {
_, err := conn.Transaction(func(tx *db.Transaction) (result any, err error) {
for _, sql := range this.ToSql(conn, grammar) {
if _, err := tx.Statement(sql, []any{}); err != nil {
tx.Rollback()
return nil, err
}
}
tx.Commit()
return
})
return err
}
func (this *Blueprint) ToSql(conn *db.Connection, grammar IGrammar) (statements []string) {
this.addImpliedCommands(grammar)
for _, cmd := range this.commands {
switch cmd.Type {
case "create":
statements = append(statements, grammar.CompileCreate(this, cmd, conn))
case "add":
statements = append(statements, grammar.CompileAdd(this, cmd, conn)...)
case "change":
statements = append(statements, grammar.CompileChange(this, cmd, conn)...)
case "primary":
// statements = append(statements, grammar.CompilePrimary(this, cmd, conn))
case "unique":
statements = append(statements, grammar.CompileUnique(this, cmd, conn))
case "index":
statements = append(statements, grammar.CompileIndex(this, cmd, conn))
case "spatialIndex":
statements = append(statements, grammar.CompileSpatialIndex(this, cmd, conn))
case "drop":
statements = append(statements, grammar.CompileDrop(this, cmd, conn))
case "dropIfExists":
statements = append(statements, grammar.CompileDropIfExists(this, cmd, conn))
case "dropColumn":
statements = append(statements, grammar.CompileDropColumn(this, cmd, conn)...)
case "dropPrimary":
// statements = append(statements, grammar.CompileDropPrimary(this, cmd, conn))
case "dropUnique":
statements = append(statements, grammar.CompileDropUnique(this, cmd, conn))
case "dropIndex":
statements = append(statements, grammar.CompileDropIndex(this, cmd, conn))
case "dropSpatialIndex":
statements = append(statements, grammar.CompileDropSpatialIndex(this, cmd, conn))
case "rename":
statements = append(statements, grammar.CompileRename(this, cmd, conn))
case "renameIndex":
statements = append(statements, grammar.CompileRenameIndex(this, cmd, conn))
case "dropAllTables":
case "dropAllViews":
case "getAllTables":
case "getAllViews":
}
}
return statements
}
// 判断是否是创建表
func (this *Blueprint) creating() bool {
return lo.SomeBy(this.commands, func(item *Command) bool {
if item.Type == "create" {
return true
}
return false
})
}
func (this *Blueprint) addColumn(typ, name string, options *ColumnOptions) (definition *ColumnDefinition) {
definition = &ColumnDefinition{Type: typ, Name: name}
if options != nil {
if options.Length > 0 {
definition.Length = options.Length
}
if options.autoIncrement {
definition.autoIncrement = true
}
if options.unsigned {
definition.unsigned = true
}
if options.Total > 0 {
definition.Total = options.Total
}
if options.Places > 0 {
definition.Places = options.Places
}
if len(options.Allowed) > 0 {
definition.Allowed = options.Allowed
}
if options.Precision > 0 {
definition.Precision = options.Precision
}
if options.change {
definition.change = true
}
}
this.columns = append(this.columns, definition)
return definition
}
func (this *Blueprint) addImpliedCommands(grammar IGrammar) {
if !this.creating() {
if len(this.getAddedColumns()) > 0 {
this.commands = append([]*Command{this.createCommand("add", CommandOptions{})}, this.commands...)
}
if len(this.getChangedColumns()) > 0 {
this.commands = append([]*Command{this.createCommand("change", CommandOptions{})}, this.commands...)
}
}
this.addFluentIndexes()
}
// 添加索引字段
func (this *Blueprint) addFluentIndexes() {
for _, column := range this.columns {
if column.primary {
this.Primary(column.Name)
continue
} else if column.unique {
this.Unique(column.Name)
continue
} else if column.index {
this.Index(column.Name)
continue
} else if column.spatialIndex {
this.SpatialIndex(column.Name)
continue
}
}
}
func (this *Blueprint) addCommand(name string, options CommandOptions) (command *Command) {
command = this.createCommand(name, options)
this.commands = append(this.commands, command)
return command
}
func (this *Blueprint) createCommand(name string, options CommandOptions) *Command {
return &Command{Type: name, CommandOptions: options}
}
// 生成索引名称
func (this *Blueprint) generateIndexName(typ string, columns []string) string {
return strings.ToLower(typ + "_" + strings.Join(columns, "_"))
}
func (this *Blueprint) getAddedColumns() []*ColumnDefinition {
return lo.Filter(this.columns, func(item *ColumnDefinition, _ int) bool {
return !item.change
})
}
func (this *Blueprint) getChangedColumns() []*ColumnDefinition {
return lo.Filter(this.columns, func(item *ColumnDefinition, _ int) bool {
return item.change
})
}