db/schema/sqlite3_grammar.go
2023-04-12 15:58:25 +08:00

352 lines
11 KiB
Go

package schema
import (
"fmt"
"regexp"
"strings"
"github.com/samber/lo"
"git.fsdpf.net/go/db"
)
type Sqlite3Grammar struct {
}
var sqlite3DefaultModifiers = []string{"VirtualAs", "StoredAs", "Nullable", "Default", "Increment"}
var sqlite3Serials = []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"}
func (this Sqlite3Grammar) CompileTableExists() string {
// 兼容其他数据库查询
// 其他数据库查看都是同时传入, 数据库名 和 数据表名
return "select count(*) from sqlite_master where type = 'table' and name = ?||?"
}
func (this Sqlite3Grammar) CompileColumnListing(table string) string {
// 兼容其他数据库查询
// 其他数据库查看都是同时传入, 数据库名 和 数据表名
return "pragma table_info(" + table + ")"
}
// 创建表
func (this Sqlite3Grammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string {
temporary := "create"
if bp.temporary {
temporary = "create temporary"
}
return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", "))
}
// 添加字段
func (this Sqlite3Grammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string {
columns := PrefixArray("add column", this.getAddedColumns(bp))
return lo.Map(columns, func(column string, _ int) string {
return "alter table `" + bp.table + "` " + column
})
}
// 修改字段
func (this Sqlite3Grammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) {
// 创建表的sql
cTableSql := this.getCreateTable(bp, conn)
// 旧表字段
oldColumns := []string{}
if _, err := conn.Select("select '`'||name||'`' from pragma_table_info(?)", []any{bp.table}, &oldColumns); err != nil {
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
}
// 数据同步
insert_sql := fmt.Sprintf("insert into `"+bp.table+"` (%s)", strings.Join(oldColumns, ", "))
select_sql := fmt.Sprintf("select %s from `"+bp.table+"_old`", strings.Join(oldColumns, ", "))
for _, column := range bp.getChangedColumns() {
name := column.Name
rename := column.rename
// 旧字段(查询) -> 新字段(插入)
if column.rename != "" {
select_sql = strings.ReplaceAll(select_sql, name, name+"` as `"+rename)
insert_sql = strings.ReplaceAll(insert_sql, name, rename)
} else {
rename = name
}
// 修改表字段
sql := "`" + rename + "` " + this.GetColumnType(column)
sql = this.addModifiers(sql, bp, column)
reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P<prefix>\\s*)(?P<content>`%s`.*)(?P<end>(,\\s|\\s)?)$", name))
cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql))
}
// 开启事物
// sqls = append(sqls, "BEGIN TRANSACTION")
// 1. 重命名已存在的表
sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table))
// 2. 创建新结构表
sqls = append(sqls, cTableSql)
// 3. 复制数据
sqls = append(sqls, insert_sql+" "+select_sql)
// 4. 删除旧表
sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table))
// 提交事物
// sqls = append(sqls, "COMMIT")
// @todo 复制索引
return sqls
}
// 创建唯一索引
func (this Sqlite3Grammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
columns := lo.Map(command.Columns, func(column string, _ int) string {
return "`" + column + "`"
})
return fmt.Sprintf("create unique index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", "))
}
// 创建普通索引
func (this Sqlite3Grammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
columns := lo.Map(command.Columns, func(column string, _ int) string {
return "`" + column + "`"
})
return fmt.Sprintf("create index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", "))
}
// 创建空间索引
func (this Sqlite3Grammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
panic("Sqlite 不支持 SpatialIndex")
}
// 删除表
func (this Sqlite3Grammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop table `" + bp.table + "`"
}
// 删除表, 先判断再删除
func (this Sqlite3Grammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop table if exists `" + bp.table + "`"
}
// 删除列
func (this Sqlite3Grammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) {
// 创建表的sql, 排除不要的字段
cTableSql := this.getCreateTable(bp, conn, command.Columns...)
// 旧表字段
oldColumns := []string{}
if _, err := conn.Select(
"select '`'||name||'`' from pragma_table_info(?) where `name` not in("+strings.Trim(strings.Repeat("?,", len(command.Columns)), ",")+")",
append([]any{bp.table}, lo.ToAnySlice(command.Columns)...),
&oldColumns); err != nil {
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
}
copy_data_sql := fmt.Sprintf("insert into `%s` select %s from `%s_old`", bp.table, strings.Join(oldColumns, ", "), bp.table)
for _, column := range bp.getChangedColumns() {
name := column.Name
rename := column.rename
if rename != "" {
copy_data_sql = strings.ReplaceAll(copy_data_sql, "`"+name+"`", "`"+rename+"`")
}
// 修改表字段
sql := "`" + rename + "` " + this.GetColumnType(column)
sql = this.addModifiers(sql, bp, column)
reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P<prefix>\\s*)(?P<content>`%s`.*)(?P<end>(,\\s|\\s)?)$", name))
cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql))
}
// 开启事物
// sqls = append(sqls, "BEGIN TRANSACTION")
// 1. 重命名已存在的表
sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table))
// 2. 创建新结构表
sqls = append(sqls, cTableSql)
// 3. 复制数据
sqls = append(sqls, copy_data_sql)
// 4. 删除旧表
sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table))
// 提交事物
// sqls = append(sqls, "COMMIT")
// @todo 复制索引
return sqls
}
// 删除唯一索引
func (this Sqlite3Grammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.CompileDropIndex(bp, command, conn)
}
// 删除索引
func (this Sqlite3Grammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop index `" + command.Index + "`"
}
// 删除空间索引
func (this Sqlite3Grammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.CompileSpatialIndex(bp, command, conn)
}
func (this Sqlite3Grammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string {
return "alter table `" + bp.table + "` rename to `" + command.To + "`"
}
func (this Sqlite3Grammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
// 先删除索引, 后创建索引
return "alter table `" + bp.table + "` rename to `" + command.To + "`"
}
// 删除所有表
func (this Sqlite3Grammar) CompileDropAllTables() string {
return "delete from sqlite_master where type in ('table', 'index', 'trigger')"
}
// 删除所有视图
func (this Sqlite3Grammar) CompileDropAllViews() string {
return "delete from sqlite_master where type in ('view')"
}
func (this Sqlite3Grammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string {
for _, modifier := range mysqlDefaultModifiers {
sql = sql + this.GetColumnModifier(modifier, bp, column)
}
return sql
}
func (this Sqlite3Grammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string {
switch modifier {
case "VirtualAs":
if column.virtualAs != "" {
return fmt.Sprintf(" as (%s)", column.virtualAs)
}
case "StoredAs":
if column.storedAs != "" {
return fmt.Sprintf(" as (%s) stored", column.storedAs)
}
case "Nullable":
if !column.nullable {
return " not null"
}
case "Default":
if column.useCurrent {
return ""
}
switch v := column.def.(type) {
case db.Expression:
return fmt.Sprintf(" default %s", v)
case nil:
default:
return fmt.Sprintf(" default '%v'", v)
}
case "Increment":
if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) {
return " primary key autoincrement"
}
}
return ""
}
func (this Sqlite3Grammar) GetColumnType(column *ColumnDefinition) string {
switch column.Type {
case "char", "string":
return fmt.Sprintf("varchar(%d)", column.Length)
case "text":
return "text"
case "year", "integer":
return "integer"
case "bigInteger":
return "bigint"
case "tinyInteger":
return "tinyint"
case "smallInteger":
return "smallint"
case "decimal":
return fmt.Sprintf("numeric(%d,%d)", column.Total, column.Places)
case "boolean":
return "tinyint(1)"
case "enum":
return fmt.Sprintf(`varchar check ("%s" in (%s))`, column.Name, QuoteString(column.Allowed))
case "json":
return "json"
case "date":
return "date"
case "binary":
return "blob"
case "datetime", "timestamp":
columnType := column.Type
if column.useCurrent {
columnType = columnType + " default CURRENT_TIMESTAMP"
}
return columnType
case "time":
return "time"
case "uuid":
return "varchar(36)"
}
panic("不支持的数据类型: " + column.Type)
}
// 获取新增表结构字段
func (this Sqlite3Grammar) getAddedColumns(bp *Blueprint) (columns []string) {
for _, column := range bp.getAddedColumns() {
sql := "`" + column.Name + "` " + this.GetColumnType(column)
columns = append(columns, this.addModifiers(sql, bp, column))
}
return
}
// 获取修改表结构字段
func (this Sqlite3Grammar) getChangedColumns(bp *Blueprint) map[string]string {
columns := map[string]string{}
for _, column := range bp.getChangedColumns() {
sql := "`" + column.Name + "` " + this.GetColumnType(column)
columns[column.Name] = this.addModifiers(sql, bp, column)
}
return columns
}
// 数据库表字段
func (this Sqlite3Grammar) getCreateTable(bp *Blueprint, conn *db.Connection, without ...string) (sql string) {
dbColumns := []struct {
Name string `db:"name"`
Type string `db:"type"`
Default db.NullString `db:"dflt_value"`
NotNull bool `db:"notnull"`
Pk bool `db:"pk"`
}{}
if _, err := conn.Select("select * from pragma_table_info(?) order by cid", []any{bp.table}, &dbColumns); err != nil {
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
}
cols := []string{}
for _, column := range dbColumns {
if lo.Contains(without, column.Name) {
continue
}
col := " `" + column.Name + "`"
col += " " + column.Type
if column.NotNull {
col += " not null"
}
if column.Pk {
col += " primary key autoincrement"
}
if column.Default != "" {
col += " default " + string(column.Default)
}
cols = append(cols, col)
}
return fmt.Sprintf("CREATE TABLE `%s` (\n%s\n)", bp.table, strings.Join(cols, ", \n"))
}