first commit

This commit is contained in:
2023-04-12 15:58:25 +08:00
commit 59b3faa171
30 changed files with 6257 additions and 0 deletions

456
schema/blueprint.go Normal file
View File

@@ -0,0 +1,456 @@
package schema
import (
"strings"
"github.com/samber/lo"
"git.fsdpf.net/go/db"
)
type Blueprint struct {
table string // the table the blueprint describes.
columns []*ColumnDefinition // columns that should be added to the table
commands []*Command //
temporary bool // Whether to make the table temporary.
charset string // The default character set that should be used for the table.
collation string // The collation that should be used for the table.
engine string // The engine that should be used for the table.
}
type Command struct {
Type string
CommandOptions
}
type CommandOptions struct {
Index string
Columns []string
Algorithm string
To string
From string
}
func NewBlueprint(table string) *Blueprint {
return &Blueprint{table: table, charset: "utf8mb4", collation: "utf8mb4_general_ci"}
}
// 字符串
func (this *Blueprint) Char(column string, length int) *ColumnDefinition {
if length == 0 {
length = 255
}
return this.addColumn("char", column, &ColumnOptions{Length: length})
}
// 可变长度字符串
func (this *Blueprint) String(column string, length int) *ColumnDefinition {
if length == 0 {
length = 255
}
return this.addColumn("string", column, &ColumnOptions{Length: length})
}
// 文本
func (this *Blueprint) Text(column string) *ColumnDefinition {
return this.addColumn("text", column, nil)
}
// 整型
func (this *Blueprint) Integer(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
unsigned := false
if len(params) > 0 {
autoIncrement = params[0]
}
if len(params) > 1 {
unsigned = params[1]
}
return this.addColumn("integer", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
}
// 迷你整型 1 byte
func (this *Blueprint) TinyInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
unsigned := false
if len(params) > 0 {
autoIncrement = params[0]
}
if len(params) > 1 {
unsigned = params[1]
}
return this.addColumn("tinyInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
}
// 小整型 2 byte
func (this *Blueprint) SmallInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
unsigned := false
if len(params) > 0 {
autoIncrement = params[0]
}
if len(params) > 1 {
unsigned = params[1]
}
return this.addColumn("smallInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
}
// 大整型 2 byte
func (this *Blueprint) BigInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
unsigned := false
if len(params) > 0 {
autoIncrement = params[0]
}
if len(params) > 1 {
unsigned = params[1]
}
return this.addColumn("bigInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
}
// 无符号整型
func (this *Blueprint) UnsignedInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
if len(params) > 0 {
autoIncrement = params[0]
}
return this.Integer(column, autoIncrement, true)
}
// 无符号迷你整型 1 byte
func (this *Blueprint) UnsignedTinyInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
if len(params) > 0 {
autoIncrement = params[0]
}
return this.TinyInteger(column, autoIncrement, true)
}
// 无符号小整型 2 byte
func (this *Blueprint) UnsignedSmallInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
if len(params) > 0 {
autoIncrement = params[0]
}
return this.SmallInteger(column, autoIncrement, true)
}
// 无符号大整型
func (this *Blueprint) UnsignedBigInteger(column string, params ...bool) *ColumnDefinition {
autoIncrement := false
if len(params) > 0 {
autoIncrement = params[0]
}
return this.BigInteger(column, autoIncrement, true)
}
// 精确小数
func (this *Blueprint) Decimal(column string, total, places int) *ColumnDefinition {
if total == 0 {
total = 8
}
if places == 0 {
places = 2
}
return this.addColumn("decimal", column, &ColumnOptions{Total: total, Places: places})
}
// 无符号精确销售
func (this *Blueprint) UnsignedDecimal(column string, total, places int) *ColumnDefinition {
if total == 0 {
total = 8
}
if places == 0 {
places = 2
}
return this.addColumn("decimal", column, &ColumnOptions{Total: total, Places: places, unsigned: true})
}
// 布尔值
func (this *Blueprint) Boolean(column string) *ColumnDefinition {
return this.addColumn("boolean", column, nil)
}
// 枚举类型
func (this *Blueprint) Enum(column string, allowed []string) *ColumnDefinition {
return this.addColumn("enum", column, &ColumnOptions{Allowed: allowed})
}
// JSON
func (this *Blueprint) Json(column string) *ColumnDefinition {
return this.addColumn("json", column, nil)
}
// 日期类型
func (this *Blueprint) Date(column string) *ColumnDefinition {
return this.addColumn("date", column, nil)
}
// 日期时间类型
func (this *Blueprint) DateTime(column string, precision ...int) *ColumnDefinition {
if len(precision) > 0 {
return this.addColumn("datetime", column, &ColumnOptions{Precision: precision[0]})
}
return this.addColumn("datetime", column, nil)
}
// 时间类型
func (this *Blueprint) Time(column string, precision ...int) *ColumnDefinition {
if len(precision) > 0 {
return this.addColumn("time", column, &ColumnOptions{Precision: precision[0]})
}
return this.addColumn("time", column, nil)
}
// 时间戳
func (this *Blueprint) Timestamp(column string, precision ...int) *ColumnDefinition {
if len(precision) > 0 {
return this.addColumn("timestamp", column, &ColumnOptions{Precision: precision[0]})
}
return this.addColumn("timestamp", column, nil)
}
// 年
func (this *Blueprint) Year(column string) *ColumnDefinition {
return this.addColumn("year", column, nil)
}
// 二进制数据
func (this *Blueprint) Binary(column string) *ColumnDefinition {
return this.addColumn("binary", column, nil)
}
// UUID
func (this *Blueprint) Uuid(column string) *ColumnDefinition {
return this.addColumn("uuid", column, nil)
}
// 自增字段
func (this *Blueprint) Increments(column string) *ColumnDefinition {
return this.UnsignedInteger(column, true)
}
// 自增Big字段
func (this *Blueprint) BigIncrements(column string) *ColumnDefinition {
return this.UnsignedBigInteger(column, true)
}
// 添加主键
func (this *Blueprint) Primary(columns ...string) *Command {
return this.addCommand("primary", CommandOptions{Index: this.generateIndexName("pk", columns), Columns: columns})
}
// 唯一键
func (this *Blueprint) Unique(columns ...string) *Command {
return this.addCommand("unique", CommandOptions{Index: this.generateIndexName("unique", columns), Columns: columns})
}
// 普通索引
func (this *Blueprint) Index(columns ...string) *Command {
return this.addCommand("index", CommandOptions{Index: this.generateIndexName("index", columns), Columns: columns})
}
// 空间索引
func (this *Blueprint) SpatialIndex(columns ...string) *Command {
return this.addCommand("spatialIndex", CommandOptions{Index: this.generateIndexName("spatial_index", columns), Columns: columns})
}
// 删除列
func (this *Blueprint) DropColumn(columns ...string) *Command {
return this.addCommand("dropColumn", CommandOptions{Columns: columns})
}
// 创建表
func (this *Blueprint) Create() *Command {
return this.addCommand("create", CommandOptions{})
}
// 设置临时表标记
func (this *Blueprint) Temporary() {
this.temporary = true
}
// 设置表字符集
func (this *Blueprint) Charset(charset string) {
this.charset = charset
}
// 修改表名
func (this *Blueprint) Rename(to string) *Command {
return this.addCommand("rename", CommandOptions{To: to})
}
// 删除表
func (this *Blueprint) Drop() *Command {
return this.addCommand("drop", CommandOptions{})
}
// 删除表, 先判断再删除
func (this *Blueprint) DropIfExists() *Command {
return this.addCommand("dropIfExists", CommandOptions{})
}
// 执行SQL查询
func (this *Blueprint) Build(conn *db.Connection, grammar IGrammar) error {
_, err := conn.Transaction(func(tx *db.Transaction) (result any, err error) {
for _, sql := range this.ToSql(conn, grammar) {
if _, err := tx.Statement(sql, []any{}); err != nil {
tx.Rollback()
return nil, err
}
}
tx.Commit()
return
})
return err
}
func (this *Blueprint) ToSql(conn *db.Connection, grammar IGrammar) (statements []string) {
this.addImpliedCommands(grammar)
for _, cmd := range this.commands {
switch cmd.Type {
case "create":
statements = append(statements, grammar.CompileCreate(this, cmd, conn))
case "add":
statements = append(statements, grammar.CompileAdd(this, cmd, conn)...)
case "change":
statements = append(statements, grammar.CompileChange(this, cmd, conn)...)
case "primary":
// statements = append(statements, grammar.CompilePrimary(this, cmd, conn))
case "unique":
statements = append(statements, grammar.CompileUnique(this, cmd, conn))
case "index":
statements = append(statements, grammar.CompileIndex(this, cmd, conn))
case "spatialIndex":
statements = append(statements, grammar.CompileSpatialIndex(this, cmd, conn))
case "drop":
statements = append(statements, grammar.CompileDrop(this, cmd, conn))
case "dropIfExists":
statements = append(statements, grammar.CompileDropIfExists(this, cmd, conn))
case "dropColumn":
statements = append(statements, grammar.CompileDropColumn(this, cmd, conn)...)
case "dropPrimary":
// statements = append(statements, grammar.CompileDropPrimary(this, cmd, conn))
case "dropUnique":
statements = append(statements, grammar.CompileDropUnique(this, cmd, conn))
case "dropIndex":
statements = append(statements, grammar.CompileDropIndex(this, cmd, conn))
case "dropSpatialIndex":
statements = append(statements, grammar.CompileDropSpatialIndex(this, cmd, conn))
case "rename":
statements = append(statements, grammar.CompileRename(this, cmd, conn))
case "renameIndex":
statements = append(statements, grammar.CompileRenameIndex(this, cmd, conn))
case "dropAllTables":
case "dropAllViews":
case "getAllTables":
case "getAllViews":
}
}
return statements
}
// 判断是否是创建表
func (this *Blueprint) creating() bool {
return lo.SomeBy(this.commands, func(item *Command) bool {
if item.Type == "create" {
return true
}
return false
})
}
func (this *Blueprint) addColumn(typ, name string, options *ColumnOptions) (definition *ColumnDefinition) {
definition = &ColumnDefinition{Type: typ, Name: name}
if options != nil {
if options.Length > 0 {
definition.Length = options.Length
}
if options.autoIncrement {
definition.autoIncrement = true
}
if options.unsigned {
definition.unsigned = true
}
if options.Total > 0 {
definition.Total = options.Total
}
if options.Places > 0 {
definition.Places = options.Places
}
if len(options.Allowed) > 0 {
definition.Allowed = options.Allowed
}
if options.Precision > 0 {
definition.Precision = options.Precision
}
if options.change {
definition.change = true
}
}
this.columns = append(this.columns, definition)
return definition
}
func (this *Blueprint) addImpliedCommands(grammar IGrammar) {
if !this.creating() {
if len(this.getAddedColumns()) > 0 {
this.commands = append([]*Command{this.createCommand("add", CommandOptions{})}, this.commands...)
}
if len(this.getChangedColumns()) > 0 {
this.commands = append([]*Command{this.createCommand("change", CommandOptions{})}, this.commands...)
}
}
this.addFluentIndexes()
}
// 添加索引字段
func (this *Blueprint) addFluentIndexes() {
for _, column := range this.columns {
if column.primary {
this.Primary(column.Name)
continue
} else if column.unique {
this.Unique(column.Name)
continue
} else if column.index {
this.Index(column.Name)
continue
} else if column.spatialIndex {
this.SpatialIndex(column.Name)
continue
}
}
}
func (this *Blueprint) addCommand(name string, options CommandOptions) (command *Command) {
command = this.createCommand(name, options)
this.commands = append(this.commands, command)
return command
}
func (this *Blueprint) createCommand(name string, options CommandOptions) *Command {
return &Command{Type: name, CommandOptions: options}
}
// 生成索引名称
func (this *Blueprint) generateIndexName(typ string, columns []string) string {
return strings.ToLower(typ + "_" + strings.Join(columns, "_"))
}
func (this *Blueprint) getAddedColumns() []*ColumnDefinition {
return lo.Filter(this.columns, func(item *ColumnDefinition, _ int) bool {
return !item.change
})
}
func (this *Blueprint) getChangedColumns() []*ColumnDefinition {
return lo.Filter(this.columns, func(item *ColumnDefinition, _ int) bool {
return item.change
})
}

216
schema/blueprint_test.go Normal file
View File

@@ -0,0 +1,216 @@
package schema
import (
"git.fsdpf.net/go/db"
"testing"
)
// func TestMysqlToSql(t *testing.T) {
// bp := NewBlueprint("users")
// // bp.BigIncrements("id") // 等同于自增 UNSIGNED BIGINT主键
// // bp.BigInteger("votes").Default("0") // 等同于 BIGINT 类型列
// // bp.Binary("data") // 等同于 BLOB 类型列
// // bp.Boolean("confirmed") // 等同于 BOOLEAN 类型列
// // bp.Char("name", 4) // 等同于 CHAR 类型列
// // bp.Date("created_at") // 等同于 DATE 类型列
// // bp.DateTime("created_at") // 等同于 DATETIME 类型列
// // bp.Decimal("amount", 5, 2) // 等同于 DECIMAL 类型列,带精度和范围
// // bp.Enum("level", []string{"easy", "hard"}) // 等同于 ENUM 类型列
// bp.Increments("id").Change("pid") // 等同于自增 UNSIGNED INTEGER (主键)类型列
// // bp.Integer("votes").AutoIncrement() // 等同于 INTEGER 类型列
// bp.Json("options").Charset("uft8").Change() // 等同于 JSON 类型列
// // bp.SmallInteger("votes").Comment("等同于 SMALLINT 类型列") // 等同于 SMALLINT 类型列
// // bp.String("name", 100).Nullable() // 等同于 VARCHAR 类型列,带一个可选长度参数
// // bp.Text("description").After("column") // 等同于 TEXT 类型列
// // bp.Time("sunrise").First() // 等同于 TIME 类型列
// bp.Timestamp("added_on").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")) // 等同于 TIMESTAMP 类型列
// // bp.TinyInteger("numbers") // 等同于 TINYINT 类型列
// // bp.UnsignedBigInteger("votes") // 等同于无符号的 BIGINT 类型列
// // bp.UnsignedDecimal("amount", 8, 2) // 等同于 UNSIGNED DECIMAL 类型列,带有总位数和精度
// // bp.UnsignedInteger("votes") // 等同于无符号的 INTEGER 类型列
// // bp.UnsignedSmallInteger("votes") // 等同于无符号的 SMALLINT 类型列
// // bp.UnsignedTinyInteger("votes") // 等同于无符号的 TINYINT 类型列
// // bp.Uuid("id").Default("00000000-0000-0000-0000-000000000000") // 等同于 UUID 类型列
// // bp.Year("birth_year")
// bp.DropColumn("name", "name_a") // 等同于 YEAR 类型列
// t.Log(bp.ToSql(mysql.Connection, &MysqlGrammar{}))
// }
// func TestSqliteToSql(t *testing.T) {
// bp := NewBlueprint("users")
// bp.charset = "utf8mb4"
// bp.collation = "utf8mb4_general_ci"
// // bp.BigIncrements("id").Comment("ID")
// // bp.Boolean("enabled").Default("1").Comment("是否有效")
// bp.Char("created_user", 36).Default("22000000-0000-0000-0000-000000000001").Comment("创建者").Change()
// bp.Char("owned_user", 36).Default("11000000-0000-0000-0000-000000000000").Comment("拥有者").Change()
// // bp.Timestamp("created_at").UseCurrent().Comment("创建时间")
// // bp.Timestamp("updated_at").Default(db.Raw("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
// bp.DateTime("deleted_at").Nullable().Comment("删除时间")
// bp.String("name_a", 50).Default("").Comment("姓名").Change("c_name")
// bp.DropColumn("name", "name_a")
// t.Log(bp.ToSql(sqlite3.Connection, &Sqlite3Grammar{}))
// }
// Mysql 创建表
func TestMysqlCreateTableToSql(t *testing.T) {
table := NewBlueprint("users")
table.Create()
table.Charset("utf8mb4")
table.BigIncrements("id").Comment("ID")
table.Boolean("enabled").Default("1").Comment("是否有效")
table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者")
table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者")
table.Timestamp("created_at").UseCurrent().Comment("创建时间")
table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
table.DateTime("deleted_at").Nullable().Comment("删除时间")
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
t.Logf("-- %s -- %s \n", "mysql", sql)
}
}
// SQLite 创建表
func TestSqliteCreateTableToSql(t *testing.T) {
table := NewBlueprint("users")
table.Create()
table.Charset("utf8mb4")
table.BigIncrements("id").Comment("ID")
table.Boolean("enabled").Default("1").Comment("是否有效")
table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者")
table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者")
table.Timestamp("created_at").UseCurrent().Comment("创建时间")
table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
table.DateTime("deleted_at").Nullable().Comment("删除时间")
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
t.Logf("-- %s -- %s \n", "sqlite", sql)
}
}
//Mysql 修改表, 添加字段
func TestMysqlAddColumnToSql(t *testing.T) {
table := NewBlueprint(test_table)
table.String("name", 50).Nullable().Comment("用户名")
table.TinyInteger("age").Nullable().Comment("年龄")
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
t.Logf("-- %s -- %s \n", "mysql", sql)
}
}
// Sqlite 修改表, 添加字段
func TestSqliteAddColumnToSql(t *testing.T) {
table := NewBlueprint(test_table)
table.String("name", 50).Nullable().Comment("用户名")
table.TinyInteger("age").Nullable().Comment("年龄")
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
t.Logf("-- %s -- %s \n", "sqlite", sql)
}
}
// Mysql 修改表, 添加/编辑字段
func TestMysqlChangeColumnToSql(t *testing.T) {
table := NewBlueprint(test_table)
table.String("name", 100).Nullable().Comment("用户名").Change()
table.SmallInteger("age").Nullable().Comment("年龄").Change()
table.String("nickname", 100).Nullable().Comment("昵称")
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
t.Logf("-- %s -- %s \n", "mysql", sql)
}
}
// Sqlite 修改表, 添加/编辑字段
func TestSqliteChangeColumnToSql(t *testing.T) {
table := NewBlueprint(test_table)
table.String("name", 100).Nullable().Comment("用户名").Change()
table.SmallInteger("age").Nullable().Comment("年龄").Change()
table.String("nickname", 100).Nullable().Comment("昵称")
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
t.Logf("-- %s -- %s \n", "sqlite", sql)
}
}
// Mysql 修改表, 添加/编辑/删除字段
func TestMysqlDropColumnToSql(t *testing.T) {
table := NewBlueprint(test_table)
table.String("username", 100).Default("").Comment("--用户名--").Change("name")
table.Boolean("is_vip").Default(1).Comment("是否VIP")
table.DropColumn("age")
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
t.Logf("-- %s -- %s \n", "mysql", sql)
}
}
// Sqlite 修改表, 添加/编辑/删除字段
func TestSqliteDropColumnToSql(t *testing.T) {
table := NewBlueprint(test_table)
table.String("username", 100).Default("").Comment("--用户名--").Change("name")
table.Boolean("is_vip").Default(1).Comment("是否VIP")
table.DropColumn("age")
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
t.Logf("-- %s -- %s \n", "sqlite", sql)
}
}
// Mysql 重命名表
func TestMysqlRenameToSql(t *testing.T) {
table := NewBlueprint(test_table)
table.Rename(test_table + "_alias")
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
t.Logf("-- %s -- %s \n", "mysql", sql)
}
}
// Sqlite 重命名表
func TestSqliteRenameToSql(t *testing.T) {
table := NewBlueprint(test_table)
table.Rename(test_table + "_alias")
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
t.Logf("-- %s -- %s \n", "sqlite", sql)
}
}
// Mysql 删除表
func TestMysqlDropTableToSql(t *testing.T) {
table := NewBlueprint(test_table + "_alias")
table.Drop()
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
t.Logf("-- %s -- %s \n", "mysql", sql)
}
}
// Sqlite 删除表
func TestSqliteDropTableToSql(t *testing.T) {
table := NewBlueprint(test_table + "_alias")
table.Drop()
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
t.Logf("-- %s -- %s \n", "sqlite", sql)
}
}

128
schema/builder.go Normal file
View File

@@ -0,0 +1,128 @@
package schema
import (
"github.com/samber/lo"
"git.fsdpf.net/go/db"
)
type Builder struct {
Connection *db.Connection // The database connection instance
Grammar IGrammar // The schema grammar instance
}
func NewSchema(conn *db.Connection) *Builder {
var grammar IGrammar
switch conn.GetConfig().Driver {
case "mysql":
grammar = &MysqlGrammar{}
case "sqlite3":
grammar = &Sqlite3Grammar{}
}
return &Builder{
Connection: conn,
Grammar: grammar,
}
}
// 判断数据表是否存在
func (this *Builder) HasTable(table string) (bool, error) {
query := this.Grammar.CompileTableExists()
dbName := this.Connection.GetConfig().Database
count := 0
if _, err := this.Connection.Select(query, []any{table, dbName}, &count); err != nil || count == 0 {
return false, err
}
return true, nil
}
// 判断表字段是个存在
func (this *Builder) HasColumns(table string, columns ...string) (bool, error) {
if tColumns, err := this.GetColumnListing(table); err != nil {
return false, err
} else {
for _, col := range columns {
if !lo.Contains(tColumns, col) {
return false, nil
}
}
}
return true, nil
}
// 判断数据库列是否存在
func (this *Builder) GetColumnListing(table string) (columns []string, err error) {
query := this.Grammar.CompileColumnListing(table)
dbName := this.Connection.GetConfig().Database
items := []struct {
ID int `db:"id"`
Name string `db:"name"`
}{}
bindings := []any{table, dbName}
if this.Connection.GetConfig().Driver == "sqlite3" {
bindings = nil
}
if _, err := this.Connection.Select(query, bindings, &items); err != nil {
return nil, err
}
for _, item := range items {
columns = append(columns, item.Name)
}
return
}
// 修改表
func (this *Builder) Table(table string, cb func(*Blueprint)) error {
bp := this.createBlueprint(table)
cb(bp)
return this.Build(bp)
}
// 创建表
func (this *Builder) Create(table string, cb func(*Blueprint)) error {
bp := this.createBlueprint(table)
bp.Create()
cb(bp)
return this.Build(bp)
}
// 修改表名
func (this *Builder) Rename(from, to string) error {
bp := this.createBlueprint(from)
bp.Rename(to)
return this.Build(bp)
}
// 删除表
func (this *Builder) Drop(table string) error {
bp := this.createBlueprint(table)
bp.Drop()
return this.Build(bp)
}
// 删除表, 先判断再删除
func (this *Builder) DropIfExists(table string) error {
bp := this.createBlueprint(table)
bp.DropIfExists()
return this.Build(bp)
}
func (b *Builder) createBlueprint(table string) *Blueprint {
return NewBlueprint(table)
}
// Build execute the blueprint to build / modify the table
func (this *Builder) Build(bp *Blueprint) error {
return bp.Build(this.Connection, this.Grammar)
}

195
schema/builder_test.go Normal file
View File

@@ -0,0 +1,195 @@
package schema
import (
"git.fsdpf.net/go/db"
"testing"
)
const test_table = "test_users"
var mysql *Builder
var sqlite3 *Builder
func init() {
odbc := db.Open(map[string]db.DBConfig{
"default": {
Driver: "mysql",
Host: "localhost",
Port: "3366",
Database: "demo",
Username: "demo",
Password: "ded86bf25d661bb723f3898b2440dd678382e2dd",
Charset: "utf8mb4",
MultiStatements: true,
// ParseTime: true,
},
"sqlite3": {
Driver: "sqlite3",
File: "test.db3",
},
})
mysql = NewSchema(odbc.Connection("default"))
sqlite3 = NewSchema(odbc.Connection("sqlite3"))
}
func TestHasTable(t *testing.T) {
t.Log(mysql.HasTable("users"))
t.Log(sqlite3.HasTable("users"))
}
func TestHasColumns(t *testing.T) {
t.Log("mysql")
t.Log(mysql.HasColumns("users", "id"))
t.Log("sqlite3")
t.Log(sqlite3.HasColumns("users", "id"))
}
func TestGetColumnListing(t *testing.T) {
t.Log("mysql")
t.Log(mysql.GetColumnListing("users"))
t.Log("sqlite3")
t.Log(sqlite3.GetColumnListing("users"))
}
// Mysql 创建表
func TestMysqlCreateTable(t *testing.T) {
if err := mysql.Create(test_table, func(table *Blueprint) {
table.Charset("utf8mb4")
table.BigIncrements("id").Comment("ID")
table.Boolean("enabled").Default("1").Comment("是否有效")
table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者")
table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者")
table.Timestamp("created_at").UseCurrent().Comment("创建时间")
table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
table.DateTime("deleted_at").Nullable().Comment("删除时间")
}); err != nil {
t.Error(err)
}
}
// Sqlite 创建表
func TestSqliteCreateTable(t *testing.T) {
if err := sqlite3.Create(test_table, func(table *Blueprint) {
table.Charset("utf8mb4")
table.Increments("id").Comment("ID")
table.Boolean("enabled").Default("1").Comment("是否有效")
table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者")
table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者")
table.Timestamp("created_at").UseCurrent().Comment("创建时间")
table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
table.DateTime("deleted_at").Nullable().Comment("删除时间")
}); err != nil {
t.Error(err)
}
}
// Mysql 修改表, 添加字段
func TestMysqlAddColumn(t *testing.T) {
if err := mysql.Table(test_table, func(table *Blueprint) {
table.String("name", 50).Nullable().Comment("用户名")
table.TinyInteger("age").Nullable().Comment("年龄")
}); err != nil {
t.Error(err)
}
}
// Sqlite 修改表, 添加字段
func TestSqliteAddColumn(t *testing.T) {
if err := sqlite3.Table(test_table, func(table *Blueprint) {
table.String("name", 50).Nullable().Comment("用户名")
table.TinyInteger("age").Nullable().Comment("年龄")
}); err != nil {
t.Error(err)
}
}
// Mysql 修改表, 添加/编辑字段
func TestMysqlChangeColumn(t *testing.T) {
if err := mysql.Table(test_table, func(table *Blueprint) {
table.String("name", 100).Nullable().Comment("用户名").Change("username")
table.SmallInteger("age").Nullable().Comment("年龄").Change()
table.String("nickname", 100).Nullable().Comment("昵称")
}); err != nil {
t.Error(err)
}
}
// Sqlite 修改表, 添加/编辑字段
func TestSqliteChangeColumn(t *testing.T) {
if err := sqlite3.Table(test_table, func(table *Blueprint) {
table.String("name", 100).Nullable().Comment("用户名").Change("username")
table.SmallInteger("age").Nullable().Comment("年龄").Change()
table.String("nickname", 100).Nullable().Comment("昵称")
}); err != nil {
t.Error(err)
}
}
// Mysql 修改表, 添加/编辑/删除字段
func TestMysqlDropColumn(t *testing.T) {
if err := mysql.Table(test_table, func(table *Blueprint) {
table.String("username", 100).Default("").Comment("--用户名--").Change("name")
table.Boolean("is_vip").Default(1).Comment("是否VIP")
table.DropColumn("age")
}); err != nil {
t.Error(err)
}
}
// Sqlite 修改表, 添加/编辑/删除字段
func TestSqliteDropColumn(t *testing.T) {
if err := sqlite3.Table(test_table, func(table *Blueprint) {
table.String("username", 100).Default("").Comment("--用户名--").Change("name")
table.Boolean("is_vip").Default(1).Comment("是否VIP")
table.DropColumn("age")
}); err != nil {
t.Error(err)
}
}
// Mysql重命名表
func TestMysqlRename(t *testing.T) {
if err := mysql.Rename(test_table, test_table+"_alias"); err != nil {
t.Error(err)
}
}
// SQLite 重命名表
func TestSqliteRename(t *testing.T) {
if err := sqlite3.Rename(test_table, test_table+"_alias"); err != nil {
t.Error(err)
}
}
// Mysql 删除表
func TestMysqlDropTable(t *testing.T) {
if err := mysql.Drop(test_table + "_alias"); err != nil {
t.Error(err)
}
}
// Sqlite 删除表
func TestSqliteDropTable(t *testing.T) {
if err := sqlite3.Drop(test_table + "_alias"); err != nil {
t.Error(err)
}
}
// Mysql 删除表
func TestMysqlDropTableIfExists(t *testing.T) {
if err := mysql.DropIfExists(test_table + "_alias"); err != nil {
t.Error(err)
}
}
// Sqlite 删除表
func TestSqliteDropTableIfExists(t *testing.T) {
if err := sqlite3.DropIfExists(test_table + "_alias"); err != nil {
t.Error(err)
}
}

115
schema/column_definition.go Normal file
View File

@@ -0,0 +1,115 @@
package schema
type ColumnDefinition struct {
Type string // kind of column, string / int
Name string // name of column
ColumnOptions
}
type ColumnOptions struct {
Length int // 字段长度
Allowed []string // enum 选项
Precision int // 日期时间精度
Total int // 小数位数
Places int // 小数精度
rename string // 重命名
useCurrent bool // CURRENT TIMESTAMP
after string // Place the column "after" another column (MySQL)
always bool // Used as a modifier for generatedAs() (PostgreSQL)
autoIncrement bool // Set INTEGER columns as auto-increment (primary key)
change bool // Change the column
charset string // Specify a character set for the column (MySQL)
collation string // Specify a collation for the column (MySQL/PostgreSQL/SQL Server)
comment string // Add a comment to the column (MySQL)
def any // Specify a "default" value for the column
first bool // Place the column "first" in the table (MySQL)
nullable bool // Allow NULL values to be inserted into the column
storedAs string // Create a stored generated column (MySQL)
unsigned bool // Set the INTEGER column as UNSIGNED (MySQL)
virtualAs string // Create a virtual generated column (MySQL)
unique bool // Add a unique index
primary bool // Add a primary index
index bool // Add an index
spatialIndex bool // Add a spatial index
}
// VirtualAs Create a virtual generated column (MySQL)
func (c *ColumnDefinition) VirtualAs(as string) *ColumnDefinition {
c.virtualAs = as
return c
}
// StoredAs Create a stored generated column (MySQL)
func (c *ColumnDefinition) StoredAs(as string) *ColumnDefinition {
c.storedAs = as
return c
}
// Unsigned Set INTEGER columns as UNSIGNED (MySQL)
func (c *ColumnDefinition) Unsigned() *ColumnDefinition {
c.unsigned = true
return c
}
// First Place the column "first" in the table (MySQL)
func (c *ColumnDefinition) First() *ColumnDefinition {
c.first = true
return c
}
// Default Specify a "default" value for the column
func (c *ColumnDefinition) Default(def any) *ColumnDefinition {
c.def = def
return c
}
// Comment Add a comment to a column (MySQL/PostgreSQL)
func (c *ColumnDefinition) Comment(comm string) *ColumnDefinition {
c.comment = comm
return c
}
// Collaction Specify a collation for the column (MySQL/PostgreSQL/SQL Server)
func (c *ColumnDefinition) Collaction(coll string) *ColumnDefinition {
c.collation = coll
return c
}
// Charset Specify a character set for the column (MySQL)
func (c *ColumnDefinition) Charset(chars string) *ColumnDefinition {
c.charset = chars
return c
}
// AutoIncrement set INTEGER columns as auto-increment (primary key)
func (c *ColumnDefinition) AutoIncrement() *ColumnDefinition {
c.autoIncrement = true
return c
}
// After place the column "after" another column (MySQL)
func (c *ColumnDefinition) After(column string) *ColumnDefinition {
c.after = column
return c
}
// Nullable makes a column nullable
func (c *ColumnDefinition) Nullable() *ColumnDefinition {
c.nullable = true
return c
}
// 修改字段, 默认新增
func (c *ColumnDefinition) Change(param ...string) *ColumnDefinition {
if len(param) > 0 && param[0] != "" {
c.rename = param[0]
}
c.change = true
return c
}
// 时间戳
func (c *ColumnDefinition) UseCurrent() *ColumnDefinition {
c.useCurrent = true
return c
}

44
schema/grammar.go Normal file
View File

@@ -0,0 +1,44 @@
package schema
import (
"git.fsdpf.net/go/db"
)
type IGrammar interface {
// 判断表是否存在
CompileTableExists() string
// 字段列
CompileColumnListing(table string) string
// 创建表
CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string
// 添加字段
CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string
// 修改字段
CompileChange(bp *Blueprint, command *Command, conn *db.Connection) []string
// 添加主键
// CompilePrimary(bp *Blueprint, command *Command, conn *db.Connection) string
// 添加唯一键
CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string
// 添加普通索引
CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string
// 添加空间索引
CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string
// 删除表
CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string
// 删除表, 先判断再删除
CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string
// 删除表, 删除列
CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) []string
// 删除主键
// CompileDropPrimary(bp *Blueprint, command *Command, conn *db.Connection) string
// 删除唯一键
CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string
// 删除唯普通索引
CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string
// 删除唯空间索引
CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string
// 表重命名
CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string
// 索引重命名
CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string
}

338
schema/mysql_grammar.go Normal file
View File

@@ -0,0 +1,338 @@
package schema
import (
"fmt"
"strings"
"github.com/samber/lo"
"git.fsdpf.net/go/db"
)
type MysqlGrammar struct {
}
var mysqlDefaultModifiers = []string{
"Unsigned", "Charset", "Collate", "VirtualAs", "StoredAs",
"Nullable", "Default", "Increment", "Comment", "After", "First",
}
var mysqlSerials = []string{
"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger",
}
func (this MysqlGrammar) CompileTableExists() string {
return "select count(*) from information_schema.tables where table_name = ? and table_schema = ? and table_type = 'BASE TABLE'"
}
func (this MysqlGrammar) CompileColumnListing(table string) string {
return "select ordinal_position as `id`, column_name as `name` from information_schema.columns where table_name = ? and table_schema = ? order by ordinal_position"
}
// 创建表
func (this MysqlGrammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string {
sql := this.compileCreateTable(bp, command, conn)
sql = this.compileCreateEncoding(sql, bp, conn)
sql = this.compileCreateEngine(sql, bp, conn)
return sql
}
// 创建表结构
func (this MysqlGrammar) compileCreateTable(bp *Blueprint, command *Command, conn *db.Connection) string {
temporary := "create"
if bp.temporary {
temporary = "create temporary"
}
return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", "))
}
// 创建表 字符集
func (this MysqlGrammar) compileCreateEncoding(sql string, bp *Blueprint, conn *db.Connection) string {
if bp.charset != "" {
sql = sql + " default charset=" + bp.charset
} else if conn.GetConfig().Charset != "" {
sql = sql + " default charset=" + conn.GetConfig().Charset
} else {
sql = sql + " default charset=utf8mb4"
}
if bp.collation != "" {
sql = sql + " collate=" + bp.collation
} else {
sql = sql + " collate=utf8mb4_general_ci"
}
return sql
}
// 创建表存储方式
func (this MysqlGrammar) compileCreateEngine(sql string, bp *Blueprint, conn *db.Connection) string {
if bp.engine != "" {
sql = sql + " engine = " + bp.engine
}
return sql
}
// 添加字段
func (this MysqlGrammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string {
columns := PrefixArray("add", this.getAddedColumns(bp))
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
}
// 修改字段
func (this MysqlGrammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) []string {
columns := PrefixArray("change", this.getChangedColumns(bp))
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
}
// 添加主键
func (this MysqlGrammar) CompilePrimary(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.compileKey(bp, command, "primary key")
}
// 添加唯一键
func (this MysqlGrammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.compileKey(bp, command, "unique")
}
// 添加普通索引
func (this MysqlGrammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.compileKey(bp, command, "index")
}
// 添加空间索引
func (this MysqlGrammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.compileKey(bp, command, "spatial index")
}
// 添加索引
func (this MysqlGrammar) compileKey(bp *Blueprint, command *Command, typ string) string {
algorithm := ""
if command.Algorithm != "" {
algorithm = " using " + command.Algorithm
}
columns := lo.Map(command.Columns, func(column string, _ int) string {
return "`" + column + "`"
})
return fmt.Sprintf("alter table %s add %s %s%s(%s)", bp.table, typ, command.Index, algorithm, strings.Join(columns, ","))
}
// 删除表
func (this MysqlGrammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop table `" + bp.table + "`"
}
// 删除表, 先判断再删除
func (this MysqlGrammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop table if exists `" + bp.table + "`"
}
// 删除列
func (this MysqlGrammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) []string {
columns := lo.Map(command.Columns, func(column string, _ int) string {
return "drop `" + column + "`"
})
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
}
// 删除主键
func (this MysqlGrammar) CompileDropPrimary(bp *Blueprint, command *Command, conn *db.Connection) string {
return "alter table `" + bp.table + "` drop primary key"
}
// 删除索引
func (this MysqlGrammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return "alter table `" + bp.table + "` drop index `" + command.Index + "`"
}
// 删除唯一键
func (this MysqlGrammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.CompileDropIndex(bp, command, conn)
}
// 删除空间索引
func (this MysqlGrammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.CompileDropIndex(bp, command, conn)
}
// 表重命名
func (this MysqlGrammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string {
return "rename table `" + bp.table + "` to `" + command.To + "`"
}
// 索引重命名
func (this MysqlGrammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return fmt.Sprintf("alter table `%s` rename index `%s` to `%s`", bp.table, command.From, command.To)
}
// 删除所有表
func (this MysqlGrammar) CompileDropAllTables(tables ...string) string {
tables = lo.Map(tables, func(table string, _ int) string {
return "`" + table + "`"
})
return "drop table " + strings.Join(tables, ", ")
}
// 删除所有视图
func (this MysqlGrammar) CompileDropAllViews(views ...string) string {
views = lo.Map(views, func(view string, _ int) string {
return "`" + view + "`"
})
return "drop view " + strings.Join(views, ", ")
}
// 查询所有表
// func (this MysqlGrammar) CompileGetAllTables(tables ...string) string {
// }
// 查询所有视图
// func (this MysqlGrammar) CompileGetAllViews(views ...string) string {}
func (this MysqlGrammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string {
for _, modifier := range mysqlDefaultModifiers {
sql = sql + this.GetColumnModifier(modifier, bp, column)
}
return sql
}
func (this MysqlGrammar) GetColumnType(column *ColumnDefinition) string {
switch column.Type {
case "char":
return fmt.Sprintf("char(%d)", column.Length)
case "string":
return fmt.Sprintf("varchar(%d)", column.Length)
case "text":
return "text"
case "integer":
return "int"
case "bigInteger":
return "bigint"
case "tinyInteger":
return "tinyint"
case "smallInteger":
return "smallint"
case "decimal":
return fmt.Sprintf("decimal(%d,%d)", column.Total, column.Places)
case "boolean":
return "tinyint(1)"
case "enum":
return fmt.Sprintf("enum(%s)", QuoteString(column.Allowed))
case "json":
return "json"
case "date":
return "date"
case "binary":
return "binary"
case "datetime":
columnType := "datetime"
if column.Precision > 0 {
columnType = fmt.Sprintf("datetime(%d)", column.Precision)
}
return columnType
case "time":
columnType := "time"
if column.Precision > 0 {
columnType = fmt.Sprintf("time(%d)", column.Precision)
}
return columnType
case "timestamp":
columnType := "timestamp"
if column.Precision > 0 {
columnType = fmt.Sprintf("timestamp(%d)", column.Precision)
}
return columnType
case "year":
return "year"
case "uuid":
return "char(36)"
}
panic("不支持的数据类型: " + column.Type)
}
func (this MysqlGrammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string {
switch modifier {
case "Unsigned":
if column.unsigned {
return " unsigned"
}
case "Charset":
if column.charset != "" {
return fmt.Sprintf(" character set %s", column.charset)
}
case "Collate":
if column.collation != "" {
return fmt.Sprintf(" collate '%s'", column.collation)
}
case "VirtualAs":
if column.virtualAs != "" {
return fmt.Sprintf(" as (%s)", column.virtualAs)
}
case "StoredAs":
if column.storedAs != "" {
return fmt.Sprintf(" as (%s) stored", column.storedAs)
}
case "Nullable":
if column.virtualAs == "" && column.storedAs == "" {
if column.nullable {
return " null"
}
return " not null"
}
case "Default":
switch v := column.def.(type) {
case db.Expression:
if column.useCurrent {
return fmt.Sprintf(" default %s %s", "CURRENT_TIMESTAMP", v)
}
return fmt.Sprintf(" default %s", v)
case nil:
if column.useCurrent {
return fmt.Sprintf(" default %s", "CURRENT_TIMESTAMP")
}
default:
return fmt.Sprintf(" default '%v'", v)
}
case "Increment":
if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) {
return " auto_increment primary key"
}
case "Comment":
if column.comment != "" {
return fmt.Sprintf(" comment '%s'", db.Addslashes(column.comment))
}
case "After":
if column.after != "" {
return fmt.Sprintf(" after `%s`", column.after)
}
case "First":
if column.first {
return " first"
}
}
return ""
}
// 获取新增表结构字段
func (this MysqlGrammar) getAddedColumns(bp *Blueprint) (columns []string) {
for _, column := range bp.getAddedColumns() {
sql := "`" + column.Name + "` " + this.GetColumnType(column)
columns = append(columns, this.addModifiers(sql, bp, column))
}
return
}
// 获取修改表结构字段
func (this MysqlGrammar) getChangedColumns(bp *Blueprint) (columns []string) {
for _, column := range bp.getChangedColumns() {
name := column.Name
rename := column.Name
if column.rename != "" {
rename = column.rename
}
sql := fmt.Sprintf("`%s` `%s` ", name, rename) + this.GetColumnType(column)
columns = append(columns, this.addModifiers(sql, bp, column))
}
return
}

351
schema/sqlite3_grammar.go Normal file
View File

@@ -0,0 +1,351 @@
package schema
import (
"fmt"
"regexp"
"strings"
"github.com/samber/lo"
"git.fsdpf.net/go/db"
)
type Sqlite3Grammar struct {
}
var sqlite3DefaultModifiers = []string{"VirtualAs", "StoredAs", "Nullable", "Default", "Increment"}
var sqlite3Serials = []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"}
func (this Sqlite3Grammar) CompileTableExists() string {
// 兼容其他数据库查询
// 其他数据库查看都是同时传入, 数据库名 和 数据表名
return "select count(*) from sqlite_master where type = 'table' and name = ?||?"
}
func (this Sqlite3Grammar) CompileColumnListing(table string) string {
// 兼容其他数据库查询
// 其他数据库查看都是同时传入, 数据库名 和 数据表名
return "pragma table_info(" + table + ")"
}
// 创建表
func (this Sqlite3Grammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string {
temporary := "create"
if bp.temporary {
temporary = "create temporary"
}
return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", "))
}
// 添加字段
func (this Sqlite3Grammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string {
columns := PrefixArray("add column", this.getAddedColumns(bp))
return lo.Map(columns, func(column string, _ int) string {
return "alter table `" + bp.table + "` " + column
})
}
// 修改字段
func (this Sqlite3Grammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) {
// 创建表的sql
cTableSql := this.getCreateTable(bp, conn)
// 旧表字段
oldColumns := []string{}
if _, err := conn.Select("select '`'||name||'`' from pragma_table_info(?)", []any{bp.table}, &oldColumns); err != nil {
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
}
// 数据同步
insert_sql := fmt.Sprintf("insert into `"+bp.table+"` (%s)", strings.Join(oldColumns, ", "))
select_sql := fmt.Sprintf("select %s from `"+bp.table+"_old`", strings.Join(oldColumns, ", "))
for _, column := range bp.getChangedColumns() {
name := column.Name
rename := column.rename
// 旧字段(查询) -> 新字段(插入)
if column.rename != "" {
select_sql = strings.ReplaceAll(select_sql, name, name+"` as `"+rename)
insert_sql = strings.ReplaceAll(insert_sql, name, rename)
} else {
rename = name
}
// 修改表字段
sql := "`" + rename + "` " + this.GetColumnType(column)
sql = this.addModifiers(sql, bp, column)
reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P<prefix>\\s*)(?P<content>`%s`.*)(?P<end>(,\\s|\\s)?)$", name))
cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql))
}
// 开启事物
// sqls = append(sqls, "BEGIN TRANSACTION")
// 1. 重命名已存在的表
sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table))
// 2. 创建新结构表
sqls = append(sqls, cTableSql)
// 3. 复制数据
sqls = append(sqls, insert_sql+" "+select_sql)
// 4. 删除旧表
sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table))
// 提交事物
// sqls = append(sqls, "COMMIT")
// @todo 复制索引
return sqls
}
// 创建唯一索引
func (this Sqlite3Grammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
columns := lo.Map(command.Columns, func(column string, _ int) string {
return "`" + column + "`"
})
return fmt.Sprintf("create unique index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", "))
}
// 创建普通索引
func (this Sqlite3Grammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
columns := lo.Map(command.Columns, func(column string, _ int) string {
return "`" + column + "`"
})
return fmt.Sprintf("create index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", "))
}
// 创建空间索引
func (this Sqlite3Grammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
panic("Sqlite 不支持 SpatialIndex")
}
// 删除表
func (this Sqlite3Grammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop table `" + bp.table + "`"
}
// 删除表, 先判断再删除
func (this Sqlite3Grammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop table if exists `" + bp.table + "`"
}
// 删除列
func (this Sqlite3Grammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) {
// 创建表的sql, 排除不要的字段
cTableSql := this.getCreateTable(bp, conn, command.Columns...)
// 旧表字段
oldColumns := []string{}
if _, err := conn.Select(
"select '`'||name||'`' from pragma_table_info(?) where `name` not in("+strings.Trim(strings.Repeat("?,", len(command.Columns)), ",")+")",
append([]any{bp.table}, lo.ToAnySlice(command.Columns)...),
&oldColumns); err != nil {
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
}
copy_data_sql := fmt.Sprintf("insert into `%s` select %s from `%s_old`", bp.table, strings.Join(oldColumns, ", "), bp.table)
for _, column := range bp.getChangedColumns() {
name := column.Name
rename := column.rename
if rename != "" {
copy_data_sql = strings.ReplaceAll(copy_data_sql, "`"+name+"`", "`"+rename+"`")
}
// 修改表字段
sql := "`" + rename + "` " + this.GetColumnType(column)
sql = this.addModifiers(sql, bp, column)
reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P<prefix>\\s*)(?P<content>`%s`.*)(?P<end>(,\\s|\\s)?)$", name))
cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql))
}
// 开启事物
// sqls = append(sqls, "BEGIN TRANSACTION")
// 1. 重命名已存在的表
sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table))
// 2. 创建新结构表
sqls = append(sqls, cTableSql)
// 3. 复制数据
sqls = append(sqls, copy_data_sql)
// 4. 删除旧表
sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table))
// 提交事物
// sqls = append(sqls, "COMMIT")
// @todo 复制索引
return sqls
}
// 删除唯一索引
func (this Sqlite3Grammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.CompileDropIndex(bp, command, conn)
}
// 删除索引
func (this Sqlite3Grammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return "drop index `" + command.Index + "`"
}
// 删除空间索引
func (this Sqlite3Grammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
return this.CompileSpatialIndex(bp, command, conn)
}
func (this Sqlite3Grammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string {
return "alter table `" + bp.table + "` rename to `" + command.To + "`"
}
func (this Sqlite3Grammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
// 先删除索引, 后创建索引
return "alter table `" + bp.table + "` rename to `" + command.To + "`"
}
// 删除所有表
func (this Sqlite3Grammar) CompileDropAllTables() string {
return "delete from sqlite_master where type in ('table', 'index', 'trigger')"
}
// 删除所有视图
func (this Sqlite3Grammar) CompileDropAllViews() string {
return "delete from sqlite_master where type in ('view')"
}
func (this Sqlite3Grammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string {
for _, modifier := range mysqlDefaultModifiers {
sql = sql + this.GetColumnModifier(modifier, bp, column)
}
return sql
}
func (this Sqlite3Grammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string {
switch modifier {
case "VirtualAs":
if column.virtualAs != "" {
return fmt.Sprintf(" as (%s)", column.virtualAs)
}
case "StoredAs":
if column.storedAs != "" {
return fmt.Sprintf(" as (%s) stored", column.storedAs)
}
case "Nullable":
if !column.nullable {
return " not null"
}
case "Default":
if column.useCurrent {
return ""
}
switch v := column.def.(type) {
case db.Expression:
return fmt.Sprintf(" default %s", v)
case nil:
default:
return fmt.Sprintf(" default '%v'", v)
}
case "Increment":
if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) {
return " primary key autoincrement"
}
}
return ""
}
func (this Sqlite3Grammar) GetColumnType(column *ColumnDefinition) string {
switch column.Type {
case "char", "string":
return fmt.Sprintf("varchar(%d)", column.Length)
case "text":
return "text"
case "year", "integer":
return "integer"
case "bigInteger":
return "bigint"
case "tinyInteger":
return "tinyint"
case "smallInteger":
return "smallint"
case "decimal":
return fmt.Sprintf("numeric(%d,%d)", column.Total, column.Places)
case "boolean":
return "tinyint(1)"
case "enum":
return fmt.Sprintf(`varchar check ("%s" in (%s))`, column.Name, QuoteString(column.Allowed))
case "json":
return "json"
case "date":
return "date"
case "binary":
return "blob"
case "datetime", "timestamp":
columnType := column.Type
if column.useCurrent {
columnType = columnType + " default CURRENT_TIMESTAMP"
}
return columnType
case "time":
return "time"
case "uuid":
return "varchar(36)"
}
panic("不支持的数据类型: " + column.Type)
}
// 获取新增表结构字段
func (this Sqlite3Grammar) getAddedColumns(bp *Blueprint) (columns []string) {
for _, column := range bp.getAddedColumns() {
sql := "`" + column.Name + "` " + this.GetColumnType(column)
columns = append(columns, this.addModifiers(sql, bp, column))
}
return
}
// 获取修改表结构字段
func (this Sqlite3Grammar) getChangedColumns(bp *Blueprint) map[string]string {
columns := map[string]string{}
for _, column := range bp.getChangedColumns() {
sql := "`" + column.Name + "` " + this.GetColumnType(column)
columns[column.Name] = this.addModifiers(sql, bp, column)
}
return columns
}
// 数据库表字段
func (this Sqlite3Grammar) getCreateTable(bp *Blueprint, conn *db.Connection, without ...string) (sql string) {
dbColumns := []struct {
Name string `db:"name"`
Type string `db:"type"`
Default db.NullString `db:"dflt_value"`
NotNull bool `db:"notnull"`
Pk bool `db:"pk"`
}{}
if _, err := conn.Select("select * from pragma_table_info(?) order by cid", []any{bp.table}, &dbColumns); err != nil {
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
}
cols := []string{}
for _, column := range dbColumns {
if lo.Contains(without, column.Name) {
continue
}
col := " `" + column.Name + "`"
col += " " + column.Type
if column.NotNull {
col += " not null"
}
if column.Pk {
col += " primary key autoincrement"
}
if column.Default != "" {
col += " default " + string(column.Default)
}
cols = append(cols, col)
}
return fmt.Sprintf("CREATE TABLE `%s` (\n%s\n)", bp.table, strings.Join(cols, ", \n"))
}

25
schema/util.go Normal file
View File

@@ -0,0 +1,25 @@
package schema
import (
"github.com/samber/lo"
"strings"
)
func PrefixArray(prefix string, values []string) (items []string) {
for _, value := range values {
items = append(items, prefix+" "+value)
}
return items
}
func QuoteString(value any) string {
switch v := value.(type) {
case []string:
return strings.Join(lo.Map(v, func(item string, _ int) string {
return "'" + item + "'"
}), ", ")
case string:
return "'" + v + "'"
}
return ""
}