db/schema/mysql_grammar.go
2023-04-12 15:58:25 +08:00

339 lines
9.8 KiB
Go

package schema
import (
"fmt"
"strings"
"github.com/samber/lo"
"git.fsdpf.net/go/db"
)
type MysqlGrammar struct {
}
var mysqlDefaultModifiers = []string{
"Unsigned", "Charset", "Collate", "VirtualAs", "StoredAs",
"Nullable", "Default", "Increment", "Comment", "After", "First",
}
var mysqlSerials = []string{
"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger",
}
func (this MysqlGrammar) CompileTableExists() string {
return "select count(*) from information_schema.tables where table_name = ? and table_schema = ? and table_type = 'BASE TABLE'"
}
func (this MysqlGrammar) CompileColumnListing(table string) string {
return "select ordinal_position as `id`, column_name as `name` from information_schema.columns where table_name = ? and table_schema = ? order by ordinal_position"
}
// 创建表
func (this MysqlGrammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string {
sql := this.compileCreateTable(bp, command, conn)
sql = this.compileCreateEncoding(sql, bp, conn)
sql = this.compileCreateEngine(sql, bp, conn)
return sql
}
// 创建表结构
func (this MysqlGrammar) compileCreateTable(bp *Blueprint, command *Command, conn *db.Connection) string {
temporary := "create"
if bp.temporary {
temporary = "create temporary"
}
return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", "))
}
// 创建表 字符集
func (this MysqlGrammar) compileCreateEncoding(sql string, bp *Blueprint, conn *db.Connection) string {
if bp.charset != "" {
sql = sql + " default charset=" + bp.charset
} else if conn.GetConfig().Charset != "" {
sql = sql + " default charset=" + conn.GetConfig().Charset
} else {
sql = sql + " default charset=utf8mb4"
}
if bp.collation != "" {
sql = sql + " collate=" + bp.collation
} else {
sql = sql + " collate=utf8mb4_general_ci"
}
return sql
}
// 创建表存储方式
func (this MysqlGrammar) compileCreateEngine(sql string, bp *Blueprint, conn *db.Connection) string {
if bp.engine != "" {
sql = sql + " engine = " + bp.engine
}
return sql
}
// 添加字段
func (this MysqlGrammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string {
columns := PrefixArray("add", this.getAddedColumns(bp))
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
}
// 修改字段
func (this MysqlGrammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) []string {
columns := PrefixArray("change", this.getChangedColumns(bp))
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
}
// 添加主键
func (this MysqlGrammar) CompilePrimary(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.compileKey(bp, command, "primary key")
}
// 添加唯一键
func (this MysqlGrammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.compileKey(bp, command, "unique")
}
// 添加普通索引
func (this MysqlGrammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.compileKey(bp, command, "index")
}
// 添加空间索引
func (this MysqlGrammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.compileKey(bp, command, "spatial index")
}
// 添加索引
func (this MysqlGrammar) compileKey(bp *Blueprint, command *Command, typ string) string {
algorithm := ""
if command.Algorithm != "" {
algorithm = " using " + command.Algorithm
}
columns := lo.Map(command.Columns, func(column string, _ int) string {
return "`" + column + "`"
})
return fmt.Sprintf("alter table %s add %s %s%s(%s)", bp.table, typ, command.Index, algorithm, strings.Join(columns, ","))
}
// 删除表
func (this MysqlGrammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop table `" + bp.table + "`"
}
// 删除表, 先判断再删除
func (this MysqlGrammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop table if exists `" + bp.table + "`"
}
// 删除列
func (this MysqlGrammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) []string {
columns := lo.Map(command.Columns, func(column string, _ int) string {
return "drop `" + column + "`"
})
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
}
// 删除主键
func (this MysqlGrammar) CompileDropPrimary(bp *Blueprint, command *Command, conn *db.Connection) string {
return "alter table `" + bp.table + "` drop primary key"
}
// 删除索引
func (this MysqlGrammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return "alter table `" + bp.table + "` drop index `" + command.Index + "`"
}
// 删除唯一键
func (this MysqlGrammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.CompileDropIndex(bp, command, conn)
}
// 删除空间索引
func (this MysqlGrammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.CompileDropIndex(bp, command, conn)
}
// 表重命名
func (this MysqlGrammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string {
return "rename table `" + bp.table + "` to `" + command.To + "`"
}
// 索引重命名
func (this MysqlGrammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return fmt.Sprintf("alter table `%s` rename index `%s` to `%s`", bp.table, command.From, command.To)
}
// 删除所有表
func (this MysqlGrammar) CompileDropAllTables(tables ...string) string {
tables = lo.Map(tables, func(table string, _ int) string {
return "`" + table + "`"
})
return "drop table " + strings.Join(tables, ", ")
}
// 删除所有视图
func (this MysqlGrammar) CompileDropAllViews(views ...string) string {
views = lo.Map(views, func(view string, _ int) string {
return "`" + view + "`"
})
return "drop view " + strings.Join(views, ", ")
}
// 查询所有表
// func (this MysqlGrammar) CompileGetAllTables(tables ...string) string {
// }
// 查询所有视图
// func (this MysqlGrammar) CompileGetAllViews(views ...string) string {}
func (this MysqlGrammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string {
for _, modifier := range mysqlDefaultModifiers {
sql = sql + this.GetColumnModifier(modifier, bp, column)
}
return sql
}
func (this MysqlGrammar) GetColumnType(column *ColumnDefinition) string {
switch column.Type {
case "char":
return fmt.Sprintf("char(%d)", column.Length)
case "string":
return fmt.Sprintf("varchar(%d)", column.Length)
case "text":
return "text"
case "integer":
return "int"
case "bigInteger":
return "bigint"
case "tinyInteger":
return "tinyint"
case "smallInteger":
return "smallint"
case "decimal":
return fmt.Sprintf("decimal(%d,%d)", column.Total, column.Places)
case "boolean":
return "tinyint(1)"
case "enum":
return fmt.Sprintf("enum(%s)", QuoteString(column.Allowed))
case "json":
return "json"
case "date":
return "date"
case "binary":
return "binary"
case "datetime":
columnType := "datetime"
if column.Precision > 0 {
columnType = fmt.Sprintf("datetime(%d)", column.Precision)
}
return columnType
case "time":
columnType := "time"
if column.Precision > 0 {
columnType = fmt.Sprintf("time(%d)", column.Precision)
}
return columnType
case "timestamp":
columnType := "timestamp"
if column.Precision > 0 {
columnType = fmt.Sprintf("timestamp(%d)", column.Precision)
}
return columnType
case "year":
return "year"
case "uuid":
return "char(36)"
}
panic("不支持的数据类型: " + column.Type)
}
func (this MysqlGrammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string {
switch modifier {
case "Unsigned":
if column.unsigned {
return " unsigned"
}
case "Charset":
if column.charset != "" {
return fmt.Sprintf(" character set %s", column.charset)
}
case "Collate":
if column.collation != "" {
return fmt.Sprintf(" collate '%s'", column.collation)
}
case "VirtualAs":
if column.virtualAs != "" {
return fmt.Sprintf(" as (%s)", column.virtualAs)
}
case "StoredAs":
if column.storedAs != "" {
return fmt.Sprintf(" as (%s) stored", column.storedAs)
}
case "Nullable":
if column.virtualAs == "" && column.storedAs == "" {
if column.nullable {
return " null"
}
return " not null"
}
case "Default":
switch v := column.def.(type) {
case db.Expression:
if column.useCurrent {
return fmt.Sprintf(" default %s %s", "CURRENT_TIMESTAMP", v)
}
return fmt.Sprintf(" default %s", v)
case nil:
if column.useCurrent {
return fmt.Sprintf(" default %s", "CURRENT_TIMESTAMP")
}
default:
return fmt.Sprintf(" default '%v'", v)
}
case "Increment":
if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) {
return " auto_increment primary key"
}
case "Comment":
if column.comment != "" {
return fmt.Sprintf(" comment '%s'", db.Addslashes(column.comment))
}
case "After":
if column.after != "" {
return fmt.Sprintf(" after `%s`", column.after)
}
case "First":
if column.first {
return " first"
}
}
return ""
}
// 获取新增表结构字段
func (this MysqlGrammar) getAddedColumns(bp *Blueprint) (columns []string) {
for _, column := range bp.getAddedColumns() {
sql := "`" + column.Name + "` " + this.GetColumnType(column)
columns = append(columns, this.addModifiers(sql, bp, column))
}
return
}
// 获取修改表结构字段
func (this MysqlGrammar) getChangedColumns(bp *Blueprint) (columns []string) {
for _, column := range bp.getChangedColumns() {
name := column.Name
rename := column.Name
if column.rename != "" {
rename = column.rename
}
sql := fmt.Sprintf("`%s` `%s` ", name, rename) + this.GetColumnType(column)
columns = append(columns, this.addModifiers(sql, bp, column))
}
return
}