From 59b3faa171bc671c02667070513d8db2b4a8abef Mon Sep 17 00:00:00 2001 From: what-00 Date: Wed, 12 Apr 2023 15:58:25 +0800 Subject: [PATCH] first commit --- .gitignore | 23 + README.md | 98 ++ builder.go | 2506 +++++++++++++++++++++++++++++++++++ config.go | 34 + connection.go | 162 +++ connection_factory.go | 45 + db.go | 58 + db_manager.go | 108 ++ go.mod | 11 + go.sum | 8 + gramar.go | 34 + mysql_connector.go | 103 ++ mysql_grammar.go | 569 ++++++++ raw.go | 7 + scan.go | 279 ++++ scan_test.go | 73 + schema/blueprint.go | 456 +++++++ schema/blueprint_test.go | 216 +++ schema/builder.go | 128 ++ schema/builder_test.go | 195 +++ schema/column_definition.go | 115 ++ schema/grammar.go | 44 + schema/mysql_grammar.go | 338 +++++ schema/sqlite3_grammar.go | 351 +++++ schema/util.go | 25 + sqlite3_connector.go | 50 + sqlite3_grammar.go | 5 + transaction.go | 96 ++ types.go | 43 + util.go | 77 ++ 30 files changed, 6257 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 builder.go create mode 100644 config.go create mode 100644 connection.go create mode 100644 connection_factory.go create mode 100644 db.go create mode 100644 db_manager.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 gramar.go create mode 100644 mysql_connector.go create mode 100644 mysql_grammar.go create mode 100644 raw.go create mode 100644 scan.go create mode 100644 scan_test.go create mode 100644 schema/blueprint.go create mode 100644 schema/blueprint_test.go create mode 100644 schema/builder.go create mode 100644 schema/builder_test.go create mode 100644 schema/column_definition.go create mode 100644 schema/grammar.go create mode 100644 schema/mysql_grammar.go create mode 100644 schema/sqlite3_grammar.go create mode 100644 schema/util.go create mode 100644 sqlite3_connector.go create mode 100644 sqlite3_grammar.go create mode 100644 transaction.go create mode 100644 types.go create mode 100644 util.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..adf8f72 --- /dev/null +++ b/.gitignore @@ -0,0 +1,23 @@ +# ---> Go +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work + diff --git a/README.md b/README.md new file mode 100644 index 0000000..3806619 --- /dev/null +++ b/README.md @@ -0,0 +1,98 @@ +# WORK IN PROGRESS +# Get Started +A golang ORM Framework like Laravel's Eloquent + +Example + +```golang +type User struct { + Id int64 `goelo:"column:id;primaryKey"` + UserName sql.NullString `goelo:"column:name"` + Age int `goelo:"column:age"` + Status int `goelo:"column:status"` + Friends []UserT `goelo:"BelongsToMany:FriendsRelation"` + Address Address `goelo:"HasOne:AddressRelation"` + CreatedAt time.Time `goelo:"column:created_at,timestatmp:create"` + UpdatedAt sql.NullTime `goelo:"column:updated_at,timestatmp:update"` +} + +//Find/First/Get +var user User +DB.Table("users").Find(&user,1) +DB.Table("users").Where("name","john").First(&user) +DB.Table("users").Where("name","john").FirstOrCreate(&user) + +//Column/Aggeregate +var age int +DB.Table("users").Where("id","=",20).Value(&age,"age") +DB.Table("users").Max(&age,"age") +var salary int +DB.Table("user").Where("age",">=",30).Avg(&salary,"salary") + +//Find record to map +var m = make(map[string]interface{}) +DB.Query().From("users").Find(&m,3) +var ms []map[string]interface{} +DB.Query().From("users").Get(&ms) + +//Pagination +var users []User +DB.Table("users").Where("id", ">", 10).Where("id", "<", 28).Paginate(&users, 10, 1) + +//Chunk/ChunkById +var total int +totalP := &total +DB.Table("users").OrderBy("id").Chunk(&[]User{}, 10, func(dest interface{}) error { + us := dest.(*[]User) + for _, user := range *us { + assert.Equal(t, user.UserName, sql.NullString{ + String: fmt.Sprintf("user-%d", user.Age), + Valid: true, + }) + *totalP++ + } + return nil +}) + +var total int +totalP := &total +DB.Table("users").ChunkById(&[]User{}, 10, func(dest interface{}) error { + us := dest.(*[]User) + for _, user := range *us { + assert.Equal(t, user.UserName, sql.NullString{ + String: fmt.Sprintf("user-%d", user.Age), + Valid: true, + }) + *totalP++ + } + return nil +}) +//Query clause +DB.Where("name","john@apple.com").OrWhere("email","john@apple.com").First(&user) + +DB.Where("is_admin", 1).Where(map[string]interface{}{ + "name": "Joe", "location": "LA", +}, db.BOOLEAN_OR).Where(func(builder *db.Builder){ + builder.WhereYear("created_at", "<", 2010).WhereColumn("first_name", "last_name").OrWhereNull("activited_at") +}).ToSql() +// sql:"select `name`, `age`, `email` where `is_admin` = ? or (`name` = ? and `location` = ?) and (year(`created_at`) < ? and `first_name` = `last_name` or `activited_at` is null)" + +// Insert/Update/Delete +DB.Table("users").Insert(map[string]interface{}{ + "name": "go-eloquent", + "age": 18, + "created_at": now, +}) +DB.Table("users").Where("id", id).Update(map[string]interface{}{ + "name": "newname", + "age": 18, + "updated_at": now.Add(time.Hour * 1), +}) +DB.Table("users").Where("id", id).Delete() + +``` +More Details,visit [Docs](https://glitterlip.github.io/go-eloquent-docs/) +# Credits +[https://github.com/go-gorm/gorm](https://github.com/go-gorm/gorm) +[https://github.com/jmoiron/sqlx](https://github.com/jmoiron/sqlx) +[https://github.com/qclaogui/database](https://github.com/qclaogui/database) diff --git a/builder.go b/builder.go new file mode 100644 index 0000000..40fbffa --- /dev/null +++ b/builder.go @@ -0,0 +1,2506 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" + "log" + "math" + "reflect" + "strings" + "time" +) + +var Operators = []string{ + "=", "<", ">", "<=", ">=", "<>", "!=", "<=>", + "like", "like binary", "not like", "ilike", + "&", "|", "^", "<<", ">>", + "rlike", "regexp", "not regexp", + "~", "~*", "!~", "!~*", "similar to", + "not similar to", "not ilike", "~~*", "!~~*"} + +var SelectComponents = []string{ + TYPE_AGGREGRATE, + TYPE_COLUMN, + TYPE_FROM, + TYPE_JOIN, + TYPE_WHERE, + TYPE_GROUP_BY, + TYPE_HAVING, + TYPE_ORDER, + TYPE_LIMIT, + TYPE_OFFSET, + TYPE_UNION, + TYPE_LOCK, +} + +var Bindings = map[string]struct{}{ + TYPE_SELECT: {}, + TYPE_FROM: {}, + TYPE_JOIN: {}, + TYPE_UPDATE: {}, + TYPE_WHERE: {}, + TYPE_GROUP_BY: {}, + TYPE_HAVING: {}, + TYPE_ORDER: {}, + TYPE_UNION: {}, + TYPE_UNION_ORDER: {}, + TYPE_INSERT: {}, +} +var BindingKeysInOrder = []string{TYPE_SELECT, TYPE_FROM, TYPE_JOIN, TYPE_UPDATE, TYPE_WHERE, TYPE_GROUP_BY, TYPE_HAVING, TYPE_ORDER, TYPE_UNION, TYPE_UNION_ORDER, TYPE_INSERT} + +type Builder struct { + Connection *Connection + Tx *Transaction + Grammar IGrammar + //Processor processors.IProcessor + PreSql strings.Builder + Bindings map[string][]interface{} //available options (select,from,join,where,groupBy,having,order,union,unionOrder) + FromTable interface{} + //TablePrefix string + TableAlias string + Wheres []Where + Aggregates []Aggregate + Columns []interface{} // The columns that should be returned. + IsDistinct bool // Indicates if the query returns distinct results. + DistinctColumns []string // distinct columns. + Joins []*Builder + Groups []interface{} + Havings []Having + Orders []Order + LimitNum int + OffsetNum int + //Unions []Where + //UnionLimit int + //UnionOffset int + //UnionOrders int + Components map[string]struct{} //SelectComponents + LockMode interface{} + LoggingQueries bool + Pretending bool + PreparedSql string + Dest interface{} + Pivots []string + PivotWheres []Where + OnlyColumns map[string]interface{} + ExceptColumns map[string]interface{} + JoinBuilder bool + JoinType string + JoinTable interface{} + UseWrite bool //TODO: + BeforeQueryCallBacks []func(builder *Builder, t string, data ...map[string]any) + AfterQueryCallBacks []func(builder *Builder, t string, res sql.Result, err error, data ...map[string]any) + RemovedScopes map[string]struct{} +} + +type Log struct { + SQL string + Bindings []interface{} + Result sql.Result + Time time.Duration +} + +type ScopeFunc func(builder *Builder) *Builder + +const ( + CONDITION_TYPE_BASIC = "basic" + CONDITION_TYPE_COLUMN = "column" + CONDITION_TYPE_RAW = "raw" + CONDITION_TYPE_IN = "in" + CONDITION_TYPE_NOT_IN = "not in" + CONDITION_TYPE_NULL = "null" + CONDITION_TYPE_BETWEEN = "between" + CONDITION_TYPE_BETWEEN_COLUMN = "between column" + CONDITION_TYPE_NOT_BETWEEN = "not between" + CONDITION_TYPE_DATE = "date" + CONDITION_TYPE_TIME = "time" + CONDITION_TYPE_DATETIME = "datetime" + CONDITION_TYPE_DAY = "day" + CONDITION_TYPE_MONTH = "month" + CONDITION_TYPE_YEAR = "year" + CONDITION_TYPE_CLOSURE = "closure" //todo + CONDITION_TYPE_NESTED = "nested" + CONDITION_TYPE_SUB = "subquery" + CONDITION_TYPE_EXIST = "exist" + CONDITION_TYPE_NOT_EXIST = "not exist" + CONDITION_TYPE_ROW_VALUES = "rowValues" + BOOLEAN_AND = "and" + BOOLEAN_OR = "or" + CONDITION_JOIN_NOT = "not" //todo + JOIN_TYPE_LEFT = "left" + JOIN_TYPE_RIGHT = "right" + JOIN_TYPE_INNER = "inner" + JOIN_TYPE_CROSS = "cross" + ORDER_ASC = "asc" + ORDER_DESC = "desc" + TYPE_SELECT = "select" + TYPE_FROM = "from" + TYPE_JOIN = "join" + TYPE_WHERE = "where" + TYPE_GROUP_BY = "groupBy" + TYPE_HAVING = "having" + TYPE_ORDER = "order" + TYPE_UNION = "union" + TYPE_UNION_ORDER = "unionOrder" + TYPE_COLUMN = "column" + TYPE_AGGREGRATE = "aggregrate" + TYPE_OFFSET = "offset" + TYPE_LIMIT = "limit" + TYPE_LOCK = "lock" + TYPE_INSERT = "insert" + TYPE_UPDATE = "update" + TYPE_DELETE = "delete" +) + +type Aggregate struct { + AggregateName string + AggregateColumn string +} + +type Order struct { + OrderType string + Direction string + Column string + RawSql interface{} +} +type Having struct { + HavingType string + HavingColumn string + HavingOperator string + HavingValue interface{} + HavingBoolean string + RawSql interface{} + Not bool +} +type Where struct { + Type string + Column string + Columns []string + Operator string + FirstColumn string + SecondColumn string + RawSql interface{} + Value interface{} + Values []interface{} + Boolean string + Not bool //not in,not between,not null + Query *Builder +} + +func NewBuilder(c *Connection) *Builder { + b := Builder{ + Connection: c, + Components: make(map[string]struct{}), + //Processor: processors.MysqlProcessor{}, + Bindings: make(map[string][]interface{}), + RemovedScopes: make(map[string]struct{}), + LoggingQueries: c.Config.EnableLog, + } + return &b +} +func NewTxBuilder(tx *Transaction) *Builder { + b := Builder{ + Components: make(map[string]struct{}), + Tx: tx, + Bindings: make(map[string][]interface{}), + LoggingQueries: tx.Config.EnableLog, + RemovedScopes: make(map[string]struct{}), + } + return &b +} + +/* +CloneBuilderWithTable clone builder use same connection and table +*/ +func CloneBuilderWithTable(b *Builder) *Builder { + cb := Builder{ + Connection: b.Connection, + Components: make(map[string]struct{}), + Tx: b.Tx, + Bindings: make(map[string][]interface{}), + LoggingQueries: b.LoggingQueries, + } + if b.Connection.Config.Driver == DriverMysql { + cb.Grammar = &MysqlGrammar{} + } else if b.Connection.Config.Driver == DriverSqlite3 { + cb.Grammar = &Sqlite3Grammar{} + } else { + panic("不支持的数据库类型") + } + cb.Grammar.SetTablePrefix(b.Grammar.GetTablePrefix()) + cb.Grammar.SetBuilder(&cb) + return &cb +} +func MergeBuilder(b *Builder, builder *Builder) *Builder { + cb := Builder{ + Connection: b.Connection, + Components: make(map[string]struct{}), + Tx: b.Tx, + Bindings: make(map[string][]interface{}), + LoggingQueries: b.LoggingQueries, + } + if b.Connection.Config.Driver == DriverMysql { + cb.Grammar = &MysqlGrammar{} + } else if b.Connection.Config.Driver == DriverSqlite3 { + cb.Grammar = &Sqlite3Grammar{} + } else { + panic("不支持的数据库类型") + } + cb.Grammar.SetTablePrefix(b.Grammar.GetTablePrefix()) + cb.Grammar.SetBuilder(&cb) + return &cb +} + +/* +Clone Clone the query. +*/ +func Clone(original *Builder) *Builder { + newBuilder := Builder{ + Connection: original.Connection, + Tx: original.Tx, + PreSql: strings.Builder{}, + Bindings: make(map[string][]interface{}, len(original.Bindings)), + FromTable: original.FromTable, + TableAlias: original.TableAlias, + Wheres: make([]Where, len(original.Wheres)), + Aggregates: make([]Aggregate, len(original.Aggregates)), + Columns: make([]interface{}, len(original.Columns)), + IsDistinct: original.IsDistinct, + DistinctColumns: make([]string, len(original.DistinctColumns)), + Joins: []*Builder{}, + Groups: make([]interface{}, len(original.Groups)), + Havings: make([]Having, len(original.Havings)), + Orders: make([]Order, len(original.Orders)), + LimitNum: original.LimitNum, + OffsetNum: original.OffsetNum, + Components: make(map[string]struct{}, len(original.Components)), + LockMode: original.LockMode, + LoggingQueries: original.LoggingQueries, + Pretending: original.Pretending, + PreparedSql: "", + Dest: nil, + Pivots: make([]string, len(original.Pivots)), + PivotWheres: make([]Where, len(original.PivotWheres)), + OnlyColumns: make(map[string]interface{}, len(original.OnlyColumns)), + ExceptColumns: make(map[string]interface{}, len(original.ExceptColumns)), + BeforeQueryCallBacks: make([]func(builder *Builder, t string, data ...map[string]any), len(original.BeforeQueryCallBacks)), + AfterQueryCallBacks: make([]func(builder *Builder, t string, res sql.Result, err error, data ...map[string]any), len(original.AfterQueryCallBacks)), + JoinBuilder: original.JoinBuilder, + JoinType: original.JoinType, + JoinTable: original.JoinTable, + } + for key, _ := range original.Bindings { + newBuilder.Bindings[key] = make([]interface{}, len(original.Bindings[key])) + copy(newBuilder.Bindings[key], original.Bindings[key]) + } + copy(newBuilder.Wheres, original.Wheres) + copy(newBuilder.Aggregates, original.Aggregates) + copy(newBuilder.Columns, original.Columns) + copy(newBuilder.DistinctColumns, original.DistinctColumns) + copy(newBuilder.Groups, original.Groups) + copy(newBuilder.Havings, original.Havings) + copy(newBuilder.Orders, original.Orders) + copy(newBuilder.Pivots, original.Pivots) + copy(newBuilder.PivotWheres, original.PivotWheres) + copy(newBuilder.BeforeQueryCallBacks, original.BeforeQueryCallBacks) + copy(newBuilder.AfterQueryCallBacks, original.AfterQueryCallBacks) + for key, _ := range original.Components { + newBuilder.Components[key] = original.Components[key] + } + for key, _ := range original.OnlyColumns { + newBuilder.OnlyColumns[key] = original.OnlyColumns[key] + } + for key, _ := range original.ExceptColumns { + newBuilder.ExceptColumns[key] = original.ExceptColumns[key] + } + for _, join := range original.Joins { + newBuilder.Joins = append(newBuilder.Joins, join.Clone()) //TODO: add tests, + } + if original.Connection.Config.Driver == DriverMysql { + newBuilder.Grammar = &MysqlGrammar{ + Prefix: original.Grammar.GetTablePrefix(), + Builder: &newBuilder, + } + } else if original.Connection.Config.Driver == DriverSqlite3 { + newBuilder.Grammar = &Sqlite3Grammar{ + MysqlGrammar: &MysqlGrammar{ + Prefix: original.Grammar.GetTablePrefix(), + Builder: &newBuilder, + }, + } + } else { + panic("不支持的数据库类型") + } + + return &newBuilder +} +func (b *Builder) Clone() *Builder { + return Clone(b) +} + +/*CloneWithout +CloneWithoutClone the query without the given properties. +*/ +func CloneWithout(original *Builder, without ...string) *Builder { + b := Clone(original) + b.Reset(without...) + return b +} +func (b *Builder) CloneWithout(without ...string) *Builder { + return CloneWithout(b, without...) +} + +/* +CloneWithoutBindings Clone the query without the given bindings. +*/ +func CloneWithoutBindings(original *Builder, bindings ...string) *Builder { + b := Clone(original) + for _, binding := range bindings { + b.Bindings[binding] = nil + } + return b +} +func (b *Builder) CloneWithoutBindings(bindings ...string) *Builder { + return CloneWithoutBindings(b, bindings...) +} + +// Select set the columns to be selected +func (b *Builder) Select(columns ...interface{}) *Builder { + b.Components[TYPE_COLUMN] = struct{}{} + + for i := 0; i < len(columns); i++ { + switch columnType := columns[i].(type) { + case string: + b.Columns = append(b.Columns, columnType) + case map[string]interface{}: + for as, q := range columnType { + switch q.(type) { + case func(builder *Builder): + b.SelectSub(q, as) + case *Builder: + b.SelectSub(q, as) + case Expression: + b.AddSelect(q) + case string: + b.Columns = append(b.Columns, q) + default: + panic(errors.New("unsupported type for select")) + } + } + case Expression: + b.AddSelect(columnType) + } + } + return b +} + +//SelectSub Add a subselect expression to the query. +func (b *Builder) SelectSub(query interface{}, as string) *Builder { + qStr, bindings := b.CreateSub(query) + queryStr := fmt.Sprintf("( %s ) as %s", qStr, b.Grammar.Wrap(as)) + + return b.SelectRaw(queryStr, bindings) + +} + +// AddSelect Add a new select column to the query +// 1. slice of string +// 2. map[string]{"alias"} +func (b *Builder) AddSelect(columns ...interface{}) *Builder { + b.Components[TYPE_COLUMN] = struct{}{} + + for i := 0; i < len(columns); i++ { + switch columnType := columns[i].(type) { + case string: + b.Columns = append(b.Columns, columnType) + case map[string]interface{}: + for as, q := range columnType { + b.SelectSub(q, as) + } + case Expression: + b.Columns = append(b.Columns, columnType) + } + } + return b +} + +// SelectRaw Add a new "raw" select expression to the query. +func (b *Builder) SelectRaw(expression string, bindings ...[]interface{}) *Builder { + b.AddSelect(Expression(expression)) + if len(bindings) > 0 { + b.AddBinding(bindings[0], TYPE_SELECT) + } + return b +} + +//CreateSub Creates a subquery and parse it. +func (b *Builder) CreateSub(query interface{}) (string, []interface{}) { + var builder *Builder + if bT, ok := query.(*Builder); ok { + builder = bT + } else if function, ok := query.(func(builder *Builder)); ok { + builder = CloneBuilderWithTable(b) + function(builder) + } else if str, ok := query.(string); ok { + return b.ParseSub(str) + } else { + panic("can not create sub") + } + return b.ParseSub(builder) +} + +/* +ParseSub Parse the subquery into SQL and bindings. +*/ +func (b *Builder) ParseSub(query interface{}) (string, []interface{}) { + if s, ok := query.(string); ok { + return s, []interface{}{} + } else if builder, ok := query.(*Builder); ok { + return builder.ToSql(), builder.GetBindings() + } + panic("A subquery must be a query builder instance, a Closure, or a string.") +} + +/* +ToSql Get the SQL representation of the query. +*/ +func (b *Builder) ToSql() string { + b.ApplyBeforeQueryCallBacks(TYPE_SELECT) + if len(b.PreparedSql) > 0 { + b.PreparedSql = "" + } + return b.Grammar.CompileSelect() +} + +/* +Distinct Force the query to only return distinct results. +*/ +func (b *Builder) Distinct(distinct ...string) *Builder { + b.IsDistinct = true + if len(distinct) > 0 { + b.DistinctColumns = append(b.DistinctColumns, distinct...) + } + return b +} + +/* +IsQueryable Determine if the value is a query builder instance or a Closure. +*/ +func IsQueryable(value interface{}) bool { + switch value.(type) { + case Builder, *Builder: + return true + case func(builder *Builder): + return true + default: + return false + } +} + +/* +Table Begin a fluent query against a database table. +*/ +func (b *Builder) Table(params ...string) *Builder { + if len(params) > 1 && params[1] != "" { + return b.From(params[0], params[1]) + } + + return b.From(params[0]) +} + +/* +From Set the table which the query is targeting. +*/ +func (b *Builder) From(table interface{}, params ...string) *Builder { + if IsQueryable(table) { + return b.FromSub(table, params[0]) + } + b.Components[TYPE_FROM] = struct{}{} + if len(params) == 1 { + b.TableAlias = params[0] + b.FromTable = fmt.Sprintf("%s as %s", table, params[0]) + } else { + b.FromTable = table.(string) + } + return b +} + +/* +FromSub Makes "from" fetch from a subquery. +*/ +func (b *Builder) FromSub(table interface{}, as string) *Builder { + qStr, bindings := b.CreateSub(table) + queryStr := fmt.Sprintf("(%s) as %s", qStr, b.Grammar.WrapTable(as)) + + return b.FromRaw(queryStr, bindings) +} + +/* +FromRaw Add a raw from clause to the query. +*/ +func (b *Builder) FromRaw(raw interface{}, bindings ...[]interface{}) *Builder { + var expression Expression + if str, ok := raw.(string); ok { + expression = Expression(str) + } else { + expression = raw.(Expression) + } + b.FromTable = expression + b.Components[TYPE_FROM] = struct{}{} + if len(bindings) > 0 { + b.AddBinding(bindings[0], TYPE_FROM) + } + return b +} + +/* +Join Add a join clause to the query. +*/ +func (b *Builder) Join(table string, first interface{}, params ...interface{}) *Builder { + //$table, $first, $operator = null, $second = null, $type = 'inner', $where = false + + var second string + var isWhere = false + var operator = "=" + var joinType = JOIN_TYPE_INNER + + length := len(params) + + switch length { + case 0: + if _, ok := first.(func(builder *Builder)); !ok { + panic(errors.New("arguements num mismatch")) + } + case 1: + if _, ok := first.(func(builder *Builder)); ok { + joinType = params[0].(string) + } else { + second = params[0].(string) + } + case 2: + operator = params[0].(string) + second = params[1].(string) + case 3: + operator = params[0].(string) + second = params[1].(string) + joinType = params[2].(string) + case 4: + operator = params[0].(string) + second = params[1].(string) + joinType = params[2].(string) + isWhere = params[3].(bool) + } + return b.join(table, first, operator, second, joinType, isWhere) +} + +/* +RightJoin Add a right join to the query. +*/ +func (b *Builder) RightJoin(table string, firstColumn interface{}, params ...interface{}) *Builder { + + var operator, second string + joinType := JOIN_TYPE_RIGHT + length := len(params) + switch length { + case 0: + if function, ok := firstColumn.(func(builder *Builder)); ok { + clause := NewJoin(b, joinType, table) + function(clause) + b.Components[TYPE_JOIN] = struct{}{} + b.Joins = append(b.Joins, clause) + b.AddBinding(clause.GetBindings(), TYPE_JOIN) + return b + } else { + panic(errors.New("arguements num mismatch")) + } + case 1: + operator = "=" + second = params[0].(string) + case 2: + operator = params[0].(string) + second = params[1].(string) + + } + return b.join(table, firstColumn, operator, second, joinType, false) +} + +/* +LeftJoin Add a left join to the query. +*/ +func (b *Builder) LeftJoin(table string, firstColumn interface{}, params ...interface{}) *Builder { + var operator, second string + joinType := JOIN_TYPE_LEFT + length := len(params) + switch length { + case 0: + if function, ok := firstColumn.(func(builder *Builder)); ok { + clause := NewJoin(b, joinType, table) + function(clause) + b.Components[TYPE_JOIN] = struct{}{} + b.Joins = append(b.Joins, clause) + b.AddBinding(clause.GetBindings(), TYPE_JOIN) + return b + } else { + panic(errors.New("arguements num mismatch")) + } + case 1: + operator = "=" + second = params[0].(string) + case 2: + operator = params[0].(string) + second = params[1].(string) + } + return b.join(table, firstColumn, operator, second, joinType, false) +} + +/* +LeftJoinWhere Add a "join where" clause to the query. +*/ +func (b *Builder) LeftJoinWhere(table, firstColumn, joinOperator, secondColumn string) *Builder { + return b.joinWhere(table, firstColumn, joinOperator, secondColumn, JOIN_TYPE_LEFT) +} + +/* +RightJoinWhere Add a "right join where" clause to the query. +*/ +func (b *Builder) RightJoinWhere(table, firstColumn, joinOperator, secondColumn string) *Builder { + return b.joinWhere(table, firstColumn, joinOperator, secondColumn, JOIN_TYPE_RIGHT) +} +func NewJoin(builder *Builder, joinType string, table interface{}) *Builder { + cb := CloneBuilderWithTable(builder) + cb.JoinBuilder = true + cb.JoinType = joinType + cb.JoinTable = table + + return cb +} +func (b *Builder) On(first interface{}, params ...interface{}) *Builder { + var second string + boolean := BOOLEAN_AND + operator := "=" + switch len(params) { + case 0: + if function, ok := first.(func(builder *Builder)); ok { + b.WhereNested(function, boolean) + return b + } + panic(errors.New("arguements mismatch")) + case 1: + second = params[0].(string) + case 2: + operator = params[0].(string) + second = params[1].(string) + case 3: + operator = params[0].(string) + second = params[1].(string) + boolean = params[2].(string) + } + + b.WhereColumn(first.(string), operator, second, boolean) + return b + +} +func (b *Builder) OrOn(first interface{}, params ...interface{}) *Builder { + var second string + boolean := BOOLEAN_OR + operator := "=" + switch len(params) { + case 0: + if function, ok := first.(func(builder *Builder)); ok { + b.WhereNested(function, boolean) + return b + } + panic(errors.New("arguements mismatch")) + case 1: + second = params[0].(string) + return b.On(first, operator, second, boolean) + case 2: + operator = params[0].(string) + second = params[1].(string) + return b.On(first, operator, second, boolean) + } + panic(errors.New("arguements mismatch")) +} + +/* +CrossJoin Add a "cross join" clause to the query. +*/ +func (b *Builder) CrossJoin(table string, params ...interface{}) *Builder { + var operator, first, second string + joinType := JOIN_TYPE_CROSS + length := len(params) + switch length { + case 0: + clause := NewJoin(b, joinType, table) + b.Joins = append(b.Joins, clause) + b.Components[TYPE_JOIN] = struct{}{} + + return b + case 1: + if function, ok := params[0].(func(builder *Builder)); ok { + clause := NewJoin(b, joinType, table) + function(clause) + b.Joins = append(b.Joins, clause) + b.AddBinding(clause.GetBindings(), TYPE_JOIN) + b.Components[TYPE_JOIN] = struct{}{} + + return b + } else { + panic(errors.New("cross join arguements mismatch")) + } + case 2: + first = params[0].(string) + operator = "=" + second = params[1].(string) + case 3: + first = params[0].(string) + operator = params[1].(string) + second = params[2].(string) + + } + return b.join(table, first, operator, second, joinType, false) +} + +/* +CrossJoinSub Add a subquery cross join to the query. +*/ +func (b *Builder) CrossJoinSub(query interface{}, as string) *Builder { + queryStr, bindings := b.CreateSub(query) + expr := fmt.Sprintf("(%s) as %s", queryStr, b.Grammar.WrapTable(as)) + b.AddBinding(bindings, TYPE_JOIN) + clause := NewJoin(b, JOIN_TYPE_CROSS, Raw(expr)) + b.Joins = append(b.Joins, clause) + b.Components[TYPE_JOIN] = struct{}{} + + return b +} + +/* +join Add a join clause to the query. +*/ +func (b *Builder) join(table, first, operator, second, joinType, isWhere interface{}) *Builder { + //$table, $first, $operator = null, $second = null, $type = 'inner', $where = false + b.Components[TYPE_JOIN] = struct{}{} + + if function, ok := first.(func(builder *Builder)); ok { + clause := NewJoin(b, joinType.(string), table) + clause.Grammar.SetTablePrefix(b.Grammar.GetTablePrefix()) + function(clause) + b.Joins = append(b.Joins, clause) + b.AddBinding(clause.GetBindings(), TYPE_JOIN) + return b + } + + clause := NewJoin(b, joinType.(string), table) + + if isWhere.(bool) { + clause.Where(first, operator, second) + } else { + clause.On(first, operator, second, BOOLEAN_AND) + } + b.AddBinding(clause.GetBindings(), TYPE_JOIN) + + clause.Grammar.SetTablePrefix(b.Grammar.GetTablePrefix()) + + b.Joins = append(b.Joins, clause) + + return b +} + +/* +joinWhere Add a "join where" clause to the query. +*/ +func (b *Builder) joinWhere(table, firstColumn, joinOperator, secondColumn, joinType string) *Builder { + return b.join(table, firstColumn, joinOperator, secondColumn, joinType, true) +} + +/* +JoinWhere Add a "join where" clause to the query. +*/ +func (b *Builder) JoinWhere(table, firstColumn, joinOperator, secondColumn string) *Builder { + return b.joinWhere(table, firstColumn, joinOperator, secondColumn, JOIN_TYPE_INNER) +} + +/* +JoinSub Add a subquery join clause to the query. +*/ +func (b *Builder) JoinSub(query interface{}, as string, first interface{}, params ...interface{}) *Builder { + queryStr, bindings := b.CreateSub(query) + expr := fmt.Sprintf("(%s) as %s", queryStr, b.Grammar.WrapTable(as)) + var operator string + joinType := JOIN_TYPE_INNER + var isWhere = false + var second interface{} + switch len(params) { + case 1: + operator = "=" + second = params[0] + case 2: + operator = params[0].(string) + second = params[1] + case 3: + operator = params[0].(string) + second = params[1] + joinType = params[2].(string) + case 4: + operator = params[0].(string) + second = params[1] + joinType = params[2].(string) + isWhere = params[3].(bool) + } + b.AddBinding(bindings, TYPE_JOIN) + + return b.join(Raw(expr), first, operator, second, joinType, isWhere) +} + +/* +LeftJoinSub Add a subquery left join to the query. +*/ +func (b *Builder) LeftJoinSub(query interface{}, as string, first interface{}, params ...interface{}) *Builder { + queryStr, bindings := b.CreateSub(query) + expr := fmt.Sprintf("(%s) as %s", queryStr, b.Grammar.WrapTable(as)) + var operator string + joinType := JOIN_TYPE_LEFT + var second interface{} + switch len(params) { + case 1: + operator = "=" + second = params[0] + case 2: + operator = params[0].(string) + second = params[1] + + } + b.AddBinding(bindings, TYPE_JOIN) + + return b.join(Raw(expr), first, operator, second, joinType, false) +} + +/* +RightJoinSub Add a subquery right join to the query. +*/ +func (b *Builder) RightJoinSub(query interface{}, as string, first interface{}, params ...interface{}) *Builder { + queryStr, bindings := b.CreateSub(query) + expr := fmt.Sprintf("(%s) as %s", queryStr, b.Grammar.WrapTable(as)) + var operator string + joinType := JOIN_TYPE_RIGHT + var second interface{} + switch len(params) { + case 1: + operator = "=" + second = params[0] + case 2: + operator = params[0].(string) + second = params[1] + + } + b.AddBinding(bindings, TYPE_JOIN) + + return b.join(Raw(expr), first, operator, second, joinType, false) +} + +//AddBinding Add a binding to the query. +func (b *Builder) AddBinding(value []interface{}, bindingType string) *Builder { + if _, ok := Bindings[bindingType]; !ok { + log.Panicf("invalid binding type:%s\n", bindingType) + } + var tv []interface{} + for _, v := range value { + if _, ok := v.(Expression); !ok { + tv = append(tv, v) + } + } + b.Bindings[bindingType] = append(b.Bindings[bindingType], tv...) + return b +} + +//GetBindings Get the current query value bindings in a flattened slice. +func (b *Builder) GetBindings() (res []interface{}) { + for _, key := range BindingKeysInOrder { + if bindings, ok := b.Bindings[key]; ok { + res = append(res, bindings...) + } + } + return +} + +// GetRawBindings Get the raw map of array of bindings. +func (b *Builder) GetRawBindings() map[string][]interface{} { + return b.Bindings +} + +/* +MergeBindings Merge an array of bindings into our bindings. +*/ +//func (b *Builder) MergeBindings(builder *Builder) *Builder { +// res := make(map[string][]interface{}) +// +// for i, i2 := range collection { +// +// } +// return b +//} + +/* +Where Add a basic where clause to the query. +column,operator,value, +*/ +func (b *Builder) Where(params ...interface{}) *Builder { + + //map of where conditions + if maps, ok := params[0].([][]interface{}); ok { + for _, conditions := range maps { + b.Where(conditions...) + } + return b + } + + paramsLength := len(params) + var operator string + var value interface{} + var boolean = BOOLEAN_AND + switch condition := params[0].(type) { + case func(builder *Builder): + var boolean string + if paramsLength > 1 { + boolean = params[1].(string) + } else { + boolean = BOOLEAN_AND + } + //clousure + cb := CloneBuilderWithTable(b) + condition(cb) + return b.AddNestedWhereQuery(cb, boolean) + case Where: + b.Wheres = append(b.Wheres, condition) + b.Components[TYPE_WHERE] = struct{}{} + return b + + case []Where: + b.Wheres = append(b.Wheres, condition...) + b.Components[TYPE_WHERE] = struct{}{} + return b + case Expression: + if paramsLength > 1 { + boolean = params[1].(string) + } else { + boolean = BOOLEAN_AND + } + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_RAW, + RawSql: condition, + Boolean: boolean, + }) + b.Components[TYPE_WHERE] = struct{}{} + return b + case map[string]interface{}: + boolean = BOOLEAN_AND + if paramsLength > 1 { + boolean = params[1].(string) + } + cb := CloneBuilderWithTable(b) + for k, v := range condition { + cb.Where(k, v) + } + return b.AddNestedWhereQuery(cb, boolean) + } + switch paramsLength { + case 2: + //assume operator is = and omitted + operator = "=" + value = params[1] + case 3: + //correspond to column,operator,value + operator = params[1].(string) + value = params[2] + case 4: + //correspond to column,operator,value,boolean jointer + operator = params[1].(string) + value = params[2] + boolean = params[3].(string) + } + column := params[0].(string) + //operator might be in/not in/between/not between,in there cases we need take value as slice + if strings.Contains("in,not in,between,not between", operator) { + switch operator { + case CONDITION_TYPE_IN: + b.WhereIn(column, value, boolean) + return b + case CONDITION_TYPE_NOT_IN: + b.WhereNotIn(column, value, boolean) + return b + case CONDITION_TYPE_BETWEEN: + b.WhereBetween(column, value, boolean) + return b + case CONDITION_TYPE_NOT_BETWEEN: + b.WhereNotBetween(column, value, boolean) + return b + } + } + if f, ok := value.(func(builder *Builder)); ok { + return b.WhereSub(column, operator, f, boolean) + } + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_BASIC, + Column: column, + Operator: operator, + Value: value, + Boolean: boolean, + }) + b.AddBinding([]interface{}{value}, TYPE_WHERE) + b.Components[TYPE_WHERE] = struct{}{} + return b +} +func (b *Builder) WherePivot(params ...interface{}) *Builder { + + column := params[0].(string) + paramsLength := len(params) + var operator string + var value interface{} + var boolean = BOOLEAN_AND + switch paramsLength { + case 2: + operator = "=" + value = params[1] + case 3: + operator = params[1].(string) + value = params[2] + case 4: + operator = params[1].(string) + value = params[2] + boolean = params[3].(string) + } + + b.PivotWheres = append(b.PivotWheres, Where{ + Type: CONDITION_TYPE_BASIC, + Column: column, + Operator: operator, + Value: value, + Boolean: boolean, + }) + return b +} + +/* +OrWhere Add an "or where" clause to the query. +*/ +func (b *Builder) OrWhere(params ...interface{}) *Builder { + paramsLength := len(params) + if clousure, ok := params[0].(func(builder *Builder)); ok { + return b.Where(clousure, BOOLEAN_OR) + } + if paramsLength == 2 { + params = []interface{}{params[0], "=", params[1], BOOLEAN_OR} + } else { + params = append(params, BOOLEAN_OR) + } + return b.Where(params...) +} + +/* +WhereColumn Add a "where" clause comparing two columns to the query. +*/ +func (b *Builder) WhereColumn(first interface{}, second ...string) *Builder { + length := len(second) + var firstColumn = first + var secondColumn, operator, boolean string + if arr, ok := first.([][]interface{}); ok { + return b.WhereNested(func(builder *Builder) { + for _, term := range arr { + var strs []string + for i := 1; i < len(term); i++ { + strs = append(strs, term[i].(string)) + } + builder.WhereColumn(term[0], strs...) + } + }) + + } + switch length { + case 1: + secondColumn = second[0] + operator = "=" + boolean = BOOLEAN_AND + case 2: + operator = second[0] + secondColumn = second[1] + boolean = BOOLEAN_AND + case 3: + operator = second[0] + secondColumn = second[1] + boolean = second[2] + default: + panic("wrong arguements in where column") + } + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_COLUMN, + FirstColumn: firstColumn.(string), + Operator: operator, + SecondColumn: secondColumn, + Boolean: boolean, + }) + b.Components[TYPE_WHERE] = struct{}{} + + return b +} + +/* +OrWhereColumn Add an "or where" clause comparing two columns to the query. +*/ +func (b *Builder) OrWhereColumn(first string, second ...string) *Builder { + var ts = make([]string, 3, 3) + switch len(second) { + case 1: + ts = []string{"=", second[0], BOOLEAN_OR} + case 2: + ts = []string{second[0], second[1], BOOLEAN_OR} + } + return b.WhereColumn(first, ts...) +} + +/* +WhereRaw Add a raw where clause to the query. +*/ +func (b *Builder) WhereRaw(rawSql string, params ...interface{}) *Builder { + paramsLength := len(params) + var boolean string = BOOLEAN_AND + var bindings []interface{} + switch paramsLength { + case 1: + bindings = params[0].([]interface{}) + b.AddBinding(bindings, TYPE_WHERE) + case 2: + bindings = params[0].([]interface{}) + b.AddBinding(bindings, TYPE_WHERE) + boolean = params[1].(string) + } + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_RAW, + RawSql: Raw(rawSql), + Boolean: boolean, + }) + b.Components[TYPE_WHERE] = struct{}{} + + return b + +} + +/* +OrWhereRaw Add a raw or where clause to the query. +*/ +func (b *Builder) OrWhereRaw(rawSql string, bindings ...[]interface{}) *Builder { + switch len(bindings) { + case 0: + return b.WhereRaw(rawSql, []interface{}{}, BOOLEAN_OR) + case 1: + return b.WhereRaw(rawSql, bindings[0], BOOLEAN_OR) + } + panic(errors.New("arguements mismatch")) +} + +/* +WhereIn Add a "where in" clause to the query. +column values boolean not +*/ +func (b *Builder) WhereIn(params ...interface{}) *Builder { + paramsLength := len(params) + var boolean string + not := false + if paramsLength > 2 { + boolean = params[2].(string) + } else { + boolean = BOOLEAN_AND + } + if paramsLength > 3 { + not = params[3].(bool) + } + + var values []interface{} + if IsQueryable(params[1]) { + queryStr, bindings := b.CreateSub(params[1]) + values = append(values, Raw(queryStr)) + b.AddBinding(bindings, TYPE_WHERE) + } else { + values = InterfaceToSlice(params[1]) + } + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_IN, + Column: params[0].(string), + Values: values, + Boolean: boolean, + Not: not, + }) + b.Components[TYPE_WHERE] = struct{}{} + b.AddBinding(values, TYPE_WHERE) + return b +} + +/* +OrWhereIn Add an "or where in" clause to the query. +column values +*/ +func (b *Builder) OrWhereIn(params ...interface{}) *Builder { + params = append(params, BOOLEAN_OR, false) + return b.WhereIn(params...) +} + +//column values [ boolean ] +func (b *Builder) WhereNotIn(params ...interface{}) *Builder { + params = append(params, BOOLEAN_AND, true) + return b.WhereIn(params...) +} + +/* +OrWhereNotIn Add an "or where not in" clause to the query. +column values +*/ +func (b *Builder) OrWhereNotIn(params ...interface{}) *Builder { + params = append(params, BOOLEAN_OR, true) + return b.WhereIn(params...) +} + +/* +WhereNull Add a "where null" clause to the query. + params takes in below order: + 1. column string + 2. boolean string in [2]string{"and","or"} + 3. type string "not" +*/ +func (b *Builder) WhereNull(column interface{}, params ...interface{}) *Builder { + paramsLength := len(params) + var boolean = BOOLEAN_AND + var not = false + switch paramsLength { + case 1: + boolean = params[0].(string) + case 2: + boolean = params[0].(string) + not = params[1].(bool) + } + b.Components[TYPE_WHERE] = struct{}{} + switch columnTemp := column.(type) { + case string: + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_NULL, + Column: columnTemp, + Boolean: boolean, + Not: not, + }) + case []interface{}: + for _, i := range columnTemp { + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_NULL, + Column: i.(string), + Boolean: boolean, + Not: not, + }) + } + case []string: + for _, i := range columnTemp { + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_NULL, + Column: i, + Boolean: boolean, + Not: not, + }) + } + } + + return b +} + +/* +WhereNotNull Add a "where not null" clause to the query. +*/ +func (b *Builder) WhereNotNull(column interface{}, params ...interface{}) *Builder { + paramsLength := len(params) + if paramsLength == 0 { + params = append(params, BOOLEAN_AND, true) + } else if paramsLength == 1 { + params = append(params, true) + } + return b.WhereNull(column, params...) +} + +/* +OrWhereNull Add an "or where null" clause to the query. +column not +*/ +func (b *Builder) OrWhereNull(column interface{}, params ...interface{}) *Builder { + paramsLength := len(params) + if paramsLength == 0 { + params = append(params, BOOLEAN_OR, false) + } else if paramsLength == 1 { + params = []interface{}{BOOLEAN_OR, params[0]} + } + return b.WhereNull(column, params...) +} + +/* +OrWhereNotNull Add an "or where not null" clause to the query. +*/ +func (b *Builder) OrWhereNotNull(column interface{}) *Builder { + params := []interface{}{BOOLEAN_OR, true} + return b.WhereNull(column, params...) +} + +/* +WhereBetween Add a where between statement to the query. + + params takes in below order: + 1. WhereBetween(column string,values []interface{"min","max"}) + 2. WhereBetween(column string,values []interface{"min","max"},"and/or") + 3. WhereBetween(column string,values []interface{"min","max","and/or",true/false}) + +*/ +func (b *Builder) WhereBetween(params ...interface{}) *Builder { + paramsLength := len(params) + var boolean = BOOLEAN_AND + not := false + if paramsLength > 2 { + boolean = params[2].(string) + } + var betweenType = CONDITION_TYPE_BETWEEN + if paramsLength > 3 { + not = params[3].(bool) + } + b.Components[TYPE_WHERE] = struct{}{} + tvalues := params[1].([]interface{})[0:2] + for _, tvalue := range tvalues { + if _, ok := tvalue.(Expression); !ok { + b.AddBinding([]interface{}{tvalue}, TYPE_WHERE) + } + } + + b.Wheres = append(b.Wheres, Where{ + Type: betweenType, + Column: params[0].(string), + Boolean: boolean, + Values: params[1].([]interface{})[0:2], + Not: not, + }) + return b +} +func (b *Builder) WhereNotBetween(params ...interface{}) *Builder { + if len(params) == 2 { + params = append(params, BOOLEAN_AND, true) + } + return b.WhereBetween(params...) +} + +/* +OrWhereBetween Add an or where between statement to the query. +*/ +func (b *Builder) OrWhereBetween(params ...interface{}) *Builder { + params = append(params, BOOLEAN_OR) + return b.WhereBetween(params...) +} + +/* +OrWhereNotBetween Add an or where not between statement to the query. +*/ +func (b *Builder) OrWhereNotBetween(params ...interface{}) *Builder { + params = append(params, BOOLEAN_OR, true) + + return b.WhereBetween(params...) +} + +/* +WhereBetweenColumns Add a where between statement using columns to the query. +*/ +func (b *Builder) WhereBetweenColumns(column string, values []interface{}, params ...interface{}) *Builder { + paramsLength := len(params) + var boolean = BOOLEAN_AND + var betweenType = CONDITION_TYPE_BETWEEN_COLUMN + not := false + if paramsLength > 0 { + boolean = params[0].(string) + } + if paramsLength > 1 { + not = params[1].(bool) + } + + b.Components[TYPE_WHERE] = struct{}{} + + b.Wheres = append(b.Wheres, Where{ + Type: betweenType, + Column: column, + Boolean: boolean, + Values: values, + Not: not, + }) + return b +} + +//AddTimeBasedWhere Add a time based (year, month, day, time) statement to the query. +//params order : timefuncionname column operator value boolean +//minimum : timefuncionname column value +func (b *Builder) AddTimeBasedWhere(params ...interface{}) *Builder { + paramsLength := len(params) + var timeType = params[0] + var boolean = BOOLEAN_AND + var operator string + var value interface{} + var tvalue interface{} + //timefunction column value + if paramsLength == 3 { + operator = "=" + tvalue = params[2] + } else if paramsLength > 3 { + //timefunction column operator value + operator = params[2].(string) + tvalue = params[3] + //timefunction column operator value boolean + if paramsLength > 4 && params[4].(string) != boolean { + boolean = BOOLEAN_OR + } + } else { + tvalue = params[3] + } + switch tvalue.(type) { + case string: + value = tvalue.(string) + case int: + value = tvalue.(int) + case time.Time: + switch timeType.(string) { + case CONDITION_TYPE_DATE: + value = tvalue.(time.Time).Format("2006-01-02") + case CONDITION_TYPE_MONTH: + value = tvalue.(time.Time).Format("01") + case CONDITION_TYPE_YEAR: + value = tvalue.(time.Time).Format("2006") + case CONDITION_TYPE_TIME: + value = tvalue.(time.Time).Format("15:04:05") + case CONDITION_TYPE_DAY: + value = tvalue.(time.Time).Format("02") + } + case Expression: + value = tvalue.(Expression) + } + b.Wheres = append(b.Wheres, Where{ + Type: timeType.(string), + Column: params[1].(string), + Boolean: boolean, + Value: value, + Operator: operator, + }) + b.AddBinding([]interface{}{value}, TYPE_WHERE) + b.Components[TYPE_WHERE] = struct{}{} + return b +} + +//column operator value boolean +func (b *Builder) WhereDate(params ...interface{}) *Builder { + p := append([]interface{}{CONDITION_TYPE_DATE}, params...) + return b.AddTimeBasedWhere(p...) +} +func (b *Builder) WhereTime(params ...interface{}) *Builder { + p := append([]interface{}{CONDITION_TYPE_TIME}, params...) + return b.AddTimeBasedWhere(p...) +} +func (b *Builder) WhereDay(params ...interface{}) *Builder { + p := append([]interface{}{CONDITION_TYPE_DAY}, params...) + return b.AddTimeBasedWhere(p...) +} +func (b *Builder) WhereMonth(params ...interface{}) *Builder { + p := append([]interface{}{CONDITION_TYPE_MONTH}, params...) + return b.AddTimeBasedWhere(p...) +} +func (b *Builder) WhereYear(params ...interface{}) *Builder { + p := append([]interface{}{CONDITION_TYPE_YEAR}, params...) + return b.AddTimeBasedWhere(p...) +} + +/* +WhereNested Add a nested where statement to the query. +*/ +func (b *Builder) WhereNested(params ...interface{}) *Builder { + paramsLength := len(params) + if paramsLength == 1 { + params = append(params, BOOLEAN_AND) + } + cb := CloneBuilderWithTable(b) + switch params[0].(type) { + case Where: + cb.Wheres = append(cb.Wheres, params[0].(Where)) + case []Where: + cb.Wheres = append(cb.Wheres, params[0].([]Where)...) + case [][]interface{}: + tp := params[0].([][]interface{}) + for i := 0; i < len(tp); i++ { + cb.Where(tp[i]...) + } + case []interface{}: + cb.Where(params[0].([]interface{})) + case func(builder *Builder): + var boolean string + if paramsLength > 1 { + boolean = params[1].(string) + } else { + boolean = BOOLEAN_AND + } + closure := params[0].(func(builder *Builder)) + closure(cb) + return b.AddNestedWhereQuery(cb, boolean) + } + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_NESTED, + Boolean: params[1].(string), + Value: cb, + }) + b.Components[TYPE_WHERE] = struct{}{} + return b +} +func (b *Builder) WhereSub(column string, operator string, value func(builder *Builder), boolean string) *Builder { + cb := CloneBuilderWithTable(b) + value(cb) + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_SUB, + Operator: operator, + Value: cb, + Column: column, + Boolean: boolean, + }) + b.Components[TYPE_WHERE] = struct{}{} + b.AddBinding(cb.GetBindings(), TYPE_WHERE) + return b +} + +//WhereExists Add an exists clause to the query. +// 1. WhereExists(cb,"and",false) +// 2. WhereExists(cb,"and") +// 3. WhereExists(cb) +func (b *Builder) WhereExists(cb func(builder *Builder), params ...interface{}) *Builder { + newBuilder := CloneBuilderWithTable(b) + cb(newBuilder) + boolean := BOOLEAN_AND + not := false + switch len(params) { + case 1: + boolean = params[0].(string) + case 2: + boolean = params[0].(string) + not = params[1].(bool) + } + + return b.AddWhereExistsQuery(newBuilder, boolean, not) +} + +/* +OrWhereExists Add an exists clause to the query. +*/ +func (b *Builder) OrWhereExists(cb func(builder *Builder), params ...interface{}) *Builder { + not := false + if len(params) > 0 { + not = params[0].(bool) + } + return b.WhereExists(cb, BOOLEAN_OR, not) +} + +/* +WhereNotExists Add a where not exists clause to the query. +*/ +func (b *Builder) WhereNotExists(cb func(builder *Builder), params ...interface{}) *Builder { + boolean := BOOLEAN_AND + if len(params) > 0 { + boolean = params[0].(string) + } + return b.WhereExists(cb, boolean, true) +} + +/* +OrWhereNotExists Add a where not exists clause to the query. +*/ +func (b *Builder) OrWhereNotExists(cb func(builder *Builder), params ...interface{}) *Builder { + return b.OrWhereExists(cb, true) +} + +// AddWhereExistsQuery Add an exists clause to the query. +func (b *Builder) AddWhereExistsQuery(builder *Builder, boolean string, not bool) *Builder { + var n bool + if not { + n = true + } else { + n = false + } + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_EXIST, + Query: builder, + Boolean: boolean, + Not: n, + }) + b.Components[TYPE_WHERE] = struct{}{} + b.AddBinding(builder.GetBindings(), TYPE_WHERE) + return b +} + +/* +GroupBy Add a "group by" clause to the query. +column operator value boolean +*/ +func (b *Builder) GroupBy(columns ...interface{}) *Builder { + for _, column := range columns { + b.Groups = append(b.Groups, column) + } + b.Components[TYPE_GROUP_BY] = struct{}{} + return b +} + +/* +GroupByRaw Add a raw groupBy clause to the query. +*/ +func (b *Builder) GroupByRaw(sql string, bindings ...[]interface{}) *Builder { + b.Groups = append(b.Groups, Expression(sql)) + if len(bindings) > 0 { + b.AddBinding(bindings[0], TYPE_GROUP_BY) + } + b.Components[TYPE_GROUP_BY] = struct{}{} + + return b +} + +/* +Having Add a "having" clause to the query. +column operator value boolean +*/ +func (b *Builder) Having(params ...interface{}) *Builder { + havingBoolean := BOOLEAN_AND + havingOperator := "=" + var havingValue interface{} + var havingColumn string + length := len(params) + switch length { + case 2: + havingColumn = params[0].(string) + havingValue = params[1] + case 3: + havingColumn = params[0].(string) + havingOperator = params[1].(string) + havingValue = params[2] + case 4: + havingColumn = params[0].(string) + havingOperator = params[1].(string) + havingValue = params[2] + havingBoolean = params[3].(string) + } + + having := Having{ + HavingType: CONDITION_TYPE_BASIC, + HavingColumn: havingColumn, + HavingOperator: havingOperator, + HavingValue: havingValue, + HavingBoolean: havingBoolean, + } + b.AddBinding([]interface{}{havingValue}, TYPE_HAVING) + b.Components[TYPE_HAVING] = struct{}{} + b.Havings = append(b.Havings, having) + return b +} + +/* +HavingRaw Add a raw having clause to the query. +*/ +func (b *Builder) HavingRaw(params ...interface{}) *Builder { + length := len(params) + havingBoolean := BOOLEAN_AND + var expression Expression + switch length { + case 1: + if expr, ok := params[0].(Expression); ok { + expression = expr + } else { + expression = Expression(params[0].(string)) + } + case 2: + if expr, ok := params[0].(Expression); ok { + expression = expr + } else { + expression = Expression(params[0].(string)) + } + b.AddBinding(params[1].([]interface{}), TYPE_HAVING) + case 3: + if expr, ok := params[0].(Expression); ok { + expression = expr + } else { + expression = Expression(params[0].(string)) + } + b.AddBinding(params[1].([]interface{}), TYPE_HAVING) + havingBoolean = params[2].(string) + } + having := Having{ + HavingType: CONDITION_TYPE_RAW, + HavingValue: expression, + HavingBoolean: havingBoolean, + RawSql: expression, + } + b.Components[TYPE_HAVING] = struct{}{} + b.Havings = append(b.Havings, having) + return b +} + +/* +OrHavingRaw Add a raw having clause to the query. +*/ +func (b *Builder) OrHavingRaw(params ...interface{}) *Builder { + bindings := []interface{}{} + if len(params) == 2 { + bindings = params[1].([]interface{}) + } + return b.HavingRaw(params[0], bindings, BOOLEAN_OR) +} + +/* +OrHaving Add an "or having" clause to the query. +*/ +func (b *Builder) OrHaving(params ...interface{}) *Builder { + return b.Having(params[0], "=", params[1], BOOLEAN_OR) +} + +/* +HavingBetween Add a "having between " clause to the query. +*/ +func (b *Builder) HavingBetween(column string, params ...interface{}) *Builder { + var values []interface{} + boolean := BOOLEAN_AND + not := false + length := len(params) + switch length { + case 1: + values = params[0].([]interface{})[0:2] + case 2: + values = params[0].([]interface{})[0:2] + boolean = params[1].(string) + case 3: + values = params[0].([]interface{})[0:2] + boolean = params[1].(string) + not = params[2].(bool) + } + having := Having{ + HavingType: CONDITION_TYPE_BETWEEN, + HavingColumn: column, + HavingValue: values, + HavingBoolean: boolean, + Not: not, + } + b.Components[TYPE_HAVING] = struct{}{} + b.AddBinding(values, TYPE_HAVING) + + b.Havings = append(b.Havings, having) + return b +} + +/* +OrderBy Add an "order by" clause to the query. +*/ +func (b *Builder) OrderBy(params ...interface{}) *Builder { + var order = ORDER_ASC + if r, ok := params[0].(Expression); ok { + b.Orders = append(b.Orders, Order{ + RawSql: r, + OrderType: CONDITION_TYPE_RAW, + }) + b.Components[TYPE_ORDER] = struct{}{} + + return b + } + if len(params) > 1 { + order = params[1].(string) + } + if order != ORDER_ASC && order != ORDER_DESC { + panic(errors.New("wrong order direction: " + order)) + } + b.Orders = append(b.Orders, Order{ + Direction: order, + Column: params[0].(string), + }) + b.Components[TYPE_ORDER] = struct{}{} + + return b +} + +/* +OrderByRaw Add a raw "order by" clause to the query. +*/ +func (b *Builder) OrderByRaw(sql string, bindings []interface{}) *Builder { + b.Orders = append(b.Orders, Order{ + OrderType: CONDITION_TYPE_RAW, + RawSql: Raw(sql), + }) + b.Components[TYPE_ORDER] = struct{}{} + b.AddBinding(bindings, TYPE_ORDER) + return b +} +func (b *Builder) OrderByDesc(column string) *Builder { + return b.OrderBy(column, ORDER_DESC) +} + +/* +ReOrder Remove all existing orders and optionally add a new order. +*/ +func (b *Builder) ReOrder(params ...string) *Builder { + b.Orders = nil + b.Bindings["order"] = nil + delete(b.Components, TYPE_ORDER) + length := len(params) + if length == 1 { + b.OrderBy(InterfaceToSlice(params[0])) + } else if length == 2 { + b.OrderBy(InterfaceToSlice(params[0:2])...) + } + return b +} + +/* +Limit Set the "limit" value of the query. +*/ +func (b *Builder) Limit(n int) *Builder { + b.Components[TYPE_LIMIT] = struct{}{} + b.LimitNum = int(math.Max(0, float64(n))) + return b +} + +/* +Offset Set the "offset" value of the query. +*/ +func (b *Builder) Offset(n int) *Builder { + b.OffsetNum = int(math.Max(0, float64(n))) + b.Components[TYPE_OFFSET] = struct{}{} + return b +} + +/* +Union Add a union statement to the query. +*/ +//func (b *Builder) Union(n int) *Builder { +// b.OffsetNum = n +// return b +//} + +/* +Lock Lock the selected rows in the table for updating. +*/ +func (b *Builder) Lock(lock ...interface{}) *Builder { + if len(lock) == 0 { + b.LockMode = true + } else { + b.LockMode = lock[0] + } + b.Components[TYPE_LOCK] = struct{}{} + return b +} + +func (b *Builder) WhereMap(params map[string]interface{}) *Builder { + for key, param := range params { + b.Where(key, "=", param) + } + return b +} + +//AddNestedWhereQuery Add another query builder as a nested where to the query builder. +func (b *Builder) AddNestedWhereQuery(builder *Builder, boolean string) *Builder { + + if len(builder.Wheres) > 0 { + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_NESTED, + Value: builder, + Boolean: boolean, + }) + b.AddBinding(builder.GetRawBindings()[TYPE_WHERE], TYPE_WHERE) + b.Components[TYPE_WHERE] = struct{}{} + } + return b +} + +/* +First Execute the query and get the first result. +*/ +func (b *Builder) First(dest interface{}, columns ...interface{}) (result sql.Result, err error) { + b.Limit(1) + return b.Get(dest, columns...) +} + +/* +RunSelect run the query as a "select" statement against the connection. +*/ +func (b *Builder) RunSelect() (result sql.Result, err error) { + result, err = b.Run(b.ToSql(), b.GetBindings(), func() (result sql.Result, err error) { + if b.Pretending { + return + } + if b.Tx != nil { + result, err = b.Tx.Select(b.Grammar.CompileSelect(), b.GetBindings(), b.Dest) + } else { + result, err = b.Connection.Select(b.Grammar.CompileSelect(), b.GetBindings(), b.Dest) + } + return + }) + + return +} + +func (b *Builder) GetConnection() IConnection { + if b.Tx != nil { + return b.Tx + } + + return b.Connection +} + +func (b *Builder) Run(query string, bindings []interface{}, callback func() (result sql.Result, err error)) (result sql.Result, err error) { + defer func() { + catchedErr := recover() + if catchedErr != nil { + switch catchedErr.(type) { + case string: + err = errors.New(catchedErr.(string)) + case error: + err = catchedErr.(error) + default: + err = errors.New("unknown panic") + } + } + }() + result, err = callback() + + return +} +func (b *Builder) WithPivot(columns ...string) *Builder { + b.Pivots = append(b.Pivots, columns...) + return b +} + +/* +Exists Determine if any rows exist for the current query. +*/ +func (b *Builder) Exists() (exists bool, err error) { + _, err = b.Connection.Select(b.Grammar.CompileExists(), b.GetBindings(), &exists) + if err != nil { + return false, err + } + return exists, nil +} + +/* +DoesntExist Determine if no rows exist for the current query. +*/ +func (b *Builder) DoesntExist() (notExists bool, err error) { + e, err := b.Exists() + if err != nil { + return false, err + } + return !e, nil +} + +/* +Aggregate Execute an aggregate function on the database. + +*/ +func (b *Builder) Aggregate(dest interface{}, fn string, column ...string) (result sql.Result, err error) { + b.Dest = dest + if column == nil { + column = append(column, "*") + } + b.Aggregates = append(b.Aggregates, Aggregate{ + AggregateName: fn, + AggregateColumn: column[0], + }) + b.Components[TYPE_AGGREGRATE] = struct{}{} + result, err = b.RunSelect() + return +} + +/* +Count Retrieve the "count" result of the query. +*/ +func (b *Builder) Count(dest interface{}, column ...string) (result sql.Result, err error) { + return b.Aggregate(dest, "count", column...) +} + +/* +Min Retrieve the minimum value of a given column. +*/ +func (b *Builder) Min(dest interface{}, column ...string) (result sql.Result, err error) { + + return b.Aggregate(dest, "min", column...) + +} + +/* +Max Retrieve the maximum value of a given column. +*/ +func (b *Builder) Max(dest interface{}, column ...string) (result sql.Result, err error) { + + return b.Aggregate(dest, "max", column...) + +} + +/* +Avg Alias for the "avg" method. +*/ +func (b *Builder) Avg(dest interface{}, column ...string) (result sql.Result, err error) { + + return b.Aggregate(dest, "avg", column...) + +} +func (b *Builder) Sum(dest interface{}, column ...string) (result sql.Result, err error) { + + return b.Aggregate(dest, "sum", column...) + +} +func (b *Builder) ForPage(page, perPage int64) *Builder { + + b.Offset(int((page - 1) * perPage)).Limit(int(perPage)) + return b +} +func (b *Builder) Only(columns ...string) *Builder { + b.OnlyColumns = make(map[string]interface{}, len(columns)) + for i := 0; i < len(columns); i++ { + b.OnlyColumns[columns[i]] = nil + } + return b +} +func (b *Builder) Except(columns ...string) *Builder { + b.ExceptColumns = make(map[string]interface{}, len(columns)) + for i := 0; i < len(columns); i++ { + b.ExceptColumns[columns[i]] = nil + } + return b +} +func (b *Builder) FileterColumn(column string) bool { + if b.OnlyColumns != nil { + if _, ok := b.OnlyColumns[column]; !ok { + return false + } + } + if b.ExceptColumns != nil { + if _, ok := b.ExceptColumns[column]; ok { + return false + } + } + return true +} + +func PrepareInsertValues(values interface{}) []map[string]interface{} { + rv := reflect.ValueOf(values) + var items []map[string]interface{} + if rv.Kind() == reflect.Ptr { + rv = reflect.Indirect(rv) + } + if rv.Kind() == reflect.Map { + items = append(items, rv.Interface().(map[string]interface{})) + } else if rv.Kind() == reflect.Slice { + eleType := rv.Type().Elem() + if eleType.Kind() == reflect.Ptr { + switch eleType.Elem().Kind() { + case reflect.Struct: + for i := 0; i < rv.Len(); i++ { + items = append(items, ExtractStruct(rv.Index(i).Elem().Interface())) + } + case reflect.Map: + for i := 0; i < rv.Len(); i++ { + items = append(items, rv.Index(i).Elem().Interface().(map[string]interface{})) + } + } + } else if eleType.Kind() == reflect.Map { + if tv, ok := values.(*[]map[string]interface{}); ok { + items = *tv + } else { + items = values.([]map[string]interface{}) + } + } else if eleType.Kind() == reflect.Struct { + for i := 0; i < rv.Len(); i++ { + items = append(items, ExtractStruct(rv.Index(i).Interface())) + } + } + } else if rv.Kind() == reflect.Struct { + items = append(items, ExtractStruct(rv.Interface())) + } + return items +} + +/* +Insert new records into the database. +*/ +func (b *Builder) Insert(values interface{}) (result sql.Result, err error) { + items := PrepareInsertValues(values) + b.ApplyBeforeQueryCallBacks(TYPE_INSERT, items...) + return b.Run(b.Grammar.CompileInsert(items), b.GetBindings(), func() (result sql.Result, err error) { + if b.Pretending { + return + } + if b.Tx != nil { + result, err = b.Tx.Insert(b.PreparedSql, b.GetBindings()) + } else { + result, err = b.Connection.Insert(b.PreparedSql, b.GetBindings()) + } + defer b.ApplyAfterQueryCallBacks(TYPE_INSERT, result, err, items...) + return + }) +} + +/* +InsertGetId Insert a new record and get the value of the primary key. +*/ +func (b *Builder) InsertGetId(values interface{}) (int64, error) { + insert, err := b.Insert(values) + if err != nil { + return 0, err + } + id, _ := insert.LastInsertId() + return id, nil +} + +/* +InsertOrIgnore Insert a new record and get the value of the primary key. +*/ +func (b *Builder) InsertOrIgnore(values interface{}) (result sql.Result, err error) { + items := PrepareInsertValues(values) + + b.ApplyBeforeQueryCallBacks(TYPE_INSERT, items...) + return b.Run(b.Grammar.CompileInsertOrIgnore(items), b.GetBindings(), func() (result sql.Result, err error) { + if b.Pretending { + return + } + if b.Tx != nil { + result, err = b.Tx.Insert(b.PreparedSql, b.GetBindings()) + } else { + result, err = b.Connection.Insert(b.PreparedSql, b.GetBindings()) + } + defer b.ApplyAfterQueryCallBacks(TYPE_INSERT, result, err, items...) + return + }) +} +func (b *Builder) Update(v map[string]interface{}) (result sql.Result, err error) { + b.ApplyBeforeQueryCallBacks(TYPE_UPDATE, v) + return b.Run(b.Grammar.CompileUpdate(v), b.GetBindings(), func() (result sql.Result, err error) { + if b.Pretending { + return + } + if b.Tx != nil { + result, err = b.Tx.Update(b.PreparedSql, b.GetBindings()) + } else { + result, err = b.Connection.Update(b.PreparedSql, b.GetBindings()) + } + defer b.ApplyAfterQueryCallBacks(TYPE_UPDATE, result, err, v) + return + }) +} + +/* +UpdateOrInsert insert or update a record matching the attributes, and fill it with values. +*/ +func (b *Builder) UpdateOrInsert(conditions map[string]interface{}, values map[string]interface{}) (updated bool, err error) { + exist, err := b.Where(conditions).Exists() + if err != nil { + return + } + if !exist { + for k, v := range values { + conditions[k] = v + } + b.Reset(TYPE_WHERE) + _, err = b.Insert(conditions) + if err != nil { + return + } + return false, nil + } else { + _, err = b.Limit(1).Update(values) + if err != nil { + return + } + return true, nil + } +} + +/* +Decrement Decrement a column's value by a given amount. +*/ +func (b *Builder) Decrement(column string, amount int, extra ...map[string]interface{}) (result sql.Result, err error) { + + var update map[string]interface{} + wrapped := b.Grammar.Wrap(column) + + if len(extra) == 0 { + update = make(map[string]interface{}) + } else { + update = extra[0] + } + update[column] = Expression(fmt.Sprintf("%s - %d", wrapped, amount)) + + return b.Update(update) +} + +/* +Increment Increment a column's value by a given amount. +*/ +func (b *Builder) Increment(column string, amount int, extra ...map[string]interface{}) (result sql.Result, err error) { + var update map[string]interface{} + wrapped := b.Grammar.Wrap(column) + + if len(extra) == 0 { + update = make(map[string]interface{}) + } else { + update = extra[0] + } + update[column] = Expression(fmt.Sprintf("%s + %d", wrapped, amount)) + return b.Update(update) +} + +/* +InRandomOrder Put the query's results in random order. +*/ +func (b *Builder) InRandomOrder(seed string) *Builder { + return b.OrderByRaw(b.Grammar.CompileRandom(seed), []interface{}{}) +} + +/* +Delete Delete records from the database. +//TODO: Delete(1) Delete(1,2,3) Delete([]interface{}{1,2,3}) +*/ +func (b *Builder) Delete(id ...interface{}) (result sql.Result, err error) { + if len(id) > 0 { + b.Where("id", id[0]) + } + b.ApplyBeforeQueryCallBacks(TYPE_DELETE) + return b.Run(b.Grammar.CompileDelete(), b.GetBindings(), func() (result sql.Result, err error) { + if b.Pretending { + return + } + if b.Tx != nil { + result, err = b.Tx.Delete(b.PreparedSql, b.GetBindings()) + } else { + result, err = b.Connection.Delete(b.PreparedSql, b.GetBindings()) + } + defer b.ApplyAfterQueryCallBacks(TYPE_DELETE, result, err) + return + }) + +} +func (b *Builder) Raw() *sql.DB { + return b.Connection.GetDB() +} + +/* +Get Execute the query as a "select" statement. +*/ +func (b *Builder) Get(dest interface{}, columns ...interface{}) (result sql.Result, err error) { + if len(columns) > 0 { + b.Select(columns...) + } + b.Dest = dest + + result, err = b.RunSelect() + + return +} + +/* +Pluck Get a collection instance containing the values of a given column. +*/ +func (b *Builder) Pluck(dest interface{}, params interface{}) (sql.Result, error) { + return b.Get(dest, params) +} + +/* +When Apply the callback if the given "value" is truthy. + + 1. When(true,func(builder *Builder)) + 2. When(true,func(builder *Builder),func(builder *Builder)) //with default callback + +*/ +func (b *Builder) When(boolean bool, cb ...func(builder *Builder)) *Builder { + if boolean { + cb[0](b) + } else if len(cb) == 2 { + //if false and we have default callback + cb[1](b) + } + return b +} +func (b *Builder) Value(dest interface{}, column string) (sql.Result, error) { + return b.First(dest, column) +} + +/*Reset +reset bindings and components +*/ +func (b *Builder) Reset(targets ...string) *Builder { + for _, componentName := range targets { + switch componentName { + case TYPE_ORDER: + delete(b.Components, TYPE_ORDER) + delete(b.Bindings, TYPE_ORDER) + b.Orders = nil + case TYPE_LIMIT: + delete(b.Bindings, TYPE_LIMIT) + delete(b.Components, TYPE_LIMIT) + b.LimitNum = 0 + case TYPE_OFFSET: + delete(b.Bindings, TYPE_OFFSET) + delete(b.Components, TYPE_OFFSET) + b.OffsetNum = 0 + case TYPE_WHERE: + delete(b.Bindings, TYPE_WHERE) + delete(b.Components, TYPE_WHERE) + b.Wheres = nil + case TYPE_GROUP_BY: + delete(b.Bindings, TYPE_GROUP_BY) + delete(b.Components, TYPE_GROUP_BY) + b.Groups = nil + case TYPE_COLUMN: + delete(b.Bindings, TYPE_COLUMN) + delete(b.Components, TYPE_COLUMN) + b.Columns = nil + } + } + return b +} +func (b *Builder) EnableLogQuery() *Builder { + b.LoggingQueries = true + return b +} +func (b *Builder) DisableLogQuery() *Builder { + b.LoggingQueries = false + return b +} + +/* +Register a closure to be invoked before the query is executed. +*/ +func (b *Builder) Before(callback func(builder *Builder, t string, data ...map[string]interface{})) *Builder { + b.BeforeQueryCallBacks = append(b.BeforeQueryCallBacks, callback) + return b +} + +/* +Register a closure to be invoked after the query is executed. +*/ +func (b *Builder) After(callback func(builder *Builder, t string, res sql.Result, err error, data ...map[string]interface{})) *Builder { + b.AfterQueryCallBacks = append(b.AfterQueryCallBacks, callback) + return b +} + +/*ApplyBeforeQueryCallBacks +Invoke the "before query" modification callbacks. +*/ +func (b *Builder) ApplyBeforeQueryCallBacks(t string, items ...map[string]any) { + for _, cb := range b.BeforeQueryCallBacks { + cb(b, t, items...) + } + + b.BeforeQueryCallBacks = nil +} + +func (b *Builder) ApplyAfterQueryCallBacks(t string, res sql.Result, err error, items ...map[string]any) { + for _, cb := range b.AfterQueryCallBacks { + cb(b, t, res, err, items...) + } + + b.AfterQueryCallBacks = nil +} + +/* +SetAggregate Set the aggregate property without running the query. +*/ +func (b *Builder) SetAggregate(function string, column ...string) *Builder { + if len(column) == 0 { + column = append(column, "*") + } + b.Aggregates = append(b.Aggregates, Aggregate{ + AggregateName: function, + AggregateColumn: column[0], + }) + b.Components[TYPE_AGGREGRATE] = struct{}{} + if len(b.Groups) == 0 { + b.Reset(TYPE_ORDER) + } + return b +} + +/* +GetCountForPagination Get the count of the total records for the paginator. +*/ +func (b *Builder) GetCountForPagination() (total int) { + + return +} + +func (b *Builder) Chunk(dest interface{}, chunkSize int64, callback func(dest interface{}) error) (err error) { + if len(b.Orders) == 0 { + panic(errors.New("must specify an orderby clause when using this method")) + } + var page int64 = 1 + var count int64 = 0 + tempDest := reflect.New(reflect.Indirect(reflect.ValueOf(dest)).Type()).Interface() + get, err := b.ForPage(1, chunkSize).Get(tempDest) + if err != nil { + return + } + count, _ = get.RowsAffected() + for count > 0 { + err = callback(tempDest) + if err != nil { + return + } + if count != chunkSize { + break + } else { + page++ + tempDest = reflect.New(reflect.Indirect(reflect.ValueOf(dest)).Type()).Interface() + get, err = b.ForPage(page, chunkSize).Get(tempDest) + if err != nil { + return err + } + count, _ = get.RowsAffected() + } + } + return nil +} + +func (b *Builder) Pretend() *Builder { + b.Pretending = true + return b +} + +/* +Implode concatenate values of a given column as a string. +*/ +func (b *Builder) Implode(column string, glue ...string) (string, error) { + var dest []string + _, err := b.Pluck(&dest, column) + if err != nil { + return "", nil + } + var sep string + if len(glue) == 0 { + sep = "" + } else { + sep = glue[0] + } + + return strings.Join(dest, sep), nil +} + +/* +WhereRowValues Adds a where condition using row values. +*/ +func (b *Builder) WhereRowValues(columns []string, operator string, values []interface{}, params ...string) *Builder { + if len(columns) != len(values) { + panic(errors.New("argements number mismatch")) + } + var boolean = BOOLEAN_AND + if len(params) == 1 { + boolean = params[0] + } + b.Components[TYPE_WHERE] = struct{}{} + b.Wheres = append(b.Wheres, Where{ + Type: CONDITION_TYPE_ROW_VALUES, + Operator: operator, + Columns: columns, + Values: values, + Boolean: boolean, + }) + return b +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..bd5998a --- /dev/null +++ b/config.go @@ -0,0 +1,34 @@ +package db + +type DBConfig struct { + Driver string + //ReadHost []string + //WriteHost []string + Host string + Port string + Database string + Username string + Password string + Charset string + Prefix string + ConnMaxLifetime int + ConnMaxIdleTime int + MaxIdleConns int + MaxOpenConns int + ParseTime bool + // mysql + Collation string + UnixSocket string + MultiStatements bool + Dsn string + // pgsql + Sslmode string + TLS string + EnableLog bool + // sqlite3 + File string + Journal string // DELETE TRUNCATE PERSIST MEMORY WAL OFF + Locking string // NORMAL EXCLUSIVE + Mode string // ro rw rwc memory + Synchronous int // 0 OFF | 1 NORMAL | 2 FULL | 3 EXTRA +} diff --git a/connection.go b/connection.go new file mode 100644 index 0000000..1deac5f --- /dev/null +++ b/connection.go @@ -0,0 +1,162 @@ +package db + +import ( + "database/sql" + "errors" +) + +type Connection struct { + RxDB *sql.DB + DB *sql.DB + Config *DBConfig + ConnectionName string +} + +type IConnection interface { + Insert(query string, bindings []interface{}) (result sql.Result, err error) + Select(query string, bindings []interface{}, dest interface{}) (result sql.Result, err error) + Update(query string, bindings []interface{}) (result sql.Result, err error) + Delete(query string, bindings []interface{}) (result sql.Result, err error) + AffectingStatement(query string, bindings []interface{}) (result sql.Result, err error) + Statement(query string, bindings []interface{}) (sql.Result, error) + Table(tableName string) *Builder +} + +type Preparer interface { + Prepare(query string) (*sql.Stmt, error) +} + +type Execer interface { + Exec(query string, args ...interface{}) (sql.Result, error) +} + +type ITransaction interface { + BeginTransaction() (*Transaction, error) + Transaction(closure TxClosure) (interface{}, error) +} + +func (c *Connection) Select(query string, bindings []interface{}, dest interface{}) (result sql.Result, err error) { + var stmt *sql.Stmt + var rows *sql.Rows + stmt, err = c.DB.Prepare(query) + if err != nil { + return + } + defer stmt.Close() + rows, err = stmt.Query(bindings...) + if err != nil { + return + } + defer rows.Close() + + return ScanResult(rows, dest), nil +} + +func (c *Connection) BeginTransaction() (*Transaction, error) { + begin, err := c.DB.Begin() + if err != nil { + return nil, errors.New(err.Error()) + } + + tx := &Transaction{ + Rx: c.RxDB, + Tx: begin, + Config: c.Config, + ConnectionName: c.ConnectionName, + } + + return tx, nil +} + +func (c *Connection) Transaction(closure TxClosure) (interface{}, error) { + begin, err := c.DB.Begin() + if err != nil { + panic(err.Error()) + } + defer func() { + if err := recover(); err != nil { + _ = begin.Rollback() + panic(err) + } else { + err = begin.Commit() + } + }() + + tx := &Transaction{ + Rx: c.RxDB, + Tx: begin, + Config: c.Config, + ConnectionName: c.ConnectionName, + } + + return closure(tx) +} + +func (c *Connection) Insert(query string, bindings []interface{}) (result sql.Result, err error) { + return c.AffectingStatement(query, bindings) +} + +func (c *Connection) Update(query string, bindings []interface{}) (result sql.Result, err error) { + return c.AffectingStatement(query, bindings) + +} + +func (c *Connection) Delete(query string, bindings []interface{}) (result sql.Result, err error) { + return c.AffectingStatement(query, bindings) +} + +func (c *Connection) AffectingStatement(query string, bindings []interface{}) (result sql.Result, err error) { + stmt, errP := c.DB.Prepare(query) + if errP != nil { + err = errP + return + } + defer stmt.Close() + result, err = stmt.Exec(bindings...) + if err != nil { + return + } + + return +} + +func (c *Connection) Table(tableName string) *Builder { + builder := NewBuilder(c) + if c.GetConfig().Driver == DriverMysql { + builder.Grammar = &MysqlGrammar{} + } else if c.Config.Driver == DriverSqlite3 { + builder.Grammar = &Sqlite3Grammar{} + } else { + panic("不支持的数据库类型") + } + builder.Grammar.SetTablePrefix(c.Config.Prefix) + builder.Grammar.SetBuilder(builder) + builder.From(tableName) + return builder +} + +func (c *Connection) Statement(query string, bindings []interface{}) (sql.Result, error) { + return c.AffectingStatement(query, bindings) +} + +func (c *Connection) GetDB() *sql.DB { + return c.DB +} + +func (c *Connection) Query() *Builder { + builder := NewBuilder(c) + if c.GetConfig().Driver == DriverMysql { + builder.Grammar = &MysqlGrammar{} + } else if c.Config.Driver == DriverSqlite3 { + builder.Grammar = &Sqlite3Grammar{} + } else { + panic("不支持的数据库类型") + } + builder.Grammar.SetBuilder(builder) + builder.Grammar.SetTablePrefix(c.Config.Prefix) + return builder +} + +func (c *Connection) GetConfig() *DBConfig { + return c.Config +} diff --git a/connection_factory.go b/connection_factory.go new file mode 100644 index 0000000..23a2341 --- /dev/null +++ b/connection_factory.go @@ -0,0 +1,45 @@ +package db + +import ( + "errors" + "fmt" +) + +const ( + DriverMysql = "mysql" + DriverSqlite3 = "sqlite3" +) + +type ConnectionFactory struct { +} + +type Connector interface { + connect(config *DBConfig) *Connection +} + +func (f ConnectionFactory) Make(config *DBConfig) *Connection { + return f.MakeConnection(config) +} + +func (f ConnectionFactory) MakeConnection(config *DBConfig) *Connection { + return f.CreateConnection(config) +} + +func (f ConnectionFactory) CreateConnection(config *DBConfig) *Connection { + switch config.Driver { + case DriverMysql: + connector := MysqlConnector{} + conn := connector.connect(config) + return conn + case DriverSqlite3: + connector := SqliteConnector{} + conn := connector.connect(config) + return conn + case "": + panic(errors.New("a driver must be specified")) + default: + panic(errors.New(fmt.Sprintf("unsupported driver:%s", config.Driver))) + } + + return nil +} diff --git a/db.go b/db.go new file mode 100644 index 0000000..217fdef --- /dev/null +++ b/db.go @@ -0,0 +1,58 @@ +package db + +import ( + "database/sql" +) + +var Engine *DB + +type DB struct { + DatabaseManager + LogFunc func(log Log) +} + +func (d *DB) SetLogger(f func(log Log)) *DB { + d.LogFunc = f + return d +} + +func Open(config map[string]DBConfig) *DB { + var configP = make(map[string]*DBConfig) + + for name := range config { + c := config[name] + configP[name] = &c + } + + db := DB{ + DatabaseManager: DatabaseManager{ + Configs: configP, + Connections: make(map[string]*Connection), + }, + } + + db.Connection("default") + + Engine = &db + + return Engine +} + +func (d *DB) AddConfig(name string, config *DBConfig) *DB { + Engine.Configs[name] = config + return d +} + +func (d *DB) GetConfigs() map[string]*DBConfig { + return Engine.Configs +} + +func (*DB) Raw(connectionName ...string) *sql.DB { + if len(connectionName) > 0 { + c := Engine.Connection(connectionName[0]) + return (*c).GetDB() + } else { + c := Engine.Connection("default") + return (*c).GetDB() + } +} diff --git a/db_manager.go b/db_manager.go new file mode 100644 index 0000000..4489854 --- /dev/null +++ b/db_manager.go @@ -0,0 +1,108 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" +) + +type DatabaseManager struct { + Connections map[string]*Connection + Factory ConnectionFactory + Configs map[string]*DBConfig +} + +func (dm *DatabaseManager) Connection(connectionName string) *Connection { + connection, ok := dm.Connections[connectionName] + if !ok { + dm.Connections[connectionName] = dm.MakeConnection(connectionName) + connection, _ = dm.Connections[connectionName] + } + return connection +} + +func (dm *DatabaseManager) MakeConnection(connectionName string) *Connection { + config, ok := dm.Configs[connectionName] + + if !ok { + panic(errors.New(fmt.Sprintf("Database connection %s not configured.", connectionName))) + } + + conn := dm.Factory.Make(config) + conn.ConnectionName = connectionName + dm.Connections[connectionName] = conn + return conn +} + +func (dm *DatabaseManager) getDefaultConnection() (defaultConnectionName string) { + defaultConnectionName = "default" + return +} + +func (dm *DatabaseManager) Table(params ...string) *Builder { + defaultConn := dm.getDefaultConnection() + c := dm.Connection(defaultConn) + builder := NewBuilder(c) + if c.Config.Driver == DriverMysql { + builder.Grammar = &MysqlGrammar{} + } else if c.Config.Driver == DriverSqlite3 { + builder.Grammar = &Sqlite3Grammar{} + } else { + panic("不支持的数据库类型") + } + builder.Grammar.SetTablePrefix(dm.Configs[defaultConn].Prefix) + builder.Grammar.SetBuilder(builder) + builder.Table(params...) + return builder +} + +func (dm *DatabaseManager) Select(query string, bindings []interface{}, dest interface{}) (sql.Result, error) { + ic := dm.Connections["default"] + return (*ic).Select(query, bindings, dest) +} + +func (dm *DatabaseManager) Insert(query string, bindings []interface{}) (sql.Result, error) { + ic := dm.Connections["default"] + return (*ic).Insert(query, bindings) +} + +func (dm *DatabaseManager) Update(query string, bindings []interface{}) (sql.Result, error) { + ic := dm.Connections["default"] + return (*ic).Update(query, bindings) +} + +func (dm *DatabaseManager) Delete(query string, bindings []interface{}) (sql.Result, error) { + ic := dm.Connections["default"] + return (*ic).Delete(query, bindings) +} + +func (dm *DatabaseManager) Statement(query string, bindings []interface{}) (sql.Result, error) { + ic := dm.Connections["default"] + return (*ic).Delete(query, bindings) +} + +func (dm *DatabaseManager) Query() *Builder { + defaultConn := dm.getDefaultConnection() + c := dm.Connection(defaultConn) + builder := NewBuilder(c) + if c.Config.Driver == DriverMysql { + builder.Grammar = &MysqlGrammar{} + } else if c.Config.Driver == DriverSqlite3 { + builder.Grammar = &Sqlite3Grammar{} + } else { + panic("不支持的数据库类型") + } + builder.Grammar.SetTablePrefix(dm.Configs[defaultConn].Prefix) + builder.Grammar.SetBuilder(builder) + return builder +} + +func (dm *DatabaseManager) Transaction(closure TxClosure) (interface{}, error) { + ic := dm.Connections["default"] + return (*ic).Transaction(closure) +} + +func (dm *DatabaseManager) BeginTransaction() (*Transaction, error) { + ic := dm.Connections["default"] + return (*ic).BeginTransaction() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..52c5e7f --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module git.fsdpf.net/go/db + +go 1.18 + +require ( + github.com/go-sql-driver/mysql v1.7.0 + github.com/mattn/go-sqlite3 v1.14.16 + github.com/samber/lo v1.38.1 +) + +require golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e4a9a27 --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= +github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= +golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= diff --git a/gramar.go b/gramar.go new file mode 100644 index 0000000..9b7138b --- /dev/null +++ b/gramar.go @@ -0,0 +1,34 @@ +package db + +type IGrammar interface { + SetTablePrefix(prefix string) + + GetTablePrefix() string + + SetBuilder(builder *Builder) + + GetBuilder() *Builder + + CompileInsert([]map[string]interface{}) string + + CompileInsertOrIgnore([]map[string]interface{}) string + + CompileDelete() string + + CompileUpdate(map[string]interface{}) string + + CompileSelect() string + + CompileExists() string + + Wrap(interface{}, ...bool) string + + WrapTable(interface{}) string + + CompileComponentWheres() string + + CompileComponentJoins() string + + CompileRandom(seed string) string + //Wrap(value string, b *query.Builder) string +} diff --git a/mysql_connector.go b/mysql_connector.go new file mode 100644 index 0000000..2a02cd6 --- /dev/null +++ b/mysql_connector.go @@ -0,0 +1,103 @@ +package db + +import ( + // "context" + "database/sql" + "fmt" + _ "github.com/go-sql-driver/mysql" + "strings" + "time" +) + +type MysqlConnector struct { +} + +func (c MysqlConnector) connect(config *DBConfig) *Connection { + /** + [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] + // user@unix(/path/to/socket)/dbname + // root:pw@unix(/tmp/mysql.sock)/myDatabase?loc=Local + // user:password@tcp(localhost:5555)/dbname?tls=skip-verify&autocommit=true + // user:password@/dbname?sql_mode=TRADITIONAL + // user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?timeout=90s&collation=utf8mb4_unicode_ci + // id:password@tcp(your-amazonaws-uri.com:3306)/dbname + // user@cloudsql(project-id:instance-name)/dbname + // user@cloudsql(project-id:regionname:instance-name)/dbname + // user:password@tcp/dbname?charset=utf8mb4,utf8&sys_var=esc%40ped + // user:password@/dbname + // user:password@/ + */ + //TODO: Protocol loc readTimeout serverPubKey timeout + if config.MaxOpenConns == 0 { + config.MaxOpenConns = 15 + } + if config.MaxIdleConns == 0 { + config.MaxIdleConns = 5 + } + if config.ConnMaxLifetime == 0 { + config.ConnMaxLifetime = 86400 + } + if config.ConnMaxIdleTime == 0 { + config.ConnMaxIdleTime = 7200 + } + var params []string + + if len(config.Charset) > 0 { + params = append(params, "charset="+config.Charset) + } + if len(config.Collation) > 0 { + params = append(params, "collation="+config.Collation) + } + if config.MultiStatements { + params = append(params, "multiStatements=true") + } + if config.ParseTime { + params = append(params, "parseTime=true") + } + var dsn string + if len(config.Dsn) > 0 { + dsn = config.Dsn + } else { + dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?%s", config.Username, config.Password, config.Host, config.Port, config.Database, strings.Join(params, "&")) + } + db, err := sql.Open(DriverMysql, dsn) + + if err != nil { + panic(err.Error()) + } else if err := db.Ping(); err != nil { + panic(err) + } else { + // 最大链接数, 默认0, 不限制 + db.SetMaxOpenConns(config.MaxOpenConns) + // 最大空闲链接数, 默认2 + db.SetMaxIdleConns(config.MaxIdleConns) + // 表示在连接池中连接的最大生存时间, 默认0, 表示不限制。 + db.SetConnMaxLifetime(time.Duration(config.ConnMaxLifetime) * time.Second) + // 连接池中空闲连接的最大生存时间, 默认0, 表示不限制。 + db.SetConnMaxIdleTime(time.Duration(config.ConnMaxIdleTime) * time.Second) + } + + // 优化事物查询速度 + // Golang 在 MYSQL 事物中不能并发执行查询, 会出现 driver:bad connection, + // 为了在事物中执行并发查询, 单独开一个脏读查询链接 + + // 处理 driver:bad connection, 需要加锁 + // 同一事务开启多协程的同时如果有并发读, 那可能会出现 driver:bad connection 错误, + // 原因是同一事务同一时间只能有一个可以进行读操作, 读完之后需要将查询得到的Rows关闭. + if len(params) > 0 { + dsn = dsn + `&` + } + RxDB, err := sql.Open(DriverMysql, dsn+`transaction_isolation='read-uncommitted'`) + if err != nil { + panic("开启事物链接失败, " + err.Error()) + } else { + RxDB.SetMaxOpenConns(5) + RxDB.SetMaxIdleConns(5) + } + + return &Connection{ + RxDB: RxDB, + DB: db, + Config: config, + } +} diff --git a/mysql_grammar.go b/mysql_grammar.go new file mode 100644 index 0000000..4c9be7d --- /dev/null +++ b/mysql_grammar.go @@ -0,0 +1,569 @@ +package db + +import ( + "errors" + "fmt" + "strings" +) + +type MysqlGrammar struct { + Prefix string + Builder *Builder +} + +func (m *MysqlGrammar) SetTablePrefix(prefix string) { + m.Prefix = prefix +} + +func (m *MysqlGrammar) GetTablePrefix() string { + return m.Prefix +} + +func (m *MysqlGrammar) SetBuilder(builder *Builder) { + m.Builder = builder +} + +func (m *MysqlGrammar) GetBuilder() *Builder { + return m.Builder +} + +func (m *MysqlGrammar) CompileInsert(values []map[string]interface{}) string { + b := m.GetBuilder() + b.PreSql.WriteString("insert into ") + b.PreSql.WriteString(m.CompileComponentTable()) + b.PreSql.WriteString(" (") + first := values[0] + length := len(values) + var columns []string + var columnizeVars []interface{} + for key := range first { + if (b.OnlyColumns == nil && b.ExceptColumns == nil) || b.FileterColumn(key) { + columns = append(columns, key) + columnizeVars = append(columnizeVars, key) + } + } + columnLength := len(columns) + b.PreSql.WriteString(m.columnize(columnizeVars)) + b.PreSql.WriteString(") values") + + for k, v := range values { + b.PreSql.WriteString(" (") + for i, key := range columns { + b.PreSql.WriteString(m.parameter(v[key])) + b.AddBinding([]interface{}{v[key]}, TYPE_INSERT) + if i != columnLength-1 { + b.PreSql.WriteString(", ") + } else { + b.PreSql.WriteString(")") + } + } + if k != length-1 { + b.PreSql.WriteString(", ") + } + } + b.PreparedSql = b.PreSql.String() + b.PreSql.Reset() + return b.PreparedSql +} + +func (m *MysqlGrammar) CompileInsertOrIgnore(values []map[string]interface{}) string { + m.GetBuilder().PreparedSql = strings.Replace(m.CompileInsert(values), "insert", "insert ignore", 1) + return m.GetBuilder().PreparedSql +} +func (m *MysqlGrammar) CompileDelete() string { + b := m.GetBuilder() + b.PreSql.WriteString("delete from ") + b.PreSql.WriteString(m.CompileComponentTable()) + b.PreSql.WriteString(m.CompileComponentWheres()) + if len(b.Orders) > 0 { + b.PreSql.WriteString(m.CompileComponentOrders()) + } + if b.LimitNum > 0 { + b.PreSql.WriteString(m.CompileComponentLimitNum()) + } + m.GetBuilder().PreparedSql = m.GetBuilder().PreSql.String() + b.PreSql.Reset() + return m.GetBuilder().PreparedSql +} + +func (m *MysqlGrammar) CompileUpdate(value map[string]interface{}) string { + b := m.GetBuilder() + b.PreSql.WriteString("update ") + b.PreSql.WriteString(m.CompileComponentTable()) + b.PreSql.WriteString(" set ") + count := 0 + length := len(value) + for k, v := range value { + count++ + if (b.OnlyColumns == nil && b.ExceptColumns == nil) || b.FileterColumn(k) { + b.PreSql.WriteString(m.Wrap(k)) + b.PreSql.WriteString(" = ") + b.AddBinding([]interface{}{v}, TYPE_UPDATE) + if e, ok := v.(Expression); ok { + b.PreSql.WriteString(string(e)) + } else { + b.PreSql.WriteString(m.parameter(v)) + } + } + if count != length { + b.PreSql.WriteString(" , ") + } + } + b.PreSql.WriteString(m.CompileComponentWheres()) + m.GetBuilder().PreparedSql = b.PreSql.String() + b.PreSql.Reset() + return m.GetBuilder().PreparedSql +} + +func (m *MysqlGrammar) CompileSelect() string { + b := m.GetBuilder() + b.PreparedSql = "" + b.PreSql = strings.Builder{} + if _, ok := b.Components[TYPE_COLUMN]; !ok || len(b.Columns) == 0 { + b.Components[TYPE_COLUMN] = struct{}{} + b.Columns = append(b.Columns, "*") + } + for _, componentName := range SelectComponents { + if _, ok := b.Components[componentName]; ok { + b.PreSql.WriteString(m.compileComponent(componentName)) + } + } + b.PreparedSql = b.PreSql.String() + b.PreSql.Reset() + return b.PreparedSql +} + +func (m *MysqlGrammar) CompileExists() string { + sql := m.CompileSelect() + m.GetBuilder().PreparedSql = fmt.Sprintf("select exists(%s) as %s", sql, m.Wrap("exists")) + return m.GetBuilder().PreparedSql +} + +func (m *MysqlGrammar) compileComponent(componentName string) string { + switch componentName { + case TYPE_AGGREGRATE: + return m.CompileComponentAggregate() + case TYPE_COLUMN: + return m.CompileComponentColumns() + case TYPE_FROM: + return m.CompileComponentFromTable() + case TYPE_JOIN: + return m.CompileComponentJoins() + case TYPE_WHERE: + return m.CompileComponentWheres() + case TYPE_GROUP_BY: + return m.CompileComponentGroups() + case TYPE_HAVING: + return m.CompileComponentHavings() + case TYPE_ORDER: + return m.CompileComponentOrders() + case TYPE_LIMIT: + return m.CompileComponentLimitNum() + case TYPE_OFFSET: + return m.CompileComponentOffsetNum() + case "unions": + case TYPE_LOCK: + return m.CompileLock() + } + return "" +} +func (m *MysqlGrammar) CompileComponentAggregate() string { + builder := strings.Builder{} + aggregate := m.GetBuilder().Aggregates[0] + builder.WriteString("select ") + builder.WriteString(aggregate.AggregateName) + builder.WriteString("(") + if m.GetBuilder().IsDistinct && aggregate.AggregateColumn != "*" { + builder.WriteString("distinct ") + builder.WriteString(m.Wrap(aggregate.AggregateColumn)) + } else { + builder.WriteString(m.Wrap(aggregate.AggregateColumn)) + } + builder.WriteString(") as aggregate") + return builder.String() +} + +// Convert []string column names into a delimited string. +// Compile the "select *" portion of the query. +func (m *MysqlGrammar) CompileComponentColumns() string { + if len(m.GetBuilder().Aggregates) == 0 { + builder := strings.Builder{} + if m.GetBuilder().IsDistinct { + builder.WriteString("select distinct ") + } else { + builder.WriteString("select ") + } + builder.WriteString(m.columnize(m.GetBuilder().Columns)) + return builder.String() + } + + return "" + +} + +func (m *MysqlGrammar) CompileComponentFromTable() string { + builder := strings.Builder{} + builder.WriteString(" from ") + builder.WriteString(m.WrapTable(m.GetBuilder().FromTable)) + return builder.String() +} +func (m *MysqlGrammar) CompileComponentTable() string { + return m.WrapTable(m.GetBuilder().FromTable) +} +func (m *MysqlGrammar) CompileComponentJoins() string { + builder := strings.Builder{} + for _, join := range m.GetBuilder().Joins { + var tableAndNestedJoins string + if len(join.Joins) > 0 { + //nested join + tableAndNestedJoins = fmt.Sprintf("(%s%s)", m.WrapTable(join.JoinTable), join.Grammar.CompileComponentJoins()) + } else { + tableAndNestedJoins = m.WrapTable(join.JoinTable) + } + onStr := join.Grammar.CompileComponentWheres() + s := "" + if len(onStr) > 0 { + s = fmt.Sprintf(" %s join %s %s", join.JoinType, tableAndNestedJoins, strings.TrimSpace(onStr)) + } else { + s = fmt.Sprintf(" %s join %s", join.JoinType, tableAndNestedJoins) + } + builder.WriteString(s) + } + return builder.String() +} + +// db.Query("select * from `groups` where `id` in (?,?,?,?)", []interface{}{7,"8","9","10"}...) +func (m *MysqlGrammar) CompileComponentWheres() string { + if len(m.GetBuilder().Wheres) == 0 { + return "" + } + builder := strings.Builder{} + if m.GetBuilder().JoinBuilder { + builder.WriteString(" on ") + } else { + builder.WriteString(" where ") + } + for i, w := range m.GetBuilder().Wheres { + if i != 0 { + builder.WriteString(" " + w.Boolean + " ") + } + if w.Type == CONDITION_TYPE_NESTED { + builder.WriteString("(") + cloneBuilder := w.Value.(*Builder) + for j := 0; j < len(cloneBuilder.Wheres); j++ { + nestedWhere := cloneBuilder.Wheres[j] + if j != 0 { + builder.WriteString(" " + nestedWhere.Boolean + " ") + } + //when compile nested where,we need bind the generated sql and params to the current builder + g := cloneBuilder.Grammar.(*MysqlGrammar) + //this will bind params to current builder when call m.parameter() + g.SetBuilder(m.GetBuilder()) + nestedSql := strings.TrimSpace(g.CompileWhere(nestedWhere)) + nestedSql = strings.Replace(nestedSql, " where ", "", 1) + builder.WriteString(nestedSql) + } + builder.WriteString(")") + } else if w.Type == CONDITION_TYPE_SUB { + builder.WriteString(m.Wrap(w.Column)) + builder.WriteString(" " + w.Operator + " ") + builder.WriteString("(") + cb := CloneBuilderWithTable(m.GetBuilder()) + + if clousure, ok := w.Value.(func(builder *Builder)); ok { + clousure(cb) + sql := cb.Grammar.CompileSelect() + builder.WriteString(sql) + } + builder.WriteString(")") + + } else { + builder.WriteString(m.CompileWhere(w)) + } + + } + return builder.String() +} +func (m *MysqlGrammar) CompileWhere(w Where) (sql string) { + var sqlBuilder strings.Builder + switch w.Type { + case CONDITION_TYPE_BASIC: + sqlBuilder.WriteString(m.Wrap(w.Column)) + sqlBuilder.WriteString(" " + w.Operator + " ") + sqlBuilder.WriteString(m.parameter(w.Value)) + case CONDITION_TYPE_BETWEEN: + sqlBuilder.WriteString(m.Wrap(w.Column)) + if w.Not { + sqlBuilder.WriteString(" not between ") + } else { + sqlBuilder.WriteString(" between ") + } + sqlBuilder.WriteString(m.parameter(w.Values[0])) + sqlBuilder.WriteString(" and ") + sqlBuilder.WriteString(m.parameter(w.Values[1])) + case CONDITION_TYPE_BETWEEN_COLUMN: + sqlBuilder.WriteString(m.Wrap(w.Column)) + if w.Not { + sqlBuilder.WriteString(" not between ") + } else { + sqlBuilder.WriteString(" between ") + } + sqlBuilder.WriteString(m.Wrap(w.Values[0])) + sqlBuilder.WriteString(" and ") + sqlBuilder.WriteString(m.Wrap(w.Values[1])) + case CONDITION_TYPE_IN: + if len(w.Values) == 0 { + if w.Not { + sqlBuilder.WriteString("1 = 1") + } else { + sqlBuilder.WriteString("0 = 1") + } + } else { + sqlBuilder.WriteString(m.Wrap(w.Column)) + if w.Not { + sqlBuilder.WriteString(" not in (") + } else { + sqlBuilder.WriteString(" in (") + } + sqlBuilder.WriteString(m.parameter(w.Values...)) + sqlBuilder.WriteString(")") + } + + case CONDITION_TYPE_DATE, CONDITION_TYPE_TIME, CONDITION_TYPE_DAY, CONDITION_TYPE_MONTH, CONDITION_TYPE_YEAR: + sqlBuilder.WriteString(w.Type) + sqlBuilder.WriteString("(") + sqlBuilder.WriteString(m.Wrap(w.Column)) + sqlBuilder.WriteString(") ") + sqlBuilder.WriteString(w.Operator) + sqlBuilder.WriteString(" ") + sqlBuilder.WriteString(m.parameter(w.Value)) + case CONDITION_TYPE_NULL: + sqlBuilder.WriteString(m.Wrap(w.Column)) + sqlBuilder.WriteString(" is ") + if w.Not { + sqlBuilder.WriteString("not ") + } + sqlBuilder.WriteString("null") + case CONDITION_TYPE_COLUMN: + sqlBuilder.WriteString(m.Wrap(w.FirstColumn)) + sqlBuilder.WriteString(" ") + sqlBuilder.WriteString(w.Operator) + sqlBuilder.WriteString(" ") + sqlBuilder.WriteString(m.Wrap(w.SecondColumn)) + case CONDITION_TYPE_RAW: + sqlBuilder.WriteString(string(w.RawSql.(Expression))) + case CONDITION_TYPE_NESTED: + sqlBuilder.WriteString("(") + sqlBuilder.WriteString(w.Value.(*Builder).Grammar.CompileComponentWheres()) + sqlBuilder.WriteString(") ") + case CONDITION_TYPE_EXIST: + if w.Not { + sqlBuilder.WriteString("not ") + } + sqlBuilder.WriteString("exists ") + sqlBuilder.WriteString(fmt.Sprintf("(%s)", w.Query.ToSql())) + case CONDITION_TYPE_ROW_VALUES: + var columns []interface{} + for _, column := range w.Columns { + columns = append(columns, column) + } + sqlBuilder.WriteString(fmt.Sprintf("(%s) %s (%s)", m.columnize(columns), w.Operator, m.parameter(w.Values...))) + + default: + panic("where type not Found") + } + return sqlBuilder.String() + +} +func (m *MysqlGrammar) parameter(values ...interface{}) string { + var ps []string + for _, value := range values { + if expre, ok := value.(Expression); ok { + ps = append(ps, string(expre)) + } else { + ps = append(ps, "?") + } + } + return strings.Join(ps, ",") +} +func (m *MysqlGrammar) CompileComponentGroups() string { + builder := strings.Builder{} + builder.WriteString(" group by ") + builder.WriteString(m.columnize(m.GetBuilder().Groups)) + return builder.String() +} + +func (m *MysqlGrammar) CompileComponentHavings() string { + builder := strings.Builder{} + builder.WriteString(" having ") + for i, having := range m.GetBuilder().Havings { + if i != 0 { + builder.WriteString(" " + having.HavingBoolean + " ") + } + if having.HavingType == CONDITION_TYPE_BASIC { + builder.WriteString(m.Wrap(having.HavingColumn)) + builder.WriteString(" ") + builder.WriteString(having.HavingOperator) + builder.WriteString(" ") + builder.WriteString(m.parameter(having.HavingValue)) + } else if having.HavingType == CONDITION_TYPE_RAW { + builder.WriteString(string(having.RawSql.(Expression))) + } else if having.HavingType == CONDITION_TYPE_BETWEEN { + vs := having.HavingValue.([]interface{}) + builder.WriteString(m.Wrap(having.HavingColumn)) + builder.WriteString(" ") + builder.WriteString(CONDITION_TYPE_BETWEEN) + builder.WriteString(" ") + builder.WriteString(m.parameter(vs[0])) + builder.WriteString(" and ") + builder.WriteString(m.parameter(vs[1])) + } + } + return builder.String() +} + +func (m *MysqlGrammar) CompileComponentOrders() string { + builder := strings.Builder{} + builder.WriteString(" order by ") + for i, order := range m.GetBuilder().Orders { + if i != 0 { + builder.WriteString(", ") + } + if order.OrderType == CONDITION_TYPE_RAW { + builder.WriteString(string(order.RawSql.(Expression))) + continue + } + builder.WriteString(m.Wrap(order.Column)) + builder.WriteString(" ") + builder.WriteString(order.Direction) + } + return builder.String() +} + +func (m *MysqlGrammar) CompileComponentLimitNum() string { + if m.GetBuilder().LimitNum >= 0 { + builder := strings.Builder{} + builder.WriteString(" limit ") + builder.WriteString(fmt.Sprintf("%v", m.GetBuilder().LimitNum)) + return builder.String() + } + return "" + +} + +func (m *MysqlGrammar) CompileComponentOffsetNum() string { + if m.GetBuilder().OffsetNum >= 0 { + builder := strings.Builder{} + builder.WriteString(" offset ") + builder.WriteString(fmt.Sprintf("%v", m.GetBuilder().OffsetNum)) + return builder.String() + } + return "" +} +func (m *MysqlGrammar) CompileLock() string { + switch m.GetBuilder().LockMode.(type) { + case string: + return " " + m.GetBuilder().LockMode.(string) + case bool: + boolean := m.GetBuilder().LockMode.(bool) + if boolean { + return " for update" + } else { + return " lock in share mode" + } + case nil: + return " for update" + } + return "" +} +func (m *MysqlGrammar) columnize(columns []interface{}) string { + builder := strings.Builder{} + var t []string + for _, value := range columns { + if s, ok := value.(string); ok { + t = append(t, m.Wrap(s)) + } else if e, ok := value.(Expression); ok { + t = append(t, string(e)) + } + } + builder.WriteString(strings.Join(t, ", ")) + return builder.String() +} + +/* +Wrap a value in keyword identifiers. +*/ +func (m *MysqlGrammar) Wrap(value interface{}, prefixAlias ...bool) string { + prefix := false + if expr, ok := value.(Expression); ok { + return string(expr) + } + str := value.(string) + if strings.Contains(str, " as ") { + if len(prefixAlias) > 0 && prefixAlias[0] { + prefix = true + } + return m.WrapAliasedValue(str, prefix) + } + return m.WrapSegments(strings.Split(str, ".")) +} + +func (m *MysqlGrammar) WrapAliasedValue(value string, prefixAlias ...bool) string { + var result strings.Builder + separator := " as " + segments := strings.SplitN(value, separator, 2) + if len(prefixAlias) > 0 && prefixAlias[0] { + segments[1] = m.GetTablePrefix() + segments[1] + } + result.WriteString(m.Wrap(segments[0])) + result.WriteString(" as ") + result.WriteString(m.WrapValue(segments[1])) + return result.String() +} + +/* +WrapSegments Wrap the given value segments. + +user.name => "prefix_user"."name" +*/ +func (m *MysqlGrammar) WrapSegments(values []string) string { + var segments []string + paramLength := len(values) + for i, value := range values { + if paramLength > 1 && i == 0 { + segments = append(segments, m.WrapTable(value)) + } else { + segments = append(segments, m.WrapValue(value)) + } + } + return strings.Join(segments, ".") +} + +/* +WrapTable wrap a table in keyword identifiers. +*/ +func (m *MysqlGrammar) WrapTable(tableName interface{}) string { + if str, ok := tableName.(string); ok { + return m.Wrap(m.GetTablePrefix()+str, true) + } else if expr, ok := tableName.(Expression); ok { + return string(expr) + } else { + panic(errors.New("tablename type mismatch")) + } +} + +// table => `table` +// t1"t2 => `t1``t2` +func (m *MysqlGrammar) WrapValue(value string) string { + if value != "*" { + return fmt.Sprintf("`%s`", strings.ReplaceAll(value, "`", "``")) + } + return value +} + +/* +CompileRandom Compile the random statement into SQL. +*/ +func (m *MysqlGrammar) CompileRandom(seed string) string { + return fmt.Sprintf("RAND(%s)", seed) +} diff --git a/raw.go b/raw.go new file mode 100644 index 0000000..c7d35ef --- /dev/null +++ b/raw.go @@ -0,0 +1,7 @@ +package db + +type Expression string + +func Raw(expr string) Expression { + return Expression(expr) +} diff --git a/scan.go b/scan.go new file mode 100644 index 0000000..407f6b8 --- /dev/null +++ b/scan.go @@ -0,0 +1,279 @@ +package db + +import ( + "database/sql" + "encoding/json" + "errors" + "reflect" +) + +type RowScan struct { + Count int +} + +func (RowScan) LastInsertId() (int64, error) { + return 0, errors.New("no insert in select") +} + +func (v RowScan) RowsAffected() (int64, error) { + return int64(v.Count), nil +} + +// 入口 +func ScanResult(rows *sql.Rows, dest any) (result RowScan) { + realDest := reflect.Indirect(reflect.ValueOf(dest)) + + if reflect.TypeOf(dest).Kind() != reflect.Ptr { + panic("dest 请传入指针类型") + } + + if realDest.Kind() == reflect.Slice { + slice := realDest.Type() + sliceItem := slice.Elem() + if sliceItem.Kind() == reflect.Map { + return sMapSlice(rows, dest) + } else if sliceItem.Kind() == reflect.Struct { + return sStructSlice(rows, dest) + } else if sliceItem.Kind() == reflect.Ptr { + if sliceItem.Elem().Kind() == reflect.Struct { + return sStructSlice(rows, dest) + } + } + return sValues(rows, dest) + } else if realDest.Kind() == reflect.Struct { + return sStruct(rows, dest) + } else if realDest.Kind() == reflect.Map { + return sMap(rows, dest) + } else { + for rows.Next() { + result.Count++ + err := rows.Scan(dest) + if err != nil { + panic(err.Error()) + } + } + } + return result +} + +// to struct slice +func sStructSlice(rows *sql.Rows, dest any) (result RowScan) { + realDest := reflect.Indirect(reflect.ValueOf(dest)) + slice := realDest.Type() + sliceItem := slice.Elem() + itemIsPtr := sliceItem.Kind() == reflect.Ptr + keys := map[string]int{} + columns, _ := rows.Columns() + + lenField := 0 + if itemIsPtr { + lenField = sliceItem.Elem().NumField() + } else { + lenField = sliceItem.NumField() + } + + // 缓存 dest 结构 + // key is field name, value is field index + for i := 0; i < lenField; i++ { + var tField reflect.StructField + if itemIsPtr { + tField = sliceItem.Elem().Field(i) + } else { + tField = sliceItem.Field(i) + } + kFiled := tField.Tag.Get("db") + if kFiled == "" { + kFiled = tField.Name + } + keys[kFiled] = i + } + + for rows.Next() { + result.Count++ + var v, vp reflect.Value + if itemIsPtr { + vp = reflect.New(sliceItem.Elem()) + v = reflect.Indirect(vp) + } else { + vp = reflect.New(sliceItem) + v = reflect.Indirect(vp) + } + + scanArgs := make([]any, len(columns)) + + // 初始化 map 字段 + for i, k := range columns { + if f, ok := keys[k]; ok { + scanArgs[i] = v.Field(f).Addr().Interface() + } else { + scanArgs[i] = new(any) + } + } + + if err := rows.Scan(scanArgs...); err != nil { + panic(err) + } + + if itemIsPtr { + realDest.Set(reflect.Append(realDest, vp)) + } else { + realDest.Set(reflect.Append(realDest, v)) + } + } + return result +} + +// to struct +func sStruct(rows *sql.Rows, dest any) (result RowScan) { + columns, _ := rows.Columns() + realDest := reflect.Indirect(reflect.ValueOf(dest)) + vp := reflect.New(realDest.Type()) + v := reflect.Indirect(vp) + + keys := map[string]int{} + + for i := 0; i < v.NumField(); i++ { + tField := v.Type().Field(i) + kFiled := tField.Tag.Get("db") + if kFiled == "" { + kFiled = tField.Name + } + keys[kFiled] = i + } + + for rows.Next() { + if result.Count == 0 { + scanArgs := make([]any, len(columns)) + + // 初始化 map 字段 + for i, k := range columns { + if f, ok := keys[k]; ok { + scanArgs[i] = v.Field(f).Addr().Interface() + } else { + scanArgs[i] = new(any) + } + } + + if err := rows.Scan(scanArgs...); err != nil { + panic(err) + } + + if realDest.Kind() == reflect.Ptr { + realDest.Set(vp) + } else { + realDest.Set(v) + } + } + + result.Count++ + } + + return result +} + +// to map slice +func sMapSlice(rows *sql.Rows, dest any) (result RowScan) { + columns, _ := rows.Columns() + realDest := reflect.Indirect(reflect.ValueOf(dest)) + + for rows.Next() { + scanArgs := make([]any, len(columns)) + element := make(map[string]any) + + result.Count++ + + // 初始化 map 字段 + for i := 0; i < len(columns); i++ { + scanArgs[i] = new(any) + } + + // 填充 map 字段 + if err := rows.Scan(scanArgs...); err != nil { + panic(err.Error()) + } + + // 将 scan 后的 map 填充到结果集 + for i, column := range columns { + switch v := (*scanArgs[i].(*any)).(type) { + case []uint8: + data := string(v) + if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") { + jData := new(any) + if err := json.Unmarshal(v, jData); err == nil { + element[column] = *jData + } + } else { + element[column] = data + } + default: + // int64 + element[column] = v + } + } + + realDest.Set(reflect.Append(realDest, reflect.ValueOf(element))) + } + + return result +} + +// to map +func sMap(rows *sql.Rows, dest any) (result RowScan) { + columns, _ := rows.Columns() + realDest := reflect.Indirect(reflect.ValueOf(dest)) + + for rows.Next() { + if result.Count == 0 { + scanArgs := make([]interface{}, len(columns)) + // 初始化 map 字段 + for i := 0; i < len(columns); i++ { + scanArgs[i] = new(any) + } + + if err := rows.Scan(scanArgs...); err != nil { + panic(err) + } + + for i, column := range columns { + var rv reflect.Value + switch v := (*scanArgs[i].(*any)).(type) { + case []uint8: + data := string(v) + if len(data) > 2 && (data[0:1] == "{" || data[0:1] == "[") { + jData := new(any) + if err := json.Unmarshal(v, jData); err == nil { + rv = reflect.ValueOf(jData).Elem() + } + } else { + rv = reflect.ValueOf(data) + } + default: + // int64 + rv = reflect.ValueOf(v) + } + realDest.SetMapIndex(reflect.ValueOf(column), rv) + } + } + + result.Count++ + } + + return result +} + +// pluck, 如 []int{} +func sValues(rows *sql.Rows, dest any) (result RowScan) { + columns, _ := rows.Columns() + realDest := reflect.Indirect(reflect.ValueOf(dest)) + scanArgs := make([]interface{}, len(columns)) + for rows.Next() { + result.Count++ + scanArgs[0] = reflect.New(realDest.Type().Elem()).Interface() + err := rows.Scan(scanArgs...) + if err != nil { + panic(err) + } + realDest.Set(reflect.Append(realDest, reflect.ValueOf(reflect.ValueOf(scanArgs[0]).Elem().Interface()))) + } + return result +} diff --git a/scan_test.go b/scan_test.go new file mode 100644 index 0000000..902f726 --- /dev/null +++ b/scan_test.go @@ -0,0 +1,73 @@ +package db + +import ( + "testing" +) + +var oDB *DB + +func init() { + oDB = Open(map[string]DBConfig{ + "default": { + Driver: "mysql", + Host: "localhost", + Port: "3366", + Database: "demo", + Username: "demo", + Password: "ded86bf25d661bb723f3898b2440dd678382e2dd", + Charset: "utf8mb4", + MultiStatements: true, + // ParseTime: true, + }, + }) +} + +func TestScanMapSlice(t *testing.T) { + dest := []map[string]any{} + + oDB.Select("select * from fsm", []any{}, &dest) + + t.Log(dest) +} + +func TestScanValues(t *testing.T) { + + dest := []int{} + + oDB.Select("select id from fsm", []any{}, &dest) + + t.Log(dest) +} + +func TestScanMap(t *testing.T) { + + dest := map[string]any{} + + oDB.Select("select * from fsm where id = 3", []any{}, &dest) + + t.Log(dest) +} + +func TestScanStruct(t *testing.T) { + + dest := struct { + Id int `db:"id"` + Code string `db:"code"` + }{} + + oDB.Select("select * from fsm", []any{}, &dest) + + t.Log(dest) +} + +func TestScanStructSlice(t *testing.T) { + + dest := []struct { + Id int `db:"id"` + Code string `db:"code"` + }{} + + oDB.Select("select * from fsm", []any{}, &dest) + + t.Log(dest) +} diff --git a/schema/blueprint.go b/schema/blueprint.go new file mode 100644 index 0000000..83956fc --- /dev/null +++ b/schema/blueprint.go @@ -0,0 +1,456 @@ +package schema + +import ( + "strings" + + "github.com/samber/lo" + + "git.fsdpf.net/go/db" +) + +type Blueprint struct { + table string // the table the blueprint describes. + columns []*ColumnDefinition // columns that should be added to the table + commands []*Command // + temporary bool // Whether to make the table temporary. + charset string // The default character set that should be used for the table. + collation string // The collation that should be used for the table. + engine string // The engine that should be used for the table. +} + +type Command struct { + Type string + CommandOptions +} + +type CommandOptions struct { + Index string + Columns []string + Algorithm string + To string + From string +} + +func NewBlueprint(table string) *Blueprint { + return &Blueprint{table: table, charset: "utf8mb4", collation: "utf8mb4_general_ci"} +} + +// 字符串 +func (this *Blueprint) Char(column string, length int) *ColumnDefinition { + if length == 0 { + length = 255 + } + return this.addColumn("char", column, &ColumnOptions{Length: length}) +} + +// 可变长度字符串 +func (this *Blueprint) String(column string, length int) *ColumnDefinition { + if length == 0 { + length = 255 + } + return this.addColumn("string", column, &ColumnOptions{Length: length}) +} + +// 文本 +func (this *Blueprint) Text(column string) *ColumnDefinition { + return this.addColumn("text", column, nil) +} + +// 整型 +func (this *Blueprint) Integer(column string, params ...bool) *ColumnDefinition { + autoIncrement := false + unsigned := false + if len(params) > 0 { + autoIncrement = params[0] + } + if len(params) > 1 { + unsigned = params[1] + } + return this.addColumn("integer", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned}) +} + +// 迷你整型 1 byte +func (this *Blueprint) TinyInteger(column string, params ...bool) *ColumnDefinition { + autoIncrement := false + unsigned := false + if len(params) > 0 { + autoIncrement = params[0] + } + if len(params) > 1 { + unsigned = params[1] + } + return this.addColumn("tinyInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned}) +} + +// 小整型 2 byte +func (this *Blueprint) SmallInteger(column string, params ...bool) *ColumnDefinition { + autoIncrement := false + unsigned := false + if len(params) > 0 { + autoIncrement = params[0] + } + if len(params) > 1 { + unsigned = params[1] + } + return this.addColumn("smallInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned}) +} + +// 大整型 2 byte +func (this *Blueprint) BigInteger(column string, params ...bool) *ColumnDefinition { + autoIncrement := false + unsigned := false + if len(params) > 0 { + autoIncrement = params[0] + } + if len(params) > 1 { + unsigned = params[1] + } + return this.addColumn("bigInteger", column, &ColumnOptions{autoIncrement: autoIncrement, unsigned: unsigned}) +} + +// 无符号整型 +func (this *Blueprint) UnsignedInteger(column string, params ...bool) *ColumnDefinition { + autoIncrement := false + if len(params) > 0 { + autoIncrement = params[0] + } + return this.Integer(column, autoIncrement, true) +} + +// 无符号迷你整型 1 byte +func (this *Blueprint) UnsignedTinyInteger(column string, params ...bool) *ColumnDefinition { + autoIncrement := false + if len(params) > 0 { + autoIncrement = params[0] + } + return this.TinyInteger(column, autoIncrement, true) +} + +// 无符号小整型 2 byte +func (this *Blueprint) UnsignedSmallInteger(column string, params ...bool) *ColumnDefinition { + autoIncrement := false + if len(params) > 0 { + autoIncrement = params[0] + } + return this.SmallInteger(column, autoIncrement, true) +} + +// 无符号大整型 +func (this *Blueprint) UnsignedBigInteger(column string, params ...bool) *ColumnDefinition { + autoIncrement := false + if len(params) > 0 { + autoIncrement = params[0] + } + return this.BigInteger(column, autoIncrement, true) +} + +// 精确小数 +func (this *Blueprint) Decimal(column string, total, places int) *ColumnDefinition { + if total == 0 { + total = 8 + } + if places == 0 { + places = 2 + } + return this.addColumn("decimal", column, &ColumnOptions{Total: total, Places: places}) +} + +// 无符号精确销售 +func (this *Blueprint) UnsignedDecimal(column string, total, places int) *ColumnDefinition { + if total == 0 { + total = 8 + } + if places == 0 { + places = 2 + } + return this.addColumn("decimal", column, &ColumnOptions{Total: total, Places: places, unsigned: true}) +} + +// 布尔值 +func (this *Blueprint) Boolean(column string) *ColumnDefinition { + return this.addColumn("boolean", column, nil) +} + +// 枚举类型 +func (this *Blueprint) Enum(column string, allowed []string) *ColumnDefinition { + return this.addColumn("enum", column, &ColumnOptions{Allowed: allowed}) +} + +// JSON +func (this *Blueprint) Json(column string) *ColumnDefinition { + return this.addColumn("json", column, nil) +} + +// 日期类型 +func (this *Blueprint) Date(column string) *ColumnDefinition { + return this.addColumn("date", column, nil) +} + +// 日期时间类型 +func (this *Blueprint) DateTime(column string, precision ...int) *ColumnDefinition { + if len(precision) > 0 { + return this.addColumn("datetime", column, &ColumnOptions{Precision: precision[0]}) + } + return this.addColumn("datetime", column, nil) +} + +// 时间类型 +func (this *Blueprint) Time(column string, precision ...int) *ColumnDefinition { + if len(precision) > 0 { + return this.addColumn("time", column, &ColumnOptions{Precision: precision[0]}) + } + return this.addColumn("time", column, nil) +} + +// 时间戳 +func (this *Blueprint) Timestamp(column string, precision ...int) *ColumnDefinition { + if len(precision) > 0 { + return this.addColumn("timestamp", column, &ColumnOptions{Precision: precision[0]}) + } + return this.addColumn("timestamp", column, nil) +} + +// 年 +func (this *Blueprint) Year(column string) *ColumnDefinition { + return this.addColumn("year", column, nil) +} + +// 二进制数据 +func (this *Blueprint) Binary(column string) *ColumnDefinition { + return this.addColumn("binary", column, nil) +} + +// UUID +func (this *Blueprint) Uuid(column string) *ColumnDefinition { + return this.addColumn("uuid", column, nil) +} + +// 自增字段 +func (this *Blueprint) Increments(column string) *ColumnDefinition { + return this.UnsignedInteger(column, true) +} + +// 自增Big字段 +func (this *Blueprint) BigIncrements(column string) *ColumnDefinition { + return this.UnsignedBigInteger(column, true) +} + +// 添加主键 +func (this *Blueprint) Primary(columns ...string) *Command { + return this.addCommand("primary", CommandOptions{Index: this.generateIndexName("pk", columns), Columns: columns}) +} + +// 唯一键 +func (this *Blueprint) Unique(columns ...string) *Command { + return this.addCommand("unique", CommandOptions{Index: this.generateIndexName("unique", columns), Columns: columns}) +} + +// 普通索引 +func (this *Blueprint) Index(columns ...string) *Command { + return this.addCommand("index", CommandOptions{Index: this.generateIndexName("index", columns), Columns: columns}) +} + +// 空间索引 +func (this *Blueprint) SpatialIndex(columns ...string) *Command { + return this.addCommand("spatialIndex", CommandOptions{Index: this.generateIndexName("spatial_index", columns), Columns: columns}) +} + +// 删除列 +func (this *Blueprint) DropColumn(columns ...string) *Command { + return this.addCommand("dropColumn", CommandOptions{Columns: columns}) +} + +// 创建表 +func (this *Blueprint) Create() *Command { + return this.addCommand("create", CommandOptions{}) +} + +// 设置临时表标记 +func (this *Blueprint) Temporary() { + this.temporary = true +} + +// 设置表字符集 +func (this *Blueprint) Charset(charset string) { + this.charset = charset +} + +// 修改表名 +func (this *Blueprint) Rename(to string) *Command { + return this.addCommand("rename", CommandOptions{To: to}) +} + +// 删除表 +func (this *Blueprint) Drop() *Command { + return this.addCommand("drop", CommandOptions{}) +} + +// 删除表, 先判断再删除 +func (this *Blueprint) DropIfExists() *Command { + return this.addCommand("dropIfExists", CommandOptions{}) +} + +// 执行SQL查询 +func (this *Blueprint) Build(conn *db.Connection, grammar IGrammar) error { + _, err := conn.Transaction(func(tx *db.Transaction) (result any, err error) { + for _, sql := range this.ToSql(conn, grammar) { + if _, err := tx.Statement(sql, []any{}); err != nil { + tx.Rollback() + return nil, err + } + } + tx.Commit() + return + }) + return err +} + +func (this *Blueprint) ToSql(conn *db.Connection, grammar IGrammar) (statements []string) { + this.addImpliedCommands(grammar) + + for _, cmd := range this.commands { + switch cmd.Type { + case "create": + statements = append(statements, grammar.CompileCreate(this, cmd, conn)) + case "add": + statements = append(statements, grammar.CompileAdd(this, cmd, conn)...) + case "change": + statements = append(statements, grammar.CompileChange(this, cmd, conn)...) + case "primary": + // statements = append(statements, grammar.CompilePrimary(this, cmd, conn)) + case "unique": + statements = append(statements, grammar.CompileUnique(this, cmd, conn)) + case "index": + statements = append(statements, grammar.CompileIndex(this, cmd, conn)) + case "spatialIndex": + statements = append(statements, grammar.CompileSpatialIndex(this, cmd, conn)) + case "drop": + statements = append(statements, grammar.CompileDrop(this, cmd, conn)) + case "dropIfExists": + statements = append(statements, grammar.CompileDropIfExists(this, cmd, conn)) + case "dropColumn": + statements = append(statements, grammar.CompileDropColumn(this, cmd, conn)...) + case "dropPrimary": + // statements = append(statements, grammar.CompileDropPrimary(this, cmd, conn)) + case "dropUnique": + statements = append(statements, grammar.CompileDropUnique(this, cmd, conn)) + case "dropIndex": + statements = append(statements, grammar.CompileDropIndex(this, cmd, conn)) + case "dropSpatialIndex": + statements = append(statements, grammar.CompileDropSpatialIndex(this, cmd, conn)) + case "rename": + statements = append(statements, grammar.CompileRename(this, cmd, conn)) + case "renameIndex": + statements = append(statements, grammar.CompileRenameIndex(this, cmd, conn)) + case "dropAllTables": + case "dropAllViews": + case "getAllTables": + case "getAllViews": + } + } + + return statements +} + +// 判断是否是创建表 +func (this *Blueprint) creating() bool { + return lo.SomeBy(this.commands, func(item *Command) bool { + if item.Type == "create" { + return true + } + return false + }) +} + +func (this *Blueprint) addColumn(typ, name string, options *ColumnOptions) (definition *ColumnDefinition) { + definition = &ColumnDefinition{Type: typ, Name: name} + + if options != nil { + if options.Length > 0 { + definition.Length = options.Length + } + if options.autoIncrement { + definition.autoIncrement = true + } + if options.unsigned { + definition.unsigned = true + } + if options.Total > 0 { + definition.Total = options.Total + } + if options.Places > 0 { + definition.Places = options.Places + } + if len(options.Allowed) > 0 { + definition.Allowed = options.Allowed + } + if options.Precision > 0 { + definition.Precision = options.Precision + } + if options.change { + definition.change = true + } + } + this.columns = append(this.columns, definition) + return definition +} + +func (this *Blueprint) addImpliedCommands(grammar IGrammar) { + if !this.creating() { + if len(this.getAddedColumns()) > 0 { + this.commands = append([]*Command{this.createCommand("add", CommandOptions{})}, this.commands...) + } + if len(this.getChangedColumns()) > 0 { + this.commands = append([]*Command{this.createCommand("change", CommandOptions{})}, this.commands...) + } + } + + this.addFluentIndexes() +} + +// 添加索引字段 +func (this *Blueprint) addFluentIndexes() { + for _, column := range this.columns { + if column.primary { + this.Primary(column.Name) + continue + } else if column.unique { + this.Unique(column.Name) + continue + } else if column.index { + this.Index(column.Name) + continue + } else if column.spatialIndex { + this.SpatialIndex(column.Name) + continue + } + + } +} + +func (this *Blueprint) addCommand(name string, options CommandOptions) (command *Command) { + command = this.createCommand(name, options) + this.commands = append(this.commands, command) + return command +} + +func (this *Blueprint) createCommand(name string, options CommandOptions) *Command { + return &Command{Type: name, CommandOptions: options} +} + +// 生成索引名称 +func (this *Blueprint) generateIndexName(typ string, columns []string) string { + return strings.ToLower(typ + "_" + strings.Join(columns, "_")) +} + +func (this *Blueprint) getAddedColumns() []*ColumnDefinition { + return lo.Filter(this.columns, func(item *ColumnDefinition, _ int) bool { + return !item.change + }) +} + +func (this *Blueprint) getChangedColumns() []*ColumnDefinition { + return lo.Filter(this.columns, func(item *ColumnDefinition, _ int) bool { + return item.change + }) +} diff --git a/schema/blueprint_test.go b/schema/blueprint_test.go new file mode 100644 index 0000000..a9458fb --- /dev/null +++ b/schema/blueprint_test.go @@ -0,0 +1,216 @@ +package schema + +import ( + "git.fsdpf.net/go/db" + "testing" +) + +// func TestMysqlToSql(t *testing.T) { +// bp := NewBlueprint("users") + +// // bp.BigIncrements("id") // 等同于自增 UNSIGNED BIGINT(主键)列 +// // bp.BigInteger("votes").Default("0") // 等同于 BIGINT 类型列 +// // bp.Binary("data") // 等同于 BLOB 类型列 +// // bp.Boolean("confirmed") // 等同于 BOOLEAN 类型列 +// // bp.Char("name", 4) // 等同于 CHAR 类型列 +// // bp.Date("created_at") // 等同于 DATE 类型列 +// // bp.DateTime("created_at") // 等同于 DATETIME 类型列 +// // bp.Decimal("amount", 5, 2) // 等同于 DECIMAL 类型列,带精度和范围 +// // bp.Enum("level", []string{"easy", "hard"}) // 等同于 ENUM 类型列 +// bp.Increments("id").Change("pid") // 等同于自增 UNSIGNED INTEGER (主键)类型列 +// // bp.Integer("votes").AutoIncrement() // 等同于 INTEGER 类型列 +// bp.Json("options").Charset("uft8").Change() // 等同于 JSON 类型列 +// // bp.SmallInteger("votes").Comment("等同于 SMALLINT 类型列") // 等同于 SMALLINT 类型列 +// // bp.String("name", 100).Nullable() // 等同于 VARCHAR 类型列,带一个可选长度参数 +// // bp.Text("description").After("column") // 等同于 TEXT 类型列 +// // bp.Time("sunrise").First() // 等同于 TIME 类型列 +// bp.Timestamp("added_on").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")) // 等同于 TIMESTAMP 类型列 +// // bp.TinyInteger("numbers") // 等同于 TINYINT 类型列 +// // bp.UnsignedBigInteger("votes") // 等同于无符号的 BIGINT 类型列 +// // bp.UnsignedDecimal("amount", 8, 2) // 等同于 UNSIGNED DECIMAL 类型列,带有总位数和精度 +// // bp.UnsignedInteger("votes") // 等同于无符号的 INTEGER 类型列 +// // bp.UnsignedSmallInteger("votes") // 等同于无符号的 SMALLINT 类型列 +// // bp.UnsignedTinyInteger("votes") // 等同于无符号的 TINYINT 类型列 +// // bp.Uuid("id").Default("00000000-0000-0000-0000-000000000000") // 等同于 UUID 类型列 +// // bp.Year("birth_year") +// bp.DropColumn("name", "name_a") // 等同于 YEAR 类型列 + +// t.Log(bp.ToSql(mysql.Connection, &MysqlGrammar{})) +// } + +// func TestSqliteToSql(t *testing.T) { +// bp := NewBlueprint("users") + +// bp.charset = "utf8mb4" +// bp.collation = "utf8mb4_general_ci" +// // bp.BigIncrements("id").Comment("ID") +// // bp.Boolean("enabled").Default("1").Comment("是否有效") +// bp.Char("created_user", 36).Default("22000000-0000-0000-0000-000000000001").Comment("创建者").Change() +// bp.Char("owned_user", 36).Default("11000000-0000-0000-0000-000000000000").Comment("拥有者").Change() +// // bp.Timestamp("created_at").UseCurrent().Comment("创建时间") +// // bp.Timestamp("updated_at").Default(db.Raw("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间") +// bp.DateTime("deleted_at").Nullable().Comment("删除时间") +// bp.String("name_a", 50).Default("").Comment("姓名").Change("c_name") + +// bp.DropColumn("name", "name_a") + +// t.Log(bp.ToSql(sqlite3.Connection, &Sqlite3Grammar{})) +// } + +// Mysql 创建表 +func TestMysqlCreateTableToSql(t *testing.T) { + table := NewBlueprint("users") + table.Create() + table.Charset("utf8mb4") + + table.BigIncrements("id").Comment("ID") + table.Boolean("enabled").Default("1").Comment("是否有效") + table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者") + table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者") + table.Timestamp("created_at").UseCurrent().Comment("创建时间") + table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间") + table.DateTime("deleted_at").Nullable().Comment("删除时间") + + for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) { + t.Logf("-- %s -- %s \n", "mysql", sql) + } +} + +// SQLite 创建表 +func TestSqliteCreateTableToSql(t *testing.T) { + table := NewBlueprint("users") + table.Create() + table.Charset("utf8mb4") + + table.BigIncrements("id").Comment("ID") + table.Boolean("enabled").Default("1").Comment("是否有效") + table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者") + table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者") + table.Timestamp("created_at").UseCurrent().Comment("创建时间") + table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间") + table.DateTime("deleted_at").Nullable().Comment("删除时间") + + for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) { + t.Logf("-- %s -- %s \n", "sqlite", sql) + } +} + +//Mysql 修改表, 添加字段 +func TestMysqlAddColumnToSql(t *testing.T) { + table := NewBlueprint(test_table) + + table.String("name", 50).Nullable().Comment("用户名") + table.TinyInteger("age").Nullable().Comment("年龄") + + for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) { + t.Logf("-- %s -- %s \n", "mysql", sql) + } +} + +// Sqlite 修改表, 添加字段 +func TestSqliteAddColumnToSql(t *testing.T) { + table := NewBlueprint(test_table) + + table.String("name", 50).Nullable().Comment("用户名") + table.TinyInteger("age").Nullable().Comment("年龄") + + for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) { + t.Logf("-- %s -- %s \n", "sqlite", sql) + } +} + +// Mysql 修改表, 添加/编辑字段 +func TestMysqlChangeColumnToSql(t *testing.T) { + table := NewBlueprint(test_table) + + table.String("name", 100).Nullable().Comment("用户名").Change() + table.SmallInteger("age").Nullable().Comment("年龄").Change() + table.String("nickname", 100).Nullable().Comment("昵称") + + for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) { + t.Logf("-- %s -- %s \n", "mysql", sql) + } +} + +// Sqlite 修改表, 添加/编辑字段 +func TestSqliteChangeColumnToSql(t *testing.T) { + table := NewBlueprint(test_table) + + table.String("name", 100).Nullable().Comment("用户名").Change() + table.SmallInteger("age").Nullable().Comment("年龄").Change() + table.String("nickname", 100).Nullable().Comment("昵称") + + for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) { + t.Logf("-- %s -- %s \n", "sqlite", sql) + } +} + +// Mysql 修改表, 添加/编辑/删除字段 +func TestMysqlDropColumnToSql(t *testing.T) { + table := NewBlueprint(test_table) + + table.String("username", 100).Default("").Comment("--用户名--").Change("name") + table.Boolean("is_vip").Default(1).Comment("是否VIP") + table.DropColumn("age") + + for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) { + t.Logf("-- %s -- %s \n", "mysql", sql) + } +} + +// Sqlite 修改表, 添加/编辑/删除字段 +func TestSqliteDropColumnToSql(t *testing.T) { + table := NewBlueprint(test_table) + + table.String("username", 100).Default("").Comment("--用户名--").Change("name") + table.Boolean("is_vip").Default(1).Comment("是否VIP") + table.DropColumn("age") + + for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) { + t.Logf("-- %s -- %s \n", "sqlite", sql) + } +} + +// Mysql 重命名表 +func TestMysqlRenameToSql(t *testing.T) { + table := NewBlueprint(test_table) + + table.Rename(test_table + "_alias") + + for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) { + t.Logf("-- %s -- %s \n", "mysql", sql) + } +} + +// Sqlite 重命名表 +func TestSqliteRenameToSql(t *testing.T) { + table := NewBlueprint(test_table) + + table.Rename(test_table + "_alias") + + for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) { + t.Logf("-- %s -- %s \n", "sqlite", sql) + } +} + +// Mysql 删除表 +func TestMysqlDropTableToSql(t *testing.T) { + table := NewBlueprint(test_table + "_alias") + + table.Drop() + + for _, sql := range table.ToSql(mysql.Connection, &MysqlGrammar{}) { + t.Logf("-- %s -- %s \n", "mysql", sql) + } +} + +// Sqlite 删除表 +func TestSqliteDropTableToSql(t *testing.T) { + table := NewBlueprint(test_table + "_alias") + + table.Drop() + + for _, sql := range table.ToSql(sqlite3.Connection, &Sqlite3Grammar{}) { + t.Logf("-- %s -- %s \n", "sqlite", sql) + } +} diff --git a/schema/builder.go b/schema/builder.go new file mode 100644 index 0000000..6f21f6b --- /dev/null +++ b/schema/builder.go @@ -0,0 +1,128 @@ +package schema + +import ( + "github.com/samber/lo" + + "git.fsdpf.net/go/db" +) + +type Builder struct { + Connection *db.Connection // The database connection instance + Grammar IGrammar // The schema grammar instance +} + +func NewSchema(conn *db.Connection) *Builder { + var grammar IGrammar + + switch conn.GetConfig().Driver { + case "mysql": + grammar = &MysqlGrammar{} + case "sqlite3": + grammar = &Sqlite3Grammar{} + } + + return &Builder{ + Connection: conn, + Grammar: grammar, + } +} + +// 判断数据表是否存在 +func (this *Builder) HasTable(table string) (bool, error) { + query := this.Grammar.CompileTableExists() + dbName := this.Connection.GetConfig().Database + + count := 0 + + if _, err := this.Connection.Select(query, []any{table, dbName}, &count); err != nil || count == 0 { + return false, err + } + + return true, nil +} + +// 判断表字段是个存在 +func (this *Builder) HasColumns(table string, columns ...string) (bool, error) { + if tColumns, err := this.GetColumnListing(table); err != nil { + return false, err + } else { + for _, col := range columns { + if !lo.Contains(tColumns, col) { + return false, nil + } + } + } + return true, nil +} + +// 判断数据库列是否存在 +func (this *Builder) GetColumnListing(table string) (columns []string, err error) { + query := this.Grammar.CompileColumnListing(table) + dbName := this.Connection.GetConfig().Database + + items := []struct { + ID int `db:"id"` + Name string `db:"name"` + }{} + + bindings := []any{table, dbName} + + if this.Connection.GetConfig().Driver == "sqlite3" { + bindings = nil + } + + if _, err := this.Connection.Select(query, bindings, &items); err != nil { + return nil, err + } + + for _, item := range items { + columns = append(columns, item.Name) + } + + return +} + +// 修改表 +func (this *Builder) Table(table string, cb func(*Blueprint)) error { + bp := this.createBlueprint(table) + cb(bp) + return this.Build(bp) +} + +// 创建表 +func (this *Builder) Create(table string, cb func(*Blueprint)) error { + bp := this.createBlueprint(table) + bp.Create() + cb(bp) + return this.Build(bp) +} + +// 修改表名 +func (this *Builder) Rename(from, to string) error { + bp := this.createBlueprint(from) + bp.Rename(to) + return this.Build(bp) +} + +// 删除表 +func (this *Builder) Drop(table string) error { + bp := this.createBlueprint(table) + bp.Drop() + return this.Build(bp) +} + +// 删除表, 先判断再删除 +func (this *Builder) DropIfExists(table string) error { + bp := this.createBlueprint(table) + bp.DropIfExists() + return this.Build(bp) +} + +func (b *Builder) createBlueprint(table string) *Blueprint { + return NewBlueprint(table) +} + +// Build execute the blueprint to build / modify the table +func (this *Builder) Build(bp *Blueprint) error { + return bp.Build(this.Connection, this.Grammar) +} diff --git a/schema/builder_test.go b/schema/builder_test.go new file mode 100644 index 0000000..dc2a3df --- /dev/null +++ b/schema/builder_test.go @@ -0,0 +1,195 @@ +package schema + +import ( + "git.fsdpf.net/go/db" + "testing" +) + +const test_table = "test_users" + +var mysql *Builder +var sqlite3 *Builder + +func init() { + odbc := db.Open(map[string]db.DBConfig{ + "default": { + Driver: "mysql", + Host: "localhost", + Port: "3366", + Database: "demo", + Username: "demo", + Password: "ded86bf25d661bb723f3898b2440dd678382e2dd", + Charset: "utf8mb4", + MultiStatements: true, + // ParseTime: true, + }, + "sqlite3": { + Driver: "sqlite3", + File: "test.db3", + }, + }) + + mysql = NewSchema(odbc.Connection("default")) + sqlite3 = NewSchema(odbc.Connection("sqlite3")) +} + +func TestHasTable(t *testing.T) { + t.Log(mysql.HasTable("users")) + t.Log(sqlite3.HasTable("users")) +} + +func TestHasColumns(t *testing.T) { + t.Log("mysql") + t.Log(mysql.HasColumns("users", "id")) + t.Log("sqlite3") + t.Log(sqlite3.HasColumns("users", "id")) +} + +func TestGetColumnListing(t *testing.T) { + t.Log("mysql") + t.Log(mysql.GetColumnListing("users")) + t.Log("sqlite3") + t.Log(sqlite3.GetColumnListing("users")) +} + +// Mysql 创建表 +func TestMysqlCreateTable(t *testing.T) { + if err := mysql.Create(test_table, func(table *Blueprint) { + table.Charset("utf8mb4") + + table.BigIncrements("id").Comment("ID") + table.Boolean("enabled").Default("1").Comment("是否有效") + table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者") + table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者") + table.Timestamp("created_at").UseCurrent().Comment("创建时间") + table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间") + table.DateTime("deleted_at").Nullable().Comment("删除时间") + }); err != nil { + t.Error(err) + } +} + +// Sqlite 创建表 +func TestSqliteCreateTable(t *testing.T) { + if err := sqlite3.Create(test_table, func(table *Blueprint) { + table.Charset("utf8mb4") + + table.Increments("id").Comment("ID") + table.Boolean("enabled").Default("1").Comment("是否有效") + table.Char("created_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("创建者") + table.Char("owned_user", 36).Default("00000000-0000-0000-0000-000000000000").Comment("拥有者") + table.Timestamp("created_at").UseCurrent().Comment("创建时间") + table.Timestamp("updated_at").UseCurrent().Default(db.Raw("ON UPDATE CURRENT_TIMESTAMP")).Comment("更新时间") + table.DateTime("deleted_at").Nullable().Comment("删除时间") + }); err != nil { + t.Error(err) + } +} + +// Mysql 修改表, 添加字段 +func TestMysqlAddColumn(t *testing.T) { + if err := mysql.Table(test_table, func(table *Blueprint) { + table.String("name", 50).Nullable().Comment("用户名") + table.TinyInteger("age").Nullable().Comment("年龄") + }); err != nil { + t.Error(err) + } +} + +// Sqlite 修改表, 添加字段 +func TestSqliteAddColumn(t *testing.T) { + if err := sqlite3.Table(test_table, func(table *Blueprint) { + table.String("name", 50).Nullable().Comment("用户名") + table.TinyInteger("age").Nullable().Comment("年龄") + }); err != nil { + t.Error(err) + } +} + +// Mysql 修改表, 添加/编辑字段 +func TestMysqlChangeColumn(t *testing.T) { + if err := mysql.Table(test_table, func(table *Blueprint) { + table.String("name", 100).Nullable().Comment("用户名").Change("username") + table.SmallInteger("age").Nullable().Comment("年龄").Change() + + table.String("nickname", 100).Nullable().Comment("昵称") + }); err != nil { + t.Error(err) + } +} + +// Sqlite 修改表, 添加/编辑字段 +func TestSqliteChangeColumn(t *testing.T) { + if err := sqlite3.Table(test_table, func(table *Blueprint) { + table.String("name", 100).Nullable().Comment("用户名").Change("username") + table.SmallInteger("age").Nullable().Comment("年龄").Change() + + table.String("nickname", 100).Nullable().Comment("昵称") + }); err != nil { + t.Error(err) + } +} + +// Mysql 修改表, 添加/编辑/删除字段 +func TestMysqlDropColumn(t *testing.T) { + if err := mysql.Table(test_table, func(table *Blueprint) { + table.String("username", 100).Default("").Comment("--用户名--").Change("name") + table.Boolean("is_vip").Default(1).Comment("是否VIP") + table.DropColumn("age") + }); err != nil { + t.Error(err) + } +} + +// Sqlite 修改表, 添加/编辑/删除字段 +func TestSqliteDropColumn(t *testing.T) { + if err := sqlite3.Table(test_table, func(table *Blueprint) { + table.String("username", 100).Default("").Comment("--用户名--").Change("name") + table.Boolean("is_vip").Default(1).Comment("是否VIP") + table.DropColumn("age") + }); err != nil { + t.Error(err) + } +} + +// Mysql重命名表 +func TestMysqlRename(t *testing.T) { + if err := mysql.Rename(test_table, test_table+"_alias"); err != nil { + t.Error(err) + } +} + +// SQLite 重命名表 +func TestSqliteRename(t *testing.T) { + if err := sqlite3.Rename(test_table, test_table+"_alias"); err != nil { + t.Error(err) + } +} + +// Mysql 删除表 +func TestMysqlDropTable(t *testing.T) { + if err := mysql.Drop(test_table + "_alias"); err != nil { + t.Error(err) + } +} + +// Sqlite 删除表 +func TestSqliteDropTable(t *testing.T) { + if err := sqlite3.Drop(test_table + "_alias"); err != nil { + t.Error(err) + } +} + +// Mysql 删除表 +func TestMysqlDropTableIfExists(t *testing.T) { + if err := mysql.DropIfExists(test_table + "_alias"); err != nil { + t.Error(err) + } +} + +// Sqlite 删除表 +func TestSqliteDropTableIfExists(t *testing.T) { + if err := sqlite3.DropIfExists(test_table + "_alias"); err != nil { + t.Error(err) + } +} diff --git a/schema/column_definition.go b/schema/column_definition.go new file mode 100644 index 0000000..e80a3e9 --- /dev/null +++ b/schema/column_definition.go @@ -0,0 +1,115 @@ +package schema + +type ColumnDefinition struct { + Type string // kind of column, string / int + Name string // name of column + ColumnOptions +} + +type ColumnOptions struct { + Length int // 字段长度 + Allowed []string // enum 选项 + Precision int // 日期时间精度 + Total int // 小数位数 + Places int // 小数精度 + rename string // 重命名 + useCurrent bool // CURRENT TIMESTAMP + after string // Place the column "after" another column (MySQL) + always bool // Used as a modifier for generatedAs() (PostgreSQL) + autoIncrement bool // Set INTEGER columns as auto-increment (primary key) + change bool // Change the column + charset string // Specify a character set for the column (MySQL) + collation string // Specify a collation for the column (MySQL/PostgreSQL/SQL Server) + comment string // Add a comment to the column (MySQL) + def any // Specify a "default" value for the column + first bool // Place the column "first" in the table (MySQL) + nullable bool // Allow NULL values to be inserted into the column + storedAs string // Create a stored generated column (MySQL) + unsigned bool // Set the INTEGER column as UNSIGNED (MySQL) + virtualAs string // Create a virtual generated column (MySQL) + unique bool // Add a unique index + primary bool // Add a primary index + index bool // Add an index + spatialIndex bool // Add a spatial index +} + +// VirtualAs Create a virtual generated column (MySQL) +func (c *ColumnDefinition) VirtualAs(as string) *ColumnDefinition { + c.virtualAs = as + return c +} + +// StoredAs Create a stored generated column (MySQL) +func (c *ColumnDefinition) StoredAs(as string) *ColumnDefinition { + c.storedAs = as + return c +} + +// Unsigned Set INTEGER columns as UNSIGNED (MySQL) +func (c *ColumnDefinition) Unsigned() *ColumnDefinition { + c.unsigned = true + return c +} + +// First Place the column "first" in the table (MySQL) +func (c *ColumnDefinition) First() *ColumnDefinition { + c.first = true + return c +} + +// Default Specify a "default" value for the column +func (c *ColumnDefinition) Default(def any) *ColumnDefinition { + c.def = def + return c +} + +// Comment Add a comment to a column (MySQL/PostgreSQL) +func (c *ColumnDefinition) Comment(comm string) *ColumnDefinition { + c.comment = comm + return c +} + +// Collaction Specify a collation for the column (MySQL/PostgreSQL/SQL Server) +func (c *ColumnDefinition) Collaction(coll string) *ColumnDefinition { + c.collation = coll + return c +} + +// Charset Specify a character set for the column (MySQL) +func (c *ColumnDefinition) Charset(chars string) *ColumnDefinition { + c.charset = chars + return c +} + +// AutoIncrement set INTEGER columns as auto-increment (primary key) +func (c *ColumnDefinition) AutoIncrement() *ColumnDefinition { + c.autoIncrement = true + return c +} + +// After place the column "after" another column (MySQL) +func (c *ColumnDefinition) After(column string) *ColumnDefinition { + c.after = column + return c +} + +// Nullable makes a column nullable +func (c *ColumnDefinition) Nullable() *ColumnDefinition { + c.nullable = true + return c +} + +// 修改字段, 默认新增 +func (c *ColumnDefinition) Change(param ...string) *ColumnDefinition { + if len(param) > 0 && param[0] != "" { + c.rename = param[0] + } + c.change = true + return c +} + +// 时间戳 +func (c *ColumnDefinition) UseCurrent() *ColumnDefinition { + c.useCurrent = true + return c +} diff --git a/schema/grammar.go b/schema/grammar.go new file mode 100644 index 0000000..1d28862 --- /dev/null +++ b/schema/grammar.go @@ -0,0 +1,44 @@ +package schema + +import ( + "git.fsdpf.net/go/db" +) + +type IGrammar interface { + // 判断表是否存在 + CompileTableExists() string + // 字段列 + CompileColumnListing(table string) string + // 创建表 + CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string + // 添加字段 + CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string + // 修改字段 + CompileChange(bp *Blueprint, command *Command, conn *db.Connection) []string + // 添加主键 + // CompilePrimary(bp *Blueprint, command *Command, conn *db.Connection) string + // 添加唯一键 + CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string + // 添加普通索引 + CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string + // 添加空间索引 + CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string + // 删除表 + CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string + // 删除表, 先判断再删除 + CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string + // 删除表, 删除列 + CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) []string + // 删除主键 + // CompileDropPrimary(bp *Blueprint, command *Command, conn *db.Connection) string + // 删除唯一键 + CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string + // 删除唯普通索引 + CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string + // 删除唯空间索引 + CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string + // 表重命名 + CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string + // 索引重命名 + CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string +} diff --git a/schema/mysql_grammar.go b/schema/mysql_grammar.go new file mode 100644 index 0000000..bfa3091 --- /dev/null +++ b/schema/mysql_grammar.go @@ -0,0 +1,338 @@ +package schema + +import ( + "fmt" + "strings" + + "github.com/samber/lo" + + "git.fsdpf.net/go/db" +) + +type MysqlGrammar struct { +} + +var mysqlDefaultModifiers = []string{ + "Unsigned", "Charset", "Collate", "VirtualAs", "StoredAs", + "Nullable", "Default", "Increment", "Comment", "After", "First", +} +var mysqlSerials = []string{ + "bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger", +} + +func (this MysqlGrammar) CompileTableExists() string { + return "select count(*) from information_schema.tables where table_name = ? and table_schema = ? and table_type = 'BASE TABLE'" +} + +func (this MysqlGrammar) CompileColumnListing(table string) string { + return "select ordinal_position as `id`, column_name as `name` from information_schema.columns where table_name = ? and table_schema = ? order by ordinal_position" +} + +// 创建表 +func (this MysqlGrammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string { + sql := this.compileCreateTable(bp, command, conn) + sql = this.compileCreateEncoding(sql, bp, conn) + sql = this.compileCreateEngine(sql, bp, conn) + return sql +} + +// 创建表结构 +func (this MysqlGrammar) compileCreateTable(bp *Blueprint, command *Command, conn *db.Connection) string { + temporary := "create" + if bp.temporary { + temporary = "create temporary" + } + return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", ")) +} + +// 创建表 字符集 +func (this MysqlGrammar) compileCreateEncoding(sql string, bp *Blueprint, conn *db.Connection) string { + if bp.charset != "" { + sql = sql + " default charset=" + bp.charset + } else if conn.GetConfig().Charset != "" { + sql = sql + " default charset=" + conn.GetConfig().Charset + } else { + sql = sql + " default charset=utf8mb4" + } + + if bp.collation != "" { + sql = sql + " collate=" + bp.collation + } else { + sql = sql + " collate=utf8mb4_general_ci" + } + + return sql +} + +// 创建表存储方式 +func (this MysqlGrammar) compileCreateEngine(sql string, bp *Blueprint, conn *db.Connection) string { + if bp.engine != "" { + sql = sql + " engine = " + bp.engine + } + return sql +} + +// 添加字段 +func (this MysqlGrammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string { + columns := PrefixArray("add", this.getAddedColumns(bp)) + return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")} +} + +// 修改字段 +func (this MysqlGrammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) []string { + columns := PrefixArray("change", this.getChangedColumns(bp)) + return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")} +} + +// 添加主键 +func (this MysqlGrammar) CompilePrimary(bp *Blueprint, command *Command, conn *db.Connection) string { + return this.compileKey(bp, command, "primary key") +} + +// 添加唯一键 +func (this MysqlGrammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string { + return this.compileKey(bp, command, "unique") +} + +// 添加普通索引 +func (this MysqlGrammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + return this.compileKey(bp, command, "index") +} + +// 添加空间索引 +func (this MysqlGrammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + return this.compileKey(bp, command, "spatial index") +} + +// 添加索引 +func (this MysqlGrammar) compileKey(bp *Blueprint, command *Command, typ string) string { + algorithm := "" + if command.Algorithm != "" { + algorithm = " using " + command.Algorithm + } + columns := lo.Map(command.Columns, func(column string, _ int) string { + return "`" + column + "`" + }) + return fmt.Sprintf("alter table %s add %s %s%s(%s)", bp.table, typ, command.Index, algorithm, strings.Join(columns, ",")) +} + +// 删除表 +func (this MysqlGrammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string { + return "drop table `" + bp.table + "`" +} + +// 删除表, 先判断再删除 +func (this MysqlGrammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string { + return "drop table if exists `" + bp.table + "`" +} + +// 删除列 +func (this MysqlGrammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) []string { + columns := lo.Map(command.Columns, func(column string, _ int) string { + return "drop `" + column + "`" + }) + return []string{"alter table `" + bp.table + "` " + strings.Join(columns, ", ")} +} + +// 删除主键 +func (this MysqlGrammar) CompileDropPrimary(bp *Blueprint, command *Command, conn *db.Connection) string { + return "alter table `" + bp.table + "` drop primary key" +} + +// 删除索引 +func (this MysqlGrammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + return "alter table `" + bp.table + "` drop index `" + command.Index + "`" +} + +// 删除唯一键 +func (this MysqlGrammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string { + return this.CompileDropIndex(bp, command, conn) +} + +// 删除空间索引 +func (this MysqlGrammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + return this.CompileDropIndex(bp, command, conn) +} + +// 表重命名 +func (this MysqlGrammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string { + return "rename table `" + bp.table + "` to `" + command.To + "`" +} + +// 索引重命名 +func (this MysqlGrammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + return fmt.Sprintf("alter table `%s` rename index `%s` to `%s`", bp.table, command.From, command.To) +} + +// 删除所有表 +func (this MysqlGrammar) CompileDropAllTables(tables ...string) string { + tables = lo.Map(tables, func(table string, _ int) string { + return "`" + table + "`" + }) + return "drop table " + strings.Join(tables, ", ") +} + +// 删除所有视图 +func (this MysqlGrammar) CompileDropAllViews(views ...string) string { + views = lo.Map(views, func(view string, _ int) string { + return "`" + view + "`" + }) + return "drop view " + strings.Join(views, ", ") +} + +// 查询所有表 +// func (this MysqlGrammar) CompileGetAllTables(tables ...string) string { + +// } + +// 查询所有视图 +// func (this MysqlGrammar) CompileGetAllViews(views ...string) string {} + +func (this MysqlGrammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string { + for _, modifier := range mysqlDefaultModifiers { + sql = sql + this.GetColumnModifier(modifier, bp, column) + } + return sql +} + +func (this MysqlGrammar) GetColumnType(column *ColumnDefinition) string { + switch column.Type { + case "char": + return fmt.Sprintf("char(%d)", column.Length) + case "string": + return fmt.Sprintf("varchar(%d)", column.Length) + case "text": + return "text" + case "integer": + return "int" + case "bigInteger": + return "bigint" + case "tinyInteger": + return "tinyint" + case "smallInteger": + return "smallint" + case "decimal": + return fmt.Sprintf("decimal(%d,%d)", column.Total, column.Places) + case "boolean": + return "tinyint(1)" + case "enum": + return fmt.Sprintf("enum(%s)", QuoteString(column.Allowed)) + case "json": + return "json" + case "date": + return "date" + case "binary": + return "binary" + case "datetime": + columnType := "datetime" + if column.Precision > 0 { + columnType = fmt.Sprintf("datetime(%d)", column.Precision) + } + return columnType + case "time": + columnType := "time" + if column.Precision > 0 { + columnType = fmt.Sprintf("time(%d)", column.Precision) + } + return columnType + case "timestamp": + columnType := "timestamp" + if column.Precision > 0 { + columnType = fmt.Sprintf("timestamp(%d)", column.Precision) + } + return columnType + case "year": + return "year" + case "uuid": + return "char(36)" + } + panic("不支持的数据类型: " + column.Type) +} + +func (this MysqlGrammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string { + switch modifier { + case "Unsigned": + if column.unsigned { + return " unsigned" + } + case "Charset": + if column.charset != "" { + return fmt.Sprintf(" character set %s", column.charset) + } + case "Collate": + if column.collation != "" { + return fmt.Sprintf(" collate '%s'", column.collation) + } + case "VirtualAs": + if column.virtualAs != "" { + return fmt.Sprintf(" as (%s)", column.virtualAs) + } + case "StoredAs": + if column.storedAs != "" { + return fmt.Sprintf(" as (%s) stored", column.storedAs) + } + case "Nullable": + if column.virtualAs == "" && column.storedAs == "" { + if column.nullable { + return " null" + } + return " not null" + } + case "Default": + switch v := column.def.(type) { + case db.Expression: + if column.useCurrent { + return fmt.Sprintf(" default %s %s", "CURRENT_TIMESTAMP", v) + } + return fmt.Sprintf(" default %s", v) + case nil: + if column.useCurrent { + return fmt.Sprintf(" default %s", "CURRENT_TIMESTAMP") + } + default: + return fmt.Sprintf(" default '%v'", v) + } + + case "Increment": + if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) { + return " auto_increment primary key" + } + case "Comment": + if column.comment != "" { + return fmt.Sprintf(" comment '%s'", db.Addslashes(column.comment)) + } + case "After": + if column.after != "" { + return fmt.Sprintf(" after `%s`", column.after) + } + case "First": + if column.first { + return " first" + } + } + + return "" +} + +// 获取新增表结构字段 +func (this MysqlGrammar) getAddedColumns(bp *Blueprint) (columns []string) { + for _, column := range bp.getAddedColumns() { + sql := "`" + column.Name + "` " + this.GetColumnType(column) + columns = append(columns, this.addModifiers(sql, bp, column)) + } + return +} + +// 获取修改表结构字段 +func (this MysqlGrammar) getChangedColumns(bp *Blueprint) (columns []string) { + for _, column := range bp.getChangedColumns() { + name := column.Name + rename := column.Name + if column.rename != "" { + rename = column.rename + } + sql := fmt.Sprintf("`%s` `%s` ", name, rename) + this.GetColumnType(column) + columns = append(columns, this.addModifiers(sql, bp, column)) + } + return +} diff --git a/schema/sqlite3_grammar.go b/schema/sqlite3_grammar.go new file mode 100644 index 0000000..6720031 --- /dev/null +++ b/schema/sqlite3_grammar.go @@ -0,0 +1,351 @@ +package schema + +import ( + "fmt" + "regexp" + "strings" + + "github.com/samber/lo" + + "git.fsdpf.net/go/db" +) + +type Sqlite3Grammar struct { +} + +var sqlite3DefaultModifiers = []string{"VirtualAs", "StoredAs", "Nullable", "Default", "Increment"} + +var sqlite3Serials = []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"} + +func (this Sqlite3Grammar) CompileTableExists() string { + // 兼容其他数据库查询 + // 其他数据库查看都是同时传入, 数据库名 和 数据表名 + return "select count(*) from sqlite_master where type = 'table' and name = ?||?" +} + +func (this Sqlite3Grammar) CompileColumnListing(table string) string { + // 兼容其他数据库查询 + // 其他数据库查看都是同时传入, 数据库名 和 数据表名 + return "pragma table_info(" + table + ")" +} + +// 创建表 +func (this Sqlite3Grammar) CompileCreate(bp *Blueprint, command *Command, conn *db.Connection) string { + temporary := "create" + if bp.temporary { + temporary = "create temporary" + } + return fmt.Sprintf("%s table %s (%s)", temporary, bp.table, strings.Join(this.getAddedColumns(bp), ", ")) +} + +// 添加字段 +func (this Sqlite3Grammar) CompileAdd(bp *Blueprint, command *Command, conn *db.Connection) []string { + columns := PrefixArray("add column", this.getAddedColumns(bp)) + + return lo.Map(columns, func(column string, _ int) string { + return "alter table `" + bp.table + "` " + column + }) +} + +// 修改字段 +func (this Sqlite3Grammar) CompileChange(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) { + // 创建表的sql + cTableSql := this.getCreateTable(bp, conn) + + // 旧表字段 + oldColumns := []string{} + if _, err := conn.Select("select '`'||name||'`' from pragma_table_info(?)", []any{bp.table}, &oldColumns); err != nil { + panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err)) + } + + // 数据同步 + insert_sql := fmt.Sprintf("insert into `"+bp.table+"` (%s)", strings.Join(oldColumns, ", ")) + select_sql := fmt.Sprintf("select %s from `"+bp.table+"_old`", strings.Join(oldColumns, ", ")) + + for _, column := range bp.getChangedColumns() { + name := column.Name + rename := column.rename + + // 旧字段(查询) -> 新字段(插入) + if column.rename != "" { + select_sql = strings.ReplaceAll(select_sql, name, name+"` as `"+rename) + insert_sql = strings.ReplaceAll(insert_sql, name, rename) + } else { + rename = name + } + + // 修改表字段 + sql := "`" + rename + "` " + this.GetColumnType(column) + sql = this.addModifiers(sql, bp, column) + reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P\\s*)(?P`%s`.*)(?P(,\\s|\\s)?)$", name)) + cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql)) + } + + // 开启事物 + // sqls = append(sqls, "BEGIN TRANSACTION") + // 1. 重命名已存在的表 + sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table)) + // 2. 创建新结构表 + sqls = append(sqls, cTableSql) + // 3. 复制数据 + sqls = append(sqls, insert_sql+" "+select_sql) + // 4. 删除旧表 + sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table)) + // 提交事物 + // sqls = append(sqls, "COMMIT") + + // @todo 复制索引 + + return sqls +} + +// 创建唯一索引 +func (this Sqlite3Grammar) CompileUnique(bp *Blueprint, command *Command, conn *db.Connection) string { + columns := lo.Map(command.Columns, func(column string, _ int) string { + return "`" + column + "`" + }) + return fmt.Sprintf("create unique index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", ")) +} + +// 创建普通索引 +func (this Sqlite3Grammar) CompileIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + columns := lo.Map(command.Columns, func(column string, _ int) string { + return "`" + column + "`" + }) + return fmt.Sprintf("create index `%s` on `%s` (%s)", command.Index, bp.table, strings.Join(columns, ", ")) +} + +// 创建空间索引 +func (this Sqlite3Grammar) CompileSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + panic("Sqlite 不支持 SpatialIndex") +} + +// 删除表 +func (this Sqlite3Grammar) CompileDrop(bp *Blueprint, command *Command, conn *db.Connection) string { + return "drop table `" + bp.table + "`" +} + +// 删除表, 先判断再删除 +func (this Sqlite3Grammar) CompileDropIfExists(bp *Blueprint, command *Command, conn *db.Connection) string { + return "drop table if exists `" + bp.table + "`" +} + +// 删除列 +func (this Sqlite3Grammar) CompileDropColumn(bp *Blueprint, command *Command, conn *db.Connection) (sqls []string) { + // 创建表的sql, 排除不要的字段 + cTableSql := this.getCreateTable(bp, conn, command.Columns...) + + // 旧表字段 + oldColumns := []string{} + + if _, err := conn.Select( + "select '`'||name||'`' from pragma_table_info(?) where `name` not in("+strings.Trim(strings.Repeat("?,", len(command.Columns)), ",")+")", + append([]any{bp.table}, lo.ToAnySlice(command.Columns)...), + &oldColumns); err != nil { + panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err)) + } + + copy_data_sql := fmt.Sprintf("insert into `%s` select %s from `%s_old`", bp.table, strings.Join(oldColumns, ", "), bp.table) + + for _, column := range bp.getChangedColumns() { + name := column.Name + rename := column.rename + + if rename != "" { + copy_data_sql = strings.ReplaceAll(copy_data_sql, "`"+name+"`", "`"+rename+"`") + } + + // 修改表字段 + sql := "`" + rename + "` " + this.GetColumnType(column) + sql = this.addModifiers(sql, bp, column) + reg := regexp.MustCompile(fmt.Sprintf("(?iUm)^(?P\\s*)(?P`%s`.*)(?P(,\\s|\\s)?)$", name)) + cTableSql = reg.ReplaceAllString(cTableSql, fmt.Sprintf("${prefix}%s${end}", sql)) + } + + // 开启事物 + // sqls = append(sqls, "BEGIN TRANSACTION") + // 1. 重命名已存在的表 + sqls = append(sqls, fmt.Sprintf("alter table `%s` rename to `%s_old`", bp.table, bp.table)) + // 2. 创建新结构表 + sqls = append(sqls, cTableSql) + // 3. 复制数据 + sqls = append(sqls, copy_data_sql) + // 4. 删除旧表 + sqls = append(sqls, fmt.Sprintf("drop table `%s_old`", bp.table)) + // 提交事物 + // sqls = append(sqls, "COMMIT") + + // @todo 复制索引 + + return sqls +} + +// 删除唯一索引 +func (this Sqlite3Grammar) CompileDropUnique(bp *Blueprint, command *Command, conn *db.Connection) string { + return this.CompileDropIndex(bp, command, conn) +} + +// 删除索引 +func (this Sqlite3Grammar) CompileDropIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + return "drop index `" + command.Index + "`" +} + +// 删除空间索引 +func (this Sqlite3Grammar) CompileDropSpatialIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + return this.CompileSpatialIndex(bp, command, conn) +} + +func (this Sqlite3Grammar) CompileRename(bp *Blueprint, command *Command, conn *db.Connection) string { + return "alter table `" + bp.table + "` rename to `" + command.To + "`" +} + +func (this Sqlite3Grammar) CompileRenameIndex(bp *Blueprint, command *Command, conn *db.Connection) string { + // 先删除索引, 后创建索引 + return "alter table `" + bp.table + "` rename to `" + command.To + "`" +} + +// 删除所有表 +func (this Sqlite3Grammar) CompileDropAllTables() string { + return "delete from sqlite_master where type in ('table', 'index', 'trigger')" +} + +// 删除所有视图 +func (this Sqlite3Grammar) CompileDropAllViews() string { + return "delete from sqlite_master where type in ('view')" +} + +func (this Sqlite3Grammar) addModifiers(sql string, bp *Blueprint, column *ColumnDefinition) string { + for _, modifier := range mysqlDefaultModifiers { + sql = sql + this.GetColumnModifier(modifier, bp, column) + } + return sql +} + +func (this Sqlite3Grammar) GetColumnModifier(modifier string, bp *Blueprint, column *ColumnDefinition) string { + switch modifier { + case "VirtualAs": + if column.virtualAs != "" { + return fmt.Sprintf(" as (%s)", column.virtualAs) + } + case "StoredAs": + if column.storedAs != "" { + return fmt.Sprintf(" as (%s) stored", column.storedAs) + } + case "Nullable": + if !column.nullable { + return " not null" + } + case "Default": + if column.useCurrent { + return "" + } + switch v := column.def.(type) { + case db.Expression: + return fmt.Sprintf(" default %s", v) + case nil: + default: + return fmt.Sprintf(" default '%v'", v) + } + case "Increment": + if column.autoIncrement && lo.Contains(mysqlSerials, column.Type) { + return " primary key autoincrement" + } + } + + return "" +} + +func (this Sqlite3Grammar) GetColumnType(column *ColumnDefinition) string { + switch column.Type { + case "char", "string": + return fmt.Sprintf("varchar(%d)", column.Length) + case "text": + return "text" + case "year", "integer": + return "integer" + case "bigInteger": + return "bigint" + case "tinyInteger": + return "tinyint" + case "smallInteger": + return "smallint" + case "decimal": + return fmt.Sprintf("numeric(%d,%d)", column.Total, column.Places) + case "boolean": + return "tinyint(1)" + case "enum": + return fmt.Sprintf(`varchar check ("%s" in (%s))`, column.Name, QuoteString(column.Allowed)) + case "json": + return "json" + case "date": + return "date" + case "binary": + return "blob" + case "datetime", "timestamp": + columnType := column.Type + if column.useCurrent { + columnType = columnType + " default CURRENT_TIMESTAMP" + } + return columnType + case "time": + return "time" + case "uuid": + return "varchar(36)" + } + panic("不支持的数据类型: " + column.Type) +} + +// 获取新增表结构字段 +func (this Sqlite3Grammar) getAddedColumns(bp *Blueprint) (columns []string) { + for _, column := range bp.getAddedColumns() { + sql := "`" + column.Name + "` " + this.GetColumnType(column) + columns = append(columns, this.addModifiers(sql, bp, column)) + } + return +} + +// 获取修改表结构字段 +func (this Sqlite3Grammar) getChangedColumns(bp *Blueprint) map[string]string { + columns := map[string]string{} + for _, column := range bp.getChangedColumns() { + sql := "`" + column.Name + "` " + this.GetColumnType(column) + columns[column.Name] = this.addModifiers(sql, bp, column) + } + return columns +} + +// 数据库表字段 +func (this Sqlite3Grammar) getCreateTable(bp *Blueprint, conn *db.Connection, without ...string) (sql string) { + dbColumns := []struct { + Name string `db:"name"` + Type string `db:"type"` + Default db.NullString `db:"dflt_value"` + NotNull bool `db:"notnull"` + Pk bool `db:"pk"` + }{} + + if _, err := conn.Select("select * from pragma_table_info(?) order by cid", []any{bp.table}, &dbColumns); err != nil { + panic(fmt.Errorf("获取 SQLite 建表字段失败, %s", err)) + } + + cols := []string{} + for _, column := range dbColumns { + if lo.Contains(without, column.Name) { + continue + } + col := " `" + column.Name + "`" + col += " " + column.Type + if column.NotNull { + col += " not null" + } + if column.Pk { + col += " primary key autoincrement" + } + if column.Default != "" { + col += " default " + string(column.Default) + } + cols = append(cols, col) + } + + return fmt.Sprintf("CREATE TABLE `%s` (\n%s\n)", bp.table, strings.Join(cols, ", \n")) +} diff --git a/schema/util.go b/schema/util.go new file mode 100644 index 0000000..0b4024e --- /dev/null +++ b/schema/util.go @@ -0,0 +1,25 @@ +package schema + +import ( + "github.com/samber/lo" + "strings" +) + +func PrefixArray(prefix string, values []string) (items []string) { + for _, value := range values { + items = append(items, prefix+" "+value) + } + return items +} + +func QuoteString(value any) string { + switch v := value.(type) { + case []string: + return strings.Join(lo.Map(v, func(item string, _ int) string { + return "'" + item + "'" + }), ", ") + case string: + return "'" + v + "'" + } + return "" +} diff --git a/sqlite3_connector.go b/sqlite3_connector.go new file mode 100644 index 0000000..a20be9b --- /dev/null +++ b/sqlite3_connector.go @@ -0,0 +1,50 @@ +package db + +import ( + "database/sql" + "fmt" + "strings" + + _ "github.com/mattn/go-sqlite3" +) + +type SqliteConnector struct { +} + +func (c SqliteConnector) connect(config *DBConfig) *Connection { + var params []string + + if len(config.Journal) > 0 { + params = append(params, "_journal="+config.Journal) + } + if len(config.Locking) > 0 { + params = append(params, "_locking="+config.Locking) + } + if len(config.Mode) > 0 { + params = append(params, "mode="+config.Mode) + } + if config.Synchronous > 0 { + params = append(params, fmt.Sprintf("_sync=%d", config.Synchronous)) + } + + var dsn string + if len(config.Dsn) > 0 { + dsn = config.Dsn + } else { + dsn = fmt.Sprintf("file:%s?%s", config.File, strings.Join(params, "&")) + } + db, err := sql.Open(DriverSqlite3, dsn) + if err != nil { + panic(err.Error()) + } + err = db.Ping() + if err != nil { + panic(err) + } + + return &Connection{ + RxDB: db, + DB: db, + Config: config, + } +} diff --git a/sqlite3_grammar.go b/sqlite3_grammar.go new file mode 100644 index 0000000..35042e1 --- /dev/null +++ b/sqlite3_grammar.go @@ -0,0 +1,5 @@ +package db + +type Sqlite3Grammar struct { + *MysqlGrammar +} diff --git a/transaction.go b/transaction.go new file mode 100644 index 0000000..b44f3d8 --- /dev/null +++ b/transaction.go @@ -0,0 +1,96 @@ +package db + +import ( + "database/sql" +) + +type Transaction struct { + Tx *sql.Tx + Rx *sql.DB + Config *DBConfig + ConnectionName string +} +type TxClosure func(tx *Transaction) (interface{}, error) + +func (t *Transaction) Table(tableName string) *Builder { + builder := NewTxBuilder(t) + if t.Config.Driver == DriverMysql { + builder.Grammar = &MysqlGrammar{} + } else if t.Config.Driver == DriverSqlite3 { + builder.Grammar = &Sqlite3Grammar{} + } else { + panic("不支持的数据库类型") + } + builder.Grammar.SetTablePrefix(t.Config.Prefix) + builder.Tx = t + builder.Grammar.SetBuilder(builder) + builder.From(tableName) + return builder +} + +func (t *Transaction) Select(query string, bindings []interface{}, dest interface{}) (result sql.Result, err error) { + var stmt *sql.Stmt + var rows *sql.Rows + stmt, err = t.Rx.Prepare(query) + if err != nil { + return + } + defer stmt.Close() + rows, err = stmt.Query(bindings...) + if err != nil { + return + } + defer rows.Close() + + return ScanResult(rows, dest), nil +} +func (t *Transaction) AffectingStatement(query string, bindings []interface{}) (result sql.Result, err error) { + stmt, errP := t.Tx.Prepare(query) + if errP != nil { + err = errP + return + } + defer stmt.Close() + result, err = stmt.Exec(bindings...) + if err != nil { + return + } + return +} +func (t *Transaction) Insert(query string, bindings []interface{}) (sql.Result, error) { + return t.AffectingStatement(query, bindings) +} + +func (t *Transaction) Update(query string, bindings []interface{}) (sql.Result, error) { + return t.AffectingStatement(query, bindings) +} + +func (t *Transaction) Delete(query string, bindings []interface{}) (sql.Result, error) { + return t.AffectingStatement(query, bindings) +} + +func (t *Transaction) Statement(query string, bindings []interface{}) (sql.Result, error) { + return t.AffectingStatement(query, bindings) +} + +func (t *Transaction) Commit() error { + return t.Tx.Commit() +} + +func (t *Transaction) Rollback() error { + return t.Tx.Rollback() +} + +func (t *Transaction) Query() *Builder { + builder := NewTxBuilder(t) + if t.Config.Driver == DriverMysql { + builder.Grammar = &MysqlGrammar{} + } else if t.Config.Driver == DriverSqlite3 { + builder.Grammar = &Sqlite3Grammar{} + } else { + panic("不支持的数据库类型") + } + builder.Grammar.SetBuilder(builder) + builder.Grammar.SetTablePrefix(t.Config.Prefix) + return builder +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..552256c --- /dev/null +++ b/types.go @@ -0,0 +1,43 @@ +package db + +import ( + "database/sql/driver" + "encoding/json" + "fmt" +) + +type Strings []string + +func (c *Strings) Scan(value any) error { + if value == nil { + return nil + } + + return json.Unmarshal(value.([]byte), c) +} + +func (c Strings) Value() (driver.Value, error) { + b, err := json.Marshal(c) + return string(b), err +} + +type NullString string + +func (this *NullString) Scan(value any) error { + if value == nil { + return nil + } + + switch s := value.(type) { + case []byte: + *this = NullString(s) + default: + *this = NullString(fmt.Sprintf("%v", s)) + } + + return nil +} + +func (c NullString) Value() (driver.Value, error) { + return c, nil +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..2164f37 --- /dev/null +++ b/util.go @@ -0,0 +1,77 @@ +package db + +import ( + "reflect" + "regexp" + "strings" +) + +var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") +var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + +func ToSnakeCase(str string) string { + snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}") + snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") + return strings.ToLower(snake) +} +func ExtractStruct(target interface{}) map[string]interface{} { + tv := reflect.Indirect(reflect.ValueOf(target)) + tt := tv.Type() + result := make(map[string]interface{}, tv.NumField()) + + for i := 0; i < tv.NumField(); i++ { + key := ToSnakeCase(tt.Field(i).Name) + result[key] = tv.Field(i).Interface() + } + + return result +} +func InterfaceToSlice(param interface{}) []interface{} { + if p, ok := param.([]interface{}); ok { + return p + } + tv := reflect.Indirect(reflect.ValueOf(param)) + var res []interface{} + if tv.Type().Kind() == reflect.Slice { + for i := 0; i < tv.Len(); i++ { + res = append(res, tv.Index(i).Interface()) + } + } else { + panic("not slice") + } + return res +} + +// addslashes() 函数返回在预定义字符之前添加反斜杠的字符串。 +// 预定义字符是: +// 单引号(') +// 双引号(") +// 反斜杠(\) +func Addslashes(str string) string { + tmpRune := []rune{} + strRune := []rune(str) + for _, ch := range strRune { + switch ch { + case []rune{'\\'}[0], []rune{'"'}[0], []rune{'\''}[0]: + tmpRune = append(tmpRune, []rune{'\\'}[0]) + tmpRune = append(tmpRune, ch) + default: + tmpRune = append(tmpRune, ch) + } + } + return string(tmpRune) +} + +// stripslashes() 函数删除由 addslashes() 函数添加的反斜杠。 +func Stripslashes(str string) string { + dstRune := []rune{} + strRune := []rune(str) + strLenth := len(strRune) + for i := 0; i < strLenth; i++ { + if strRune[i] == []rune{'\\'}[0] { + i++ + } + dstRune = append(dstRune, strRune[i]) + } + return string(dstRune) +}