diff --git a/README.md b/README.md index da55878e..c8c43894 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,9 @@ Xorm is a simple and powerful ORM for Go. -[![CircleCI](https://circleci.com/gh/go-xorm/xorm/tree/master.svg?style=svg)](https://circleci.com/gh/go-xorm/xorm/tree/master) [![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/go-xorm/xorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) - -# Notice - -The last master version is not backwards compatible. You should use `engine.ShowSQL()` and `engine.Logger().SetLevel()` instead of `engine.ShowSQL = `, `engine.ShowInfo = ` and so on. +[![CircleCI](https://circleci.com/gh/go-xorm/xorm.svg?style=shield)](https://circleci.com/gh/go-xorm/xorm) [![codecov](https://codecov.io/gh/go-xorm/xorm/branch/master/graph/badge.svg)](https://codecov.io/gh/go-xorm/xorm) +[![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm) +[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3) # Features @@ -36,7 +34,7 @@ Drivers for Go's sql package which currently support database/sql includes: * Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) -* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) +* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/tree/master/godrv) * Postgres: [github.com/lib/pq](https://github.com/lib/pq) @@ -50,6 +48,14 @@ Drivers for Go's sql package which currently support database/sql includes: # Changelog +* **v0.6.3** + * merge tests to main project + * add `Exist` function + * add `SumInt` function + * Mysql now support read and create column comment. + * fix time related bugs. + * fix some other bugs. + * **v0.6.2** * refactor tag parse methods * add Scan features to Get @@ -62,22 +68,6 @@ methods can use `builder.Cond` as parameter * add Sum, SumInt, SumInt64 and NotIn methods * some bugs fixed -* **v0.5.0** - * logging interface changed - * some bugs fixed - -* **v0.4.5** - * many bugs fixed - * extends support unlimited deepth - * Delete Limit support - -* **v0.4.4** - * ql database expriment support - * tidb database expriment support - * sql.NullString and etc. field support - * select ForUpdate support - * many bugs fixed - [More changes ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16) # Installation @@ -124,7 +114,7 @@ results, err := engine.Query("select * from user") results, err := engine.QueryString("select * from user") ``` -* `Execute` runs a SQL string, it returns `affetcted` and `error` +* `Execute` runs a SQL string, it returns `affected` and `error` ```Go affected, err := engine.Exec("update user set age = ? where name = ?", age, name) @@ -166,6 +156,25 @@ has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice) // SELECT col1, col2, col3 FROM user WHERE id = ? ``` +* Check if one record exist on table + +```Go +has, err := testEngine.Exist(new(RecordExist)) +// SELECT * FROM record_exist LIMIT 1 +has, err = testEngine.Exist(&RecordExist{ + Name: "test1", + }) +// SELECT * FROM record_exist WHERE name = ? LIMIT 1 +has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{}) +// SELECT * FROM record_exist WHERE name = ? LIMIT 1 +has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist() +// select * from record_exist where name = ? +has, err = testEngine.Table("record_exist").Exist() +// SELECT * FROM record_exist LIMIT 1 +has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() +// SELECT * FROM record_exist WHERE name = ? LIMIT 1 +``` + * Query multiple records from database, also you can use join and extends ```Go @@ -258,13 +267,21 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e")) # Cases +* [studygolang](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang) + +* [Gitea](http://gitea.io) - [github.com/go-gitea/gitea](http://github.com/go-gitea/gitea) + +* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs) + +* [grafana](https://grafana.com/) - [github.com/grafana/grafana](http://github.com/grafana/grafana) + * [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader) * [Wego](http://github.com/go-tango/wego) * [Docker.cn](https://docker.cn/) -* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs) +* [Xorm Adapter](https://github.com/casbin/xorm-adapter) for [Casbin](https://github.com/casbin/casbin) - [github.com/casbin/xorm-adapter](https://github.com/casbin/xorm-adapter) * [Gorevel](http://gorevel.cn/) - [github.com/goofcc/gorevel](http://github.com/goofcc/gorevel) diff --git a/README_CN.md b/README_CN.md index 40f5f600..cb2c1799 100644 --- a/README_CN.md +++ b/README_CN.md @@ -4,11 +4,9 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 -[![CircleCI](https://circleci.com/gh/go-xorm/xorm/tree/master.svg?style=svg)](https://circleci.com/gh/go-xorm/xorm/tree/master) [![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/go-xorm/xorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) - -# 注意 - -最新的版本有不兼容的更新,您必须使用 `engine.ShowSQL()` 和 `engine.Logger().SetLevel()` 来替代 `engine.ShowSQL = `, `engine.ShowInfo = ` 等等。 +[![CircleCI](https://circleci.com/gh/go-xorm/xorm.svg?style=shield)](https://circleci.com/gh/go-xorm/xorm) [![codecov](https://codecov.io/gh/go-xorm/xorm/branch/master/graph/badge.svg)](https://codecov.io/gh/go-xorm/xorm) +[![](https://goreportcard.com/badge/github.com/go-xorm/xorm)](https://goreportcard.com/report/github.com/go-xorm/xorm) +[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3) ## 特性 @@ -54,9 +52,18 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 ## 更新日志 +* **v0.6.3** + * 合并单元测试到主工程 + * 新增`Exist`方法 + * 新增`SumInt`方法 + * Mysql新增读取和创建字段注释支持 + * 新增`SetConnMaxLifetime`方法 + * 修正了时间相关的Bug + * 修复了一些其它Bug + * **v0.6.2** * 重构Tag解析方式 - * Get方法新增类似Sacn的特性 + * Get方法新增类似Scan的特性 * 新增 QueryString 方法 * **v0.6.0** @@ -70,18 +77,6 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * logging接口进行不兼容改变 * Bug修正 -* **v0.4.5** - * bug修正 - * extends 支持无限级 - * Delete Limit 支持 - -* **v0.4.4** - * Tidb 数据库支持 - * QL 试验性支持 - * sql.NullString支持 - * ForUpdate 支持 - * bug修正 - [更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16) ## 安装 @@ -170,6 +165,25 @@ has, err := engine.Where("id = ?", id).Cols(cols...).Get(&valuesSlice) // SELECT col1, col2, col3 FROM user WHERE id = ? ``` +* 检测记录是否存在 + +```Go +has, err := testEngine.Exist(new(RecordExist)) +// SELECT * FROM record_exist LIMIT 1 +has, err = testEngine.Exist(&RecordExist{ + Name: "test1", + }) +// SELECT * FROM record_exist WHERE name = ? LIMIT 1 +has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{}) +// SELECT * FROM record_exist WHERE name = ? LIMIT 1 +has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist() +// select * from record_exist where name = ? +has, err = testEngine.Table("record_exist").Exist() +// SELECT * FROM record_exist LIMIT 1 +has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() +// SELECT * FROM record_exist WHERE name = ? LIMIT 1 +``` + * 查询多条记录,当然可以使用Join和extends来组合使用 ```Go @@ -261,13 +275,21 @@ err := engine.Where(builder.NotIn("a", 1, 2).And(builder.In("b", "c", "d", "e")) # 案例 +* [Go语言中文网](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang) + +* [Gitea](http://gitea.io) - [github.com/go-gitea/gitea](http://github.com/go-gitea/gitea) + +* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs) + +* [grafana](https://grafana.com/) - [github.com/grafana/grafana](http://github.com/grafana/grafana) + * [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader) * [Wego](http://github.com/go-tango/wego) * [Docker.cn](https://docker.cn/) -* [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs) +* [Xorm Adapter](https://github.com/casbin/xorm-adapter) for [Casbin](https://github.com/casbin/casbin) - [github.com/casbin/xorm-adapter](https://github.com/casbin/xorm-adapter) * [Gowalker](http://gowalker.org) - [github.com/Unknwon/gowalker](http://github.com/Unknwon/gowalker) diff --git a/lru_cacher.go b/cache_lru.go similarity index 90% rename from lru_cacher.go rename to cache_lru.go index 4a745043..c9672ceb 100644 --- a/lru_cacher.go +++ b/cache_lru.go @@ -15,13 +15,12 @@ import ( // LRUCacher implments cache object facilities type LRUCacher struct { - idList *list.List - sqlList *list.List - idIndex map[string]map[string]*list.Element - sqlIndex map[string]map[string]*list.Element - store core.CacheStore - mutex sync.Mutex - // maxSize int + idList *list.List + sqlList *list.List + idIndex map[string]map[string]*list.Element + sqlIndex map[string]map[string]*list.Element + store core.CacheStore + mutex sync.Mutex MaxElementSize int Expired time.Duration GcInterval time.Duration @@ -54,8 +53,6 @@ func (m *LRUCacher) RunGC() { // GC check ids lit and sql list to remove all element expired func (m *LRUCacher) GC() { - //fmt.Println("begin gc ...") - //defer fmt.Println("end gc ...") m.mutex.Lock() defer m.mutex.Unlock() var removedNum int @@ -64,12 +61,10 @@ func (m *LRUCacher) GC() { time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { removedNum++ next := e.Next() - //fmt.Println("removing ...", e.Value) node := e.Value.(*idNode) m.delBean(node.tbName, node.id) e = next } else { - //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len()) break } } @@ -80,12 +75,10 @@ func (m *LRUCacher) GC() { time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { removedNum++ next := e.Next() - //fmt.Println("removing ...", e.Value) node := e.Value.(*sqlNode) m.delIds(node.tbName, node.sql) e = next } else { - //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len()) break } } @@ -116,7 +109,6 @@ func (m *LRUCacher) GetIds(tableName, sql string) interface{} { } m.delIds(tableName, sql) - return nil } @@ -134,7 +126,6 @@ func (m *LRUCacher) GetBean(tableName string, id string) interface{} { // if expired, remove the node and return nil if time.Now().Sub(lastTime) > m.Expired { m.delBean(tableName, id) - //m.clearIds(tableName) return nil } m.idList.MoveToBack(el) @@ -148,7 +139,6 @@ func (m *LRUCacher) GetBean(tableName string, id string) interface{} { // store bean is not exist, then remove memory's index m.delBean(tableName, id) - //m.clearIds(tableName) return nil } @@ -166,8 +156,8 @@ func (m *LRUCacher) clearIds(tableName string) { // ClearIds clears all sql-ids mapping on table tableName from cache func (m *LRUCacher) ClearIds(tableName string) { m.mutex.Lock() - defer m.mutex.Unlock() m.clearIds(tableName) + m.mutex.Unlock() } func (m *LRUCacher) clearBeans(tableName string) { @@ -184,14 +174,13 @@ func (m *LRUCacher) clearBeans(tableName string) { // ClearBeans clears all beans in some table func (m *LRUCacher) ClearBeans(tableName string) { m.mutex.Lock() - defer m.mutex.Unlock() m.clearBeans(tableName) + m.mutex.Unlock() } // PutIds pus ids into table func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { m.mutex.Lock() - defer m.mutex.Unlock() if _, ok := m.sqlIndex[tableName]; !ok { m.sqlIndex[tableName] = make(map[string]*list.Element) } @@ -207,12 +196,12 @@ func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { node := e.Value.(*sqlNode) m.delIds(node.tbName, node.sql) } + m.mutex.Unlock() } // PutBean puts beans into table func (m *LRUCacher) PutBean(tableName string, id string, obj interface{}) { m.mutex.Lock() - defer m.mutex.Unlock() var el *list.Element var ok bool @@ -229,6 +218,7 @@ func (m *LRUCacher) PutBean(tableName string, id string, obj interface{}) { node := e.Value.(*idNode) m.delBean(node.tbName, node.id) } + m.mutex.Unlock() } func (m *LRUCacher) delIds(tableName, sql string) { @@ -244,8 +234,8 @@ func (m *LRUCacher) delIds(tableName, sql string) { // DelIds deletes ids func (m *LRUCacher) DelIds(tableName, sql string) { m.mutex.Lock() - defer m.mutex.Unlock() m.delIds(tableName, sql) + m.mutex.Unlock() } func (m *LRUCacher) delBean(tableName string, id string) { @@ -261,8 +251,8 @@ func (m *LRUCacher) delBean(tableName string, id string) { // DelBean deletes beans in some table func (m *LRUCacher) DelBean(tableName string, id string) { m.mutex.Lock() - defer m.mutex.Unlock() m.delBean(tableName, id) + m.mutex.Unlock() } type idNode struct { diff --git a/cache_lru_test.go b/cache_lru_test.go new file mode 100644 index 00000000..28854474 --- /dev/null +++ b/cache_lru_test.go @@ -0,0 +1,52 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/go-xorm/core" + "github.com/stretchr/testify/assert" +) + +func TestLRUCache(t *testing.T) { + type CacheObject1 struct { + Id int64 + } + + store := NewMemoryStore() + cacher := NewLRUCacher(store, 10000) + + tableName := "cache_object1" + pks := []core.PK{ + {1}, + {2}, + } + + for _, pk := range pks { + sid, err := pk.ToString() + assert.NoError(t, err) + + cacher.PutIds(tableName, "select * from cache_object1", sid) + ids := cacher.GetIds(tableName, "select * from cache_object1") + assert.EqualValues(t, sid, ids) + + cacher.ClearIds(tableName) + ids2 := cacher.GetIds(tableName, "select * from cache_object1") + assert.Nil(t, ids2) + + obj2 := cacher.GetBean(tableName, sid) + assert.Nil(t, obj2) + + var obj = new(CacheObject1) + cacher.PutBean(tableName, sid, obj) + obj3 := cacher.GetBean(tableName, sid) + assert.EqualValues(t, obj, obj3) + + cacher.DelBean(tableName, sid) + obj4 := cacher.GetBean(tableName, sid) + assert.Nil(t, obj4) + } +} diff --git a/memory_store.go b/cache_memory_store.go similarity index 100% rename from memory_store.go rename to cache_memory_store.go diff --git a/cache_memory_store_test.go b/cache_memory_store_test.go new file mode 100644 index 00000000..fc27ae32 --- /dev/null +++ b/cache_memory_store_test.go @@ -0,0 +1,37 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMemoryStore(t *testing.T) { + store := NewMemoryStore() + var kvs = map[string]interface{}{ + "a": "b", + } + for k, v := range kvs { + assert.NoError(t, store.Put(k, v)) + } + + for k, v := range kvs { + val, err := store.Get(k) + assert.NoError(t, err) + assert.EqualValues(t, v, val) + } + + for k, _ := range kvs { + err := store.Del(k) + assert.NoError(t, err) + } + + for k, _ := range kvs { + _, err := store.Get(k) + assert.EqualValues(t, ErrNotExist, err) + } +} diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 00000000..5f138f24 --- /dev/null +++ b/cache_test.go @@ -0,0 +1,179 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCacheFind(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type MailBox struct { + Id int64 `xorm:"pk"` + Username string + Password string + } + + oldCacher := testEngine.Cacher + cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) + testEngine.SetDefaultCacher(cacher) + + assert.NoError(t, testEngine.Sync2(new(MailBox))) + + var inserts = []*MailBox{ + { + Id: 0, + Username: "user1", + Password: "pass1", + }, + { + Id: 1, + Username: "user2", + Password: "pass2", + }, + } + _, err := testEngine.Insert(inserts[0], inserts[1]) + assert.NoError(t, err) + + var boxes []MailBox + assert.NoError(t, testEngine.Find(&boxes)) + assert.EqualValues(t, 2, len(boxes)) + for i, box := range boxes { + assert.Equal(t, inserts[i].Id, box.Id) + assert.Equal(t, inserts[i].Username, box.Username) + assert.Equal(t, inserts[i].Password, box.Password) + } + + boxes = make([]MailBox, 0, 2) + assert.NoError(t, testEngine.Find(&boxes)) + assert.EqualValues(t, 2, len(boxes)) + for i, box := range boxes { + assert.Equal(t, inserts[i].Id, box.Id) + assert.Equal(t, inserts[i].Username, box.Username) + assert.Equal(t, inserts[i].Password, box.Password) + } + + boxes = make([]MailBox, 0, 2) + assert.NoError(t, testEngine.Alias("a").Where("a.id > -1").Asc("a.id").Find(&boxes)) + assert.EqualValues(t, 2, len(boxes)) + for i, box := range boxes { + assert.Equal(t, inserts[i].Id, box.Id) + assert.Equal(t, inserts[i].Username, box.Username) + assert.Equal(t, inserts[i].Password, box.Password) + } + + type MailBox4 struct { + Id int64 + Username string + Password string + } + + boxes2 := make([]MailBox4, 0, 2) + assert.NoError(t, testEngine.Table("mail_box").Where("mail_box.id > -1").Asc("mail_box.id").Find(&boxes2)) + assert.EqualValues(t, 2, len(boxes2)) + for i, box := range boxes2 { + assert.Equal(t, inserts[i].Id, box.Id) + assert.Equal(t, inserts[i].Username, box.Username) + assert.Equal(t, inserts[i].Password, box.Password) + } + + testEngine.SetDefaultCacher(oldCacher) +} + +func TestCacheFind2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type MailBox2 struct { + Id uint64 `xorm:"pk"` + Username string + Password string + } + + oldCacher := testEngine.Cacher + cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) + testEngine.SetDefaultCacher(cacher) + + assert.NoError(t, testEngine.Sync2(new(MailBox2))) + + var inserts = []*MailBox2{ + { + Id: 0, + Username: "user1", + Password: "pass1", + }, + { + Id: 1, + Username: "user2", + Password: "pass2", + }, + } + _, err := testEngine.Insert(inserts[0], inserts[1]) + assert.NoError(t, err) + + var boxes []MailBox2 + assert.NoError(t, testEngine.Find(&boxes)) + assert.EqualValues(t, 2, len(boxes)) + for i, box := range boxes { + assert.Equal(t, inserts[i].Id, box.Id) + assert.Equal(t, inserts[i].Username, box.Username) + assert.Equal(t, inserts[i].Password, box.Password) + } + + boxes = make([]MailBox2, 0, 2) + assert.NoError(t, testEngine.Find(&boxes)) + assert.EqualValues(t, 2, len(boxes)) + for i, box := range boxes { + assert.Equal(t, inserts[i].Id, box.Id) + assert.Equal(t, inserts[i].Username, box.Username) + assert.Equal(t, inserts[i].Password, box.Password) + } + + testEngine.SetDefaultCacher(oldCacher) +} + +func TestCacheGet(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type MailBox3 struct { + Id uint64 + Username string + Password string + } + + oldCacher := testEngine.Cacher + cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) + testEngine.SetDefaultCacher(cacher) + + assert.NoError(t, testEngine.Sync2(new(MailBox3))) + + var inserts = []*MailBox3{ + { + Username: "user1", + Password: "pass1", + }, + } + _, err := testEngine.Insert(inserts[0]) + assert.NoError(t, err) + + var box1 MailBox3 + has, err := testEngine.Where("id = ?", inserts[0].Id).Get(&box1) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "user1", box1.Username) + assert.EqualValues(t, "pass1", box1.Password) + + var box2 MailBox3 + has, err = testEngine.Where("id = ?", inserts[0].Id).Get(&box2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "user1", box2.Username) + assert.EqualValues(t, "pass1", box2.Password) + + testEngine.SetDefaultCacher(oldCacher) +} diff --git a/circle.yml b/circle.yml index 8faa627d..69fc7164 100644 --- a/circle.yml +++ b/circle.yml @@ -3,6 +3,8 @@ dependencies: # './...' is a relative pattern which means all subdirectories - go get -t -d -v ./... - go get -t -d -v github.com/go-xorm/tests + - go get -u github.com/go-xorm/core + - go get -u github.com/go-xorm/builder - go build -v database: @@ -19,7 +21,18 @@ database: test: override: # './...' is a relative pattern which means all subdirectories - - go test -v -race + - go get -u github.com/wadey/gocovmerge; + - go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic + - go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic + - go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic + - go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic + - go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -coverprofile=coverage3-1.txt -covermode=atomic + - go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic + - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic + - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic + - gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt > coverage.txt - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh - - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh \ No newline at end of file + - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh + post: + - bash <(curl -s https://codecov.io/bash) \ No newline at end of file diff --git a/convert.go b/convert.go index fbd24b5b..0504bef1 100644 --- a/convert.go +++ b/convert.go @@ -334,3 +334,15 @@ func convertInt(v interface{}) (int64, error) { } return 0, fmt.Errorf("unsupported type: %v", v) } + +func asBool(bs []byte) (bool, error) { + if len(bs) == 0 { + return false, nil + } + if bs[0] == 0x00 { + return false, nil + } else if bs[0] == 0x01 { + return true, nil + } + return strconv.ParseBool(string(bs)) +} diff --git a/dialect_mssql.go b/dialect_mssql.go index f83cfc17..6d2291dc 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -215,7 +215,7 @@ func (db *mssql) SqlType(c *core.Column) string { var res string switch t := c.SQLType.Name; t { case core.Bool: - res = core.TinyInt + res = core.Bit if strings.EqualFold(c.Default, "true") { c.Default = "1" } else { @@ -250,6 +250,9 @@ func (db *mssql) SqlType(c *core.Column) string { case core.Uuid: res = core.Varchar c.Length = 40 + case core.TinyInt: + res = core.TinyInt + c.Length = 0 default: res = t } @@ -335,9 +338,15 @@ func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{} s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, - replace(replace(isnull(c.text,''),'(',''),')','') as vdefault - from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id - left join sys.syscomments c on a.default_object_id=c.id + replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, + ISNULL(i.is_primary_key, 0) + from sys.columns a + left join sys.types b on a.user_type_id=b.user_type_id + left join sys.syscomments c on a.default_object_id=c.id + LEFT OUTER JOIN + sys.index_columns ic ON ic.object_id = a.object_id AND ic.column_id = a.column_id + LEFT OUTER JOIN + sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id where a.object_id=object_id('` + tableName + `')` db.LogSQL(s, args) @@ -352,8 +361,8 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column for rows.Next() { var name, ctype, vdefault string var maxLen, precision, scale int - var nullable bool - err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &vdefault) + var nullable, isPK bool + err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &vdefault, &isPK) if err != nil { return nil, nil, err } @@ -363,6 +372,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column col.Name = strings.Trim(name, "` ") col.Nullable = nullable col.Default = vdefault + col.IsPrimaryKey = isPK ct := strings.ToUpper(ctype) if ct == "DECIMAL" { col.Length = precision @@ -536,7 +546,6 @@ type odbcDriver struct { func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { kv := strings.Split(dataSourceName, ";") var dbName string - for _, c := range kv { vv := strings.Split(strings.TrimSpace(c), "=") if len(vv) == 2 { diff --git a/dialect_mysql.go b/dialect_mysql.go index 55cfdd76..99100b23 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -299,7 +299,7 @@ func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{db.DbName, tableName} s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + - " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + " `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -314,13 +314,14 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column col := new(core.Column) col.Indexes = make(map[string]int) - var columnName, isNullable, colType, colKey, extra string + var columnName, isNullable, colType, colKey, extra, comment string var colDefault *string - err = rows.Scan(&columnName, &isNullable, &colDefault, &colType, &colKey, &extra) + err = rows.Scan(&columnName, &isNullable, &colDefault, &colType, &colKey, &extra, &comment) if err != nil { return nil, nil, err } col.Name = strings.Trim(columnName, "` ") + col.Comment = comment if "YES" == isNullable { col.Nullable = true } @@ -407,7 +408,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column func (db *mysql) GetTables() ([]*core.Table, error) { args := []interface{}{db.DbName} - s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from " + + s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" db.LogSQL(s, args) @@ -420,14 +421,15 @@ func (db *mysql) GetTables() ([]*core.Table, error) { tables := make([]*core.Table, 0) for rows.Next() { table := core.NewEmptyTable() - var name, engine, tableRows string + var name, engine, tableRows, comment string var autoIncr *string - err = rows.Scan(&name, &engine, &tableRows, &autoIncr) + err = rows.Scan(&name, &engine, &tableRows, &autoIncr, &comment) if err != nil { return nil, err } table.Name = name + table.Comment = comment table.StoreEngine = engine tables = append(tables, table) } diff --git a/dialect_postgres.go b/dialect_postgres.go index 1d4daa27..3f5c526f 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -781,6 +781,9 @@ func (db *postgres) SqlType(c *core.Column) string { case core.TinyInt: res = core.SmallInt return res + case core.Bit: + res = core.Boolean + return res case core.MediumInt, core.Int, core.Integer: if c.IsAutoIncrement { return core.Serial diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index c13fd02b..a55b1615 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -14,10 +14,6 @@ import ( "github.com/go-xorm/core" ) -// func init() { -// RegisterDialect("sqlite3", &sqlite3{}) -// } - var ( sqlite3ReservedWords = map[string]bool{ "ABORT": true, @@ -310,11 +306,25 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu for _, colStr := range colCreates { reg = regexp.MustCompile(`,\s`) colStr = reg.ReplaceAllString(colStr, ",") + if strings.HasPrefix(strings.TrimSpace(colStr), "PRIMARY KEY") { + parts := strings.Split(strings.TrimSpace(colStr), "(") + if len(parts) == 2 { + pkCols := strings.Split(strings.TrimRight(strings.TrimSpace(parts[1]), ")"), ",") + for _, pk := range pkCols { + if col, ok := cols[strings.Trim(strings.TrimSpace(pk), "`")]; ok { + col.IsPrimaryKey = true + } + } + } + continue + } + fields := strings.Fields(strings.TrimSpace(colStr)) col := new(core.Column) col.Indexes = make(map[string]int) col.Nullable = true col.DefaultIsEmpty = true + for idx, field := range fields { if idx == 0 { col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`) diff --git a/doc.go b/doc.go index 5b36fcd8..a687e694 100644 --- a/doc.go +++ b/doc.go @@ -8,7 +8,7 @@ Package xorm is a simple and powerful ORM for Go. Installation -Make sure you have installed Go 1.1+ and then: +Make sure you have installed Go 1.6+ and then: go get github.com/go-xorm/xorm @@ -51,11 +51,15 @@ There are 8 major ORM methods and many helpful methods to use to operate databas // INSERT INTO struct1 () values () // INSERT INTO struct2 () values (),(),() -2. Query one record from database +2. Query one record or one variable from database has, err := engine.Get(&user) // SELECT * FROM user LIMIT 1 + var id int64 + has, err := engine.Table("user").Where("name = ?", name).Get(&id) + // SELECT id FROM user WHERE name = ? LIMIT 1 + 3. Query multiple records from database var sliceOfStructs []Struct @@ -86,7 +90,7 @@ another is Rows 5. Update one or more records - affected, err := engine.Id(...).Update(&user) + affected, err := engine.ID(...).Update(&user) // UPDATE user SET ... 6. Delete one or more records, Delete MUST has condition @@ -99,6 +103,9 @@ another is Rows counts, err := engine.Count(&user) // SELECT count(*) AS total FROM user + counts, err := engine.SQL("select count(*) FROM user").Count() + // select count(*) FROM user + 8. Sum records sumFloat64, err := engine.Sum(&user, "id") diff --git a/engine.go b/engine.go index a788c117..17d16063 100644 --- a/engine.go +++ b/engine.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/go-xorm/builder" "github.com/go-xorm/core" ) @@ -40,7 +41,7 @@ type Engine struct { showExecTime bool logger core.ILogger - TZLocation *time.Location + TZLocation *time.Location // The timezone of the application DatabaseTZ *time.Location // The timezone of the database disableGlobalCache bool @@ -143,7 +144,6 @@ func (engine *Engine) Quote(value string) string { // QuoteTo quotes string and writes into the buffer func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) { - if buf == nil { return } @@ -169,7 +169,7 @@ func (engine *Engine) quote(sql string) string { return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() } -// SqlType will be depracated, please use SQLType instead +// SqlType will be deprecated, please use SQLType instead // // Deprecated: use SQLType instead func (engine *Engine) SqlType(c *core.Column) string { @@ -205,14 +205,14 @@ func (engine *Engine) SetDefaultCacher(cacher core.Cacher) { // you can use NoCache() func (engine *Engine) NoCache() *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.NoCache() } // NoCascade If you do not want to auto cascade load object func (engine *Engine) NoCascade() *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.NoCascade() } @@ -245,7 +245,7 @@ func (engine *Engine) Dialect() core.Dialect { // NewSession New a session func (engine *Engine) NewSession() *Session { - session := &Session{Engine: engine} + session := &Session{engine: engine} session.Init() return session } @@ -259,7 +259,6 @@ func (engine *Engine) Close() error { func (engine *Engine) Ping() error { session := engine.NewSession() defer session.Close() - engine.logger.Infof("PING DATABASE %v", engine.DriverName()) return session.Ping() } @@ -267,43 +266,13 @@ func (engine *Engine) Ping() error { func (engine *Engine) logSQL(sqlStr string, sqlArgs ...interface{}) { if engine.showSQL && !engine.showExecTime { if len(sqlArgs) > 0 { - engine.logger.Infof("[SQL] %v %v", sqlStr, sqlArgs) + engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs) } else { engine.logger.Infof("[SQL] %v", sqlStr) } } } -func (engine *Engine) logSQLQueryTime(sqlStr string, args []interface{}, executionBlock func() (*core.Stmt, *core.Rows, error)) (*core.Stmt, *core.Rows, error) { - if engine.showSQL && engine.showExecTime { - b4ExecTime := time.Now() - stmt, res, err := executionBlock() - execDuration := time.Since(b4ExecTime) - if len(args) > 0 { - engine.logger.Infof("[SQL] %s %v - took: %v", sqlStr, args, execDuration) - } else { - engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration) - } - return stmt, res, err - } - return executionBlock() -} - -func (engine *Engine) logSQLExecutionTime(sqlStr string, args []interface{}, executionBlock func() (sql.Result, error)) (sql.Result, error) { - if engine.showSQL && engine.showExecTime { - b4ExecTime := time.Now() - res, err := executionBlock() - execDuration := time.Since(b4ExecTime) - if len(args) > 0 { - engine.logger.Infof("[sql] %s [args] %v - took: %v", sqlStr, args, execDuration) - } else { - engine.logger.Infof("[sql] %s - took: %v", sqlStr, execDuration) - } - return res, err - } - return executionBlock() -} - // Sql provides raw sql input parameter. When you have a complex SQL statement // and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. // @@ -320,7 +289,7 @@ func (engine *Engine) Sql(querystring string, args ...interface{}) *Session { // This code will execute "select * from user" and set the records to users func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.SQL(query, args...) } @@ -329,14 +298,14 @@ func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session { // invoked. Call NoAutoTime if you dont' want to fill automatically. func (engine *Engine) NoAutoTime() *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.NoAutoTime() } // NoAutoCondition disable auto generate Where condition from bean or not func (engine *Engine) NoAutoCondition(no ...bool) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.NoAutoCondition(no...) } @@ -570,56 +539,56 @@ func (engine *Engine) tbName(v reflect.Value) string { // Cascade use cascade or not func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Cascade(trueOrFalse...) } // Where method provide a condition query func (engine *Engine) Where(query interface{}, args ...interface{}) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Where(query, args...) } -// Id will be depracated, please use ID instead +// Id will be deprecated, please use ID instead func (engine *Engine) Id(id interface{}) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Id(id) } // ID method provoide a condition as (id) = ? func (engine *Engine) ID(id interface{}) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.ID(id) } // Before apply before Processor, affected bean is passed to closure arg func (engine *Engine) Before(closures func(interface{})) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Before(closures) } // After apply after insert Processor, affected bean is passed to closure arg func (engine *Engine) After(closures func(interface{})) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.After(closures) } // Charset set charset when create table, only support mysql now func (engine *Engine) Charset(charset string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Charset(charset) } // StoreEngine set store engine when create table, only support mysql now func (engine *Engine) StoreEngine(storeEngine string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.StoreEngine(storeEngine) } @@ -628,35 +597,35 @@ func (engine *Engine) StoreEngine(storeEngine string) *Session { // but distinct will not provide id func (engine *Engine) Distinct(columns ...string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Distinct(columns...) } // Select customerize your select columns or contents func (engine *Engine) Select(str string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Select(str) } // Cols only use the parameters as select or update columns func (engine *Engine) Cols(columns ...string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Cols(columns...) } // AllCols indicates that all columns should be use func (engine *Engine) AllCols() *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.AllCols() } // MustCols specify some columns must use even if they are empty func (engine *Engine) MustCols(columns ...string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.MustCols(columns...) } @@ -667,77 +636,84 @@ func (engine *Engine) MustCols(columns ...string) *Session { // it will use parameters's columns func (engine *Engine) UseBool(columns ...string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.UseBool(columns...) } // Omit only not use the parameters as select or update columns func (engine *Engine) Omit(columns ...string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Omit(columns...) } // Nullable set null when column is zero-value and nullable for update func (engine *Engine) Nullable(columns ...string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Nullable(columns...) } // In will generate "column IN (?, ?)" func (engine *Engine) In(column string, args ...interface{}) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.In(column, args...) } +// NotIn will generate "column NOT IN (?, ?)" +func (engine *Engine) NotIn(column string, args ...interface{}) *Session { + session := engine.NewSession() + session.isAutoClose = true + return session.NotIn(column, args...) +} + // Incr provides a update string like "column = column + ?" func (engine *Engine) Incr(column string, arg ...interface{}) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Incr(column, arg...) } // Decr provides a update string like "column = column - ?" func (engine *Engine) Decr(column string, arg ...interface{}) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Decr(column, arg...) } // SetExpr provides a update string like "column = {expression}" func (engine *Engine) SetExpr(column string, expression string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.SetExpr(column, expression) } // Table temporarily change the Get, Find, Update's table func (engine *Engine) Table(tableNameOrBean interface{}) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Table(tableNameOrBean) } // Alias set the table alias func (engine *Engine) Alias(alias string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Alias(alias) } // Limit will generate "LIMIT start, limit" func (engine *Engine) Limit(limit int, start ...int) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Limit(limit, start...) } // Desc will generate "ORDER BY column1 DESC, column2 DESC" func (engine *Engine) Desc(colNames ...string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Desc(colNames...) } @@ -749,38 +725,44 @@ func (engine *Engine) Desc(colNames ...string) *Session { // func (engine *Engine) Asc(colNames ...string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Asc(colNames...) } // OrderBy will generate "ORDER BY order" func (engine *Engine) OrderBy(order string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.OrderBy(order) } // Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Join(joinOperator, tablename, condition, args...) } // GroupBy generate group by statement func (engine *Engine) GroupBy(keys string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.GroupBy(keys) } // Having generate having statement func (engine *Engine) Having(conditions string) *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Having(conditions) } +func (engine *Engine) unMapType(t reflect.Type) { + engine.mutex.Lock() + defer engine.mutex.Unlock() + delete(engine.Tables, t) +} + func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) { t := v.Type() engine.mutex.Lock() @@ -1007,6 +989,10 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) + + if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { + idFieldColName = col.Name + } } if col.IsAutoIncrement { col.Nullable = false @@ -1014,9 +1000,6 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { table.AddColumn(col) - if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { - idFieldColName = col.Name - } } // end for if idFieldColName != "" && len(table.PrimaryKeys) == 0 { @@ -1097,19 +1080,39 @@ func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) { pk := make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { + var err error pkField := v.FieldByName(col.FieldName) switch pkField.Kind() { case reflect.String: - pk[i] = pkField.String() + pk[i], err = engine.idTypeAssertion(col, pkField.String()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - pk[i] = pkField.Int() + pk[i], err = engine.idTypeAssertion(col, strconv.FormatInt(pkField.Int(), 10)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - pk[i] = pkField.Uint() + // id of uint will be converted to int64 + pk[i], err = engine.idTypeAssertion(col, strconv.FormatUint(pkField.Uint(), 10)) + } + + if err != nil { + return nil, err } } return core.PK(pk), nil } +func (engine *Engine) idTypeAssertion(col *core.Column, sid string) (interface{}, error) { + if col.SQLType.IsNumeric() { + n, err := strconv.ParseInt(sid, 10, 64) + if err != nil { + return nil, err + } + return n, nil + } else if col.SQLType.IsText() { + return sid, nil + } else { + return nil, errors.New("not supported") + } +} + // CreateIndexes create indexes func (engine *Engine) CreateIndexes(bean interface{}) error { session := engine.NewSession() @@ -1181,6 +1184,9 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { // table, column, index, unique. but will not delete or change anything. // If you change some field, you should change the database manually. func (engine *Engine) Sync(beans ...interface{}) error { + session := engine.NewSession() + defer session.Close() + for _, bean := range beans { v := rValue(bean) tableName := engine.tbName(v) @@ -1189,14 +1195,12 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } - s := engine.NewSession() - defer s.Close() - isExist, err := s.Table(bean).isTableExist(tableName) + isExist, err := session.Table(bean).isTableExist(tableName) if err != nil { return err } if !isExist { - err = engine.CreateTables(bean) + err = session.createTable(bean) if err != nil { return err } @@ -1207,11 +1211,11 @@ func (engine *Engine) Sync(beans ...interface{}) error { }*/ var isEmpty bool if isEmpty { - err = engine.DropTables(bean) + err = session.dropTable(bean) if err != nil { return err } - err = engine.CreateTables(bean) + err = session.createTable(bean) if err != nil { return err } @@ -1222,9 +1226,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - session := engine.NewSession() - defer session.Close() - if err := session.Statement.setRefValue(v); err != nil { + if err := session.statement.setRefValue(v); err != nil { return err } err = session.addColumn(col.Name) @@ -1235,21 +1237,16 @@ func (engine *Engine) Sync(beans ...interface{}) error { } for name, index := range table.Indexes { - session := engine.NewSession() - defer session.Close() - if err := session.Statement.setRefValue(v); err != nil { + if err := session.statement.setRefValue(v); err != nil { return err } if index.Type == core.UniqueType { - //isExist, err := session.isIndexExist(table.Name, name, true) isExist, err := session.isIndexExist2(tableName, index.Cols, true) if err != nil { return err } if !isExist { - session := engine.NewSession() - defer session.Close() - if err := session.Statement.setRefValue(v); err != nil { + if err := session.statement.setRefValue(v); err != nil { return err } @@ -1264,9 +1261,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - session := engine.NewSession() - defer session.Close() - if err := session.Statement.setRefValue(v); err != nil { + if err := session.statement.setRefValue(v); err != nil { return err } @@ -1291,23 +1286,6 @@ func (engine *Engine) Sync2(beans ...interface{}) error { return s.Sync2(beans...) } -// Drop all mapped table -func (engine *Engine) dropAll() error { - session := engine.NewSession() - defer session.Close() - - err := session.Begin() - if err != nil { - return err - } - err = session.dropAll() - if err != nil { - session.Rollback() - return err - } - return session.Commit() -} - // CreateTables create tabls according bean func (engine *Engine) CreateTables(beans ...interface{}) error { session := engine.NewSession() @@ -1319,7 +1297,7 @@ func (engine *Engine) CreateTables(beans ...interface{}) error { } for _, bean := range beans { - err = session.CreateTable(bean) + err = session.createTable(bean) if err != nil { session.Rollback() return err @@ -1339,7 +1317,7 @@ func (engine *Engine) DropTables(beans ...interface{}) error { } for _, bean := range beans { - err = session.DropTable(bean) + err = session.dropTable(bean) if err != nil { session.Rollback() return err @@ -1348,10 +1326,11 @@ func (engine *Engine) DropTables(beans ...interface{}) error { return session.Commit() } -func (engine *Engine) createAll() error { +// DropIndexes drop indexes of a table +func (engine *Engine) DropIndexes(bean interface{}) error { session := engine.NewSession() defer session.Close() - return session.createAll() + return session.DropIndexes(bean) } // Exec raw sql @@ -1375,6 +1354,13 @@ func (engine *Engine) QueryString(sqlStr string, args ...interface{}) ([]map[str return session.QueryString(sqlStr, args...) } +// QueryInterface runs a raw sql and return records as []map[string]interface{} +func (engine *Engine) QueryInterface(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) { + session := engine.NewSession() + defer session.Close() + return session.QueryInterface(sqlStr, args...) +} + // Insert one or more records func (engine *Engine) Insert(beans ...interface{}) (int64, error) { session := engine.NewSession() @@ -1416,6 +1402,13 @@ func (engine *Engine) Get(bean interface{}) (bool, error) { return session.Get(bean) } +// Exist returns true if the record exist otherwise return false +func (engine *Engine) Exist(bean ...interface{}) (bool, error) { + session := engine.NewSession() + defer session.Close() + return session.Exist(bean...) +} + // Find retrieve records from table, condiBeans's non-empty fields // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct @@ -1441,10 +1434,10 @@ func (engine *Engine) Rows(bean interface{}) (*Rows, error) { } // Count counts the records. bean's non-empty fields are conditions. -func (engine *Engine) Count(bean interface{}) (int64, error) { +func (engine *Engine) Count(bean ...interface{}) (int64, error) { session := engine.NewSession() defer session.Close() - return session.Count(bean) + return session.Count(bean...) } // Sum sum the records by some column. bean's non-empty fields are conditions. @@ -1454,6 +1447,13 @@ func (engine *Engine) Sum(bean interface{}, colName string) (float64, error) { return session.Sum(bean, colName) } +// SumInt sum the records by some column. bean's non-empty fields are conditions. +func (engine *Engine) SumInt(bean interface{}, colName string) (int64, error) { + session := engine.NewSession() + defer session.Close() + return session.SumInt(bean, colName) +} + // Sums sum the records by some columns. bean's non-empty fields are conditions. func (engine *Engine) Sums(bean interface{}, colNames ...string) ([]float64, error) { session := engine.NewSession() @@ -1509,7 +1509,6 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { results = append(results, result) if err != nil { return nil, err - //lastError = err } } } @@ -1517,49 +1516,28 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { return results, lastError } -// TZTime change one time to xorm time location -func (engine *Engine) TZTime(t time.Time) time.Time { - if !t.IsZero() { // if time is not initialized it's not suitable for Time.In() - return t.In(engine.TZLocation) - } - return t -} - -// NowTime return current time -func (engine *Engine) NowTime(sqlTypeName string) interface{} { - t := time.Now() - return engine.FormatTime(sqlTypeName, t) -} - // NowTime2 return current time func (engine *Engine) NowTime2(sqlTypeName string) (interface{}, time.Time) { t := time.Now() - return engine.FormatTime(sqlTypeName, t), t -} - -// FormatTime format time -func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}) { - return engine.formatTime(engine.TZLocation, sqlTypeName, t) + return engine.formatTime(sqlTypeName, t.In(engine.DatabaseTZ)), t.In(engine.TZLocation) } func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{}) { - if col.DisableTimeZone { - return engine.formatTime(nil, col.SQLType.Name, t) - } else if col.TimeZone != nil { - return engine.formatTime(col.TimeZone, col.SQLType.Name, t) + if t.IsZero() { + if col.Nullable { + return nil + } + return "" } - return engine.formatTime(engine.TZLocation, col.SQLType.Name, t) + + if col.TimeZone != nil { + return engine.formatTime(col.SQLType.Name, t.In(col.TimeZone)) + } + return engine.formatTime(col.SQLType.Name, t.In(engine.DatabaseTZ)) } -func (engine *Engine) formatTime(tz *time.Location, sqlTypeName string, t time.Time) (v interface{}) { - if engine.dialect.DBType() == core.ORACLE { - return t - } - if tz != nil { - t = t.In(tz) - } else { - t = engine.TZTime(t) - } +// formatTime format time as column type +func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) { switch sqlTypeName { case core.Time: s := t.Format("2006-01-02 15:04:05") //time.RFC3339 @@ -1567,18 +1545,10 @@ func (engine *Engine) formatTime(tz *time.Location, sqlTypeName string, t time.T case core.Date: v = t.Format("2006-01-02") case core.DateTime, core.TimeStamp: - if engine.dialect.DBType() == "ql" { - v = t - } else if engine.dialect.DBType() == "sqlite3" { - v = t.UTC().Format("2006-01-02 15:04:05") - } else { - v = t.Format("2006-01-02 15:04:05") - } + v = t.Format("2006-01-02 15:04:05") case core.TimeStampz: if engine.dialect.DBType() == core.MSSQL { v = t.Format("2006-01-02T15:04:05.9999999Z07:00") - } else if engine.DriverName() == "mssql" { - v = t } else { v = t.Format(time.RFC3339Nano) } @@ -1593,6 +1563,21 @@ func (engine *Engine) formatTime(tz *time.Location, sqlTypeName string, t time.T // Unscoped always disable struct tag "deleted" func (engine *Engine) Unscoped() *Session { session := engine.NewSession() - session.IsAutoClose = true + session.isAutoClose = true return session.Unscoped() } + +// CondDeleted returns the conditions whether a record is soft deleted. +func (engine *Engine) CondDeleted(colName string) builder.Cond { + if engine.dialect.DBType() == core.MSSQL { + return builder.IsNull{colName} + } + return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1}) +} + +// BufferSize sets buffer size for iterate +func (engine *Engine) BufferSize(size int) *Session { + session := engine.NewSession() + session.isAutoClose = true + return session.BufferSize(size) +} diff --git a/engine_cond.go b/engine_cond.go new file mode 100644 index 00000000..6c8e3879 --- /dev/null +++ b/engine_cond.go @@ -0,0 +1,230 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/go-xorm/builder" + "github.com/go-xorm/core" +) + +func (engine *Engine) buildConds(table *core.Table, bean interface{}, + includeVersion bool, includeUpdated bool, includeNil bool, + includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, + mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { + var conds []builder.Cond + for _, col := range table.Columns() { + if !includeVersion && col.IsVersion { + continue + } + if !includeUpdated && col.IsUpdated { + continue + } + if !includeAutoIncr && col.IsAutoIncrement { + continue + } + + if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { + continue + } + if col.SQLType.IsJson() { + continue + } + + var colName string + if addedTableName { + var nm = tableName + if len(aliasName) > 0 { + nm = aliasName + } + colName = engine.Quote(nm) + "." + engine.Quote(col.Name) + } else { + colName = engine.Quote(col.Name) + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + engine.logger.Error(err) + continue + } + + if col.IsDeleted && !unscoped { // tag "deleted" is enabled + conds = append(conds, engine.CondDeleted(colName)) + } + + fieldValue := *fieldValuePtr + if fieldValue.Interface() == nil { + continue + } + + fieldType := reflect.TypeOf(fieldValue.Interface()) + requiredField := useAllCols + + if b, ok := getFlagForColumn(mustColumnMap, col); ok { + if b { + requiredField = true + } else { + continue + } + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + conds = append(conds, builder.Eq{colName: nil}) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + var val interface{} + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + t := int64(fieldValue.Uint()) + val = reflect.ValueOf(&t).Interface() + case reflect.Struct: + if fieldType.ConvertibleTo(core.TimeType) { + t := fieldValue.Convert(core.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = engine.formatColTime(col, t) + } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { + continue + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = valNul.Value() + if val == nil { + continue + } + } else { + if col.SQLType.IsJson() { + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = bytes + } + } else { + engine.autoMapType(fieldValue) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !isZero(pkField.Interface()) { + val = pkField.Interface() + } else { + continue + } + } else { + //TODO: how to handler? + return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } else { + val = fieldValue.Interface() + } + } + } + case reflect.Array: + continue + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else { + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + conds = append(conds, builder.Eq{colName: val}) + } + + return builder.And(conds...), nil +} diff --git a/engine_maxlife.go b/engine_maxlife.go new file mode 100644 index 00000000..21daeaa1 --- /dev/null +++ b/engine_maxlife.go @@ -0,0 +1,14 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.6 + +package xorm + +import "time" + +// SetConnMaxLifetime sets the maximum amount of time a connection may be reused. +func (engine *Engine) SetConnMaxLifetime(d time.Duration) { + engine.db.SetConnMaxLifetime(d) +} diff --git a/examples/cache.go b/examples/cache.go index 72d987df..0c680d23 100644 --- a/examples/cache.go +++ b/examples/cache.go @@ -67,7 +67,7 @@ func main() { fmt.Println("users3:", users3) user4 := new(User) - has, err := Orm.Id(1).Get(user4) + has, err := Orm.ID(1).Get(user4) if err != nil { fmt.Println(err) return @@ -76,7 +76,7 @@ func main() { fmt.Println("user4:", has, user4) user4.Name = "xiaolunwen" - _, err = Orm.Id(1).Update(user4) + _, err = Orm.ID(1).Update(user4) if err != nil { fmt.Println(err) return @@ -84,14 +84,14 @@ func main() { fmt.Println("user4:", user4) user5 := new(User) - has, err = Orm.Id(1).Get(user5) + has, err = Orm.ID(1).Get(user5) if err != nil { fmt.Println(err) return } fmt.Println("user5:", has, user5) - _, err = Orm.Id(1).Delete(new(User)) + _, err = Orm.ID(1).Delete(new(User)) if err != nil { fmt.Println(err) return @@ -99,7 +99,7 @@ func main() { for { user6 := new(User) - has, err = Orm.Id(1).Get(user6) + has, err = Orm.ID(1).Get(user6) if err != nil { fmt.Println(err) return diff --git a/examples/cachegoroutine.go b/examples/cachegoroutine.go index 815e0ad1..e5b2b9d7 100644 --- a/examples/cachegoroutine.go +++ b/examples/cachegoroutine.go @@ -55,7 +55,7 @@ func test(engine *xorm.Engine) { } else if x+j < 16 { _, err = engine.Insert(&User{Name: "xlw"}) } else if x+j < 32 { - //_, err = engine.Id(1).Delete(u) + //_, err = engine.ID(1).Delete(u) _, err = engine.Delete(u) } if err != nil { diff --git a/examples/derive.go b/examples/derive.go index 86529eea..90561514 100644 --- a/examples/derive.go +++ b/examples/derive.go @@ -51,7 +51,7 @@ func main() { } info := LoginInfo{} - _, err = orm.Id(1).Get(&info) + _, err = orm.ID(1).Get(&info) if err != nil { fmt.Println(err) return diff --git a/examples/goroutine.go b/examples/goroutine.go index 629ea9ac..59e56d10 100644 --- a/examples/goroutine.go +++ b/examples/goroutine.go @@ -59,7 +59,7 @@ func test(engine *xorm.Engine) { } else if x+j < 16 { _, err = engine.Insert(&User{Name: "xlw"}) } else if x+j < 32 { - _, err = engine.Id(1).Delete(u) + _, err = engine.ID(1).Delete(u) } if err != nil { fmt.Println(err) diff --git a/examples/maxconnect.go b/examples/maxconnect.go index 507cbc3c..72d0a503 100644 --- a/examples/maxconnect.go +++ b/examples/maxconnect.go @@ -62,7 +62,7 @@ func test(engine *xorm.Engine) { } else if x+j < 16 { _, err = engine.Insert(&User{Name: "xlw"}) } else if x+j < 32 { - _, err = engine.Id(1).Delete(u) + _, err = engine.ID(1).Delete(u) } if err != nil { fmt.Println(err) diff --git a/examples/singlemapping.go b/examples/singlemapping.go index 3ae0fd1a..86e5d1d7 100644 --- a/examples/singlemapping.go +++ b/examples/singlemapping.go @@ -48,7 +48,7 @@ func main() { } info := LoginInfo{} - _, err = orm.Id(1).Get(&info) + _, err = orm.ID(1).Get(&info) if err != nil { fmt.Println(err) return diff --git a/helpers.go b/helpers.go index 324c5bea..5a0fe7c8 100644 --- a/helpers.go +++ b/helpers.go @@ -196,25 +196,43 @@ func isArrayValueZero(v reflect.Value) bool { func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { var v interface{} - switch tp.Kind() { - case reflect.Int16: - v = int16(id) - case reflect.Int32: - v = int32(id) - case reflect.Int: - v = int(id) - case reflect.Int64: - v = id - case reflect.Uint16: - v = uint16(id) - case reflect.Uint32: - v = uint32(id) - case reflect.Uint64: - v = uint64(id) - case reflect.Uint: - v = uint(id) + kind := tp.Kind() + + if kind == reflect.Ptr { + kind = tp.Elem().Kind() } - return reflect.ValueOf(v).Convert(tp) + + switch kind { + case reflect.Int16: + temp := int16(id) + v = &temp + case reflect.Int32: + temp := int32(id) + v = &temp + case reflect.Int: + temp := int(id) + v = &temp + case reflect.Int64: + temp := id + v = &temp + case reflect.Uint16: + temp := uint16(id) + v = &temp + case reflect.Uint32: + temp := uint32(id) + v = &temp + case reflect.Uint64: + temp := uint64(id) + v = &temp + case reflect.Uint: + temp := uint(id) + v = &temp + } + + if tp.Kind() == reflect.Ptr { + return reflect.ValueOf(v).Convert(tp) + } + return reflect.ValueOf(v).Elem().Convert(tp) } func int64ToInt(id int64, tp reflect.Type) interface{} { @@ -302,175 +320,6 @@ func sliceEq(left, right []string) bool { return true } -func reflect2value(rawValue *reflect.Value) (str string, err error) { - aa := reflect.TypeOf((*rawValue).Interface()) - vv := reflect.ValueOf((*rawValue).Interface()) - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - case reflect.String: - str = vv.String() - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - data := rawValue.Interface().([]byte) - str = string(data) - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - // time type - case reflect.Struct: - if aa.ConvertibleTo(core.TimeType) { - str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) - } else { - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - return -} - -func value2Bytes(rawValue *reflect.Value) (data []byte, err error) { - var str string - str, err = reflect2value(rawValue) - if err != nil { - return - } - data = []byte(str) - return -} - -func value2String(rawValue *reflect.Value) (data string, err error) { - data, err = reflect2value(rawValue) - if err != nil { - return - } - return -} - -func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - for rows.Next() { - result, err := row2mapStr(rows, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, result) - } - - return resultsSlice, nil -} - -func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - for rows.Next() { - result, err := row2map(rows, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, result) - } - - return resultsSlice, nil -} - -func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) { - result := make(map[string][]byte) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - //if row is null then ignore - if rawValue.Interface() == nil { - //fmt.Println("ignore ...", key, rawValue) - continue - } - - if data, err := value2Bytes(&rawValue); err == nil { - result[key] = data - } else { - return nil, err // !nashtsai! REVIEW, should return err or just error log? - } - } - return result, nil -} - -func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) { - result := make(map[string]string) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - //if row is null then ignore - if rawValue.Interface() == nil { - //fmt.Println("ignore ...", key, rawValue) - continue - } - - if data, err := value2String(&rawValue); err == nil { - result[key] = data - } else { - return nil, err // !nashtsai! REVIEW, should return err or just error log? - } - } - return result, nil -} - -func txQuery2(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string]string, error) { - rows, err := tx.Query(sqlStr, params...) - if err != nil { - return nil, err - } - defer rows.Close() - - return rows2Strings(rows) -} - -func query2(db *core.DB, sqlStr string, params ...interface{}) ([]map[string]string, error) { - rows, err := db.Query(sqlStr, params...) - if err != nil { - return nil, err - } - defer rows.Close() - return rows2Strings(rows) -} - func setColumnInt(bean interface{}, col *core.Column, t int64) { v, err := col.ValueOf(bean) if err != nil { @@ -509,7 +358,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, for _, col := range table.Columns() { if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { + if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { continue } } @@ -537,6 +386,10 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, if len(fieldValue.String()) == 0 { continue } + case reflect.Ptr: + if fieldValue.Pointer() == 0 { + continue + } } } @@ -544,28 +397,32 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, continue } - if session.Statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { + if session.statement.ColumnStr != "" { + if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { + continue + } else if _, ok := session.statement.incrColumns[col.Name]; ok { + continue + } else if _, ok := session.statement.decrColumns[col.Name]; ok { continue } } - if session.Statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok { + if session.statement.OmitStr != "" { + if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { continue } } // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.Statement.nullableMap, col); ok { + if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { if col.Nullable && isZero(fieldValue.Interface()) { var nilValue *int fieldValue = reflect.ValueOf(nilValue) } } - if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { + if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { // if time is non-empty, then set to auto time - val, t := session.Engine.NowTime2(col.SQLType.Name) + val, t := session.engine.NowTime2(col.SQLType.Name) args = append(args, val) var colName = col.Name @@ -573,7 +430,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.Statement.checkVersion { + } else if col.IsVersion && session.statement.checkVersion { args = append(args, 1) } else { arg, err := session.value2Interface(col, fieldValue) @@ -584,7 +441,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, } if includeQuote { - colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?") + colNames = append(colNames, session.engine.Quote(col.Name)+" = ?") } else { colNames = append(colNames, col.Name) } diff --git a/helpers_test.go b/helpers_test.go index 7d17383d..d57c54ae 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -1,3 +1,7 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import "testing" diff --git a/helpler_time.go b/helpler_time.go new file mode 100644 index 00000000..f4013e27 --- /dev/null +++ b/helpler_time.go @@ -0,0 +1,21 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import "time" + +const ( + zeroTime0 = "0000-00-00 00:00:00" + zeroTime1 = "0001-01-01 00:00:00" +) + +func formatTime(t time.Time) string { + return t.Format("2006-01-02 15:04:05") +} + +func isTimeZero(t time.Time) bool { + return t.IsZero() || formatTime(t) == zeroTime0 || + formatTime(t) == zeroTime1 +} diff --git a/processors_test.go b/processors_test.go index e370d7a0..4ee59066 100644 --- a/processors_test.go +++ b/processors_test.go @@ -5,6 +5,8 @@ package xorm import ( + "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -40,24 +42,24 @@ func TestBefore_Get(t *testing.T) { func TestBefore_Find(t *testing.T) { assert.NoError(t, prepareEngine()) - type BeforeTable struct { + type BeforeTable2 struct { Id int64 Name string Val string `xorm:"-"` } - assert.NoError(t, testEngine.Sync2(new(BeforeTable))) + assert.NoError(t, testEngine.Sync2(new(BeforeTable2))) - cnt, err := testEngine.Insert([]BeforeTable{ + cnt, err := testEngine.Insert([]BeforeTable2{ {Name: "test1"}, {Name: "test2"}, }) assert.NoError(t, err) assert.EqualValues(t, 2, cnt) - var be []BeforeTable + var be []BeforeTable2 err = testEngine.Before(func(bean interface{}) { - bean.(*BeforeTable).Val = "val" + bean.(*BeforeTable2).Val = "val" }).Find(&be) assert.NoError(t, err) assert.Equal(t, 2, len(be)) @@ -66,3 +68,899 @@ func TestBefore_Find(t *testing.T) { assert.Equal(t, "val", be[1].Val) assert.Equal(t, "test2", be[1].Name) } + +type ProcessorsStruct struct { + Id int64 + + B4InsertFlag int + AfterInsertedFlag int + B4UpdateFlag int + AfterUpdatedFlag int + B4DeleteFlag int `xorm:"-"` + AfterDeletedFlag int `xorm:"-"` + BeforeSetFlag int `xorm:"-"` + + B4InsertViaExt int + AfterInsertedViaExt int + B4UpdateViaExt int + AfterUpdatedViaExt int + B4DeleteViaExt int `xorm:"-"` + AfterDeletedViaExt int `xorm:"-"` + AfterSetFlag int `xorm:"-"` +} + +func (p *ProcessorsStruct) BeforeInsert() { + p.B4InsertFlag = 1 +} + +func (p *ProcessorsStruct) BeforeUpdate() { + p.B4UpdateFlag = 1 +} + +func (p *ProcessorsStruct) BeforeDelete() { + p.B4DeleteFlag = 1 +} + +func (p *ProcessorsStruct) BeforeSet(col string, cell Cell) { + p.BeforeSetFlag = p.BeforeSetFlag + 1 +} + +func (p *ProcessorsStruct) AfterInsert() { + p.AfterInsertedFlag = 1 +} + +func (p *ProcessorsStruct) AfterUpdate() { + p.AfterUpdatedFlag = 1 +} + +func (p *ProcessorsStruct) AfterDelete() { + p.AfterDeletedFlag = 1 +} + +func (p *ProcessorsStruct) AfterSet(col string, cell Cell) { + p.AfterSetFlag = p.AfterSetFlag + 1 +} + +func TestProcessors(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } + p := &ProcessorsStruct{} + + err = testEngine.CreateTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } + + b4InsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4InsertViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + afterInsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterInsertedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + _, err = testEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedViaExt == 0 { + t.Error(errors.New("AfterInsertedViaExt not set")) + } + } + + p2 := &ProcessorsStruct{} + _, err = testEngine.ID(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p2.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p2.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p2.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + if p2.BeforeSetFlag != 9 { + t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) + } + if p2.AfterSetFlag != 9 { + t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) + } + } + // -- + + // test find processors + var p2Find []*ProcessorsStruct + err = testEngine.Find(&p2Find) + if err != nil { + t.Error(err) + panic(err) + } else { + if len(p2Find) != 1 { + err = errors.New("Should get 1") + t.Error(err) + } + p21 := p2Find[0] + if p21.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p21.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p21.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p21.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + if p21.BeforeSetFlag != 9 { + t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p21.BeforeSetFlag)) + } + if p21.AfterSetFlag != 9 { + t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p21.BeforeSetFlag)) + } + } + // -- + + // test find map processors + var p2FindMap = make(map[int64]*ProcessorsStruct) + err = testEngine.Find(&p2FindMap) + if err != nil { + t.Error(err) + panic(err) + } else { + if len(p2FindMap) != 1 { + err = errors.New("Should get 1") + t.Error(err) + } + var p22 *ProcessorsStruct + for _, v := range p2FindMap { + p22 = v + } + + if p22.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p22.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p22.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p22.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + if p22.BeforeSetFlag != 9 { + t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p22.BeforeSetFlag)) + } + if p22.AfterSetFlag != 9 { + t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p22.BeforeSetFlag)) + } + } + // -- + + // test update processors + b4UpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4UpdateViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + afterUpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterUpdatedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + p = p2 // reset + + _, err = testEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) + } + } + + p2 = &ProcessorsStruct{} + _, err = testEngine.ID(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p2.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set: " + string(p.AfterUpdatedFlag))) + } + if p2.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p2.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set: " + string(p.AfterUpdatedViaExt))) + } + if p2.BeforeSetFlag != 9 { + t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) + } + if p2.AfterSetFlag != 9 { + t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) + } + } + // -- + + // test delete processors + b4DeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4DeleteViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + afterDeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterDeletedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + p = p2 // reset + _, err = testEngine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag == 0 { + t.Error(errors.New("AfterDeletedFlag not set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt == 0 { + t.Error(errors.New("AfterDeletedViaExt not set")) + } + } + // -- + + // test insert multi + pslice := make([]*ProcessorsStruct, 0) + pslice = append(pslice, &ProcessorsStruct{}) + pslice = append(pslice, &ProcessorsStruct{}) + cnt, err := testEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice) + if err != nil { + t.Error(err) + panic(err) + } else { + if cnt != 2 { + t.Error(errors.New("incorrect insert count")) + } + for _, elem := range pslice { + if elem.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if elem.AfterInsertedFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if elem.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if elem.AfterInsertedViaExt == 0 { + t.Error(errors.New("AfterInsertedViaExt not set")) + } + } + } + + for _, elem := range pslice { + p = &ProcessorsStruct{} + _, err = testEngine.ID(elem.Id).Get(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p2.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p2.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p2.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + if p2.BeforeSetFlag != 9 { + t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) + } + if p2.AfterSetFlag != 9 { + t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) + } + } + } + // -- +} + +func TestProcessorsTx(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } + + // test insert processors with tx rollback + session := testEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p := &ProcessorsStruct{} + b4InsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4InsertViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + afterInsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterInsertedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("B4InsertFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + + err = session.Rollback() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("B4InsertFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + session.Close() + p2 := &ProcessorsStruct{} + _, err = testEngine.ID(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.Id > 0 { + err = errors.New("tx got committed upon insert!?") + t.Error(err) + panic(err) + } + } + // -- + + // test insert processors with tx commit + session = testEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p = &ProcessorsStruct{} + _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag == 0 { + t.Error(errors.New("AfterInsertedFlag not set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt == 0 { + t.Error(errors.New("AfterInsertedViaExt not set")) + } + } + session.Close() + p2 = &ProcessorsStruct{} + _, err = testEngine.ID(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p2.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p2.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p2.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + insertedId := p2.Id + // -- + + // test update processors with tx rollback + session = testEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + b4UpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4UpdateViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + afterUpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterUpdatedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + p = p2 // reset + + _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + err = session.Rollback() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + + session.Close() + + p2 = &ProcessorsStruct{} + _, err = testEngine.ID(insertedId).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4UpdateFlag != 0 { + t.Error(errors.New("B4UpdateFlag is set")) + } + if p2.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p2.B4UpdateViaExt != 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p2.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + // -- + + // test update processors with tx rollback + session = testEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p = &ProcessorsStruct{Id: insertedId} + + _, err = session.Update(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag set")) + } + } + + session.Close() + + // test update processors with tx commit + session = testEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p = &ProcessorsStruct{} + + _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) + } + } + session.Close() + p2 = &ProcessorsStruct{} + _, err = testEngine.ID(insertedId).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) + } + } + // -- + + // test delete processors with tx rollback + session = testEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + b4DeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4DeleteViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + afterDeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterDeletedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + + p = &ProcessorsStruct{} // reset + + _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + err = session.Rollback() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + session.Close() + + p2 = &ProcessorsStruct{} + _, err = testEngine.ID(insertedId).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4DeleteFlag != 0 { + t.Error(errors.New("B4DeleteFlag is set")) + } + if p2.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p2.B4DeleteViaExt != 0 { + t.Error(errors.New("B4DeleteViaExt is set")) + } + if p2.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + // -- + + // test delete processors with tx commit + session = testEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p = &ProcessorsStruct{} + + _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag == 0 { + t.Error(errors.New("AfterDeletedFlag not set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt == 0 { + t.Error(errors.New("AfterDeletedViaExt not set")) + } + } + session.Close() + + // test delete processors with tx commit + session = testEngine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + p = &ProcessorsStruct{Id: insertedId} + fmt.Println("delete") + _, err = session.Delete(p) + + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag == 0 { + t.Error(errors.New("AfterDeletedFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag set")) + } + } + session.Close() + // -- +} diff --git a/rows.go b/rows.go index 47bc322f..258d9f27 100644 --- a/rows.go +++ b/rows.go @@ -17,7 +17,6 @@ type Rows struct { NoTypeCheck bool session *Session - stmt *core.Stmt rows *core.Rows fields []string beanType reflect.Type @@ -29,53 +28,33 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { rows.session = session rows.beanType = reflect.Indirect(reflect.ValueOf(bean)).Type() - defer rows.session.resetStatement() - var sqlStr string var args []interface{} + var err error - if err := rows.session.Statement.setRefValue(rValue(bean)); err != nil { + if err = rows.session.statement.setRefValue(rValue(bean)); err != nil { return nil, err } - if len(session.Statement.TableName()) <= 0 { + if len(session.statement.TableName()) <= 0 { return nil, ErrTableNotFound } - if rows.session.Statement.RawSQL == "" { - sqlStr, args = rows.session.Statement.genGetSQL(bean) - } else { - sqlStr = rows.session.Statement.RawSQL - args = rows.session.Statement.RawParams - } - - for _, filter := range rows.session.Engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.Engine.dialect, rows.session.Statement.RefTable) - } - - rows.session.saveLastSQL(sqlStr, args...) - var err error - if rows.session.prepareStmt { - rows.stmt, err = rows.session.DB().Prepare(sqlStr) + if rows.session.statement.RawSQL == "" { + sqlStr, args, err = rows.session.statement.genGetSQL(bean) if err != nil { - rows.lastError = err - rows.Close() - return nil, err - } - - rows.rows, err = rows.stmt.Query(args...) - if err != nil { - rows.lastError = err - rows.Close() return nil, err } } else { - rows.rows, err = rows.session.DB().Query(sqlStr, args...) - if err != nil { - rows.lastError = err - rows.Close() - return nil, err - } + sqlStr = rows.session.statement.RawSQL + args = rows.session.statement.RawParams + } + + rows.rows, err = rows.session.queryRows(sqlStr, args...) + if err != nil { + rows.lastError = err + rows.Close() + return nil, err } rows.fields, err = rows.rows.Columns() @@ -116,17 +95,22 @@ func (rows *Rows) Scan(bean interface{}) error { } dataStruct := rValue(bean) - if err := rows.session.Statement.setRefValue(dataStruct); err != nil { + if err := rows.session.statement.setRefValue(dataStruct); err != nil { return err } - _, err := rows.session.row2Bean(rows.rows, rows.fields, len(rows.fields), bean, &dataStruct, rows.session.Statement.RefTable) + scanResults, err := rows.session.row2Slice(rows.rows, rows.fields, len(rows.fields), bean) + if err != nil { + return err + } + + _, err = rows.session.slice2Bean(scanResults, rows.fields, len(rows.fields), bean, &dataStruct, rows.session.statement.RefTable) return err } // Close session if session.IsAutoClose is true, and claimed any opened resources func (rows *Rows) Close() error { - if rows.session.IsAutoClose { + if rows.session.isAutoClose { defer rows.session.Close() } @@ -134,17 +118,10 @@ func (rows *Rows) Close() error { if rows.rows != nil { rows.lastError = rows.rows.Close() if rows.lastError != nil { - defer rows.stmt.Close() return rows.lastError } } - if rows.stmt != nil { - rows.lastError = rows.stmt.Close() - } } else { - if rows.stmt != nil { - defer rows.stmt.Close() - } if rows.rows != nil { defer rows.rows.Close() } diff --git a/rows_test.go b/rows_test.go new file mode 100644 index 00000000..c48938a9 --- /dev/null +++ b/rows_test.go @@ -0,0 +1,68 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRows(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserRows struct { + Id int64 + IsMan bool + } + + assert.NoError(t, testEngine.Sync2(new(UserRows))) + + cnt, err := testEngine.Insert(&UserRows{ + IsMan: true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + rows, err := testEngine.Rows(new(UserRows)) + assert.NoError(t, err) + defer rows.Close() + + cnt = 0 + user := new(UserRows) + for rows.Next() { + err = rows.Scan(user) + assert.NoError(t, err) + cnt++ + } + assert.EqualValues(t, 1, cnt) + + sess := testEngine.NewSession() + defer sess.Close() + + rows1, err := sess.Prepare().Rows(new(UserRows)) + assert.NoError(t, err) + defer rows1.Close() + + cnt = 0 + for rows1.Next() { + err = rows1.Scan(user) + assert.NoError(t, err) + cnt++ + } + assert.EqualValues(t, 1, cnt) + + rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows)) + assert.NoError(t, err) + defer rows2.Close() + + cnt = 0 + for rows2.Next() { + err = rows2.Scan(user) + assert.NoError(t, err) + cnt++ + } + assert.EqualValues(t, 1, cnt) +} diff --git a/session.go b/session.go index 475c769f..c69ac9e5 100644 --- a/session.go +++ b/session.go @@ -21,16 +21,16 @@ import ( // kind of database operations. type Session struct { db *core.DB - Engine *Engine - Tx *core.Tx - Statement Statement - IsAutoCommit bool - IsCommitedOrRollbacked bool - IsAutoClose bool + engine *Engine + tx *core.Tx + statement Statement + isAutoCommit bool + isCommitedOrRollbacked bool + isAutoClose bool // Automatically reset the statement after operations that execute a SQL // query such as Count(), Find(), Get(), ... - AutoResetStatement bool + autoResetStatement bool // !nashtsai! storing these beans due to yet committed tx afterInsertBeans map[interface{}]*[]func(interface{}) @@ -48,6 +48,8 @@ type Session struct { //beforeSQLExec func(string, ...interface{}) lastSQL string lastSQLArgs []interface{} + + err error } // Clone copy all the session's content and return a new session @@ -58,12 +60,12 @@ func (session *Session) Clone() *Session { // Init reset the session as the init status. func (session *Session) Init() { - session.Statement.Init() - session.Statement.Engine = session.Engine - session.IsAutoCommit = true - session.IsCommitedOrRollbacked = false - session.IsAutoClose = false - session.AutoResetStatement = true + session.statement.Init() + session.statement.Engine = session.engine + session.isAutoCommit = true + session.isCommitedOrRollbacked = false + session.isAutoClose = false + session.autoResetStatement = true session.prepareStmt = false // !nashtsai! is lazy init better? @@ -86,19 +88,23 @@ func (session *Session) Close() { if session.db != nil { // When Close be called, if session is a transaction and do not call // Commit or Rollback, then call Rollback. - if session.Tx != nil && !session.IsCommitedOrRollbacked { + if session.tx != nil && !session.isCommitedOrRollbacked { session.Rollback() } - session.Tx = nil + session.tx = nil session.stmtCache = nil - session.Init() session.db = nil } } +// IsClosed returns if session is closed +func (session *Session) IsClosed() bool { + return session.db == nil +} + func (session *Session) resetStatement() { - if session.AutoResetStatement { - session.Statement.Init() + if session.autoResetStatement { + session.statement.Init() } } @@ -126,75 +132,75 @@ func (session *Session) After(closures func(interface{})) *Session { // Table can input a string or pointer to struct for special a table to operate. func (session *Session) Table(tableNameOrBean interface{}) *Session { - session.Statement.Table(tableNameOrBean) + session.statement.Table(tableNameOrBean) return session } // Alias set the table alias func (session *Session) Alias(alias string) *Session { - session.Statement.Alias(alias) + session.statement.Alias(alias) return session } // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.Statement.UseCascade = false + session.statement.UseCascade = false return session } // ForUpdate Set Read/Write locking for UPDATE func (session *Session) ForUpdate() *Session { - session.Statement.IsForUpdate = true + session.statement.IsForUpdate = true return session } // NoAutoCondition disable generate SQL condition from beans func (session *Session) NoAutoCondition(no ...bool) *Session { - session.Statement.NoAutoCondition(no...) + session.statement.NoAutoCondition(no...) return session } // Limit provide limit and offset query condition func (session *Session) Limit(limit int, start ...int) *Session { - session.Statement.Limit(limit, start...) + session.statement.Limit(limit, start...) return session } // OrderBy provide order by query condition, the input parameter is the content // after order by on a sql statement. func (session *Session) OrderBy(order string) *Session { - session.Statement.OrderBy(order) + session.statement.OrderBy(order) return session } // Desc provide desc order by query condition, the input parameters are columns. func (session *Session) Desc(colNames ...string) *Session { - session.Statement.Desc(colNames...) + session.statement.Desc(colNames...) return session } // Asc provide asc order by query condition, the input parameters are columns. func (session *Session) Asc(colNames ...string) *Session { - session.Statement.Asc(colNames...) + session.statement.Asc(colNames...) return session } // StoreEngine is only avialble mysql dialect currently func (session *Session) StoreEngine(storeEngine string) *Session { - session.Statement.StoreEngine = storeEngine + session.statement.StoreEngine = storeEngine return session } // Charset is only avialble mysql dialect currently func (session *Session) Charset(charset string) *Session { - session.Statement.Charset = charset + session.statement.Charset = charset return session } // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { if len(trueOrFalse) >= 1 { - session.Statement.UseCascade = trueOrFalse[0] + session.statement.UseCascade = trueOrFalse[0] } return session } @@ -202,32 +208,32 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { // NoCache ask this session do not retrieve data from cache system and // get data from database directly. func (session *Session) NoCache() *Session { - session.Statement.UseCache = false + session.statement.UseCache = false return session } // Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { - session.Statement.Join(joinOperator, tablename, condition, args...) + session.statement.Join(joinOperator, tablename, condition, args...) return session } // GroupBy Generate Group By statement func (session *Session) GroupBy(keys string) *Session { - session.Statement.GroupBy(keys) + session.statement.GroupBy(keys) return session } // Having Generate Having statement func (session *Session) Having(conditions string) *Session { - session.Statement.Having(conditions) + session.statement.Having(conditions) return session } // DB db return the wrapper of sql.DB func (session *Session) DB() *core.DB { if session.db == nil { - session.db = session.Engine.db + session.db = session.engine.db session.stmtCache = make(map[uint32]*core.Stmt, 0) } return session.db @@ -240,13 +246,13 @@ func cleanupProcessorsClosures(slices *[]func(interface{})) { } func (session *Session) canCache() bool { - if session.Statement.RefTable == nil || - session.Statement.JoinStr != "" || - session.Statement.RawSQL != "" || - !session.Statement.UseCache || - session.Statement.IsForUpdate || - session.Tx != nil || - len(session.Statement.selectStr) > 0 { + if session.statement.RefTable == nil || + session.statement.JoinStr != "" || + session.statement.RawSQL != "" || + !session.statement.UseCache || + session.statement.IsForUpdate || + session.tx != nil || + len(session.statement.selectStr) > 0 { return false } return true @@ -270,18 +276,18 @@ func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) { func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value { var col *core.Column if col = table.GetColumnIdx(key, idx); col == nil { - //session.Engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq()) + //session.engine.logger.Warnf("table %v has no column %v. %v", table.Name, key, table.ColumnsSeq()) return nil } fieldValue, err := col.ValueOfV(dataStruct) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return nil } if !fieldValue.IsValid() || !fieldValue.CanSet() { - session.Engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key) + session.engine.logger.Warnf("table %v's column %v is not valid or cannot set", table.Name, key) return nil } return fieldValue @@ -297,11 +303,16 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount var newValue = newElemFunc(fields) bean := newValue.Interface() dataStruct := rValue(bean) - pk, err := session.row2Bean(rows, fields, fieldsCount, bean, &dataStruct, table) + + // handle beforeClosures + scanResults, err := session.row2Slice(rows, fields, fieldsCount, bean) + if err != nil { + return err + } + pk, err := session.slice2Bean(scanResults, fields, fieldsCount, bean, &dataStruct, table) if err != nil { return err } - err = sliceValueSetFunc(&newValue, pk) if err != nil { return err @@ -310,8 +321,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount return nil } -func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) { - // handle beforeClosures +func (session *Session) row2Slice(rows *core.Rows, fields []string, fieldsCount int, bean interface{}) ([]interface{}, error) { for _, closure := range session.beforeClosures { closure(bean) } @@ -330,7 +340,10 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i b.BeforeSet(key, Cell(scanResults[ii].(*interface{}))) } } + return scanResults, nil +} +func (session *Session) slice2Bean(scanResults []interface{}, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) { defer func() { if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { for ii, key := range fields { @@ -344,15 +357,6 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i } }() - dbTZ := session.Engine.DatabaseTZ - if dbTZ == nil { - if session.Engine.dialect.DBType() == core.SQLITE { - dbTZ = time.UTC - } else { - dbTZ = time.Local - } - } - var tempMap = make(map[string]int) var pk core.PK for ii, key := range fields { @@ -528,11 +532,9 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i } case reflect.Struct: if fieldType.ConvertibleTo(core.TimeType) { - var tz *time.Location - if col.TimeZone == nil { - tz = session.Engine.TZLocation - } else { - tz = col.TimeZone + dbTZ := session.engine.DatabaseTZ + if col.TimeZone != nil { + dbTZ = col.TimeZone } if rawValueType == core.TimeType { @@ -543,26 +545,25 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i z, _ := t.Zone() // set new location if database don't save timezone or give an incorrect timezone if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.Engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) + session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), dbTZ) } - // !nashtsai! convert to engine location - t = t.In(tz) + t = t.In(session.engine.TZLocation) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) } else if rawValueType == core.IntType || rawValueType == core.Int64Type || rawValueType == core.Int32Type { hasAssigned = true - t := time.Unix(vv.Int(), 0).In(tz) + t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) } else { if d, ok := vv.Interface().([]uint8); ok { hasAssigned = true t, err := session.byte2Time(col, d) if err != nil { - session.Engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Error("byte2Time error:", err.Error()) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) @@ -571,20 +572,20 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i hasAssigned = true t, err := session.str2Time(col, d) if err != nil { - session.Engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Error("byte2Time error:", err.Error()) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) } } else { - panic(fmt.Sprintf("rawValueType is %v, value is %v", rawValueType, vv.Interface())) + return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) } } } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { // !! 增加支持sql.Scanner接口的结构,如sql.NullString hasAssigned = true if err := nulVal.Scan(vv.Interface()); err != nil { - session.Engine.logger.Error("sql.Sanner error:", err.Error()) + session.engine.logger.Error("sql.Sanner error:", err.Error()) hasAssigned = false } } else if col.SQLType.IsJson() { @@ -609,15 +610,15 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i fieldValue.Set(x.Elem()) } } - } else if session.Statement.UseCascade { - table, err := session.Engine.autoMapType(*fieldValue) + } else if session.statement.UseCascade { + table, err := session.engine.autoMapType(*fieldValue) if err != nil { return nil, err } hasAssigned = true if len(table.PrimaryKeys) != 1 { - panic("unsupported non or composited primary key cascade") + return nil, errors.New("unsupported non or composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) pk[0], err = asKind(vv, rawValueType) @@ -630,9 +631,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily structInter := reflect.New(fieldValue.Type()) - newsession := session.Engine.NewSession() - defer newsession.Close() - has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) if err != nil { return nil, err } @@ -776,19 +775,11 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i return pk, nil } -func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { - for _, filter := range session.Engine.dialect.Filters() { - *sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable) - } - - session.saveLastSQL(*sqlStr, paramStr...) -} - // saveLastSQL stores executed query information func (session *Session) saveLastSQL(sql string, args ...interface{}) { session.lastSQL = sql session.lastSQLArgs = args - session.Engine.logSQL(sql, args...) + session.engine.logSQL(sql, args...) } // LastSQL returns last query information @@ -798,8 +789,8 @@ func (session *Session) LastSQL() (string, []interface{}) { // tbName get some table's table name func (session *Session) tbNameNoSchema(table *core.Table) string { - if len(session.Statement.AltTableName) > 0 { - return session.Statement.AltTableName + if len(session.statement.AltTableName) > 0 { + return session.statement.AltTableName } return table.Name @@ -807,6 +798,6 @@ func (session *Session) tbNameNoSchema(table *core.Table) string { // Unscoped always disable struct tag "deleted" func (session *Session) Unscoped() *Session { - session.Statement.Unscoped() + session.statement.Unscoped() return session } diff --git a/session_cols.go b/session_cols.go index 91185def..9972cb0a 100644 --- a/session_cols.go +++ b/session_cols.go @@ -6,43 +6,43 @@ package xorm // Incr provides a query string like "count = count + 1" func (session *Session) Incr(column string, arg ...interface{}) *Session { - session.Statement.Incr(column, arg...) + session.statement.Incr(column, arg...) return session } // Decr provides a query string like "count = count - 1" func (session *Session) Decr(column string, arg ...interface{}) *Session { - session.Statement.Decr(column, arg...) + session.statement.Decr(column, arg...) return session } // SetExpr provides a query string like "column = {expression}" func (session *Session) SetExpr(column string, expression string) *Session { - session.Statement.SetExpr(column, expression) + session.statement.SetExpr(column, expression) return session } // Select provides some columns to special func (session *Session) Select(str string) *Session { - session.Statement.Select(str) + session.statement.Select(str) return session } // Cols provides some columns to special func (session *Session) Cols(columns ...string) *Session { - session.Statement.Cols(columns...) + session.statement.Cols(columns...) return session } // AllCols ask all columns func (session *Session) AllCols() *Session { - session.Statement.AllCols() + session.statement.AllCols() return session } // MustCols specify some columns must use even if they are empty func (session *Session) MustCols(columns ...string) *Session { - session.Statement.MustCols(columns...) + session.statement.MustCols(columns...) return session } @@ -52,7 +52,7 @@ func (session *Session) MustCols(columns ...string) *Session { // If no parameters, it will use all the bool field of struct, or // it will use parameters's columns func (session *Session) UseBool(columns ...string) *Session { - session.Statement.UseBool(columns...) + session.statement.UseBool(columns...) return session } @@ -60,25 +60,25 @@ func (session *Session) UseBool(columns ...string) *Session { // distinct will not be cached because cache system need id, // but distinct will not provide id func (session *Session) Distinct(columns ...string) *Session { - session.Statement.Distinct(columns...) + session.statement.Distinct(columns...) return session } // Omit Only not use the parameters as select or update columns func (session *Session) Omit(columns ...string) *Session { - session.Statement.Omit(columns...) + session.statement.Omit(columns...) return session } // Nullable Set null when column is zero-value and nullable for update func (session *Session) Nullable(columns ...string) *Session { - session.Statement.Nullable(columns...) + session.statement.Nullable(columns...) return session } // NoAutoTime means do not automatically give created field and updated field // the current time on the current session temporarily func (session *Session) NoAutoTime() *Session { - session.Statement.UseAutoTime = false + session.statement.UseAutoTime = false return session } diff --git a/session_cols_test.go b/session_cols_test.go index 8bef8bd7..43854723 100644 --- a/session_cols_test.go +++ b/session_cols_test.go @@ -4,13 +4,66 @@ package xorm -import "testing" +import ( + "testing" + + "github.com/go-xorm/core" + "github.com/stretchr/testify/assert" +) func TestSetExpr(t *testing.T) { - type User struct { + assert.NoError(t, prepareEngine()) + + type UserExpr struct { Id int64 Show bool } - testEngine.SetExpr("show", "NOT show").Id(1).Update(new(User)) + assert.NoError(t, testEngine.Sync2(new(UserExpr))) + + cnt, err := testEngine.Insert(&UserExpr{ + Show: true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var not = "NOT" + if testEngine.dialect.DBType() == core.MSSQL { + not = "~" + } + cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestCols(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type ColsTable struct { + Id int64 + Col1 string + Col2 string + } + + assertSync(t, new(ColsTable)) + + _, err := testEngine.Insert(&ColsTable{ + Col1: "1", + Col2: "2", + }) + assert.NoError(t, err) + + sess := testEngine.ID(1) + _, err = sess.Cols("col1").Cols("col2").Update(&ColsTable{ + Col1: "", + Col2: "", + }) + assert.NoError(t, err) + + var tb ColsTable + has, err := testEngine.ID(1).Get(&tb) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "", tb.Col1) + assert.EqualValues(t, "", tb.Col2) } diff --git a/session_cond.go b/session_cond.go index 948a90bc..e1d528f2 100644 --- a/session_cond.go +++ b/session_cond.go @@ -17,25 +17,25 @@ func (session *Session) Sql(query string, args ...interface{}) *Session { // SQL provides raw sql input parameter. When you have a complex SQL statement // and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. func (session *Session) SQL(query interface{}, args ...interface{}) *Session { - session.Statement.SQL(query, args...) + session.statement.SQL(query, args...) return session } // Where provides custom query condition. func (session *Session) Where(query interface{}, args ...interface{}) *Session { - session.Statement.Where(query, args...) + session.statement.Where(query, args...) return session } // And provides custom query condition. func (session *Session) And(query interface{}, args ...interface{}) *Session { - session.Statement.And(query, args...) + session.statement.And(query, args...) return session } // Or provides custom query condition. func (session *Session) Or(query interface{}, args ...interface{}) *Session { - session.Statement.Or(query, args...) + session.statement.Or(query, args...) return session } @@ -48,23 +48,23 @@ func (session *Session) Id(id interface{}) *Session { // ID provides converting id as a query condition func (session *Session) ID(id interface{}) *Session { - session.Statement.ID(id) + session.statement.ID(id) return session } // In provides a query string like "id in (1, 2, 3)" func (session *Session) In(column string, args ...interface{}) *Session { - session.Statement.In(column, args...) + session.statement.In(column, args...) return session } // NotIn provides a query string like "id in (1, 2, 3)" func (session *Session) NotIn(column string, args ...interface{}) *Session { - session.Statement.NotIn(column, args...) + session.statement.NotIn(column, args...) return session } -// Conds returns session query conditions +// Conds returns session query conditions except auto bean conditions func (session *Session) Conds() builder.Cond { - return session.Statement.cond + return session.statement.cond } diff --git a/session_cond_test.go b/session_cond_test.go index d5a93924..5f8716f0 100644 --- a/session_cond_test.go +++ b/session_cond_test.go @@ -5,6 +5,8 @@ package xorm import ( + "errors" + "fmt" "testing" "github.com/go-xorm/builder" @@ -81,6 +83,11 @@ func TestBuilder(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") + conds = make([]Condition, 0) + err = testEngine.NotIn("col_name", "col1", "col2").Find(&conds) + assert.NoError(t, err) + assert.EqualValues(t, 0, len(conds), "records should not exist") + // complex condtions var where = builder.NewCond() if true { @@ -93,3 +100,200 @@ func TestBuilder(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") } + +func TestIn(t *testing.T) { + assert.NoError(t, prepareEngine()) + assert.NoError(t, testEngine.Sync2(new(Userinfo))) + + cnt, err := testEngine.Insert([]Userinfo{ + { + Username: "user1", + Departname: "dev", + }, + { + Username: "user2", + Departname: "dev", + }, + { + Username: "user3", + Departname: "dev", + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 3, cnt) + + var usrs []Userinfo + err = testEngine.Limit(3).Find(&usrs) + if err != nil { + t.Error(err) + panic(err) + } + + if len(usrs) != 3 { + err = errors.New("there are not 3 records") + t.Error(err) + panic(err) + } + + var ids []int64 + var idsStr string + for _, u := range usrs { + ids = append(ids, u.Uid) + idsStr = fmt.Sprintf("%d,", u.Uid) + } + idsStr = idsStr[:len(idsStr)-1] + + users := make([]Userinfo, 0) + err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + if len(users) != 3 { + err = errors.New("in uses should be " + idsStr + " total 3") + t.Error(err) + panic(err) + } + + users = make([]Userinfo, 0) + err = testEngine.In("(id)", ids).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + if len(users) != 3 { + err = errors.New("in uses should be " + idsStr + " total 3") + t.Error(err) + panic(err) + } + + for _, user := range users { + if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] { + err = errors.New("in uses should be " + idsStr + " total 3") + t.Error(err) + panic(err) + } + } + + users = make([]Userinfo, 0) + var idsInterface []interface{} + for _, id := range ids { + idsInterface = append(idsInterface, id) + } + + department := "`" + testEngine.ColumnMapper.Obj2Table("Departname") + "`" + err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + + if len(users) != 3 { + err = errors.New("in uses should be " + idsStr + " total 3") + t.Error(err) + panic(err) + } + + for _, user := range users { + if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] { + err = errors.New("in uses should be " + idsStr + " total 3") + t.Error(err) + panic(err) + } + } + + dev := testEngine.ColumnMapper.Obj2Table("Dev") + + err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users) + + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + + cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update records not 1") + t.Error(err) + panic(err) + } + + user := new(Userinfo) + has, err := testEngine.ID(ids[0]).Get(user) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get record not 1") + t.Error(err) + panic(err) + } + if user.Departname != "dev-" { + err = errors.New("update not success") + t.Error(err) + panic(err) + } + + cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update records not 1") + t.Error(err) + panic(err) + } + + cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("deleted records not 1") + t.Error(err) + panic(err) + } +} + +func TestFindAndCount(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type FindAndCount struct { + Id int64 + Name string + } + + assert.NoError(t, testEngine.Sync2(new(FindAndCount))) + + _, err := testEngine.Insert([]FindAndCount{ + { + Name: "test1", + }, + { + Name: "test2", + }, + }) + assert.NoError(t, err) + + var results []FindAndCount + sess := testEngine.Where("name = ?", "test1") + conds := sess.Conds() + err = sess.Find(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + + total, err := testEngine.Where(conds).Count(new(FindAndCount)) + assert.NoError(t, err) + assert.EqualValues(t, 1, total) +} diff --git a/session_convert.go b/session_convert.go index 7ef57b5f..f2c949ba 100644 --- a/session_convert.go +++ b/session_convert.go @@ -23,41 +23,38 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti var x time.Time var err error - if sdata == "0000-00-00 00:00:00" || - sdata == "0001-01-01 00:00:00" { + var parseLoc = session.engine.DatabaseTZ + if col.TimeZone != nil { + parseLoc = col.TimeZone + } + + if sdata == zeroTime0 || sdata == zeroTime1 { } else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column // time stamp sd, err := strconv.ParseInt(sdata, 10, 64) if err == nil { x = time.Unix(sd, 0) - // !nashtsai! HACK mymysql driver is causing Local location being change to CHAT and cause wrong time conversion - if col.TimeZone == nil { - x = x.In(session.Engine.TZLocation) - } else { - x = x.In(col.TimeZone) - } - session.Engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else { - session.Engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } } else if len(sdata) > 19 && strings.Contains(sdata, "-") { - x, err = time.ParseInLocation(time.RFC3339Nano, sdata, session.Engine.TZLocation) - session.Engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) + session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) if err != nil { - x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, session.Engine.TZLocation) - session.Engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) + session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } if err != nil { - x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, session.Engine.TZLocation) - session.Engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) + session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } - } else if len(sdata) == 19 && strings.Contains(sdata, "-") { - x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, session.Engine.TZLocation) - session.Engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) + session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { - x, err = time.ParseInLocation("2006-01-02", sdata, session.Engine.TZLocation) - session.Engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) + session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if col.SQLType.Name == core.Time { if strings.Contains(sdata, " ") { ssd := strings.Split(sdata, " ") @@ -65,13 +62,13 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti } sdata = strings.TrimSpace(sdata) - if session.Engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 { + if session.engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 { sdata = sdata[len(sdata)-8:] } st := fmt.Sprintf("2006-01-02 %v", sdata) - x, err = time.ParseInLocation("2006-01-02 15:04:05", st, session.Engine.TZLocation) - session.Engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) + x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) + session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else { outErr = fmt.Errorf("unsupported time format %v", sdata) return @@ -80,7 +77,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti outErr = fmt.Errorf("unsupported time format %v: %v", sdata, err) return } - outTime = x + outTime = x.In(session.engine.TZLocation) return } @@ -108,7 +105,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, if len(data) > 0 { err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return err } fieldValue.Set(x.Elem()) @@ -122,7 +119,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, if len(data) > 0 { err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return err } fieldValue.Set(x.Elem()) @@ -135,7 +132,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, if len(data) > 0 { err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return err } fieldValue.Set(x.Elem()) @@ -147,8 +144,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, case reflect.String: fieldValue.SetString(string(data)) case reflect.Bool: - d := string(data) - v, err := strconv.ParseBool(d) + v, err := asBool(data) if err != nil { return fmt.Errorf("arg %v as bool: %s", key, err.Error()) } @@ -159,7 +155,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var err error // for mysql, when use bit, it returned \x01 if col.SQLType.Name == core.Bit && - session.Engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API + session.engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API if len(data) == 1 { x = int64(data[0]) } else { @@ -207,16 +203,17 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if session.Statement.UseCascade { - table, err := session.Engine.autoMapType(*fieldValue) + } else if session.statement.UseCascade { + table, err := session.engine.autoMapType(*fieldValue) if err != nil { return err } // TODO: current only support 1 primary key if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + return errors.New("unsupported composited primary key cascade") } + var pk = make(core.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) @@ -229,9 +226,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily structInter := reflect.New(fieldValue.Type()) - newsession := session.Engine.NewSession() - defer newsession.Close() - has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) if err != nil { return err } @@ -266,7 +261,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, if len(data) > 0 { err := json.Unmarshal(data, &x) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return err } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) @@ -277,7 +272,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, if len(data) > 0 { err := json.Unmarshal(data, &x) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return err } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) @@ -349,7 +344,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var err error // for mysql, when use bit, it returned \x01 if col.SQLType.Name == core.Bit && - strings.Contains(session.Engine.DriverName(), "mysql") { + strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int64(data[0]) } else { @@ -374,7 +369,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var err error // for mysql, when use bit, it returned \x01 if col.SQLType.Name == core.Bit && - strings.Contains(session.Engine.DriverName(), "mysql") { + strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int(data[0]) } else { @@ -402,7 +397,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var err error // for mysql, when use bit, it returned \x01 if col.SQLType.Name == core.Bit && - session.Engine.dialect.DBType() == core.MYSQL { + session.engine.dialect.DBType() == core.MYSQL { if len(data) == 1 { x = int32(data[0]) } else { @@ -430,7 +425,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var err error // for mysql, when use bit, it returned \x01 if col.SQLType.Name == core.Bit && - strings.Contains(session.Engine.DriverName(), "mysql") { + strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int8(data[0]) } else { @@ -458,7 +453,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var err error // for mysql, when use bit, it returned \x01 if col.SQLType.Name == core.Bit && - strings.Contains(session.Engine.DriverName(), "mysql") { + strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int16(data[0]) } else { @@ -490,16 +485,17 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, v = x fieldValue.Set(reflect.ValueOf(&x)) default: - if session.Statement.UseCascade { + if session.statement.UseCascade { structInter := reflect.New(fieldType.Elem()) - table, err := session.Engine.autoMapType(structInter.Elem()) + table, err := session.engine.autoMapType(structInter.Elem()) if err != nil { return err } if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + return errors.New("unsupported composited primary key cascade") } + var pk = make(core.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) @@ -511,9 +507,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily - newsession := session.Engine.NewSession() - defer newsession.Close() - has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) if err != nil { return err } @@ -570,7 +564,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val if fieldValue.IsNil() { return nil, nil } else if !fieldValue.IsValid() { - session.Engine.logger.Warn("the field[", col.FieldName, "] is invalid") + session.engine.logger.Warn("the field[", col.FieldName, "] is invalid") return nil, nil } else { // !nashtsai! deference pointer type to instance type @@ -588,12 +582,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val case reflect.Struct: if fieldType.ConvertibleTo(core.TimeType) { t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - if session.Engine.dialect.DBType() == core.MSSQL { - if t.IsZero() { - return nil, nil - } - } - tf := session.Engine.FormatTime(col.SQLType.Name, t) + tf := session.engine.formatColTime(col, t) return tf, nil } @@ -603,7 +592,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val return v.Value() } - fieldTable, err := session.Engine.autoMapType(fieldValue) + fieldTable, err := session.engine.autoMapType(fieldValue) if err != nil { return nil, err } @@ -617,14 +606,14 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val if col.SQLType.IsText() { bytes, err := json.Marshal(fieldValue.Interface()) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return 0, err } return string(bytes), nil } else if col.SQLType.IsBlob() { bytes, err := json.Marshal(fieldValue.Interface()) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return 0, err } return bytes, nil @@ -633,7 +622,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val case reflect.Complex64, reflect.Complex128: bytes, err := json.Marshal(fieldValue.Interface()) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return 0, err } return string(bytes), nil @@ -645,7 +634,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val if col.SQLType.IsText() { bytes, err := json.Marshal(fieldValue.Interface()) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return 0, err } return string(bytes), nil @@ -658,7 +647,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val } else { bytes, err = json.Marshal(fieldValue.Interface()) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) return 0, err } } diff --git a/session_delete.go b/session_delete.go index 0c1e705e..1d7d662c 100644 --- a/session_delete.go +++ b/session_delete.go @@ -12,26 +12,26 @@ import ( "github.com/go-xorm/core" ) -func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { - if session.Statement.RefTable == nil || - session.Tx != nil { +func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, args ...interface{}) error { + if table == nil || + session.tx != nil { return ErrCacheFailed } - for _, filter := range session.Engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) + for _, filter := range session.engine.dialect.Filters() { + sqlStr = filter.Do(sqlStr, session.engine.dialect, table) } - newsql := session.Statement.convertIDSQL(sqlStr) + newsql := session.statement.convertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } - cacher := session.Engine.getCacher2(session.Statement.RefTable) - tableName := session.Statement.TableName() + cacher := session.engine.getCacher2(table) + pkColumns := table.PKColumns() ids, err := core.GetCacheSql(cacher, tableName, newsql, args) if err != nil { - resultsSlice, err := session.query(newsql, args...) + resultsSlice, err := session.queryBytes(newsql, args...) if err != nil { return err } @@ -40,7 +40,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { for _, data := range resultsSlice { var id int64 var pk core.PK = make([]interface{}, 0) - for _, col := range session.Statement.RefTable.PKColumns() { + for _, col := range pkColumns { if v, ok := data[col.Name]; !ok { return errors.New("no id") } else if col.SQLType.IsText() { @@ -58,35 +58,30 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { ids = append(ids, pk) } } - } /*else { - session.Engine.LogDebug("delete cache sql %v", newsql) - cacher.DelIds(tableName, genSqlKey(newsql, args)) - }*/ + } for _, id := range ids { - session.Engine.logger.Debug("[cacheDelete] delete cache obj", tableName, id) + session.engine.logger.Debug("[cacheDelete] delete cache obj:", tableName, id) sid, err := id.ToString() if err != nil { return err } cacher.DelBean(tableName, sid) } - session.Engine.logger.Debug("[cacheDelete] clear cache sql", tableName) + session.engine.logger.Debug("[cacheDelete] clear cache table:", tableName) cacher.ClearIds(tableName) return nil } // Delete records, bean's non-empty fields are conditions func (session *Session) Delete(bean interface{}) (int64, error) { - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } - if err := session.Statement.setRefValue(rValue(bean)); err != nil { + if err := session.statement.setRefValue(rValue(bean)); err != nil { return 0, err } - var table = session.Statement.RefTable // handle before delete processors for _, closure := range session.beforeClosures { @@ -98,13 +93,17 @@ func (session *Session) Delete(bean interface{}) (int64, error) { processor.BeforeDelete() } - // -- - condSQL, condArgs, _ := session.Statement.genConds(bean) - if len(condSQL) == 0 && session.Statement.LimitN == 0 { + condSQL, condArgs, err := session.statement.genConds(bean) + if err != nil { + return 0, err + } + if len(condSQL) == 0 && session.statement.LimitN == 0 { return 0, ErrNeedDeletedCond } - var tableName = session.Engine.Quote(session.Statement.TableName()) + var tableNameNoQuote = session.statement.TableName() + var tableName = session.engine.Quote(tableNameNoQuote) + var table = session.statement.RefTable var deleteSQL string if len(condSQL) > 0 { deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) @@ -113,15 +112,15 @@ func (session *Session) Delete(bean interface{}) (int64, error) { } var orderSQL string - if len(session.Statement.OrderStr) > 0 { - orderSQL += fmt.Sprintf(" ORDER BY %s", session.Statement.OrderStr) + if len(session.statement.OrderStr) > 0 { + orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) } - if session.Statement.LimitN > 0 { - orderSQL += fmt.Sprintf(" LIMIT %d", session.Statement.LimitN) + if session.statement.LimitN > 0 { + orderSQL += fmt.Sprintf(" LIMIT %d", session.statement.LimitN) } if len(orderSQL) > 0 { - switch session.Engine.dialect.DBType() { + switch session.engine.dialect.DBType() { case core.POSTGRES: inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { @@ -146,7 +145,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { var realSQL string argsForCache := make([]interface{}, 0, len(condArgs)*2) - if session.Statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled + if session.statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled realSQL = deleteSQL copy(argsForCache, condArgs) argsForCache = append(condArgs, argsForCache...) @@ -157,12 +156,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) { deletedColumn := table.DeletedColumn() realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", - session.Engine.Quote(session.Statement.TableName()), - session.Engine.Quote(deletedColumn.Name), + session.engine.Quote(session.statement.TableName()), + session.engine.Quote(deletedColumn.Name), condSQL) if len(orderSQL) > 0 { - switch session.Engine.dialect.DBType() { + switch session.engine.dialect.DBType() { case core.POSTGRES: inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { @@ -185,12 +184,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) { } } - // !oinume! Insert NowTime to the head of session.Statement.Params + // !oinume! Insert NowTime to the head of session.statement.Params condArgs = append(condArgs, "") paramsLen := len(condArgs) copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) - val, t := session.Engine.NowTime2(deletedColumn.SQLType.Name) + val, t := session.engine.NowTime2(deletedColumn.SQLType.Name) condArgs[0] = val var colName = deletedColumn.Name @@ -200,17 +199,18 @@ func (session *Session) Delete(bean interface{}) (int64, error) { }) } - if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && session.Statement.UseCache { - session.cacheDelete(deleteSQL, argsForCache...) + if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { + session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) } + session.statement.RefTable = table res, err := session.exec(realSQL, condArgs...) if err != nil { return 0, err } // handle after delete processors - if session.IsAutoCommit { + if session.isAutoCommit { for _, closure := range session.afterClosures { closure(bean) } diff --git a/session_delete_test.go b/session_delete_test.go new file mode 100644 index 00000000..adabb269 --- /dev/null +++ b/session_delete_test.go @@ -0,0 +1,222 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestDelete(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserinfoDelete struct { + Uid int64 `xorm:"id pk not null autoincr"` + IsMan bool + } + + assert.NoError(t, testEngine.Sync2(new(UserinfoDelete))) + + user := UserinfoDelete{Uid: 1} + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Delete(&UserinfoDelete{Uid: user.Uid}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + user.Uid = 0 + user.IsMan = true + has, err := testEngine.ID(1).Get(&user) + assert.NoError(t, err) + assert.False(t, has) + + cnt, err = testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Where("id=?", user.Uid).Delete(&UserinfoDelete{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + user.Uid = 0 + user.IsMan = true + has, err = testEngine.ID(2).Get(&user) + assert.NoError(t, err) + assert.False(t, has) +} + +func TestDeleted(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Deleted struct { + Id int64 `xorm:"pk"` + Name string + DeletedAt time.Time `xorm:"deleted"` + } + + err := testEngine.DropTables(&Deleted{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Deleted{}) + assert.NoError(t, err) + + _, err = testEngine.InsertOne(&Deleted{Id: 1, Name: "11111"}) + assert.NoError(t, err) + + _, err = testEngine.InsertOne(&Deleted{Id: 2, Name: "22222"}) + assert.NoError(t, err) + + _, err = testEngine.InsertOne(&Deleted{Id: 3, Name: "33333"}) + assert.NoError(t, err) + + // Test normal Find() + var records1 []Deleted + err = testEngine.Where("`"+testEngine.ColumnMapper.Obj2Table("Id")+"` > 0").Find(&records1, &Deleted{}) + assert.EqualValues(t, 3, len(records1)) + + // Test normal Get() + record1 := &Deleted{} + has, err := testEngine.ID(1).Get(record1) + assert.NoError(t, err) + assert.True(t, has) + + // Test Delete() with deleted + affected, err := testEngine.ID(1).Delete(&Deleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) + + has, err = testEngine.ID(1).Get(&Deleted{}) + assert.NoError(t, err) + assert.False(t, has) + + var records2 []Deleted + err = testEngine.Where("`" + testEngine.ColumnMapper.Obj2Table("Id") + "` > 0").Find(&records2) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(records2)) + + // Test no rows affected after Delete() again. + affected, err = testEngine.ID(1).Delete(&Deleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 0, affected) + + // Deleted.DeletedAt must not be updated. + affected, err = testEngine.ID(2).Update(&Deleted{Name: "2", DeletedAt: time.Now()}) + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) + + record2 := &Deleted{} + has, err = testEngine.ID(2).Get(record2) + assert.NoError(t, err) + assert.True(t, record2.DeletedAt.IsZero()) + + // Test find all records whatever `deleted`. + var unscopedRecords1 []Deleted + err = testEngine.Unscoped().Where("`"+testEngine.ColumnMapper.Obj2Table("Id")+"` > 0").Find(&unscopedRecords1, &Deleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 3, len(unscopedRecords1)) + + // Delete() must really delete a record with Unscoped() + affected, err = testEngine.Unscoped().ID(1).Delete(&Deleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) + + var unscopedRecords2 []Deleted + err = testEngine.Unscoped().Where("`"+testEngine.ColumnMapper.Obj2Table("Id")+"` > 0").Find(&unscopedRecords2, &Deleted{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(unscopedRecords2)) + + var records3 []Deleted + err = testEngine.Where("`"+testEngine.ColumnMapper.Obj2Table("Id")+"` > 0").And("`"+testEngine.ColumnMapper.Obj2Table("Id")+"`> 1"). + Or("`"+testEngine.ColumnMapper.Obj2Table("Id")+"` = ?", 3).Find(&records3) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(records3)) +} + +func TestCacheDelete(t *testing.T) { + assert.NoError(t, prepareEngine()) + + oldCacher := testEngine.Cacher + cacher := NewLRUCacher(NewMemoryStore(), 1000) + testEngine.SetDefaultCacher(cacher) + + type CacheDeleteStruct struct { + Id int64 + } + + err := testEngine.CreateTables(&CacheDeleteStruct{}) + assert.NoError(t, err) + + _, err = testEngine.Insert(&CacheDeleteStruct{}) + assert.NoError(t, err) + + aff, err := testEngine.Delete(&CacheDeleteStruct{ + Id: 1, + }) + assert.NoError(t, err) + assert.EqualValues(t, aff, 1) + + aff, err = testEngine.Unscoped().Delete(&CacheDeleteStruct{ + Id: 1, + }) + assert.NoError(t, err) + assert.EqualValues(t, aff, 0) + + testEngine.SetDefaultCacher(oldCacher) +} + +func TestUnscopeDelete(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UnscopeDeleteStruct struct { + Id int64 + Name string + DeletedAt time.Time `xorm:"deleted"` + } + + assertSync(t, new(UnscopeDeleteStruct)) + + cnt, err := testEngine.Insert(&UnscopeDeleteStruct{ + Name: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var nowUnix = time.Now().Unix() + var s UnscopeDeleteStruct + cnt, err = testEngine.ID(1).Delete(&s) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.EqualValues(t, nowUnix, s.DeletedAt.Unix()) + + var s1 UnscopeDeleteStruct + has, err := testEngine.ID(1).Get(&s1) + assert.NoError(t, err) + assert.False(t, has) + + var s2 UnscopeDeleteStruct + has, err = testEngine.ID(1).Unscoped().Get(&s2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "test", s2.Name) + assert.EqualValues(t, nowUnix, s2.DeletedAt.Unix()) + + cnt, err = testEngine.ID(1).Unscoped().Delete(new(UnscopeDeleteStruct)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s3 UnscopeDeleteStruct + has, err = testEngine.ID(1).Get(&s3) + assert.NoError(t, err) + assert.False(t, has) + + var s4 UnscopeDeleteStruct + has, err = testEngine.ID(1).Unscoped().Get(&s4) + assert.NoError(t, err) + assert.False(t, has) +} diff --git a/session_exist.go b/session_exist.go new file mode 100644 index 00000000..049c1ddf --- /dev/null +++ b/session_exist.go @@ -0,0 +1,77 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "fmt" + "reflect" + + "github.com/go-xorm/builder" +) + +// Exist returns true if the record exist otherwise return false +func (session *Session) Exist(bean ...interface{}) (bool, error) { + if session.isAutoClose { + defer session.Close() + } + + var sqlStr string + var args []interface{} + var err error + + if session.statement.RawSQL == "" { + if len(bean) == 0 { + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return false, ErrTableNotFound + } + + if session.statement.cond.IsValid() { + condSQL, condArgs, err := builder.ToSQL(session.statement.cond) + if err != nil { + return false, err + } + + sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) + args = condArgs + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) + args = []interface{}{} + } + } else { + beanValue := reflect.ValueOf(bean[0]) + if beanValue.Kind() != reflect.Ptr { + return false, errors.New("needs a pointer") + } + + if beanValue.Elem().Kind() == reflect.Struct { + if err := session.statement.setRefValue(beanValue.Elem()); err != nil { + return false, err + } + } + + if len(session.statement.TableName()) <= 0 { + return false, ErrTableNotFound + } + session.statement.Limit(1) + sqlStr, args, err = session.statement.genGetSQL(bean[0]) + if err != nil { + return false, err + } + } + } else { + sqlStr = session.statement.RawSQL + args = session.statement.RawParams + } + + rows, err := session.queryRows(sqlStr, args...) + if err != nil { + return false, err + } + defer rows.Close() + + return rows.Next(), nil +} diff --git a/session_exist_test.go b/session_exist_test.go new file mode 100644 index 00000000..857bf4a1 --- /dev/null +++ b/session_exist_test.go @@ -0,0 +1,76 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExistStruct(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type RecordExist struct { + Id int64 + Name string + } + + assertSync(t, new(RecordExist)) + + has, err := testEngine.Exist(new(RecordExist)) + assert.NoError(t, err) + assert.False(t, has) + + cnt, err := testEngine.Insert(&RecordExist{ + Name: "test1", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + has, err = testEngine.Exist(new(RecordExist)) + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Exist(&RecordExist{ + Name: "test1", + }) + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Exist(&RecordExist{ + Name: "test2", + }) + assert.NoError(t, err) + assert.False(t, has) + + has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{}) + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Where("name = ?", "test2").Exist(&RecordExist{}) + assert.NoError(t, err) + assert.False(t, has) + + has, err = testEngine.SQL("select * from record_exist where name = ?", "test1").Exist() + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.SQL("select * from record_exist where name = ?", "test2").Exist() + assert.NoError(t, err) + assert.False(t, has) + + has, err = testEngine.Table("record_exist").Exist() + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Table("record_exist").Where("name = ?", "test2").Exist() + assert.NoError(t, err) + assert.False(t, has) +} diff --git a/session_find.go b/session_find.go index 16c6ff4f..05ec724f 100644 --- a/session_find.go +++ b/session_find.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "reflect" - "strconv" "strings" "github.com/go-xorm/builder" @@ -24,11 +23,13 @@ const ( // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error { - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } + return session.find(rowsSlicePtr, condiBean...) +} +func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { return errors.New("needs a pointer to a slice or a map") @@ -37,11 +38,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) sliceElementType := sliceValue.Type().Elem() var tp = tpStruct - if session.Statement.RefTable == nil { + if session.statement.RefTable == nil { if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { pv := reflect.New(sliceElementType.Elem()) - if err := session.Statement.setRefValue(pv.Elem()); err != nil { + if err := session.statement.setRefValue(pv.Elem()); err != nil { return err } } else { @@ -49,7 +50,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } } else if sliceElementType.Kind() == reflect.Struct { pv := reflect.New(sliceElementType) - if err := session.Statement.setRefValue(pv.Elem()); err != nil { + if err := session.statement.setRefValue(pv.Elem()); err != nil { return err } } else { @@ -57,61 +58,59 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } } - var table = session.Statement.RefTable + var table = session.statement.RefTable - var addedTableName = (len(session.Statement.JoinStr) > 0) + var addedTableName = (len(session.statement.JoinStr) > 0) var autoCond builder.Cond if tp == tpStruct { - if !session.Statement.noAutoCondition && len(condiBean) > 0 { + if !session.statement.noAutoCondition && len(condiBean) > 0 { var err error - autoCond, err = session.Statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName) + autoCond, err = session.statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName) if err != nil { - panic(err) + return err } } else { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://github.com/go-xorm/xorm/issues/179 - if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled - var colName = session.Engine.Quote(col.Name) + if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled + var colName = session.engine.Quote(col.Name) if addedTableName { - var nm = session.Statement.TableName() - if len(session.Statement.TableAlias) > 0 { - nm = session.Statement.TableAlias + var nm = session.statement.TableName() + if len(session.statement.TableAlias) > 0 { + nm = session.statement.TableAlias } - colName = session.Engine.Quote(nm) + "." + colName - } - if session.Engine.dialect.DBType() == core.MSSQL { - autoCond = builder.IsNull{colName} - } else { - autoCond = builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"}) + colName = session.engine.Quote(nm) + "." + colName } + + autoCond = session.engine.CondDeleted(colName) } } } var sqlStr string var args []interface{} - if session.Statement.RawSQL == "" { - if len(session.Statement.TableName()) <= 0 { + var err error + if session.statement.RawSQL == "" { + if len(session.statement.TableName()) <= 0 { return ErrTableNotFound } - var columnStr = session.Statement.ColumnStr - if len(session.Statement.selectStr) > 0 { - columnStr = session.Statement.selectStr + var columnStr = session.statement.ColumnStr + if len(session.statement.selectStr) > 0 { + columnStr = session.statement.selectStr } else { - if session.Statement.JoinStr == "" { + if session.statement.JoinStr == "" { if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) + if session.statement.GroupByStr != "" { + columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) } else { - columnStr = session.Statement.genColumnStr() + columnStr = session.statement.genColumnStr() } } } else { if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) + if session.statement.GroupByStr != "" { + columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) } else { columnStr = "*" } @@ -122,31 +121,37 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } } - condSQL, condArgs, _ := builder.ToSQL(session.Statement.cond.And(autoCond)) + session.statement.cond = session.statement.cond.And(autoCond) + condSQL, condArgs, err := builder.ToSQL(session.statement.cond) + if err != nil { + return err + } - args = append(session.Statement.joinArgs, condArgs...) - sqlStr = session.Statement.genSelectSQL(columnStr, condSQL) + args = append(session.statement.joinArgs, condArgs...) + sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL) + if err != nil { + return err + } // for mssql and use limit qs := strings.Count(sqlStr, "?") if len(args)*2 == qs { args = append(args, args...) } } else { - sqlStr = session.Statement.RawSQL - args = session.Statement.RawParams + sqlStr = session.statement.RawSQL + args = session.statement.RawParams } - var err error if session.canCache() { - if cacher := session.Engine.getCacher2(table); cacher != nil && - !session.Statement.IsDistinct && - !session.Statement.unscoped { + if cacher := session.engine.getCacher2(table); cacher != nil && + !session.statement.IsDistinct && + !session.statement.unscoped { err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) if err != ErrCacheFailed { return err } err = nil // !nashtsai! reset err to nil for ErrCacheFailed - session.Engine.logger.Warn("Cache Find Failed") + session.engine.logger.Warn("Cache Find Failed") } } @@ -154,21 +159,13 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { - var rawRows *core.Rows - var err error - - session.queryPreprocess(&sqlStr, args...) - if session.IsAutoCommit { - _, rawRows, err = session.innerQuery(sqlStr, args...) - } else { - rawRows, err = session.Tx.Query(sqlStr, args...) - } + rows, err := session.queryRows(sqlStr, args...) if err != nil { return err } - defer rawRows.Close() + defer rows.Close() - fields, err := rawRows.Columns() + fields, err := rows.Columns() if err != nil { return err } @@ -238,24 +235,24 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va if elemType.Kind() == reflect.Struct { var newValue = newElemFunc(fields) dataStruct := rValue(newValue.Interface()) - tb, err := session.Engine.autoMapType(dataStruct) + tb, err := session.engine.autoMapType(dataStruct) if err != nil { return err } - return session.rows2Beans(rawRows, fields, len(fields), tb, newElemFunc, containerValueSetFunc) + return session.rows2Beans(rows, fields, len(fields), tb, newElemFunc, containerValueSetFunc) } - for rawRows.Next() { + for rows.Next() { var newValue = newElemFunc(fields) bean := newValue.Interface() switch elemType.Kind() { case reflect.Slice: - err = rawRows.ScanSlice(bean) + err = rows.ScanSlice(bean) case reflect.Map: - err = rawRows.ScanMap(bean) + err = rows.ScanMap(bean) default: - err = rawRows.Scan(bean) + err = rows.Scan(bean) } if err != nil { @@ -286,22 +283,21 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in return ErrCacheFailed } - for _, filter := range session.Engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) + for _, filter := range session.engine.dialect.Filters() { + sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) } - newsql := session.Statement.convertIDSQL(sqlStr) + newsql := session.statement.convertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } - tableName := session.Statement.TableName() - - table := session.Statement.RefTable - cacher := session.Engine.getCacher2(table) + tableName := session.statement.TableName() + table := session.statement.RefTable + cacher := session.engine.getCacher2(table) ids, err := core.GetCacheSql(cacher, tableName, newsql, args) if err != nil { - rows, err := session.DB().Query(newsql, args...) + rows, err := session.queryRows(newsql, args...) if err != nil { return err } @@ -312,7 +308,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in for rows.Next() { i++ if i > 500 { - session.Engine.logger.Debug("[cacheFind] ids length > 500, no cache") + session.engine.logger.Debug("[cacheFind] ids length > 500, no cache") return ErrCacheFailed } var res = make([]string, len(table.PrimaryKeys)) @@ -320,32 +316,24 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if err != nil { return err } - var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { - if col.SQLType.IsNumeric() { - n, err := strconv.ParseInt(res[i], 10, 64) - if err != nil { - return err - } - pk[i] = n - } else if col.SQLType.IsText() { - pk[i] = res[i] - } else { - return errors.New("not supported") + pk[i], err = session.engine.idTypeAssertion(col, res[i]) + if err != nil { + return err } } ids = append(ids, pk) } - session.Engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, newsql, args) + session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, sqlStr, newsql, args) err = core.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return err } } else { - session.Engine.logger.Debug("[cacheFind] cache hit sql:", newsql, args) + session.engine.logger.Debug("[cacheFind] cache hit sql:", tableName, sqlStr, newsql, args) } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) @@ -360,20 +348,20 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in return err } bean := cacher.GetBean(tableName, sid) - if bean == nil { + if bean == nil || reflect.ValueOf(bean).Elem().Type() != t { ides = append(ides, id) ididxes[sid] = idx } else { - session.Engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean) + session.engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean) - pk := session.Engine.IdOf(bean) + pk := session.engine.IdOf(bean) xid, err := pk.ToString() if err != nil { return err } if sid != xid { - session.Engine.logger.Error("[cacheFind] error cache", xid, sid, bean) + session.engine.logger.Error("[cacheFind] error cache", xid, sid, bean) return ErrCacheFailed } temps[idx] = bean @@ -381,9 +369,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } if len(ides) > 0 { - newSession := session.Engine.NewSession() - defer newSession.Close() - slices := reflect.New(reflect.SliceOf(t)) beans := slices.Interface() @@ -393,18 +378,18 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ff = append(ff, ie[0]) } - newSession.In("`"+table.PrimaryKeys[0]+"`", ff...) + session.In("`"+table.PrimaryKeys[0]+"`", ff...) } else { for _, ie := range ides { cond := builder.NewCond() for i, name := range table.PrimaryKeys { cond = cond.And(builder.Eq{"`" + name + "`": ie[i]}) } - newSession.Or(cond) + session.Or(cond) } } - err = newSession.NoCache().Find(beans) + err = session.NoCache().Table(tableName).find(beans) if err != nil { return err } @@ -415,7 +400,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if rv.Kind() != reflect.Ptr { rv = rv.Addr() } - id, err := session.Engine.idOfV(rv) + id, err := session.engine.idOfV(rv) if err != nil { return err } @@ -426,7 +411,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in bean := rv.Interface() temps[ididxes[sid]] = bean - session.Engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps) + session.engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps) cacher.PutBean(tableName, sid, bean) } } @@ -434,7 +419,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in for j := 0; j < len(temps); j++ { bean := temps[j] if bean == nil { - session.Engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps) + session.engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps) // return errors.New("cache error") // !nashtsai! no need to return error, but continue instead continue } diff --git a/session_find_test.go b/session_find_test.go index ef60b68b..9739bc44 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -5,8 +5,11 @@ package xorm import ( + "errors" + "fmt" "testing" + "github.com/go-xorm/core" "github.com/stretchr/testify/assert" ) @@ -57,3 +60,431 @@ func TestJoinLimit(t *testing.T) { Find(&salaries) assert.NoError(t, err) } + +func assertSync(t *testing.T, beans ...interface{}) { + for _, bean := range beans { + assert.NoError(t, testEngine.DropTables(bean)) + assert.NoError(t, testEngine.Sync(bean)) + } +} + +func TestWhere(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assertSync(t, new(Userinfo)) + + users := make([]Userinfo, 0) + err := testEngine.Where("(id) > ?", 2).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + + err = testEngine.Where("(id) > ?", 2).And("(id) < ?", 10).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) +} + +func TestFind(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + users := make([]Userinfo, 0) + + err := testEngine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + fmt.Println(user) + } + + users2 := make([]Userinfo, 0) + userinfo := testEngine.TableMapper.Obj2Table("Userinfo") + err = testEngine.Sql("select * from " + testEngine.Quote(userinfo)).Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } +} + +func TestFind2(t *testing.T) { + assert.NoError(t, prepareEngine()) + users := make([]*Userinfo, 0) + + assertSync(t, new(Userinfo)) + + err := testEngine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + fmt.Println(user) + } +} + +type Team struct { + Id int64 +} + +type TeamUser struct { + OrgId int64 + Uid int64 + TeamId int64 +} + +func TestFind3(t *testing.T) { + assert.NoError(t, prepareEngine()) + err := testEngine.Sync2(new(Team), new(TeamUser)) + if err != nil { + t.Error(err) + panic(err.Error()) + } + + var teams []Team + err = testEngine.Cols("`team`.id"). + Where("`team_user`.org_id=?", 1). + And("`team_user`.uid=?", 2). + Join("INNER", "`team_user`", "`team_user`.team_id=`team`.id"). + Find(&teams) + if err != nil { + t.Error(err) + panic(err.Error()) + } +} + +func TestFindMap(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + users := make(map[int64]Userinfo) + err := testEngine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + fmt.Println(user) + } +} + +func TestFindMap2(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + users := make(map[int64]*Userinfo) + err := testEngine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for id, user := range users { + fmt.Println(id, user) + } +} + +func TestDistinct(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + _, err := testEngine.Insert(&Userinfo{ + Username: "lunny", + }) + assert.NoError(t, err) + + users := make([]Userinfo, 0) + departname := testEngine.TableMapper.Obj2Table("Departname") + err = testEngine.Distinct(departname).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + if len(users) != 1 { + t.Error(err) + panic(errors.New("should be one record")) + } + + fmt.Println(users) + + type Depart struct { + Departname string + } + + users2 := make([]Depart, 0) + err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } + if len(users2) != 1 { + t.Error(err) + panic(errors.New("should be one record")) + } + fmt.Println(users2) +} + +func TestOrder(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + users := make([]Userinfo, 0) + err := testEngine.OrderBy("id desc").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + + users2 := make([]Userinfo, 0) + err = testEngine.Asc("id", "username").Desc("height").Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users2) +} + +func TestHaving(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + users := make([]Userinfo, 0) + err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + + /*users = make([]Userinfo, 0) + err = testEngine.Cols("id, username").GroupBy("username").Having("username='xlw'").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users)*/ +} + +func TestOrderSameMapper(t *testing.T) { + assert.NoError(t, prepareEngine()) + testEngine.unMapType(rValue(new(Userinfo)).Type()) + + mapper := testEngine.TableMapper + testEngine.SetMapper(core.SameMapper{}) + + defer func() { + testEngine.unMapType(rValue(new(Userinfo)).Type()) + testEngine.SetMapper(mapper) + }() + + assertSync(t, new(Userinfo)) + + users := make([]Userinfo, 0) + err := testEngine.OrderBy("(id) desc").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) + + users2 := make([]Userinfo, 0) + err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users2) +} + +func TestHavingSameMapper(t *testing.T) { + assert.NoError(t, prepareEngine()) + testEngine.unMapType(rValue(new(Userinfo)).Type()) + + mapper := testEngine.TableMapper + testEngine.SetMapper(core.SameMapper{}) + defer func() { + testEngine.unMapType(rValue(new(Userinfo)).Type()) + testEngine.SetMapper(mapper) + }() + assertSync(t, new(Userinfo)) + + users := make([]Userinfo, 0) + err := testEngine.GroupBy("`Username`").Having("`Username`='xlw'").Find(&users) + if err != nil { + t.Fatal(err) + } + fmt.Println(users) +} + +func TestFindInts(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + userinfo := testEngine.TableMapper.Obj2Table("Userinfo") + var idsInt64 []int64 + err := testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsInt64) + if err != nil { + t.Fatal(err) + } + fmt.Println(idsInt64) + + var idsInt32 []int32 + err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsInt32) + if err != nil { + t.Fatal(err) + } + fmt.Println(idsInt32) + + var idsInt []int + err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsInt) + if err != nil { + t.Fatal(err) + } + fmt.Println(idsInt) + + var idsUint []uint + err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsUint) + if err != nil { + t.Fatal(err) + } + fmt.Println(idsUint) + + type MyInt int + var idsMyInt []MyInt + err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsMyInt) + if err != nil { + t.Fatal(err) + } + fmt.Println(idsMyInt) +} + +func TestFindStrings(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + userinfo := testEngine.TableMapper.Obj2Table("Userinfo") + username := testEngine.ColumnMapper.Obj2Table("Username") + var idsString []string + err := testEngine.Table(userinfo).Cols(username).Desc("id").Find(&idsString) + if err != nil { + t.Fatal(err) + } + fmt.Println(idsString) +} + +func TestFindMyString(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + userinfo := testEngine.TableMapper.Obj2Table("Userinfo") + username := testEngine.ColumnMapper.Obj2Table("Username") + + var idsMyString []MyString + err := testEngine.Table(userinfo).Cols(username).Desc("id").Find(&idsMyString) + if err != nil { + t.Fatal(err) + } + fmt.Println(idsMyString) +} + +func TestFindInterface(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + userinfo := testEngine.TableMapper.Obj2Table("Userinfo") + username := testEngine.ColumnMapper.Obj2Table("Username") + var idsInterface []interface{} + err := testEngine.Table(userinfo).Cols(username).Desc("id").Find(&idsInterface) + if err != nil { + t.Fatal(err) + } + fmt.Println(idsInterface) +} + +func TestFindSliceBytes(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + userinfo := testEngine.TableMapper.Obj2Table("Userinfo") + var ids [][][]byte + err := testEngine.Table(userinfo).Desc("id").Find(&ids) + if err != nil { + t.Fatal(err) + } + for _, record := range ids { + fmt.Println(record) + } +} + +func TestFindSlicePtrString(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + userinfo := testEngine.TableMapper.Obj2Table("Userinfo") + var ids [][]*string + err := testEngine.Table(userinfo).Desc("id").Find(&ids) + if err != nil { + t.Fatal(err) + } + for _, record := range ids { + fmt.Println(record) + } +} + +func TestFindMapBytes(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + userinfo := testEngine.TableMapper.Obj2Table("Userinfo") + var ids []map[string][]byte + err := testEngine.Table(userinfo).Desc("id").Find(&ids) + if err != nil { + t.Fatal(err) + } + for _, record := range ids { + fmt.Println(record) + } +} + +func TestFindMapPtrString(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + userinfo := testEngine.TableMapper.Obj2Table("Userinfo") + var ids []map[string]*string + err := testEngine.Table(userinfo).Desc("id").Find(&ids) + assert.NoError(t, err) + for _, record := range ids { + fmt.Println(record) + } +} + +func TestFindBit(t *testing.T) { + type FindBitStruct struct { + Id int64 + Msg bool `xorm:"bit"` + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(FindBitStruct)) + + cnt, err := testEngine.Insert([]FindBitStruct{ + { + Msg: false, + }, + { + Msg: true, + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + var results = make([]FindBitStruct, 0, 2) + err = testEngine.Find(&results) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(results)) +} diff --git a/session_get.go b/session_get.go index c7c03d90..1f1e61cd 100644 --- a/session_get.go +++ b/session_get.go @@ -15,39 +15,49 @@ import ( // Get retrieve one record from database, bean's non-empty fields // will be as conditions func (session *Session) Get(bean interface{}) (bool, error) { - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } + return session.get(bean) +} +func (session *Session) get(bean interface{}) (bool, error) { beanValue := reflect.ValueOf(bean) if beanValue.Kind() != reflect.Ptr { - return false, errors.New("needs a pointer") + return false, errors.New("needs a pointer to a value") + } else if beanValue.Elem().Kind() == reflect.Ptr { + return false, errors.New("a pointer to a pointer is not allowed") } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.Statement.setRefValue(beanValue.Elem()); err != nil { + if err := session.statement.setRefValue(beanValue.Elem()); err != nil { return false, err } } var sqlStr string var args []interface{} + var err error - if session.Statement.RawSQL == "" { - if len(session.Statement.TableName()) <= 0 { + if session.statement.RawSQL == "" { + if len(session.statement.TableName()) <= 0 { return false, ErrTableNotFound } - session.Statement.Limit(1) - sqlStr, args = session.Statement.genGetSQL(bean) + session.statement.Limit(1) + sqlStr, args, err = session.statement.genGetSQL(bean) + if err != nil { + return false, err + } } else { - sqlStr = session.Statement.RawSQL - args = session.Statement.RawParams + sqlStr = session.statement.RawSQL + args = session.statement.RawParams } + table := session.statement.RefTable + if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { - if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && - !session.Statement.unscoped { + if cacher := session.engine.getCacher2(table); cacher != nil && + !session.statement.unscoped { has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { return has, err @@ -55,49 +65,46 @@ func (session *Session) Get(bean interface{}) (bool, error) { } } - return session.nocacheGet(beanValue.Elem().Kind(), bean, sqlStr, args...) + return session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...) } -func (session *Session) nocacheGet(beanKind reflect.Kind, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { - session.queryPreprocess(&sqlStr, args...) - - var rawRows *core.Rows - var err error - if session.IsAutoCommit { - _, rawRows, err = session.innerQuery(sqlStr, args...) - } else { - rawRows, err = session.Tx.Query(sqlStr, args...) - } +func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { + rows, err := session.queryRows(sqlStr, args...) if err != nil { return false, err } + defer rows.Close() - defer rawRows.Close() + if !rows.Next() { + return false, nil + } - if rawRows.Next() { - switch beanKind { - case reflect.Struct: - fields, err := rawRows.Columns() - if err != nil { - // WARN: Alougth rawRows return true, but get fields failed - return true, err - } - dataStruct := rValue(bean) - if err := session.Statement.setRefValue(dataStruct); err != nil { - return false, err - } - _, err = session.row2Bean(rawRows, fields, len(fields), bean, &dataStruct, session.Statement.RefTable) - case reflect.Slice: - err = rawRows.ScanSlice(bean) - case reflect.Map: - err = rawRows.ScanMap(bean) - default: - err = rawRows.Scan(bean) + switch beanKind { + case reflect.Struct: + fields, err := rows.Columns() + if err != nil { + // WARN: Alougth rows return true, but get fields failed + return true, err } - return true, err + scanResults, err := session.row2Slice(rows, fields, len(fields), bean) + if err != nil { + return false, err + } + // close it before covert data + rows.Close() + + dataStruct := rValue(bean) + _, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, table) + case reflect.Slice: + err = rows.ScanSlice(bean) + case reflect.Map: + err = rows.ScanMap(bean) + default: + err = rows.Scan(bean) } - return false, nil + + return true, err } func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { @@ -106,22 +113,22 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return false, ErrCacheFailed } - for _, filter := range session.Engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) + for _, filter := range session.engine.dialect.Filters() { + sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) } - newsql := session.Statement.convertIDSQL(sqlStr) + newsql := session.statement.convertIDSQL(sqlStr) if newsql == "" { return false, ErrCacheFailed } - cacher := session.Engine.getCacher2(session.Statement.RefTable) - tableName := session.Statement.TableName() - session.Engine.logger.Debug("[cacheGet] find sql:", newsql, args) + cacher := session.engine.getCacher2(session.statement.RefTable) + tableName := session.statement.TableName() + session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) + table := session.statement.RefTable ids, err := core.GetCacheSql(cacher, tableName, newsql, args) - table := session.Statement.RefTable if err != nil { var res = make([]string, len(table.PrimaryKeys)) - rows, err := session.DB().Query(newsql, args...) + rows, err := session.NoCache().queryRows(newsql, args...) if err != nil { return false, err } @@ -152,19 +159,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } ids = []core.PK{pk} - session.Engine.logger.Debug("[cacheGet] cache ids:", newsql, ids) + session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids) err = core.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return false, err } } else { - session.Engine.logger.Debug("[cacheGet] cache hit sql:", newsql) + session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql, ids) } if len(ids) > 0 { structValue := reflect.Indirect(reflect.ValueOf(bean)) id := ids[0] - session.Engine.logger.Debug("[cacheGet] get bean:", tableName, id) + session.engine.logger.Debug("[cacheGet] get bean:", tableName, id) sid, err := id.ToString() if err != nil { return false, err @@ -172,15 +179,15 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf cacheBean := cacher.GetBean(tableName, sid) if cacheBean == nil { cacheBean = bean - has, err = session.nocacheGet(reflect.Struct, cacheBean, sqlStr, args...) + has, err = session.nocacheGet(reflect.Struct, table, cacheBean, sqlStr, args...) if err != nil || !has { return has, err } - session.Engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean) + session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean) cacher.PutBean(tableName, sid, cacheBean) } else { - session.Engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean) + session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean) has = true } structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean))) diff --git a/session_get_test.go b/session_get_test.go index b1fb6bc9..91006365 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/go-xorm/core" "github.com/stretchr/testify/assert" ) @@ -45,6 +46,15 @@ func TestGetVar(t *testing.T) { assert.Equal(t, true, has) assert.Equal(t, 28, age) + var age2 int64 + has, err = testEngine.Table("get_var").Cols("age"). + Where("age > ?", 20). + And("age < ?", 30). + Get(&age2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.EqualValues(t, 28, age2) + var money float64 has, err = testEngine.Table("get_var").Cols("money").Get(&money) assert.NoError(t, err) @@ -61,15 +71,18 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "28", valuesString["age"]) assert.Equal(t, "1.5", valuesString["money"]) - var valuesInter = make(map[string]interface{}) - has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) - assert.NoError(t, err) - assert.Equal(t, true, has) - assert.Equal(t, 5, len(valuesInter)) - assert.EqualValues(t, 1, valuesInter["id"]) - assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"])) - assert.EqualValues(t, 28, valuesInter["age"]) - assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"])) + // for mymysql driver, interface{} will be []byte, so ignore it currently + if testEngine.dialect.DriverName() != "mymysql" { + var valuesInter = make(map[string]interface{}) + has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 5, len(valuesInter)) + assert.EqualValues(t, 1, valuesInter["id"]) + assert.Equal(t, "hi", fmt.Sprintf("%s", valuesInter["msg"])) + assert.EqualValues(t, 28, valuesInter["age"]) + assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"])) + } var valuesSliceString = make([]string, 5) has, err = testEngine.Table("get_var").Get(&valuesSliceString) @@ -99,3 +112,85 @@ func TestGetVar(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "1.5", fmt.Sprintf("%v", v4)) } + +func TestGetStruct(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserinfoGet struct { + Uid int `xorm:"pk autoincr"` + IsMan bool + } + + assert.NoError(t, testEngine.Sync(new(UserinfoGet))) + + var err error + if testEngine.dialect.DBType() == core.MSSQL { + _, err = testEngine.Exec("SET IDENTITY_INSERT userinfo_get ON") + assert.NoError(t, err) + } + cnt, err := testEngine.Insert(&UserinfoGet{Uid: 2}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + user := UserinfoGet{Uid: 2} + has, err := testEngine.Get(&user) + assert.NoError(t, err) + assert.True(t, has) + + type NoIdUser struct { + User string `xorm:"unique"` + Remain int64 + Total int64 + } + + assert.NoError(t, testEngine.Sync(&NoIdUser{})) + + userCol := testEngine.ColumnMapper.Obj2Table("User") + _, err = testEngine.Where("`"+userCol+"` = ?", "xlw").Delete(&NoIdUser{}) + assert.NoError(t, err) + + cnt, err = testEngine.Insert(&NoIdUser{"xlw", 20, 100}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + noIdUser := new(NoIdUser) + has, err = testEngine.Where("`"+userCol+"` = ?", "xlw").Get(noIdUser) + assert.NoError(t, err) + assert.True(t, has) +} + +func TestGetSlice(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserinfoSlice struct { + Uid int `xorm:"pk autoincr"` + IsMan bool + } + + assertSync(t, new(UserinfoSlice)) + + var users []UserinfoSlice + has, err := testEngine.Get(&users) + assert.False(t, has) + assert.Error(t, err) +} + +func TestGetError(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type GetError struct { + Uid int `xorm:"pk autoincr"` + IsMan bool + } + + assertSync(t, new(GetError)) + + var info = new(GetError) + has, err := testEngine.Get(&info) + assert.False(t, has) + assert.Error(t, err) + + has, err = testEngine.Get(info) + assert.False(t, has) + assert.NoError(t, err) +} diff --git a/session_insert.go b/session_insert.go index 2c8ad782..705f6a89 100644 --- a/session_insert.go +++ b/session_insert.go @@ -19,17 +19,16 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { var affected int64 var err error - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } - defer session.resetStatement() for _, bean := range beans { sliceValue := reflect.Indirect(reflect.ValueOf(bean)) if sliceValue.Kind() == reflect.Slice { size := sliceValue.Len() if size > 0 { - if session.Engine.SupportInsertMany() { + if session.engine.SupportInsertMany() { cnt, err := session.innerInsertMulti(bean) if err != nil { return affected, err @@ -67,15 +66,15 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, errors.New("could not insert a empty slice") } - if err := session.Statement.setRefValue(sliceValue.Index(0)); err != nil { + if err := session.statement.setRefValue(reflect.ValueOf(sliceValue.Index(0).Interface())); err != nil { return 0, err } - if len(session.Statement.TableName()) <= 0 { + if len(session.statement.TableName()) <= 0 { return 0, ErrTableNotFound } - table := session.Statement.RefTable + table := session.statement.RefTable size := sliceValue.Len() var colNames []string @@ -116,18 +115,18 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsDeleted { continue } - if session.Statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { + if session.statement.ColumnStr != "" { + if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { continue } } - if session.Statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok { + if session.statement.OmitStr != "" { + if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { continue } } - if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { - val, t := session.Engine.NowTime2(col.SQLType.Name) + if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { + val, t := session.engine.NowTime2(col.SQLType.Name) args = append(args, val) var colName = col.Name @@ -135,7 +134,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.Statement.checkVersion { + } else if col.IsVersion && session.statement.checkVersion { args = append(args, 1) var colName = col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { @@ -171,18 +170,18 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsDeleted { continue } - if session.Statement.ColumnStr != "" { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok { + if session.statement.ColumnStr != "" { + if _, ok := getFlagForColumn(session.statement.columnMap, col); !ok { continue } } - if session.Statement.OmitStr != "" { - if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok { + if session.statement.OmitStr != "" { + if _, ok := getFlagForColumn(session.statement.columnMap, col); ok { continue } } - if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { - val, t := session.Engine.NowTime2(col.SQLType.Name) + if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { + val, t := session.engine.NowTime2(col.SQLType.Name) args = append(args, val) var colName = col.Name @@ -190,7 +189,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.Statement.checkVersion { + } else if col.IsVersion && session.statement.checkVersion { args = append(args, 1) var colName = col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { @@ -214,25 +213,26 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" var statement string - if session.Engine.dialect.DBType() == core.ORACLE { + var tableName = session.statement.TableName() + if session.engine.dialect.DBType() == core.ORACLE { sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL" temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", - session.Engine.Quote(session.Statement.TableName()), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), - session.Engine.QuoteStr()) + session.engine.Quote(tableName), + session.engine.QuoteStr(), + strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), + session.engine.QuoteStr()) statement = fmt.Sprintf(sql, - session.Engine.Quote(session.Statement.TableName()), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), - session.Engine.QuoteStr(), + session.engine.Quote(tableName), + session.engine.QuoteStr(), + strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), + session.engine.QuoteStr(), strings.Join(colMultiPlaces, temp)) } else { statement = fmt.Sprintf(sql, - session.Engine.Quote(session.Statement.TableName()), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), - session.Engine.QuoteStr(), + session.engine.Quote(tableName), + session.engine.QuoteStr(), + strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), + session.engine.QuoteStr(), strings.Join(colMultiPlaces, "),(")) } res, err := session.exec(statement, args...) @@ -240,8 +240,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, err } - if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { - session.cacheInsert(session.Statement.TableName()) + if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { + session.cacheInsert(table, tableName) } lenAfterClosures := len(session.afterClosures) @@ -249,7 +249,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface() // handle AfterInsertProcessor - if session.IsAutoCommit { + if session.isAutoCommit { // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? for _, closure := range session.afterClosures { closure(elemValue) @@ -280,8 +280,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error // InsertMulti insert multiple records func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } @@ -299,14 +298,14 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { } func (session *Session) innerInsert(bean interface{}) (int64, error) { - if err := session.Statement.setRefValue(rValue(bean)); err != nil { + if err := session.statement.setRefValue(rValue(bean)); err != nil { return 0, err } - if len(session.Statement.TableName()) <= 0 { + if len(session.statement.TableName()) <= 0 { return 0, ErrTableNotFound } - table := session.Statement.RefTable + table := session.statement.RefTable // handle BeforeInsertProcessor for _, closure := range session.beforeClosures { @@ -318,12 +317,12 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { processor.BeforeInsert() } // -- - colNames, args, err := genCols(session.Statement.RefTable, session, bean, false, false) + colNames, args, err := genCols(session.statement.RefTable, session, bean, false, false) if err != nil { return 0, err } // insert expr columns, override if exists - exprColumns := session.Statement.getExpr() + exprColumns := session.statement.getExpr() exprColVals := make([]string, 0, len(exprColumns)) for _, v := range exprColumns { // remove the expr columns @@ -343,18 +342,30 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if len(exprColVals) > 0 { colPlaces = colPlaces + strings.Join(exprColVals, ", ") } else { - colPlaces = colPlaces[0 : len(colPlaces)-2] + if len(colPlaces) > 0 { + colPlaces = colPlaces[0 : len(colPlaces)-2] + } } - sqlStr := fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", - session.Engine.Quote(session.Statement.TableName()), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.Quote(", ")), - session.Engine.QuoteStr(), - colPlaces) + var sqlStr string + var tableName = session.statement.TableName() + if len(colPlaces) > 0 { + sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", + session.engine.Quote(tableName), + session.engine.QuoteStr(), + strings.Join(colNames, session.engine.Quote(", ")), + session.engine.QuoteStr(), + colPlaces) + } else { + if session.engine.dialect.DBType() == core.MYSQL { + sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName)) + } else { + sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.engine.Quote(tableName)) + } + } handleAfterInsertProcessorFunc := func(bean interface{}) { - if session.IsAutoCommit { + if session.isAutoCommit { for _, closure := range session.afterClosures { closure(bean) } @@ -383,23 +394,22 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. - if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { - //assert table.AutoIncrement != "" - res, err := session.query("select seq_atable.currval from dual", args...) + if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { + res, err := session.queryBytes("select seq_atable.currval from dual", args...) if err != nil { return 0, err } handleAfterInsertProcessorFunc(bean) - if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { - session.cacheInsert(session.Statement.TableName()) + if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { + session.cacheInsert(table, tableName) } - if table.Version != "" && session.Statement.checkVersion { + if table.Version != "" && session.statement.checkVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) } else if verValue.IsValid() && verValue.CanSet() { verValue.SetInt(1) } @@ -417,7 +427,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -427,24 +437,24 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, nil - } else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 { + } else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 { //assert table.AutoIncrement != "" - sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement) - res, err := session.query(sqlStr, args...) + sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement) + res, err := session.queryBytes(sqlStr, args...) if err != nil { return 0, err } handleAfterInsertProcessorFunc(bean) - if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { - session.cacheInsert(session.Statement.TableName()) + if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { + session.cacheInsert(table, tableName) } - if table.Version != "" && session.Statement.checkVersion { + if table.Version != "" && session.statement.checkVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) } else if verValue.IsValid() && verValue.CanSet() { verValue.SetInt(1) } @@ -462,7 +472,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -480,14 +490,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { defer handleAfterInsertProcessorFunc(bean) - if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { - session.cacheInsert(session.Statement.TableName()) + if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { + session.cacheInsert(table, tableName) } - if table.Version != "" && session.Statement.checkVersion { + if table.Version != "" && session.statement.checkVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) } else if verValue.IsValid() && verValue.CanSet() { verValue.SetInt(1) } @@ -505,7 +515,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -522,24 +532,21 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // The in parameter bean must a struct or a point to struct. The return // parameter is inserted and error func (session *Session) InsertOne(bean interface{}) (int64, error) { - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } return session.innerInsert(bean) } -func (session *Session) cacheInsert(tables ...string) error { - if session.Statement.RefTable == nil { +func (session *Session) cacheInsert(table *core.Table, tables ...string) error { + if table == nil { return ErrCacheFailed } - table := session.Statement.RefTable - cacher := session.Engine.getCacher2(table) - + cacher := session.engine.getCacher2(table) for _, t := range tables { - session.Engine.logger.Debug("[cache] clear sql:", t) + session.engine.logger.Debug("[cache] clear sql:", t) cacher.ClearIds(t) } diff --git a/session_insert_test.go b/session_insert_test.go index b232d3f7..d4878af6 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -5,6 +5,9 @@ package xorm import ( + "errors" + "fmt" + "reflect" "testing" "time" @@ -26,3 +29,636 @@ func TestInsertOne(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) } + +func TestInsertMulti(t *testing.T) { + + assert.NoError(t, prepareEngine()) + type TestMulti struct { + Id int64 `xorm:"int(11) pk"` + Name string `xorm:"varchar(255)"` + } + + assert.NoError(t, testEngine.Sync2(new(TestMulti))) + + num, err := insertMultiDatas(1, + append([]TestMulti{}, TestMulti{1, "test1"}, TestMulti{2, "test2"}, TestMulti{3, "test3"})) + assert.NoError(t, err) + assert.EqualValues(t, 3, num) +} + +func insertMultiDatas(step int, datas interface{}) (num int64, err error) { + sliceValue := reflect.Indirect(reflect.ValueOf(datas)) + var iLen int64 + if sliceValue.Kind() != reflect.Slice { + return 0, fmt.Errorf("not silce") + } + iLen = int64(sliceValue.Len()) + if iLen == 0 { + return + } + + session := testEngine.NewSession() + defer session.Close() + + if err = callbackLooper(datas, step, + func(innerDatas interface{}) error { + n, e := session.InsertMulti(innerDatas) + if e != nil { + return e + } + num += n + return nil + }); err != nil { + return 0, err + } else if num != iLen { + return 0, fmt.Errorf("num error: %d - %d", num, iLen) + } + return +} + +func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) { + + sliceValue := reflect.Indirect(reflect.ValueOf(datas)) + if sliceValue.Kind() != reflect.Slice { + return fmt.Errorf("not slice") + } + if sliceValue.Len() <= 0 { + return + } + + tempLen := 0 + processedLen := sliceValue.Len() + for i := 0; i < sliceValue.Len(); i += step { + if processedLen > step { + tempLen = i + step + } else { + tempLen = sliceValue.Len() + } + var tempInterface []interface{} + for j := i; j < tempLen; j++ { + tempInterface = append(tempInterface, sliceValue.Index(j).Interface()) + } + if err = actionFunc(tempInterface); err != nil { + return + } + processedLen = processedLen - step + } + return +} + +func TestInsertOneIfPkIsPoint(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TestPoint struct { + Id *int64 `xorm:"autoincr pk notnull 'id'"` + Msg *string `xorm:"varchar(255)"` + Created *time.Time `xorm:"created"` + } + + assert.NoError(t, testEngine.Sync2(new(TestPoint))) + msg := "hi" + data := TestPoint{Msg: &msg} + _, err := testEngine.InsertOne(&data) + assert.NoError(t, err) +} + +func TestInsertOneIfPkIsPointRename(t *testing.T) { + assert.NoError(t, prepareEngine()) + type ID *int64 + type TestPoint2 struct { + Id ID `xorm:"autoincr pk notnull 'id'"` + Msg *string `xorm:"varchar(255)"` + Created *time.Time `xorm:"created"` + } + + assert.NoError(t, testEngine.Sync2(new(TestPoint2))) + msg := "hi" + data := TestPoint2{Msg: &msg} + _, err := testEngine.InsertOne(&data) + assert.NoError(t, err) +} + +func TestInsert(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), + Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} + cnt, err := testEngine.Insert(&user) + fmt.Println(user.Uid) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + } + + if user.Uid <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } + + user.Uid = 0 + cnt, err = testEngine.Insert(&user) + if err == nil { + err = errors.New("insert failed but no return error") + t.Error(err) + panic(err) + } + if cnt != 0 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } +} + +func TestInsertAutoIncr(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assertSync(t, new(Userinfo)) + + // auto increment insert + user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), + Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} + cnt, err := testEngine.Insert(&user) + fmt.Println(user.Uid) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + } + if user.Uid <= 0 { + t.Error(errors.New("not return id error")) + } +} + +type DefaultInsert struct { + Id int64 + Status int `xorm:"default -1"` + Name string + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` +} + +func TestInsertDefault(t *testing.T) { + assert.NoError(t, prepareEngine()) + + di := new(DefaultInsert) + err := testEngine.Sync2(di) + if err != nil { + t.Error(err) + } + + var di2 = DefaultInsert{Name: "test"} + _, err = testEngine.Omit(testEngine.ColumnMapper.Obj2Table("Status")).Insert(&di2) + if err != nil { + t.Error(err) + } + + has, err := testEngine.Desc("(id)").Get(di) + if err != nil { + t.Error(err) + } + if !has { + err = errors.New("error with no data") + t.Error(err) + panic(err) + } + if di.Status != -1 { + err = errors.New("inserted error data") + t.Error(err) + panic(err) + } + if di2.Updated.Unix() != di.Updated.Unix() { + err = errors.New("updated should equal") + t.Error(err, di.Updated, di2.Updated) + panic(err) + } + if di2.Created.Unix() != di.Created.Unix() { + err = errors.New("created should equal") + t.Error(err, di.Created, di2.Created) + panic(err) + } +} + +type DefaultInsert2 struct { + Id int64 + Name string + Url string `xorm:"text"` + CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"` +} + +func TestInsertDefault2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + di := new(DefaultInsert2) + err := testEngine.Sync2(di) + if err != nil { + t.Error(err) + } + + var di2 = DefaultInsert2{Name: "test"} + _, err = testEngine.Omit(testEngine.ColumnMapper.Obj2Table("CheckTime")).Insert(&di2) + if err != nil { + t.Error(err) + } + + has, err := testEngine.Desc("(id)").Get(di) + if err != nil { + t.Error(err) + } + if !has { + err = errors.New("error with no data") + t.Error(err) + panic(err) + } + + has, err = testEngine.NoAutoCondition().Desc("(id)").Get(&di2) + if err != nil { + t.Error(err) + } + + if !has { + err = errors.New("error with no data") + t.Error(err) + panic(err) + } + + if *di != di2 { + err = fmt.Errorf("%v is not equal to %v", di, di2) + t.Error(err) + panic(err) + } + + /*if di2.Updated.Unix() != di.Updated.Unix() { + err = errors.New("updated should equal") + t.Error(err, di.Updated, di2.Updated) + panic(err) + } + if di2.Created.Unix() != di.Created.Unix() { + err = errors.New("created should equal") + t.Error(err, di.Created, di2.Created) + panic(err) + }*/ +} + +type CreatedInsert struct { + Id int64 + Created time.Time `xorm:"created"` +} + +type CreatedInsert2 struct { + Id int64 + Created int64 `xorm:"created"` +} + +type CreatedInsert3 struct { + Id int64 + Created int `xorm:"created bigint"` +} + +type CreatedInsert4 struct { + Id int64 + Created int `xorm:"created"` +} + +type CreatedInsert5 struct { + Id int64 + Created time.Time `xorm:"created bigint"` +} + +type CreatedInsert6 struct { + Id int64 + Created time.Time `xorm:"created bigint"` +} + +func TestInsertCreated(t *testing.T) { + assert.NoError(t, prepareEngine()) + + di := new(CreatedInsert) + err := testEngine.Sync2(di) + if err != nil { + t.Fatal(err) + } + ci := &CreatedInsert{} + _, err = testEngine.Insert(ci) + if err != nil { + t.Fatal(err) + } + + has, err := testEngine.Desc("(id)").Get(di) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci.Created.Unix() != di.Created.Unix() { + t.Fatal("should equal:", ci, di) + } + fmt.Println("ci:", ci, "di:", di) + + di2 := new(CreatedInsert2) + err = testEngine.Sync2(di2) + if err != nil { + t.Fatal(err) + } + ci2 := &CreatedInsert2{} + _, err = testEngine.Insert(ci2) + if err != nil { + t.Fatal(err) + } + has, err = testEngine.Desc("(id)").Get(di2) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci2.Created != di2.Created { + t.Fatal("should equal:", ci2, di2) + } + fmt.Println("ci2:", ci2, "di2:", di2) + + di3 := new(CreatedInsert3) + err = testEngine.Sync2(di3) + if err != nil { + t.Fatal(err) + } + ci3 := &CreatedInsert3{} + _, err = testEngine.Insert(ci3) + if err != nil { + t.Fatal(err) + } + has, err = testEngine.Desc("(id)").Get(di3) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci3.Created != di3.Created { + t.Fatal("should equal:", ci3, di3) + } + fmt.Println("ci3:", ci3, "di3:", di3) + + di4 := new(CreatedInsert4) + err = testEngine.Sync2(di4) + if err != nil { + t.Fatal(err) + } + ci4 := &CreatedInsert4{} + _, err = testEngine.Insert(ci4) + if err != nil { + t.Fatal(err) + } + has, err = testEngine.Desc("(id)").Get(di4) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci4.Created != di4.Created { + t.Fatal("should equal:", ci4, di4) + } + fmt.Println("ci4:", ci4, "di4:", di4) + + di5 := new(CreatedInsert5) + err = testEngine.Sync2(di5) + if err != nil { + t.Fatal(err) + } + ci5 := &CreatedInsert5{} + _, err = testEngine.Insert(ci5) + if err != nil { + t.Fatal(err) + } + has, err = testEngine.Desc("(id)").Get(di5) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci5.Created.Unix() != di5.Created.Unix() { + t.Fatal("should equal:", ci5, di5) + } + fmt.Println("ci5:", ci5, "di5:", di5) + + di6 := new(CreatedInsert6) + err = testEngine.Sync2(di6) + if err != nil { + t.Fatal(err) + } + oldTime := time.Now().Add(-time.Hour) + ci6 := &CreatedInsert6{Created: oldTime} + _, err = testEngine.Insert(ci6) + if err != nil { + t.Fatal(err) + } + + has, err = testEngine.Desc("(id)").Get(di6) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci6.Created.Unix() != di6.Created.Unix() { + t.Fatal("should equal:", ci6, di6) + } + fmt.Println("ci6:", ci6, "di6:", di6) +} + +type JsonTime time.Time + +func (j JsonTime) format() string { + t := time.Time(j) + if t.IsZero() { + return "" + } + + return t.Format("2006-01-02") +} + +func (j JsonTime) MarshalText() ([]byte, error) { + return []byte(j.format()), nil +} + +func (j JsonTime) MarshalJSON() ([]byte, error) { + return []byte(`"` + j.format() + `"`), nil +} + +func TestDefaultTime3(t *testing.T) { + type PrepareTask struct { + Id int `xorm:"not null pk autoincr INT(11)" json:"id"` + // ... + StartTime JsonTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP index" json:"start_time"` + EndTime JsonTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP" json:"end_time"` + Cuser string `xorm:"not null default '' VARCHAR(64) index" json:"cuser"` + Muser string `xorm:"not null default '' VARCHAR(64)" json:"muser"` + Ctime JsonTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP created" json:"ctime"` + Mtime JsonTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP updated" json:"mtime"` + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(PrepareTask)) + + prepareTask := &PrepareTask{ + StartTime: JsonTime(time.Now()), + Cuser: "userId", + Muser: "userId", + } + cnt, err := testEngine.Omit("end_time").InsertOne(prepareTask) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type MyJsonTime struct { + Id int64 `json:"id"` + Created JsonTime `xorm:"created" json:"created_at"` +} + +func TestCreatedJsonTime(t *testing.T) { + assert.NoError(t, prepareEngine()) + + di5 := new(MyJsonTime) + err := testEngine.Sync2(di5) + if err != nil { + t.Fatal(err) + } + ci5 := &MyJsonTime{} + _, err = testEngine.Insert(ci5) + if err != nil { + t.Fatal(err) + } + has, err := testEngine.Desc("(id)").Get(di5) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if time.Time(ci5.Created).Unix() != time.Time(di5.Created).Unix() { + t.Fatal("should equal:", time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix()) + } + fmt.Println("ci5:", ci5, "di5:", di5) + + var dis = make([]MyJsonTime, 0) + err = testEngine.Find(&dis) + if err != nil { + t.Fatal(err) + } +} + +func TestInsertMulti2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assertSync(t, new(Userinfo)) + + users := []Userinfo{ + {Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + cnt, err := testEngine.Insert(&users) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != int64(len(users)) { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + + users2 := []*Userinfo{ + &Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &Userinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + &Userinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &Userinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + + cnt, err = testEngine.Insert(&users2) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != int64(len(users2)) { + err = errors.New(fmt.Sprintf("insert not returned %v", len(users2))) + t.Error(err) + panic(err) + } +} + +func TestInsertTwoTable(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assertSync(t, new(Userinfo), new(Userdetail)) + + userdetail := Userdetail{ /*Id: 1, */ Intro: "I'm a very beautiful women.", Profile: "sfsaf"} + userinfo := Userinfo{Username: "xlw3", Departname: "dev", Alias: "lunny4", Created: time.Now(), Detail: userdetail} + + cnt, err := testEngine.Insert(&userinfo, &userdetail) + if err != nil { + t.Error(err) + panic(err) + } + + if userinfo.Uid <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } + + if userdetail.Id <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } + + if cnt != 2 { + err = errors.New("insert not returned 2") + t.Error(err) + panic(err) + } +} + +func TestInsertCreatedInt64(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TestCreatedInt64 struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Created int64 `xorm:"created"` + } + + assert.NoError(t, testEngine.Sync2(new(TestCreatedInt64))) + + data := TestCreatedInt64{Msg: "hi"} + now := time.Now() + cnt, err := testEngine.Insert(&data) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.True(t, now.Unix() <= data.Created) + + var data2 TestCreatedInt64 + has, err := testEngine.Get(&data2) + assert.NoError(t, err) + assert.True(t, has) + + assert.EqualValues(t, data.Created, data2.Created) +} diff --git a/session_iterate.go b/session_iterate.go index 7c148095..071fce49 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -19,6 +19,14 @@ func (session *Session) Rows(bean interface{}) (*Rows, error) { // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct func (session *Session) Iterate(bean interface{}, fun IterFunc) error { + if session.isAutoClose { + defer session.Close() + } + + if session.statement.bufferSize > 0 { + return session.bufferIterate(bean, fun) + } + rows, err := session.Rows(bean) if err != nil { return err @@ -40,3 +48,49 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { } return err } + +// BufferSize sets the buffersize for iterate +func (session *Session) BufferSize(size int) *Session { + session.statement.bufferSize = size + return session +} + +func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error { + if session.isAutoClose { + defer session.Close() + } + + var bufferSize = session.statement.bufferSize + var limit = session.statement.LimitN + if limit > 0 && bufferSize > limit { + bufferSize = limit + } + var start = session.statement.Start + v := rValue(bean) + sliceType := reflect.SliceOf(v.Type()) + var idx = 0 + for { + slice := reflect.New(sliceType) + if err := session.Limit(bufferSize, start).find(slice.Interface(), bean); err != nil { + return err + } + + for i := 0; i < slice.Elem().Len(); i++ { + if err := fun(idx, slice.Elem().Index(i).Addr().Interface()); err != nil { + return err + } + idx++ + } + + start = start + slice.Elem().Len() + if limit > 0 && idx+bufferSize > limit { + bufferSize = limit - idx + } + + if bufferSize <= 0 || slice.Elem().Len() < bufferSize || idx == limit { + break + } + } + + return nil +} diff --git a/session_iterate_test.go b/session_iterate_test.go new file mode 100644 index 00000000..9a7ec25f --- /dev/null +++ b/session_iterate_test.go @@ -0,0 +1,92 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIterate(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserIterate struct { + Id int64 + IsMan bool + } + + assert.NoError(t, testEngine.Sync2(new(UserIterate))) + + cnt, err := testEngine.Insert(&UserIterate{ + IsMan: true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt = 0 + err = testEngine.Iterate(new(UserIterate), func(i int, bean interface{}) error { + user := bean.(*UserIterate) + assert.EqualValues(t, 1, user.Id) + assert.EqualValues(t, true, user.IsMan) + cnt++ + return nil + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestBufferIterate(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserBufferIterate struct { + Id int64 + IsMan bool + } + + assert.NoError(t, testEngine.Sync2(new(UserBufferIterate))) + + var size = 20 + for i := 0; i < size; i++ { + cnt, err := testEngine.Insert(&UserBufferIterate{ + IsMan: true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + } + + var cnt = 0 + err := testEngine.BufferSize(9).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error { + user := bean.(*UserBufferIterate) + assert.EqualValues(t, cnt+1, user.Id) + assert.EqualValues(t, true, user.IsMan) + cnt++ + return nil + }) + assert.NoError(t, err) + assert.EqualValues(t, size, cnt) + + cnt = 0 + err = testEngine.Limit(20).BufferSize(9).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error { + user := bean.(*UserBufferIterate) + assert.EqualValues(t, cnt+1, user.Id) + assert.EqualValues(t, true, user.IsMan) + cnt++ + return nil + }) + assert.NoError(t, err) + assert.EqualValues(t, size, cnt) + + cnt = 0 + err = testEngine.Limit(7).BufferSize(9).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error { + user := bean.(*UserBufferIterate) + assert.EqualValues(t, cnt+1, user.Id) + assert.EqualValues(t, true, user.IsMan) + cnt++ + return nil + }) + assert.NoError(t, err) + assert.EqualValues(t, 7, cnt) +} diff --git a/session_pk_test.go b/session_pk_test.go index 44ffb300..3370b2ad 100644 --- a/session_pk_test.go +++ b/session_pk_test.go @@ -7,6 +7,7 @@ package xorm import ( "errors" "testing" + "time" "github.com/go-xorm/core" "github.com/stretchr/testify/assert" @@ -126,7 +127,7 @@ func TestIntId(t *testing.T) { panic(err) } - cnt, err = testEngine.Id(bean.Id).Delete(&IntId{}) + cnt, err = testEngine.ID(bean.Id).Delete(&IntId{}) if err != nil { t.Error(err) panic(err) @@ -201,7 +202,7 @@ func TestInt16Id(t *testing.T) { panic(err) } - cnt, err = testEngine.Id(bean.Id).Delete(&Int16Id{}) + cnt, err = testEngine.ID(bean.Id).Delete(&Int16Id{}) if err != nil { t.Error(err) panic(err) @@ -276,7 +277,7 @@ func TestInt32Id(t *testing.T) { panic(err) } - cnt, err = testEngine.Id(bean.Id).Delete(&Int32Id{}) + cnt, err = testEngine.ID(bean.Id).Delete(&Int32Id{}) if err != nil { t.Error(err) panic(err) @@ -365,7 +366,7 @@ func TestUintId(t *testing.T) { panic(err) } - cnt, err = testEngine.Id(bean.Id).Delete(&UintId{}) + cnt, err = testEngine.ID(bean.Id).Delete(&UintId{}) if err != nil { t.Error(err) panic(err) @@ -440,7 +441,7 @@ func TestUint16Id(t *testing.T) { panic(err) } - cnt, err = testEngine.Id(bean.Id).Delete(&Uint16Id{}) + cnt, err = testEngine.ID(bean.Id).Delete(&Uint16Id{}) if err != nil { t.Error(err) panic(err) @@ -515,7 +516,7 @@ func TestUint32Id(t *testing.T) { panic(err) } - cnt, err = testEngine.Id(bean.Id).Delete(&Uint32Id{}) + cnt, err = testEngine.ID(bean.Id).Delete(&Uint32Id{}) if err != nil { t.Error(err) panic(err) @@ -603,7 +604,7 @@ func TestUint64Id(t *testing.T) { panic(errors.New("should be equal")) } - cnt, err = testEngine.Id(bean.Id).Delete(&Uint64Id{}) + cnt, err = testEngine.ID(bean.Id).Delete(&Uint64Id{}) if err != nil { t.Error(err) panic(err) @@ -678,7 +679,7 @@ func TestStringPK(t *testing.T) { panic(err) } - cnt, err = testEngine.Id(bean.Id).Delete(&StringPK{}) + cnt, err = testEngine.ID(bean.Id).Delete(&StringPK{}) if err != nil { t.Error(err) panic(err) @@ -724,7 +725,7 @@ func TestCompositeKey(t *testing.T) { } var compositeKeyVal CompositeKey - has, err := testEngine.Id(core.PK{11, 22}).Get(&compositeKeyVal) + has, err := testEngine.ID(core.PK{11, 22}).Get(&compositeKeyVal) if err != nil { t.Error(err) } else if !has { @@ -733,7 +734,7 @@ func TestCompositeKey(t *testing.T) { var compositeKeyVal2 CompositeKey // test passing PK ptr, this test seem failed withCache - has, err = testEngine.Id(&core.PK{11, 22}).Get(&compositeKeyVal2) + has, err = testEngine.ID(&core.PK{11, 22}).Get(&compositeKeyVal2) if err != nil { t.Error(err) } else if !has { @@ -763,31 +764,21 @@ func TestCompositeKey(t *testing.T) { t.Error(errors.New("failed to insert CompositeKey{22, 22}")) } - if testEngine.Cacher != nil { - testEngine.Cacher.ClearBeans(testEngine.TableInfo(compositeKeyVal).Name) - } - cps = make([]CompositeKey, 0) err = testEngine.Find(&cps) - if err != nil { - t.Error(err) - } - if len(cps) != 2 { - t.Error(errors.New("should has two record")) - } - if cps[0] != compositeKeyVal { - t.Error(errors.New("should be equeal")) - } + assert.NoError(t, err) + assert.EqualValues(t, 2, len(cps), "should has two record") + assert.EqualValues(t, compositeKeyVal, cps[0], "should be equeal") compositeKeyVal = CompositeKey{UpdateStr: "test1"} - cnt, err = testEngine.Id(core.PK{11, 22}).Update(&compositeKeyVal) + cnt, err = testEngine.ID(core.PK{11, 22}).Update(&compositeKeyVal) if err != nil { t.Error(err) } else if cnt != 1 { t.Error(errors.New("can't update CompositeKey{11, 22}")) } - cnt, err = testEngine.Id(core.PK{11, 22}).Delete(&CompositeKey{}) + cnt, err = testEngine.ID(core.PK{11, 22}).Delete(&CompositeKey{}) if err != nil { t.Error(err) } else if cnt != 1 { @@ -795,16 +786,16 @@ func TestCompositeKey(t *testing.T) { } } -type User struct { - UserId string `xorm:"varchar(19) not null pk"` - NickName string `xorm:"varchar(19) not null"` - GameId uint32 `xorm:"integer pk"` - Score int32 `xorm:"integer"` -} - func TestCompositeKey2(t *testing.T) { assert.NoError(t, prepareEngine()) + type User struct { + UserId string `xorm:"varchar(19) not null pk"` + NickName string `xorm:"varchar(19) not null"` + GameId uint32 `xorm:"integer pk"` + Score int32 `xorm:"integer"` + } + err := testEngine.DropTables(&User{}) if err != nil { @@ -831,7 +822,7 @@ func TestCompositeKey2(t *testing.T) { } var user User - has, err := testEngine.Id(core.PK{"11", 22}).Get(&user) + has, err := testEngine.ID(core.PK{"11", 22}).Get(&user) if err != nil { t.Error(err) } else if !has { @@ -839,7 +830,7 @@ func TestCompositeKey2(t *testing.T) { } // test passing PK ptr, this test seem failed withCache - has, err = testEngine.Id(&core.PK{"11", 22}).Get(&user) + has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user) if err != nil { t.Error(err) } else if !has { @@ -847,14 +838,14 @@ func TestCompositeKey2(t *testing.T) { } user = User{NickName: "test1"} - cnt, err = testEngine.Id(core.PK{"11", 22}).Update(&user) + cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user) if err != nil { t.Error(err) } else if cnt != 1 { t.Error(errors.New("can't update User{11, 22}")) } - cnt, err = testEngine.Id(core.PK{"11", 22}).Delete(&User{}) + cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&User{}) if err != nil { t.Error(err) } else if cnt != 1 { @@ -899,7 +890,7 @@ func TestCompositeKey3(t *testing.T) { } var user UserPK2 - has, err := testEngine.Id(core.PK{"11", 22}).Get(&user) + has, err := testEngine.ID(core.PK{"11", 22}).Get(&user) if err != nil { t.Error(err) } else if !has { @@ -907,7 +898,7 @@ func TestCompositeKey3(t *testing.T) { } // test passing PK ptr, this test seem failed withCache - has, err = testEngine.Id(&core.PK{"11", 22}).Get(&user) + has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user) if err != nil { t.Error(err) } else if !has { @@ -915,14 +906,14 @@ func TestCompositeKey3(t *testing.T) { } user = UserPK2{NickName: "test1"} - cnt, err = testEngine.Id(core.PK{"11", 22}).Update(&user) + cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user) if err != nil { t.Error(err) } else if cnt != 1 { t.Error(errors.New("can't update User{11, 22}")) } - cnt, err = testEngine.Id(core.PK{"11", 22}).Delete(&UserPK2{}) + cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&UserPK2{}) if err != nil { t.Error(err) } else if cnt != 1 { @@ -1006,7 +997,7 @@ func TestMyIntId(t *testing.T) { panic(errors.New("should be equal")) } - cnt, err = testEngine.Id(bean.ID).Delete(&MyIntPK{}) + cnt, err = testEngine.ID(bean.ID).Delete(&MyIntPK{}) if err != nil { t.Error(err) panic(err) @@ -1094,7 +1085,7 @@ func TestMyStringId(t *testing.T) { panic(errors.New("should be equal")) } - cnt, err = testEngine.Id(bean.ID).Delete(&MyStringPK{}) + cnt, err = testEngine.ID(bean.ID).Delete(&MyStringPK{}) if err != nil { t.Error(err) panic(err) @@ -1105,3 +1096,62 @@ func TestMyStringId(t *testing.T) { panic(err) } } + +func TestSingleAutoIncrColumn(t *testing.T) { + type Account struct { + Id int64 `xorm:"pk autoincr"` + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(Account)) + + _, err := testEngine.Insert(&Account{}) + assert.NoError(t, err) +} + +func TestCompositePK(t *testing.T) { + type TaskSolution struct { + UID string `xorm:"notnull pk UUID 'uid'"` + TID string `xorm:"notnull pk UUID 'tid'"` + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(TaskSolution)) + + assert.NoError(t, testEngine.Sync2(new(TaskSolution))) + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + pkCols := tables[0].PKColumns() + assert.EqualValues(t, 2, len(pkCols)) + assert.EqualValues(t, "uid", pkCols[0].Name) + assert.EqualValues(t, "tid", pkCols[1].Name) +} + +func TestNoPKIdQueryUpdate(t *testing.T) { + type NoPKTable struct { + Username string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(NoPKTable)) + + cnt, err := testEngine.Insert(&NoPKTable{ + Username: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var res NoPKTable + has, err := testEngine.ID("test").Get(&res) + assert.Error(t, err) + assert.False(t, has) + + cnt, err = testEngine.ID("test").Update(&NoPKTable{ + Username: "test1", + }) + assert.Error(t, err) + assert.EqualValues(t, 0, cnt) +} diff --git a/session_query.go b/session_query.go new file mode 100644 index 00000000..a693bace --- /dev/null +++ b/session_query.go @@ -0,0 +1,177 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "reflect" + "strconv" + "time" + + "github.com/go-xorm/core" +) + +// Query runs a raw sql and return records as []map[string][]byte +func (session *Session) Query(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { + if session.isAutoClose { + defer session.Close() + } + + return session.queryBytes(sqlStr, args...) +} + +func value2String(rawValue *reflect.Value) (str string, err error) { + aa := reflect.TypeOf((*rawValue).Interface()) + vv := reflect.ValueOf((*rawValue).Interface()) + switch aa.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + case reflect.String: + str = vv.String() + case reflect.Array, reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + data := rawValue.Interface().([]byte) + str = string(data) + if str == "\x00" { + str = "0" + } + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + // time type + case reflect.Struct: + if aa.ConvertibleTo(core.TimeType) { + str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) + } else { + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + case reflect.Bool: + str = strconv.FormatBool(vv.Bool()) + case reflect.Complex128, reflect.Complex64: + str = fmt.Sprintf("%v", vv.Complex()) + /* TODO: unsupported types below + case reflect.Map: + case reflect.Ptr: + case reflect.Uintptr: + case reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.Interface: + */ + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + return +} + +func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) { + result := make(map[string]string) + scanResultContainers := make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var scanResultContainer interface{} + scanResultContainers[i] = &scanResultContainer + } + if err := rows.Scan(scanResultContainers...); err != nil { + return nil, err + } + + for ii, key := range fields { + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + // if row is null then as empty string + if rawValue.Interface() == nil { + result[key] = "" + continue + } + + if data, err := value2String(&rawValue); err == nil { + result[key] = data + } else { + return nil, err + } + } + return result, nil +} + +func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2mapStr(rows, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + + return resultsSlice, nil +} + +// QueryString runs a raw sql and return records as []map[string]string +func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) { + if session.isAutoClose { + defer session.Close() + } + + rows, err := session.queryRows(sqlStr, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return rows2Strings(rows) +} + +func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) { + resultsMap = make(map[string]interface{}, len(fields)) + scanResultContainers := make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var scanResultContainer interface{} + scanResultContainers[i] = &scanResultContainer + } + if err := rows.Scan(scanResultContainers...); err != nil { + return nil, err + } + + for ii, key := range fields { + resultsMap[key] = reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])).Interface() + } + return +} + +func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2mapInterface(rows, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + + return resultsSlice, nil +} + +// QueryInterface runs a raw sql and return records as []map[string]interface{} +func (session *Session) QueryInterface(sqlStr string, args ...interface{}) ([]map[string]interface{}, error) { + if session.isAutoClose { + defer session.Close() + } + + rows, err := session.queryRows(sqlStr, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return rows2Interfaces(rows) +} diff --git a/session_query_test.go b/session_query_test.go new file mode 100644 index 00000000..4bb4598b --- /dev/null +++ b/session_query_test.go @@ -0,0 +1,136 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestQueryString(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type GetVar2 struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Age int + Money float32 + Created time.Time `xorm:"created"` + } + + assert.NoError(t, testEngine.Sync2(new(GetVar2))) + + var data = GetVar2{ + Msg: "hi", + Age: 28, + Money: 1.5, + } + _, err := testEngine.InsertOne(data) + assert.NoError(t, err) + + records, err := testEngine.QueryString("select * from get_var2") + assert.NoError(t, err) + assert.Equal(t, 1, len(records)) + assert.Equal(t, 5, len(records[0])) + assert.Equal(t, "1", records[0]["id"]) + assert.Equal(t, "hi", records[0]["msg"]) + assert.Equal(t, "28", records[0]["age"]) + assert.Equal(t, "1.5", records[0]["money"]) +} + +func TestQueryString2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type GetVar3 struct { + Id int64 `xorm:"autoincr pk"` + Msg bool `xorm:"bit"` + } + + assert.NoError(t, testEngine.Sync2(new(GetVar3))) + + var data = GetVar3{ + Msg: false, + } + _, err := testEngine.Insert(data) + assert.NoError(t, err) + + records, err := testEngine.QueryString("select * from get_var3") + assert.NoError(t, err) + assert.Equal(t, 1, len(records)) + assert.Equal(t, 2, len(records[0])) + assert.Equal(t, "1", records[0]["id"]) + assert.True(t, "0" == records[0]["msg"] || "false" == records[0]["msg"]) +} + +func toString(i interface{}) string { + switch i.(type) { + case []byte: + return string(i.([]byte)) + case string: + return i.(string) + } + return fmt.Sprintf("%v", i) +} + +func toInt64(i interface{}) int64 { + switch i.(type) { + case []byte: + n, _ := strconv.ParseInt(string(i.([]byte)), 10, 64) + return n + case int: + return int64(i.(int)) + case int64: + return i.(int64) + } + return 0 +} + +func toFloat64(i interface{}) float64 { + switch i.(type) { + case []byte: + n, _ := strconv.ParseFloat(string(i.([]byte)), 64) + return n + case float64: + return i.(float64) + case float32: + return float64(i.(float32)) + } + return 0 +} + +func TestQueryInterface(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type GetVarInterface struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Age int + Money float32 + Created time.Time `xorm:"created"` + } + + assert.NoError(t, testEngine.Sync2(new(GetVarInterface))) + + var data = GetVarInterface{ + Msg: "hi", + Age: 28, + Money: 1.5, + } + _, err := testEngine.InsertOne(data) + assert.NoError(t, err) + + records, err := testEngine.QueryInterface("select * from get_var_interface") + assert.NoError(t, err) + assert.Equal(t, 1, len(records)) + assert.Equal(t, 5, len(records[0])) + assert.EqualValues(t, 1, toInt64(records[0]["id"])) + assert.Equal(t, "hi", toString(records[0]["msg"])) + assert.EqualValues(t, 28, toInt64(records[0]["age"])) + assert.EqualValues(t, 1.5, toFloat64(records[0]["money"])) +} diff --git a/session_raw.go b/session_raw.go index 0f5a0a43..c225598e 100644 --- a/session_raw.go +++ b/session_raw.go @@ -6,21 +6,133 @@ package xorm import ( "database/sql" + "reflect" + "time" "github.com/go-xorm/core" ) -func (session *Session) query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) { - session.queryPreprocess(&sqlStr, paramStr...) - - if session.IsAutoCommit { - return session.innerQuery2(sqlStr, paramStr...) +func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { + for _, filter := range session.engine.dialect.Filters() { + *sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable) } - return session.txQuery(session.Tx, sqlStr, paramStr...) + + session.lastSQL = *sqlStr + session.lastSQLArgs = paramStr } -func (session *Session) txQuery(tx *core.Tx, sqlStr string, params ...interface{}) ([]map[string][]byte, error) { - rows, err := tx.Query(sqlStr, params...) +func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Rows, error) { + defer session.resetStatement() + + session.queryPreprocess(&sqlStr, args...) + + if session.engine.showSQL { + if session.engine.showExecTime { + b4ExecTime := time.Now() + defer func() { + execDuration := time.Since(b4ExecTime) + if len(args) > 0 { + session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration) + } else { + session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration) + } + }() + } else { + if len(args) > 0 { + session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args) + } else { + session.engine.logger.Infof("[SQL] %v", sqlStr) + } + } + } + + if session.isAutoCommit { + if session.prepareStmt { + // don't clear stmt since session will cache them + stmt, err := session.doPrepare(sqlStr) + if err != nil { + return nil, err + } + + rows, err := stmt.Query(args...) + if err != nil { + return nil, err + } + return rows, nil + } + + rows, err := session.DB().Query(sqlStr, args...) + if err != nil { + return nil, err + } + return rows, nil + } + + rows, err := session.tx.Query(sqlStr, args...) + if err != nil { + return nil, err + } + return rows, nil +} + +func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row { + return core.NewRow(session.queryRows(sqlStr, args...)) +} + +func value2Bytes(rawValue *reflect.Value) ([]byte, error) { + str, err := value2String(rawValue) + if err != nil { + return nil, err + } + return []byte(str), nil +} + +func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) { + result := make(map[string][]byte) + scanResultContainers := make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var scanResultContainer interface{} + scanResultContainers[i] = &scanResultContainer + } + if err := rows.Scan(scanResultContainers...); err != nil { + return nil, err + } + + for ii, key := range fields { + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + //if row is null then ignore + if rawValue.Interface() == nil { + result[key] = []byte{} + continue + } + + if data, err := value2Bytes(&rawValue); err == nil { + result[key] = data + } else { + return nil, err // !nashtsai! REVIEW, should return err or just error log? + } + } + return result, nil +} + +func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2map(rows, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + + return resultsSlice, nil +} + +func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { + rows, err := session.queryRows(sqlStr, args...) if err != nil { return nil, err } @@ -29,74 +141,35 @@ func (session *Session) txQuery(tx *core.Tx, sqlStr string, params ...interface{ return rows2maps(rows) } -func (session *Session) innerQuery(sqlStr string, params ...interface{}) (*core.Stmt, *core.Rows, error) { - var callback func() (*core.Stmt, *core.Rows, error) - if session.prepareStmt { - callback = func() (*core.Stmt, *core.Rows, error) { - stmt, err := session.doPrepare(sqlStr) - if err != nil { - return nil, nil, err - } - rows, err := stmt.Query(params...) - if err != nil { - return nil, nil, err - } - return stmt, rows, nil - } - } else { - callback = func() (*core.Stmt, *core.Rows, error) { - rows, err := session.DB().Query(sqlStr, params...) - if err != nil { - return nil, nil, err - } - return nil, rows, err - } - } - stmt, rows, err := session.Engine.logSQLQueryTime(sqlStr, params, callback) - if err != nil { - return nil, nil, err - } - return stmt, rows, nil -} - -func (session *Session) innerQuery2(sqlStr string, params ...interface{}) ([]map[string][]byte, error) { - _, rows, err := session.innerQuery(sqlStr, params...) - if rows != nil { - defer rows.Close() - } - if err != nil { - return nil, err - } - return rows2maps(rows) -} - -// Query runs a raw sql and return records as []map[string][]byte -func (session *Session) Query(sqlStr string, paramStr ...interface{}) ([]map[string][]byte, error) { +func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - return session.query(sqlStr, paramStr...) -} - -// QueryString runs a raw sql and return records as []map[string]string -func (session *Session) QueryString(sqlStr string, args ...interface{}) ([]map[string]string, error) { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } session.queryPreprocess(&sqlStr, args...) - if session.IsAutoCommit { - return query2(session.DB(), sqlStr, args...) + if session.engine.showSQL { + if session.engine.showExecTime { + b4ExecTime := time.Now() + defer func() { + execDuration := time.Since(b4ExecTime) + if len(args) > 0 { + session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration) + } else { + session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration) + } + }() + } else { + if len(args) > 0 { + session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args) + } else { + session.engine.logger.Infof("[SQL] %v", sqlStr) + } + } + } + + if !session.isAutoCommit { + return session.tx.Exec(sqlStr, args...) } - return txQuery2(session.Tx, sqlStr, args...) -} -// Execute sql -func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) { if session.prepareStmt { stmt, err := session.doPrepare(sqlStr) if err != nil { @@ -113,33 +186,9 @@ func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Resul return session.DB().Exec(sqlStr, args...) } -func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { - for _, filter := range session.Engine.dialect.Filters() { - // TODO: for table name, it's no need to RefTable - sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) - } - - session.saveLastSQL(sqlStr, args...) - - return session.Engine.logSQLExecutionTime(sqlStr, args, func() (sql.Result, error) { - if session.IsAutoCommit { - // FIXME: oci8 can not auto commit (github.com/mattn/go-oci8) - if session.Engine.dialect.DBType() == core.ORACLE { - session.Begin() - r, err := session.Tx.Exec(sqlStr, args...) - session.Commit() - return r, err - } - return session.innerExec(sqlStr, args...) - } - return session.Tx.Exec(sqlStr, args...) - }) -} - // Exec raw sql func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } diff --git a/session_raw_test.go b/session_raw_test.go index 126f617f..cf381974 100644 --- a/session_raw_test.go +++ b/session_raw_test.go @@ -5,39 +5,33 @@ package xorm import ( + "strconv" "testing" - "time" "github.com/stretchr/testify/assert" ) -func TestQueryString(t *testing.T) { +func TestQuery(t *testing.T) { assert.NoError(t, prepareEngine()) - type GetVar struct { - Id int64 `xorm:"autoincr pk"` - Msg string `xorm:"varchar(255)"` - Age int - Money float32 - Created time.Time `xorm:"created"` + type UserinfoQuery struct { + Uid int + Name string } - assert.NoError(t, testEngine.Sync2(new(GetVar))) + assert.NoError(t, testEngine.Sync(new(UserinfoQuery))) - var data = GetVar{ - Msg: "hi", - Age: 28, - Money: 1.5, - } - _, err := testEngine.InsertOne(data) + res, err := testEngine.Exec("INSERT INTO `userinfo_query` (uid, name) VALUES (?, ?)", 1, "user") assert.NoError(t, err) + cnt, err := res.RowsAffected() + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) - records, err := testEngine.QueryString("select * from get_var") + results, err := testEngine.Query("select * from userinfo_query") assert.NoError(t, err) - assert.Equal(t, 1, len(records)) - assert.Equal(t, 5, len(records[0])) - assert.Equal(t, "1", records[0]["id"]) - assert.Equal(t, "hi", records[0]["msg"]) - assert.Equal(t, "28", records[0]["age"]) - assert.Equal(t, "1.5", records[0]["money"]) + assert.EqualValues(t, 1, len(results)) + id, err := strconv.Atoi(string(results[0]["uid"])) + assert.NoError(t, err) + assert.EqualValues(t, 1, id) + assert.Equal(t, "user", string(results[0]["name"])) } diff --git a/session_schema.go b/session_schema.go index 19c0cbf5..a2708b73 100644 --- a/session_schema.go +++ b/session_schema.go @@ -16,42 +16,50 @@ import ( // Ping test if database is ok func (session *Session) Ping() error { - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } + session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) return session.DB().Ping() } // CreateTable create a table according a bean func (session *Session) CreateTable(bean interface{}) error { - v := rValue(bean) - if err := session.Statement.setRefValue(v); err != nil { - return err - } - - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } - return session.createOneTable() + return session.createTable(bean) +} + +func (session *Session) createTable(bean interface{}) error { + v := rValue(bean) + if err := session.statement.setRefValue(v); err != nil { + return err + } + + sqlStr := session.statement.genCreateTableSQL() + _, err := session.exec(sqlStr) + return err } // CreateIndexes create indexes func (session *Session) CreateIndexes(bean interface{}) error { - v := rValue(bean) - if err := session.Statement.setRefValue(v); err != nil { - return err - } - - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } - sqls := session.Statement.genIndexSQL() + return session.createIndexes(bean) +} + +func (session *Session) createIndexes(bean interface{}) error { + v := rValue(bean) + if err := session.statement.setRefValue(v); err != nil { + return err + } + + sqls := session.statement.genIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -63,17 +71,19 @@ func (session *Session) CreateIndexes(bean interface{}) error { // CreateUniques create uniques func (session *Session) CreateUniques(bean interface{}) error { + if session.isAutoClose { + defer session.Close() + } + return session.createUniques(bean) +} + +func (session *Session) createUniques(bean interface{}) error { v := rValue(bean) - if err := session.Statement.setRefValue(v); err != nil { + if err := session.statement.setRefValue(v); err != nil { return err } - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - sqls := session.Statement.genUniqueSQL() + sqls := session.statement.genUniqueSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -83,43 +93,22 @@ func (session *Session) CreateUniques(bean interface{}) error { return nil } -func (session *Session) createOneTable() error { - sqlStr := session.Statement.genCreateTableSQL() - _, err := session.exec(sqlStr) - return err -} - -// to be deleted -func (session *Session) createAll() error { - if session.IsAutoClose { +// DropIndexes drop indexes +func (session *Session) DropIndexes(bean interface{}) error { + if session.isAutoClose { defer session.Close() } - for _, table := range session.Engine.Tables { - session.Statement.RefTable = table - session.Statement.tableName = table.Name - err := session.createOneTable() - session.resetStatement() - if err != nil { - return err - } - } - return nil + return session.dropIndexes(bean) } -// DropIndexes drop indexes -func (session *Session) DropIndexes(bean interface{}) error { +func (session *Session) dropIndexes(bean interface{}) error { v := rValue(bean) - if err := session.Statement.setRefValue(v); err != nil { + if err := session.statement.setRefValue(v); err != nil { return err } - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - sqls := session.Statement.genDelIndexSQL() + sqls := session.statement.genDelIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -131,15 +120,23 @@ func (session *Session) DropIndexes(bean interface{}) error { // DropTable drop table will drop table if exist, if drop failed, it will return error func (session *Session) DropTable(beanOrTableName interface{}) error { - tableName, err := session.Engine.tableName(beanOrTableName) + if session.isAutoClose { + defer session.Close() + } + + return session.dropTable(beanOrTableName) +} + +func (session *Session) dropTable(beanOrTableName interface{}) error { + tableName, err := session.engine.tableName(beanOrTableName) if err != nil { return err } var needDrop = true - if !session.Engine.dialect.SupportDropIfExists() { - sqlStr, args := session.Engine.dialect.TableCheckSql(tableName) - results, err := session.query(sqlStr, args...) + if !session.engine.dialect.SupportDropIfExists() { + sqlStr, args := session.engine.dialect.TableCheckSql(tableName) + results, err := session.queryBytes(sqlStr, args...) if err != nil { return err } @@ -147,7 +144,7 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { } if needDrop { - sqlStr := session.Engine.Dialect().DropTableSql(tableName) + sqlStr := session.engine.Dialect().DropTableSql(tableName) _, err = session.exec(sqlStr) return err } @@ -156,7 +153,11 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { // IsTableExist if a table is exist func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) { - tableName, err := session.Engine.tableName(beanOrTableName) + if session.isAutoClose { + defer session.Close() + } + + tableName, err := session.engine.tableName(beanOrTableName) if err != nil { return false, err } @@ -165,12 +166,8 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) } func (session *Session) isTableExist(tableName string) (bool, error) { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - sqlStr, args := session.Engine.dialect.TableCheckSql(tableName) - results, err := session.query(sqlStr, args...) + sqlStr, args := session.engine.dialect.TableCheckSql(tableName) + results, err := session.queryBytes(sqlStr, args...) return len(results) > 0, err } @@ -180,6 +177,9 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { t := v.Type() if t.Kind() == reflect.String { + if session.isAutoClose { + defer session.Close() + } return session.isTableEmpty(bean.(string)) } else if t.Kind() == reflect.Struct { rows, err := session.Count(bean) @@ -189,15 +189,9 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { } func (session *Session) isTableEmpty(tableName string) (bool, error) { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - var total int64 - sqlStr := fmt.Sprintf("select count(*) from %s", session.Engine.Quote(tableName)) - err := session.DB().QueryRow(sqlStr).Scan(&total) - session.saveLastSQL(sqlStr) + sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) + err := session.queryRow(sqlStr).Scan(&total) if err != nil { if err == sql.ErrNoRows { err = nil @@ -208,30 +202,9 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) { return total == 0, nil } -func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bool, error) { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - var idx string - if unique { - idx = uniqueName(tableName, idxName) - } else { - idx = indexName(tableName, idxName) - } - sqlStr, args := session.Engine.dialect.IndexCheckSql(tableName, idx) - results, err := session.query(sqlStr, args...) - return len(results) > 0, err -} - // find if index is exist according cols func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - indexes, err := session.Engine.dialect.GetIndexes(tableName) + indexes, err := session.engine.dialect.GetIndexes(tableName) if err != nil { return false, err } @@ -248,62 +221,34 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo } func (session *Session) addColumn(colName string) error { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - col := session.Statement.RefTable.GetColumn(colName) - sql, args := session.Statement.genAddColumnStr(col) + col := session.statement.RefTable.GetColumn(colName) + sql, args := session.statement.genAddColumnStr(col) _, err := session.exec(sql, args...) return err } func (session *Session) addIndex(tableName, idxName string) error { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - index := session.Statement.RefTable.Indexes[idxName] - sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index) - + index := session.statement.RefTable.Indexes[idxName] + sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) _, err := session.exec(sqlStr) return err } func (session *Session) addUnique(tableName, uqeName string) error { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - index := session.Statement.RefTable.Indexes[uqeName] - sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index) + index := session.statement.RefTable.Indexes[uqeName] + sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) _, err := session.exec(sqlStr) return err } -// To be deleted -func (session *Session) dropAll() error { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - for _, table := range session.Engine.Tables { - session.Statement.Init() - session.Statement.RefTable = table - sqlStr := session.Engine.Dialect().DropTableSql(session.Statement.TableName()) - _, err := session.exec(sqlStr) - if err != nil { - return err - } - } - return nil -} - // Sync2 synchronize structs to database tables func (session *Session) Sync2(beans ...interface{}) error { - engine := session.Engine + engine := session.engine + + if session.isAutoClose { + session.isAutoClose = false + defer session.Close() + } tables, err := engine.DBMetas() if err != nil { @@ -330,17 +275,17 @@ func (session *Session) Sync2(beans ...interface{}) error { } if oriTable == nil { - err = session.StoreEngine(session.Statement.StoreEngine).CreateTable(bean) + err = session.StoreEngine(session.statement.StoreEngine).createTable(bean) if err != nil { return err } - err = session.CreateUniques(bean) + err = session.createUniques(bean) if err != nil { return err } - err = session.CreateIndexes(bean) + err = session.createIndexes(bean) if err != nil { return err } @@ -365,7 +310,7 @@ func (session *Session) Sync2(beans ...interface{}) error { engine.dialect.DBType() == core.POSTGRES { engine.logger.Infof("Table %s column %s change type from %s to %s\n", tbName, col.Name, curType, expectedType) - _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) } else { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", tbName, col.Name, curType, expectedType) @@ -375,7 +320,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbName, col.Name, oriCol.Length, col.Length) - _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) } } } else { @@ -389,7 +334,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbName, col.Name, oriCol.Length, col.Length) - _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) } } } @@ -402,10 +347,8 @@ func (session *Session) Sync2(beans ...interface{}) error { tbName, col.Name, oriCol.Nullable, col.Nullable) } } else { - session := engine.NewSession() - session.Statement.RefTable = table - session.Statement.tableName = tbName - defer session.Close() + session.statement.RefTable = table + session.statement.tableName = tbName err = session.addColumn(col.Name) } if err != nil { @@ -429,7 +372,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriIndex != nil { if oriIndex.Type != index.Type { sql := engine.dialect.DropIndexSql(tbName, oriIndex) - _, err = engine.Exec(sql) + _, err = session.exec(sql) if err != nil { return err } @@ -445,7 +388,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for name2, index2 := range oriTable.Indexes { if _, ok := foundIndexNames[name2]; !ok { sql := engine.dialect.DropIndexSql(tbName, index2) - _, err = engine.Exec(sql) + _, err = session.exec(sql) if err != nil { return err } @@ -454,16 +397,12 @@ func (session *Session) Sync2(beans ...interface{}) error { for name, index := range addedNames { if index.Type == core.UniqueType { - session := engine.NewSession() - session.Statement.RefTable = table - session.Statement.tableName = tbName - defer session.Close() + session.statement.RefTable = table + session.statement.tableName = tbName err = session.addUnique(tbName, name) } else if index.Type == core.IndexType { - session := engine.NewSession() - session.Statement.RefTable = table - session.Statement.tableName = tbName - defer session.Close() + session.statement.RefTable = table + session.statement.tableName = tbName err = session.addIndex(tbName, name) } if err != nil { diff --git a/session_schema_test.go b/session_schema_test.go new file mode 100644 index 00000000..be999ce3 --- /dev/null +++ b/session_schema_test.go @@ -0,0 +1,219 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStoreEngine(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assert.NoError(t, testEngine.DropTables("user_store_engine")) + + type UserinfoStoreEngine struct { + Id int64 + Name string + } + + assert.NoError(t, testEngine.StoreEngine("InnoDB").Table("user_store_engine").CreateTable(&UserinfoStoreEngine{})) +} + +func TestCreateTable(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assert.NoError(t, testEngine.DropTables("user_user")) + + type UserinfoCreateTable struct { + Id int64 + Name string + } + + assert.NoError(t, testEngine.Table("user_user").CreateTable(&UserinfoCreateTable{})) +} + +func TestCreateMultiTables(t *testing.T) { + assert.NoError(t, prepareEngine()) + + session := testEngine.NewSession() + defer session.Close() + + type UserinfoMultiTable struct { + Id int64 + Name string + } + + user := &UserinfoMultiTable{} + assert.NoError(t, session.Begin()) + + for i := 0; i < 10; i++ { + tableName := fmt.Sprintf("user_%v", i) + + assert.NoError(t, session.DropTable(tableName)) + + assert.NoError(t, session.Table(tableName).CreateTable(user)) + } + + assert.NoError(t, session.Commit()) +} + +type SyncTable1 struct { + Id int64 + Name string + Dev int `xorm:"index"` +} + +type SyncTable2 struct { + Id int64 + Name string `xorm:"unique"` + Number string `xorm:"index"` + Dev int + Age int +} + +func (SyncTable2) TableName() string { + return "sync_table1" +} + +func TestSyncTable(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assert.NoError(t, testEngine.Sync2(new(SyncTable1))) + + assert.NoError(t, testEngine.Sync2(new(SyncTable2))) +} + +func TestIsTableExist(t *testing.T) { + assert.NoError(t, prepareEngine()) + + exist, err := testEngine.IsTableExist(new(CustomTableName)) + assert.NoError(t, err) + assert.False(t, exist) + + assert.NoError(t, testEngine.CreateTables(new(CustomTableName))) + + exist, err = testEngine.IsTableExist(new(CustomTableName)) + assert.NoError(t, err) + assert.True(t, exist) +} + +func TestIsTableEmpty(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type NumericEmpty struct { + Numeric float64 `xorm:"numeric(26,2)"` + } + + type PictureEmpty struct { + Id int64 + Url string `xorm:"unique"` //image's url + Title string + Description string + Created time.Time `xorm:"created"` + ILike int + PageView int + From_url string + Pre_url string `xorm:"unique"` //pre view image's url + Uid int64 + } + + assert.NoError(t, testEngine.DropTables(&PictureEmpty{}, &NumericEmpty{})) + + assert.NoError(t, testEngine.Sync(new(PictureEmpty), new(NumericEmpty))) + + isEmpty, err := testEngine.IsTableEmpty(&PictureEmpty{}) + assert.NoError(t, err) + assert.True(t, isEmpty) + + tbName := testEngine.TableMapper.Obj2Table("PictureEmpty") + isEmpty, err = testEngine.IsTableEmpty(tbName) + assert.NoError(t, err) + assert.True(t, isEmpty) +} + +type CustomTableName struct { + Id int64 + Name string +} + +func (c *CustomTableName) TableName() string { + return "customtablename" +} + +func TestCustomTableName(t *testing.T) { + assert.NoError(t, prepareEngine()) + + c := new(CustomTableName) + assert.NoError(t, testEngine.DropTables(c)) + + assert.NoError(t, testEngine.CreateTables(c)) +} + +func TestDump(t *testing.T) { + assert.NoError(t, prepareEngine()) + + fp := testEngine.Dialect().URI().DbName + ".sql" + os.Remove(fp) + assert.NoError(t, testEngine.DumpAllToFile(fp)) +} + +type IndexOrUnique struct { + Id int64 + Index int `xorm:"index"` + Unique int `xorm:"unique"` + Group1 int `xorm:"index(ttt)"` + Group2 int `xorm:"index(ttt)"` + UniGroup1 int `xorm:"unique(lll)"` + UniGroup2 int `xorm:"unique(lll)"` +} + +func TestIndexAndUnique(t *testing.T) { + assert.NoError(t, prepareEngine()) + + assert.NoError(t, testEngine.CreateTables(&IndexOrUnique{})) + + assert.NoError(t, testEngine.DropTables(&IndexOrUnique{})) + + assert.NoError(t, testEngine.CreateTables(&IndexOrUnique{})) + + assert.NoError(t, testEngine.CreateIndexes(&IndexOrUnique{})) + + assert.NoError(t, testEngine.CreateUniques(&IndexOrUnique{})) + + assert.NoError(t, testEngine.DropIndexes(&IndexOrUnique{})) +} + +func TestMetaInfo(t *testing.T) { + assert.NoError(t, prepareEngine()) + assert.NoError(t, testEngine.Sync2(new(CustomTableName), new(IndexOrUnique))) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 2, len(tables)) + tableNames := []string{tables[0].Name, tables[1].Name} + assert.Contains(t, tableNames, "customtablename") + assert.Contains(t, tableNames, "index_or_unique") +} + +func TestCharst(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables("user_charset") + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.Charset("utf8").Table("user_charset").CreateTable(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } +} diff --git a/session_stats.go b/session_stats.go new file mode 100644 index 00000000..c2cac830 --- /dev/null +++ b/session_stats.go @@ -0,0 +1,98 @@ +// Copyright 2016 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "database/sql" + "errors" + "reflect" +) + +// Count counts the records. bean's non-empty fields +// are conditions. +func (session *Session) Count(bean ...interface{}) (int64, error) { + if session.isAutoClose { + defer session.Close() + } + + var sqlStr string + var args []interface{} + var err error + if session.statement.RawSQL == "" { + sqlStr, args, err = session.statement.genCountSQL(bean...) + if err != nil { + return 0, err + } + } else { + sqlStr = session.statement.RawSQL + args = session.statement.RawParams + } + + var total int64 + err = session.queryRow(sqlStr, args...).Scan(&total) + if err == sql.ErrNoRows || err == nil { + return total, nil + } + + return 0, err +} + +// sum call sum some column. bean's non-empty fields are conditions. +func (session *Session) sum(res interface{}, bean interface{}, columnNames ...string) error { + if session.isAutoClose { + defer session.Close() + } + + v := reflect.ValueOf(res) + if v.Kind() != reflect.Ptr { + return errors.New("need a pointer to a variable") + } + + var isSlice = v.Elem().Kind() == reflect.Slice + var sqlStr string + var args []interface{} + var err error + if len(session.statement.RawSQL) == 0 { + sqlStr, args, err = session.statement.genSumSQL(bean, columnNames...) + if err != nil { + return err + } + } else { + sqlStr = session.statement.RawSQL + args = session.statement.RawParams + } + + if isSlice { + err = session.queryRow(sqlStr, args...).ScanSlice(res) + } else { + err = session.queryRow(sqlStr, args...).Scan(res) + } + if err == sql.ErrNoRows || err == nil { + return nil + } + return err +} + +// Sum call sum some column. bean's non-empty fields are conditions. +func (session *Session) Sum(bean interface{}, columnName string) (res float64, err error) { + return res, session.sum(&res, bean, columnName) +} + +// SumInt call sum some column. bean's non-empty fields are conditions. +func (session *Session) SumInt(bean interface{}, columnName string) (res int64, err error) { + return res, session.sum(&res, bean, columnName) +} + +// Sums call sum some columns. bean's non-empty fields are conditions. +func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) { + var res = make([]float64, len(columnNames), len(columnNames)) + return res, session.sum(&res, bean, columnNames...) +} + +// SumsInt sum specify columns and return as []int64 instead of []float64 +func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) { + var res = make([]int64, len(columnNames), len(columnNames)) + return res, session.sum(&res, bean, columnNames...) +} diff --git a/session_sum_test.go b/session_stats_test.go similarity index 57% rename from session_sum_test.go rename to session_stats_test.go index 31a65f98..17eaf6dc 100644 --- a/session_sum_test.go +++ b/session_stats_test.go @@ -9,6 +9,7 @@ import ( "strconv" "testing" + "github.com/go-xorm/builder" "github.com/stretchr/testify/assert" ) @@ -75,27 +76,85 @@ func TestSum(t *testing.T) { func TestSumCustomColumn(t *testing.T) { assert.NoError(t, prepareEngine()) - type SumStruct struct { + type SumStruct2 struct { Int int Float float32 } var ( - cases = []SumStruct{ + cases = []SumStruct2{ {1, 6.2}, {2, 5.3}, {92, -0.2}, } ) - assert.NoError(t, testEngine.Sync2(new(SumStruct))) + assert.NoError(t, testEngine.Sync2(new(SumStruct2))) cnt, err := testEngine.Insert(cases) assert.NoError(t, err) assert.EqualValues(t, 3, cnt) - sumInt, err := testEngine.Sum(new(SumStruct), + sumInt, err := testEngine.Sum(new(SumStruct2), "CASE WHEN `int` <= 2 THEN `int` ELSE 0 END") assert.NoError(t, err) assert.EqualValues(t, 3, int(sumInt)) } + +func TestCount(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserinfoCount struct { + Departname string + } + assert.NoError(t, testEngine.Sync2(new(UserinfoCount))) + + colName := testEngine.ColumnMapper.Obj2Table("Departname") + var cond builder.Cond = builder.Eq{ + "`" + colName + "`": "dev", + } + + total, err := testEngine.Where(cond).Count(new(UserinfoCount)) + assert.NoError(t, err) + assert.EqualValues(t, 0, total) + + cnt, err := testEngine.Insert(&UserinfoCount{ + Departname: "dev", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + total, err = testEngine.Where(cond).Count(new(UserinfoCount)) + assert.NoError(t, err) + assert.EqualValues(t, 1, total) + + total, err = testEngine.Where(cond).Table("userinfo_count").Count() + assert.NoError(t, err) + assert.EqualValues(t, 1, total) + + total, err = testEngine.Table("userinfo_count").Count() + assert.NoError(t, err) + assert.EqualValues(t, 1, total) +} + +func TestSQLCount(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserinfoCount2 struct { + Id int64 + Departname string + } + + type UserinfoBooks struct { + Id int64 + Pid int64 + IsOpen bool + } + + assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) + + total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2"). + Count() + assert.NoError(t, err) + assert.EqualValues(t, 0, total) +} diff --git a/session_sum.go b/session_sum.go deleted file mode 100644 index e1409c7f..00000000 --- a/session_sum.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2016 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import "database/sql" - -// Count counts the records. bean's non-empty fields -// are conditions. -func (session *Session) Count(bean interface{}) (int64, error) { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - var sqlStr string - var args []interface{} - if session.Statement.RawSQL == "" { - sqlStr, args = session.Statement.genCountSQL(bean) - } else { - sqlStr = session.Statement.RawSQL - args = session.Statement.RawParams - } - - session.queryPreprocess(&sqlStr, args...) - - var err error - var total int64 - if session.IsAutoCommit { - err = session.DB().QueryRow(sqlStr, args...).Scan(&total) - } else { - err = session.Tx.QueryRow(sqlStr, args...).Scan(&total) - } - - if err == sql.ErrNoRows || err == nil { - return total, nil - } - - return 0, err -} - -// Sum call sum some column. bean's non-empty fields are conditions. -func (session *Session) Sum(bean interface{}, columnName string) (float64, error) { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - var sqlStr string - var args []interface{} - if len(session.Statement.RawSQL) == 0 { - sqlStr, args = session.Statement.genSumSQL(bean, columnName) - } else { - sqlStr = session.Statement.RawSQL - args = session.Statement.RawParams - } - - session.queryPreprocess(&sqlStr, args...) - - var err error - var res float64 - if session.IsAutoCommit { - err = session.DB().QueryRow(sqlStr, args...).Scan(&res) - } else { - err = session.Tx.QueryRow(sqlStr, args...).Scan(&res) - } - - if err == sql.ErrNoRows || err == nil { - return res, nil - } - return 0, err -} - -// Sums call sum some columns. bean's non-empty fields are conditions. -func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - var sqlStr string - var args []interface{} - if len(session.Statement.RawSQL) == 0 { - sqlStr, args = session.Statement.genSumSQL(bean, columnNames...) - } else { - sqlStr = session.Statement.RawSQL - args = session.Statement.RawParams - } - - session.queryPreprocess(&sqlStr, args...) - - var err error - var res = make([]float64, len(columnNames), len(columnNames)) - if session.IsAutoCommit { - err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) - } else { - err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res) - } - - if err == sql.ErrNoRows || err == nil { - return res, nil - } - return nil, err -} - -// SumsInt sum specify columns and return as []int64 instead of []float64 -func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) { - defer session.resetStatement() - if session.IsAutoClose { - defer session.Close() - } - - var sqlStr string - var args []interface{} - if len(session.Statement.RawSQL) == 0 { - sqlStr, args = session.Statement.genSumSQL(bean, columnNames...) - } else { - sqlStr = session.Statement.RawSQL - args = session.Statement.RawParams - } - - session.queryPreprocess(&sqlStr, args...) - - var err error - var res = make([]int64, len(columnNames), len(columnNames)) - if session.IsAutoCommit { - err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) - } else { - err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res) - } - - if err == sql.ErrNoRows || err == nil { - return res, nil - } - return nil, err -} diff --git a/session_test.go b/session_test.go new file mode 100644 index 00000000..d003274d --- /dev/null +++ b/session_test.go @@ -0,0 +1,23 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClose(t *testing.T) { + assert.NoError(t, prepareEngine()) + + sess1 := testEngine.NewSession() + sess1.Close() + assert.True(t, sess1.IsClosed()) + + sess2 := testEngine.Where("a = ?", 1) + sess2.Close() + assert.True(t, sess2.IsClosed()) +} diff --git a/session_tx.go b/session_tx.go index 302bc104..84d2f7f9 100644 --- a/session_tx.go +++ b/session_tx.go @@ -6,14 +6,14 @@ package xorm // Begin a transaction func (session *Session) Begin() error { - if session.IsAutoCommit { + if session.isAutoCommit { tx, err := session.DB().Begin() if err != nil { return err } - session.IsAutoCommit = false - session.IsCommitedOrRollbacked = false - session.Tx = tx + session.isAutoCommit = false + session.isCommitedOrRollbacked = false + session.tx = tx session.saveLastSQL("BEGIN TRANSACTION") } return nil @@ -21,25 +21,23 @@ func (session *Session) Begin() error { // Rollback When using transaction, you can rollback if any error func (session *Session) Rollback() error { - if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { - session.saveLastSQL(session.Engine.dialect.RollBackStr()) - session.IsCommitedOrRollbacked = true - return session.Tx.Rollback() + if !session.isAutoCommit && !session.isCommitedOrRollbacked { + session.saveLastSQL(session.engine.dialect.RollBackStr()) + session.isCommitedOrRollbacked = true + return session.tx.Rollback() } return nil } // Commit When using transaction, Commit will commit all operations. func (session *Session) Commit() error { - if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { + if !session.isAutoCommit && !session.isCommitedOrRollbacked { session.saveLastSQL("COMMIT") - session.IsCommitedOrRollbacked = true + session.isCommitedOrRollbacked = true var err error - if err = session.Tx.Commit(); err == nil { + if err = session.tx.Commit(); err == nil { // handle processors after tx committed - closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { - if closuresPtr != nil { for _, closure := range *closuresPtr { closure(bean) diff --git a/session_tx_test.go b/session_tx_test.go new file mode 100644 index 00000000..3e71bb40 --- /dev/null +++ b/session_tx_test.go @@ -0,0 +1,192 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "testing" + "time" + + "github.com/go-xorm/core" + "github.com/stretchr/testify/assert" +) + +func TestTransaction(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + counter := func() { + total, err := testEngine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + } + fmt.Printf("----now total %v records\n", total) + } + + counter() + //defer counter() + + session := testEngine.NewSession() + defer session.Close() + + err := session.Begin() + if err != nil { + t.Error(err) + panic(err) + return + } + + user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + return + } + + user2 := Userinfo{Username: "yyy"} + _, err = session.Where("(id) = ?", 0).Update(&user2) + if err != nil { + session.Rollback() + fmt.Println(err) + //t.Error(err) + return + } + + _, err = session.Delete(&user2) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + return + } + + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + return + } + // panic(err) !nashtsai! should remove this +} + +func TestCombineTransaction(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + counter := func() { + total, err := testEngine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + } + fmt.Printf("----now total %v records\n", total) + } + + counter() + //defer counter() + session := testEngine.NewSession() + defer session.Close() + + err := session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + + user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } + user2 := Userinfo{Username: "zzz"} + _, err = session.Where("id = ?", 0).Update(&user2) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } + + _, err = session.Exec("delete from userinfo where username = ?", user2.Username) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } + + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } +} + +func TestCombineTransactionSameMapper(t *testing.T) { + assert.NoError(t, prepareEngine()) + + oldMapper := testEngine.ColumnMapper + testEngine.unMapType(rValue(new(Userinfo)).Type()) + testEngine.SetMapper(core.SameMapper{}) + defer func() { + testEngine.unMapType(rValue(new(Userinfo)).Type()) + testEngine.SetMapper(oldMapper) + }() + + assertSync(t, new(Userinfo)) + + counter := func() { + total, err := testEngine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + } + fmt.Printf("----now total %v records\n", total) + } + + counter() + defer counter() + session := testEngine.NewSession() + defer session.Close() + + err := session.Begin() + if err != nil { + t.Error(err) + panic(err) + return + } + + user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + return + } + + user2 := Userinfo{Username: "zzz"} + _, err = session.Where("(id) = ?", 0).Update(&user2) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + return + } + + _, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + return + } + + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } +} diff --git a/session_update.go b/session_update.go index 1d77d294..4e0f656d 100644 --- a/session_update.go +++ b/session_update.go @@ -15,20 +15,20 @@ import ( "github.com/go-xorm/core" ) -func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { - if session.Statement.RefTable == nil || - session.Tx != nil { +func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error { + if table == nil || + session.tx != nil { return ErrCacheFailed } - oldhead, newsql := session.Statement.convertUpdateSQL(sqlStr) + oldhead, newsql := session.statement.convertUpdateSQL(sqlStr) if newsql == "" { return ErrCacheFailed } - for _, filter := range session.Engine.dialect.Filters() { - newsql = filter.Do(newsql, session.Engine.dialect, session.Statement.RefTable) + for _, filter := range session.engine.dialect.Filters() { + newsql = filter.Do(newsql, session.engine.dialect, table) } - session.Engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql) + session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql) var nStart int if len(args) > 0 { @@ -39,13 +39,12 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { nStart = strings.Count(oldhead, "$") } } - table := session.Statement.RefTable - cacher := session.Engine.getCacher2(table) - tableName := session.Statement.TableName() - session.Engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) + + cacher := session.engine.getCacher2(table) + session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) if err != nil { - rows, err := session.DB().Query(newsql, args[nStart:]...) + rows, err := session.NoCache().queryRows(newsql, args[nStart:]...) if err != nil { return err } @@ -75,9 +74,9 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { ids = append(ids, pk) } - session.Engine.logger.Debug("[cacheUpdate] find updated id", ids) + session.engine.logger.Debug("[cacheUpdate] find updated id", ids) } /*else { - session.Engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) + session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) cacher.DelIds(tableName, genSqlKey(newsql, args)) }*/ @@ -103,36 +102,36 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { colName := sps2[len(sps2)-1] if strings.Contains(colName, "`") { colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) - } else if strings.Contains(colName, session.Engine.QuoteStr()) { - colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1)) + } else if strings.Contains(colName, session.engine.QuoteStr()) { + colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1)) } else { - session.Engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) + session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) return ErrCacheFailed } if col := table.GetColumn(colName); col != nil { fieldValue, err := col.ValueOf(bean) if err != nil { - session.Engine.logger.Error(err) + session.engine.logger.Error(err) } else { - session.Engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) - if col.IsVersion && session.Statement.checkVersion { + session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) + if col.IsVersion && session.statement.checkVersion { fieldValue.SetInt(fieldValue.Int() + 1) } else { fieldValue.Set(reflect.ValueOf(args[idx])) } } } else { - session.Engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's", + session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's", colName, table.Name) } } - session.Engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean) + session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean) cacher.PutBean(tableName, sid, bean) } } - session.Engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName) + session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName) cacher.ClearIds(tableName) return nil } @@ -144,8 +143,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { // You should call UseBool if you have bool to use. // 2.float32 & float64 may be not inexact as conditions func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { - defer session.resetStatement() - if session.IsAutoClose { + if session.isAutoClose { defer session.Close() } @@ -169,21 +167,21 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var isMap = t.Kind() == reflect.Map var isStruct = t.Kind() == reflect.Struct if isStruct { - if err := session.Statement.setRefValue(v); err != nil { + if err := session.statement.setRefValue(v); err != nil { return 0, err } - if len(session.Statement.TableName()) <= 0 { + if len(session.statement.TableName()) <= 0 { return 0, ErrTableNotFound } - if session.Statement.ColumnStr == "" { - colNames, args = buildUpdates(session.Engine, session.Statement.RefTable, bean, false, false, - false, false, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.mustColumnMap, session.Statement.nullableMap, - session.Statement.columnMap, true, session.Statement.unscoped) + if session.statement.ColumnStr == "" { + colNames, args = buildUpdates(session.engine, session.statement.RefTable, bean, false, false, + false, false, session.statement.allUseBool, session.statement.useAllCols, + session.statement.mustColumnMap, session.statement.nullableMap, + session.statement.columnMap, true, session.statement.unscoped) } else { - colNames, args, err = genCols(session.Statement.RefTable, session, bean, true, true) + colNames, args, err = genCols(session.statement.RefTable, session, bean, true, true) if err != nil { return 0, err } @@ -194,68 +192,71 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 bValue := reflect.Indirect(reflect.ValueOf(bean)) for _, v := range bValue.MapKeys() { - colNames = append(colNames, session.Engine.Quote(v.String())+" = ?") + colNames = append(colNames, session.engine.Quote(v.String())+" = ?") args = append(args, bValue.MapIndex(v).Interface()) } } else { return 0, ErrParamsType } - table := session.Statement.RefTable + table := session.statement.RefTable - if session.Statement.UseAutoTime && table != nil && table.Updated != "" { - colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?") - col := table.UpdatedColumn() - val, t := session.Engine.NowTime2(col.SQLType.Name) - args = append(args, val) + if session.statement.UseAutoTime && table != nil && table.Updated != "" { + if _, ok := session.statement.columnMap[strings.ToLower(table.Updated)]; !ok { + colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") + col := table.UpdatedColumn() + val, t := session.engine.NowTime2(col.SQLType.Name) + args = append(args, val) - var colName = col.Name - if isStruct { - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) + var colName = col.Name + if isStruct { + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } } } //for update action to like "column = column + ?" - incColumns := session.Statement.getInc() + incColumns := session.statement.getInc() for _, v := range incColumns { - colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" + ?") + colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?") args = append(args, v.arg) } //for update action to like "column = column - ?" - decColumns := session.Statement.getDec() + decColumns := session.statement.getDec() for _, v := range decColumns { - colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" - ?") + colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?") args = append(args, v.arg) } //for update action to like "column = expression" - exprColumns := session.Statement.getExpr() + exprColumns := session.statement.getExpr() for _, v := range exprColumns { - colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr) + colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr) } - session.Statement.processIDParam() + if err = session.statement.processIDParam(); err != nil { + return 0, err + } var autoCond builder.Cond - if !session.Statement.noAutoCondition && len(condiBean) > 0 { + if !session.statement.noAutoCondition && len(condiBean) > 0 { var err error - autoCond, err = session.Statement.buildConds(session.Statement.RefTable, condiBean[0], true, true, false, true, false) + autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } } - st := session.Statement - defer session.resetStatement() + st := &session.statement var sqlStr string var condArgs []interface{} var condSQL string - cond := session.Statement.cond.And(autoCond) + cond := session.statement.cond.And(autoCond) - var doIncVer = (table != nil && table.Version != "" && session.Statement.checkVersion) + var doIncVer = (table != nil && table.Version != "" && session.statement.checkVersion) var verValue *reflect.Value if doIncVer { verValue, err = table.VersionColumn().ValueOf(bean) @@ -263,11 +264,15 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, err } - cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()}) - colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1") + cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) + colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1") + } + + condSQL, condArgs, err = builder.ToSQL(cond) + if err != nil { + return 0, err } - condSQL, condArgs, _ = builder.ToSQL(cond) if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } @@ -276,6 +281,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr) } + var tableName = session.statement.TableName() // TODO: Oracle support needed var top string if st.LimitN > 0 { @@ -284,27 +290,53 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } else if st.Engine.dialect.DBType() == core.SQLITE { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", - session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) - condSQL, condArgs, _ = builder.ToSQL(cond) + session.engine.Quote(tableName), tempCondSQL), condArgs...)) + condSQL, condArgs, err = builder.ToSQL(cond) + if err != nil { + return 0, err + } if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } } else if st.Engine.dialect.DBType() == core.POSTGRES { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", - session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) - condSQL, condArgs, _ = builder.ToSQL(cond) + session.engine.Quote(tableName), tempCondSQL), condArgs...)) + condSQL, condArgs, err = builder.ToSQL(cond) + if err != nil { + return 0, err + } + if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } } else if st.Engine.dialect.DBType() == core.MSSQL { - top = fmt.Sprintf("top (%d) ", st.LimitN) + if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL && + table != nil && len(table.PrimaryKeys) == 1 { + cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", + table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0], + session.engine.Quote(tableName), condSQL), condArgs...) + + condSQL, condArgs, err = builder.ToSQL(cond) + if err != nil { + return 0, err + } + if len(condSQL) > 0 { + condSQL = "WHERE " + condSQL + } + } else { + top = fmt.Sprintf("TOP (%d) ", st.LimitN) + } } } + if len(colNames) <= 0 { + return 0, errors.New("No content found to be updated") + } + sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v", top, - session.Engine.Quote(session.Statement.TableName()), + session.engine.Quote(tableName), strings.Join(colNames, ", "), condSQL) @@ -318,19 +350,20 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if table != nil { - if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { - cacher.ClearIds(session.Statement.TableName()) - cacher.ClearBeans(session.Statement.TableName()) + if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache { + //session.cacheUpdate(table, tableName, sqlStr, args...) + cacher.ClearIds(tableName) + cacher.ClearBeans(tableName) } } // handle after update processors - if session.IsAutoCommit { + if session.isAutoCommit { for _, closure := range session.afterClosures { closure(bean) } if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - session.Engine.logger.Debug("[event]", session.Statement.TableName(), " has after update processor") + session.engine.logger.Debug("[event]", tableName, " has after update processor") processor.AfterUpdate() } } else { diff --git a/session_update_test.go b/session_update_test.go index 9eeb6186..690bd106 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -5,8 +5,13 @@ package xorm import ( + "errors" + "fmt" + "sync" "testing" + "time" + "github.com/go-xorm/core" "github.com/stretchr/testify/assert" ) @@ -38,14 +43,14 @@ func TestUpdateMap(t *testing.T) { func TestUpdateLimit(t *testing.T) { assert.NoError(t, prepareEngine()) - type UpdateTable struct { + type UpdateTable2 struct { Id int64 Name string Age int } - assert.NoError(t, testEngine.Sync2(new(UpdateTable))) - var tb = UpdateTable{ + assert.NoError(t, testEngine.Sync2(new(UpdateTable2))) + var tb = UpdateTable2{ Name: "test1", Age: 35, } @@ -59,16 +64,1154 @@ func TestUpdateLimit(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - cnt, err = testEngine.OrderBy("name desc").Limit(1).Update(&UpdateTable{ + cnt, err = testEngine.OrderBy("name desc").Limit(1).Update(&UpdateTable2{ Age: 30, }) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var uts []UpdateTable + var uts []UpdateTable2 err = testEngine.Find(&uts) assert.NoError(t, err) assert.EqualValues(t, 2, len(uts)) assert.EqualValues(t, 35, uts[0].Age) assert.EqualValues(t, 30, uts[1].Age) } + +type ForUpdate struct { + Id int64 `xorm:"pk"` + Name string +} + +func setupForUpdate(engine *Engine) error { + v := new(ForUpdate) + err := testEngine.DropTables(v) + if err != nil { + return err + } + err = testEngine.CreateTables(v) + if err != nil { + return err + } + + list := []ForUpdate{ + {1, "data1"}, + {2, "data2"}, + {3, "data3"}, + } + + for _, f := range list { + _, err = testEngine.Insert(f) + if err != nil { + return err + } + } + return nil +} + +func TestForUpdate(t *testing.T) { + if testEngine.DriverName() != "mysql" && testEngine.DriverName() != "mymysql" { + return + } + + err := setupForUpdate(testEngine) + if err != nil { + t.Error(err) + return + } + + session1 := testEngine.NewSession() + session2 := testEngine.NewSession() + session3 := testEngine.NewSession() + defer session1.Close() + defer session2.Close() + defer session3.Close() + + // start transaction + err = session1.Begin() + if err != nil { + t.Error(err) + return + } + + // use lock + fList := make([]ForUpdate, 0) + session1.ForUpdate() + session1.Where("(id) = ?", 1) + err = session1.Find(&fList) + switch { + case err != nil: + t.Error(err) + return + case len(fList) != 1: + t.Errorf("find not returned single row") + return + case fList[0].Name != "data1": + t.Errorf("for_update.name must be `data1`") + return + } + + // wait for lock + wg := &sync.WaitGroup{} + + // lock is used + wg.Add(1) + go func() { + f2 := new(ForUpdate) + session2.Where("(id) = ?", 1).ForUpdate() + has, err := session2.Get(f2) // wait release lock + switch { + case err != nil: + t.Error(err) + case !has: + t.Errorf("cannot find target row. for_update.id = 1") + case f2.Name != "updated by session1": + t.Errorf("read lock failed") + } + wg.Done() + }() + + // lock is NOT used + wg.Add(1) + go func() { + f3 := new(ForUpdate) + session3.Where("(id) = ?", 1) + has, err := session3.Get(f3) // wait release lock + switch { + case err != nil: + t.Error(err) + case !has: + t.Errorf("cannot find target row. for_update.id = 1") + case f3.Name != "data1": + t.Errorf("read lock failed") + } + wg.Done() + }() + + // wait for go rountines + time.Sleep(50 * time.Millisecond) + + f := new(ForUpdate) + f.Name = "updated by session1" + session1.Where("(id) = ?", 1) + session1.Update(f) + + // release lock + err = session1.Commit() + if err != nil { + t.Error(err) + return + } + + wg.Wait() +} + +func TestWithIn(t *testing.T) { + type temp3 struct { + Id int64 `xorm:"Id pk autoincr"` + Name string `xorm:"Name"` + Test bool `xorm:"Test"` + } + + assert.NoError(t, prepareEngine()) + assert.NoError(t, testEngine.Sync(new(temp3))) + + testEngine.Insert(&[]temp3{ + { + Name: "user1", + }, + { + Name: "user1", + }, + { + Name: "user1", + }, + }) + + cnt, err := testEngine.In("Id", 1, 2, 3, 4).Update(&temp3{Name: "aa"}, &temp3{Name: "user1"}) + assert.NoError(t, err) + assert.EqualValues(t, 3, cnt) +} + +type Condi map[string]interface{} + +type UpdateAllCols struct { + Id int64 + Bool bool + String string + Ptr *string +} + +type UpdateMustCols struct { + Id int64 + Bool bool + String string +} + +type UpdateIncr struct { + Id int64 + Cnt int + Name string +} + +type Article struct { + Id int32 `xorm:"pk INT autoincr"` + Name string `xorm:"VARCHAR(45)"` + Img string `xorm:"VARCHAR(100)"` + Aside string `xorm:"VARCHAR(200)"` + Desc string `xorm:"VARCHAR(200)"` + Content string `xorm:"TEXT"` + Status int8 `xorm:"TINYINT(4)"` +} + +func TestUpdateMap2(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(UpdateMustCols)) + + _, err := testEngine.Table("update_must_cols").Where("id =?", 1).Update(map[string]interface{}{ + "bool": true, + }) + if err != nil { + t.Error(err) + panic(err) + } +} + +func TestUpdate1(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + _, err := testEngine.Insert(&Userinfo{ + Username: "user1", + }) + + var ori Userinfo + has, err := testEngine.Get(&ori) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + t.Error(errors.New("not exist")) + panic(errors.New("not exist")) + } + + // update by id + user := Userinfo{Username: "xxx", Height: 1.2} + cnt, err := testEngine.ID(ori.Uid).Update(&user) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } + + condi := Condi{"username": "zzz", "departname": ""} + cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } + + cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user) + if err != nil { + t.Error(err) + panic(err) + } + total, err := testEngine.Count(&user) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != total { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + + // nullable update + { + user := &Userinfo{Username: "not null data", Height: 180.5} + _, err := testEngine.Insert(user) + if err != nil { + t.Error(err) + panic(err) + } + userID := user.Uid + + has, err := testEngine.ID(userID). + And("username = ?", user.Username). + And("height = ?", user.Height). + And("departname = ?", ""). + And("detail_id = ?", 0). + And("is_man = ?", 0). + Get(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("cannot insert properly") + t.Error(err) + panic(err) + } + + updatedUser := &Userinfo{Username: "null data"} + cnt, err = testEngine.ID(userID). + Nullable("height", "departname", "is_man", "created"). + Update(updatedUser) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + } + + has, err = testEngine.ID(userID). + And("username = ?", updatedUser.Username). + And("height IS NULL"). + And("departname IS NULL"). + And("is_man IS NULL"). + And("created IS NULL"). + And("detail_id = ?", 0). + Get(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("cannot update with null properly") + t.Error(err) + panic(err) + } + + cnt, err = testEngine.ID(userID).Delete(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("delete not returned 1") + t.Error(err) + panic(err) + } + } + + err = testEngine.StoreEngine("Innodb").Sync2(&Article{}) + if err != nil { + t.Error(err) + panic(err) + } + + defer func() { + err = testEngine.DropTables(&Article{}) + if err != nil { + t.Error(err) + panic(err) + } + }() + + a := &Article{0, "1", "2", "3", "4", "5", 2} + cnt, err = testEngine.Insert(a) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) + t.Error(err) + panic(err) + } + + if a.Id == 0 { + err = errors.New("insert returned id is 0") + t.Error(err) + panic(err) + } + + cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) + t.Error(err) + panic(err) + return + } + + var s = "test" + + col1 := &UpdateAllCols{Ptr: &s} + err = testEngine.Sync(col1) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = testEngine.Insert(col1) + if err != nil { + t.Error(err) + panic(err) + } + + col2 := &UpdateAllCols{col1.Id, true, "", nil} + _, err = testEngine.ID(col2.Id).AllCols().Update(col2) + if err != nil { + t.Error(err) + panic(err) + } + + col3 := &UpdateAllCols{} + has, err = testEngine.ID(col2.Id).Get(col3) + if err != nil { + t.Error(err) + panic(err) + } + + if !has { + err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) + t.Error(err) + panic(err) + return + } + + if *col2 != *col3 { + err = errors.New(fmt.Sprintf("col2 should eq col3")) + t.Error(err) + panic(err) + return + } + + { + + col1 := &UpdateMustCols{} + err = testEngine.Sync(col1) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = testEngine.Insert(col1) + if err != nil { + t.Error(err) + panic(err) + } + + col2 := &UpdateMustCols{col1.Id, true, ""} + boolStr := testEngine.ColumnMapper.Obj2Table("Bool") + stringStr := testEngine.ColumnMapper.Obj2Table("String") + _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2) + if err != nil { + t.Error(err) + panic(err) + } + + col3 := &UpdateMustCols{} + has, err := testEngine.ID(col2.Id).Get(col3) + if err != nil { + t.Error(err) + panic(err) + } + + if !has { + err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) + t.Error(err) + panic(err) + return + } + + if *col2 != *col3 { + err = errors.New(fmt.Sprintf("col2 should eq col3")) + t.Error(err) + panic(err) + return + } + } +} + +func TestUpdateIncrDecr(t *testing.T) { + assert.NoError(t, prepareEngine()) + + col1 := &UpdateIncr{ + Name: "test", + } + assert.NoError(t, testEngine.Sync(col1)) + + _, err := testEngine.Insert(col1) + assert.NoError(t, err) + + colName := testEngine.ColumnMapper.Obj2Table("Cnt") + + cnt, err := testEngine.ID(col1.Id).Incr(colName).Update(col1) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + newCol := new(UpdateIncr) + has, err := testEngine.ID(col1.Id).Get(newCol) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, newCol.Cnt) + + cnt, err = testEngine.ID(col1.Id).Decr(colName).Update(col1) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + newCol = new(UpdateIncr) + has, err = testEngine.ID(col1.Id).Get(newCol) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 0, newCol.Cnt) + + cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type UpdatedUpdate struct { + Id int64 + Updated time.Time `xorm:"updated"` +} + +type UpdatedUpdate2 struct { + Id int64 + Updated int64 `xorm:"updated"` +} + +type UpdatedUpdate3 struct { + Id int64 + Updated int `xorm:"updated bigint"` +} + +type UpdatedUpdate4 struct { + Id int64 + Updated int `xorm:"updated"` +} + +type UpdatedUpdate5 struct { + Id int64 + Updated time.Time `xorm:"updated bigint"` +} + +func TestUpdateUpdated(t *testing.T) { + assert.NoError(t, prepareEngine()) + + di := new(UpdatedUpdate) + err := testEngine.Sync2(di) + if err != nil { + t.Fatal(err) + } + + _, err = testEngine.Insert(&UpdatedUpdate{}) + if err != nil { + t.Fatal(err) + } + + ci := &UpdatedUpdate{} + _, err = testEngine.ID(1).Update(ci) + if err != nil { + t.Fatal(err) + } + + has, err := testEngine.ID(1).Get(di) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci.Updated.Unix() != di.Updated.Unix() { + t.Fatal("should equal:", ci, di) + } + fmt.Println("ci:", ci, "di:", di) + + di2 := new(UpdatedUpdate2) + err = testEngine.Sync2(di2) + assert.NoError(t, err) + + now := time.Now() + var di20 UpdatedUpdate2 + cnt, err := testEngine.Insert(&di20) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.True(t, now.Unix() <= di20.Updated) + + var di21 UpdatedUpdate2 + has, err = testEngine.ID(di20.Id).Get(&di21) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, di20.Updated, di21.Updated) + + ci2 := &UpdatedUpdate2{} + _, err = testEngine.ID(1).Update(ci2) + assert.NoError(t, err) + + has, err = testEngine.ID(1).Get(di2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci2.Updated, di2.Updated) + assert.True(t, ci2.Updated >= di21.Updated) + + di3 := new(UpdatedUpdate3) + err = testEngine.Sync2(di3) + if err != nil { + t.Fatal(err) + } + + _, err = testEngine.Insert(&UpdatedUpdate3{}) + if err != nil { + t.Fatal(err) + } + ci3 := &UpdatedUpdate3{} + _, err = testEngine.ID(1).Update(ci3) + if err != nil { + t.Fatal(err) + } + + has, err = testEngine.ID(1).Get(di3) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci3.Updated != di3.Updated { + t.Fatal("should equal:", ci3, di3) + } + fmt.Println("ci3:", ci3, "di3:", di3) + + di4 := new(UpdatedUpdate4) + err = testEngine.Sync2(di4) + if err != nil { + t.Fatal(err) + } + + _, err = testEngine.Insert(&UpdatedUpdate4{}) + if err != nil { + t.Fatal(err) + } + + ci4 := &UpdatedUpdate4{} + _, err = testEngine.ID(1).Update(ci4) + if err != nil { + t.Fatal(err) + } + + has, err = testEngine.ID(1).Get(di4) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci4.Updated != di4.Updated { + t.Fatal("should equal:", ci4, di4) + } + fmt.Println("ci4:", ci4, "di4:", di4) + + di5 := new(UpdatedUpdate5) + err = testEngine.Sync2(di5) + if err != nil { + t.Fatal(err) + } + + _, err = testEngine.Insert(&UpdatedUpdate5{}) + if err != nil { + t.Fatal(err) + } + ci5 := &UpdatedUpdate5{} + _, err = testEngine.ID(1).Update(ci5) + if err != nil { + t.Fatal(err) + } + + has, err = testEngine.ID(1).Get(di5) + if err != nil { + t.Fatal(err) + } + if !has { + t.Fatal(ErrNotExist) + } + if ci5.Updated.Unix() != di5.Updated.Unix() { + t.Fatal("should equal:", ci5, di5) + } + fmt.Println("ci5:", ci5, "di5:", di5) +} + +func TestUpdateSameMapper(t *testing.T) { + assert.NoError(t, prepareEngine()) + + oldMapper := testEngine.ColumnMapper + testEngine.unMapType(rValue(new(Userinfo)).Type()) + testEngine.unMapType(rValue(new(Condi)).Type()) + testEngine.unMapType(rValue(new(Article)).Type()) + testEngine.unMapType(rValue(new(UpdateAllCols)).Type()) + testEngine.unMapType(rValue(new(UpdateMustCols)).Type()) + testEngine.unMapType(rValue(new(UpdateIncr)).Type()) + testEngine.SetMapper(core.SameMapper{}) + defer func() { + testEngine.unMapType(rValue(new(Userinfo)).Type()) + testEngine.unMapType(rValue(new(Condi)).Type()) + testEngine.unMapType(rValue(new(Article)).Type()) + testEngine.unMapType(rValue(new(UpdateAllCols)).Type()) + testEngine.unMapType(rValue(new(UpdateMustCols)).Type()) + testEngine.unMapType(rValue(new(UpdateIncr)).Type()) + testEngine.SetMapper(oldMapper) + }() + + assertSync(t, new(Userinfo)) + + _, err := testEngine.Insert(&Userinfo{ + Username: "user1", + }) + assert.NoError(t, err) + + var ori Userinfo + has, err := testEngine.Get(&ori) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + t.Error(errors.New("not exist")) + panic(errors.New("not exist")) + } + // update by id + user := Userinfo{Username: "xxx", Height: 1.2} + cnt, err := testEngine.ID(ori.Uid).Update(&user) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } + + condi := Condi{"Username": "zzz", "Departname": ""} + cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } + + cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user) + if err != nil { + t.Error(err) + panic(err) + } + + total, err := testEngine.Count(&user) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != total { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + + err = testEngine.Sync(&Article{}) + if err != nil { + t.Error(err) + panic(err) + } + + defer func() { + err = testEngine.DropTables(&Article{}) + if err != nil { + t.Error(err) + panic(err) + } + }() + + a := &Article{0, "1", "2", "3", "4", "5", 2} + cnt, err = testEngine.Insert(a) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) + t.Error(err) + panic(err) + } + + if a.Id == 0 { + err = errors.New("insert returned id is 0") + t.Error(err) + panic(err) + } + + cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) + t.Error(err) + panic(err) + return + } + + col1 := &UpdateAllCols{} + err = testEngine.Sync(col1) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = testEngine.Insert(col1) + if err != nil { + t.Error(err) + panic(err) + } + + col2 := &UpdateAllCols{col1.Id, true, "", nil} + _, err = testEngine.ID(col2.Id).AllCols().Update(col2) + if err != nil { + t.Error(err) + panic(err) + } + + col3 := &UpdateAllCols{} + has, err = testEngine.ID(col2.Id).Get(col3) + if err != nil { + t.Error(err) + panic(err) + } + + if !has { + err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) + t.Error(err) + panic(err) + return + } + + if *col2 != *col3 { + err = errors.New(fmt.Sprintf("col2 should eq col3")) + t.Error(err) + panic(err) + return + } + + { + col1 := &UpdateMustCols{} + err = testEngine.Sync(col1) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = testEngine.Insert(col1) + if err != nil { + t.Error(err) + panic(err) + } + + col2 := &UpdateMustCols{col1.Id, true, ""} + boolStr := testEngine.ColumnMapper.Obj2Table("Bool") + stringStr := testEngine.ColumnMapper.Obj2Table("String") + _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2) + if err != nil { + t.Error(err) + panic(err) + } + + col3 := &UpdateMustCols{} + has, err := testEngine.ID(col2.Id).Get(col3) + if err != nil { + t.Error(err) + panic(err) + } + + if !has { + err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) + t.Error(err) + panic(err) + return + } + + if *col2 != *col3 { + err = errors.New(fmt.Sprintf("col2 should eq col3")) + t.Error(err) + panic(err) + return + } + } + + { + + col1 := &UpdateIncr{} + err = testEngine.Sync(col1) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = testEngine.Insert(col1) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := testEngine.ID(col1.Id).Incr("`Cnt`").Update(col1) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update incr failed") + t.Error(err) + panic(err) + } + + newCol := new(UpdateIncr) + has, err := testEngine.ID(col1.Id).Get(newCol) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("has incr failed") + t.Error(err) + panic(err) + } + if 1 != newCol.Cnt { + err = errors.New("incr failed") + t.Error(err) + panic(err) + } + } +} + +func TestUseBool(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + cnt1, err := testEngine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + + users := make([]Userinfo, 0) + err = testEngine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + var fNumber int64 + for _, u := range users { + if u.IsMan == false { + fNumber += 1 + } + } + + cnt2, err := testEngine.UseBool().Update(&Userinfo{IsMan: true}) + if err != nil { + t.Error(err) + panic(err) + } + if fNumber != cnt2 { + fmt.Println("cnt1", cnt1, "fNumber", fNumber, "cnt2", cnt2) + /*err = errors.New("Updated number is not corrected.") + t.Error(err) + panic(err)*/ + } + + _, err = testEngine.Update(&Userinfo{IsMan: true}) + if err == nil { + err = errors.New("error condition") + t.Error(err) + panic(err) + } +} + +func TestBool(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + _, err := testEngine.UseBool().Update(&Userinfo{IsMan: true}) + if err != nil { + t.Error(err) + panic(err) + } + users := make([]Userinfo, 0) + err = testEngine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + if !user.IsMan { + err = errors.New("update bool or find bool error") + t.Error(err) + panic(err) + } + } + + _, err = testEngine.UseBool().Update(&Userinfo{IsMan: false}) + if err != nil { + t.Error(err) + panic(err) + } + users = make([]Userinfo, 0) + err = testEngine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + if user.IsMan { + err = errors.New("update bool or find bool error") + t.Error(err) + panic(err) + } + } +} + +func TestNoUpdate(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type NoUpdate struct { + Id int64 + Content string + } + + assertSync(t, new(NoUpdate)) + + cnt, err := testEngine.Insert(&NoUpdate{ + Content: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + _, err = testEngine.ID(1).Update(&NoUpdate{}) + assert.Error(t, err) + assert.EqualValues(t, "No content found to be updated", err.Error()) +} + +func TestNewUpdate(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TbUserInfo struct { + Id int64 `xorm:"pk autoincr unique BIGINT" json:"id"` + Phone string `xorm:"not null unique VARCHAR(20)" json:"phone"` + UserName string `xorm:"VARCHAR(20)" json:"user_name"` + Gender int `xorm:"default 0 INTEGER" json:"gender"` + Pw string `xorm:"VARCHAR(100)" json:"pw"` + Token string `xorm:"TEXT" json:"token"` + Avatar string `xorm:"TEXT" json:"avatar"` + Extras interface{} `xorm:"JSON" json:"extras"` + Created time.Time `xorm:"DATETIME created"` + Updated time.Time `xorm:"DATETIME updated"` + Deleted time.Time `xorm:"DATETIME deleted"` + } + + assertSync(t, new(TbUserInfo)) + + targetUsr := TbUserInfo{Phone: "13126564922"} + changeUsr := TbUserInfo{Token: "ABCDEFG"} + af, err := testEngine.Update(&changeUsr, &targetUsr) + assert.NoError(t, err) + assert.EqualValues(t, 0, af) + + af, err = testEngine.Table(new(TbUserInfo)).Where("phone=?", 13126564922).Update(&changeUsr) + assert.NoError(t, err) + assert.EqualValues(t, 0, af) +} + +func TestUpdateUpdate(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type PublicKeyUpdate struct { + Id int64 + UpdatedUnix int64 `xorm:"updated"` + } + + assertSync(t, new(PublicKeyUpdate)) + + cnt, err := testEngine.ID(1).Cols("updated_unix").Update(&PublicKeyUpdate{ + UpdatedUnix: time.Now().Unix(), + }) + assert.NoError(t, err) + assert.EqualValues(t, 0, cnt) +} + +func TestCreatedUpdated2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type CreatedUpdatedStruct struct { + Id int64 + Name string + CreateAt time.Time `xorm:"created" json:"create_at"` + UpdateAt time.Time `xorm:"updated" json:"update_at"` + } + + assertSync(t, new(CreatedUpdatedStruct)) + + var s = CreatedUpdatedStruct{ + Name: "test", + } + cnt, err := testEngine.Insert(&s) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.EqualValues(t, s.UpdateAt.Unix(), s.CreateAt.Unix()) + + time.Sleep(time.Second) + + var s1 = CreatedUpdatedStruct{ + Name: "test1", + CreateAt: s.CreateAt, + UpdateAt: s.UpdateAt, + } + + cnt, err = testEngine.ID(1).Update(&s1) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.EqualValues(t, s.CreateAt.Unix(), s1.CreateAt.Unix()) + assert.True(t, s1.UpdateAt.Unix() > s.UpdateAt.Unix()) + + var s2 CreatedUpdatedStruct + has, err := testEngine.ID(1).Get(&s2) + assert.NoError(t, err) + assert.True(t, has) + + assert.EqualValues(t, s.CreateAt.Unix(), s2.CreateAt.Unix()) + assert.True(t, s2.UpdateAt.Unix() > s.UpdateAt.Unix()) + assert.True(t, s2.UpdateAt.Unix() > s2.CreateAt.Unix()) +} diff --git a/statement.go b/statement.go index b6f0baf2..23346c71 100644 --- a/statement.go +++ b/statement.go @@ -73,6 +73,7 @@ type Statement struct { decrColumns map[string]decrParam exprColumns map[string]exprParam cond builder.Cond + bufferSize int } // Init reset all the statement's fields @@ -111,6 +112,7 @@ func (statement *Statement) Init() { statement.decrColumns = make(map[string]decrParam) statement.exprColumns = make(map[string]exprParam) statement.cond = builder.NewCond() + statement.bufferSize = 0 } // NoAutoCondition if you do not want convert bean's field as query condition, then use this function @@ -272,6 +274,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, fieldValue := *fieldValuePtr fieldType := reflect.TypeOf(fieldValue.Interface()) + if fieldType == nil { + continue + } requiredField := useAllCols includeNil := useAllCols @@ -376,7 +381,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { continue } - val = engine.FormatTime(col.SQLType.Name, t) + val = engine.formatColTime(col, t) } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { val, _ = nulType.Value() } else { @@ -490,224 +495,6 @@ func (statement *Statement) colName(col *core.Column, tableName string) string { return statement.Engine.Quote(col.Name) } -func buildConds(engine *Engine, table *core.Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, - includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { - var conds []builder.Cond - for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - - if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { - continue - } - if col.SQLType.IsJson() { - continue - } - - var colName string - if addedTableName { - var nm = tableName - if len(aliasName) > 0 { - nm = aliasName - } - colName = engine.Quote(nm) + "." + engine.Quote(col.Name) - } else { - colName = engine.Quote(col.Name) - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - engine.logger.Error(err) - continue - } - - if col.IsDeleted && !unscoped { // tag "deleted" is enabled - if engine.dialect.DBType() == core.MSSQL { - conds = append(conds, builder.IsNull{colName}) - } else { - conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"})) - } - } - - fieldValue := *fieldValuePtr - if fieldValue.Interface() == nil { - continue - } - - fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := useAllCols - - if b, ok := getFlagForColumn(mustColumnMap, col); ok { - if b { - requiredField = true - } else { - continue - } - } - - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - conds = append(conds, builder.Eq{colName: nil}) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } - - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - t := int64(fieldValue.Uint()) - val = reflect.ValueOf(&t).Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = engine.FormatTime(col.SQLType.Name, t) - } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil { - continue - } - } else { - if col.SQLType.IsJson() { - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !isZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys)) - } - } else { - val = fieldValue.Interface() - } - } - } - case reflect.Array: - continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - - conds = append(conds, builder.Eq{colName: val}) - } - - return builder.And(conds...), nil -} - // TableName return current tableName func (statement *Statement) TableName() string { if statement.AltTableName != "" { @@ -810,6 +597,22 @@ func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { return newColumns } +func (statement *Statement) colmap2NewColsWithQuote() []string { + newColumns := make([]string, 0, len(statement.columnMap)) + for col := range statement.columnMap { + fields := strings.Split(strings.TrimSpace(col), ".") + if len(fields) == 1 { + newColumns = append(newColumns, statement.Engine.quote(fields[0])) + } else if len(fields) == 2 { + newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+ + statement.Engine.quote(fields[1])) + } else { + panic(errors.New("unwanted colnames")) + } + } + return newColumns +} + // Distinct generates "DISTINCT col1, col2 " statement func (statement *Statement) Distinct(columns ...string) *Statement { statement.IsDistinct = true @@ -836,7 +639,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { statement.columnMap[strings.ToLower(nc)] = true } - newColumns := statement.col2NewColsWithQuote(columns...) + newColumns := statement.colmap2NewColsWithQuote() statement.ColumnStr = strings.Join(newColumns, ", ") statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) return statement @@ -1104,26 +907,35 @@ func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interfa } func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { - return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, + return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) } -func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { +func (statement *Statement) mergeConds(bean interface{}) error { if !statement.noAutoCondition { var addedTableName = (len(statement.JoinStr) > 0) autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) if err != nil { - return "", nil, err + return err } statement.cond = statement.cond.And(autoCond) } - statement.processIDParam() + if err := statement.processIDParam(); err != nil { + return err + } + return nil +} + +func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } return builder.ToSQL(statement.cond) } -func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) { +func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) { v := rValue(bean) isStruct := v.Kind() == reflect.Struct if isStruct { @@ -1156,21 +968,37 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) columnStr = "*" } - var condSQL string - var condArgs []interface{} if isStruct { - condSQL, condArgs, _ = statement.genConds(bean) - } else { - condSQL, condArgs, _ = builder.ToSQL(statement.cond) + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } + } + condSQL, condArgs, err := builder.ToSQL(statement.cond) + if err != nil { + return "", nil, err } - return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...) + sqlStr, err := statement.genSelectSQL(columnStr, condSQL) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) { - statement.setRefValue(rValue(bean)) - - condSQL, condArgs, _ := statement.genConds(bean) +func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) { + var condSQL string + var condArgs []interface{} + var err error + if len(beans) > 0 { + statement.setRefValue(rValue(beans[0])) + condSQL, condArgs, err = statement.genConds(beans[0]) + } else { + condSQL, condArgs, err = builder.ToSQL(statement.cond) + } + if err != nil { + return "", nil, err + } var selectSQL = statement.selectStr if len(selectSQL) <= 0 { @@ -1180,10 +1008,15 @@ func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{} selectSQL = "count(*)" } } - return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...) + sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) { +func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { statement.setRefValue(rValue(bean)) var sumStrs = make([]string, 0, len(columns)) @@ -1195,12 +1028,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri } sumSelect := strings.Join(sumStrs, ", ") - condSQL, condArgs, _ := statement.genConds(bean) + condSQL, condArgs, err := statement.genConds(bean) + if err != nil { + return "", nil, err + } - return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...) + sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { +func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { var distinct string if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " @@ -1211,7 +1052,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { var top string var mssqlCondi string - statement.processIDParam() + if err := statement.processIDParam(); err != nil { + return "", err + } var buf bytes.Buffer if len(condSQL) > 0 { @@ -1278,7 +1121,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { } // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern - a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, fromStr, whereStr) + a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) if len(mssqlCondi) > 0 { if len(whereStr) > 0 { a += " AND " + mssqlCondi @@ -1314,19 +1157,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { return } -func (statement *Statement) processIDParam() { +func (statement *Statement) processIDParam() error { if statement.idParam == nil { - return + return nil + } + + if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) { + return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", + len(statement.RefTable.PrimaryKeys), + len(*statement.idParam), + ) } for i, col := range statement.RefTable.PKColumns() { var colName = statement.colName(col, statement.TableName()) - if i < len(*(statement.idParam)) { - statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) - } else { - statement.cond = statement.cond.And(builder.Eq{colName: ""}) - } + statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) } + return nil } func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string { @@ -1360,7 +1207,8 @@ func (statement *Statement) convertIDSQL(sqlStr string) string { top = fmt.Sprintf("TOP %d ", statement.LimitN) } - return fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) + newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) + return newsql } return "" } diff --git a/statement_test.go b/statement_test.go index 01a09afc..594aa4f3 100644 --- a/statement_test.go +++ b/statement_test.go @@ -26,7 +26,7 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { - if *db == "postgres" { + if dbType == "postgres" || dbType == "mssql" { return } diff --git a/tag.go b/tag.go index 4b0e3f54..e1c821fb 100644 --- a/tag.go +++ b/tag.go @@ -54,6 +54,7 @@ var ( "UNIQUE": UniqueTagHandler, "CACHE": CacheTagHandler, "NOCACHE": NoCacheTagHandler, + "COMMENT": CommentTagHandler, } ) @@ -192,6 +193,14 @@ func UniqueTagHandler(ctx *tagContext) error { return nil } +// CommentTagHandler add comment to column +func CommentTagHandler(ctx *tagContext) error { + if len(ctx.params) > 0 { + ctx.col.Comment = strings.Trim(ctx.params[0], "' ") + } + return nil +} + // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *tagContext) error { ctx.col.SQLType = core.SQLType{Name: ctx.tagName} diff --git a/tag_cache_test.go b/tag_cache_test.go new file mode 100644 index 00000000..14a65fb8 --- /dev/null +++ b/tag_cache_test.go @@ -0,0 +1,39 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCacheTag(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type CacheDomain struct { + Id int64 `xorm:"pk cache"` + Name string + } + + assert.NoError(t, testEngine.CreateTables(&CacheDomain{})) + + table := testEngine.TableInfo(&CacheDomain{}) + assert.True(t, table.Cacher != nil) +} + +func TestNoCacheTag(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type NoCacheDomain struct { + Id int64 `xorm:"pk nocache"` + Name string + } + + assert.NoError(t, testEngine.CreateTables(&NoCacheDomain{})) + + table := testEngine.TableInfo(&NoCacheDomain{}) + assert.True(t, table.Cacher == nil) +} diff --git a/tag_extends_test.go b/tag_extends_test.go new file mode 100644 index 00000000..61a61e9e --- /dev/null +++ b/tag_extends_test.go @@ -0,0 +1,538 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/go-xorm/core" + "github.com/stretchr/testify/assert" +) + +type tempUser struct { + Id int64 + Username string +} + +type tempUser2 struct { + TempUser tempUser `xorm:"extends"` + Departname string +} + +type tempUser3 struct { + Temp *tempUser `xorm:"extends"` + Departname string +} + +type tempUser4 struct { + TempUser2 tempUser2 `xorm:"extends"` +} + +type Userinfo struct { + Uid int64 `xorm:"id pk not null autoincr"` + Username string `xorm:"unique"` + Departname string + Alias string `xorm:"-"` + Created time.Time + Detail Userdetail `xorm:"detail_id int(11)"` + Height float64 + Avatar []byte + IsMan bool +} + +type Userdetail struct { + Id int64 + Intro string `xorm:"text"` + Profile string `xorm:"varchar(2000)"` +} + +type UserAndDetail struct { + Userinfo `xorm:"extends"` + Userdetail `xorm:"extends"` +} + +func TestExtends(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(&tempUser2{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(&tempUser2{}) + if err != nil { + t.Error(err) + panic(err) + } + + tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} + _, err = testEngine.Insert(tu) + if err != nil { + t.Error(err) + panic(err) + } + + tu2 := &tempUser2{} + _, err = testEngine.Get(tu2) + if err != nil { + t.Error(err) + panic(err) + } + + tu3 := &tempUser2{tempUser{0, "extends update"}, ""} + _, err = testEngine.ID(tu2.TempUser.Id).Update(tu3) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.DropTables(&tempUser4{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(&tempUser4{}) + if err != nil { + t.Error(err) + panic(err) + } + + tu8 := &tempUser4{tempUser2{tempUser{0, "extends"}, "dev depart"}} + _, err = testEngine.Insert(tu8) + if err != nil { + t.Error(err) + panic(err) + } + + tu9 := &tempUser4{} + _, err = testEngine.Get(tu9) + if err != nil { + t.Error(err) + panic(err) + } + if tu9.TempUser2.TempUser.Username != tu8.TempUser2.TempUser.Username || tu9.TempUser2.Departname != tu8.TempUser2.Departname { + err = errors.New(fmt.Sprintln("not equal for", tu8, tu9)) + t.Error(err) + panic(err) + } + + tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} + _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.DropTables(&tempUser3{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(&tempUser3{}) + if err != nil { + t.Error(err) + panic(err) + } + + tu4 := &tempUser3{&tempUser{0, "extends"}, "dev depart"} + _, err = testEngine.Insert(tu4) + if err != nil { + t.Error(err) + panic(err) + } + + tu5 := &tempUser3{} + _, err = testEngine.Get(tu5) + if err != nil { + t.Error(err) + panic(err) + } + if tu5.Temp == nil { + err = errors.New("error get data extends") + t.Error(err) + panic(err) + } + if tu5.Temp.Id != 1 || tu5.Temp.Username != "extends" || + tu5.Departname != "dev depart" { + err = errors.New("error get data extends") + t.Error(err) + panic(err) + } + + tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} + _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) + if err != nil { + t.Error(err) + panic(err) + } + + users := make([]tempUser3, 0) + err = testEngine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + if len(users) != 1 { + err = errors.New("error get data not 1") + t.Error(err) + panic(err) + } + + assertSync(t, new(Userinfo), new(Userdetail)) + + detail := Userdetail{ + Intro: "I'm in China", + } + _, err = testEngine.Insert(&detail) + assert.NoError(t, err) + + _, err = testEngine.Insert(&Userinfo{ + Username: "lunny", + Detail: detail, + }) + assert.NoError(t, err) + + var info UserAndDetail + qt := testEngine.Quote + ui := testEngine.TableMapper.Obj2Table("Userinfo") + ud := testEngine.TableMapper.Obj2Table("Userdetail") + uiid := testEngine.TableMapper.Obj2Table("Id") + udid := "detail_id" + sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", + qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) + b, err := testEngine.Sql(sql).NoCascade().Get(&info) + if err != nil { + t.Error(err) + panic(err) + } + if !b { + err = errors.New("should has lest one record") + t.Error(err) + panic(err) + } + fmt.Println(info) + if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 { + err = errors.New("all of the id should has value") + t.Error(err) + panic(err) + } + + fmt.Println("----join--info2") + var info2 UserAndDetail + b, err = testEngine.Table(&Userinfo{}). + Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). + NoCascade().Get(&info2) + if err != nil { + t.Error(err) + panic(err) + } + if !b { + err = errors.New("should has lest one record") + t.Error(err) + panic(err) + } + if info2.Userinfo.Uid == 0 || info2.Userdetail.Id == 0 { + err = errors.New("all of the id should has value") + t.Error(err) + panic(err) + } + fmt.Println(info2) + + fmt.Println("----join--infos2") + var infos2 = make([]UserAndDetail, 0) + err = testEngine.Table(&Userinfo{}). + Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). + NoCascade(). + Find(&infos2) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(infos2) +} + +type MessageBase struct { + Id int64 `xorm:"int(11) pk autoincr"` + TypeId int64 `xorm:"int(11) notnull"` +} + +type Message struct { + MessageBase `xorm:"extends"` + Title string `xorm:"varchar(100) notnull"` + Content string `xorm:"text notnull"` + Uid int64 `xorm:"int(11) notnull"` + ToUid int64 `xorm:"int(11) notnull"` + CreateTime time.Time `xorm:"datetime notnull created"` +} + +type MessageUser struct { + Id int64 + Name string +} + +type MessageType struct { + Id int64 + Name string +} + +type MessageExtend3 struct { + Message `xorm:"extends"` + Sender MessageUser `xorm:"extends"` + Receiver MessageUser `xorm:"extends"` + Type MessageType `xorm:"extends"` +} + +type MessageExtend4 struct { + Message `xorm:"extends"` + MessageUser `xorm:"extends"` + MessageType `xorm:"extends"` +} + +func TestExtends2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) + if err != nil { + t.Error(err) + panic(err) + } + + var sender = MessageUser{Name: "sender"} + var receiver = MessageUser{Name: "receiver"} + var msgtype = MessageType{Name: "type"} + _, err = testEngine.Insert(&sender, &receiver, &msgtype) + if err != nil { + t.Error(err) + panic(err) + } + + msg := Message{ + MessageBase: MessageBase{ + Id: msgtype.Id, + }, + Title: "test", + Content: "test", + Uid: sender.Id, + ToUid: receiver.Id, + } + if testEngine.dialect.DBType() == core.MSSQL { + _, err = testEngine.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + + _, err = testEngine.Insert(&msg) + if err != nil { + t.Error(err) + panic(err) + } + + var mapper = testEngine.TableMapper.Obj2Table + userTableName := mapper("MessageUser") + typeTableName := mapper("MessageType") + msgTableName := mapper("Message") + + list := make([]Message, 0) + err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). + Find(&list) + if err != nil { + t.Error(err) + panic(err) + } + + if len(list) != 1 { + err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) + t.Error(err) + panic(err) + } + + if list[0].Id != msg.Id { + err = errors.New(fmt.Sprintln("should message equal", list[0], msg)) + t.Error(err) + panic(err) + } +} + +func TestExtends3(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) + if err != nil { + t.Error(err) + panic(err) + } + + var sender = MessageUser{Name: "sender"} + var receiver = MessageUser{Name: "receiver"} + var msgtype = MessageType{Name: "type"} + _, err = testEngine.Insert(&sender, &receiver, &msgtype) + if err != nil { + t.Error(err) + panic(err) + } + + msg := Message{ + MessageBase: MessageBase{ + Id: msgtype.Id, + }, + Title: "test", + Content: "test", + Uid: sender.Id, + ToUid: receiver.Id, + } + if testEngine.dialect.DBType() == core.MSSQL { + _, err = testEngine.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + _, err = testEngine.Insert(&msg) + if err != nil { + t.Error(err) + panic(err) + } + + var mapper = testEngine.TableMapper.Obj2Table + userTableName := mapper("MessageUser") + typeTableName := mapper("MessageType") + msgTableName := mapper("Message") + + list := make([]MessageExtend3, 0) + err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). + Find(&list) + if err != nil { + t.Error(err) + panic(err) + } + + if len(list) != 1 { + err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) + t.Error(err) + panic(err) + } + + if list[0].Message.Id != msg.Id { + err = errors.New(fmt.Sprintln("should message equal", list[0].Message, msg)) + t.Error(err) + panic(err) + } + + if list[0].Sender.Id != sender.Id || list[0].Sender.Name != sender.Name { + err = errors.New(fmt.Sprintln("should sender equal", list[0].Sender, sender)) + t.Error(err) + panic(err) + } + + if list[0].Receiver.Id != receiver.Id || list[0].Receiver.Name != receiver.Name { + err = errors.New(fmt.Sprintln("should receiver equal", list[0].Receiver, receiver)) + t.Error(err) + panic(err) + } + + if list[0].Type.Id != msgtype.Id || list[0].Type.Name != msgtype.Name { + err = errors.New(fmt.Sprintln("should msgtype equal", list[0].Type, msgtype)) + t.Error(err) + panic(err) + } +} + +func TestExtends4(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) + if err != nil { + t.Error(err) + panic(err) + } + + var sender = MessageUser{Name: "sender"} + var msgtype = MessageType{Name: "type"} + _, err = testEngine.Insert(&sender, &msgtype) + if err != nil { + t.Error(err) + panic(err) + } + + msg := Message{ + MessageBase: MessageBase{ + Id: msgtype.Id, + }, + Title: "test", + Content: "test", + Uid: sender.Id, + } + if testEngine.dialect.DBType() == core.MSSQL { + _, err = testEngine.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + _, err = testEngine.Insert(&msg) + if err != nil { + t.Error(err) + panic(err) + } + + var mapper = testEngine.TableMapper.Obj2Table + userTableName := mapper("MessageUser") + typeTableName := mapper("MessageType") + msgTableName := mapper("Message") + + list := make([]MessageExtend4, 0) + err = testEngine.Table(msgTableName).Join("LEFT", userTableName, "`"+userTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Uid")+"`"). + Join("LEFT", typeTableName, "`"+typeTableName+"`.`"+mapper("Id")+"`=`"+msgTableName+"`.`"+mapper("Id")+"`"). + Find(&list) + if err != nil { + t.Error(err) + panic(err) + } + + if len(list) != 1 { + err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) + t.Error(err) + panic(err) + } + + if list[0].Message.Id != msg.Id { + err = errors.New(fmt.Sprintln("should message equal", list[0].Message, msg)) + t.Error(err) + panic(err) + } + + if list[0].MessageUser.Id != sender.Id || list[0].MessageUser.Name != sender.Name { + err = errors.New(fmt.Sprintln("should sender equal", list[0].MessageUser, sender)) + t.Error(err) + panic(err) + } + + if list[0].MessageType.Id != msgtype.Id || list[0].MessageType.Name != msgtype.Name { + err = errors.New(fmt.Sprintln("should msgtype equal", list[0].MessageType, msgtype)) + t.Error(err) + panic(err) + } +} diff --git a/tag_id_test.go b/tag_id_test.go new file mode 100644 index 00000000..d22cc7b1 --- /dev/null +++ b/tag_id_test.go @@ -0,0 +1,85 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/go-xorm/core" + "github.com/stretchr/testify/assert" +) + +type IDGonicMapper struct { + ID int64 +} + +func TestGonicMapperID(t *testing.T) { + assert.NoError(t, prepareEngine()) + + oldMapper := testEngine.ColumnMapper + testEngine.unMapType(rValue(new(IDGonicMapper)).Type()) + testEngine.SetMapper(core.LintGonicMapper) + defer func() { + testEngine.unMapType(rValue(new(IDGonicMapper)).Type()) + testEngine.SetMapper(oldMapper) + }() + + err := testEngine.CreateTables(new(IDGonicMapper)) + if err != nil { + t.Fatal(err) + } + + tables, err := testEngine.DBMetas() + if err != nil { + t.Fatal(err) + } + + for _, tb := range tables { + if tb.Name == "id_gonic_mapper" { + if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "id" { + t.Fatal(tb) + } + return + } + } + + t.Fatal("not table id_gonic_mapper") +} + +type IDSameMapper struct { + ID int64 +} + +func TestSameMapperID(t *testing.T) { + assert.NoError(t, prepareEngine()) + + oldMapper := testEngine.ColumnMapper + testEngine.unMapType(rValue(new(IDSameMapper)).Type()) + testEngine.SetMapper(core.SameMapper{}) + defer func() { + testEngine.unMapType(rValue(new(IDSameMapper)).Type()) + testEngine.SetMapper(oldMapper) + }() + + err := testEngine.CreateTables(new(IDSameMapper)) + if err != nil { + t.Fatal(err) + } + + tables, err := testEngine.DBMetas() + if err != nil { + t.Fatal(err) + } + + for _, tb := range tables { + if tb.Name == "IDSameMapper" { + if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "ID" { + t.Fatal(tb) + } + return + } + } + t.Fatal("not table IDSameMapper") +} diff --git a/tag_test.go b/tag_test.go new file mode 100644 index 00000000..ef5028f6 --- /dev/null +++ b/tag_test.go @@ -0,0 +1,395 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "fmt" + "strings" + "testing" + "time" + + "github.com/go-xorm/core" + "github.com/stretchr/testify/assert" +) + +type UserCU struct { + Id int64 + Name string + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` +} + +func TestCreatedAndUpdated(t *testing.T) { + assert.NoError(t, prepareEngine()) + + u := new(UserCU) + err := testEngine.DropTables(u) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(u) + if err != nil { + t.Error(err) + panic(err) + } + + u.Name = "sss" + cnt, err := testEngine.Insert(u) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + + u.Name = "xxx" + cnt, err = testEngine.ID(u.Id).Update(u) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } + + u.Id = 0 + u.Created = time.Now().Add(-time.Hour * 24 * 365) + u.Updated = u.Created + fmt.Println(u) + cnt, err = testEngine.NoAutoTime().Insert(u) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } +} + +type StrangeName struct { + Id_t int64 `xorm:"pk autoincr"` + Name string +} + +func TestStrangeName(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(new(StrangeName)) + if err != nil { + t.Error(err) + } + + err = testEngine.CreateTables(new(StrangeName)) + if err != nil { + t.Error(err) + } + + _, err = testEngine.Insert(&StrangeName{Name: "sfsfdsfds"}) + if err != nil { + t.Error(err) + } + + beans := make([]StrangeName, 0) + err = testEngine.Find(&beans) + if err != nil { + t.Error(err) + } +} + +func TestCreatedUpdated(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type CreatedUpdated struct { + Id int64 + Name string + Value float64 `xorm:"numeric"` + Created time.Time `xorm:"created"` + Created2 time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` + } + + err := testEngine.Sync(&CreatedUpdated{}) + assert.NoError(t, err) + + c := &CreatedUpdated{Name: "test"} + _, err = testEngine.Insert(c) + assert.NoError(t, err) + + c2 := new(CreatedUpdated) + has, err := testEngine.ID(c.Id).Get(c2) + assert.NoError(t, err) + + assert.True(t, has) + + c2.Value -= 1 + _, err = testEngine.ID(c2.Id).Update(c2) + assert.NoError(t, err) +} + +func TestCreatedUpdatedInt64(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type CreatedUpdatedInt64 struct { + Id int64 + Name string + Value float64 `xorm:"numeric"` + Created int64 `xorm:"created"` + Created2 int64 `xorm:"created"` + Updated int64 `xorm:"updated"` + } + + assertSync(t, &CreatedUpdatedInt64{}) + + c := &CreatedUpdatedInt64{Name: "test"} + _, err := testEngine.Insert(c) + assert.NoError(t, err) + + c2 := new(CreatedUpdatedInt64) + has, err := testEngine.ID(c.Id).Get(c2) + assert.NoError(t, err) + assert.True(t, has) + + c2.Value -= 1 + _, err = testEngine.ID(c2.Id).Update(c2) + assert.NoError(t, err) +} + +type Lowercase struct { + Id int64 + Name string + ended int64 `xorm:"-"` +} + +func TestLowerCase(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.Sync(&Lowercase{}) + _, err = testEngine.Where("(id) > 0").Delete(&Lowercase{}) + if err != nil { + t.Error(err) + panic(err) + } + _, err = testEngine.Insert(&Lowercase{ended: 1}) + if err != nil { + t.Error(err) + panic(err) + } + + ls := make([]Lowercase, 0) + err = testEngine.Find(&ls) + if err != nil { + t.Error(err) + panic(err) + } + + if len(ls) != 1 { + err = errors.New("should be 1") + t.Error(err) + panic(err) + } +} + +func TestAutoIncrTag(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TestAutoIncr1 struct { + Id int64 + } + + tb := testEngine.TableInfo(new(TestAutoIncr1)) + cols := tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.True(t, cols[0].IsAutoIncrement) + assert.True(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + + type TestAutoIncr2 struct { + Id int64 `xorm:"id"` + } + + tb = testEngine.TableInfo(new(TestAutoIncr2)) + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.False(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + + type TestAutoIncr3 struct { + Id int64 `xorm:"'ID'"` + } + + tb = testEngine.TableInfo(new(TestAutoIncr3)) + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.False(t, cols[0].IsPrimaryKey) + assert.Equal(t, "ID", cols[0].Name) + + type TestAutoIncr4 struct { + Id int64 `xorm:"pk"` + } + + tb = testEngine.TableInfo(new(TestAutoIncr4)) + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.True(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) +} + +func TestTagComment(t *testing.T) { + assert.NoError(t, prepareEngine()) + // FIXME: only support mysql + if testEngine.dialect.DriverName() != core.MYSQL { + return + } + + type TestComment1 struct { + Id int64 `xorm:"comment(主键)"` + } + + assert.NoError(t, testEngine.Sync2(new(TestComment1))) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) + + assert.NoError(t, testEngine.DropTables(new(TestComment1))) + + type TestComment2 struct { + Id int64 `xorm:"comment('主键')"` + } + + assert.NoError(t, testEngine.Sync2(new(TestComment2))) + + tables, err = testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) +} + +func TestTagDefault(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type DefaultStruct struct { + Id int64 + Name string + Age int `xorm:"default(10)"` + } + + assertSync(t, new(DefaultStruct)) + + cnt, err := testEngine.Omit("age").Insert(&DefaultStruct{ + Name: "test", + Age: 20, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s DefaultStruct + has, err := testEngine.ID(1).Get(&s) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 10, s.Age) + assert.EqualValues(t, "test", s.Name) +} + +func TestTagsDirection(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type OnlyFromDBStruct struct { + Id int64 + Name string + Uuid string `xorm:"<- default '1'"` + } + + assertSync(t, new(OnlyFromDBStruct)) + + cnt, err := testEngine.Insert(&OnlyFromDBStruct{ + Name: "test", + Uuid: "2", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s OnlyFromDBStruct + has, err := testEngine.ID(1).Get(&s) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "1", s.Uuid) + assert.EqualValues(t, "test", s.Name) + + type OnlyToDBStruct struct { + Id int64 + Name string + Uuid string `xorm:"->"` + } + + assertSync(t, new(OnlyToDBStruct)) + + cnt, err = testEngine.Insert(&OnlyToDBStruct{ + Name: "test", + Uuid: "2", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s2 OnlyToDBStruct + has, err = testEngine.ID(1).Get(&s2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "", s2.Uuid) + assert.EqualValues(t, "test", s2.Name) +} + +func TestTagTime(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TagUTCStruct struct { + Id int64 + Name string + Created time.Time `xorm:"created utc"` + } + + assertSync(t, new(TagUTCStruct)) + + assert.EqualValues(t, time.Local.String(), testEngine.TZLocation.String()) + + s := TagUTCStruct{ + Name: "utc", + } + cnt, err := testEngine.Insert(&s) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var u TagUTCStruct + has, err := testEngine.ID(1).Get(&u) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, s.Created.Format("2006-01-02 15:04:05"), u.Created.Format("2006-01-02 15:04:05")) + + var tm string + has, err = testEngine.Table("tag_u_t_c_struct").Cols("created").Get(&tm) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), + strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) +} diff --git a/tag_version_test.go b/tag_version_test.go new file mode 100644 index 00000000..570a6754 --- /dev/null +++ b/tag_version_test.go @@ -0,0 +1,128 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type VersionS struct { + Id int64 + Name string + Ver int `xorm:"version"` + Created time.Time `xorm:"created"` +} + +func TestVersion1(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(new(VersionS)) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(new(VersionS)) + if err != nil { + t.Error(err) + panic(err) + } + + ver := &VersionS{Name: "sfsfdsfds"} + _, err = testEngine.Insert(ver) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(ver) + if ver.Ver != 1 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } + + newVer := new(VersionS) + has, err := testEngine.ID(ver.Id).Get(newVer) + if err != nil { + t.Error(err) + panic(err) + } + + if !has { + t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id))) + panic(err) + } + fmt.Println(newVer) + if newVer.Ver != 1 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } + + newVer.Name = "-------" + _, err = testEngine.ID(ver.Id).Update(newVer) + if err != nil { + t.Error(err) + panic(err) + } + if newVer.Ver != 2 { + err = errors.New("update should set version back to struct") + t.Error(err) + } + + newVer = new(VersionS) + has, err = testEngine.ID(ver.Id).Get(newVer) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(newVer) + if newVer.Ver != 2 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } +} + +func TestVersion2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(new(VersionS)) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(new(VersionS)) + if err != nil { + t.Error(err) + panic(err) + } + + var vers = []VersionS{ + {Name: "sfsfdsfds"}, + {Name: "xxxxx"}, + } + _, err = testEngine.Insert(vers) + if err != nil { + t.Error(err) + panic(err) + } + + fmt.Println(vers) + + for _, v := range vers { + if v.Ver != 1 { + err := errors.New("version should be 1") + t.Error(err) + panic(err) + } + } +} diff --git a/test_mssql.sh b/test_mssql.sh new file mode 100755 index 00000000..6f9cf729 --- /dev/null +++ b/test_mssql.sh @@ -0,0 +1 @@ +go test -db=mssql -conn_str="server=192.168.1.58;user id=sa;password=123456;database=xorm_test" \ No newline at end of file diff --git a/test_mssql_cache.sh b/test_mssql_cache.sh new file mode 100755 index 00000000..76efd6ca --- /dev/null +++ b/test_mssql_cache.sh @@ -0,0 +1 @@ +go test -db=mssql -conn_str="server=192.168.1.58;user id=sa;password=123456;database=xorm_test" -cache=true \ No newline at end of file diff --git a/test_mymysql.sh b/test_mymysql.sh new file mode 100755 index 00000000..f7780d14 --- /dev/null +++ b/test_mymysql.sh @@ -0,0 +1 @@ +go test -db=mymysql -conn_str="xorm_test/root/" \ No newline at end of file diff --git a/test_mymysql_cache.sh b/test_mymysql_cache.sh new file mode 100755 index 00000000..0100286d --- /dev/null +++ b/test_mymysql_cache.sh @@ -0,0 +1 @@ +go test -db=mymysql -conn_str="xorm_test/root/" -cache=true \ No newline at end of file diff --git a/test_mysql.sh b/test_mysql.sh new file mode 100755 index 00000000..650e4ee1 --- /dev/null +++ b/test_mysql.sh @@ -0,0 +1 @@ +go test -db=mysql -conn_str="root:@/xorm_test" \ No newline at end of file diff --git a/test_mysql_cache.sh b/test_mysql_cache.sh new file mode 100755 index 00000000..c542e735 --- /dev/null +++ b/test_mysql_cache.sh @@ -0,0 +1 @@ +go test -db=mysql -conn_str="root:@/xorm_test" -cache=true \ No newline at end of file diff --git a/test_postgres.sh b/test_postgres.sh new file mode 100755 index 00000000..dc1152e0 --- /dev/null +++ b/test_postgres.sh @@ -0,0 +1 @@ +go test -db=postgres -conn_str="dbname=xorm_test sslmode=disable" \ No newline at end of file diff --git a/test_postgres_cache.sh b/test_postgres_cache.sh new file mode 100755 index 00000000..462fc948 --- /dev/null +++ b/test_postgres_cache.sh @@ -0,0 +1 @@ +go test -db=postgres -conn_str="dbname=xorm_test sslmode=disable" -cache=true \ No newline at end of file diff --git a/test_sqlite.sh b/test_sqlite.sh new file mode 100755 index 00000000..6352b5cb --- /dev/null +++ b/test_sqlite.sh @@ -0,0 +1 @@ +go test -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ No newline at end of file diff --git a/test_sqlite_cache.sh b/test_sqlite_cache.sh new file mode 100755 index 00000000..75a054c3 --- /dev/null +++ b/test_sqlite_cache.sh @@ -0,0 +1 @@ +go test -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" -cache=true \ No newline at end of file diff --git a/time_test.go b/time_test.go new file mode 100644 index 00000000..15b20c37 --- /dev/null +++ b/time_test.go @@ -0,0 +1,476 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestTimeUserTime(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TimeUser struct { + Id string + OperTime time.Time + } + + assertSync(t, new(TimeUser)) + + var user = TimeUser{ + Id: "lunny", + OperTime: time.Now(), + } + + fmt.Println("user", user.OperTime) + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var user2 TimeUser + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.OperTime.Unix(), user2.OperTime.Unix()) + assert.EqualValues(t, formatTime(user.OperTime), formatTime(user2.OperTime)) + fmt.Println("user2", user2.OperTime) +} + +func TestTimeUserTimeDiffLoc(t *testing.T) { + assert.NoError(t, prepareEngine()) + loc, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + testEngine.TZLocation = loc + dbLoc, err := time.LoadLocation("America/New_York") + assert.NoError(t, err) + testEngine.DatabaseTZ = dbLoc + + type TimeUser2 struct { + Id string + OperTime time.Time + } + + assertSync(t, new(TimeUser2)) + + var user = TimeUser2{ + Id: "lunny", + OperTime: time.Now(), + } + + fmt.Println("user", user.OperTime) + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var user2 TimeUser2 + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.OperTime.Unix(), user2.OperTime.Unix()) + assert.EqualValues(t, formatTime(user.OperTime.In(loc)), formatTime(user2.OperTime)) + fmt.Println("user2", user2.OperTime) +} + +func TestTimeUserCreated(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserCreated struct { + Id string + CreatedAt time.Time `xorm:"created"` + } + + assertSync(t, new(UserCreated)) + + var user = UserCreated{ + Id: "lunny", + } + + fmt.Println("user", user.CreatedAt) + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var user2 UserCreated + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user2.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) + fmt.Println("user2", user2.CreatedAt) +} + +func TestTimeUserCreatedDiffLoc(t *testing.T) { + assert.NoError(t, prepareEngine()) + loc, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + testEngine.TZLocation = loc + dbLoc, err := time.LoadLocation("America/New_York") + assert.NoError(t, err) + testEngine.DatabaseTZ = dbLoc + + type UserCreated2 struct { + Id string + CreatedAt time.Time `xorm:"created"` + } + + assertSync(t, new(UserCreated2)) + + var user = UserCreated2{ + Id: "lunny", + } + + fmt.Println("user", user.CreatedAt) + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var user2 UserCreated2 + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user2.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) + fmt.Println("user2", user2.CreatedAt) +} + +func TestTimeUserUpdated(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserUpdated struct { + Id string + CreatedAt time.Time `xorm:"created"` + UpdatedAt time.Time `xorm:"updated"` + } + + assertSync(t, new(UserUpdated)) + + var user = UserUpdated{ + Id: "lunny", + } + + fmt.Println("user", user.CreatedAt, user.UpdatedAt) + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var user2 UserUpdated + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user2.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) + assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) + assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) + fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt) + + var user3 = UserUpdated{ + Id: "lunny2", + } + + cnt, err = testEngine.Update(&user3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.True(t, user.UpdatedAt.Unix() <= user3.UpdatedAt.Unix()) + + var user4 UserUpdated + has, err = testEngine.Get(&user4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user4.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user4.CreatedAt)) + assert.EqualValues(t, user3.UpdatedAt.Unix(), user4.UpdatedAt.Unix()) + assert.EqualValues(t, formatTime(user3.UpdatedAt), formatTime(user4.UpdatedAt)) + fmt.Println("user3", user.CreatedAt, user4.UpdatedAt) +} + +func TestTimeUserUpdatedDiffLoc(t *testing.T) { + assert.NoError(t, prepareEngine()) + loc, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + testEngine.TZLocation = loc + dbLoc, err := time.LoadLocation("America/New_York") + assert.NoError(t, err) + testEngine.DatabaseTZ = dbLoc + + type UserUpdated2 struct { + Id string + CreatedAt time.Time `xorm:"created"` + UpdatedAt time.Time `xorm:"updated"` + } + + assertSync(t, new(UserUpdated2)) + + var user = UserUpdated2{ + Id: "lunny", + } + + fmt.Println("user", user.CreatedAt, user.UpdatedAt) + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var user2 UserUpdated2 + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user2.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) + assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) + assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) + fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt) + + var user3 = UserUpdated2{ + Id: "lunny2", + } + + cnt, err = testEngine.Update(&user3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.True(t, user.UpdatedAt.Unix() <= user3.UpdatedAt.Unix()) + + var user4 UserUpdated2 + has, err = testEngine.Get(&user4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user4.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user4.CreatedAt)) + assert.EqualValues(t, user3.UpdatedAt.Unix(), user4.UpdatedAt.Unix()) + assert.EqualValues(t, formatTime(user3.UpdatedAt), formatTime(user4.UpdatedAt)) + fmt.Println("user3", user.CreatedAt, user4.UpdatedAt) +} + +func TestTimeUserDeleted(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserDeleted struct { + Id string + CreatedAt time.Time `xorm:"created"` + UpdatedAt time.Time `xorm:"updated"` + DeletedAt time.Time `xorm:"deleted"` + } + + assertSync(t, new(UserDeleted)) + + var user = UserDeleted{ + Id: "lunny", + } + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + fmt.Println("user", user.CreatedAt, user.UpdatedAt, user.DeletedAt) + + var user2 UserDeleted + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user2.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) + assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) + assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) + assert.True(t, isTimeZero(user2.DeletedAt)) + fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) + + var user3 UserDeleted + cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.True(t, !isTimeZero(user3.DeletedAt)) + + var user4 UserDeleted + has, err = testEngine.Unscoped().Get(&user4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user3.DeletedAt.Unix(), user4.DeletedAt.Unix()) + assert.EqualValues(t, formatTime(user3.DeletedAt), formatTime(user4.DeletedAt)) + fmt.Println("user3", user3.DeletedAt, user4.DeletedAt) +} + +func TestTimeUserDeletedDiffLoc(t *testing.T) { + assert.NoError(t, prepareEngine()) + loc, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + testEngine.TZLocation = loc + dbLoc, err := time.LoadLocation("America/New_York") + assert.NoError(t, err) + testEngine.DatabaseTZ = dbLoc + + type UserDeleted2 struct { + Id string + CreatedAt time.Time `xorm:"created"` + UpdatedAt time.Time `xorm:"updated"` + DeletedAt time.Time `xorm:"deleted"` + } + + assertSync(t, new(UserDeleted2)) + + var user = UserDeleted2{ + Id: "lunny", + } + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + fmt.Println("user", user.CreatedAt, user.UpdatedAt, user.DeletedAt) + + var user2 UserDeleted2 + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user2.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) + assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) + assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) + assert.True(t, isTimeZero(user2.DeletedAt)) + fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) + + var user3 UserDeleted2 + cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.True(t, !isTimeZero(user3.DeletedAt)) + + var user4 UserDeleted2 + has, err = testEngine.Unscoped().Get(&user4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user3.DeletedAt.Unix(), user4.DeletedAt.Unix()) + assert.EqualValues(t, formatTime(user3.DeletedAt), formatTime(user4.DeletedAt)) + fmt.Println("user3", user3.DeletedAt, user4.DeletedAt) +} + +type JsonDate time.Time + +func (j JsonDate) MarshalJSON() ([]byte, error) { + if time.Time(j).IsZero() { + return []byte(`""`), nil + } + return []byte(`"` + time.Time(j).Format("2006-01-02 15:04:05") + `"`), nil +} + +func (j *JsonDate) UnmarshalJSON(value []byte) error { + var v = strings.TrimSpace(strings.Trim(string(value), "\"")) + + t, err := time.ParseInLocation("2006-01-02 15:04:05", v, time.Local) + if err != nil { + return err + } + *j = JsonDate(t) + return nil +} + +func (j *JsonDate) Unix() int64 { + return (*time.Time)(j).Unix() +} + +func TestCustomTimeUserDeleted(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type UserDeleted3 struct { + Id string + CreatedAt JsonDate `xorm:"created"` + UpdatedAt JsonDate `xorm:"updated"` + DeletedAt JsonDate `xorm:"deleted"` + } + + assertSync(t, new(UserDeleted3)) + + var user = UserDeleted3{ + Id: "lunny", + } + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + fmt.Println("user", user.CreatedAt, user.UpdatedAt, user.DeletedAt) + + var user2 UserDeleted3 + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user2.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(time.Time(user.CreatedAt)), formatTime(time.Time(user2.CreatedAt))) + assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) + assert.EqualValues(t, formatTime(time.Time(user.UpdatedAt)), formatTime(time.Time(user2.UpdatedAt))) + assert.True(t, isTimeZero(time.Time(user2.DeletedAt))) + fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) + + var user3 UserDeleted3 + cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.True(t, !isTimeZero(time.Time(user3.DeletedAt))) + + var user4 UserDeleted3 + has, err = testEngine.Unscoped().Get(&user4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user3.DeletedAt.Unix(), user4.DeletedAt.Unix()) + assert.EqualValues(t, formatTime(time.Time(user3.DeletedAt)), formatTime(time.Time(user4.DeletedAt))) + fmt.Println("user3", user3.DeletedAt, user4.DeletedAt) +} + +func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { + assert.NoError(t, prepareEngine()) + loc, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + testEngine.TZLocation = loc + dbLoc, err := time.LoadLocation("America/New_York") + assert.NoError(t, err) + testEngine.DatabaseTZ = dbLoc + + type UserDeleted4 struct { + Id string + CreatedAt JsonDate `xorm:"created"` + UpdatedAt JsonDate `xorm:"updated"` + DeletedAt JsonDate `xorm:"deleted"` + } + + assertSync(t, new(UserDeleted4)) + + var user = UserDeleted4{ + Id: "lunny", + } + + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + fmt.Println("user", user.CreatedAt, user.UpdatedAt, user.DeletedAt) + + var user2 UserDeleted4 + has, err := testEngine.Get(&user2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user.CreatedAt.Unix(), user2.CreatedAt.Unix()) + assert.EqualValues(t, formatTime(time.Time(user.CreatedAt)), formatTime(time.Time(user2.CreatedAt))) + assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) + assert.EqualValues(t, formatTime(time.Time(user.UpdatedAt)), formatTime(time.Time(user2.UpdatedAt))) + assert.True(t, isTimeZero(time.Time(user2.DeletedAt))) + fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) + + var user3 UserDeleted4 + cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.True(t, !isTimeZero(time.Time(user3.DeletedAt))) + + var user4 UserDeleted4 + has, err = testEngine.Unscoped().Get(&user4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, user3.DeletedAt.Unix(), user4.DeletedAt.Unix()) + assert.EqualValues(t, formatTime(time.Time(user3.DeletedAt)), formatTime(time.Time(user4.DeletedAt))) + fmt.Println("user3", user3.DeletedAt, user4.DeletedAt) +} diff --git a/types_null_test.go b/types_null_test.go new file mode 100644 index 00000000..22fc1024 --- /dev/null +++ b/types_null_test.go @@ -0,0 +1,404 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +type NullType struct { + Id int `xorm:"pk autoincr"` + Name sql.NullString + Age sql.NullInt64 + Height sql.NullFloat64 + IsMan sql.NullBool `xorm:"null"` + CustomStruct CustomStruct `xorm:"valchar(64) null"` +} + +type CustomStruct struct { + Year int + Month int + Day int +} + +func (CustomStruct) String() string { + return "CustomStruct" +} + +func (m *CustomStruct) Scan(value interface{}) error { + if value == nil { + m.Year, m.Month, m.Day = 0, 0, 0 + return nil + } + + if s, ok := value.([]byte); ok { + seps := strings.Split(string(s), "/") + m.Year, _ = strconv.Atoi(seps[0]) + m.Month, _ = strconv.Atoi(seps[1]) + m.Day, _ = strconv.Atoi(seps[2]) + return nil + } + + return errors.New("scan data not fit []byte") +} + +func (m CustomStruct) Value() (driver.Value, error) { + return fmt.Sprintf("%d/%d/%d", m.Year, m.Month, m.Day), nil +} + +func TestCreateNullStructTable(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.CreateTables(new(NullType)) + if err != nil { + t.Error(err) + panic(err) + } +} + +func TestDropNullStructTable(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(new(NullType)) + if err != nil { + t.Error(err) + panic(err) + } +} + +func TestNullStructInsert(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(NullType)) + + if true { + item := new(NullType) + _, err := testEngine.Insert(item) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(item) + if item.Id != 1 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } + } + + if true { + item := NullType{ + Name: sql.NullString{"haolei", true}, + Age: sql.NullInt64{34, true}, + Height: sql.NullFloat64{1.72, true}, + IsMan: sql.NullBool{true, true}, + } + _, err := testEngine.Insert(&item) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(item) + if item.Id != 2 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } + } + + if true { + items := []NullType{} + + for i := 0; i < 5; i++ { + item := NullType{ + Name: sql.NullString{"haolei_" + fmt.Sprint(i+1), true}, + Age: sql.NullInt64{30 + int64(i), true}, + Height: sql.NullFloat64{1.5 + 1.1*float64(i), true}, + IsMan: sql.NullBool{true, true}, + CustomStruct: CustomStruct{i, i + 1, i + 2}, + } + + items = append(items, item) + } + + _, err := testEngine.Insert(&items) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(items) + } +} + +func TestNullStructUpdate(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(NullType)) + + _, err := testEngine.Insert([]NullType{ + { + Name: sql.NullString{ + String: "name1", + Valid: true, + }, + }, + { + Name: sql.NullString{ + String: "name2", + Valid: true, + }, + }, + { + Name: sql.NullString{ + String: "name3", + Valid: true, + }, + }, + { + Name: sql.NullString{ + String: "name4", + Valid: true, + }, + }, + }) + assert.NoError(t, err) + + if true { // 测试可插入NULL + item := new(NullType) + item.Age = sql.NullInt64{23, true} + item.Height = sql.NullFloat64{0, false} // update to NULL + + affected, err := testEngine.ID(2).Cols("age", "height", "is_man").Update(item) + if err != nil { + t.Error(err) + panic(err) + } + if affected != 1 { + err := errors.New("update failed") + t.Error(err) + panic(err) + } + } + + if true { // 测试In update + item := new(NullType) + item.Age = sql.NullInt64{23, true} + affected, err := testEngine.In("id", 3, 4).Cols("age", "height", "is_man").Update(item) + if err != nil { + t.Error(err) + panic(err) + } + if affected != 2 { + err := errors.New("update failed") + t.Error(err) + panic(err) + } + } + + if true { // 测试where + item := new(NullType) + item.Name = sql.NullString{"nullname", true} + item.IsMan = sql.NullBool{true, true} + item.Age = sql.NullInt64{34, true} + + _, err := testEngine.Where("age > ?", 34).Update(item) + if err != nil { + t.Error(err) + panic(err) + } + } + + if true { // 修改全部时,插入空值 + item := &NullType{ + Name: sql.NullString{"winxxp", true}, + Age: sql.NullInt64{30, true}, + Height: sql.NullFloat64{1.72, true}, + // IsMan: sql.NullBool{true, true}, + } + + _, err := testEngine.AllCols().ID(6).Update(item) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(item) + } + +} + +func TestNullStructFind(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(NullType)) + + _, err := testEngine.Insert([]NullType{ + { + Name: sql.NullString{ + String: "name1", + Valid: false, + }, + }, + { + Name: sql.NullString{ + String: "name2", + Valid: true, + }, + }, + { + Name: sql.NullString{ + String: "name3", + Valid: true, + }, + }, + { + Name: sql.NullString{ + String: "name4", + Valid: true, + }, + }, + }) + assert.NoError(t, err) + + if true { + item := new(NullType) + has, err := testEngine.ID(1).Get(item) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + t.Error(errors.New("no find id 1")) + panic(err) + } + fmt.Println(item) + if item.Id != 1 || item.Name.Valid || item.Age.Valid || item.Height.Valid || + item.IsMan.Valid { + err = errors.New("insert error") + t.Error(err) + panic(err) + } + } + + if true { + item := new(NullType) + item.Id = 2 + + has, err := testEngine.Get(item) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + t.Error(errors.New("no find id 2")) + panic(err) + } + fmt.Println(item) + } + + if true { + item := make([]NullType, 0) + + err := testEngine.ID(2).Find(&item) + if err != nil { + t.Error(err) + panic(err) + } + + fmt.Println(item) + } + + if true { + item := make([]NullType, 0) + + err := testEngine.Asc("age").Find(&item) + if err != nil { + t.Error(err) + panic(err) + } + + for k, v := range item { + fmt.Println(k, v) + } + } +} + +func TestNullStructIterate(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(NullType)) + + if true { + err := testEngine.Where("age IS NOT NULL").OrderBy("age").Iterate(new(NullType), + func(i int, bean interface{}) error { + nultype := bean.(*NullType) + fmt.Println(i, nultype) + return nil + }) + if err != nil { + t.Error(err) + panic(err) + } + } +} + +func TestNullStructCount(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(NullType)) + + if true { + item := new(NullType) + total, err := testEngine.Where("age IS NOT NULL").Count(item) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(total) + } +} + +func TestNullStructRows(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(NullType)) + + item := new(NullType) + rows, err := testEngine.Where("id > ?", 1).Rows(item) + if err != nil { + t.Error(err) + panic(err) + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(item) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(item) + } +} + +func TestNullStructDelete(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(NullType)) + + item := new(NullType) + + _, err := testEngine.ID(1).Delete(item) + if err != nil { + t.Error(err) + panic(err) + } + + _, err = testEngine.Where("id > ?", 1).Delete(item) + if err != nil { + t.Error(err) + panic(err) + } +} diff --git a/types_test.go b/types_test.go index f5a51679..df4ee70e 100644 --- a/types_test.go +++ b/types_test.go @@ -5,8 +5,12 @@ package xorm import ( + "encoding/json" + "errors" + "fmt" "testing" + "github.com/go-xorm/core" "github.com/stretchr/testify/assert" ) @@ -33,7 +37,7 @@ func TestArrayField(t *testing.T) { assert.EqualValues(t, 1, cnt) var arr ArrayStruct - has, err := testEngine.Id(1).Get(&arr) + has, err := testEngine.ID(1).Get(&arr) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, as.Name, arr.Name) @@ -77,7 +81,7 @@ func TestGetBytes(t *testing.T) { assert.NoError(t, prepareEngine()) type Varbinary struct { - Data []byte `xorm:"VARBINARY"` + Data []byte `xorm:"VARBINARY(250)"` } err := testEngine.Sync2(new(Varbinary)) @@ -95,3 +99,237 @@ func TestGetBytes(t *testing.T) { assert.Equal(t, true, has) assert.Equal(t, "test", string(b.Data)) } + +type ConvString string + +func (s *ConvString) FromDB(data []byte) error { + *s = ConvString("prefix---" + string(data)) + return nil +} + +func (s *ConvString) ToDB() ([]byte, error) { + return []byte(string(*s)), nil +} + +type ConvConfig struct { + Name string + Id int64 +} + +func (s *ConvConfig) FromDB(data []byte) error { + return json.Unmarshal(data, s) +} + +func (s *ConvConfig) ToDB() ([]byte, error) { + return json.Marshal(s) +} + +type SliceType []*ConvConfig + +func (s *SliceType) FromDB(data []byte) error { + return json.Unmarshal(data, s) +} + +func (s *SliceType) ToDB() ([]byte, error) { + return json.Marshal(s) +} + +type ConvStruct struct { + Conv ConvString + Conv2 *ConvString + Cfg1 ConvConfig + Cfg2 *ConvConfig `xorm:"TEXT"` + Cfg3 core.Conversion `xorm:"BLOB"` + Slice SliceType +} + +func (c *ConvStruct) BeforeSet(name string, cell Cell) { + if name == "cfg3" || name == "Cfg3" { + c.Cfg3 = new(ConvConfig) + } +} + +func TestConversion(t *testing.T) { + assert.NoError(t, prepareEngine()) + + c := new(ConvStruct) + assert.NoError(t, testEngine.DropTables(c)) + assert.NoError(t, testEngine.Sync(c)) + + var s ConvString = "sssss" + c.Conv = "tttt" + c.Conv2 = &s + c.Cfg1 = ConvConfig{"mm", 1} + c.Cfg2 = &ConvConfig{"xx", 2} + c.Cfg3 = &ConvConfig{"zz", 3} + c.Slice = []*ConvConfig{{"yy", 4}, {"ff", 5}} + + _, err := testEngine.Insert(c) + assert.NoError(t, err) + + c1 := new(ConvStruct) + has, err := testEngine.Get(c1) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "prefix---tttt", string(c1.Conv)) + assert.NotNil(t, c1.Conv2) + assert.EqualValues(t, "prefix---"+s, *c1.Conv2) + assert.EqualValues(t, c.Cfg1, c1.Cfg1) + assert.NotNil(t, c1.Cfg2) + assert.EqualValues(t, *c.Cfg2, *c1.Cfg2) + assert.NotNil(t, c1.Cfg3) + assert.EqualValues(t, *c.Cfg3.(*ConvConfig), *c1.Cfg3.(*ConvConfig)) + assert.EqualValues(t, 2, len(c1.Slice)) + assert.EqualValues(t, *c.Slice[0], *c1.Slice[0]) + assert.EqualValues(t, *c.Slice[1], *c1.Slice[1]) +} + +type MyInt int +type MyUInt uint +type MyFloat float64 + +type MyStruct struct { + Type MyInt + U MyUInt + F MyFloat + S MyString + IA []MyInt + UA []MyUInt + FA []MyFloat + SA []MyString + NameArray []string + Name string + UIA []uint + UIA8 []uint8 + UIA16 []uint16 + UIA32 []uint32 + UIA64 []uint64 + UI uint + //C64 complex64 + MSS map[string]string +} + +func TestCustomType1(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.DropTables(&MyStruct{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&MyStruct{}) + assert.NoError(t, err) + + i := MyStruct{Name: "Test", Type: MyInt(1)} + i.U = 23 + i.F = 1.34 + i.S = "fafdsafdsaf" + i.UI = 2 + i.IA = []MyInt{1, 3, 5} + i.UIA = []uint{1, 3} + i.UIA16 = []uint16{2} + i.UIA32 = []uint32{4, 5} + i.UIA64 = []uint64{6, 7, 9} + i.UIA8 = []uint8{1, 2, 3, 4} + i.NameArray = []string{"ssss", "fsdf", "lllll, ss"} + i.MSS = map[string]string{"s": "sfds,ss", "x": "lfjljsl"} + + cnt, err := testEngine.Insert(&i) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + fmt.Println(i) + i.NameArray = []string{} + i.MSS = map[string]string{} + i.F = 0 + has, err := testEngine.Get(&i) + assert.NoError(t, err) + assert.True(t, has) + + ss := []MyStruct{} + err = testEngine.Find(&ss) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(ss)) + assert.EqualValues(t, i, ss[0]) + + sss := MyStruct{} + has, err = testEngine.Get(&sss) + assert.NoError(t, err) + assert.True(t, has) + + sss.NameArray = []string{} + sss.MSS = map[string]string{} + cnt, err = testEngine.Delete(&sss) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type Status struct { + Name string + Color string +} + +var ( + _ core.Conversion = &Status{} + Registed Status = Status{"Registed", "white"} + Approved Status = Status{"Approved", "green"} + Removed Status = Status{"Removed", "red"} + Statuses map[string]Status = map[string]Status{ + Registed.Name: Registed, + Approved.Name: Approved, + Removed.Name: Removed, + } +) + +func (s *Status) FromDB(bytes []byte) error { + if r, ok := Statuses[string(bytes)]; ok { + *s = r + return nil + } else { + return errors.New("no this data") + } +} + +func (s *Status) ToDB() ([]byte, error) { + return []byte(s.Name), nil +} + +type UserCus struct { + Id int64 + Name string + Status Status `xorm:"varchar(40)"` +} + +func TestCustomType2(t *testing.T) { + assert.NoError(t, prepareEngine()) + + err := testEngine.CreateTables(&UserCus{}) + assert.NoError(t, err) + + tableName := testEngine.TableMapper.Obj2Table("UserCus") + _, err = testEngine.Exec("delete from " + testEngine.Quote(tableName)) + assert.NoError(t, err) + + if testEngine.Dialect().DBType() == core.MSSQL { + return + /*_, err = engine.Exec("set IDENTITY_INSERT " + tableName + " on") + if err != nil { + t.Fatal(err) + }*/ + } + + _, err = testEngine.Insert(&UserCus{1, "xlw", Registed}) + assert.NoError(t, err) + + user := UserCus{} + exist, err := testEngine.ID(1).Get(&user) + assert.NoError(t, err) + assert.True(t, exist) + + fmt.Println(user) + + users := make([]UserCus, 0) + err = testEngine.Where("`"+testEngine.ColumnMapper.Obj2Table("Status")+"` = ?", "Registed").Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + + fmt.Println(users) +} diff --git a/xorm.go b/xorm.go index c22c1b65..4fdadf2f 100644 --- a/xorm.go +++ b/xorm.go @@ -17,7 +17,7 @@ import ( const ( // Version show the xorm's version - Version string = "0.6.2.0412" + Version string = "0.6.4.0910" ) func regDrvsNDialects() bool { @@ -50,10 +50,13 @@ func close(engine *Engine) { engine.Close() } +func init() { + regDrvsNDialects() +} + // NewEngine new a db manager according to the parameter. Currently support four // drivers func NewEngine(driverName string, dataSourceName string) (*Engine, error) { - regDrvsNDialects() driver := core.QueryDriver(driverName) if driver == nil { return nil, fmt.Errorf("Unsupported driver name: %v", driverName) @@ -89,6 +92,12 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { tagHandlers: defaultTagHandlers, } + if uri.DbType == core.SQLITE { + engine.DatabaseTZ = time.UTC + } else { + engine.DatabaseTZ = time.Local + } + logger := NewSimpleLogger(os.Stdout) logger.SetLevel(core.LOG_INFO) engine.SetLogger(logger) diff --git a/xorm_test.go b/xorm_test.go index 98b42b66..1a757d3f 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -4,20 +4,25 @@ import ( "flag" "fmt" "os" + "strings" "testing" + _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" + "github.com/go-xorm/core" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + _ "github.com/ziutek/mymysql/godrv" ) var ( testEngine *Engine + dbType string connString string db = flag.String("db", "sqlite3", "the tested database") showSQL = flag.Bool("show_sql", true, "show generated SQLs") - ptrConnStr = flag.String("conn_str", "", "test database connection string") + ptrConnStr = flag.String("conn_str", "./test.db?cache=shared&mode=rwc", "test database connection string") mapType = flag.String("map_type", "snake", "indicate the name mapping") cache = flag.Bool("cache", false, "if enable cache") ) @@ -31,6 +36,22 @@ func createEngine(dbType, connStr string) error { } testEngine.ShowSQL(*showSQL) + testEngine.logger.SetLevel(core.LOG_DEBUG) + if *cache { + cacher := NewLRUCacher(NewMemoryStore(), 100000) + testEngine.SetDefaultCacher(cacher) + } + + if len(*mapType) > 0 { + switch *mapType { + case "snake": + testEngine.SetMapper(core.SnakeMapper{}) + case "same": + testEngine.SetMapper(core.SameMapper{}) + case "gonic": + testEngine.SetMapper(core.LintGonicMapper) + } + } } tables, err := testEngine.DBMetas() @@ -41,19 +62,23 @@ func createEngine(dbType, connStr string) error { for _, table := range tables { tableNames = append(tableNames, table.Name) } - return testEngine.DropTables(tableNames...) + if err = testEngine.DropTables(tableNames...); err != nil { + return err + } + return nil } func prepareEngine() error { - return createEngine(*db, connString) + return createEngine(dbType, connString) } func TestMain(m *testing.M) { flag.Parse() + dbType = *db if *db == "sqlite3" { if ptrConnStr == nil { - connString = "./test.db" + connString = "./test.db?cache=shared&mode=rwc" } else { connString = *ptrConnStr } @@ -65,11 +90,28 @@ func TestMain(m *testing.M) { connString = *ptrConnStr } - if err := prepareEngine(); err != nil { - fmt.Println(err) - return + dbs := strings.Split(*db, "::") + conns := strings.Split(connString, "::") + + var res int + for i := 0; i < len(dbs); i++ { + dbType = dbs[i] + connString = conns[i] + testEngine = nil + fmt.Println("testing", dbType, connString) + + if err := prepareEngine(); err != nil { + fmt.Println(err) + return + } + + code := m.Run() + if code > 0 { + res = code + } } - os.Exit(m.Run()) + + os.Exit(res) } func TestPing(t *testing.T) {