352 lines
11 KiB
Go
352 lines
11 KiB
Go
|
package schema
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/samber/lo"
|
||
|
|
||
|
"git.fsdpf.net/go/db"
|
||
|
)
|
||
|
|
||
|
type Sqlite3Grammar struct {
|
||
|
}
|
||
|
|
||
|
var sqlite3DefaultModifiers = []string{"VirtualAs", "StoredAs", "Nullable", "Default", "Increment"}
|
||
|
|
||
|
var sqlite3Serials = []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"}
|
||
|
|
||
|
func (this Sqlite3Grammar) CompileTableExists() string {
|
||
|
// 兼容其他数据库查询
|
||
|
// 其他数据库查看都是同时传入, 数据库名 和 数据表名
|
||
|
return "select count(*) from sqlite_master where type = 'table' and name = ?||?"
|
||
|
}
|
||
|
|
||
|
func (this Sqlite3Grammar) CompileColumnListing(table string) string {
|
||
|
// 兼容其他数据库查询
|
||
|
// 其他数据库查看都是同时传入, 数据库名 和 数据表名
|
||
|
return "pragma table_info(" + table + ")"
|
||
|
}
|
||
|
|
||
|
// 创建表
|
||
|
func (this Sqlite3Grammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
temporary := "create"
|
||
|
if bp.temporary {
|
||
|
temporary = "create temporary"
|
||
|
}
|
||
|
return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", "))
|
||
|
}
|
||
|
|
||
|
// 添加字段
|
||
|
func (this Sqlite3Grammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string {
|
||
|
columns := PrefixArray("add column", this.getAddedColumns(bp))
|
||
|
|
||
|
return lo.Map(columns, func(column string, _ int) string {
|
||
|
return "alter table `" + bp.table + "` " + column
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// 修改字段
|
||
|
func (this Sqlite3Grammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) {
|
||
|
// 创建表的sql
|
||
|
cTableSql := this.getCreateTable(bp, conn)
|
||
|
|
||
|
// 旧表字段
|
||
|
oldColumns := []string{}
|
||
|
if _, err := conn.Select("select '`'||name||'`' from pragma_table_info(?)", []any{bp.table}, &oldColumns); err != nil {
|
||
|
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
|
||
|
}
|
||
|
|
||
|
// 数据同步
|
||
|
insert_sql := fmt.Sprintf("insert into `"+bp.table+"` (%s)", strings.Join(oldColumns, ", "))
|
||
|
select_sql := fmt.Sprintf("select %s from `"+bp.table+"_old`", strings.Join(oldColumns, ", "))
|
||
|
|
||
|
for _, column := range bp.getChangedColumns() {
|
||
|
name := column.Name
|
||
|
rename := column.rename
|
||
|
|
||
|
// 旧字段(查询) -> 新字段(插入)
|
||
|
if column.rename != "" {
|
||
|
select_sql = strings.ReplaceAll(select_sql, name, name+"` as `"+rename)
|
||
|
insert_sql = strings.ReplaceAll(insert_sql, name, rename)
|
||
|
} else {
|
||
|
rename = name
|
||
|
}
|
||
|
|
||
|
// 修改表字段
|
||
|
sql := "`" + rename + "` " + this.GetColumnType(column)
|
||
|
sql = this.addModifiers(sql, bp, column)
|
||
|
reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P<prefix>\\s*)(?P<content>`%s`.*)(?P<end>(,\\s|\\s)?)$", name))
|
||
|
cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql))
|
||
|
}
|
||
|
|
||
|
// 开启事物
|
||
|
// sqls = append(sqls, "BEGIN TRANSACTION")
|
||
|
// 1. 重命名已存在的表
|
||
|
sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table))
|
||
|
// 2. 创建新结构表
|
||
|
sqls = append(sqls, cTableSql)
|
||
|
// 3. 复制数据
|
||
|
sqls = append(sqls, insert_sql+" "+select_sql)
|
||
|
// 4. 删除旧表
|
||
|
sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table))
|
||
|
// 提交事物
|
||
|
// sqls = append(sqls, "COMMIT")
|
||
|
|
||
|
// @todo 复制索引
|
||
|
|
||
|
return sqls
|
||
|
}
|
||
|
|
||
|
// 创建唯一索引
|
||
|
func (this Sqlite3Grammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
columns := lo.Map(command.Columns, func(column string, _ int) string {
|
||
|
return "`" + column + "`"
|
||
|
})
|
||
|
return fmt.Sprintf("create unique index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", "))
|
||
|
}
|
||
|
|
||
|
// 创建普通索引
|
||
|
func (this Sqlite3Grammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
columns := lo.Map(command.Columns, func(column string, _ int) string {
|
||
|
return "`" + column + "`"
|
||
|
})
|
||
|
return fmt.Sprintf("create index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", "))
|
||
|
}
|
||
|
|
||
|
// 创建空间索引
|
||
|
func (this Sqlite3Grammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
panic("Sqlite 不支持 SpatialIndex")
|
||
|
}
|
||
|
|
||
|
// 删除表
|
||
|
func (this Sqlite3Grammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
return "drop table `" + bp.table + "`"
|
||
|
}
|
||
|
|
||
|
// 删除表, 先判断再删除
|
||
|
func (this Sqlite3Grammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
return "drop table if exists `" + bp.table + "`"
|
||
|
}
|
||
|
|
||
|
// 删除列
|
||
|
func (this Sqlite3Grammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) {
|
||
|
// 创建表的sql, 排除不要的字段
|
||
|
cTableSql := this.getCreateTable(bp, conn, command.Columns...)
|
||
|
|
||
|
// 旧表字段
|
||
|
oldColumns := []string{}
|
||
|
|
||
|
if _, err := conn.Select(
|
||
|
"select '`'||name||'`' from pragma_table_info(?) where `name` not in("+strings.Trim(strings.Repeat("?,", len(command.Columns)), ",")+")",
|
||
|
append([]any{bp.table}, lo.ToAnySlice(command.Columns)...),
|
||
|
&oldColumns); err != nil {
|
||
|
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
|
||
|
}
|
||
|
|
||
|
copy_data_sql := fmt.Sprintf("insert into `%s` select %s from `%s_old`", bp.table, strings.Join(oldColumns, ", "), bp.table)
|
||
|
|
||
|
for _, column := range bp.getChangedColumns() {
|
||
|
name := column.Name
|
||
|
rename := column.rename
|
||
|
|
||
|
if rename != "" {
|
||
|
copy_data_sql = strings.ReplaceAll(copy_data_sql, "`"+name+"`", "`"+rename+"`")
|
||
|
}
|
||
|
|
||
|
// 修改表字段
|
||
|
sql := "`" + rename + "` " + this.GetColumnType(column)
|
||
|
sql = this.addModifiers(sql, bp, column)
|
||
|
reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P<prefix>\\s*)(?P<content>`%s`.*)(?P<end>(,\\s|\\s)?)$", name))
|
||
|
cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql))
|
||
|
}
|
||
|
|
||
|
// 开启事物
|
||
|
// sqls = append(sqls, "BEGIN TRANSACTION")
|
||
|
// 1. 重命名已存在的表
|
||
|
sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table))
|
||
|
// 2. 创建新结构表
|
||
|
sqls = append(sqls, cTableSql)
|
||
|
// 3. 复制数据
|
||
|
sqls = append(sqls, copy_data_sql)
|
||
|
// 4. 删除旧表
|
||
|
sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table))
|
||
|
// 提交事物
|
||
|
// sqls = append(sqls, "COMMIT")
|
||
|
|
||
|
// @todo 复制索引
|
||
|
|
||
|
return sqls
|
||
|
}
|
||
|
|
||
|
// 删除唯一索引
|
||
|
func (this Sqlite3Grammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
return this.CompileDropIndex(bp, command, conn)
|
||
|
}
|
||
|
|
||
|
// 删除索引
|
||
|
func (this Sqlite3Grammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
return "drop index `" + command.Index + "`"
|
||
|
}
|
||
|
|
||
|
// 删除空间索引
|
||
|
func (this Sqlite3Grammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
return this.CompileSpatialIndex(bp, command, conn)
|
||
|
}
|
||
|
|
||
|
func (this Sqlite3Grammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
return "alter table `" + bp.table + "` rename to `" + command.To + "`"
|
||
|
}
|
||
|
|
||
|
func (this Sqlite3Grammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||
|
// 先删除索引, 后创建索引
|
||
|
return "alter table `" + bp.table + "` rename to `" + command.To + "`"
|
||
|
}
|
||
|
|
||
|
// 删除所有表
|
||
|
func (this Sqlite3Grammar) CompileDropAllTables() string {
|
||
|
return "delete from sqlite_master where type in ('table', 'index', 'trigger')"
|
||
|
}
|
||
|
|
||
|
// 删除所有视图
|
||
|
func (this Sqlite3Grammar) CompileDropAllViews() string {
|
||
|
return "delete from sqlite_master where type in ('view')"
|
||
|
}
|
||
|
|
||
|
func (this Sqlite3Grammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string {
|
||
|
for _, modifier := range mysqlDefaultModifiers {
|
||
|
sql = sql + this.GetColumnModifier(modifier, bp, column)
|
||
|
}
|
||
|
return sql
|
||
|
}
|
||
|
|
||
|
func (this Sqlite3Grammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string {
|
||
|
switch modifier {
|
||
|
case "VirtualAs":
|
||
|
if column.virtualAs != "" {
|
||
|
return fmt.Sprintf(" as (%s)", column.virtualAs)
|
||
|
}
|
||
|
case "StoredAs":
|
||
|
if column.storedAs != "" {
|
||
|
return fmt.Sprintf(" as (%s) stored", column.storedAs)
|
||
|
}
|
||
|
case "Nullable":
|
||
|
if !column.nullable {
|
||
|
return " not null"
|
||
|
}
|
||
|
case "Default":
|
||
|
if column.useCurrent {
|
||
|
return ""
|
||
|
}
|
||
|
switch v := column.def.(type) {
|
||
|
case db.Expression:
|
||
|
return fmt.Sprintf(" default %s", v)
|
||
|
case nil:
|
||
|
default:
|
||
|
return fmt.Sprintf(" default '%v'", v)
|
||
|
}
|
||
|
case "Increment":
|
||
|
if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) {
|
||
|
return " primary key autoincrement"
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
func (this Sqlite3Grammar) GetColumnType(column *ColumnDefinition) string {
|
||
|
switch column.Type {
|
||
|
case "char", "string":
|
||
|
return fmt.Sprintf("varchar(%d)", column.Length)
|
||
|
case "text":
|
||
|
return "text"
|
||
|
case "year", "integer":
|
||
|
return "integer"
|
||
|
case "bigInteger":
|
||
|
return "bigint"
|
||
|
case "tinyInteger":
|
||
|
return "tinyint"
|
||
|
case "smallInteger":
|
||
|
return "smallint"
|
||
|
case "decimal":
|
||
|
return fmt.Sprintf("numeric(%d,%d)", column.Total, column.Places)
|
||
|
case "boolean":
|
||
|
return "tinyint(1)"
|
||
|
case "enum":
|
||
|
return fmt.Sprintf(`varchar check ("%s" in (%s))`, column.Name, QuoteString(column.Allowed))
|
||
|
case "json":
|
||
|
return "json"
|
||
|
case "date":
|
||
|
return "date"
|
||
|
case "binary":
|
||
|
return "blob"
|
||
|
case "datetime", "timestamp":
|
||
|
columnType := column.Type
|
||
|
if column.useCurrent {
|
||
|
columnType = columnType + " default CURRENT_TIMESTAMP"
|
||
|
}
|
||
|
return columnType
|
||
|
case "time":
|
||
|
return "time"
|
||
|
case "uuid":
|
||
|
return "varchar(36)"
|
||
|
}
|
||
|
panic("不支持的数据类型: " + column.Type)
|
||
|
}
|
||
|
|
||
|
// 获取新增表结构字段
|
||
|
func (this Sqlite3Grammar) getAddedColumns(bp *Blueprint) (columns []string) {
|
||
|
for _, column := range bp.getAddedColumns() {
|
||
|
sql := "`" + column.Name + "` " + this.GetColumnType(column)
|
||
|
columns = append(columns, this.addModifiers(sql, bp, column))
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// 获取修改表结构字段
|
||
|
func (this Sqlite3Grammar) getChangedColumns(bp *Blueprint) map[string]string {
|
||
|
columns := map[string]string{}
|
||
|
for _, column := range bp.getChangedColumns() {
|
||
|
sql := "`" + column.Name + "` " + this.GetColumnType(column)
|
||
|
columns[column.Name] = this.addModifiers(sql, bp, column)
|
||
|
}
|
||
|
return columns
|
||
|
}
|
||
|
|
||
|
// 数据库表字段
|
||
|
func (this Sqlite3Grammar) getCreateTable(bp *Blueprint, conn *db.Connection, without ...string) (sql string) {
|
||
|
dbColumns := []struct {
|
||
|
Name string `db:"name"`
|
||
|
Type string `db:"type"`
|
||
|
Default db.NullString `db:"dflt_value"`
|
||
|
NotNull bool `db:"notnull"`
|
||
|
Pk bool `db:"pk"`
|
||
|
}{}
|
||
|
|
||
|
if _, err := conn.Select("select * from pragma_table_info(?) order by cid", []any{bp.table}, &dbColumns); err != nil {
|
||
|
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
|
||
|
}
|
||
|
|
||
|
cols := []string{}
|
||
|
for _, column := range dbColumns {
|
||
|
if lo.Contains(without, column.Name) {
|
||
|
continue
|
||
|
}
|
||
|
col := " `" + column.Name + "`"
|
||
|
col += " " + column.Type
|
||
|
if column.NotNull {
|
||
|
col += " not null"
|
||
|
}
|
||
|
if column.Pk {
|
||
|
col += " primary key autoincrement"
|
||
|
}
|
||
|
if column.Default != "" {
|
||
|
col += " default " + string(column.Default)
|
||
|
}
|
||
|
cols = append(cols, col)
|
||
|
}
|
||
|
|
||
|
return fmt.Sprintf("CREATE TABLE `%s` (\n%s\n)", bp.table, strings.Join(cols, ", \n"))
|
||
|
}
|