package schema import ( "fmt" "regexp" "strings" "github.com/samber/lo" "git.fsdpf.net/go/db" ) type Sqlite3Grammar struct { } var sqlite3DefaultModifiers = []string{"VirtualAs", "StoredAs", "Nullable", "Default", "Increment"} var sqlite3Serials = []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"} func (this Sqlite3Grammar) CompileTableExists() string { // 兼容其他数据库查询 // 其他数据库查看都是同时传入, 数据库名 和 数据表名 return "select count(*) from sqlite_master where type = 'table' and name = ?||?" } func (this Sqlite3Grammar) CompileColumnListing(table string) string { // 兼容其他数据库查询 // 其他数据库查看都是同时传入, 数据库名 和 数据表名 return "pragma table_info(" + table + ")" } // 创建表 func (this Sqlite3Grammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string { temporary := "create" if bp.temporary { temporary = "create temporary" } return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", ")) } // 添加字段 func (this Sqlite3Grammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string { columns := PrefixArray("add column", this.getAddedColumns(bp)) return lo.Map(columns, func(column string, _ int) string { return "alter table `" + bp.table + "` " + column }) } // 修改字段 func (this Sqlite3Grammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) { // 创建表的sql cTableSql := this.getCreateTable(bp, conn) // 旧表字段 oldColumns := []string{} if _, err := conn.Select("select '`'||name||'`' from pragma_table_info(?)", []any{bp.table}, &oldColumns); err != nil { panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err)) } // 数据同步 insert_sql := fmt.Sprintf("insert into `"+bp.table+"` (%s)", strings.Join(oldColumns, ", ")) select_sql := fmt.Sprintf("select %s from `"+bp.table+"_old`", strings.Join(oldColumns, ", ")) for _, column := range bp.getChangedColumns() { name := column.Name rename := column.rename // 旧字段(查询) -> 新字段(插入) if column.rename != "" { select_sql = strings.ReplaceAll(select_sql, name, name+"` as `"+rename) insert_sql = strings.ReplaceAll(insert_sql, name, rename) } else { rename = name } // 修改表字段 sql := "`" + rename + "` " + this.GetColumnType(column) sql = this.addModifiers(sql, bp, column) reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P\\s*)(?P`%s`.*)(?P(,\\s|\\s)?)$", name)) cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql)) } // 开启事物 // sqls = append(sqls, "BEGIN TRANSACTION") // 1. 重命名已存在的表 sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table)) // 2. 创建新结构表 sqls = append(sqls, cTableSql) // 3. 复制数据 sqls = append(sqls, insert_sql+" "+select_sql) // 4. 删除旧表 sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table)) // 提交事物 // sqls = append(sqls, "COMMIT") // @todo 复制索引 return sqls } // 创建唯一索引 func (this Sqlite3Grammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string { columns := lo.Map(command.Columns, func(column string, _ int) string { return "`" + column + "`" }) return fmt.Sprintf("create unique index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", ")) } // 创建普通索引 func (this Sqlite3Grammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string { columns := lo.Map(command.Columns, func(column string, _ int) string { return "`" + column + "`" }) return fmt.Sprintf("create index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", ")) } // 创建空间索引 func (this Sqlite3Grammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string { panic("Sqlite 不支持 SpatialIndex") } // 删除表 func (this Sqlite3Grammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string { return "drop table `" + bp.table + "`" } // 删除表, 先判断再删除 func (this Sqlite3Grammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string { return "drop table if exists `" + bp.table + "`" } // 删除列 func (this Sqlite3Grammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) { // 创建表的sql, 排除不要的字段 cTableSql := this.getCreateTable(bp, conn, command.Columns...) // 旧表字段 oldColumns := []string{} if _, err := conn.Select( "select '`'||name||'`' from pragma_table_info(?) where `name` not in("+strings.Trim(strings.Repeat("?,", len(command.Columns)), ",")+")", append([]any{bp.table}, lo.ToAnySlice(command.Columns)...), &oldColumns); err != nil { panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err)) } copy_data_sql := fmt.Sprintf("insert into `%s` select %s from `%s_old`", bp.table, strings.Join(oldColumns, ", "), bp.table) for _, column := range bp.getChangedColumns() { name := column.Name rename := column.rename if rename != "" { copy_data_sql = strings.ReplaceAll(copy_data_sql, "`"+name+"`", "`"+rename+"`") } // 修改表字段 sql := "`" + rename + "` " + this.GetColumnType(column) sql = this.addModifiers(sql, bp, column) reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P\\s*)(?P`%s`.*)(?P(,\\s|\\s)?)$", name)) cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql)) } // 开启事物 // sqls = append(sqls, "BEGIN TRANSACTION") // 1. 重命名已存在的表 sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table)) // 2. 创建新结构表 sqls = append(sqls, cTableSql) // 3. 复制数据 sqls = append(sqls, copy_data_sql) // 4. 删除旧表 sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table)) // 提交事物 // sqls = append(sqls, "COMMIT") // @todo 复制索引 return sqls } // 删除唯一索引 func (this Sqlite3Grammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string { return this.CompileDropIndex(bp, command, conn) } // 删除索引 func (this Sqlite3Grammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string { return "drop index `" + command.Index + "`" } // 删除空间索引 func (this Sqlite3Grammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string { return this.CompileSpatialIndex(bp, command, conn) } func (this Sqlite3Grammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string { return "alter table `" + bp.table + "` rename to `" + command.To + "`" } func (this Sqlite3Grammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string { // 先删除索引, 后创建索引 return "alter table `" + bp.table + "` rename to `" + command.To + "`" } // 删除所有表 func (this Sqlite3Grammar) CompileDropAllTables() string { return "delete from sqlite_master where type in ('table', 'index', 'trigger')" } // 删除所有视图 func (this Sqlite3Grammar) CompileDropAllViews() string { return "delete from sqlite_master where type in ('view')" } func (this Sqlite3Grammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string { for _, modifier := range mysqlDefaultModifiers { sql = sql + this.GetColumnModifier(modifier, bp, column) } return sql } func (this Sqlite3Grammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string { switch modifier { case "VirtualAs": if column.virtualAs != "" { return fmt.Sprintf(" as (%s)", column.virtualAs) } case "StoredAs": if column.storedAs != "" { return fmt.Sprintf(" as (%s) stored", column.storedAs) } case "Nullable": if !column.nullable { return " not null" } case "Default": if column.useCurrent { return "" } switch v := column.def.(type) { case db.Expression: return fmt.Sprintf(" default %s", v) case nil: default: return fmt.Sprintf(" default '%v'", v) } case "Increment": if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) { return " primary key autoincrement" } } return "" } func (this Sqlite3Grammar) GetColumnType(column *ColumnDefinition) string { switch column.Type { case "char", "string": return fmt.Sprintf("varchar(%d)", column.Length) case "text": return "text" case "year", "integer": return "integer" case "bigInteger": return "bigint" case "tinyInteger": return "tinyint" case "smallInteger": return "smallint" case "decimal": return fmt.Sprintf("numeric(%d,%d)", column.Total, column.Places) case "boolean": return "tinyint(1)" case "enum": return fmt.Sprintf(`varchar check ("%s" in (%s))`, column.Name, QuoteString(column.Allowed)) case "json": return "json" case "date": return "date" case "binary": return "blob" case "datetime", "timestamp": columnType := column.Type if column.useCurrent { columnType = columnType + " default CURRENT_TIMESTAMP" } return columnType case "time": return "time" case "uuid": return "varchar(36)" } panic("不支持的数据类型: " + column.Type) } // 获取新增表结构字段 func (this Sqlite3Grammar) getAddedColumns(bp *Blueprint) (columns []string) { for _, column := range bp.getAddedColumns() { sql := "`" + column.Name + "` " + this.GetColumnType(column) columns = append(columns, this.addModifiers(sql, bp, column)) } return } // 获取修改表结构字段 func (this Sqlite3Grammar) getChangedColumns(bp *Blueprint) map[string]string { columns := map[string]string{} for _, column := range bp.getChangedColumns() { sql := "`" + column.Name + "` " + this.GetColumnType(column) columns[column.Name] = this.addModifiers(sql, bp, column) } return columns } // 数据库表字段 func (this Sqlite3Grammar) getCreateTable(bp *Blueprint, conn *db.Connection, without ...string) (sql string) { dbColumns := []struct { Name string `db:"name"` Type string `db:"type"` Default db.NullString `db:"dflt_value"` NotNull bool `db:"notnull"` Pk bool `db:"pk"` }{} if _, err := conn.Select("select * from pragma_table_info(?) order by cid", []any{bp.table}, &dbColumns); err != nil { panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err)) } cols := []string{} for _, column := range dbColumns { if lo.Contains(without, column.Name) { continue } col := " `" + column.Name + "`" col += " " + column.Type if column.NotNull { col += " not null" } if column.Pk { col += " primary key autoincrement" } if column.Default != "" { col += " default " + string(column.Default) } cols = append(cols, col) } return fmt.Sprintf("CREATE TABLE `%s` (\n%s\n)", bp.table, strings.Join(cols, ", \n")) }