From 65846bacc36a9f0f991d89998c1b399dd39e5d59 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 2 Jul 2021 11:25:20 +0800 Subject: [PATCH 01/31] Improve QueryString performance (#1962) As title. Reviewed-on: https://gitea.com/xorm/xorm/pulls/1962 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- engine.go | 12 ++--- integrations/session_find_test.go | 15 +++--- integrations/session_insert_test.go | 30 +++++------ integrations/session_query_test.go | 8 +-- scan.go | 48 ++++++++++++++++++ session_query.go | 78 ++++++----------------------- 6 files changed, 95 insertions(+), 96 deletions(-) create mode 100644 scan.go diff --git a/engine.go b/engine.go index 76ce8f1a..0eb429b1 100644 --- a/engine.go +++ b/engine.go @@ -444,7 +444,7 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return engine.dumpTables(tables, w, tp...) } -func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { +func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { if d == nil { return "NULL" } @@ -473,10 +473,8 @@ func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas. return "'" + strings.Replace(v, "'", "''", -1) + "'" } else if col.SQLType.IsTime() { - if dstDialect.URI().DBType == schemas.MSSQL && col.SQLType.Name == schemas.DateTime { - if t, ok := d.(time.Time); ok { - return "'" + t.UTC().Format("2006-01-02 15:04:05") + "'" - } + if t, ok := d.(time.Time); ok { + return "'" + t.In(dbLocation).Format("2006-01-02 15:04:05") + "'" } var v = fmt.Sprintf("%s", d) if strings.HasSuffix(v, " +0000 UTC") { @@ -653,7 +651,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } field := dataStruct.FieldByIndex(col.FieldIndex) - temp += "," + formatColumnValue(dstDialect, field.Interface(), col) + temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col) } _, err = io.WriteString(w, temp[1:]+");\n") if err != nil { @@ -680,7 +678,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return errors.New("unknow column error") } - temp += "," + formatColumnValue(dstDialect, d, col) + temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col) } _, err = io.WriteString(w, temp[1:]+");\n") if err != nil { diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index 0ea12e26..80f3b72c 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -406,16 +406,16 @@ func TestFindMapPtrString(t *testing.T) { assert.NoError(t, err) } -func TestFindBit(t *testing.T) { - type FindBitStruct struct { +func TestFindBool(t *testing.T) { + type FindBoolStruct struct { Id int64 - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, PrepareEngine()) - assertSync(t, new(FindBitStruct)) + assertSync(t, new(FindBoolStruct)) - cnt, err := testEngine.Insert([]FindBitStruct{ + cnt, err := testEngine.Insert([]FindBoolStruct{ { Msg: false, }, @@ -426,14 +426,13 @@ func TestFindBit(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, cnt) - var results = make([]FindBitStruct, 0, 2) + var results = make([]FindBoolStruct, 0, 2) err = testEngine.Find(&results) assert.NoError(t, err) assert.EqualValues(t, 2, len(results)) } func TestFindMark(t *testing.T) { - type Mark struct { Mark1 string `xorm:"VARCHAR(1)"` Mark2 string `xorm:"VARCHAR(1)"` @@ -468,7 +467,7 @@ func TestFindAndCountOneFunc(t *testing.T) { type FindAndCountStruct struct { Id int64 Content string - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, PrepareEngine()) diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index e5d880ae..72e9d050 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -168,17 +168,17 @@ func TestInsertAutoIncr(t *testing.T) { assert.Greater(t, user.Uid, int64(0)) } -type DefaultInsert struct { - Id int64 - Status int `xorm:"default -1"` - Name string - Created time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` -} - func TestInsertDefault(t *testing.T) { assert.NoError(t, PrepareEngine()) + type DefaultInsert struct { + Id int64 + Status int `xorm:"default -1"` + Name string + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` + } + di := new(DefaultInsert) err := testEngine.Sync2(di) assert.NoError(t, err) @@ -195,16 +195,16 @@ func TestInsertDefault(t *testing.T) { assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix()) } -type DefaultInsert2 struct { - Id int64 - Name string - Url string `xorm:"text"` - CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"` -} - func TestInsertDefault2(t *testing.T) { assert.NoError(t, PrepareEngine()) + type DefaultInsert2 struct { + Id int64 + Name string + Url string `xorm:"text"` + CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"` + } + di := new(DefaultInsert2) err := testEngine.Sync2(di) assert.NoError(t, err) diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index 30f2e6ab..5f3a0797 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -52,7 +52,7 @@ func TestQueryString2(t *testing.T) { type GetVar3 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, testEngine.Sync2(new(GetVar3))) @@ -192,7 +192,7 @@ func TestQueryStringNoParam(t *testing.T) { type GetVar4 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, testEngine.Sync2(new(GetVar4))) @@ -229,7 +229,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { type GetVar6 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, testEngine.Sync2(new(GetVar6))) @@ -266,7 +266,7 @@ func TestQueryInterfaceNoParam(t *testing.T) { type GetVar5 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, testEngine.Sync2(new(GetVar5))) diff --git a/scan.go b/scan.go new file mode 100644 index 00000000..0a9ef613 --- /dev/null +++ b/scan.go @@ -0,0 +1,48 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "database/sql" + + "xorm.io/xorm/core" +) + +func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { + var scanResults = make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var s sql.NullString + scanResults[i] = &s + } + + if err := rows.Scan(scanResults...); err != nil { + return nil, err + } + + result := make(map[string]string, len(fields)) + for ii, key := range fields { + s := scanResults[ii].(*sql.NullString) + result[key] = s.String + } + return result, nil +} + +func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { + results := make([]string, 0, len(fields)) + var scanResults = make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var s sql.NullString + scanResults[i] = &s + } + + if err := rows.Scan(scanResults...); err != nil { + return nil, err + } + + for i := 0; i < len(fields); i++ { + results = append(results, scanResults[i].(*sql.NullString).String) + } + return results, nil +} diff --git a/session_query.go b/session_query.go index 12136466..379ad0e1 100644 --- a/session_query.go +++ b/session_query.go @@ -75,69 +75,18 @@ func value2String(rawValue *reflect.Value) (str string, err error) { return } -func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) { - result := make(map[string]string) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - // if row is null then as empty string - if rawValue.Interface() == nil { - result[key] = "" - continue - } - - if data, err := value2String(&rawValue); err == nil { - result[key] = data - } else { - return nil, err - } - } - return result, nil -} - -func row2sliceStr(rows *core.Rows, fields []string) (results []string, err error) { - result := make([]string, 0, len(fields)) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for i := 0; i < len(fields); i++ { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[i])) - // if row is null then as empty string - if rawValue.Interface() == nil { - result = append(result, "") - continue - } - - if data, err := value2String(&rawValue); err == nil { - result = append(result, data) - } else { - return nil, err - } - } - return result, nil -} - -func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { +func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { fields, err := rows.Columns() if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + for rows.Next() { - result, err := row2mapStr(rows, fields) + result, err := session.engine.row2mapStr(rows, types, fields) if err != nil { return nil, err } @@ -147,13 +96,18 @@ func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) return resultsSlice, nil } -func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) { +func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) { fields, err := rows.Columns() if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + for rows.Next() { - record, err := row2sliceStr(rows, fields) + record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err } @@ -180,7 +134,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri } defer rows.Close() - return rows2Strings(rows) + return session.rows2Strings(rows) } // QuerySliceString runs a raw sql and return records as [][]string @@ -200,7 +154,7 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, } defer rows.Close() - return rows2SliceString(rows) + return session.rows2SliceString(rows) } func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) { From 66fc59b71c87a1bf1ae27f0b8cb6315b1adf1d0b Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 2 Jul 2021 12:37:03 +0800 Subject: [PATCH 02/31] Query bytes based on Query string (#1964) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1964 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- scan.go | 19 +++++++++++++++++++ session_raw.go | 38 +++++++------------------------------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/scan.go b/scan.go index 0a9ef613..e19037a0 100644 --- a/scan.go +++ b/scan.go @@ -29,6 +29,25 @@ func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, field return result, nil } +func (engine *Engine) row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) { + var scanResults = make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var s sql.NullString + scanResults[i] = &s + } + + if err := rows.Scan(scanResults...); err != nil { + return nil, err + } + + result := make(map[string][]byte, len(fields)) + for ii, key := range fields { + s := scanResults[ii].(*sql.NullString) + result[key] = []byte(s.String) + } + return result, nil +} + func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { results := make([]string, 0, len(fields)) var scanResults = make([]interface{}, len(fields)) diff --git a/session_raw.go b/session_raw.go index 4cfe297a..d5c4520b 100644 --- a/session_raw.go +++ b/session_raw.go @@ -79,41 +79,17 @@ func value2Bytes(rawValue *reflect.Value) ([]byte, error) { return []byte(str), nil } -func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) { - result := make(map[string][]byte) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - //if row is null then ignore - if rawValue.Interface() == nil { - result[key] = []byte{} - continue - } - - if data, err := value2Bytes(&rawValue); err == nil { - result[key] = data - } else { - return nil, err // !nashtsai! REVIEW, should return err or just error log? - } - } - return result, nil -} - -func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { +func (session *Session) rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { fields, err := rows.Columns() if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } for rows.Next() { - result, err := row2map(rows, fields) + result, err := session.engine.row2mapBytes(rows, types, fields) if err != nil { return nil, err } @@ -130,7 +106,7 @@ func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[st } defer rows.Close() - return rows2maps(rows) + return session.rows2maps(rows) } func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { From 962962bb64af2d2f9fd3964794d425ceaa83ea9d Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 3 Jul 2021 20:26:49 +0800 Subject: [PATCH 03/31] Fix #929 (#1936) sql server doesn't accept to insert a blank datetime like `0001-01-01 00:00:00`. So that we have a break change here that deleted column should not have a notnull tag. Reviewed-on: https://gitea.com/xorm/xorm/pulls/1936 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- integrations/session_insert_test.go | 41 +++++++++++++++++++++++++++++ internal/statements/insert.go | 2 +- internal/statements/statement.go | 2 +- session_insert.go | 11 ++++++++ tags/parser.go | 6 +++++ tags/tag.go | 1 + 6 files changed, 61 insertions(+), 2 deletions(-) diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index 72e9d050..a023ab72 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -1024,3 +1024,44 @@ func TestInsertIntSlice(t *testing.T) { assert.True(t, has) assert.EqualValues(t, v3, v4) } + +func TestInsertDeleted(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type InsertDeletedStructNotRight struct { + ID uint64 `xorm:"'ID' pk autoincr"` + DeletedAt time.Time `xorm:"'DELETED_AT' deleted notnull"` + } + // notnull tag will be ignored + err := testEngine.Sync2(new(InsertDeletedStructNotRight)) + assert.NoError(t, err) + + type InsertDeletedStruct struct { + ID uint64 `xorm:"'ID' pk autoincr"` + DeletedAt time.Time `xorm:"'DELETED_AT' deleted"` + } + + assert.NoError(t, testEngine.Sync2(new(InsertDeletedStruct))) + + var v InsertDeletedStruct + _, err = testEngine.Insert(&v) + assert.NoError(t, err) + + var v2 InsertDeletedStruct + has, err := testEngine.Get(&v2) + assert.NoError(t, err) + assert.True(t, has) + + _, err = testEngine.ID(v.ID).Delete(new(InsertDeletedStruct)) + assert.NoError(t, err) + + var v3 InsertDeletedStruct + has, err = testEngine.Get(&v3) + assert.NoError(t, err) + assert.False(t, has) + + var v4 InsertDeletedStruct + has, err = testEngine.Unscoped().Get(&v4) + assert.NoError(t, err) + assert.True(t, has) +} diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 367dbdc9..4e43c5bd 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -17,7 +17,7 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem if _, err := buf.WriteString(" OUTPUT Inserted."); err != nil { return err } - if _, err := buf.WriteString(table.AutoIncrement); err != nil { + if err := statement.dialect.Quoter().QuoteTo(buf, table.AutoIncrement); err != nil { return err } } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index b1a5ed3c..2d173b87 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -978,7 +978,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName // CondDeleted returns the conditions whether a record is soft deleted. func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { - var colName = col.Name + var colName = statement.quote(col.Name) if statement.JoinStr != "" { var prefix string if statement.TableAlias != "" { diff --git a/session_insert.go b/session_insert.go index 82d91969..e733e06e 100644 --- a/session_insert.go +++ b/session_insert.go @@ -11,6 +11,7 @@ import ( "sort" "strconv" "strings" + "time" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" @@ -497,6 +498,16 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } if col.IsDeleted { + colNames = append(colNames, col.Name) + if !col.Nullable { + if col.SQLType.IsNumeric() { + args = append(args, 0) + } else { + args = append(args, time.Time{}.Format("2006-01-02 15:04:05")) + } + } else { + args = append(args, nil) + } continue } diff --git a/tags/parser.go b/tags/parser.go index d701e316..b793a8f1 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -296,5 +296,11 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table.AddColumn(col) } // end for + deletedColumn := table.DeletedColumn() + // check columns + if deletedColumn != nil { + deletedColumn.Nullable = true + } + return table, nil } diff --git a/tags/tag.go b/tags/tag.go index 4a39ba54..641b8c52 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -238,6 +238,7 @@ func UpdatedTagHandler(ctx *Context) error { // DeletedTagHandler describes deleted tag handler func DeletedTagHandler(ctx *Context) error { ctx.col.IsDeleted = true + ctx.col.Nullable = true return nil } From 60e128eb4d92027de3a99b59fb1f4c5751eac382 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 3 Jul 2021 22:45:28 +0800 Subject: [PATCH 04/31] Changelog for v1.1.1 --- CHANGELOG.md | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13e721ec..ce5b4fe9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,36 @@ This changelog goes through all the changes that have been made in each release without substantial changes to our git log. +## [1.1.1](https://gitea.com/xorm/xorm/releases/tag/1.1.1) - 2021-07-03 + +* BUGFIXES + * Ignore comments when deciding when to replace question marks. #1954 (#1955) + * Fix bug didn't reset statement on update (#1939) + * Fix create table with struct missing columns (#1938) + * Fix #929 (#1936) + * Fix exist (#1921) +* ENHANCEMENTS + * Improve get field value of bean (#1961) + * refactor splitTag function (#1960) + * Fix #1663 (#1952) + * fix pg GetColumns missing comment (#1949) + * Support build flag jsoniter to replace default json (#1916) + * refactor exprParam (#1825) + * Add DBVersion (#1723) +* TESTING + * Add test to confirm #1247 resolved (#1951) + * Add test for dump table with default value (#1950) + * Test for #1486 (#1942) + * Add sync tests to confirm #539 is gone (#1937) + * test for unsigned int32 (#1923) + * Add tests for array store (#1922) +* BUILD + * Remove mymysql from ci (#1928) +* MISC + * fix lint (#1953) + * Compitable with cockroach (#1930) + * Replace goracle with godror (#1914) + ## [1.1.0](https://gitea.com/xorm/xorm/releases/tag/1.1.0) - 2021-05-14 * FEATURES From cbc40dfe5c36c63bd93d567e8cafc6d77d7b9dc5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 4 Jul 2021 18:19:46 +0800 Subject: [PATCH 05/31] Add release tag (#1966) as title Reviewed-on: https://gitea.com/xorm/xorm/pulls/1966 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- .drone.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.drone.yml b/.drone.yml index 4f84d7fa..8a9f8877 100644 --- a/.drone.yml +++ b/.drone.yml @@ -347,3 +347,19 @@ steps: image: golang:1.15 commands: - make coverage + +--- +kind: pipeline +name: release-tag +trigger: + event: + - tag +steps: +- name: release-tag-gitea + pull: always + image: plugins/gitea-release:latest + settings: + base_url: https://gitea.com + title: '${DRONE_TAG} is released' + api_key: + from_secret: gitea_token \ No newline at end of file From 4f92921e43a27093034284b6cf0109c1b214a591 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 4 Jul 2021 19:04:48 +0800 Subject: [PATCH 06/31] Add changelog for v1.1.2 --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce5b4fe9..cd567b27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ This changelog goes through all the changes that have been made in each release without substantial changes to our git log. +## [1.1.2](https://gitea.com/xorm/xorm/releases/tag/1.1.2) - 2021-07-04 + +* BUILD + * Add release tag (#1966) + ## [1.1.1](https://gitea.com/xorm/xorm/releases/tag/1.1.1) - 2021-07-03 * BUGFIXES From d0e5dba40efff87a4e58c827ba0e4276b518539d Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 4 Jul 2021 21:23:17 +0800 Subject: [PATCH 07/31] Query interface (#1965) refactor query interface Reviewed-on: https://gitea.com/xorm/xorm/pulls/1965 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert/interface.go | 48 +++++++++ convert/time.go | 30 ++++++ dialects/driver.go | 18 ++++ dialects/mssql.go | 25 +++++ dialects/mysql.go | 157 ++++++++++++++++++++++------- dialects/oracle.go | 23 +++++ dialects/postgres.go | 32 ++++++ dialects/sqlite3.go | 27 +++++ engine.go | 2 + integrations/session_query_test.go | 26 +++-- scan.go | 55 +++++++++- session_query.go | 84 ++------------- session_raw.go | 73 ++++++++++---- 13 files changed, 455 insertions(+), 145 deletions(-) create mode 100644 convert/interface.go create mode 100644 convert/time.go diff --git a/convert/interface.go b/convert/interface.go new file mode 100644 index 00000000..2b055253 --- /dev/null +++ b/convert/interface.go @@ -0,0 +1,48 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "time" +) + +func Interface2Interface(userLocation *time.Location, v interface{}) (interface{}, error) { + if v == nil { + return nil, nil + } + switch vv := v.(type) { + case *int64: + return *vv, nil + case *int8: + return *vv, nil + case *sql.NullString: + return vv.String, nil + case *sql.RawBytes: + if len([]byte(*vv)) > 0 { + return []byte(*vv), nil + } + return nil, nil + case *sql.NullInt32: + return vv.Int32, nil + case *sql.NullInt64: + return vv.Int64, nil + case *sql.NullFloat64: + return vv.Float64, nil + case *sql.NullBool: + if vv.Valid { + return vv.Bool, nil + } + return nil, nil + case *sql.NullTime: + if vv.Valid { + return vv.Time.In(userLocation).Format("2006-01-02 15:04:05"), nil + } + return "", nil + default: + return "", fmt.Errorf("convert assign string unsupported type: %#v", vv) + } +} diff --git a/convert/time.go b/convert/time.go new file mode 100644 index 00000000..8901279b --- /dev/null +++ b/convert/time.go @@ -0,0 +1,30 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "fmt" + "time" +) + +// String2Time converts a string to time with original location +func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { + if len(s) == 19 { + dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { + dt, err := time.ParseInLocation("2006-01-02T15:04:05Z", s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } + return nil, fmt.Errorf("unsupported convertion from %s to time", s) +} diff --git a/dialects/driver.go b/dialects/driver.go index bb46a936..c511b665 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -5,12 +5,24 @@ package dialects import ( + "database/sql" "fmt" + "time" + + "xorm.io/xorm/core" ) +// ScanContext represents a context when Scan +type ScanContext struct { + DBLocation *time.Location + UserLocation *time.Location +} + // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) + GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface + Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } var ( @@ -59,3 +71,9 @@ func OpenDialect(driverName, connstr string) (Dialect, error) { return dialect, nil } + +type baseDriver struct{} + +func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { + return rows.Scan(v...) +} diff --git a/dialects/mssql.go b/dialects/mssql.go index 7e922e62..c3c15077 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "net/url" @@ -624,6 +625,7 @@ func (db *mssql) Filters() []Filter { } type odbcDriver struct { + baseDriver } func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -652,3 +654,26 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { } return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil } + +func (p *odbcDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT": + fallthrough + case "DATE", "DATETIME", "DATETIME2", "TIME": + var s sql.NullString + return &s, nil + case "FLOAT", "REAL": + var s sql.NullFloat64 + return &s, nil + case "BIGINT", "DATETIMEOFFSET": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "SMALLINT", "INT": + var s sql.NullInt32 + return &s, nil + + default: + var r sql.RawBytes + return &r, nil + } +} diff --git a/dialects/mysql.go b/dialects/mysql.go index a169b901..03bc9a4b 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -7,6 +7,7 @@ package dialects import ( "context" "crypto/tls" + "database/sql" "errors" "fmt" "regexp" @@ -14,6 +15,7 @@ import ( "strings" "time" + "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/schemas" ) @@ -630,7 +632,124 @@ func (db *mysql) Filters() []Filter { return []Filter{} } +type mysqlDriver struct { +} + +func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { + dsnPattern := regexp.MustCompile( + `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] + `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] + `\/(?P.*?)` + // /dbname + `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] + matches := dsnPattern.FindStringSubmatch(dataSourceName) + // tlsConfigRegister := make(map[string]*tls.Config) + names := dsnPattern.SubexpNames() + + uri := &URI{DBType: schemas.MYSQL} + + for i, match := range matches { + switch names[i] { + case "dbname": + uri.DBName = match + case "params": + if len(match) > 0 { + kvs := strings.Split(match, "&") + for _, kv := range kvs { + splits := strings.Split(kv, "=") + if len(splits) == 2 { + switch splits[0] { + case "charset": + uri.Charset = splits[1] + } + } + } + } + + } + } + return uri, nil +} + +func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET": + var s sql.NullString + return &s, nil + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "SMALLINT", "MEDIUMINT", "INT": + var s sql.NullInt32 + return &s, nil + case "FLOAT", "REAL", "DOUBLE PRECISION": + var s sql.NullFloat64 + return &s, nil + case "DECIMAL", "NUMERIC": + var s sql.NullString + return &s, nil + case "DATETIME": + var s sql.NullTime + return &s, nil + case "BIT": + var s sql.RawBytes + return &s, nil + case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB": + var r sql.RawBytes + return &r, nil + default: + var r sql.RawBytes + return &r, nil + } +} + +func (p *mysqlDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, scanResults ...interface{}) error { + var v2 = make([]interface{}, 0, len(scanResults)) + var turnBackIdxes = make([]int, 0, 5) + for i, vv := range scanResults { + switch vv.(type) { + case *time.Time: + v2 = append(v2, &sql.NullString{}) + turnBackIdxes = append(turnBackIdxes, i) + case *sql.NullTime: + v2 = append(v2, &sql.NullString{}) + turnBackIdxes = append(turnBackIdxes, i) + default: + v2 = append(v2, scanResults[i]) + } + } + if err := rows.Scan(v2...); err != nil { + return err + } + for _, i := range turnBackIdxes { + switch t := scanResults[i].(type) { + case *time.Time: + var s = *(v2[i].(*sql.NullString)) + if !s.Valid { + break + } + dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) + if err != nil { + return err + } + *t = *dt + case *sql.NullTime: + var s = *(v2[i].(*sql.NullString)) + if !s.Valid { + break + } + dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) + if err != nil { + return err + } + t.Time = *dt + t.Valid = true + } + } + return nil +} + type mymysqlDriver struct { + mysqlDriver } func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -681,41 +800,3 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { return uri, nil } - -type mysqlDriver struct { -} - -func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { - dsnPattern := regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] - matches := dsnPattern.FindStringSubmatch(dataSourceName) - // tlsConfigRegister := make(map[string]*tls.Config) - names := dsnPattern.SubexpNames() - - uri := &URI{DBType: schemas.MYSQL} - - for i, match := range matches { - switch names[i] { - case "dbname": - uri.DBName = match - case "params": - if len(match) > 0 { - kvs := strings.Split(match, "&") - for _, kv := range kvs { - splits := strings.Split(kv, "=") - if len(splits) == 2 { - switch splits[0] { - case "charset": - uri.Charset = splits[1] - } - } - } - } - - } - } - return uri, nil -} diff --git a/dialects/oracle.go b/dialects/oracle.go index 0b06c4c6..7043972b 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "regexp" @@ -823,6 +824,7 @@ func (db *oracle) Filters() []Filter { } type godrorDriver struct { + baseDriver } func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -848,7 +850,28 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) return db, nil } +func (p *godrorDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + var s sql.NullString + return &s, nil + case "NUMBER": + var s sql.NullString + return &s, nil + case "DATE": + var s sql.NullTime + return &s, nil + case "BLOB": + var r sql.RawBytes + return &r, nil + default: + var r sql.RawBytes + return &r, nil + } +} + type oci8Driver struct { + godrorDriver } // dataSourceName=user/password@ipv4:port/dbname diff --git a/dialects/postgres.go b/dialects/postgres.go index 9acf763a..e4641509 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "net/url" @@ -1298,6 +1299,7 @@ func (db *postgres) Filters() []Filter { } type pqDriver struct { + baseDriver } type values map[string]string @@ -1374,6 +1376,36 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { return db, nil } +func (p *pqDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT": + var s sql.NullString + return &s, nil + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "INT", "INT8", "INT4": + var s sql.NullInt32 + return &s, nil + case "FLOAT", "FLOAT4": + var s sql.NullFloat64 + return &s, nil + case "DATETIME", "TIMESTAMP": + var s sql.NullTime + return &s, nil + case "BIT": + var s sql.RawBytes + return &s, nil + case "BOOL": + var s sql.NullBool + return &s, nil + default: + fmt.Printf("unknow postgres database type: %v\n", colType) + var r sql.RawBytes + return &r, nil + } +} + type pqDriverPgx struct { pqDriver } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index a42aad48..306f377c 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -540,6 +540,7 @@ func (db *sqlite3) Filters() []Filter { } type sqlite3Driver struct { + baseDriver } func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -549,3 +550,29 @@ func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil } + +func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "TEXT": + var s sql.NullString + return &s, nil + case "INTEGER": + var s sql.NullInt64 + return &s, nil + case "DATETIME": + var s sql.NullTime + return &s, nil + case "REAL": + var s sql.NullFloat64 + return &s, nil + case "NUMERIC": + var s sql.NullString + return &s, nil + case "BLOB": + var s sql.RawBytes + return &s, nil + default: + var r sql.NullString + return &r, nil + } +} diff --git a/engine.go b/engine.go index 0eb429b1..1064e8e1 100644 --- a/engine.go +++ b/engine.go @@ -35,6 +35,7 @@ type Engine struct { cacherMgr *caches.Manager defaultContext context.Context dialect dialects.Dialect + driver dialects.Driver engineGroup *EngineGroup logger log.ContextLogger tagParser *tags.Parser @@ -72,6 +73,7 @@ func newEngine(driverName, dataSourceName string, dialect dialects.Dialect, db * engine := &Engine{ dialect: dialect, + driver: dialects.QueryDriver(driverName), TZLocation: time.Local, defaultContext: context.Background(), cacherMgr: cacherMgr, diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index 5f3a0797..ed03ff3e 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -107,6 +107,16 @@ func toFloat64(i interface{}) float64 { return 0 } +func toBool(i interface{}) bool { + switch t := i.(type) { + case int32: + return t > 0 + case bool: + return t + } + return false +} + func TestQueryInterface(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -132,10 +142,10 @@ func TestQueryInterface(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.Equal(t, "hi", toString(records[0]["msg"])) - assert.EqualValues(t, 28, toInt64(records[0]["age"])) - assert.EqualValues(t, 1.5, toFloat64(records[0]["money"])) + assert.EqualValues(t, int64(1), records[0]["id"]) + assert.Equal(t, "hi", records[0]["msg"]) + assert.EqualValues(t, 28, records[0]["age"]) + assert.EqualValues(t, 1.5, records[0]["money"]) } func TestQueryNoParams(t *testing.T) { @@ -280,14 +290,14 @@ func TestQueryInterfaceNoParam(t *testing.T) { records, err := testEngine.Table("get_var5").Limit(1).QueryInterface() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.EqualValues(t, 0, toInt64(records[0]["msg"])) + assert.EqualValues(t, 1, records[0]["id"]) + assert.False(t, toBool(records[0]["msg"])) records, err = testEngine.Table("get_var5").Where(builder.Eq{"id": 1}).QueryInterface() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.EqualValues(t, 0, toInt64(records[0]["msg"])) + assert.EqualValues(t, 1, records[0]["id"]) + assert.False(t, toBool(records[0]["msg"])) } func TestQueryWithBuilder(t *testing.T) { diff --git a/scan.go b/scan.go index e19037a0..e11d6e8d 100644 --- a/scan.go +++ b/scan.go @@ -7,10 +7,12 @@ package xorm import ( "database/sql" + "xorm.io/xorm/convert" "xorm.io/xorm/core" + "xorm.io/xorm/dialects" ) -func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { +func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { var s sql.NullString @@ -29,7 +31,7 @@ func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, field return result, nil } -func (engine *Engine) row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) { +func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string][]byte, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { var s sql.NullString @@ -48,7 +50,7 @@ func (engine *Engine) row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fie return result, nil } -func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { +func row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { results := make([]string, 0, len(fields)) var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { @@ -65,3 +67,50 @@ func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fie } return results, nil } + +func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2mapBytes(rows, types, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + + return resultsSlice, nil +} + +func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) { + var resultsMap = make(map[string]interface{}, len(fields)) + var scanResultContainers = make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResultContainers...); err != nil { + return nil, err + } + + for ii, key := range fields { + res, err := convert.Interface2Interface(engine.TZLocation, scanResultContainers[ii]) + if err != nil { + return nil, err + } + resultsMap[key] = res + } + return resultsMap, nil +} diff --git a/session_query.go b/session_query.go index 379ad0e1..01cd6f44 100644 --- a/session_query.go +++ b/session_query.go @@ -5,13 +5,7 @@ package xorm import ( - "fmt" - "reflect" - "strconv" - "time" - "xorm.io/xorm/core" - "xorm.io/xorm/schemas" ) // Query runs a raw sql and return records as []map[string][]byte @@ -28,53 +22,6 @@ func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, er return session.queryBytes(sqlStr, args...) } -func value2String(rawValue *reflect.Value) (str string, err error) { - aa := reflect.TypeOf((*rawValue).Interface()) - vv := reflect.ValueOf((*rawValue).Interface()) - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - case reflect.String: - str = vv.String() - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - data := rawValue.Interface().([]byte) - str = string(data) - if str == "\x00" { - str = "0" - } - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - // time type - case reflect.Struct: - if aa.ConvertibleTo(schemas.TimeType) { - str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) - } else { - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - return -} - func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { fields, err := rows.Columns() if err != nil { @@ -86,7 +33,7 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string } for rows.Next() { - result, err := session.engine.row2mapStr(rows, types, fields) + result, err := row2mapStr(rows, types, fields) if err != nil { return nil, err } @@ -107,7 +54,7 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { - record, err := session.engine.row2sliceStr(rows, types, fields) + record, err := row2sliceStr(rows, types, fields) if err != nil { return nil, err } @@ -157,30 +104,17 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, return session.rows2SliceString(rows) } -func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) { - resultsMap = make(map[string]interface{}, len(fields)) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - resultsMap[key] = reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])).Interface() - } - return -} - -func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) { +func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) { fields, err := rows.Columns() if err != nil { return nil, err } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } for rows.Next() { - result, err := row2mapInterface(rows, fields) + result, err := session.engine.row2mapInterface(rows, types, fields) if err != nil { return nil, err } @@ -207,5 +141,5 @@ func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]i } defer rows.Close() - return rows2Interfaces(rows) + return session.rows2Interfaces(rows) } diff --git a/session_raw.go b/session_raw.go index d5c4520b..bf32c6ed 100644 --- a/session_raw.go +++ b/session_raw.go @@ -6,9 +6,13 @@ package xorm import ( "database/sql" + "fmt" "reflect" + "strconv" + "time" "xorm.io/xorm/core" + "xorm.io/xorm/schemas" ) func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { @@ -71,6 +75,53 @@ func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row { return core.NewRow(session.queryRows(sqlStr, args...)) } +func value2String(rawValue *reflect.Value) (str string, err error) { + aa := reflect.TypeOf((*rawValue).Interface()) + vv := reflect.ValueOf((*rawValue).Interface()) + switch aa.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + case reflect.String: + str = vv.String() + case reflect.Array, reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + data := rawValue.Interface().([]byte) + str = string(data) + if str == "\x00" { + str = "0" + } + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + // time type + case reflect.Struct: + if aa.ConvertibleTo(schemas.TimeType) { + str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) + } else { + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + case reflect.Bool: + str = strconv.FormatBool(vv.Bool()) + case reflect.Complex128, reflect.Complex64: + str = fmt.Sprintf("%v", vv.Complex()) + /* TODO: unsupported types below + case reflect.Map: + case reflect.Ptr: + case reflect.Uintptr: + case reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.Interface: + */ + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + return +} + func value2Bytes(rawValue *reflect.Value) ([]byte, error) { str, err := value2String(rawValue) if err != nil { @@ -79,26 +130,6 @@ func value2Bytes(rawValue *reflect.Value) ([]byte, error) { return []byte(str), nil } -func (session *Session) rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - types, err := rows.ColumnTypes() - if err != nil { - return nil, err - } - for rows.Next() { - result, err := session.engine.row2mapBytes(rows, types, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, result) - } - - return resultsSlice, nil -} - func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { @@ -106,7 +137,7 @@ func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[st } defer rows.Close() - return session.rows2maps(rows) + return rows2maps(rows) } func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { From a5030dc7a444c098b8c97d514f06df90ed6b57f8 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 6 Jul 2021 16:06:04 +0800 Subject: [PATCH 08/31] refactor get (#1967) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1967 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 416 +++++++++++++++++++++++++++++++++++++++++-- dialects/driver.go | 11 ++ dialects/mysql.go | 1 + dialects/postgres.go | 6 + dialects/sqlite3.go | 6 + scan.go | 197 +++++++++++++++++++- session_find.go | 2 +- session_get.go | 299 ++++++++++++++++--------------- session_insert.go | 6 +- session_query.go | 2 +- 10 files changed, 784 insertions(+), 162 deletions(-) diff --git a/convert.go b/convert.go index b7f30cad..f7d733ad 100644 --- a/convert.go +++ b/convert.go @@ -5,12 +5,15 @@ package xorm import ( + "database/sql" "database/sql/driver" "errors" "fmt" "reflect" "strconv" "time" + + "xorm.io/xorm/convert" ) var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error @@ -37,6 +40,12 @@ func asString(src interface{}) string { return v case []byte: return string(v) + case *sql.NullString: + return v.String + case *sql.NullInt32: + return fmt.Sprintf("%d", v.Int32) + case *sql.NullInt64: + return fmt.Sprintf("%d", v.Int64) } rv := reflect.ValueOf(src) switch rv.Kind() { @@ -54,6 +63,156 @@ func asString(src interface{}) string { return fmt.Sprintf("%v", src) } +func asInt64(src interface{}) (int64, error) { + switch v := src.(type) { + case int: + return int64(v), nil + case int16: + return int64(v), nil + case int32: + return int64(v), nil + case int8: + return int64(v), nil + case int64: + return v, nil + case uint: + return int64(v), nil + case uint8: + return int64(v), nil + case uint16: + return int64(v), nil + case uint32: + return int64(v), nil + case uint64: + return int64(v), nil + case []byte: + return strconv.ParseInt(string(v), 10, 64) + case string: + return strconv.ParseInt(v, 10, 64) + case *sql.NullString: + return strconv.ParseInt(v.String, 10, 64) + case *sql.NullInt32: + return int64(v.Int32), nil + case *sql.NullInt64: + return int64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(rv.Uint()), nil + case reflect.Float64: + return int64(rv.Float()), nil + case reflect.Float32: + return int64(rv.Float()), nil + case reflect.String: + return strconv.ParseInt(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + +func asUint64(src interface{}) (uint64, error) { + switch v := src.(type) { + case int: + return uint64(v), nil + case int16: + return uint64(v), nil + case int32: + return uint64(v), nil + case int8: + return uint64(v), nil + case int64: + return uint64(v), nil + case uint: + return uint64(v), nil + case uint8: + return uint64(v), nil + case uint16: + return uint64(v), nil + case uint32: + return uint64(v), nil + case uint64: + return v, nil + case []byte: + return strconv.ParseUint(string(v), 10, 64) + case string: + return strconv.ParseUint(v, 10, 64) + case *sql.NullString: + return strconv.ParseUint(v.String, 10, 64) + case *sql.NullInt32: + return uint64(v.Int32), nil + case *sql.NullInt64: + return uint64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return uint64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uint64(rv.Uint()), nil + case reflect.Float64: + return uint64(rv.Float()), nil + case reflect.Float32: + return uint64(rv.Float()), nil + case reflect.String: + return strconv.ParseUint(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as uint64", src) +} + +func asFloat64(src interface{}) (float64, error) { + switch v := src.(type) { + case int: + return float64(v), nil + case int16: + return float64(v), nil + case int32: + return float64(v), nil + case int8: + return float64(v), nil + case int64: + return float64(v), nil + case uint: + return float64(v), nil + case uint8: + return float64(v), nil + case uint16: + return float64(v), nil + case uint32: + return float64(v), nil + case uint64: + return float64(v), nil + case []byte: + return strconv.ParseFloat(string(v), 64) + case string: + return strconv.ParseFloat(v, 64) + case *sql.NullString: + return strconv.ParseFloat(v.String, 64) + case *sql.NullInt32: + return float64(v.Int32), nil + case *sql.NullInt64: + return float64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float64(rv.Uint()), nil + case reflect.Float64: + return float64(rv.Float()), nil + case reflect.Float32: + return float64(rv.Float()), nil + case reflect.String: + return strconv.ParseFloat(rv.String(), 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -76,7 +235,7 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { // convertAssign copies to dest the value in src, converting it if possible. // An error is returned if the copy would result in loss of information. // dest should be a pointer type. -func convertAssign(dest, src interface{}) error { +func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { // Common cases, without reflect. switch s := src.(type) { case string: @@ -143,6 +302,163 @@ func convertAssign(dest, src interface{}) error { *d = nil return nil } + case *sql.NullString: + switch d := dest.(type) { + case *int: + if s.Valid { + *d, _ = strconv.Atoi(s.String) + } + case *int64: + if s.Valid { + *d, _ = strconv.ParseInt(s.String, 10, 64) + } + case *string: + if s.Valid { + *d = s.String + } + return nil + case *time.Time: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + *d = *dt + } + return nil + case *sql.NullTime: + if s.Valid { + var err error + dt, err := convert.String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + d.Valid = true + d.Time = *dt + } + } + case *sql.NullInt32: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int32) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int32) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int32) + } + return nil + case *int32: + if s.Valid { + *d = s.Int32 + } + return nil + case *int64: + if s.Valid { + *d = int64(s.Int32) + } + return nil + } + case *sql.NullInt64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int64) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int64) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int64) + } + return nil + case *int32: + if s.Valid { + *d = int32(s.Int64) + } + return nil + case *int64: + if s.Valid { + *d = s.Int64 + } + return nil + } + case *sql.NullFloat64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Float64) + } + return nil + case *float64: + if s.Valid { + *d = s.Float64 + } + return nil + } + case *sql.NullBool: + switch d := dest.(type) { + case *bool: + if s.Valid { + *d = s.Bool + } + return nil + } + case *sql.NullTime: + switch d := dest.(type) { + case *time.Time: + if s.Valid { + *d = s.Time + } + return nil + case *string: + if s.Valid { + *d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05") + } + return nil + } + case *NullUint32: + switch d := dest.(type) { + case *uint8: + if s.Valid { + *d = uint8(s.Uint32) + } + return nil + case *uint16: + if s.Valid { + *d = uint16(s.Uint32) + } + return nil + case *uint: + if s.Valid { + *d = uint(s.Uint32) + } + return nil + } + case *NullUint64: + switch d := dest.(type) { + case *uint64: + if s.Valid { + *d = s.Uint64 + } + return nil + } + case *sql.RawBytes: + switch d := dest.(type) { + case convert.Conversion: + return d.FromDB(*s) + } } var sv reflect.Value @@ -175,10 +491,10 @@ func convertAssign(dest, src interface{}) error { return nil } - return convertAssignV(reflect.ValueOf(dest), src) + return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation) } -func convertAssignV(dpv reflect.Value, src interface{}) error { +func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error { if dpv.Kind() != reflect.Ptr { return errors.New("destination not a pointer") } @@ -212,31 +528,28 @@ func convertAssignV(dpv reflect.Value, src interface{}) error { } dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src) + return convertAssign(dv.Interface(), src, originalLocation, convertedLocation) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - s := asString(src) - i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + i64, err := asInt64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetInt(i64) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - s := asString(src) - u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + u64, err := asUint64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetUint(u64) return nil case reflect.Float32, reflect.Float64: - s := asString(src) - f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + f64, err := asFloat64(src) if err != nil { err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) } dv.SetFloat(f64) return nil @@ -376,3 +689,80 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { } return v.Interface(), nil } + +var ( + _ sql.Scanner = &NullUint64{} +) + +// NullUint64 represents an uint64 that may be null. +// NullUint64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint64 struct { + Uint64 uint64 + Valid bool +} + +// Scan implements the Scanner interface. +func (n *NullUint64) Scan(value interface{}) error { + if value == nil { + n.Uint64, n.Valid = 0, false + return nil + } + n.Valid = true + var err error + n.Uint64, err = asUint64(value) + return err +} + +// Value implements the driver Valuer interface. +func (n NullUint64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Uint64, nil +} + +var ( + _ sql.Scanner = &NullUint32{} +) + +// NullUint32 represents an uint32 that may be null. +// NullUint32 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint32 struct { + Uint32 uint32 + Valid bool // Valid is true if Uint32 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullUint32) Scan(value interface{}) error { + if value == nil { + n.Uint32, n.Valid = 0, false + return nil + } + n.Valid = true + i64, err := asUint64(value) + if err != nil { + return err + } + n.Uint32 = uint32(i64) + return nil +} + +// Value implements the driver Valuer interface. +func (n NullUint32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Uint32), nil +} + +var ( + _ sql.Scanner = &EmptyScanner{} +) + +type EmptyScanner struct{} + +func (EmptyScanner) Scan(value interface{}) error { + return nil +} diff --git a/dialects/driver.go b/dialects/driver.go index c511b665..0b6187d3 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -18,9 +18,14 @@ type ScanContext struct { UserLocation *time.Location } +type DriverFeatures struct { + SupportNullable bool +} + // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) + Features() DriverFeatures GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } @@ -77,3 +82,9 @@ type baseDriver struct{} func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { return rows.Scan(v...) } + +func (b *baseDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: true, + } +} diff --git a/dialects/mysql.go b/dialects/mysql.go index 03bc9a4b..a341ce05 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -633,6 +633,7 @@ func (db *mysql) Filters() []Filter { } type mysqlDriver struct { + baseDriver } func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { diff --git a/dialects/postgres.go b/dialects/postgres.go index e4641509..a2611c60 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1302,6 +1302,12 @@ type pqDriver struct { baseDriver } +func (b *pqDriver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} + type values map[string]string func (vs values) Set(k, v string) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 306f377c..1bc0b218 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -576,3 +576,9 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { return &r, nil } } + +func (b *sqlite3Driver) Features() DriverFeatures { + return DriverFeatures{ + SupportNullable: false, + } +} diff --git a/scan.go b/scan.go index e11d6e8d..c5cb77ff 100644 --- a/scan.go +++ b/scan.go @@ -6,12 +6,120 @@ package xorm import ( "database/sql" + "fmt" + "reflect" + "time" "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/dialects" ) +// genScanResultsByBeanNullabale generates scan result +func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case *string: + return &sql.NullString{}, true, nil + case *int, *int8, *int16, *int32: + return &sql.NullInt32{}, true, nil + case *int64: + return &sql.NullInt64{}, true, nil + case *uint, *uint8, *uint16, *uint32: + return &NullUint32{}, true, nil + case *uint64: + return &NullUint64{}, true, nil + case *float32, *float64: + return &sql.NullFloat64{}, true, nil + case *bool: + return &sql.NullBool{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return &sql.NullString{}, true, nil + case reflect.Int64: + return &sql.NullInt64{}, true, nil + case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8: + return &sql.NullInt32{}, true, nil + case reflect.Uint64: + return &NullUint64{}, true, nil + case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8: + return &NullUint32{}, true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + +func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, + *string, + *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, + *float32, *float64, + *bool: + return t, false, nil + case *time.Time: + return &sql.NullTime{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return new(string), true, nil + case reflect.Int64: + return new(int64), true, nil + case reflect.Int32: + return new(int32), true, nil + case reflect.Int: + return new(int32), true, nil + case reflect.Int16: + return new(int32), true, nil + case reflect.Int8: + return new(int32), true, nil + case reflect.Uint64: + return new(uint64), true, nil + case reflect.Uint32: + return new(uint32), true, nil + case reflect.Uint: + return new(uint), true, nil + case reflect.Uint16: + return new(uint16), true, nil + case reflect.Uint8: + return new(uint8), true, nil + case reflect.Float32: + return new(float32), true, nil + case reflect.Float64: + return new(float64), true, nil + default: + return nil, false, fmt.Errorf("unsupported type: %#v", bean) + } +} + func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { @@ -50,18 +158,97 @@ func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (ma return result, nil } -func row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { - results := make([]string, 0, len(fields)) - var scanResults = make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { +func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResults = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { var s sql.NullString scanResults[i] = &s } - if err := rows.Scan(scanResults...); err != nil { + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResults...); err != nil { + return nil, err + } + return scanResults, nil +} + +// scan is a wrap of driver.Scan but will automatically change the input values according requirements +func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error { + var scanResults = make([]interface{}, 0, len(types)) + var replaces = make([]bool, 0, len(types)) + var err error + for _, v := range vv { + var replaced bool + var scanResult interface{} + if _, ok := v.(sql.Scanner); !ok { + var useNullable = true + if engine.driver.Features().SupportNullable { + nullable, ok := types[0].Nullable() + useNullable = ok && nullable + } + + if useNullable { + scanResult, replaced, err = genScanResultsByBeanNullable(v) + } else { + scanResult, replaced, err = genScanResultsByBean(v) + } + if err != nil { + return err + } + } else { + scanResult = v + } + scanResults = append(scanResults, scanResult) + replaces = append(replaces, replaced) + } + + var scanCtx = dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + } + + if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil { + return err + } + + for i, replaced := range replaces { + if replaced { + if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil { + return err + } + } + } + + return nil +} + +func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { + var scanResultContainers = make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResultContainers...); err != nil { + return nil, err + } + return scanResultContainers, nil +} + +func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { + scanResults, err := engine.scanStringInterface(rows, types) + if err != nil { return nil, err } + var results = make([]string, 0, len(fields)) for i := 0; i < len(fields); i++ { results = append(results, scanResults[i].(*sql.NullString).String) } diff --git a/session_find.go b/session_find.go index 0daea005..261e6b7f 100644 --- a/session_find.go +++ b/session_find.go @@ -276,7 +276,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { cols := table.PKColumns() if len(cols) == 1 { - return convertAssign(dst, pk[0]) + return convertAssign(dst, pk[0], nil, nil) } dst = pk diff --git a/session_get.go b/session_get.go index e303176d..cb2bda75 100644 --- a/session_get.go +++ b/session_get.go @@ -6,12 +6,16 @@ package xorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" "strconv" + "time" "xorm.io/xorm/caches" + "xorm.io/xorm/convert" + "xorm.io/xorm/core" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -108,6 +112,17 @@ func (session *Session) get(bean interface{}) (bool, error) { return true, nil } +var ( + valuerTypePlaceHolder driver.Valuer + valuerType = reflect.TypeOf(&valuerTypePlaceHolder).Elem() + + scannerTypePlaceHolder sql.Scanner + scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() + + conversionTypePlaceHolder convert.Conversion + conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() +) + func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { @@ -122,155 +137,161 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, return false, nil } - switch bean.(type) { - case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString: - return true, rows.Scan(&bean) - case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: - return true, rows.Scan(bean) - case *string: - var res sql.NullString - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*string)) = res.String - } - return true, nil - case *int: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int)) = int(res.Int64) - } - return true, nil - case *int8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int8)) = int8(res.Int64) - } - return true, nil - case *int16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int16)) = int16(res.Int64) - } - return true, nil - case *int32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int32)) = int32(res.Int64) - } - return true, nil - case *int64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int64)) = int64(res.Int64) - } - return true, nil - case *uint: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint)) = uint(res.Int64) - } - return true, nil - case *uint8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint8)) = uint8(res.Int64) - } - return true, nil - case *uint16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint16)) = uint16(res.Int64) - } - return true, nil - case *uint32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint32)) = uint32(res.Int64) - } - return true, nil - case *uint64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint64)) = uint64(res.Int64) - } - return true, nil - case *bool: - var res sql.NullBool - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*bool)) = res.Bool - } - return true, nil + // WARN: Alougth rows return true, but we may also return error. + types, err := rows.ColumnTypes() + if err != nil { + return true, err + } + fields, err := rows.Columns() + if err != nil { + return true, err } - switch beanKind { case reflect.Struct: - fields, err := rows.Columns() - if err != nil { - // WARN: Alougth rows return true, but get fields failed - return true, err + if _, ok := bean.(*time.Time); ok { + break } - - scanResults, err := session.row2Slice(rows, fields, bean) - if err != nil { - return false, err + if _, ok := bean.(sql.Scanner); ok { + break } - // close it before convert data - rows.Close() - - dataStruct := utils.ReflectValue(bean) - _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { - return true, err + if _, ok := bean.(convert.Conversion); len(types) == 1 && ok { + break } - - return true, session.executeProcessors() + return session.getStruct(rows, types, fields, table, bean) case reflect.Slice: - err = rows.ScanSlice(bean) + return session.getSlice(rows, types, fields, bean) case reflect.Map: - err = rows.ScanMap(bean) - case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - err = rows.Scan(bean) - default: - err = rows.Scan(bean) + return session.getMap(rows, types, fields, bean) } - return true, err + return session.getVars(rows, types, fields, bean) +} + +func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *[]string: + res, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + + var needAppend = len(*t) == 0 // both support slice is empty or has been initlized + for i, r := range res { + if needAppend { + *t = append(*t, r.(*sql.NullString).String) + } else { + (*t)[i] = r.(*sql.NullString).String + } + } + return true, nil + case *[]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + var needAppend = len(*t) == 0 + for ii := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + if needAppend { + *t = append(*t, s) + } else { + (*t)[ii] = s + } + } + return true, nil + default: + return true, fmt.Errorf("unspoorted slice type: %t", t) + } +} + +func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { + switch t := bean.(type) { + case *map[string]string: + scanResults, err := session.engine.scanStringInterface(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + (*t)[key] = scanResults[ii].(*sql.NullString).String + } + return true, nil + case *map[string]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, types) + if err != nil { + return true, err + } + for ii, key := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return true, err + } + (*t)[key] = s + } + return true, nil + default: + return true, fmt.Errorf("unspoorted map type: %t", t) + } +} + +func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) { + if len(beans) != len(types) { + return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) + } + var scanResults = make([]interface{}, 0, len(types)) + var replaceds = make([]bool, 0, len(types)) + for _, bean := range beans { + switch t := bean.(type) { + case sql.Scanner: + scanResults = append(scanResults, t) + replaceds = append(replaceds, false) + case convert.Conversion: + scanResults = append(scanResults, &sql.RawBytes{}) + replaceds = append(replaceds, true) + default: + scanResults = append(scanResults, bean) + replaceds = append(replaceds, false) + } + } + + err := session.engine.scan(rows, fields, types, scanResults...) + if err != nil { + return true, err + } + for i, replaced := range replaceds { + if replaced { + err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation) + if err != nil { + return true, err + } + } + } + return true, nil +} + +func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { + fields, err := rows.Columns() + if err != nil { + // WARN: Alougth rows return true, but get fields failed + return true, err + } + + scanResults, err := session.row2Slice(rows, fields, bean) + if err != nil { + return false, err + } + // close it before convert data + rows.Close() + + dataStruct := utils.ReflectValue(bean) + _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + if err != nil { + return true, err + } + + return true, session.executeProcessors() } func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { diff --git a/session_insert.go b/session_insert.go index e733e06e..7f8f3008 100644 --- a/session_insert.go +++ b/session_insert.go @@ -375,7 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id) + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -415,7 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id) + return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) } res, err := session.exec(sqlStr, args...) @@ -455,7 +455,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - if err := convertAssignV(aiValue.Addr(), id); err != nil { + if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil { return 0, err } diff --git a/session_query.go b/session_query.go index 01cd6f44..fa33496d 100644 --- a/session_query.go +++ b/session_query.go @@ -54,7 +54,7 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { - record, err := row2sliceStr(rows, types, fields) + record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err } From 8f64a78cd4b287c4f24e789767b15f229b0b7ee7 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 6 Jul 2021 17:11:45 +0800 Subject: [PATCH 09/31] Support delete with no bean (#1926) Now you can use delete like this: ``` orm.Table("my_table").Where("id=?",1).Delete() ``` Reviewed-on: https://gitea.com/xorm/xorm/pulls/1926 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- README.md | 19 +++++--- README_CN.md | 3 ++ engine.go | 4 +- integrations/session_delete_test.go | 25 ++++++++++ interface.go | 2 +- session_delete.go | 76 +++++++++++++++++------------ 6 files changed, 87 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 67380839..40826f13 100644 --- a/README.md +++ b/README.md @@ -245,35 +245,38 @@ for rows.Next() { ```Go affected, err := engine.ID(1).Update(&user) -// UPDATE user SET ... Where id = ? +// UPDATE user SET ... WHERE id = ? affected, err := engine.Update(&user, &User{Name:name}) -// UPDATE user SET ... Where name = ? +// UPDATE user SET ... WHERE name = ? var ids = []int64{1, 2, 3} affected, err := engine.In("id", ids).Update(&user) -// UPDATE user SET ... Where id IN (?, ?, ?) +// UPDATE user SET ... WHERE id IN (?, ?, ?) // force update indicated columns by Cols affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12}) -// UPDATE user SET age = ?, updated=? Where id = ? +// UPDATE user SET age = ?, updated=? WHERE id = ? // force NOT update indicated columns by Omit affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12}) -// UPDATE user SET age = ?, updated=? Where id = ? +// UPDATE user SET age = ?, updated=? WHERE id = ? affected, err := engine.ID(1).AllCols().Update(&user) -// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ? +// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? WHERE id = ? ``` * `Delete` delete one or more records, Delete MUST have condition ```Go affected, err := engine.Where(...).Delete(&user) -// DELETE FROM user Where ... +// DELETE FROM user WHERE ... affected, err := engine.ID(2).Delete(&user) -// DELETE FROM user Where id = ? +// DELETE FROM user WHERE id = ? + +affected, err := engine.Table("user").Where(...).Delete() +// DELETE FROM user WHERE ... ``` * `Count` count records diff --git a/README_CN.md b/README_CN.md index 80245dd3..06706417 100644 --- a/README_CN.md +++ b/README_CN.md @@ -271,6 +271,9 @@ affected, err := engine.Where(...).Delete(&user) affected, err := engine.ID(2).Delete(&user) // DELETE FROM user Where id = ? + +affected, err := engine.Table("user").Where(...).Delete() +// DELETE FROM user WHERE ... ``` * `Count` 获取记录条数 diff --git a/engine.go b/engine.go index 1064e8e1..a45771a2 100644 --- a/engine.go +++ b/engine.go @@ -1202,10 +1202,10 @@ func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64 } // Delete records, bean's non-empty fields are conditions -func (engine *Engine) Delete(bean interface{}) (int64, error) { +func (engine *Engine) Delete(beans ...interface{}) (int64, error) { session := engine.NewSession() defer session.Close() - return session.Delete(bean) + return session.Delete(beans...) } // Get retrieve one record from table, bean's non-empty fields diff --git a/integrations/session_delete_test.go b/integrations/session_delete_test.go index cc7e861d..56f6f5b8 100644 --- a/integrations/session_delete_test.go +++ b/integrations/session_delete_test.go @@ -241,3 +241,28 @@ func TestUnscopeDelete(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func TestDelete2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoDelete2 struct { + Uid int64 `xorm:"id pk not null autoincr"` + IsMan bool + } + + assert.NoError(t, testEngine.Sync2(new(UserinfoDelete2))) + + user := UserinfoDelete2{} + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Table("userinfo_delete2").In("id", []int{1}).Delete() + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + user2 := UserinfoDelete2{} + has, err := testEngine.ID(1).Get(&user2) + assert.NoError(t, err) + assert.False(t, has) +} diff --git a/interface.go b/interface.go index d31323ff..fbb81015 100644 --- a/interface.go +++ b/interface.go @@ -30,7 +30,7 @@ type Interface interface { CreateUniques(bean interface{}) error Decr(column string, arg ...interface{}) *Session Desc(...string) *Session - Delete(interface{}) (int64, error) + Delete(...interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error Exec(sqlOrArgs ...interface{}) (sql.Result, error) diff --git a/session_delete.go b/session_delete.go index 13bf791f..baabb558 100644 --- a/session_delete.go +++ b/session_delete.go @@ -83,7 +83,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri } // Delete records, bean's non-empty fields are conditions -func (session *Session) Delete(bean interface{}) (int64, error) { +func (session *Session) Delete(beans ...interface{}) (int64, error) { if session.isAutoClose { defer session.Close() } @@ -92,20 +92,32 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return 0, session.statement.LastError } - if err := session.statement.SetRefBean(bean); err != nil { - return 0, err + var ( + condSQL string + condArgs []interface{} + err error + bean interface{} + ) + if len(beans) > 0 { + bean = beans[0] + if err = session.statement.SetRefBean(bean); err != nil { + return 0, err + } + + executeBeforeClosures(session, bean) + + if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { + processor.BeforeDelete() + } + + condSQL, condArgs, err = session.statement.GenConds(bean) + } else { + condSQL, condArgs, err = session.statement.GenCondSQL(session.statement.Conds()) } - - executeBeforeClosures(session, bean) - - if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { - processor.BeforeDelete() - } - - condSQL, condArgs, err := session.statement.GenConds(bean) if err != nil { return 0, err } + pLimitN := session.statement.LimitN if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { return 0, ErrNeedDeletedCond @@ -156,7 +168,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { var realSQL string argsForCache := make([]interface{}, 0, len(condArgs)*2) - if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled + if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled realSQL = deleteSQL copy(argsForCache, condArgs) argsForCache = append(condArgs, argsForCache...) @@ -220,27 +232,29 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return 0, err } - // handle after delete processors - if session.isAutoCommit { - for _, closure := range session.afterClosures { - closure(bean) - } - if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { - processor.AfterDelete() - } - } else { - lenAfterClosures := len(session.afterClosures) - if lenAfterClosures > 0 { - if value, has := session.afterDeleteBeans[bean]; has && value != nil { - *value = append(*value, session.afterClosures...) - } else { - afterClosures := make([]func(interface{}), lenAfterClosures) - copy(afterClosures, session.afterClosures) - session.afterDeleteBeans[bean] = &afterClosures + if bean != nil { + // handle after delete processors + if session.isAutoCommit { + for _, closure := range session.afterClosures { + closure(bean) + } + if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { + processor.AfterDelete() } } else { - if _, ok := interface{}(bean).(AfterDeleteProcessor); ok { - session.afterDeleteBeans[bean] = nil + lenAfterClosures := len(session.afterClosures) + if lenAfterClosures > 0 && len(beans) > 0 { + if value, has := session.afterDeleteBeans[beans[0]]; has && value != nil { + *value = append(*value, session.afterClosures...) + } else { + afterClosures := make([]func(interface{}), lenAfterClosures) + copy(afterClosures, session.afterClosures) + session.afterDeleteBeans[bean] = &afterClosures + } + } else { + if _, ok := interface{}(bean).(AfterDeleteProcessor); ok { + session.afterDeleteBeans[bean] = nil + } } } } From c433fd51cb154e42e4cd0163e17545616f25bdbd Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 6 Jul 2021 23:20:17 +0800 Subject: [PATCH 10/31] Nil ptr is nullable (#1919) replace #661 Co-authored-by: Jim Salem Co-authored-by: Oleh Herych Reviewed-on: https://gitea.com/xorm/xorm/pulls/1919 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- integrations/types_test.go | 45 ++++++++++++++++++++++++++++++++------ interface.go | 1 + 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/integrations/types_test.go b/integrations/types_test.go index 539171d5..f192c1ff 100644 --- a/integrations/types_test.go +++ b/integrations/types_test.go @@ -147,13 +147,39 @@ func (s *SliceType) ToDB() ([]byte, error) { return json.DefaultJSONHandler.Marshal(s) } +type Nullable struct { + Data string +} + +func (s *Nullable) FromDB(data []byte) error { + if data == nil { + return nil + } + + *s = Nullable{ + Data: string(data), + } + + return nil +} + +func (s *Nullable) ToDB() ([]byte, error) { + if s == nil { + return nil, nil + } + + return []byte(s.Data), nil +} + type ConvStruct struct { - Conv ConvString - Conv2 *ConvString - Cfg1 ConvConfig - Cfg2 *ConvConfig `xorm:"TEXT"` - Cfg3 convert.Conversion `xorm:"BLOB"` - Slice SliceType + Conv ConvString + Conv2 *ConvString + Cfg1 ConvConfig + Cfg2 *ConvConfig `xorm:"TEXT"` + Cfg3 convert.Conversion `xorm:"BLOB"` + Slice SliceType + Nullable1 *Nullable `xorm:"null"` + Nullable2 *Nullable `xorm:"null"` } func (c *ConvStruct) BeforeSet(name string, cell xorm.Cell) { @@ -176,8 +202,10 @@ func TestConversion(t *testing.T) { c.Cfg2 = &ConvConfig{"xx", 2} c.Cfg3 = &ConvConfig{"zz", 3} c.Slice = []*ConvConfig{{"yy", 4}, {"ff", 5}} + c.Nullable1 = &Nullable{Data: "test"} + c.Nullable2 = nil - _, err := testEngine.Insert(c) + _, err := testEngine.Nullable("nullable2").Insert(c) assert.NoError(t, err) c1 := new(ConvStruct) @@ -219,6 +247,9 @@ func TestConversion(t *testing.T) { assert.EqualValues(t, 2, len(c2.Slice)) assert.EqualValues(t, *c.Slice[0], *c2.Slice[0]) assert.EqualValues(t, *c.Slice[1], *c2.Slice[1]) + assert.NotNil(t, c1.Nullable1) + assert.Equal(t, c1.Nullable1.Data, "test") + assert.Nil(t, c1.Nullable2) } type MyInt int diff --git a/interface.go b/interface.go index fbb81015..5d68f536 100644 --- a/interface.go +++ b/interface.go @@ -51,6 +51,7 @@ type Interface interface { MustCols(columns ...string) *Session NoAutoCondition(...bool) *Session NotIn(string, ...interface{}) *Session + Nullable(...string) *Session Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session Omit(columns ...string) *Session OrderBy(order string) *Session From bb91a0773cbcac4755fbb568c3832c43a32069d5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 11:34:33 +0800 Subject: [PATCH 11/31] Fix postgres genScanResult (#1972) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1972 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- dialects/postgres.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/dialects/postgres.go b/dialects/postgres.go index a2611c60..fd6d871c 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1387,26 +1387,22 @@ func (p *pqDriver) GenScanResult(colType string) (interface{}, error) { case "VARCHAR", "TEXT": var s sql.NullString return &s, nil - case "BIGINT": + case "BIGINT", "BIGSERIAL": var s sql.NullInt64 return &s, nil - case "TINYINT", "INT", "INT8", "INT4": + case "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL": var s sql.NullInt32 return &s, nil - case "FLOAT", "FLOAT4": + case "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION": var s sql.NullFloat64 return &s, nil case "DATETIME", "TIMESTAMP": var s sql.NullTime return &s, nil - case "BIT": - var s sql.RawBytes - return &s, nil case "BOOL": var s sql.NullBool return &s, nil default: - fmt.Printf("unknow postgres database type: %v\n", colType) var r sql.RawBytes return &r, nil } From bece9a6373852d5e4faf51758a4dec969ce01cd9 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 13:03:05 +0800 Subject: [PATCH 12/31] refactor slice2Bean (#1974) as title. Reviewed-on: https://gitea.com/xorm/xorm/pulls/1974 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 2 + session.go | 800 ++++++++++++++++++++++++++--------------------------- 2 files changed, 402 insertions(+), 400 deletions(-) diff --git a/convert.go b/convert.go index f7d733ad..67183098 100644 --- a/convert.go +++ b/convert.go @@ -238,6 +238,8 @@ func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { func convertAssign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { // Common cases, without reflect. switch s := src.(type) { + case *interface{}: + return convertAssign(dest, *s, originalLocation, convertedLocation) case string: switch d := dest.(type) { case *string: diff --git a/session.go b/session.go index 6df9e20d..3fb92991 100644 --- a/session.go +++ b/session.go @@ -436,6 +436,397 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa return scanResults, nil } +func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, + scanResult interface{}, table *schemas.Table) error { + rawValue := reflect.Indirect(reflect.ValueOf(scanResult)) + + // if row is null then ignore + if rawValue.Interface() == nil { + return nil + } + + if fieldValue.CanAddr() { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + data, err := value2Bytes(&rawValue) + if err != nil { + return err + } + if err := structConvert.FromDB(data); err != nil { + return err + } + return nil + } + } + + if _, ok := fieldValue.Interface().(convert.Conversion); ok { + if data, err := value2Bytes(&rawValue); err == nil { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fieldValue.Interface().(convert.Conversion).FromDB(data) + } else { + return err + } + return nil + } + + rawValueType := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + + fieldType := fieldValue.Type() + + if col.IsJSON { + var bs []byte + if rawValueType.Kind() == reflect.String { + bs = []byte(vv.String()) + } else if rawValueType.ConvertibleTo(schemas.BytesType) { + bs = vv.Bytes() + } else { + return fmt.Errorf("unsupported database data type: %s %v", col.Name, rawValueType.Kind()) + } + + if len(bs) > 0 { + if fieldType.Kind() == reflect.String { + fieldValue.SetString(string(bs)) + return nil + } + if fieldValue.CanAddr() { + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return err + } + } else { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } + } + return nil + } + + switch fieldType.Kind() { + case reflect.Complex64, reflect.Complex128: + // TODO: reimplement this + var bs []byte + if rawValueType.Kind() == reflect.String { + bs = []byte(vv.String()) + } else if rawValueType.ConvertibleTo(schemas.BytesType) { + bs = vv.Bytes() + } + + if len(bs) > 0 { + if fieldValue.CanAddr() { + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return err + } + } else { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } + } + return nil + case reflect.Slice, reflect.Array: + switch rawValueType.Kind() { + case reflect.Slice, reflect.Array: + switch rawValueType.Elem().Kind() { + case reflect.Uint8: + if fieldType.Elem().Kind() == reflect.Uint8 { + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } else { + if fieldValue.Len() > 0 { + for i := 0; i < fieldValue.Len(); i++ { + if i < vv.Len() { + fieldValue.Index(i).Set(vv.Index(i)) + } + } + } else { + for i := 0; i < vv.Len(); i++ { + fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) + } + } + } + return nil + } + } + } + case reflect.String: + if rawValueType.Kind() == reflect.String { + fieldValue.SetString(vv.String()) + return nil + } + case reflect.Bool: + if rawValueType.Kind() == reflect.Bool { + fieldValue.SetBool(vv.Bool()) + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch rawValueType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fieldValue.SetInt(vv.Int()) + return nil + } + case reflect.Float32, reflect.Float64: + switch rawValueType.Kind() { + case reflect.Float32, reflect.Float64: + fieldValue.SetFloat(vv.Float()) + return nil + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + switch rawValueType.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + fieldValue.SetUint(vv.Uint()) + return nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fieldValue.SetUint(uint64(vv.Int())) + return nil + } + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + dbTZ := session.engine.DatabaseTZ + if col.TimeZone != nil { + dbTZ = col.TimeZone + } + + if rawValueType == schemas.TimeType { + t := vv.Convert(schemas.TimeType).Interface().(time.Time) + + z, _ := t.Zone() + // set new location if database don't save timezone or give an incorrect timezone + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location + session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", col.Name, t, z, *t.Location()) + t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbTZ) + } + + t = t.In(session.engine.TZLocation) + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + return nil + } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || + rawValueType == schemas.Int32Type { + t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + return nil + } else { + if d, ok := vv.Interface().([]uint8); ok { + t, err := session.byte2Time(col, d) + if err != nil { + session.engine.logger.Errorf("byte2Time error: %v", err) + } else { + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + return nil + } + + } else if d, ok := vv.Interface().(string); ok { + t, err := session.str2Time(col, d) + if err != nil { + session.engine.logger.Errorf("byte2Time error: %v", err) + } else { + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + return nil + } + } else { + return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) + } + } + } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + err := nulVal.Scan(vv.Interface()) + if err == nil { + return nil + } + session.engine.logger.Errorf("sql.Sanner error: %v", err) + } else if col.IsJSON { + if rawValueType.Kind() == reflect.String { + x := reflect.New(fieldType) + if len([]byte(vv.String())) > 0 { + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } + return nil + } else if rawValueType.Kind() == reflect.Slice { + x := reflect.New(fieldType) + if len(vv.Bytes()) > 0 { + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } + return nil + } + } else if session.statement.UseCascade { + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) + if err != nil { + return err + } + + if len(table.PrimaryKeys) != 1 { + return errors.New("unsupported non or composited primary key cascade") + } + var pk = make(schemas.PK, len(table.PrimaryKeys)) + pk[0], err = asKind(vv, rawValueType) + if err != nil { + return err + } + + if !pk.IsZero() { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + structInter := reflect.New(fieldValue.Type()) + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + if err != nil { + return err + } + if has { + fieldValue.Set(structInter.Elem()) + } else { + return errors.New("cascade obj is not exist") + } + } + return nil + } + case reflect.Ptr: + // !nashtsai! TODO merge duplicated codes above + switch fieldType { + // following types case matching ptr's native type, therefore assign ptr directly + case schemas.PtrStringType: + if rawValueType.Kind() == reflect.String { + x := vv.String() + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrBoolType: + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrTimeType: + if rawValueType == schemas.PtrTimeType { + var x = rawValue.Interface().(time.Time) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrFloat64Type: + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrUint64Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint64(vv.Int()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrInt64Type: + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrFloat32Type: + if rawValueType.Kind() == reflect.Float64 { + var x = float32(vv.Float()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrIntType: + if rawValueType.Kind() == reflect.Int64 { + var x = int(vv.Int()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrInt32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int32(vv.Int()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrInt8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int8(vv.Int()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrInt16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int16(vv.Int()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrUintType: + if rawValueType.Kind() == reflect.Int64 { + var x = uint(vv.Int()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.PtrUint32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint32(vv.Int()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.Uint8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint8(vv.Int()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.Uint16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint16(vv.Int()) + fieldValue.Set(reflect.ValueOf(&x)) + return nil + } + case schemas.Complex64Type: + var x complex64 + if len([]byte(vv.String())) > 0 { + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + if err != nil { + return err + } + fieldValue.Set(reflect.ValueOf(&x)) + } + return nil + case schemas.Complex128Type: + var x complex128 + if len([]byte(vv.String())) > 0 { + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + if err != nil { + return err + } + fieldValue.Set(reflect.ValueOf(&x)) + } + return nil + } // switch fieldType + } // switch fieldType.Kind() + + data, err := value2Bytes(&rawValue) + if err != nil { + return err + } + + return session.bytes2Value(col, fieldValue, data) +} + func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { defer func() { executeAfterSet(bean, fields, scanResults) @@ -447,14 +838,19 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b var pk schemas.PK for ii, key := range fields { var idx int - var ok bool var lKey = strings.ToLower(key) + var ok bool + if idx, ok = tempMap[lKey]; !ok { idx = 0 } else { idx = idx + 1 } + tempMap[lKey] = idx + col := table.GetColumnIdx(key, idx) + + var scanResult = scanResults[ii] fieldValue, err := session.getField(dataStruct, key, table, idx) if err != nil { @@ -466,408 +862,12 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b if fieldValue == nil { continue } - rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) - // if row is null then ignore - if rawValue.Interface() == nil { - continue + if err := session.convertBeanField(col, fieldValue, scanResult, table); err != nil { + return nil, err } - - if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { - if err := structConvert.FromDB(data); err != nil { - return nil, err - } - } else { - return nil, err - } - continue - } - } - - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue.Interface().(convert.Conversion).FromDB(data) - } else { - return nil, err - } - continue - } - - rawValueType := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) - col := table.GetColumnIdx(key, idx) if col.IsPrimaryKey { - pk = append(pk, rawValue.Interface()) - } - fieldType := fieldValue.Type() - hasAssigned := false - - if col.IsJSON { - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(schemas.BytesType) { - bs = vv.Bytes() - } else { - return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) - } - - hasAssigned = true - - if len(bs) > 0 { - if fieldType.Kind() == reflect.String { - fieldValue.SetString(string(bs)) - continue - } - if fieldValue.CanAddr() { - err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return nil, err - } - } else { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } - - continue - } - - switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: - // TODO: reimplement this - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(schemas.BytesType) { - bs = vv.Bytes() - } - - hasAssigned = true - if len(bs) > 0 { - if fieldValue.CanAddr() { - err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return nil, err - } - } else { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } - case reflect.Slice, reflect.Array: - switch rawValueType.Kind() { - case reflect.Slice, reflect.Array: - switch rawValueType.Elem().Kind() { - case reflect.Uint8: - if fieldType.Elem().Kind() == reflect.Uint8 { - hasAssigned = true - if col.SQLType.IsText() { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } else { - if fieldValue.Len() > 0 { - for i := 0; i < fieldValue.Len(); i++ { - if i < vv.Len() { - fieldValue.Index(i).Set(vv.Index(i)) - } - } - } else { - for i := 0; i < vv.Len(); i++ { - fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) - } - } - } - } - } - } - case reflect.String: - if rawValueType.Kind() == reflect.String { - hasAssigned = true - fieldValue.SetString(vv.String()) - } - case reflect.Bool: - if rawValueType.Kind() == reflect.Bool { - hasAssigned = true - fieldValue.SetBool(vv.Bool()) - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch rawValueType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true - fieldValue.SetInt(vv.Int()) - } - case reflect.Float32, reflect.Float64: - switch rawValueType.Kind() { - case reflect.Float32, reflect.Float64: - hasAssigned = true - fieldValue.SetFloat(vv.Float()) - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - switch rawValueType.Kind() { - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - hasAssigned = true - fieldValue.SetUint(vv.Uint()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true - fieldValue.SetUint(uint64(vv.Int())) - } - case reflect.Struct: - if fieldType.ConvertibleTo(schemas.TimeType) { - dbTZ := session.engine.DatabaseTZ - if col.TimeZone != nil { - dbTZ = col.TimeZone - } - - if rawValueType == schemas.TimeType { - hasAssigned = true - - t := vv.Convert(schemas.TimeType).Interface().(time.Time) - - z, _ := t.Zone() - // set new location if database don't save timezone or give an incorrect timezone - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) - t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbTZ) - } - - t = t.In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || - rawValueType == schemas.Int32Type { - hasAssigned = true - - t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else { - if d, ok := vv.Interface().([]uint8); ok { - hasAssigned = true - t, err := session.byte2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - hasAssigned = false - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } - } else if d, ok := vv.Interface().(string); ok { - hasAssigned = true - t, err := session.str2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - hasAssigned = false - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } - } else { - return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) - } - } - } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - // !! 增加支持sql.Scanner接口的结构,如sql.NullString - hasAssigned = true - if err := nulVal.Scan(vv.Interface()); err != nil { - session.engine.logger.Errorf("sql.Sanner error: %v", err) - hasAssigned = false - } - } else if col.IsJSON { - if rawValueType.Kind() == reflect.String { - hasAssigned = true - x := reflect.New(fieldType) - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } else if rawValueType.Kind() == reflect.Slice { - hasAssigned = true - x := reflect.New(fieldType) - if len(vv.Bytes()) > 0 { - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } - } else if session.statement.UseCascade { - table, err := session.engine.tagParser.ParseWithCache(*fieldValue) - if err != nil { - return nil, err - } - - hasAssigned = true - if len(table.PrimaryKeys) != 1 { - return nil, errors.New("unsupported non or composited primary key cascade") - } - var pk = make(schemas.PK, len(table.PrimaryKeys)) - pk[0], err = asKind(vv, rawValueType) - if err != nil { - return nil, err - } - - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return nil, err - } - if has { - fieldValue.Set(structInter.Elem()) - } else { - return nil, errors.New("cascade obj is not exist") - } - } - } - case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - switch fieldType { - // following types case matching ptr's native type, therefore assign ptr directly - case schemas.PtrStringType: - if rawValueType.Kind() == reflect.String { - x := vv.String() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrBoolType: - if rawValueType.Kind() == reflect.Bool { - x := vv.Bool() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrTimeType: - if rawValueType == schemas.PtrTimeType { - hasAssigned = true - var x = rawValue.Interface().(time.Time) - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrFloat64Type: - if rawValueType.Kind() == reflect.Float64 { - x := vv.Float() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrUint64Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint64(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrInt64Type: - if rawValueType.Kind() == reflect.Int64 { - x := vv.Int() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrFloat32Type: - if rawValueType.Kind() == reflect.Float64 { - var x = float32(vv.Float()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrIntType: - if rawValueType.Kind() == reflect.Int64 { - var x = int(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrInt32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrInt8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrInt16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrUintType: - if rawValueType.Kind() == reflect.Int64 { - var x = uint(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrUint32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.Uint8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.Uint16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.Complex64Type: - var x complex64 - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) - if err != nil { - return nil, err - } - fieldValue.Set(reflect.ValueOf(&x)) - } - hasAssigned = true - case schemas.Complex128Type: - var x complex128 - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) - if err != nil { - return nil, err - } - fieldValue.Set(reflect.ValueOf(&x)) - } - hasAssigned = true - } // switch fieldType - } // switch fieldType.Kind() - - // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value - if !hasAssigned { - data, err := value2Bytes(&rawValue) - if err != nil { - return nil, err - } - - if err = session.bytes2Value(col, fieldValue, data); err != nil { - return nil, err - } + pk = append(pk, scanResult) } } return pk, nil From 54bbead2be07a68b5b8caceedb041c14f602cb7a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 13:59:48 +0800 Subject: [PATCH 13/31] refactor slice2Bean 2 (#1975) as title. Reviewed-on: https://gitea.com/xorm/xorm/pulls/1975 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- session.go | 132 +++++++---------------------------------------------- 1 file changed, 16 insertions(+), 116 deletions(-) diff --git a/session.go b/session.go index 3fb92991..a3b11889 100644 --- a/session.go +++ b/session.go @@ -472,7 +472,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec rawValueType := reflect.TypeOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface()) - fieldType := fieldValue.Type() if col.IsJSON { @@ -508,6 +507,22 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } switch fieldType.Kind() { + case reflect.Ptr: + if scanResult == nil { + return nil + } + if v, ok := scanResult.(*interface{}); ok && v == nil { + return nil + } + + var e reflect.Value + if fieldValue.IsNil() { + e = reflect.New(fieldType.Elem()).Elem() + } else { + e = fieldValue.Elem() + } + + return session.convertBeanField(col, &e, scanResult, table) case reflect.Complex64, reflect.Complex128: // TODO: reimplement this var bs []byte @@ -702,121 +717,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } return nil } - case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - switch fieldType { - // following types case matching ptr's native type, therefore assign ptr directly - case schemas.PtrStringType: - if rawValueType.Kind() == reflect.String { - x := vv.String() - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrBoolType: - if rawValueType.Kind() == reflect.Bool { - x := vv.Bool() - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrTimeType: - if rawValueType == schemas.PtrTimeType { - var x = rawValue.Interface().(time.Time) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrFloat64Type: - if rawValueType.Kind() == reflect.Float64 { - x := vv.Float() - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrUint64Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint64(vv.Int()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrInt64Type: - if rawValueType.Kind() == reflect.Int64 { - x := vv.Int() - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrFloat32Type: - if rawValueType.Kind() == reflect.Float64 { - var x = float32(vv.Float()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrIntType: - if rawValueType.Kind() == reflect.Int64 { - var x = int(vv.Int()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrInt32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int32(vv.Int()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrInt8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int8(vv.Int()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrInt16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int16(vv.Int()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrUintType: - if rawValueType.Kind() == reflect.Int64 { - var x = uint(vv.Int()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.PtrUint32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint32(vv.Int()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.Uint8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint8(vv.Int()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.Uint16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint16(vv.Int()) - fieldValue.Set(reflect.ValueOf(&x)) - return nil - } - case schemas.Complex64Type: - var x complex64 - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) - if err != nil { - return err - } - fieldValue.Set(reflect.ValueOf(&x)) - } - return nil - case schemas.Complex128Type: - var x complex128 - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) - if err != nil { - return err - } - fieldValue.Set(reflect.ValueOf(&x)) - } - return nil - } // switch fieldType } // switch fieldType.Kind() data, err := value2Bytes(&rawValue) From b754e78269bcd507b117e39d5d7b4064797fa2fc Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 14:00:16 +0800 Subject: [PATCH 14/31] Support big.Float (#1973) Now you can use big.Float for numeric type. ```go type MyMoney struct { Id int64 Money big.Float `xorm:"numeric(22,2)"` } ``` Reviewed-on: https://gitea.com/xorm/xorm/pulls/1973 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 12 ++++++++ dialects/sqlite3.go | 2 +- integrations/session_get_test.go | 51 ++++++++++++++++++++++++++++++ internal/statements/values.go | 5 +++ scan.go | 24 +++++++++------ session_get.go | 53 +++++++++++--------------------- 6 files changed, 102 insertions(+), 45 deletions(-) diff --git a/convert.go b/convert.go index 67183098..491626a8 100644 --- a/convert.go +++ b/convert.go @@ -9,6 +9,7 @@ import ( "database/sql/driver" "errors" "fmt" + "math/big" "reflect" "strconv" "time" @@ -310,10 +311,12 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve if s.Valid { *d, _ = strconv.Atoi(s.String) } + return nil case *int64: if s.Valid { *d, _ = strconv.ParseInt(s.String, 10, 64) } + return nil case *string: if s.Valid { *d = s.String @@ -339,6 +342,15 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve d.Valid = true d.Time = *dt } + return nil + case *big.Float: + if s.Valid { + if d == nil { + d = big.NewFloat(0) + } + d.SetString(s.String) + } + return nil } case *sql.NullInt32: switch d := dest.(type) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 1bc0b218..04e5b457 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -565,7 +565,7 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { case "REAL": var s sql.NullFloat64 return &s, nil - case "NUMERIC": + case "NUMERIC", "DECIMAL": var s sql.NullString return &s, nil case "BLOB": diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 99db98fc..6fc202bc 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -8,6 +8,7 @@ import ( "database/sql" "errors" "fmt" + "math/big" "strconv" "testing" "time" @@ -766,3 +767,53 @@ func TestGetNil(t *testing.T) { assert.True(t, errors.Is(err, xorm.ErrObjectIsNil)) assert.False(t, has) } + +func TestGetBigFloat(t *testing.T) { + type GetBigFloat struct { + Id int64 + Money *big.Float `xorm:"numeric(22,2)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetBigFloat)) + + { + var gf = GetBigFloat{ + Money: big.NewFloat(999999.99), + } + _, err := testEngine.Insert(&gf) + assert.NoError(t, err) + + var m big.Float + has, err := testEngine.Table("get_big_float").Cols("money").Where("id=?", gf.Id).Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) + //fmt.Println(m.Cmp(gf.Money)) + //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + } + + type GetBigFloat2 struct { + Id int64 + Money *big.Float `xorm:"decimal(22,2)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetBigFloat2)) + + { + var gf2 = GetBigFloat2{ + Money: big.NewFloat(9999999.99), + } + _, err := testEngine.Insert(&gf2) + assert.NoError(t, err) + + var m2 big.Float + has, err := testEngine.Table("get_big_float2").Cols("money").Where("id=?", gf2.Id).Get(&m2) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String()) + //fmt.Println(m.Cmp(gf.Money)) + //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + } +} diff --git a/internal/statements/values.go b/internal/statements/values.go index 71327c55..994070ac 100644 --- a/internal/statements/values.go +++ b/internal/statements/values.go @@ -8,6 +8,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "math/big" "reflect" "time" @@ -19,6 +20,7 @@ import ( var ( nullFloatType = reflect.TypeOf(sql.NullFloat64{}) + bigFloatType = reflect.TypeOf(big.Float{}) ) // Value2Interface convert a field value of a struct to interface for puting into database @@ -84,6 +86,9 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl return nil, nil } return t.Float64, nil + } else if fieldType.ConvertibleTo(bigFloatType) { + t := fieldValue.Convert(bigFloatType).Interface().(big.Float) + return t.String(), nil } if !col.IsJSON { diff --git a/scan.go b/scan.go index c5cb77ff..6396b097 100644 --- a/scan.go +++ b/scan.go @@ -7,6 +7,7 @@ package xorm import ( "database/sql" "fmt" + "math/big" "reflect" "time" @@ -182,13 +183,21 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column for _, v := range vv { var replaced bool var scanResult interface{} - if _, ok := v.(sql.Scanner); !ok { + switch t := v.(type) { + case sql.Scanner: + scanResult = t + case convert.Conversion: + scanResult = &sql.RawBytes{} + replaced = true + case *big.Float: + scanResult = &sql.NullString{} + replaced = true + default: var useNullable = true if engine.driver.Features().SupportNullable { nullable, ok := types[0].Nullable() useNullable = ok && nullable } - if useNullable { scanResult, replaced, err = genScanResultsByBeanNullable(v) } else { @@ -197,25 +206,22 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column if err != nil { return err } - } else { - scanResult = v } + scanResults = append(scanResults, scanResult) replaces = append(replaces, replaced) } - var scanCtx = dialects.ScanContext{ + if err = engine.driver.Scan(&dialects.ScanContext{ DBLocation: engine.DatabaseTZ, UserLocation: engine.TZLocation, - } - - if err = engine.driver.Scan(&scanCtx, rows, types, scanResults...); err != nil { + }, rows, types, scanResults...); err != nil { return err } for i, replaced := range replaces { if replaced { - if err = convertAssign(vv[i], scanResults[i], scanCtx.DBLocation, engine.TZLocation); err != nil { + if err = convertAssign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil { return err } } diff --git a/session_get.go b/session_get.go index cb2bda75..58255033 100644 --- a/session_get.go +++ b/session_get.go @@ -9,6 +9,7 @@ import ( "database/sql/driver" "errors" "fmt" + "math/big" "reflect" "strconv" "time" @@ -123,6 +124,20 @@ var ( conversionType = reflect.TypeOf(&conversionTypePlaceHolder).Elem() ) +func isScannableStruct(bean interface{}, typeLen int) bool { + switch bean.(type) { + case *time.Time: + return false + case sql.Scanner: + return false + case convert.Conversion: + return typeLen > 1 + case *big.Float: + return false + } + return true +} + func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { @@ -148,13 +163,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, } switch beanKind { case reflect.Struct: - if _, ok := bean.(*time.Time); ok { - break - } - if _, ok := bean.(sql.Scanner); ok { - break - } - if _, ok := bean.(convert.Conversion); len(types) == 1 && ok { + if !isScannableStruct(bean, len(types)) { break } return session.getStruct(rows, types, fields, table, bean) @@ -240,35 +249,9 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields if len(beans) != len(types) { return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) } - var scanResults = make([]interface{}, 0, len(types)) - var replaceds = make([]bool, 0, len(types)) - for _, bean := range beans { - switch t := bean.(type) { - case sql.Scanner: - scanResults = append(scanResults, t) - replaceds = append(replaceds, false) - case convert.Conversion: - scanResults = append(scanResults, &sql.RawBytes{}) - replaceds = append(replaceds, true) - default: - scanResults = append(scanResults, bean) - replaceds = append(replaceds, false) - } - } - err := session.engine.scan(rows, fields, types, scanResults...) - if err != nil { - return true, err - } - for i, replaced := range replaceds { - if replaced { - err = convertAssign(beans[i], scanResults[i], session.engine.DatabaseTZ, session.engine.TZLocation) - if err != nil { - return true, err - } - } - } - return true, nil + err := session.engine.scan(rows, fields, types, beans...) + return true, err } func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { From 46fd8f58b3b925ac7305b879e0c1f4a2fc8ad140 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 15:46:21 +0800 Subject: [PATCH 15/31] Get struct and Find support big.Float (#1976) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1976 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 89 ++++++++++++++++++++++++++++---- integrations/session_get_test.go | 12 +++++ schemas/type.go | 5 +- session.go | 34 ++++++++---- 4 files changed, 120 insertions(+), 20 deletions(-) diff --git a/convert.go b/convert.go index 491626a8..20a6e373 100644 --- a/convert.go +++ b/convert.go @@ -104,9 +104,7 @@ func asInt64(src interface{}) (int64, error) { return rv.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return int64(rv.Uint()), nil - case reflect.Float64: - return int64(rv.Float()), nil - case reflect.Float32: + case reflect.Float64, reflect.Float32: return int64(rv.Float()), nil case reflect.String: return strconv.ParseInt(rv.String(), 10, 64) @@ -154,9 +152,7 @@ func asUint64(src interface{}) (uint64, error) { return uint64(rv.Int()), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return uint64(rv.Uint()), nil - case reflect.Float64: - return uint64(rv.Float()), nil - case reflect.Float32: + case reflect.Float64, reflect.Float32: return uint64(rv.Float()), nil case reflect.String: return strconv.ParseUint(rv.String(), 10, 64) @@ -204,9 +200,7 @@ func asFloat64(src interface{}) (float64, error) { return float64(rv.Int()), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return float64(rv.Uint()), nil - case reflect.Float64: - return float64(rv.Float()), nil - case reflect.Float32: + case reflect.Float64, reflect.Float32: return float64(rv.Float()), nil case reflect.String: return strconv.ParseFloat(rv.String(), 64) @@ -214,6 +208,83 @@ func asFloat64(src interface{}) (float64, error) { return 0, fmt.Errorf("unsupported value %T as int64", src) } +func asBigFloat(src interface{}) (*big.Float, error) { + res := big.NewFloat(0) + switch v := src.(type) { + case int: + res.SetInt64(int64(v)) + return res, nil + case int16: + res.SetInt64(int64(v)) + return res, nil + case int32: + res.SetInt64(int64(v)) + return res, nil + case int8: + res.SetInt64(int64(v)) + return res, nil + case int64: + res.SetInt64(int64(v)) + return res, nil + case uint: + res.SetUint64(uint64(v)) + return res, nil + case uint8: + res.SetUint64(uint64(v)) + return res, nil + case uint16: + res.SetUint64(uint64(v)) + return res, nil + case uint32: + res.SetUint64(uint64(v)) + return res, nil + case uint64: + res.SetUint64(uint64(v)) + return res, nil + case []byte: + res.SetString(string(v)) + return res, nil + case string: + res.SetString(v) + return res, nil + case *sql.NullString: + if v.Valid { + res.SetString(v.String) + return res, nil + } + return nil, nil + case *sql.NullInt32: + if v.Valid { + res.SetInt64(int64(v.Int32)) + return res, nil + } + return nil, nil + case *sql.NullInt64: + if v.Valid { + res.SetInt64(int64(v.Int64)) + return res, nil + } + return nil, nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + res.SetInt64(rv.Int()) + return res, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + res.SetUint64(rv.Uint()) + return res, nil + case reflect.Float64, reflect.Float32: + res.SetFloat64(rv.Float()) + return res, nil + case reflect.String: + res.SetString(rv.String()) + return res, nil + } + return nil, fmt.Errorf("unsupported value %T as big.Float", src) +} + func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 6fc202bc..02b060b1 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -815,5 +815,17 @@ func TestGetBigFloat(t *testing.T) { assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String()) //fmt.Println(m.Cmp(gf.Money)) //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + + var gf3 GetBigFloat2 + has, err = testEngine.ID(gf2.Id).Get(&gf3) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, gf3.Money.String() == gf2.Money.String(), "%v != %v", gf3.Money.String(), gf2.Money.String()) + + var gfs []GetBigFloat2 + err = testEngine.Find(&gfs) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(gfs)) + assert.True(t, gfs[0].Money.String() == gf2.Money.String(), "%v != %v", gfs[0].Money.String(), gf2.Money.String()) } } diff --git a/schemas/type.go b/schemas/type.go index fc02f015..3846b5ee 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -5,6 +5,7 @@ package schemas import ( + "math/big" "reflect" "sort" "strings" @@ -240,6 +241,7 @@ var ( intDefault int uintDefault uint timeDefault time.Time + bigFloatDefault big.Float ) // enumerates all types @@ -267,7 +269,8 @@ var ( ByteType = reflect.TypeOf(byteDefault) BytesType = reflect.SliceOf(ByteType) - TimeType = reflect.TypeOf(timeDefault) + TimeType = reflect.TypeOf(timeDefault) + BigFloatType = reflect.TypeOf(bigFloatDefault) ) // enumerates all types diff --git a/session.go b/session.go index a3b11889..64b1758a 100644 --- a/session.go +++ b/session.go @@ -438,8 +438,15 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, scanResult interface{}, table *schemas.Table) error { - rawValue := reflect.Indirect(reflect.ValueOf(scanResult)) + v, ok := scanResult.(*interface{}) + if ok { + scanResult = *v + } + if scanResult == nil { + return nil + } + rawValue := reflect.Indirect(reflect.ValueOf(scanResult)) // if row is null then ignore if rawValue.Interface() == nil { return nil @@ -508,21 +515,19 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec switch fieldType.Kind() { case reflect.Ptr: - if scanResult == nil { - return nil - } - if v, ok := scanResult.(*interface{}); ok && v == nil { - return nil - } - var e reflect.Value if fieldValue.IsNil() { e = reflect.New(fieldType.Elem()).Elem() } else { e = fieldValue.Elem() } - - return session.convertBeanField(col, &e, scanResult, table) + if err := session.convertBeanField(col, &e, scanResult, table); err != nil { + return err + } + if fieldValue.IsNil() { + fieldValue.Set(e.Addr()) + } + return nil case reflect.Complex64, reflect.Complex128: // TODO: reimplement this var bs []byte @@ -610,6 +615,15 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return nil } case reflect.Struct: + if fieldType.ConvertibleTo(schemas.BigFloatType) { + v, err := asBigFloat(scanResult) + if err != nil { + return err + } + fieldValue.Set(reflect.ValueOf(v).Elem().Convert(fieldType)) + return nil + } + if fieldType.ConvertibleTo(schemas.TimeType) { dbTZ := session.engine.DatabaseTZ if col.TimeZone != nil { From a38b1fddb33050b8378fb7ba292185935ba89a76 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 17:00:58 +0800 Subject: [PATCH 16/31] Add tests for github.com/shopspring/decimal support (#1977) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1977 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- go.mod | 1 + go.sum | 2 ++ integrations/session_get_test.go | 52 ++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/go.mod b/go.mod index f6e4af90..78d8d7d4 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/json-iterator/go v1.1.11 github.com/lib/pq v1.7.0 github.com/mattn/go-sqlite3 v1.14.6 + github.com/shopspring/decimal v1.2.0 github.com/stretchr/testify v1.4.0 github.com/syndtr/goleveldb v1.0.0 github.com/ziutek/mymysql v1.5.4 diff --git a/go.sum b/go.sum index 3c79850c..85953202 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 02b060b1..f4338b4f 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -17,6 +17,7 @@ import ( "xorm.io/xorm/contexts" "xorm.io/xorm/schemas" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" ) @@ -829,3 +830,54 @@ func TestGetBigFloat(t *testing.T) { assert.True(t, gfs[0].Money.String() == gf2.Money.String(), "%v != %v", gfs[0].Money.String(), gf2.Money.String()) } } + +func TestGetDecimal(t *testing.T) { + type GetDecimal struct { + Id int64 + Money decimal.Decimal `xorm:"decimal(22,2)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetDecimal)) + + { + var gf = GetDecimal{ + Money: decimal.NewFromFloat(999999.99), + } + _, err := testEngine.Insert(&gf) + assert.NoError(t, err) + + var m decimal.Decimal + has, err := testEngine.Table("get_decimal").Cols("money").Where("id=?", gf.Id).Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) + //fmt.Println(m.Cmp(gf.Money)) + //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + } + + type GetDecimal2 struct { + Id int64 + Money *decimal.Decimal `xorm:"decimal(22,2)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetDecimal2)) + + { + v := decimal.NewFromFloat(999999.99) + var gf = GetDecimal2{ + Money: &v, + } + _, err := testEngine.Insert(&gf) + assert.NoError(t, err) + + var m decimal.Decimal + has, err := testEngine.Table("get_decimal2").Cols("money").Where("id=?", gf.Id).Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) + //fmt.Println(m.Cmp(gf.Money)) + //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + } +} From 717e4a0d2177e14d02a3cd74aa4bf5e9f88a3731 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 17:09:40 +0800 Subject: [PATCH 17/31] Add database alias table and fix wrong warning (#1947) fix #1831 Reviewed-on: https://gitea.com/xorm/xorm/pulls/1947 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- dialects/dialect.go | 6 ++++++ dialects/mysql.go | 15 +++++++++++++++ dialects/postgres.go | 12 ++++++++++++ integrations/session_schema_test.go | 13 +++++++++++++ schemas/type.go | 6 ++++++ session_schema.go | 6 ++++-- 6 files changed, 56 insertions(+), 2 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index 325836b4..b3d374cc 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -43,6 +43,7 @@ type Dialect interface { Init(*URI) error URI() *URI SQLType(*schemas.Column) string + Alias(string) string // return what a sql type's alias of FormatBytes(b []byte) string Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) @@ -80,6 +81,11 @@ type Base struct { quoter schemas.Quoter } +// Alias returned col itself +func (db *Base) Alias(col string) string { + return col +} + // Quoter returns the current database Quoter func (db *Base) Quoter() schemas.Quoter { return db.quoter diff --git a/dialects/mysql.go b/dialects/mysql.go index a341ce05..da19b820 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -190,6 +190,21 @@ func (db *mysql) Init(uri *URI) error { return db.Base.Init(db, uri) } +var ( + mysqlColAliases = map[string]string{ + "numeric": "decimal", + } +) + +// Alias returns a alias of column +func (db *mysql) Alias(col string) string { + v, ok := mysqlColAliases[strings.ToLower(col)] + if ok { + return v + } + return col +} + func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { rows, err := queryer.QueryContext(ctx, "SELECT @@VERSION") if err != nil { diff --git a/dialects/postgres.go b/dialects/postgres.go index fd6d871c..9f3c7275 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -778,12 +778,24 @@ var ( var ( // DefaultPostgresSchema default postgres schema DefaultPostgresSchema = "public" + postgresColAliases = map[string]string{ + "numeric": "decimal", + } ) type postgres struct { Base } +// Alias returns a alias of column +func (db *postgres) Alias(col string) string { + v, ok := postgresColAliases[strings.ToLower(col)] + if ok { + return v + } + return col +} + func (db *postgres) Init(uri *URI) error { db.quoter = postgresQuoter return db.Base.Init(db, uri) diff --git a/integrations/session_schema_test.go b/integrations/session_schema_test.go index 28c75119..9cbebcbf 100644 --- a/integrations/session_schema_test.go +++ b/integrations/session_schema_test.go @@ -286,6 +286,19 @@ func TestSyncTable3(t *testing.T) { } } +func TestSyncTable4(t *testing.T) { + type SyncTable6 struct { + Id int64 + Qty float64 `xorm:"numeric(36,2)"` + } + + assert.NoError(t, PrepareEngine()) + + assert.NoError(t, testEngine.Sync2(new(SyncTable6))) + + assert.NoError(t, testEngine.Sync2(new(SyncTable6))) +} + func TestIsTableExist(t *testing.T) { assert.NoError(t, PrepareEngine()) diff --git a/schemas/type.go b/schemas/type.go index 3846b5ee..f49348be 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -368,3 +368,9 @@ func SQLType2Type(st SQLType) reflect.Type { return reflect.TypeOf("") } } + +// SQLTypeName returns sql type name +func SQLTypeName(tp string) string { + fields := strings.Split(tp, "(") + return fields[0] +} diff --git a/session_schema.go b/session_schema.go index 7d36ae7f..7cfcb626 100644 --- a/session_schema.go +++ b/session_schema.go @@ -336,8 +336,10 @@ func (session *Session) Sync2(beans ...interface{}) error { } } else { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { - engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", - tbNameWithSchema, col.Name, curType, expectedType) + if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) { + engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", + tbNameWithSchema, col.Name, curType, expectedType) + } } } } else if expectedType == schemas.Varchar { From 375857b4bee1a1ea9ce1ca5e672edf13de497640 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 18:17:47 +0800 Subject: [PATCH 18/31] Add benchmark tests (#1978) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1978 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- integrations/performance_test.go | 104 +++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 integrations/performance_test.go diff --git a/integrations/performance_test.go b/integrations/performance_test.go new file mode 100644 index 00000000..4b54b40c --- /dev/null +++ b/integrations/performance_test.go @@ -0,0 +1,104 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func BenchmarkGetVars(b *testing.B) { + b.StopTimer() + + assert.NoError(b, PrepareEngine()) + testEngine.ShowSQL(false) + + type BenchmarkGetVars struct { + Id int64 + Name string + } + + assert.NoError(b, testEngine.Sync2(new(BenchmarkGetVars))) + + var v = BenchmarkGetVars{ + Name: "myname", + } + _, err := testEngine.Insert(&v) + assert.NoError(b, err) + + b.StartTimer() + var myname string + for i := 0; i < b.N; i++ { + has, err := testEngine.Cols("name").Table("benchmark_get_vars").Where("id=?", v.Id).Get(&myname) + b.StopTimer() + myname = "" + assert.True(b, has) + assert.NoError(b, err) + b.StartTimer() + } +} + +func BenchmarkGetStruct(b *testing.B) { + b.StopTimer() + + assert.NoError(b, PrepareEngine()) + testEngine.ShowSQL(false) + + type BenchmarkGetStruct struct { + Id int64 + Name string + } + + assert.NoError(b, testEngine.Sync2(new(BenchmarkGetStruct))) + + var v = BenchmarkGetStruct{ + Name: "myname", + } + _, err := testEngine.Insert(&v) + assert.NoError(b, err) + + b.StartTimer() + var myname BenchmarkGetStruct + for i := 0; i < b.N; i++ { + has, err := testEngine.ID(v.Id).Get(&myname) + b.StopTimer() + myname.Id = 0 + myname.Name = "" + assert.True(b, has) + assert.NoError(b, err) + b.StartTimer() + } +} + +func BenchmarkFindStruct(b *testing.B) { + b.StopTimer() + + assert.NoError(b, PrepareEngine()) + testEngine.ShowSQL(false) + + type BenchmarkFindStruct struct { + Id int64 + Name string + } + + assert.NoError(b, testEngine.Sync2(new(BenchmarkFindStruct))) + + var v = BenchmarkFindStruct{ + Name: "myname", + } + _, err := testEngine.Insert(&v) + assert.NoError(b, err) + + b.StartTimer() + var mynames = make([]BenchmarkFindStruct, 0, 1) + for i := 0; i < b.N; i++ { + err := testEngine.Find(&mynames) + b.StopTimer() + mynames = make([]BenchmarkFindStruct, 0, 1) + assert.NoError(b, err) + b.StartTimer() + } +} From 27b1736c57c6243cd93e79bcce8c2bf060ba2890 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 7 Jul 2021 19:16:36 +0800 Subject: [PATCH 19/31] Add test for get map with NULL column (#1948) Add tests for #1824 Reviewed-on: https://gitea.com/xorm/xorm/pulls/1948 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- integrations/session_get_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index f4338b4f..9f82ce73 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -347,6 +347,29 @@ func TestGetSlice(t *testing.T) { assert.Error(t, err) } +func TestGetMap(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoMap struct { + Uid int `xorm:"pk autoincr"` + IsMan bool + } + + assertSync(t, new(UserinfoMap)) + + tableName := testEngine.Quote(testEngine.TableName("userinfo_map", true)) + _, err := testEngine.Exec(fmt.Sprintf("INSERT INTO %s (is_man) VALUES (NULL)", tableName)) + assert.NoError(t, err) + + var valuesString = make(map[string]string) + has, err := testEngine.Table("userinfo_map").Get(&valuesString) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 2, len(valuesString)) + assert.Equal(t, "1", valuesString["uid"]) + assert.Equal(t, "", valuesString["is_man"]) +} + func TestGetError(t *testing.T) { assert.NoError(t, PrepareEngine()) From dbd45f3f8e0f2d32f3932a2a23530ccaeb611d4e Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 10 Jul 2021 23:27:55 +0800 Subject: [PATCH 20/31] set test timeout 20m (#1985) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1985 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- Makefile | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index bf71b0f4..1bdd44c9 100644 --- a/Makefile +++ b/Makefile @@ -138,7 +138,7 @@ test: go-check test-cockroach: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=postgres -schema='$(TEST_COCKROACH_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_COCKROACH_USERNAME):$(TEST_COCKROACH_PASSWORD)@$(TEST_COCKROACH_HOST)/$(TEST_COCKROACH_DBNAME)?sslmode=disable&experimental_serial_normalization=sql_sequence" \ - -ignore_update_limit=true -coverprofile=cockroach.$(TEST_COCKROACH_SCHEMA).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -ignore_update_limit=true -coverprofile=cockroach.$(TEST_COCKROACH_SCHEMA).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-cockroach\#% test-cockroach\#%: go-check @@ -152,7 +152,7 @@ test-mssql: go-check -conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \ -default_varchar=$(TEST_MSSQL_DEFAULT_VARCHAR) -default_char=$(TEST_MSSQL_DEFAULT_CHAR) \ -do_nvarchar_override_test=$(TEST_MSSQL_DO_NVARCHAR_OVERRIDE_TEST) \ - -coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PNONY: test-mssql\#% test-mssql\#%: go-check @@ -166,7 +166,7 @@ test-mssql\#%: go-check test-mymysql: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ - -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PNONY: test-mymysql\#% test-mymysql\#%: go-check @@ -178,7 +178,7 @@ test-mymysql\#%: go-check test-mysql: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \ - -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-mysql\#% test-mysql\#%: go-check @@ -190,7 +190,7 @@ test-mysql\#%: go-check test-postgres: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-postgres\#% test-postgres\#%: go-check @@ -201,12 +201,12 @@ test-postgres\#%: go-check .PHONY: test-sqlite3 test-sqlite3: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite3.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite3.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-sqlite3-schema test-sqlite3-schema: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -schema=xorm -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite3.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite3.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-sqlite3\#% test-sqlite3\#%: go-check @@ -216,12 +216,12 @@ test-sqlite3\#%: go-check .PHONY: test-sqlite test-sqlite: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite -conn_str="./test.db?cache=shared&mode=rwc" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-sqlite-schema test-sqlite-schema: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -schema=xorm -cache=$(TEST_CACHE_ENABLE) -db=sqlite -conn_str="./test.db?cache=shared&mode=rwc" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-sqlite\#% test-sqlite\#%: go-check @@ -233,7 +233,7 @@ test-sqlite\#%: go-check test-tidb: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ -conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-tidb\#% test-tidb\#%: go-check From 6f46e684259937a5233af527fc2bc5000260190e Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 11 Jul 2021 09:30:33 +0800 Subject: [PATCH 21/31] Support Get time.Time (#1933) Fix #1107 Reviewed-on: https://gitea.com/xorm/xorm/pulls/1933 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- integrations/session_get_test.go | 21 +++++++++++++++++++++ scan.go | 8 +++++--- session_get.go | 15 ++++++++++++++- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 9f82ce73..ca894d59 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -904,3 +904,24 @@ func TestGetDecimal(t *testing.T) { //assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) } } +func TestGetTime(t *testing.T) { + type GetTimeStruct struct { + Id int64 + CreateTime time.Time + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetTimeStruct)) + + var gts = GetTimeStruct{ + CreateTime: time.Now(), + } + _, err := testEngine.Insert(>s) + assert.NoError(t, err) + + var gn time.Time + has, err := testEngine.Table("get_time_struct").Cols(colMapper.Obj2Table("CreateTime")).Get(&gn) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, gts.CreateTime.Format(time.RFC3339), gn.Format(time.RFC3339)) +} diff --git a/scan.go b/scan.go index 6396b097..d668208a 100644 --- a/scan.go +++ b/scan.go @@ -22,7 +22,9 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: return t, false, nil case *time.Time: - return &sql.NullTime{}, true, nil + return &sql.NullString{}, true, nil + case *sql.NullTime: + return &sql.NullString{}, true, nil case *string: return &sql.NullString{}, true, nil case *int, *int8, *int16, *int32: @@ -75,8 +77,8 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { *float32, *float64, *bool: return t, false, nil - case *time.Time: - return &sql.NullTime{}, true, nil + case *time.Time, *sql.NullTime: + return &sql.NullString{}, true, nil case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, time.Time, string, diff --git a/session_get.go b/session_get.go index 58255033..f710a0b1 100644 --- a/session_get.go +++ b/session_get.go @@ -35,6 +35,19 @@ func (session *Session) Get(bean interface{}) (bool, error) { return session.get(bean) } +func isPtrOfTime(v interface{}) bool { + if _, ok := v.(*time.Time); ok { + return true + } + + el := reflect.ValueOf(v).Elem() + if el.Kind() != reflect.Struct { + return false + } + + return el.Type().ConvertibleTo(schemas.TimeType) +} + func (session *Session) get(bean interface{}) (bool, error) { defer session.resetStatement() @@ -51,7 +64,7 @@ func (session *Session) get(bean interface{}) (bool, error) { return false, ErrObjectIsNil } - if beanValue.Elem().Kind() == reflect.Struct { + if beanValue.Elem().Kind() == reflect.Struct && !isPtrOfTime(bean) { if err := session.statement.SetRefBean(bean); err != nil { return false, err } From 8bf97de140c8af4ba4f61a9275a3ecb5b56d28f6 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 11 Jul 2021 20:05:43 +0800 Subject: [PATCH 22/31] Fix bug on dumptable (#1984) Fix #1983 Reviewed-on: https://gitea.com/xorm/xorm/pulls/1984 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 1 - convert/time.go | 9 +- convert/time_test.go | 30 +++++ engine.go | 212 +++++++++++------------------------- integrations/engine_test.go | 13 +++ schemas/type.go | 8 +- 6 files changed, 121 insertions(+), 152 deletions(-) create mode 100644 convert/time_test.go diff --git a/convert.go b/convert.go index 20a6e373..69277734 100644 --- a/convert.go +++ b/convert.go @@ -348,7 +348,6 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve *d = cloneBytes(s) return nil } - case time.Time: switch d := dest.(type) { case *string: diff --git a/convert/time.go b/convert/time.go index 8901279b..696b301c 100644 --- a/convert/time.go +++ b/convert/time.go @@ -19,7 +19,14 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { - dt, err := time.ParseInLocation("2006-01-02T15:04:05Z", s, originalLocation) + dt, err := time.ParseInLocation(time.RFC3339, s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' { + dt, err := time.Parse(time.RFC3339, s) if err != nil { return nil, err } diff --git a/convert/time_test.go b/convert/time_test.go new file mode 100644 index 00000000..ef01b362 --- /dev/null +++ b/convert/time_test.go @@ -0,0 +1,30 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestString2Time(t *testing.T) { + expectedLoc, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + + var kases = map[string]time.Time{ + "2021-06-06T22:58:20+08:00": time.Date(2021, 6, 6, 22, 58, 20, 0, expectedLoc), + "2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc), + "2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc), + } + for layout, tm := range kases { + t.Run(layout, func(t *testing.T) { + target, err := String2Time(layout, time.UTC, expectedLoc) + assert.NoError(t, err) + assert.EqualValues(t, tm, *target) + }) + } +} diff --git a/engine.go b/engine.go index a45771a2..d3ee8a8c 100644 --- a/engine.go +++ b/engine.go @@ -13,7 +13,6 @@ import ( "os" "reflect" "runtime" - "strconv" "strings" "time" @@ -21,7 +20,6 @@ import ( "xorm.io/xorm/contexts" "xorm.io/xorm/core" "xorm.io/xorm/dialects" - "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/log" "xorm.io/xorm/names" @@ -446,93 +444,14 @@ func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return engine.dumpTables(tables, w, tp...) } -func formatColumnValue(dbLocation *time.Location, dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { - if d == nil { - return "NULL" - } - - if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE || - dstDialect.URI().DBType == schemas.MSSQL) { - if dq { +func formatBool(s string, dstDialect dialects.Dialect) string { + if dstDialect.URI().DBType == schemas.MSSQL { + switch s { + case "true": return "1" + case "false": + return "0" } - return "0" - } - - if col.SQLType.IsText() { - var v string - switch reflect.TypeOf(d).Kind() { - case reflect.Struct, reflect.Array, reflect.Slice, reflect.Map: - bytes, err := json.DefaultJSONHandler.Marshal(d) - if err != nil { - v = fmt.Sprintf("%s", d) - } else { - v = string(bytes) - } - default: - v = fmt.Sprintf("%s", d) - } - - return "'" + strings.Replace(v, "'", "''", -1) + "'" - } else if col.SQLType.IsTime() { - if t, ok := d.(time.Time); ok { - return "'" + t.In(dbLocation).Format("2006-01-02 15:04:05") + "'" - } - var v = fmt.Sprintf("%s", d) - if strings.HasSuffix(v, " +0000 UTC") { - return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 UTC")]) - } else if strings.HasSuffix(v, " +0000 +0000") { - return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 +0000")]) - } - return "'" + strings.Replace(v, "'", "''", -1) + "'" - } else if col.SQLType.IsBlob() { - if reflect.TypeOf(d).Kind() == reflect.Slice { - return fmt.Sprintf("%s", dstDialect.FormatBytes(d.([]byte))) - } else if reflect.TypeOf(d).Kind() == reflect.String { - return fmt.Sprintf("'%s'", d.(string)) - } - } else if col.SQLType.IsNumeric() { - switch reflect.TypeOf(d).Kind() { - case reflect.Slice: - if col.SQLType.Name == schemas.Bool { - return fmt.Sprintf("%v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) - } - return fmt.Sprintf("%s", string(d.([]byte))) - case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: - if col.SQLType.Name == schemas.Bool { - v := reflect.ValueOf(d).Int() > 0 - if dstDialect.URI().DBType == schemas.SQLITE { - if v { - return "1" - } - return "0" - } - return fmt.Sprintf("%v", strconv.FormatBool(v)) - } - return fmt.Sprintf("%d", d) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if col.SQLType.Name == schemas.Bool { - v := reflect.ValueOf(d).Uint() > 0 - if dstDialect.URI().DBType == schemas.SQLITE { - if v { - return "1" - } - return "0" - } - return fmt.Sprintf("%v", strconv.FormatBool(v)) - } - return fmt.Sprintf("%d", d) - default: - return fmt.Sprintf("%v", d) - } - } - - s := fmt.Sprintf("%v", d) - if strings.Contains(s, ":") || strings.Contains(s, "-") { - if strings.HasSuffix(s, " +0000 UTC") { - return fmt.Sprintf("'%s'", s[0:len(s)-len(" +0000 UTC")]) - } - return fmt.Sprintf("'%s'", s) } return s } @@ -545,7 +464,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } else { dstDialect = dialects.QueryDialect(tp[0]) if dstDialect == nil { - return errors.New("Unsupported database type") + return fmt.Errorf("unsupported database type %v", tp[0]) } uri := engine.dialect.URI() @@ -619,73 +538,68 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } defer rows.Close() - if table.Type != nil { - sess := engine.NewSession() - defer sess.Close() - for rows.Next() { - beanValue := reflect.New(table.Type) - bean := beanValue.Interface() - fields, err := rows.Columns() - if err != nil { - return err - } - scanResults, err := sess.row2Slice(rows, fields, bean) - if err != nil { - return err - } + types, err := rows.ColumnTypes() + if err != nil { + return err + } - dataStruct := utils.ReflectValue(bean) - _, err = sess.slice2Bean(scanResults, fields, bean, &dataStruct, table) - if err != nil { - return err - } + sess := engine.NewSession() + defer sess.Close() + for rows.Next() { + _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") + if err != nil { + return err + } - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") - if err != nil { - return err - } - - var temp string - for _, d := range dstCols { - col := table.GetColumn(d) - if col == nil { - return errors.New("unknown column error") + scanResults, err := sess.engine.scanStringInterface(rows, types) + if err != nil { + return err + } + for i, scanResult := range scanResults { + stp := schemas.SQLType{Name: types[i].DatabaseTypeName()} + if stp.IsNumeric() { + s := scanResult.(*sql.NullString) + if s.Valid { + if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil { + return err + } + } else { + if _, err = io.WriteString(w, "NULL"); err != nil { + return err + } + } + } else if stp.IsBool() { + s := scanResult.(*sql.NullString) + if s.Valid { + if _, err = io.WriteString(w, formatBool(s.String, dstDialect)); err != nil { + return err + } + } else { + if _, err = io.WriteString(w, "NULL"); err != nil { + return err + } + } + } else { + s := scanResult.(*sql.NullString) + if s.Valid { + if _, err = io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil { + return err + } + } else { + if _, err = io.WriteString(w, "NULL"); err != nil { + return err + } } - - field := dataStruct.FieldByIndex(col.FieldIndex) - temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, field.Interface(), col) } - _, err = io.WriteString(w, temp[1:]+");\n") - if err != nil { - return err + if i < len(scanResults)-1 { + if _, err = io.WriteString(w, ","); err != nil { + return err + } } } - } else { - for rows.Next() { - dest := make([]interface{}, len(cols)) - err = rows.ScanSlice(&dest) - if err != nil { - return err - } - - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") - if err != nil { - return err - } - - var temp string - for i, d := range dest { - col := table.GetColumn(cols[i]) - if col == nil { - return errors.New("unknow column error") - } - - temp += "," + formatColumnValue(engine.DatabaseTZ, dstDialect, d, col) - } - _, err = io.WriteString(w, temp[1:]+");\n") - if err != nil { - return err - } + _, err = io.WriteString(w, ");\n") + if err != nil { + return err } } diff --git a/integrations/engine_test.go b/integrations/engine_test.go index a06d91aa..a594ee46 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -172,8 +172,21 @@ func TestDumpTables(t *testing.T) { name := fmt.Sprintf("dump_%v-table.sql", tp) t.Run(name, func(t *testing.T) { assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) + }) } + + assert.NoError(t, testEngine.DropTables(new(TestDumpTableStruct))) + + importPath := fmt.Sprintf("dump_%v-table.sql", testEngine.Dialect().URI().DBType) + t.Run("import_"+importPath, func(t *testing.T) { + sess := testEngine.NewSession() + defer sess.Close() + assert.NoError(t, sess.Begin()) + _, err = sess.ImportFile(importPath) + assert.NoError(t, err) + assert.NoError(t, sess.Commit()) + }) } func TestDumpTables2(t *testing.T) { diff --git a/schemas/type.go b/schemas/type.go index f49348be..62e66c2e 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -39,6 +39,7 @@ const ( TIME_TYPE NUMERIC_TYPE ARRAY_TYPE + BOOL_TYPE ) // IsType reutrns ture if the column type is the same as the parameter @@ -64,6 +65,10 @@ func (s *SQLType) IsTime() bool { return s.IsType(TIME_TYPE) } +func (s *SQLType) IsBool() bool { + return s.IsType(BOOL_TYPE) +} + // IsNumeric returns true if column is a numeric type func (s *SQLType) IsNumeric() bool { return s.IsType(NUMERIC_TYPE) @@ -209,7 +214,8 @@ var ( Bytea: BLOB_TYPE, UniqueIdentifier: BLOB_TYPE, - Bool: NUMERIC_TYPE, + Bool: BOOL_TYPE, + Boolean: BOOL_TYPE, Serial: NUMERIC_TYPE, BigSerial: NUMERIC_TYPE, From 394c4e1f1715421c138d59561cae44f50423cb6c Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 11 Jul 2021 21:33:01 +0800 Subject: [PATCH 23/31] Replace #1044 (#1935) Fix #1372, #765 TODO: - [x] Add tests Co-authored-by: MURAOKA Taro Reviewed-on: https://gitea.com/xorm/xorm/pulls/1935 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- integrations/session_update_test.go | 44 +++++++++++++++++++++++++++-- internal/statements/update.go | 10 ++++--- internal/statements/values.go | 15 ++++++++-- session.go | 2 +- 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index 796bfa0a..22808d60 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -1396,15 +1396,22 @@ func TestNilFromDB(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(TestTable1)) - cnt, err := testEngine.Insert(&TestTable1{ + var tt0 = TestTable1{ Field1: &TestFieldType1{ cb: []byte("string"), }, UpdateTime: time.Now(), - }) + } + cnt, err := testEngine.Insert(&tt0) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + var tt1 TestTable1 + has, err := testEngine.ID(tt0.Id).Get(&tt1) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "string", string(tt1.Field1.cb)) + cnt, err = testEngine.Update(TestTable1{ UpdateTime: time.Now().Add(time.Second), }, TestTable1{ @@ -1418,4 +1425,37 @@ func TestNilFromDB(t *testing.T) { }) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + var tt = TestTable1{ + UpdateTime: time.Now(), + Field1: &TestFieldType1{ + cb: nil, + }, + } + cnt, err = testEngine.Insert(&tt) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var tt2 TestTable1 + has, err = testEngine.ID(tt.Id).Get(&tt2) + assert.NoError(t, err) + assert.True(t, has) + assert.Nil(t, tt2.Field1) + + var tt3 = TestTable1{ + UpdateTime: time.Now(), + Field1: &TestFieldType1{ + cb: []byte{}, + }, + } + cnt, err = testEngine.Insert(&tt3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var tt4 TestTable1 + has, err = testEngine.ID(tt3.Id).Get(&tt4) + assert.NoError(t, err) + assert.True(t, has) + assert.NotNil(t, tt4.Field1) + assert.NotNil(t, tt4.Field1.cb) } diff --git a/internal/statements/update.go b/internal/statements/update.go index 06cf0689..3020595b 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -127,8 +127,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if err != nil { return nil, nil, err } - - val = data + if data != nil { + val = data + } goto APPEND } } @@ -138,8 +139,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if err != nil { return nil, nil, err } - - val = data + if data != nil { + val = data + } goto APPEND } diff --git a/internal/statements/values.go b/internal/statements/values.go index 994070ac..ee3821e9 100644 --- a/internal/statements/values.go +++ b/internal/statements/values.go @@ -31,6 +31,12 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl if err != nil { return nil, err } + if data == nil { + if col.Nullable { + return nil, nil + } + data = []byte{} + } if col.SQLType.IsBlob() { return data, nil } @@ -45,12 +51,15 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl if err != nil { return nil, err } + if data == nil { + if col.Nullable { + return nil, nil + } + data = []byte{} + } if col.SQLType.IsBlob() { return data, nil } - if nil == data { - return nil, nil - } return string(data), nil } } diff --git a/session.go b/session.go index 64b1758a..486911a5 100644 --- a/session.go +++ b/session.go @@ -768,7 +768,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue, err := session.getField(dataStruct, key, table, idx) if err != nil { - if !strings.Contains(err.Error(), "is not valid") { + if _, ok := err.(ErrFieldIsNotValid); !ok { session.engine.logger.Warnf("%v", err) } continue From 147328f6298a87593f1edf341ff07a23d785b673 Mon Sep 17 00:00:00 2001 From: andreasgerstmayr Date: Mon, 12 Jul 2021 23:51:50 +0800 Subject: [PATCH 24/31] fix possible null dereference in internal/statements/query.go (#1988) Make sure that pLimitN is not `nil` before dereferencing the pointer. Co-authored-by: Andreas Gerstmayr Reviewed-on: https://gitea.com/xorm/xorm/pulls/1988 Co-authored-by: andreasgerstmayr Co-committed-by: andreasgerstmayr --- internal/statements/query.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/statements/query.go b/internal/statements/query.go index a972a8e0..69f48e73 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -314,7 +314,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB fmt.Fprint(&buf, " LIMIT ", *pLimitN) } } else if dialect.URI().DBType == schemas.ORACLE { - if statement.Start != 0 || pLimitN != nil { + if statement.Start != 0 && pLimitN != nil { oldString := buf.String() buf.Reset() rawColStr := columnStr From b296c8f1d73a2df55c6dc967d600bbe6b4724983 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 14 Jul 2021 12:20:26 +0800 Subject: [PATCH 25/31] Exec with time arg now will obey time zone settings on engine (#1989) Fix #1770 Reviewed-on: https://gitea.com/xorm/xorm/pulls/1989 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- dialects/dialect.go | 12 ++++-------- dialects/mssql.go | 13 +++++++++++++ dialects/mysql.go | 15 +++++++++++++++ dialects/oracle.go | 15 +++++++++++++++ dialects/postgres.go | 20 +++++++++++++++----- dialects/sqlite3.go | 15 +++++++++++++-- integrations/session_raw_test.go | 30 ++++++++++++++++++++++++++++++ internal/statements/statement.go | 17 +++++++++++++++-- scan.go | 22 ++++++++++++++++++---- session_query.go | 2 +- 10 files changed, 139 insertions(+), 22 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index b3d374cc..df33155d 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -42,11 +42,12 @@ func (uri *URI) SetSchema(schema string) { type Dialect interface { Init(*URI) error URI() *URI - SQLType(*schemas.Column) string - Alias(string) string // return what a sql type's alias of - FormatBytes(b []byte) string Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) + SQLType(*schemas.Column) string + Alias(string) string // return what a sql type's alias of + ColumnTypeKind(string) int // database column type kind + IsReserved(string) bool Quoter() schemas.Quoter SetQuotePolicy(quotePolicy QuotePolicy) @@ -102,11 +103,6 @@ func (db *Base) URI() *URI { return db.uri } -// FormatBytes formats bytes -func (db *Base) FormatBytes(bs []byte) string { - return fmt.Sprintf("0x%x", bs) -} - // DropTableSQL returns drop table SQL func (db *Base) DropTableSQL(tableName string) (string, bool) { quote := db.dialect.Quoter().Quote diff --git a/dialects/mssql.go b/dialects/mssql.go index c3c15077..e708ba80 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -364,6 +364,19 @@ func (db *mssql) SQLType(c *schemas.Column) string { return res } +func (db *mssql) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATE", "DATETIME", "DATETIME2", "TIME": + return schemas.TIME_TYPE + case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT": + return schemas.TEXT_TYPE + case "FLOAT", "REAL", "BIGINT", "DATETIMEOFFSET", "TINYINT", "SMALLINT", "INT": + return schemas.NUMERIC_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *mssql) IsReserved(name string) bool { _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok diff --git a/dialects/mysql.go b/dialects/mysql.go index da19b820..db45cd62 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -337,6 +337,21 @@ func (db *mysql) SQLType(c *schemas.Column) string { return res } +func (db *mysql) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME": + return schemas.TIME_TYPE + case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET": + return schemas.TEXT_TYPE + case "BIGINT", "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "FLOAT", "REAL", "DOUBLE PRECISION", "DECIMAL", "NUMERIC", "BIT": + return schemas.NUMERIC_TYPE + case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *mysql) IsReserved(name string) bool { _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok diff --git a/dialects/oracle.go b/dialects/oracle.go index 7043972b..5dd92887 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -568,6 +568,21 @@ func (db *oracle) SQLType(c *schemas.Column) string { return res } +func (db *oracle) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATE": + return schemas.TIME_TYPE + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + return schemas.TEXT_TYPE + case "NUMBER": + return schemas.NUMERIC_TYPE + case "BLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *oracle) AutoIncrStr() string { return "AUTO_INCREMENT" } diff --git a/dialects/postgres.go b/dialects/postgres.go index 9f3c7275..4ec780e8 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -873,11 +873,6 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { } } -// FormatBytes formats bytes -func (db *postgres) FormatBytes(bs []byte) string { - return fmt.Sprintf("E'\\x%x'", bs) -} - func (db *postgres) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { @@ -943,6 +938,21 @@ func (db *postgres) SQLType(c *schemas.Column) string { return res } +func (db *postgres) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME", "TIMESTAMP": + return schemas.TIME_TYPE + case "VARCHAR", "TEXT": + return schemas.TEXT_TYPE + case "BIGINT", "BIGSERIAL", "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL", "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION": + return schemas.NUMERIC_TYPE + case "BOOL": + return schemas.BOOL_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *postgres) IsReserved(name string) bool { _, ok := postgresReservedWords[strings.ToUpper(name)] return ok diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 04e5b457..581272ad 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -233,8 +233,19 @@ func (db *sqlite3) SQLType(c *schemas.Column) string { } } -func (db *sqlite3) FormatBytes(bs []byte) string { - return fmt.Sprintf("X'%x'", bs) +func (db *sqlite3) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME": + return schemas.TIME_TYPE + case "TEXT": + return schemas.TEXT_TYPE + case "INTEGER", "REAL", "NUMERIC", "DECIMAL": + return schemas.NUMERIC_TYPE + case "BLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } } func (db *sqlite3) IsReserved(name string) bool { diff --git a/integrations/session_raw_test.go b/integrations/session_raw_test.go index 8b9d6766..36677683 100644 --- a/integrations/session_raw_test.go +++ b/integrations/session_raw_test.go @@ -7,6 +7,7 @@ package integrations import ( "strconv" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -35,3 +36,32 @@ func TestExecAndQuery(t *testing.T) { assert.EqualValues(t, 1, id) assert.Equal(t, "user", string(results[0]["name"])) } + +func TestExecTime(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoExecTime struct { + Uid int + Name string + Created time.Time + } + + assert.NoError(t, testEngine.Sync2(new(UserinfoExecTime))) + now := time.Now() + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_exec_time`", true)+" (uid, name, created) VALUES (?, ?, ?)", 1, "user", now) + assert.NoError(t, err) + cnt, err := res.RowsAffected() + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + results, err := testEngine.QueryString("SELECT * FROM " + testEngine.TableName("`userinfo_exec_time`", true)) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), results[0]["created"]) + + var uet UserinfoExecTime + has, err := testEngine.Where("uid=?", 1).Get(&uet) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), uet.Created.Format("2006-01-02 15:04:05")) +} diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 2d173b87..bfe9987f 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -942,16 +942,29 @@ func (statement *Statement) quoteColumnStr(columnStr string) string { // ConvertSQLOrArgs converts sql or args func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { - sql, args, err := convertSQLOrArgs(sqlOrArgs...) + sql, args, err := statement.convertSQLOrArgs(sqlOrArgs...) if err != nil { return "", nil, err } return statement.ReplaceQuote(sql), args, nil } -func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { +func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { switch sqlOrArgs[0].(type) { case string: + if len(sqlOrArgs) > 1 { + var newArgs = make([]interface{}, 0, len(sqlOrArgs)-1) + for _, arg := range sqlOrArgs[1:] { + if v, ok := arg.(*time.Time); ok { + newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) + } else if v, ok := arg.(time.Time); ok { + newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) + } else { + newArgs = append(newArgs, arg) + } + } + return sqlOrArgs[0].(string), newArgs, nil + } return sqlOrArgs[0].(string), sqlOrArgs[1:], nil case *builder.Builder: return sqlOrArgs[0].(*builder.Builder).ToSQL() diff --git a/scan.go b/scan.go index d668208a..2fedd415 100644 --- a/scan.go +++ b/scan.go @@ -14,6 +14,7 @@ import ( "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/dialects" + "xorm.io/xorm/schemas" ) // genScanResultsByBeanNullabale generates scan result @@ -123,7 +124,7 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { } } -func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { +func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { var scanResults = make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { var s sql.NullString @@ -135,9 +136,22 @@ func row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[ } result := make(map[string]string, len(fields)) - for ii, key := range fields { - s := scanResults[ii].(*sql.NullString) - result[key] = s.String + for i, key := range fields { + s := scanResults[i].(*sql.NullString) + if s.String == "" { + result[key] = "" + continue + } + + if schemas.TIME_TYPE == engine.dialect.ColumnTypeKind(types[i].DatabaseTypeName()) { + t, err := convert.String2Time(s.String, engine.DatabaseTZ, engine.TZLocation) + if err != nil { + return nil, err + } + result[key] = t.Format("2006-01-02 15:04:05") + } else { + result[key] = s.String + } } return result, nil } diff --git a/session_query.go b/session_query.go index fa33496d..d14c3908 100644 --- a/session_query.go +++ b/session_query.go @@ -33,7 +33,7 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string } for rows.Next() { - result, err := row2mapStr(rows, types, fields) + result, err := session.engine.row2mapStr(rows, types, fields) if err != nil { return nil, err } From 69a7db5312a1e6ef8c1edbf80127a9fb44a37cff Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 14 Jul 2021 17:06:53 +0800 Subject: [PATCH 26/31] improve uint tests (#1990) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1990 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- integrations/session_pk_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/session_pk_test.go b/integrations/session_pk_test.go index d5f23491..8f7dcb55 100644 --- a/integrations/session_pk_test.go +++ b/integrations/session_pk_test.go @@ -173,6 +173,16 @@ func TestUintId(t *testing.T) { err = testEngine.CreateTables(&UintId{}) assert.NoError(t, err) + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + assert.EqualValues(t, 1, len(tables)) + cols := tables[0].PKColumns() + assert.EqualValues(t, 1, len(cols)) + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + assert.EqualValues(t, "UNSIGNED INT", cols[0].SQLType.Name) + } + cnt, err := testEngine.Insert(&UintId{Name: "test"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) From 779a74ccff7e50341c48eb59d7fd77330a74d5b5 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 15 Jul 2021 07:06:15 +0800 Subject: [PATCH 27/31] Remove duplicated code (#1991) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1991 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- session_get.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/session_get.go b/session_get.go index f710a0b1..cc6427d7 100644 --- a/session_get.go +++ b/session_get.go @@ -268,12 +268,6 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields } func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { - fields, err := rows.Columns() - if err != nil { - // WARN: Alougth rows return true, but get fields failed - return true, err - } - scanResults, err := session.row2Slice(rows, fields, bean) if err != nil { return false, err From aaa2111e8ff6340b497b6a991d6d71e9b45282bc Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 19 Jul 2021 00:21:46 +0800 Subject: [PATCH 28/31] Refactor asbytes (#1995) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1995 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 197 ++++++++++---- convert/time.go | 9 +- dialects/mysql.go | 47 ---- engine.go | 7 +- integrations/session_get_test.go | 2 +- integrations/time_test.go | 45 +++ integrations/types_null_test.go | 14 +- rows.go | 6 +- scan.go | 32 +-- session.go | 167 ++++-------- session_convert.go | 451 ------------------------------- session_find.go | 7 +- session_get.go | 10 +- session_insert.go | 6 +- session_raw.go | 60 ---- 15 files changed, 296 insertions(+), 764 deletions(-) diff --git a/convert.go b/convert.go index 69277734..533dbe99 100644 --- a/convert.go +++ b/convert.go @@ -7,6 +7,7 @@ package xorm import ( "database/sql" "database/sql/driver" + "encoding/json" "errors" "fmt" "math/big" @@ -285,23 +286,94 @@ func asBigFloat(src interface{}) (*big.Float, error) { return nil, fmt.Errorf("unsupported value %T as big.Float", src) } -func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { +func asBytes(src interface{}) ([]byte, bool) { + switch t := src.(type) { + case []byte: + return t, true + case *sql.NullString: + if !t.Valid { + return nil, true + } + return []byte(t.String), true + case *sql.RawBytes: + return *t, true + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.AppendInt(buf, rv.Int(), 10), true + return strconv.AppendInt(nil, rv.Int(), 10), true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.AppendUint(buf, rv.Uint(), 10), true + return strconv.AppendUint(nil, rv.Uint(), 10), true case reflect.Float32: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 32), true case reflect.Float64: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 64), true case reflect.Bool: - return strconv.AppendBool(buf, rv.Bool()), true + return strconv.AppendBool(nil, rv.Bool()), true case reflect.String: - s := rv.String() - return append(buf, s...), true + return []byte(rv.String()), true } - return + return nil, false +} + +func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) { + switch t := src.(type) { + case string: + return convert.String2Time(t, dbLoc, uiLoc) + case *sql.NullString: + if !t.Valid { + return nil, nil + } + return convert.String2Time(t.String, dbLoc, uiLoc) + case []uint8: + if t == nil { + return nil, nil + } + return convert.String2Time(string(t), dbLoc, uiLoc) + case *sql.NullTime: + if !t.Valid { + return nil, nil + } + z, _ := t.Time.Zone() + if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() { + tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(), + t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.Time.In(uiLoc) + return &tm, nil + case *time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case int: + tm := time.Unix(int64(t), 0).In(uiLoc) + return &tm, nil + case int64: + tm := time.Unix(t, 0).In(uiLoc) + return &tm, nil + case *sql.NullInt64: + tm := time.Unix(t.Int64, 0).In(uiLoc) + return &tm, nil + + } + return nil, fmt.Errorf("unsupported value %#v as time", src) } // convertAssign copies to dest the value in src, converting it if possible. @@ -559,8 +631,7 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve return nil } case *[]byte: - sv = reflect.ValueOf(src) - if b, ok := asBytes(nil, sv); ok { + if b, ok := asBytes(src); ok { *d = b return nil } @@ -575,44 +646,24 @@ func convertAssign(dest, src interface{}, originalLocation *time.Location, conve return nil } - return convertAssignV(reflect.ValueOf(dest), src, originalLocation, convertedLocation) + return convertAssignV(reflect.ValueOf(dest), src) } -func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, convertedLocation *time.Location) error { - if dpv.Kind() != reflect.Ptr { - return errors.New("destination not a pointer") - } - if dpv.IsNil() { - return errNilPtr - } - - var sv = reflect.ValueOf(src) - - dv := reflect.Indirect(dpv) - if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { - switch b := src.(type) { - case []byte: - dv.Set(reflect.ValueOf(cloneBytes(b))) - default: - dv.Set(sv) - } +func convertAssignV(dv reflect.Value, src interface{}) error { + if src == nil { return nil } - if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { - dv.Set(sv.Convert(dv.Type())) - return nil + if dv.Type().Implements(scannerType) { + return dv.Interface().(sql.Scanner).Scan(src) } switch dv.Kind() { case reflect.Ptr: - if src == nil { - dv.Set(reflect.Zero(dv.Type())) - return nil + if dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) } - - dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src, originalLocation, convertedLocation) + return convertAssignV(dv.Elem(), src) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i64, err := asInt64(src) if err != nil { @@ -640,9 +691,28 @@ func convertAssignV(dpv reflect.Value, src interface{}, originalLocation, conver case reflect.String: dv.SetString(asString(src)) return nil + case reflect.Bool: + b, err := asBool(src) + if err != nil { + return err + } + dv.SetBool(b) + return nil + case reflect.Slice, reflect.Map, reflect.Struct, reflect.Array: + data, ok := asBytes(src) + if !ok { + return fmt.Errorf("onvertAssignV: src cannot be as bytes %#v", src) + } + if data == nil { + return nil + } + if dv.Kind() != reflect.Ptr { + dv = dv.Addr() + } + return json.Unmarshal(data, dv.Interface()) + default: + return fmt.Errorf("convertAssignV: unsupported Scan, storing driver.Value type %T into type %T", src, dv.Interface()) } - - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface()) } func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { @@ -682,16 +752,43 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) } -func asBool(bs []byte) (bool, error) { - if len(bs) == 0 { - return false, nil +func asBool(src interface{}) (bool, error) { + switch v := src.(type) { + case bool: + return v, nil + case *bool: + return *v, nil + case *sql.NullBool: + return v.Bool, nil + case int64: + return v > 0, nil + case int: + return v > 0, nil + case int8: + return v > 0, nil + case int16: + return v > 0, nil + case int32: + return v > 0, nil + case []byte: + if len(v) == 0 { + return false, nil + } + if v[0] == 0x00 { + return false, nil + } else if v[0] == 0x01 { + return true, nil + } + return strconv.ParseBool(string(v)) + case string: + return strconv.ParseBool(v) + case *sql.NullInt64: + return v.Int64 > 0, nil + case *sql.NullInt32: + return v.Int32 > 0, nil + default: + return false, fmt.Errorf("unknow type %T as bool", src) } - if bs[0] == 0x00 { - return false, nil - } else if bs[0] == 0x01 { - return true, nil - } - return strconv.ParseBool(string(bs)) } // str2PK convert string value to primary key value according to tp diff --git a/convert/time.go b/convert/time.go index 696b301c..5a3e5246 100644 --- a/convert/time.go +++ b/convert/time.go @@ -6,6 +6,7 @@ package convert import ( "fmt" + "strconv" "time" ) @@ -19,7 +20,7 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t dt = dt.In(convertedLocation) return &dt, nil } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { - dt, err := time.ParseInLocation(time.RFC3339, s, originalLocation) + dt, err := time.ParseInLocation("2006-01-02T15:04:05", s[:19], originalLocation) if err != nil { return nil, err } @@ -32,6 +33,12 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t } dt = dt.In(convertedLocation) return &dt, nil + } else { + i, err := strconv.ParseInt(s, 10, 64) + if err == nil { + tm := time.Unix(i, 0).In(convertedLocation) + return &tm, nil + } } return nil, fmt.Errorf("unsupported convertion from %s to time", s) } diff --git a/dialects/mysql.go b/dialects/mysql.go index db45cd62..9312c071 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -15,7 +15,6 @@ import ( "strings" "time" - "xorm.io/xorm/convert" "xorm.io/xorm/core" "xorm.io/xorm/schemas" ) @@ -733,52 +732,6 @@ func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { } } -func (p *mysqlDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, scanResults ...interface{}) error { - var v2 = make([]interface{}, 0, len(scanResults)) - var turnBackIdxes = make([]int, 0, 5) - for i, vv := range scanResults { - switch vv.(type) { - case *time.Time: - v2 = append(v2, &sql.NullString{}) - turnBackIdxes = append(turnBackIdxes, i) - case *sql.NullTime: - v2 = append(v2, &sql.NullString{}) - turnBackIdxes = append(turnBackIdxes, i) - default: - v2 = append(v2, scanResults[i]) - } - } - if err := rows.Scan(v2...); err != nil { - return err - } - for _, i := range turnBackIdxes { - switch t := scanResults[i].(type) { - case *time.Time: - var s = *(v2[i].(*sql.NullString)) - if !s.Valid { - break - } - dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) - if err != nil { - return err - } - *t = *dt - case *sql.NullTime: - var s = *(v2[i].(*sql.NullString)) - if !s.Valid { - break - } - dt, err := convert.String2Time(s.String, ctx.DBLocation, ctx.UserLocation) - if err != nil { - return err - } - t.Time = *dt - t.Valid = true - } - } - return nil -} - type mymysqlDriver struct { mysqlDriver } diff --git a/engine.go b/engine.go index d3ee8a8c..b4ef9593 100644 --- a/engine.go +++ b/engine.go @@ -543,6 +543,11 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } + fields, err := rows.Columns() + if err != nil { + return err + } + sess := engine.NewSession() defer sess.Close() for rows.Next() { @@ -551,7 +556,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } - scanResults, err := sess.engine.scanStringInterface(rows, types) + scanResults, err := sess.engine.scanStringInterface(rows, fields, types) if err != nil { return err } diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index ca894d59..b1dffe14 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -914,7 +914,7 @@ func TestGetTime(t *testing.T) { assertSync(t, new(GetTimeStruct)) var gts = GetTimeStruct{ - CreateTime: time.Now(), + CreateTime: time.Now().In(testEngine.GetTZLocation()), } _, err := testEngine.Insert(>s) assert.NoError(t, err) diff --git a/integrations/time_test.go b/integrations/time_test.go index 6d8d812c..50fd1847 100644 --- a/integrations/time_test.go +++ b/integrations/time_test.go @@ -53,9 +53,18 @@ func TestTimeUserTimeDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type TimeUser2 struct { @@ -118,9 +127,18 @@ func TestTimeUserCreatedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserCreated2 struct { @@ -204,9 +222,18 @@ func TestTimeUserUpdatedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserUpdated2 struct { @@ -311,9 +338,18 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserDeleted2 struct { @@ -435,9 +471,18 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserDeleted4 struct { diff --git a/integrations/types_null_test.go b/integrations/types_null_test.go index 98bd86b9..86ce1939 100644 --- a/integrations/types_null_test.go +++ b/integrations/types_null_test.go @@ -7,7 +7,6 @@ package integrations import ( "database/sql" "database/sql/driver" - "errors" "fmt" "strconv" "strings" @@ -42,15 +41,22 @@ func (m *CustomStruct) Scan(value interface{}) error { return nil } - if s, ok := value.([]byte); ok { - seps := strings.Split(string(s), "/") + var s string + switch t := value.(type) { + case string: + s = t + case []byte: + s = string(t) + } + if len(s) > 0 { + seps := strings.Split(s, "/") m.Year, _ = strconv.Atoi(seps[0]) m.Month, _ = strconv.Atoi(seps[1]) m.Day, _ = strconv.Atoi(seps[2]) return nil } - return errors.New("scan data not fit []byte") + return fmt.Errorf("scan data %#v not fit []byte", value) } func (m CustomStruct) Value() (driver.Value, error) { diff --git a/rows.go b/rows.go index a56ea1c9..5e0a1ffe 100644 --- a/rows.go +++ b/rows.go @@ -129,8 +129,12 @@ func (rows *Rows) Scan(bean interface{}) error { if err != nil { return err } + types, err := rows.rows.ColumnTypes() + if err != nil { + return err + } - scanResults, err := rows.session.row2Slice(rows.rows, fields, bean) + scanResults, err := rows.session.row2Slice(rows.rows, fields, types, bean) if err != nil { return err } diff --git a/scan.go b/scan.go index 2fedd415..e4c0e4a1 100644 --- a/scan.go +++ b/scan.go @@ -20,6 +20,8 @@ import ( // genScanResultsByBeanNullabale generates scan result func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { switch t := bean.(type) { + case *interface{}: + return t, false, nil case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: return t, false, nil case *time.Time: @@ -71,7 +73,10 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { switch t := bean.(type) { + case *interface{}: + return t, false, nil case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, + *sql.RawBytes, *string, *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, @@ -175,17 +180,14 @@ func row2mapBytes(rows *core.Rows, types []*sql.ColumnType, fields []string) (ma return result, nil } -func (engine *Engine) scanStringInterface(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { +func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { var scanResults = make([]interface{}, len(types)) for i := 0; i < len(types); i++ { var s sql.NullString scanResults[i] = &s } - if err := engine.driver.Scan(&dialects.ScanContext{ - DBLocation: engine.DatabaseTZ, - UserLocation: engine.TZLocation, - }, rows, types, scanResults...); err != nil { + if err := engine.scan(rows, fields, types, scanResults...); err != nil { return nil, err } return scanResults, nil @@ -200,14 +202,14 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column var replaced bool var scanResult interface{} switch t := v.(type) { + case *big.Float, *time.Time, *sql.NullTime: + scanResult = &sql.NullString{} + replaced = true case sql.Scanner: scanResult = t case convert.Conversion: scanResult = &sql.RawBytes{} replaced = true - case *big.Float: - scanResult = &sql.NullString{} - replaced = true default: var useNullable = true if engine.driver.Features().SupportNullable { @@ -246,7 +248,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column return nil } -func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ([]interface{}, error) { +func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { var scanResultContainers = make([]interface{}, len(types)) for i := 0; i < len(types); i++ { scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) @@ -255,17 +257,14 @@ func (engine *Engine) scanInterfaces(rows *core.Rows, types []*sql.ColumnType) ( } scanResultContainers[i] = scanResult } - if err := engine.driver.Scan(&dialects.ScanContext{ - DBLocation: engine.DatabaseTZ, - UserLocation: engine.TZLocation, - }, rows, types, scanResultContainers...); err != nil { + if err := engine.scan(rows, fields, types, scanResultContainers...); err != nil { return nil, err } return scanResultContainers, nil } func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { - scanResults, err := engine.scanStringInterface(rows, types) + scanResults, err := engine.scanStringInterface(rows, fields, types) if err != nil { return nil, err } @@ -307,10 +306,7 @@ func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, } scanResultContainers[i] = scanResult } - if err := engine.driver.Scan(&dialects.ScanContext{ - DBLocation: engine.DatabaseTZ, - UserLocation: engine.TZLocation, - }, rows, types, scanResultContainers...); err != nil { + if err := engine.scan(rows, fields, types, scanResultContainers...); err != nil { return nil, err } diff --git a/session.go b/session.go index 486911a5..5557d717 100644 --- a/session.go +++ b/session.go @@ -16,7 +16,6 @@ import ( "io" "reflect" "strings" - "time" "xorm.io/xorm/contexts" "xorm.io/xorm/convert" @@ -389,7 +388,7 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s // Cell cell is a result of one column field type Cell *interface{} -func (session *Session) rows2Beans(rows *core.Rows, fields []string, +func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType, table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { @@ -398,7 +397,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, dataStruct := newValue.Elem() // handle beforeClosures - scanResults, err := session.row2Slice(rows, fields, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { return err } @@ -417,7 +416,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, return nil } -func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) { +func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) { for _, closure := range session.beforeClosures { closure(bean) } @@ -427,7 +426,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa var cell interface{} scanResults[i] = &cell } - if err := rows.Scan(scanResults...); err != nil { + if err := session.engine.scan(rows, fields, types, scanResults...); err != nil { return nil, err } @@ -454,27 +453,28 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec if fieldValue.CanAddr() { if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - data, err := value2Bytes(&rawValue) - if err != nil { - return err + data, ok := asBytes(scanResult) + if !ok { + return fmt.Errorf("cannot convert %#v as bytes", scanResult) } - if err := structConvert.FromDB(data); err != nil { - return err - } - return nil + return structConvert.FromDB(data) } } - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue.Interface().(convert.Conversion).FromDB(data) - } else { - return err + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, ok := asBytes(scanResult) + if !ok { + return fmt.Errorf("cannot convert %#v as bytes", scanResult) } - return nil + if data == nil { + return nil + } + + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + return fieldValue.Interface().(convert.Conversion).FromDB(data) + } + return structConvert.FromDB(data) } rawValueType := reflect.TypeOf(rawValue.Interface()) @@ -554,64 +554,28 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } return nil case reflect.Slice, reflect.Array: - switch rawValueType.Kind() { - case reflect.Slice, reflect.Array: - switch rawValueType.Elem().Kind() { - case reflect.Uint8: - if fieldType.Elem().Kind() == reflect.Uint8 { - if col.SQLType.IsText() { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } else { - if fieldValue.Len() > 0 { - for i := 0; i < fieldValue.Len(); i++ { - if i < vv.Len() { - fieldValue.Index(i).Set(vv.Index(i)) - } - } - } else { - for i := 0; i < vv.Len(); i++ { - fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) - } + bs, ok := asBytes(scanResult) + if ok && fieldType.Elem().Kind() == reflect.Uint8 { + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } else { + if fieldValue.Len() > 0 { + for i := 0; i < fieldValue.Len(); i++ { + if i < vv.Len() { + fieldValue.Index(i).Set(vv.Index(i)) } } - return nil + } else { + for i := 0; i < vv.Len(); i++ { + fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) + } } } - } - case reflect.String: - if rawValueType.Kind() == reflect.String { - fieldValue.SetString(vv.String()) - return nil - } - case reflect.Bool: - if rawValueType.Kind() == reflect.Bool { - fieldValue.SetBool(vv.Bool()) - return nil - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch rawValueType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fieldValue.SetInt(vv.Int()) - return nil - } - case reflect.Float32, reflect.Float64: - switch rawValueType.Kind() { - case reflect.Float32, reflect.Float64: - fieldValue.SetFloat(vv.Float()) - return nil - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - switch rawValueType.Kind() { - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - fieldValue.SetUint(vv.Uint()) - return nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - fieldValue.SetUint(uint64(vv.Int())) return nil } case reflect.Struct: @@ -630,47 +594,13 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec dbTZ = col.TimeZone } - if rawValueType == schemas.TimeType { - t := vv.Convert(schemas.TimeType).Interface().(time.Time) - - z, _ := t.Zone() - // set new location if database don't save timezone or give an incorrect timezone - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", col.Name, t, z, *t.Location()) - t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbTZ) - } - - t = t.In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || - rawValueType == schemas.Int32Type { - t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } else { - if d, ok := vv.Interface().([]uint8); ok { - t, err := session.byte2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } - - } else if d, ok := vv.Interface().(string); ok { - t, err := session.str2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - return nil - } - } else { - return fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) - } + t, err := asTime(scanResult, dbTZ, session.engine.TZLocation) + if err != nil { + return err } + + fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) + return nil } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { err := nulVal.Scan(vv.Interface()) if err == nil { @@ -733,12 +663,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } } // switch fieldType.Kind() - data, err := value2Bytes(&rawValue) - if err != nil { - return err - } - - return session.bytes2Value(col, fieldValue, data) + return convertAssignV(fieldValue.Addr(), scanResult) } func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { diff --git a/session_convert.go b/session_convert.go index b8218a77..452801e2 100644 --- a/session_convert.go +++ b/session_convert.go @@ -5,16 +5,11 @@ package xorm import ( - "database/sql" - "errors" "fmt" - "reflect" "strconv" "strings" "time" - "xorm.io/xorm/convert" - "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -73,449 +68,3 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time outTime = x.In(session.engine.TZLocation) return } - -func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) { - return session.str2Time(col, string(data)) -} - -// convert a db data([]byte) to a field value -func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Value, data []byte) error { - if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - return structConvert.FromDB(data) - } - - if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { - return structConvert.FromDB(data) - } - - var v interface{} - key := col.Name - fieldType := fieldValue.Type() - - switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: - x := reflect.New(fieldType) - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - case reflect.Slice, reflect.Array, reflect.Map: - v = data - t := fieldType.Elem() - k := t.Kind() - if col.SQLType.IsText() { - x := reflect.New(fieldType) - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } else if col.SQLType.IsBlob() { - if k == reflect.Uint8 { - fieldValue.Set(reflect.ValueOf(v)) - } else { - x := reflect.New(fieldType) - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } - } else { - return ErrUnSupportedType - } - case reflect.String: - fieldValue.SetString(string(data)) - case reflect.Bool: - v, err := asBool(data) - if err != nil { - return fmt.Errorf("arg %v as bool: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(v)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - sdata := string(data) - var x int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - session.engine.dialect.URI().DBType == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API - if len(data) == 1 { - x = int64(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x, err = strconv.ParseInt(sdata, 16, 64) - } else if strings.HasPrefix(sdata, "0") { - x, err = strconv.ParseInt(sdata, 8, 64) - } else if strings.EqualFold(sdata, "true") { - x = 1 - } else if strings.EqualFold(sdata, "false") { - x = 0 - } else { - x, err = strconv.ParseInt(sdata, 10, 64) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.SetInt(x) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return fmt.Errorf("arg %v as float64: %s", key, err.Error()) - } - fieldValue.SetFloat(x) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.SetUint(x) - //Currently only support Time type - case reflect.Struct: - // !! 增加支持sql.Scanner接口的结构,如sql.NullString - if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - if err := nulVal.Scan(data); err != nil { - return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error()) - } - } else { - if fieldType.ConvertibleTo(schemas.TimeType) { - x, err := session.byte2Time(col, data) - if err != nil { - return err - } - v = x - fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if session.statement.UseCascade { - table, err := session.engine.tagParser.ParseWithCache(*fieldValue) - if err != nil { - return err - } - - // TODO: current only support 1 primary key - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(schemas.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) - pk[0], err = str2PK(string(data), rawValueType) - if err != nil { - return err - } - - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Elem().Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } - } - } - } - case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - //typeStr := fieldType.String() - switch fieldType.Elem().Kind() { - // case "*string": - case schemas.StringType.Kind(): - x := string(data) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*bool": - case schemas.BoolType.Kind(): - d := string(data) - v, err := strconv.ParseBool(d) - if err != nil { - return fmt.Errorf("arg %v as bool: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&v).Convert(fieldType)) - // case "*complex64": - case schemas.Complex64Type.Kind(): - var x complex64 - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, &x) - if err != nil { - return err - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - } - // case "*complex128": - case schemas.Complex128Type.Kind(): - var x complex128 - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, &x) - if err != nil { - return err - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - } - // case "*float64": - case schemas.Float64Type.Kind(): - x, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return fmt.Errorf("arg %v as float64: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*float32": - case schemas.Float32Type.Kind(): - var x float32 - x1, err := strconv.ParseFloat(string(data), 32) - if err != nil { - return fmt.Errorf("arg %v as float32: %s", key, err.Error()) - } - x = float32(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint64": - case schemas.Uint64Type.Kind(): - var x uint64 - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint": - case schemas.UintType.Kind(): - var x uint - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint32": - case schemas.Uint32Type.Kind(): - var x uint32 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint32(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint8": - case schemas.Uint8Type.Kind(): - var x uint8 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint8(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint16": - case schemas.Uint16Type.Kind(): - var x uint16 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint16(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int64": - case schemas.Int64Type.Kind(): - sdata := string(data) - var x int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int64(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x, err = strconv.ParseInt(sdata, 16, 64) - } else if strings.HasPrefix(sdata, "0") { - x, err = strconv.ParseInt(sdata, 8, 64) - } else { - x, err = strconv.ParseInt(sdata, 10, 64) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int": - case schemas.IntType.Kind(): - sdata := string(data) - var x int - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int32": - case schemas.Int32Type.Kind(): - sdata := string(data) - var x int32 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - session.engine.dialect.URI().DBType == schemas.MYSQL { - if len(data) == 1 { - x = int32(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int32(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int32(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int32(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int8": - case schemas.Int8Type.Kind(): - sdata := string(data) - var x int8 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int8(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int8(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int8(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int8(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int16": - case schemas.Int16Type.Kind(): - sdata := string(data) - var x int16 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int16(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int16(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int16(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int16(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*SomeStruct": - case reflect.Struct: - switch fieldType { - // case "*.time.Time": - case schemas.PtrTimeType: - x, err := session.byte2Time(col, data) - if err != nil { - return err - } - v = x - fieldValue.Set(reflect.ValueOf(&x)) - default: - if session.statement.UseCascade { - structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.tagParser.ParseWithCache(structInter.Elem()) - if err != nil { - return err - } - - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(schemas.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) - pk[0], err = str2PK(string(data), rawValueType) - if err != nil { - return err - } - - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } - } - } else { - return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) - } - } - default: - return fmt.Errorf("unsupported type in Scan: %s", fieldValue.Type().String()) - } - default: - return fmt.Errorf("unsupported type in Scan: %s", fieldValue.Type().String()) - } - - return nil -} diff --git a/session_find.go b/session_find.go index 261e6b7f..41d68479 100644 --- a/session_find.go +++ b/session_find.go @@ -172,6 +172,11 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } + types, err := rows.ColumnTypes() + if err != nil { + return err + } + var newElemFunc func(fields []string) reflect.Value elemType := containerValue.Type().Elem() var isPointer bool @@ -241,7 +246,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if err != nil { return err } - err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc) + err = session.rows2Beans(rows, fields, types, tb, newElemFunc, containerValueSetFunc) rows.Close() if err != nil { return err diff --git a/session_get.go b/session_get.go index cc6427d7..fa97e68e 100644 --- a/session_get.go +++ b/session_get.go @@ -192,7 +192,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { switch t := bean.(type) { case *[]string: - res, err := session.engine.scanStringInterface(rows, types) + res, err := session.engine.scanStringInterface(rows, fields, types) if err != nil { return true, err } @@ -207,7 +207,7 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field } return true, nil case *[]interface{}: - scanResults, err := session.engine.scanInterfaces(rows, types) + scanResults, err := session.engine.scanInterfaces(rows, fields, types) if err != nil { return true, err } @@ -232,7 +232,7 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { switch t := bean.(type) { case *map[string]string: - scanResults, err := session.engine.scanStringInterface(rows, types) + scanResults, err := session.engine.scanStringInterface(rows, fields, types) if err != nil { return true, err } @@ -241,7 +241,7 @@ func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields } return true, nil case *map[string]interface{}: - scanResults, err := session.engine.scanInterfaces(rows, types) + scanResults, err := session.engine.scanInterfaces(rows, fields, types) if err != nil { return true, err } @@ -268,7 +268,7 @@ func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields } func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { - scanResults, err := session.row2Slice(rows, fields, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { return false, err } diff --git a/session_insert.go b/session_insert.go index 7f8f3008..b41dbbac 100644 --- a/session_insert.go +++ b/session_insert.go @@ -375,7 +375,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) + return 1, convertAssignV(aiValue.Addr(), id) } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) @@ -415,7 +415,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - return 1, convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation) + return 1, convertAssignV(*aiValue, id) } res, err := session.exec(sqlStr, args...) @@ -455,7 +455,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - if err := convertAssignV(aiValue.Addr(), id, session.engine.DatabaseTZ, session.engine.TZLocation); err != nil { + if err := convertAssignV(*aiValue, id); err != nil { return 0, err } diff --git a/session_raw.go b/session_raw.go index bf32c6ed..7eb8585d 100644 --- a/session_raw.go +++ b/session_raw.go @@ -6,13 +6,8 @@ package xorm import ( "database/sql" - "fmt" - "reflect" - "strconv" - "time" "xorm.io/xorm/core" - "xorm.io/xorm/schemas" ) func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { @@ -75,61 +70,6 @@ func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row { return core.NewRow(session.queryRows(sqlStr, args...)) } -func value2String(rawValue *reflect.Value) (str string, err error) { - aa := reflect.TypeOf((*rawValue).Interface()) - vv := reflect.ValueOf((*rawValue).Interface()) - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - case reflect.String: - str = vv.String() - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - data := rawValue.Interface().([]byte) - str = string(data) - if str == "\x00" { - str = "0" - } - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - // time type - case reflect.Struct: - if aa.ConvertibleTo(schemas.TimeType) { - str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) - } else { - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - return -} - -func value2Bytes(rawValue *reflect.Value) ([]byte, error) { - str, err := value2String(rawValue) - if err != nil { - return nil, err - } - return []byte(str), nil -} - func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { From 5950824e37b0cbbd1996bc01c62aa78e5656bb00 Mon Sep 17 00:00:00 2001 From: raizen666 Date: Mon, 19 Jul 2021 12:49:50 +0800 Subject: [PATCH 29/31] Support build flag go-json to replace default json (#1982) `go build -tags=gojson` to use `github.com/goccy/go-json` as default json handler Co-authored-by: Lunny Xiao Reviewed-on: https://gitea.com/xorm/xorm/pulls/1982 Reviewed-by: Lunny Xiao Co-authored-by: raizen666 Co-committed-by: raizen666 --- go.mod | 1 + go.sum | 2 ++ internal/json/gojson.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 31 insertions(+) create mode 100644 internal/json/gojson.go diff --git a/go.mod b/go.mod index 78d8d7d4..dbc59e76 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/json-iterator/go v1.1.11 github.com/lib/pq v1.7.0 github.com/mattn/go-sqlite3 v1.14.6 + github.com/goccy/go-json v0.7.4 github.com/shopspring/decimal v1.2.0 github.com/stretchr/testify v1.4.0 github.com/syndtr/goleveldb v1.0.0 diff --git a/go.sum b/go.sum index 85953202..da88d67a 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHX github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/goccy/go-json v0.7.4 h1:B44qRUFwz/vxPKPISQ1KhvzRi9kZ28RAf6YtjriBZ5k= +github.com/goccy/go-json v0.7.4/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= diff --git a/internal/json/gojson.go b/internal/json/gojson.go new file mode 100644 index 00000000..4f1448e7 --- /dev/null +++ b/internal/json/gojson.go @@ -0,0 +1,28 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build gojson + +package json + +import ( + gojson "github.com/goccy/go-json" +) + +func init() { + DefaultJSONHandler = GOjson{} +} + +// GOjson implements JSONInterface via gojson +type GOjson struct{} + +// Marshal implements JSONInterface +func (GOjson) Marshal(v interface{}) ([]byte, error) { + return gojson.Marshal(v) +} + +// Unmarshal implements JSONInterface +func (GOjson) Unmarshal(data []byte, v interface{}) error { + return gojson.Unmarshal(data, v) +} From 86775af2ecd1de58d172ebcbd5a10f7412dc8689 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 19 Jul 2021 13:43:53 +0800 Subject: [PATCH 30/31] refactor and add setjson function (#1997) Fix #1992 Reviewed-on: https://gitea.com/xorm/xorm/pulls/1997 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 4 ++ convert/time.go | 5 ++ dialects/dialect.go | 3 + dialects/mssql.go | 9 +++ dialects/mysql.go | 9 +++ dialects/oracle.go | 9 +++ dialects/postgres.go | 26 ++++++-- dialects/sqlite3.go | 11 +++- engine.go | 3 + scan.go | 3 + session.go | 154 ++++++++++++++----------------------------- session_convert.go | 70 -------------------- session_find.go | 6 ++ session_get.go | 5 +- session_iterate.go | 3 + session_query.go | 9 +++ session_update.go | 3 + 17 files changed, 148 insertions(+), 184 deletions(-) delete mode 100644 session_convert.go diff --git a/convert.go b/convert.go index 533dbe99..1aaf5dca 100644 --- a/convert.go +++ b/convert.go @@ -193,6 +193,8 @@ func asFloat64(src interface{}) (float64, error) { return float64(v.Int32), nil case *sql.NullInt64: return float64(v.Int64), nil + case *sql.NullFloat64: + return v.Float64, nil } rv := reflect.ValueOf(src) @@ -717,6 +719,8 @@ func convertAssignV(dv reflect.Value, src interface{}) error { func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { switch tp.Kind() { + case reflect.Ptr: + return asKind(vv.Elem(), tp.Elem()) case reflect.Int64: return vv.Int(), nil case reflect.Int: diff --git a/convert/time.go b/convert/time.go index 5a3e5246..283c7f83 100644 --- a/convert/time.go +++ b/convert/time.go @@ -8,11 +8,16 @@ import ( "fmt" "strconv" "time" + + "xorm.io/xorm/internal/utils" ) // String2Time converts a string to time with original location func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { if len(s) == 19 { + if s == utils.ZeroTime0 || s == utils.ZeroTime1 { + return &time.Time{}, nil + } dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation) if err != nil { return nil, err diff --git a/dialects/dialect.go b/dialects/dialect.go index df33155d..81d1ee8d 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -118,6 +118,9 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri defer rows.Close() if rows.Next() { + if rows.Err() != nil { + return true, rows.Err() + } return true, nil } return false, nil diff --git a/dialects/mssql.go b/dialects/mssql.go index e708ba80..08232487 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -456,6 +456,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } var name, ctype, vdefault string var maxLen, precision, scale int var nullable, isPK, defaultIsNull, isIncrement bool @@ -524,6 +527,9 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -558,6 +564,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? indexes := make(map[string]*schemas.Index, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var indexType int var indexName, colName, isUnique string diff --git a/dialects/mysql.go b/dialects/mysql.go index 9312c071..88c1038e 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -405,6 +405,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -519,6 +522,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } table := schemas.NewEmptyTable() var name, engine string var autoIncr, comment *string @@ -566,6 +572,9 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName indexes := make(map[string]*schemas.Index, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var indexType int var indexName, colName, nonUnique string err = rows.Scan(&indexName, &nonUnique, &colName) diff --git a/dialects/oracle.go b/dialects/oracle.go index 5dd92887..9240046a 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -677,6 +677,9 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -772,6 +775,9 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables := make([]*schemas.Table, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { @@ -796,6 +802,9 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam indexes := make(map[string]*schemas.Index, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var indexType int var indexName, colName, uniqueness string diff --git a/dialects/postgres.go b/dialects/postgres.go index 4ec780e8..e1dca631 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -810,7 +810,7 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas var version string if !rows.Next() { - return nil, errors.New("Unknow version") + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -1098,6 +1098,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A colSeq := make([]string, 0) for rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -1192,7 +1195,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A } } if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { - return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name) + return nil, nil, fmt.Errorf("unknown colType: %s - %s", dataType, col.SQLType.Name) } col.Length = maxLen @@ -1200,13 +1203,13 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A if !col.DefaultIsEmpty { if col.SQLType.IsText() { if strings.HasSuffix(col.Default, "::character varying") { - col.Default = strings.TrimRight(col.Default, "::character varying") + col.Default = strings.TrimSuffix(col.Default, "::character varying") } else if !strings.HasPrefix(col.Default, "'") { col.Default = "'" + col.Default + "'" } } else if col.SQLType.IsTime() { if strings.HasSuffix(col.Default, "::timestamp without time zone") { - col.Default = strings.TrimRight(col.Default, "::timestamp without time zone") + col.Default = strings.TrimSuffix(col.Default, "::timestamp without time zone") } } } @@ -1234,6 +1237,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch tables := make([]*schemas.Table, 0) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -1259,7 +1265,7 @@ func getIndexColName(indexdef string) []string { func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} - s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") + s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1" if len(db.getSchema()) != 0 { args = append(args, db.getSchema()) s = s + " AND schemaname=$2" @@ -1271,8 +1277,11 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var indexType int var indexName, indexdef string var colNames []string @@ -1450,6 +1459,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri } defer rows.Close() if rows.Next() { + if rows.Err() != nil { + return "", rows.Err() + } var defaultSchema string if err = rows.Scan(&defaultSchema); err != nil { return "", err @@ -1458,5 +1470,5 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri return strings.TrimSpace(parts[len(parts)-1]), nil } - return "", errors.New("No default schema") + return "", errors.New("no default schema") } diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 581272ad..da28d9d1 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -415,12 +415,14 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa defer rows.Close() var name string - for rows.Next() { + if rows.Next() { + if rows.Err() != nil { + return nil, nil, rows.Err() + } err = rows.Scan(&name) if err != nil { return nil, nil, err } - break } if name == "" { @@ -496,8 +498,11 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } var tmpSQL sql.NullString err = rows.Scan(&tmpSQL) if err != nil { diff --git a/engine.go b/engine.go index b4ef9593..35104b04 100644 --- a/engine.go +++ b/engine.go @@ -551,6 +551,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch sess := engine.NewSession() defer sess.Close() for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") if err != nil { return err diff --git a/scan.go b/scan.go index e4c0e4a1..444aa8ac 100644 --- a/scan.go +++ b/scan.go @@ -286,6 +286,9 @@ func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { return nil, err } for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } result, err := row2mapBytes(rows, types, fields) if err != nil { return nil, err diff --git a/session.go b/session.go index 5557d717..8c1d8c3b 100644 --- a/session.go +++ b/session.go @@ -364,25 +364,24 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } -func (session *Session) getField(dataStruct *reflect.Value, key string, table *schemas.Table, idx int) (*reflect.Value, error) { - var col *schemas.Column - if col = table.GetColumnIdx(key, idx); col == nil { - return nil, ErrFieldIsNotExist{key, table.Name} +func (session *Session) getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { + var col = table.GetColumnIdx(colName, idx) + if col == nil { + return nil, nil, ErrFieldIsNotExist{colName, table.Name} } fieldValue, err := col.ValueOfV(dataStruct) if err != nil { - return nil, err + return nil, nil, err } if fieldValue == nil { - return nil, ErrFieldIsNotValid{key, table.Name} + return nil, nil, ErrFieldIsNotValid{colName, table.Name} } - if !fieldValue.IsValid() || !fieldValue.CanSet() { - return nil, ErrFieldIsNotValid{key, table.Name} + return nil, nil, ErrFieldIsNotValid{colName, table.Name} } - return fieldValue, nil + return col, fieldValue, nil } // Cell cell is a result of one column field @@ -392,6 +391,9 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } var newValue = newElemFunc(fields) bean := newValue.Interface() dataStruct := newValue.Elem() @@ -435,6 +437,36 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql return scanResults, nil } +func (session *Session) setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error { + bs, ok := asBytes(scanResult) + if !ok { + return fmt.Errorf("unsupported database data type: %#v", scanResult) + } + if len(bs) == 0 { + return nil + } + + if fieldType.Kind() == reflect.String { + fieldValue.SetString(string(bs)) + return nil + } + + if fieldValue.CanAddr() { + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return err + } + } else { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } + return nil +} + func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, scanResult interface{}, table *schemas.Table) error { v, ok := scanResult.(*interface{}) @@ -445,12 +477,6 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return nil } - rawValue := reflect.Indirect(reflect.ValueOf(scanResult)) - // if row is null then ignore - if rawValue.Interface() == nil { - return nil - } - if fieldValue.CanAddr() { if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { data, ok := asBytes(scanResult) @@ -477,40 +503,11 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return structConvert.FromDB(data) } - rawValueType := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) + vv := reflect.ValueOf(scanResult) fieldType := fieldValue.Type() if col.IsJSON { - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(schemas.BytesType) { - bs = vv.Bytes() - } else { - return fmt.Errorf("unsupported database data type: %s %v", col.Name, rawValueType.Kind()) - } - - if len(bs) > 0 { - if fieldType.Kind() == reflect.String { - fieldValue.SetString(string(bs)) - return nil - } - if fieldValue.CanAddr() { - err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return err - } - } else { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } - return nil + return session.setJSON(fieldValue, fieldType, scanResult) } switch fieldType.Kind() { @@ -529,30 +526,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec } return nil case reflect.Complex64, reflect.Complex128: - // TODO: reimplement this - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(schemas.BytesType) { - bs = vv.Bytes() - } - - if len(bs) > 0 { - if fieldValue.CanAddr() { - err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return err - } - } else { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } - return nil + return session.setJSON(fieldValue, fieldType, scanResult) case reflect.Slice, reflect.Array: bs, ok := asBytes(scanResult) if ok && fieldType.Elem().Kind() == reflect.Uint8 { @@ -602,33 +576,11 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) return nil } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err := nulVal.Scan(vv.Interface()) + err := nulVal.Scan(scanResult) if err == nil { return nil } session.engine.logger.Errorf("sql.Sanner error: %v", err) - } else if col.IsJSON { - if rawValueType.Kind() == reflect.String { - x := reflect.New(fieldType) - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - return nil - } else if rawValueType.Kind() == reflect.Slice { - x := reflect.New(fieldType) - if len(vv.Bytes()) > 0 { - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - return nil - } } else if session.statement.UseCascade { table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { @@ -639,7 +591,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec return errors.New("unsupported non or composited primary key cascade") } var pk = make(schemas.PK, len(table.PrimaryKeys)) - pk[0], err = asKind(vv, rawValueType) + pk[0], err = asKind(vv, reflect.TypeOf(scanResult)) if err != nil { return err } @@ -675,9 +627,9 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b var tempMap = make(map[string]int) var pk schemas.PK - for ii, key := range fields { + for i, colName := range fields { var idx int - var lKey = strings.ToLower(key) + var lKey = strings.ToLower(colName) var ok bool if idx, ok = tempMap[lKey]; !ok { @@ -685,13 +637,9 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } else { idx = idx + 1 } - tempMap[lKey] = idx - col := table.GetColumnIdx(key, idx) - var scanResult = scanResults[ii] - - fieldValue, err := session.getField(dataStruct, key, table, idx) + col, fieldValue, err := session.getField(dataStruct, table, colName, idx) if err != nil { if _, ok := err.(ErrFieldIsNotValid); !ok { session.engine.logger.Warnf("%v", err) @@ -702,11 +650,11 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b continue } - if err := session.convertBeanField(col, fieldValue, scanResult, table); err != nil { + if err := session.convertBeanField(col, fieldValue, scanResults[i], table); err != nil { return nil, err } if col.IsPrimaryKey { - pk = append(pk, scanResult) + pk = append(pk, scanResults[i]) } } return pk, nil diff --git a/session_convert.go b/session_convert.go deleted file mode 100644 index 452801e2..00000000 --- a/session_convert.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "strconv" - "strings" - "time" - - "xorm.io/xorm/internal/utils" - "xorm.io/xorm/schemas" -) - -func (session *Session) str2Time(col *schemas.Column, data string) (outTime time.Time, outErr error) { - sdata := strings.TrimSpace(data) - var x time.Time - var err error - - var parseLoc = session.engine.DatabaseTZ - if col.TimeZone != nil { - parseLoc = col.TimeZone - } - - if sdata == utils.ZeroTime0 || sdata == utils.ZeroTime1 { - } else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column - // time stamp - sd, err := strconv.ParseInt(sdata, 10, 64) - if err == nil { - x = time.Unix(sd, 0) - } - } else if len(sdata) > 19 && strings.Contains(sdata, "-") { - x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) - session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata) - if err != nil { - x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) - } - if err != nil { - x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) - } - } else if len(sdata) == 19 && strings.Contains(sdata, "-") { - x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) - } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { - x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) - } else if col.SQLType.Name == schemas.Time { - if strings.Contains(sdata, " ") { - ssd := strings.Split(sdata, " ") - sdata = ssd[1] - } - - sdata = strings.TrimSpace(sdata) - if session.engine.dialect.URI().DBType == schemas.MYSQL && len(sdata) > 8 { - sdata = sdata[len(sdata)-8:] - } - - st := fmt.Sprintf("2006-01-02 %v", sdata) - x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) - } else { - outErr = fmt.Errorf("unsupported time format %v", sdata) - return - } - if err != nil { - outErr = fmt.Errorf("unsupported time format %v: %v", sdata, err) - return - } - outTime = x.In(session.engine.TZLocation) - return -} diff --git a/session_find.go b/session_find.go index 41d68479..89e34e80 100644 --- a/session_find.go +++ b/session_find.go @@ -255,6 +255,9 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } var newValue = newElemFunc(fields) bean := newValue.Interface() @@ -322,6 +325,9 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in var i int ids = make([]schemas.PK, 0) for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } i++ if i > 500 { session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") diff --git a/session_get.go b/session_get.go index fa97e68e..1062bd9d 100644 --- a/session_get.go +++ b/session_get.go @@ -313,9 +313,12 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf defer rows.Close() if rows.Next() { + if rows.Err() != nil { + return true, rows.Err() + } err = rows.ScanSlice(&res) if err != nil { - return false, err + return true, err } } else { return false, ErrCacheFailed diff --git a/session_iterate.go b/session_iterate.go index 8cab8f48..dbbeb3f4 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -43,6 +43,9 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { i := 0 for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } b := reflect.New(rows.beanType).Interface() err = rows.Scan(b) if err != nil { diff --git a/session_query.go b/session_query.go index d14c3908..8543ba12 100644 --- a/session_query.go +++ b/session_query.go @@ -33,6 +33,9 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string } for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } result, err := session.engine.row2mapStr(rows, types, fields) if err != nil { return nil, err @@ -54,6 +57,9 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err @@ -114,6 +120,9 @@ func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[str return nil, err } for rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } result, err := session.engine.row2mapInterface(rows, types, fields) if err != nil { return nil, err diff --git a/session_update.go b/session_update.go index 78907e43..32e28ae0 100644 --- a/session_update.go +++ b/session_update.go @@ -59,6 +59,9 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = make([]schemas.PK, 0) for rows.Next() { + if rows.Err() != nil { + return rows.Err() + } var res = make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { From a7e010df2dd0e38ac86a587fc760858423a8f480 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 20 Jul 2021 13:46:24 +0800 Subject: [PATCH 31/31] refactor insert condition generation (#1998) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1998 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 4 +- convert/interface.go | 1 + convert/time.go | 2 +- dialects/dialect.go | 5 +- dialects/driver.go | 11 -- dialects/mssql.go | 26 +-- dialects/mysql.go | 38 ++-- dialects/oracle.go | 35 ++-- dialects/postgres.go | 38 ++-- dialects/postgres_test.go | 2 - dialects/sqlite3.go | 26 +-- engine.go | 6 +- go.mod | 17 +- go.sum | 35 +++- integrations/engine_test.go | 1 - integrations/session_get_test.go | 10 +- integrations/session_insert_test.go | 2 +- integrations/session_update_test.go | 1 - internal/statements/statement.go | 269 ++++++++++++++-------------- internal/statements/values.go | 2 +- internal/utils/strings.go | 4 +- names/mapper.go | 2 +- rows.go | 35 +--- scan.go | 14 +- schemas/table_test.go | 2 - schemas/type.go | 1 + session.go | 5 +- session_exist.go | 5 +- session_find.go | 11 +- session_get.go | 11 +- session_insert.go | 1 - session_iterate.go | 5 +- session_query.go | 18 +- session_update.go | 6 +- tags/parser.go | 1 + 35 files changed, 324 insertions(+), 328 deletions(-) diff --git a/convert.go b/convert.go index 1aaf5dca..c3eb4de9 100644 --- a/convert.go +++ b/convert.go @@ -373,7 +373,6 @@ func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. case *sql.NullInt64: tm := time.Unix(t.Int64, 0).In(uiLoc) return &tm, nil - } return nil, fmt.Errorf("unsupported value %#v as time", src) } @@ -751,7 +750,6 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { } return v, nil } - } return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) } @@ -946,8 +944,10 @@ var ( _ sql.Scanner = &EmptyScanner{} ) +// EmptyScanner represents an empty scanner which will ignore the scan type EmptyScanner struct{} +// Scan implements sql.Scanner func (EmptyScanner) Scan(value interface{}) error { return nil } diff --git a/convert/interface.go b/convert/interface.go index 2b055253..b0f28c81 100644 --- a/convert/interface.go +++ b/convert/interface.go @@ -10,6 +10,7 @@ import ( "time" ) +// Interface2Interface converts interface of pointer as interface of value func Interface2Interface(userLocation *time.Location, v interface{}) (interface{}, error) { if v == nil { return nil, nil diff --git a/convert/time.go b/convert/time.go index 283c7f83..6a53171b 100644 --- a/convert/time.go +++ b/convert/time.go @@ -45,5 +45,5 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t return &tm, nil } } - return nil, fmt.Errorf("unsupported convertion from %s to time", s) + return nil, fmt.Errorf("unsupported conversion from %s to time", s) } diff --git a/dialects/dialect.go b/dialects/dialect.go index 81d1ee8d..fc11eac1 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -118,12 +118,9 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return true, rows.Err() - } return true, nil } - return false, nil + return false, rows.Err() } // IsColumnExist returns true if the column of the table exist diff --git a/dialects/driver.go b/dialects/driver.go index 0b6187d3..c511b665 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -18,14 +18,9 @@ type ScanContext struct { UserLocation *time.Location } -type DriverFeatures struct { - SupportNullable bool -} - // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) - Features() DriverFeatures GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } @@ -82,9 +77,3 @@ type baseDriver struct{} func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { return rows.Scan(v...) } - -func (b *baseDriver) Features() DriverFeatures { - return DriverFeatures{ - SupportNullable: true, - } -} diff --git a/dialects/mssql.go b/dialects/mssql.go index 08232487..742928b0 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -264,6 +264,9 @@ func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve var version, level, edition string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -456,9 +459,6 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } var name, ctype, vdefault string var maxLen, precision, scale int var nullable, isPK, defaultIsNull, isIncrement bool @@ -512,6 +512,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -527,9 +530,6 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -539,6 +539,9 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.Name = strings.Trim(name, "` ") tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -562,11 +565,8 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, isUnique string @@ -604,6 +604,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -664,8 +667,7 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { for _, c := range kv { vv := strings.Split(strings.TrimSpace(c), "=") if len(vv) == 2 { - switch strings.ToLower(vv[0]) { - case "database": + if strings.ToLower(vv[0]) == "database" { dbName = vv[1] } } diff --git a/dialects/mysql.go b/dialects/mysql.go index 88c1038e..71ee3864 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -213,7 +213,10 @@ func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve var version string if !rows.Next() { - return nil, errors.New("Unknow version") + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -254,9 +257,6 @@ func (db *mysql) SetParams(params map[string]string) { fallthrough case "COMPRESSED": db.rowFormat = t - break - default: - break } } } @@ -405,9 +405,6 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -506,6 +503,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -522,9 +522,6 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name, engine string var autoIncr, comment *string @@ -540,6 +537,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.StoreEngine = engine tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -570,11 +570,8 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, nonUnique string err = rows.Scan(&indexName, &nonUnique, &colName) @@ -586,7 +583,7 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName continue } - if "YES" == nonUnique || nonUnique == "1" { + if nonUnique == "YES" || nonUnique == "1" { indexType = schemas.IndexType } else { indexType = schemas.UniqueType @@ -610,6 +607,9 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -696,14 +696,12 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { for _, kv := range kvs { splits := strings.Split(kv, "=") if len(splits) == 2 { - switch splits[0] { - case "charset": + if splits[0] == "charset" { uri.Charset = splits[1] } } } } - } } return uri, nil @@ -720,13 +718,13 @@ func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { case "TINYINT", "SMALLINT", "MEDIUMINT", "INT": var s sql.NullInt32 return &s, nil - case "FLOAT", "REAL", "DOUBLE PRECISION": + case "FLOAT", "REAL", "DOUBLE PRECISION", "DOUBLE": var s sql.NullFloat64 return &s, nil case "DECIMAL", "NUMERIC": var s sql.NullString return &s, nil - case "DATETIME": + case "DATETIME", "TIMESTAMP": var s sql.NullTime return &s, nil case "BIT": diff --git a/dialects/oracle.go b/dialects/oracle.go index 9240046a..902e0c66 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -525,6 +525,9 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V var version string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -677,9 +680,6 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -759,6 +759,9 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -775,9 +778,6 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { @@ -786,6 +786,9 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -800,11 +803,8 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, uniqueness string @@ -838,6 +838,9 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -851,7 +854,7 @@ type godrorDriver struct { baseDriver } -func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { +func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] @@ -863,8 +866,7 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) names := dsnPattern.SubexpNames() for i, match := range matches { - switch names[i] { - case "dbname": + if names[i] == "dbname" { db.DBName = match } } @@ -874,7 +876,7 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) return db, nil } -func (p *godrorDriver) GenScanResult(colType string) (interface{}, error) { +func (g *godrorDriver) GenScanResult(colType string) (interface{}, error) { switch colType { case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": var s sql.NullString @@ -900,7 +902,7 @@ type oci8Driver struct { // dataSourceName=user/password@ipv4:port/dbname // dataSourceName=user/password@[ipv6]:port/dbname -func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { +func (o *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?P.*)\/(?P.*)@` + // user:password@ @@ -909,8 +911,7 @@ func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { matches := dsnPattern.FindStringSubmatch(dataSourceName) names := dsnPattern.SubexpNames() for i, match := range matches { - switch names[i] { - case "dbname": + if names[i] == "dbname" { db.DBName = match } } diff --git a/dialects/postgres.go b/dialects/postgres.go index e1dca631..6462982d 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -810,6 +810,9 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas var version string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -1062,7 +1065,10 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { @@ -1098,9 +1104,6 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -1216,6 +1219,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -1237,9 +1243,6 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -1249,6 +1252,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch table.Name = name tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -1279,9 +1285,6 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, indexdef string var colNames []string @@ -1322,6 +1325,9 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -1333,12 +1339,6 @@ type pqDriver struct { baseDriver } -func (b *pqDriver) Features() DriverFeatures { - return DriverFeatures{ - SupportNullable: false, - } -} - type values map[string]string func (vs values) Set(k, v string) { @@ -1459,9 +1459,6 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri } defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return "", rows.Err() - } var defaultSchema string if err = rows.Scan(&defaultSchema); err != nil { return "", err @@ -1469,6 +1466,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri parts := strings.Split(defaultSchema, ",") return strings.TrimSpace(parts[len(parts)-1]), nil } + if rows.Err() != nil { + return "", rows.Err() + } return "", errors.New("no default schema") } diff --git a/dialects/postgres_test.go b/dialects/postgres_test.go index c0a8eb6f..e0c36f92 100644 --- a/dialects/postgres_test.go +++ b/dialects/postgres_test.go @@ -76,9 +76,7 @@ func TestParsePgx(t *testing.T) { } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } - } - } func TestGetIndexColName(t *testing.T) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index da28d9d1..89f86147 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -169,7 +169,10 @@ func (db *sqlite3) Version(ctx context.Context, queryer core.Queryer) (*schemas. var version string if !rows.Next() { - return nil, errors.New("Unknow version") + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -416,14 +419,14 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa var name string if rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } err = rows.Scan(&name) if err != nil { return nil, nil, err } } + if rows.Err() != nil { + return nil, nil, rows.Err() + } if name == "" { return nil, nil, errors.New("no table named " + tableName) @@ -485,6 +488,9 @@ func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*sche } tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -500,9 +506,6 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var tmpSQL sql.NullString err = rows.Scan(&tmpSQL) if err != nil { @@ -547,6 +550,9 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -592,9 +598,3 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { return &r, nil } } - -func (b *sqlite3Driver) Features() DriverFeatures { - return DriverFeatures{ - SupportNullable: false, - } -} diff --git a/engine.go b/engine.go index 35104b04..20c07e13 100644 --- a/engine.go +++ b/engine.go @@ -551,9 +551,6 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch sess := engine.NewSession() defer sess.Close() for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") if err != nil { return err @@ -610,6 +607,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } + if rows.Err() != nil { + return rows.Err() + } // FIXME: Hack for postgres if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { diff --git a/go.mod b/go.mod index dbc59e76..1b3baf0c 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,17 @@ module xorm.io/xorm go 1.13 require ( - github.com/denisenkom/go-mssqldb v0.9.0 - github.com/go-sql-driver/mysql v1.5.0 - github.com/json-iterator/go v1.1.11 - github.com/lib/pq v1.7.0 - github.com/mattn/go-sqlite3 v1.14.6 + github.com/denisenkom/go-mssqldb v0.10.0 + github.com/go-sql-driver/mysql v1.6.0 github.com/goccy/go-json v0.7.4 + github.com/json-iterator/go v1.1.11 + github.com/lib/pq v1.10.2 + github.com/mattn/go-sqlite3 v1.14.8 github.com/shopspring/decimal v1.2.0 - github.com/stretchr/testify v1.4.0 + github.com/stretchr/testify v1.7.0 github.com/syndtr/goleveldb v1.0.0 github.com/ziutek/mymysql v1.5.4 - modernc.org/sqlite v1.10.1-0.20210314190707-798bbeb9bb84 - xorm.io/builder v0.3.8 + gopkg.in/yaml.v2 v2.2.2 // indirect + modernc.org/sqlite v1.11.2 + xorm.io/builder v0.3.9 ) diff --git a/go.sum b/go.sum index da88d67a..3d4b72a6 100644 --- a/go.sum +++ b/go.sum @@ -5,12 +5,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/denisenkom/go-mssqldb v0.9.0 h1:RSohk2RsiZqLZ0zCjtfn3S4Gp4exhpBWHyQ7D0yGjAk= github.com/denisenkom/go-mssqldb v0.9.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= +github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/goccy/go-json v0.7.4 h1:B44qRUFwz/vxPKPISQ1KhvzRi9kZ28RAf6YtjriBZ5k= +github.com/goccy/go-json v0.7.4/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= @@ -28,12 +34,14 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/lib/pq v1.7.0 h1:h93mCPfUSkaul3Ka/VG8uZdmW1uMHDGxzu0NWHuJmHY= github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= -github.com/goccy/go-json v0.7.4 h1:B44qRUFwz/vxPKPISQ1KhvzRi9kZ28RAf6YtjriBZ5k= -github.com/goccy/go-json v0.7.4/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU= +github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= @@ -49,10 +57,13 @@ github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6O github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -103,28 +114,46 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.1.1 h1:pnxCASz787iMf+02ssImqk6OLt+Z5QHMoZyUXR4z6JU= +lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= modernc.org/cc/v3 v3.31.5-0.20210308123301-7a3e9dab9009 h1:u0oCo5b9wyLr++HF3AN9JicGhkUxJhMz51+8TIZH9N0= modernc.org/cc/v3 v3.31.5-0.20210308123301-7a3e9dab9009/go.mod h1:0R6jl1aZlIl2avnYfbfHBS1QB6/f+16mihBObaBC878= +modernc.org/cc/v3 v3.33.6 h1:r63dgSzVzRxUpAJFPQWHy1QeZeY1ydNENUDaBx1GqYc= +modernc.org/cc/v3 v3.33.6/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= modernc.org/ccgo/v3 v3.9.0 h1:JbcEIqjw4Agf+0g3Tc85YvfYqkkFOv6xBwS4zkfqSoA= modernc.org/ccgo/v3 v3.9.0/go.mod h1:nQbgkn8mwzPdp4mm6BT6+p85ugQ7FrGgIcYaE7nSrpY= +modernc.org/ccgo/v3 v3.9.5 h1:dEuUSf8WN51rDkprFuAqjfchKEzN0WttP/Py3enBwjk= +modernc.org/ccgo/v3 v3.9.5/go.mod h1:umuo2EP2oDSBnD3ckjaVUXMrmeAw8C8OSICVa0iFf60= modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM= modernc.org/libc v1.7.13-0.20210308123627-12f642a52bb8/go.mod h1:U1eq8YWr/Kc1RWCMFUWEdkTg8OTcfLw2kY8EDwl039w= modernc.org/libc v1.8.0 h1:Pp4uv9g0csgBMpGPABKtkieF6O5MGhfGo6ZiOdlYfR8= modernc.org/libc v1.8.0/go.mod h1:U1eq8YWr/Kc1RWCMFUWEdkTg8OTcfLw2kY8EDwl039w= +modernc.org/libc v1.9.8/go.mod h1:U1eq8YWr/Kc1RWCMFUWEdkTg8OTcfLw2kY8EDwl039w= +modernc.org/libc v1.9.11 h1:QUxZMs48Ahg2F7SN41aERvMfGLY2HU/ADnB9DC4Yts8= +modernc.org/libc v1.9.11/go.mod h1:NyF3tsA5ArIjJ83XB0JlqhjTabTCHm9aX4XMPHyQn0Q= modernc.org/mathutil v1.1.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= modernc.org/mathutil v1.2.2 h1:+yFk8hBprV+4c0U9GjFtL+dV3N8hOJ8JCituQcMShFY= modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/mathutil v1.4.0 h1:GCjoRaBew8ECCKINQA2nYjzvufFW9YiEuuB+rQ9bn2E= +modernc.org/mathutil v1.4.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= modernc.org/memory v1.0.4 h1:utMBrFcpnQDdNsmM6asmyH/FM9TqLPS7XF7otpJmrwM= modernc.org/memory v1.0.4/go.mod h1:nV2OApxradM3/OVbs2/0OsP6nPfakXpi50C7dcoHXlc= modernc.org/opt v0.1.1 h1:/0RX92k9vwVeDXj+Xn23DKp2VJubL7k8qNffND6qn3A= modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= modernc.org/sqlite v1.10.1-0.20210314190707-798bbeb9bb84 h1:rgEUzE849tFlHSoeCrKyS9cZAljC+DY7MdMHKq6R6sY= modernc.org/sqlite v1.10.1-0.20210314190707-798bbeb9bb84/go.mod h1:PGzq6qlhyYjL6uVbSgS6WoF7ZopTW/sI7+7p+mb4ZVU= +modernc.org/sqlite v1.11.2 h1:ShWQpeD3ag/bmx6TqidBlIWonWmQaSQKls3aenCbt+w= +modernc.org/sqlite v1.11.2/go.mod h1:+mhs/P1ONd+6G7hcAs6irwDi/bjTQ7nLW6LHRBsEa3A= modernc.org/strutil v1.1.0 h1:+1/yCzZxY2pZwwrsbH+4T7BQMoLQ9QiBshRC9eicYsc= modernc.org/strutil v1.1.0/go.mod h1:lstksw84oURvj9y3tn8lGvRxyRC1S2+g5uuIzNfIOBs= +modernc.org/strutil v1.1.1 h1:xv+J1BXY3Opl2ALrBwyfEikFAj8pmqcpnfmuwUwcozs= +modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw= modernc.org/tcl v1.5.0 h1:euZSUNfE0Fd4W8VqXI1Ly1v7fqDJoBuAV88Ea+SnaSs= modernc.org/tcl v1.5.0/go.mod h1:gb57hj4pO8fRrK54zveIfFXBaMHK3SKJNWcmRw1cRzc= +modernc.org/tcl v1.5.5/go.mod h1:ADkaTUuwukkrlhqwERyq0SM8OvyXo7+TjFz7yAF56EI= modernc.org/token v1.0.0 h1:a0jaWiNMDhDUtqOj09wvjWWAqd3q7WpBulmL9H2egsk= modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= modernc.org/z v1.0.1-0.20210308123920-1f282aa71362/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA= @@ -132,3 +161,5 @@ modernc.org/z v1.0.1 h1:WyIDpEpAIx4Hel6q/Pcgj/VhaQV5XPJ2I6ryIYbjnpc= modernc.org/z v1.0.1/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA= xorm.io/builder v0.3.8 h1:P/wPgRqa9kX5uE0aA1/ukJ23u9KH0aSRpHLwDKXigSE= xorm.io/builder v0.3.8/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= +xorm.io/builder v0.3.9 h1:Sd65/LdWyO7LR8+Cbd+e7mm3sK/7U9k0jS3999IDHMc= +xorm.io/builder v0.3.9/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= diff --git a/integrations/engine_test.go b/integrations/engine_test.go index a594ee46..b5ecb2c2 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -172,7 +172,6 @@ func TestDumpTables(t *testing.T) { name := fmt.Sprintf("dump_%v-table.sql", tp) t.Run(name, func(t *testing.T) { assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) - }) } diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index b1dffe14..d3ce2a11 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -818,8 +818,9 @@ func TestGetBigFloat(t *testing.T) { } type GetBigFloat2 struct { - Id int64 - Money *big.Float `xorm:"decimal(22,2)"` + Id int64 + Money *big.Float `xorm:"decimal(22,2)"` + Money2 big.Float `xorm:"decimal(22,2)"` } assert.NoError(t, PrepareEngine()) @@ -827,7 +828,8 @@ func TestGetBigFloat(t *testing.T) { { var gf2 = GetBigFloat2{ - Money: big.NewFloat(9999999.99), + Money: big.NewFloat(9999999.99), + Money2: *big.NewFloat(99.99), } _, err := testEngine.Insert(&gf2) assert.NoError(t, err) @@ -845,12 +847,14 @@ func TestGetBigFloat(t *testing.T) { assert.NoError(t, err) assert.True(t, has) assert.True(t, gf3.Money.String() == gf2.Money.String(), "%v != %v", gf3.Money.String(), gf2.Money.String()) + assert.True(t, gf3.Money2.String() == gf2.Money2.String(), "%v != %v", gf3.Money2.String(), gf2.Money2.String()) var gfs []GetBigFloat2 err = testEngine.Find(&gfs) assert.NoError(t, err) assert.EqualValues(t, 1, len(gfs)) assert.True(t, gfs[0].Money.String() == gf2.Money.String(), "%v != %v", gfs[0].Money.String(), gf2.Money.String()) + assert.True(t, gfs[0].Money2.String() == gf2.Money2.String(), "%v != %v", gfs[0].Money2.String(), gf2.Money2.String()) } } diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index a023ab72..ce52d3c4 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -202,7 +202,7 @@ func TestInsertDefault2(t *testing.T) { Id int64 Name string Url string `xorm:"text"` - CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"` + CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00'"` } di := new(DefaultInsert2) diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index 22808d60..cc1042b6 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -1313,7 +1313,6 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { assert.EqualValues(t, true, has) assert.EqualValues(t, "", record.OnlyFromDBField) return &record - } assert.NoError(t, PrepareEngine()) assertSync(t, new(TestOnlyFromDBField)) diff --git a/internal/statements/statement.go b/internal/statements/statement.go index bfe9987f..0e245a96 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -8,6 +8,7 @@ import ( "database/sql/driver" "errors" "fmt" + "math/big" "reflect" "strings" "time" @@ -662,10 +663,6 @@ func (statement *Statement) GenIndexSQL() []string { return sqls } -func uniqueName(tableName, uqeName string) string { - return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) -} - // GenUniqueSQL generates unique SQL func (statement *Statement) GenUniqueSQL() []string { var sqls []string @@ -693,6 +690,138 @@ func (statement *Statement) GenDelIndexSQL() []string { return sqls } +func (statement *Statement) asDBCond(fieldValue reflect.Value, fieldType reflect.Type, col *schemas.Column, allUseBool, requiredField bool) (interface{}, bool, error) { + switch fieldType.Kind() { + case reflect.Ptr: + if fieldValue.IsNil() { + return nil, true, nil + } + return statement.asDBCond(fieldValue.Elem(), fieldType.Elem(), col, allUseBool, requiredField) + case reflect.Bool: + if allUseBool || requiredField { + return fieldValue.Interface(), true, nil + } + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + return nil, false, nil + case reflect.String: + if !requiredField && fieldValue.String() == "" { + return nil, false, nil + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + return fieldValue.String(), true, nil + } + return fieldValue.Interface(), true, nil + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + return nil, false, nil + } + return dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t), true, nil + } else if fieldType.ConvertibleTo(schemas.BigFloatType) { + t := fieldValue.Convert(schemas.BigFloatType).Interface().(big.Float) + v := t.String() + if v == "0" { + return nil, false, nil + } + return t.String(), true, nil + } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { + return nil, false, nil + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ := valNul.Value() + if val == nil && !requiredField { + return nil, false, nil + } + return val, true, nil + } else { + if col.IsJSON { + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return string(bytes), true, nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return bytes, true, nil + } + } else { + table, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + return fieldValue.Interface(), true, nil + } + + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { + return pkField.Interface(), true, nil + } + return nil, false, nil + } + return nil, false, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } + case reflect.Array: + return nil, false, nil + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + return nil, false, nil + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + return nil, false, nil + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return string(bytes), true, nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + return fieldValue.Bytes(), true, nil + } + return nil, false, nil + } + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return bytes, true, nil + } + return nil, false, nil + } + return fieldValue.Interface(), true, nil +} + func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, @@ -747,9 +876,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, continue } - fieldType := reflect.TypeOf(fieldValue.Interface()) requiredField := useAllCols - if b, ok := getFlagForColumn(mustColumnMap, col); ok { if b { requiredField = true @@ -758,6 +885,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, } } + fieldType := reflect.TypeOf(fieldValue.Interface()) if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { @@ -774,131 +902,12 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, } } - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(schemas.TimeType) { - t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) - } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil && !requiredField { - continue - } - } else { - if col.IsJSON { - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = bytes - } - } else { - table, err := statement.tagParser.ParseWithCache(fieldValue) - if err != nil { - val = fieldValue.Interface() - } else { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) - } - } - } - } - case reflect.Array: + val, ok, err := statement.asDBCond(fieldValue, fieldType, col, allUseBool, requiredField) + if err != nil { + return nil, err + } + if !ok { continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() } conds = append(conds, builder.Eq{colName: val}) diff --git a/internal/statements/values.go b/internal/statements/values.go index ee3821e9..c572ead5 100644 --- a/internal/statements/values.go +++ b/internal/statements/values.go @@ -23,7 +23,7 @@ var ( bigFloatType = reflect.TypeOf(big.Float{}) ) -// Value2Interface convert a field value of a struct to interface for puting into database +// Value2Interface convert a field value of a struct to interface for putting into database func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) { if fieldValue.CanAddr() { if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { diff --git a/internal/utils/strings.go b/internal/utils/strings.go index 86469c0f..159e2876 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -13,7 +13,7 @@ func IndexNoCase(s, sep string) int { return strings.Index(strings.ToLower(s), strings.ToLower(sep)) } -// SplitNoCase split a string by a seperator with no care of capitalize +// SplitNoCase split a string by a separator with no care of capitalize func SplitNoCase(s, sep string) []string { idx := IndexNoCase(s, sep) if idx < 0 { @@ -22,7 +22,7 @@ func SplitNoCase(s, sep string) []string { return strings.Split(s, s[idx:idx+len(sep)]) } -// SplitNNoCase split n by a seperator with no care of capitalize +// SplitNNoCase split n by a separator with no care of capitalize func SplitNNoCase(s, sep string, n int) []string { idx := IndexNoCase(s, sep) if idx < 0 { diff --git a/names/mapper.go b/names/mapper.go index b0ce8076..69f67171 100644 --- a/names/mapper.go +++ b/names/mapper.go @@ -79,7 +79,7 @@ func (m SameMapper) Table2Obj(t string) string { return t } -// SnakeMapper implements IMapper and provides name transaltion between +// SnakeMapper implements IMapper and provides name translation between // struct and database table type SnakeMapper struct { } diff --git a/rows.go b/rows.go index 5e0a1ffe..8e7cc075 100644 --- a/rows.go +++ b/rows.go @@ -5,7 +5,6 @@ package xorm import ( - "database/sql" "errors" "fmt" "reflect" @@ -17,10 +16,9 @@ import ( // Rows rows wrapper a rows to type Rows struct { - session *Session - rows *core.Rows - beanType reflect.Type - lastError error + session *Session + rows *core.Rows + beanType reflect.Type } func newRows(session *Session, bean interface{}) (*Rows, error) { @@ -62,15 +60,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://gitea.com/xorm/xorm/issues/179 if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled - var colName = session.engine.Quote(col.Name) - if addedTableName { - var nm = session.statement.TableName() - if len(session.statement.TableAlias) > 0 { - nm = session.statement.TableAlias - } - colName = session.engine.Quote(nm) + "." + colName - } - autoCond = session.statement.CondDeleted(col) } } @@ -86,7 +75,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { rows.rows, err = rows.session.queryRows(sqlStr, args...) if err != nil { - rows.lastError = err rows.Close() return nil, err } @@ -96,25 +84,18 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { // Next move cursor to next record, return false if end has reached func (rows *Rows) Next() bool { - if rows.lastError == nil && rows.rows != nil { - hasNext := rows.rows.Next() - if !hasNext { - rows.lastError = sql.ErrNoRows - } - return hasNext - } - return false + return rows.rows.Next() } // Err returns the error, if any, that was encountered during iteration. Err may be called after an explicit or implicit Close. func (rows *Rows) Err() error { - return rows.lastError + return rows.rows.Err() } // Scan row record to bean properties func (rows *Rows) Scan(bean interface{}) error { - if rows.lastError != nil { - return rows.lastError + if rows.Err() != nil { + return rows.Err() } if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { @@ -158,5 +139,5 @@ func (rows *Rows) Close() error { return rows.rows.Close() } - return rows.lastError + return rows.Err() } diff --git a/scan.go b/scan.go index 444aa8ac..ccd6938d 100644 --- a/scan.go +++ b/scan.go @@ -211,12 +211,8 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column scanResult = &sql.RawBytes{} replaced = true default: - var useNullable = true - if engine.driver.Features().SupportNullable { - nullable, ok := types[0].Nullable() - useNullable = ok && nullable - } - if useNullable { + nullable, ok := types[0].Nullable() + if !ok || nullable { scanResult, replaced, err = genScanResultsByBeanNullable(v) } else { scanResult, replaced, err = genScanResultsByBean(v) @@ -286,15 +282,15 @@ func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { return nil, err } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } result, err := row2mapBytes(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, result) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } diff --git a/schemas/table_test.go b/schemas/table_test.go index 0e35193f..f352675b 100644 --- a/schemas/table_test.go +++ b/schemas/table_test.go @@ -58,7 +58,6 @@ func TestGetColumnIdx(t *testing.T) { func BenchmarkGetColumnWithToLower(b *testing.B) { for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { - if _, ok := table.columnsMap[strings.ToLower(test.name)]; !ok { b.Errorf("Column not found:%s", test.name) } @@ -69,7 +68,6 @@ func BenchmarkGetColumnWithToLower(b *testing.B) { func BenchmarkGetColumnIdxWithToLower(b *testing.B) { for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { - if c, ok := table.columnsMap[strings.ToLower(test.name)]; ok { if test.idx < len(c) { continue diff --git a/schemas/type.go b/schemas/type.go index 62e66c2e..d64251bf 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -65,6 +65,7 @@ func (s *SQLType) IsTime() bool { return s.IsType(TIME_TYPE) } +// IsBool returns true if column is a boolean type func (s *SQLType) IsBool() bool { return s.IsType(BOOL_TYPE) } diff --git a/session.go b/session.go index 8c1d8c3b..62d6a770 100644 --- a/session.go +++ b/session.go @@ -391,9 +391,6 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } var newValue = newElemFunc(fields) bean := newValue.Interface() dataStruct := newValue.Elem() @@ -415,7 +412,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq bean: bean, }) } - return nil + return rows.Err() } func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) { diff --git a/session_exist.go b/session_exist.go index e52c618e..b5e4a655 100644 --- a/session_exist.go +++ b/session_exist.go @@ -25,5 +25,8 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } diff --git a/session_find.go b/session_find.go index 89e34e80..010ecd6c 100644 --- a/session_find.go +++ b/session_find.go @@ -255,9 +255,6 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } var newValue = newElemFunc(fields) bean := newValue.Interface() @@ -278,7 +275,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } } - return nil + return rows.Err() } func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { @@ -325,9 +322,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in var i int ids = make([]schemas.PK, 0) for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } i++ if i > 500 { session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") @@ -348,6 +342,9 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ids = append(ids, pk) } + if rows.Err() != nil { + return rows.Err() + } session.engine.logger.Debugf("[cache] cache sql: %v, %v, %v, %v, %v", ids, tableName, sqlStr, newsql, args) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) diff --git a/session_get.go b/session_get.go index 1062bd9d..08172524 100644 --- a/session_get.go +++ b/session_get.go @@ -159,10 +159,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, defer rows.Close() if !rows.Next() { - if rows.Err() != nil { - return false, rows.Err() - } - return false, nil + return false, rows.Err() } // WARN: Alougth rows return true, but we may also return error. @@ -313,14 +310,14 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return true, rows.Err() - } err = rows.ScanSlice(&res) if err != nil { return true, err } } else { + if rows.Err() != nil { + return false, rows.Err() + } return false, ErrCacheFailed } diff --git a/session_insert.go b/session_insert.go index b41dbbac..a9b8b7d2 100644 --- a/session_insert.go +++ b/session_insert.go @@ -325,7 +325,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { copy(afterClosures, session.afterClosures) session.afterInsertBeans[bean] = &afterClosures } - } else { if _, ok := interface{}(bean).(AfterInsertProcessor); ok { session.afterInsertBeans[bean] = nil diff --git a/session_iterate.go b/session_iterate.go index dbbeb3f4..f6301009 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -43,9 +43,6 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { i := 0 for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } b := reflect.New(rows.beanType).Interface() err = rows.Scan(b) if err != nil { @@ -57,7 +54,7 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { } i++ } - return err + return rows.Err() } // BufferSize sets the buffersize for iterate diff --git a/session_query.go b/session_query.go index 8543ba12..a4070985 100644 --- a/session_query.go +++ b/session_query.go @@ -33,15 +33,15 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } result, err := session.engine.row2mapStr(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, result) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } @@ -57,15 +57,15 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, record) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } @@ -120,15 +120,15 @@ func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[str return nil, err } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } result, err := session.engine.row2mapInterface(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, result) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } diff --git a/session_update.go b/session_update.go index 32e28ae0..4f8e6961 100644 --- a/session_update.go +++ b/session_update.go @@ -59,9 +59,6 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = make([]schemas.PK, 0) for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } var res = make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { @@ -84,6 +81,9 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = append(ids, pk) } + if rows.Err() != nil { + return rows.Err() + } session.engine.logger.Debugf("[cache] find updated id: %v", ids) } /*else { session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) diff --git a/tags/parser.go b/tags/parser.go index b793a8f1..72baa153 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -124,6 +124,7 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index } } +// ErrIgnoreField represents an error to ignore field var ErrIgnoreField = errors.New("field will be ignored") func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {