339 lines
9.8 KiB
Go
339 lines
9.8 KiB
Go
package schema
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/samber/lo"
|
|
|
|
"git.fsdpf.net/go/db"
|
|
)
|
|
|
|
type MysqlGrammar struct {
|
|
}
|
|
|
|
var mysqlDefaultModifiers = []string{
|
|
"Unsigned", "Charset", "Collate", "VirtualAs", "StoredAs",
|
|
"Nullable", "Default", "Increment", "Comment", "After", "First",
|
|
}
|
|
var mysqlSerials = []string{
|
|
"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger",
|
|
}
|
|
|
|
func (this MysqlGrammar) CompileTableExists() string {
|
|
return "select count(*) from information_schema.tables where table_name = ? and table_schema = ? and table_type = 'BASE TABLE'"
|
|
}
|
|
|
|
func (this MysqlGrammar) CompileColumnListing(table string) string {
|
|
return "select ordinal_position as `id`, column_name as `name` from information_schema.columns where table_name = ? and table_schema = ? order by ordinal_position"
|
|
}
|
|
|
|
// 创建表
|
|
func (this MysqlGrammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
sql := this.compileCreateTable(bp, command, conn)
|
|
sql = this.compileCreateEncoding(sql, bp, conn)
|
|
sql = this.compileCreateEngine(sql, bp, conn)
|
|
return sql
|
|
}
|
|
|
|
// 创建表结构
|
|
func (this MysqlGrammar) compileCreateTable(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
temporary := "create"
|
|
if bp.temporary {
|
|
temporary = "create temporary"
|
|
}
|
|
return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", "))
|
|
}
|
|
|
|
// 创建表 字符集
|
|
func (this MysqlGrammar) compileCreateEncoding(sql string, bp *Blueprint, conn *db.Connection) string {
|
|
if bp.charset != "" {
|
|
sql = sql + " default charset=" + bp.charset
|
|
} else if conn.GetConfig().Charset != "" {
|
|
sql = sql + " default charset=" + conn.GetConfig().Charset
|
|
} else {
|
|
sql = sql + " default charset=utf8mb4"
|
|
}
|
|
|
|
if bp.collation != "" {
|
|
sql = sql + " collate=" + bp.collation
|
|
} else {
|
|
sql = sql + " collate=utf8mb4_general_ci"
|
|
}
|
|
|
|
return sql
|
|
}
|
|
|
|
// 创建表存储方式
|
|
func (this MysqlGrammar) compileCreateEngine(sql string, bp *Blueprint, conn *db.Connection) string {
|
|
if bp.engine != "" {
|
|
sql = sql + " engine = " + bp.engine
|
|
}
|
|
return sql
|
|
}
|
|
|
|
// 添加字段
|
|
func (this MysqlGrammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string {
|
|
columns := PrefixArray("add", this.getAddedColumns(bp))
|
|
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
|
|
}
|
|
|
|
// 修改字段
|
|
func (this MysqlGrammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) []string {
|
|
columns := PrefixArray("change", this.getChangedColumns(bp))
|
|
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
|
|
}
|
|
|
|
// 添加主键
|
|
func (this MysqlGrammar) CompilePrimary(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return this.compileKey(bp, command, "primary key")
|
|
}
|
|
|
|
// 添加唯一键
|
|
func (this MysqlGrammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return this.compileKey(bp, command, "unique")
|
|
}
|
|
|
|
// 添加普通索引
|
|
func (this MysqlGrammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return this.compileKey(bp, command, "index")
|
|
}
|
|
|
|
// 添加空间索引
|
|
func (this MysqlGrammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return this.compileKey(bp, command, "spatial index")
|
|
}
|
|
|
|
// 添加索引
|
|
func (this MysqlGrammar) compileKey(bp *Blueprint, command *Command, typ string) string {
|
|
algorithm := ""
|
|
if command.Algorithm != "" {
|
|
algorithm = " using " + command.Algorithm
|
|
}
|
|
columns := lo.Map(command.Columns, func(column string, _ int) string {
|
|
return "`" + column + "`"
|
|
})
|
|
return fmt.Sprintf("alter table %s add %s %s%s(%s)", bp.table, typ, command.Index, algorithm, strings.Join(columns, ","))
|
|
}
|
|
|
|
// 删除表
|
|
func (this MysqlGrammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return "drop table `" + bp.table + "`"
|
|
}
|
|
|
|
// 删除表, 先判断再删除
|
|
func (this MysqlGrammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return "drop table if exists `" + bp.table + "`"
|
|
}
|
|
|
|
// 删除列
|
|
func (this MysqlGrammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) []string {
|
|
columns := lo.Map(command.Columns, func(column string, _ int) string {
|
|
return "drop `" + column + "`"
|
|
})
|
|
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
|
|
}
|
|
|
|
// 删除主键
|
|
func (this MysqlGrammar) CompileDropPrimary(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return "alter table `" + bp.table + "` drop primary key"
|
|
}
|
|
|
|
// 删除索引
|
|
func (this MysqlGrammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return "alter table `" + bp.table + "` drop index `" + command.Index + "`"
|
|
}
|
|
|
|
// 删除唯一键
|
|
func (this MysqlGrammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return this.CompileDropIndex(bp, command, conn)
|
|
}
|
|
|
|
// 删除空间索引
|
|
func (this MysqlGrammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return this.CompileDropIndex(bp, command, conn)
|
|
}
|
|
|
|
// 表重命名
|
|
func (this MysqlGrammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return "rename table `" + bp.table + "` to `" + command.To + "`"
|
|
}
|
|
|
|
// 索引重命名
|
|
func (this MysqlGrammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
|
return fmt.Sprintf("alter table `%s` rename index `%s` to `%s`", bp.table, command.From, command.To)
|
|
}
|
|
|
|
// 删除所有表
|
|
func (this MysqlGrammar) CompileDropAllTables(tables ...string) string {
|
|
tables = lo.Map(tables, func(table string, _ int) string {
|
|
return "`" + table + "`"
|
|
})
|
|
return "drop table " + strings.Join(tables, ", ")
|
|
}
|
|
|
|
// 删除所有视图
|
|
func (this MysqlGrammar) CompileDropAllViews(views ...string) string {
|
|
views = lo.Map(views, func(view string, _ int) string {
|
|
return "`" + view + "`"
|
|
})
|
|
return "drop view " + strings.Join(views, ", ")
|
|
}
|
|
|
|
// 查询所有表
|
|
// func (this MysqlGrammar) CompileGetAllTables(tables ...string) string {
|
|
|
|
// }
|
|
|
|
// 查询所有视图
|
|
// func (this MysqlGrammar) CompileGetAllViews(views ...string) string {}
|
|
|
|
func (this MysqlGrammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string {
|
|
for _, modifier := range mysqlDefaultModifiers {
|
|
sql = sql + this.GetColumnModifier(modifier, bp, column)
|
|
}
|
|
return sql
|
|
}
|
|
|
|
func (this MysqlGrammar) GetColumnType(column *ColumnDefinition) string {
|
|
switch column.Type {
|
|
case "char":
|
|
return fmt.Sprintf("char(%d)", column.Length)
|
|
case "string":
|
|
return fmt.Sprintf("varchar(%d)", column.Length)
|
|
case "text":
|
|
return "text"
|
|
case "integer":
|
|
return "int"
|
|
case "bigInteger":
|
|
return "bigint"
|
|
case "tinyInteger":
|
|
return "tinyint"
|
|
case "smallInteger":
|
|
return "smallint"
|
|
case "decimal":
|
|
return fmt.Sprintf("decimal(%d,%d)", column.Total, column.Places)
|
|
case "boolean":
|
|
return "tinyint(1)"
|
|
case "enum":
|
|
return fmt.Sprintf("enum(%s)", QuoteString(column.Allowed))
|
|
case "json":
|
|
return "json"
|
|
case "date":
|
|
return "date"
|
|
case "binary":
|
|
return "binary"
|
|
case "datetime":
|
|
columnType := "datetime"
|
|
if column.Precision > 0 {
|
|
columnType = fmt.Sprintf("datetime(%d)", column.Precision)
|
|
}
|
|
return columnType
|
|
case "time":
|
|
columnType := "time"
|
|
if column.Precision > 0 {
|
|
columnType = fmt.Sprintf("time(%d)", column.Precision)
|
|
}
|
|
return columnType
|
|
case "timestamp":
|
|
columnType := "timestamp"
|
|
if column.Precision > 0 {
|
|
columnType = fmt.Sprintf("timestamp(%d)", column.Precision)
|
|
}
|
|
return columnType
|
|
case "year":
|
|
return "year"
|
|
case "uuid":
|
|
return "char(36)"
|
|
}
|
|
panic("不支持的数据类型: " + column.Type)
|
|
}
|
|
|
|
func (this MysqlGrammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string {
|
|
switch modifier {
|
|
case "Unsigned":
|
|
if column.unsigned {
|
|
return " unsigned"
|
|
}
|
|
case "Charset":
|
|
if column.charset != "" {
|
|
return fmt.Sprintf(" character set %s", column.charset)
|
|
}
|
|
case "Collate":
|
|
if column.collation != "" {
|
|
return fmt.Sprintf(" collate '%s'", column.collation)
|
|
}
|
|
case "VirtualAs":
|
|
if column.virtualAs != "" {
|
|
return fmt.Sprintf(" as (%s)", column.virtualAs)
|
|
}
|
|
case "StoredAs":
|
|
if column.storedAs != "" {
|
|
return fmt.Sprintf(" as (%s) stored", column.storedAs)
|
|
}
|
|
case "Nullable":
|
|
if column.virtualAs == "" && column.storedAs == "" {
|
|
if column.nullable {
|
|
return " null"
|
|
}
|
|
return " not null"
|
|
}
|
|
case "Default":
|
|
switch v := column.def.(type) {
|
|
case db.Expression:
|
|
if column.useCurrent {
|
|
return fmt.Sprintf(" default %s %s", "CURRENT_TIMESTAMP", v)
|
|
}
|
|
return fmt.Sprintf(" default %s", v)
|
|
case nil:
|
|
if column.useCurrent {
|
|
return fmt.Sprintf(" default %s", "CURRENT_TIMESTAMP")
|
|
}
|
|
default:
|
|
return fmt.Sprintf(" default '%v'", v)
|
|
}
|
|
|
|
case "Increment":
|
|
if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) {
|
|
return " auto_increment primary key"
|
|
}
|
|
case "Comment":
|
|
if column.comment != "" {
|
|
return fmt.Sprintf(" comment '%s'", db.Addslashes(column.comment))
|
|
}
|
|
case "After":
|
|
if column.after != "" {
|
|
return fmt.Sprintf(" after `%s`", column.after)
|
|
}
|
|
case "First":
|
|
if column.first {
|
|
return " first"
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// 获取新增表结构字段
|
|
func (this MysqlGrammar) getAddedColumns(bp *Blueprint) (columns []string) {
|
|
for _, column := range bp.getAddedColumns() {
|
|
sql := "`" + column.Name + "` " + this.GetColumnType(column)
|
|
columns = append(columns, this.addModifiers(sql, bp, column))
|
|
}
|
|
return
|
|
}
|
|
|
|
// 获取修改表结构字段
|
|
func (this MysqlGrammar) getChangedColumns(bp *Blueprint) (columns []string) {
|
|
for _, column := range bp.getChangedColumns() {
|
|
name := column.Name
|
|
rename := column.Name
|
|
if column.rename != "" {
|
|
rename = column.rename
|
|
}
|
|
sql := fmt.Sprintf("`%s` `%s` ", name, rename) + this.GetColumnType(column)
|
|
columns = append(columns, this.addModifiers(sql, bp, column))
|
|
}
|
|
return
|
|
}
|