first commit
This commit is contained in:
commit
59b3faa171
23
.gitignore
vendored
Normal file
23
.gitignore
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
# ---> Go
|
||||
# If you prefer the allow list template instead of the deny list, see community template:
|
||||
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
|
||||
#
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
98
README.md
Normal file
98
README.md
Normal file
@ -0,0 +1,98 @@
|
||||
# WORK IN PROGRESS
|
||||
# Get Started
|
||||
A golang ORM Framework like Laravel's Eloquent
|
||||
|
||||
Example
|
||||
|
||||
```golang
|
||||
type User struct {
|
||||
Id int64 `goelo:"column:id;primaryKey"`
|
||||
UserName sql.NullString `goelo:"column:name"`
|
||||
Age int `goelo:"column:age"`
|
||||
Status int `goelo:"column:status"`
|
||||
Friends []UserT `goelo:"BelongsToMany:FriendsRelation"`
|
||||
Address Address `goelo:"HasOne:AddressRelation"`
|
||||
CreatedAt time.Time `goelo:"column:created_at,timestatmp:create"`
|
||||
UpdatedAt sql.NullTime `goelo:"column:updated_at,timestatmp:update"`
|
||||
}
|
||||
|
||||
//Find/First/Get
|
||||
var user User
|
||||
DB.Table("users").Find(&user,1)
|
||||
DB.Table("users").Where("name","john").First(&user)
|
||||
DB.Table("users").Where("name","john").FirstOrCreate(&user)
|
||||
|
||||
//Column/Aggeregate
|
||||
var age int
|
||||
DB.Table("users").Where("id","=",20).Value(&age,"age")
|
||||
DB.Table("users").Max(&age,"age")
|
||||
var salary int
|
||||
DB.Table("user").Where("age",">=",30).Avg(&salary,"salary")
|
||||
|
||||
//Find record to map
|
||||
var m = make(map[string]interface{})
|
||||
DB.Query().From("users").Find(&m,3)
|
||||
var ms []map[string]interface{}
|
||||
DB.Query().From("users").Get(&ms)
|
||||
|
||||
//Pagination
|
||||
var users []User
|
||||
DB.Table("users").Where("id", ">", 10).Where("id", "<", 28).Paginate(&users, 10, 1)
|
||||
|
||||
//Chunk/ChunkById
|
||||
var total int
|
||||
totalP := &total
|
||||
DB.Table("users").OrderBy("id").Chunk(&[]User{}, 10, func(dest interface{}) error {
|
||||
us := dest.(*[]User)
|
||||
for _, user := range *us {
|
||||
assert.Equal(t, user.UserName, sql.NullString{
|
||||
String: fmt.Sprintf("user-%d", user.Age),
|
||||
Valid: true,
|
||||
})
|
||||
*totalP++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
var total int
|
||||
totalP := &total
|
||||
DB.Table("users").ChunkById(&[]User{}, 10, func(dest interface{}) error {
|
||||
us := dest.(*[]User)
|
||||
for _, user := range *us {
|
||||
assert.Equal(t, user.UserName, sql.NullString{
|
||||
String: fmt.Sprintf("user-%d", user.Age),
|
||||
Valid: true,
|
||||
})
|
||||
*totalP++
|
||||
}
|
||||
return nil
|
||||
})
|
||||
//Query clause
|
||||
DB.Where("name","john@apple.com").OrWhere("email","john@apple.com").First(&user)
|
||||
|
||||
DB.Where("is_admin", 1).Where(map[string]interface{}{
|
||||
"name": "Joe", "location": "LA",
|
||||
}, db.BOOLEAN_OR).Where(func(builder *db.Builder){
|
||||
builder.WhereYear("created_at", "<", 2010).WhereColumn("first_name", "last_name").OrWhereNull("activited_at")
|
||||
}).ToSql()
|
||||
// sql:"select `name`, `age`, `email` where `is_admin` = ? or (`name` = ? and `location` = ?) and (year(`created_at`) < ? and `first_name` = `last_name` or `activited_at` is null)"
|
||||
|
||||
// Insert/Update/Delete
|
||||
DB.Table("users").Insert(map[string]interface{}{
|
||||
"name": "go-eloquent",
|
||||
"age": 18,
|
||||
"created_at": now,
|
||||
})
|
||||
DB.Table("users").Where("id", id).Update(map[string]interface{}{
|
||||
"name": "newname",
|
||||
"age": 18,
|
||||
"updated_at": now.Add(time.Hour * 1),
|
||||
})
|
||||
DB.Table("users").Where("id", id).Delete()
|
||||
|
||||
```
|
||||
More Details,visit [Docs](https://glitterlip.github.io/go-eloquent-docs/)
|
||||
# Credits
|
||||
[https://github.com/go-gorm/gorm](https://github.com/go-gorm/gorm)
|
||||
[https://github.com/jmoiron/sqlx](https://github.com/jmoiron/sqlx)
|
||||
[https://github.com/qclaogui/database](https://github.com/qclaogui/database)
|
2506
builder.go
Normal file
2506
builder.go
Normal file
File diff suppressed because it is too large
Load Diff
34
config.go
Normal file
34
config.go
Normal file
@ -0,0 +1,34 @@
|
||||
package db
|
||||
|
||||
type DBConfig struct {
|
||||
Driver string
|
||||
//ReadHost []string
|
||||
//WriteHost []string
|
||||
Host string
|
||||
Port string
|
||||
Database string
|
||||
Username string
|
||||
Password string
|
||||
Charset string
|
||||
Prefix string
|
||||
ConnMaxLifetime int
|
||||
ConnMaxIdleTime int
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
ParseTime bool
|
||||
// mysql
|
||||
Collation string
|
||||
UnixSocket string
|
||||
MultiStatements bool
|
||||
Dsn string
|
||||
// pgsql
|
||||
Sslmode string
|
||||
TLS string
|
||||
EnableLog bool
|
||||
// sqlite3
|
||||
File string
|
||||
Journal string // DELETE TRUNCATE PERSIST MEMORY WAL OFF
|
||||
Locking string // NORMAL EXCLUSIVE
|
||||
Mode string // ro rw rwc memory
|
||||
Synchronous int // 0 OFF | 1 NORMAL | 2 FULL | 3 EXTRA
|
||||
}
|
162
connection.go
Normal file
162
connection.go
Normal file
@ -0,0 +1,162 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
RxDB *sql.DB
|
||||
DB *sql.DB
|
||||
Config *DBConfig
|
||||
ConnectionName string
|
||||
}
|
||||
|
||||
type IConnection interface {
|
||||
Insert(query string, bindings []interface{}) (result sql.Result, err error)
|
||||
Select(query string, bindings []interface{}, dest interface{}) (result sql.Result, err error)
|
||||
Update(query string, bindings []interface{}) (result sql.Result, err error)
|
||||
Delete(query string, bindings []interface{}) (result sql.Result, err error)
|
||||
AffectingStatement(query string, bindings []interface{}) (result sql.Result, err error)
|
||||
Statement(query string, bindings []interface{}) (sql.Result, error)
|
||||
Table(tableName string) *Builder
|
||||
}
|
||||
|
||||
type Preparer interface {
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
type Execer interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
type ITransaction interface {
|
||||
BeginTransaction() (*Transaction, error)
|
||||
Transaction(closure TxClosure) (interface{}, error)
|
||||
}
|
||||
|
||||
func (c *Connection) Select(query string, bindings []interface{}, dest interface{}) (result sql.Result, err error) {
|
||||
var stmt *sql.Stmt
|
||||
var rows *sql.Rows
|
||||
stmt, err = c.DB.Prepare(query)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err = stmt.Query(bindings...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return ScanResult(rows, dest), nil
|
||||
}
|
||||
|
||||
func (c *Connection) BeginTransaction() (*Transaction, error) {
|
||||
begin, err := c.DB.Begin()
|
||||
if err != nil {
|
||||
return nil, errors.New(err.Error())
|
||||
}
|
||||
|
||||
tx := &Transaction{
|
||||
Rx: c.RxDB,
|
||||
Tx: begin,
|
||||
Config: c.Config,
|
||||
ConnectionName: c.ConnectionName,
|
||||
}
|
||||
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Transaction(closure TxClosure) (interface{}, error) {
|
||||
begin, err := c.DB.Begin()
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
_ = begin.Rollback()
|
||||
panic(err)
|
||||
} else {
|
||||
err = begin.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
tx := &Transaction{
|
||||
Rx: c.RxDB,
|
||||
Tx: begin,
|
||||
Config: c.Config,
|
||||
ConnectionName: c.ConnectionName,
|
||||
}
|
||||
|
||||
return closure(tx)
|
||||
}
|
||||
|
||||
func (c *Connection) Insert(query string, bindings []interface{}) (result sql.Result, err error) {
|
||||
return c.AffectingStatement(query, bindings)
|
||||
}
|
||||
|
||||
func (c *Connection) Update(query string, bindings []interface{}) (result sql.Result, err error) {
|
||||
return c.AffectingStatement(query, bindings)
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) Delete(query string, bindings []interface{}) (result sql.Result, err error) {
|
||||
return c.AffectingStatement(query, bindings)
|
||||
}
|
||||
|
||||
func (c *Connection) AffectingStatement(query string, bindings []interface{}) (result sql.Result, err error) {
|
||||
stmt, errP := c.DB.Prepare(query)
|
||||
if errP != nil {
|
||||
err = errP
|
||||
return
|
||||
}
|
||||
defer stmt.Close()
|
||||
result, err = stmt.Exec(bindings...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) Table(tableName string) *Builder {
|
||||
builder := NewBuilder(c)
|
||||
if c.GetConfig().Driver == DriverMysql {
|
||||
builder.Grammar = &MysqlGrammar{}
|
||||
} else if c.Config.Driver == DriverSqlite3 {
|
||||
builder.Grammar = &Sqlite3Grammar{}
|
||||
} else {
|
||||
panic("不支持的数据库类型")
|
||||
}
|
||||
builder.Grammar.SetTablePrefix(c.Config.Prefix)
|
||||
builder.Grammar.SetBuilder(builder)
|
||||
builder.From(tableName)
|
||||
return builder
|
||||
}
|
||||
|
||||
func (c *Connection) Statement(query string, bindings []interface{}) (sql.Result, error) {
|
||||
return c.AffectingStatement(query, bindings)
|
||||
}
|
||||
|
||||
func (c *Connection) GetDB() *sql.DB {
|
||||
return c.DB
|
||||
}
|
||||
|
||||
func (c *Connection) Query() *Builder {
|
||||
builder := NewBuilder(c)
|
||||
if c.GetConfig().Driver == DriverMysql {
|
||||
builder.Grammar = &MysqlGrammar{}
|
||||
} else if c.Config.Driver == DriverSqlite3 {
|
||||
builder.Grammar = &Sqlite3Grammar{}
|
||||
} else {
|
||||
panic("不支持的数据库类型")
|
||||
}
|
||||
builder.Grammar.SetBuilder(builder)
|
||||
builder.Grammar.SetTablePrefix(c.Config.Prefix)
|
||||
return builder
|
||||
}
|
||||
|
||||
func (c *Connection) GetConfig() *DBConfig {
|
||||
return c.Config
|
||||
}
|
45
connection_factory.go
Normal file
45
connection_factory.go
Normal file
@ -0,0 +1,45 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
DriverMysql = "mysql"
|
||||
DriverSqlite3 = "sqlite3"
|
||||
)
|
||||
|
||||
type ConnectionFactory struct {
|
||||
}
|
||||
|
||||
type Connector interface {
|
||||
connect(config *DBConfig) *Connection
|
||||
}
|
||||
|
||||
func (f ConnectionFactory) Make(config *DBConfig) *Connection {
|
||||
return f.MakeConnection(config)
|
||||
}
|
||||
|
||||
func (f ConnectionFactory) MakeConnection(config *DBConfig) *Connection {
|
||||
return f.CreateConnection(config)
|
||||
}
|
||||
|
||||
func (f ConnectionFactory) CreateConnection(config *DBConfig) *Connection {
|
||||
switch config.Driver {
|
||||
case DriverMysql:
|
||||
connector := MysqlConnector{}
|
||||
conn := connector.connect(config)
|
||||
return conn
|
||||
case DriverSqlite3:
|
||||
connector := SqliteConnector{}
|
||||
conn := connector.connect(config)
|
||||
return conn
|
||||
case "":
|
||||
panic(errors.New("a driver must be specified"))
|
||||
default:
|
||||
panic(errors.New(fmt.Sprintf("unsupported driver:%s", config.Driver)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
58
db.go
Normal file
58
db.go
Normal file
@ -0,0 +1,58 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
var Engine *DB
|
||||
|
||||
type DB struct {
|
||||
DatabaseManager
|
||||
LogFunc func(log Log)
|
||||
}
|
||||
|
||||
func (d *DB) SetLogger(f func(log Log)) *DB {
|
||||
d.LogFunc = f
|
||||
return d
|
||||
}
|
||||
|
||||
func Open(config map[string]DBConfig) *DB {
|
||||
var configP = make(map[string]*DBConfig)
|
||||
|
||||
for name := range config {
|
||||
c := config[name]
|
||||
configP[name] = &c
|
||||
}
|
||||
|
||||
db := DB{
|
||||
DatabaseManager: DatabaseManager{
|
||||
Configs: configP,
|
||||
Connections: make(map[string]*Connection),
|
||||
},
|
||||
}
|
||||
|
||||
db.Connection("default")
|
||||
|
||||
Engine = &db
|
||||
|
||||
return Engine
|
||||
}
|
||||
|
||||
func (d *DB) AddConfig(name string, config *DBConfig) *DB {
|
||||
Engine.Configs[name] = config
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *DB) GetConfigs() map[string]*DBConfig {
|
||||
return Engine.Configs
|
||||
}
|
||||
|
||||
func (*DB) Raw(connectionName ...string) *sql.DB {
|
||||
if len(connectionName) > 0 {
|
||||
c := Engine.Connection(connectionName[0])
|
||||
return (*c).GetDB()
|
||||
} else {
|
||||
c := Engine.Connection("default")
|
||||
return (*c).GetDB()
|
||||
}
|
||||
}
|
108
db_manager.go
Normal file
108
db_manager.go
Normal file
@ -0,0 +1,108 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type DatabaseManager struct {
|
||||
Connections map[string]*Connection
|
||||
Factory ConnectionFactory
|
||||
Configs map[string]*DBConfig
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) Connection(connectionName string) *Connection {
|
||||
connection, ok := dm.Connections[connectionName]
|
||||
if !ok {
|
||||
dm.Connections[connectionName] = dm.MakeConnection(connectionName)
|
||||
connection, _ = dm.Connections[connectionName]
|
||||
}
|
||||
return connection
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) MakeConnection(connectionName string) *Connection {
|
||||
config, ok := dm.Configs[connectionName]
|
||||
|
||||
if !ok {
|
||||
panic(errors.New(fmt.Sprintf("Database connection %s not configured.", connectionName)))
|
||||
}
|
||||
|
||||
conn := dm.Factory.Make(config)
|
||||
conn.ConnectionName = connectionName
|
||||
dm.Connections[connectionName] = conn
|
||||
return conn
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) getDefaultConnection() (defaultConnectionName string) {
|
||||
defaultConnectionName = "default"
|
||||
return
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) Table(params ...string) *Builder {
|
||||
defaultConn := dm.getDefaultConnection()
|
||||
c := dm.Connection(defaultConn)
|
||||
builder := NewBuilder(c)
|
||||
if c.Config.Driver == DriverMysql {
|
||||
builder.Grammar = &MysqlGrammar{}
|
||||
} else if c.Config.Driver == DriverSqlite3 {
|
||||
builder.Grammar = &Sqlite3Grammar{}
|
||||
} else {
|
||||
panic("不支持的数据库类型")
|
||||
}
|
||||
builder.Grammar.SetTablePrefix(dm.Configs[defaultConn].Prefix)
|
||||
builder.Grammar.SetBuilder(builder)
|
||||
builder.Table(params...)
|
||||
return builder
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) Select(query string, bindings []interface{}, dest interface{}) (sql.Result, error) {
|
||||
ic := dm.Connections["default"]
|
||||
return (*ic).Select(query, bindings, dest)
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) Insert(query string, bindings []interface{}) (sql.Result, error) {
|
||||
ic := dm.Connections["default"]
|
||||
return (*ic).Insert(query, bindings)
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) Update(query string, bindings []interface{}) (sql.Result, error) {
|
||||
ic := dm.Connections["default"]
|
||||
return (*ic).Update(query, bindings)
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) Delete(query string, bindings []interface{}) (sql.Result, error) {
|
||||
ic := dm.Connections["default"]
|
||||
return (*ic).Delete(query, bindings)
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) Statement(query string, bindings []interface{}) (sql.Result, error) {
|
||||
ic := dm.Connections["default"]
|
||||
return (*ic).Delete(query, bindings)
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) Query() *Builder {
|
||||
defaultConn := dm.getDefaultConnection()
|
||||
c := dm.Connection(defaultConn)
|
||||
builder := NewBuilder(c)
|
||||
if c.Config.Driver == DriverMysql {
|
||||
builder.Grammar = &MysqlGrammar{}
|
||||
} else if c.Config.Driver == DriverSqlite3 {
|
||||
builder.Grammar = &Sqlite3Grammar{}
|
||||
} else {
|
||||
panic("不支持的数据库类型")
|
||||
}
|
||||
builder.Grammar.SetTablePrefix(dm.Configs[defaultConn].Prefix)
|
||||
builder.Grammar.SetBuilder(builder)
|
||||
return builder
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) Transaction(closure TxClosure) (interface{}, error) {
|
||||
ic := dm.Connections["default"]
|
||||
return (*ic).Transaction(closure)
|
||||
}
|
||||
|
||||
func (dm *DatabaseManager) BeginTransaction() (*Transaction, error) {
|
||||
ic := dm.Connections["default"]
|
||||
return (*ic).BeginTransaction()
|
||||
}
|
11
go.mod
Normal file
11
go.mod
Normal file
@ -0,0 +1,11 @@
|
||||
module git.fsdpf.net/go/db
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/go-sql-driver/mysql v1.7.0
|
||||
github.com/mattn/go-sqlite3 v1.14.16
|
||||
github.com/samber/lo v1.38.1
|
||||
)
|
||||
|
||||
require golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
|
8
go.sum
Normal file
8
go.sum
Normal file
@ -0,0 +1,8 @@
|
||||
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
|
||||
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
|
||||
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
34
gramar.go
Normal file
34
gramar.go
Normal file
@ -0,0 +1,34 @@
|
||||
package db
|
||||
|
||||
type IGrammar interface {
|
||||
SetTablePrefix(prefix string)
|
||||
|
||||
GetTablePrefix() string
|
||||
|
||||
SetBuilder(builder *Builder)
|
||||
|
||||
GetBuilder() *Builder
|
||||
|
||||
CompileInsert([]map[string]interface{}) string
|
||||
|
||||
CompileInsertOrIgnore([]map[string]interface{}) string
|
||||
|
||||
CompileDelete() string
|
||||
|
||||
CompileUpdate(map[string]interface{}) string
|
||||
|
||||
CompileSelect() string
|
||||
|
||||
CompileExists() string
|
||||
|
||||
Wrap(interface{}, ...bool) string
|
||||
|
||||
WrapTable(interface{}) string
|
||||
|
||||
CompileComponentWheres() string
|
||||
|
||||
CompileComponentJoins() string
|
||||
|
||||
CompileRandom(seed string) string
|
||||
//Wrap(value string, b *query.Builder) string
|
||||
}
|
103
mysql_connector.go
Normal file
103
mysql_connector.go
Normal file
@ -0,0 +1,103 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
// "context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MysqlConnector struct {
|
||||
}
|
||||
|
||||
func (c MysqlConnector) connect(config *DBConfig) *Connection {
|
||||
/**
|
||||
[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]
|
||||
// user@unix(/path/to/socket)/dbname
|
||||
// root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local
|
||||
// user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true
|
||||
// user:password@/dbname?sql_mode=TRADITIONAL
|
||||
// user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci
|
||||
// id:password@tcp(your-amazonaws-uri.com:3306)/dbname
|
||||
// user@cloudsql(project-id:instance-name)/dbname
|
||||
// user@cloudsql(project-id:regionname:instance-name)/dbname
|
||||
// user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped
|
||||
// user:password@/dbname
|
||||
// user:password@/
|
||||
*/
|
||||
//TODO: Protocol loc readTimeout serverPubKey timeout
|
||||
if config.MaxOpenConns == 0 {
|
||||
config.MaxOpenConns = 15
|
||||
}
|
||||
if config.MaxIdleConns == 0 {
|
||||
config.MaxIdleConns = 5
|
||||
}
|
||||
if config.ConnMaxLifetime == 0 {
|
||||
config.ConnMaxLifetime = 86400
|
||||
}
|
||||
if config.ConnMaxIdleTime == 0 {
|
||||
config.ConnMaxIdleTime = 7200
|
||||
}
|
||||
var params []string
|
||||
|
||||
if len(config.Charset) > 0 {
|
||||
params = append(params, "charset="+config.Charset)
|
||||
}
|
||||
if len(config.Collation) > 0 {
|
||||
params = append(params, "collation="+config.Collation)
|
||||
}
|
||||
if config.MultiStatements {
|
||||
params = append(params, "multiStatements=true")
|
||||
}
|
||||
if config.ParseTime {
|
||||
params = append(params, "parseTime=true")
|
||||
}
|
||||
var dsn string
|
||||
if len(config.Dsn) > 0 {
|
||||
dsn = config.Dsn
|
||||
} else {
|
||||
dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?%s", config.Username, config.Password, config.Host, config.Port, config.Database, strings.Join(params, "&"))
|
||||
}
|
||||
db, err := sql.Open(DriverMysql, dsn)
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
} else if err := db.Ping(); err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
// 最大链接数, 默认0, 不限制
|
||||
db.SetMaxOpenConns(config.MaxOpenConns)
|
||||
// 最大空闲链接数, 默认2
|
||||
db.SetMaxIdleConns(config.MaxIdleConns)
|
||||
// 表示在连接池中连接的最大生存时间, 默认0, 表示不限制。
|
||||
db.SetConnMaxLifetime(time.Duration(config.ConnMaxLifetime) * time.Second)
|
||||
// 连接池中空闲连接的最大生存时间, 默认0, 表示不限制。
|
||||
db.SetConnMaxIdleTime(time.Duration(config.ConnMaxIdleTime) * time.Second)
|
||||
}
|
||||
|
||||
// 优化事物查询速度
|
||||
// Golang 在 MYSQL 事物中不能并发执行查询, 会出现 driver:bad connection,
|
||||
// 为了在事物中执行并发查询, 单独开一个脏读查询链接
|
||||
|
||||
// 处理 driver:bad connection, 需要加锁
|
||||
// 同一事务开启多协程的同时如果有并发读, 那可能会出现 driver:bad connection 错误,
|
||||
// 原因是同一事务同一时间只能有一个可以进行读操作, 读完之后需要将查询得到的Rows关闭.
|
||||
if len(params) > 0 {
|
||||
dsn = dsn + `&`
|
||||
}
|
||||
RxDB, err := sql.Open(DriverMysql, dsn+`transaction_isolation='read-uncommitted'`)
|
||||
if err != nil {
|
||||
panic("开启事物链接失败, " + err.Error())
|
||||
} else {
|
||||
RxDB.SetMaxOpenConns(5)
|
||||
RxDB.SetMaxIdleConns(5)
|
||||
}
|
||||
|
||||
return &Connection{
|
||||
RxDB: RxDB,
|
||||
DB: db,
|
||||
Config: config,
|
||||
}
|
||||
}
|
569
mysql_grammar.go
Normal file
569
mysql_grammar.go
Normal file
@ -0,0 +1,569 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MysqlGrammar struct {
|
||||
Prefix string
|
||||
Builder *Builder
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) SetTablePrefix(prefix string) {
|
||||
m.Prefix = prefix
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) GetTablePrefix() string {
|
||||
return m.Prefix
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) SetBuilder(builder *Builder) {
|
||||
m.Builder = builder
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) GetBuilder() *Builder {
|
||||
return m.Builder
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileInsert(values []map[string]interface{}) string {
|
||||
b := m.GetBuilder()
|
||||
b.PreSql.WriteString("insert into ")
|
||||
b.PreSql.WriteString(m.CompileComponentTable())
|
||||
b.PreSql.WriteString(" (")
|
||||
first := values[0]
|
||||
length := len(values)
|
||||
var columns []string
|
||||
var columnizeVars []interface{}
|
||||
for key := range first {
|
||||
if (b.OnlyColumns == nil && b.ExceptColumns == nil) || b.FileterColumn(key) {
|
||||
columns = append(columns, key)
|
||||
columnizeVars = append(columnizeVars, key)
|
||||
}
|
||||
}
|
||||
columnLength := len(columns)
|
||||
b.PreSql.WriteString(m.columnize(columnizeVars))
|
||||
b.PreSql.WriteString(") values")
|
||||
|
||||
for k, v := range values {
|
||||
b.PreSql.WriteString(" (")
|
||||
for i, key := range columns {
|
||||
b.PreSql.WriteString(m.parameter(v[key]))
|
||||
b.AddBinding([]interface{}{v[key]}, TYPE_INSERT)
|
||||
if i != columnLength-1 {
|
||||
b.PreSql.WriteString(", ")
|
||||
} else {
|
||||
b.PreSql.WriteString(")")
|
||||
}
|
||||
}
|
||||
if k != length-1 {
|
||||
b.PreSql.WriteString(", ")
|
||||
}
|
||||
}
|
||||
b.PreparedSql = b.PreSql.String()
|
||||
b.PreSql.Reset()
|
||||
return b.PreparedSql
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileInsertOrIgnore(values []map[string]interface{}) string {
|
||||
m.GetBuilder().PreparedSql = strings.Replace(m.CompileInsert(values), "insert", "insert ignore", 1)
|
||||
return m.GetBuilder().PreparedSql
|
||||
}
|
||||
func (m *MysqlGrammar) CompileDelete() string {
|
||||
b := m.GetBuilder()
|
||||
b.PreSql.WriteString("delete from ")
|
||||
b.PreSql.WriteString(m.CompileComponentTable())
|
||||
b.PreSql.WriteString(m.CompileComponentWheres())
|
||||
if len(b.Orders) > 0 {
|
||||
b.PreSql.WriteString(m.CompileComponentOrders())
|
||||
}
|
||||
if b.LimitNum > 0 {
|
||||
b.PreSql.WriteString(m.CompileComponentLimitNum())
|
||||
}
|
||||
m.GetBuilder().PreparedSql = m.GetBuilder().PreSql.String()
|
||||
b.PreSql.Reset()
|
||||
return m.GetBuilder().PreparedSql
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileUpdate(value map[string]interface{}) string {
|
||||
b := m.GetBuilder()
|
||||
b.PreSql.WriteString("update ")
|
||||
b.PreSql.WriteString(m.CompileComponentTable())
|
||||
b.PreSql.WriteString(" set ")
|
||||
count := 0
|
||||
length := len(value)
|
||||
for k, v := range value {
|
||||
count++
|
||||
if (b.OnlyColumns == nil && b.ExceptColumns == nil) || b.FileterColumn(k) {
|
||||
b.PreSql.WriteString(m.Wrap(k))
|
||||
b.PreSql.WriteString(" = ")
|
||||
b.AddBinding([]interface{}{v}, TYPE_UPDATE)
|
||||
if e, ok := v.(Expression); ok {
|
||||
b.PreSql.WriteString(string(e))
|
||||
} else {
|
||||
b.PreSql.WriteString(m.parameter(v))
|
||||
}
|
||||
}
|
||||
if count != length {
|
||||
b.PreSql.WriteString(" , ")
|
||||
}
|
||||
}
|
||||
b.PreSql.WriteString(m.CompileComponentWheres())
|
||||
m.GetBuilder().PreparedSql = b.PreSql.String()
|
||||
b.PreSql.Reset()
|
||||
return m.GetBuilder().PreparedSql
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileSelect() string {
|
||||
b := m.GetBuilder()
|
||||
b.PreparedSql = ""
|
||||
b.PreSql = strings.Builder{}
|
||||
if _, ok := b.Components[TYPE_COLUMN]; !ok || len(b.Columns) == 0 {
|
||||
b.Components[TYPE_COLUMN] = struct{}{}
|
||||
b.Columns = append(b.Columns, "*")
|
||||
}
|
||||
for _, componentName := range SelectComponents {
|
||||
if _, ok := b.Components[componentName]; ok {
|
||||
b.PreSql.WriteString(m.compileComponent(componentName))
|
||||
}
|
||||
}
|
||||
b.PreparedSql = b.PreSql.String()
|
||||
b.PreSql.Reset()
|
||||
return b.PreparedSql
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileExists() string {
|
||||
sql := m.CompileSelect()
|
||||
m.GetBuilder().PreparedSql = fmt.Sprintf("select exists(%s) as %s", sql, m.Wrap("exists"))
|
||||
return m.GetBuilder().PreparedSql
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) compileComponent(componentName string) string {
|
||||
switch componentName {
|
||||
case TYPE_AGGREGRATE:
|
||||
return m.CompileComponentAggregate()
|
||||
case TYPE_COLUMN:
|
||||
return m.CompileComponentColumns()
|
||||
case TYPE_FROM:
|
||||
return m.CompileComponentFromTable()
|
||||
case TYPE_JOIN:
|
||||
return m.CompileComponentJoins()
|
||||
case TYPE_WHERE:
|
||||
return m.CompileComponentWheres()
|
||||
case TYPE_GROUP_BY:
|
||||
return m.CompileComponentGroups()
|
||||
case TYPE_HAVING:
|
||||
return m.CompileComponentHavings()
|
||||
case TYPE_ORDER:
|
||||
return m.CompileComponentOrders()
|
||||
case TYPE_LIMIT:
|
||||
return m.CompileComponentLimitNum()
|
||||
case TYPE_OFFSET:
|
||||
return m.CompileComponentOffsetNum()
|
||||
case "unions":
|
||||
case TYPE_LOCK:
|
||||
return m.CompileLock()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
func (m *MysqlGrammar) CompileComponentAggregate() string {
|
||||
builder := strings.Builder{}
|
||||
aggregate := m.GetBuilder().Aggregates[0]
|
||||
builder.WriteString("select ")
|
||||
builder.WriteString(aggregate.AggregateName)
|
||||
builder.WriteString("(")
|
||||
if m.GetBuilder().IsDistinct && aggregate.AggregateColumn != "*" {
|
||||
builder.WriteString("distinct ")
|
||||
builder.WriteString(m.Wrap(aggregate.AggregateColumn))
|
||||
} else {
|
||||
builder.WriteString(m.Wrap(aggregate.AggregateColumn))
|
||||
}
|
||||
builder.WriteString(") as aggregate")
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// Convert []string column names into a delimited string.
|
||||
// Compile the "select *" portion of the query.
|
||||
func (m *MysqlGrammar) CompileComponentColumns() string {
|
||||
if len(m.GetBuilder().Aggregates) == 0 {
|
||||
builder := strings.Builder{}
|
||||
if m.GetBuilder().IsDistinct {
|
||||
builder.WriteString("select distinct ")
|
||||
} else {
|
||||
builder.WriteString("select ")
|
||||
}
|
||||
builder.WriteString(m.columnize(m.GetBuilder().Columns))
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
return ""
|
||||
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileComponentFromTable() string {
|
||||
builder := strings.Builder{}
|
||||
builder.WriteString(" from ")
|
||||
builder.WriteString(m.WrapTable(m.GetBuilder().FromTable))
|
||||
return builder.String()
|
||||
}
|
||||
func (m *MysqlGrammar) CompileComponentTable() string {
|
||||
return m.WrapTable(m.GetBuilder().FromTable)
|
||||
}
|
||||
func (m *MysqlGrammar) CompileComponentJoins() string {
|
||||
builder := strings.Builder{}
|
||||
for _, join := range m.GetBuilder().Joins {
|
||||
var tableAndNestedJoins string
|
||||
if len(join.Joins) > 0 {
|
||||
//nested join
|
||||
tableAndNestedJoins = fmt.Sprintf("(%s%s)", m.WrapTable(join.JoinTable), join.Grammar.CompileComponentJoins())
|
||||
} else {
|
||||
tableAndNestedJoins = m.WrapTable(join.JoinTable)
|
||||
}
|
||||
onStr := join.Grammar.CompileComponentWheres()
|
||||
s := ""
|
||||
if len(onStr) > 0 {
|
||||
s = fmt.Sprintf(" %s join %s %s", join.JoinType, tableAndNestedJoins, strings.TrimSpace(onStr))
|
||||
} else {
|
||||
s = fmt.Sprintf(" %s join %s", join.JoinType, tableAndNestedJoins)
|
||||
}
|
||||
builder.WriteString(s)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// db.Query("select * from `groups` where `id` in (?,?,?,?)", []interface{}{7,"8","9","10"}...)
|
||||
func (m *MysqlGrammar) CompileComponentWheres() string {
|
||||
if len(m.GetBuilder().Wheres) == 0 {
|
||||
return ""
|
||||
}
|
||||
builder := strings.Builder{}
|
||||
if m.GetBuilder().JoinBuilder {
|
||||
builder.WriteString(" on ")
|
||||
} else {
|
||||
builder.WriteString(" where ")
|
||||
}
|
||||
for i, w := range m.GetBuilder().Wheres {
|
||||
if i != 0 {
|
||||
builder.WriteString(" " + w.Boolean + " ")
|
||||
}
|
||||
if w.Type == CONDITION_TYPE_NESTED {
|
||||
builder.WriteString("(")
|
||||
cloneBuilder := w.Value.(*Builder)
|
||||
for j := 0; j < len(cloneBuilder.Wheres); j++ {
|
||||
nestedWhere := cloneBuilder.Wheres[j]
|
||||
if j != 0 {
|
||||
builder.WriteString(" " + nestedWhere.Boolean + " ")
|
||||
}
|
||||
//when compile nested where,we need bind the generated sql and params to the current builder
|
||||
g := cloneBuilder.Grammar.(*MysqlGrammar)
|
||||
//this will bind params to current builder when call m.parameter()
|
||||
g.SetBuilder(m.GetBuilder())
|
||||
nestedSql := strings.TrimSpace(g.CompileWhere(nestedWhere))
|
||||
nestedSql = strings.Replace(nestedSql, " where ", "", 1)
|
||||
builder.WriteString(nestedSql)
|
||||
}
|
||||
builder.WriteString(")")
|
||||
} else if w.Type == CONDITION_TYPE_SUB {
|
||||
builder.WriteString(m.Wrap(w.Column))
|
||||
builder.WriteString(" " + w.Operator + " ")
|
||||
builder.WriteString("(")
|
||||
cb := CloneBuilderWithTable(m.GetBuilder())
|
||||
|
||||
if clousure, ok := w.Value.(func(builder *Builder)); ok {
|
||||
clousure(cb)
|
||||
sql := cb.Grammar.CompileSelect()
|
||||
builder.WriteString(sql)
|
||||
}
|
||||
builder.WriteString(")")
|
||||
|
||||
} else {
|
||||
builder.WriteString(m.CompileWhere(w))
|
||||
}
|
||||
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
func (m *MysqlGrammar) CompileWhere(w Where) (sql string) {
|
||||
var sqlBuilder strings.Builder
|
||||
switch w.Type {
|
||||
case CONDITION_TYPE_BASIC:
|
||||
sqlBuilder.WriteString(m.Wrap(w.Column))
|
||||
sqlBuilder.WriteString(" " + w.Operator + " ")
|
||||
sqlBuilder.WriteString(m.parameter(w.Value))
|
||||
case CONDITION_TYPE_BETWEEN:
|
||||
sqlBuilder.WriteString(m.Wrap(w.Column))
|
||||
if w.Not {
|
||||
sqlBuilder.WriteString(" not between ")
|
||||
} else {
|
||||
sqlBuilder.WriteString(" between ")
|
||||
}
|
||||
sqlBuilder.WriteString(m.parameter(w.Values[0]))
|
||||
sqlBuilder.WriteString(" and ")
|
||||
sqlBuilder.WriteString(m.parameter(w.Values[1]))
|
||||
case CONDITION_TYPE_BETWEEN_COLUMN:
|
||||
sqlBuilder.WriteString(m.Wrap(w.Column))
|
||||
if w.Not {
|
||||
sqlBuilder.WriteString(" not between ")
|
||||
} else {
|
||||
sqlBuilder.WriteString(" between ")
|
||||
}
|
||||
sqlBuilder.WriteString(m.Wrap(w.Values[0]))
|
||||
sqlBuilder.WriteString(" and ")
|
||||
sqlBuilder.WriteString(m.Wrap(w.Values[1]))
|
||||
case CONDITION_TYPE_IN:
|
||||
if len(w.Values) == 0 {
|
||||
if w.Not {
|
||||
sqlBuilder.WriteString("1 = 1")
|
||||
} else {
|
||||
sqlBuilder.WriteString("0 = 1")
|
||||
}
|
||||
} else {
|
||||
sqlBuilder.WriteString(m.Wrap(w.Column))
|
||||
if w.Not {
|
||||
sqlBuilder.WriteString(" not in (")
|
||||
} else {
|
||||
sqlBuilder.WriteString(" in (")
|
||||
}
|
||||
sqlBuilder.WriteString(m.parameter(w.Values...))
|
||||
sqlBuilder.WriteString(")")
|
||||
}
|
||||
|
||||
case CONDITION_TYPE_DATE, CONDITION_TYPE_TIME, CONDITION_TYPE_DAY, CONDITION_TYPE_MONTH, CONDITION_TYPE_YEAR:
|
||||
sqlBuilder.WriteString(w.Type)
|
||||
sqlBuilder.WriteString("(")
|
||||
sqlBuilder.WriteString(m.Wrap(w.Column))
|
||||
sqlBuilder.WriteString(") ")
|
||||
sqlBuilder.WriteString(w.Operator)
|
||||
sqlBuilder.WriteString(" ")
|
||||
sqlBuilder.WriteString(m.parameter(w.Value))
|
||||
case CONDITION_TYPE_NULL:
|
||||
sqlBuilder.WriteString(m.Wrap(w.Column))
|
||||
sqlBuilder.WriteString(" is ")
|
||||
if w.Not {
|
||||
sqlBuilder.WriteString("not ")
|
||||
}
|
||||
sqlBuilder.WriteString("null")
|
||||
case CONDITION_TYPE_COLUMN:
|
||||
sqlBuilder.WriteString(m.Wrap(w.FirstColumn))
|
||||
sqlBuilder.WriteString(" ")
|
||||
sqlBuilder.WriteString(w.Operator)
|
||||
sqlBuilder.WriteString(" ")
|
||||
sqlBuilder.WriteString(m.Wrap(w.SecondColumn))
|
||||
case CONDITION_TYPE_RAW:
|
||||
sqlBuilder.WriteString(string(w.RawSql.(Expression)))
|
||||
case CONDITION_TYPE_NESTED:
|
||||
sqlBuilder.WriteString("(")
|
||||
sqlBuilder.WriteString(w.Value.(*Builder).Grammar.CompileComponentWheres())
|
||||
sqlBuilder.WriteString(") ")
|
||||
case CONDITION_TYPE_EXIST:
|
||||
if w.Not {
|
||||
sqlBuilder.WriteString("not ")
|
||||
}
|
||||
sqlBuilder.WriteString("exists ")
|
||||
sqlBuilder.WriteString(fmt.Sprintf("(%s)", w.Query.ToSql()))
|
||||
case CONDITION_TYPE_ROW_VALUES:
|
||||
var columns []interface{}
|
||||
for _, column := range w.Columns {
|
||||
columns = append(columns, column)
|
||||
}
|
||||
sqlBuilder.WriteString(fmt.Sprintf("(%s) %s (%s)", m.columnize(columns), w.Operator, m.parameter(w.Values...)))
|
||||
|
||||
default:
|
||||
panic("where type not Found")
|
||||
}
|
||||
return sqlBuilder.String()
|
||||
|
||||
}
|
||||
func (m *MysqlGrammar) parameter(values ...interface{}) string {
|
||||
var ps []string
|
||||
for _, value := range values {
|
||||
if expre, ok := value.(Expression); ok {
|
||||
ps = append(ps, string(expre))
|
||||
} else {
|
||||
ps = append(ps, "?")
|
||||
}
|
||||
}
|
||||
return strings.Join(ps, ",")
|
||||
}
|
||||
func (m *MysqlGrammar) CompileComponentGroups() string {
|
||||
builder := strings.Builder{}
|
||||
builder.WriteString(" group by ")
|
||||
builder.WriteString(m.columnize(m.GetBuilder().Groups))
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileComponentHavings() string {
|
||||
builder := strings.Builder{}
|
||||
builder.WriteString(" having ")
|
||||
for i, having := range m.GetBuilder().Havings {
|
||||
if i != 0 {
|
||||
builder.WriteString(" " + having.HavingBoolean + " ")
|
||||
}
|
||||
if having.HavingType == CONDITION_TYPE_BASIC {
|
||||
builder.WriteString(m.Wrap(having.HavingColumn))
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(having.HavingOperator)
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(m.parameter(having.HavingValue))
|
||||
} else if having.HavingType == CONDITION_TYPE_RAW {
|
||||
builder.WriteString(string(having.RawSql.(Expression)))
|
||||
} else if having.HavingType == CONDITION_TYPE_BETWEEN {
|
||||
vs := having.HavingValue.([]interface{})
|
||||
builder.WriteString(m.Wrap(having.HavingColumn))
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(CONDITION_TYPE_BETWEEN)
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(m.parameter(vs[0]))
|
||||
builder.WriteString(" and ")
|
||||
builder.WriteString(m.parameter(vs[1]))
|
||||
}
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileComponentOrders() string {
|
||||
builder := strings.Builder{}
|
||||
builder.WriteString(" order by ")
|
||||
for i, order := range m.GetBuilder().Orders {
|
||||
if i != 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
if order.OrderType == CONDITION_TYPE_RAW {
|
||||
builder.WriteString(string(order.RawSql.(Expression)))
|
||||
continue
|
||||
}
|
||||
builder.WriteString(m.Wrap(order.Column))
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(order.Direction)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileComponentLimitNum() string {
|
||||
if m.GetBuilder().LimitNum >= 0 {
|
||||
builder := strings.Builder{}
|
||||
builder.WriteString(" limit ")
|
||||
builder.WriteString(fmt.Sprintf("%v", m.GetBuilder().LimitNum))
|
||||
return builder.String()
|
||||
}
|
||||
return ""
|
||||
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) CompileComponentOffsetNum() string {
|
||||
if m.GetBuilder().OffsetNum >= 0 {
|
||||
builder := strings.Builder{}
|
||||
builder.WriteString(" offset ")
|
||||
builder.WriteString(fmt.Sprintf("%v", m.GetBuilder().OffsetNum))
|
||||
return builder.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
func (m *MysqlGrammar) CompileLock() string {
|
||||
switch m.GetBuilder().LockMode.(type) {
|
||||
case string:
|
||||
return " " + m.GetBuilder().LockMode.(string)
|
||||
case bool:
|
||||
boolean := m.GetBuilder().LockMode.(bool)
|
||||
if boolean {
|
||||
return " for update"
|
||||
} else {
|
||||
return " lock in share mode"
|
||||
}
|
||||
case nil:
|
||||
return " for update"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
func (m *MysqlGrammar) columnize(columns []interface{}) string {
|
||||
builder := strings.Builder{}
|
||||
var t []string
|
||||
for _, value := range columns {
|
||||
if s, ok := value.(string); ok {
|
||||
t = append(t, m.Wrap(s))
|
||||
} else if e, ok := value.(Expression); ok {
|
||||
t = append(t, string(e))
|
||||
}
|
||||
}
|
||||
builder.WriteString(strings.Join(t, ", "))
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
/*
|
||||
Wrap a value in keyword identifiers.
|
||||
*/
|
||||
func (m *MysqlGrammar) Wrap(value interface{}, prefixAlias ...bool) string {
|
||||
prefix := false
|
||||
if expr, ok := value.(Expression); ok {
|
||||
return string(expr)
|
||||
}
|
||||
str := value.(string)
|
||||
if strings.Contains(str, " as ") {
|
||||
if len(prefixAlias) > 0 && prefixAlias[0] {
|
||||
prefix = true
|
||||
}
|
||||
return m.WrapAliasedValue(str, prefix)
|
||||
}
|
||||
return m.WrapSegments(strings.Split(str, "."))
|
||||
}
|
||||
|
||||
func (m *MysqlGrammar) WrapAliasedValue(value string, prefixAlias ...bool) string {
|
||||
var result strings.Builder
|
||||
separator := " as "
|
||||
segments := strings.SplitN(value, separator, 2)
|
||||
if len(prefixAlias) > 0 && prefixAlias[0] {
|
||||
segments[1] = m.GetTablePrefix() + segments[1]
|
||||
}
|
||||
result.WriteString(m.Wrap(segments[0]))
|
||||
result.WriteString(" as ")
|
||||
result.WriteString(m.WrapValue(segments[1]))
|
||||
return result.String()
|
||||
}
|
||||
|
||||
/*
|
||||
WrapSegments Wrap the given value segments.
|
||||
|
||||
user.name => "prefix_user"."name"
|
||||
*/
|
||||
func (m *MysqlGrammar) WrapSegments(values []string) string {
|
||||
var segments []string
|
||||
paramLength := len(values)
|
||||
for i, value := range values {
|
||||
if paramLength > 1 && i == 0 {
|
||||
segments = append(segments, m.WrapTable(value))
|
||||
} else {
|
||||
segments = append(segments, m.WrapValue(value))
|
||||
}
|
||||
}
|
||||
return strings.Join(segments, ".")
|
||||
}
|
||||
|
||||
/*
|
||||
WrapTable wrap a table in keyword identifiers.
|
||||
*/
|
||||
func (m *MysqlGrammar) WrapTable(tableName interface{}) string {
|
||||
if str, ok := tableName.(string); ok {
|
||||
return m.Wrap(m.GetTablePrefix()+str, true)
|
||||
} else if expr, ok := tableName.(Expression); ok {
|
||||
return string(expr)
|
||||
} else {
|
||||
panic(errors.New("tablename type mismatch"))
|
||||
}
|
||||
}
|
||||
|
||||
// table => `table`
|
||||
// t1"t2 => `t1``t2`
|
||||
func (m *MysqlGrammar) WrapValue(value string) string {
|
||||
if value != "*" {
|
||||
return fmt.Sprintf("`%s`", strings.ReplaceAll(value, "`", "``"))
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
/*
|
||||
CompileRandom Compile the random statement into SQL.
|
||||
*/
|
||||
func (m *MysqlGrammar) CompileRandom(seed string) string {
|
||||
return fmt.Sprintf("RAND(%s)", seed)
|
||||
}
|
7
raw.go
Normal file
7
raw.go
Normal file
@ -0,0 +1,7 @@
|
||||
package db
|
||||
|
||||
type Expression string
|
||||
|
||||
func Raw(expr string) Expression {
|
||||
return Expression(expr)
|
||||
}
|
279
scan.go
Normal file
279
scan.go
Normal file
@ -0,0 +1,279 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type RowScan struct {
|
||||
Count int
|
||||
}
|
||||
|
||||
func (RowScan) LastInsertId() (int64, error) {
|
||||
return 0, errors.New("no insert in select")
|
||||
}
|
||||
|
||||
func (v RowScan) RowsAffected() (int64, error) {
|
||||
return int64(v.Count), nil
|
||||
}
|
||||
|
||||
// 入口
|
||||
func ScanResult(rows *sql.Rows, dest any) (result RowScan) {
|
||||
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
||||
|
||||
if reflect.TypeOf(dest).Kind() != reflect.Ptr {
|
||||
panic("dest 请传入指针类型")
|
||||
}
|
||||
|
||||
if realDest.Kind() == reflect.Slice {
|
||||
slice := realDest.Type()
|
||||
sliceItem := slice.Elem()
|
||||
if sliceItem.Kind() == reflect.Map {
|
||||
return sMapSlice(rows, dest)
|
||||
} else if sliceItem.Kind() == reflect.Struct {
|
||||
return sStructSlice(rows, dest)
|
||||
} else if sliceItem.Kind() == reflect.Ptr {
|
||||
if sliceItem.Elem().Kind() == reflect.Struct {
|
||||
return sStructSlice(rows, dest)
|
||||
}
|
||||
}
|
||||
return sValues(rows, dest)
|
||||
} else if realDest.Kind() == reflect.Struct {
|
||||
return sStruct(rows, dest)
|
||||
} else if realDest.Kind() == reflect.Map {
|
||||
return sMap(rows, dest)
|
||||
} else {
|
||||
for rows.Next() {
|
||||
result.Count++
|
||||
err := rows.Scan(dest)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// to struct slice
|
||||
func sStructSlice(rows *sql.Rows, dest any) (result RowScan) {
|
||||
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
||||
slice := realDest.Type()
|
||||
sliceItem := slice.Elem()
|
||||
itemIsPtr := sliceItem.Kind() == reflect.Ptr
|
||||
keys := map[string]int{}
|
||||
columns, _ := rows.Columns()
|
||||
|
||||
lenField := 0
|
||||
if itemIsPtr {
|
||||
lenField = sliceItem.Elem().NumField()
|
||||
} else {
|
||||
lenField = sliceItem.NumField()
|
||||
}
|
||||
|
||||
// 缓存 dest 结构
|
||||
// key is field name, value is field index
|
||||
for i := 0; i < lenField; i++ {
|
||||
var tField reflect.StructField
|
||||
if itemIsPtr {
|
||||
tField = sliceItem.Elem().Field(i)
|
||||
} else {
|
||||
tField = sliceItem.Field(i)
|
||||
}
|
||||
kFiled := tField.Tag.Get("db")
|
||||
if kFiled == "" {
|
||||
kFiled = tField.Name
|
||||
}
|
||||
keys[kFiled] = i
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
result.Count++
|
||||
var v, vp reflect.Value
|
||||
if itemIsPtr {
|
||||
vp = reflect.New(sliceItem.Elem())
|
||||
v = reflect.Indirect(vp)
|
||||
} else {
|
||||
vp = reflect.New(sliceItem)
|
||||
v = reflect.Indirect(vp)
|
||||
}
|
||||
|
||||
scanArgs := make([]any, len(columns))
|
||||
|
||||
// 初始化 map 字段
|
||||
for i, k := range columns {
|
||||
if f, ok := keys[k]; ok {
|
||||
scanArgs[i] = v.Field(f).Addr().Interface()
|
||||
} else {
|
||||
scanArgs[i] = new(any)
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Scan(scanArgs...); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if itemIsPtr {
|
||||
realDest.Set(reflect.Append(realDest, vp))
|
||||
} else {
|
||||
realDest.Set(reflect.Append(realDest, v))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// to struct
|
||||
func sStruct(rows *sql.Rows, dest any) (result RowScan) {
|
||||
columns, _ := rows.Columns()
|
||||
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
||||
vp := reflect.New(realDest.Type())
|
||||
v := reflect.Indirect(vp)
|
||||
|
||||
keys := map[string]int{}
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
tField := v.Type().Field(i)
|
||||
kFiled := tField.Tag.Get("db")
|
||||
if kFiled == "" {
|
||||
kFiled = tField.Name
|
||||
}
|
||||
keys[kFiled] = i
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
if result.Count == 0 {
|
||||
scanArgs := make([]any, len(columns))
|
||||
|
||||
// 初始化 map 字段
|
||||
for i, k := range columns {
|
||||
if f, ok := keys[k]; ok {
|
||||
scanArgs[i] = v.Field(f).Addr().Interface()
|
||||
} else {
|
||||
scanArgs[i] = new(any)
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Scan(scanArgs...); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if realDest.Kind() == reflect.Ptr {
|
||||
realDest.Set(vp)
|
||||
} else {
|
||||
realDest.Set(v)
|
||||
}
|
||||
}
|
||||
|
||||
result.Count++
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// to map slice
|
||||
func sMapSlice(rows *sql.Rows, dest any) (result RowScan) {
|
||||
columns, _ := rows.Columns()
|
||||
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
||||
|
||||
for rows.Next() {
|
||||
scanArgs := make([]any, len(columns))
|
||||
element := make(map[string]any)
|
||||
|
||||
result.Count++
|
||||
|
||||
// 初始化 map 字段
|
||||
for i := 0; i < len(columns); i++ {
|
||||
scanArgs[i] = new(any)
|
||||
}
|
||||
|
||||
// 填充 map 字段
|
||||
if err := rows.Scan(scanArgs...); err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
// 将 scan 后的 map 填充到结果集
|
||||
for i, column := range columns {
|
||||
switch v := (*scanArgs[i].(*any)).(type) {
|
||||
case []uint8:
|
||||
data := string(v)
|
||||
if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
|
||||
jData := new(any)
|
||||
if err := json.Unmarshal(v, jData); err == nil {
|
||||
element[column] = *jData
|
||||
}
|
||||
} else {
|
||||
element[column] = data
|
||||
}
|
||||
default:
|
||||
// int64
|
||||
element[column] = v
|
||||
}
|
||||
}
|
||||
|
||||
realDest.Set(reflect.Append(realDest, reflect.ValueOf(element)))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// to map
|
||||
func sMap(rows *sql.Rows, dest any) (result RowScan) {
|
||||
columns, _ := rows.Columns()
|
||||
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
||||
|
||||
for rows.Next() {
|
||||
if result.Count == 0 {
|
||||
scanArgs := make([]interface{}, len(columns))
|
||||
// 初始化 map 字段
|
||||
for i := 0; i < len(columns); i++ {
|
||||
scanArgs[i] = new(any)
|
||||
}
|
||||
|
||||
if err := rows.Scan(scanArgs...); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i, column := range columns {
|
||||
var rv reflect.Value
|
||||
switch v := (*scanArgs[i].(*any)).(type) {
|
||||
case []uint8:
|
||||
data := string(v)
|
||||
if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") {
|
||||
jData := new(any)
|
||||
if err := json.Unmarshal(v, jData); err == nil {
|
||||
rv = reflect.ValueOf(jData).Elem()
|
||||
}
|
||||
} else {
|
||||
rv = reflect.ValueOf(data)
|
||||
}
|
||||
default:
|
||||
// int64
|
||||
rv = reflect.ValueOf(v)
|
||||
}
|
||||
realDest.SetMapIndex(reflect.ValueOf(column), rv)
|
||||
}
|
||||
}
|
||||
|
||||
result.Count++
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// pluck, 如 []int{}
|
||||
func sValues(rows *sql.Rows, dest any) (result RowScan) {
|
||||
columns, _ := rows.Columns()
|
||||
realDest := reflect.Indirect(reflect.ValueOf(dest))
|
||||
scanArgs := make([]interface{}, len(columns))
|
||||
for rows.Next() {
|
||||
result.Count++
|
||||
scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface()
|
||||
err := rows.Scan(scanArgs...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
realDest.Set(reflect.Append(realDest, reflect.ValueOf(reflect.ValueOf(scanArgs[0]).Elem().Interface())))
|
||||
}
|
||||
return result
|
||||
}
|
73
scan_test.go
Normal file
73
scan_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
var oDB *DB
|
||||
|
||||
func init() {
|
||||
oDB = Open(map[string]DBConfig{
|
||||
"default": {
|
||||
Driver: "mysql",
|
||||
Host: "localhost",
|
||||
Port: "3366",
|
||||
Database: "demo",
|
||||
Username: "demo",
|
||||
Password: "ded86bf25d661bb723f3898b2440dd678382e2dd",
|
||||
Charset: "utf8mb4",
|
||||
MultiStatements: true,
|
||||
// ParseTime: true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestScanMapSlice(t *testing.T) {
|
||||
dest := []map[string]any{}
|
||||
|
||||
oDB.Select("select * from fsm", []any{}, &dest)
|
||||
|
||||
t.Log(dest)
|
||||
}
|
||||
|
||||
func TestScanValues(t *testing.T) {
|
||||
|
||||
dest := []int{}
|
||||
|
||||
oDB.Select("select id from fsm", []any{}, &dest)
|
||||
|
||||
t.Log(dest)
|
||||
}
|
||||
|
||||
func TestScanMap(t *testing.T) {
|
||||
|
||||
dest := map[string]any{}
|
||||
|
||||
oDB.Select("select * from fsm where id = 3", []any{}, &dest)
|
||||
|
||||
t.Log(dest)
|
||||
}
|
||||
|
||||
func TestScanStruct(t *testing.T) {
|
||||
|
||||
dest := struct {
|
||||
Id int `db:"id"`
|
||||
Code string `db:"code"`
|
||||
}{}
|
||||
|
||||
oDB.Select("select * from fsm", []any{}, &dest)
|
||||
|
||||
t.Log(dest)
|
||||
}
|
||||
|
||||
func TestScanStructSlice(t *testing.T) {
|
||||
|
||||
dest := []struct {
|
||||
Id int `db:"id"`
|
||||
Code string `db:"code"`
|
||||
}{}
|
||||
|
||||
oDB.Select("select * from fsm", []any{}, &dest)
|
||||
|
||||
t.Log(dest)
|
||||
}
|
456
schema/blueprint.go
Normal file
456
schema/blueprint.go
Normal file
@ -0,0 +1,456 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"git.fsdpf.net/go/db"
|
||||
)
|
||||
|
||||
type Blueprint struct {
|
||||
table string // the table the blueprint describes.
|
||||
columns []*ColumnDefinition // columns that should be added to the table
|
||||
commands []*Command //
|
||||
temporary bool // Whether to make the table temporary.
|
||||
charset string // The default character set that should be used for the table.
|
||||
collation string // The collation that should be used for the table.
|
||||
engine string // The engine that should be used for the table.
|
||||
}
|
||||
|
||||
type Command struct {
|
||||
Type string
|
||||
CommandOptions
|
||||
}
|
||||
|
||||
type CommandOptions struct {
|
||||
Index string
|
||||
Columns []string
|
||||
Algorithm string
|
||||
To string
|
||||
From string
|
||||
}
|
||||
|
||||
func NewBlueprint(table string) *Blueprint {
|
||||
return &Blueprint{table: table, charset: "utf8mb4", collation: "utf8mb4_general_ci"}
|
||||
}
|
||||
|
||||
// 字符串
|
||||
func (this *Blueprint) Char(column string, length int) *ColumnDefinition {
|
||||
if length == 0 {
|
||||
length = 255
|
||||
}
|
||||
return this.addColumn("char", column, &ColumnOptions{Length: length})
|
||||
}
|
||||
|
||||
// 可变长度字符串
|
||||
func (this *Blueprint) String(column string, length int) *ColumnDefinition {
|
||||
if length == 0 {
|
||||
length = 255
|
||||
}
|
||||
return this.addColumn("string", column, &ColumnOptions{Length: length})
|
||||
}
|
||||
|
||||
// 文本
|
||||
func (this *Blueprint) Text(column string) *ColumnDefinition {
|
||||
return this.addColumn("text", column, nil)
|
||||
}
|
||||
|
||||
// 整型
|
||||
func (this *Blueprint) Integer(column string, params ...bool) *ColumnDefinition {
|
||||
autoIncrement := false
|
||||
unsigned := false
|
||||
if len(params) > 0 {
|
||||
autoIncrement = params[0]
|
||||
}
|
||||
if len(params) > 1 {
|
||||
unsigned = params[1]
|
||||
}
|
||||
return this.addColumn("integer", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
|
||||
}
|
||||
|
||||
// 迷你整型 1 byte
|
||||
func (this *Blueprint) TinyInteger(column string, params ...bool) *ColumnDefinition {
|
||||
autoIncrement := false
|
||||
unsigned := false
|
||||
if len(params) > 0 {
|
||||
autoIncrement = params[0]
|
||||
}
|
||||
if len(params) > 1 {
|
||||
unsigned = params[1]
|
||||
}
|
||||
return this.addColumn("tinyInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
|
||||
}
|
||||
|
||||
// 小整型 2 byte
|
||||
func (this *Blueprint) SmallInteger(column string, params ...bool) *ColumnDefinition {
|
||||
autoIncrement := false
|
||||
unsigned := false
|
||||
if len(params) > 0 {
|
||||
autoIncrement = params[0]
|
||||
}
|
||||
if len(params) > 1 {
|
||||
unsigned = params[1]
|
||||
}
|
||||
return this.addColumn("smallInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
|
||||
}
|
||||
|
||||
// 大整型 2 byte
|
||||
func (this *Blueprint) BigInteger(column string, params ...bool) *ColumnDefinition {
|
||||
autoIncrement := false
|
||||
unsigned := false
|
||||
if len(params) > 0 {
|
||||
autoIncrement = params[0]
|
||||
}
|
||||
if len(params) > 1 {
|
||||
unsigned = params[1]
|
||||
}
|
||||
return this.addColumn("bigInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned})
|
||||
}
|
||||
|
||||
// 无符号整型
|
||||
func (this *Blueprint) UnsignedInteger(column string, params ...bool) *ColumnDefinition {
|
||||
autoIncrement := false
|
||||
if len(params) > 0 {
|
||||
autoIncrement = params[0]
|
||||
}
|
||||
return this.Integer(column, autoIncrement, true)
|
||||
}
|
||||
|
||||
// 无符号迷你整型 1 byte
|
||||
func (this *Blueprint) UnsignedTinyInteger(column string, params ...bool) *ColumnDefinition {
|
||||
autoIncrement := false
|
||||
if len(params) > 0 {
|
||||
autoIncrement = params[0]
|
||||
}
|
||||
return this.TinyInteger(column, autoIncrement, true)
|
||||
}
|
||||
|
||||
// 无符号小整型 2 byte
|
||||
func (this *Blueprint) UnsignedSmallInteger(column string, params ...bool) *ColumnDefinition {
|
||||
autoIncrement := false
|
||||
if len(params) > 0 {
|
||||
autoIncrement = params[0]
|
||||
}
|
||||
return this.SmallInteger(column, autoIncrement, true)
|
||||
}
|
||||
|
||||
// 无符号大整型
|
||||
func (this *Blueprint) UnsignedBigInteger(column string, params ...bool) *ColumnDefinition {
|
||||
autoIncrement := false
|
||||
if len(params) > 0 {
|
||||
autoIncrement = params[0]
|
||||
}
|
||||
return this.BigInteger(column, autoIncrement, true)
|
||||
}
|
||||
|
||||
// 精确小数
|
||||
func (this *Blueprint) Decimal(column string, total, places int) *ColumnDefinition {
|
||||
if total == 0 {
|
||||
total = 8
|
||||
}
|
||||
if places == 0 {
|
||||
places = 2
|
||||
}
|
||||
return this.addColumn("decimal", column, &ColumnOptions{Total: total, Places: places})
|
||||
}
|
||||
|
||||
// 无符号精确销售
|
||||
func (this *Blueprint) UnsignedDecimal(column string, total, places int) *ColumnDefinition {
|
||||
if total == 0 {
|
||||
total = 8
|
||||
}
|
||||
if places == 0 {
|
||||
places = 2
|
||||
}
|
||||
return this.addColumn("decimal", column, &ColumnOptions{Total: total, Places: places, unsigned: true})
|
||||
}
|
||||
|
||||
// 布尔值
|
||||
func (this *Blueprint) Boolean(column string) *ColumnDefinition {
|
||||
return this.addColumn("boolean", column, nil)
|
||||
}
|
||||
|
||||
// 枚举类型
|
||||
func (this *Blueprint) Enum(column string, allowed []string) *ColumnDefinition {
|
||||
return this.addColumn("enum", column, &ColumnOptions{Allowed: allowed})
|
||||
}
|
||||
|
||||
// JSON
|
||||
func (this *Blueprint) Json(column string) *ColumnDefinition {
|
||||
return this.addColumn("json", column, nil)
|
||||
}
|
||||
|
||||
// 日期类型
|
||||
func (this *Blueprint) Date(column string) *ColumnDefinition {
|
||||
return this.addColumn("date", column, nil)
|
||||
}
|
||||
|
||||
// 日期时间类型
|
||||
func (this *Blueprint) DateTime(column string, precision ...int) *ColumnDefinition {
|
||||
if len(precision) > 0 {
|
||||
return this.addColumn("datetime", column, &ColumnOptions{Precision: precision[0]})
|
||||
}
|
||||
return this.addColumn("datetime", column, nil)
|
||||
}
|
||||
|
||||
// 时间类型
|
||||
func (this *Blueprint) Time(column string, precision ...int) *ColumnDefinition {
|
||||
if len(precision) > 0 {
|
||||
return this.addColumn("time", column, &ColumnOptions{Precision: precision[0]})
|
||||
}
|
||||
return this.addColumn("time", column, nil)
|
||||
}
|
||||
|
||||
// 时间戳
|
||||
func (this *Blueprint) Timestamp(column string, precision ...int) *ColumnDefinition {
|
||||
if len(precision) > 0 {
|
||||
return this.addColumn("timestamp", column, &ColumnOptions{Precision: precision[0]})
|
||||
}
|
||||
return this.addColumn("timestamp", column, nil)
|
||||
}
|
||||
|
||||
// 年
|
||||
func (this *Blueprint) Year(column string) *ColumnDefinition {
|
||||
return this.addColumn("year", column, nil)
|
||||
}
|
||||
|
||||
// 二进制数据
|
||||
func (this *Blueprint) Binary(column string) *ColumnDefinition {
|
||||
return this.addColumn("binary", column, nil)
|
||||
}
|
||||
|
||||
// UUID
|
||||
func (this *Blueprint) Uuid(column string) *ColumnDefinition {
|
||||
return this.addColumn("uuid", column, nil)
|
||||
}
|
||||
|
||||
// 自增字段
|
||||
func (this *Blueprint) Increments(column string) *ColumnDefinition {
|
||||
return this.UnsignedInteger(column, true)
|
||||
}
|
||||
|
||||
// 自增Big字段
|
||||
func (this *Blueprint) BigIncrements(column string) *ColumnDefinition {
|
||||
return this.UnsignedBigInteger(column, true)
|
||||
}
|
||||
|
||||
// 添加主键
|
||||
func (this *Blueprint) Primary(columns ...string) *Command {
|
||||
return this.addCommand("primary", CommandOptions{Index: this.generateIndexName("pk", columns), Columns: columns})
|
||||
}
|
||||
|
||||
// 唯一键
|
||||
func (this *Blueprint) Unique(columns ...string) *Command {
|
||||
return this.addCommand("unique", CommandOptions{Index: this.generateIndexName("unique", columns), Columns: columns})
|
||||
}
|
||||
|
||||
// 普通索引
|
||||
func (this *Blueprint) Index(columns ...string) *Command {
|
||||
return this.addCommand("index", CommandOptions{Index: this.generateIndexName("index", columns), Columns: columns})
|
||||
}
|
||||
|
||||
// 空间索引
|
||||
func (this *Blueprint) SpatialIndex(columns ...string) *Command {
|
||||
return this.addCommand("spatialIndex", CommandOptions{Index: this.generateIndexName("spatial_index", columns), Columns: columns})
|
||||
}
|
||||
|
||||
// 删除列
|
||||
func (this *Blueprint) DropColumn(columns ...string) *Command {
|
||||
return this.addCommand("dropColumn", CommandOptions{Columns: columns})
|
||||
}
|
||||
|
||||
// 创建表
|
||||
func (this *Blueprint) Create() *Command {
|
||||
return this.addCommand("create", CommandOptions{})
|
||||
}
|
||||
|
||||
// 设置临时表标记
|
||||
func (this *Blueprint) Temporary() {
|
||||
this.temporary = true
|
||||
}
|
||||
|
||||
// 设置表字符集
|
||||
func (this *Blueprint) Charset(charset string) {
|
||||
this.charset = charset
|
||||
}
|
||||
|
||||
// 修改表名
|
||||
func (this *Blueprint) Rename(to string) *Command {
|
||||
return this.addCommand("rename", CommandOptions{To: to})
|
||||
}
|
||||
|
||||
// 删除表
|
||||
func (this *Blueprint) Drop() *Command {
|
||||
return this.addCommand("drop", CommandOptions{})
|
||||
}
|
||||
|
||||
// 删除表, 先判断再删除
|
||||
func (this *Blueprint) DropIfExists() *Command {
|
||||
return this.addCommand("dropIfExists", CommandOptions{})
|
||||
}
|
||||
|
||||
// 执行SQL查询
|
||||
func (this *Blueprint) Build(conn *db.Connection, grammar IGrammar) error {
|
||||
_, err := conn.Transaction(func(tx *db.Transaction) (result any, err error) {
|
||||
for _, sql := range this.ToSql(conn, grammar) {
|
||||
if _, err := tx.Statement(sql, []any{}); err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
tx.Commit()
|
||||
return
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *Blueprint) ToSql(conn *db.Connection, grammar IGrammar) (statements []string) {
|
||||
this.addImpliedCommands(grammar)
|
||||
|
||||
for _, cmd := range this.commands {
|
||||
switch cmd.Type {
|
||||
case "create":
|
||||
statements = append(statements, grammar.CompileCreate(this, cmd, conn))
|
||||
case "add":
|
||||
statements = append(statements, grammar.CompileAdd(this, cmd, conn)...)
|
||||
case "change":
|
||||
statements = append(statements, grammar.CompileChange(this, cmd, conn)...)
|
||||
case "primary":
|
||||
// statements = append(statements, grammar.CompilePrimary(this, cmd, conn))
|
||||
case "unique":
|
||||
statements = append(statements, grammar.CompileUnique(this, cmd, conn))
|
||||
case "index":
|
||||
statements = append(statements, grammar.CompileIndex(this, cmd, conn))
|
||||
case "spatialIndex":
|
||||
statements = append(statements, grammar.CompileSpatialIndex(this, cmd, conn))
|
||||
case "drop":
|
||||
statements = append(statements, grammar.CompileDrop(this, cmd, conn))
|
||||
case "dropIfExists":
|
||||
statements = append(statements, grammar.CompileDropIfExists(this, cmd, conn))
|
||||
case "dropColumn":
|
||||
statements = append(statements, grammar.CompileDropColumn(this, cmd, conn)...)
|
||||
case "dropPrimary":
|
||||
// statements = append(statements, grammar.CompileDropPrimary(this, cmd, conn))
|
||||
case "dropUnique":
|
||||
statements = append(statements, grammar.CompileDropUnique(this, cmd, conn))
|
||||
case "dropIndex":
|
||||
statements = append(statements, grammar.CompileDropIndex(this, cmd, conn))
|
||||
case "dropSpatialIndex":
|
||||
statements = append(statements, grammar.CompileDropSpatialIndex(this, cmd, conn))
|
||||
case "rename":
|
||||
statements = append(statements, grammar.CompileRename(this, cmd, conn))
|
||||
case "renameIndex":
|
||||
statements = append(statements, grammar.CompileRenameIndex(this, cmd, conn))
|
||||
case "dropAllTables":
|
||||
case "dropAllViews":
|
||||
case "getAllTables":
|
||||
case "getAllViews":
|
||||
}
|
||||
}
|
||||
|
||||
return statements
|
||||
}
|
||||
|
||||
// 判断是否是创建表
|
||||
func (this *Blueprint) creating() bool {
|
||||
return lo.SomeBy(this.commands, func(item *Command) bool {
|
||||
if item.Type == "create" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Blueprint) addColumn(typ, name string, options *ColumnOptions) (definition *ColumnDefinition) {
|
||||
definition = &ColumnDefinition{Type: typ, Name: name}
|
||||
|
||||
if options != nil {
|
||||
if options.Length > 0 {
|
||||
definition.Length = options.Length
|
||||
}
|
||||
if options.autoIncrement {
|
||||
definition.autoIncrement = true
|
||||
}
|
||||
if options.unsigned {
|
||||
definition.unsigned = true
|
||||
}
|
||||
if options.Total > 0 {
|
||||
definition.Total = options.Total
|
||||
}
|
||||
if options.Places > 0 {
|
||||
definition.Places = options.Places
|
||||
}
|
||||
if len(options.Allowed) > 0 {
|
||||
definition.Allowed = options.Allowed
|
||||
}
|
||||
if options.Precision > 0 {
|
||||
definition.Precision = options.Precision
|
||||
}
|
||||
if options.change {
|
||||
definition.change = true
|
||||
}
|
||||
}
|
||||
this.columns = append(this.columns, definition)
|
||||
return definition
|
||||
}
|
||||
|
||||
func (this *Blueprint) addImpliedCommands(grammar IGrammar) {
|
||||
if !this.creating() {
|
||||
if len(this.getAddedColumns()) > 0 {
|
||||
this.commands = append([]*Command{this.createCommand("add", CommandOptions{})}, this.commands...)
|
||||
}
|
||||
if len(this.getChangedColumns()) > 0 {
|
||||
this.commands = append([]*Command{this.createCommand("change", CommandOptions{})}, this.commands...)
|
||||
}
|
||||
}
|
||||
|
||||
this.addFluentIndexes()
|
||||
}
|
||||
|
||||
// 添加索引字段
|
||||
func (this *Blueprint) addFluentIndexes() {
|
||||
for _, column := range this.columns {
|
||||
if column.primary {
|
||||
this.Primary(column.Name)
|
||||
continue
|
||||
} else if column.unique {
|
||||
this.Unique(column.Name)
|
||||
continue
|
||||
} else if column.index {
|
||||
this.Index(column.Name)
|
||||
continue
|
||||
} else if column.spatialIndex {
|
||||
this.SpatialIndex(column.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Blueprint) addCommand(name string, options CommandOptions) (command *Command) {
|
||||
command = this.createCommand(name, options)
|
||||
this.commands = append(this.commands, command)
|
||||
return command
|
||||
}
|
||||
|
||||
func (this *Blueprint) createCommand(name string, options CommandOptions) *Command {
|
||||
return &Command{Type: name, CommandOptions: options}
|
||||
}
|
||||
|
||||
// 生成索引名称
|
||||
func (this *Blueprint) generateIndexName(typ string, columns []string) string {
|
||||
return strings.ToLower(typ + "_" + strings.Join(columns, "_"))
|
||||
}
|
||||
|
||||
func (this *Blueprint) getAddedColumns() []*ColumnDefinition {
|
||||
return lo.Filter(this.columns, func(item *ColumnDefinition, _ int) bool {
|
||||
return !item.change
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Blueprint) getChangedColumns() []*ColumnDefinition {
|
||||
return lo.Filter(this.columns, func(item *ColumnDefinition, _ int) bool {
|
||||
return item.change
|
||||
})
|
||||
}
|
216
schema/blueprint_test.go
Normal file
216
schema/blueprint_test.go
Normal file
@ -0,0 +1,216 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"git.fsdpf.net/go/db"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// func TestMysqlToSql(t *testing.T) {
|
||||
// bp := NewBlueprint("users")
|
||||
|
||||
// // bp.BigIncrements("id") // 等同于自增 UNSIGNED BIGINT(主键)列
|
||||
// // bp.BigInteger("votes").Default("0") // 等同于 BIGINT 类型列
|
||||
// // bp.Binary("data") // 等同于 BLOB 类型列
|
||||
// // bp.Boolean("confirmed") // 等同于 BOOLEAN 类型列
|
||||
// // bp.Char("name", 4) // 等同于 CHAR 类型列
|
||||
// // bp.Date("created_at") // 等同于 DATE 类型列
|
||||
// // bp.DateTime("created_at") // 等同于 DATETIME 类型列
|
||||
// // bp.Decimal("amount", 5, 2) // 等同于 DECIMAL 类型列,带精度和范围
|
||||
// // bp.Enum("level", []string{"easy", "hard"}) // 等同于 ENUM 类型列
|
||||
// bp.Increments("id").Change("pid") // 等同于自增 UNSIGNED INTEGER (主键)类型列
|
||||
// // bp.Integer("votes").AutoIncrement() // 等同于 INTEGER 类型列
|
||||
// bp.Json("options").Charset("uft8").Change() // 等同于 JSON 类型列
|
||||
// // bp.SmallInteger("votes").Comment("等同于 SMALLINT 类型列") // 等同于 SMALLINT 类型列
|
||||
// // bp.String("name", 100).Nullable() // 等同于 VARCHAR 类型列,带一个可选长度参数
|
||||
// // bp.Text("description").After("column") // 等同于 TEXT 类型列
|
||||
// // bp.Time("sunrise").First() // 等同于 TIME 类型列
|
||||
// bp.Timestamp("added_on").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")) // 等同于 TIMESTAMP 类型列
|
||||
// // bp.TinyInteger("numbers") // 等同于 TINYINT 类型列
|
||||
// // bp.UnsignedBigInteger("votes") // 等同于无符号的 BIGINT 类型列
|
||||
// // bp.UnsignedDecimal("amount", 8, 2) // 等同于 UNSIGNED DECIMAL 类型列,带有总位数和精度
|
||||
// // bp.UnsignedInteger("votes") // 等同于无符号的 INTEGER 类型列
|
||||
// // bp.UnsignedSmallInteger("votes") // 等同于无符号的 SMALLINT 类型列
|
||||
// // bp.UnsignedTinyInteger("votes") // 等同于无符号的 TINYINT 类型列
|
||||
// // bp.Uuid("id").Default("00000000-0000-0000-0000-000000000000") // 等同于 UUID 类型列
|
||||
// // bp.Year("birth_year")
|
||||
// bp.DropColumn("name", "name_a") // 等同于 YEAR 类型列
|
||||
|
||||
// t.Log(bp.ToSql(mysql.Connection, &MysqlGrammar{}))
|
||||
// }
|
||||
|
||||
// func TestSqliteToSql(t *testing.T) {
|
||||
// bp := NewBlueprint("users")
|
||||
|
||||
// bp.charset = "utf8mb4"
|
||||
// bp.collation = "utf8mb4_general_ci"
|
||||
// // bp.BigIncrements("id").Comment("ID")
|
||||
// // bp.Boolean("enabled").Default("1").Comment("是否有效")
|
||||
// bp.Char("created_user", 36).Default("22000000-0000-0000-0000-000000000001").Comment("创建者").Change()
|
||||
// bp.Char("owned_user", 36).Default("11000000-0000-0000-0000-000000000000").Comment("拥有者").Change()
|
||||
// // bp.Timestamp("created_at").UseCurrent().Comment("创建时间")
|
||||
// // bp.Timestamp("updated_at").Default(db.Raw("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
|
||||
// bp.DateTime("deleted_at").Nullable().Comment("删除时间")
|
||||
// bp.String("name_a", 50).Default("").Comment("姓名").Change("c_name")
|
||||
|
||||
// bp.DropColumn("name", "name_a")
|
||||
|
||||
// t.Log(bp.ToSql(sqlite3.Connection, &Sqlite3Grammar{}))
|
||||
// }
|
||||
|
||||
// Mysql 创建表
|
||||
func TestMysqlCreateTableToSql(t *testing.T) {
|
||||
table := NewBlueprint("users")
|
||||
table.Create()
|
||||
table.Charset("utf8mb4")
|
||||
|
||||
table.BigIncrements("id").Comment("ID")
|
||||
table.Boolean("enabled").Default("1").Comment("是否有效")
|
||||
table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者")
|
||||
table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者")
|
||||
table.Timestamp("created_at").UseCurrent().Comment("创建时间")
|
||||
table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
|
||||
table.DateTime("deleted_at").Nullable().Comment("删除时间")
|
||||
|
||||
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "mysql", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// SQLite 创建表
|
||||
func TestSqliteCreateTableToSql(t *testing.T) {
|
||||
table := NewBlueprint("users")
|
||||
table.Create()
|
||||
table.Charset("utf8mb4")
|
||||
|
||||
table.BigIncrements("id").Comment("ID")
|
||||
table.Boolean("enabled").Default("1").Comment("是否有效")
|
||||
table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者")
|
||||
table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者")
|
||||
table.Timestamp("created_at").UseCurrent().Comment("创建时间")
|
||||
table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
|
||||
table.DateTime("deleted_at").Nullable().Comment("删除时间")
|
||||
|
||||
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "sqlite", sql)
|
||||
}
|
||||
}
|
||||
|
||||
//Mysql 修改表, 添加字段
|
||||
func TestMysqlAddColumnToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table)
|
||||
|
||||
table.String("name", 50).Nullable().Comment("用户名")
|
||||
table.TinyInteger("age").Nullable().Comment("年龄")
|
||||
|
||||
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "mysql", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 修改表, 添加字段
|
||||
func TestSqliteAddColumnToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table)
|
||||
|
||||
table.String("name", 50).Nullable().Comment("用户名")
|
||||
table.TinyInteger("age").Nullable().Comment("年龄")
|
||||
|
||||
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "sqlite", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql 修改表, 添加/编辑字段
|
||||
func TestMysqlChangeColumnToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table)
|
||||
|
||||
table.String("name", 100).Nullable().Comment("用户名").Change()
|
||||
table.SmallInteger("age").Nullable().Comment("年龄").Change()
|
||||
table.String("nickname", 100).Nullable().Comment("昵称")
|
||||
|
||||
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "mysql", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 修改表, 添加/编辑字段
|
||||
func TestSqliteChangeColumnToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table)
|
||||
|
||||
table.String("name", 100).Nullable().Comment("用户名").Change()
|
||||
table.SmallInteger("age").Nullable().Comment("年龄").Change()
|
||||
table.String("nickname", 100).Nullable().Comment("昵称")
|
||||
|
||||
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "sqlite", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql 修改表, 添加/编辑/删除字段
|
||||
func TestMysqlDropColumnToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table)
|
||||
|
||||
table.String("username", 100).Default("").Comment("--用户名--").Change("name")
|
||||
table.Boolean("is_vip").Default(1).Comment("是否VIP")
|
||||
table.DropColumn("age")
|
||||
|
||||
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "mysql", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 修改表, 添加/编辑/删除字段
|
||||
func TestSqliteDropColumnToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table)
|
||||
|
||||
table.String("username", 100).Default("").Comment("--用户名--").Change("name")
|
||||
table.Boolean("is_vip").Default(1).Comment("是否VIP")
|
||||
table.DropColumn("age")
|
||||
|
||||
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "sqlite", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql 重命名表
|
||||
func TestMysqlRenameToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table)
|
||||
|
||||
table.Rename(test_table + "_alias")
|
||||
|
||||
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "mysql", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 重命名表
|
||||
func TestSqliteRenameToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table)
|
||||
|
||||
table.Rename(test_table + "_alias")
|
||||
|
||||
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "sqlite", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql 删除表
|
||||
func TestMysqlDropTableToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table + "_alias")
|
||||
|
||||
table.Drop()
|
||||
|
||||
for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "mysql", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 删除表
|
||||
func TestSqliteDropTableToSql(t *testing.T) {
|
||||
table := NewBlueprint(test_table + "_alias")
|
||||
|
||||
table.Drop()
|
||||
|
||||
for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) {
|
||||
t.Logf("-- %s -- %s \n", "sqlite", sql)
|
||||
}
|
||||
}
|
128
schema/builder.go
Normal file
128
schema/builder.go
Normal file
@ -0,0 +1,128 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/samber/lo"
|
||||
|
||||
"git.fsdpf.net/go/db"
|
||||
)
|
||||
|
||||
type Builder struct {
|
||||
Connection *db.Connection // The database connection instance
|
||||
Grammar IGrammar // The schema grammar instance
|
||||
}
|
||||
|
||||
func NewSchema(conn *db.Connection) *Builder {
|
||||
var grammar IGrammar
|
||||
|
||||
switch conn.GetConfig().Driver {
|
||||
case "mysql":
|
||||
grammar = &MysqlGrammar{}
|
||||
case "sqlite3":
|
||||
grammar = &Sqlite3Grammar{}
|
||||
}
|
||||
|
||||
return &Builder{
|
||||
Connection: conn,
|
||||
Grammar: grammar,
|
||||
}
|
||||
}
|
||||
|
||||
// 判断数据表是否存在
|
||||
func (this *Builder) HasTable(table string) (bool, error) {
|
||||
query := this.Grammar.CompileTableExists()
|
||||
dbName := this.Connection.GetConfig().Database
|
||||
|
||||
count := 0
|
||||
|
||||
if _, err := this.Connection.Select(query, []any{table, dbName}, &count); err != nil || count == 0 {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 判断表字段是个存在
|
||||
func (this *Builder) HasColumns(table string, columns ...string) (bool, error) {
|
||||
if tColumns, err := this.GetColumnListing(table); err != nil {
|
||||
return false, err
|
||||
} else {
|
||||
for _, col := range columns {
|
||||
if !lo.Contains(tColumns, col) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 判断数据库列是否存在
|
||||
func (this *Builder) GetColumnListing(table string) (columns []string, err error) {
|
||||
query := this.Grammar.CompileColumnListing(table)
|
||||
dbName := this.Connection.GetConfig().Database
|
||||
|
||||
items := []struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
}{}
|
||||
|
||||
bindings := []any{table, dbName}
|
||||
|
||||
if this.Connection.GetConfig().Driver == "sqlite3" {
|
||||
bindings = nil
|
||||
}
|
||||
|
||||
if _, err := this.Connection.Select(query, bindings, &items); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
columns = append(columns, item.Name)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 修改表
|
||||
func (this *Builder) Table(table string, cb func(*Blueprint)) error {
|
||||
bp := this.createBlueprint(table)
|
||||
cb(bp)
|
||||
return this.Build(bp)
|
||||
}
|
||||
|
||||
// 创建表
|
||||
func (this *Builder) Create(table string, cb func(*Blueprint)) error {
|
||||
bp := this.createBlueprint(table)
|
||||
bp.Create()
|
||||
cb(bp)
|
||||
return this.Build(bp)
|
||||
}
|
||||
|
||||
// 修改表名
|
||||
func (this *Builder) Rename(from, to string) error {
|
||||
bp := this.createBlueprint(from)
|
||||
bp.Rename(to)
|
||||
return this.Build(bp)
|
||||
}
|
||||
|
||||
// 删除表
|
||||
func (this *Builder) Drop(table string) error {
|
||||
bp := this.createBlueprint(table)
|
||||
bp.Drop()
|
||||
return this.Build(bp)
|
||||
}
|
||||
|
||||
// 删除表, 先判断再删除
|
||||
func (this *Builder) DropIfExists(table string) error {
|
||||
bp := this.createBlueprint(table)
|
||||
bp.DropIfExists()
|
||||
return this.Build(bp)
|
||||
}
|
||||
|
||||
func (b *Builder) createBlueprint(table string) *Blueprint {
|
||||
return NewBlueprint(table)
|
||||
}
|
||||
|
||||
// Build execute the blueprint to build / modify the table
|
||||
func (this *Builder) Build(bp *Blueprint) error {
|
||||
return bp.Build(this.Connection, this.Grammar)
|
||||
}
|
195
schema/builder_test.go
Normal file
195
schema/builder_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"git.fsdpf.net/go/db"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const test_table = "test_users"
|
||||
|
||||
var mysql *Builder
|
||||
var sqlite3 *Builder
|
||||
|
||||
func init() {
|
||||
odbc := db.Open(map[string]db.DBConfig{
|
||||
"default": {
|
||||
Driver: "mysql",
|
||||
Host: "localhost",
|
||||
Port: "3366",
|
||||
Database: "demo",
|
||||
Username: "demo",
|
||||
Password: "ded86bf25d661bb723f3898b2440dd678382e2dd",
|
||||
Charset: "utf8mb4",
|
||||
MultiStatements: true,
|
||||
// ParseTime: true,
|
||||
},
|
||||
"sqlite3": {
|
||||
Driver: "sqlite3",
|
||||
File: "test.db3",
|
||||
},
|
||||
})
|
||||
|
||||
mysql = NewSchema(odbc.Connection("default"))
|
||||
sqlite3 = NewSchema(odbc.Connection("sqlite3"))
|
||||
}
|
||||
|
||||
func TestHasTable(t *testing.T) {
|
||||
t.Log(mysql.HasTable("users"))
|
||||
t.Log(sqlite3.HasTable("users"))
|
||||
}
|
||||
|
||||
func TestHasColumns(t *testing.T) {
|
||||
t.Log("mysql")
|
||||
t.Log(mysql.HasColumns("users", "id"))
|
||||
t.Log("sqlite3")
|
||||
t.Log(sqlite3.HasColumns("users", "id"))
|
||||
}
|
||||
|
||||
func TestGetColumnListing(t *testing.T) {
|
||||
t.Log("mysql")
|
||||
t.Log(mysql.GetColumnListing("users"))
|
||||
t.Log("sqlite3")
|
||||
t.Log(sqlite3.GetColumnListing("users"))
|
||||
}
|
||||
|
||||
// Mysql 创建表
|
||||
func TestMysqlCreateTable(t *testing.T) {
|
||||
if err := mysql.Create(test_table, func(table *Blueprint) {
|
||||
table.Charset("utf8mb4")
|
||||
|
||||
table.BigIncrements("id").Comment("ID")
|
||||
table.Boolean("enabled").Default("1").Comment("是否有效")
|
||||
table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者")
|
||||
table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者")
|
||||
table.Timestamp("created_at").UseCurrent().Comment("创建时间")
|
||||
table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
|
||||
table.DateTime("deleted_at").Nullable().Comment("删除时间")
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 创建表
|
||||
func TestSqliteCreateTable(t *testing.T) {
|
||||
if err := sqlite3.Create(test_table, func(table *Blueprint) {
|
||||
table.Charset("utf8mb4")
|
||||
|
||||
table.Increments("id").Comment("ID")
|
||||
table.Boolean("enabled").Default("1").Comment("是否有效")
|
||||
table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者")
|
||||
table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者")
|
||||
table.Timestamp("created_at").UseCurrent().Comment("创建时间")
|
||||
table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间")
|
||||
table.DateTime("deleted_at").Nullable().Comment("删除时间")
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql 修改表, 添加字段
|
||||
func TestMysqlAddColumn(t *testing.T) {
|
||||
if err := mysql.Table(test_table, func(table *Blueprint) {
|
||||
table.String("name", 50).Nullable().Comment("用户名")
|
||||
table.TinyInteger("age").Nullable().Comment("年龄")
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 修改表, 添加字段
|
||||
func TestSqliteAddColumn(t *testing.T) {
|
||||
if err := sqlite3.Table(test_table, func(table *Blueprint) {
|
||||
table.String("name", 50).Nullable().Comment("用户名")
|
||||
table.TinyInteger("age").Nullable().Comment("年龄")
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql 修改表, 添加/编辑字段
|
||||
func TestMysqlChangeColumn(t *testing.T) {
|
||||
if err := mysql.Table(test_table, func(table *Blueprint) {
|
||||
table.String("name", 100).Nullable().Comment("用户名").Change("username")
|
||||
table.SmallInteger("age").Nullable().Comment("年龄").Change()
|
||||
|
||||
table.String("nickname", 100).Nullable().Comment("昵称")
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 修改表, 添加/编辑字段
|
||||
func TestSqliteChangeColumn(t *testing.T) {
|
||||
if err := sqlite3.Table(test_table, func(table *Blueprint) {
|
||||
table.String("name", 100).Nullable().Comment("用户名").Change("username")
|
||||
table.SmallInteger("age").Nullable().Comment("年龄").Change()
|
||||
|
||||
table.String("nickname", 100).Nullable().Comment("昵称")
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql 修改表, 添加/编辑/删除字段
|
||||
func TestMysqlDropColumn(t *testing.T) {
|
||||
if err := mysql.Table(test_table, func(table *Blueprint) {
|
||||
table.String("username", 100).Default("").Comment("--用户名--").Change("name")
|
||||
table.Boolean("is_vip").Default(1).Comment("是否VIP")
|
||||
table.DropColumn("age")
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 修改表, 添加/编辑/删除字段
|
||||
func TestSqliteDropColumn(t *testing.T) {
|
||||
if err := sqlite3.Table(test_table, func(table *Blueprint) {
|
||||
table.String("username", 100).Default("").Comment("--用户名--").Change("name")
|
||||
table.Boolean("is_vip").Default(1).Comment("是否VIP")
|
||||
table.DropColumn("age")
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql重命名表
|
||||
func TestMysqlRename(t *testing.T) {
|
||||
if err := mysql.Rename(test_table, test_table+"_alias"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// SQLite 重命名表
|
||||
func TestSqliteRename(t *testing.T) {
|
||||
if err := sqlite3.Rename(test_table, test_table+"_alias"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql 删除表
|
||||
func TestMysqlDropTable(t *testing.T) {
|
||||
if err := mysql.Drop(test_table + "_alias"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 删除表
|
||||
func TestSqliteDropTable(t *testing.T) {
|
||||
if err := sqlite3.Drop(test_table + "_alias"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Mysql 删除表
|
||||
func TestMysqlDropTableIfExists(t *testing.T) {
|
||||
if err := mysql.DropIfExists(test_table + "_alias"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sqlite 删除表
|
||||
func TestSqliteDropTableIfExists(t *testing.T) {
|
||||
if err := sqlite3.DropIfExists(test_table + "_alias"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
115
schema/column_definition.go
Normal file
115
schema/column_definition.go
Normal file
@ -0,0 +1,115 @@
|
||||
package schema
|
||||
|
||||
type ColumnDefinition struct {
|
||||
Type string // kind of column, string / int
|
||||
Name string // name of column
|
||||
ColumnOptions
|
||||
}
|
||||
|
||||
type ColumnOptions struct {
|
||||
Length int // 字段长度
|
||||
Allowed []string // enum 选项
|
||||
Precision int // 日期时间精度
|
||||
Total int // 小数位数
|
||||
Places int // 小数精度
|
||||
rename string // 重命名
|
||||
useCurrent bool // CURRENT TIMESTAMP
|
||||
after string // Place the column "after" another column (MySQL)
|
||||
always bool // Used as a modifier for generatedAs() (PostgreSQL)
|
||||
autoIncrement bool // Set INTEGER columns as auto-increment (primary key)
|
||||
change bool // Change the column
|
||||
charset string // Specify a character set for the column (MySQL)
|
||||
collation string // Specify a collation for the column (MySQL/PostgreSQL/SQL Server)
|
||||
comment string // Add a comment to the column (MySQL)
|
||||
def any // Specify a "default" value for the column
|
||||
first bool // Place the column "first" in the table (MySQL)
|
||||
nullable bool // Allow NULL values to be inserted into the column
|
||||
storedAs string // Create a stored generated column (MySQL)
|
||||
unsigned bool // Set the INTEGER column as UNSIGNED (MySQL)
|
||||
virtualAs string // Create a virtual generated column (MySQL)
|
||||
unique bool // Add a unique index
|
||||
primary bool // Add a primary index
|
||||
index bool // Add an index
|
||||
spatialIndex bool // Add a spatial index
|
||||
}
|
||||
|
||||
// VirtualAs Create a virtual generated column (MySQL)
|
||||
func (c *ColumnDefinition) VirtualAs(as string) *ColumnDefinition {
|
||||
c.virtualAs = as
|
||||
return c
|
||||
}
|
||||
|
||||
// StoredAs Create a stored generated column (MySQL)
|
||||
func (c *ColumnDefinition) StoredAs(as string) *ColumnDefinition {
|
||||
c.storedAs = as
|
||||
return c
|
||||
}
|
||||
|
||||
// Unsigned Set INTEGER columns as UNSIGNED (MySQL)
|
||||
func (c *ColumnDefinition) Unsigned() *ColumnDefinition {
|
||||
c.unsigned = true
|
||||
return c
|
||||
}
|
||||
|
||||
// First Place the column "first" in the table (MySQL)
|
||||
func (c *ColumnDefinition) First() *ColumnDefinition {
|
||||
c.first = true
|
||||
return c
|
||||
}
|
||||
|
||||
// Default Specify a "default" value for the column
|
||||
func (c *ColumnDefinition) Default(def any) *ColumnDefinition {
|
||||
c.def = def
|
||||
return c
|
||||
}
|
||||
|
||||
// Comment Add a comment to a column (MySQL/PostgreSQL)
|
||||
func (c *ColumnDefinition) Comment(comm string) *ColumnDefinition {
|
||||
c.comment = comm
|
||||
return c
|
||||
}
|
||||
|
||||
// Collaction Specify a collation for the column (MySQL/PostgreSQL/SQL Server)
|
||||
func (c *ColumnDefinition) Collaction(coll string) *ColumnDefinition {
|
||||
c.collation = coll
|
||||
return c
|
||||
}
|
||||
|
||||
// Charset Specify a character set for the column (MySQL)
|
||||
func (c *ColumnDefinition) Charset(chars string) *ColumnDefinition {
|
||||
c.charset = chars
|
||||
return c
|
||||
}
|
||||
|
||||
// AutoIncrement set INTEGER columns as auto-increment (primary key)
|
||||
func (c *ColumnDefinition) AutoIncrement() *ColumnDefinition {
|
||||
c.autoIncrement = true
|
||||
return c
|
||||
}
|
||||
|
||||
// After place the column "after" another column (MySQL)
|
||||
func (c *ColumnDefinition) After(column string) *ColumnDefinition {
|
||||
c.after = column
|
||||
return c
|
||||
}
|
||||
|
||||
// Nullable makes a column nullable
|
||||
func (c *ColumnDefinition) Nullable() *ColumnDefinition {
|
||||
c.nullable = true
|
||||
return c
|
||||
}
|
||||
|
||||
// 修改字段, 默认新增
|
||||
func (c *ColumnDefinition) Change(param ...string) *ColumnDefinition {
|
||||
if len(param) > 0 && param[0] != "" {
|
||||
c.rename = param[0]
|
||||
}
|
||||
c.change = true
|
||||
return c
|
||||
}
|
||||
|
||||
// 时间戳
|
||||
func (c *ColumnDefinition) UseCurrent() *ColumnDefinition {
|
||||
c.useCurrent = true
|
||||
return c
|
||||
}
|
44
schema/grammar.go
Normal file
44
schema/grammar.go
Normal file
@ -0,0 +1,44 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"git.fsdpf.net/go/db"
|
||||
)
|
||||
|
||||
type IGrammar interface {
|
||||
// 判断表是否存在
|
||||
CompileTableExists() string
|
||||
// 字段列
|
||||
CompileColumnListing(table string) string
|
||||
// 创建表
|
||||
CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 添加字段
|
||||
CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string
|
||||
// 修改字段
|
||||
CompileChange(bp *Blueprint, command *Command, conn *db.Connection) []string
|
||||
// 添加主键
|
||||
// CompilePrimary(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 添加唯一键
|
||||
CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 添加普通索引
|
||||
CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 添加空间索引
|
||||
CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 删除表
|
||||
CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 删除表, 先判断再删除
|
||||
CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 删除表, 删除列
|
||||
CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) []string
|
||||
// 删除主键
|
||||
// CompileDropPrimary(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 删除唯一键
|
||||
CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 删除唯普通索引
|
||||
CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 删除唯空间索引
|
||||
CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 表重命名
|
||||
CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
// 索引重命名
|
||||
CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string
|
||||
}
|
338
schema/mysql_grammar.go
Normal file
338
schema/mysql_grammar.go
Normal file
@ -0,0 +1,338 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"git.fsdpf.net/go/db"
|
||||
)
|
||||
|
||||
type MysqlGrammar struct {
|
||||
}
|
||||
|
||||
var mysqlDefaultModifiers = []string{
|
||||
"Unsigned", "Charset", "Collate", "VirtualAs", "StoredAs",
|
||||
"Nullable", "Default", "Increment", "Comment", "After", "First",
|
||||
}
|
||||
var mysqlSerials = []string{
|
||||
"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger",
|
||||
}
|
||||
|
||||
func (this MysqlGrammar) CompileTableExists() string {
|
||||
return "select count(*) from information_schema.tables where table_name = ? and table_schema = ? and table_type = 'BASE TABLE'"
|
||||
}
|
||||
|
||||
func (this MysqlGrammar) CompileColumnListing(table string) string {
|
||||
return "select ordinal_position as `id`, column_name as `name` from information_schema.columns where table_name = ? and table_schema = ? order by ordinal_position"
|
||||
}
|
||||
|
||||
// 创建表
|
||||
func (this MysqlGrammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
sql := this.compileCreateTable(bp, command, conn)
|
||||
sql = this.compileCreateEncoding(sql, bp, conn)
|
||||
sql = this.compileCreateEngine(sql, bp, conn)
|
||||
return sql
|
||||
}
|
||||
|
||||
// 创建表结构
|
||||
func (this MysqlGrammar) compileCreateTable(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
temporary := "create"
|
||||
if bp.temporary {
|
||||
temporary = "create temporary"
|
||||
}
|
||||
return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", "))
|
||||
}
|
||||
|
||||
// 创建表 字符集
|
||||
func (this MysqlGrammar) compileCreateEncoding(sql string, bp *Blueprint, conn *db.Connection) string {
|
||||
if bp.charset != "" {
|
||||
sql = sql + " default charset=" + bp.charset
|
||||
} else if conn.GetConfig().Charset != "" {
|
||||
sql = sql + " default charset=" + conn.GetConfig().Charset
|
||||
} else {
|
||||
sql = sql + " default charset=utf8mb4"
|
||||
}
|
||||
|
||||
if bp.collation != "" {
|
||||
sql = sql + " collate=" + bp.collation
|
||||
} else {
|
||||
sql = sql + " collate=utf8mb4_general_ci"
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
// 创建表存储方式
|
||||
func (this MysqlGrammar) compileCreateEngine(sql string, bp *Blueprint, conn *db.Connection) string {
|
||||
if bp.engine != "" {
|
||||
sql = sql + " engine = " + bp.engine
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
// 添加字段
|
||||
func (this MysqlGrammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string {
|
||||
columns := PrefixArray("add", this.getAddedColumns(bp))
|
||||
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
|
||||
}
|
||||
|
||||
// 修改字段
|
||||
func (this MysqlGrammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) []string {
|
||||
columns := PrefixArray("change", this.getChangedColumns(bp))
|
||||
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
|
||||
}
|
||||
|
||||
// 添加主键
|
||||
func (this MysqlGrammar) CompilePrimary(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return this.compileKey(bp, command, "primary key")
|
||||
}
|
||||
|
||||
// 添加唯一键
|
||||
func (this MysqlGrammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return this.compileKey(bp, command, "unique")
|
||||
}
|
||||
|
||||
// 添加普通索引
|
||||
func (this MysqlGrammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return this.compileKey(bp, command, "index")
|
||||
}
|
||||
|
||||
// 添加空间索引
|
||||
func (this MysqlGrammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return this.compileKey(bp, command, "spatial index")
|
||||
}
|
||||
|
||||
// 添加索引
|
||||
func (this MysqlGrammar) compileKey(bp *Blueprint, command *Command, typ string) string {
|
||||
algorithm := ""
|
||||
if command.Algorithm != "" {
|
||||
algorithm = " using " + command.Algorithm
|
||||
}
|
||||
columns := lo.Map(command.Columns, func(column string, _ int) string {
|
||||
return "`" + column + "`"
|
||||
})
|
||||
return fmt.Sprintf("alter table %s add %s %s%s(%s)", bp.table, typ, command.Index, algorithm, strings.Join(columns, ","))
|
||||
}
|
||||
|
||||
// 删除表
|
||||
func (this MysqlGrammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return "drop table `" + bp.table + "`"
|
||||
}
|
||||
|
||||
// 删除表, 先判断再删除
|
||||
func (this MysqlGrammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return "drop table if exists `" + bp.table + "`"
|
||||
}
|
||||
|
||||
// 删除列
|
||||
func (this MysqlGrammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) []string {
|
||||
columns := lo.Map(command.Columns, func(column string, _ int) string {
|
||||
return "drop `" + column + "`"
|
||||
})
|
||||
return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")}
|
||||
}
|
||||
|
||||
// 删除主键
|
||||
func (this MysqlGrammar) CompileDropPrimary(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return "alter table `" + bp.table + "` drop primary key"
|
||||
}
|
||||
|
||||
// 删除索引
|
||||
func (this MysqlGrammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return "alter table `" + bp.table + "` drop index `" + command.Index + "`"
|
||||
}
|
||||
|
||||
// 删除唯一键
|
||||
func (this MysqlGrammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return this.CompileDropIndex(bp, command, conn)
|
||||
}
|
||||
|
||||
// 删除空间索引
|
||||
func (this MysqlGrammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return this.CompileDropIndex(bp, command, conn)
|
||||
}
|
||||
|
||||
// 表重命名
|
||||
func (this MysqlGrammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return "rename table `" + bp.table + "` to `" + command.To + "`"
|
||||
}
|
||||
|
||||
// 索引重命名
|
||||
func (this MysqlGrammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return fmt.Sprintf("alter table `%s` rename index `%s` to `%s`", bp.table, command.From, command.To)
|
||||
}
|
||||
|
||||
// 删除所有表
|
||||
func (this MysqlGrammar) CompileDropAllTables(tables ...string) string {
|
||||
tables = lo.Map(tables, func(table string, _ int) string {
|
||||
return "`" + table + "`"
|
||||
})
|
||||
return "drop table " + strings.Join(tables, ", ")
|
||||
}
|
||||
|
||||
// 删除所有视图
|
||||
func (this MysqlGrammar) CompileDropAllViews(views ...string) string {
|
||||
views = lo.Map(views, func(view string, _ int) string {
|
||||
return "`" + view + "`"
|
||||
})
|
||||
return "drop view " + strings.Join(views, ", ")
|
||||
}
|
||||
|
||||
// 查询所有表
|
||||
// func (this MysqlGrammar) CompileGetAllTables(tables ...string) string {
|
||||
|
||||
// }
|
||||
|
||||
// 查询所有视图
|
||||
// func (this MysqlGrammar) CompileGetAllViews(views ...string) string {}
|
||||
|
||||
func (this MysqlGrammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string {
|
||||
for _, modifier := range mysqlDefaultModifiers {
|
||||
sql = sql + this.GetColumnModifier(modifier, bp, column)
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (this MysqlGrammar) GetColumnType(column *ColumnDefinition) string {
|
||||
switch column.Type {
|
||||
case "char":
|
||||
return fmt.Sprintf("char(%d)", column.Length)
|
||||
case "string":
|
||||
return fmt.Sprintf("varchar(%d)", column.Length)
|
||||
case "text":
|
||||
return "text"
|
||||
case "integer":
|
||||
return "int"
|
||||
case "bigInteger":
|
||||
return "bigint"
|
||||
case "tinyInteger":
|
||||
return "tinyint"
|
||||
case "smallInteger":
|
||||
return "smallint"
|
||||
case "decimal":
|
||||
return fmt.Sprintf("decimal(%d,%d)", column.Total, column.Places)
|
||||
case "boolean":
|
||||
return "tinyint(1)"
|
||||
case "enum":
|
||||
return fmt.Sprintf("enum(%s)", QuoteString(column.Allowed))
|
||||
case "json":
|
||||
return "json"
|
||||
case "date":
|
||||
return "date"
|
||||
case "binary":
|
||||
return "binary"
|
||||
case "datetime":
|
||||
columnType := "datetime"
|
||||
if column.Precision > 0 {
|
||||
columnType = fmt.Sprintf("datetime(%d)", column.Precision)
|
||||
}
|
||||
return columnType
|
||||
case "time":
|
||||
columnType := "time"
|
||||
if column.Precision > 0 {
|
||||
columnType = fmt.Sprintf("time(%d)", column.Precision)
|
||||
}
|
||||
return columnType
|
||||
case "timestamp":
|
||||
columnType := "timestamp"
|
||||
if column.Precision > 0 {
|
||||
columnType = fmt.Sprintf("timestamp(%d)", column.Precision)
|
||||
}
|
||||
return columnType
|
||||
case "year":
|
||||
return "year"
|
||||
case "uuid":
|
||||
return "char(36)"
|
||||
}
|
||||
panic("不支持的数据类型: " + column.Type)
|
||||
}
|
||||
|
||||
func (this MysqlGrammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string {
|
||||
switch modifier {
|
||||
case "Unsigned":
|
||||
if column.unsigned {
|
||||
return " unsigned"
|
||||
}
|
||||
case "Charset":
|
||||
if column.charset != "" {
|
||||
return fmt.Sprintf(" character set %s", column.charset)
|
||||
}
|
||||
case "Collate":
|
||||
if column.collation != "" {
|
||||
return fmt.Sprintf(" collate '%s'", column.collation)
|
||||
}
|
||||
case "VirtualAs":
|
||||
if column.virtualAs != "" {
|
||||
return fmt.Sprintf(" as (%s)", column.virtualAs)
|
||||
}
|
||||
case "StoredAs":
|
||||
if column.storedAs != "" {
|
||||
return fmt.Sprintf(" as (%s) stored", column.storedAs)
|
||||
}
|
||||
case "Nullable":
|
||||
if column.virtualAs == "" && column.storedAs == "" {
|
||||
if column.nullable {
|
||||
return " null"
|
||||
}
|
||||
return " not null"
|
||||
}
|
||||
case "Default":
|
||||
switch v := column.def.(type) {
|
||||
case db.Expression:
|
||||
if column.useCurrent {
|
||||
return fmt.Sprintf(" default %s %s", "CURRENT_TIMESTAMP", v)
|
||||
}
|
||||
return fmt.Sprintf(" default %s", v)
|
||||
case nil:
|
||||
if column.useCurrent {
|
||||
return fmt.Sprintf(" default %s", "CURRENT_TIMESTAMP")
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf(" default '%v'", v)
|
||||
}
|
||||
|
||||
case "Increment":
|
||||
if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) {
|
||||
return " auto_increment primary key"
|
||||
}
|
||||
case "Comment":
|
||||
if column.comment != "" {
|
||||
return fmt.Sprintf(" comment '%s'", db.Addslashes(column.comment))
|
||||
}
|
||||
case "After":
|
||||
if column.after != "" {
|
||||
return fmt.Sprintf(" after `%s`", column.after)
|
||||
}
|
||||
case "First":
|
||||
if column.first {
|
||||
return " first"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// 获取新增表结构字段
|
||||
func (this MysqlGrammar) getAddedColumns(bp *Blueprint) (columns []string) {
|
||||
for _, column := range bp.getAddedColumns() {
|
||||
sql := "`" + column.Name + "` " + this.GetColumnType(column)
|
||||
columns = append(columns, this.addModifiers(sql, bp, column))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 获取修改表结构字段
|
||||
func (this MysqlGrammar) getChangedColumns(bp *Blueprint) (columns []string) {
|
||||
for _, column := range bp.getChangedColumns() {
|
||||
name := column.Name
|
||||
rename := column.Name
|
||||
if column.rename != "" {
|
||||
rename = column.rename
|
||||
}
|
||||
sql := fmt.Sprintf("`%s` `%s` ", name, rename) + this.GetColumnType(column)
|
||||
columns = append(columns, this.addModifiers(sql, bp, column))
|
||||
}
|
||||
return
|
||||
}
|
351
schema/sqlite3_grammar.go
Normal file
351
schema/sqlite3_grammar.go
Normal file
@ -0,0 +1,351 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"git.fsdpf.net/go/db"
|
||||
)
|
||||
|
||||
type Sqlite3Grammar struct {
|
||||
}
|
||||
|
||||
var sqlite3DefaultModifiers = []string{"VirtualAs", "StoredAs", "Nullable", "Default", "Increment"}
|
||||
|
||||
var sqlite3Serials = []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"}
|
||||
|
||||
func (this Sqlite3Grammar) CompileTableExists() string {
|
||||
// 兼容其他数据库查询
|
||||
// 其他数据库查看都是同时传入, 数据库名 和 数据表名
|
||||
return "select count(*) from sqlite_master where type = 'table' and name = ?||?"
|
||||
}
|
||||
|
||||
func (this Sqlite3Grammar) CompileColumnListing(table string) string {
|
||||
// 兼容其他数据库查询
|
||||
// 其他数据库查看都是同时传入, 数据库名 和 数据表名
|
||||
return "pragma table_info(" + table + ")"
|
||||
}
|
||||
|
||||
// 创建表
|
||||
func (this Sqlite3Grammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
temporary := "create"
|
||||
if bp.temporary {
|
||||
temporary = "create temporary"
|
||||
}
|
||||
return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", "))
|
||||
}
|
||||
|
||||
// 添加字段
|
||||
func (this Sqlite3Grammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string {
|
||||
columns := PrefixArray("add column", this.getAddedColumns(bp))
|
||||
|
||||
return lo.Map(columns, func(column string, _ int) string {
|
||||
return "alter table `" + bp.table + "` " + column
|
||||
})
|
||||
}
|
||||
|
||||
// 修改字段
|
||||
func (this Sqlite3Grammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) {
|
||||
// 创建表的sql
|
||||
cTableSql := this.getCreateTable(bp, conn)
|
||||
|
||||
// 旧表字段
|
||||
oldColumns := []string{}
|
||||
if _, err := conn.Select("select '`'||name||'`' from pragma_table_info(?)", []any{bp.table}, &oldColumns); err != nil {
|
||||
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
|
||||
}
|
||||
|
||||
// 数据同步
|
||||
insert_sql := fmt.Sprintf("insert into `"+bp.table+"` (%s)", strings.Join(oldColumns, ", "))
|
||||
select_sql := fmt.Sprintf("select %s from `"+bp.table+"_old`", strings.Join(oldColumns, ", "))
|
||||
|
||||
for _, column := range bp.getChangedColumns() {
|
||||
name := column.Name
|
||||
rename := column.rename
|
||||
|
||||
// 旧字段(查询) -> 新字段(插入)
|
||||
if column.rename != "" {
|
||||
select_sql = strings.ReplaceAll(select_sql, name, name+"` as `"+rename)
|
||||
insert_sql = strings.ReplaceAll(insert_sql, name, rename)
|
||||
} else {
|
||||
rename = name
|
||||
}
|
||||
|
||||
// 修改表字段
|
||||
sql := "`" + rename + "` " + this.GetColumnType(column)
|
||||
sql = this.addModifiers(sql, bp, column)
|
||||
reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P<prefix>\\s*)(?P<content>`%s`.*)(?P<end>(,\\s|\\s)?)$", name))
|
||||
cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql))
|
||||
}
|
||||
|
||||
// 开启事物
|
||||
// sqls = append(sqls, "BEGIN TRANSACTION")
|
||||
// 1. 重命名已存在的表
|
||||
sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table))
|
||||
// 2. 创建新结构表
|
||||
sqls = append(sqls, cTableSql)
|
||||
// 3. 复制数据
|
||||
sqls = append(sqls, insert_sql+" "+select_sql)
|
||||
// 4. 删除旧表
|
||||
sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table))
|
||||
// 提交事物
|
||||
// sqls = append(sqls, "COMMIT")
|
||||
|
||||
// @todo 复制索引
|
||||
|
||||
return sqls
|
||||
}
|
||||
|
||||
// 创建唯一索引
|
||||
func (this Sqlite3Grammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
columns := lo.Map(command.Columns, func(column string, _ int) string {
|
||||
return "`" + column + "`"
|
||||
})
|
||||
return fmt.Sprintf("create unique index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", "))
|
||||
}
|
||||
|
||||
// 创建普通索引
|
||||
func (this Sqlite3Grammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
columns := lo.Map(command.Columns, func(column string, _ int) string {
|
||||
return "`" + column + "`"
|
||||
})
|
||||
return fmt.Sprintf("create index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", "))
|
||||
}
|
||||
|
||||
// 创建空间索引
|
||||
func (this Sqlite3Grammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
panic("Sqlite 不支持 SpatialIndex")
|
||||
}
|
||||
|
||||
// 删除表
|
||||
func (this Sqlite3Grammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return "drop table `" + bp.table + "`"
|
||||
}
|
||||
|
||||
// 删除表, 先判断再删除
|
||||
func (this Sqlite3Grammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return "drop table if exists `" + bp.table + "`"
|
||||
}
|
||||
|
||||
// 删除列
|
||||
func (this Sqlite3Grammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) {
|
||||
// 创建表的sql, 排除不要的字段
|
||||
cTableSql := this.getCreateTable(bp, conn, command.Columns...)
|
||||
|
||||
// 旧表字段
|
||||
oldColumns := []string{}
|
||||
|
||||
if _, err := conn.Select(
|
||||
"select '`'||name||'`' from pragma_table_info(?) where `name` not in("+strings.Trim(strings.Repeat("?,", len(command.Columns)), ",")+")",
|
||||
append([]any{bp.table}, lo.ToAnySlice(command.Columns)...),
|
||||
&oldColumns); err != nil {
|
||||
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
|
||||
}
|
||||
|
||||
copy_data_sql := fmt.Sprintf("insert into `%s` select %s from `%s_old`", bp.table, strings.Join(oldColumns, ", "), bp.table)
|
||||
|
||||
for _, column := range bp.getChangedColumns() {
|
||||
name := column.Name
|
||||
rename := column.rename
|
||||
|
||||
if rename != "" {
|
||||
copy_data_sql = strings.ReplaceAll(copy_data_sql, "`"+name+"`", "`"+rename+"`")
|
||||
}
|
||||
|
||||
// 修改表字段
|
||||
sql := "`" + rename + "` " + this.GetColumnType(column)
|
||||
sql = this.addModifiers(sql, bp, column)
|
||||
reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P<prefix>\\s*)(?P<content>`%s`.*)(?P<end>(,\\s|\\s)?)$", name))
|
||||
cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql))
|
||||
}
|
||||
|
||||
// 开启事物
|
||||
// sqls = append(sqls, "BEGIN TRANSACTION")
|
||||
// 1. 重命名已存在的表
|
||||
sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table))
|
||||
// 2. 创建新结构表
|
||||
sqls = append(sqls, cTableSql)
|
||||
// 3. 复制数据
|
||||
sqls = append(sqls, copy_data_sql)
|
||||
// 4. 删除旧表
|
||||
sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table))
|
||||
// 提交事物
|
||||
// sqls = append(sqls, "COMMIT")
|
||||
|
||||
// @todo 复制索引
|
||||
|
||||
return sqls
|
||||
}
|
||||
|
||||
// 删除唯一索引
|
||||
func (this Sqlite3Grammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return this.CompileDropIndex(bp, command, conn)
|
||||
}
|
||||
|
||||
// 删除索引
|
||||
func (this Sqlite3Grammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return "drop index `" + command.Index + "`"
|
||||
}
|
||||
|
||||
// 删除空间索引
|
||||
func (this Sqlite3Grammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return this.CompileSpatialIndex(bp, command, conn)
|
||||
}
|
||||
|
||||
func (this Sqlite3Grammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
return "alter table `" + bp.table + "` rename to `" + command.To + "`"
|
||||
}
|
||||
|
||||
func (this Sqlite3Grammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string {
|
||||
// 先删除索引, 后创建索引
|
||||
return "alter table `" + bp.table + "` rename to `" + command.To + "`"
|
||||
}
|
||||
|
||||
// 删除所有表
|
||||
func (this Sqlite3Grammar) CompileDropAllTables() string {
|
||||
return "delete from sqlite_master where type in ('table', 'index', 'trigger')"
|
||||
}
|
||||
|
||||
// 删除所有视图
|
||||
func (this Sqlite3Grammar) CompileDropAllViews() string {
|
||||
return "delete from sqlite_master where type in ('view')"
|
||||
}
|
||||
|
||||
func (this Sqlite3Grammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string {
|
||||
for _, modifier := range mysqlDefaultModifiers {
|
||||
sql = sql + this.GetColumnModifier(modifier, bp, column)
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (this Sqlite3Grammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string {
|
||||
switch modifier {
|
||||
case "VirtualAs":
|
||||
if column.virtualAs != "" {
|
||||
return fmt.Sprintf(" as (%s)", column.virtualAs)
|
||||
}
|
||||
case "StoredAs":
|
||||
if column.storedAs != "" {
|
||||
return fmt.Sprintf(" as (%s) stored", column.storedAs)
|
||||
}
|
||||
case "Nullable":
|
||||
if !column.nullable {
|
||||
return " not null"
|
||||
}
|
||||
case "Default":
|
||||
if column.useCurrent {
|
||||
return ""
|
||||
}
|
||||
switch v := column.def.(type) {
|
||||
case db.Expression:
|
||||
return fmt.Sprintf(" default %s", v)
|
||||
case nil:
|
||||
default:
|
||||
return fmt.Sprintf(" default '%v'", v)
|
||||
}
|
||||
case "Increment":
|
||||
if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) {
|
||||
return " primary key autoincrement"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (this Sqlite3Grammar) GetColumnType(column *ColumnDefinition) string {
|
||||
switch column.Type {
|
||||
case "char", "string":
|
||||
return fmt.Sprintf("varchar(%d)", column.Length)
|
||||
case "text":
|
||||
return "text"
|
||||
case "year", "integer":
|
||||
return "integer"
|
||||
case "bigInteger":
|
||||
return "bigint"
|
||||
case "tinyInteger":
|
||||
return "tinyint"
|
||||
case "smallInteger":
|
||||
return "smallint"
|
||||
case "decimal":
|
||||
return fmt.Sprintf("numeric(%d,%d)", column.Total, column.Places)
|
||||
case "boolean":
|
||||
return "tinyint(1)"
|
||||
case "enum":
|
||||
return fmt.Sprintf(`varchar check ("%s" in (%s))`, column.Name, QuoteString(column.Allowed))
|
||||
case "json":
|
||||
return "json"
|
||||
case "date":
|
||||
return "date"
|
||||
case "binary":
|
||||
return "blob"
|
||||
case "datetime", "timestamp":
|
||||
columnType := column.Type
|
||||
if column.useCurrent {
|
||||
columnType = columnType + " default CURRENT_TIMESTAMP"
|
||||
}
|
||||
return columnType
|
||||
case "time":
|
||||
return "time"
|
||||
case "uuid":
|
||||
return "varchar(36)"
|
||||
}
|
||||
panic("不支持的数据类型: " + column.Type)
|
||||
}
|
||||
|
||||
// 获取新增表结构字段
|
||||
func (this Sqlite3Grammar) getAddedColumns(bp *Blueprint) (columns []string) {
|
||||
for _, column := range bp.getAddedColumns() {
|
||||
sql := "`" + column.Name + "` " + this.GetColumnType(column)
|
||||
columns = append(columns, this.addModifiers(sql, bp, column))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 获取修改表结构字段
|
||||
func (this Sqlite3Grammar) getChangedColumns(bp *Blueprint) map[string]string {
|
||||
columns := map[string]string{}
|
||||
for _, column := range bp.getChangedColumns() {
|
||||
sql := "`" + column.Name + "` " + this.GetColumnType(column)
|
||||
columns[column.Name] = this.addModifiers(sql, bp, column)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
// 数据库表字段
|
||||
func (this Sqlite3Grammar) getCreateTable(bp *Blueprint, conn *db.Connection, without ...string) (sql string) {
|
||||
dbColumns := []struct {
|
||||
Name string `db:"name"`
|
||||
Type string `db:"type"`
|
||||
Default db.NullString `db:"dflt_value"`
|
||||
NotNull bool `db:"notnull"`
|
||||
Pk bool `db:"pk"`
|
||||
}{}
|
||||
|
||||
if _, err := conn.Select("select * from pragma_table_info(?) order by cid", []any{bp.table}, &dbColumns); err != nil {
|
||||
panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err))
|
||||
}
|
||||
|
||||
cols := []string{}
|
||||
for _, column := range dbColumns {
|
||||
if lo.Contains(without, column.Name) {
|
||||
continue
|
||||
}
|
||||
col := " `" + column.Name + "`"
|
||||
col += " " + column.Type
|
||||
if column.NotNull {
|
||||
col += " not null"
|
||||
}
|
||||
if column.Pk {
|
||||
col += " primary key autoincrement"
|
||||
}
|
||||
if column.Default != "" {
|
||||
col += " default " + string(column.Default)
|
||||
}
|
||||
cols = append(cols, col)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("CREATE TABLE `%s` (\n%s\n)", bp.table, strings.Join(cols, ", \n"))
|
||||
}
|
25
schema/util.go
Normal file
25
schema/util.go
Normal file
@ -0,0 +1,25 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/samber/lo"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func PrefixArray(prefix string, values []string) (items []string) {
|
||||
for _, value := range values {
|
||||
items = append(items, prefix+" "+value)
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func QuoteString(value any) string {
|
||||
switch v := value.(type) {
|
||||
case []string:
|
||||
return strings.Join(lo.Map(v, func(item string, _ int) string {
|
||||
return "'" + item + "'"
|
||||
}), ", ")
|
||||
case string:
|
||||
return "'" + v + "'"
|
||||
}
|
||||
return ""
|
||||
}
|
50
sqlite3_connector.go
Normal file
50
sqlite3_connector.go
Normal file
@ -0,0 +1,50 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type SqliteConnector struct {
|
||||
}
|
||||
|
||||
func (c SqliteConnector) connect(config *DBConfig) *Connection {
|
||||
var params []string
|
||||
|
||||
if len(config.Journal) > 0 {
|
||||
params = append(params, "_journal="+config.Journal)
|
||||
}
|
||||
if len(config.Locking) > 0 {
|
||||
params = append(params, "_locking="+config.Locking)
|
||||
}
|
||||
if len(config.Mode) > 0 {
|
||||
params = append(params, "mode="+config.Mode)
|
||||
}
|
||||
if config.Synchronous > 0 {
|
||||
params = append(params, fmt.Sprintf("_sync=%d", config.Synchronous))
|
||||
}
|
||||
|
||||
var dsn string
|
||||
if len(config.Dsn) > 0 {
|
||||
dsn = config.Dsn
|
||||
} else {
|
||||
dsn = fmt.Sprintf("file:%s?%s", config.File, strings.Join(params, "&"))
|
||||
}
|
||||
db, err := sql.Open(DriverSqlite3, dsn)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &Connection{
|
||||
RxDB: db,
|
||||
DB: db,
|
||||
Config: config,
|
||||
}
|
||||
}
|
5
sqlite3_grammar.go
Normal file
5
sqlite3_grammar.go
Normal file
@ -0,0 +1,5 @@
|
||||
package db
|
||||
|
||||
type Sqlite3Grammar struct {
|
||||
*MysqlGrammar
|
||||
}
|
96
transaction.go
Normal file
96
transaction.go
Normal file
@ -0,0 +1,96 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type Transaction struct {
|
||||
Tx *sql.Tx
|
||||
Rx *sql.DB
|
||||
Config *DBConfig
|
||||
ConnectionName string
|
||||
}
|
||||
type TxClosure func(tx *Transaction) (interface{}, error)
|
||||
|
||||
func (t *Transaction) Table(tableName string) *Builder {
|
||||
builder := NewTxBuilder(t)
|
||||
if t.Config.Driver == DriverMysql {
|
||||
builder.Grammar = &MysqlGrammar{}
|
||||
} else if t.Config.Driver == DriverSqlite3 {
|
||||
builder.Grammar = &Sqlite3Grammar{}
|
||||
} else {
|
||||
panic("不支持的数据库类型")
|
||||
}
|
||||
builder.Grammar.SetTablePrefix(t.Config.Prefix)
|
||||
builder.Tx = t
|
||||
builder.Grammar.SetBuilder(builder)
|
||||
builder.From(tableName)
|
||||
return builder
|
||||
}
|
||||
|
||||
func (t *Transaction) Select(query string, bindings []interface{}, dest interface{}) (result sql.Result, err error) {
|
||||
var stmt *sql.Stmt
|
||||
var rows *sql.Rows
|
||||
stmt, err = t.Rx.Prepare(query)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err = stmt.Query(bindings...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return ScanResult(rows, dest), nil
|
||||
}
|
||||
func (t *Transaction) AffectingStatement(query string, bindings []interface{}) (result sql.Result, err error) {
|
||||
stmt, errP := t.Tx.Prepare(query)
|
||||
if errP != nil {
|
||||
err = errP
|
||||
return
|
||||
}
|
||||
defer stmt.Close()
|
||||
result, err = stmt.Exec(bindings...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
func (t *Transaction) Insert(query string, bindings []interface{}) (sql.Result, error) {
|
||||
return t.AffectingStatement(query, bindings)
|
||||
}
|
||||
|
||||
func (t *Transaction) Update(query string, bindings []interface{}) (sql.Result, error) {
|
||||
return t.AffectingStatement(query, bindings)
|
||||
}
|
||||
|
||||
func (t *Transaction) Delete(query string, bindings []interface{}) (sql.Result, error) {
|
||||
return t.AffectingStatement(query, bindings)
|
||||
}
|
||||
|
||||
func (t *Transaction) Statement(query string, bindings []interface{}) (sql.Result, error) {
|
||||
return t.AffectingStatement(query, bindings)
|
||||
}
|
||||
|
||||
func (t *Transaction) Commit() error {
|
||||
return t.Tx.Commit()
|
||||
}
|
||||
|
||||
func (t *Transaction) Rollback() error {
|
||||
return t.Tx.Rollback()
|
||||
}
|
||||
|
||||
func (t *Transaction) Query() *Builder {
|
||||
builder := NewTxBuilder(t)
|
||||
if t.Config.Driver == DriverMysql {
|
||||
builder.Grammar = &MysqlGrammar{}
|
||||
} else if t.Config.Driver == DriverSqlite3 {
|
||||
builder.Grammar = &Sqlite3Grammar{}
|
||||
} else {
|
||||
panic("不支持的数据库类型")
|
||||
}
|
||||
builder.Grammar.SetBuilder(builder)
|
||||
builder.Grammar.SetTablePrefix(t.Config.Prefix)
|
||||
return builder
|
||||
}
|
43
types.go
Normal file
43
types.go
Normal file
@ -0,0 +1,43 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Strings []string
|
||||
|
||||
func (c *Strings) Scan(value any) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(value.([]byte), c)
|
||||
}
|
||||
|
||||
func (c Strings) Value() (driver.Value, error) {
|
||||
b, err := json.Marshal(c)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
type NullString string
|
||||
|
||||
func (this *NullString) Scan(value any) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch s := value.(type) {
|
||||
case []byte:
|
||||
*this = NullString(s)
|
||||
default:
|
||||
*this = NullString(fmt.Sprintf("%v", s))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c NullString) Value() (driver.Value, error) {
|
||||
return c, nil
|
||||
}
|
77
util.go
Normal file
77
util.go
Normal file
@ -0,0 +1,77 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
|
||||
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
|
||||
|
||||
func ToSnakeCase(str string) string {
|
||||
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
|
||||
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
|
||||
return strings.ToLower(snake)
|
||||
}
|
||||
func ExtractStruct(target interface{}) map[string]interface{} {
|
||||
tv := reflect.Indirect(reflect.ValueOf(target))
|
||||
tt := tv.Type()
|
||||
result := make(map[string]interface{}, tv.NumField())
|
||||
|
||||
for i := 0; i < tv.NumField(); i++ {
|
||||
key := ToSnakeCase(tt.Field(i).Name)
|
||||
result[key] = tv.Field(i).Interface()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
func InterfaceToSlice(param interface{}) []interface{} {
|
||||
if p, ok := param.([]interface{}); ok {
|
||||
return p
|
||||
}
|
||||
tv := reflect.Indirect(reflect.ValueOf(param))
|
||||
var res []interface{}
|
||||
if tv.Type().Kind() == reflect.Slice {
|
||||
for i := 0; i < tv.Len(); i++ {
|
||||
res = append(res, tv.Index(i).Interface())
|
||||
}
|
||||
} else {
|
||||
panic("not slice")
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// addslashes() 函数返回在预定义字符之前添加反斜杠的字符串。
|
||||
// 预定义字符是:
|
||||
// 单引号(')
|
||||
// 双引号(")
|
||||
// 反斜杠(\)
|
||||
func Addslashes(str string) string {
|
||||
tmpRune := []rune{}
|
||||
strRune := []rune(str)
|
||||
for _, ch := range strRune {
|
||||
switch ch {
|
||||
case []rune{'\\'}[0], []rune{'"'}[0], []rune{'\''}[0]:
|
||||
tmpRune = append(tmpRune, []rune{'\\'}[0])
|
||||
tmpRune = append(tmpRune, ch)
|
||||
default:
|
||||
tmpRune = append(tmpRune, ch)
|
||||
}
|
||||
}
|
||||
return string(tmpRune)
|
||||
}
|
||||
|
||||
// stripslashes() 函数删除由 addslashes() 函数添加的反斜杠。
|
||||
func Stripslashes(str string) string {
|
||||
dstRune := []rune{}
|
||||
strRune := []rune(str)
|
||||
strLenth := len(strRune)
|
||||
for i := 0; i < strLenth; i++ {
|
||||
if strRune[i] == []rune{'\\'}[0] {
|
||||
i++
|
||||
}
|
||||
dstRune = append(dstRune, strRune[i])
|
||||
}
|
||||
return string(dstRune)
|
||||
}
|
Loading…
Reference in New Issue
Block a user