db/connection.go
2023-04-12 15:58:25 +08:00

163 lines
3.8 KiB
Go

package db
import (
"database/sql"
"errors"
)
type Connection struct {
RxDB *sql.DB
DB *sql.DB
Config *DBConfig
ConnectionName string
}
type IConnection interface {
Insert(query string, bindings []interface{}) (result sql.Result, err error)
Select(query string, bindings []interface{}, dest interface{}) (result sql.Result, err error)
Update(query string, bindings []interface{}) (result sql.Result, err error)
Delete(query string, bindings []interface{}) (result sql.Result, err error)
AffectingStatement(query string, bindings []interface{}) (result sql.Result, err error)
Statement(query string, bindings []interface{}) (sql.Result, error)
Table(tableName string) *Builder
}
type Preparer interface {
Prepare(query string) (*sql.Stmt, error)
}
type Execer interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}
type ITransaction interface {
BeginTransaction() (*Transaction, error)
Transaction(closure TxClosure) (interface{}, error)
}
func (c *Connection) Select(query string, bindings []interface{}, dest interface{}) (result sql.Result, err error) {
var stmt *sql.Stmt
var rows *sql.Rows
stmt, err = c.DB.Prepare(query)
if err != nil {
return
}
defer stmt.Close()
rows, err = stmt.Query(bindings...)
if err != nil {
return
}
defer rows.Close()
return ScanResult(rows, dest), nil
}
func (c *Connection) BeginTransaction() (*Transaction, error) {
begin, err := c.DB.Begin()
if err != nil {
return nil, errors.New(err.Error())
}
tx := &Transaction{
Rx: c.RxDB,
Tx: begin,
Config: c.Config,
ConnectionName: c.ConnectionName,
}
return tx, nil
}
func (c *Connection) Transaction(closure TxClosure) (interface{}, error) {
begin, err := c.DB.Begin()
if err != nil {
panic(err.Error())
}
defer func() {
if err := recover(); err != nil {
_ = begin.Rollback()
panic(err)
} else {
err = begin.Commit()
}
}()
tx := &Transaction{
Rx: c.RxDB,
Tx: begin,
Config: c.Config,
ConnectionName: c.ConnectionName,
}
return closure(tx)
}
func (c *Connection) Insert(query string, bindings []interface{}) (result sql.Result, err error) {
return c.AffectingStatement(query, bindings)
}
func (c *Connection) Update(query string, bindings []interface{}) (result sql.Result, err error) {
return c.AffectingStatement(query, bindings)
}
func (c *Connection) Delete(query string, bindings []interface{}) (result sql.Result, err error) {
return c.AffectingStatement(query, bindings)
}
func (c *Connection) AffectingStatement(query string, bindings []interface{}) (result sql.Result, err error) {
stmt, errP := c.DB.Prepare(query)
if errP != nil {
err = errP
return
}
defer stmt.Close()
result, err = stmt.Exec(bindings...)
if err != nil {
return
}
return
}
func (c *Connection) Table(tableName string) *Builder {
builder := NewBuilder(c)
if c.GetConfig().Driver == DriverMysql {
builder.Grammar = &MysqlGrammar{}
} else if c.Config.Driver == DriverSqlite3 {
builder.Grammar = &Sqlite3Grammar{}
} else {
panic("不支持的数据库类型")
}
builder.Grammar.SetTablePrefix(c.Config.Prefix)
builder.Grammar.SetBuilder(builder)
builder.From(tableName)
return builder
}
func (c *Connection) Statement(query string, bindings []interface{}) (sql.Result, error) {
return c.AffectingStatement(query, bindings)
}
func (c *Connection) GetDB() *sql.DB {
return c.DB
}
func (c *Connection) Query() *Builder {
builder := NewBuilder(c)
if c.GetConfig().Driver == DriverMysql {
builder.Grammar = &MysqlGrammar{}
} else if c.Config.Driver == DriverSqlite3 {
builder.Grammar = &Sqlite3Grammar{}
} else {
panic("不支持的数据库类型")
}
builder.Grammar.SetBuilder(builder)
builder.Grammar.SetTablePrefix(c.Config.Prefix)
return builder
}
func (c *Connection) GetConfig() *DBConfig {
return c.Config
}