package schema import ( "fmt" "strings" "github.com/samber/lo" "git.fsdpf.net/go/db" ) type MysqlGrammar struct { } var mysqlDefaultModifiers = []string{ "Unsigned", "Charset", "Collate", "VirtualAs", "StoredAs", "Nullable", "Default", "Increment", "Comment", "After", "First", } var mysqlSerials = []string{ "bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger", } func (this MysqlGrammar) CompileTableExists() string { return "select count(*) from information_schema.tables where table_name = ? and table_schema = ? and table_type = 'BASE TABLE'" } func (this MysqlGrammar) CompileColumnListing(table string) string { return "select ordinal_position as `id`, column_name as `name` from information_schema.columns where table_name = ? and table_schema = ? order by ordinal_position" } // 创建表 func (this MysqlGrammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string { sql := this.compileCreateTable(bp, command, conn) sql = this.compileCreateEncoding(sql, bp, conn) sql = this.compileCreateEngine(sql, bp, conn) return sql } // 创建表结构 func (this MysqlGrammar) compileCreateTable(bp *Blueprint, command *Command, conn *db.Connection) string { temporary := "create" if bp.temporary { temporary = "create temporary" } return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", ")) } // 创建表 字符集 func (this MysqlGrammar) compileCreateEncoding(sql string, bp *Blueprint, conn *db.Connection) string { if bp.charset != "" { sql = sql + " default charset=" + bp.charset } else if conn.GetConfig().Charset != "" { sql = sql + " default charset=" + conn.GetConfig().Charset } else { sql = sql + " default charset=utf8mb4" } if bp.collation != "" { sql = sql + " collate=" + bp.collation } else { sql = sql + " collate=utf8mb4_general_ci" } return sql } // 创建表存储方式 func (this MysqlGrammar) compileCreateEngine(sql string, bp *Blueprint, conn *db.Connection) string { if bp.engine != "" { sql = sql + " engine = " + bp.engine } return sql } // 添加字段 func (this MysqlGrammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string { columns := PrefixArray("add", this.getAddedColumns(bp)) return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")} } // 修改字段 func (this MysqlGrammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) []string { columns := PrefixArray("change", this.getChangedColumns(bp)) return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")} } // 添加主键 func (this MysqlGrammar) CompilePrimary(bp *Blueprint, command *Command, conn *db.Connection) string { return this.compileKey(bp, command, "primary key") } // 添加唯一键 func (this MysqlGrammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string { return this.compileKey(bp, command, "unique") } // 添加普通索引 func (this MysqlGrammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string { return this.compileKey(bp, command, "index") } // 添加空间索引 func (this MysqlGrammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string { return this.compileKey(bp, command, "spatial index") } // 添加索引 func (this MysqlGrammar) compileKey(bp *Blueprint, command *Command, typ string) string { algorithm := "" if command.Algorithm != "" { algorithm = " using " + command.Algorithm } columns := lo.Map(command.Columns, func(column string, _ int) string { return "`" + column + "`" }) return fmt.Sprintf("alter table %s add %s %s%s(%s)", bp.table, typ, command.Index, algorithm, strings.Join(columns, ",")) } // 删除表 func (this MysqlGrammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string { return "drop table `" + bp.table + "`" } // 删除表, 先判断再删除 func (this MysqlGrammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string { return "drop table if exists `" + bp.table + "`" } // 删除列 func (this MysqlGrammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) []string { columns := lo.Map(command.Columns, func(column string, _ int) string { return "drop `" + column + "`" }) return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")} } // 删除主键 func (this MysqlGrammar) CompileDropPrimary(bp *Blueprint, command *Command, conn *db.Connection) string { return "alter table `" + bp.table + "` drop primary key" } // 删除索引 func (this MysqlGrammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string { return "alter table `" + bp.table + "` drop index `" + command.Index + "`" } // 删除唯一键 func (this MysqlGrammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string { return this.CompileDropIndex(bp, command, conn) } // 删除空间索引 func (this MysqlGrammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string { return this.CompileDropIndex(bp, command, conn) } // 表重命名 func (this MysqlGrammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string { return "rename table `" + bp.table + "` to `" + command.To + "`" } // 索引重命名 func (this MysqlGrammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string { return fmt.Sprintf("alter table `%s` rename index `%s` to `%s`", bp.table, command.From, command.To) } // 删除所有表 func (this MysqlGrammar) CompileDropAllTables(tables ...string) string { tables = lo.Map(tables, func(table string, _ int) string { return "`" + table + "`" }) return "drop table " + strings.Join(tables, ", ") } // 删除所有视图 func (this MysqlGrammar) CompileDropAllViews(views ...string) string { views = lo.Map(views, func(view string, _ int) string { return "`" + view + "`" }) return "drop view " + strings.Join(views, ", ") } // 查询所有表 // func (this MysqlGrammar) CompileGetAllTables(tables ...string) string { // } // 查询所有视图 // func (this MysqlGrammar) CompileGetAllViews(views ...string) string {} func (this MysqlGrammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string { for _, modifier := range mysqlDefaultModifiers { sql = sql + this.GetColumnModifier(modifier, bp, column) } return sql } func (this MysqlGrammar) GetColumnType(column *ColumnDefinition) string { switch column.Type { case "char": return fmt.Sprintf("char(%d)", column.Length) case "string": return fmt.Sprintf("varchar(%d)", column.Length) case "text": return "text" case "integer": return "int" case "bigInteger": return "bigint" case "tinyInteger": return "tinyint" case "smallInteger": return "smallint" case "decimal": return fmt.Sprintf("decimal(%d,%d)", column.Total, column.Places) case "boolean": return "tinyint(1)" case "enum": return fmt.Sprintf("enum(%s)", QuoteString(column.Allowed)) case "json": return "json" case "date": return "date" case "binary": return "binary" case "datetime": columnType := "datetime" if column.Precision > 0 { columnType = fmt.Sprintf("datetime(%d)", column.Precision) } return columnType case "time": columnType := "time" if column.Precision > 0 { columnType = fmt.Sprintf("time(%d)", column.Precision) } return columnType case "timestamp": columnType := "timestamp" if column.Precision > 0 { columnType = fmt.Sprintf("timestamp(%d)", column.Precision) } return columnType case "year": return "year" case "uuid": return "char(36)" } panic("不支持的数据类型: " + column.Type) } func (this MysqlGrammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string { switch modifier { case "Unsigned": if column.unsigned { return " unsigned" } case "Charset": if column.charset != "" { return fmt.Sprintf(" character set %s", column.charset) } case "Collate": if column.collation != "" { return fmt.Sprintf(" collate '%s'", column.collation) } case "VirtualAs": if column.virtualAs != "" { return fmt.Sprintf(" as (%s)", column.virtualAs) } case "StoredAs": if column.storedAs != "" { return fmt.Sprintf(" as (%s) stored", column.storedAs) } case "Nullable": if column.virtualAs == "" && column.storedAs == "" { if column.nullable { return " null" } return " not null" } case "Default": switch v := column.def.(type) { case db.Expression: if column.useCurrent { return fmt.Sprintf(" default %s %s", "CURRENT_TIMESTAMP", v) } return fmt.Sprintf(" default %s", v) case nil: if column.useCurrent { return fmt.Sprintf(" default %s", "CURRENT_TIMESTAMP") } default: return fmt.Sprintf(" default '%v'", v) } case "Increment": if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) { return " auto_increment primary key" } case "Comment": if column.comment != "" { return fmt.Sprintf(" comment '%s'", db.Addslashes(column.comment)) } case "After": if column.after != "" { return fmt.Sprintf(" after `%s`", column.after) } case "First": if column.first { return " first" } } return "" } // 获取新增表结构字段 func (this MysqlGrammar) getAddedColumns(bp *Blueprint) (columns []string) { for _, column := range bp.getAddedColumns() { sql := "`" + column.Name + "` " + this.GetColumnType(column) columns = append(columns, this.addModifiers(sql, bp, column)) } return } // 获取修改表结构字段 func (this MysqlGrammar) getChangedColumns(bp *Blueprint) (columns []string) { for _, column := range bp.getChangedColumns() { name := column.Name rename := column.Name if column.rename != "" { rename = column.rename } sql := fmt.Sprintf("`%s` `%s` ", name, rename) + this.GetColumnType(column) columns = append(columns, this.addModifiers(sql, bp, column)) } return }