diff --git a/.drone.yml b/.drone.yml index 300c7841..2bad4b5a 100644 --- a/.drone.yml +++ b/.drone.yml @@ -1,59 +1,109 @@ --- kind: pipeline -name: testing +name: test-mysql +environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.io" + CGO_ENABLED: 1 +trigger: + ref: + - refs/heads/master + - refs/pull/*/head steps: - name: test-vet - image: golang:1.11 # The lowest golang requirement - environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" + image: golang:1.17 + pull: always + volumes: + - name: cache + path: /go/pkg/mod commands: - make vet - - make test - - make fmt-check - when: - event: - - push - - pull_request - -- name: test-sqlite - image: golang:1.12 - environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" +- name: test-sqlite3 + image: golang:1.17 + volumes: + - name: cache + path: /go/pkg/mod + depends_on: + - test-vet commands: - - make test-sqlite - - TEST_CACHE_ENABLE=true make test-sqlite - - TEST_QUOTE_POLICY=reserved make test-sqlite - when: - event: - - push - - pull_request - + - make fmt-check + - make test + - make test-sqlite3 + - TEST_CACHE_ENABLE=true make test-sqlite3 +- name: test-sqlite + image: golang:1.17 + volumes: + - name: cache + path: /go/pkg/mod + depends_on: + - test-vet + commands: + - make test-sqlite + - TEST_QUOTE_POLICY=reserved make test-sqlite - name: test-mysql - image: golang:1.12 + image: golang:1.17 + pull: never + volumes: + - name: cache + path: /go/pkg/mod + depends_on: + - test-vet environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" TEST_MYSQL_HOST: mysql TEST_MYSQL_CHARSET: utf8 TEST_MYSQL_DBNAME: xorm_test TEST_MYSQL_USERNAME: root TEST_MYSQL_PASSWORD: commands: - - make test-mysql - TEST_CACHE_ENABLE=true make test-mysql - - TEST_QUOTE_POLICY=reserved make test-mysql - when: - event: - - push - - pull_request - -- name: test-mysql8 - image: golang:1.12 + +- name: test-mysql-utf8mb4 + image: golang:1.17 + pull: never + volumes: + - name: cache + path: /go/pkg/mod + depends_on: + - test-mysql + environment: + TEST_MYSQL_HOST: mysql + TEST_MYSQL_CHARSET: utf8mb4 + TEST_MYSQL_DBNAME: xorm_test + TEST_MYSQL_USERNAME: root + TEST_MYSQL_PASSWORD: + commands: + - make test-mysql + - TEST_QUOTE_POLICY=reserved make test-mysql-tls + +volumes: +- name: cache + host: + path: /tmp/cache + +services: +- name: mysql + image: mysql:5.7 + environment: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: xorm_test + +--- +kind: pipeline +name: test-mysql8 +depends_on: + - test-mysql +trigger: + ref: + - refs/heads/master + - refs/pull/*/head +steps: +- name: test-mysql8 + image: golang:1.17 + pull: never + volumes: + - name: cache + path: /go/pkg/mod environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" TEST_MYSQL_HOST: mysql8 TEST_MYSQL_CHARSET: utf8mb4 TEST_MYSQL_DBNAME: xorm_test @@ -62,59 +112,36 @@ steps: commands: - make test-mysql - TEST_CACHE_ENABLE=true make test-mysql - - TEST_QUOTE_POLICY=reserved make test-mysql - when: - event: - - push - - pull_request -- name: test-mysql-utf8mb4 - image: golang:1.12 - depends_on: - - test-mysql +volumes: +- name: cache + host: + path: /tmp/cache + +services: +- name: mysql8 + image: mysql:8.0 environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" - TEST_MYSQL_HOST: mysql - TEST_MYSQL_CHARSET: utf8mb4 - TEST_MYSQL_DBNAME: xorm_test - TEST_MYSQL_USERNAME: root - TEST_MYSQL_PASSWORD: - commands: - - make test-mysql - - TEST_CACHE_ENABLE=true make test-mysql - - TEST_QUOTE_POLICY=reserved make test-mysql - when: - event: - - push - - pull_request - -- name: test-mymysql - pull: default - image: golang:1.12 - depends_on: - - test-mysql-utf8mb4 - environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" - TEST_MYSQL_HOST: mysql:3306 - TEST_MYSQL_DBNAME: xorm_test - TEST_MYSQL_USERNAME: root - TEST_MYSQL_PASSWORD: - commands: - - make test-mymysql - - TEST_CACHE_ENABLE=true make test-mymysql - - TEST_QUOTE_POLICY=reserved make test-mymysql - when: - event: - - push - - pull_request + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: xorm_test +--- +kind: pipeline +name: test-mariadb +depends_on: + - test-mysql8 +trigger: + ref: + - refs/heads/master + - refs/pull/*/head +steps: - name: test-mariadb - image: golang:1.12 + image: golang:1.17 + pull: never + volumes: + - name: cache + path: /go/pkg/mod environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" TEST_MYSQL_HOST: mariadb TEST_MYSQL_CHARSET: utf8mb4 TEST_MYSQL_DBNAME: xorm_test @@ -122,19 +149,37 @@ steps: TEST_MYSQL_PASSWORD: commands: - make test-mysql - - TEST_CACHE_ENABLE=true make test-mysql - TEST_QUOTE_POLICY=reserved make test-mysql - when: - event: - - push - - pull_request -- name: test-postgres - pull: default - image: golang:1.12 +volumes: +- name: cache + host: + path: /tmp/cache + +services: +- name: mariadb + image: mariadb:10.4 + environment: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: xorm_test + +--- +kind: pipeline +name: test-postgres +depends_on: + - test-mariadb +trigger: + ref: + - refs/heads/master + - refs/pull/*/head +steps: +- name: test-postgres + pull: never + image: golang:1.17 + volumes: + - name: cache + path: /go/pkg/mod environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" TEST_PGSQL_HOST: pgsql TEST_PGSQL_DBNAME: xorm_test TEST_PGSQL_USERNAME: postgres @@ -142,79 +187,163 @@ steps: commands: - make test-postgres - TEST_CACHE_ENABLE=true make test-postgres - - TEST_QUOTE_POLICY=reserved make test-postgres - when: - event: - - push - - pull_request - name: test-postgres-schema - pull: default - image: golang:1.12 + pull: never + image: golang:1.17 + volumes: + - name: cache + path: /go/pkg/mod depends_on: - test-postgres environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" TEST_PGSQL_HOST: pgsql TEST_PGSQL_SCHEMA: xorm TEST_PGSQL_DBNAME: xorm_test TEST_PGSQL_USERNAME: postgres TEST_PGSQL_PASSWORD: postgres commands: - - make test-postgres - - TEST_CACHE_ENABLE=true make test-postgres - TEST_QUOTE_POLICY=reserved make test-postgres - when: - event: - - push - - pull_request -- name: test-mssql - pull: default - image: golang:1.12 +- name: test-pgx + pull: never + image: golang:1.17 + volumes: + - name: cache + path: /go/pkg/mod + depends_on: + - test-postgres-schema + environment: + TEST_PGSQL_HOST: pgsql + TEST_PGSQL_DBNAME: xorm_test + TEST_PGSQL_USERNAME: postgres + TEST_PGSQL_PASSWORD: postgres + commands: + - make test-pgx + - TEST_CACHE_ENABLE=true make test-pgx + - TEST_QUOTE_POLICY=reserved make test-pgx + +- name: test-pgx-schema + pull: never + image: golang:1.17 + volumes: + - name: cache + path: /go/pkg/mod + depends_on: + - test-pgx + environment: + TEST_PGSQL_HOST: pgsql + TEST_PGSQL_SCHEMA: xorm + TEST_PGSQL_DBNAME: xorm_test + TEST_PGSQL_USERNAME: postgres + TEST_PGSQL_PASSWORD: postgres + commands: + - make test-pgx + - TEST_CACHE_ENABLE=true make test-pgx + - TEST_QUOTE_POLICY=reserved make test-pgx + +volumes: +- name: cache + host: + path: /tmp/cache + +services: +- name: pgsql + image: postgres:9.5 + environment: + POSTGRES_DB: xorm_test + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + +--- +kind: pipeline +name: test-mssql +depends_on: + - test-postgres +trigger: + ref: + - refs/heads/master + - refs/pull/*/head +steps: +- name: test-mssql + pull: never + image: golang:1.17 + volumes: + - name: cache + path: /go/pkg/mod environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" TEST_MSSQL_HOST: mssql TEST_MSSQL_DBNAME: xorm_test TEST_MSSQL_USERNAME: sa TEST_MSSQL_PASSWORD: "yourStrong(!)Password" commands: - make test-mssql - - TEST_CACHE_ENABLE=true make test-mssql - - TEST_QUOTE_POLICY=reserved make test-mssql - TEST_MSSQL_DEFAULT_VARCHAR=NVARCHAR TEST_MSSQL_DEFAULT_CHAR=NCHAR make test-mssql - when: - event: - - push - - pull_request -- name: test-tidb - pull: default - image: golang:1.12 +volumes: +- name: cache + host: + path: /tmp/cache + +services: +- name: mssql + pull: always + image: mcr.microsoft.com/mssql/server:latest + environment: + ACCEPT_EULA: Y + SA_PASSWORD: yourStrong(!)Password + MSSQL_PID: Standard + +--- +kind: pipeline +name: test-tidb +depends_on: + - test-mssql +trigger: + ref: + - refs/heads/master + - refs/pull/*/head +steps: +- name: test-tidb + pull: never + image: golang:1.17 + volumes: + - name: cache + path: /go/pkg/mod environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" TEST_TIDB_HOST: "tidb:4000" TEST_TIDB_DBNAME: xorm_test TEST_TIDB_USERNAME: root TEST_TIDB_PASSWORD: commands: - make test-tidb - - TEST_CACHE_ENABLE=true make test-tidb - - TEST_QUOTE_POLICY=reserved make test-tidb - when: - event: - - push - - pull_request +volumes: +- name: cache + host: + path: /tmp/cache + +services: +- name: tidb + image: pingcap/tidb:v3.0.3 + +--- +kind: pipeline +name: test-cockroach +depends_on: + - test-tidb +trigger: + ref: + - refs/heads/master + - refs/pull/*/head +steps: - name: test-cockroach - pull: default - image: golang:1.13 + pull: never + image: golang:1.17 + volumes: + - name: cache + path: /go/pkg/mod environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" TEST_COCKROACH_HOST: "cockroach:26257" TEST_COCKROACH_DBNAME: xorm_test TEST_COCKROACH_USERNAME: root @@ -222,116 +351,87 @@ steps: commands: - sleep 10 - make test-cockroach - - TEST_CACHE_ENABLE=true make test-cockroach - when: - event: - - push - - pull_request -- name: merge_coverage - pull: default - image: golang:1.12 - environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" - depends_on: - - test-vet - - test-sqlite - - test-mysql - - test-mysql8 - - test-mymysql - - test-postgres - - test-postgres-schema - - test-mssql - - test-tidb - - test-cockroach - commands: - - make coverage - when: - event: - - push - - pull_request +volumes: +- name: cache + host: + path: /tmp/cache services: - -- name: mysql - pull: default - image: mysql:5.7 - environment: - MYSQL_ALLOW_EMPTY_PASSWORD: yes - MYSQL_DATABASE: xorm_test - when: - event: - - push - - tag - - pull_request - -- name: mysql8 - pull: default - image: mysql:8.0 - environment: - MYSQL_ALLOW_EMPTY_PASSWORD: yes - MYSQL_DATABASE: xorm_test - when: - event: - - push - - tag - - pull_request - -- name: mariadb - pull: default - image: mariadb:10.4 - environment: - MYSQL_ALLOW_EMPTY_PASSWORD: yes - MYSQL_DATABASE: xorm_test - when: - event: - - push - - tag - - pull_request - -- name: pgsql - pull: default - image: postgres:9.5 - environment: - POSTGRES_DB: xorm_test - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - when: - event: - - push - - tag - - pull_request - -- name: mssql - pull: default - image: microsoft/mssql-server-linux:latest - environment: - ACCEPT_EULA: Y - SA_PASSWORD: yourStrong(!)Password - MSSQL_PID: Developer - when: - event: - - push - - tag - - pull_request - -- name: tidb - pull: default - image: pingcap/tidb:v3.0.3 - when: - event: - - push - - tag - - pull_request - - name: cockroach - pull: default image: cockroachdb/cockroach:v19.2.4 commands: - /cockroach/cockroach start --insecure - when: - event: - - push + +# --- +# kind: pipeline +# name: test-dameng +# depends_on: +# - test-cockroach +# trigger: +# ref: +# - refs/heads/master +# - refs/pull/*/head +# steps: +# - name: test-dameng +# pull: never +# image: golang:1.17 +# volumes: +# - name: cache +# path: /go/pkg/mod +# environment: +# TEST_DAMENG_HOST: "dameng:5236" +# TEST_DAMENG_USERNAME: SYSDBA +# TEST_DAMENG_PASSWORD: SYSDBA +# commands: +# - sleep 30 +# - make test-dameng + +# volumes: +# - name: cache +# host: +# path: /tmp/cache + +# services: +# - name: dameng +# image: lunny/dm:v1.0 +# commands: +# - /bin/bash /startDm.sh + +--- +kind: pipeline +name: merge_coverage +depends_on: + - test-mysql + - test-mysql8 + - test-mariadb + - test-postgres + - test-mssql + - test-tidb + - test-cockroach + #- test-dameng +trigger: + ref: + - refs/heads/master + - refs/pull/*/head +steps: +- name: merge_coverage + image: golang:1.17 + commands: + - make coverage + +--- +kind: pipeline +name: release-tag +trigger: + event: - tag - - pull_request +steps: +- name: release-tag-gitea + pull: always + image: plugins/gitea-release:latest + settings: + base_url: https://gitea.com + title: '${DRONE_TAG} is released' + api_key: + from_secret: gitea_token \ No newline at end of file diff --git a/.gitignore b/.gitignore index 617d5da7..a183a295 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,5 @@ test.db.sql *coverage.out test.db integrations/*.sql +integrations/test_sqlite* +cover.out \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..7b91f22d --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,24 @@ +linters: + enable: + - gosimple + - deadcode + - typecheck + - govet + - errcheck + - staticcheck + - unused + - structcheck + - varcheck + - dupl + #- gocyclo # The cyclomatic complexety of a lot of functions is too high, we should refactor those another time. + - gofmt + - misspell + - gocritic + - bidichk + - ineffassign + enable-all: false + disable-all: true + fast: false + +run: + timeout: 3m \ No newline at end of file diff --git a/.revive.toml b/.revive.toml deleted file mode 100644 index 64e223bb..00000000 --- a/.revive.toml +++ /dev/null @@ -1,25 +0,0 @@ -ignoreGeneratedHeader = false -severity = "warning" -confidence = 0.8 -errorCode = 1 -warningCode = 1 - -[rule.blank-imports] -[rule.context-as-argument] -[rule.context-keys-type] -[rule.dot-imports] -[rule.error-return] -[rule.error-strings] -[rule.error-naming] -[rule.exported] -[rule.if-return] -[rule.increment-decrement] -[rule.var-naming] -[rule.var-declaration] -[rule.package-comments] -[rule.range] -[rule.receiver-naming] -[rule.time-naming] -[rule.unexported-return] -[rule.indent-error-flow] -[rule.errorf] \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cead87d..6887cb97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,175 @@ This changelog goes through all the changes that have been made in each release without substantial changes to our git log. +## [1.3.2](https://gitea.com/xorm/xorm/releases/tag/1.3.2) - 2022-09-03 + +* BUGFIXES + * Change schemas.Column to use int64 (#2160) +* MISC + * Prevent Sync failure with non-regular indexes on Postgres (#2174) + +## [1.3.1](https://gitea.com/xorm/xorm/releases/tag/1.3.1) - 2022-06-03 + +* BREAKING + * Refactor orderby and support arguments (#2150) + * return a clear error for set TEXT type as compare condition (#2062) +* BUGFIXES + * Fix oid index for postgres (#2154) + * Add ORDER BY SEQ_IN_INDEX to MySQL GetIndexes to Fix IndexTests (#2152) + * some improvement (#2136) +* ENHANCEMENTS + * Add interface to allow structs to provide specific index information (#2137) + * MySQL/MariaDB: return max length for text columns (#2133) + * PostgreSQL: enable comment on column (#2131) +* TESTING + * Add test for find date (#2121) + +## [1.3.0](https://gitea.com/xorm/xorm/releases/tag/1.3.0) - 2022-04-14 + +* BREAKING + * New Prepare useage (#2061) + * Make Get and Rows.Scan accept multiple parameters (#2029) + * Drop sync function and rename sync2 to sync (#2018) +* FEATURES + * Add dameng support (#2007) +* BUGFIXES + * bugfix :Oid It's a special index. You can't put it in (#2105) + * Fix new-lined query execution in master DB node. (#2066) + * Fix bug of Rows (#2048) + * Fix bug (#2046) + * fix panic when `Iterate()` fails (#2040) + * fix panic when convert sql and args with nil time.Time pointer (#2038) +* ENHANCEMENTS + * Fix to add session.statement.IsForUpdate check in Session.queryRows() (#2064) + * Expose ScanString / ScanInterface and etc (#2039) +* TESTING + * Add test for mysql tls (#2049) +* BUILD + * Upgrade dependencies modules (#2078) +* MISC + * Fix oracle keyword AS (#2109) + * Some performance optimization for get (#2043) + +## [1.2.2](https://gitea.com/xorm/xorm/releases/tag/1.2.2) - 2021-08-11 + +* MISC + * Move convert back to xorm.io/xorm/convert (#2030) + +## [1.2.1](https://gitea.com/xorm/xorm/releases/tag/1.2.1) - 2021-08-08 + +* FEATURES + * Add pgx driver support (#1795) +* BUGFIXES + * Fix wrong comment (#2027) + * Fix import file bug (#2025) +* ENHANCEMENTS + * Fix timesatmp (#2021) + +## [1.2.0](https://gitea.com/xorm/xorm/releases/tag/1.2.0) - 2021-08-04 + +* BREAKING + * Exec with time arg now will obey time zone settings on engine (#1989) + * Query interface (#1965) + * Support delete with no bean (#1926) + * Nil ptr is nullable (#1919) +* FEATURES + * Support batch insert map (#2019) + * Support big.Float (#1973) +* BUGFIXES + * fix possible null dereference in internal/statements/query.go (#1988) + * Fix bug on dumptable (#1984) +* ENHANCEMENTS + * Move assign functions to convert package (#2015) + * refactor conversion (#2001) + * refactor some code (#2000) + * refactor insert condition generation (#1998) + * refactor and add setjson function (#1997) + * Get struct and Find support big.Float (#1976) + * refactor slice2Bean (#1974, #1975) + * refactor get (#1967) + * Replace #1044 (#1935) + * Support Get time.Time (#1933) +* TESTING + * Add benchmark tests (#1978) + * Add tests for github.com/shopspring/decimal support (#1977) + * Add test for get map with NULL column (#1948) + * Add test for limit with query (#1787) +* MISC + * Fix DBMetas returned unsigned tinyint (#2017) + * Fix deleted column (#2014) + * Add database alias table and fix wrong warning (#1947) + +## [1.1.2](https://gitea.com/xorm/xorm/releases/tag/1.1.2) - 2021-07-04 + +* BUILD + * Add release tag (#1966) + +## [1.1.1](https://gitea.com/xorm/xorm/releases/tag/1.1.1) - 2021-07-03 + +* BUGFIXES + * Ignore comments when deciding when to replace question marks. #1954 (#1955) + * Fix bug didn't reset statement on update (#1939) + * Fix create table with struct missing columns (#1938) + * Fix #929 (#1936) + * Fix exist (#1921) +* ENHANCEMENTS + * Improve get field value of bean (#1961) + * refactor splitTag function (#1960) + * Fix #1663 (#1952) + * fix pg GetColumns missing comment (#1949) + * Support build flag jsoniter to replace default json (#1916) + * refactor exprParam (#1825) + * Add DBVersion (#1723) +* TESTING + * Add test to confirm #1247 resolved (#1951) + * Add test for dump table with default value (#1950) + * Test for #1486 (#1942) + * Add sync tests to confirm #539 is gone (#1937) + * test for unsigned int32 (#1923) + * Add tests for array store (#1922) +* BUILD + * Remove mymysql from ci (#1928) +* MISC + * fix lint (#1953) + * Compitable with cockroach (#1930) + * Replace goracle with godror (#1914) + +## [1.1.0](https://gitea.com/xorm/xorm/releases/tag/1.1.0) - 2021-05-14 + +* FEATURES + * Unsigned Support for mysql (#1889) + * Support modernc.org/sqlite (#1850) +* TESTING + * More tests (#1890) +* MISC + * Byte strings in postgres aren't 0x... (#1906) + * Fix another bug with #1872 (#1905) + * Fix two issues with dumptables (#1903) + * Fix comments (#1896) + * Fix comments (#1893) + * MariaDB 10.5 adds a suffix on old datatypes (#1885) + +## [1.0.7](https://gitea.com/xorm/xorm/pulls?q=&type=all&state=closed&milestone=1336) - 2021-01-21 + +* BUGFIXES + * Fix bug for mssql (#1854) +* MISC + * fix_bugs_for_mssql (#1852) + +## [1.0.6](https://gitea.com/xorm/xorm/pulls?q=&type=all&state=closed&milestone=1308) - 2021-01-05 + +* BUGFIXES + * Fix bug when modify column on mssql (#1849) + * Fix find and count bug with cols (#1826) + * Fix update bug (#1823) + * Fix json tag with other type (#1822) +* ENHANCEMENTS + * prevent panic when struct with unexport field (#1839) + * Automatically convert datetime to int64 (#1715) +* MISC + * Fix index (#1841) + * Performance improvement for columnsbyName (#1788) + ## [1.0.5](https://gitea.com/xorm/xorm/pulls?q=&type=all&state=closed&milestone=1299) - 2020-09-08 * BUGFIXES diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a6925a5c..27e6929b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,13 +1,13 @@ ## Contributing to xorm -`xorm` has a backlog of [pull requests](https://help.github.com/articles/using-pull-requests), but contributions are still very -much welcome. You can help with patch review, submitting bug reports, +`xorm` has a backlog of [pull requests](https://gitea.com/xorm/xorm/pulls), but contributions are still very +much welcome. You can help with patch review, submitting [bug reports](https://gitea.com/xorm/xorm/issues), or adding new functionality. There is no formal style guide, but please conform to the style of existing code and general Go formatting conventions when submitting patches. -* [fork a repo](https://help.github.com/articles/fork-a-repo) -* [creating a pull request ](https://help.github.com/articles/creating-a-pull-request) +* [fork the repo](https://gitea.com/repo/fork/2038) +* [creating a pull request ](https://docs.gitea.io/en-us/pull-request/) ### Language @@ -15,7 +15,7 @@ Since `xorm` is a world-wide open source project, please describe your issues or ### Sign your codes with comments ``` -// !! your comments +// !! your comments e.g., @@ -65,7 +65,7 @@ And if your branch is related with cache, you could also enable it via `TEST_CAC ### Patch review -Help review existing open [pull requests](https://help.github.com/articles/using-pull-requests) by commenting on the code or +Help review existing open [pull requests](https://gitea.com/xorm/xorm/pulls) by commenting on the code or proposed functionality. ### Bug reports diff --git a/Makefile b/Makefile index 092f23b3..b43c4a4c 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,9 @@ GOFMT ?= gofmt -s TAGS ?= SED_INPLACE := sed -i -GOFILES := $(shell find . -name "*.go" -type f) +GO_DIRS := caches contexts integrations core dialects internal log migrate names schemas tags +GOFILES := $(wildcard *.go) +GOFILES += $(shell find $(GO_DIRS) -name "*.go" -type f) INTEGRATION_PACKAGES := xorm.io/xorm/integrations PACKAGES ?= $(filter-out $(INTEGRATION_PACKAGES),$(shell $(GO) list ./...)) @@ -41,6 +43,10 @@ TEST_TIDB_DBNAME ?= xorm_test TEST_TIDB_USERNAME ?= root TEST_TIDB_PASSWORD ?= +TEST_DAMENG_HOST ?= dameng:5236 +TEST_DAMENG_USERNAME ?= SYSDBA +TEST_DAMENG_PASSWORD ?= SYSDBA + TEST_CACHE_ENABLE ?= false TEST_QUOTE_POLICY ?= always @@ -92,40 +98,37 @@ help: @echo " - build creates the entire project" @echo " - clean delete integration files and build files but not css and js files" @echo " - fmt format the code" - @echo " - lint run code linter revive" - @echo " - misspell check if a word is written wrong" + @echo " - lint run code linter" @echo " - test run default unit test" @echo " - test-cockroach run integration tests for cockroach" @echo " - test-mysql run integration tests for mysql" @echo " - test-mssql run integration tests for mssql" @echo " - test-postgres run integration tests for postgres" - @echo " - test-sqlite run integration tests for sqlite" + @echo " - test-sqlite3 run integration tests for sqlite" + @echo " - test-sqlite run integration tests for pure go sqlite" @echo " - test-tidb run integration tests for tidb" @echo " - vet examines Go source code and reports suspicious constructs" .PHONY: lint -lint: revive +lint: golangci-lint -.PHONY: revive -revive: - @hash revive > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ - $(GO) get -u github.com/mgechev/revive; \ - fi - revive -config .revive.toml -exclude=./vendor/... ./... || exit 1 +.PHONY: golangci-lint +golangci-lint: golangci-lint-check + golangci-lint run --timeout 10m -.PHONY: misspell -misspell: - @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ - $(GO) get -u github.com/client9/misspell/cmd/misspell; \ +.PHONY: golangci-lint-check +golangci-lint-check: + $(eval GOLANGCI_LINT_VERSION := $(shell printf "%03d%03d%03d" $(shell golangci-lint --version | grep -Eo '[0-9]+\.[0-9.]+' | tr '.' ' ');)) + $(eval MIN_GOLANGCI_LINT_VER_FMT := $(shell printf "%g.%g.%g" $(shell echo $(MIN_GOLANGCI_LINT_VERSION) | grep -o ...))) + @hash golangci-lint > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + echo "Downloading golangci-lint v${MIN_GOLANGCI_LINT_VER_FMT}"; \ + export BINARY="golangci-lint"; \ + curl -sfL "https://raw.githubusercontent.com/golangci/golangci-lint/v${MIN_GOLANGCI_LINT_VER_FMT}/install.sh" | sh -s -- -b $(GOPATH)/bin v$(MIN_GOLANGCI_LINT_VER_FMT); \ + elif [ "$(GOLANGCI_LINT_VERSION)" -lt "$(MIN_GOLANGCI_LINT_VERSION)" ]; then \ + echo "Downloading newer version of golangci-lint v${MIN_GOLANGCI_LINT_VER_FMT}"; \ + export BINARY="golangci-lint"; \ + curl -sfL "https://raw.githubusercontent.com/golangci/golangci-lint/v${MIN_GOLANGCI_LINT_VER_FMT}/install.sh" | sh -s -- -b $(GOPATH)/bin v$(MIN_GOLANGCI_LINT_VER_FMT); \ fi - misspell -w -i unknwon $(GOFILES) - -.PHONY: misspell-check -misspell-check: - @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ - $(GO) get -u github.com/client9/misspell/cmd/misspell; \ - fi - misspell -error -i unknwon,destory $(GOFILES) .PHONY: test test: go-check @@ -135,7 +138,7 @@ test: go-check test-cockroach: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=postgres -schema='$(TEST_COCKROACH_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_COCKROACH_USERNAME):$(TEST_COCKROACH_PASSWORD)@$(TEST_COCKROACH_HOST)/$(TEST_COCKROACH_DBNAME)?sslmode=disable&experimental_serial_normalization=sql_sequence" \ - -ignore_update_limit=true -coverprofile=cockroach.$(TEST_COCKROACH_SCHEMA).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -ignore_update_limit=true -coverprofile=cockroach.$(TEST_COCKROACH_SCHEMA).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-cockroach\#% test-cockroach\#%: go-check @@ -149,7 +152,7 @@ test-mssql: go-check -conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \ -default_varchar=$(TEST_MSSQL_DEFAULT_VARCHAR) -default_char=$(TEST_MSSQL_DEFAULT_CHAR) \ -do_nvarchar_override_test=$(TEST_MSSQL_DO_NVARCHAR_OVERRIDE_TEST) \ - -coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PNONY: test-mssql\#% test-mssql\#%: go-check @@ -163,7 +166,7 @@ test-mssql\#%: go-check test-mymysql: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ - -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PNONY: test-mymysql\#% test-mymysql\#%: go-check @@ -175,7 +178,7 @@ test-mymysql\#%: go-check test-mysql: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ -conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \ - -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-mysql\#% test-mysql\#%: go-check @@ -183,11 +186,23 @@ test-mysql\#%: go-check -conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \ -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic +.PNONY: test-mysql-tls +test-mysql-tls: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)&tls=skip-verify" \ + -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + +.PHONY: test-mysql-tls\#% +test-mysql-tls\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)&tls=skip-verify" \ + -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + .PNONY: test-postgres test-postgres: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-postgres\#% test-postgres\#%: go-check @@ -195,26 +210,53 @@ test-postgres\#%: go-check -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic +.PHONY: test-sqlite3 +test-sqlite3: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite3.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + +.PHONY: test-sqlite3-schema +test-sqlite3-schema: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -schema=xorm -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite3.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + +.PHONY: test-sqlite3\#% +test-sqlite3\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite3.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + +.PNONY: test-pgx +test-pgx: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=pgx -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + +.PHONY: test-pgx\#% +test-pgx\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=pgx -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + .PHONY: test-sqlite test-sqlite: go-check - $(GO) test $(INTEGRATION_PACKAGES) -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + $(GO) test $(INTEGRATION_PACKAGES) -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite -conn_str="./test.db?cache=shared&mode=rwc" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-sqlite-schema test-sqlite-schema: go-check - $(GO) test $(INTEGRATION_PACKAGES) -v -race -schema=xorm -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + $(GO) test $(INTEGRATION_PACKAGES) -v -race -schema=xorm -cache=$(TEST_CACHE_ENABLE) -db=sqlite -conn_str="./test.db?cache=shared&mode=rwc" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-sqlite\#% test-sqlite\#%: go-check - $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite -conn_str="./test.db?cache=shared&mode=rwc" \ -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PNONY: test-tidb test-tidb: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ -conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ - -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m .PHONY: test-tidb\#% test-tidb\#%: go-check @@ -222,6 +264,18 @@ test-tidb\#%: go-check -conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic +.PNONY: test-dameng +test-dameng: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=dm -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="dm://$(TEST_DAMENG_USERNAME):$(TEST_DAMENG_PASSWORD)@$(TEST_DAMENG_HOST)" \ + -coverprofile=dameng.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + +.PHONY: test-dameng\#% +test-dameng\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=dm -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="dm://$(TEST_DAMENG_USERNAME):$(TEST_DAMENG_PASSWORD)@$(TEST_DAMENG_HOST)" \ + -coverprofile=dameng.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m + .PHONY: vet vet: - $(GO) vet $(shell $(GO) list ./...) \ No newline at end of file + $(GO) vet $(shell $(GO) list ./...) diff --git a/README.md b/README.md index 67380839..f30449a1 100644 --- a/README.md +++ b/README.md @@ -41,15 +41,19 @@ Drivers for Go's sql package which currently support database/sql includes: * [Postgres](https://github.com/postgres/postgres) / [Cockroach](https://github.com/cockroachdb/cockroach) - [github.com/lib/pq](https://github.com/lib/pq) + - [github.com/jackc/pgx](https://github.com/jackc/pgx) * [SQLite](https://sqlite.org) - [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + - [modernc.org/sqlite](https://gitlab.com/cznic/sqlite) (windows unsupported) * MsSql - [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) * Oracle + - [github.com/godror/godror](https://github.com/godror/godror) (experiment) - [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment) + - [github.com/sijms/go-ora](https://github.com/sijms/go-ora) (experiment) ## Installation @@ -71,7 +75,7 @@ Firstly, we should new an engine for a database. engine, err := xorm.NewEngine(driverName, dataSourceName) ``` -* Define a struct and Sync2 table struct to database +* Define a struct and Sync table struct to database ```Go type User struct { @@ -84,7 +88,7 @@ type User struct { Updated time.Time `xorm:"updated"` } -err := engine.Sync2(new(User)) +err := engine.Sync(new(User)) ``` * Create Engine Group @@ -138,6 +142,24 @@ affected, err := engine.Insert(&users) affected, err := engine.Insert(&user1, &users) // INSERT INTO struct1 () values () // INSERT INTO struct2 () values (),(),() + +affected, err := engine.Table("user").Insert(map[string]interface{}{ + "name": "lunny", + "age": 18, +}) +// INSERT INTO user (name, age) values (?,?) + +affected, err := engine.Table("user").Insert([]map[string]interface{}{ + { + "name": "lunny", + "age": 18, + }, + { + "name": "lunny2", + "age": 19, + }, +}) +// INSERT INTO user (name, age) values (?,?),(?,?) ``` * `Get` query one record from database @@ -158,6 +180,11 @@ has, err := engine.Table(&user).Where("name = ?", name).Cols("id").Get(&id) has, err := engine.SQL("select id from user").Get(&id) // SELECT id FROM user WHERE name = ? +var id int64 +var name string +has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name) +// SELECT id, name FROM user LIMIT 1 + var valuesMap = make(map[string]string) has, err := engine.Table(&user).Where("id = ?", id).Get(&valuesMap) // SELECT * FROM user WHERE id = ? @@ -231,7 +258,11 @@ err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean inter }) // SELECT * FROM user Limit 0, 100 // SELECT * FROM user Limit 101, 100 +``` +You can use rows which is similiar with `sql.Rows` + +```Go rows, err := engine.Rows(&User{Name:name}) // SELECT * FROM user defer rows.Close() @@ -241,39 +272,55 @@ for rows.Next() { } ``` +or + +```Go +rows, err := engine.Cols("name", "age").Rows(&User{Name:name}) +// SELECT * FROM user +defer rows.Close() +for rows.Next() { + var name string + var age int + err = rows.Scan(&name, &age) +} +``` + * `Update` update one or more records, default will update non-empty and non-zero fields except when you use Cols, AllCols and so on. ```Go affected, err := engine.ID(1).Update(&user) -// UPDATE user SET ... Where id = ? +// UPDATE user SET ... WHERE id = ? affected, err := engine.Update(&user, &User{Name:name}) -// UPDATE user SET ... Where name = ? +// UPDATE user SET ... WHERE name = ? var ids = []int64{1, 2, 3} affected, err := engine.In("id", ids).Update(&user) -// UPDATE user SET ... Where id IN (?, ?, ?) +// UPDATE user SET ... WHERE id IN (?, ?, ?) // force update indicated columns by Cols affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12}) -// UPDATE user SET age = ?, updated=? Where id = ? +// UPDATE user SET age = ?, updated=? WHERE id = ? // force NOT update indicated columns by Omit affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12}) -// UPDATE user SET age = ?, updated=? Where id = ? +// UPDATE user SET age = ?, updated=? WHERE id = ? affected, err := engine.ID(1).AllCols().Update(&user) -// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ? +// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? WHERE id = ? ``` * `Delete` delete one or more records, Delete MUST have condition ```Go affected, err := engine.Where(...).Delete(&user) -// DELETE FROM user Where ... +// DELETE FROM user WHERE ... affected, err := engine.ID(2).Delete(&user) -// DELETE FROM user Where id = ? +// DELETE FROM user WHERE id = ? + +affected, err := engine.Table("user").Where(...).Delete() +// DELETE FROM user WHERE ... ``` * `Count` count records diff --git a/README_CN.md b/README_CN.md index 80245dd3..a5aaae66 100644 --- a/README_CN.md +++ b/README_CN.md @@ -40,14 +40,17 @@ v1.0.0 相对于 v0.8.2 有以下不兼容的变更: * [Postgres](https://github.com/postgres/postgres) / [Cockroach](https://github.com/cockroachdb/cockroach) - [github.com/lib/pq](https://github.com/lib/pq) + - [github.com/jackc/pgx](https://github.com/jackc/pgx) * [SQLite](https://sqlite.org) - [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + - [modernc.org/sqlite](https://gitlab.com/cznic/sqlite) (Windows试验性支持) * MsSql - [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) * Oracle + - [github.com/godror/godror](https://github.com/godror/godror) (试验性支持) - [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (试验性支持) ## 安装 @@ -62,7 +65,7 @@ v1.0.0 相对于 v0.8.2 有以下不兼容的变更: # 快速开始 -* 第一步创建引擎,driverName, dataSourceName和database/sql接口相同 +* 第一步创建引擎,`driverName`, `dataSourceName` 和 `database/sql` 接口相同 ```Go engine, err := xorm.NewEngine(driverName, dataSourceName) @@ -81,7 +84,7 @@ type User struct { Updated time.Time `xorm:"updated"` } -err := engine.Sync2(new(User)) +err := engine.Sync(new(User)) ``` * 创建Engine组 @@ -100,7 +103,7 @@ engineGroup, err := xorm.NewEngineGroup(masterEngine, []*Engine{slave1Engine, sl 所有使用 `engine` 都可以简单的用 `engineGroup` 来替换。 -* `Query` 最原始的也支持SQL语句查询,返回的结果类型为 []map[string][]byte。`QueryString` 返回 []map[string]string, `QueryInterface` 返回 `[]map[string]interface{}`. +* `Query` 最原始的也支持SQL语句查询,返回的结果类型为 `[]map[string][]byte`。`QueryString` 返回 `[]map[string]string`, `QueryInterface` 返回 `[]map[string]interface{}`. ```Go results, err := engine.Query("select * from user") @@ -135,6 +138,24 @@ affected, err := engine.Insert(&users) affected, err := engine.Insert(&user1, &users) // INSERT INTO struct1 () values () // INSERT INTO struct2 () values (),(),() + +affected, err := engine.Table("user").Insert(map[string]interface{}{ + "name": "lunny", + "age": 18, +}) +// INSERT INTO user (name, age) values (?,?) + +affected, err := engine.Table("user").Insert([]map[string]interface{}{ + { + "name": "lunny", + "age": 18, + }, + { + "name": "lunny2", + "age": 19, + }, +}) +// INSERT INTO user (name, age) values (?,?),(?,?) ``` * `Get` 查询单条记录 @@ -155,6 +176,11 @@ has, err := engine.Table(&user).Where("name = ?", name).Cols("id").Get(&id) has, err := engine.SQL("select id from user").Get(&id) // SELECT id FROM user WHERE name = ? +var id int64 +var name string +has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name) +// SELECT id, name FROM user LIMIT 1 + var valuesMap = make(map[string]string) has, err := engine.Table(&user).Where("id = ?", id).Get(&valuesMap) // SELECT * FROM user WHERE id = ? @@ -206,7 +232,7 @@ type UserDetail struct { } var users []UserDetail -err := engine.Table("user").Select("user.*, detail.*") +err := engine.Table("user").Select("user.*, detail.*"). Join("INNER", "detail", "detail.user_id = user.id"). Where("user.name = ?", name).Limit(10, 0). Find(&users) @@ -228,7 +254,11 @@ err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean inter }) // SELECT * FROM user Limit 0, 100 // SELECT * FROM user Limit 101, 100 +``` +Rows 的用法类似 `sql.Rows`。 + +```Go rows, err := engine.Rows(&User{Name:name}) // SELECT * FROM user defer rows.Close() @@ -238,6 +268,19 @@ for rows.Next() { } ``` +或者 + +```Go +rows, err := engine.Cols("name", "age").Rows(&User{Name:name}) +// SELECT * FROM user +defer rows.Close() +for rows.Next() { + var name string + var age int + err = rows.Scan(&name, &age) +} +``` + * `Update` 更新数据,除非使用Cols,AllCols函数指明,默认只更新非空和非0的字段 ```Go @@ -271,6 +314,9 @@ affected, err := engine.Where(...).Delete(&user) affected, err := engine.ID(2).Delete(&user) // DELETE FROM user Where id = ? + +affected, err := engine.Table("user").Where(...).Delete() +// DELETE FROM user WHERE ... ``` * `Count` 获取记录条数 diff --git a/caches/encode.go b/caches/encode.go index 4ba39924..8659668c 100644 --- a/caches/encode.go +++ b/caches/encode.go @@ -13,22 +13,26 @@ import ( "io" ) -// md5 hash string +// Md5 return md5 hash string func Md5(str string) string { m := md5.New() - io.WriteString(m, str) + _, _ = io.WriteString(m, str) return fmt.Sprintf("%x", m.Sum(nil)) } + +// Encode Encode data func Encode(data interface{}) ([]byte, error) { - //return JsonEncode(data) + // return JsonEncode(data) return GobEncode(data) } +// Decode decode data func Decode(data []byte, to interface{}) error { - //return JsonDecode(data, to) + // return JsonDecode(data, to) return GobDecode(data, to) } +// GobEncode encode data with gob func GobEncode(data interface{}) ([]byte, error) { var buf bytes.Buffer enc := gob.NewEncoder(&buf) @@ -39,12 +43,14 @@ func GobEncode(data interface{}) ([]byte, error) { return buf.Bytes(), nil } +// GobDecode decode data with gob func GobDecode(data []byte, to interface{}) error { buf := bytes.NewBuffer(data) dec := gob.NewDecoder(buf) return dec.Decode(to) } +// JsonEncode encode data with json func JsonEncode(data interface{}) ([]byte, error) { val, err := json.Marshal(data) if err != nil { @@ -53,6 +59,7 @@ func JsonEncode(data interface{}) ([]byte, error) { return val, nil } +// JsonDecode decode data with json func JsonDecode(data []byte, to interface{}) error { return json.Unmarshal(data, to) } diff --git a/caches/leveldb.go b/caches/leveldb.go index d1a177ad..f2f71d84 100644 --- a/caches/leveldb.go +++ b/caches/leveldb.go @@ -19,6 +19,7 @@ type LevelDBStore struct { var _ CacheStore = &LevelDBStore{} +// NewLevelDBStore creates a leveldb store func NewLevelDBStore(dbfile string) (*LevelDBStore, error) { db := &LevelDBStore{} h, err := leveldb.OpenFile(dbfile, nil) @@ -29,6 +30,7 @@ func NewLevelDBStore(dbfile string) (*LevelDBStore, error) { return db, nil } +// Put implements CacheStore func (s *LevelDBStore) Put(key string, value interface{}) error { val, err := Encode(value) if err != nil { @@ -50,6 +52,7 @@ func (s *LevelDBStore) Put(key string, value interface{}) error { return err } +// Get implements CacheStore func (s *LevelDBStore) Get(key string) (interface{}, error) { data, err := s.store.Get([]byte(key), nil) if err != nil { @@ -75,6 +78,7 @@ func (s *LevelDBStore) Get(key string) (interface{}, error) { return s.v, err } +// Del implements CacheStore func (s *LevelDBStore) Del(key string) error { err := s.store.Delete([]byte(key), nil) if err != nil { @@ -89,6 +93,7 @@ func (s *LevelDBStore) Del(key string) error { return err } +// Close implements CacheStore func (s *LevelDBStore) Close() { s.store.Close() } diff --git a/caches/lru.go b/caches/lru.go index 6b45ac94..885f02d6 100644 --- a/caches/lru.go +++ b/caches/lru.go @@ -56,7 +56,7 @@ func (m *LRUCacher) GC() { var removedNum int for e := m.idList.Front(); e != nil; { if removedNum <= CacheGcMaxRemoved && - time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { + time.Since(e.Value.(*idNode).lastVisit) > m.Expired { removedNum++ next := e.Next() node := e.Value.(*idNode) @@ -70,7 +70,7 @@ func (m *LRUCacher) GC() { removedNum = 0 for e := m.sqlList.Front(); e != nil; { if removedNum <= CacheGcMaxRemoved && - time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { + time.Since(e.Value.(*sqlNode).lastVisit) > m.Expired { removedNum++ next := e.Next() node := e.Value.(*sqlNode) @@ -96,7 +96,7 @@ func (m *LRUCacher) GetIds(tableName, sql string) interface{} { } else { lastTime := el.Value.(*sqlNode).lastVisit // if expired, remove the node and return nil - if time.Now().Sub(lastTime) > m.Expired { + if time.Since(lastTime) > m.Expired { m.delIds(tableName, sql) return nil } @@ -122,7 +122,7 @@ func (m *LRUCacher) GetBean(tableName string, id string) interface{} { if el, ok := m.idIndex[tableName][id]; ok { lastTime := el.Value.(*idNode).lastVisit // if expired, remove the node and return nil - if time.Now().Sub(lastTime) > m.Expired { + if time.Since(lastTime) > m.Expired { m.delBean(tableName, id) return nil } @@ -145,7 +145,7 @@ func (m *LRUCacher) clearIds(tableName string) { if tis, ok := m.sqlIndex[tableName]; ok { for sql, v := range tis { m.sqlList.Remove(v) - m.store.Del(sql) + _ = m.store.Del(sql) } } m.sqlIndex[tableName] = make(map[string]*list.Element) @@ -163,7 +163,7 @@ func (m *LRUCacher) clearBeans(tableName string) { for id, v := range tis { m.idList.Remove(v) tid := genID(tableName, id) - m.store.Del(tid) + _ = m.store.Del(tid) } } m.idIndex[tableName] = make(map[string]*list.Element) @@ -188,7 +188,7 @@ func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { } else { el.Value.(*sqlNode).lastVisit = time.Now() } - m.store.Put(sql, ids) + _ = m.store.Put(sql, ids) if m.sqlList.Len() > m.MaxElementSize { e := m.sqlList.Front() node := e.Value.(*sqlNode) @@ -210,7 +210,7 @@ func (m *LRUCacher) PutBean(tableName string, id string, obj interface{}) { el.Value.(*idNode).lastVisit = time.Now() } - m.store.Put(genID(tableName, id), obj) + _ = m.store.Put(genID(tableName, id), obj) if m.idList.Len() > m.MaxElementSize { e := m.idList.Front() node := e.Value.(*idNode) @@ -226,7 +226,7 @@ func (m *LRUCacher) delIds(tableName, sql string) { m.sqlList.Remove(el) } } - m.store.Del(sql) + _ = m.store.Del(sql) } // DelIds deletes ids @@ -243,7 +243,7 @@ func (m *LRUCacher) delBean(tableName string, id string) { m.idList.Remove(el) m.clearIds(tableName) } - m.store.Del(tid) + _ = m.store.Del(tid) } // DelBean deletes beans in some table @@ -265,10 +265,6 @@ type sqlNode struct { lastVisit time.Time } -func genSQLKey(sql string, args interface{}) string { - return fmt.Sprintf("%s-%v", sql, args) -} - func genID(prefix string, id string) string { return fmt.Sprintf("%s-%s", prefix, id) } diff --git a/caches/manager.go b/caches/manager.go index 05045210..89a14106 100644 --- a/caches/manager.go +++ b/caches/manager.go @@ -6,6 +6,7 @@ package caches import "sync" +// Manager represents a cache manager type Manager struct { cacher Cacher disableGlobalCache bool @@ -14,6 +15,7 @@ type Manager struct { cacherLock sync.RWMutex } +// NewManager creates a cache manager func NewManager() *Manager { return &Manager{ cachers: make(map[string]Cacher), @@ -27,12 +29,14 @@ func (mgr *Manager) SetDisableGlobalCache(disable bool) { } } +// SetCacher set cacher of table func (mgr *Manager) SetCacher(tableName string, cacher Cacher) { mgr.cacherLock.Lock() mgr.cachers[tableName] = cacher mgr.cacherLock.Unlock() } +// GetCacher returns a cache of a table func (mgr *Manager) GetCacher(tableName string) Cacher { var cacher Cacher var ok bool diff --git a/contexts/hook.go b/contexts/hook.go index 71ad8e87..f6d86cfc 100644 --- a/contexts/hook.go +++ b/contexts/hook.go @@ -31,26 +31,31 @@ func NewContextHook(ctx context.Context, sql string, args []interface{}) *Contex } } +// End finish the hook invokation func (c *ContextHook) End(ctx context.Context, result sql.Result, err error) { c.Ctx = ctx c.Result = result c.Err = err - c.ExecuteTime = time.Now().Sub(c.start) + c.ExecuteTime = time.Since(c.start) } +// Hook represents a hook behaviour type Hook interface { BeforeProcess(c *ContextHook) (context.Context, error) AfterProcess(c *ContextHook) error } +// Hooks implements Hook interface but contains multiple Hook type Hooks struct { hooks []Hook } +// AddHook adds a Hook func (h *Hooks) AddHook(hooks ...Hook) { h.hooks = append(h.hooks, hooks...) } +// BeforeProcess invoked before execute the process func (h *Hooks) BeforeProcess(c *ContextHook) (context.Context, error) { ctx := c.Ctx for _, h := range h.hooks { @@ -63,6 +68,7 @@ func (h *Hooks) BeforeProcess(c *ContextHook) (context.Context, error) { return ctx, nil } +// AfterProcess invoked after exetue the process func (h *Hooks) AfterProcess(c *ContextHook) error { firstErr := c.Err for _, h := range h.hooks { diff --git a/convert.go b/convert.go deleted file mode 100644 index c19d30e0..00000000 --- a/convert.go +++ /dev/null @@ -1,422 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "database/sql/driver" - "errors" - "fmt" - "reflect" - "strconv" - "time" -) - -var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error - -func strconvErr(err error) error { - if ne, ok := err.(*strconv.NumError); ok { - return ne.Err - } - return err -} - -func cloneBytes(b []byte) []byte { - if b == nil { - return nil - } - c := make([]byte, len(b)) - copy(c, b) - return c -} - -func asString(src interface{}) string { - switch v := src.(type) { - case string: - return v - case []byte: - return string(v) - } - rv := reflect.ValueOf(src) - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.FormatInt(rv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.FormatUint(rv.Uint(), 10) - case reflect.Float64: - return strconv.FormatFloat(rv.Float(), 'g', -1, 64) - case reflect.Float32: - return strconv.FormatFloat(rv.Float(), 'g', -1, 32) - case reflect.Bool: - return strconv.FormatBool(rv.Bool()) - } - return fmt.Sprintf("%v", src) -} - -func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { - switch rv.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.AppendInt(buf, rv.Int(), 10), true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return strconv.AppendUint(buf, rv.Uint(), 10), true - case reflect.Float32: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true - case reflect.Float64: - return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true - case reflect.Bool: - return strconv.AppendBool(buf, rv.Bool()), true - case reflect.String: - s := rv.String() - return append(buf, s...), true - } - return -} - -// convertAssign copies to dest the value in src, converting it if possible. -// An error is returned if the copy would result in loss of information. -// dest should be a pointer type. -func convertAssign(dest, src interface{}) error { - // Common cases, without reflect. - switch s := src.(type) { - case string: - switch d := dest.(type) { - case *string: - if d == nil { - return errNilPtr - } - *d = s - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = []byte(s) - return nil - } - case []byte: - switch d := dest.(type) { - case *string: - if d == nil { - return errNilPtr - } - *d = string(s) - return nil - case *interface{}: - if d == nil { - return errNilPtr - } - *d = cloneBytes(s) - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = cloneBytes(s) - return nil - } - - case time.Time: - switch d := dest.(type) { - case *string: - *d = s.Format(time.RFC3339Nano) - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = []byte(s.Format(time.RFC3339Nano)) - return nil - } - case nil: - switch d := dest.(type) { - case *interface{}: - if d == nil { - return errNilPtr - } - *d = nil - return nil - case *[]byte: - if d == nil { - return errNilPtr - } - *d = nil - return nil - } - } - - var sv reflect.Value - - switch d := dest.(type) { - case *string: - sv = reflect.ValueOf(src) - switch sv.Kind() { - case reflect.Bool, - reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Float32, reflect.Float64: - *d = asString(src) - return nil - } - case *[]byte: - sv = reflect.ValueOf(src) - if b, ok := asBytes(nil, sv); ok { - *d = b - return nil - } - case *bool: - bv, err := driver.Bool.ConvertValue(src) - if err == nil { - *d = bv.(bool) - } - return err - case *interface{}: - *d = src - return nil - } - - dpv := reflect.ValueOf(dest) - if dpv.Kind() != reflect.Ptr { - return errors.New("destination not a pointer") - } - if dpv.IsNil() { - return errNilPtr - } - - if !sv.IsValid() { - sv = reflect.ValueOf(src) - } - - dv := reflect.Indirect(dpv) - if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { - switch b := src.(type) { - case []byte: - dv.Set(reflect.ValueOf(cloneBytes(b))) - default: - dv.Set(sv) - } - return nil - } - - if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { - dv.Set(sv.Convert(dv.Type())) - return nil - } - - switch dv.Kind() { - case reflect.Ptr: - if src == nil { - dv.Set(reflect.Zero(dv.Type())) - return nil - } - - dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - s := asString(src) - i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetInt(i64) - return nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - s := asString(src) - u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetUint(u64) - return nil - case reflect.Float32, reflect.Float64: - s := asString(src) - f64, err := strconv.ParseFloat(s, dv.Type().Bits()) - if err != nil { - err = strconvErr(err) - return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) - } - dv.SetFloat(f64) - return nil - case reflect.String: - dv.SetString(asString(src)) - return nil - } - - return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) -} - -func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { - switch tp.Kind() { - case reflect.Int64: - return vv.Int(), nil - case reflect.Int: - return int(vv.Int()), nil - case reflect.Int32: - return int32(vv.Int()), nil - case reflect.Int16: - return int16(vv.Int()), nil - case reflect.Int8: - return int8(vv.Int()), nil - case reflect.Uint64: - return vv.Uint(), nil - case reflect.Uint: - return uint(vv.Uint()), nil - case reflect.Uint32: - return uint32(vv.Uint()), nil - case reflect.Uint16: - return uint16(vv.Uint()), nil - case reflect.Uint8: - return uint8(vv.Uint()), nil - case reflect.String: - return vv.String(), nil - case reflect.Slice: - if tp.Elem().Kind() == reflect.Uint8 { - v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) - if err != nil { - return nil, err - } - return v, nil - } - - } - return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) -} - -func asBool(bs []byte) (bool, error) { - if len(bs) == 0 { - return false, nil - } - if bs[0] == 0x00 { - return false, nil - } else if bs[0] == 0x01 { - return true, nil - } - return strconv.ParseBool(string(bs)) -} - -// str2PK convert string value to primary key value according to tp -func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { - var err error - var result interface{} - var defReturn = reflect.Zero(tp) - - switch tp.Kind() { - case reflect.Int: - result, err = strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int: %s", s, err.Error()) - } - case reflect.Int8: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int8: %s", s, err.Error()) - } - result = int8(x) - case reflect.Int16: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int16: %s", s, err.Error()) - } - result = int16(x) - case reflect.Int32: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int32: %s", s, err.Error()) - } - result = int32(x) - case reflect.Int64: - result, err = strconv.ParseInt(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int64: %s", s, err.Error()) - } - case reflect.Uint: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint: %s", s, err.Error()) - } - result = uint(x) - case reflect.Uint8: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint8: %s", s, err.Error()) - } - result = uint8(x) - case reflect.Uint16: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint16: %s", s, err.Error()) - } - result = uint16(x) - case reflect.Uint32: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint32: %s", s, err.Error()) - } - result = uint32(x) - case reflect.Uint64: - result, err = strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint64: %s", s, err.Error()) - } - case reflect.String: - result = s - default: - return defReturn, errors.New("unsupported convert type") - } - return reflect.ValueOf(result).Convert(tp), nil -} - -func str2PK(s string, tp reflect.Type) (interface{}, error) { - v, err := str2PKValue(s, tp) - if err != nil { - return nil, err - } - return v.Interface(), nil -} - -func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { - var v interface{} - kind := tp.Kind() - - if kind == reflect.Ptr { - kind = tp.Elem().Kind() - } - - switch kind { - case reflect.Int16: - temp := int16(id) - v = &temp - case reflect.Int32: - temp := int32(id) - v = &temp - case reflect.Int: - temp := int(id) - v = &temp - case reflect.Int64: - temp := id - v = &temp - case reflect.Uint16: - temp := uint16(id) - v = &temp - case reflect.Uint32: - temp := uint32(id) - v = &temp - case reflect.Uint64: - temp := uint64(id) - v = &temp - case reflect.Uint: - temp := uint(id) - v = &temp - } - - if tp.Kind() == reflect.Ptr { - return reflect.ValueOf(v).Convert(tp) - } - return reflect.ValueOf(v).Elem().Convert(tp) -} - -func int64ToInt(id int64, tp reflect.Type) interface{} { - return int64ToIntValue(id, tp).Interface() -} diff --git a/convert/bool.go b/convert/bool.go new file mode 100644 index 00000000..58b23f4b --- /dev/null +++ b/convert/bool.go @@ -0,0 +1,51 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "strconv" +) + +// AsBool convert interface as bool +func AsBool(src interface{}) (bool, error) { + switch v := src.(type) { + case bool: + return v, nil + case *bool: + return *v, nil + case *sql.NullBool: + return v.Bool, nil + case int64: + return v > 0, nil + case int: + return v > 0, nil + case int8: + return v > 0, nil + case int16: + return v > 0, nil + case int32: + return v > 0, nil + case []byte: + if len(v) == 0 { + return false, nil + } + if v[0] == 0x00 { + return false, nil + } else if v[0] == 0x01 { + return true, nil + } + return strconv.ParseBool(string(v)) + case string: + return strconv.ParseBool(v) + case *sql.NullInt64: + return v.Int64 > 0, nil + case *sql.NullInt32: + return v.Int32 > 0, nil + default: + return false, fmt.Errorf("unknow type %T as bool", src) + } +} diff --git a/convert/conversion.go b/convert/conversion.go index 16f1a92a..b69e345c 100644 --- a/convert/conversion.go +++ b/convert/conversion.go @@ -4,9 +4,386 @@ package convert +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "math/big" + "reflect" + "strconv" + "time" +) + // Conversion is an interface. A type implements Conversion will according // the custom method to fill into database and retrieve from database. type Conversion interface { FromDB([]byte) error ToDB() ([]byte, error) } + +// ErrNilPtr represents an error +var ErrNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + c := make([]byte, len(b)) + copy(c, b) + return c +} + +// Assign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func Assign(dest, src interface{}, originalLocation *time.Location, convertedLocation *time.Location) error { + // Common cases, without reflect. + switch s := src.(type) { + case *interface{}: + return Assign(dest, *s, originalLocation, convertedLocation) + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return ErrNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return ErrNilPtr + } + *d = []byte(s) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return ErrNilPtr + } + *d = string(s) + return nil + case *interface{}: + if d == nil { + return ErrNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return ErrNilPtr + } + *d = cloneBytes(s) + return nil + } + case time.Time: + switch d := dest.(type) { + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return ErrNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + } + case nil: + switch d := dest.(type) { + case *interface{}: + if d == nil { + return ErrNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return ErrNilPtr + } + *d = nil + return nil + } + case *sql.NullString: + switch d := dest.(type) { + case *int: + if s.Valid { + *d, _ = strconv.Atoi(s.String) + } + return nil + case *int64: + if s.Valid { + *d, _ = strconv.ParseInt(s.String, 10, 64) + } + return nil + case *string: + if s.Valid { + *d = s.String + } + return nil + case *time.Time: + if s.Valid { + var err error + dt, err := String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + *d = *dt + } + return nil + case *sql.NullTime: + if s.Valid { + var err error + dt, err := String2Time(s.String, originalLocation, convertedLocation) + if err != nil { + return err + } + d.Valid = true + d.Time = *dt + } + return nil + case *big.Float: + if s.Valid { + if d == nil { + d = big.NewFloat(0) + } + d.SetString(s.String) + } + return nil + } + case *sql.NullInt32: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int32) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int32) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int32) + } + return nil + case *int32: + if s.Valid { + *d = s.Int32 + } + return nil + case *int64: + if s.Valid { + *d = int64(s.Int32) + } + return nil + } + case *sql.NullInt64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Int64) + } + return nil + case *int8: + if s.Valid { + *d = int8(s.Int64) + } + return nil + case *int16: + if s.Valid { + *d = int16(s.Int64) + } + return nil + case *int32: + if s.Valid { + *d = int32(s.Int64) + } + return nil + case *int64: + if s.Valid { + *d = s.Int64 + } + return nil + } + case *sql.NullFloat64: + switch d := dest.(type) { + case *int: + if s.Valid { + *d = int(s.Float64) + } + return nil + case *float64: + if s.Valid { + *d = s.Float64 + } + return nil + } + case *sql.NullBool: + switch d := dest.(type) { + case *bool: + if s.Valid { + *d = s.Bool + } + return nil + } + case *sql.NullTime: + switch d := dest.(type) { + case *time.Time: + if s.Valid { + *d = s.Time + } + return nil + case *string: + if s.Valid { + *d = s.Time.In(convertedLocation).Format("2006-01-02 15:04:05") + } + return nil + } + case *NullUint32: + switch d := dest.(type) { + case *uint8: + if s.Valid { + *d = uint8(s.Uint32) + } + return nil + case *uint16: + if s.Valid { + *d = uint16(s.Uint32) + } + return nil + case *uint: + if s.Valid { + *d = uint(s.Uint32) + } + return nil + } + case *NullUint64: + switch d := dest.(type) { + case *uint64: + if s.Valid { + *d = s.Uint64 + } + return nil + } + case *sql.RawBytes: + switch d := dest.(type) { + case Conversion: + return d.FromDB(*s) + } + } + + switch d := dest.(type) { + case *string: + var sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = AsString(src) + return nil + } + case *[]byte: + if b, ok := AsBytes(src); ok { + *d = b + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *interface{}: + *d = src + return nil + } + + return AssignValue(reflect.ValueOf(dest), src) +} + +var ( + scannerTypePlaceHolder sql.Scanner + scannerType = reflect.TypeOf(&scannerTypePlaceHolder).Elem() +) + +// AssignValue assign src as dv +func AssignValue(dv reflect.Value, src interface{}) error { + if src == nil { + return nil + } + if v, ok := src.(*interface{}); ok { + return AssignValue(dv, *v) + } + + if dv.Type().Implements(scannerType) { + return dv.Interface().(sql.Scanner).Scan(src) + } + + switch dv.Kind() { + case reflect.Ptr: + if dv.IsNil() { + dv.Set(reflect.New(dv.Type().Elem())) + } + return AssignValue(dv.Elem(), src) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i64, err := AsInt64(src) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u64, err := AsUint64(src) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + f64, err := AsFloat64(src) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T to a %s: %v", src, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + dv.SetString(AsString(src)) + return nil + case reflect.Bool: + b, err := AsBool(src) + if err != nil { + return err + } + dv.SetBool(b) + return nil + case reflect.Slice, reflect.Map, reflect.Struct, reflect.Array: + data, ok := AsBytes(src) + if !ok { + return fmt.Errorf("convert.AssignValue: src cannot be as bytes %#v", src) + } + if data == nil { + return nil + } + if dv.Kind() != reflect.Ptr { + dv = dv.Addr() + } + return json.Unmarshal(data, dv.Interface()) + default: + return fmt.Errorf("convert.AssignValue: unsupported Scan, storing driver.Value type %T into type %T", src, dv.Interface()) + } +} diff --git a/convert/float.go b/convert/float.go new file mode 100644 index 00000000..51b441ce --- /dev/null +++ b/convert/float.go @@ -0,0 +1,142 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "math/big" + "reflect" + "strconv" +) + +// AsFloat64 convets interface as float64 +func AsFloat64(src interface{}) (float64, error) { + switch v := src.(type) { + case int: + return float64(v), nil + case int16: + return float64(v), nil + case int32: + return float64(v), nil + case int8: + return float64(v), nil + case int64: + return float64(v), nil + case uint: + return float64(v), nil + case uint8: + return float64(v), nil + case uint16: + return float64(v), nil + case uint32: + return float64(v), nil + case uint64: + return float64(v), nil + case []byte: + return strconv.ParseFloat(string(v), 64) + case string: + return strconv.ParseFloat(v, 64) + case *sql.NullString: + return strconv.ParseFloat(v.String, 64) + case *sql.NullInt32: + return float64(v.Int32), nil + case *sql.NullInt64: + return float64(v.Int64), nil + case *sql.NullFloat64: + return v.Float64, nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float64(rv.Uint()), nil + case reflect.Float64, reflect.Float32: + return float64(rv.Float()), nil + case reflect.String: + return strconv.ParseFloat(rv.String(), 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + +// AsBigFloat converts interface as big.Float +func AsBigFloat(src interface{}) (*big.Float, error) { + res := big.NewFloat(0) + switch v := src.(type) { + case int: + res.SetInt64(int64(v)) + return res, nil + case int16: + res.SetInt64(int64(v)) + return res, nil + case int32: + res.SetInt64(int64(v)) + return res, nil + case int8: + res.SetInt64(int64(v)) + return res, nil + case int64: + res.SetInt64(int64(v)) + return res, nil + case uint: + res.SetUint64(uint64(v)) + return res, nil + case uint8: + res.SetUint64(uint64(v)) + return res, nil + case uint16: + res.SetUint64(uint64(v)) + return res, nil + case uint32: + res.SetUint64(uint64(v)) + return res, nil + case uint64: + res.SetUint64(uint64(v)) + return res, nil + case []byte: + res.SetString(string(v)) + return res, nil + case string: + res.SetString(v) + return res, nil + case *sql.NullString: + if v.Valid { + res.SetString(v.String) + return res, nil + } + return nil, nil + case *sql.NullInt32: + if v.Valid { + res.SetInt64(int64(v.Int32)) + return res, nil + } + return nil, nil + case *sql.NullInt64: + if v.Valid { + res.SetInt64(int64(v.Int64)) + return res, nil + } + return nil, nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + res.SetInt64(rv.Int()) + return res, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + res.SetUint64(rv.Uint()) + return res, nil + case reflect.Float64, reflect.Float32: + res.SetFloat64(rv.Float()) + return res, nil + case reflect.String: + res.SetString(rv.String()) + return res, nil + } + return nil, fmt.Errorf("unsupported value %T as big.Float", src) +} diff --git a/convert/int.go b/convert/int.go new file mode 100644 index 00000000..af8d4f75 --- /dev/null +++ b/convert/int.go @@ -0,0 +1,178 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "strconv" +) + +// AsInt64 converts interface as int64 +func AsInt64(src interface{}) (int64, error) { + switch v := src.(type) { + case int: + return int64(v), nil + case int16: + return int64(v), nil + case int32: + return int64(v), nil + case int8: + return int64(v), nil + case int64: + return v, nil + case uint: + return int64(v), nil + case uint8: + return int64(v), nil + case uint16: + return int64(v), nil + case uint32: + return int64(v), nil + case uint64: + return int64(v), nil + case []byte: + return strconv.ParseInt(string(v), 10, 64) + case string: + return strconv.ParseInt(v, 10, 64) + case *sql.NullString: + return strconv.ParseInt(v.String, 10, 64) + case *sql.NullInt32: + return int64(v.Int32), nil + case *sql.NullInt64: + return int64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(rv.Uint()), nil + case reflect.Float64, reflect.Float32: + return int64(rv.Float()), nil + case reflect.String: + return strconv.ParseInt(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as int64", src) +} + +// AsUint64 converts interface as uint64 +func AsUint64(src interface{}) (uint64, error) { + switch v := src.(type) { + case int: + return uint64(v), nil + case int16: + return uint64(v), nil + case int32: + return uint64(v), nil + case int8: + return uint64(v), nil + case int64: + return uint64(v), nil + case uint: + return uint64(v), nil + case uint8: + return uint64(v), nil + case uint16: + return uint64(v), nil + case uint32: + return uint64(v), nil + case uint64: + return v, nil + case []byte: + return strconv.ParseUint(string(v), 10, 64) + case string: + return strconv.ParseUint(v, 10, 64) + case *sql.NullString: + return strconv.ParseUint(v.String, 10, 64) + case *sql.NullInt32: + return uint64(v.Int32), nil + case *sql.NullInt64: + return uint64(v.Int64), nil + } + + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return uint64(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return uint64(rv.Uint()), nil + case reflect.Float64, reflect.Float32: + return uint64(rv.Float()), nil + case reflect.String: + return strconv.ParseUint(rv.String(), 10, 64) + } + return 0, fmt.Errorf("unsupported value %T as uint64", src) +} + +var ( + _ sql.Scanner = &NullUint64{} +) + +// NullUint64 represents an uint64 that may be null. +// NullUint64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint64 struct { + Uint64 uint64 + Valid bool +} + +// Scan implements the Scanner interface. +func (n *NullUint64) Scan(value interface{}) error { + if value == nil { + n.Uint64, n.Valid = 0, false + return nil + } + n.Valid = true + var err error + n.Uint64, err = AsUint64(value) + return err +} + +// Value implements the driver Valuer interface. +func (n NullUint64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Uint64, nil +} + +var ( + _ sql.Scanner = &NullUint32{} +) + +// NullUint32 represents an uint32 that may be null. +// NullUint32 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullUint32 struct { + Uint32 uint32 + Valid bool // Valid is true if Uint32 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullUint32) Scan(value interface{}) error { + if value == nil { + n.Uint32, n.Valid = 0, false + return nil + } + n.Valid = true + i64, err := AsUint64(value) + if err != nil { + return err + } + n.Uint32 = uint32(i64) + return nil +} + +// Value implements the driver Valuer interface. +func (n NullUint32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Uint32), nil +} diff --git a/convert/interface.go b/convert/interface.go new file mode 100644 index 00000000..b0f28c81 --- /dev/null +++ b/convert/interface.go @@ -0,0 +1,49 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "time" +) + +// Interface2Interface converts interface of pointer as interface of value +func Interface2Interface(userLocation *time.Location, v interface{}) (interface{}, error) { + if v == nil { + return nil, nil + } + switch vv := v.(type) { + case *int64: + return *vv, nil + case *int8: + return *vv, nil + case *sql.NullString: + return vv.String, nil + case *sql.RawBytes: + if len([]byte(*vv)) > 0 { + return []byte(*vv), nil + } + return nil, nil + case *sql.NullInt32: + return vv.Int32, nil + case *sql.NullInt64: + return vv.Int64, nil + case *sql.NullFloat64: + return vv.Float64, nil + case *sql.NullBool: + if vv.Valid { + return vv.Bool, nil + } + return nil, nil + case *sql.NullTime: + if vv.Valid { + return vv.Time.In(userLocation).Format("2006-01-02 15:04:05"), nil + } + return "", nil + default: + return "", fmt.Errorf("convert assign string unsupported type: %#v", vv) + } +} diff --git a/convert/scanner.go b/convert/scanner.go new file mode 100644 index 00000000..505d3be0 --- /dev/null +++ b/convert/scanner.go @@ -0,0 +1,19 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import "database/sql" + +var ( + _ sql.Scanner = &EmptyScanner{} +) + +// EmptyScanner represents an empty scanner which will ignore the scan +type EmptyScanner struct{} + +// Scan implements sql.Scanner +func (EmptyScanner) Scan(value interface{}) error { + return nil +} diff --git a/convert/string.go b/convert/string.go new file mode 100644 index 00000000..de11fa01 --- /dev/null +++ b/convert/string.go @@ -0,0 +1,75 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "reflect" + "strconv" +) + +// AsString converts interface as string +func AsString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + case *sql.NullString: + return v.String + case *sql.NullInt32: + return fmt.Sprintf("%d", v.Int32) + case *sql.NullInt64: + return fmt.Sprintf("%d", v.Int64) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +// AsBytes converts interface as bytes +func AsBytes(src interface{}) ([]byte, bool) { + switch t := src.(type) { + case []byte: + return t, true + case *sql.NullString: + if !t.Valid { + return nil, true + } + return []byte(t.String), true + case *sql.RawBytes: + return *t, true + } + + rv := reflect.ValueOf(src) + + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(nil, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(nil, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(nil, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(nil, rv.Bool()), true + case reflect.String: + return []byte(rv.String()), true + } + return nil, false +} diff --git a/convert/time.go b/convert/time.go new file mode 100644 index 00000000..cc2e0a10 --- /dev/null +++ b/convert/time.go @@ -0,0 +1,127 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "database/sql" + "fmt" + "strconv" + "strings" + "time" + + "xorm.io/xorm/internal/utils" +) + +// String2Time converts a string to time with original location +func String2Time(s string, originalLocation *time.Location, convertedLocation *time.Location) (*time.Time, error) { + if len(s) == 19 { + if s == utils.ZeroTime0 || s == utils.ZeroTime1 { + return &time.Time{}, nil + } + dt, err := time.ParseInLocation("2006-01-02 15:04:05", s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 20 && s[10] == 'T' && s[19] == 'Z' { + dt, err := time.ParseInLocation("2006-01-02T15:04:05", s[:19], originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 25 && s[10] == 'T' && s[19] == '+' && s[22] == ':' { + dt, err := time.Parse(time.RFC3339, s) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) >= 21 && s[19] == '.' { + var layout = "2006-01-02 15:04:05." + strings.Repeat("0", len(s)-20) + dt, err := time.ParseInLocation(layout, s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else if len(s) == 10 && s[4] == '-' { + if s == "0000-00-00" || s == "0001-01-01" { + return &time.Time{}, nil + } + dt, err := time.ParseInLocation("2006-01-02", s, originalLocation) + if err != nil { + return nil, err + } + dt = dt.In(convertedLocation) + return &dt, nil + } else { + i, err := strconv.ParseInt(s, 10, 64) + if err == nil { + tm := time.Unix(i, 0).In(convertedLocation) + return &tm, nil + } + } + return nil, fmt.Errorf("unsupported conversion from %s to time", s) +} + +// AsTime converts interface as time +func AsTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time.Time, error) { + switch t := src.(type) { + case string: + return String2Time(t, dbLoc, uiLoc) + case *sql.NullString: + if !t.Valid { + return nil, nil + } + return String2Time(t.String, dbLoc, uiLoc) + case []uint8: + if t == nil { + return nil, nil + } + return String2Time(string(t), dbLoc, uiLoc) + case *sql.NullTime: + if !t.Valid { + return nil, nil + } + z, _ := t.Time.Zone() + if len(z) == 0 || t.Time.Year() == 0 || t.Time.Location().String() != dbLoc.String() { + tm := time.Date(t.Time.Year(), t.Time.Month(), t.Time.Day(), t.Time.Hour(), + t.Time.Minute(), t.Time.Second(), t.Time.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.Time.In(uiLoc) + return &tm, nil + case *time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case time.Time: + z, _ := t.Zone() + if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbLoc.String() { + tm := time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), + t.Minute(), t.Second(), t.Nanosecond(), dbLoc).In(uiLoc) + return &tm, nil + } + tm := t.In(uiLoc) + return &tm, nil + case int: + tm := time.Unix(int64(t), 0).In(uiLoc) + return &tm, nil + case int64: + tm := time.Unix(t, 0).In(uiLoc) + return &tm, nil + case *sql.NullInt64: + tm := time.Unix(t.Int64, 0).In(uiLoc) + return &tm, nil + } + return nil, fmt.Errorf("unsupported value %#v as time", src) +} diff --git a/convert/time_test.go b/convert/time_test.go new file mode 100644 index 00000000..5ddceb64 --- /dev/null +++ b/convert/time_test.go @@ -0,0 +1,31 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestString2Time(t *testing.T) { + expectedLoc, err := time.LoadLocation("Asia/Shanghai") + assert.NoError(t, err) + + var kases = map[string]time.Time{ + "2021-08-10": time.Date(2021, 8, 10, 8, 0, 0, 0, expectedLoc), + "2021-06-06T22:58:20+08:00": time.Date(2021, 6, 6, 22, 58, 20, 0, expectedLoc), + "2021-07-11 10:44:00": time.Date(2021, 7, 11, 18, 44, 0, 0, expectedLoc), + "2021-08-10T10:33:04Z": time.Date(2021, 8, 10, 18, 33, 04, 0, expectedLoc), + } + for layout, tm := range kases { + t.Run(layout, func(t *testing.T) { + target, err := String2Time(layout, time.UTC, expectedLoc) + assert.NoError(t, err) + assert.EqualValues(t, tm, *target) + }) + } +} diff --git a/core/db.go b/core/db.go index 50c64c6f..b476ef9a 100644 --- a/core/db.go +++ b/core/db.go @@ -23,6 +23,7 @@ var ( DefaultCacheSize = 200 ) +// MapToSlice map query and struct as sql and args func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -44,6 +45,7 @@ func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { return query, args, err } +// StructToSlice converts a query and struct as sql and args func StructToSlice(query string, st interface{}) (string, []interface{}, error) { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -134,7 +136,7 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value { cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0} db.reflectCache[typ] = cs } else { - cs.idx = cs.idx + 1 + cs.idx++ } return cs.value.Index(cs.idx).Addr() } @@ -176,6 +178,7 @@ func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { return db.QueryMapContext(context.Background(), query, mp) } +// QueryStructContext query rows with struct func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -184,10 +187,12 @@ func (db *DB) QueryStructContext(ctx context.Context, query string, st interface return db.QueryContext(ctx, query, args...) } +// QueryStruct query rows with struct func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { return db.QueryStructContext(context.Background(), query, st) } +// QueryRowContext query row with args func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := db.QueryContext(ctx, query, args...) if err != nil { @@ -196,10 +201,12 @@ func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interfa return &Row{rows, nil} } +// QueryRow query row with args func (db *DB) QueryRow(query string, args ...interface{}) *Row { return db.QueryRowContext(context.Background(), query, args...) } +// QueryRowMapContext query row with map func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { query, args, err := MapToSlice(query, mp) if err != nil { @@ -208,10 +215,12 @@ func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface return db.QueryRowContext(ctx, query, args...) } +// QueryRowMap query row with map func (db *DB) QueryRowMap(query string, mp interface{}) *Row { return db.QueryRowMapContext(context.Background(), query, mp) } +// QueryRowStructContext query row with struct func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { query, args, err := StructToSlice(query, st) if err != nil { @@ -220,6 +229,7 @@ func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interf return db.QueryRowContext(ctx, query, args...) } +// QueryRowStruct query row with struct func (db *DB) QueryRowStruct(query string, st interface{}) *Row { return db.QueryRowStructContext(context.Background(), query, st) } @@ -239,10 +249,12 @@ func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) return db.ExecContext(ctx, query, args...) } +// ExecMap exec query with map func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { return db.ExecMapContext(context.Background(), query, mp) } +// ExecStructContext exec query with map func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -251,6 +263,7 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ return db.ExecContext(ctx, query, args...) } +// ExecContext exec query with args func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { hookCtx := contexts.NewContextHook(ctx, query, args) ctx, err := db.beforeProcess(hookCtx) @@ -265,6 +278,7 @@ func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{} return res, nil } +// ExecStruct exec query with struct func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { return db.ExecStructContext(context.Background(), query, st) } @@ -288,6 +302,7 @@ func (db *DB) afterProcess(c *contexts.ContextHook) error { return err } +// AddHook adds hook func (db *DB) AddHook(h ...contexts.Hook) { db.hooks.AddHook(h...) } diff --git a/core/db_test.go b/core/db_test.go index 777ab0ad..a9c19392 100644 --- a/core/db_test.go +++ b/core/db_test.go @@ -15,25 +15,26 @@ import ( _ "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" + _ "modernc.org/sqlite" ) var ( dbtype = flag.String("dbtype", "sqlite3", "database type") dbConn = flag.String("dbConn", "./db_test.db", "database connect string") - createTableSql string + createTableSQL string ) func TestMain(m *testing.M) { flag.Parse() switch *dbtype { - case "sqlite3": - createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + + case "sqlite3", "sqlite": + createTableSQL = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" case "mysql": fallthrough default: - createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + + createTableSQL = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" } @@ -45,8 +46,11 @@ func TestMain(m *testing.M) { func testOpen() (*DB, error) { switch *dbtype { case "sqlite3": - os.Remove("./test.db") + os.Remove("./test_sqlite3.db") return Open("sqlite3", "./test.db") + case "sqlite": + os.Remove("./test_sqlite.db") + return Open("sqlite", "./test.db") case "mysql": return Open("mysql", *dbConn) default: @@ -62,7 +66,7 @@ func BenchmarkOriQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -92,7 +96,7 @@ func BenchmarkOriQuery(b *testing.B) { if err != nil { b.Error(err) } - //fmt.Println(Id, Name, Title, Age, Alias, NickName) + // fmt.Println(Id, Name, Title, Age, Alias, NickName) } rows.Close() } @@ -117,7 +121,7 @@ func BenchmarkStructQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -162,7 +166,7 @@ func BenchmarkStruct2Query(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -208,7 +212,7 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -241,13 +245,13 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) { b.Error(err) } b.Log(slice) - switch slice[1].(type) { + switch st := slice[1].(type) { case *string: - if *slice[1].(*string) != "xlw" { + if *st != "xlw" { b.Error(errors.New("name should be xlw")) } case []byte: - if string(slice[1].([]byte)) != "xlw" { + if string(st) != "xlw" { b.Error(errors.New("name should be xlw")) } } @@ -266,7 +270,7 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -317,7 +321,7 @@ func BenchmarkSliceStringQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -368,7 +372,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -395,14 +399,14 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { if err != nil { b.Error(err) } - switch m["name"].(type) { + switch t := m["name"].(type) { case string: - if m["name"].(string) != "xlw" { + if t != "xlw" { b.Log(m) b.Error(errors.New("name should be xlw")) } case []byte: - if string(m["name"].([]byte)) != "xlw" { + if string(t) != "xlw" { b.Log(m) b.Error(errors.New("name should be xlw")) } @@ -422,7 +426,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -469,7 +473,7 @@ func BenchmarkMapStringQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -515,7 +519,7 @@ func BenchmarkExec(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -540,7 +544,7 @@ func BenchmarkExecMap(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -573,7 +577,7 @@ func TestExecMap(t *testing.T) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { t.Error(err) } @@ -616,7 +620,7 @@ func TestExecStruct(t *testing.T) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { t.Error(err) } @@ -659,7 +663,7 @@ func BenchmarkExecStruct(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } diff --git a/core/rows.go b/core/rows.go index a1e8bfbc..75d6ebf0 100644 --- a/core/rows.go +++ b/core/rows.go @@ -11,11 +11,13 @@ import ( "sync" ) +// Rows represents rows of table type Rows struct { *sql.Rows db *DB } +// ToMapString returns all records func (rs *Rows) ToMapString() ([]map[string]string, error) { cols, err := rs.Columns() if err != nil { @@ -34,7 +36,7 @@ func (rs *Rows) ToMapString() ([]map[string]string, error) { return results, nil } -// scan data to a struct's pointer according field index +// ScanStructByIndex scan data to a struct's pointer according field index func (rs *Rows) ScanStructByIndex(dest ...interface{}) error { if len(dest) == 0 { return errors.New("at least one struct") @@ -60,7 +62,7 @@ func (rs *Rows) ScanStructByIndex(dest ...interface{}) error { for _, vvv := range vvvs { for j := 0; j < vvv.NumField(); j++ { newDest[i] = vvv.Field(j).Addr().Interface() - i = i + 1 + i++ } } @@ -94,7 +96,7 @@ func fieldByName(v reflect.Value, name string) reflect.Value { return reflect.Zero(t) } -// scan data to a struct's pointer according field name +// ScanStructByName scan data to a struct's pointer according field name func (rs *Rows) ScanStructByName(dest interface{}) error { vv := reflect.ValueOf(dest) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -120,7 +122,7 @@ func (rs *Rows) ScanStructByName(dest interface{}) error { return rs.Rows.Scan(newDest...) } -// scan data to a slice's pointer, slice's length should equal to columns' number +// ScanSlice scan data to a slice's pointer, slice's length should equal to columns' number func (rs *Rows) ScanSlice(dest interface{}) error { vv := reflect.ValueOf(dest) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice { @@ -155,7 +157,7 @@ func (rs *Rows) ScanSlice(dest interface{}) error { return nil } -// scan data to a map's pointer +// ScanMap scan data to a map's pointer func (rs *Rows) ScanMap(dest interface{}) error { vv := reflect.ValueOf(dest) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -187,6 +189,7 @@ func (rs *Rows) ScanMap(dest interface{}) error { return nil } +// Row reprents a row of a tab type Row struct { rows *Rows // One of these two will be non-nil: @@ -205,6 +208,7 @@ func NewRow(rows *Rows, err error) *Row { return &Row{rows, err} } +// Columns returns all columns of the row func (row *Row) Columns() ([]string, error) { if row.err != nil { return nil, row.err @@ -212,6 +216,7 @@ func (row *Row) Columns() ([]string, error) { return row.rows.Columns() } +// Scan retrieves all row column values func (row *Row) Scan(dest ...interface{}) error { if row.err != nil { return row.err @@ -238,6 +243,7 @@ func (row *Row) Scan(dest ...interface{}) error { return row.rows.Close() } +// ScanStructByName retrieves all row column values into a struct func (row *Row) ScanStructByName(dest interface{}) error { if row.err != nil { return row.err @@ -258,6 +264,7 @@ func (row *Row) ScanStructByName(dest interface{}) error { return row.rows.Close() } +// ScanStructByIndex retrieves all row column values into a struct func (row *Row) ScanStructByIndex(dest interface{}) error { if row.err != nil { return row.err @@ -278,7 +285,7 @@ func (row *Row) ScanStructByIndex(dest interface{}) error { return row.rows.Close() } -// scan data to a slice's pointer, slice's length should equal to columns' number +// ScanSlice scan data to a slice's pointer, slice's length should equal to columns' number func (row *Row) ScanSlice(dest interface{}) error { if row.err != nil { return row.err @@ -300,7 +307,7 @@ func (row *Row) ScanSlice(dest interface{}) error { return row.rows.Close() } -// scan data to a map's pointer +// ScanMap scan data to a map's pointer func (row *Row) ScanMap(dest interface{}) error { if row.err != nil { return row.err @@ -322,6 +329,7 @@ func (row *Row) ScanMap(dest interface{}) error { return row.rows.Close() } +// ToMapString returns all clumns of this record func (row *Row) ToMapString() (map[string]string, error) { cols, err := row.Columns() if err != nil { diff --git a/core/scan.go b/core/scan.go index 897b5341..1e7e4525 100644 --- a/core/scan.go +++ b/core/scan.go @@ -10,12 +10,14 @@ import ( "time" ) +// NullTime defines a customize type NullTime type NullTime time.Time var ( _ driver.Valuer = NullTime{} ) +// Scan implements driver.Valuer func (ns *NullTime) Scan(value interface{}) error { if value == nil { return nil @@ -58,9 +60,11 @@ func convertTime(dest *NullTime, src interface{}) error { return nil } +// EmptyScanner represents an empty scanner type EmptyScanner struct { } +// Scan implements func (EmptyScanner) Scan(src interface{}) error { return nil } diff --git a/core/stmt.go b/core/stmt.go index d46ac9c6..3247efed 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -21,6 +21,7 @@ type Stmt struct { query string } +// PrepareContext creates a prepare statement func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { names := make(map[string]int) var i int @@ -42,10 +43,12 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { return &Stmt{stmt, db, names, query}, nil } +// Prepare creates a prepare statement func (db *DB) Prepare(query string) (*Stmt, error) { return db.PrepareContext(context.Background(), query) } +// ExecMapContext execute with map func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -59,10 +62,12 @@ func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, return s.ExecContext(ctx, args...) } +// ExecMap executes with map func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { return s.ExecMapContext(context.Background(), mp) } +// ExecStructContext executes with struct func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -76,17 +81,19 @@ func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Resul return s.ExecContext(ctx, args...) } +// ExecStruct executes with struct func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { return s.ExecStructContext(context.Background(), st) } +// ExecContext with args func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { hookCtx := contexts.NewContextHook(ctx, s.query, args) ctx, err := s.db.beforeProcess(hookCtx) if err != nil { return nil, err } - res, err := s.Stmt.ExecContext(ctx, args) + res, err := s.Stmt.ExecContext(ctx, args...) hookCtx.End(ctx, res, err) if err := s.db.afterProcess(hookCtx); err != nil { return nil, err @@ -94,6 +101,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result return res, nil } +// QueryContext query with args func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { hookCtx := contexts.NewContextHook(ctx, s.query, args) ctx, err := s.db.beforeProcess(hookCtx) @@ -108,10 +116,12 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er return &Rows{rows, s.db}, nil } +// Query query with args func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return s.QueryContext(context.Background(), args...) } +// QueryMapContext query with map func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -126,10 +136,12 @@ func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, erro return s.QueryContext(ctx, args...) } +// QueryMap query with map func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { return s.QueryMapContext(context.Background(), mp) } +// QueryStructContext query with struct func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -144,19 +156,23 @@ func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, e return s.QueryContext(ctx, args...) } +// QueryStruct query with struct func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { return s.QueryStructContext(context.Background(), st) } +// QueryRowContext query row with args func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { rows, err := s.QueryContext(ctx, args...) return &Row{rows, err} } +// QueryRow query row with args func (s *Stmt) QueryRow(args ...interface{}) *Row { return s.QueryRowContext(context.Background(), args...) } +// QueryRowMapContext query row with map func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -171,10 +187,12 @@ func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { return s.QueryRowContext(ctx, args...) } +// QueryRowMap query row with map func (s *Stmt) QueryRowMap(mp interface{}) *Row { return s.QueryRowMapContext(context.Background(), mp) } +// QueryRowStructContext query row with struct func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -189,6 +207,7 @@ func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { return s.QueryRowContext(ctx, args...) } +// QueryRowStruct query row with struct func (s *Stmt) QueryRowStruct(st interface{}) *Row { return s.QueryRowStructContext(context.Background(), st) } diff --git a/core/tx.go b/core/tx.go index a85a6874..a2f745f8 100644 --- a/core/tx.go +++ b/core/tx.go @@ -22,6 +22,7 @@ type Tx struct { ctx context.Context } +// BeginTx begin a transaction with option func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil) ctx, err := db.beforeProcess(hookCtx) @@ -36,10 +37,12 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { return &Tx{tx, db, ctx}, nil } +// Begin begins a transaction func (db *DB) Begin() (*Tx, error) { return db.BeginTx(context.Background(), nil) } +// Commit submit the transaction func (tx *Tx) Commit() error { hookCtx := contexts.NewContextHook(tx.ctx, "COMMIT", nil) ctx, err := tx.db.beforeProcess(hookCtx) @@ -48,12 +51,10 @@ func (tx *Tx) Commit() error { } err = tx.Tx.Commit() hookCtx.End(ctx, nil, err) - if err := tx.db.afterProcess(hookCtx); err != nil { - return err - } - return nil + return tx.db.afterProcess(hookCtx) } +// Rollback rollback the transaction func (tx *Tx) Rollback() error { hookCtx := contexts.NewContextHook(tx.ctx, "ROLLBACK", nil) ctx, err := tx.db.beforeProcess(hookCtx) @@ -62,12 +63,10 @@ func (tx *Tx) Rollback() error { } err = tx.Tx.Rollback() hookCtx.End(ctx, nil, err) - if err := tx.db.afterProcess(hookCtx); err != nil { - return err - } - return nil + return tx.db.afterProcess(hookCtx) } +// PrepareContext prepare the query func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { names := make(map[string]int) var i int @@ -89,19 +88,23 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { return &Stmt{stmt, tx.db, names, query}, nil } +// Prepare prepare the query func (tx *Tx) Prepare(query string) (*Stmt, error) { return tx.PrepareContext(context.Background(), query) } +// StmtContext creates Stmt with context func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { stmt.Stmt = tx.Tx.StmtContext(ctx, stmt.Stmt) return stmt } +// Stmt creates Stmt func (tx *Tx) Stmt(stmt *Stmt) *Stmt { return tx.StmtContext(context.Background(), stmt) } +// ExecMapContext executes query with args in a map func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { query, args, err := MapToSlice(query, mp) if err != nil { @@ -110,10 +113,12 @@ func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) return tx.ExecContext(ctx, query, args...) } +// ExecMap executes query with args in a map func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) { return tx.ExecMapContext(context.Background(), query, mp) } +// ExecStructContext executes query with args in a struct func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -122,6 +127,7 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{ return tx.ExecContext(ctx, query, args...) } +// ExecContext executes a query with args func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { hookCtx := contexts.NewContextHook(ctx, query, args) ctx, err := tx.db.beforeProcess(hookCtx) @@ -136,10 +142,12 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{} return res, err } +// ExecStruct executes query with args in a struct func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { return tx.ExecStructContext(context.Background(), query, st) } +// QueryContext query with args func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { hookCtx := contexts.NewContextHook(ctx, query, args) ctx, err := tx.db.beforeProcess(hookCtx) @@ -157,10 +165,12 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{ return &Rows{rows, tx.db}, nil } +// Query query with args func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { return tx.QueryContext(context.Background(), query, args...) } +// QueryMapContext query with args in a map func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) { query, args, err := MapToSlice(query, mp) if err != nil { @@ -169,10 +179,12 @@ func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) return tx.QueryContext(ctx, query, args...) } +// QueryMap query with args in a map func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { return tx.QueryMapContext(context.Background(), query, mp) } +// QueryStructContext query with args in struct func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -181,19 +193,23 @@ func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface return tx.QueryContext(ctx, query, args...) } +// QueryStruct query with args in struct func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) { return tx.QueryStructContext(context.Background(), query, st) } +// QueryRowContext query one row with args func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := tx.QueryContext(ctx, query, args...) return &Row{rows, err} } +// QueryRow query one row with args func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { return tx.QueryRowContext(context.Background(), query, args...) } +// QueryRowMapContext query one row with args in a map func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { query, args, err := MapToSlice(query, mp) if err != nil { @@ -202,10 +218,12 @@ func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface return tx.QueryRowContext(ctx, query, args...) } +// QueryRowMap query one row with args in a map func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { return tx.QueryRowMapContext(context.Background(), query, mp) } +// QueryRowStructContext query one row with args in struct func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { query, args, err := StructToSlice(query, st) if err != nil { @@ -214,6 +232,7 @@ func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interf return tx.QueryRowContext(ctx, query, args...) } +// QueryRowStruct query one row with args in struct func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { return tx.QueryRowStructContext(context.Background(), query, st) } diff --git a/dialects/dameng.go b/dialects/dameng.go new file mode 100644 index 00000000..5e92ec2f --- /dev/null +++ b/dialects/dameng.go @@ -0,0 +1,1201 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "context" + "database/sql" + "errors" + "fmt" + "net/url" + "strconv" + "strings" + + "xorm.io/xorm/convert" + "xorm.io/xorm/core" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +func init() { + RegisterDriver("dm", &damengDriver{}) + RegisterDialect(schemas.DAMENG, func() Dialect { + return &dameng{} + }) +} + +var ( + damengReservedWords = map[string]bool{ + "ACCESS": true, + "ACCOUNT": true, + "ACTIVATE": true, + "ADD": true, + "ADMIN": true, + "ADVISE": true, + "AFTER": true, + "ALL": true, + "ALL_ROWS": true, + "ALLOCATE": true, + "ALTER": true, + "ANALYZE": true, + "AND": true, + "ANY": true, + "ARCHIVE": true, + "ARCHIVELOG": true, + "ARRAY": true, + "AS": true, + "ASC": true, + "AT": true, + "AUDIT": true, + "AUTHENTICATED": true, + "AUTHORIZATION": true, + "AUTOEXTEND": true, + "AUTOMATIC": true, + "BACKUP": true, + "BECOME": true, + "BEFORE": true, + "BEGIN": true, + "BETWEEN": true, + "BFILE": true, + "BITMAP": true, + "BLOB": true, + "BLOCK": true, + "BODY": true, + "BY": true, + "CACHE": true, + "CACHE_INSTANCES": true, + "CANCEL": true, + "CASCADE": true, + "CAST": true, + "CFILE": true, + "CHAINED": true, + "CHANGE": true, + "CHAR": true, + "CHAR_CS": true, + "CHARACTER": true, + "CHECK": true, + "CHECKPOINT": true, + "CHOOSE": true, + "CHUNK": true, + "CLEAR": true, + "CLOB": true, + "CLONE": true, + "CLOSE": true, + "CLOSE_CACHED_OPEN_CURSORS": true, + "CLUSTER": true, + "COALESCE": true, + "COLUMN": true, + "COLUMNS": true, + "COMMENT": true, + "COMMIT": true, + "COMMITTED": true, + "COMPATIBILITY": true, + "COMPILE": true, + "COMPLETE": true, + "COMPOSITE_LIMIT": true, + "COMPRESS": true, + "COMPUTE": true, + "CONNECT": true, + "CONNECT_TIME": true, + "CONSTRAINT": true, + "CONSTRAINTS": true, + "CONTENTS": true, + "CONTINUE": true, + "CONTROLFILE": true, + "CONVERT": true, + "COST": true, + "CPU_PER_CALL": true, + "CPU_PER_SESSION": true, + "CREATE": true, + "CURRENT": true, + "CURRENT_SCHEMA": true, + "CURREN_USER": true, + "CURSOR": true, + "CYCLE": true, + "DANGLING": true, + "DATABASE": true, + "DATAFILE": true, + "DATAFILES": true, + "DATAOBJNO": true, + "DATE": true, + "DBA": true, + "DBHIGH": true, + "DBLOW": true, + "DBMAC": true, + "DEALLOCATE": true, + "DEBUG": true, + "DEC": true, + "DECIMAL": true, + "DECLARE": true, + "DEFAULT": true, + "DEFERRABLE": true, + "DEFERRED": true, + "DEGREE": true, + "DELETE": true, + "DEREF": true, + "DESC": true, + "DIRECTORY": true, + "DISABLE": true, + "DISCONNECT": true, + "DISMOUNT": true, + "DISTINCT": true, + "DISTRIBUTED": true, + "DML": true, + "DOUBLE": true, + "DROP": true, + "DUMP": true, + "EACH": true, + "ELSE": true, + "ENABLE": true, + "END": true, + "ENFORCE": true, + "ENTRY": true, + "ESCAPE": true, + "EXCEPT": true, + "EXCEPTIONS": true, + "EXCHANGE": true, + "EXCLUDING": true, + "EXCLUSIVE": true, + "EXECUTE": true, + "EXISTS": true, + "EXPIRE": true, + "EXPLAIN": true, + "EXTENT": true, + "EXTENTS": true, + "EXTERNALLY": true, + "FAILED_LOGIN_ATTEMPTS": true, + "FALSE": true, + "FAST": true, + "FILE": true, + "FIRST_ROWS": true, + "FLAGGER": true, + "FLOAT": true, + "FLOB": true, + "FLUSH": true, + "FOR": true, + "FORCE": true, + "FOREIGN": true, + "FREELIST": true, + "FREELISTS": true, + "FROM": true, + "FULL": true, + "FUNCTION": true, + "GLOBAL": true, + "GLOBALLY": true, + "GLOBAL_NAME": true, + "GRANT": true, + "GROUP": true, + "GROUPS": true, + "HASH": true, + "HASHKEYS": true, + "HAVING": true, + "HEADER": true, + "HEAP": true, + "IDENTIFIED": true, + "IDGENERATORS": true, + "IDLE_TIME": true, + "IF": true, + "IMMEDIATE": true, + "IN": true, + "INCLUDING": true, + "INCREMENT": true, + "INDEX": true, + "INDEXED": true, + "INDEXES": true, + "INDICATOR": true, + "IND_PARTITION": true, + "INITIAL": true, + "INITIALLY": true, + "INITRANS": true, + "INSERT": true, + "INSTANCE": true, + "INSTANCES": true, + "INSTEAD": true, + "INT": true, + "INTEGER": true, + "INTERMEDIATE": true, + "INTERSECT": true, + "INTO": true, + "IS": true, + "ISOLATION": true, + "ISOLATION_LEVEL": true, + "KEEP": true, + "KEY": true, + "KILL": true, + "LABEL": true, + "LAYER": true, + "LESS": true, + "LEVEL": true, + "LIBRARY": true, + "LIKE": true, + "LIMIT": true, + "LINK": true, + "LIST": true, + "LOB": true, + "LOCAL": true, + "LOCK": true, + "LOCKED": true, + "LOG": true, + "LOGFILE": true, + "LOGGING": true, + "LOGICAL_READS_PER_CALL": true, + "LOGICAL_READS_PER_SESSION": true, + "LONG": true, + "MANAGE": true, + "MASTER": true, + "MAX": true, + "MAXARCHLOGS": true, + "MAXDATAFILES": true, + "MAXEXTENTS": true, + "MAXINSTANCES": true, + "MAXLOGFILES": true, + "MAXLOGHISTORY": true, + "MAXLOGMEMBERS": true, + "MAXSIZE": true, + "MAXTRANS": true, + "MAXVALUE": true, + "MIN": true, + "MEMBER": true, + "MINIMUM": true, + "MINEXTENTS": true, + "MINUS": true, + "MINVALUE": true, + "MLSLABEL": true, + "MLS_LABEL_FORMAT": true, + "MODE": true, + "MODIFY": true, + "MOUNT": true, + "MOVE": true, + "MTS_DISPATCHERS": true, + "MULTISET": true, + "NATIONAL": true, + "NCHAR": true, + "NCHAR_CS": true, + "NCLOB": true, + "NEEDED": true, + "NESTED": true, + "NETWORK": true, + "NEW": true, + "NEXT": true, + "NOARCHIVELOG": true, + "NOAUDIT": true, + "NOCACHE": true, + "NOCOMPRESS": true, + "NOCYCLE": true, + "NOFORCE": true, + "NOLOGGING": true, + "NOMAXVALUE": true, + "NOMINVALUE": true, + "NONE": true, + "NOORDER": true, + "NOOVERRIDE": true, + "NOPARALLEL": true, + "NOREVERSE": true, + "NORMAL": true, + "NOSORT": true, + "NOT": true, + "NOTHING": true, + "NOWAIT": true, + "NULL": true, + "NUMBER": true, + "NUMERIC": true, + "NVARCHAR2": true, + "OBJECT": true, + "OBJNO": true, + "OBJNO_REUSE": true, + "OF": true, + "OFF": true, + "OFFLINE": true, + "OID": true, + "OIDINDEX": true, + "OLD": true, + "ON": true, + "ONLINE": true, + "ONLY": true, + "OPCODE": true, + "OPEN": true, + "OPTIMAL": true, + "OPTIMIZER_GOAL": true, + "OPTION": true, + "OR": true, + "ORDER": true, + "ORGANIZATION": true, + "OSLABEL": true, + "OVERFLOW": true, + "OWN": true, + "PACKAGE": true, + "PARALLEL": true, + "PARTITION": true, + "PASSWORD": true, + "PASSWORD_GRACE_TIME": true, + "PASSWORD_LIFE_TIME": true, + "PASSWORD_LOCK_TIME": true, + "PASSWORD_REUSE_MAX": true, + "PASSWORD_REUSE_TIME": true, + "PASSWORD_VERIFY_FUNCTION": true, + "PCTFREE": true, + "PCTINCREASE": true, + "PCTTHRESHOLD": true, + "PCTUSED": true, + "PCTVERSION": true, + "PERCENT": true, + "PERMANENT": true, + "PLAN": true, + "PLSQL_DEBUG": true, + "POST_TRANSACTION": true, + "PRECISION": true, + "PRESERVE": true, + "PRIMARY": true, + "PRIOR": true, + "PRIVATE": true, + "PRIVATE_SGA": true, + "PRIVILEGE": true, + "PRIVILEGES": true, + "PROCEDURE": true, + "PROFILE": true, + "PUBLIC": true, + "PURGE": true, + "QUEUE": true, + "QUOTA": true, + "RANGE": true, + "RAW": true, + "RBA": true, + "READ": true, + "READUP": true, + "REAL": true, + "REBUILD": true, + "RECOVER": true, + "RECOVERABLE": true, + "RECOVERY": true, + "REF": true, + "REFERENCES": true, + "REFERENCING": true, + "REFRESH": true, + "RENAME": true, + "REPLACE": true, + "RESET": true, + "RESETLOGS": true, + "RESIZE": true, + "RESOURCE": true, + "RESTRICTED": true, + "RETURN": true, + "RETURNING": true, + "REUSE": true, + "REVERSE": true, + "REVOKE": true, + "ROLE": true, + "ROLES": true, + "ROLLBACK": true, + "ROW": true, + "ROWID": true, + "ROWNUM": true, + "ROWS": true, + "RULE": true, + "SAMPLE": true, + "SAVEPOINT": true, + "SB4": true, + "SCAN_INSTANCES": true, + "SCHEMA": true, + "SCN": true, + "SCOPE": true, + "SD_ALL": true, + "SD_INHIBIT": true, + "SD_SHOW": true, + "SEGMENT": true, + "SEG_BLOCK": true, + "SEG_FILE": true, + "SELECT": true, + "SEQUENCE": true, + "SERIALIZABLE": true, + "SESSION": true, + "SESSION_CACHED_CURSORS": true, + "SESSIONS_PER_USER": true, + "SET": true, + "SHARE": true, + "SHARED": true, + "SHARED_POOL": true, + "SHRINK": true, + "SIZE": true, + "SKIP": true, + "SKIP_UNUSABLE_INDEXES": true, + "SMALLINT": true, + "SNAPSHOT": true, + "SOME": true, + "SORT": true, + "SPECIFICATION": true, + "SPLIT": true, + "SQL_TRACE": true, + "STANDBY": true, + "START": true, + "STATEMENT_ID": true, + "STATISTICS": true, + "STOP": true, + "STORAGE": true, + "STORE": true, + "STRUCTURE": true, + "SUCCESSFUL": true, + "SWITCH": true, + "SYS_OP_ENFORCE_NOT_NULL$": true, + "SYS_OP_NTCIMG$": true, + "SYNONYM": true, + "SYSDATE": true, + "SYSDBA": true, + "SYSOPER": true, + "SYSTEM": true, + "TABLE": true, + "TABLES": true, + "TABLESPACE": true, + "TABLESPACE_NO": true, + "TABNO": true, + "TEMPORARY": true, + "THAN": true, + "THE": true, + "THEN": true, + "THREAD": true, + "TIMESTAMP": true, + "TIME": true, + "TO": true, + "TOPLEVEL": true, + "TRACE": true, + "TRACING": true, + "TRANSACTION": true, + "TRANSITIONAL": true, + "TRIGGER": true, + "TRIGGERS": true, + "TRUE": true, + "TRUNCATE": true, + "TX": true, + "TYPE": true, + "UB2": true, + "UBA": true, + "UID": true, + "UNARCHIVED": true, + "UNDO": true, + "UNION": true, + "UNIQUE": true, + "UNLIMITED": true, + "UNLOCK": true, + "UNRECOVERABLE": true, + "UNTIL": true, + "UNUSABLE": true, + "UNUSED": true, + "UPDATABLE": true, + "UPDATE": true, + "USAGE": true, + "USE": true, + "USER": true, + "USING": true, + "VALIDATE": true, + "VALIDATION": true, + "VALUE": true, + "VALUES": true, + "VARCHAR": true, + "VARCHAR2": true, + "VARYING": true, + "VIEW": true, + "WHEN": true, + "WHENEVER": true, + "WHERE": true, + "WITH": true, + "WITHOUT": true, + "WORK": true, + "WRITE": true, + "WRITEDOWN": true, + "WRITEUP": true, + "XID": true, + "YEAR": true, + "ZONE": true, + } + + damengQuoter = schemas.Quoter{ + Prefix: '"', + Suffix: '"', + IsReserved: schemas.AlwaysReserve, + } +) + +type dameng struct { + Base +} + +func (db *dameng) Init(uri *URI) error { + db.quoter = damengQuoter + return db.Base.Init(db, uri) +} + +func (db *dameng) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { + rows, err := queryer.QueryContext(ctx, "SELECT * FROM V$VERSION") // select id_code + if err != nil { + return nil, err + } + defer rows.Close() + + var version string + if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") + } + + if err := rows.Scan(&version); err != nil { + return nil, err + } + return &schemas.Version{ + Number: version, + }, nil +} + +func (db *dameng) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: SequenceAutoincrMode, + } +} + +// DropIndexSQL returns a SQL to drop index +func (db *dameng) DropIndexSQL(tableName string, index *schemas.Index) string { + quote := db.dialect.Quoter().Quote + var name string + if index.IsRegular { + name = index.XName(tableName) + } else { + name = index.Name + } + return fmt.Sprintf("DROP INDEX %v", quote(name)) +} + +func (db *dameng) SQLType(c *schemas.Column) string { + var res string + switch t := c.SQLType.Name; t { + case schemas.TinyInt, "BYTE": + return "TINYINT" + case schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedTinyInt: + return "INTEGER" + case schemas.BigInt, + schemas.UnsignedBigInt, schemas.UnsignedBit, schemas.UnsignedInt, + schemas.Serial, schemas.BigSerial: + return "BIGINT" + case schemas.Bit, schemas.Bool, schemas.Boolean: + return schemas.Bit + case schemas.Uuid: + res = schemas.Varchar + c.Length = 40 + case schemas.Binary: + if c.Length == 0 { + return schemas.Binary + "(MAX)" + } + case schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: + return schemas.VarBinary + case schemas.Date: + return schemas.Date + case schemas.Time: + if c.Length > 0 { + return fmt.Sprintf("%s(%d)", schemas.Time, c.Length) + } + return schemas.Time + case schemas.DateTime, schemas.TimeStamp: + res = schemas.TimeStamp + case schemas.TimeStampz: + if c.Length > 0 { + return fmt.Sprintf("TIMESTAMP(%d) WITH TIME ZONE", c.Length) + } + return "TIMESTAMP WITH TIME ZONE" + case schemas.Float: + res = "FLOAT" + case schemas.Real, schemas.Double: + res = "REAL" + case schemas.Numeric, schemas.Decimal, "NUMBER": + res = "NUMERIC" + case schemas.Text, schemas.Json: + return "TEXT" + case schemas.MediumText, schemas.LongText: + res = "CLOB" + case schemas.Char, schemas.Varchar, schemas.TinyText: + res = "VARCHAR2" + default: + res = t + } + + hasLen1 := (c.Length > 0) + hasLen2 := (c.Length2 > 0) + + if hasLen2 { + res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")" + } else if hasLen1 { + res += "(" + strconv.FormatInt(c.Length, 10) + ")" + } + return res +} + +func (db *dameng) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATE": + return schemas.TIME_TYPE + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + return schemas.TEXT_TYPE + case "NUMBER": + return schemas.NUMERIC_TYPE + case "BLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + +func (db *dameng) AutoIncrStr() string { + return "IDENTITY" +} + +func (db *dameng) IsReserved(name string) bool { + _, ok := damengReservedWords[strings.ToUpper(name)] + return ok +} + +func (db *dameng) DropTableSQL(tableName string) (string, bool) { + return fmt.Sprintf("DROP TABLE %s", db.quoter.Quote(tableName)), false +} + +// ModifyColumnSQL returns a SQL to modify SQL +func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, false) + return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s) +} + +func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { + if tableName == "" { + tableName = table.Name + } + + quoter := db.Quoter() + var b strings.Builder + if _, err := b.WriteString("CREATE TABLE "); err != nil { + return "", false, err + } + if err := quoter.QuoteTo(&b, tableName); err != nil { + return "", false, err + } + if _, err := b.WriteString(" ("); err != nil { + return "", false, err + } + + pkList := table.PrimaryKeys + + for i, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + if col.SQLType.IsBool() && !col.DefaultIsEmpty { + if col.Default == "true" { + col.Default = "1" + } else if col.Default == "false" { + col.Default = "0" + } + } + + s, _ := ColumnString(db, col, false) + if _, err := b.WriteString(s); err != nil { + return "", false, err + } + if i != len(table.ColumnsSeq())-1 { + if _, err := b.WriteString(", "); err != nil { + return "", false, err + } + } + } + + if len(pkList) > 0 { + if len(table.ColumnsSeq()) > 0 { + if _, err := b.WriteString(", "); err != nil { + return "", false, err + } + } + if _, err := b.WriteString(fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (", tableName)); err != nil { + return "", false, err + } + if err := quoter.JoinWrite(&b, pkList, ","); err != nil { + return "", false, err + } + if _, err := b.WriteString(")"); err != nil { + return "", false, err + } + } + if _, err := b.WriteString(")"); err != nil { + return "", false, err + } + + return b.String(), false, nil +} + +func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + q := damengQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + q := damengQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = damengQuoter + } +} + +func (db *dameng) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { + args := []interface{}{tableName, idxName} + return `SELECT INDEX_NAME FROM USER_INDEXES ` + + `WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args +} + +func (db *dameng) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = ?`, tableName) +} + +func (db *dameng) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { + var cnt int + rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM user_sequences WHERE sequence_name = ?", seqName) + if err != nil { + return false, err + } + defer rows.Close() + if !rows.Next() { + if rows.Err() != nil { + return false, rows.Err() + } + return false, errors.New("query sequence failed") + } + + if err := rows.Scan(&cnt); err != nil { + return false, err + } + return cnt > 0, nil +} + +func (db *dameng) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { + args := []interface{}{tableName, colName} + query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + + " AND column_name = ?" + return db.HasRecords(queryer, ctx, query, args...) +} + +var _ sql.Scanner = &dmClobScanner{} + +type dmClobScanner struct { + valid bool + data string +} + +type dmClobObject interface { + GetLength() (int64, error) + ReadString(int, int) (string, error) +} + +// var _ dmClobObject = &dm.DmClob{} + +func (d *dmClobScanner) Scan(data interface{}) error { + if data == nil { + return nil + } + + switch t := data.(type) { + case dmClobObject: // *dm.DmClob + if t == nil { + return nil + } + l, err := t.GetLength() + if err != nil { + return err + } + if l == 0 { + d.valid = true + return nil + } + d.data, err = t.ReadString(1, int(l)) + if err != nil { + return err + } + d.valid = true + return nil + case []byte: + if t == nil { + return nil + } + d.data = string(t) + d.valid = true + return nil + default: + return fmt.Errorf("cannot convert %T as dmClobScanner", data) + } +} + +func addSingleQuote(name string) string { + if len(name) < 2 { + return name + } + if name[0] == '\'' && name[len(name)-1] == '\'' { + return name + } + return fmt.Sprintf("'%s'", name) +} + +func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { + s := `select column_name from user_cons_columns + where constraint_name = (select constraint_name from user_constraints + where table_name = ? and constraint_type ='P')` + rows, err := queryer.QueryContext(ctx, s, tableName) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + var pkNames []string + for rows.Next() { + var pkName string + err = rows.Scan(&pkName) + if err != nil { + return nil, nil, err + } + pkNames = append(pkNames, pkName) + } + if rows.Err() != nil { + return nil, nil, rows.Err() + } + rows.Close() + + s = `SELECT USER_TAB_COLS.COLUMN_NAME, USER_TAB_COLS.DATA_DEFAULT, USER_TAB_COLS.DATA_TYPE, USER_TAB_COLS.DATA_LENGTH, + USER_TAB_COLS.data_precision, USER_TAB_COLS.data_scale, USER_TAB_COLS.NULLABLE, + user_col_comments.comments + FROM USER_TAB_COLS + LEFT JOIN user_col_comments on user_col_comments.TABLE_NAME=USER_TAB_COLS.TABLE_NAME + AND user_col_comments.COLUMN_NAME=USER_TAB_COLS.COLUMN_NAME + WHERE USER_TAB_COLS.table_name = ?` + rows, err = queryer.QueryContext(ctx, s, tableName) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + cols := make(map[string]*schemas.Column) + colSeq := make([]string, 0) + for rows.Next() { + col := new(schemas.Column) + col.Indexes = make(map[string]int) + + var colDefault dmClobScanner + var colName, nullable, dataType, dataPrecision, comment sql.NullString + var dataScale, dataLen sql.NullInt64 + + err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision, + &dataScale, &nullable, &comment) + if err != nil { + return nil, nil, err + } + + if !colName.Valid { + return nil, nil, errors.New("column name is nil") + } + + col.Name = strings.Trim(colName.String, `" `) + if colDefault.valid { + col.Default = colDefault.data + } else { + col.DefaultIsEmpty = true + } + + if nullable.String == "Y" { + col.Nullable = true + } else { + col.Nullable = false + } + + if !comment.Valid { + col.Comment = comment.String + } + if utils.IndexSlice(pkNames, col.Name) > -1 { + col.IsPrimaryKey = true + has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = ?", utils.SeqName(tableName)) + if err != nil { + return nil, nil, err + } + if has { + col.IsAutoIncrement = true + } + } + + var ( + ignore bool + dt string + len1, len2 int64 + ) + + dts := strings.Split(dataType.String, "(") + dt = dts[0] + if len(dts) > 1 { + lens := strings.Split(dts[1][:len(dts[1])-1], ",") + if len(lens) > 1 { + len1, _ = strconv.ParseInt(lens[0], 10, 64) + len2, _ = strconv.ParseInt(lens[1], 10, 64) + } else { + len1, _ = strconv.ParseInt(lens[0], 10, 64) + } + } + + switch dt { + case "VARCHAR2": + col.SQLType = schemas.SQLType{Name: "VARCHAR2", DefaultLength: len1, DefaultLength2: len2} + case "VARCHAR": + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: len1, DefaultLength2: len2} + case "TIMESTAMP WITH TIME ZONE": + col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} + case "NUMBER": + col.SQLType = schemas.SQLType{Name: "NUMBER", DefaultLength: len1, DefaultLength2: len2} + case "LONG", "LONG RAW", "NCLOB", "CLOB", "TEXT": + col.SQLType = schemas.SQLType{Name: schemas.Text, DefaultLength: 0, DefaultLength2: 0} + case "RAW": + col.SQLType = schemas.SQLType{Name: schemas.Binary, DefaultLength: 0, DefaultLength2: 0} + case "ROWID": + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 18, DefaultLength2: 0} + case "AQ$_SUBSCRIBERS": + ignore = true + default: + col.SQLType = schemas.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2} + } + + if ignore { + continue + } + + if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { + return nil, nil, fmt.Errorf("unknown colType %v %v", dataType.String, col.SQLType) + } + + if col.SQLType.Name == "TIMESTAMP" { + col.Length = dataScale.Int64 + } else { + col.Length = dataLen.Int64 + } + + if col.SQLType.IsTime() { + if !col.DefaultIsEmpty && !strings.EqualFold(col.Default, "CURRENT_TIMESTAMP") { + col.Default = addSingleQuote(col.Default) + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } + if rows.Err() != nil { + return nil, nil, rows.Err() + } + + return colSeq, cols, nil +} + +func (db *dameng) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { + s := "SELECT table_name FROM user_tables WHERE temporary = 'N' AND table_name NOT LIKE ?" + args := []interface{}{strings.ToUpper(db.uri.User), "%$%"} + + rows, err := queryer.QueryContext(ctx, s, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + tables := make([]*schemas.Table, 0) + for rows.Next() { + table := schemas.NewEmptyTable() + err = rows.Scan(&table.Name) + if err != nil { + return nil, err + } + + tables = append(tables, table) + } + if rows.Err() != nil { + return nil, rows.Err() + } + return tables, nil +} + +func (db *dameng) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { + args := []interface{}{tableName, tableName} + s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =?" + + " AND t.index_name not in (SELECT index_name FROM ALL_CONSTRAINTS WHERE CONSTRAINT_TYPE='P' AND table_name = ?)" + + rows, err := queryer.QueryContext(ctx, s, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + indexes := make(map[string]*schemas.Index) + for rows.Next() { + var indexType int + var indexName, colName, uniqueness string + + err = rows.Scan(&colName, &uniqueness, &indexName) + if err != nil { + return nil, err + } + + indexName = strings.Trim(indexName, `" `) + + var isRegular bool + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + indexName = indexName[5+len(tableName):] + isRegular = true + } + + if uniqueness == "UNIQUE" { + indexType = schemas.UniqueType + } else { + indexType = schemas.IndexType + } + + var index *schemas.Index + var ok bool + if index, ok = indexes[indexName]; !ok { + index = new(schemas.Index) + index.Type = indexType + index.Name = indexName + index.IsRegular = isRegular + indexes[indexName] = index + } + index.AddColumn(colName) + } + if rows.Err() != nil { + return nil, rows.Err() + } + return indexes, nil +} + +func (db *dameng) Filters() []Filter { + return []Filter{} +} + +type damengDriver struct { + baseDriver +} + +// Features return features +func (d *damengDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: false, + } +} + +// Parse parse the datasource +// dm://userName:password@ip:port +func (d *damengDriver) Parse(driverName, dataSourceName string) (*URI, error) { + u, err := url.Parse(dataSourceName) + if err != nil { + return nil, err + } + + if u.User == nil { + return nil, errors.New("user/password needed") + } + + passwd, _ := u.User.Password() + return &URI{ + DBType: schemas.DAMENG, + Proto: u.Scheme, + Host: u.Hostname(), + Port: u.Port(), + DBName: u.User.Username(), + User: u.User.Username(), + Passwd: passwd, + }, nil +} + +func (d *damengDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + var s sql.NullString + return &s, nil + case "NUMBER": + var s sql.NullString + return &s, nil + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "INTEGER": + var s sql.NullInt32 + return &s, nil + case "DATE", "TIMESTAMP": + var s sql.NullString + return &s, nil + case "BLOB": + var r sql.RawBytes + return &r, nil + case "FLOAT": + var s sql.NullFloat64 + return &s, nil + default: + var r sql.RawBytes + return &r, nil + } +} + +func (d *damengDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error { + scanResults := make([]interface{}, 0, len(types)) + replaces := make([]bool, 0, len(types)) + var err error + for i, v := range vv { + var replaced bool + var scanResult interface{} + switch types[i].DatabaseTypeName() { + case "CLOB", "TEXT": + scanResult = &dmClobScanner{} + replaced = true + case "TIMESTAMP": + scanResult = &sql.NullString{} + replaced = true + default: + scanResult = v + } + + scanResults = append(scanResults, scanResult) + replaces = append(replaces, replaced) + } + + if err = rows.Scan(scanResults...); err != nil { + return err + } + + for i, replaced := range replaces { + if replaced { + switch t := scanResults[i].(type) { + case *dmClobScanner: + var d interface{} + if t.valid { + d = t.data + } else { + d = nil + } + if err := convert.Assign(vv[i], d, ctx.DBLocation, ctx.UserLocation); err != nil { + return err + } + default: + switch types[i].DatabaseTypeName() { + case "TIMESTAMP": + ns := t.(*sql.NullString) + if !ns.Valid { + break + } + s := ns.String + fields := strings.Split(s, "+") + if err := convert.Assign(vv[i], strings.Replace(fields[0], "T", " ", -1), ctx.DBLocation, ctx.UserLocation); err != nil { + return err + } + default: + return fmt.Errorf("don't support convert %T to %T", t, vv[i]) + } + } + } + } + + return nil +} diff --git a/dialects/dialect.go b/dialects/dialect.go index dc96f73a..70d599e6 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -38,12 +38,27 @@ func (uri *URI) SetSchema(schema string) { } } +// enumerates all autoincr mode +const ( + IncrAutoincrMode = iota + SequenceAutoincrMode +) + +// DialectFeatures represents a dialect parameters +type DialectFeatures struct { + AutoincrMode int // 0 autoincrement column, 1 sequence +} + // Dialect represents a kind of database type Dialect interface { Init(*URI) error URI() *URI + Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) + Features() *DialectFeatures + SQLType(*schemas.Column) string - FormatBytes(b []byte) string + Alias(string) string // return what a sql type's alias of + ColumnTypeKind(string) int // database column type kind IsReserved(string) bool Quoter() schemas.Quoter @@ -58,9 +73,13 @@ type Dialect interface { GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) - CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) + CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) DropTableSQL(tableName string) (string, bool) + CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) + IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) + DropSequenceSQL(seqName string) (string, error) + GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) AddColumnSQL(tableName string, col *schemas.Column) string @@ -79,32 +98,87 @@ type Base struct { quoter schemas.Quoter } -func (b *Base) Quoter() schemas.Quoter { - return b.quoter +// Alias returned col itself +func (db *Base) Alias(col string) string { + return col } -func (b *Base) Init(dialect Dialect, uri *URI) error { - b.dialect, b.uri = dialect, uri +// Quoter returns the current database Quoter +func (db *Base) Quoter() schemas.Quoter { + return db.quoter +} + +// Init initialize the dialect +func (db *Base) Init(dialect Dialect, uri *URI) error { + db.dialect, db.uri = dialect, uri return nil } -func (b *Base) URI() *URI { - return b.uri +// URI returns the uri of database +func (db *Base) URI() *URI { + return db.uri } -func (b *Base) DBType() schemas.DBType { - return b.uri.DBType +// CreateTableSQL implements Dialect +func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { + if tableName == "" { + tableName = table.Name + } + + quoter := db.dialect.Quoter() + var b strings.Builder + b.WriteString("CREATE TABLE IF NOT EXISTS ") + if err := quoter.QuoteTo(&b, tableName); err != nil { + return "", false, err + } + b.WriteString(" (") + + for i, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + b.WriteString(s) + + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") + } + } + + if len(table.PrimaryKeys) > 1 { + b.WriteString(", PRIMARY KEY (") + b.WriteString(quoter.Join(table.PrimaryKeys, ",")) + b.WriteString(")") + } + + b.WriteString(")") + + return b.String(), false, nil } -func (b *Base) FormatBytes(bs []byte) string { - return fmt.Sprintf("0x%x", bs) +func (db *Base) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) { + return fmt.Sprintf(`CREATE SEQUENCE %s + minvalue 1 + nomaxvalue + start with 1 + increment by 1 + nocycle + nocache`, seqName), nil } +func (db *Base) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { + return false, fmt.Errorf("unsupported sequence feature") +} + +func (db *Base) DropSequenceSQL(seqName string) (string, error) { + return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil +} + +// DropTableSQL returns drop table SQL func (db *Base) DropTableSQL(tableName string) (string, bool) { quote := db.dialect.Quoter().Quote return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true } +// HasRecords returns true if the SQL has records returned func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) { rows, err := queryer.QueryContext(ctx, query, args...) if err != nil { @@ -115,9 +189,10 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri if rows.Next() { return true, nil } - return false, nil + return false, rows.Err() } +// IsColumnExist returns true if the column of the table exist func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { quote := db.dialect.Quoter().Quote query := fmt.Sprintf( @@ -132,11 +207,13 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName) } +// AddColumnSQL returns a SQL to add a column func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { s, _ := ColumnString(db.dialect, col, true) - return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), s) + return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s) } +// CreateIndexSQL returns a SQL to create index func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { quoter := db.dialect.Quoter() var unique string @@ -150,6 +227,7 @@ func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { quoter.Join(index.Cols, ",")) } +// DropIndexSQL returns a SQL to drop index func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { quote := db.dialect.Quoter().Quote var name string @@ -161,16 +239,19 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName)) } +// ModifyColumnSQL returns a SQL to modify SQL func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { s, _ := ColumnString(db.dialect, col, false) - return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, s) + return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) } -func (b *Base) ForUpdateSQL(query string) string { +// ForUpdateSQL returns for updateSQL +func (db *Base) ForUpdateSQL(query string) string { return query + " FOR UPDATE" } -func (b *Base) SetParams(params map[string]string) { +// SetParams set params +func (db *Base) SetParams(params map[string]string) { } var ( @@ -206,8 +287,10 @@ func regDrvsNDialects() bool { "postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }}, "pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }}, "sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }}, + "sqlite": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }}, "oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }}, - "goracle": {"oracle", func() Driver { return &goracleDriver{} }, func() Dialect { return &oracle{} }}, + "godror": {"oracle", func() Driver { return &godrorDriver{} }, func() Dialect { return &oracle{} }}, + "oracle": {"oracle", func() Driver { return &oracleDriver{} }, func() Dialect { return &oracle{} }}, } for driverName, v := range providedDrvsNDialects { @@ -239,43 +322,41 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) return "", err } - if err := bd.WriteByte(' '); err != nil { - return "", err - } - if includePrimaryKey && col.IsPrimaryKey { - if _, err := bd.WriteString("PRIMARY KEY "); err != nil { + if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { return "", err } - if col.IsAutoIncrement { - if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { - return "", err - } if err := bd.WriteByte(' '); err != nil { return "", err } + if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { + return "", err + } } } - if col.Default != "" { - if _, err := bd.WriteString("DEFAULT "); err != nil { + if !col.DefaultIsEmpty { + if _, err := bd.WriteString(" DEFAULT "); err != nil { return "", err } - if _, err := bd.WriteString(col.Default); err != nil { - return "", err - } - if err := bd.WriteByte(' '); err != nil { - return "", err + if col.Default == "" { + if _, err := bd.WriteString("''"); err != nil { + return "", err + } + } else { + if _, err := bd.WriteString(col.Default); err != nil { + return "", err + } } } if col.Nullable { - if _, err := bd.WriteString("NULL "); err != nil { + if _, err := bd.WriteString(" NULL"); err != nil { return "", err } } else { - if _, err := bd.WriteString("NOT NULL "); err != nil { + if _, err := bd.WriteString(" NOT NULL"); err != nil { return "", err } } diff --git a/dialects/driver.go b/dialects/driver.go index ae3afe42..c63dbfa3 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -5,17 +5,37 @@ package dialects import ( + "database/sql" "fmt" + "time" + + "xorm.io/xorm/core" ) +// ScanContext represents a context when Scan +type ScanContext struct { + DBLocation *time.Location + UserLocation *time.Location +} + +// DriverFeatures represents driver feature +type DriverFeatures struct { + SupportReturnInsertedID bool +} + +// Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) + Features() *DriverFeatures + GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface + Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } var ( drivers = map[string]Driver{} ) +// RegisterDriver register a driver func RegisterDriver(driverName string, driver Driver) { if driver == nil { panic("core: Register driver is nil") @@ -26,10 +46,12 @@ func RegisterDriver(driverName string, driver Driver) { drivers[driverName] = driver } +// QueryDriver query a driver with name func QueryDriver(driverName string) Driver { return drivers[driverName] } +// RegisteredDriverSize returned all drivers's length func RegisteredDriverSize() int { return len(drivers) } @@ -38,7 +60,7 @@ func RegisteredDriverSize() int { func OpenDialect(driverName, connstr string) (Dialect, error) { driver := QueryDriver(driverName) if driver == nil { - return nil, fmt.Errorf("Unsupported driver name: %v", driverName) + return nil, fmt.Errorf("unsupported driver name: %v", driverName) } uri, err := driver.Parse(driverName, connstr) @@ -48,10 +70,16 @@ func OpenDialect(driverName, connstr string) (Dialect, error) { dialect := QueryDialect(uri.DBType) if dialect == nil { - return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) + return nil, fmt.Errorf("unsupported dialect type: %v", uri.DBType) } dialect.Init(uri) return dialect, nil } + +type baseDriver struct{} + +func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { + return rows.Scan(v...) +} diff --git a/dialects/filter.go b/dialects/filter.go index 6968b6ce..bfe2e93e 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -23,13 +23,45 @@ type SeqFilter struct { func convertQuestionMark(sql, prefix string, start int) string { var buf strings.Builder var beginSingleQuote bool + var isLineComment bool + var isComment bool + var isMaybeLineComment bool + var isMaybeComment bool + var isMaybeCommentEnd bool var index = start for _, c := range sql { - if !beginSingleQuote && c == '?' { + if !beginSingleQuote && !isLineComment && !isComment && c == '?' { buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) index++ } else { - if c == '\'' { + if isMaybeLineComment { + if c == '-' { + isLineComment = true + } + isMaybeLineComment = false + } else if isMaybeComment { + if c == '*' { + isComment = true + } + isMaybeComment = false + } else if isMaybeCommentEnd { + if c == '/' { + isComment = false + } + isMaybeCommentEnd = false + } else if isLineComment { + if c == '\n' { + isLineComment = false + } + } else if isComment { + if c == '*' { + isMaybeCommentEnd = true + } + } else if !beginSingleQuote && c == '-' { + isMaybeLineComment = true + } else if !beginSingleQuote && c == '/' { + isMaybeComment = true + } else if c == '\'' { beginSingleQuote = !beginSingleQuote } buf.WriteRune(c) @@ -38,6 +70,7 @@ func convertQuestionMark(sql, prefix string, start int) string { return buf.String() } +// Do implements Filter func (s *SeqFilter) Do(sql string) string { return convertQuestionMark(sql, s.Prefix, s.Start) } diff --git a/dialects/filter_test.go b/dialects/filter_test.go index 7e2ef0a2..15050656 100644 --- a/dialects/filter_test.go +++ b/dialects/filter_test.go @@ -19,3 +19,60 @@ func TestSeqFilter(t *testing.T) { assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) } } + +func TestSeqFilterLineComment(t *testing.T) { + var kases = map[string]string{ + `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=? -- it's a comment + AND b=?`: `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=$1 -- it's a comment + AND b=$2`, + `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=? -- it's a comment? + AND b=?`: `SELECT * + FROM TABLE1 + WHERE foo='bar' + AND a=$1 -- it's a comment? + AND b=$2`, + `SELECT * + FROM TABLE1 + WHERE a=? -- it's a comment? and that's okay? + AND b=?`: `SELECT * + FROM TABLE1 + WHERE a=$1 -- it's a comment? and that's okay? + AND b=$2`, + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} + +func TestSeqFilterComment(t *testing.T) { + var kases = map[string]string{ + `SELECT * + FROM TABLE1 + WHERE a=? /* it's a comment */ + AND b=?`: `SELECT * + FROM TABLE1 + WHERE a=$1 /* it's a comment */ + AND b=$2`, + `SELECT /* it's a comment * ? + More comment on the next line! */ * + FROM TABLE1 + WHERE a=? /**/ + AND b=?`: `SELECT /* it's a comment * ? + More comment on the next line! */ * + FROM TABLE1 + WHERE a=$1 /**/ + AND b=$2`, + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} diff --git a/dialects/mssql.go b/dialects/mssql.go index 8e76e538..1b6fe692 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "net/url" @@ -220,13 +221,15 @@ type mssql struct { func (db *mssql) Init(uri *URI) error { db.quoter = mssqlQuoter + db.defaultChar = "CHAR" + db.defaultVarchar = "VARCHAR" return db.Base.Init(db, uri) } func (db *mssql) SetParams(params map[string]string) { defaultVarchar, ok := params["DEFAULT_VARCHAR"] if ok { - var t = strings.ToUpper(defaultVarchar) + t := strings.ToUpper(defaultVarchar) switch t { case "NVARCHAR", "VARCHAR": db.defaultVarchar = t @@ -239,7 +242,7 @@ func (db *mssql) SetParams(params map[string]string) { defaultChar, ok := params["DEFAULT_CHAR"] if ok { - var t = strings.ToUpper(defaultChar) + t := strings.ToUpper(defaultChar) switch t { case "NCHAR", "CHAR": db.defaultChar = t @@ -251,10 +254,44 @@ func (db *mssql) SetParams(params map[string]string) { } } +func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { + rows, err := queryer.QueryContext(ctx, + "SELECT SERVERPROPERTY('productversion'), SERVERPROPERTY ('productlevel') AS ProductLevel, SERVERPROPERTY ('edition') AS ProductEdition") + if err != nil { + return nil, err + } + defer rows.Close() + + var version, level, edition string + if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") + } + + if err := rows.Scan(&version, &level, &edition); err != nil { + return nil, err + } + + // MSSQL: Microsoft SQL Server 2017 (RTM-CU13) (KB4466404) - 14.0.3048.4 (X64) Nov 30 2018 12:57:58 Copyright (C) 2017 Microsoft Corporation Developer Edition (64-bit) on Linux (Ubuntu 16.04.5 LTS) + return &schemas.Version{ + Number: version, + Level: level, + Edition: edition, + }, nil +} + +func (db *mssql) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + func (db *mssql) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case schemas.Bool: + case schemas.Bool, schemas.Boolean: res = schemas.Bit if strings.EqualFold(c.Default, "true") { c.Default = "1" @@ -272,17 +309,26 @@ func (db *mssql) SQLType(c *schemas.Column) string { c.IsPrimaryKey = true c.Nullable = false res = schemas.BigInt - case schemas.Bytea, schemas.Blob, schemas.Binary, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob: + case schemas.Bytea, schemas.Binary: res = schemas.VarBinary if c.Length == 0 { c.Length = 50 } - case schemas.TimeStamp: - res = schemas.DateTime + case schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob: + res = schemas.VarBinary + if c.Length == 0 { + res += "(MAX)" + } + case schemas.TimeStamp, schemas.DateTime: + if c.Length > 3 { + res = "DATETIME2" + } else { + return schemas.DateTime + } case schemas.TimeStampz: res = "DATETIMEOFFSET" c.Length = 7 - case schemas.MediumInt: + case schemas.MediumInt, schemas.TinyInt, schemas.SmallInt, schemas.UnsignedMediumInt, schemas.UnsignedTinyInt, schemas.UnsignedSmallInt: res = schemas.Int case schemas.Text, schemas.MediumText, schemas.TinyText, schemas.LongText, schemas.Json: res = db.defaultVarchar + "(MAX)" @@ -294,7 +340,7 @@ func (db *mssql) SQLType(c *schemas.Column) string { case schemas.TinyInt: res = schemas.TinyInt c.Length = 0 - case schemas.BigInt: + case schemas.BigInt, schemas.UnsignedBigInt, schemas.UnsignedInt: res = schemas.BigInt c.Length = 0 case schemas.NVarchar: @@ -321,7 +367,7 @@ func (db *mssql) SQLType(c *schemas.Column) string { res = t } - if res == schemas.Int || res == schemas.Bit || res == schemas.DateTime { + if res == schemas.Int || res == schemas.Bit { return res } @@ -329,13 +375,26 @@ func (db *mssql) SQLType(c *schemas.Column) string { hasLen2 := (c.Length2 > 0) if hasLen2 { - res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")" } else if hasLen1 { - res += "(" + strconv.Itoa(c.Length) + ")" + res += "(" + strconv.FormatInt(c.Length, 10) + ")" } return res } +func (db *mssql) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATE", "DATETIME", "DATETIME2", "TIME": + return schemas.TIME_TYPE + case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT": + return schemas.TEXT_TYPE + case "FLOAT", "REAL", "BIGINT", "DATETIMEOFFSET", "TINYINT", "SMALLINT", "INT": + return schemas.NUMERIC_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *mssql) IsReserved(name string) bool { _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok @@ -344,11 +403,11 @@ func (db *mssql) IsReserved(name string) bool { func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - var q = mssqlQuoter + q := mssqlQuoter q.IsReserved = schemas.AlwaysNoReserve db.quoter = q case QuotePolicyReserved: - var q = mssqlQuoter + q := mssqlQuoter q.IsReserved = db.IsReserved db.quoter = q case QuotePolicyAlways: @@ -368,6 +427,11 @@ func (db *mssql) DropTableSQL(tableName string) (string, bool) { "DROP TABLE \"%s\"", tableName, tableName), true } +func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, false) + return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s) +} + func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{idxName} sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?" @@ -411,7 +475,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName colSeq := make([]string, 0) for rows.Next() { var name, ctype, vdefault string - var maxLen, precision, scale int + var maxLen, precision, scale int64 var nullable, isPK, defaultIsNull, isIncrement bool err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) if err != nil { @@ -444,6 +508,12 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName col.Length /= 2 col.Length2 /= 2 } + case "DATETIME2": + col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 7, DefaultLength2: 0} + col.Length = scale + case "DATETIME": + col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 3, DefaultLength2: 0} + col.Length = scale case "IMAGE": col.SQLType = schemas.SQLType{Name: schemas.VarBinary, DefaultLength: 0, DefaultLength2: 0} case "NCHAR": @@ -463,6 +533,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -487,6 +560,9 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.Name = strings.Trim(name, "` ") tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -510,7 +586,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { var indexType int var indexName, colName, isUnique string @@ -533,7 +609,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? colName = strings.Trim(colName, "` ") var isRegular bool - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + if (strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName)) && len(indexName) > (5+len(tableName)) { indexName = indexName[5+len(tableName):] isRegular = true } @@ -549,38 +625,44 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } -func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql string +func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { if tableName == "" { tableName = table.Name } - sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + tableName + "' ) CREATE TABLE " + quoter := db.dialect.Quoter() + var b strings.Builder + b.WriteString("IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '") + quoter.QuoteTo(&b, tableName) + b.WriteString("' ) CREATE TABLE ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" (") - sql += db.Quoter().Quote(tableName) + " (" - - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { + for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) - sql += s - sql = strings.TrimSpace(sql) - sql += ", " + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + b.WriteString(s) + + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") + } } - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += strings.Join(pkList, ",") - sql += " ), " + if len(table.PrimaryKeys) > 1 { + b.WriteString(", PRIMARY KEY (") + b.WriteString(quoter.Join(table.PrimaryKeys, ",")) + b.WriteString(")") } - sql = sql[:len(sql)-2] + ")" - sql += ";" - return []string{sql}, true + b.WriteString(")") + + return b.String(), true, nil } func (db *mssql) ForUpdateSQL(query string) string { @@ -592,6 +674,13 @@ func (db *mssql) Filters() []Filter { } type odbcDriver struct { + baseDriver +} + +func (p *odbcDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: false, + } } func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -608,8 +697,7 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { for _, c := range kv { vv := strings.Split(strings.TrimSpace(c), "=") if len(vv) == 2 { - switch strings.ToLower(vv[0]) { - case "database": + if strings.ToLower(vv[0]) == "database" { dbName = vv[1] } } @@ -620,3 +708,26 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { } return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil } + +func (p *odbcDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT", "CHAR", "NVARCHAR", "NCHAR", "NTEXT": + fallthrough + case "DATE", "DATETIME", "DATETIME2", "TIME": + var s sql.NullString + return &s, nil + case "FLOAT", "REAL": + var s sql.NullFloat64 + return &s, nil + case "BIGINT", "DATETIMEOFFSET": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "SMALLINT", "INT": + var s sql.NullInt32 + return &s, nil + + default: + var r sql.RawBytes + return &r, nil + } +} diff --git a/dialects/mysql.go b/dialects/mysql.go index 32e18a17..195e1f23 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -6,7 +6,7 @@ package dialects import ( "context" - "crypto/tls" + "database/sql" "errors" "fmt" "regexp" @@ -171,16 +171,7 @@ var ( type mysql struct { Base - net string - addr string - params map[string]string - loc *time.Location - timeout time.Duration - tls *tls.Config - allowAllFiles bool - allowOldPasswords bool - clientFoundRows bool - rowFormat string + rowFormat string } func (db *mysql) Init(uri *URI) error { @@ -188,10 +179,69 @@ func (db *mysql) Init(uri *URI) error { return db.Base.Init(db, uri) } +var mysqlColAliases = map[string]string{ + "numeric": "decimal", +} + +// Alias returns a alias of column +func (db *mysql) Alias(col string) string { + v, ok := mysqlColAliases[strings.ToLower(col)] + if ok { + return v + } + return col +} + +func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { + rows, err := queryer.QueryContext(ctx, "SELECT @@VERSION") + if err != nil { + return nil, err + } + defer rows.Close() + + var version string + if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") + } + + if err := rows.Scan(&version); err != nil { + return nil, err + } + + fields := strings.Split(version, "-") + if len(fields) == 3 && fields[1] == "TiDB" { + // 5.7.25-TiDB-v3.0.3 + return &schemas.Version{ + Number: strings.TrimPrefix(fields[2], "v"), + Level: fields[0], + Edition: fields[1], + }, nil + } + + var edition string + if len(fields) == 2 { + edition = fields[1] + } + + return &schemas.Version{ + Number: fields[0], + Edition: edition, + }, nil +} + +func (db *mysql) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + func (db *mysql) SetParams(params map[string]string) { rowFormat, ok := params["rowFormat"] if ok { - var t = strings.ToUpper(rowFormat) + t := strings.ToUpper(rowFormat) switch t { case "COMPACT": fallthrough @@ -201,15 +251,13 @@ func (db *mysql) SetParams(params map[string]string) { fallthrough case "COMPRESSED": db.rowFormat = t - break - default: - break } } } func (db *mysql) SQLType(c *schemas.Column) string { var res string + var isUnsigned bool switch t := c.SQLType.Name; t { case schemas.Bool: res = schemas.TinyInt @@ -254,6 +302,21 @@ func (db *mysql) SQLType(c *schemas.Column) string { c.Length = 40 case schemas.Json: res = schemas.Text + case schemas.UnsignedInt: + res = schemas.Int + isUnsigned = true + case schemas.UnsignedBigInt: + res = schemas.BigInt + isUnsigned = true + case schemas.UnsignedMediumInt: + res = schemas.MediumInt + isUnsigned = true + case schemas.UnsignedSmallInt: + res = schemas.SmallInt + isUnsigned = true + case schemas.UnsignedTinyInt: + res = schemas.TinyInt + isUnsigned = true default: res = t } @@ -267,13 +330,33 @@ func (db *mysql) SQLType(c *schemas.Column) string { } if hasLen2 { - res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")" } else if hasLen1 { - res += "(" + strconv.Itoa(c.Length) + ")" + res += "(" + strconv.FormatInt(c.Length, 10) + ")" } + + if isUnsigned { + res += " UNSIGNED" + } + return res } +func (db *mysql) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME": + return schemas.TIME_TYPE + case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET": + return schemas.TEXT_TYPE + case "BIGINT", "TINYINT", "SMALLINT", "MEDIUMINT", "INT", "FLOAT", "REAL", "DOUBLE PRECISION", "DECIMAL", "NUMERIC", "BIT": + return schemas.NUMERIC_TYPE + case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *mysql) IsReserved(name string) bool { _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok @@ -314,10 +397,10 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName "(SUBSTRING_INDEX(SUBSTRING(VERSION(), 4), '.', 1) = 2 && " + "SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))" s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + - " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, " + + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + alreadyQuoted + " AS NEEDS_QUOTE " + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + - " ORDER BY `COLUMNS`.ORDINAL_POSITION" + " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC" rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { @@ -331,16 +414,16 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName col := new(schemas.Column) col.Indexes = make(map[string]int) - var columnName, isNullable, colType, colKey, extra, comment string - var alreadyQuoted bool - var colDefault *string - err = rows.Scan(&columnName, &isNullable, &colDefault, &colType, &colKey, &extra, &comment, &alreadyQuoted) + var columnName, nullableStr, colType, colKey, extra, comment string + var alreadyQuoted, isUnsigned bool + var colDefault, maxLength *string + err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted) if err != nil { return nil, nil, err } col.Name = strings.Trim(columnName, "` ") col.Comment = comment - if "YES" == isNullable { + if nullableStr == "YES" { col.Nullable = true } @@ -351,10 +434,17 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName col.DefaultIsEmpty = true } + fields := strings.Fields(colType) + if len(fields) == 2 && fields[1] == "unsigned" { + isUnsigned = true + } + colType = fields[0] cts := strings.Split(colType, "(") colName := cts[0] + // Remove the /* mariadb-5.3 */ suffix from coltypes + colName = strings.TrimSuffix(colName, "/* mariadb-5.3 */") colType = strings.ToUpper(colName) - var len1, len2 int + var len1, len2 int64 if len(cts) == 2 { idx := strings.Index(cts[1], ")") if colType == schemas.Enum && cts[1][0] == '\'' { // enum @@ -375,38 +465,43 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName } } else { lens := strings.Split(cts[1][0:idx], ",") - len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) + len1, err = strconv.ParseInt(strings.TrimSpace(lens[0]), 10, 64) if err != nil { return nil, nil, err } if len(lens) == 2 { - len2, err = strconv.Atoi(lens[1]) + len2, err = strconv.ParseInt(lens[1], 10, 64) if err != nil { return nil, nil, err } } } + } else { + switch colType { + case "MEDIUMTEXT", "LONGTEXT", "TEXT": + len1, err = strconv.ParseInt(*maxLength, 10, 64) + if err != nil { + return nil, nil, err + } + } } - if colType == "FLOAT UNSIGNED" { - colType = "FLOAT" - } - if colType == "DOUBLE UNSIGNED" { - colType = "DOUBLE" + if isUnsigned { + colType = "UNSIGNED " + colType } col.Length = len1 col.Length2 = len2 if _, ok := schemas.SqlTypes[colType]; ok { col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} } else { - return nil, nil, fmt.Errorf("Unknown colType %v", colType) + return nil, nil, fmt.Errorf("unknown colType %v", colType) } if colKey == "PRI" { col.IsPrimaryKey = true } - if colKey == "UNI" { - // col.is - } + // if colKey == "UNI" { + // col.is + // } if extra == "auto_increment" { col.IsAutoIncrement = true @@ -422,6 +517,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -453,17 +551,20 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.StoreEngine = engine tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - var q = mysqlQuoter + q := mysqlQuoter q.IsReserved = schemas.AlwaysNoReserve db.quoter = q case QuotePolicyReserved: - var q = mysqlQuoter + q := mysqlQuoter q.IsReserved = db.IsReserved db.quoter = q case QuotePolicyAlways: @@ -475,7 +576,7 @@ func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{db.uri.DBName, tableName} - s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `SEQ_IN_INDEX`" rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { @@ -483,7 +584,7 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { var indexType int var indexName, colName, nonUnique string @@ -496,7 +597,7 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName continue } - if "YES" == nonUnique || nonUnique == "1" { + if nonUnique == "YES" || nonUnique == "1" { indexType = schemas.IndexType } else { indexType = schemas.UniqueType @@ -520,119 +621,87 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } -func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql = "CREATE TABLE IF NOT EXISTS " +func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { if tableName == "" { tableName = table.Name } - quoter := db.Quoter() + quoter := db.dialect.Quoter() + var b strings.Builder + b.WriteString("CREATE TABLE IF NOT EXISTS ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" (") - sql += quoter.Quote(tableName) - sql += " (" + for i, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + b.WriteString(s) - if len(table.ColumnsSeq()) > 0 { - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { - col := table.GetColumn(colName) - s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) - sql += s - sql = strings.TrimSpace(sql) - if len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" - } - sql += ", " + if len(col.Comment) > 0 { + b.WriteString(" COMMENT '") + b.WriteString(col.Comment) + b.WriteString("'") } - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += quoter.Join(pkList, ",") - sql += " ), " + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") } - - sql = sql[:len(sql)-2] } - sql += ")" + + if len(table.PrimaryKeys) > 1 { + b.WriteString(", PRIMARY KEY (") + b.WriteString(quoter.Join(table.PrimaryKeys, ",")) + b.WriteString(")") + } + + b.WriteString(")") if table.StoreEngine != "" { - sql += " ENGINE=" + table.StoreEngine + b.WriteString(" ENGINE=") + b.WriteString(table.StoreEngine) } - var charset = table.Charset + charset := table.Charset if len(charset) == 0 { charset = db.URI().Charset } if len(charset) != 0 { - sql += " DEFAULT CHARSET " + charset + b.WriteString(" DEFAULT CHARSET ") + b.WriteString(charset) } if db.rowFormat != "" { - sql += " ROW_FORMAT=" + db.rowFormat + b.WriteString(" ROW_FORMAT=") + b.WriteString(db.rowFormat) } - return []string{sql}, true + + if table.Comment != "" { + b.WriteString(" COMMENT='") + b.WriteString(table.Comment) + b.WriteString("'") + } + + return b.String(), true, nil } func (db *mysql) Filters() []Filter { return []Filter{} } -type mymysqlDriver struct { -} - -func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { - uri := &URI{DBType: schemas.MYSQL} - - pd := strings.SplitN(dataSourceName, "*", 2) - if len(pd) == 2 { - // Parse protocol part of URI - p := strings.SplitN(pd[0], ":", 2) - if len(p) != 2 { - return nil, errors.New("Wrong protocol part of URI") - } - uri.Proto = p[0] - options := strings.Split(p[1], ",") - uri.Raddr = options[0] - for _, o := range options[1:] { - kv := strings.SplitN(o, "=", 2) - var k, v string - if len(kv) == 2 { - k, v = kv[0], kv[1] - } else { - k, v = o, "true" - } - switch k { - case "laddr": - uri.Laddr = v - case "timeout": - to, err := time.ParseDuration(v) - if err != nil { - return nil, err - } - uri.Timeout = to - default: - return nil, errors.New("Unknown option: " + k) - } - } - // Remove protocol part - pd = pd[1:] - } - // Parse database part of URI - dup := strings.SplitN(pd[0], "/", 3) - if len(dup) != 3 { - return nil, errors.New("Wrong database part of URI") - } - uri.DBName = dup[0] - uri.User = dup[1] - uri.Passwd = dup[2] - - return uri, nil -} - type mysqlDriver struct { + baseDriver +} + +func (p *mysqlDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: true, + } } func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -657,15 +726,99 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { for _, kv := range kvs { splits := strings.Split(kv, "=") if len(splits) == 2 { - switch splits[0] { - case "charset": + if splits[0] == "charset" { uri.Charset = splits[1] } } } } - } } return uri, nil } + +func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { + colType = strings.Replace(colType, "UNSIGNED ", "", -1) + switch colType { + case "CHAR", "VARCHAR", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT", "ENUM", "SET", "JSON": + var s sql.NullString + return &s, nil + case "BIGINT": + var s sql.NullInt64 + return &s, nil + case "TINYINT", "SMALLINT", "MEDIUMINT", "INT": + var s sql.NullInt32 + return &s, nil + case "FLOAT", "REAL", "DOUBLE PRECISION", "DOUBLE": + var s sql.NullFloat64 + return &s, nil + case "DECIMAL", "NUMERIC": + var s sql.NullString + return &s, nil + case "DATETIME", "TIMESTAMP": + var s sql.NullTime + return &s, nil + case "BIT": + var s sql.RawBytes + return &s, nil + case "BINARY", "VARBINARY", "TINYBLOB", "BLOB", "MEDIUMBLOB", "LONGBLOB": + var r sql.RawBytes + return &r, nil + default: + var r sql.RawBytes + return &r, nil + } +} + +type mymysqlDriver struct { + mysqlDriver +} + +func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { + uri := &URI{DBType: schemas.MYSQL} + + pd := strings.SplitN(dataSourceName, "*", 2) + if len(pd) == 2 { + // Parse protocol part of URI + p := strings.SplitN(pd[0], ":", 2) + if len(p) != 2 { + return nil, errors.New("wrong protocol part of URI") + } + uri.Proto = p[0] + options := strings.Split(p[1], ",") + uri.Raddr = options[0] + for _, o := range options[1:] { + kv := strings.SplitN(o, "=", 2) + var k, v string + if len(kv) == 2 { + k, v = kv[0], kv[1] + } else { + k, v = o, "true" + } + switch k { + case "laddr": + uri.Laddr = v + case "timeout": + to, err := time.ParseDuration(v) + if err != nil { + return nil, err + } + uri.Timeout = to + default: + return nil, errors.New("unknown option: " + k) + } + } + // Remove protocol part + pd = pd[1:] + } + // Parse database part of URI + dup := strings.SplitN(pd[0], "/", 3) + if len(dup) != 3 { + return nil, errors.New("Wrong database part of URI") + } + uri.DBName = dup[0] + uri.User = dup[1] + uri.Passwd = dup[2] + + return uri, nil +} diff --git a/dialects/oracle.go b/dialects/oracle.go index 91eed251..72c26ce2 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "regexp" @@ -515,10 +516,46 @@ func (db *oracle) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { + rows, err := queryer.QueryContext(ctx, "select * from v$version where banner like 'Oracle%'") + if err != nil { + return nil, err + } + defer rows.Close() + + var version string + if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") + } + + if err := rows.Scan(&version); err != nil { + return nil, err + } + return &schemas.Version{ + Number: version, + }, nil +} + +func (db *oracle) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: SequenceAutoincrMode, + } +} + func (db *oracle) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt, schemas.Bool, schemas.Serial, schemas.BigSerial: + case schemas.Bool: + if c.Default == "true" { + c.Default = "1" + } else if c.Default == "false" { + c.Default = "0" + } + res = "NUMBER(1,0)" + case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt, schemas.Serial, schemas.BigSerial: res = "NUMBER" case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: return schemas.Blob @@ -540,13 +577,28 @@ func (db *oracle) SQLType(c *schemas.Column) string { hasLen2 := (c.Length2 > 0) if hasLen2 { - res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")" } else if hasLen1 { - res += "(" + strconv.Itoa(c.Length) + ")" + res += "(" + strconv.FormatInt(c.Length, 10) + ")" } return res } +func (db *oracle) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATE": + return schemas.TIME_TYPE + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + return schemas.TEXT_TYPE + case "NUMBER": + return schemas.NUMERIC_TYPE + case "BLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *oracle) AutoIncrStr() string { return "AUTO_INCREMENT" } @@ -560,8 +612,8 @@ func (db *oracle) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("DROP TABLE `%s`", tableName), false } -func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql = "CREATE TABLE " +func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { + sql := "CREATE TABLE " if tableName == "" { tableName = table.Name } @@ -590,17 +642,17 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]stri } sql = sql[:len(sql)-2] + ")" - return []string{sql}, false + return sql, false, nil } func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - var q = oracleQuoter + q := oracleQuoter q.IsReserved = schemas.AlwaysNoReserve db.quoter = q case QuotePolicyReserved: - var q = oracleQuoter + q := oracleQuoter q.IsReserved = db.IsReserved db.quoter = q case QuotePolicyAlways: @@ -645,7 +697,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam col.Indexes = make(map[string]int) var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string - var dataLen int + var dataLen int64 err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision, &dataScale, &nullable) @@ -668,16 +720,16 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam var ignore bool var dt string - var len1, len2 int + var len1, len2 int64 dts := strings.Split(*dataType, "(") dt = dts[0] if len(dts) > 1 { lens := strings.Split(dts[1][:len(dts[1])-1], ",") if len(lens) > 1 { - len1, _ = strconv.Atoi(lens[0]) - len2, _ = strconv.Atoi(lens[1]) + len1, _ = strconv.ParseInt(lens[0], 10, 64) + len2, _ = strconv.ParseInt(lens[1], 10, 64) } else { - len1, _ = strconv.Atoi(lens[0]) + len1, _ = strconv.ParseInt(lens[0], 10, 64) } } @@ -720,6 +772,9 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -744,6 +799,9 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -758,7 +816,7 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { var indexType int var indexName, colName, uniqueness string @@ -793,6 +851,9 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -802,10 +863,17 @@ func (db *oracle) Filters() []Filter { } } -type goracleDriver struct { +type godrorDriver struct { + baseDriver } -func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*URI, error) { +func (g *godrorDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: false, + } +} + +func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] @@ -817,8 +885,7 @@ func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*URI, error) names := dsnPattern.SubexpNames() for i, match := range matches { - switch names[i] { - case "dbname": + if names[i] == "dbname" { db.DBName = match } } @@ -828,12 +895,33 @@ func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*URI, error) return db, nil } +func (g *godrorDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": + var s sql.NullString + return &s, nil + case "NUMBER": + var s sql.NullString + return &s, nil + case "DATE": + var s sql.NullTime + return &s, nil + case "BLOB": + var r sql.RawBytes + return &r, nil + default: + var r sql.RawBytes + return &r, nil + } +} + type oci8Driver struct { + godrorDriver } // dataSourceName=user/password@ipv4:port/dbname // dataSourceName=user/password@[ipv6]:port/dbname -func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { +func (o *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?P.*)\/(?P.*)@` + // user:password@ @@ -842,8 +930,7 @@ func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { matches := dsnPattern.FindStringSubmatch(dataSourceName) names := dsnPattern.SubexpNames() for i, match := range matches { - switch names[i] { - case "dbname": + if names[i] == "dbname" { db.DBName = match } } @@ -852,3 +939,7 @@ func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { } return db, nil } + +type oracleDriver struct { + godrorDriver +} diff --git a/dialects/postgres.go b/dialects/postgres.go index a2c0de74..5efe54f4 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -6,6 +6,7 @@ package dialects import ( "context" + "database/sql" "errors" "fmt" "net/url" @@ -777,17 +778,68 @@ var ( var ( // DefaultPostgresSchema default postgres schema DefaultPostgresSchema = "public" + postgresColAliases = map[string]string{ + "numeric": "decimal", + } ) type postgres struct { Base } +// Alias returns a alias of column +func (db *postgres) Alias(col string) string { + v, ok := postgresColAliases[strings.ToLower(col)] + if ok { + return v + } + return col +} + func (db *postgres) Init(uri *URI) error { db.quoter = postgresQuoter return db.Base.Init(db, uri) } +func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { + rows, err := queryer.QueryContext(ctx, "SELECT version()") + if err != nil { + return nil, err + } + defer rows.Close() + + var version string + if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") + } + + if err := rows.Scan(&version); err != nil { + return nil, err + } + + // Postgres: 9.5.22 on x86_64-pc-linux-gnu (Debian 9.5.22-1.pgdg90+1), compiled by gcc (Debian 6.3.0-18+deb9u1) 6.3.0 20170516, 64-bit + // CockroachDB CCL v19.2.4 (x86_64-unknown-linux-gnu, built + if strings.HasPrefix(version, "CockroachDB") { + versions := strings.Split(strings.TrimPrefix(version, "CockroachDB CCL "), " ") + return &schemas.Version{ + Number: strings.TrimPrefix(versions[0], "v"), + Edition: "CockroachDB", + }, nil + } else if strings.HasPrefix(version, "PostgreSQL") { + versions := strings.Split(strings.TrimPrefix(version, "PostgreSQL "), " on ") + return &schemas.Version{ + Number: versions[0], + Level: versions[1], + Edition: "PostgreSQL", + }, nil + } + + return nil, errors.New("unknow database version") +} + func (db *postgres) getSchema() string { if db.uri.Schema != "" { return db.uri.Schema @@ -810,11 +862,11 @@ func (db *postgres) needQuote(name string) bool { func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: - var q = postgresQuoter + q := postgresQuoter q.IsReserved = schemas.AlwaysNoReserve db.quoter = q case QuotePolicyReserved: - var q = postgresQuoter + q := postgresQuoter q.IsReserved = db.needQuote db.quoter = q case QuotePolicyAlways: @@ -827,18 +879,18 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *postgres) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case schemas.TinyInt: + case schemas.TinyInt, schemas.UnsignedTinyInt: res = schemas.SmallInt return res case schemas.Bit: res = schemas.Boolean return res - case schemas.MediumInt, schemas.Int, schemas.Integer: + case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt: if c.IsAutoIncrement { return schemas.Serial } return schemas.Integer - case schemas.BigInt: + case schemas.BigInt, schemas.UnsignedBigInt, schemas.UnsignedInt: if c.IsAutoIncrement { return schemas.BigSerial } @@ -882,13 +934,34 @@ func (db *postgres) SQLType(c *schemas.Column) string { hasLen2 := (c.Length2 > 0) if hasLen2 { - res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")" } else if hasLen1 { - res += "(" + strconv.Itoa(c.Length) + ")" + res += "(" + strconv.FormatInt(c.Length, 10) + ")" } return res } +func (db *postgres) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + +func (db *postgres) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME", "TIMESTAMP": + return schemas.TIME_TYPE + case "VARCHAR", "TEXT": + return schemas.TEXT_TYPE + case "BIGINT", "BIGSERIAL", "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL", "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION": + return schemas.NUMERIC_TYPE + case "BOOL": + return schemas.BOOL_TYPE + default: + return schemas.UNKNOW_TYPE + } +} + func (db *postgres) IsReserved(name string) bool { _, ok := postgresReservedWords[strings.ToUpper(name)] return ok @@ -898,41 +971,6 @@ func (db *postgres) AutoIncrStr() string { return "" } -func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql string - sql = "CREATE TABLE IF NOT EXISTS " - if tableName == "" { - tableName = table.Name - } - - quoter := db.Quoter() - sql += quoter.Quote(tableName) - sql += " (" - - if len(table.ColumnsSeq()) > 0 { - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { - col := table.GetColumn(colName) - s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) - sql += s - sql = strings.TrimSpace(sql) - sql += ", " - } - - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += quoter.Join(pkList, ",") - sql += " ), " - } - - sql = sql[:len(sql)-2] - } - sql += ")" - - return []string{sql}, true -} - func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { if len(db.getSchema()) == 0 { args := []interface{}{tableName, idxName} @@ -953,13 +991,37 @@ func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tabl db.getSchema(), tableName) } -func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { +func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, true) + + quoter := db.dialect.Quoter() + addColumnSQL := "" + commentSQL := "; " if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") { - return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", - tableName, col.Name, db.SQLType(col)) + addColumnSQL = fmt.Sprintf("ALTER TABLE %s ADD %s", quoter.Quote(tableName), s) + commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) + return addColumnSQL + commentSQL } - return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", - db.getSchema(), tableName, col.Name, db.SQLType(col)) + + addColumnSQL = fmt.Sprintf("ALTER TABLE %s.%s ADD %s", quoter.Quote(db.getSchema()), quoter.Quote(tableName), s) + commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'", quoter.Quote(db.getSchema()), quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) + return addColumnSQL + commentSQL +} + +func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { + quoter := db.dialect.Quoter() + modifyColumnSQL := "" + commentSQL := "; " + + if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") { + modifyColumnSQL = fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s TYPE %s", quoter.Quote(tableName), quoter.Quote(col.Name), db.SQLType(col)) + commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) + return modifyColumnSQL + commentSQL + } + + modifyColumnSQL = fmt.Sprintf("ALTER TABLE %s.%s ALTER COLUMN %s TYPE %s", quoter.Quote(db.getSchema()), quoter.Quote(tableName), quoter.Quote(col.Name), db.SQLType(col)) + commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s.%s IS '%s'", quoter.Quote(db.getSchema()), quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) + return modifyColumnSQL + commentSQL } func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string { @@ -968,11 +1030,10 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".") tableName = tableParts[len(tableParts)-1] - if !strings.HasPrefix(idxName, "UQE_") && - !strings.HasPrefix(idxName, "IDX_") { - if index.Type == schemas.UniqueType { + if index.IsRegular { + if index.Type == schemas.UniqueType && !strings.HasPrefix(idxName, "UQE_") { idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) - } else { + } else if index.Type == schemas.IndexType && !strings.HasPrefix(idxName, "IDX_") { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } @@ -998,17 +1059,21 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} - s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, + s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, description, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey FROM pg_attribute f JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum + LEFT JOIN pg_description de ON f.attrelid=de.objoid AND f.attnum=de.objsubid LEFT JOIN pg_namespace n ON n.oid = c.relnamespace LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid @@ -1037,25 +1102,29 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A col.Indexes = make(map[string]int) var colName, isNullable, dataType string - var maxLenStr, colDefault *string + var maxLenStr, colDefault, description *string var isPK, isUnique bool - err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &isPK, &isUnique) + err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &description, &isPK, &isUnique) if err != nil { return nil, nil, err } - var maxLen int + var maxLen int64 if maxLenStr != nil { - maxLen, err = strconv.Atoi(*maxLenStr) + maxLen, err = strconv.ParseInt(*maxLenStr, 10, 64) if err != nil { return nil, nil, err } } + if colDefault != nil && *colDefault == "unique_rowid()" { // ignore the system column added by cockroach + continue + } + col.Name = strings.Trim(colName, `" `) if colDefault != nil { - var theDefault = *colDefault + theDefault := *colDefault // cockroach has type with the default value with ::: // and postgres with ::, we should remove them before store them idx := strings.Index(theDefault, ":::") @@ -1081,6 +1150,10 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A col.DefaultIsEmpty = true } + if description != nil { + col.Comment = *description + } + if isPK { col.IsPrimaryKey = true } @@ -1112,14 +1185,14 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A startIdx := strings.Index(strings.ToLower(dataType), "string(") if startIdx != -1 && strings.HasSuffix(dataType, ")") { length := dataType[startIdx+8 : len(dataType)-1] - l, _ := strconv.Atoi(length) + l, _ := strconv.ParseInt(length, 10, 64) col.SQLType = schemas.SQLType{Name: "STRING", DefaultLength: l, DefaultLength2: 0} } else { col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} } } if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { - return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name) + return nil, nil, fmt.Errorf("unknown colType: %s - %s", dataType, col.SQLType.Name) } col.Length = maxLen @@ -1127,19 +1200,20 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A if !col.DefaultIsEmpty { if col.SQLType.IsText() { if strings.HasSuffix(col.Default, "::character varying") { - col.Default = strings.TrimRight(col.Default, "::character varying") + col.Default = strings.TrimSuffix(col.Default, "::character varying") } else if !strings.HasPrefix(col.Default, "'") { col.Default = "'" + col.Default + "'" } } else if col.SQLType.IsTime() { - if strings.HasSuffix(col.Default, "::timestamp without time zone") { - col.Default = strings.TrimRight(col.Default, "::timestamp without time zone") - } + col.Default = strings.TrimSuffix(col.Default, "::timestamp without time zone") } } cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -1170,6 +1244,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch table.Name = name tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -1186,10 +1263,10 @@ func getIndexColName(indexdef string) []string { func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} - s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") + s := "SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1" if len(db.getSchema()) != 0 { args = append(args, db.getSchema()) - s = s + " AND schemaname=$2" + s += " AND schemaname=$2" } rows, err := queryer.QueryContext(ctx, s, args...) @@ -1198,7 +1275,7 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { var indexType int var indexName, indexdef string @@ -1212,7 +1289,8 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN continue } indexName = strings.Trim(indexName, `" `) - if strings.HasSuffix(indexName, "_pkey") { + // ignore primary index + if strings.HasSuffix(indexName, "_pkey") || strings.EqualFold(indexName, "primary") { continue } if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { @@ -1221,6 +1299,12 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN indexType = schemas.IndexType } colNames = getIndexColName(indexdef) + + // Oid It's a special index. You can't put it in. TODO: This is not perfect. + if indexName == tableName+"_oid_index" && len(colNames) == 1 && colNames[0] == "oid" { + continue + } + var isRegular bool if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { newIdxName := indexName[5+len(tableName):] @@ -1232,31 +1316,57 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN index := &schemas.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} for _, colName := range colNames { - index.Cols = append(index.Cols, strings.TrimSpace(strings.Replace(colName, `"`, "", -1))) + col := strings.TrimSpace(strings.Replace(colName, `"`, "", -1)) + fields := strings.Split(col, " ") + index.Cols = append(index.Cols, fields[0]) } index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } +func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { + quoter := db.dialect.Quoter() + if len(db.getSchema()) != 0 && !strings.Contains(tableName, ".") { + tableName = fmt.Sprintf("%s.%s", db.getSchema(), tableName) + } + + createTableSQL, ok, err := db.Base.CreateTableSQL(ctx, queryer, table, tableName) + if err != nil { + return "", ok, err + } + + commentSQL := "; " + if table.Comment != "" { + // support schema.table -> "schema"."table" + commentSQL += fmt.Sprintf("COMMENT ON TABLE %s IS '%s'; ", quoter.Quote(tableName), table.Comment) + } + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + + if len(col.Comment) > 0 { + commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'; ", quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment) + } + } + + return createTableSQL + commentSQL, true, nil +} + func (db *postgres) Filters() []Filter { return []Filter{&SeqFilter{Prefix: "$", Start: 1}} } type pqDriver struct { + baseDriver } type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - func parseURL(connstr string) (string, error) { u, err := url.Parse(connstr) if err != nil { @@ -1276,30 +1386,94 @@ func parseURL(connstr string) (string, error) { return "", nil } -func parseOpts(name string, o values) error { - if len(name) == 0 { - return fmt.Errorf("invalid options: %s", name) +func parseOpts(urlStr string, o values) error { + if len(urlStr) == 0 { + return fmt.Errorf("invalid options: %s", urlStr) } - name = strings.TrimSpace(name) + urlStr = strings.TrimSpace(urlStr) - ps := strings.Split(name, " ") - for _, p := range ps { - kv := strings.Split(p, "=") - if len(kv) < 2 { - return fmt.Errorf("invalid option: %q", p) + var ( + inQuote bool + state int // 0 key, 1 space, 2 value, 3 equal + start int + key string + ) + for i, c := range urlStr { + switch c { + case ' ': + if !inQuote { + if state == 2 { + state = 1 + v := urlStr[start:i] + if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") { + v = v[1 : len(v)-1] + } else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") { + return fmt.Errorf("wrong single quote in %d of %s", i, urlStr) + } + o[key] = v + } else if state != 1 { + return fmt.Errorf("wrong format: %v", urlStr) + } + } + case '\'': + if state == 3 { + state = 2 + start = i + } else if state != 2 { + return fmt.Errorf("wrong format: %v", urlStr) + } + inQuote = !inQuote + case '=': + if !inQuote { + if state != 0 { + return fmt.Errorf("wrong format: %v", urlStr) + } + key = urlStr[start:i] + state = 3 + } + default: + if state == 3 { + state = 2 + start = i + } else if state == 1 { + state = 0 + start = i + } + } + + if i == len(urlStr)-1 { + if state != 2 { + return errors.New("no value matched key") + } + v := urlStr[start : i+1] + if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") { + v = v[1 : len(v)-1] + } else if strings.HasPrefix(v, "'") || strings.HasSuffix(v, "'") { + return fmt.Errorf("wrong single quote in %d of %s", i, urlStr) + } + o[key] = v } - o.Set(kv[0], kv[1]) } return nil } +func (p *pqDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: false, + } +} + func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.POSTGRES} - var err error - if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { + var err error + if strings.Contains(dataSourceName, "://") { + if !strings.HasPrefix(dataSourceName, "postgresql://") && !strings.HasPrefix(dataSourceName, "postgres://") { + return nil, fmt.Errorf("unsupported protocol %v", dataSourceName) + } + db.DBName, err = parseURL(dataSourceName) if err != nil { return nil, err @@ -1311,7 +1485,7 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { return nil, err } - db.DBName = o.Get("dbname") + db.DBName = o["dbname"] } if db.DBName == "" { @@ -1321,6 +1495,32 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { return db, nil } +func (p *pqDriver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "VARCHAR", "TEXT": + var s sql.NullString + return &s, nil + case "BIGINT", "BIGSERIAL": + var s sql.NullInt64 + return &s, nil + case "SMALLINT", "INT", "INT8", "INT4", "INTEGER", "SERIAL": + var s sql.NullInt32 + return &s, nil + case "FLOAT", "FLOAT4", "REAL", "DOUBLE PRECISION": + var s sql.NullFloat64 + return &s, nil + case "DATETIME", "TIMESTAMP": + var s sql.NullTime + return &s, nil + case "BOOL": + var s sql.NullBool + return &s, nil + default: + var r sql.RawBytes + return &r, nil + } +} + type pqDriverPgx struct { pqDriver } @@ -1348,6 +1548,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri parts := strings.Split(defaultSchema, ",") return strings.TrimSpace(parts[len(parts)-1]), nil } + if rows.Err() != nil { + return "", rows.Err() + } - return "", errors.New("No default schema") + return "", errors.New("no default schema") } diff --git a/dialects/postgres_test.go b/dialects/postgres_test.go index c0a8eb6f..bed8f307 100644 --- a/dialects/postgres_test.go +++ b/dialects/postgres_test.go @@ -22,20 +22,24 @@ func TestParsePostgres(t *testing.T) { //{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true}, {"dbname=db sslmode=disable", "db", true}, {"user=auser password=password dbname=db sslmode=disable", "db", true}, + {"user=auser password='pass word' dbname=db sslmode=disable", "db", true}, + {"user=auser password='pass word' sslmode=disable dbname='db'", "db", true}, + {"user=auser password='pass word' sslmode='disable dbname=db'", "db", false}, {"", "db", false}, {"dbname=db =disable", "db", false}, } driver := QueryDriver("postgres") - for _, test := range tests { - uri, err := driver.Parse("postgres", test.in) + t.Run(test.in, func(t *testing.T) { + uri, err := driver.Parse("postgres", test.in) - if err != nil && test.valid { - t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) - } + if err != nil && test.valid { + t.Errorf("%q got unexpected error: %s", test.in, err) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) + } + }) } } @@ -76,9 +80,7 @@ func TestParsePgx(t *testing.T) { } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } - } - } func TestGetIndexColName(t *testing.T) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 73f98beb..4ff9a39e 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -160,6 +160,36 @@ func (db *sqlite3) Init(uri *URI) error { return db.Base.Init(db, uri) } +func (db *sqlite3) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) { + rows, err := queryer.QueryContext(ctx, "SELECT sqlite_version()") + if err != nil { + return nil, err + } + defer rows.Close() + + var version string + if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") + } + + if err := rows.Scan(&version); err != nil { + return nil, err + } + return &schemas.Version{ + Number: version, + Edition: "sqlite", + }, nil +} + +func (db *sqlite3) Features() *DialectFeatures { + return &DialectFeatures{ + AutoincrMode: IncrAutoincrMode, + } +} + func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { switch quotePolicy { case QuotePolicyNone: @@ -193,7 +223,9 @@ func (db *sqlite3) SQLType(c *schemas.Column) string { case schemas.Char, schemas.Varchar, schemas.NVarchar, schemas.TinyText, schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json: return schemas.Text - case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt: + case schemas.Bit, schemas.TinyInt, schemas.UnsignedTinyInt, schemas.SmallInt, + schemas.UnsignedSmallInt, schemas.MediumInt, schemas.Int, schemas.UnsignedInt, + schemas.BigInt, schemas.UnsignedBigInt, schemas.Integer: return schemas.Integer case schemas.Float, schemas.Double, schemas.Real: return schemas.Real @@ -211,8 +243,19 @@ func (db *sqlite3) SQLType(c *schemas.Column) string { } } -func (db *sqlite3) FormatBytes(bs []byte) string { - return fmt.Sprintf("X'%x'", bs) +func (db *sqlite3) ColumnTypeKind(t string) int { + switch strings.ToUpper(t) { + case "DATETIME": + return schemas.TIME_TYPE + case "TEXT": + return schemas.TEXT_TYPE + case "INTEGER", "REAL", "NUMERIC", "DECIMAL": + return schemas.NUMERIC_TYPE + case "BLOB": + return schemas.BLOB_TYPE + default: + return schemas.UNKNOW_TYPE + } } func (db *sqlite3) IsReserved(name string) bool { @@ -248,41 +291,6 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql string - sql = "CREATE TABLE IF NOT EXISTS " - if tableName == "" { - tableName = table.Name - } - - quoter := db.Quoter() - sql += quoter.Quote(tableName) - sql += " (" - - if len(table.ColumnsSeq()) > 0 { - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { - col := table.GetColumn(colName) - s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) - sql += s - sql = strings.TrimSpace(sql) - sql += ", " - } - - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += quoter.Join(pkList, ",") - sql += " ), " - } - - sql = sql[:len(sql)-2] - } - sql += ")" - - return []string{sql}, true -} - func (db *sqlite3) ForUpdateSQL(query string) string { return query } @@ -382,12 +390,14 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa defer rows.Close() var name string - for rows.Next() { + if rows.Next() { err = rows.Scan(&name) if err != nil { return nil, nil, err } - break + } + if rows.Err() != nil { + return nil, nil, rows.Err() } if name == "" { @@ -450,6 +460,9 @@ func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*sche } tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -463,7 +476,7 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { var tmpSQL sql.NullString err = rows.Scan(&tmpSQL) @@ -483,7 +496,7 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa continue } - indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []'\"") + indexName := strings.Trim(strings.TrimSpace(sql[nNStart+6:nNEnd]), "`[]'\"") var isRegular bool if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { index.Name = indexName[5+len(tableName):] @@ -509,6 +522,9 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -518,6 +534,13 @@ func (db *sqlite3) Filters() []Filter { } type sqlite3Driver struct { + baseDriver +} + +func (p *sqlite3Driver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: true, + } } func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { @@ -527,3 +550,29 @@ func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil } + +func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { + switch colType { + case "TEXT": + var s sql.NullString + return &s, nil + case "INTEGER": + var s sql.NullInt64 + return &s, nil + case "DATETIME": + var s sql.NullTime + return &s, nil + case "REAL": + var s sql.NullFloat64 + return &s, nil + case "NUMERIC", "DECIMAL": + var s sql.NullString + return &s, nil + case "BLOB": + var s sql.RawBytes + return &s, nil + default: + var r sql.NullString + return &r, nil + } +} diff --git a/dialects/table_name.go b/dialects/table_name.go index e190cd4b..8a0baeac 100644 --- a/dialects/table_name.go +++ b/dialects/table_name.go @@ -11,14 +11,14 @@ import ( "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" + "xorm.io/xorm/schemas" ) // TableNameWithSchema will add schema prefix on table name if possible func TableNameWithSchema(dialect Dialect, tableName string) string { // Add schema name as prefix of table name. // Only for postgres database. - if dialect.URI().Schema != "" && - strings.Index(tableName, ".") == -1 { + if dialect.URI().Schema != "" && !strings.Contains(tableName, ".") { return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName) } return tableName @@ -27,20 +27,21 @@ func TableNameWithSchema(dialect Dialect, tableName string) string { // TableNameNoSchema returns table name with given tableName func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface{}) string { quote := dialect.Quoter().Quote - switch tableName.(type) { + switch tt := tableName.(type) { case []string: - t := tableName.([]string) - if len(t) > 1 { - return fmt.Sprintf("%v AS %v", quote(t[0]), quote(t[1])) - } else if len(t) == 1 { - return quote(t[0]) + if len(tt) > 1 { + if dialect.URI().DBType == schemas.ORACLE { + return fmt.Sprintf("%v %v", quote(tt[0]), quote(tt[1])) + } + return fmt.Sprintf("%v AS %v", quote(tt[0]), quote(tt[1])) + } else if len(tt) == 1 { + return quote(tt[0]) } case []interface{}: - t := tableName.([]interface{}) - l := len(t) + l := len(tt) var table string if l > 0 { - f := t[0] + f := tt[0] switch f.(type) { case string: table = f.(string) @@ -57,7 +58,10 @@ func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface } } if l > 1 { - return fmt.Sprintf("%v AS %v", quote(table), quote(fmt.Sprintf("%v", t[1]))) + if dialect.URI().DBType == schemas.ORACLE { + return fmt.Sprintf("%v %v", quote(table), quote(fmt.Sprintf("%v", tt[1]))) + } + return fmt.Sprintf("%v AS %v", quote(table), quote(fmt.Sprintf("%v", tt[1]))) } else if l == 1 { return quote(table) } diff --git a/dialects/time.go b/dialects/time.go index b0394745..cdc896be 100644 --- a/dialects/time.go +++ b/dialects/time.go @@ -5,45 +5,59 @@ package dialects import ( + "strings" "time" "xorm.io/xorm/schemas" ) -// FormatTime format time as column type -func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}) { - switch sqlTypeName { - case schemas.Time: - s := t.Format("2006-01-02 15:04:05") // time.RFC3339 - v = s[11:19] - case schemas.Date: - v = t.Format("2006-01-02") - case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar. - v = t.Format("2006-01-02 15:04:05") - case schemas.TimeStampz: - if dialect.URI().DBType == schemas.MSSQL { - v = t.Format("2006-01-02T15:04:05.9999999Z07:00") - } else { - v = t.Format(time.RFC3339Nano) - } - case schemas.BigInt, schemas.Int: - v = t.Unix() - default: - v = t - } - return -} - -func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) { +// FormatColumnTime format column time +func FormatColumnTime(dialect Dialect, dbLocation *time.Location, col *schemas.Column, t time.Time) (interface{}, error) { if t.IsZero() { if col.Nullable { - return nil + return nil, nil + } + + if col.SQLType.IsNumeric() { + return 0, nil } - return "" } + tmZone := dbLocation if col.TimeZone != nil { - return FormatTime(dialect, col.SQLType.Name, t.In(col.TimeZone)) + tmZone = col.TimeZone + } + + t = t.In(tmZone) + + switch col.SQLType.Name { + case schemas.Date: + return t.Format("2006-01-02"), nil + case schemas.Time: + layout := "15:04:05" + if col.Length > 0 { + // we can use int(...) casting here as it's very unlikely to a huge sized field + layout += "." + strings.Repeat("0", int(col.Length)) + } + return t.Format(layout), nil + case schemas.DateTime, schemas.TimeStamp: + layout := "2006-01-02 15:04:05" + if col.Length > 0 { + // we can use int(...) casting here as it's very unlikely to a huge sized field + layout += "." + strings.Repeat("0", int(col.Length)) + } + return t.Format(layout), nil + case schemas.Varchar: + return t.Format("2006-01-02 15:04:05"), nil + case schemas.TimeStampz: + if dialect.URI().DBType == schemas.MSSQL { + return t.Format("2006-01-02T15:04:05.9999999Z07:00"), nil + } else { + return t.Format(time.RFC3339Nano), nil + } + case schemas.BigInt, schemas.Int: + return t.Unix(), nil + default: + return t, nil } - return FormatTime(dialect, col.SQLType.Name, t.In(defaultTimeZone)) } diff --git a/doc.go b/doc.go index ea6a2226..f88f5371 100644 --- a/doc.go +++ b/doc.go @@ -3,181 +3,246 @@ // license that can be found in the LICENSE file. /* - Package xorm is a simple and powerful ORM for Go. -Installation +# Installation Make sure you have installed Go 1.11+ and then: - go get xorm.io/xorm + go get xorm.io/xorm -Create Engine +# Create Engine -Firstly, we should new an engine for a database +Firstly, we should create an engine for a database - engine, err := xorm.NewEngine(driverName, dataSourceName) + engine, err := xorm.NewEngine(driverName, dataSourceName) -Method NewEngine's parameters is the same as sql.Open. It depends -drivers' implementation. -Generally, one engine for an application is enough. You can set it as package variable. +Method NewEngine's parameters are the same as sql.Open which depend drivers' implementation. +Generally, one engine for an application is enough. You can define it as a package variable. -Raw Methods +# Raw Methods -XORM also support raw SQL execution: +XORM supports raw SQL execution: -1. query a SQL string, the returned results is []map[string][]byte +1. query with a SQL string, the returned results is []map[string][]byte - results, err := engine.Query("select * from user") + results, err := engine.Query("select * from user") -2. execute a SQL string, the returned results +2. query with a SQL string, the returned results is []map[string]string - affected, err := engine.Exec("update user set .... where ...") + results, err := engine.QueryString("select * from user") -ORM Methods +3. query with a SQL string, the returned results is []map[string]interface{} + + results, err := engine.QueryInterface("select * from user") + +4. execute with a SQL string, the returned results + + affected, err := engine.Exec("update user set .... where ...") + +# ORM Methods There are 8 major ORM methods and many helpful methods to use to operate database. 1. Insert one or multiple records to database - affected, err := engine.Insert(&struct) - // INSERT INTO struct () values () - affected, err := engine.Insert(&struct1, &struct2) - // INSERT INTO struct1 () values () - // INSERT INTO struct2 () values () - affected, err := engine.Insert(&sliceOfStruct) - // INSERT INTO struct () values (),(),() - affected, err := engine.Insert(&struct1, &sliceOfStruct2) - // INSERT INTO struct1 () values () - // INSERT INTO struct2 () values (),(),() + affected, err := engine.Insert(&struct) + // INSERT INTO struct () values () + affected, err := engine.Insert(&struct1, &struct2) + // INSERT INTO struct1 () values () + // INSERT INTO struct2 () values () + affected, err := engine.Insert(&sliceOfStruct) + // INSERT INTO struct () values (),(),() + affected, err := engine.Insert(&struct1, &sliceOfStruct2) + // INSERT INTO struct1 () values () + // INSERT INTO struct2 () values (),(),() 2. Query one record or one variable from database - has, err := engine.Get(&user) - // SELECT * FROM user LIMIT 1 + has, err := engine.Get(&user) + // SELECT * FROM user LIMIT 1 - var id int64 - has, err := engine.Table("user").Where("name = ?", name).Get(&id) - // SELECT id FROM user WHERE name = ? LIMIT 1 + var id int64 + has, err := engine.Table("user").Where("name = ?", name).Get(&id) + // SELECT id FROM user WHERE name = ? LIMIT 1 + + var id int64 + var name string + has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name) + // SELECT id, name FROM user LIMIT 1 3. Query multiple records from database - var sliceOfStructs []Struct - err := engine.Find(&sliceOfStructs) - // SELECT * FROM user + var sliceOfStructs []Struct + err := engine.Find(&sliceOfStructs) + // SELECT * FROM user - var mapOfStructs = make(map[int64]Struct) - err := engine.Find(&mapOfStructs) - // SELECT * FROM user + var mapOfStructs = make(map[int64]Struct) + err := engine.Find(&mapOfStructs) + // SELECT * FROM user - var int64s []int64 - err := engine.Table("user").Cols("id").Find(&int64s) - // SELECT id FROM user + var int64s []int64 + err := engine.Table("user").Cols("id").Find(&int64s) + // SELECT id FROM user 4. Query multiple records and record by record handle, there two methods, one is Iterate, another is Rows - err := engine.Iterate(...) - // SELECT * FROM user + err := engine.Iterate(new(User), func(i int, bean interface{}) error { + // do something + }) + // SELECT * FROM user - rows, err := engine.Rows(...) - // SELECT * FROM user - defer rows.Close() - bean := new(Struct) - for rows.Next() { - err = rows.Scan(bean) - } + rows, err := engine.Rows(...) + // SELECT * FROM user + defer rows.Close() + bean := new(Struct) + for rows.Next() { + err = rows.Scan(bean) + } + +or + + rows, err := engine.Cols("name", "age").Rows(...) + // SELECT * FROM user + defer rows.Close() + for rows.Next() { + var name string + var age int + err = rows.Scan(&name, &age) + } 5. Update one or more records - affected, err := engine.ID(...).Update(&user) - // UPDATE user SET ... + affected, err := engine.ID(...).Update(&user) + // UPDATE user SET ... 6. Delete one or more records, Delete MUST has condition - affected, err := engine.Where(...).Delete(&user) - // DELETE FROM user Where ... + affected, err := engine.Where(...).Delete(&user) + // DELETE FROM user Where ... 7. Count records - counts, err := engine.Count(&user) - // SELECT count(*) AS total FROM user + counts, err := engine.Count(&user) + // SELECT count(*) AS total FROM user - counts, err := engine.SQL("select count(*) FROM user").Count() - // select count(*) FROM user + counts, err := engine.SQL("select count(*) FROM user").Count() + // select count(*) FROM user 8. Sum records - sumFloat64, err := engine.Sum(&user, "id") - // SELECT sum(id) from user + sumFloat64, err := engine.Sum(&user, "id") + // SELECT sum(id) from user - sumFloat64s, err := engine.Sums(&user, "id1", "id2") - // SELECT sum(id1), sum(id2) from user + sumFloat64s, err := engine.Sums(&user, "id1", "id2") + // SELECT sum(id1), sum(id2) from user - sumInt64s, err := engine.SumsInt(&user, "id1", "id2") - // SELECT sum(id1), sum(id2) from user + sumInt64s, err := engine.SumsInt(&user, "id1", "id2") + // SELECT sum(id1), sum(id2) from user -Conditions +# Conditions The above 8 methods could use with condition methods chainable. -Attention: the above 8 methods should be the last chainable method. +Notice: the above 8 methods should be the last chainable method. 1. ID, In - engine.ID(1).Get(&user) // for single primary key - // SELECT * FROM user WHERE id = 1 - engine.ID(schemas.PK{1, 2}).Get(&user) // for composite primary keys - // SELECT * FROM user WHERE id1 = 1 AND id2 = 2 - engine.In("id", 1, 2, 3).Find(&users) - // SELECT * FROM user WHERE id IN (1, 2, 3) - engine.In("id", []int{1, 2, 3}).Find(&users) - // SELECT * FROM user WHERE id IN (1, 2, 3) + engine.ID(1).Get(&user) // for single primary key + // SELECT * FROM user WHERE id = 1 + engine.ID(schemas.PK{1, 2}).Get(&user) // for composite primary keys + // SELECT * FROM user WHERE id1 = 1 AND id2 = 2 + engine.In("id", 1, 2, 3).Find(&users) + // SELECT * FROM user WHERE id IN (1, 2, 3) + engine.In("id", []int{1, 2, 3}).Find(&users) + // SELECT * FROM user WHERE id IN (1, 2, 3) 2. Where, And, Or - engine.Where().And().Or().Find() - // SELECT * FROM user WHERE (.. AND ..) OR ... + engine.Where().And().Or().Find() + // SELECT * FROM user WHERE (.. AND ..) OR ... 3. OrderBy, Asc, Desc - engine.Asc().Desc().Find() - // SELECT * FROM user ORDER BY .. ASC, .. DESC - engine.OrderBy().Find() - // SELECT * FROM user ORDER BY .. + engine.Asc().Desc().Find() + // SELECT * FROM user ORDER BY .. ASC, .. DESC + engine.OrderBy().Find() + // SELECT * FROM user ORDER BY .. 4. Limit, Top - engine.Limit().Find() - // SELECT * FROM user LIMIT .. OFFSET .. - engine.Top(5).Find() - // SELECT TOP 5 * FROM user // for mssql - // SELECT * FROM user LIMIT .. OFFSET 0 //for other databases + engine.Limit().Find() + // SELECT * FROM user LIMIT .. OFFSET .. + engine.Top(5).Find() + // SELECT TOP 5 * FROM user // for mssql + // SELECT * FROM user LIMIT .. OFFSET 0 //for other databases 5. SQL, let you custom SQL - var users []User - engine.SQL("select * from user").Find(&users) + var users []User + engine.SQL("select * from user").Find(&users) 6. Cols, Omit, Distinct - var users []*User - engine.Cols("col1, col2").Find(&users) - // SELECT col1, col2 FROM user - engine.Cols("col1", "col2").Where().Update(user) - // UPDATE user set col1 = ?, col2 = ? Where ... - engine.Omit("col1").Find(&users) - // SELECT col2, col3 FROM user - engine.Omit("col1").Insert(&user) - // INSERT INTO table (non-col1) VALUES () - engine.Distinct("col1").Find(&users) - // SELECT DISTINCT col1 FROM user + var users []*User + engine.Cols("col1, col2").Find(&users) + // SELECT col1, col2 FROM user + engine.Cols("col1", "col2").Where().Update(user) + // UPDATE user set col1 = ?, col2 = ? Where ... + engine.Omit("col1").Find(&users) + // SELECT col2, col3 FROM user + engine.Omit("col1").Insert(&user) + // INSERT INTO table (non-col1) VALUES () + engine.Distinct("col1").Find(&users) + // SELECT DISTINCT col1 FROM user 7. Join, GroupBy, Having - engine.GroupBy("name").Having("name='xlw'").Find(&users) - //SELECT * FROM user GROUP BY name HAVING name='xlw' - engine.Join("LEFT", "userdetail", "user.id=userdetail.id").Find(&users) - //SELECT * FROM user LEFT JOIN userdetail ON user.id=userdetail.id + engine.GroupBy("name").Having("name='xlw'").Find(&users) + //SELECT * FROM user GROUP BY name HAVING name='xlw' + engine.Join("LEFT", "userdetail", "user.id=userdetail.id").Find(&users) + //SELECT * FROM user LEFT JOIN userdetail ON user.id=userdetail.id + +# Builder + +xorm could work with xorm.io/builder directly. + +1. With Where + + var cond = builder.Eq{"a":1, "b":2} + engine.Where(cond).Find(&users) + +2. With In + + var subQuery = builder.Select("name").From("group") + engine.In("group_name", subQuery).Find(&users) + +3. With Join + + var subQuery = builder.Select("name").From("group") + engine.Join("INNER", subQuery, "group.id = user.group_id").Find(&users) + +4. With SetExprs + + var subQuery = builder.Select("name").From("group") + engine.ID(1).SetExprs("name", subQuery).Update(new(User)) + +5. With SQL + + var query = builder.Select("name").From("group") + results, err := engine.SQL(query).Find(&groups) + +6. With Query + + var query = builder.Select("name").From("group") + results, err := engine.Query(query) + results, err := engine.QueryString(query) + results, err := engine.QueryInterface(query) + +7. With Exec + + var query = builder.Insert("a, b").Into("table1").Select("b, c").From("table2") + results, err := engine.Exec(query) More usage, please visit http://xorm.io/docs */ diff --git a/engine.go b/engine.go index 6c894e74..389819e7 100644 --- a/engine.go +++ b/engine.go @@ -7,11 +7,11 @@ package xorm import ( "context" "database/sql" - "errors" "fmt" "io" "os" "reflect" + "regexp" "runtime" "strconv" "strings" @@ -34,6 +34,7 @@ type Engine struct { cacherMgr *caches.Manager defaultContext context.Context dialect dialects.Dialect + driver dialects.Driver engineGroup *EngineGroup logger log.ContextLogger tagParser *tags.Parser @@ -71,6 +72,7 @@ func newEngine(driverName, dataSourceName string, dialect dialects.Dialect, db * engine := &Engine{ dialect: dialect, + driver: dialects.QueryDriver(driverName), TZLocation: time.Local, defaultContext: context.Background(), cacherMgr: cacherMgr, @@ -105,6 +107,15 @@ func NewEngineWithParams(driverName string, dataSourceName string, params map[st return engine, err } +// NewEngineWithDB new a db manager with db. The params will be passed to db. +func NewEngineWithDB(driverName string, dataSourceName string, db *core.DB) (*Engine, error) { + dialect, err := dialects.OpenDialect(driverName, dataSourceName) + if err != nil { + return nil, err + } + return newEngine(driverName, dataSourceName, dialect, db) +} + // NewEngineWithDialectAndDB new a db manager according to the parameter. // If you do not want to use your own dialect or db, please use NewEngine. // For creating dialect, you can call dialects.OpenDialect. And, for creating db, @@ -159,6 +170,8 @@ func (engine *Engine) SetLogger(logger interface{}) { realLogger = t case log.Logger: realLogger = log.NewLoggerAdapter(t) + default: + panic("logger should implement either log.ContextLogger or log.Logger") } engine.logger = realLogger engine.DB().Logger = realLogger @@ -200,6 +213,11 @@ func (engine *Engine) SetColumnMapper(mapper names.Mapper) { engine.tagParser.SetColumnMapper(mapper) } +// SetTagIdentifier set the tag identifier +func (engine *Engine) SetTagIdentifier(tagIdentifier string) { + engine.tagParser.SetIdentifier(tagIdentifier) +} + // Quote Use QuoteStr quote the string sql func (engine *Engine) Quote(value string) string { value = strings.TrimSpace(value) @@ -231,16 +249,16 @@ func (engine *Engine) SQLType(c *schemas.Column) string { return engine.dialect.SQLType(c) } -// AutoIncrStr Database's autoincrement statement -func (engine *Engine) AutoIncrStr() string { - return engine.dialect.AutoIncrStr() -} - // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. func (engine *Engine) SetConnMaxLifetime(d time.Duration) { engine.DB().SetConnMaxLifetime(d) } +// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle. +func (engine *Engine) SetConnMaxIdleTime(d time.Duration) { + engine.DB().SetConnMaxIdleTime(d) +} + // SetMaxOpenConns is only available for go 1.2+ func (engine *Engine) SetMaxOpenConns(conns int) { engine.DB().SetMaxOpenConns(conns) @@ -317,7 +335,7 @@ func (engine *Engine) Ping() error { // SQL method let's you manually write raw SQL and operate // For example: // -// engine.SQL("select * from user").Find(&users) +// engine.SQL("select * from user").Find(&users) // // This code will execute "select * from user" and set the records to users func (engine *Engine) SQL(query interface{}, args ...interface{}) *Session { @@ -359,13 +377,16 @@ func (engine *Engine) loadTableInfo(table *schemas.Table) error { var seq int for _, index := range indexes { for _, name := range index.Cols { - parts := strings.Split(name, " ") + parts := strings.Split(strings.TrimSpace(name), " ") if len(parts) > 1 { if parts[1] == "DESC" { seq = 1 + } else if parts[1] == "ASC" { + seq = 0 } } - if col := table.GetColumn(parts[0]); col != nil { + colName := strings.Trim(parts[0], `"`) + if col := table.GetColumn(colName); col != nil { col.Indexes[index.Name] = index.Type } else { return fmt.Errorf("Unknown col %s seq %d, in index %v of table %v, columns %v", name, seq, index.Name, table.Name, table.ColumnsSeq()) @@ -421,103 +442,47 @@ func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp .. // DumpTables dump specify tables to io.Writer func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { - return engine.dumpTables(tables, w, tp...) + return engine.dumpTables(context.Background(), tables, w, tp...) } -func formatColumnValue(dstDialect dialects.Dialect, d interface{}, col *schemas.Column) string { - if d == nil { - return "NULL" - } - - if dq, ok := d.(bool); ok && (dstDialect.URI().DBType == schemas.SQLITE || - dstDialect.URI().DBType == schemas.MSSQL) { - if dq { +func formatBool(s bool, dstDialect dialects.Dialect) string { + if dstDialect.URI().DBType != schemas.POSTGRES { + if s { return "1" } return "0" } - - if col.SQLType.IsText() { - var v = fmt.Sprintf("%s", d) - return "'" + strings.Replace(v, "'", "''", -1) + "'" - } else if col.SQLType.IsTime() { - var v = fmt.Sprintf("%s", d) - if strings.HasSuffix(v, " +0000 UTC") { - return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 UTC")]) - } else if strings.HasSuffix(v, " +0000 +0000") { - return fmt.Sprintf("'%s'", v[0:len(v)-len(" +0000 +0000")]) - } - return "'" + strings.Replace(v, "'", "''", -1) + "'" - } else if col.SQLType.IsBlob() { - if reflect.TypeOf(d).Kind() == reflect.Slice { - return fmt.Sprintf("%s", dstDialect.FormatBytes(d.([]byte))) - } else if reflect.TypeOf(d).Kind() == reflect.String { - return fmt.Sprintf("'%s'", d.(string)) - } - } else if col.SQLType.IsNumeric() { - switch reflect.TypeOf(d).Kind() { - case reflect.Slice: - if col.SQLType.Name == schemas.Bool { - return fmt.Sprintf("%v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) - } - return fmt.Sprintf("%s", string(d.([]byte))) - case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: - if col.SQLType.Name == schemas.Bool { - v := reflect.ValueOf(d).Int() > 0 - if dstDialect.URI().DBType == schemas.SQLITE { - if v { - return "1" - } - return "0" - } - return fmt.Sprintf("%v", strconv.FormatBool(v)) - } - return fmt.Sprintf("%v", d) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if col.SQLType.Name == schemas.Bool { - v := reflect.ValueOf(d).Uint() > 0 - if dstDialect.URI().DBType == schemas.SQLITE { - if v { - return "1" - } - return "0" - } - return fmt.Sprintf("%v", strconv.FormatBool(v)) - } - return fmt.Sprintf("%v", d) - default: - return fmt.Sprintf("%v", d) - } - } - - s := fmt.Sprintf("%v", d) - if strings.Contains(s, ":") || strings.Contains(s, "-") { - if strings.HasSuffix(s, " +0000 UTC") { - return fmt.Sprintf("'%s'", s[0:len(s)-len(" +0000 UTC")]) - } - return fmt.Sprintf("'%s'", s) - } - return s + return strconv.FormatBool(s) } +var controlCharactersRe = regexp.MustCompile(`[\x00-\x1f\x7f]+`) + // dumpTables dump database all table structs and data to w with specify db type -func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { +func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { var dstDialect dialects.Dialect if len(tp) == 0 { dstDialect = engine.dialect } else { dstDialect = dialects.QueryDialect(tp[0]) if dstDialect == nil { - return errors.New("Unsupported database type") + return fmt.Errorf("unsupported database type %v", tp[0]) } uri := engine.dialect.URI() destURI := dialects.URI{ DBType: tp[0], DBName: uri.DBName, + // DO NOT SET SCHEMA HERE + } + if tp[0] == schemas.POSTGRES { + destURI.Schema = engine.dialect.URI().Schema + } + if err := dstDialect.Init(&destURI); err != nil { + return err } - dstDialect.Init(&destURI) } + cacherMgr := caches.NewManager() + dstTableCache := tags.NewParser("xorm", dstDialect, engine.GetTableMapper(), engine.GetColumnMapper(), cacherMgr) _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n", time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType)) @@ -525,10 +490,29 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } + if dstDialect.URI().DBType == schemas.MYSQL { + // For MySQL set NO_BACKLASH_ESCAPES so that strings work properly + if _, err := io.WriteString(w, "SET sql_mode='NO_BACKSLASH_ESCAPES';\n"); err != nil { + return err + } + } + for i, table := range tables { - tableName := table.Name + dstTable := table + if table.Type != nil { + dstTable, err = dstTableCache.Parse(reflect.New(table.Type).Elem()) + if err != nil { + engine.logger.Errorf("Unable to infer table for %s in new dialect. Error: %v", table.Name) + dstTable = table + } + } + + dstTableName := dstTable.Name + quoter := dstDialect.Quoter().Quote + quotedDstTableName := quoter(dstTable.Name) if dstDialect.URI().Schema != "" { - tableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, table.Name) + dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name) + quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name)) } originalTableName := table.Name if engine.dialect.URI().Schema != "" { @@ -540,27 +524,43 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } - sqls, _ := dstDialect.CreateTableSQL(table, tableName) - for _, s := range sqls { - _, err = io.WriteString(w, s+";\n") + + if dstTable.AutoIncrement != "" && dstDialect.Features().AutoincrMode == dialects.SequenceAutoincrMode { + sqlstr, err := dstDialect.CreateSequenceSQL(ctx, engine.db, utils.SeqName(dstTableName)) + if err != nil { + return err + } + _, err = io.WriteString(w, sqlstr+";\n") if err != nil { return err } } - if len(table.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL { - fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name) + + sqlstr, _, err := dstDialect.CreateTableSQL(ctx, engine.db, dstTable, dstTableName) + if err != nil { + return err + } + _, err = io.WriteString(w, sqlstr+";\n") + if err != nil { + return err } - for _, index := range table.Indexes { - _, err = io.WriteString(w, dstDialect.CreateIndexSQL(table.Name, index)+";\n") + if len(dstTable.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL { + fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", dstTable.Name) + } + + for _, index := range dstTable.Indexes { + _, err = io.WriteString(w, dstDialect.CreateIndexSQL(dstTable.Name, index)+";\n") if err != nil { return err } } cols := table.ColumnsSeq() + dstCols := dstTable.ColumnsSeq() + colNames := engine.dialect.Quoter().Join(cols, ", ") - destColNames := dstDialect.Quoter().Join(cols, ", ") + destColNames := dstDialect.Quoter().Join(dstCols, ", ") rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(originalTableName)) if err != nil { @@ -568,39 +568,261 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch } defer rows.Close() + types, err := rows.ColumnTypes() + if err != nil { + return err + } + + fields, err := rows.Columns() + if err != nil { + return err + } + + sess := engine.NewSession() + defer sess.Close() for rows.Next() { - dest := make([]interface{}, len(cols)) - err = rows.ScanSlice(&dest) + _, err = io.WriteString(w, "INSERT INTO "+quotedDstTableName+" ("+destColNames+") VALUES (") if err != nil { return err } - _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(tableName)+" ("+destColNames+") VALUES (") + scanResults, err := sess.engine.scanStringInterface(rows, fields, types) if err != nil { return err } + for i, scanResult := range scanResults { + stp := schemas.SQLType{Name: types[i].DatabaseTypeName()} + s := scanResult.(*sql.NullString) + if !s.Valid { + if _, err = io.WriteString(w, "NULL"); err != nil { + return err + } + } else { + if table.Columns()[i].SQLType.IsBool() || stp.IsBool() || (dstDialect.URI().DBType == schemas.MSSQL && strings.EqualFold(stp.Name, schemas.Bit)) { + val, err := strconv.ParseBool(s.String) + if err != nil { + return err + } - var temp string - for i, d := range dest { - col := table.GetColumn(cols[i]) - if col == nil { - return errors.New("unknow column error") + if _, err = io.WriteString(w, formatBool(val, dstDialect)); err != nil { + return err + } + } else if stp.IsNumeric() { + if _, err = io.WriteString(w, s.String); err != nil { + return err + } + } else if sess.engine.dialect.URI().DBType == schemas.DAMENG && stp.IsTime() && len(s.String) == 25 { + r := strings.ReplaceAll(s.String[:19], "T", " ") + if _, err = io.WriteString(w, "'"+r+"'"); err != nil { + return err + } + } else if len(s.String) == 0 { + if _, err := io.WriteString(w, "''"); err != nil { + return err + } + } else if dstDialect.URI().DBType == schemas.POSTGRES { + if dstTable.Columns()[i].SQLType.IsBlob() { + // Postgres has the escape format and we should use that for bytea data + if _, err := fmt.Fprintf(w, "'\\x%x'", s.String); err != nil { + return err + } + } else { + // Postgres concatentates strings using || (NOTE: a NUL byte in a text segment will fail) + toCheck := strings.ReplaceAll(s.String, "'", "''") + for len(toCheck) > 0 { + loc := controlCharactersRe.FindStringIndex(toCheck) + if loc == nil { + if _, err := io.WriteString(w, "'"+toCheck+"'"); err != nil { + return err + } + break + } + if loc[0] > 0 { + if _, err := io.WriteString(w, "'"+toCheck[:loc[0]]+"' || "); err != nil { + return err + } + } + if _, err := io.WriteString(w, "e'"); err != nil { + return err + } + for i := loc[0]; i < loc[1]; i++ { + if _, err := fmt.Fprintf(w, "\\x%02x", toCheck[i]); err != nil { + return err + } + } + toCheck = toCheck[loc[1]:] + if len(toCheck) > 0 { + if _, err := io.WriteString(w, "' || "); err != nil { + return err + } + } else { + if _, err := io.WriteString(w, "'"); err != nil { + return err + } + } + } + } + } else if dstDialect.URI().DBType == schemas.MYSQL { + loc := controlCharactersRe.FindStringIndex(s.String) + if loc == nil { + if _, err := io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil { + return err + } + } else { + if _, err := io.WriteString(w, "CONCAT("); err != nil { + return err + } + toCheck := strings.ReplaceAll(s.String, "'", "''") + for len(toCheck) > 0 { + loc := controlCharactersRe.FindStringIndex(toCheck) + if loc == nil { + if _, err := io.WriteString(w, "'"+toCheck+"')"); err != nil { + return err + } + break + } + if loc[0] > 0 { + if _, err := io.WriteString(w, "'"+toCheck[:loc[0]]+"', "); err != nil { + return err + } + } + for i := loc[0]; i < loc[1]-1; i++ { + if _, err := io.WriteString(w, "CHAR("+strconv.Itoa(int(toCheck[i]))+"), "); err != nil { + return err + } + } + char := toCheck[loc[1]-1] + toCheck = toCheck[loc[1]:] + if len(toCheck) > 0 { + if _, err := io.WriteString(w, "CHAR("+strconv.Itoa(int(char))+"), "); err != nil { + return err + } + } else { + if _, err = io.WriteString(w, "CHAR("+strconv.Itoa(int(char))+"))"); err != nil { + return err + } + } + } + } + } else if dstDialect.URI().DBType == schemas.SQLITE { + if dstTable.Columns()[i].SQLType.IsBlob() { + // SQLite has its escape format + if _, err := fmt.Fprintf(w, "X'%x'", s.String); err != nil { + return err + } + } else { + // SQLite concatentates strings using || (NOTE: a NUL byte in a text segment will fail) + toCheck := strings.ReplaceAll(s.String, "'", "''") + for len(toCheck) > 0 { + loc := controlCharactersRe.FindStringIndex(toCheck) + if loc == nil { + if _, err := io.WriteString(w, "'"+toCheck+"'"); err != nil { + return err + } + break + } + if loc[0] > 0 { + if _, err := io.WriteString(w, "'"+toCheck[:loc[0]]+"' || "); err != nil { + return err + } + } + if _, err := fmt.Fprintf(w, "X'%x'", toCheck[loc[0]:loc[1]]); err != nil { + return err + } + toCheck = toCheck[loc[1]:] + if len(toCheck) > 0 { + if _, err := io.WriteString(w, " || "); err != nil { + return err + } + } + } + } + } else if dstDialect.URI().DBType == schemas.DAMENG || dstDialect.URI().DBType == schemas.ORACLE { + if dstTable.Columns()[i].SQLType.IsBlob() { + // ORACLE/DAMENG uses HEXTORAW + if _, err := fmt.Fprintf(w, "HEXTORAW('%x')", s.String); err != nil { + return err + } + } else { + // ORACLE/DAMENG concatentates strings in multiple ways but uses CHAR and has CONCAT + // (NOTE: a NUL byte in a text segment will fail) + if _, err := io.WriteString(w, "CONCAT("); err != nil { + return err + } + toCheck := strings.ReplaceAll(s.String, "'", "''") + for len(toCheck) > 0 { + loc := controlCharactersRe.FindStringIndex(toCheck) + if loc == nil { + if _, err := io.WriteString(w, "'"+toCheck+"')"); err != nil { + return err + } + break + } + if loc[0] > 0 { + if _, err := io.WriteString(w, "'"+toCheck[:loc[0]]+"', "); err != nil { + return err + } + } + for i := loc[0]; i < loc[1]-1; i++ { + if _, err := io.WriteString(w, "CHAR("+strconv.Itoa(int(toCheck[i]))+"), "); err != nil { + return err + } + } + char := toCheck[loc[1]-1] + toCheck = toCheck[loc[1]:] + if len(toCheck) > 0 { + if _, err := io.WriteString(w, "CHAR("+strconv.Itoa(int(char))+"), "); err != nil { + return err + } + } else { + if _, err = io.WriteString(w, "CHAR("+strconv.Itoa(int(char))+"))"); err != nil { + return err + } + } + } + } + } else if dstDialect.URI().DBType == schemas.MSSQL { + if dstTable.Columns()[i].SQLType.IsBlob() { + // MSSQL uses CONVERT(VARBINARY(MAX), '0xDEADBEEF', 1) + if _, err := fmt.Fprintf(w, "CONVERT(VARBINARY(MAX), '0x%x', 1)", s.String); err != nil { + return err + } + } else { + if _, err = io.WriteString(w, "N'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil { + return err + } + } + } else { + if _, err = io.WriteString(w, "'"+strings.ReplaceAll(s.String, "'", "''")+"'"); err != nil { + return err + } + } + } + if i < len(scanResults)-1 { + if _, err = io.WriteString(w, ","); err != nil { + return err + } } - temp += "," + formatColumnValue(dstDialect, d, col) } - _, err = io.WriteString(w, temp[1:]+");\n") + _, err = io.WriteString(w, ");\n") if err != nil { return err } } + if rows.Err() != nil { + return rows.Err() + } // FIXME: Hack for postgres if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { - _, err = io.WriteString(w, "SELECT setval('"+tableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(tableName)+"), 1), false);\n") + _, err = io.WriteString(w, "SELECT setval('"+dstTableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(dstTableName)+"), 1), false);\n") if err != nil { return err } } + // !datbeohbbh! if no error, manually close + rows.Close() + sess.Close() } return nil } @@ -782,9 +1004,8 @@ func (engine *Engine) Desc(colNames ...string) *Session { // Asc will generate "ORDER BY column1,column2 Asc" // This method can chainable use. // -// engine.Desc("name").Asc("age").Find(&users) -// // SELECT * FROM user ORDER BY name DESC, age ASC -// +// engine.Desc("name").Asc("age").Find(&users) +// // SELECT * FROM user ORDER BY name DESC, age ASC func (engine *Engine) Asc(colNames ...string) *Session { session := engine.NewSession() session.isAutoClose = true @@ -792,10 +1013,10 @@ func (engine *Engine) Asc(colNames ...string) *Session { } // OrderBy will generate "ORDER BY order" -func (engine *Engine) OrderBy(order string) *Session { +func (engine *Engine) OrderBy(order interface{}, args ...interface{}) *Session { session := engine.NewSession() session.isAutoClose = true - return session.OrderBy(order) + return session.OrderBy(order, args...) } // Prepare enables prepare statement @@ -806,7 +1027,7 @@ func (engine *Engine) Prepare() *Session { } // Join the join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (engine *Engine) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { +func (engine *Engine) Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session { session := engine.NewSession() session.isAutoClose = true return session.Join(joinOperator, tablename, condition, args...) @@ -826,15 +1047,9 @@ func (engine *Engine) Having(conditions string) *Session { return session.Having(conditions) } -// Table table struct -type Table struct { - *schemas.Table - Name string -} - -// IsValid if table is valid -func (t *Table) IsValid() bool { - return t.Table != nil && len(t.Name) > 0 +// DBVersion returns the database version +func (engine *Engine) DBVersion() (*schemas.Version, error) { + return engine.dialect.Version(engine.defaultContext, engine.db) } // TableInfo get table info according to bean's content @@ -911,104 +1126,13 @@ func (engine *Engine) UnMapType(t reflect.Type) { func (engine *Engine) Sync(beans ...interface{}) error { session := engine.NewSession() defer session.Close() - - for _, bean := range beans { - v := utils.ReflectValue(bean) - tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) - table, err := engine.tagParser.ParseWithCache(v) - if err != nil { - return err - } - - isExist, err := session.Table(bean).isTableExist(tableNameNoSchema) - if err != nil { - return err - } - if !isExist { - err = session.createTable(bean) - if err != nil { - return err - } - } - /*isEmpty, err := engine.IsEmptyTable(bean) - if err != nil { - return err - }*/ - var isEmpty bool - if isEmpty { - err = session.dropTable(bean) - if err != nil { - return err - } - err = session.createTable(bean) - if err != nil { - return err - } - } else { - for _, col := range table.Columns() { - isExist, err := engine.dialect.IsColumnExist(engine.db, session.ctx, tableNameNoSchema, col.Name) - if err != nil { - return err - } - if !isExist { - if err := session.statement.SetRefBean(bean); err != nil { - return err - } - err = session.addColumn(col.Name) - if err != nil { - return err - } - } - } - - for name, index := range table.Indexes { - if err := session.statement.SetRefBean(bean); err != nil { - return err - } - if index.Type == schemas.UniqueType { - isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) - if err != nil { - return err - } - if !isExist { - if err := session.statement.SetRefBean(bean); err != nil { - return err - } - - err = session.addUnique(tableNameNoSchema, name) - if err != nil { - return err - } - } - } else if index.Type == schemas.IndexType { - isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) - if err != nil { - return err - } - if !isExist { - if err := session.statement.SetRefBean(bean); err != nil { - return err - } - - err = session.addIndex(tableNameNoSchema, name) - if err != nil { - return err - } - } - } else { - return errors.New("unknow index type") - } - } - } - } - return nil + return session.Sync(beans...) } // Sync2 synchronize structs to database tables +// Depricated func (engine *Engine) Sync2(beans ...interface{}) error { - s := engine.NewSession() - defer s.Close() - return s.Sync2(beans...) + return engine.Sync(beans...) } // CreateTables create tabls according bean @@ -1024,7 +1148,7 @@ func (engine *Engine) CreateTables(beans ...interface{}) error { for _, bean := range beans { err = session.createTable(bean) if err != nil { - session.Rollback() + _ = session.Rollback() return err } } @@ -1044,7 +1168,7 @@ func (engine *Engine) DropTables(beans ...interface{}) error { for _, bean := range beans { err = session.dropTable(bean) if err != nil { - session.Rollback() + _ = session.Rollback() return err } } @@ -1103,9 +1227,10 @@ func (engine *Engine) InsertOne(bean interface{}) (int64, error) { // Update records, bean's non-empty fields are updated contents, // condiBean' non-empty filds are conditions // CAUTION: -// 1.bool will defaultly be updated content nor conditions -// You should call UseBool if you have bool to use. -// 2.float32 & float64 may be not inexact as conditions +// +// 1.bool will defaultly be updated content nor conditions +// You should call UseBool if you have bool to use. +// 2.float32 & float64 may be not inexact as conditions func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) { session := engine.NewSession() defer session.Close() @@ -1113,18 +1238,27 @@ func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64 } // Delete records, bean's non-empty fields are conditions -func (engine *Engine) Delete(bean interface{}) (int64, error) { +// At least one condition must be set. +func (engine *Engine) Delete(beans ...interface{}) (int64, error) { session := engine.NewSession() defer session.Close() - return session.Delete(bean) + return session.Delete(beans...) +} + +// Truncate records, bean's non-empty fields are conditions +// In contrast to Delete this method allows deletes without conditions. +func (engine *Engine) Truncate(beans ...interface{}) (int64, error) { + session := engine.NewSession() + defer session.Close() + return session.Truncate(beans...) } // Get retrieve one record from table, bean's non-empty fields // are conditions -func (engine *Engine) Get(bean interface{}) (bool, error) { +func (engine *Engine) Get(beans ...interface{}) (bool, error) { session := engine.NewSession() defer session.Close() - return session.Get(bean) + return session.Get(beans...) } // Exist returns true if the record exist otherwise return false @@ -1215,13 +1349,13 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { } // nowTime return current time -func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) { +func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time, error) { t := time.Now() - var tz = engine.DatabaseTZ - if !col.DisableTimeZone && col.TimeZone != nil { - tz = col.TimeZone + result, err := dialects.FormatColumnTime(engine.dialect, engine.DatabaseTZ, col, t) + if err != nil { + return nil, time.Time{}, err } - return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) + return result, t.In(engine.TZLocation), nil } // GetColumnMapper returns the column name mapper @@ -1259,6 +1393,7 @@ func (engine *Engine) SetSchema(schema string) { engine.dialect.URI().SetSchema(schema) } +// AddHook adds a context Hook func (engine *Engine) AddHook(hook contexts.Hook) { engine.db.AddHook(hook) } @@ -1274,7 +1409,7 @@ func (engine *Engine) tbNameWithSchema(v string) string { return dialects.TableNameWithSchema(engine.dialect, v) } -// ContextHook creates a session with the context +// Context creates a session with the context func (engine *Engine) Context(ctx context.Context) *Session { session := engine.NewSession() session.isAutoClose = true diff --git a/engine_group.go b/engine_group.go index cdd9dd44..f2fe913d 100644 --- a/engine_group.go +++ b/engine_group.go @@ -79,7 +79,7 @@ func (eg *EngineGroup) Close() error { return nil } -// ContextHook returned a group session +// Context returned a group session func (eg *EngineGroup) Context(ctx context.Context) *Session { sess := eg.NewSession() sess.isAutoClose = true @@ -144,6 +144,7 @@ func (eg *EngineGroup) SetLogger(logger interface{}) { } } +// AddHook adds Hook func (eg *EngineGroup) AddHook(hook contexts.Hook) { eg.Engine.AddHook(hook) for i := 0; i < len(eg.slaves); i++ { @@ -167,6 +168,14 @@ func (eg *EngineGroup) SetMapper(mapper names.Mapper) { } } +// SetTagIdentifier set the tag identifier +func (eg *EngineGroup) SetTagIdentifier(tagIdentifier string) { + eg.Engine.SetTagIdentifier(tagIdentifier) + for i := 0; i < len(eg.slaves); i++ { + eg.slaves[i].SetTagIdentifier(tagIdentifier) + } +} + // SetMaxIdleConns set the max idle connections on pool, default is 2 func (eg *EngineGroup) SetMaxIdleConns(conns int) { eg.Engine.DB().SetMaxIdleConns(conns) @@ -228,3 +237,31 @@ func (eg *EngineGroup) Slave() *Engine { func (eg *EngineGroup) Slaves() []*Engine { return eg.slaves } + +// Query execcute a select SQL and return the result +func (eg *EngineGroup) Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error) { + sess := eg.NewSession() + sess.isAutoClose = true + return sess.Query(sqlOrArgs...) +} + +// QueryInterface execcute a select SQL and return the result +func (eg *EngineGroup) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) { + sess := eg.NewSession() + sess.isAutoClose = true + return sess.QueryInterface(sqlOrArgs...) +} + +// QueryString execcute a select SQL and return the result +func (eg *EngineGroup) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) { + sess := eg.NewSession() + sess.isAutoClose = true + return sess.QueryString(sqlOrArgs...) +} + +// Rows execcute a select SQL and return the result +func (eg *EngineGroup) Rows(bean interface{}) (*Rows, error) { + sess := eg.NewSession() + sess.isAutoClose = true + return sess.Rows(bean) +} diff --git a/go.mod b/go.mod index e0d22a24..7bde41ae 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,21 @@ module xorm.io/xorm -go 1.11 +go 1.13 require ( - github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc - github.com/go-sql-driver/mysql v1.5.0 - github.com/lib/pq v1.7.0 - github.com/mattn/go-sqlite3 v1.14.0 - github.com/stretchr/testify v1.4.0 + gitee.com/travelliu/dm v1.8.11192 + github.com/denisenkom/go-mssqldb v0.10.0 + github.com/go-sql-driver/mysql v1.6.0 + github.com/goccy/go-json v0.8.1 + github.com/golang/snappy v0.0.4 // indirect + github.com/jackc/pgx/v4 v4.12.0 + github.com/json-iterator/go v1.1.12 + github.com/lib/pq v1.10.2 + github.com/mattn/go-sqlite3 v1.14.9 + github.com/shopspring/decimal v1.2.0 + github.com/stretchr/testify v1.7.0 github.com/syndtr/goleveldb v1.0.0 github.com/ziutek/mymysql v1.5.4 - xorm.io/builder v0.3.7 + modernc.org/sqlite v1.14.2 + xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978 ) diff --git a/go.sum b/go.sum index 844dd094..8bdc9798 100644 --- a/go.sum +++ b/go.sum @@ -1,70 +1,663 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s= gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= -github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= -github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +gitee.com/travelliu/dm v1.8.11192 h1:aqJT0xhodZjRutIfEXxKYv0CxqmHUHzsbz6SFaRL6OY= +gitee.com/travelliu/dm v1.8.11192/go.mod h1:DHTzyhCrM843x9VdKVbZ+GKXGRbKM2sJ4LxihRxShkE= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= +github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= +github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= +github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= +github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aryann/difflib v0.0.0-20170710044230-e206f873d14a/go.mod h1:DAHtR1m6lCRdSC2Tm3DSWRPvIPr6xNKyeHdqDQSQT+A= +github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU= +github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clbanning/x2j v0.0.0-20191024224557-825249438eec/go.mod h1:jMjuTZXRI4dUb/I5gc9Hdhagfvm9+RyrPryS/auMzxE= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= +github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= +github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= +github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc h1:VRRKCwnzqk8QCaRC4os14xoKDdbHqqlJtJA0oc1ZAjg= -github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= +github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= +github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= +github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= +github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= +github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgOZ7o= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/goccy/go-json v0.8.1 h1:4/Wjm0JIJaTDm8K1KcGrLHJoa8EsJ13YWeX+6Kfq6uI= +github.com/goccy/go-json v0.8.1/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= -github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.5.3 h1:x95R7cp+rSeeqAMI2knLtQ0DKlaBhv2NrtrOvafPHRo= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= +github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= +github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerXegt+ozgdvDeDU= +github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO+LraFDTW64= +github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= +github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= +github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/lib/pq v1.7.0 h1:h93mCPfUSkaul3Ka/VG8uZdmW1uMHDGxzu0NWHuJmHY= -github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/QA= -github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= +github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.4.0/go.mod h1:Y2O3ZDF0q4mMacyWV3AstPJpeHXWGEetiFttmq5lahk= +github.com/jackc/pgconn v1.5.0/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= +github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853/go.mod h1:QeD3lBfpTFe8WUnPZWN5KY/mB8FGMIYRdd8P8Jr0fAI= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.8.1/go.mod h1:JV6m6b6jhjdmzchES0drzCcYcAHS1OPD5xu3OZ/lE2g= +github.com/jackc/pgconn v1.9.0 h1:gqibKSTJup/ahCsNKyMZAniPuZEfIqfXFc8FOWVYR+Q= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd h1:eDErF6V/JPJON/B7s68BxwHgfmyOntHJQ8IOaz0x4R8= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1 h1:7PQ/4gLoqnl87ZxL7xjO0DR5gYuviDCZxQJsUlFW1eI= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200307190119-3430c5407db8/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= +github.com/jackc/pgtype v1.3.1-0.20200510190516-8cd94a14c75a/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= +github.com/jackc/pgtype v1.3.1-0.20200606141011-f6355165a91c/go.mod h1:cvk9Bgu/VzJ9/lxTO5R5sf80p0DiucVtN7ZxvaC4GmQ= +github.com/jackc/pgtype v1.7.0/go.mod h1:ZnHF+rMePVqDKaOfJVI4Q8IVvAQMryDlDkZnKOI75BE= +github.com/jackc/pgtype v1.8.0 h1:iFVCcVhYlw0PulYCVoguRGm0SE9guIcPcccnLzHj8bA= +github.com/jackc/pgtype v1.8.0/go.mod h1:PqDKcEBtllAtk/2p6z6SHdXW5UB+MhE75tUol2OKexE= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.5.0/go.mod h1:EpAKPLdnTorwmPUUsqrPxy5fphV18j9q3wrfRXgo+kA= +github.com/jackc/pgx/v4 v4.6.1-0.20200510190926-94ba730bb1e9/go.mod h1:t3/cdRQl6fOLDxqtlyhe9UWgfIi9R8+8v8GKV5TRA/o= +github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904/go.mod h1:ZDaNWkt9sW1JMiNn0kdYBaLelIhw7Pg4qd+Vk6tw7Hg= +github.com/jackc/pgx/v4 v4.11.0/go.mod h1:i62xJgdrtVDsnL3U8ekyrQXEwGNTRoG7/8r+CIdYfcc= +github.com/jackc/pgx/v4 v4.12.0 h1:xiP3TdnkwyslWNp77yE5XAPfxAsU9RMFDe0c1SwN8h4= +github.com/jackc/pgx/v4 v4.12.0/go.mod h1:fE547h6VulLPA3kySjfnSG/e2D861g/50JlVUa/ub60= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.8/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= +github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= +github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= +github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA= +github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= +github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= +github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= +github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= +github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= +github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= +github.com/nats-io/nats-server/v2 v2.1.2/go.mod h1:Afk+wRZqkMQs/p45uXdrVLuab3gwv3Z8C4HTBu8GD/k= +github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= +github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= +github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= +github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= +github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/op/go-logging v0.0.0-20160315200505-970db520ece7/go.mod h1:HzydrMdWErDVzsI23lYNej1Htcns9BCg93Dk0bBINWk= +github.com/opentracing-contrib/go-observer v0.0.0-20170622124052-a52f23424492/go.mod h1:Ngi6UdF0k5OKD5t5wlmGhe/EDKPoUM3BXZSSfIuJbis= +github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= +github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxSfWAKL3wpBW7V8scJMt8N8gnaMCS9E/cA= +github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= +github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= +github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4= +github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= +github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= +github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/performancecopilot/speed v3.0.0+incompatible/go.mod h1:/CLtqpZ5gBg1M9iaPbIdPPGyKcA8hKdoy6hAWba7Yac= +github.com/pierrec/lz4 v1.0.2-0.20190131084431-473cd7ce01a1/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc= +github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.1.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.7.0/go.mod h1:DjGbpBbp5NYNiECxcL/VnbXCCaQpKd3tt26CguLLsqA= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= +github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= +github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= +github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw= +github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= +github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= +github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs= github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= +go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= +go.etcd.io/etcd v0.0.0-20191023171146-3cf2f69b5738/go.mod h1:dnLIgRNXwCJa5e+c6mIZCrds/GIG4ncV9HhK5PX7jPg= +go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= +go.opencensus.io v0.20.2/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= +go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k= -golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= +golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e h1:o3PsSEY8E4eXWkXrIP9YJALUkVZqzHJT5DOasTyn8Vs= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= +golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201126233918-771906719818/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac h1:oN6lz7iLW/YC7un8pq+9bOLyXrprv2+DKfkJY+2LJJw= +golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78 h1:M8tBwCtWD/cZV9DZpFYRUgaymAYAr+aIUTWzDaM3uPs= +golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190530194941-fb225487d101/go.mod h1:z3L6/3dTEVtUr6QSP8miRzeRqwQOioJ9I66odjN4I7s= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.22.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/gcfg.v1 v1.2.3/go.mod h1:yesOnuUOFQAhST5vPY4nbZsb/huCgGGXlipJsBn0b3o= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= +gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -xorm.io/builder v0.3.7 h1:2pETdKRK+2QG4mLX4oODHEhn5Z8j1m8sXa7jfu+/SZI= -xorm.io/builder v0.3.7/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +lukechampine.com/uint128 v1.1.1 h1:pnxCASz787iMf+02ssImqk6OLt+Z5QHMoZyUXR4z6JU= +lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.33.6/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.33.9/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.33.11/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.34.0/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.0/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.4/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.5/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.7/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.8/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.10/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.15/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.16/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.17/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/cc/v3 v3.35.18 h1:rMZhRcWrba0y3nVmdiQ7kxAgOOSq2m2f2VzjHLgEs6U= +modernc.org/cc/v3 v3.35.18/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= +modernc.org/ccgo/v3 v3.9.5/go.mod h1:umuo2EP2oDSBnD3ckjaVUXMrmeAw8C8OSICVa0iFf60= +modernc.org/ccgo/v3 v3.10.0/go.mod h1:c0yBmkRFi7uW4J7fwx/JiijwOjeAeR2NoSaRVFPmjMw= +modernc.org/ccgo/v3 v3.11.0/go.mod h1:dGNposbDp9TOZ/1KBxghxtUp/bzErD0/0QW4hhSaBMI= +modernc.org/ccgo/v3 v3.11.1/go.mod h1:lWHxfsn13L3f7hgGsGlU28D9eUOf6y3ZYHKoPaKU0ag= +modernc.org/ccgo/v3 v3.11.3/go.mod h1:0oHunRBMBiXOKdaglfMlRPBALQqsfrCKXgw9okQ3GEw= +modernc.org/ccgo/v3 v3.12.4/go.mod h1:Bk+m6m2tsooJchP/Yk5ji56cClmN6R1cqc9o/YtbgBQ= +modernc.org/ccgo/v3 v3.12.6/go.mod h1:0Ji3ruvpFPpz+yu+1m0wk68pdr/LENABhTrDkMDWH6c= +modernc.org/ccgo/v3 v3.12.8/go.mod h1:Hq9keM4ZfjCDuDXxaHptpv9N24JhgBZmUG5q60iLgUo= +modernc.org/ccgo/v3 v3.12.11/go.mod h1:0jVcmyDwDKDGWbcrzQ+xwJjbhZruHtouiBEvDfoIsdg= +modernc.org/ccgo/v3 v3.12.14/go.mod h1:GhTu1k0YCpJSuWwtRAEHAol5W7g1/RRfS4/9hc9vF5I= +modernc.org/ccgo/v3 v3.12.18/go.mod h1:jvg/xVdWWmZACSgOiAhpWpwHWylbJaSzayCqNOJKIhs= +modernc.org/ccgo/v3 v3.12.20/go.mod h1:aKEdssiu7gVgSy/jjMastnv/q6wWGRbszbheXgWRHc8= +modernc.org/ccgo/v3 v3.12.21/go.mod h1:ydgg2tEprnyMn159ZO/N4pLBqpL7NOkJ88GT5zNU2dE= +modernc.org/ccgo/v3 v3.12.22/go.mod h1:nyDVFMmMWhMsgQw+5JH6B6o4MnZ+UQNw1pp52XYFPRk= +modernc.org/ccgo/v3 v3.12.25/go.mod h1:UaLyWI26TwyIT4+ZFNjkyTbsPsY3plAEB6E7L/vZV3w= +modernc.org/ccgo/v3 v3.12.29/go.mod h1:FXVjG7YLf9FetsS2OOYcwNhcdOLGt8S9bQ48+OP75cE= +modernc.org/ccgo/v3 v3.12.36/go.mod h1:uP3/Fiezp/Ga8onfvMLpREq+KUjUmYMxXPO8tETHtA8= +modernc.org/ccgo/v3 v3.12.38/go.mod h1:93O0G7baRST1vNj4wnZ49b1kLxt0xCW5Hsa2qRaZPqc= +modernc.org/ccgo/v3 v3.12.43/go.mod h1:k+DqGXd3o7W+inNujK15S5ZYuPoWYLpF5PYougCmthU= +modernc.org/ccgo/v3 v3.12.46/go.mod h1:UZe6EvMSqOxaJ4sznY7b23/k13R8XNlyWsO5bAmSgOE= +modernc.org/ccgo/v3 v3.12.47/go.mod h1:m8d6p0zNps187fhBwzY/ii6gxfjob1VxWb919Nk1HUk= +modernc.org/ccgo/v3 v3.12.50/go.mod h1:bu9YIwtg+HXQxBhsRDE+cJjQRuINuT9PUK4orOco/JI= +modernc.org/ccgo/v3 v3.12.51/go.mod h1:gaIIlx4YpmGO2bLye04/yeblmvWEmE4BBBls4aJXFiE= +modernc.org/ccgo/v3 v3.12.53/go.mod h1:8xWGGTFkdFEWBEsUmi+DBjwu/WLy3SSOrqEmKUjMeEg= +modernc.org/ccgo/v3 v3.12.54/go.mod h1:yANKFTm9llTFVX1FqNKHE0aMcQb1fuPJx6p8AcUx+74= +modernc.org/ccgo/v3 v3.12.55/go.mod h1:rsXiIyJi9psOwiBkplOaHye5L4MOOaCjHg1Fxkj7IeU= +modernc.org/ccgo/v3 v3.12.56/go.mod h1:ljeFks3faDseCkr60JMpeDb2GSO3TKAmrzm7q9YOcMU= +modernc.org/ccgo/v3 v3.12.57/go.mod h1:hNSF4DNVgBl8wYHpMvPqQWDQx8luqxDnNGCMM4NFNMc= +modernc.org/ccgo/v3 v3.12.60/go.mod h1:k/Nn0zdO1xHVWjPYVshDeWKqbRWIfif5dtsIOCUVMqM= +modernc.org/ccgo/v3 v3.12.65/go.mod h1:D6hQtKxPNZiY6wDBtehSGKFKmyXn53F8nGTpH+POmS4= +modernc.org/ccgo/v3 v3.12.66/go.mod h1:jUuxlCFZTUZLMV08s7B1ekHX5+LIAurKTTaugUr/EhQ= +modernc.org/ccgo/v3 v3.12.67/go.mod h1:Bll3KwKvGROizP2Xj17GEGOTrlvB1XcVaBrC90ORO84= +modernc.org/ccgo/v3 v3.12.73/go.mod h1:hngkB+nUUqzOf3iqsM48Gf1FZhY599qzVg1iX+BT3cQ= +modernc.org/ccgo/v3 v3.12.81/go.mod h1:p2A1duHoBBg1mFtYvnhAnQyI6vL0uw5PGYLSIgF6rYY= +modernc.org/ccgo/v3 v3.12.82 h1:wudcnJyjLj1aQQCXF3IM9Gz2X6UNjw+afIghzdtn0v8= +modernc.org/ccgo/v3 v3.12.82/go.mod h1:ApbflUfa5BKadjHynCficldU1ghjen84tuM5jRynB7w= +modernc.org/ccorpus v1.11.1 h1:K0qPfpVG1MJh5BYazccnmhywH4zHuOgJXgbjzyp6dWA= +modernc.org/ccorpus v1.11.1/go.mod h1:2gEUTrWqdpH2pXsmTM1ZkjeSrUWDpjMu2T6m29L/ErQ= +modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= +modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM= +modernc.org/libc v1.9.8/go.mod h1:U1eq8YWr/Kc1RWCMFUWEdkTg8OTcfLw2kY8EDwl039w= +modernc.org/libc v1.9.11/go.mod h1:NyF3tsA5ArIjJ83XB0JlqhjTabTCHm9aX4XMPHyQn0Q= +modernc.org/libc v1.11.0/go.mod h1:2lOfPmj7cz+g1MrPNmX65QCzVxgNq2C5o0jdLY2gAYg= +modernc.org/libc v1.11.2/go.mod h1:ioIyrl3ETkugDO3SGZ+6EOKvlP3zSOycUETe4XM4n8M= +modernc.org/libc v1.11.5/go.mod h1:k3HDCP95A6U111Q5TmG3nAyUcp3kR5YFZTeDS9v8vSU= +modernc.org/libc v1.11.6/go.mod h1:ddqmzR6p5i4jIGK1d/EiSw97LBcE3dK24QEwCFvgNgE= +modernc.org/libc v1.11.11/go.mod h1:lXEp9QOOk4qAYOtL3BmMve99S5Owz7Qyowzvg6LiZso= +modernc.org/libc v1.11.13/go.mod h1:ZYawJWlXIzXy2Pzghaf7YfM8OKacP3eZQI81PDLFdY8= +modernc.org/libc v1.11.16/go.mod h1:+DJquzYi+DMRUtWI1YNxrlQO6TcA5+dRRiq8HWBWRC8= +modernc.org/libc v1.11.19/go.mod h1:e0dgEame6mkydy19KKaVPBeEnyJB4LGNb0bBH1EtQ3I= +modernc.org/libc v1.11.24/go.mod h1:FOSzE0UwookyT1TtCJrRkvsOrX2k38HoInhw+cSCUGk= +modernc.org/libc v1.11.26/go.mod h1:SFjnYi9OSd2W7f4ct622o/PAYqk7KHv6GS8NZULIjKY= +modernc.org/libc v1.11.27/go.mod h1:zmWm6kcFXt/jpzeCgfvUNswM0qke8qVwxqZrnddlDiE= +modernc.org/libc v1.11.28/go.mod h1:Ii4V0fTFcbq3qrv3CNn+OGHAvzqMBvC7dBNyC4vHZlg= +modernc.org/libc v1.11.31/go.mod h1:FpBncUkEAtopRNJj8aRo29qUiyx5AvAlAxzlx9GNaVM= +modernc.org/libc v1.11.34/go.mod h1:+Tzc4hnb1iaX/SKAutJmfzES6awxfU1BPvrrJO0pYLg= +modernc.org/libc v1.11.37/go.mod h1:dCQebOwoO1046yTrfUE5nX1f3YpGZQKNcITUYWlrAWo= +modernc.org/libc v1.11.39/go.mod h1:mV8lJMo2S5A31uD0k1cMu7vrJbSA3J3waQJxpV4iqx8= +modernc.org/libc v1.11.42/go.mod h1:yzrLDU+sSjLE+D4bIhS7q1L5UwXDOw99PLSX0BlZvSQ= +modernc.org/libc v1.11.44/go.mod h1:KFq33jsma7F5WXiYelU8quMJasCCTnHK0mkri4yPHgA= +modernc.org/libc v1.11.45/go.mod h1:Y192orvfVQQYFzCNsn+Xt0Hxt4DiO4USpLNXBlXg/tM= +modernc.org/libc v1.11.47/go.mod h1:tPkE4PzCTW27E6AIKIR5IwHAQKCAtudEIeAV1/SiyBg= +modernc.org/libc v1.11.49/go.mod h1:9JrJuK5WTtoTWIFQ7QjX2Mb/bagYdZdscI3xrvHbXjE= +modernc.org/libc v1.11.51/go.mod h1:R9I8u9TS+meaWLdbfQhq2kFknTW0O3aw3kEMqDDxMaM= +modernc.org/libc v1.11.53/go.mod h1:5ip5vWYPAoMulkQ5XlSJTy12Sz5U6blOQiYasilVPsU= +modernc.org/libc v1.11.54/go.mod h1:S/FVnskbzVUrjfBqlGFIPA5m7UwB3n9fojHhCNfSsnw= +modernc.org/libc v1.11.55/go.mod h1:j2A5YBRm6HjNkoSs/fzZrSxCuwWqcMYTDPLNx0URn3M= +modernc.org/libc v1.11.56/go.mod h1:pakHkg5JdMLt2OgRadpPOTnyRXm/uzu+Yyg/LSLdi18= +modernc.org/libc v1.11.58/go.mod h1:ns94Rxv0OWyoQrDqMFfWwka2BcaF6/61CqJRK9LP7S8= +modernc.org/libc v1.11.70/go.mod h1:DUOmMYe+IvKi9n6Mycyx3DbjfzSKrdr/0Vgt3j7P5gw= +modernc.org/libc v1.11.71/go.mod h1:DUOmMYe+IvKi9n6Mycyx3DbjfzSKrdr/0Vgt3j7P5gw= +modernc.org/libc v1.11.75/go.mod h1:dGRVugT6edz361wmD9gk6ax1AbDSe0x5vji0dGJiPT0= +modernc.org/libc v1.11.82/go.mod h1:NF+Ek1BOl2jeC7lw3a7Jj5PWyHPwWD4aq3wVKxqV1fI= +modernc.org/libc v1.11.86/go.mod h1:ePuYgoQLmvxdNT06RpGnaDKJmDNEkV7ZPKI2jnsvZoE= +modernc.org/libc v1.11.87 h1:PzIzOqtlzMDDcCzJ5cUP6h/Ku6Fa9iyflP2ccTY64aE= +modernc.org/libc v1.11.87/go.mod h1:Qvd5iXTeLhI5PS0XSyqMY99282y+3euapQFxM7jYnpY= +modernc.org/mathutil v1.1.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/mathutil v1.4.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/mathutil v1.4.1 h1:ij3fYGe8zBF4Vu+g0oT7mB06r8sqGWKuJu1yXeR4by8= +modernc.org/mathutil v1.4.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.0.4/go.mod h1:nV2OApxradM3/OVbs2/0OsP6nPfakXpi50C7dcoHXlc= +modernc.org/memory v1.0.5 h1:XRch8trV7GgvTec2i7jc33YlUI0RKVDBvZ5eZ5m8y14= +modernc.org/memory v1.0.5/go.mod h1:B7OYswTRnfGg+4tDH1t1OeUNnsy2viGTdME4tzd+IjM= +modernc.org/opt v0.1.1 h1:/0RX92k9vwVeDXj+Xn23DKp2VJubL7k8qNffND6qn3A= +modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sqlite v1.14.2 h1:ohsW2+e+Qe2To1W6GNezzKGwjXwSax6R+CrhRxVaFbE= +modernc.org/sqlite v1.14.2/go.mod h1:yqfn85u8wVOE6ub5UT8VI9JjhrwBUUCNyTACN0h6Sx8= +modernc.org/strutil v1.1.1 h1:xv+J1BXY3Opl2ALrBwyfEikFAj8pmqcpnfmuwUwcozs= +modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw= +modernc.org/tcl v1.8.13 h1:V0sTNBw0Re86PvXZxuCub3oO9WrSTqALgrwNZNvLFGw= +modernc.org/tcl v1.8.13/go.mod h1:V+q/Ef0IJaNUSECieLU4o+8IScapxnMyFV6i/7uQlAY= +modernc.org/token v1.0.0 h1:a0jaWiNMDhDUtqOj09wvjWWAqd3q7WpBulmL9H2egsk= +modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +modernc.org/z v1.2.19 h1:BGyRFWhDVn5LFS5OcX4Yd/MlpRTOc7hOPTdcIpCiUao= +modernc.org/z v1.2.19/go.mod h1:+ZpP0pc4zz97eukOzW3xagV/lS82IpPN9NGG5pNF9vY= +sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= +sourcegraph.com/sourcegraph/appdash v0.0.0-20190731080439-ebfcffb1b5c0/go.mod h1:hI742Nqp5OhwiqlzhgfbWU4mW4yO10fP+LoT9WOswdU= +xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978 h1:bvLlAPW1ZMTWA32LuZMBEGHAUOcATZjzHcotf3SWweM= +xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= diff --git a/integrations/cache_test.go b/integrations/cache_test.go index 44e817b1..2caeaa34 100644 --- a/integrations/cache_test.go +++ b/integrations/cache_test.go @@ -26,7 +26,7 @@ func TestCacheFind(t *testing.T) { cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000) testEngine.SetDefaultCacher(cacher) - assert.NoError(t, testEngine.Sync2(new(MailBox))) + assert.NoError(t, testEngine.Sync(new(MailBox))) var inserts = []*MailBox{ { @@ -62,7 +62,8 @@ func TestCacheFind(t *testing.T) { } boxes = make([]MailBox, 0, 2) - assert.NoError(t, testEngine.Alias("a").Where("a.id > -1").Asc("a.id").Find(&boxes)) + assert.NoError(t, testEngine.Alias("a").Where("`a`.`id`> -1"). + Asc("`a`.`id`").Find(&boxes)) assert.EqualValues(t, 2, len(boxes)) for i, box := range boxes { assert.Equal(t, inserts[i].Id, box.Id) @@ -77,7 +78,8 @@ func TestCacheFind(t *testing.T) { } boxes2 := make([]MailBox4, 0, 2) - assert.NoError(t, testEngine.Table("mail_box").Where("mail_box.id > -1").Asc("mail_box.id").Find(&boxes2)) + assert.NoError(t, testEngine.Table("mail_box").Where("`mail_box`.`id` > -1"). + Asc("mail_box.id").Find(&boxes2)) assert.EqualValues(t, 2, len(boxes2)) for i, box := range boxes2 { assert.Equal(t, inserts[i].Id, box.Id) @@ -101,7 +103,7 @@ func TestCacheFind2(t *testing.T) { cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000) testEngine.SetDefaultCacher(cacher) - assert.NoError(t, testEngine.Sync2(new(MailBox2))) + assert.NoError(t, testEngine.Sync(new(MailBox2))) var inserts = []*MailBox2{ { @@ -152,7 +154,7 @@ func TestCacheGet(t *testing.T) { cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000) testEngine.SetDefaultCacher(cacher) - assert.NoError(t, testEngine.Sync2(new(MailBox3))) + assert.NoError(t, testEngine.Sync(new(MailBox3))) var inserts = []*MailBox3{ { @@ -164,14 +166,14 @@ func TestCacheGet(t *testing.T) { assert.NoError(t, err) var box1 MailBox3 - has, err := testEngine.Where("id = ?", inserts[0].Id).Get(&box1) + has, err := testEngine.Where("`id` = ?", inserts[0].Id).Get(&box1) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box1.Username) assert.EqualValues(t, "pass1", box1.Password) var box2 MailBox3 - has, err = testEngine.Where("id = ?", inserts[0].Id).Get(&box2) + has, err = testEngine.Where("`id` = ?", inserts[0].Id).Get(&box2) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box2.Username) diff --git a/integrations/engine_dm_test.go b/integrations/engine_dm_test.go new file mode 100644 index 00000000..3b195ef8 --- /dev/null +++ b/integrations/engine_dm_test.go @@ -0,0 +1,14 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build dm +// +build dm + +package integrations + +import "xorm.io/xorm/schemas" + +func init() { + dbtypes = append(dbtypes, schemas.DAMENG) +} diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 0e5d3424..730a424e 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -14,12 +14,15 @@ import ( "xorm.io/xorm" "xorm.io/xorm/schemas" + _ "gitee.com/travelliu/dm" _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" + _ "github.com/jackc/pgx/v4/stdlib" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" _ "github.com/ziutek/mymysql/godrv" + _ "modernc.org/sqlite" ) func TestPing(t *testing.T) { @@ -50,17 +53,18 @@ func TestAutoTransaction(t *testing.T) { Created time.Time `xorm:"created"` } - assert.NoError(t, testEngine.Sync2(new(TestTx))) + assert.NoError(t, testEngine.Sync(new(TestTx))) engine := testEngine.(*xorm.Engine) // will success - engine.Transaction(func(session *xorm.Session) (interface{}, error) { + _, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) { _, err := session.Insert(TestTx{Msg: "hi"}) assert.NoError(t, err) return nil, nil }) + assert.NoError(t, err) has, err := engine.Exist(&TestTx{Msg: "hi"}) assert.NoError(t, err) @@ -84,7 +88,7 @@ func assertSync(t *testing.T, beans ...interface{}) { for _, bean := range beans { t.Run(testEngine.TableName(bean, true), func(t *testing.T) { assert.NoError(t, testEngine.DropTables(bean)) - assert.NoError(t, testEngine.Sync2(bean)) + assert.NoError(t, testEngine.Sync(bean)) }) } } @@ -132,11 +136,14 @@ func TestDump(t *testing.T) { } } +var dbtypes = []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} + func TestDumpTables(t *testing.T) { assert.NoError(t, PrepareEngine()) type TestDumpTableStruct struct { Id int64 + Data []byte `xorm:"BLOB"` Name string IsMan bool Created time.Time `xorm:"created"` @@ -144,13 +151,18 @@ func TestDumpTables(t *testing.T) { assertSync(t, new(TestDumpTableStruct)) - testEngine.Insert([]TestDumpTableStruct{ + _, err := testEngine.Insert([]TestDumpTableStruct{ {Name: "1", IsMan: true}, - {Name: "2\n"}, - {Name: "3;"}, - {Name: "4\n;\n''"}, - {Name: "5'\n"}, + {Name: "2\n", Data: []byte{'\000', '\001', '\002'}}, + {Name: "3;", Data: []byte("0x000102")}, + {Name: "4\n;\n''", Data: []byte("Help")}, + {Name: "5'\n", Data: []byte("0x48656c70")}, + {Name: "6\\n'\n", Data: []byte("48656c70")}, + {Name: "7\\n'\r\n", Data: []byte("7\\n'\r\n")}, + {Name: "x0809ee"}, + {Name: "090a10"}, }) + assert.NoError(t, err) fp := fmt.Sprintf("%v-table.sql", testEngine.Dialect().URI().DBType) os.Remove(fp) @@ -167,12 +179,41 @@ func TestDumpTables(t *testing.T) { assert.NoError(t, err) assert.NoError(t, sess.Commit()) - for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} { + for _, tp := range dbtypes { name := fmt.Sprintf("dump_%v-table.sql", tp) t.Run(name, func(t *testing.T) { assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) }) } + + assert.NoError(t, testEngine.DropTables(new(TestDumpTableStruct))) + + importPath := fmt.Sprintf("dump_%v-table.sql", testEngine.Dialect().URI().DBType) + t.Run("import_"+importPath, func(t *testing.T) { + sess := testEngine.NewSession() + defer sess.Close() + assert.NoError(t, sess.Begin()) + _, err = sess.ImportFile(importPath) + assert.NoError(t, err) + assert.NoError(t, sess.Commit()) + }) +} + +func TestDumpTables2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestDumpTableStruct2 struct { + Id int64 + Created time.Time `xorm:"Default CURRENT_TIMESTAMP"` + } + + assertSync(t, new(TestDumpTableStruct2)) + + fp := fmt.Sprintf("./dump2-%v-table.sql", testEngine.Dialect().URI().DBType) + os.Remove(fp) + tb, err := testEngine.TableInfo(new(TestDumpTableStruct2)) + assert.NoError(t, err) + assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, fp)) } func TestSetSchema(t *testing.T) { @@ -186,3 +227,95 @@ func TestSetSchema(t *testing.T) { assert.EqualValues(t, oldSchema, testEngine.Dialect().URI().Schema) } } + +func TestImport(t *testing.T) { + if testEngine.Dialect().URI().DBType != schemas.MYSQL { + t.Skip() + return + } + sess := testEngine.NewSession() + defer sess.Close() + assert.NoError(t, sess.Begin()) + _, err := sess.ImportFile("./testdata/import1.sql") + assert.NoError(t, err) + assert.NoError(t, sess.Commit()) + + assert.NoError(t, sess.Begin()) + _, err = sess.ImportFile("./testdata/import2.sql") + assert.NoError(t, err) + assert.NoError(t, sess.Commit()) +} + +func TestDBVersion(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + version, err := testEngine.DBVersion() + assert.NoError(t, err) + + fmt.Println(testEngine.Dialect().URI().DBType, "version is", version) +} + +func TestGetColumnsComment(t *testing.T) { + switch testEngine.Dialect().URI().DBType { + case schemas.POSTGRES, schemas.MYSQL: + default: + t.Skip() + return + } + comment := "this is a comment" + type TestCommentStruct struct { + HasComment int `xorm:"comment('this is a comment')"` + NoComment int + } + + assertSync(t, new(TestCommentStruct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + tableName := testEngine.GetColumnMapper().Obj2Table("TestCommentStruct") + var hasComment, noComment string + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn(testEngine.GetColumnMapper().Obj2Table("HasComment")) + assert.NotNil(t, col) + hasComment = col.Comment + col2 := table.GetColumn(testEngine.GetColumnMapper().Obj2Table("NoComment")) + assert.NotNil(t, col2) + noComment = col2.Comment + break + } + } + assert.Equal(t, comment, hasComment) + assert.Zero(t, noComment) +} + +func TestGetColumnsLength(t *testing.T) { + var max_length int64 + switch testEngine.Dialect().URI().DBType { + case schemas.POSTGRES: + max_length = 0 + case schemas.MYSQL: + max_length = 65535 + default: + t.Skip() + return + } + + type TestLengthStringStruct struct { + Content string `xorm:"TEXT NOT NULL"` + } + + assertSync(t, new(TestLengthStringStruct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + tableLengthStringName := testEngine.GetColumnMapper().Obj2Table("TestLengthStringStruct") + for _, table := range tables { + if table.Name == tableLengthStringName { + col := table.GetColumn("content") + assert.Equal(t, col.Length, max_length) + assert.Zero(t, col.Length2) + break + } + } +} diff --git a/integrations/performance_test.go b/integrations/performance_test.go new file mode 100644 index 00000000..49183717 --- /dev/null +++ b/integrations/performance_test.go @@ -0,0 +1,104 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func BenchmarkGetVars(b *testing.B) { + b.StopTimer() + + assert.NoError(b, PrepareEngine()) + testEngine.ShowSQL(false) + + type BenchmarkGetVars struct { + Id int64 + Name string + } + + assert.NoError(b, testEngine.Sync(new(BenchmarkGetVars))) + + var v = BenchmarkGetVars{ + Name: "myname", + } + _, err := testEngine.Insert(&v) + assert.NoError(b, err) + + b.StartTimer() + var myname string + for i := 0; i < b.N; i++ { + has, err := testEngine.Cols("name").Table("benchmark_get_vars").Where("`id`=?", v.Id).Get(&myname) + b.StopTimer() + myname = "" + assert.True(b, has) + assert.NoError(b, err) + b.StartTimer() + } +} + +func BenchmarkGetStruct(b *testing.B) { + b.StopTimer() + + assert.NoError(b, PrepareEngine()) + testEngine.ShowSQL(false) + + type BenchmarkGetStruct struct { + Id int64 + Name string + } + + assert.NoError(b, testEngine.Sync(new(BenchmarkGetStruct))) + + var v = BenchmarkGetStruct{ + Name: "myname", + } + _, err := testEngine.Insert(&v) + assert.NoError(b, err) + + b.StartTimer() + var myname BenchmarkGetStruct + for i := 0; i < b.N; i++ { + has, err := testEngine.ID(v.Id).Get(&myname) + b.StopTimer() + myname.Id = 0 + myname.Name = "" + assert.True(b, has) + assert.NoError(b, err) + b.StartTimer() + } +} + +func BenchmarkFindStruct(b *testing.B) { + b.StopTimer() + + assert.NoError(b, PrepareEngine()) + testEngine.ShowSQL(false) + + type BenchmarkFindStruct struct { + Id int64 + Name string + } + + assert.NoError(b, testEngine.Sync(new(BenchmarkFindStruct))) + + var v = BenchmarkFindStruct{ + Name: "myname", + } + _, err := testEngine.Insert(&v) + assert.NoError(b, err) + + var mynames = make([]BenchmarkFindStruct, 0, 1) + b.StartTimer() + for i := 0; i < b.N; i++ { + err := testEngine.Find(&mynames) + b.StopTimer() + mynames = make([]BenchmarkFindStruct, 0, 1) + assert.NoError(b, err) + b.StartTimer() + } +} diff --git a/integrations/processors_test.go b/integrations/processors_test.go index e349988d..4c383437 100644 --- a/integrations/processors_test.go +++ b/integrations/processors_test.go @@ -23,7 +23,7 @@ func TestBefore_Get(t *testing.T) { Val string `xorm:"-"` } - assert.NoError(t, testEngine.Sync2(new(BeforeTable))) + assert.NoError(t, testEngine.Sync(new(BeforeTable))) cnt, err := testEngine.Insert(&BeforeTable{ Name: "test", @@ -50,7 +50,7 @@ func TestBefore_Find(t *testing.T) { Val string `xorm:"-"` } - assert.NoError(t, testEngine.Sync2(new(BeforeTable2))) + assert.NoError(t, testEngine.Sync(new(BeforeTable2))) cnt, err := testEngine.Insert([]BeforeTable2{ {Name: "test1"}, @@ -104,7 +104,7 @@ func (p *ProcessorsStruct) BeforeDelete() { } func (p *ProcessorsStruct) BeforeSet(col string, cell xorm.Cell) { - p.BeforeSetFlag = p.BeforeSetFlag + 1 + p.BeforeSetFlag++ } func (p *ProcessorsStruct) AfterInsert() { @@ -120,7 +120,7 @@ func (p *ProcessorsStruct) AfterDelete() { } func (p *ProcessorsStruct) AfterSet(col string, cell xorm.Cell) { - p.AfterSetFlag = p.AfterSetFlag + 1 + p.AfterSetFlag++ } func TestProcessors(t *testing.T) { diff --git a/integrations/rows_test.go b/integrations/rows_test.go index f68030a4..e354b75e 100644 --- a/integrations/rows_test.go +++ b/integrations/rows_test.go @@ -18,7 +18,7 @@ func TestRows(t *testing.T) { IsMan bool } - assert.NoError(t, testEngine.Sync2(new(UserRows))) + assert.NoError(t, testEngine.Sync(new(UserRows))) cnt, err := testEngine.Insert(&UserRows{ IsMan: true, @@ -70,7 +70,7 @@ func TestRows(t *testing.T) { } assert.EqualValues(t, 1, cnt) - var tbName = testEngine.Quote(testEngine.TableName(user, true)) + tbName := testEngine.Quote(testEngine.TableName(user, true)) rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows)) assert.NoError(t, err) defer rows2.Close() @@ -92,9 +92,9 @@ func TestRowsMyTableName(t *testing.T) { IsMan bool } - var tableName = "user_rows_my_table_name" + tableName := "user_rows_my_table_name" - assert.NoError(t, testEngine.Table(tableName).Sync2(new(UserRowsMyTable))) + assert.NoError(t, testEngine.Table(tableName).Sync(new(UserRowsMyTable))) cnt, err := testEngine.Table(tableName).Insert(&UserRowsMyTable{ IsMan: true, @@ -141,7 +141,7 @@ func (UserRowsSpecTable) TableName() string { func TestRowsSpecTableName(t *testing.T) { assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Sync2(new(UserRowsSpecTable))) + assert.NoError(t, testEngine.Sync(new(UserRowsSpecTable))) cnt, err := testEngine.Insert(&UserRowsSpecTable{ IsMan: true, @@ -160,5 +160,121 @@ func TestRowsSpecTableName(t *testing.T) { assert.NoError(t, err) cnt++ } + assert.NoError(t, rows.Err()) assert.EqualValues(t, 1, cnt) } + +func TestRowsScanVars(t *testing.T) { + type RowsScanVars struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, PrepareEngine()) + assert.NoError(t, testEngine.Sync2(new(RowsScanVars))) + + cnt, err := testEngine.Insert(&RowsScanVars{ + Name: "xlw", + Age: 42, + }, &RowsScanVars{ + Name: "xlw2", + Age: 24, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + rows, err := testEngine.Cols("name", "age").Rows(new(RowsScanVars)) + assert.NoError(t, err) + defer rows.Close() + + cnt = 0 + for rows.Next() { + var name string + var age int + err = rows.Scan(&name, &age) + assert.NoError(t, err) + if cnt == 0 { + assert.EqualValues(t, "xlw", name) + assert.EqualValues(t, 42, age) + } else if cnt == 1 { + assert.EqualValues(t, "xlw2", name) + assert.EqualValues(t, 24, age) + } + cnt++ + } + assert.NoError(t, rows.Err()) + assert.EqualValues(t, 2, cnt) +} + +func TestRowsScanBytes(t *testing.T) { + type RowsScanBytes struct { + Id int64 + Bytes1 []byte + Bytes2 []byte + } + + assert.NoError(t, PrepareEngine()) + assert.NoError(t, testEngine.Sync(new(RowsScanBytes))) + + cnt, err := testEngine.Insert(&RowsScanBytes{ + Bytes1: []byte("bytes1"), + Bytes2: []byte("bytes2"), + }, &RowsScanBytes{ + Bytes1: []byte("bytes1-1"), + Bytes2: []byte("bytes2-2"), + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + { + rows, err := testEngine.Cols("bytes1, bytes2").Rows(new(RowsScanBytes)) + assert.NoError(t, err) + defer rows.Close() + + cnt = 0 + var bytes1 []byte + var bytes2 []byte + for rows.Next() { + err = rows.Scan(&bytes1, &bytes2) + assert.NoError(t, err) + if cnt == 0 { + assert.EqualValues(t, []byte("bytes1"), bytes1) + assert.EqualValues(t, []byte("bytes2"), bytes2) + } else if cnt == 1 { + // bytes1 now should be `bytes1` but will be override + assert.EqualValues(t, []byte("bytes1-1"), bytes1) + assert.EqualValues(t, []byte("bytes2-2"), bytes2) + } + cnt++ + } + assert.NoError(t, rows.Err()) + assert.EqualValues(t, 2, cnt) + rows.Close() + } + + { + rows, err := testEngine.Cols("bytes1, bytes2").Rows(new(RowsScanBytes)) + assert.NoError(t, err) + defer rows.Close() + + cnt = 0 + var rsb RowsScanBytes + for rows.Next() { + err = rows.Scan(&rsb) + assert.NoError(t, err) + if cnt == 0 { + assert.EqualValues(t, []byte("bytes1"), rsb.Bytes1) + assert.EqualValues(t, []byte("bytes2"), rsb.Bytes2) + } else if cnt == 1 { + // bytes1 now should be `bytes1` but will be override + assert.EqualValues(t, []byte("bytes1-1"), rsb.Bytes1) + assert.EqualValues(t, []byte("bytes2-2"), rsb.Bytes2) + } + cnt++ + } + assert.NoError(t, rows.Err()) + assert.EqualValues(t, 2, cnt) + rows.Close() + } +} diff --git a/integrations/session_cols_test.go b/integrations/session_cols_test.go index b74c6f8a..462ea7c7 100644 --- a/integrations/session_cols_test.go +++ b/integrations/session_cols_test.go @@ -20,7 +20,7 @@ func TestSetExpr(t *testing.T) { Title string } - assert.NoError(t, testEngine.Sync2(new(UserExprIssue))) + assert.NoError(t, testEngine.Sync(new(UserExprIssue))) var issue = UserExprIssue{ Title: "my issue", @@ -36,7 +36,7 @@ func TestSetExpr(t *testing.T) { Show bool } - assert.NoError(t, testEngine.Sync2(new(UserExpr))) + assert.NoError(t, testEngine.Sync(new(UserExpr))) cnt, err = testEngine.Insert(&UserExpr{ Show: true, @@ -45,7 +45,7 @@ func TestSetExpr(t *testing.T) { assert.EqualValues(t, 1, cnt) var not = "NOT" - if testEngine.Dialect().URI().DBType == schemas.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL || testEngine.Dialect().URI().DBType == schemas.DAMENG { not = "~" } cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) @@ -54,9 +54,9 @@ func TestSetExpr(t *testing.T) { tableName := testEngine.TableName(new(UserExprIssue), true) cnt, err = testEngine.SetExpr("issue_id", - builder.Select("id"). - From(tableName). - Where(builder.Eq{"id": issue.Id})). + builder.Select("`id`"). + From(testEngine.Quote(tableName)). + Where(builder.Eq{"`id`": issue.Id})). ID(1). Update(new(UserExpr)) assert.NoError(t, err) diff --git a/integrations/session_cond_test.go b/integrations/session_cond_test.go index a0a91cad..0597d74e 100644 --- a/integrations/session_cond_test.go +++ b/integrations/session_cond_test.go @@ -37,49 +37,50 @@ func TestBuilder(t *testing.T) { assert.NoError(t, err) var cond Condition - has, err := testEngine.Where(builder.Eq{"col_name": "col1"}).Get(&cond) + var q = testEngine.Quote + has, err := testEngine.Where(builder.Eq{q("col_name"): "col1"}).Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") - has, err = testEngine.Where(builder.Eq{"col_name": "col1"}. - And(builder.Eq{"op": OpEqual})). + has, err = testEngine.Where(builder.Eq{q("col_name"): "col1"}. + And(builder.Eq{q("op"): OpEqual})). NoAutoCondition(). Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") - has, err = testEngine.Where(builder.Eq{"col_name": "col1", "op": OpEqual, "value": "1"}). + has, err = testEngine.Where(builder.Eq{q("col_name"): "col1", q("op"): OpEqual, q("value"): "1"}). NoAutoCondition(). Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") - has, err = testEngine.Where(builder.Eq{"col_name": "col1"}. - And(builder.Neq{"op": OpEqual})). + has, err = testEngine.Where(builder.Eq{q("col_name"): "col1"}. + And(builder.Neq{q("op"): OpEqual})). NoAutoCondition(). Get(&cond) assert.NoError(t, err) assert.Equal(t, false, has, "records should not exist") var conds []Condition - err = testEngine.Where(builder.Eq{"col_name": "col1"}. - And(builder.Eq{"op": OpEqual})). + err = testEngine.Where(builder.Eq{q("col_name"): "col1"}. + And(builder.Eq{q("op"): OpEqual})). Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") conds = make([]Condition, 0) - err = testEngine.Where(builder.Like{"col_name", "col"}).Find(&conds) + err = testEngine.Where(builder.Like{q("col_name"), "col"}).Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") conds = make([]Condition, 0) - err = testEngine.Where(builder.Expr("col_name = ?", "col1")).Find(&conds) + err = testEngine.Where(builder.Expr(q("col_name")+" = ?", "col1")).Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") conds = make([]Condition, 0) - err = testEngine.Where(builder.In("col_name", "col1", "col2")).Find(&conds) + err = testEngine.Where(builder.In(q("col_name"), "col1", "col2")).Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") @@ -91,8 +92,8 @@ func TestBuilder(t *testing.T) { // complex condtions var where = builder.NewCond() if true { - where = where.And(builder.Eq{"col_name": "col1"}) - where = where.Or(builder.And(builder.In("col_name", "col1", "col2"), builder.Expr("col_name = ?", "col1"))) + where = where.And(builder.Eq{q("col_name"): "col1"}) + where = where.Or(builder.And(builder.In(q("col_name"), "col1", "col2"), builder.Expr(q("col_name")+" = ?", "col1"))) } conds = make([]Condition, 0) @@ -103,7 +104,7 @@ func TestBuilder(t *testing.T) { func TestIn(t *testing.T) { assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Sync2(new(Userinfo))) + assert.NoError(t, testEngine.Sync(new(Userinfo))) cnt, err := testEngine.Insert([]Userinfo{ { @@ -202,7 +203,7 @@ func TestFindAndCount(t *testing.T) { Name string } - assert.NoError(t, testEngine.Sync2(new(FindAndCount))) + assert.NoError(t, testEngine.Sync(new(FindAndCount))) _, err := testEngine.Insert([]FindAndCount{ { @@ -215,7 +216,7 @@ func TestFindAndCount(t *testing.T) { assert.NoError(t, err) var results []FindAndCount - sess := testEngine.Where("name = ?", "test1") + sess := testEngine.Where("`name` = ?", "test1") conds := sess.Conds() err = sess.Find(&results) assert.NoError(t, err) diff --git a/integrations/session_count_test.go b/integrations/session_count_test.go new file mode 100644 index 00000000..13d84edb --- /dev/null +++ b/integrations/session_count_test.go @@ -0,0 +1,172 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "xorm.io/builder" +) + +func TestCount(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoCount struct { + Departname string + } + assert.NoError(t, testEngine.Sync(new(UserinfoCount))) + + colName := testEngine.GetColumnMapper().Obj2Table("Departname") + var cond builder.Cond = builder.Eq{ + "`" + colName + "`": "dev", + } + + total, err := testEngine.Where(cond).Count(new(UserinfoCount)) + assert.NoError(t, err) + assert.EqualValues(t, 0, total) + + cnt, err := testEngine.Insert(&UserinfoCount{ + Departname: "dev", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + total, err = testEngine.Where(cond).Count(new(UserinfoCount)) + assert.NoError(t, err) + assert.EqualValues(t, 1, total) + + total, err = testEngine.Where(cond).Table("userinfo_count").Count() + assert.NoError(t, err) + assert.EqualValues(t, 1, total) + + total, err = testEngine.Table("userinfo_count").Count() + assert.NoError(t, err) + assert.EqualValues(t, 1, total) +} + +func TestSQLCount(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoCount2 struct { + Id int64 + Departname string + } + + type UserinfoBooks struct { + Id int64 + Pid int64 + IsOpen bool + } + + assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) + + total, err := testEngine.SQL("SELECT count(`id`) FROM " + testEngine.Quote(testEngine.TableName("userinfo_count2", true))). + Count() + assert.NoError(t, err) + assert.EqualValues(t, 0, total) +} + +func TestCountWithOthers(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type CountWithOthers struct { + Id int64 + Name string + } + + assertSync(t, new(CountWithOthers)) + + _, err := testEngine.Insert(&CountWithOthers{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(&CountWithOthers{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.OrderBy("`id` desc").Limit(1).Count(new(CountWithOthers)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +} + +type CountWithTableName struct { + Id int64 + Name string +} + +func (CountWithTableName) TableName() string { + return "count_with_table_name1" +} + +func TestWithTableName(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.OrderBy("`id` desc").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) + + total, err = testEngine.OrderBy("`id` desc").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +} + +func TestCountWithSelectCols(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.Cols("id").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) + + total, err = testEngine.Select("count(`id`)").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +} + +func TestCountWithGroupBy(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "1", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "2", + }) + assert.NoError(t, err) + + cnt, err := testEngine.GroupBy("`name`").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) +} diff --git a/integrations/session_delete_test.go b/integrations/session_delete_test.go index f3565963..680c3215 100644 --- a/integrations/session_delete_test.go +++ b/integrations/session_delete_test.go @@ -22,7 +22,7 @@ func TestDelete(t *testing.T) { IsMan bool } - assert.NoError(t, testEngine.Sync2(new(UserinfoDelete))) + assert.NoError(t, testEngine.Sync(new(UserinfoDelete))) session := testEngine.NewSession() defer session.Close() @@ -97,6 +97,7 @@ func TestDeleted(t *testing.T) { // Test normal Find() var records1 []Deleted err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&records1, &Deleted{}) + assert.NoError(t, err) assert.EqualValues(t, 3, len(records1)) // Test normal Get() @@ -132,6 +133,7 @@ func TestDeleted(t *testing.T) { record2 := &Deleted{} has, err = testEngine.ID(2).Get(record2) assert.NoError(t, err) + assert.True(t, has) assert.True(t, record2.DeletedAt.IsZero()) // Test find all records whatever `deleted`. @@ -206,12 +208,12 @@ func TestUnscopeDelete(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var nowUnix = time.Now().Unix() + nowUnix := time.Now().Unix() var s UnscopeDeleteStruct cnt, err = testEngine.ID(1).Delete(&s) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - assert.EqualValues(t, nowUnix, s.DeletedAt.Unix()) + assert.LessOrEqual(t, int(s.DeletedAt.Unix()-nowUnix), 1) var s1 UnscopeDeleteStruct has, err := testEngine.ID(1).Get(&s1) @@ -223,7 +225,7 @@ func TestUnscopeDelete(t *testing.T) { assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "test", s2.Name) - assert.EqualValues(t, nowUnix, s2.DeletedAt.Unix()) + assert.LessOrEqual(t, int(s2.DeletedAt.Unix()-nowUnix), 1) cnt, err = testEngine.ID(1).Unscoped().Delete(new(UnscopeDeleteStruct)) assert.NoError(t, err) @@ -239,3 +241,53 @@ func TestUnscopeDelete(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func TestDelete2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoDelete2 struct { + Uid int64 `xorm:"id pk not null autoincr"` + IsMan bool + } + + assert.NoError(t, testEngine.Sync(new(UserinfoDelete2))) + + user := UserinfoDelete2{} + cnt, err := testEngine.Insert(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Table("userinfo_delete2").In("id", []int{1}).Delete() + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + user2 := UserinfoDelete2{} + has, err := testEngine.ID(1).Get(&user2) + assert.NoError(t, err) + assert.False(t, has) +} + +func TestTruncate(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TruncateUser struct { + Uid int64 `xorm:"id pk not null autoincr"` + } + + assert.NoError(t, testEngine.Sync(new(TruncateUser))) + + cnt, err := testEngine.Insert(&TruncateUser{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + _, err = testEngine.Delete(&TruncateUser{}) + assert.Error(t, err) + + _, err = testEngine.Truncate(&TruncateUser{}) + assert.NoError(t, err) + + user2 := TruncateUser{} + has, err := testEngine.ID(1).Get(&user2) + assert.NoError(t, err) + assert.False(t, has) +} diff --git a/integrations/session_exist_test.go b/integrations/session_exist_test.go index 6247c91a..ca1e66ad 100644 --- a/integrations/session_exist_test.go +++ b/integrations/session_exist_test.go @@ -48,19 +48,19 @@ func TestExistStruct(t *testing.T) { assert.NoError(t, err) assert.False(t, has) - has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{}) + has, err = testEngine.Where("`name` = ?", "test1").Exist(&RecordExist{}) assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.Where("name = ?", "test2").Exist(&RecordExist{}) + has, err = testEngine.Where("`name` = ?", "test2").Exist(&RecordExist{}) assert.NoError(t, err) assert.False(t, has) - has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist() + has, err = testEngine.SQL("select * from "+testEngine.Quote(testEngine.TableName("record_exist", true))+" where `name` = ?", "test1").Exist() assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist() + has, err = testEngine.SQL("select * from "+testEngine.Quote(testEngine.TableName("record_exist", true))+" where `name` = ?", "test2").Exist() assert.NoError(t, err) assert.False(t, has) @@ -68,13 +68,17 @@ func TestExistStruct(t *testing.T) { assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() + has, err = testEngine.Table("record_exist").Where("`name` = ?", "test1").Exist() assert.NoError(t, err) assert.True(t, has) - has, err = testEngine.Table("record_exist").Where("name = ?", "test2").Exist() + has, err = testEngine.Table("record_exist").Where("`name` = ?", "test2").Exist() assert.NoError(t, err) assert.False(t, has) + + has, err = testEngine.Table(new(RecordExist)).ID(1).Cols("id").Exist() + assert.NoError(t, err) + assert.True(t, has) } func TestExistStructForJoin(t *testing.T) { @@ -95,7 +99,7 @@ func TestExistStructForJoin(t *testing.T) { Name string } - assert.NoError(t, testEngine.Sync2(new(Number), new(OrderList), new(Player))) + assert.NoError(t, testEngine.Sync(new(Number), new(OrderList), new(Player))) var ply Player cnt, err := testEngine.Insert(&ply) @@ -120,43 +124,43 @@ func TestExistStructForJoin(t *testing.T) { defer session.Close() session.Table("number"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid"). - Where("number.lid = ?", 1) + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`"). + Where("`number`.`lid` = ?", 1) has, err := session.Exist() assert.NoError(t, err) assert.True(t, has) session.Table("number"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid"). - Where("number.lid = ?", 2) + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`"). + Where("`number`.`lid` = ?", 2) has, err = session.Exist() assert.NoError(t, err) assert.False(t, has) session.Table("number"). - Select("order_list.id"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid"). - Where("order_list.id = ?", 1) + Select("`order_list`.`id`"). + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`"). + Where("`order_list`.`id` = ?", 1) has, err = session.Exist() assert.NoError(t, err) assert.True(t, has) session.Table("number"). Select("player.id"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid"). - Where("player.id = ?", 2) + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`"). + Where("`player`.`id` = ?", 2) has, err = session.Exist() assert.NoError(t, err) assert.False(t, has) session.Table("number"). Select("player.id"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid") + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`") has, err = session.Exist() assert.NoError(t, err) assert.True(t, has) @@ -170,15 +174,15 @@ func TestExistStructForJoin(t *testing.T) { session.Table("number"). Select("player.id"). - Join("INNER", "order_list", "order_list.id = number.lid"). - Join("LEFT", "player", "player.id = order_list.eid") + Join("INNER", "order_list", "`order_list`.`id` = `number`.`lid`"). + Join("LEFT", "player", "`player`.`id` = `order_list`.`eid`") has, err = session.Exist() assert.Error(t, err) assert.False(t, has) session.Table("number"). Select("player.id"). - Join("LEFT", "player", "player.id = number.lid") + Join("LEFT", "player", "`player`.`id` = `number`.`lid`") has, err = session.Exist() assert.NoError(t, err) assert.True(t, has) diff --git a/integrations/session_find_test.go b/integrations/session_find_test.go index c3e99183..5c2a4c68 100644 --- a/integrations/session_find_test.go +++ b/integrations/session_find_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "xorm.io/builder" + "xorm.io/xorm" "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" @@ -32,21 +34,21 @@ func TestJoinLimit(t *testing.T) { Name string } - assert.NoError(t, testEngine.Sync2(new(Salary), new(CheckList), new(Empsetting))) + assert.NoError(t, testEngine.Sync(new(Salary), new(CheckList), new(Empsetting))) var emp Empsetting cnt, err := testEngine.Insert(&emp) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var checklist = CheckList{ + checklist := CheckList{ Eid: emp.Id, } cnt, err = testEngine.Insert(&checklist) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var salary = Salary{ + salary := Salary{ Lid: checklist.Id, } cnt, err = testEngine.Insert(&salary) @@ -55,8 +57,8 @@ func TestJoinLimit(t *testing.T) { var salaries []Salary err = testEngine.Table("salary"). - Join("INNER", "check_list", "check_list.id = salary.lid"). - Join("LEFT", "empsetting", "empsetting.id = check_list.eid"). + Join("INNER", "check_list", "`check_list`.`id` = `salary`.`lid`"). + Join("LEFT", "empsetting", "`empsetting`.`id` = `check_list`.`eid`"). Limit(10, 0). Find(&salaries) assert.NoError(t, err) @@ -68,10 +70,10 @@ func TestWhere(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.Where("id > ?", 2).Find(&users) + err := testEngine.Where("`id` > ?", 2).Find(&users) assert.NoError(t, err) - err = testEngine.Where("id > ?", 2).And("id < ?", 10).Find(&users) + err = testEngine.Where("`id` > ?", 2).And("`id` < ?", 10).Find(&users) assert.NoError(t, err) } @@ -84,8 +86,11 @@ func TestFind(t *testing.T) { err := testEngine.Find(&users) assert.NoError(t, err) + err = testEngine.Limit(10, 0).Find(&users) + assert.NoError(t, err) + users2 := make([]Userinfo, 0) - var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) + tbName := testEngine.Quote(testEngine.TableName(new(Userinfo), true)) err = testEngine.SQL("select * from " + tbName).Find(&users2) assert.NoError(t, err) } @@ -115,56 +120,56 @@ func (TeamUser) TableName() string { } func TestFind3(t *testing.T) { - var teamUser = new(TeamUser) + teamUser := new(TeamUser) assert.NoError(t, PrepareEngine()) - err := testEngine.Sync2(new(Team), teamUser) + err := testEngine.Sync(new(Team), teamUser) assert.NoError(t, err) var teams []Team - err = testEngine.Cols("`team`.id"). - Where("`team_user`.org_id=?", 1). - And("`team_user`.uid=?", 2). - Join("INNER", "`team_user`", "`team_user`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`team_user`.`org_id`=?", 1). + And("`team_user`.`uid`=?", 2). + Join("INNER", "`team_user`", "`team_user`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) err = testEngine.Cols("`team`.id"). - Where("`team_user`.org_id=?", 1). - And("`team_user`.uid=?", 2). - Join("INNER", teamUser, "`team_user`.team_id=`team`.id"). + Where("`team_user`.`org_id`=?", 1). + And("`team_user`.`uid`=?", 2). + Join("INNER", teamUser, "`team_user`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) - err = testEngine.Cols("`team`.id"). - Where("`team_user`.org_id=?", 1). - And("`team_user`.uid=?", 2). - Join("INNER", []interface{}{teamUser}, "`team_user`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`team_user`.`org_id`=?", 1). + And("`team_user`.`uid`=?", 2). + Join("INNER", []interface{}{teamUser}, "`team_user`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) - err = testEngine.Cols("`team`.id"). - Where("`tu`.org_id=?", 1). - And("`tu`.uid=?", 2). - Join("INNER", []string{"team_user", "tu"}, "`tu`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`tu`.`org_id`=?", 1). + And("`tu`.`uid`=?", 2). + Join("INNER", []string{"team_user", "tu"}, "`tu`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) - err = testEngine.Cols("`team`.id"). - Where("`tu`.org_id=?", 1). - And("`tu`.uid=?", 2). - Join("INNER", []interface{}{"team_user", "tu"}, "`tu`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`tu`.`org_id`=?", 1). + And("`tu`.`uid`=?", 2). + Join("INNER", []interface{}{"team_user", "tu"}, "`tu`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) teams = make([]Team, 0) - err = testEngine.Cols("`team`.id"). - Where("`tu`.org_id=?", 1). - And("`tu`.uid=?", 2). - Join("INNER", []interface{}{teamUser, "tu"}, "`tu`.team_id=`team`.id"). + err = testEngine.Cols("`team`.`id`"). + Where("`tu`.`org_id`=?", 1). + And("`tu`.`uid`=?", 2). + Join("INNER", []interface{}{teamUser, "tu"}, "`tu`.`team_id`=`team`.`id`"). Find(&teams) assert.NoError(t, err) } @@ -237,12 +242,16 @@ func TestOrder(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.OrderBy("id desc").Find(&users) + err := testEngine.OrderBy("`id` desc").Find(&users) assert.NoError(t, err) users2 := make([]Userinfo, 0) err = testEngine.Asc("id", "username").Desc("height").Find(&users2) assert.NoError(t, err) + + users = make([]Userinfo, 0) + err = testEngine.OrderBy("CASE WHEN username LIKE ? THEN 0 ELSE 1 END DESC", "a").Find(&users) + assert.NoError(t, err) } func TestGroupBy(t *testing.T) { @@ -250,7 +259,7 @@ func TestGroupBy(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.GroupBy("id, username").Find(&users) + err := testEngine.GroupBy("`id`, `username`").Find(&users) assert.NoError(t, err) } @@ -259,7 +268,7 @@ func TestHaving(t *testing.T) { assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users) + err := testEngine.GroupBy("`username`").Having("`username`='xlw'").Find(&users) assert.NoError(t, err) } @@ -402,16 +411,16 @@ func TestFindMapPtrString(t *testing.T) { assert.NoError(t, err) } -func TestFindBit(t *testing.T) { - type FindBitStruct struct { +func TestFindBool(t *testing.T) { + type FindBoolStruct struct { Id int64 - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, PrepareEngine()) - assertSync(t, new(FindBitStruct)) + assertSync(t, new(FindBoolStruct)) - cnt, err := testEngine.Insert([]FindBitStruct{ + cnt, err := testEngine.Insert([]FindBoolStruct{ { Msg: false, }, @@ -422,14 +431,13 @@ func TestFindBit(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, cnt) - var results = make([]FindBitStruct, 0, 2) + results := make([]FindBoolStruct, 0, 2) err = testEngine.Find(&results) assert.NoError(t, err) assert.EqualValues(t, 2, len(results)) } func TestFindMark(t *testing.T) { - type Mark struct { Mark1 string `xorm:"VARCHAR(1)"` Mark2 string `xorm:"VARCHAR(1)"` @@ -454,7 +462,7 @@ func TestFindMark(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, cnt) - var results = make([]Mark, 0, 2) + results := make([]Mark, 0, 2) err = testEngine.Find(&results) assert.NoError(t, err) assert.EqualValues(t, 2, len(results)) @@ -464,7 +472,7 @@ func TestFindAndCountOneFunc(t *testing.T) { type FindAndCountStruct struct { Id int64 Content string - Msg bool `xorm:"bit"` + Msg bool } assert.NoError(t, PrepareEngine()) @@ -483,7 +491,7 @@ func TestFindAndCountOneFunc(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, cnt) - var results = make([]FindAndCountStruct, 0, 2) + results := make([]FindAndCountStruct, 0, 2) cnt, err = testEngine.Limit(1).FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) @@ -496,7 +504,7 @@ func TestFindAndCountOneFunc(t *testing.T) { assert.EqualValues(t, 2, cnt) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).FindAndCount(&results) + cnt, err = testEngine.Where("`msg` = ?", true).FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) @@ -546,21 +554,21 @@ func TestFindAndCountOneFunc(t *testing.T) { }, results[0]) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).Select("id, content, msg"). + cnt, err = testEngine.Where("`msg` = ?", true).Select("`id`, `content`, `msg`"). Limit(1).FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).Cols("id", "content", "msg"). + cnt, err = testEngine.Where("`msg` = ?", true).Cols("id", "content", "msg"). Limit(1).FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).Desc("id"). + cnt, err = testEngine.Where("`msg` = ?", true).Desc("id"). Limit(1).Cols("content").FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) @@ -608,14 +616,14 @@ func TestFindAndCount2(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(TestFindAndCountUser), new(TestFindAndCountHotel)) - var u = TestFindAndCountUser{ + u := TestFindAndCountUser{ Name: "myname", } cnt, err := testEngine.Insert(&u) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - var hotel = TestFindAndCountHotel{ + hotel := TestFindAndCountHotel{ Name: "myhotel", Code: "111", Region: "222", @@ -646,13 +654,98 @@ func TestFindAndCount2(t *testing.T) { cnt, err = testEngine. Table(new(TestFindAndCountHotel)). Alias("t"). - Where("t.region like '6501%'"). + Where("`t`.`region` like '6501%'"). Limit(10, 0). FindAndCount(&hotels) assert.NoError(t, err) assert.EqualValues(t, 0, cnt) } +type FindAndCountWithTableName struct { + Id int64 + Name string +} + +func (FindAndCountWithTableName) TableName() string { + return "find_and_count_with_table_name1" +} + +func TestFindAndCountWithTableName(t *testing.T) { + assert.NoError(t, PrepareEngine()) + assertSync(t, new(FindAndCountWithTableName)) + + cnt, err := testEngine.Insert(&FindAndCountWithTableName{ + Name: "1", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var res []FindAndCountWithTableName + cnt, err = testEngine.FindAndCount(&res) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestFindAndCountWithGroupBy(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type FindAndCountWithGroupBy struct { + Id int64 + Age int `xorm:"index"` + Name string + } + + assert.NoError(t, testEngine.Sync(new(FindAndCountWithGroupBy))) + + _, err := testEngine.Insert([]FindAndCountWithGroupBy{ + { + Name: "test1", + Age: 10, + }, + { + Name: "test2", + Age: 20, + }, + }) + assert.NoError(t, err) + + var results []FindAndCountWithGroupBy + cnt, err := testEngine.GroupBy("`age`").FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, 2, len(results)) +} + +func TestFindAndCountWithDistinct(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type FindAndCountWithDistinct struct { + Id int64 + Age int `xorm:"index"` + Name string + } + + assert.NoError(t, testEngine.Sync(new(FindAndCountWithDistinct))) + + _, err := testEngine.Insert([]FindAndCountWithDistinct{ + { + Name: "test1", + Age: 10, + }, + { + Name: "test2", + Age: 20, + }, + }) + assert.NoError(t, err) + + var results []FindAndCountWithDistinct + cnt, err := testEngine.Distinct("`age`").FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, 2, len(results)) +} + type FindMapDevice struct { Deviceid string `xorm:"pk"` Status int @@ -677,14 +770,14 @@ func TestFindMapStringId(t *testing.T) { deviceMaps := make(map[string]*FindMapDevice, len(deviceIDs)) err = testEngine. - Where("status = ?", 1). + Where("`status` = ?", 1). In("deviceid", deviceIDs). Find(&deviceMaps) assert.NoError(t, err) deviceMaps2 := make(map[string]FindMapDevice, len(deviceIDs)) err = testEngine. - Where("status = ?", 1). + Where("`status` = ?", 1). In("deviceid", deviceIDs). Find(&deviceMaps2) assert.NoError(t, err) @@ -861,17 +954,21 @@ func TestFindJoin(t *testing.T) { assertSync(t, new(SceneItem), new(DeviceUserPrivrels), new(Order)) var scenes []SceneItem - err := testEngine.Join("LEFT OUTER", "device_user_privrels", "device_user_privrels.device_id=scene_item.device_id"). - Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) + err := testEngine.Join("LEFT OUTER", "device_user_privrels", "`device_user_privrels`.`device_id`=`scene_item`.`device_id`"). + Where("`scene_item`.`type`=?", 3).Or("`device_user_privrels`.`user_id`=?", 339).Find(&scenes) assert.NoError(t, err) scenes = make([]SceneItem, 0) - err = testEngine.Join("LEFT OUTER", new(DeviceUserPrivrels), "device_user_privrels.device_id=scene_item.device_id"). - Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) + err = testEngine.Join("LEFT OUTER", new(DeviceUserPrivrels), "`device_user_privrels`.`device_id`=`scene_item`.`device_id`"). + Where("`scene_item`.`type`=?", 3).Or("`device_user_privrels`.`user_id`=?", 339).Find(&scenes) assert.NoError(t, err) scenes = make([]SceneItem, 0) - err = testEngine.Join("INNER", "order", "`scene_item`.device_id=`order`.id").Find(&scenes) + err = testEngine.Join("INNER", "order", "`scene_item`.`device_id`=`order`.`id`").Find(&scenes) + assert.NoError(t, err) + + scenes = make([]SceneItem, 0) + err = testEngine.Join("INNER", "order", builder.Expr("`scene_item`.`device_id`=`order`.`id`")).Find(&scenes) assert.NoError(t, err) } @@ -891,7 +988,7 @@ func TestJoinFindLimit(t *testing.T) { assertSync(t, new(JoinFindLimit1), new(JoinFindLimit2)) var finds []JoinFindLimit1 - err := testEngine.Join("INNER", new(JoinFindLimit2), "join_find_limit2.eid=join_find_limit1.id"). + err := testEngine.Join("INNER", new(JoinFindLimit2), "`join_find_limit2`.`eid`=`join_find_limit1`.`id`"). Limit(10, 10).Find(&finds) assert.NoError(t, err) } @@ -923,9 +1020,9 @@ func TestMoreExtends(t *testing.T) { assertSync(t, new(MoreExtendsUsers), new(MoreExtendsBooks)) var books []MoreExtendsBooksExtend - err := testEngine.Table("more_extends_books").Select("more_extends_books.*, more_extends_users.*"). - Join("INNER", "more_extends_users", "more_extends_books.user_id = more_extends_users.id"). - Where("more_extends_books.name LIKE ?", "abc"). + err := testEngine.Table("more_extends_books").Select("`more_extends_books`.*, `more_extends_users`.*"). + Join("INNER", "more_extends_users", "`more_extends_books`.`user_id` = `more_extends_users`.`id`"). + Where("`more_extends_books`.`name` LIKE ?", "abc"). Limit(10, 10). Find(&books) assert.NoError(t, err) @@ -933,9 +1030,9 @@ func TestMoreExtends(t *testing.T) { books = make([]MoreExtendsBooksExtend, 0, len(books)) err = testEngine.Table("more_extends_books"). Alias("m"). - Select("m.*, more_extends_users.*"). - Join("INNER", "more_extends_users", "m.user_id = more_extends_users.id"). - Where("m.name LIKE ?", "abc"). + Select("`m`.*, `more_extends_users`.*"). + Join("INNER", "more_extends_users", "`m`.`user_id` = `more_extends_users`.`id`"). + Where("`m`.`name` LIKE ?", "abc"). Limit(10, 10). Find(&books) assert.NoError(t, err) @@ -962,3 +1059,140 @@ func TestDistinctAndCols(t *testing.T) { assert.EqualValues(t, 1, len(names)) assert.EqualValues(t, "test", names[0]) } + +func TestUpdateFind(t *testing.T) { + type TestUpdateFind struct { + Id int64 + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestUpdateFind)) + + session := testEngine.NewSession() + defer session.Close() + + tuf := TestUpdateFind{ + Name: "test", + } + _, err := session.Insert(&tuf) + assert.NoError(t, err) + _, err = session.Where("`id` = ?", tuf.Id).Update(&TestUpdateFind{}) + assert.EqualError(t, xorm.ErrNoColumnsTobeUpdated, err.Error()) + + var tufs []TestUpdateFind + err = session.Where("`id` = ?", tuf.Id).Find(&tufs) + assert.NoError(t, err) +} + +func TestFindAnonymousStruct(t *testing.T) { + type FindAnonymousStruct struct { + Id int64 + Name string + Age int + IsMan bool + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(FindAnonymousStruct)) + + cnt, err := testEngine.Insert(&FindAnonymousStruct{ + Name: "xlw", + Age: 42, + IsMan: true, + }) + assert.EqualValues(t, 1, cnt) + assert.NoError(t, err) + + findRes := make([]struct { + Id int64 + Name string + }, 0) + err = testEngine.Table(new(FindAnonymousStruct)).Find(&findRes) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(findRes)) + assert.EqualValues(t, 1, findRes[0].Id) + assert.EqualValues(t, "xlw", findRes[0].Name) + + findRes = make([]struct { + Id int64 + Name string + }, 0) + err = testEngine.Select("`id`,`name`").Table(new(FindAnonymousStruct)).Find(&findRes) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(findRes)) + assert.EqualValues(t, 1, findRes[0].Id) + assert.EqualValues(t, "xlw", findRes[0].Name) +} + +func TestFindBytesVars(t *testing.T) { + type FindBytesVars struct { + Id int64 + Bytes1 []byte + Bytes2 []byte + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(FindBytesVars)) + + _, err := testEngine.Insert([]FindBytesVars{ + { + Bytes1: []byte("bytes1"), + Bytes2: []byte("bytes2"), + }, + { + Bytes1: []byte("bytes1-1"), + Bytes2: []byte("bytes2-2"), + }, + }) + assert.NoError(t, err) + + var gbv []FindBytesVars + err = testEngine.Find(&gbv) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(gbv)) + assert.EqualValues(t, []byte("bytes1"), gbv[0].Bytes1) + assert.EqualValues(t, []byte("bytes2"), gbv[0].Bytes2) + assert.EqualValues(t, []byte("bytes1-1"), gbv[1].Bytes1) + assert.EqualValues(t, []byte("bytes2-2"), gbv[1].Bytes2) + + err = testEngine.Find(&gbv) + assert.NoError(t, err) + assert.EqualValues(t, 4, len(gbv)) + assert.EqualValues(t, []byte("bytes1"), gbv[0].Bytes1) + assert.EqualValues(t, []byte("bytes2"), gbv[0].Bytes2) + assert.EqualValues(t, []byte("bytes1-1"), gbv[1].Bytes1) + assert.EqualValues(t, []byte("bytes2-2"), gbv[1].Bytes2) + assert.EqualValues(t, []byte("bytes1"), gbv[2].Bytes1) + assert.EqualValues(t, []byte("bytes2"), gbv[2].Bytes2) + assert.EqualValues(t, []byte("bytes1-1"), gbv[3].Bytes1) + assert.EqualValues(t, []byte("bytes2-2"), gbv[3].Bytes2) +} + +func TestUpdateFindDate(t *testing.T) { + type TestUpdateFindDate struct { + Id int64 + Name string + Tm time.Time `xorm:"DATE created"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestUpdateFindDate)) + + session := testEngine.NewSession() + defer session.Close() + + tuf := TestUpdateFindDate{ + Name: "test", + } + _, err := session.Insert(&tuf) + assert.NoError(t, err) + _, err = session.Where("`id` = ?", tuf.Id).Update(&TestUpdateFindDate{}) + assert.EqualError(t, xorm.ErrNoColumnsTobeUpdated, err.Error()) + + var tufs []TestUpdateFindDate + err = session.Find(&tufs) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tufs)) + assert.EqualValues(t, tuf.Tm.Format("2006-01-02"), tufs[0].Tm.Format("2006-01-02")) +} diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index e4d9f82e..841ec709 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -6,67 +6,22 @@ package integrations import ( "database/sql" + "errors" "fmt" - "strconv" + "math/big" "testing" "time" + "xorm.io/xorm" "xorm.io/xorm/contexts" + "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" "xorm.io/xorm/schemas" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" ) -func convertInt(v interface{}) (int64, error) { - switch v.(type) { - case int: - return int64(v.(int)), nil - case int8: - return int64(v.(int8)), nil - case int16: - return int64(v.(int16)), nil - case int32: - return int64(v.(int32)), nil - case int64: - return v.(int64), nil - case []byte: - i, err := strconv.ParseInt(string(v.([]byte)), 10, 64) - if err != nil { - return 0, err - } - return i, nil - case string: - i, err := strconv.ParseInt(v.(string), 10, 64) - if err != nil { - return 0, err - } - return i, nil - } - return 0, fmt.Errorf("unsupported type: %v", v) -} - -func convertFloat(v interface{}) (float64, error) { - switch v.(type) { - case float32: - return float64(v.(float32)), nil - case float64: - return v.(float64), nil - case string: - i, err := strconv.ParseFloat(v.(string), 64) - if err != nil { - return 0, err - } - return i, nil - case []byte: - i, err := strconv.ParseFloat(string(v.([]byte)), 64) - if err != nil { - return 0, err - } - return i, nil - } - return 0, fmt.Errorf("unsupported type: %v", v) -} - func TestGetVar(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -78,9 +33,9 @@ func TestGetVar(t *testing.T) { Created time.Time `xorm:"created"` } - assert.NoError(t, testEngine.Sync2(new(GetVar))) + assert.NoError(t, testEngine.Sync(new(GetVar))) - var data = GetVar{ + data := GetVar{ Msg: "hi", Age: 28, Money: 1.5, @@ -101,15 +56,15 @@ func TestGetVar(t *testing.T) { assert.Equal(t, 28, age) var ageMax int - has, err = testEngine.SQL("SELECT max(age) FROM "+testEngine.TableName("get_var", true)+" WHERE `id` = ?", data.Id).Get(&ageMax) + has, err = testEngine.SQL("SELECT max(`age`) FROM "+testEngine.Quote(testEngine.TableName("get_var", true))+" WHERE `id` = ?", data.Id).Get(&ageMax) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, 28, ageMax) var age2 int64 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age2) assert.NoError(t, err) assert.Equal(t, true, has) @@ -123,8 +78,8 @@ func TestGetVar(t *testing.T) { var age4 int16 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age4) assert.NoError(t, err) assert.Equal(t, true, has) @@ -132,8 +87,8 @@ func TestGetVar(t *testing.T) { var age5 int32 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age5) assert.NoError(t, err) assert.Equal(t, true, has) @@ -147,8 +102,8 @@ func TestGetVar(t *testing.T) { var age7 int64 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age7) assert.NoError(t, err) assert.Equal(t, true, has) @@ -162,8 +117,8 @@ func TestGetVar(t *testing.T) { var age9 int16 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age9) assert.NoError(t, err) assert.Equal(t, true, has) @@ -171,8 +126,8 @@ func TestGetVar(t *testing.T) { var age10 int32 has, err = testEngine.Table("get_var").Cols("age"). - Where("age > ?", 20). - And("age < ?", 30). + Where("`age` > ?", 20). + And("`age` < ?", 30). Get(&age10) assert.NoError(t, err) assert.Equal(t, true, has) @@ -207,20 +162,20 @@ func TestGetVar(t *testing.T) { var money2 float64 if testEngine.Dialect().URI().DBType == schemas.MSSQL { - has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2) + has, err = testEngine.SQL("SELECT TOP 1 `money` FROM " + testEngine.Quote(testEngine.TableName("get_var", true))).Get(&money2) } else { - has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) + has, err = testEngine.SQL("SELECT `money` FROM " + testEngine.Quote(testEngine.TableName("get_var", true)) + " LIMIT 1").Get(&money2) } assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2)) var money3 float64 - has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " WHERE money > 20").Get(&money3) + has, err = testEngine.SQL("SELECT `money` FROM " + testEngine.Quote(testEngine.TableName("get_var", true)) + " WHERE `money` > 20").Get(&money3) assert.NoError(t, err) assert.Equal(t, false, has) - var valuesString = make(map[string]string) + valuesString := make(map[string]string) has, err = testEngine.Table("get_var").Get(&valuesString) assert.NoError(t, err) assert.Equal(t, true, has) @@ -232,8 +187,8 @@ func TestGetVar(t *testing.T) { // for mymysql driver, interface{} will be []byte, so ignore it currently if testEngine.DriverName() != "mymysql" { - var valuesInter = make(map[string]interface{}) - has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) + valuesInter := make(map[string]interface{}) + has, err = testEngine.Table("get_var").Where("`id` = ?", 1).Select("*").Get(&valuesInter) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, 5, len(valuesInter)) @@ -243,7 +198,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"])) } - var valuesSliceString = make([]string, 5) + valuesSliceString := make([]string, 5) has, err = testEngine.Table("get_var").Get(&valuesSliceString) assert.NoError(t, err) assert.Equal(t, true, has) @@ -252,22 +207,22 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "28", valuesSliceString[2]) assert.Equal(t, "1.5", valuesSliceString[3]) - var valuesSliceInter = make([]interface{}, 5) + valuesSliceInter := make([]interface{}, 5) has, err = testEngine.Table("get_var").Get(&valuesSliceInter) assert.NoError(t, err) assert.Equal(t, true, has) - v1, err := convertInt(valuesSliceInter[0]) + v1, err := convert.AsInt64(valuesSliceInter[0]) assert.NoError(t, err) assert.EqualValues(t, 1, v1) assert.Equal(t, "hi", fmt.Sprintf("%s", valuesSliceInter[1])) - v3, err := convertInt(valuesSliceInter[2]) + v3, err := convert.AsInt64(valuesSliceInter[2]) assert.NoError(t, err) assert.EqualValues(t, 28, v3) - v4, err := convertFloat(valuesSliceInter[3]) + v4, err := convert.AsFloat64(valuesSliceInter[3]) assert.NoError(t, err) assert.Equal(t, "1.5", fmt.Sprintf("%v", v4)) } @@ -280,7 +235,7 @@ func TestGetStruct(t *testing.T) { IsMan bool } - assert.NoError(t, testEngine.Sync2(new(UserinfoGet))) + assert.NoError(t, testEngine.Sync(new(UserinfoGet))) session := testEngine.NewSession() defer session.Close() @@ -289,7 +244,7 @@ func TestGetStruct(t *testing.T) { if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) - _, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON") + _, err = session.Exec("SET IDENTITY_INSERT `userinfo_get` ON") assert.NoError(t, err) } cnt, err := session.Insert(&UserinfoGet{Uid: 2}) @@ -311,7 +266,7 @@ func TestGetStruct(t *testing.T) { Total int64 } - assert.NoError(t, testEngine.Sync2(&NoIdUser{})) + assert.NoError(t, testEngine.Sync(&NoIdUser{})) userCol := testEngine.GetColumnMapper().Obj2Table("User") _, err = testEngine.Where("`"+userCol+"` = ?", "xlw").Delete(&NoIdUser{}) @@ -343,6 +298,34 @@ func TestGetSlice(t *testing.T) { assert.Error(t, err) } +func TestGetMap(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + if testEngine.Dialect().Features().AutoincrMode == dialects.SequenceAutoincrMode { + t.SkipNow() + return + } + + type UserinfoMap struct { + Uid int `xorm:"pk autoincr"` + IsMan bool + } + + assertSync(t, new(UserinfoMap)) + + tableName := testEngine.Quote(testEngine.TableName("userinfo_map", true)) + _, err := testEngine.Exec(fmt.Sprintf("INSERT INTO %s (`is_man`) VALUES (NULL)", tableName)) + assert.NoError(t, err) + + valuesString := make(map[string]string) + has, err := testEngine.Table("userinfo_map").Get(&valuesString) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, 2, len(valuesString)) + assert.Equal(t, "1", valuesString["uid"]) + assert.Equal(t, "", valuesString["is_man"]) +} + func TestGetError(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -353,7 +336,7 @@ func TestGetError(t *testing.T) { assertSync(t, new(GetError)) - var info = new(GetError) + info := new(GetError) has, err := testEngine.Get(&info) assert.False(t, has) assert.Error(t, err) @@ -473,7 +456,7 @@ func TestGetActionMapping(t *testing.T) { }) assert.NoError(t, err) - var valuesSlice = make([]string, 2) + valuesSlice := make([]string, 2) has, err := testEngine.Table(new(ActionMapping)). Cols("script_id", "rollback_id"). ID("1").Get(&valuesSlice) @@ -500,9 +483,9 @@ func TestGetStructId(t *testing.T) { Id int64 } - //var id int64 + // var id int64 var maxid maxidst - sql := "select max(id) as id from " + testEngine.TableName(&TestGetStruct{}, true) + sql := "select max(`id`) as id from " + testEngine.Quote(testEngine.TableName(&TestGetStruct{}, true)) has, err := testEngine.SQL(sql).Get(&maxid) assert.NoError(t, err) assert.True(t, has) @@ -593,7 +576,7 @@ func (MyGetCustomTableImpletation) TableName() string { func TestGetCustomTableInterface(t *testing.T) { assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Table(getCustomTableName).Sync2(new(MyGetCustomTableImpletation))) + assert.NoError(t, testEngine.Table(getCustomTableName).Sync(new(MyGetCustomTableImpletation))) exist, err := testEngine.IsTableExist(getCustomTableName) assert.NoError(t, err) @@ -620,73 +603,78 @@ func TestGetNullVar(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(TestGetNullVarStruct)) - affected, err := testEngine.Exec("insert into " + testEngine.TableName(new(TestGetNullVarStruct), true) + " (name,age) values (null,null)") + if testEngine.Dialect().Features().AutoincrMode == dialects.SequenceAutoincrMode { + t.SkipNow() + return + } + + affected, err := testEngine.Exec("insert into " + testEngine.Quote(testEngine.TableName(new(TestGetNullVarStruct), true)) + " (`name`,`age`) values (null,null)") assert.NoError(t, err) a, _ := affected.RowsAffected() assert.EqualValues(t, 1, a) var name string - has, err := testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("name").Get(&name) + has, err := testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("name").Get(&name) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "", name) var age int - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age) var age2 int8 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age2) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age2) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age2) var age3 int16 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age3) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age3) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age3) var age4 int32 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age4) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age4) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age4) var age5 int64 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age5) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age5) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age5) var age6 uint - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age6) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age6) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age6) var age7 uint8 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age7) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age7) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age7) var age8 int16 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age8) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age8) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age8) var age9 int32 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age9) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age9) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age9) var age10 int64 - has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("id = ?", 1).Cols("age").Get(&age10) + has, err = testEngine.Table(new(TestGetNullVarStruct)).Where("`id` = ?", 1).Cols("age").Get(&age10) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 0, age10) @@ -705,7 +693,7 @@ func TestCustomTypes(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(TestCustomizeStruct)) - var s = TestCustomizeStruct{ + s := TestCustomizeStruct{ Name: "test", Age: 32, } @@ -720,7 +708,7 @@ func TestCustomTypes(t *testing.T) { assert.EqualValues(t, "test", name) var age MyInt - has, err = testEngine.Table(new(TestCustomizeStruct)).ID(s.Id).Select("age").Get(&age) + has, err = testEngine.Table(new(TestCustomizeStruct)).ID(s.Id).Select("`age`").Get(&age) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 32, age) @@ -750,3 +738,278 @@ func TestGetViaMapCond(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func TestGetNil(t *testing.T) { + type GetNil struct { + Id int64 + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetNil)) + + var gn *GetNil + has, err := testEngine.Get(gn) + assert.True(t, errors.Is(err, xorm.ErrObjectIsNil)) + assert.False(t, has) +} + +func TestGetBigFloat(t *testing.T) { + type GetBigFloat struct { + Id int64 + Money *big.Float `xorm:"numeric(22,2)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetBigFloat)) + + { + gf := GetBigFloat{ + Money: big.NewFloat(999999.99), + } + _, err := testEngine.Insert(&gf) + assert.NoError(t, err) + + var m big.Float + has, err := testEngine.Table("get_big_float").Cols("money").Where("`id`=?", gf.Id).Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) + // fmt.Println(m.Cmp(gf.Money)) + // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + } + + type GetBigFloat2 struct { + Id int64 + Money *big.Float `xorm:"decimal(22,2)"` + Money2 big.Float `xorm:"decimal(22,2)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetBigFloat2)) + + { + gf2 := GetBigFloat2{ + Money: big.NewFloat(9999999.99), + Money2: *big.NewFloat(99.99), + } + _, err := testEngine.Insert(&gf2) + assert.NoError(t, err) + + var m2 big.Float + has, err := testEngine.Table("get_big_float2").Cols("money").Where("`id`=?", gf2.Id).Get(&m2) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String()) + // fmt.Println(m.Cmp(gf.Money)) + // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + + var gf3 GetBigFloat2 + has, err = testEngine.ID(gf2.Id).Get(&gf3) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, gf3.Money.String() == gf2.Money.String(), "%v != %v", gf3.Money.String(), gf2.Money.String()) + assert.True(t, gf3.Money2.String() == gf2.Money2.String(), "%v != %v", gf3.Money2.String(), gf2.Money2.String()) + + var gfs []GetBigFloat2 + err = testEngine.Find(&gfs) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(gfs)) + assert.True(t, gfs[0].Money.String() == gf2.Money.String(), "%v != %v", gfs[0].Money.String(), gf2.Money.String()) + assert.True(t, gfs[0].Money2.String() == gf2.Money2.String(), "%v != %v", gfs[0].Money2.String(), gf2.Money2.String()) + } +} + +func TestGetDecimal(t *testing.T) { + type GetDecimal struct { + Id int64 + Money decimal.Decimal `xorm:"decimal(22,2)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetDecimal)) + + { + gf := GetDecimal{ + Money: decimal.NewFromFloat(999999.99), + } + _, err := testEngine.Insert(&gf) + assert.NoError(t, err) + + var m decimal.Decimal + has, err := testEngine.Table("get_decimal").Cols("money").Where("`id`=?", gf.Id).Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) + // fmt.Println(m.Cmp(gf.Money)) + // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + } + + type GetDecimal2 struct { + Id int64 + Money *decimal.Decimal `xorm:"decimal(22,2)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetDecimal2)) + + { + v := decimal.NewFromFloat(999999.99) + gf := GetDecimal2{ + Money: &v, + } + _, err := testEngine.Insert(&gf) + assert.NoError(t, err) + + var m decimal.Decimal + has, err := testEngine.Table("get_decimal2").Cols("money").Where("`id`=?", gf.Id).Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) + // fmt.Println(m.Cmp(gf.Money)) + // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) + } +} + +func TestGetTime(t *testing.T) { + type GetTimeStruct struct { + Id int64 + CreateTime time.Time + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetTimeStruct)) + + gts := GetTimeStruct{ + CreateTime: time.Now().In(testEngine.GetTZLocation()), + } + _, err := testEngine.Insert(>s) + assert.NoError(t, err) + + var gn time.Time + has, err := testEngine.Table("get_time_struct").Cols(colMapper.Obj2Table("CreateTime")).Get(&gn) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, gts.CreateTime.Format(time.RFC3339), gn.Format(time.RFC3339)) +} + +func TestGetVars(t *testing.T) { + type GetVars struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetVars)) + + _, err := testEngine.Insert(&GetVars{ + Name: "xlw", + Age: 42, + }) + assert.NoError(t, err) + + var name string + var age int + has, err := testEngine.Table(new(GetVars)).Cols("name", "age").Get(&name, &age) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", name) + assert.EqualValues(t, 42, age) +} + +func TestGetWithPrepare(t *testing.T) { + type GetVarsWithPrepare struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetVarsWithPrepare)) + + _, err := testEngine.Insert(&GetVarsWithPrepare{ + Name: "xlw", + Age: 42, + }) + assert.NoError(t, err) + + var v1 GetVarsWithPrepare + has, err := testEngine.Prepare().ID(1).Get(&v1) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", v1.Name) + assert.EqualValues(t, 42, v1.Age) + + sess := testEngine.NewSession() + defer sess.Close() + + var v2 GetVarsWithPrepare + has, err = sess.Prepare().ID(1).Get(&v2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", v2.Name) + assert.EqualValues(t, 42, v2.Age) + + var v3 GetVarsWithPrepare + has, err = sess.Prepare().ID(1).Get(&v3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", v3.Name) + assert.EqualValues(t, 42, v3.Age) + + err = sess.Begin() + assert.NoError(t, err) + + cnt, err := sess.Prepare().Insert(&GetVarsWithPrepare{ + Name: "xlw2", + Age: 12, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = sess.Prepare().Insert(&GetVarsWithPrepare{ + Name: "xlw3", + Age: 13, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + err = sess.Commit() + assert.NoError(t, err) +} + +func TestGetBytesVars(t *testing.T) { + type GetBytesVars struct { + Id int64 + Bytes1 []byte + Bytes2 []byte + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetBytesVars)) + + _, err := testEngine.Insert([]GetBytesVars{ + { + Bytes1: []byte("bytes1"), + Bytes2: []byte("bytes2"), + }, + { + Bytes1: []byte("bytes1-1"), + Bytes2: []byte("bytes2-2"), + }, + }) + assert.NoError(t, err) + + var gbv GetBytesVars + has, err := testEngine.Asc("id").Get(&gbv) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, []byte("bytes1"), gbv.Bytes1) + assert.EqualValues(t, []byte("bytes2"), gbv.Bytes2) + + has, err = testEngine.Desc("id").NoAutoCondition().Get(&gbv) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, []byte("bytes1-1"), gbv.Bytes1) + assert.EqualValues(t, []byte("bytes2-2"), gbv.Bytes2) +} diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index 47789b8a..084deb38 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -11,6 +11,7 @@ import ( "time" "xorm.io/xorm" + "xorm.io/xorm/schemas" "github.com/stretchr/testify/assert" ) @@ -24,7 +25,7 @@ func TestInsertOne(t *testing.T) { Created time.Time `xorm:"created"` } - assert.NoError(t, testEngine.Sync2(new(Test))) + assert.NoError(t, testEngine.Sync(new(Test))) data := Test{Msg: "hi"} _, err := testEngine.InsertOne(data) @@ -32,14 +33,13 @@ func TestInsertOne(t *testing.T) { } func TestInsertMulti(t *testing.T) { - assert.NoError(t, PrepareEngine()) type TestMulti struct { Id int64 `xorm:"int(11) pk"` Name string `xorm:"varchar(255)"` } - assert.NoError(t, testEngine.Sync2(new(TestMulti))) + assert.NoError(t, testEngine.Sync(new(TestMulti))) num, err := insertMultiDatas(1, append([]TestMulti{}, TestMulti{1, "test1"}, TestMulti{2, "test2"}, TestMulti{3, "test3"})) @@ -78,7 +78,6 @@ func insertMultiDatas(step int, datas interface{}) (num int64, err error) { } func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) { - sliceValue := reflect.Indirect(reflect.ValueOf(datas)) if sliceValue.Kind() != reflect.Slice { return fmt.Errorf("not slice") @@ -102,7 +101,7 @@ func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) er if err = actionFunc(tempInterface); err != nil { return } - processedLen = processedLen - step + processedLen -= step } return } @@ -116,7 +115,7 @@ func TestInsertOneIfPkIsPoint(t *testing.T) { Created *time.Time `xorm:"created"` } - assert.NoError(t, testEngine.Sync2(new(TestPoint))) + assert.NoError(t, testEngine.Sync(new(TestPoint))) msg := "hi" data := TestPoint{Msg: &msg} _, err := testEngine.InsertOne(&data) @@ -132,7 +131,7 @@ func TestInsertOneIfPkIsPointRename(t *testing.T) { Created *time.Time `xorm:"created"` } - assert.NoError(t, testEngine.Sync2(new(TestPoint2))) + assert.NoError(t, testEngine.Sync(new(TestPoint2))) msg := "hi" data := TestPoint2{Msg: &msg} _, err := testEngine.InsertOne(&data) @@ -170,19 +169,19 @@ func TestInsertAutoIncr(t *testing.T) { assert.Greater(t, user.Uid, int64(0)) } -type DefaultInsert struct { - Id int64 - Status int `xorm:"default -1"` - Name string - Created time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` -} - func TestInsertDefault(t *testing.T) { assert.NoError(t, PrepareEngine()) + type DefaultInsert struct { + Id int64 + Status int `xorm:"default -1"` + Name string + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` + } + di := new(DefaultInsert) - err := testEngine.Sync2(di) + err := testEngine.Sync(di) assert.NoError(t, err) var di2 = DefaultInsert{Name: "test"} @@ -193,22 +192,22 @@ func TestInsertDefault(t *testing.T) { assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, -1, di.Status) - assert.EqualValues(t, di2.Updated.Unix(), di.Updated.Unix()) - assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix()) -} - -type DefaultInsert2 struct { - Id int64 - Name string - Url string `xorm:"text"` - CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"` + assert.EqualValues(t, di2.Updated.Unix(), di.Updated.Unix(), di.Updated) + assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix(), di.Created) } func TestInsertDefault2(t *testing.T) { assert.NoError(t, PrepareEngine()) + type DefaultInsert2 struct { + Id int64 + Name string + Url string `xorm:"text"` + CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00'"` + } + di := new(DefaultInsert2) - err := testEngine.Sync2(di) + err := testEngine.Sync(di) assert.NoError(t, err) var di2 = DefaultInsert2{Name: "test"} @@ -259,7 +258,7 @@ func TestInsertCreated(t *testing.T) { assert.NoError(t, PrepareEngine()) di := new(CreatedInsert) - err := testEngine.Sync2(di) + err := testEngine.Sync(di) assert.NoError(t, err) ci := &CreatedInsert{} @@ -272,7 +271,7 @@ func TestInsertCreated(t *testing.T) { assert.EqualValues(t, ci.Created.Unix(), di.Created.Unix()) di2 := new(CreatedInsert2) - err = testEngine.Sync2(di2) + err = testEngine.Sync(di2) assert.NoError(t, err) ci2 := &CreatedInsert2{} @@ -285,7 +284,7 @@ func TestInsertCreated(t *testing.T) { assert.EqualValues(t, ci2.Created, di2.Created) di3 := new(CreatedInsert3) - err = testEngine.Sync2(di3) + err = testEngine.Sync(di3) assert.NoError(t, err) ci3 := &CreatedInsert3{} @@ -298,7 +297,7 @@ func TestInsertCreated(t *testing.T) { assert.EqualValues(t, ci3.Created, di3.Created) di4 := new(CreatedInsert4) - err = testEngine.Sync2(di4) + err = testEngine.Sync(di4) assert.NoError(t, err) ci4 := &CreatedInsert4{} @@ -311,7 +310,7 @@ func TestInsertCreated(t *testing.T) { assert.EqualValues(t, ci4.Created, di4.Created) di5 := new(CreatedInsert5) - err = testEngine.Sync2(di5) + err = testEngine.Sync(di5) assert.NoError(t, err) ci5 := &CreatedInsert5{} @@ -324,7 +323,7 @@ func TestInsertCreated(t *testing.T) { assert.EqualValues(t, ci5.Created.Unix(), di5.Created.Unix()) di6 := new(CreatedInsert6) - err = testEngine.Sync2(di6) + err = testEngine.Sync(di6) assert.NoError(t, err) oldTime := time.Now().Add(-time.Hour) @@ -338,6 +337,42 @@ func TestInsertCreated(t *testing.T) { assert.EqualValues(t, ci6.Created.Unix(), di6.Created.Unix()) } +func TestInsertTime(t *testing.T) { + type InsertTimeStruct struct { + Id int64 + CreatedAt time.Time `xorm:"created"` + UpdatedAt time.Time `xorm:"updated"` + DeletedAt time.Time `xorm:"deleted"` + Stime time.Time + Etime time.Time + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(InsertTimeStruct)) + + its := &InsertTimeStruct{ + Stime: time.Now(), + Etime: time.Now(), + } + cnt, err := testEngine.Insert(its) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var itsGet InsertTimeStruct + has, err := testEngine.ID(1).Get(&itsGet) + assert.NoError(t, err) + assert.True(t, has) + assert.False(t, itsGet.Stime.IsZero()) + assert.False(t, itsGet.Etime.IsZero()) + + var itsFind []*InsertTimeStruct + err = testEngine.Find(&itsFind) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(itsFind)) + assert.False(t, itsFind[0].Stime.IsZero()) + assert.False(t, itsFind[0].Etime.IsZero()) +} + type JSONTime time.Time func (j JSONTime) format() string { @@ -391,7 +426,7 @@ func TestCreatedJsonTime(t *testing.T) { assert.NoError(t, PrepareEngine()) di5 := new(MyJSONTime) - err := testEngine.Sync2(di5) + err := testEngine.Sync(di5) assert.NoError(t, err) ci5 := &MyJSONTime{} @@ -490,7 +525,7 @@ func TestInsertCreatedInt64(t *testing.T) { Created int64 `xorm:"created"` } - assert.NoError(t, testEngine.Sync2(new(TestCreatedInt64))) + assert.NoError(t, testEngine.Sync(new(TestCreatedInt64))) data := TestCreatedInt64{Msg: "hi"} now := time.Now() @@ -626,6 +661,11 @@ func TestAnonymousStruct(t *testing.T) { } func TestInsertMap(t *testing.T) { + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + type InsertMap struct { Id int64 Width uint32 @@ -704,7 +744,8 @@ func TestInsertMap(t *testing.T) { assert.EqualValues(t, "lunny", ims[3].Name) } -/*INSERT INTO `issue` (`repo_id`, `poster_id`, ... ,`name`, `content`, ... ,`index`) +/* +INSERT INTO `issue` (`repo_id`, `poster_id`, ... ,`name`, `content`, ... ,`index`) SELECT $1, $2, ..., $14, $15, ..., MAX(`index`) + 1 FROM `issue` WHERE `repo_id` = $1; */ func TestInsertWhere(t *testing.T) { @@ -729,7 +770,7 @@ func TestInsertWhere(t *testing.T) { } inserted, err := testEngine.SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). - Where("repo_id=?", 1). + Where("`repo_id`=?", 1). Insert(&i) assert.NoError(t, err) assert.EqualValues(t, 1, inserted) @@ -742,7 +783,12 @@ func TestInsertWhere(t *testing.T) { i.Index = 1 assert.EqualValues(t, i, j) - inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + + inserted, err = testEngine.Table(new(InsertWhere)).Where("`repo_id`=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). Insert(map[string]interface{}{ "repo_id": 1, @@ -763,7 +809,7 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, "trest2", j2.Name) assert.EqualValues(t, 2, j2.Index) - inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + inserted, err = testEngine.Table(new(InsertWhere)).Where("`repo_id`=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). SetExpr("repo_id", "1"). Insert(map[string]string{ @@ -779,7 +825,7 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, "trest3", j3.Name) assert.EqualValues(t, 3, j3.Index) - inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + inserted, err = testEngine.Table(new(InsertWhere)).Where("`repo_id`=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). Insert(map[string]interface{}{ "repo_id": 1, @@ -795,7 +841,7 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, "10';delete * from insert_where; --", j4.Name) assert.EqualValues(t, 4, j4.Index) - inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1). + inserted, err = testEngine.Table(new(InsertWhere)).Where("`repo_id`=?", 1). SetExpr("`index`", "coalesce(MAX(`index`),0)+1"). Insert(map[string]interface{}{ "repo_id": 1, @@ -848,6 +894,11 @@ func TestInsertExpr2(t *testing.T) { assert.EqualValues(t, 1, ie2.RepoId) assert.EqualValues(t, true, ie2.IsTag) + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + inserted, err = testEngine.Table(new(InsertExprsRelease)). SetExpr("is_draft", true). SetExpr("num_commits", 0). @@ -882,7 +933,7 @@ func TestMultipleInsertTableName(t *testing.T) { assert.NoError(t, PrepareEngine()) tableName := `prd_nightly_rate_16` - assert.NoError(t, testEngine.Table(tableName).Sync2(new(NightlyRate))) + assert.NoError(t, testEngine.Table(tableName).Sync(new(NightlyRate))) trans := testEngine.NewSession() defer trans.Close() @@ -918,7 +969,7 @@ func TestInsertMultiWithOmit(t *testing.T) { Omitted string `xorm:"varchar(255) 'omitted'"` } - assert.NoError(t, testEngine.Sync2(new(TestMultiOmit))) + assert.NoError(t, testEngine.Sync(new(TestMultiOmit))) l := []interface{}{ TestMultiOmit{Id: 1, Name: "1", Omitted: "1"}, @@ -963,7 +1014,7 @@ func TestInsertTwice(t *testing.T) { FieldB int } - assert.NoError(t, testEngine.Sync2(new(InsertStructA), new(InsertStructB))) + assert.NoError(t, testEngine.Sync(new(InsertStructA), new(InsertStructB))) var sliceA []InsertStructA // sliceA is empty sliceB := []InsertStructB{ @@ -986,3 +1037,168 @@ func TestInsertTwice(t *testing.T) { assert.NoError(t, ssn.Commit()) } + +func TestInsertIntSlice(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type InsertIntSlice struct { + NameIDs []int `xorm:"json notnull"` + } + + assert.NoError(t, testEngine.Sync(new(InsertIntSlice))) + + var v = InsertIntSlice{ + NameIDs: []int{1, 2}, + } + cnt, err := testEngine.Insert(&v) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var v2 InsertIntSlice + has, err := testEngine.Get(&v2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, v, v2) + + cnt, err = testEngine.Where("1=1").Delete(new(InsertIntSlice)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var v3 = InsertIntSlice{ + NameIDs: nil, + } + cnt, err = testEngine.Insert(&v3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var v4 InsertIntSlice + has, err = testEngine.Get(&v4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, v3, v4) +} + +func TestInsertDeleted(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type InsertDeletedStructNotRight struct { + ID uint64 `xorm:"'ID' pk autoincr"` + DeletedAt time.Time `xorm:"'DELETED_AT' deleted notnull"` + } + // notnull tag will be ignored + err := testEngine.Sync(new(InsertDeletedStructNotRight)) + assert.NoError(t, err) + + type InsertDeletedStruct struct { + ID uint64 `xorm:"'ID' pk autoincr"` + DeletedAt time.Time `xorm:"'DELETED_AT' deleted"` + } + + assert.NoError(t, testEngine.Sync(new(InsertDeletedStruct))) + + var v InsertDeletedStruct + _, err = testEngine.Insert(&v) + assert.NoError(t, err) + + var v2 InsertDeletedStruct + has, err := testEngine.Get(&v2) + assert.NoError(t, err) + assert.True(t, has) + + _, err = testEngine.ID(v.ID).Delete(new(InsertDeletedStruct)) + assert.NoError(t, err) + + var v3 InsertDeletedStruct + has, err = testEngine.Get(&v3) + assert.NoError(t, err) + assert.False(t, has) + + var v4 InsertDeletedStruct + has, err = testEngine.Unscoped().Get(&v4) + assert.NoError(t, err) + assert.True(t, has) +} + +func TestInsertMultipleMap(t *testing.T) { + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + + type InsertMultipleMap struct { + Id int64 + Width uint32 + Height uint32 + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(InsertMultipleMap)) + + cnt, err := testEngine.Table(new(InsertMultipleMap)).Insert([]map[string]interface{}{ + { + "width": 20, + "height": 10, + "name": "lunny", + }, + { + "width": 30, + "height": 20, + "name": "xiaolunwen", + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + var res []InsertMultipleMap + err = testEngine.Find(&res) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(res)) + assert.EqualValues(t, InsertMultipleMap{ + Id: 1, + Width: 20, + Height: 10, + Name: "lunny", + }, res[0]) + assert.EqualValues(t, InsertMultipleMap{ + Id: 2, + Width: 30, + Height: 20, + Name: "xiaolunwen", + }, res[1]) + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(InsertMultipleMap)) + + cnt, err = testEngine.Table(new(InsertMultipleMap)).Insert([]map[string]string{ + { + "width": "20", + "height": "10", + "name": "lunny", + }, + { + "width": "30", + "height": "20", + "name": "xiaolunwen", + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + res = make([]InsertMultipleMap, 0, 2) + err = testEngine.Find(&res) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(res)) + assert.EqualValues(t, InsertMultipleMap{ + Id: 1, + Width: 20, + Height: 10, + Name: "lunny", + }, res[0]) + assert.EqualValues(t, InsertMultipleMap{ + Id: 2, + Width: 30, + Height: 20, + Name: "xiaolunwen", + }, res[1]) +} diff --git a/integrations/session_iterate_test.go b/integrations/session_iterate_test.go index 564f457b..c5ecc593 100644 --- a/integrations/session_iterate_test.go +++ b/integrations/session_iterate_test.go @@ -18,7 +18,7 @@ func TestIterate(t *testing.T) { IsMan bool } - assert.NoError(t, testEngine.Sync2(new(UserIterate))) + assert.NoError(t, testEngine.Sync(new(UserIterate))) cnt, err := testEngine.Insert(&UserIterate{ IsMan: true, @@ -26,16 +26,27 @@ func TestIterate(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + cnt, err = testEngine.Insert(&UserIterate{ + IsMan: false, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + cnt = 0 err = testEngine.Iterate(new(UserIterate), func(i int, bean interface{}) error { user := bean.(*UserIterate) - assert.EqualValues(t, 1, user.Id) - assert.EqualValues(t, true, user.IsMan) + if cnt == 0 { + assert.EqualValues(t, 1, user.Id) + assert.EqualValues(t, true, user.IsMan) + } else { + assert.EqualValues(t, 2, user.Id) + assert.EqualValues(t, false, user.IsMan) + } cnt++ return nil }) assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) + assert.EqualValues(t, 2, cnt) } func TestBufferIterate(t *testing.T) { @@ -46,7 +57,7 @@ func TestBufferIterate(t *testing.T) { IsMan bool } - assert.NoError(t, testEngine.Sync2(new(UserBufferIterate))) + assert.NoError(t, testEngine.Sync(new(UserBufferIterate))) var size = 20 for i := 0; i < size; i++ { @@ -91,7 +102,7 @@ func TestBufferIterate(t *testing.T) { assert.EqualValues(t, 7, cnt) cnt = 0 - err = testEngine.Where("id <= 10").BufferSize(2).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error { + err = testEngine.Where("`id` <= 10").BufferSize(2).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error { user := bean.(*UserBufferIterate) assert.EqualValues(t, cnt+1, user.Id) assert.EqualValues(t, true, user.IsMan) diff --git a/integrations/session_pk_test.go b/integrations/session_pk_test.go index d5f23491..0244937f 100644 --- a/integrations/session_pk_test.go +++ b/integrations/session_pk_test.go @@ -121,7 +121,7 @@ func TestInt16Id(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(beans)) - beans2 := make(map[int16]Int16Id, 0) + beans2 := make(map[int16]Int16Id) err = testEngine.Find(&beans2) assert.NoError(t, err) assert.EqualValues(t, 1, len(beans2)) @@ -154,7 +154,7 @@ func TestInt32Id(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(beans)) - beans2 := make(map[int32]Int32Id, 0) + beans2 := make(map[int32]Int32Id) err = testEngine.Find(&beans2) assert.NoError(t, err) assert.EqualValues(t, 1, len(beans2)) @@ -173,6 +173,16 @@ func TestUintId(t *testing.T) { err = testEngine.CreateTables(&UintId{}) assert.NoError(t, err) + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + assert.EqualValues(t, 1, len(tables)) + cols := tables[0].PKColumns() + assert.EqualValues(t, 1, len(cols)) + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + assert.EqualValues(t, "UNSIGNED INT", cols[0].SQLType.Name) + } + cnt, err := testEngine.Insert(&UintId{Name: "test"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) @@ -195,7 +205,7 @@ func TestUintId(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 3, len(beans)) - beans2 := make(map[uint]UintId, 0) + beans2 := make(map[uint]UintId) err = testEngine.Find(&beans2) assert.NoError(t, err) assert.EqualValues(t, 3, len(beans2)) @@ -229,7 +239,7 @@ func TestUint16Id(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(beans)) - beans2 := make(map[uint16]Uint16Id, 0) + beans2 := make(map[uint16]Uint16Id) err = testEngine.Find(&beans2) assert.NoError(t, err) assert.EqualValues(t, 1, len(beans2)) @@ -263,7 +273,7 @@ func TestUint32Id(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(beans)) - beans2 := make(map[uint32]Uint32Id, 0) + beans2 := make(map[uint32]Uint32Id) err = testEngine.Find(&beans2) assert.NoError(t, err) assert.EqualValues(t, 1, len(beans2)) @@ -300,7 +310,7 @@ func TestUint64Id(t *testing.T) { assert.EqualValues(t, 1, len(beans)) assert.EqualValues(t, *bean, beans[0]) - beans2 := make(map[uint64]Uint64Id, 0) + beans2 := make(map[uint64]Uint64Id) err = testEngine.Find(&beans2) assert.NoError(t, err) assert.EqualValues(t, 1, len(beans2)) @@ -523,7 +533,7 @@ func TestMyIntId(t *testing.T) { assert.EqualValues(t, 1, len(beans)) assert.EqualValues(t, *bean, beans[0]) - beans2 := make(map[ID]MyIntPK, 0) + beans2 := make(map[ID]MyIntPK) err = testEngine.Find(&beans2) assert.NoError(t, err) assert.EqualValues(t, 1, len(beans2)) @@ -560,7 +570,7 @@ func TestMyStringId(t *testing.T) { assert.EqualValues(t, 1, len(beans)) assert.EqualValues(t, *bean, beans[0]) - beans2 := make(map[StrID]MyStringPK, 0) + beans2 := make(map[StrID]MyStringPK) err = testEngine.Find(&beans2) assert.NoError(t, err) assert.EqualValues(t, 1, len(beans2)) @@ -597,7 +607,7 @@ func TestCompositePK(t *testing.T) { assert.NoError(t, err) assertSync(t, new(TaskSolution)) - assert.NoError(t, testEngine.Sync2(new(TaskSolution))) + assert.NoError(t, testEngine.Sync(new(TaskSolution))) tables2, err := testEngine.DBMetas() assert.NoError(t, err) diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index 30f2e6ab..b72f7ef2 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -5,7 +5,6 @@ package integrations import ( - "fmt" "strconv" "testing" "time" @@ -27,7 +26,7 @@ func TestQueryString(t *testing.T) { Created time.Time `xorm:"created"` } - assert.NoError(t, testEngine.Sync2(new(GetVar2))) + assert.NoError(t, testEngine.Sync(new(GetVar2))) var data = GetVar2{ Msg: "hi", @@ -37,7 +36,7 @@ func TestQueryString(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var2", true)) + records, err := testEngine.QueryString("select * from " + testEngine.Quote(testEngine.TableName("get_var2", true))) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) @@ -52,10 +51,10 @@ func TestQueryString2(t *testing.T) { type GetVar3 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } - assert.NoError(t, testEngine.Sync2(new(GetVar3))) + assert.NoError(t, testEngine.Sync(new(GetVar3))) var data = GetVar3{ Msg: false, @@ -63,7 +62,7 @@ func TestQueryString2(t *testing.T) { _, err := testEngine.Insert(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var3", true)) + records, err := testEngine.QueryString("select * from " + testEngine.Quote(testEngine.TableName("get_var3", true))) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 2, len(records[0])) @@ -71,40 +70,14 @@ func TestQueryString2(t *testing.T) { assert.True(t, "0" == records[0]["msg"] || "false" == records[0]["msg"]) } -func toString(i interface{}) string { - switch i.(type) { - case []byte: - return string(i.([]byte)) - case string: - return i.(string) +func toBool(i interface{}) bool { + switch t := i.(type) { + case int32: + return t > 0 + case bool: + return t } - return fmt.Sprintf("%v", i) -} - -func toInt64(i interface{}) int64 { - switch i.(type) { - case []byte: - n, _ := strconv.ParseInt(string(i.([]byte)), 10, 64) - return n - case int: - return int64(i.(int)) - case int64: - return i.(int64) - } - return 0 -} - -func toFloat64(i interface{}) float64 { - switch i.(type) { - case []byte: - n, _ := strconv.ParseFloat(string(i.([]byte)), 64) - return n - case float64: - return i.(float64) - case float32: - return float64(i.(float32)) - } - return 0 + return false } func TestQueryInterface(t *testing.T) { @@ -118,7 +91,7 @@ func TestQueryInterface(t *testing.T) { Created time.Time `xorm:"created"` } - assert.NoError(t, testEngine.Sync2(new(GetVarInterface))) + assert.NoError(t, testEngine.Sync(new(GetVarInterface))) var data = GetVarInterface{ Msg: "hi", @@ -128,14 +101,14 @@ func TestQueryInterface(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryInterface("select * from " + testEngine.TableName("get_var_interface", true)) + records, err := testEngine.QueryInterface("select * from " + testEngine.Quote(testEngine.TableName("get_var_interface", true))) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.Equal(t, "hi", toString(records[0]["msg"])) - assert.EqualValues(t, 28, toInt64(records[0]["age"])) - assert.EqualValues(t, 1.5, toFloat64(records[0]["money"])) + assert.EqualValues(t, int64(1), records[0]["id"]) + assert.Equal(t, "hi", records[0]["msg"]) + assert.EqualValues(t, 28, records[0]["age"]) + assert.EqualValues(t, 1.5, records[0]["money"]) } func TestQueryNoParams(t *testing.T) { @@ -151,7 +124,7 @@ func TestQueryNoParams(t *testing.T) { testEngine.ShowSQL(true) - assert.NoError(t, testEngine.Sync2(new(QueryNoParams))) + assert.NoError(t, testEngine.Sync(new(QueryNoParams))) var q = QueryNoParams{ Msg: "message", @@ -182,7 +155,7 @@ func TestQueryNoParams(t *testing.T) { assert.NoError(t, err) assertResult(t, results) - results, err = testEngine.SQL("select * from " + testEngine.TableName("query_no_params", true)).Query() + results, err = testEngine.SQL("select * from " + testEngine.Quote(testEngine.TableName("query_no_params", true))).Query() assert.NoError(t, err) assertResult(t, results) } @@ -192,10 +165,10 @@ func TestQueryStringNoParam(t *testing.T) { type GetVar4 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } - assert.NoError(t, testEngine.Sync2(new(GetVar4))) + assert.NoError(t, testEngine.Sync(new(GetVar4))) var data = GetVar4{ Msg: false, @@ -213,7 +186,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.EqualValues(t, "0", records[0]["msg"]) } - records, err = testEngine.Table("get_var4").Where(builder.Eq{"id": 1}).QueryString() + records, err = testEngine.Table("get_var4").Where(builder.Eq{"`id`": 1}).QueryString() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) @@ -229,10 +202,10 @@ func TestQuerySliceStringNoParam(t *testing.T) { type GetVar6 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } - assert.NoError(t, testEngine.Sync2(new(GetVar6))) + assert.NoError(t, testEngine.Sync(new(GetVar6))) var data = GetVar6{ Msg: false, @@ -250,7 +223,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.EqualValues(t, "0", records[0][1]) } - records, err = testEngine.Table("get_var6").Where(builder.Eq{"id": 1}).QuerySliceString() + records, err = testEngine.Table("get_var6").Where(builder.Eq{"`id`": 1}).QuerySliceString() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0][0]) @@ -266,10 +239,10 @@ func TestQueryInterfaceNoParam(t *testing.T) { type GetVar5 struct { Id int64 `xorm:"autoincr pk"` - Msg bool `xorm:"bit"` + Msg bool } - assert.NoError(t, testEngine.Sync2(new(GetVar5))) + assert.NoError(t, testEngine.Sync(new(GetVar5))) var data = GetVar5{ Msg: false, @@ -280,14 +253,14 @@ func TestQueryInterfaceNoParam(t *testing.T) { records, err := testEngine.Table("get_var5").Limit(1).QueryInterface() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.EqualValues(t, 0, toInt64(records[0]["msg"])) + assert.EqualValues(t, 1, records[0]["id"]) + assert.False(t, toBool(records[0]["msg"])) - records, err = testEngine.Table("get_var5").Where(builder.Eq{"id": 1}).QueryInterface() + records, err = testEngine.Table("get_var5").Where(builder.Eq{"`id`": 1}).QueryInterface() assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) - assert.EqualValues(t, 1, toInt64(records[0]["id"])) - assert.EqualValues(t, 0, toInt64(records[0]["msg"])) + assert.EqualValues(t, 1, records[0]["id"]) + assert.False(t, toBool(records[0]["msg"])) } func TestQueryWithBuilder(t *testing.T) { @@ -303,7 +276,7 @@ func TestQueryWithBuilder(t *testing.T) { testEngine.ShowSQL(true) - assert.NoError(t, testEngine.Sync2(new(QueryWithBuilder))) + assert.NoError(t, testEngine.Sync(new(QueryWithBuilder))) var q = QueryWithBuilder{ Msg: "message", @@ -330,7 +303,7 @@ func TestQueryWithBuilder(t *testing.T) { assert.EqualValues(t, 3000, money) } - results, err := testEngine.Query(builder.Select("*").From(testEngine.TableName("query_with_builder", true))) + results, err := testEngine.Query(builder.Select("*").From(testEngine.Quote(testEngine.TableName("query_with_builder", true)))) assert.NoError(t, err) assertResult(t, results) } @@ -352,7 +325,7 @@ func TestJoinWithSubQuery(t *testing.T) { testEngine.ShowSQL(true) - assert.NoError(t, testEngine.Sync2(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart))) + assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart))) var depart = JoinWithSubQueryDepart{ Name: "depart1", @@ -373,16 +346,38 @@ func TestJoinWithSubQuery(t *testing.T) { tbName := testEngine.Quote(testEngine.TableName("join_with_sub_query_depart", true)) var querys []JoinWithSubQuery1 - err = testEngine.Join("INNER", builder.Select("id").From(tbName), - "join_with_sub_query_depart.id = join_with_sub_query1.depart_id").Find(&querys) + err = testEngine.Join("INNER", builder.Select("`id`").From(tbName), + "`join_with_sub_query_depart`.`id` = `join_with_sub_query1`.`depart_id`").Find(&querys) assert.NoError(t, err) assert.EqualValues(t, 1, len(querys)) assert.EqualValues(t, q, querys[0]) querys = make([]JoinWithSubQuery1, 0, 1) - err = testEngine.Join("INNER", "(SELECT id FROM "+tbName+") join_with_sub_query_depart", "join_with_sub_query_depart.id = join_with_sub_query1.depart_id"). + err = testEngine.Join("INNER", "(SELECT `id` FROM "+tbName+") `a`", "`a`.`id` = `join_with_sub_query1`.`depart_id`"). Find(&querys) assert.NoError(t, err) assert.EqualValues(t, 1, len(querys)) assert.EqualValues(t, q, querys[0]) } + +func TestQueryStringWithLimit(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + t.SkipNow() + return + } + + type QueryWithLimit struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + DepartId int64 + Money float32 + } + + assert.NoError(t, testEngine.Sync(new(QueryWithLimit))) + + data, err := testEngine.Table("query_with_limit").Limit(20, 20).QueryString() + assert.NoError(t, err) + assert.EqualValues(t, 0, len(data)) +} diff --git a/integrations/session_raw_test.go b/integrations/session_raw_test.go index 8b9d6766..5fa48d6e 100644 --- a/integrations/session_raw_test.go +++ b/integrations/session_raw_test.go @@ -7,6 +7,7 @@ package integrations import ( "strconv" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -19,15 +20,15 @@ func TestExecAndQuery(t *testing.T) { Name string } - assert.NoError(t, testEngine.Sync2(new(UserinfoQuery))) + assert.NoError(t, testEngine.Sync(new(UserinfoQuery))) - res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (uid, name) VALUES (?, ?)", 1, "user") + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (`uid`, `name`) VALUES (?, ?)", 1, "user") assert.NoError(t, err) cnt, err := res.RowsAffected() assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - results, err := testEngine.Query("select * from " + testEngine.TableName("userinfo_query", true)) + results, err := testEngine.Query("select * from " + testEngine.Quote(testEngine.TableName("userinfo_query", true))) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) id, err := strconv.Atoi(string(results[0]["uid"])) @@ -35,3 +36,32 @@ func TestExecAndQuery(t *testing.T) { assert.EqualValues(t, 1, id) assert.Equal(t, "user", string(results[0]["name"])) } + +func TestExecTime(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoExecTime struct { + Uid int + Name string + Created time.Time + } + + assert.NoError(t, testEngine.Sync(new(UserinfoExecTime))) + now := time.Now() + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_exec_time`", true)+" (`uid`, `name`, `created`) VALUES (?, ?, ?)", 1, "user", now) + assert.NoError(t, err) + cnt, err := res.RowsAffected() + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + results, err := testEngine.QueryString("SELECT * FROM " + testEngine.Quote(testEngine.TableName("userinfo_exec_time", true))) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), results[0]["created"]) + + var uet UserinfoExecTime + has, err := testEngine.Where("`uid`=?", 1).Get(&uet) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), uet.Created.Format("2006-01-02 15:04:05")) +} diff --git a/integrations/session_schema_test.go b/integrations/session_schema_test.go index 3f8f810b..3212d027 100644 --- a/integrations/session_schema_test.go +++ b/integrations/session_schema_test.go @@ -6,10 +6,12 @@ package integrations import ( "fmt" + "strings" "testing" "time" "github.com/stretchr/testify/assert" + "xorm.io/xorm/schemas" ) func TestStoreEngine(t *testing.T) { @@ -38,6 +40,82 @@ func TestCreateTable(t *testing.T) { assert.NoError(t, testEngine.Table("user_user").CreateTable(&UserinfoCreateTable{})) } +func TestCreateTable2(t *testing.T) { + type BaseModelLogicalDel struct { + Id string `xorm:"varchar(46) pk"` + CreatedAt time.Time `xorm:"created"` + UpdatedAt time.Time `xorm:"updated"` + DeletedAt *time.Time `xorm:"deleted"` + } + type TestPerson struct { + BaseModelLogicalDel `xorm:"extends"` + UserId string `xorm:"varchar(46) notnull"` + PersonId string `xorm:"varchar(46) notnull"` + Star bool + SortNo int + DispName string `xorm:"varchar(100)"` + FirstName string + LastName string + FirstNameKana string + LastNameKana string + BirthYear *int + BirthMonth *int + BirthDay *int + ImageId string `xorm:"varchar(46)"` + ImageDefaultId string `xorm:"varchar(46)"` + UserText string `xorm:"varchar(2000)"` + GenderId *int + At1 string `xorm:"varchar(10)"` + At1Rate int + At2 string `xorm:"varchar(10)"` + At2Rate int + At3 string `xorm:"varchar(10)"` + At3Rate int + At4 string `xorm:"varchar(10)"` + At4Rate int + At5 string `xorm:"varchar(10)"` + At5Rate int + At6 string `xorm:"varchar(10)"` + At6Rate int + } + + assert.NoError(t, PrepareEngine()) + + tb1, err := testEngine.TableInfo(TestPerson{}) + assert.NoError(t, err) + tb2, err := testEngine.TableInfo(new(TestPerson)) + assert.NoError(t, err) + cols1, cols2 := tb1.ColumnsSeq(), tb2.ColumnsSeq() + assert.EqualValues(t, len(cols1), len(cols2)) + for i, col := range cols1 { + assert.EqualValues(t, col, cols2[i]) + } + + result, err := testEngine.IsTableExist(new(TestPerson)) + assert.NoError(t, err) + if result { + assert.NoError(t, testEngine.DropTables(new(TestPerson))) + } + + assert.NoError(t, testEngine.CreateTables(new(TestPerson))) + tables1, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.Len(t, tables1, 1) + assert.EqualValues(t, len(cols1), len(tables1[0].Columns())) + + result, err = testEngine.IsTableExist(new(TestPerson)) + assert.NoError(t, err) + if result { + assert.NoError(t, testEngine.DropTables(new(TestPerson))) + } + + assert.NoError(t, testEngine.CreateTables(TestPerson{})) + tables2, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.Len(t, tables2, 1) + assert.EqualValues(t, len(cols1), len(tables2[0].Columns())) +} + func TestCreateMultiTables(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -96,7 +174,7 @@ func (s *SyncTable3) TableName() string { func TestSyncTable(t *testing.T) { assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Sync2(new(SyncTable1))) + assert.NoError(t, testEngine.Sync(new(SyncTable1))) tables, err := testEngine.DBMetas() assert.NoError(t, err) @@ -106,7 +184,7 @@ func TestSyncTable(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, testEngine.Dialect().SQLType(tables[0].GetColumn("name")), testEngine.Dialect().SQLType(tableInfo.GetColumn("name"))) - assert.NoError(t, testEngine.Sync2(new(SyncTable2))) + assert.NoError(t, testEngine.Sync(new(SyncTable2))) tables, err = testEngine.DBMetas() assert.NoError(t, err) @@ -116,7 +194,7 @@ func TestSyncTable(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, testEngine.Dialect().SQLType(tables[0].GetColumn("name")), testEngine.Dialect().SQLType(tableInfo.GetColumn("name"))) - assert.NoError(t, testEngine.Sync2(new(SyncTable3))) + assert.NoError(t, testEngine.Sync(new(SyncTable3))) tables, err = testEngine.DBMetas() assert.NoError(t, err) @@ -130,7 +208,7 @@ func TestSyncTable(t *testing.T) { func TestSyncTable2(t *testing.T) { assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Table("sync_tablex").Sync2(new(SyncTable1))) + assert.NoError(t, testEngine.Table("sync_tablex").Sync(new(SyncTable1))) tables, err := testEngine.DBMetas() assert.NoError(t, err) @@ -143,7 +221,7 @@ func TestSyncTable2(t *testing.T) { NewCol string } - assert.NoError(t, testEngine.Table("sync_tablex").Sync2(new(SyncTable4))) + assert.NoError(t, testEngine.Table("sync_tablex").Sync(new(SyncTable4))) tables, err = testEngine.DBMetas() assert.NoError(t, err) assert.EqualValues(t, 1, len(tables)) @@ -164,14 +242,16 @@ func TestSyncTable3(t *testing.T) { assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Sync2(new(SyncTable5))) + assert.NoError(t, testEngine.Sync(new(SyncTable5))) tables, err := testEngine.DBMetas() assert.NoError(t, err) tableInfo, err := testEngine.TableInfo(new(SyncTable5)) assert.NoError(t, err) assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("name")), testEngine.Dialect().SQLType(tables[0].GetColumn("name"))) - assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("text")), testEngine.Dialect().SQLType(tables[0].GetColumn("text"))) + /* Engine.DBMetas() returns the size of the column from the database but Engine.TableInfo() might not be able to guess the column size. + For example using MySQL/MariaDB: when utf-8 charset is used, "`xorm:"TEXT(21846)`" creates a MEDIUMTEXT column not a TEXT column. */ + assert.True(t, testEngine.Dialect().SQLType(tables[0].GetColumn("text")) == testEngine.Dialect().SQLType(tableInfo.GetColumn("text")) || strings.HasPrefix(testEngine.Dialect().SQLType(tables[0].GetColumn("text")), testEngine.Dialect().SQLType(tableInfo.GetColumn("text"))+"(")) assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("char")), testEngine.Dialect().SQLType(tables[0].GetColumn("char"))) assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("ten_char")), testEngine.Dialect().SQLType(tables[0].GetColumn("ten_char"))) assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("ten_var_char")), testEngine.Dialect().SQLType(tables[0].GetColumn("ten_var_char"))) @@ -195,7 +275,7 @@ func TestSyncTable3(t *testing.T) { }() assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Sync2(new(SyncTable5))) + assert.NoError(t, testEngine.Sync(new(SyncTable5))) tables, err := testEngine.DBMetas() assert.NoError(t, err) @@ -209,6 +289,19 @@ func TestSyncTable3(t *testing.T) { } } +func TestSyncTable4(t *testing.T) { + type SyncTable6 struct { + Id int64 + Qty float64 `xorm:"numeric(36,2)"` + } + + assert.NoError(t, PrepareEngine()) + + assert.NoError(t, testEngine.Sync(new(SyncTable6))) + + assert.NoError(t, testEngine.Sync(new(SyncTable6))) +} + func TestIsTableExist(t *testing.T) { assert.NoError(t, PrepareEngine()) @@ -238,14 +331,14 @@ func TestIsTableEmpty(t *testing.T) { Created time.Time `xorm:"created"` ILike int PageView int - From_url string + From_url string // nolint Pre_url string `xorm:"unique"` //pre view image's url Uid int64 } assert.NoError(t, testEngine.DropTables(&PictureEmpty{}, &NumericEmpty{})) - assert.NoError(t, testEngine.Sync2(new(PictureEmpty), new(NumericEmpty))) + assert.NoError(t, testEngine.Sync(new(PictureEmpty), new(NumericEmpty))) isEmpty, err := testEngine.IsTableEmpty(&PictureEmpty{}) assert.NoError(t, err) @@ -303,7 +396,7 @@ func TestIndexAndUnique(t *testing.T) { func TestMetaInfo(t *testing.T) { assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Sync2(new(CustomTableName), new(IndexOrUnique))) + assert.NoError(t, testEngine.Sync(new(CustomTableName), new(IndexOrUnique))) tables, err := testEngine.DBMetas() assert.NoError(t, err) @@ -333,8 +426,8 @@ func TestSync2_1(t *testing.T) { assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("wx_test")) - assert.NoError(t, testEngine.Sync2(new(WxTest))) - assert.NoError(t, testEngine.Sync2(new(WxTest))) + assert.NoError(t, testEngine.Sync(new(WxTest))) + assert.NoError(t, testEngine.Sync(new(WxTest))) } func TestUnique_1(t *testing.T) { @@ -350,7 +443,7 @@ func TestUnique_1(t *testing.T) { assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("user_unique")) - assert.NoError(t, testEngine.Sync2(new(UserUnique))) + assert.NoError(t, testEngine.Sync(new(UserUnique))) assert.NoError(t, testEngine.DropTables("user_unique")) assert.NoError(t, testEngine.CreateTables(new(UserUnique))) @@ -369,7 +462,7 @@ func TestSync2_2(t *testing.T) { for i := 0; i < 10; i++ { tableName := fmt.Sprintf("test_sync2_index_%d", i) tableNames[tableName] = true - assert.NoError(t, testEngine.Table(tableName).Sync2(new(TestSync2Index))) + assert.NoError(t, testEngine.Table(tableName).Sync(new(TestSync2Index))) exist, err := testEngine.IsTableExist(tableName) assert.NoError(t, err) @@ -394,5 +487,52 @@ func TestSync2_Default(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(TestSync2Default)) - assert.NoError(t, testEngine.Sync2(new(TestSync2Default))) + assert.NoError(t, testEngine.Sync(new(TestSync2Default))) +} + +func TestSync2_Default2(t *testing.T) { + type TestSync2Default2 struct { + Id int64 + UserId int64 `xorm:"default(1)"` + IsMember bool `xorm:"default(true)"` + Name string `xorm:"default('')"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestSync2Default2)) + assert.NoError(t, testEngine.Sync(new(TestSync2Default2))) + assert.NoError(t, testEngine.Sync(new(TestSync2Default2))) + assert.NoError(t, testEngine.Sync(new(TestSync2Default2))) + + assert.NoError(t, testEngine.Sync(new(TestSync2Default2))) + assert.NoError(t, testEngine.Sync(new(TestSync2Default2))) + assert.NoError(t, testEngine.Sync(new(TestSync2Default2))) +} + +func TestModifyColum(t *testing.T) { + // Since SQLITE don't support modify column SQL, currrently just ignore + if testEngine.Dialect().URI().DBType == schemas.SQLITE { + return + } + type TestModifyColumn struct { + Id int64 + UserId int64 `xorm:"default(1)"` + IsMember bool `xorm:"default(true)"` + Name string `xorm:"char(10)"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestModifyColumn)) + + alterSQL := testEngine.Dialect().ModifyColumnSQL("test_modify_column", &schemas.Column{ + Name: "name", + SQLType: schemas.SQLType{ + Name: "VARCHAR", + }, + Length: 16, + Nullable: false, + DefaultIsEmpty: true, + }) + _, err := testEngine.Exec(alterSQL) + assert.NoError(t, err) } diff --git a/integrations/session_stats_test.go b/integrations/session_sum_test.go similarity index 50% rename from integrations/session_stats_test.go rename to integrations/session_sum_test.go index 47a64076..e000233b 100644 --- a/integrations/session_stats_test.go +++ b/integrations/session_sum_test.go @@ -10,7 +10,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "xorm.io/builder" ) func isFloatEq(i, j float64, precision int) bool { @@ -24,7 +23,7 @@ func TestSum(t *testing.T) { } assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Sync2(new(SumStruct))) + assert.NoError(t, testEngine.Sync(new(SumStruct))) var ( cases = []SumStruct{ @@ -83,7 +82,7 @@ func (s SumStructWithTableName) TableName() string { func TestSumWithTableName(t *testing.T) { assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Sync2(new(SumStructWithTableName))) + assert.NoError(t, testEngine.Sync(new(SumStructWithTableName))) var ( cases = []SumStructWithTableName{ @@ -147,7 +146,7 @@ func TestSumCustomColumn(t *testing.T) { } ) - assert.NoError(t, testEngine.Sync2(new(SumStruct2))) + assert.NoError(t, testEngine.Sync(new(SumStruct2))) cnt, err := testEngine.Insert(cases) assert.NoError(t, err) @@ -158,143 +157,3 @@ func TestSumCustomColumn(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 3, int(sumInt)) } - -func TestCount(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - type UserinfoCount struct { - Departname string - } - assert.NoError(t, testEngine.Sync2(new(UserinfoCount))) - - colName := testEngine.GetColumnMapper().Obj2Table("Departname") - var cond builder.Cond = builder.Eq{ - "`" + colName + "`": "dev", - } - - total, err := testEngine.Where(cond).Count(new(UserinfoCount)) - assert.NoError(t, err) - assert.EqualValues(t, 0, total) - - cnt, err := testEngine.Insert(&UserinfoCount{ - Departname: "dev", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - total, err = testEngine.Where(cond).Count(new(UserinfoCount)) - assert.NoError(t, err) - assert.EqualValues(t, 1, total) - - total, err = testEngine.Where(cond).Table("userinfo_count").Count() - assert.NoError(t, err) - assert.EqualValues(t, 1, total) - - total, err = testEngine.Table("userinfo_count").Count() - assert.NoError(t, err) - assert.EqualValues(t, 1, total) -} - -func TestSQLCount(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - type UserinfoCount2 struct { - Id int64 - Departname string - } - - type UserinfoBooks struct { - Id int64 - Pid int64 - IsOpen bool - } - - assertSync(t, new(UserinfoCount2), new(UserinfoBooks)) - - total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)). - Count() - assert.NoError(t, err) - assert.EqualValues(t, 0, total) -} - -func TestCountWithOthers(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - type CountWithOthers struct { - Id int64 - Name string - } - - assertSync(t, new(CountWithOthers)) - - _, err := testEngine.Insert(&CountWithOthers{ - Name: "orderby", - }) - assert.NoError(t, err) - - _, err = testEngine.Insert(&CountWithOthers{ - Name: "limit", - }) - assert.NoError(t, err) - - total, err := testEngine.OrderBy("id desc").Limit(1).Count(new(CountWithOthers)) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) -} - -type CountWithTableName struct { - Id int64 - Name string -} - -func (CountWithTableName) TableName() string { - return "count_with_table_name1" -} - -func TestWithTableName(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - assertSync(t, new(CountWithTableName)) - - _, err := testEngine.Insert(&CountWithTableName{ - Name: "orderby", - }) - assert.NoError(t, err) - - _, err = testEngine.Insert(CountWithTableName{ - Name: "limit", - }) - assert.NoError(t, err) - - total, err := testEngine.OrderBy("id desc").Count(new(CountWithTableName)) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) - - total, err = testEngine.OrderBy("id desc").Count(CountWithTableName{}) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) -} - -func TestCountWithSelectCols(t *testing.T) { - assert.NoError(t, PrepareEngine()) - - assertSync(t, new(CountWithTableName)) - - _, err := testEngine.Insert(&CountWithTableName{ - Name: "orderby", - }) - assert.NoError(t, err) - - _, err = testEngine.Insert(CountWithTableName{ - Name: "limit", - }) - assert.NoError(t, err) - - total, err := testEngine.Cols("id").Count(new(CountWithTableName)) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) - - total, err = testEngine.Select("count(id)").Count(CountWithTableName{}) - assert.NoError(t, err) - assert.EqualValues(t, 2, total) -} diff --git a/integrations/session_test.go b/integrations/session_test.go index c3ef0513..a36b81bf 100644 --- a/integrations/session_test.go +++ b/integrations/session_test.go @@ -18,7 +18,7 @@ func TestClose(t *testing.T) { sess1.Close() assert.True(t, sess1.IsClosed()) - sess2 := testEngine.Where("a = ?", 1) + sess2 := testEngine.Where("`a` = ?", 1) sess2.Close() assert.True(t, sess2.IsClosed()) } @@ -32,7 +32,7 @@ func TestNullFloatStruct(t *testing.T) { } assert.NoError(t, PrepareEngine()) - assert.NoError(t, testEngine.Sync2(new(MyNullFloatStruct))) + assert.NoError(t, testEngine.Sync(new(MyNullFloatStruct))) _, err := testEngine.Insert(&MyNullFloatStruct{ Uuid: "111111", diff --git a/integrations/session_tx_test.go b/integrations/session_tx_test.go index 4cff5610..8d6519d0 100644 --- a/integrations/session_tx_test.go +++ b/integrations/session_tx_test.go @@ -37,7 +37,7 @@ func TestTransaction(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "yyy"} - _, err = session.Where("id = ?", 0).Update(&user2) + _, err = session.Where("`id` = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Delete(&user2) @@ -70,10 +70,10 @@ func TestCombineTransaction(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} - _, err = session.Where("id = ?", 0).Update(&user2) + _, err = session.Where("`id` = ?", 0).Update(&user2) assert.NoError(t, err) - _, err = session.Exec("delete from "+testEngine.TableName("userinfo", true)+" where username = ?", user2.Username) + _, err = session.Exec("delete from "+testEngine.Quote(testEngine.TableName("userinfo", true))+" where `username` = ?", user2.Username) assert.NoError(t, err) err = session.Commit() @@ -113,10 +113,10 @@ func TestCombineTransactionSameMapper(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} - _, err = session.Where("id = ?", 0).Update(&user2) + _, err = session.Where("`id` = ?", 0).Update(&user2) assert.NoError(t, err) - _, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username) + _, err = session.Exec("delete from "+testEngine.Quote(testEngine.TableName("Userinfo", true))+" where `Username` = ?", user2.Username) assert.NoError(t, err) err = session.Commit() @@ -144,7 +144,7 @@ func TestMultipleTransaction(t *testing.T) { assert.NoError(t, err) user2 := MultipleTransaction{Name: "zzz"} - _, err = session.Where("id = ?", 0).Update(&user2) + _, err = session.Where("`id` = ?", 0).Update(&user2) assert.NoError(t, err) err = session.Commit() @@ -158,7 +158,7 @@ func TestMultipleTransaction(t *testing.T) { err = session.Begin() assert.NoError(t, err) - _, err = session.Where("id=?", m1.Id).Delete(new(MultipleTransaction)) + _, err = session.Where("`id`=?", m1.Id).Delete(new(MultipleTransaction)) assert.NoError(t, err) err = session.Commit() diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index 07c722bd..45338cad 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -15,6 +15,7 @@ import ( "xorm.io/xorm/internal/statements" "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" + "xorm.io/xorm/schemas" ) func TestUpdateMap(t *testing.T) { @@ -26,7 +27,7 @@ func TestUpdateMap(t *testing.T) { Age int } - assert.NoError(t, testEngine.Sync2(new(UpdateTable))) + assert.NoError(t, testEngine.Sync(new(UpdateTable))) var tb = UpdateTable{ Name: "test", Age: 35, @@ -34,7 +35,7 @@ func TestUpdateMap(t *testing.T) { _, err := testEngine.Insert(&tb) assert.NoError(t, err) - cnt, err := testEngine.Table("update_table").Where("id = ?", tb.Id).Update(map[string]interface{}{ + cnt, err := testEngine.Table("update_table").Where("`id` = ?", tb.Id).Update(map[string]interface{}{ "name": "test2", "age": 36, }) @@ -48,6 +49,19 @@ func TestUpdateMap(t *testing.T) { assert.Error(t, err) assert.True(t, statements.IsIDConditionWithNoTableErr(err)) assert.EqualValues(t, 0, cnt) + + cnt, err = testEngine.Table("update_table").Update(map[string]interface{}{ + "name": "test2", + "age": 36, + }, &UpdateTable{ + Id: tb.Id, + }) + assert.NoError(t, err) + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + assert.EqualValues(t, 0, cnt) + } else { + assert.EqualValues(t, 1, cnt) + } } func TestUpdateLimit(t *testing.T) { @@ -64,7 +78,7 @@ func TestUpdateLimit(t *testing.T) { Age int } - assert.NoError(t, testEngine.Sync2(new(UpdateTable2))) + assert.NoError(t, testEngine.Sync(new(UpdateTable2))) var tb = UpdateTable2{ Name: "test1", Age: 35, @@ -79,7 +93,12 @@ func TestUpdateLimit(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - cnt, err = testEngine.OrderBy("name desc").Limit(1).Update(&UpdateTable2{ + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + + cnt, err = testEngine.OrderBy("`name` desc").Limit(1).Update(&UpdateTable2{ Age: 30, }) assert.NoError(t, err) @@ -152,7 +171,7 @@ func TestForUpdate(t *testing.T) { // use lock fList := make([]ForUpdate, 0) session1.ForUpdate() - session1.Where("id = ?", 1) + session1.Where("`id` = ?", 1) err = session1.Find(&fList) switch { case err != nil: @@ -173,7 +192,7 @@ func TestForUpdate(t *testing.T) { wg.Add(1) go func() { f2 := new(ForUpdate) - session2.Where("id = ?", 1).ForUpdate() + session2.Where("`id` = ?", 1).ForUpdate() has, err := session2.Get(f2) // wait release lock switch { case err != nil: @@ -193,7 +212,7 @@ func TestForUpdate(t *testing.T) { wg2.Add(1) go func() { f3 := new(ForUpdate) - session3.Where("id = ?", 1) + session3.Where("`id` = ?", 1) has, err := session3.Get(f3) // wait release lock switch { case err != nil: @@ -211,15 +230,13 @@ func TestForUpdate(t *testing.T) { f := new(ForUpdate) f.Name = "updated by session1" - session1.Where("id = ?", 1) - session1.Update(f) + session1.Where("`id` = ?", 1) + _, err = session1.Update(f) + assert.NoError(t, err) // release lock err = session1.Commit() - if err != nil { - t.Error(err) - return - } + assert.NoError(t, err) wg.Wait() } @@ -234,7 +251,7 @@ func TestWithIn(t *testing.T) { assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync(new(temp3))) - testEngine.Insert(&[]temp3{ + _, err := testEngine.Insert(&[]temp3{ { Name: "user1", }, @@ -245,6 +262,7 @@ func TestWithIn(t *testing.T) { Name: "user1", }, }) + assert.NoError(t, err) cnt, err := testEngine.In("Id", 1, 2, 3, 4).Update(&temp3{Name: "aa"}, &temp3{Name: "user1"}) assert.NoError(t, err) @@ -286,7 +304,7 @@ func TestUpdateMap2(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(UpdateMustCols)) - _, err := testEngine.Table("update_must_cols").Where("id =?", 1).Update(map[string]interface{}{ + _, err := testEngine.Table("update_must_cols").Where("`id` =?", 1).Update(map[string]interface{}{ "bool": true, }) assert.NoError(t, err) @@ -299,6 +317,7 @@ func TestUpdate1(t *testing.T) { _, err := testEngine.Insert(&Userinfo{ Username: "user1", }) + assert.NoError(t, err) var ori Userinfo has, err := testEngine.Get(&ori) @@ -331,11 +350,11 @@ func TestUpdate1(t *testing.T) { userID := user.Uid has, err := testEngine.ID(userID). - And("username = ?", user.Username). - And("height = ?", user.Height). - And("departname = ?", ""). - And("detail_id = ?", 0). - And("is_man = ?", 0). + And("`username` = ?", user.Username). + And("`height` = ?", user.Height). + And("`departname` = ?", ""). + And("`detail_id` = ?", 0). + And("`is_man` = ?", false). Get(&Userinfo{}) assert.NoError(t, err) assert.True(t, has, "cannot insert properly") @@ -348,12 +367,12 @@ func TestUpdate1(t *testing.T) { assert.EqualValues(t, 1, cnt, "update not returned 1") has, err = testEngine.ID(userID). - And("username = ?", updatedUser.Username). - And("height IS NULL"). - And("departname IS NULL"). - And("is_man IS NULL"). - And("created IS NULL"). - And("detail_id = ?", 0). + And("`username` = ?", updatedUser.Username). + And("`height` IS NULL"). + And("`departname` IS NULL"). + And("`is_man` IS NULL"). + And("`created` IS NULL"). + And("`detail_id` = ?", 0). Get(&Userinfo{}) assert.NoError(t, err) assert.True(t, has, "cannot update with null properly") @@ -363,7 +382,7 @@ func TestUpdate1(t *testing.T) { assert.EqualValues(t, 1, cnt, "delete not returned 1") } - err = testEngine.StoreEngine("Innodb").Sync2(&Article{}) + err = testEngine.StoreEngine("Innodb").Sync(&Article{}) assert.NoError(t, err) defer func() { @@ -458,6 +477,11 @@ func TestUpdateIncrDecr(t *testing.T) { cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + testEngine.SetColumnMapper(testEngine.GetColumnMapper()) + cnt, err = testEngine.Cols(colName).Decr(colName, 2).ID(col1.Id).Update(new(UpdateIncr)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } type UpdatedUpdate struct { @@ -489,7 +513,7 @@ func TestUpdateUpdated(t *testing.T) { assert.NoError(t, PrepareEngine()) di := new(UpdatedUpdate) - err := testEngine.Sync2(di) + err := testEngine.Sync(di) assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate{}) @@ -505,7 +529,7 @@ func TestUpdateUpdated(t *testing.T) { assert.EqualValues(t, ci.Updated.Unix(), di.Updated.Unix()) di2 := new(UpdatedUpdate2) - err = testEngine.Sync2(di2) + err = testEngine.Sync(di2) assert.NoError(t, err) now := time.Now() @@ -532,7 +556,7 @@ func TestUpdateUpdated(t *testing.T) { assert.True(t, ci2.Updated >= di21.Updated) di3 := new(UpdatedUpdate3) - err = testEngine.Sync2(di3) + err = testEngine.Sync(di3) assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate3{}) @@ -548,7 +572,7 @@ func TestUpdateUpdated(t *testing.T) { assert.EqualValues(t, ci3.Updated, di3.Updated) di4 := new(UpdatedUpdate4) - err = testEngine.Sync2(di4) + err = testEngine.Sync(di4) assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate4{}) @@ -564,7 +588,7 @@ func TestUpdateUpdated(t *testing.T) { assert.EqualValues(t, ci4.Updated, di4.Updated) di5 := new(UpdatedUpdate5) - err = testEngine.Sync2(di5) + err = testEngine.Sync(di5) assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate5{}) @@ -778,7 +802,7 @@ func TestNoUpdate(t *testing.T) { _, err = testEngine.ID(1).Update(&NoUpdate{}) assert.Error(t, err) - assert.EqualValues(t, "No content found to be updated", err.Error()) + assert.EqualError(t, xorm.ErrNoColumnsTobeUpdated, err.Error()) } func TestNewUpdate(t *testing.T) { @@ -806,7 +830,7 @@ func TestNewUpdate(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 0, af) - af, err = testEngine.Table(new(TbUserInfo)).Where("phone=?", 13126564922).Update(&changeUsr) + af, err = testEngine.Table(new(TbUserInfo)).Where("`phone`=?", "13126564922").Update(&changeUsr) assert.NoError(t, err) assert.EqualValues(t, 0, af) } @@ -901,6 +925,7 @@ func TestDeletedUpdate(t *testing.T) { var s1 DeletedUpdatedStruct has, err := testEngine.ID(s.Id).Get(&s1) + assert.NoError(t, err) assert.EqualValues(t, true, has) cnt, err = testEngine.ID(s.Id).Delete(&DeletedUpdatedStruct{}) @@ -908,7 +933,7 @@ func TestDeletedUpdate(t *testing.T) { assert.EqualValues(t, 1, cnt) cnt, err = testEngine.ID(s.Id).Cols("deleted_at").Update(&DeletedUpdatedStruct{}) - assert.EqualValues(t, "No content found to be updated", err.Error()) + assert.EqualError(t, xorm.ErrNoColumnsTobeUpdated, err.Error()) assert.EqualValues(t, 0, cnt) cnt, err = testEngine.ID(s.Id).Unscoped().Cols("deleted_at").Update(&DeletedUpdatedStruct{}) @@ -917,6 +942,7 @@ func TestDeletedUpdate(t *testing.T) { var s2 DeletedUpdatedStruct has, err = testEngine.ID(s.Id).Get(&s2) + assert.NoError(t, err) assert.EqualValues(t, true, has) } @@ -1147,7 +1173,7 @@ func TestUpdateExprs(t *testing.T) { }) assert.NoError(t, err) - _, err = testEngine.SetExpr("num_issues", "num_issues+1").AllCols().Update(&UpdateExprs{ + _, err = testEngine.SetExpr("num_issues", "`num_issues`+1").AllCols().Update(&UpdateExprs{ NumIssues: 3, Name: "lunny xiao", }) @@ -1178,7 +1204,7 @@ func TestUpdateAlias(t *testing.T) { }) assert.NoError(t, err) - _, err = testEngine.Alias("ua").Where("ua.id = ?", 1).Update(&UpdateAlias{ + _, err = testEngine.Alias("ua").Where("ua.`id` = ?", 1).Update(&UpdateAlias{ NumIssues: 2, Name: "lunny xiao", }) @@ -1218,7 +1244,7 @@ func TestUpdateExprs2(t *testing.T) { assert.EqualValues(t, 1, inserted) updated, err := testEngine. - Where("repo_id = ? AND is_tag = ?", 1, false). + Where("`repo_id` = ? AND `is_tag` = ?", 1, false). SetExpr("is_draft", true). SetExpr("num_commits", 0). SetExpr("sha1", ""). @@ -1238,6 +1264,11 @@ func TestUpdateExprs2(t *testing.T) { } func TestUpdateMap3(t *testing.T) { + if testEngine.Dialect().URI().DBType == schemas.DAMENG { + t.SkipNow() + return + } + assert.NoError(t, PrepareEngine()) type UpdateMapUser struct { @@ -1289,12 +1320,11 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { assertGetRecord := func() *TestOnlyFromDBField { var record TestOnlyFromDBField - has, err := testEngine.Where("id = ?", 1).Get(&record) + has, err := testEngine.Where("`id` = ?", 1).Get(&record) assert.NoError(t, err) assert.EqualValues(t, true, has) assert.EqualValues(t, "", record.OnlyFromDBField) return &record - } assert.NoError(t, PrepareEngine()) assertSync(t, new(TestOnlyFromDBField)) @@ -1377,15 +1407,22 @@ func TestNilFromDB(t *testing.T) { assert.NoError(t, PrepareEngine()) assertSync(t, new(TestTable1)) - cnt, err := testEngine.Insert(&TestTable1{ + var tt0 = TestTable1{ Field1: &TestFieldType1{ cb: []byte("string"), }, UpdateTime: time.Now(), - }) + } + cnt, err := testEngine.Insert(&tt0) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + var tt1 TestTable1 + has, err := testEngine.ID(tt0.Id).Get(&tt1) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "string", string(tt1.Field1.cb)) + cnt, err = testEngine.Update(TestTable1{ UpdateTime: time.Now().Add(time.Second), }, TestTable1{ @@ -1399,4 +1436,37 @@ func TestNilFromDB(t *testing.T) { }) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + var tt = TestTable1{ + UpdateTime: time.Now(), + Field1: &TestFieldType1{ + cb: nil, + }, + } + cnt, err = testEngine.Insert(&tt) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var tt2 TestTable1 + has, err = testEngine.ID(tt.Id).Get(&tt2) + assert.NoError(t, err) + assert.True(t, has) + assert.Nil(t, tt2.Field1) + + var tt3 = TestTable1{ + UpdateTime: time.Now(), + Field1: &TestFieldType1{ + cb: []byte{}, + }, + } + cnt, err = testEngine.Insert(&tt3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var tt4 TestTable1 + has, err = testEngine.ID(tt3.Id).Get(&tt4) + assert.NoError(t, err) + assert.True(t, has) + assert.NotNil(t, tt4.Field1) + assert.NotNil(t, tt4.Field1.cb) } diff --git a/integrations/tags_test.go b/integrations/tags_test.go index f787fffe..4c33d56c 100644 --- a/integrations/tags_test.go +++ b/integrations/tags_test.go @@ -165,7 +165,7 @@ func TestExtends(t *testing.T) { assert.True(t, info2.Userinfo.Uid > 0, "all of the id should has value") assert.True(t, info2.Userdetail.Id > 0, "all of the id should has value") - var infos2 = make([]UserAndDetail, 0) + infos2 := make([]UserAndDetail, 0) err = testEngine.Table(&Userinfo{}). Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). NoCascade(). @@ -219,9 +219,9 @@ func TestExtends2(t *testing.T) { err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) assert.NoError(t, err) - var sender = MessageUser{Name: "sender"} - var receiver = MessageUser{Name: "receiver"} - var msgtype = MessageType{Name: "type"} + sender := MessageUser{Name: "sender"} + receiver := MessageUser{Name: "receiver"} + msgtype := MessageType{Name: "type"} _, err = testEngine.Insert(&sender, &receiver, &msgtype) assert.NoError(t, err) @@ -254,8 +254,8 @@ func TestExtends2(t *testing.T) { assert.NoError(t, err) } - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote + mapper := testEngine.GetTableMapper().Obj2Table + quote := testEngine.Quote userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) msgTableName := quote(testEngine.TableName(mapper("Message"), true)) @@ -280,9 +280,9 @@ func TestExtends3(t *testing.T) { err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) assert.NoError(t, err) - var sender = MessageUser{Name: "sender"} - var receiver = MessageUser{Name: "receiver"} - var msgtype = MessageType{Name: "type"} + sender := MessageUser{Name: "sender"} + receiver := MessageUser{Name: "receiver"} + msgtype := MessageType{Name: "type"} _, err = testEngine.Insert(&sender, &receiver, &msgtype) assert.NoError(t, err) @@ -314,8 +314,8 @@ func TestExtends3(t *testing.T) { assert.NoError(t, err) } - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote + mapper := testEngine.GetTableMapper().Obj2Table + quote := testEngine.Quote userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) msgTableName := quote(testEngine.TableName(mapper("Message"), true)) @@ -345,8 +345,8 @@ func TestExtends4(t *testing.T) { err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) assert.NoError(t, err) - var sender = MessageUser{Name: "sender"} - var msgtype = MessageType{Name: "type"} + sender := MessageUser{Name: "sender"} + msgtype := MessageType{Name: "type"} _, err = testEngine.Insert(&sender, &msgtype) assert.NoError(t, err) @@ -377,8 +377,8 @@ func TestExtends4(t *testing.T) { assert.NoError(t, err) } - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote + mapper := testEngine.GetTableMapper().Obj2Table + quote := testEngine.Quote userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) msgTableName := quote(testEngine.TableName(mapper("Message"), true)) @@ -417,29 +417,29 @@ func TestExtends5(t *testing.T) { err = testEngine.CreateTables(&Size{}, &Book{}) assert.NoError(t, err) - var sc = Size{Width: 0.2, Height: 0.4} - var so = Size{Width: 0.2, Height: 0.8} - var s = Size{Width: 0.15, Height: 1.5} - var bk1 = Book{ + sc := Size{Width: 0.2, Height: 0.4} + so := Size{Width: 0.2, Height: 0.8} + s := Size{Width: 0.15, Height: 1.5} + bk1 := Book{ SizeOpen: &so, SizeClosed: &sc, Size: &s, } - var bk2 = Book{ + bk2 := Book{ SizeOpen: &so, } - var bk3 = Book{ + bk3 := Book{ SizeClosed: &sc, Size: &s, } - var bk4 = Book{} - var bk5 = Book{Size: &s} + bk4 := Book{} + bk5 := Book{Size: &s} _, err = testEngine.Insert(&sc, &so, &s, &bk1, &bk2, &bk3, &bk4, &bk5) if err != nil { t.Fatal(err) } - var books = map[int64]Book{ + books := map[int64]Book{ bk1.ID: bk1, bk2.ID: bk2, bk3.ID: bk3, @@ -450,15 +450,15 @@ func TestExtends5(t *testing.T) { session := testEngine.NewSession() defer session.Close() - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote + mapper := testEngine.GetTableMapper().Obj2Table + quote := testEngine.Quote bookTableName := quote(testEngine.TableName(mapper("Book"), true)) sizeTableName := quote(testEngine.TableName(mapper("Size"), true)) list := make([]Book, 0) err = session. Select(fmt.Sprintf( - "%s.%s, sc.%s AS %s, sc.%s AS %s, s.%s, s.%s", + "%s.%s, `sc`.%s AS %s, `sc`.%s AS %s, `s`.%s, `s`.%s", quote(bookTableName), quote("id"), quote("Width"), @@ -472,12 +472,12 @@ func TestExtends5(t *testing.T) { Join( "LEFT", sizeTableName+" AS `sc`", - bookTableName+".`SizeClosed`=sc.`id`", + bookTableName+".`SizeClosed`=`sc`.`id`", ). Join( "LEFT", sizeTableName+" AS `s`", - bookTableName+".`Size`=s.`id`", + bookTableName+".`Size`=`s`.`id`", ). Find(&list) assert.NoError(t, err) @@ -673,7 +673,7 @@ func TestCreatedUpdated(t *testing.T) { Updated time.Time `xorm:"updated"` } - err := testEngine.Sync2(&CreatedUpdated{}) + err := testEngine.Sync(&CreatedUpdated{}) assert.NoError(t, err) c := &CreatedUpdated{Name: "test"} @@ -728,9 +728,9 @@ type Lowercase struct { func TestLowerCase(t *testing.T) { assert.NoError(t, PrepareEngine()) - err := testEngine.Sync2(&Lowercase{}) + err := testEngine.Sync(&Lowercase{}) assert.NoError(t, err) - _, err = testEngine.Where("id > 0").Delete(&Lowercase{}) + _, err = testEngine.Where("`id` > 0").Delete(&Lowercase{}) assert.NoError(t, err) _, err = testEngine.Insert(&Lowercase{ended: 1}) @@ -757,6 +757,8 @@ func TestAutoIncrTag(t *testing.T) { assert.True(t, cols[0].IsAutoIncrement) assert.True(t, cols[0].IsPrimaryKey) assert.Equal(t, "id", cols[0].Name) + assert.True(t, cols[0].DefaultIsEmpty) + assert.EqualValues(t, "", cols[0].Default) type TestAutoIncr2 struct { Id int64 `xorm:"id"` @@ -770,6 +772,8 @@ func TestAutoIncrTag(t *testing.T) { assert.False(t, cols[0].IsAutoIncrement) assert.False(t, cols[0].IsPrimaryKey) assert.Equal(t, "id", cols[0].Name) + assert.True(t, cols[0].DefaultIsEmpty) + assert.EqualValues(t, "", cols[0].Default) type TestAutoIncr3 struct { Id int64 `xorm:"'ID'"` @@ -783,6 +787,8 @@ func TestAutoIncrTag(t *testing.T) { assert.False(t, cols[0].IsAutoIncrement) assert.False(t, cols[0].IsPrimaryKey) assert.Equal(t, "ID", cols[0].Name) + assert.True(t, cols[0].DefaultIsEmpty) + assert.EqualValues(t, "", cols[0].Default) type TestAutoIncr4 struct { Id int64 `xorm:"pk"` @@ -796,6 +802,8 @@ func TestAutoIncrTag(t *testing.T) { assert.False(t, cols[0].IsAutoIncrement) assert.True(t, cols[0].IsPrimaryKey) assert.Equal(t, "id", cols[0].Name) + assert.True(t, cols[0].DefaultIsEmpty) + assert.EqualValues(t, "", cols[0].Default) } func TestTagComment(t *testing.T) { @@ -809,7 +817,17 @@ func TestTagComment(t *testing.T) { Id int64 `xorm:"comment(主键)"` } - assert.NoError(t, testEngine.Sync2(new(TestComment1))) + tb, err := testEngine.TableInfo(new(TestComment1)) + assert.NoError(t, err) + cols := tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.False(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + assert.True(t, cols[0].DefaultIsEmpty) + assert.EqualValues(t, "", cols[0].Default) + + assert.NoError(t, testEngine.Sync(new(TestComment1))) tables, err := testEngine.DBMetas() assert.NoError(t, err) @@ -823,7 +841,17 @@ func TestTagComment(t *testing.T) { Id int64 `xorm:"comment('主键')"` } - assert.NoError(t, testEngine.Sync2(new(TestComment2))) + tb, err = testEngine.TableInfo(new(TestComment2)) + assert.NoError(t, err) + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.False(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + assert.True(t, cols[0].DefaultIsEmpty) + assert.EqualValues(t, "", cols[0].Default) + + assert.NoError(t, testEngine.Sync(new(TestComment2))) tables, err = testEngine.DBMetas() assert.NoError(t, err) @@ -841,6 +869,28 @@ func TestTagDefault(t *testing.T) { Age int `xorm:"default(10)"` } + tb, err := testEngine.TableInfo(new(DefaultStruct)) + assert.NoError(t, err) + cols := tb.Columns() + assert.EqualValues(t, 3, len(cols)) + assert.True(t, cols[0].IsAutoIncrement) + assert.True(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + assert.True(t, cols[0].DefaultIsEmpty) + assert.EqualValues(t, "", cols[0].Default) + + assert.False(t, cols[1].IsAutoIncrement) + assert.False(t, cols[1].IsPrimaryKey) + assert.Equal(t, "name", cols[1].Name) + assert.True(t, cols[1].DefaultIsEmpty) + assert.EqualValues(t, "", cols[1].Default) + + assert.False(t, cols[2].IsAutoIncrement) + assert.False(t, cols[2].IsPrimaryKey) + assert.Equal(t, "age", cols[2].Name) + assert.False(t, cols[2].DefaultIsEmpty) + assert.EqualValues(t, "10", cols[2].Default) + assertSync(t, new(DefaultStruct)) tables, err := testEngine.DBMetas() @@ -880,10 +930,33 @@ func TestTagDefault2(t *testing.T) { assert.NoError(t, PrepareEngine()) type DefaultStruct2 struct { - Id int64 - Name string + Id int64 + Name string + NullDefault string `xorm:"default('NULL')"` } + tb, err := testEngine.TableInfo(new(DefaultStruct2)) + assert.NoError(t, err) + cols := tb.Columns() + assert.EqualValues(t, 3, len(cols)) + assert.True(t, cols[0].IsAutoIncrement) + assert.True(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + assert.True(t, cols[0].DefaultIsEmpty) + assert.EqualValues(t, "", cols[0].Default) + + assert.False(t, cols[1].IsAutoIncrement) + assert.False(t, cols[1].IsPrimaryKey) + assert.Equal(t, "name", cols[1].Name) + assert.True(t, cols[1].DefaultIsEmpty) + assert.EqualValues(t, "", cols[1].Default) + + assert.False(t, cols[2].IsAutoIncrement) + assert.False(t, cols[2].IsPrimaryKey) + assert.Equal(t, "null_default", cols[2].Name) + assert.False(t, cols[2].DefaultIsEmpty) + assert.EqualValues(t, "'NULL'", cols[2].Default) + assertSync(t, new(DefaultStruct2)) tables, err := testEngine.DBMetas() @@ -1129,7 +1202,7 @@ func TestTagTime(t *testing.T) { assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), - strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) + strings.ReplaceAll(strings.ReplaceAll(tm, "T", " "), "Z", "")) } func TestTagAutoIncr(t *testing.T) { @@ -1214,7 +1287,7 @@ func TestVersion1(t *testing.T) { assert.EqualValues(t, newVer.Ver, 2) newVer = new(VersionS) - has, err = testEngine.ID(ver.Id).Get(newVer) + _, err = testEngine.ID(ver.Id).Get(newVer) assert.NoError(t, err) assert.EqualValues(t, newVer.Ver, 2) } @@ -1228,7 +1301,7 @@ func TestVersion2(t *testing.T) { err = testEngine.CreateTables(new(VersionS)) assert.NoError(t, err) - var vers = []VersionS{ + vers := []VersionS{ {Name: "sfsfdsfds"}, {Name: "xxxxx"}, } @@ -1272,7 +1345,7 @@ func TestVersion3(t *testing.T) { assert.EqualValues(t, newVer.Ver, 2) newVer = new(VersionUintS) - has, err = testEngine.ID(ver.Id).Get(newVer) + _, err = testEngine.ID(ver.Id).Get(newVer) assert.NoError(t, err) assert.EqualValues(t, newVer.Ver, 2) } @@ -1286,7 +1359,7 @@ func TestVersion4(t *testing.T) { err = testEngine.CreateTables(new(VersionUintS)) assert.NoError(t, err) - var vers = []VersionUintS{ + vers := []VersionUintS{ {Name: "sfsfdsfds"}, {Name: "xxxxx"}, } @@ -1327,3 +1400,55 @@ func TestIndexes(t *testing.T) { assert.EqualValues(t, slice1, slice2) assert.EqualValues(t, 3, len(tables[0].Indexes)) } + +type TestTableIndicesStruct struct { + Id int64 + Name string `xorm:"index index(f_one_f_two) unique(s)"` // we're going to override the index f_one_f_two in TableIndices and remove it from this column + Email string `xorm:"index unique(s)"` + FTwo string `xorm:"index(f_two_f_one) index(f_one_f_two) f_two"` + FOne string `xorm:"index(f_two_f_one) f_one"` +} + +func (t *TestTableIndicesStruct) TableIndices() []*schemas.Index { + newIndex := schemas.NewIndex("f_one_f_two", schemas.IndexType) + newIndex.AddColumn("f_one", "f_two") + + return []*schemas.Index{newIndex} +} + +func TestTableIndices(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(TestTableIndicesStruct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 5, len(tables[0].Columns())) + slice1 := []string{ + testEngine.GetColumnMapper().Obj2Table("Id"), + testEngine.GetColumnMapper().Obj2Table("Name"), + testEngine.GetColumnMapper().Obj2Table("Email"), + testEngine.GetColumnMapper().Obj2Table("FTwo"), + testEngine.GetColumnMapper().Obj2Table("FOne"), + } + slice2 := []string{ + tables[0].Columns()[0].Name, + tables[0].Columns()[1].Name, + tables[0].Columns()[2].Name, + tables[0].Columns()[3].Name, + tables[0].Columns()[4].Name, + } + sort.Strings(slice1) + sort.Strings(slice2) + assert.EqualValues(t, slice1, slice2) + assert.EqualValues(t, 5, len(tables[0].Indexes)) + index, ok := tables[0].Indexes["f_one_f_two"] + if assert.True(t, ok) { + assert.EqualValues(t, []string{"f_one", "f_two"}, index.Cols) + } + index, ok = tables[0].Indexes["f_two_f_one"] + if assert.True(t, ok) { + assert.EqualValues(t, []string{"f_two", "f_one"}, index.Cols) + } +} diff --git a/integrations/testdata/import1.sql b/integrations/testdata/import1.sql new file mode 100644 index 00000000..e004f41c --- /dev/null +++ b/integrations/testdata/import1.sql @@ -0,0 +1,279 @@ +SET SQL_MODE = "NO_AUTO_VALUE_ON_ZERO"; +SET time_zone = "+00:00"; + +-- 基本用户信息表 +CREATE TABLE IF NOT EXISTS `user` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + PRIMARY KEY (`id`), + KEY `uid` (`id`), + `user_name` varchar(128) CHARACTER SET utf8mb4 NOT NULL, + KEY `user_name` (`user_name`), + `email` varchar(32) NOT NULL, + KEY `email` (`email`), + `pass` varchar(256) NOT NULL, + `passwd` varchar(16) NOT NULL, + `uuid` TEXT NULL DEFAULT NULL COMMENT 'uuid', + `t` int(11) NOT NULL DEFAULT '0', + `u` bigint(20) NOT NULL, + `d` bigint(20) NOT NULL, + `plan` varchar(2) CHARACTER SET utf8mb4 NOT NULL DEFAULT 'A', + `node_group` INT NOT NULL DEFAULT '0', + `auto_reset_day` INT NOT NULL DEFAULT '0', + `auto_reset_bandwidth` DECIMAL(12,2) NOT NULL DEFAULT '0.00', + `transfer_enable` BIGINT(20) NOT NULL, + `port` int(11) NOT NULL, + `protocol_param` VARCHAR(128) NULL DEFAULT NULL, + `obfs_param` VARCHAR(128) NULL DEFAULT NULL, + `switch` tinyint(4) NOT NULL DEFAULT '1', + `enable` tinyint(4) NOT NULL DEFAULT '1', + `type` tinyint(4) NOT NULL DEFAULT '1', + `last_get_gift_time` int(11) NOT NULL DEFAULT '0', + `last_check_in_time` int(11) NOT NULL DEFAULT '0', + `last_rest_pass_time` int(11) NOT NULL DEFAULT '0', + `reg_date` datetime NOT NULL, + `invite_num` int(8) NOT NULL, + `money` decimal(12,2) NOT NULL, + `ref_by` int(11) NOT NULL DEFAULT '0', + `expire_time` int(11) NOT NULL DEFAULT '0', + `is_email_verify` tinyint(4) NOT NULL DEFAULT '0', + `reg_ip` varchar(128) NOT NULL DEFAULT '127.0.0.1', + `node_speedlimit` DECIMAL(12,2) NOT NULL DEFAULT '0.00', + `node_connector` int(11) NOT NULL DEFAULT '0', + `forbidden_ip` LONGTEXT NULL DEFAULT '', + `forbidden_port` LONGTEXT NULL DEFAULT '', + `disconnect_ip` LONGTEXT NULL DEFAULT '', + `is_hide` INT NOT NULL DEFAULT '0', + `last_detect_ban_time` datetime DEFAULT '1989-06-04 00:05:00', + `all_detect_number` int(11) NOT NULL DEFAULT '0', + `is_multi_user` INT NOT NULL DEFAULT '0', + `telegram_id` BIGINT NULL, + `is_admin` int(2) NOT NULL DEFAULT '0', + `im_type` int(11) DEFAULT '1', + `im_value` text, + `last_day_t` bigint(20) NOT NULL DEFAULT '0', + `mail_notified` int(11) NOT NULL DEFAULT '0', + `class` int(11) NOT NULL DEFAULT '0', + `class_expire` datetime NOT NULL DEFAULT '1989-06-04 00:05:00', + `expire_in` datetime NOT NULL DEFAULT '2099-06-04 00:05:00', + `theme` text NOT NULL, + `ga_token` text NOT NULL, + `ga_enable` int(11) NOT NULL DEFAULT '0', + `pac` LONGTEXT, + `remark` text +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 用户流量信息表 +-- TODO: 重写流量信息提取逻辑 +CREATE TABLE IF NOT EXISTS `user_traffic_log` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + PRIMARY KEY (`id`), + `user_id` int(11) NOT NULL, + `u` BIGINT(20) NOT NULL, + `d` BIGINT(20) NOT NULL, + `node_id` int(11) NOT NULL, + `rate` float NOT NULL, + `traffic` varchar(32) NOT NULL, + `log_time` int(11) NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 用户订阅 TOKEN 信息表 +CREATE TABLE IF NOT EXISTS `user_token` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + PRIMARY KEY (`id`), + `token` varchar(256) NOT NULL, + `user_id` int(11) NOT NULL, + `create_time` int(11) NOT NULL, + `expire_time` int(11) NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 充值码使用信息表 +CREATE TABLE IF NOT EXISTS `charge_code` ( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + PRIMARY KEY (`id`), + `code` text NOT NULL, + `type` int(11) NOT NULL, + `number` DECIMAL(11,2) NOT NULL, + `isused` int(11) NOT NULL DEFAULT '0', + `userid` bigint(20) NOT NULL, + `usedatetime` datetime NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 邀请码使用信息表 +CREATE TABLE IF NOT EXISTS `invite_code` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + PRIMARY KEY (`id`), + `code` varchar(128) NOT NULL, + KEY `user_id` (`user_id`), + `user_id` int(11) NOT NULL, + `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + `updated_at` timestamp NOT NULL DEFAULT '2016-06-01 00:00:00' +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 公告信息表 +CREATE TABLE IF NOT EXISTS `announcement` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + PRIMARY KEY (`id`), + `date` datetime NOT NULL, + `content` LONGTEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL, + `markdown` LONGTEXT NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 节点信息表 +CREATE TABLE IF NOT EXISTS `node` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + PRIMARY KEY (`id`), + `name` varchar(128) NOT NULL, + `type` int(3) NOT NULL, + `online_user` int(11) NOT NULL, + `mu_only` INT NULL DEFAULT '0', + `online` BOOLEAN NOT NULL DEFAULT TRUE, + `server` varchar(128) NOT NULL, + `method` varchar(64) NOT NULL, + `info` varchar(128) NOT NULL, + `status` varchar(128) NOT NULL, + `node_group` INT NOT NULL DEFAULT '0', + `sort` int(3) NOT NULL, + `custom_method` tinyint(1) NOT NULL DEFAULT '0', + `traffic_rate` float NOT NULL DEFAULT '1', + `node_class` int(11) NOT NULL DEFAULT '0', + `node_speedlimit` DECIMAL(12,2) NOT NULL DEFAULT '0.00', + `node_connector` int(11) NOT NULL DEFAULT '0', + `node_bandwidth` bigint(20) NOT NULL DEFAULT '0', + `node_bandwidth_limit` bigint(20) NOT NULL DEFAULT '0', + `bandwidthlimit_resetday` int(11) NOT NULL DEFAULT '0', + `node_heartbeat` bigint(20) NOT NULL DEFAULT '0', + `node_ip` text +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +-- TODO: 修改 VPN 节点的结算说明 + +-- 商店数据表 +CREATE TABLE `shop` ( + `id` BIGINT NOT NULL AUTO_INCREMENT, + `name` TEXT NOT NULL, + `price` DECIMAL(12,2) NOT NULL, + `content` TEXT NOT NULL, + `auto_renew` INT NOT NULL, + `status` INT NOT NULL DEFAULT '1', + `auto_reset_bandwidth` INT NOT NULL DEFAULT '0', + PRIMARY KEY (`id`) +) ENGINE = InnoDB CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 优惠券数据表 +CREATE TABLE `coupon` ( + `id` BIGINT NOT NULL AUTO_INCREMENT, + `code` TEXT NOT NULL, + `onetime` INT NOT NULL, + `expire` BIGINT NOT NULL, + `shop` TEXT NOT NULL, + `credit` INT NOT NULL, + PRIMARY KEY (`id`) +) ENGINE = InnoDB CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 购买记录数据表 +CREATE TABLE `bought` ( + `id` BIGINT NOT NULL AUTO_INCREMENT, + `userid` BIGINT NOT NULL, + `shopid` BIGINT NOT NULL, + `coupon` TEXT NOT NULL, + `datetime` BIGINT NOT NULL, + `renew` BIGINT(11) NOT NULL, + `price` DECIMAL(12,2) NOT NULL, + PRIMARY KEY (`id`) +) ENGINE = InnoDB CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 工单数据表 +CREATE TABLE `ticket` ( + `id` BIGINT NOT NULL AUTO_INCREMENT, + `title` LONGTEXT NOT NULL, + `status` INT NOT NULL DEFAULT '1', + `content` LONGTEXT NOT NULL, + `rootid` BIGINT NOT NULL,`userid` BIGINT NOT NULL, + `datetime` BIGINT NOT NULL, + PRIMARY KEY (`id`) +) ENGINE = InnoDB CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 返利记录数据表 +CREATE TABLE `payback` ( + `id` BIGINT NOT NULL AUTO_INCREMENT, + `total` DECIMAL(12,2) NOT NULL, + `userid` BIGINT NOT NULL, + `ref_by` BIGINT NOT NULL, + `ref_get` DECIMAL(12,2) NOT NULL, + `datetime` BIGINT NOT NULL, + PRIMARY KEY (`id`) +) ENGINE = InnoDB CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 审计规则数据表 +CREATE TABLE `detect_list` ( + `id` BIGINT NOT NULL AUTO_INCREMENT, + `name` LONGTEXT NOT NULL, + `type` INT NOT NULL, + `text` LONGTEXT NOT NULL, + `regex` LONGTEXT NOT NULL, + PRIMARY KEY (`id`) +) ENGINE = InnoDB CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 审计记录数据表 +CREATE TABLE `detect_log` ( + `id` BIGINT NOT NULL AUTO_INCREMENT, + `user_id` BIGINT NOT NULL, + `node_id` INT NOT NULL, + `list_id` BIGINT NOT NULL, + `datetime` BIGINT NOT NULL, + PRIMARY KEY (`id`) +) ENGINE = InnoDB CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 中转规则数据表 +CREATE TABLE IF NOT EXISTS `relay` ( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + PRIMARY KEY (`id`), + `user_id` bigint(20) NOT NULL, + `source_node_id` bigint(20) NOT NULL, + `dist_node_id` bigint(20) NOT NULL, + `dist_ip` text NOT NULL, + `port` int(11) NOT NULL, + `priority` int(11) NOT NULL +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 用户订阅日志 +CREATE TABLE IF NOT EXISTS `user_subscribe_log` ( + `id` int(11) unsigned NOT NULL AUTO_INCREMENT, + `user_name` varchar(128) NOT NULL COMMENT '用户名', + `user_id` int(11) NOT NULL COMMENT '用户 ID', + `email` varchar(32) NOT NULL COMMENT '用户邮箱', + `subscribe_type` varchar(20) NOT NULL COMMENT '获取的订阅类型', + `request_ip` varchar(128) NOT NULL COMMENT '请求 IP', + `request_time` datetime NOT NULL COMMENT '请求时间', + `request_user_agent` text COMMENT '请求 UA 信息', + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='用户订阅日志'; + +-- 审计封禁日志 +CREATE TABLE IF NOT EXISTS `detect_ban_log` ( + `id` int(11) unsigned NOT NULL AUTO_INCREMENT, + `user_name` varchar(128) NOT NULL COMMENT '用户名', + `user_id` int(11) NOT NULL COMMENT '用户 ID', + `email` varchar(32) NOT NULL COMMENT '用户邮箱', + `detect_number` int(11) NOT NULL COMMENT '本次违规次数', + `ban_time` int(11) NOT NULL COMMENT '本次封禁时长', + `start_time` bigint(20) NOT NULL COMMENT '统计开始时间', + `end_time` bigint(20) NOT NULL COMMENT '统计结束时间', + `all_detect_number` int(11) NOT NULL COMMENT '累计违规次数', + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='审计封禁日志'; + +-- 管理员操作记录 +CREATE TABLE IF NOT EXISTS `gconfig` ( + `id` int(11) unsigned NOT NULL AUTO_INCREMENT, + `key` varchar(128) NOT NULL COMMENT '配置键名', + `type` varchar(32) NOT NULL COMMENT '值类型', + `value` text NOT NULL COMMENT '配置值', + `oldvalue` text NOT NULL COMMENT '之前的配置值', + `name` varchar(128) NOT NULL COMMENT '配置名称', + `comment` text NOT NULL COMMENT '配置描述', + `operator_id` int(11) NOT NULL COMMENT '操作员 ID', + `operator_name` varchar(128) NOT NULL COMMENT '操作员名称', + `operator_email` varchar(32) NOT NULL COMMENT '操作员邮箱', + `last_update` bigint(20) NOT NULL COMMENT '修改时间', + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='网站配置'; \ No newline at end of file diff --git a/integrations/testdata/import2.sql b/integrations/testdata/import2.sql new file mode 100644 index 00000000..469b096b --- /dev/null +++ b/integrations/testdata/import2.sql @@ -0,0 +1,3 @@ +CREATE TABLE IF NOT EXISTS `Core_Goods` (`Id` BIGINT(20) PRIMARY KEY AUTO_INCREMENT NOT NULL, `GoodsSN` VARCHAR(20) NULL COMMENT '商品序列号', `GoodsSort` INT(11) NULL COMMENT '商品排序', `GoodsName` VARCHAR(100) NULL COMMENT '商品名称', `GoodsThumb` VARCHAR(500) NULL, `GoodsUnit` VARCHAR(255) NULL, `PerUnitNum` BIGINT(20) NULL, `GoodsState` TINYINT(4) NULL COMMENT '商品状态', `IsClose` TINYINT(1) DEFAULT 0 NULL COMMENT '关闭下单', `GoodsDesc` VARCHAR(200) NULL COMMENT '商品简介', `GoodsContent` TEXT NULL COMMENT '商品详情', `GoodsImages` TEXT NULL COMMENT '商品图片', `MinOrderNum` INT(11) NULL COMMENT '最少下单数', `MaxOrderNum` INT(11) NULL COMMENT '最大下单数', `CategoryId` INT(11) NULL COMMENT '商品分类', `SupplyPrice` BIGINT(20) NULL COMMENT '供货单价', `StockNum` INT(11) NULL COMMENT '库存数量,小于0不限制', `HandleRemarks` VARCHAR(255) NULL COMMENT '处理备注', `ParamsTemplate` TEXT NULL COMMENT '下单参数模板', `PriceTemplateId` INT(11) NULL COMMENT '加价模板id', `GoodsSnapshotId` INT(11) NULL COMMENT '当前快照id', `SupplierUserId` INT(11) NULL COMMENT '供货商用户id', `CanTui` TINYINT(1) DEFAULT 0 NULL COMMENT '是否可以申请退款', `CanRepeat` TINYINT(1) DEFAULT 1 NULL COMMENT '是否可以申请退款', `GoodsType` TINYINT(4) DEFAULT 1 NULL COMMENT '商品类型', `CreatedAt` DATETIME NULL COMMENT '创建时间', `UpdatedAt` DATETIME NULL COMMENT '修改时间', `LastHandlerAdminUserId` INT(11) NULL COMMENT '最后操作管理员', `GoodsMode` TINYINT(4) DEFAULT 0 NULL COMMENT '商品属性', `SaleTotal` INT(11) NULL COMMENT '总销量', `SaleMonth` INT(11) NULL COMMENT '月销量', `Notice` TEXT NULL COMMENT '商品公告', `AfterLunchOrderState` TINYINT(4) NULL COMMENT '下单后状态', `JoinMode` TINYINT(4) DEFAULT 1 NULL COMMENT '对接模式', `ApiOrderLunchConfig` TEXT NULL COMMENT '提交订单配置', `Version` BIGINT(20) DEFAULT 1 NULL) ENGINE=InnoDB; +INSERT INTO `Core_Goods` (`Id`, `GoodsSN`, `GoodsSort`, `GoodsName`, `GoodsThumb`, `GoodsUnit`, `PerUnitNum`, `GoodsState`, `IsClose`, `GoodsDesc`, `GoodsContent`, `GoodsImages`, `MinOrderNum`, `MaxOrderNum`, `CategoryId`, `SupplyPrice`, `StockNum`, `HandleRemarks`, `ParamsTemplate`, `PriceTemplateId`, `GoodsSnapshotId`, `SupplierUserId`, `CanTui`, `CanRepeat`, `GoodsType`, `CreatedAt`, `UpdatedAt`, `LastHandlerAdminUserId`, `GoodsMode`, `SaleTotal`, `SaleMonth`, `Notice`, `AfterLunchOrderState`, `JoinMode`, `ApiOrderLunchConfig`, `Version`) VALUES (42,'2107290140000432',91,' - ','','',1,2,0,'','--','[]',3,10000,10,5974060,-1,'','',1,0,10001,0,1,1,'2021-07-29 01:40:55','2021-07-30 18:28:59',10000,2,0,0,'',0,0,'{"url":"","method":"","type":"","succContactStr":"","data":null}',35); +INSERT INTO `Core_Goods` (`Id`, `GoodsSN`, `GoodsSort`, `GoodsName`, `GoodsThumb`, `GoodsUnit`, `PerUnitNum`, `GoodsState`, `IsClose`, `GoodsDesc`, `GoodsContent`, `GoodsImages`, `MinOrderNum`, `MaxOrderNum`, `CategoryId`, `SupplyPrice`, `StockNum`, `HandleRemarks`, `ParamsTemplate`, `PriceTemplateId`, `GoodsSnapshotId`, `SupplierUserId`, `CanTui`, `CanRepeat`, `GoodsType`, `CreatedAt`, `UpdatedAt`, `LastHandlerAdminUserId`, `GoodsMode`, `SaleTotal`, `SaleMonth`, `Notice`, `AfterLunchOrderState`, `JoinMode`, `ApiOrderLunchConfig`, `Version`) VALUES (43,'2107290140000433',90,' - ','','',1,2,0,'','','[]',3,10000,10,9064091,-1,'','',1,0,10001,0,1,1,'2021-07-29 01:40:55','2021-07-30 18:28:59',10000,2,0,0,'',0,0,'{"url":"","method":"","type":"","succContactStr":"","data":null}',39); \ No newline at end of file diff --git a/integrations/tests.go b/integrations/tests.go index 31fa99bf..59f4b29a 100644 --- a/integrations/tests.go +++ b/integrations/tests.go @@ -8,6 +8,7 @@ import ( "database/sql" "flag" "fmt" + "net/url" "os" "strings" "testing" @@ -50,7 +51,7 @@ func createEngine(dbType, connStr string) error { if !*cluster { switch schemas.DBType(strings.ToLower(dbType)) { case schemas.MSSQL: - db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) + db, err := sql.Open(dbType, strings.ReplaceAll(connStr, "xorm_test", "master")) if err != nil { return err } @@ -60,7 +61,7 @@ func createEngine(dbType, connStr string) error { db.Close() *ignoreSelectUpdate = true case schemas.POSTGRES: - db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "postgres", -1)) + db, err := sql.Open(dbType, strings.ReplaceAll(connStr, "xorm_test", "postgres")) if err != nil { return err } @@ -89,7 +90,7 @@ func createEngine(dbType, connStr string) error { db.Close() *ignoreSelectUpdate = true case schemas.MYSQL: - db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1)) + db, err := sql.Open(dbType, strings.ReplaceAll(connStr, "xorm_test", "mysql")) if err != nil { return err } @@ -97,6 +98,13 @@ func createEngine(dbType, connStr string) error { return fmt.Errorf("db.Exec: %v", err) } db.Close() + case schemas.SQLITE, "sqlite": + u, err := url.Parse(connStr) + if err != nil { + return err + } + connStr = u.Path + *ignoreSelectUpdate = true default: *ignoreSelectUpdate = true } @@ -158,23 +166,28 @@ func createEngine(dbType, connStr string) error { for _, table := range tables { tableNames = append(tableNames, table.Name) } - if err = testEngine.DropTables(tableNames...); err != nil { - return err - } - return nil + return testEngine.DropTables(tableNames...) } +// PrepareEngine prepare tests ORM engine func PrepareEngine() error { return createEngine(dbType, connString) } +// MainTest the tests entrance func MainTest(m *testing.M) { flag.Parse() dbType = *db if *db == "sqlite3" { if ptrConnStr == nil { - connString = "./test.db?cache=shared&mode=rwc" + connString = "./test_sqlite3.db?cache=shared&mode=rwc" + } else { + connString = *ptrConnStr + } + } else if *db == "sqlite" { + if ptrConnStr == nil { + connString = "./test_sqlite.db?cache=shared&mode=rwc" } else { connString = *ptrConnStr } diff --git a/integrations/time_test.go b/integrations/time_test.go index 6d8d812c..a8447eea 100644 --- a/integrations/time_test.go +++ b/integrations/time_test.go @@ -15,8 +15,12 @@ import ( "github.com/stretchr/testify/assert" ) -func formatTime(t time.Time) string { - return t.Format("2006-01-02 15:04:05") +func formatTime(t time.Time, scales ...int) string { + var layout = "2006-01-02 15:04:05" + if len(scales) > 0 && scales[0] > 0 { + layout += "." + strings.Repeat("0", scales[0]) + } + return t.Format(layout) } func TestTimeUserTime(t *testing.T) { @@ -53,9 +57,18 @@ func TestTimeUserTimeDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type TimeUser2 struct { @@ -118,9 +131,18 @@ func TestTimeUserCreatedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserCreated2 struct { @@ -204,9 +226,18 @@ func TestTimeUserUpdatedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserUpdated2 struct { @@ -293,7 +324,7 @@ func TestTimeUserDeleted(t *testing.T) { fmt.Println("user2 str", user2.CreatedAtStr, user2.UpdatedAtStr) var user3 UserDeleted - cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + cnt, err = testEngine.Where("`id` = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) assert.True(t, !utils.IsTimeZero(user3.DeletedAt)) @@ -311,9 +342,18 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserDeleted2 struct { @@ -346,7 +386,7 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted2 - cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + cnt, err = testEngine.Where("`id` = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) assert.True(t, !utils.IsTimeZero(user3.DeletedAt)) @@ -417,7 +457,7 @@ func TestCustomTimeUserDeleted(t *testing.T) { fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted3 - cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + cnt, err = testEngine.Where("`id` = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) assert.True(t, !utils.IsTimeZero(time.Time(user3.DeletedAt))) @@ -435,9 +475,18 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) + oldTZLoc := testEngine.GetTZLocation() + defer func() { + testEngine.SetTZLocation(oldTZLoc) + }() testEngine.SetTZLocation(loc) + dbLoc, err := time.LoadLocation("America/New_York") assert.NoError(t, err) + oldDBLoc := testEngine.GetTZDatabase() + defer func() { + testEngine.SetTZDatabase(oldDBLoc) + }() testEngine.SetTZDatabase(dbLoc) type UserDeleted4 struct { @@ -470,7 +519,7 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted4 - cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) + cnt, err = testEngine.Where("`id` = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) assert.True(t, !utils.IsTimeZero(time.Time(user3.DeletedAt))) @@ -520,3 +569,53 @@ func TestDeletedInt64(t *testing.T) { assert.True(t, has) assert.EqualValues(t, d1, d4) } + +func TestTimestamp(t *testing.T) { + { + assert.NoError(t, PrepareEngine()) + + type TimestampStruct struct { + Id int64 + InsertTime time.Time `xorm:"DATETIME(6)"` + } + + assertSync(t, new(TimestampStruct)) + + var d1 = TimestampStruct{ + InsertTime: time.Now(), + } + cnt, err := testEngine.Insert(&d1) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var d2 TimestampStruct + has, err := testEngine.ID(d1.Id).Get(&d2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, formatTime(d1.InsertTime, 6), formatTime(d2.InsertTime, 6)) + } + + /*{ + assert.NoError(t, PrepareEngine()) + + type TimestampzStruct struct { + Id int64 + InsertTime time.Time `xorm:"TIMESTAMPZ"` + } + + assertSync(t, new(TimestampzStruct)) + + var d3 = TimestampzStruct{ + InsertTime: time.Now(), + } + cnt, err := testEngine.Insert(&d3) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var d4 TimestampzStruct + has, err := testEngine.ID(d3.Id).Get(&d4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, formatTime(d3.InsertTime, 6), formatTime(d4.InsertTime, 6)) + }*/ +} diff --git a/integrations/types_null_test.go b/integrations/types_null_test.go index 98bd86b9..8d98b456 100644 --- a/integrations/types_null_test.go +++ b/integrations/types_null_test.go @@ -7,7 +7,6 @@ package integrations import ( "database/sql" "database/sql/driver" - "errors" "fmt" "strconv" "strings" @@ -16,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" ) -type NullType struct { +type NullStruct struct { Id int `xorm:"pk autoincr"` Name sql.NullString Age sql.NullInt64 @@ -42,15 +41,22 @@ func (m *CustomStruct) Scan(value interface{}) error { return nil } - if s, ok := value.([]byte); ok { - seps := strings.Split(string(s), "/") + var s string + switch t := value.(type) { + case string: + s = t + case []byte: + s = string(t) + } + if len(s) > 0 { + seps := strings.Split(s, "/") m.Year, _ = strconv.Atoi(seps[0]) m.Month, _ = strconv.Atoi(seps[1]) m.Day, _ = strconv.Atoi(seps[2]) return nil } - return errors.New("scan data not fit []byte") + return fmt.Errorf("scan data %#v not fit []byte", value) } func (m CustomStruct) Value() (driver.Value, error) { @@ -59,26 +65,26 @@ func (m CustomStruct) Value() (driver.Value, error) { func TestCreateNullStructTable(t *testing.T) { assert.NoError(t, PrepareEngine()) - err := testEngine.CreateTables(new(NullType)) + err := testEngine.CreateTables(new(NullStruct)) assert.NoError(t, err) } func TestDropNullStructTable(t *testing.T) { assert.NoError(t, PrepareEngine()) - err := testEngine.DropTables(new(NullType)) + err := testEngine.DropTables(new(NullStruct)) assert.NoError(t, err) } func TestNullStructInsert(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - item1 := new(NullType) + item1 := new(NullStruct) _, err := testEngine.Insert(item1) assert.NoError(t, err) assert.EqualValues(t, 1, item1.Id) - item := NullType{ + item := NullStruct{ Name: sql.NullString{String: "haolei", Valid: true}, Age: sql.NullInt64{Int64: 34, Valid: true}, Height: sql.NullFloat64{Float64: 1.72, Valid: true}, @@ -89,9 +95,9 @@ func TestNullStructInsert(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, item.Id) - items := []NullType{} + items := []NullStruct{} for i := 0; i < 5; i++ { - item := NullType{ + item := NullStruct{ Name: sql.NullString{String: "haolei_" + fmt.Sprint(i+1), Valid: true}, Age: sql.NullInt64{Int64: 30 + int64(i), Valid: true}, Height: sql.NullFloat64{Float64: 1.5 + 1.1*float64(i), Valid: true}, @@ -105,7 +111,7 @@ func TestNullStructInsert(t *testing.T) { _, err = testEngine.Insert(&items) assert.NoError(t, err) - items = make([]NullType, 0, 7) + items = make([]NullStruct, 0, 7) err = testEngine.Find(&items) assert.NoError(t, err) assert.EqualValues(t, 7, len(items)) @@ -113,9 +119,9 @@ func TestNullStructInsert(t *testing.T) { func TestNullStructUpdate(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - _, err := testEngine.Insert([]NullType{ + _, err := testEngine.Insert([]NullStruct{ { Name: sql.NullString{ String: "name1", @@ -144,7 +150,7 @@ func TestNullStructUpdate(t *testing.T) { assert.NoError(t, err) if true { // 测试可插入NULL - item := new(NullType) + item := new(NullStruct) item.Age = sql.NullInt64{Int64: 23, Valid: true} item.Height = sql.NullFloat64{Float64: 0, Valid: false} // update to NULL @@ -154,7 +160,7 @@ func TestNullStructUpdate(t *testing.T) { } if true { // 测试In update - item := new(NullType) + item := new(NullStruct) item.Age = sql.NullInt64{Int64: 23, Valid: true} affected, err := testEngine.In("id", 3, 4).Cols("age", "height", "is_man").Update(item) assert.NoError(t, err) @@ -162,17 +168,17 @@ func TestNullStructUpdate(t *testing.T) { } if true { // 测试where - item := new(NullType) + item := new(NullStruct) item.Name = sql.NullString{String: "nullname", Valid: true} item.IsMan = sql.NullBool{Bool: true, Valid: true} item.Age = sql.NullInt64{Int64: 34, Valid: true} - _, err := testEngine.Where("age > ?", 34).Update(item) + _, err := testEngine.Where("`age` > ?", 34).Update(item) assert.NoError(t, err) } if true { // 修改全部时,插入空值 - item := &NullType{ + item := &NullStruct{ Name: sql.NullString{String: "winxxp", Valid: true}, Age: sql.NullInt64{Int64: 30, Valid: true}, Height: sql.NullFloat64{Float64: 1.72, Valid: true}, @@ -186,9 +192,9 @@ func TestNullStructUpdate(t *testing.T) { func TestNullStructFind(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - _, err := testEngine.Insert([]NullType{ + _, err := testEngine.Insert([]NullStruct{ { Name: sql.NullString{ String: "name1", @@ -217,7 +223,7 @@ func TestNullStructFind(t *testing.T) { assert.NoError(t, err) if true { - item := new(NullType) + item := new(NullStruct) has, err := testEngine.ID(1).Get(item) assert.NoError(t, err) assert.True(t, has) @@ -229,7 +235,7 @@ func TestNullStructFind(t *testing.T) { } if true { - item := new(NullType) + item := new(NullStruct) item.Id = 2 has, err := testEngine.Get(item) assert.NoError(t, err) @@ -237,13 +243,13 @@ func TestNullStructFind(t *testing.T) { } if true { - item := make([]NullType, 0) + item := make([]NullStruct, 0) err := testEngine.ID(2).Find(&item) assert.NoError(t, err) } if true { - item := make([]NullType, 0) + item := make([]NullStruct, 0) err := testEngine.Asc("age").Find(&item) assert.NoError(t, err) } @@ -251,12 +257,12 @@ func TestNullStructFind(t *testing.T) { func TestNullStructIterate(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) if true { - err := testEngine.Where("age IS NOT NULL").OrderBy("age").Iterate(new(NullType), + err := testEngine.Where("`age` IS NOT NULL").OrderBy("age").Iterate(new(NullStruct), func(i int, bean interface{}) error { - nultype := bean.(*NullType) + nultype := bean.(*NullStruct) fmt.Println(i, nultype) return nil }) @@ -266,21 +272,21 @@ func TestNullStructIterate(t *testing.T) { func TestNullStructCount(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) if true { - item := new(NullType) - _, err := testEngine.Where("age IS NOT NULL").Count(item) + item := new(NullStruct) + _, err := testEngine.Where("`age` IS NOT NULL").Count(item) assert.NoError(t, err) } } func TestNullStructRows(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - item := new(NullType) - rows, err := testEngine.Where("id > ?", 1).Rows(item) + item := new(NullStruct) + rows, err := testEngine.Where("`id` > ?", 1).Rows(item) assert.NoError(t, err) defer rows.Close() @@ -292,13 +298,13 @@ func TestNullStructRows(t *testing.T) { func TestNullStructDelete(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(NullType)) + assertSync(t, new(NullStruct)) - item := new(NullType) + item := new(NullStruct) _, err := testEngine.ID(1).Delete(item) assert.NoError(t, err) - _, err = testEngine.Where("id > ?", 1).Delete(item) + _, err = testEngine.Where("`id` > ?", 1).Delete(item) assert.NoError(t, err) } diff --git a/integrations/types_test.go b/integrations/types_test.go index 112308f3..1c815b7a 100644 --- a/integrations/types_test.go +++ b/integrations/types_test.go @@ -7,6 +7,9 @@ package integrations import ( "errors" "fmt" + "math" + "math/big" + "strconv" "testing" "xorm.io/xorm" @@ -25,9 +28,9 @@ func TestArrayField(t *testing.T) { Name [20]byte `xorm:"char(80)"` } - assert.NoError(t, testEngine.Sync2(new(ArrayStruct))) + assert.NoError(t, testEngine.Sync(new(ArrayStruct))) - var as = ArrayStruct{ + as := ArrayStruct{ Name: [20]byte{ 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, @@ -51,7 +54,7 @@ func TestArrayField(t *testing.T) { assert.EqualValues(t, 1, len(arrs)) assert.Equal(t, as.Name, arrs[0].Name) - var newName = [20]byte{ + newName := [20]byte{ 90, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, 96, @@ -87,7 +90,7 @@ func TestGetBytes(t *testing.T) { Data []byte `xorm:"VARBINARY(250)"` } - err := testEngine.Sync2(new(Varbinary)) + err := testEngine.Sync(new(Varbinary)) assert.NoError(t, err) cnt, err := testEngine.Insert(&Varbinary{ @@ -144,13 +147,39 @@ func (s *SliceType) ToDB() ([]byte, error) { return json.DefaultJSONHandler.Marshal(s) } +type Nullable struct { + Data string +} + +func (s *Nullable) FromDB(data []byte) error { + if data == nil { + return nil + } + + *s = Nullable{ + Data: string(data), + } + + return nil +} + +func (s *Nullable) ToDB() ([]byte, error) { + if s == nil { + return nil, nil + } + + return []byte(s.Data), nil +} + type ConvStruct struct { - Conv ConvString - Conv2 *ConvString - Cfg1 ConvConfig - Cfg2 *ConvConfig `xorm:"TEXT"` - Cfg3 convert.Conversion `xorm:"BLOB"` - Slice SliceType + Conv ConvString + Conv2 *ConvString + Cfg1 ConvConfig + Cfg2 *ConvConfig `xorm:"TEXT"` + Cfg3 convert.Conversion `xorm:"BLOB"` + Slice SliceType + Nullable1 *Nullable `xorm:"null"` + Nullable2 *Nullable `xorm:"null"` } func (c *ConvStruct) BeforeSet(name string, cell xorm.Cell) { @@ -164,7 +193,7 @@ func TestConversion(t *testing.T) { c := new(ConvStruct) assert.NoError(t, testEngine.DropTables(c)) - assert.NoError(t, testEngine.Sync2(c)) + assert.NoError(t, testEngine.Sync(c)) var s ConvString = "sssss" c.Conv = "tttt" @@ -173,8 +202,10 @@ func TestConversion(t *testing.T) { c.Cfg2 = &ConvConfig{"xx", 2} c.Cfg3 = &ConvConfig{"zz", 3} c.Slice = []*ConvConfig{{"yy", 4}, {"ff", 5}} + c.Nullable1 = &Nullable{Data: "test"} + c.Nullable2 = nil - _, err := testEngine.Insert(c) + _, err := testEngine.Nullable("nullable2").Insert(c) assert.NoError(t, err) c1 := new(ConvStruct) @@ -216,11 +247,16 @@ func TestConversion(t *testing.T) { assert.EqualValues(t, 2, len(c2.Slice)) assert.EqualValues(t, *c.Slice[0], *c2.Slice[0]) assert.EqualValues(t, *c.Slice[1], *c2.Slice[1]) + assert.NotNil(t, c1.Nullable1) + assert.Equal(t, c1.Nullable1.Data, "test") + assert.Nil(t, c1.Nullable2) } -type MyInt int -type MyUInt uint -type MyFloat float64 +type ( + MyInt int + MyUInt uint + MyFloat float64 +) type MyStruct struct { Type MyInt @@ -239,7 +275,7 @@ type MyStruct struct { UIA32 []uint32 UIA64 []uint64 UI uint - //C64 complex64 + // C64 complex64 MSS map[string]string } @@ -270,6 +306,13 @@ func TestCustomType1(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + // since mssql don't support use text as index condition, we have to ignore below + // get and find tests + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + t.Skip() + return + } + fmt.Println(i) i.NameArray = []string{} i.MSS = map[string]string{} @@ -375,3 +418,204 @@ func TestCustomType2(t *testing.T) { fmt.Println(users) } + +func TestUnsignedUint64(t *testing.T) { + type MyUnsignedStruct struct { + Id uint64 + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(MyUnsignedStruct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + + switch testEngine.Dialect().URI().DBType { + case schemas.SQLITE: + assert.EqualValues(t, "INTEGER", tables[0].Columns()[0].SQLType.Name) + case schemas.MYSQL: + assert.EqualValues(t, "UNSIGNED BIGINT", tables[0].Columns()[0].SQLType.Name) + case schemas.POSTGRES, schemas.DAMENG: + assert.EqualValues(t, "BIGINT", tables[0].Columns()[0].SQLType.Name) + case schemas.MSSQL: + assert.EqualValues(t, "BIGINT", tables[0].Columns()[0].SQLType.Name) + default: + assert.False(t, true, "Unsigned is not implemented") + } + + // Only MYSQL database supports unsigned bigint + if testEngine.Dialect().URI().DBType != schemas.MYSQL { + return + } + + cnt, err := testEngine.Insert(&MyUnsignedStruct{ + Id: math.MaxUint64, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var v MyUnsignedStruct + has, err := testEngine.Get(&v) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, uint64(math.MaxUint64), v.Id) +} + +func TestUnsignedUint32(t *testing.T) { + type MyUnsignedInt32Struct struct { + Id uint32 + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(MyUnsignedInt32Struct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + + switch testEngine.Dialect().URI().DBType { + case schemas.SQLITE: + assert.EqualValues(t, "INTEGER", tables[0].Columns()[0].SQLType.Name) + case schemas.MYSQL: + assert.EqualValues(t, "UNSIGNED INT", tables[0].Columns()[0].SQLType.Name) + case schemas.POSTGRES, schemas.MSSQL, schemas.DAMENG: + assert.EqualValues(t, "BIGINT", tables[0].Columns()[0].SQLType.Name) + default: + assert.False(t, true, "Unsigned is not implemented") + } + + cnt, err := testEngine.Insert(&MyUnsignedInt32Struct{ + Id: math.MaxUint32, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var v MyUnsignedInt32Struct + has, err := testEngine.Get(&v) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, uint64(math.MaxUint32), v.Id) +} + +func TestUnsignedTinyInt(t *testing.T) { + type MyUnsignedTinyIntStruct struct { + Id uint8 `xorm:"unsigned tinyint"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(MyUnsignedTinyIntStruct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + + switch testEngine.Dialect().URI().DBType { + case schemas.SQLITE, schemas.DAMENG: + assert.EqualValues(t, "INTEGER", tables[0].Columns()[0].SQLType.Name) + case schemas.MYSQL: + assert.EqualValues(t, "UNSIGNED TINYINT", tables[0].Columns()[0].SQLType.Name) + case schemas.POSTGRES: + assert.EqualValues(t, "SMALLINT", tables[0].Columns()[0].SQLType.Name) + case schemas.MSSQL: + assert.EqualValues(t, "INT", tables[0].Columns()[0].SQLType.Name) + default: + assert.False(t, true, fmt.Sprintf("Unsigned is not implemented, returned %s", tables[0].Columns()[0].SQLType.Name)) + } + + cnt, err := testEngine.Insert(&MyUnsignedTinyIntStruct{ + Id: math.MaxUint8, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var v MyUnsignedTinyIntStruct + has, err := testEngine.Get(&v) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, uint64(math.MaxUint32), v.Id) +} + +type MyDecimal big.Int + +func (d *MyDecimal) FromDB(data []byte) error { + i, _ := strconv.ParseInt(string(data), 10, 64) + if d == nil { + d = (*MyDecimal)(big.NewInt(i)) + } else { + (*big.Int)(d).SetInt64(i) + } + return nil +} + +func (d *MyDecimal) ToDB() ([]byte, error) { + return []byte(fmt.Sprintf("%d", (*big.Int)(d).Int64())), nil +} + +func (d *MyDecimal) AsBigInt() *big.Int { + return (*big.Int)(d) +} + +func (d *MyDecimal) AsInt64() int64 { + return d.AsBigInt().Int64() +} + +func TestDecimal(t *testing.T) { + type MyMoney struct { + Id int64 + Account *MyDecimal + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(MyMoney)) + + _, err := testEngine.Insert(&MyMoney{ + Account: (*MyDecimal)(big.NewInt(10000000000000000)), + }) + assert.NoError(t, err) + + var m MyMoney + has, err := testEngine.Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.NotNil(t, m.Account) + assert.EqualValues(t, 10000000000000000, m.Account.AsInt64()) +} + +type MyArray [20]byte + +func (d *MyArray) FromDB(data []byte) error { + for i, b := range data[:20] { + (*d)[i] = b + } + return nil +} + +func (d MyArray) ToDB() ([]byte, error) { + return d[:], nil +} + +func TestMyArray(t *testing.T) { + type MyArrayStruct struct { + Id int64 + Content MyArray + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(MyArrayStruct)) + + v := [20]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + _, err := testEngine.Insert(&MyArrayStruct{ + Content: v, + }) + assert.NoError(t, err) + + var m MyArrayStruct + has, err := testEngine.Get(&m) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, v, m.Content) +} diff --git a/interface.go b/interface.go index 0fe9cbe1..d10abe9e 100644 --- a/interface.go +++ b/interface.go @@ -30,14 +30,15 @@ type Interface interface { CreateUniques(bean interface{}) error Decr(column string, arg ...interface{}) *Session Desc(...string) *Session - Delete(interface{}) (int64, error) + Delete(...interface{}) (int64, error) + Truncate(...interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error Exec(sqlOrArgs ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error FindAndCount(interface{}, ...interface{}) (int64, error) - Get(interface{}) (bool, error) + Get(...interface{}) (bool, error) GroupBy(keys string) *Session ID(interface{}) *Session In(string, ...interface{}) *Session @@ -51,9 +52,10 @@ type Interface interface { MustCols(columns ...string) *Session NoAutoCondition(...bool) *Session NotIn(string, ...interface{}) *Session - Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session + Nullable(...string) *Session + Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session Omit(columns ...string) *Session - OrderBy(order string) *Session + OrderBy(order interface{}, args ...interface{}) *Session Ping() error Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) @@ -83,6 +85,7 @@ type EngineInterface interface { Context(context.Context) *Session CreateTables(...interface{}) error DBMetas() ([]*schemas.Table, error) + DBVersion() (*schemas.Version, error) Dialect() dialects.Dialect DriverName() string DropTables(...interface{}) error @@ -97,10 +100,12 @@ type EngineInterface interface { MapCacher(interface{}, caches.Cacher) error NewSession() *Session NoAutoTime() *Session + Prepare() *Session Quote(string) string SetCacher(string, caches.Cacher) SetConnMaxLifetime(time.Duration) SetColumnMapper(names.Mapper) + SetTagIdentifier(string) SetDefaultCacher(caches.Cacher) SetLogger(logger interface{}) SetLogLevel(log.LogLevel) diff --git a/internal/json/gojson.go b/internal/json/gojson.go new file mode 100644 index 00000000..9bfa5c29 --- /dev/null +++ b/internal/json/gojson.go @@ -0,0 +1,29 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build gojson +// +build gojson + +package json + +import ( + gojson "github.com/goccy/go-json" +) + +func init() { + DefaultJSONHandler = GOjson{} +} + +// GOjson implements JSONInterface via gojson +type GOjson struct{} + +// Marshal implements JSONInterface +func (GOjson) Marshal(v interface{}) ([]byte, error) { + return gojson.Marshal(v) +} + +// Unmarshal implements JSONInterface +func (GOjson) Unmarshal(data []byte, v interface{}) error { + return gojson.Unmarshal(data, v) +} diff --git a/internal/json/json.go b/internal/json/json.go index c9a2eb4e..ef52f51f 100644 --- a/internal/json/json.go +++ b/internal/json/json.go @@ -6,15 +6,15 @@ package json import "encoding/json" -// JSONInterface represents an interface to handle json data -type JSONInterface interface { +// Interface represents an interface to handle json data +type Interface interface { Marshal(v interface{}) ([]byte, error) Unmarshal(data []byte, v interface{}) error } var ( // DefaultJSONHandler default json handler - DefaultJSONHandler JSONInterface = StdJSON{} + DefaultJSONHandler Interface = StdJSON{} ) // StdJSON implements JSONInterface via encoding/json diff --git a/internal/json/jsoniter.go b/internal/json/jsoniter.go new file mode 100644 index 00000000..be93ac4e --- /dev/null +++ b/internal/json/jsoniter.go @@ -0,0 +1,29 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build jsoniter +// +build jsoniter + +package json + +import ( + jsoniter "github.com/json-iterator/go" +) + +func init() { + DefaultJSONHandler = JSONiter{} +} + +// JSONiter implements JSONInterface via jsoniter +type JSONiter struct{} + +// Marshal implements JSONInterface +func (JSONiter) Marshal(v interface{}) ([]byte, error) { + return jsoniter.Marshal(v) +} + +// Unmarshal implements JSONInterface +func (JSONiter) Unmarshal(data []byte, v interface{}) error { + return jsoniter.Unmarshal(data, v) +} diff --git a/internal/statements/cache.go b/internal/statements/cache.go index cb33df08..669cd018 100644 --- a/internal/statements/cache.go +++ b/internal/statements/cache.go @@ -12,6 +12,7 @@ import ( "xorm.io/xorm/schemas" ) +// ConvertIDSQL converts SQL with id func (statement *Statement) ConvertIDSQL(sqlStr string) string { if statement.RefTable != nil { cols := statement.RefTable.PKColumns() @@ -37,6 +38,7 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string { return "" } +// ConvertUpdateSQL converts update SQL func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { return "", "" diff --git a/internal/statements/cond.go b/internal/statements/cond.go new file mode 100644 index 00000000..dfc6c208 --- /dev/null +++ b/internal/statements/cond.go @@ -0,0 +1,111 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +type QuoteReplacer struct { + *builder.BytesWriter + quoter schemas.Quoter +} + +func (q *QuoteReplacer) Write(p []byte) (n int, err error) { + c := q.quoter.Replace(string(p)) + return q.BytesWriter.Builder.WriteString(c) +} + +func (statement *Statement) QuoteReplacer(w *builder.BytesWriter) *QuoteReplacer { + return &QuoteReplacer{ + BytesWriter: w, + quoter: statement.dialect.Quoter(), + } +} + +// Where add Where statement +func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement { + return statement.And(query, args...) +} + +// And add Where & and statement +func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { + switch qr := query.(type) { + case string: + cond := builder.Expr(qr, args...) + statement.cond = statement.cond.And(cond) + case map[string]interface{}: + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v + } + statement.cond = statement.cond.And(cond) + case builder.Cond: + statement.cond = statement.cond.And(qr) + for _, v := range args { + if vv, ok := v.(builder.Cond); ok { + statement.cond = statement.cond.And(vv) + } + } + default: + statement.LastError = ErrConditionType + } + + return statement +} + +// Or add Where & Or statement +func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { + switch qr := query.(type) { + case string: + cond := builder.Expr(qr, args...) + statement.cond = statement.cond.Or(cond) + case map[string]interface{}: + cond := make(builder.Eq) + for k, v := range qr { + cond[statement.quote(k)] = v + } + statement.cond = statement.cond.Or(cond) + case builder.Cond: + statement.cond = statement.cond.Or(qr) + for _, v := range args { + if vv, ok := v.(builder.Cond); ok { + statement.cond = statement.cond.Or(vv) + } + } + default: + statement.LastError = ErrConditionType + } + return statement +} + +// In generate "Where column IN (?) " statement +func (statement *Statement) In(column string, args ...interface{}) *Statement { + in := builder.In(statement.quote(column), args...) + statement.cond = statement.cond.And(in) + return statement +} + +// NotIn generate "Where column NOT IN (?) " statement +func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { + notIn := builder.NotIn(statement.quote(column), args...) + statement.cond = statement.cond.And(notIn) + return statement +} + +// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function +func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement { + statement.NoAutoCondition = true + if len(no) > 0 { + statement.NoAutoCondition = no[0] + } + return statement +} + +// Conds returns condtions +func (statement *Statement) Conds() builder.Cond { + return statement.cond +} diff --git a/internal/statements/expr.go b/internal/statements/expr.go new file mode 100644 index 00000000..c2a2e1cc --- /dev/null +++ b/internal/statements/expr.go @@ -0,0 +1,94 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +// ErrUnsupportedExprType represents an error with unsupported express type +type ErrUnsupportedExprType struct { + tp string +} + +func (err ErrUnsupportedExprType) Error() string { + return fmt.Sprintf("Unsupported expression type: %v", err.tp) +} + +// Expr represents an SQL express +type Expr struct { + ColName string + Arg interface{} +} + +// WriteArgs writes args to the writer +func (expr *Expr) WriteArgs(w *builder.BytesWriter) error { + switch arg := expr.Arg.(type) { + case *builder.Builder: + if _, err := w.WriteString("("); err != nil { + return err + } + if err := arg.WriteTo(w); err != nil { + return err + } + if _, err := w.WriteString(")"); err != nil { + return err + } + case string: + if arg == "" { + arg = "''" + } + if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { + return err + } + default: + if _, err := w.WriteString("?"); err != nil { + return err + } + w.Append(arg) + } + return nil +} + +type exprParams []Expr + +func (exprs exprParams) ColNames() []string { + var cols = make([]string, 0, len(exprs)) + for _, expr := range exprs { + cols = append(cols, expr.ColName) + } + return cols +} + +func (exprs *exprParams) Add(name string, arg interface{}) { + *exprs = append(*exprs, Expr{name, arg}) +} + +func (exprs exprParams) IsColExist(colName string) bool { + for _, expr := range exprs { + if strings.EqualFold(schemas.CommonQuoter.Trim(expr.ColName), schemas.CommonQuoter.Trim(colName)) { + return true + } + } + return false +} + +func (exprs exprParams) WriteArgs(w *builder.BytesWriter) error { + for i, expr := range exprs { + if err := expr.WriteArgs(w); err != nil { + return err + } + if i != len(exprs)-1 { + if _, err := w.WriteString(","); err != nil { + return err + } + } + } + return nil +} diff --git a/internal/statements/expr_param.go b/internal/statements/expr_param.go deleted file mode 100644 index 6657408e..00000000 --- a/internal/statements/expr_param.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package statements - -import ( - "fmt" - "strings" - - "xorm.io/builder" - "xorm.io/xorm/schemas" -) - -type ErrUnsupportedExprType struct { - tp string -} - -func (err ErrUnsupportedExprType) Error() string { - return fmt.Sprintf("Unsupported expression type: %v", err.tp) -} - -type exprParam struct { - colName string - arg interface{} -} - -type exprParams struct { - ColNames []string - Args []interface{} -} - -func (exprs *exprParams) Len() int { - return len(exprs.ColNames) -} - -func (exprs *exprParams) addParam(colName string, arg interface{}) { - exprs.ColNames = append(exprs.ColNames, colName) - exprs.Args = append(exprs.Args, arg) -} - -func (exprs *exprParams) IsColExist(colName string) bool { - for _, name := range exprs.ColNames { - if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) { - return true - } - } - return false -} - -func (exprs *exprParams) getByName(colName string) (exprParam, bool) { - for i, name := range exprs.ColNames { - if strings.EqualFold(name, colName) { - return exprParam{name, exprs.Args[i]}, true - } - } - return exprParam{}, false -} - -func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error { - for i, expr := range exprs.Args { - switch arg := expr.(type) { - case *builder.Builder: - if _, err := w.WriteString("("); err != nil { - return err - } - if err := arg.WriteTo(w); err != nil { - return err - } - if _, err := w.WriteString(")"); err != nil { - return err - } - case string: - if arg == "" { - arg = "''" - } - if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { - return err - } - default: - if _, err := w.WriteString("?"); err != nil { - return err - } - w.Append(arg) - } - if i != len(exprs.Args)-1 { - if _, err := w.WriteString(","); err != nil { - return err - } - } - } - return nil -} - -func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { - for i, colName := range exprs.ColNames { - if _, err := w.WriteString(colName); err != nil { - return err - } - if _, err := w.WriteString("="); err != nil { - return err - } - - switch arg := exprs.Args[i].(type) { - case *builder.Builder: - if _, err := w.WriteString("("); err != nil { - return err - } - if err := arg.WriteTo(w); err != nil { - return err - } - if _, err := w.WriteString("("); err != nil { - return err - } - default: - w.Append(exprs.Args[i]) - } - - if i+1 != len(exprs.ColNames) { - if _, err := w.WriteString(","); err != nil { - return err - } - } - } - return nil -} diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 6cbbbeda..91a33319 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -5,10 +5,12 @@ package statements import ( + "errors" "fmt" "strings" "xorm.io/builder" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -17,7 +19,7 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem if _, err := buf.WriteString(" OUTPUT Inserted."); err != nil { return err } - if _, err := buf.WriteString(table.AutoIncrement); err != nil { + if err := statement.dialect.Quoter().QuoteTo(buf, table.AutoIncrement); err != nil { return err } } @@ -41,7 +43,19 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(colNames) <= 0 { + var hasInsertColumns = len(colNames) > 0 + var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) + if needSeq { + for _, col := range colNames { + if strings.EqualFold(col, table.AutoIncrement) { + needSeq = false + break + } + } + } + + if !hasInsertColumns && statement.dialect.URI().DBType != schemas.ORACLE && + statement.dialect.URI().DBType != schemas.DAMENG { if statement.dialect.URI().DBType == schemas.MYSQL { if _, err := buf.WriteString(" VALUES ()"); err != nil { return "", nil, err @@ -59,7 +73,11 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil { + if needSeq { + colNames = append(colNames, table.AutoIncrement) + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames()...), ","); err != nil { return "", nil, err } @@ -79,13 +97,23 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(exprs.Args) > 0 { - if _, err := buf.WriteString(","); err != nil { + if needSeq { + if len(args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { return "", nil, err } } - if err := exprs.WriteArgs(buf); err != nil { - return "", nil, err + if len(exprs) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } } if _, err := buf.WriteString(" FROM "); err != nil { @@ -112,7 +140,19 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if len(exprs.Args) > 0 { + // Insert tablename (id) Values(seq_tablename.nextval) + if needSeq { + if hasInsertColumns { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { + return "", nil, err + } + } + + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } @@ -152,7 +192,7 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames...), ","); err != nil { + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { return "", nil, err } @@ -166,7 +206,7 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } @@ -190,7 +230,7 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return "", nil, err } - if len(exprs.Args) > 0 { + if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err } @@ -205,3 +245,55 @@ func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{} return buf.String(), buf.Args(), nil } + +func (statement *Statement) GenInsertMultipleMapSQL(columns []string, argss [][]interface{}) (string, []interface{}, error) { + var ( + buf = builder.NewWriter() + exprs = statement.ExprColumns + tableName = statement.TableName() + ) + + if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s (", statement.quote(tableName))); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { + return "", nil, err + } + + // if insert where + if statement.Conds().IsValid() { + return "", nil, errors.New("batch insert don't support with where") + } + + if _, err := buf.WriteString(") VALUES "); err != nil { + return "", nil, err + } + for i, args := range argss { + if _, err := buf.WriteString("("); err != nil { + return "", nil, err + } + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + if i < len(argss)-1 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + } + + return buf.String(), buf.Args(), nil +} diff --git a/internal/statements/join.go b/internal/statements/join.go new file mode 100644 index 00000000..adf349e7 --- /dev/null +++ b/internal/statements/join.go @@ -0,0 +1,96 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/builder" + "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN +func (statement *Statement) Join(joinOP string, tablename interface{}, condition interface{}, args ...interface{}) *Statement { + var buf strings.Builder + if len(statement.JoinStr) > 0 { + fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) + } else { + fmt.Fprintf(&buf, "%v JOIN ", joinOP) + } + + condStr := "" + condArgs := []interface{}{} + switch condTp := condition.(type) { + case string: + condStr = condTp + case builder.Cond: + var err error + condStr, condArgs, err = builder.ToSQL(condTp) + if err != nil { + statement.LastError = err + return statement + } + default: + statement.LastError = fmt.Errorf("unsupported join condition type: %v", condTp) + return statement + } + + switch tp := tablename.(type) { + case builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.LastError = err + return statement + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) + statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) + case *builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.LastError = err + return statement + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr)) + statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...) + default: + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) + if !utils.IsSubQuery(tbName) { + var buf strings.Builder + _ = statement.dialect.Quoter().QuoteTo(&buf, tbName) + tbName = buf.String() + } else { + tbName = statement.ReplaceQuote(tbName) + } + fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condStr)) + statement.joinArgs = append(statement.joinArgs, condArgs...) + } + + statement.JoinStr = buf.String() + statement.joinArgs = append(statement.joinArgs, args...) + return statement +} + +func (statement *Statement) writeJoin(w builder.Writer) error { + if statement.JoinStr != "" { + if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil { + return err + } + w.Append(statement.joinArgs...) + } + return nil +} diff --git a/internal/statements/order_by.go b/internal/statements/order_by.go new file mode 100644 index 00000000..08a8263b --- /dev/null +++ b/internal/statements/order_by.go @@ -0,0 +1,90 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/builder" +) + +func (statement *Statement) HasOrderBy() bool { + return statement.orderStr != "" +} + +// ResetOrderBy reset ordery conditions +func (statement *Statement) ResetOrderBy() { + statement.orderStr = "" + statement.orderArgs = nil +} + +// WriteOrderBy write order by to writer +func (statement *Statement) WriteOrderBy(w builder.Writer) error { + if len(statement.orderStr) > 0 { + if _, err := fmt.Fprintf(w, " ORDER BY %s", statement.orderStr); err != nil { + return err + } + w.Append(statement.orderArgs...) + } + return nil +} + +// OrderBy generate "Order By order" statement +func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement { + if len(statement.orderStr) > 0 { + statement.orderStr += ", " + } + var rawOrder string + switch t := order.(type) { + case (*builder.Expression): + rawOrder = t.Content() + args = t.Args() + case string: + rawOrder = t + default: + statement.LastError = ErrUnSupportedSQLType + return statement + } + statement.orderStr += statement.ReplaceQuote(rawOrder) + if len(args) > 0 { + statement.orderArgs = append(statement.orderArgs, args...) + } + return statement +} + +// Desc generate `ORDER BY xx DESC` +func (statement *Statement) Desc(colNames ...string) *Statement { + var buf strings.Builder + if len(statement.orderStr) > 0 { + fmt.Fprint(&buf, statement.orderStr, ", ") + } + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + _ = statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " DESC") + } + statement.orderStr = buf.String() + return statement +} + +// Asc provide asc order by query condition, the input parameters are columns. +func (statement *Statement) Asc(colNames ...string) *Statement { + var buf strings.Builder + if len(statement.orderStr) > 0 { + fmt.Fprint(&buf, statement.orderStr, ", ") + } + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + _ = statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " ASC") + } + statement.orderStr = buf.String() + return statement +} diff --git a/internal/statements/query.go b/internal/statements/query.go index ab3021bf..f72c8602 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -11,9 +11,11 @@ import ( "strings" "xorm.io/builder" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) +// GenQuerySQL generate query SQL func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { if len(sqlOrArgs) > 0 { return statement.ConvertSQLOrArgs(sqlOrArgs...) @@ -27,7 +29,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return "", nil, ErrTableNotFound } - var columnStr = statement.ColumnStr() + columnStr := statement.ColumnStr() if len(statement.SelectStr) > 0 { columnStr = statement.SelectStr } else { @@ -57,29 +59,20 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return "", nil, err } - sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) - if err != nil { - return "", nil, err - } - args := append(statement.joinArgs, condArgs...) - - // for mssql and use limit - qs := strings.Count(sqlStr, "?") - if len(args)*2 == qs { - args = append(args, args...) - } - - return sqlStr, args, nil + return statement.genSelectSQL(columnStr, true, true) } +// GenSumSQL generates sum SQL func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil } - statement.SetRefBean(bean) + if err := statement.SetRefBean(bean); err != nil { + return "", nil, err + } - var sumStrs = make([]string, 0, len(columns)) + sumStrs := make([]string, 0, len(columns)) for _, colName := range columns { if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { colName = statement.quote(colName) @@ -90,26 +83,27 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri } sumSelect := strings.Join(sumStrs, ", ") - if err := statement.mergeConds(bean); err != nil { + if err := statement.MergeConds(bean); err != nil { return "", nil, err } - sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil + return statement.genSelectSQL(sumSelect, true, true) } +// GenGetSQL generates Get SQL func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) { - v := rValue(bean) - isStruct := v.Kind() == reflect.Struct - if isStruct { - statement.SetRefBean(bean) + var isStruct bool + if bean != nil { + v := rValue(bean) + isStruct = v.Kind() == reflect.Struct + if isStruct { + if err := statement.SetRefBean(bean); err != nil { + return "", nil, err + } + } } - var columnStr = statement.ColumnStr() + columnStr := statement.ColumnStr() if len(statement.SelectStr) > 0 { columnStr = statement.SelectStr } else { @@ -136,7 +130,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } if isStruct { - if err := statement.mergeConds(bean); err != nil { + if err := statement.MergeConds(bean); err != nil { return "", nil, err } } else { @@ -145,12 +139,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } } - sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil + return statement.genSelectSQL(columnStr, true, true) } // GenCountSQL generates the SQL for counting @@ -162,13 +151,15 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa var condArgs []interface{} var err error if len(beans) > 0 { - statement.SetRefBean(beans[0]) - if err := statement.mergeConds(beans[0]); err != nil { + if err := statement.SetRefBean(beans[0]); err != nil { + return "", nil, err + } + if err := statement.MergeConds(beans[0]); err != nil { return "", nil, err } } - var selectSQL = statement.SelectStr + selectSQL := statement.SelectStr if len(selectSQL) <= 0 { if statement.IsDistinct { selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) @@ -178,49 +169,74 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false) + var subQuerySelect string + if statement.GroupByStr != "" { + subQuerySelect = statement.GroupByStr + } else { + subQuerySelect = selectSQL + } + + sqlStr, condArgs, err := statement.genSelectSQL(subQuerySelect, false, false) if err != nil { return "", nil, err } - return sqlStr, append(statement.joinArgs, condArgs...), nil + if statement.GroupByStr != "" { + sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr) + } + + return sqlStr, condArgs, nil +} + +func (statement *Statement) writeFrom(w builder.Writer) error { + if _, err := fmt.Fprint(w, " FROM "); err != nil { + return err + } + if err := statement.writeTableName(w); err != nil { + return err + } + if err := statement.writeAlias(w); err != nil { + return err + } + return statement.writeJoin(w) +} + +func (statement *Statement) writeLimitOffset(w builder.Writer) error { + if statement.Start > 0 { + if statement.LimitN != nil { + _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start) + return err + } + _, err := fmt.Fprintf(w, " LIMIT 0 OFFSET %v", statement.Start) + return err + } + if statement.LimitN != nil { + _, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN) + return err + } + // no limit statement + return nil } func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { var ( - distinct string - dialect = statement.dialect - quote = statement.quote - fromStr = " FROM " - top, mssqlCondi, whereStr string + distinct string + dialect = statement.dialect + top, whereStr string + mssqlCondi = builder.NewWriter() ) + if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " } - condSQL, condArgs, err := statement.GenCondSQL(statement.cond) - if err != nil { + condWriter := builder.NewWriter() + if err := statement.cond.WriteTo(statement.QuoteReplacer(condWriter)); err != nil { return "", nil, err } - if len(condSQL) > 0 { - whereStr = " WHERE " + condSQL - } - if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { - fromStr += statement.TableName() - } else { - fromStr += quote(statement.TableName()) - } - - if statement.TableAlias != "" { - if dialect.URI().DBType == schemas.ORACLE { - fromStr += " " + quote(statement.TableAlias) - } else { - fromStr += " AS " + quote(statement.TableAlias) - } - } - if statement.JoinStr != "" { - fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) + if condWriter.Len() > 0 { + whereStr = " WHERE " } pLimitN := statement.LimitN @@ -230,6 +246,9 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB top = fmt.Sprintf("TOP %d ", LimitNValue) } if statement.Start > 0 { + if statement.RefTable == nil { + return "", nil, errors.New("Unsupported query limit without reference table") + } var column string if len(statement.RefTable.PKColumns()) == 0 { for _, index := range statement.RefTable.Indexes { @@ -246,121 +265,117 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } if statement.needTableName() { if len(statement.TableAlias) > 0 { - column = statement.TableAlias + "." + column + column = fmt.Sprintf("%s.%s", statement.TableAlias, column) } else { - column = statement.TableName() + "." + column + column = fmt.Sprintf("%s.%s", statement.TableName(), column) } } - var orderStr string - if needOrderBy && len(statement.OrderStr) > 0 { - orderStr = " ORDER BY " + statement.OrderStr + if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s", + column, statement.Start, column); err != nil { + return "", nil, err } - - var groupStr string - if len(statement.GroupByStr) > 0 { - groupStr = " GROUP BY " + statement.GroupByStr + if err := statement.writeFrom(mssqlCondi); err != nil { + return "", nil, err + } + if whereStr != "" { + if _, err := fmt.Fprint(mssqlCondi, whereStr); err != nil { + return "", nil, err + } + if err := utils.WriteBuilder(mssqlCondi, statement.QuoteReplacer(condWriter)); err != nil { + return "", nil, err + } + } + if needOrderBy { + if err := statement.WriteOrderBy(mssqlCondi); err != nil { + return "", nil, err + } + } + if err := statement.WriteGroupBy(mssqlCondi); err != nil { + return "", nil, err + } + if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil { + return "", nil, err } - mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))", - column, statement.Start, column, fromStr, whereStr, orderStr, groupStr) } } - var buf strings.Builder - fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) - if len(mssqlCondi) > 0 { + buf := builder.NewWriter() + if _, err := fmt.Fprintf(buf, "SELECT %v%v%v", distinct, top, columnStr); err != nil { + return "", nil, err + } + if err := statement.writeFrom(buf); err != nil { + return "", nil, err + } + if whereStr != "" { + if _, err := fmt.Fprint(buf, whereStr); err != nil { + return "", nil, err + } + if err := utils.WriteBuilder(buf, statement.QuoteReplacer(condWriter)); err != nil { + return "", nil, err + } + } + if mssqlCondi.Len() > 0 { if len(whereStr) > 0 { - fmt.Fprint(&buf, " AND ", mssqlCondi) + if _, err := fmt.Fprint(buf, " AND "); err != nil { + return "", nil, err + } } else { - fmt.Fprint(&buf, " WHERE ", mssqlCondi) + if _, err := fmt.Fprint(buf, " WHERE "); err != nil { + return "", nil, err + } + } + + if err := utils.WriteBuilder(buf, mssqlCondi); err != nil { + return "", nil, err } } - if statement.GroupByStr != "" { - fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) + if err := statement.WriteGroupBy(buf); err != nil { + return "", nil, err } - if statement.HavingStr != "" { - fmt.Fprint(&buf, " ", statement.HavingStr) + if err := statement.writeHaving(buf); err != nil { + return "", nil, err } - if needOrderBy && statement.OrderStr != "" { - fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) + if needOrderBy { + if err := statement.WriteOrderBy(buf); err != nil { + return "", nil, err + } } if needLimit { if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { - if statement.Start > 0 { - if pLimitN != nil { - fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) - } else { - fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start) - } - } else if pLimitN != nil { - fmt.Fprint(&buf, " LIMIT ", *pLimitN) + if err := statement.writeLimitOffset(buf); err != nil { + return "", nil, err } } else if dialect.URI().DBType == schemas.ORACLE { - if statement.Start != 0 || pLimitN != nil { + if pLimitN != nil { oldString := buf.String() buf.Reset() rawColStr := columnStr if rawColStr == "*" { rawColStr = "at.*" } - fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + fmt.Fprintf(buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) } } } if statement.IsForUpdate { - return dialect.ForUpdateSQL(buf.String()), condArgs, nil + return dialect.ForUpdateSQL(buf.String()), buf.Args(), nil } - return buf.String(), condArgs, nil + return buf.String(), buf.Args(), nil } +// GenExistSQL generates Exist SQL func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil } - var sqlStr string - var args []interface{} - var joinStr string - var err error - if len(bean) == 0 { - tableName := statement.TableName() - if len(tableName) <= 0 { - return "", nil, ErrTableNotFound - } - - tableName = statement.quote(tableName) - if len(statement.JoinStr) > 0 { - joinStr = statement.JoinStr - } - - if statement.Conds().IsValid() { - condSQL, condArgs, err := statement.GenCondSQL(statement.Conds()) - if err != nil { - return "", nil, err - } - - if statement.dialect.URI().DBType == schemas.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) - } else if statement.dialect.URI().DBType == schemas.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) - } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) - } - args = condArgs - } else { - if statement.dialect.URI().DBType == schemas.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) - } else if statement.dialect.URI().DBType == schemas.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) - } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) - } - args = []interface{}{} - } - } else { + var b interface{} + if len(bean) > 0 { + b = bean[0] beanValue := reflect.ValueOf(bean[0]) if beanValue.Kind() != reflect.Ptr { return "", nil, errors.New("needs a pointer") @@ -371,34 +386,88 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return "", nil, err } } + } + tableName := statement.TableName() + if len(tableName) <= 0 { + return "", nil, ErrTableNotFound + } + if statement.RefTable != nil { + return statement.Limit(1).GenGetSQL(b) + } - if len(statement.TableName()) <= 0 { - return "", nil, ErrTableNotFound + tableName = statement.quote(tableName) + + buf := builder.NewWriter() + if statement.dialect.URI().DBType == schemas.MSSQL { + if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil { + return "", nil, err } - statement.Limit(1) - sqlStr, args, err = statement.GenGetSQL(bean[0]) - if err != nil { + if err := statement.writeJoin(buf); err != nil { + return "", nil, err + } + if statement.Conds().IsValid() { + if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { + return "", nil, err + } + if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil { + return "", nil, err + } + } + } else if statement.dialect.URI().DBType == schemas.ORACLE { + if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { + return "", nil, err + } + if err := statement.writeJoin(buf); err != nil { + return "", nil, err + } + if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { + return "", nil, err + } + if statement.Conds().IsValid() { + if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil { + return "", nil, err + } + if _, err := fmt.Fprintf(buf, " AND "); err != nil { + return "", nil, err + } + } + if _, err := fmt.Fprintf(buf, "ROWNUM=1"); err != nil { + return "", nil, err + } + } else { + if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil { + return "", nil, err + } + if err := statement.writeJoin(buf); err != nil { + return "", nil, err + } + if statement.Conds().IsValid() { + if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { + return "", nil, err + } + if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil { + return "", nil, err + } + } + if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil { return "", nil, err } } - return sqlStr, args, nil + return buf.String(), buf.Args(), nil } +// GenFindSQL generates Find SQL func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil } - var sqlStr string - var args []interface{} - var err error - if len(statement.TableName()) <= 0 { return "", nil, ErrTableNotFound } - var columnStr = statement.ColumnStr() + columnStr := statement.ColumnStr() if len(statement.SelectStr) > 0 { columnStr = statement.SelectStr } else { @@ -426,16 +495,5 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa statement.cond = statement.cond.And(autoCond) - sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) - if err != nil { - return "", nil, err - } - args = append(statement.joinArgs, condArgs...) - // for mssql and use limit - qs := strings.Count(sqlStr, "?") - if len(args)*2 == qs { - args = append(args, args...) - } - - return sqlStr, args, nil + return statement.genSelectSQL(columnStr, true, true) } diff --git a/internal/statements/select.go b/internal/statements/select.go new file mode 100644 index 00000000..2bd2e94d --- /dev/null +++ b/internal/statements/select.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/xorm/schemas" +) + +// Select replace select +func (statement *Statement) Select(str string) *Statement { + statement.SelectStr = statement.ReplaceQuote(str) + return statement +} + +func col2NewCols(columns ...string) []string { + newColumns := make([]string, 0, len(columns)) + for _, col := range columns { + col = strings.Replace(col, "`", "", -1) + col = strings.Replace(col, `"`, "", -1) + ccols := strings.Split(col, ",") + for _, c := range ccols { + newColumns = append(newColumns, strings.TrimSpace(c)) + } + } + return newColumns +} + +// Cols generate "col1, col2" statement +func (statement *Statement) Cols(columns ...string) *Statement { + cols := col2NewCols(columns...) + for _, nc := range cols { + statement.ColumnMap.Add(nc) + } + return statement +} + +// ColumnStr returns column string +func (statement *Statement) ColumnStr() string { + return statement.dialect.Quoter().Join(statement.ColumnMap, ", ") +} + +// AllCols update use only: update all columns +func (statement *Statement) AllCols() *Statement { + statement.useAllCols = true + return statement +} + +// MustCols update use only: must update columns +func (statement *Statement) MustCols(columns ...string) *Statement { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.MustColumnMap[strings.ToLower(nc)] = true + } + return statement +} + +// UseBool indicates that use bool fields as update contents and query contiditions +func (statement *Statement) UseBool(columns ...string) *Statement { + if len(columns) > 0 { + statement.MustCols(columns...) + } else { + statement.allUseBool = true + } + return statement +} + +// Omit do not use the columns +func (statement *Statement) Omit(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.OmitColumnMap = append(statement.OmitColumnMap, nc) + } +} + +func (statement *Statement) genColumnStr() string { + if statement.RefTable == nil { + return "" + } + + var buf strings.Builder + columns := statement.RefTable.Columns() + + for _, col := range columns { + if statement.OmitColumnMap.Contain(col.Name) { + continue + } + + if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) { + continue + } + + if col.MapType == schemas.ONLYTODB { + continue + } + + if buf.Len() != 0 { + buf.WriteString(", ") + } + + if statement.JoinStr != "" { + if statement.TableAlias != "" { + buf.WriteString(statement.TableAlias) + } else { + buf.WriteString(statement.TableName()) + } + + buf.WriteString(".") + } + + statement.dialect.Quoter().QuoteTo(&buf, col.Name) + } + + return buf.String() +} + +func (statement *Statement) colName(col *schemas.Column, tableName string) string { + if statement.needTableName() { + nm := tableName + if len(statement.TableAlias) > 0 { + nm = statement.TableAlias + } + return fmt.Sprintf("%s.%s", statement.quote(nm), statement.quote(col.Name)) + } + return statement.quote(col.Name) +} + +// Distinct generates "DISTINCT col1, col2 " statement +func (statement *Statement) Distinct(columns ...string) *Statement { + statement.IsDistinct = true + statement.Cols(columns...) + return statement +} diff --git a/internal/statements/statement.go b/internal/statements/statement.go index a4294bec..a8fe34fa 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -8,6 +8,7 @@ import ( "database/sql/driver" "errors" "fmt" + "math/big" "reflect" "strings" "time" @@ -42,7 +43,8 @@ type Statement struct { Start int LimitN *int idParam schemas.PK - OrderStr string + orderStr string + orderArgs []interface{} JoinStr string joinArgs []interface{} GroupByStr string @@ -90,27 +92,17 @@ func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZ return statement } +// SetTableName set table name func (statement *Statement) SetTableName(tableName string) { statement.tableName = tableName } -func (statement *Statement) omitStr() string { - return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,") -} - // GenRawSQL generates correct raw sql func (statement *Statement) GenRawSQL() string { return statement.ReplaceQuote(statement.RawSQL) } -func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) { - condSQL, condArgs, err := builder.ToSQL(condOrBuilder) - if err != nil { - return "", nil, err - } - return statement.ReplaceQuote(condSQL), condArgs, nil -} - +// ReplaceQuote replace sql key words with quote func (statement *Statement) ReplaceQuote(sql string) string { if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || statement.dialect.URI().DBType == schemas.SQLITE { @@ -119,16 +111,17 @@ func (statement *Statement) ReplaceQuote(sql string) string { return statement.dialect.Quoter().Replace(sql) } +// SetContextCache sets context cache func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { statement.Context = ctxCache } -// Init reset all the statement's fields +// Reset reset all the statement's fields func (statement *Statement) Reset() { statement.RefTable = nil statement.Start = 0 statement.LimitN = nil - statement.OrderStr = "" + statement.ResetOrderBy() statement.UseCascade = true statement.JoinStr = "" statement.joinArgs = make([]interface{}, 0) @@ -163,21 +156,6 @@ func (statement *Statement) Reset() { statement.LastError = nil } -// NoAutoCondition if you do not want convert bean's field as query condition, then use this function -func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement { - statement.NoAutoCondition = true - if len(no) > 0 { - statement.NoAutoCondition = no[0] - } - return statement -} - -// Alias set the table alias -func (statement *Statement) Alias(alias string) *Statement { - statement.TableAlias = alias - return statement -} - // SQL adds raw sql statement func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { switch query.(type) { @@ -197,80 +175,11 @@ func (statement *Statement) SQL(query interface{}, args ...interface{}) *Stateme return statement } -// Where add Where statement -func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement { - return statement.And(query, args...) -} - func (statement *Statement) quote(s string) string { return statement.dialect.Quoter().Quote(s) } -// And add Where & and statement -func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { - switch query.(type) { - case string: - cond := builder.Expr(query.(string), args...) - statement.cond = statement.cond.And(cond) - case map[string]interface{}: - queryMap := query.(map[string]interface{}) - newMap := make(map[string]interface{}) - for k, v := range queryMap { - newMap[statement.quote(k)] = v - } - statement.cond = statement.cond.And(builder.Eq(newMap)) - case builder.Cond: - cond := query.(builder.Cond) - statement.cond = statement.cond.And(cond) - for _, v := range args { - if vv, ok := v.(builder.Cond); ok { - statement.cond = statement.cond.And(vv) - } - } - default: - statement.LastError = ErrConditionType - } - - return statement -} - -// Or add Where & Or statement -func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { - switch query.(type) { - case string: - cond := builder.Expr(query.(string), args...) - statement.cond = statement.cond.Or(cond) - case map[string]interface{}: - cond := builder.Eq(query.(map[string]interface{})) - statement.cond = statement.cond.Or(cond) - case builder.Cond: - cond := query.(builder.Cond) - statement.cond = statement.cond.Or(cond) - for _, v := range args { - if vv, ok := v.(builder.Cond); ok { - statement.cond = statement.cond.Or(vv) - } - } - default: - // TODO: not support condition type - } - return statement -} - -// In generate "Where column IN (?) " statement -func (statement *Statement) In(column string, args ...interface{}) *Statement { - in := builder.In(statement.quote(column), args...) - statement.cond = statement.cond.And(in) - return statement -} - -// NotIn generate "Where column NOT IN (?) " statement -func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { - notIn := builder.NotIn(statement.quote(column), args...) - statement.cond = statement.cond.And(notIn) - return statement -} - +// SetRefValue set ref value func (statement *Statement) SetRefValue(v reflect.Value) error { var err error statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v)) @@ -285,6 +194,7 @@ func rValue(bean interface{}) reflect.Value { return reflect.Indirect(reflect.ValueOf(bean)) } +// SetRefBean set ref bean func (statement *Statement) SetRefBean(bean interface{}) error { var err error statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean)) @@ -299,32 +209,12 @@ func (statement *Statement) needTableName() bool { return len(statement.JoinStr) > 0 } -func (statement *Statement) colName(col *schemas.Column, tableName string) string { - if statement.needTableName() { - var nm = tableName - if len(statement.TableAlias) > 0 { - nm = statement.TableAlias - } - return statement.quote(nm) + "." + statement.quote(col.Name) - } - return statement.quote(col.Name) -} - -// TableName return current tableName -func (statement *Statement) TableName() string { - if statement.AltTableName != "" { - return statement.AltTableName - } - - return statement.tableName -} - // Incr Generate "Update ... Set column = column + arg" statement func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { - statement.IncrColumns.addParam(column, arg[0]) + statement.IncrColumns.Add(column, arg[0]) } else { - statement.IncrColumns.addParam(column, 1) + statement.IncrColumns.Add(column, 1) } return statement } @@ -332,9 +222,9 @@ func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { // Decr Generate "Update ... Set column = column - arg" statement func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { - statement.DecrColumns.addParam(column, arg[0]) + statement.DecrColumns.Add(column, arg[0]) } else { - statement.DecrColumns.addParam(column, 1) + statement.DecrColumns.Add(column, 1) } return statement } @@ -342,91 +232,19 @@ func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { // SetExpr Generate "Update ... Set column = {expression}" statement func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { if e, ok := expression.(string); ok { - statement.ExprColumns.addParam(column, statement.dialect.Quoter().Replace(e)) + statement.ExprColumns.Add(column, statement.dialect.Quoter().Replace(e)) } else { - statement.ExprColumns.addParam(column, expression) + statement.ExprColumns.Add(column, expression) } return statement } -// Distinct generates "DISTINCT col1, col2 " statement -func (statement *Statement) Distinct(columns ...string) *Statement { - statement.IsDistinct = true - statement.Cols(columns...) - return statement -} - // ForUpdate generates "SELECT ... FOR UPDATE" statement func (statement *Statement) ForUpdate() *Statement { statement.IsForUpdate = true return statement } -// Select replace select -func (statement *Statement) Select(str string) *Statement { - statement.SelectStr = statement.ReplaceQuote(str) - return statement -} - -func col2NewCols(columns ...string) []string { - newColumns := make([]string, 0, len(columns)) - for _, col := range columns { - col = strings.Replace(col, "`", "", -1) - col = strings.Replace(col, `"`, "", -1) - ccols := strings.Split(col, ",") - for _, c := range ccols { - newColumns = append(newColumns, strings.TrimSpace(c)) - } - } - return newColumns -} - -// Cols generate "col1, col2" statement -func (statement *Statement) Cols(columns ...string) *Statement { - cols := col2NewCols(columns...) - for _, nc := range cols { - statement.ColumnMap.Add(nc) - } - return statement -} - -func (statement *Statement) ColumnStr() string { - return statement.dialect.Quoter().Join(statement.ColumnMap, ", ") -} - -// AllCols update use only: update all columns -func (statement *Statement) AllCols() *Statement { - statement.useAllCols = true - return statement -} - -// MustCols update use only: must update columns -func (statement *Statement) MustCols(columns ...string) *Statement { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.MustColumnMap[strings.ToLower(nc)] = true - } - return statement -} - -// UseBool indicates that use bool fields as update contents and query contiditions -func (statement *Statement) UseBool(columns ...string) *Statement { - if len(columns) > 0 { - statement.MustCols(columns...) - } else { - statement.allUseBool = true - } - return statement -} - -// Omit do not use the columns -func (statement *Statement) Omit(columns ...string) { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.OmitColumnMap = append(statement.OmitColumnMap, nc) - } -} - // Nullable Update use only: update columns to null when value is nullable and zero-value func (statement *Statement) Nullable(columns ...string) { newColumns := col2NewCols(columns...) @@ -450,54 +268,7 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement { return statement } -// OrderBy generate "Order By order" statement -func (statement *Statement) OrderBy(order string) *Statement { - if len(statement.OrderStr) > 0 { - statement.OrderStr += ", " - } - statement.OrderStr += statement.ReplaceQuote(order) - return statement -} - -// Desc generate `ORDER BY xx DESC` -func (statement *Statement) Desc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, statement.OrderStr, ", ") - } - for i, col := range colNames { - if i > 0 { - fmt.Fprint(&buf, ", ") - } - statement.dialect.Quoter().QuoteTo(&buf, col) - fmt.Fprint(&buf, " DESC") - } - statement.OrderStr = buf.String() - return statement -} - -// Asc provide asc order by query condition, the input parameters are columns. -func (statement *Statement) Asc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, statement.OrderStr, ", ") - } - for i, col := range colNames { - if i > 0 { - fmt.Fprint(&buf, ", ") - } - statement.dialect.Quoter().QuoteTo(&buf, col) - fmt.Fprint(&buf, " ASC") - } - statement.OrderStr = buf.String() - return statement -} - -func (statement *Statement) Conds() builder.Cond { - return statement.cond -} - -// Table tempororily set table name, the parameter could be a string or a pointer of struct +// SetTable tempororily set table name, the parameter could be a string or a pointer of struct func (statement *Statement) SetTable(tableNameOrBean interface{}) error { v := rValue(tableNameOrBean) t := v.Type() @@ -513,136 +284,46 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error { return nil } -// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { - var buf strings.Builder - if len(statement.JoinStr) > 0 { - fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) - } else { - fmt.Fprintf(&buf, "%v JOIN ", joinOP) - } - - switch tp := tablename.(type) { - case builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement - } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) - case *builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.LastError = err - return statement - } - - fields := strings.Split(tp.TableName(), ".") - aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) - aliasName = schemas.CommonQuoter.Trim(aliasName) - - fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) - default: - tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) - if !utils.IsSubQuery(tbName) { - var buf strings.Builder - statement.dialect.Quoter().QuoteTo(&buf, tbName) - tbName = buf.String() - } - fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition)) - } - - statement.JoinStr = buf.String() - statement.joinArgs = append(statement.joinArgs, args...) - return statement -} - -// tbName get some table's table name -func (statement *Statement) tbNameNoSchema(table *schemas.Table) string { - if len(statement.AltTableName) > 0 { - return statement.AltTableName - } - - return table.Name -} - // GroupBy generate "Group By keys" statement func (statement *Statement) GroupBy(keys string) *Statement { statement.GroupByStr = statement.ReplaceQuote(keys) return statement } +func (statement *Statement) WriteGroupBy(w builder.Writer) error { + if statement.GroupByStr == "" { + return nil + } + _, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr) + return err +} + // Having generate "Having conditions" statement func (statement *Statement) Having(conditions string) *Statement { statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions)) return statement } -// Unscoped always disable struct tag "deleted" +func (statement *Statement) writeHaving(w builder.Writer) error { + if statement.HavingStr == "" { + return nil + } + _, err := fmt.Fprint(w, " ", statement.HavingStr) + return err +} + +// SetUnscoped always disable struct tag "deleted" func (statement *Statement) SetUnscoped() *Statement { statement.unscoped = true return statement } +// GetUnscoped return true if it's unscoped func (statement *Statement) GetUnscoped() bool { return statement.unscoped } -func (statement *Statement) genColumnStr() string { - if statement.RefTable == nil { - return "" - } - - var buf strings.Builder - columns := statement.RefTable.Columns() - - for _, col := range columns { - if statement.OmitColumnMap.Contain(col.Name) { - continue - } - - if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) { - continue - } - - if col.MapType == schemas.ONLYTODB { - continue - } - - if buf.Len() != 0 { - buf.WriteString(", ") - } - - if statement.JoinStr != "" { - if statement.TableAlias != "" { - buf.WriteString(statement.TableAlias) - } else { - buf.WriteString(statement.TableName()) - } - - buf.WriteString(".") - } - - statement.dialect.Quoter().QuoteTo(&buf, col.Name) - } - - return buf.String() -} - -func (statement *Statement) GenCreateTableSQL() []string { - statement.RefTable.StoreEngine = statement.StoreEngine - statement.RefTable.Charset = statement.Charset - s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) - return s -} - +// GenIndexSQL generated create index SQL func (statement *Statement) GenIndexSQL() []string { var sqls []string tbName := statement.TableName() @@ -655,10 +336,7 @@ func (statement *Statement) GenIndexSQL() []string { return sqls } -func uniqueName(tableName, uqeName string) string { - return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) -} - +// GenUniqueSQL generates unique SQL func (statement *Statement) GenUniqueSQL() []string { var sqls []string tbName := statement.TableName() @@ -671,6 +349,7 @@ func (statement *Statement) GenUniqueSQL() []string { return sqls } +// GenDelIndexSQL generate delete index SQL func (statement *Statement) GenDelIndexSQL() []string { var sqls []string tbName := statement.TableName() @@ -684,10 +363,147 @@ func (statement *Statement) GenDelIndexSQL() []string { return sqls } +func (statement *Statement) asDBCond(fieldValue reflect.Value, fieldType reflect.Type, col *schemas.Column, allUseBool, requiredField bool) (interface{}, bool, error) { + switch fieldType.Kind() { + case reflect.Ptr: + if fieldValue.IsNil() { + return nil, true, nil + } + return statement.asDBCond(fieldValue.Elem(), fieldType.Elem(), col, allUseBool, requiredField) + case reflect.Bool: + if allUseBool || requiredField { + return fieldValue.Interface(), true, nil + } + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + return nil, false, nil + case reflect.String: + if !requiredField && fieldValue.String() == "" { + return nil, false, nil + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + return fieldValue.String(), true, nil + } + return fieldValue.Interface(), true, nil + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + return nil, false, nil + } + res, err := dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + if err != nil { + return nil, false, err + } + return res, true, nil + } else if fieldType.ConvertibleTo(schemas.BigFloatType) { + t := fieldValue.Convert(schemas.BigFloatType).Interface().(big.Float) + v := t.String() + if v == "0" { + return nil, false, nil + } + return t.String(), true, nil + } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { + return nil, false, nil + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ := valNul.Value() + if val == nil && !requiredField { + return nil, false, nil + } + return val, true, nil + } else { + if col.IsJSON { + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return string(bytes), true, nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return bytes, true, nil + } + } else { + table, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + return fieldValue.Interface(), true, nil + } + + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + // if pkField.Int() != 0 { + if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { + return pkField.Interface(), true, nil + } + return nil, false, nil + } + return nil, false, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } + case reflect.Array: + return nil, false, nil + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + return nil, false, nil + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + return nil, false, nil + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return string(bytes), true, nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + return fieldValue.Bytes(), true, nil + } + return nil, false, nil + } + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return bytes, true, nil + } + return nil, false, nil + } + return fieldValue.Interface(), true, nil +} + func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { + mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool, +) (builder.Cond, error) { var conds []builder.Cond for _, col := range table.Columns() { if !includeVersion && col.IsVersion { @@ -700,17 +516,13 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, continue } - if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text || - col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { - continue - } if col.IsJSON { continue } var colName string if addedTableName { - var nm = tableName + nm := tableName if len(aliasName) > 0 { nm = aliasName } @@ -721,9 +533,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, fieldValuePtr, err := col.ValueOf(bean) if err != nil { - if !strings.Contains(err.Error(), "is not valid") { - //engine.logger.Warn(err) - } + continue + } else if fieldValuePtr == nil { continue } @@ -736,9 +547,16 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, continue } - fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := useAllCols + if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text || + col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { + if utils.IsValueZero(fieldValue) { + continue + } + return nil, fmt.Errorf("column %s is a TEXT type with data %#v which cannot be as compare condition", col.Name, fieldValue.Interface()) + } + + requiredField := useAllCols if b, ok := getFlagForColumn(mustColumnMap, col); ok { if b { requiredField = true @@ -747,6 +565,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, } } + fieldType := reflect.TypeOf(fieldValue.Interface()) if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { @@ -763,131 +582,12 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, } } - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(schemas.TimeType) { - t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) - } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil && !requiredField { - continue - } - } else { - if col.IsJSON { - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = bytes - } - } else { - table, err := statement.tagParser.ParseWithCache(fieldValue) - if err != nil { - val = fieldValue.Interface() - } else { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) - } - } - } - } - case reflect.Array: + val, ok, err := statement.asDBCond(fieldValue, fieldType, col, allUseBool, requiredField) + if err != nil { + return nil, err + } + if !ok { continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() } conds = append(conds, builder.Eq{colName: val}) @@ -896,14 +596,16 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, return builder.And(conds...), nil } +// BuildConds builds condition func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { return statement.buildConds2(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) } -func (statement *Statement) mergeConds(bean interface{}) error { +// MergeConds merge conditions from bean and id +func (statement *Statement) MergeConds(bean interface{}) error { if !statement.NoAutoCondition && statement.RefTable != nil { - var addedTableName = (len(statement.JoinStr) > 0) + addedTableName := (len(statement.JoinStr) > 0) autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) if err != nil { return err @@ -911,18 +613,7 @@ func (statement *Statement) mergeConds(bean interface{}) error { statement.cond = statement.cond.And(autoCond) } - if err := statement.ProcessIDParam(); err != nil { - return err - } - return nil -} - -func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) { - if err := statement.mergeConds(bean); err != nil { - return "", nil, err - } - - return statement.GenCondSQL(statement.cond) + return statement.ProcessIDParam() } func (statement *Statement) quoteColumnStr(columnStr string) string { @@ -930,17 +621,31 @@ func (statement *Statement) quoteColumnStr(columnStr string) string { return statement.dialect.Quoter().Join(columns, ",") } +// ConvertSQLOrArgs converts sql or args func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { - sql, args, err := convertSQLOrArgs(sqlOrArgs...) + sql, args, err := statement.convertSQLOrArgs(sqlOrArgs...) if err != nil { return "", nil, err } return statement.ReplaceQuote(sql), args, nil } -func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { +func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { switch sqlOrArgs[0].(type) { case string: + if len(sqlOrArgs) > 1 { + newArgs := make([]interface{}, 0, len(sqlOrArgs)-1) + for _, arg := range sqlOrArgs[1:] { + if v, ok := arg.(time.Time); ok { + newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) + } else if v, ok := arg.(*time.Time); ok && v != nil { + newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) + } else { + newArgs = append(newArgs, arg) + } + } + return sqlOrArgs[0].(string), newArgs, nil + } return sqlOrArgs[0].(string), sqlOrArgs[1:], nil case *builder.Builder: return sqlOrArgs[0].(*builder.Builder).ToSQL() @@ -953,7 +658,7 @@ func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { } func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string { - var colnames = make([]string, len(cols)) + colnames := make([]string, len(cols)) for i, col := range cols { if includeTableName { colnames[i] = statement.quote(statement.TableName()) + @@ -967,7 +672,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName // CondDeleted returns the conditions whether a record is soft deleted. func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { - var colName = col.Name + colName := statement.quote(col.Name) if statement.JoinStr != "" { var prefix string if statement.TableAlias != "" { @@ -977,7 +682,7 @@ func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { } colName = statement.quote(prefix) + "." + statement.quote(col.Name) } - var cond = builder.NewCond() + cond := builder.NewCond() if col.SQLType.IsNumeric() { cond = builder.Eq{colName: 0} } else { diff --git a/internal/statements/statement_args.go b/internal/statements/statement_args.go index dc14467d..727d5977 100644 --- a/internal/statements/statement_args.go +++ b/internal/statements/statement_args.go @@ -5,78 +5,11 @@ package statements import ( - "fmt" - "reflect" - "strings" - "time" - "xorm.io/builder" "xorm.io/xorm/schemas" ) -func quoteNeeded(a interface{}) bool { - switch a.(type) { - case int, int8, int16, int32, int64: - return false - case uint, uint8, uint16, uint32, uint64: - return false - case float32, float64: - return false - case bool: - return false - case string: - return true - case time.Time, *time.Time: - return true - case builder.Builder, *builder.Builder: - return false - } - - t := reflect.TypeOf(a) - switch t.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return false - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return false - case reflect.Float32, reflect.Float64: - return false - case reflect.Bool: - return false - case reflect.String: - return true - } - - return true -} - -func convertStringSingleQuote(arg string) string { - return "'" + strings.Replace(arg, "'", "''", -1) + "'" -} - -func convertString(arg string) string { - var buf strings.Builder - buf.WriteRune('\'') - for _, c := range arg { - if c == '\\' || c == '\'' { - buf.WriteRune('\\') - } - buf.WriteRune(c) - } - buf.WriteRune('\'') - return buf.String() -} - -func convertArg(arg interface{}, convertFunc func(string) string) string { - if quoteNeeded(arg) { - argv := fmt.Sprintf("%v", arg) - return convertFunc(argv) - } - - return fmt.Sprintf("%v", arg) -} - -const insertSelectPlaceHolder = true - +// WriteArg writes an arg func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { case *builder.Builder: @@ -90,32 +23,23 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er return err } default: - if insertSelectPlaceHolder { - if err := w.WriteByte('?'); err != nil { - return err - } - if v, ok := arg.(bool); ok && statement.dialect.URI().DBType == schemas.MSSQL { - if v { - w.Append(1) - } else { - w.Append(0) - } + if err := w.WriteByte('?'); err != nil { + return err + } + if v, ok := arg.(bool); ok && statement.dialect.URI().DBType == schemas.MSSQL { + if v { + w.Append(1) } else { - w.Append(arg) + w.Append(0) } } else { - var convertFunc = convertStringSingleQuote - if statement.dialect.URI().DBType == schemas.MYSQL { - convertFunc = convertString - } - if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil { - return err - } + w.Append(arg) } } return nil } +// WriteArgs writes args func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error { for i, arg := range args { if err := statement.WriteArg(w, arg); err != nil { diff --git a/internal/statements/statement_test.go b/internal/statements/statement_test.go index 15f446f4..31428efa 100644 --- a/internal/statements/statement_test.go +++ b/internal/statements/statement_test.go @@ -5,6 +5,7 @@ package statements import ( + "os" "reflect" "strings" "testing" @@ -37,6 +38,7 @@ func TestMain(m *testing.M) { panic("tags parser is nil") } m.Run() + os.Exit(0) } var colStrTests = []struct { @@ -77,8 +79,24 @@ func TestColumnsStringGeneration(t *testing.T) { } } -func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { +func TestConvertSQLOrArgs(t *testing.T) { + statement, err := createTestStatement() + assert.NoError(t, err) + // example orm struct + // type Table struct { + // ID int + // del *time.Time `xorm:"deleted"` + // } + args := []interface{}{ + "INSERT `table` (`id`, `del`) VALUES (?, ?)", 1, (*time.Time)(nil), + } + // before fix, here will panic + _, _, err = statement.convertSQLOrArgs(args...) + assert.NoError(t, err) +} + +func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StopTimer() mapCols := make(map[string]bool) @@ -101,9 +119,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - for _, col := range cols { - if _, ok := getFlagForColumn(mapCols, col); !ok { b.Fatal("Unexpected result") } @@ -112,7 +128,6 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { } func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { - b.StopTimer() mapCols := make(map[string]bool) @@ -131,9 +146,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - for _, col := range cols { - if _, ok := getFlagForColumn(mapCols, col); ok { b.Fatal("Unexpected result") } diff --git a/internal/statements/table_name.go b/internal/statements/table_name.go new file mode 100644 index 00000000..8072a99d --- /dev/null +++ b/internal/statements/table_name.go @@ -0,0 +1,56 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +// TableName return current tableName +func (statement *Statement) TableName() string { + if statement.AltTableName != "" { + return statement.AltTableName + } + + return statement.tableName +} + +// Alias set the table alias +func (statement *Statement) Alias(alias string) *Statement { + statement.TableAlias = alias + return statement +} + +func (statement *Statement) writeAlias(w builder.Writer) error { + if statement.TableAlias != "" { + if statement.dialect.URI().DBType == schemas.ORACLE { + if _, err := fmt.Fprint(w, " ", statement.quote(statement.TableAlias)); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(w, " AS ", statement.quote(statement.TableAlias)); err != nil { + return err + } + } + } + return nil +} + +func (statement *Statement) writeTableName(w builder.Writer) error { + if statement.dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { + if _, err := fmt.Fprint(w, statement.TableName()); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(w, statement.quote(statement.TableName())); err != nil { + return err + } + } + return nil +} diff --git a/internal/statements/update.go b/internal/statements/update.go index 251880b2..40159e0c 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -88,6 +88,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if err != nil { return nil, nil, err } + if fieldValuePtr == nil { + continue + } fieldValue := *fieldValuePtr fieldType := reflect.TypeOf(fieldValue.Interface()) @@ -124,8 +127,12 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if err != nil { return nil, nil, err } - - val = data + if data != nil { + val = data + if !col.SQLType.IsBlob() { + val = string(data) + } + } goto APPEND } } @@ -135,8 +142,12 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if err != nil { return nil, nil, err } - - val = data + if data != nil { + val = data + if !col.SQLType.IsBlob() { + val = string(data) + } + } goto APPEND } @@ -197,7 +208,10 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { continue } - val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + val, err = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + if err != nil { + return nil, nil, err + } } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { val, _ = nulType.Value() if val == nil && !requiredField { diff --git a/internal/statements/values.go b/internal/statements/values.go index 71327c55..4c1360ed 100644 --- a/internal/statements/values.go +++ b/internal/statements/values.go @@ -8,6 +8,7 @@ import ( "database/sql" "database/sql/driver" "fmt" + "math/big" "reflect" "time" @@ -19,9 +20,10 @@ import ( var ( nullFloatType = reflect.TypeOf(sql.NullFloat64{}) + bigFloatType = reflect.TypeOf(big.Float{}) ) -// Value2Interface convert a field value of a struct to interface for puting into database +// Value2Interface convert a field value of a struct to interface for putting into database func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) { if fieldValue.CanAddr() { if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { @@ -29,6 +31,12 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl if err != nil { return nil, err } + if data == nil { + if col.Nullable { + return nil, nil + } + data = []byte{} + } if col.SQLType.IsBlob() { return data, nil } @@ -43,12 +51,15 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl if err != nil { return nil, err } + if data == nil { + if col.Nullable { + return nil, nil + } + data = []byte{} + } if col.SQLType.IsBlob() { return data, nil } - if nil == data { - return nil, nil - } return string(data), nil } } @@ -76,14 +87,17 @@ func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue refl case reflect.Struct: if fieldType.ConvertibleTo(schemas.TimeType) { t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - tf := dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) - return tf, nil + tf, err := dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + return tf, err } else if fieldType.ConvertibleTo(nullFloatType) { t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) if !t.Valid { return nil, nil } return t.Float64, nil + } else if fieldType.ConvertibleTo(bigFloatType) { + t := fieldValue.Convert(bigFloatType).Interface().(big.Float) + return t.String(), nil } if !col.IsJSON { diff --git a/internal/utils/builder.go b/internal/utils/builder.go new file mode 100644 index 00000000..bc97526f --- /dev/null +++ b/internal/utils/builder.go @@ -0,0 +1,27 @@ +// Copyright 2022 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "fmt" + + "xorm.io/builder" +) + +type BuildReader interface { + String() string + Args() []interface{} +} + +// WriteBuilder writes writers to one +func WriteBuilder(w *builder.BytesWriter, inputs ...BuildReader) error { + for _, input := range inputs { + if _, err := fmt.Fprint(w, input.String()); err != nil { + return err + } + w.Append(input.Args()...) + } + return nil +} diff --git a/internal/utils/name.go b/internal/utils/name.go index f5fc3ff7..aeef683d 100644 --- a/internal/utils/name.go +++ b/internal/utils/name.go @@ -6,8 +6,15 @@ package utils import ( "fmt" + "strings" ) +// IndexName returns index name func IndexName(tableName, idxName string) string { return fmt.Sprintf("IDX_%v_%v", tableName, idxName) } + +// SeqName returns sequence name for some table +func SeqName(tableName string) string { + return "SEQ_" + strings.ToUpper(tableName) +} diff --git a/internal/utils/new.go b/internal/utils/new.go new file mode 100644 index 00000000..e3b4eae8 --- /dev/null +++ b/internal/utils/new.go @@ -0,0 +1,25 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import "reflect" + +// New creates a value according type +func New(tp reflect.Type, length, cap int) reflect.Value { + switch tp.Kind() { + case reflect.Slice: + slice := reflect.MakeSlice(tp, length, cap) + x := reflect.New(slice.Type()) + x.Elem().Set(slice) + return x + case reflect.Map: + mp := reflect.MakeMapWithSize(tp, cap) + x := reflect.New(mp.Type()) + x.Elem().Set(mp) + return x + default: + return reflect.New(tp) + } +} diff --git a/internal/utils/reflect.go b/internal/utils/reflect.go index 3dad6bfe..7973d4d3 100644 --- a/internal/utils/reflect.go +++ b/internal/utils/reflect.go @@ -8,6 +8,7 @@ import ( "reflect" ) +// ReflectValue returns value of a bean func ReflectValue(bean interface{}) reflect.Value { return reflect.Indirect(reflect.ValueOf(bean)) } diff --git a/internal/utils/slice.go b/internal/utils/slice.go index 89685706..82289b1a 100644 --- a/internal/utils/slice.go +++ b/internal/utils/slice.go @@ -11,8 +11,8 @@ func SliceEq(left, right []string) bool { if len(left) != len(right) { return false } - sort.Sort(sort.StringSlice(left)) - sort.Sort(sort.StringSlice(right)) + sort.Strings(left) + sort.Strings(right) for i := 0; i < len(left); i++ { if left[i] != right[i] { return false @@ -20,3 +20,13 @@ func SliceEq(left, right []string) bool { } return true } + +// IndexSlice search c in slice s and return the index, return -1 if s don't contain c +func IndexSlice(s []string, c string) int { + for i, ss := range s { + if c == ss { + return i + } + } + return -1 +} diff --git a/internal/utils/sql.go b/internal/utils/sql.go index 5e68c4a4..369ca2b8 100644 --- a/internal/utils/sql.go +++ b/internal/utils/sql.go @@ -8,6 +8,7 @@ import ( "strings" ) +// IsSubQuery returns true if it contains a sub query func IsSubQuery(tbName string) bool { const selStr = "select" if len(tbName) <= len(selStr)+1 { diff --git a/internal/utils/strings.go b/internal/utils/strings.go index 72466705..159e2876 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -8,10 +8,12 @@ import ( "strings" ) +// IndexNoCase index a string in a string with no care of capitalize func IndexNoCase(s, sep string) int { return strings.Index(strings.ToLower(s), strings.ToLower(sep)) } +// SplitNoCase split a string by a separator with no care of capitalize func SplitNoCase(s, sep string) []string { idx := IndexNoCase(s, sep) if idx < 0 { @@ -20,6 +22,7 @@ func SplitNoCase(s, sep string) []string { return strings.Split(s, s[idx:idx+len(sep)]) } +// SplitNNoCase split n by a separator with no care of capitalize func SplitNNoCase(s, sep string, n int) []string { idx := IndexNoCase(s, sep) if idx < 0 { diff --git a/internal/utils/zero.go b/internal/utils/zero.go index 8f033c60..007e3c33 100644 --- a/internal/utils/zero.go +++ b/internal/utils/zero.go @@ -9,6 +9,7 @@ import ( "time" ) +// Zeroable represents an interface which could know if it's a zero value type Zeroable interface { IsZero() bool } @@ -21,39 +22,39 @@ func IsZero(k interface{}) bool { return true } - switch k.(type) { + switch t := k.(type) { case int: - return k.(int) == 0 + return t == 0 case int8: - return k.(int8) == 0 + return t == 0 case int16: - return k.(int16) == 0 + return t == 0 case int32: - return k.(int32) == 0 + return t == 0 case int64: - return k.(int64) == 0 + return t == 0 case uint: - return k.(uint) == 0 + return t == 0 case uint8: - return k.(uint8) == 0 + return t == 0 case uint16: - return k.(uint16) == 0 + return t == 0 case uint32: - return k.(uint32) == 0 + return t == 0 case uint64: - return k.(uint64) == 0 + return t == 0 case float32: - return k.(float32) == 0 + return t == 0 case float64: - return k.(float64) == 0 + return t == 0 case bool: - return k.(bool) == false + return !t case string: - return k.(string) == "" + return t == "" case *time.Time: - return k.(*time.Time) == nilTime || IsTimeZero(*k.(*time.Time)) + return t == nilTime || IsTimeZero(*t) case time.Time: - return IsTimeZero(k.(time.Time)) + return IsTimeZero(t) case Zeroable: return k.(Zeroable) == nil || k.(Zeroable).IsZero() case reflect.Value: // for go version less than 1.13 because reflect.Value has no method IsZero @@ -65,6 +66,7 @@ func IsZero(k interface{}) bool { var zeroType = reflect.TypeOf((*Zeroable)(nil)).Elem() +// IsValueZero returns true if the reflect Value is a zero func IsValueZero(v reflect.Value) bool { switch v.Kind() { case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice: @@ -88,6 +90,7 @@ func IsValueZero(v reflect.Value) bool { return false } +// IsStructZero returns true if the Value is a struct and all fields is zero func IsStructZero(v reflect.Value) bool { if !v.IsValid() || v.NumField() == 0 { return true @@ -120,6 +123,7 @@ func IsStructZero(v reflect.Value) bool { return true } +// IsArrayZero returns true is a slice of array is zero func IsArrayZero(v reflect.Value) bool { if !v.IsValid() || v.Len() == 0 { return true @@ -134,11 +138,13 @@ func IsArrayZero(v reflect.Value) bool { return true } +// represents all zero times const ( ZeroTime0 = "0000-00-00 00:00:00" ZeroTime1 = "0001-01-01 00:00:00" ) +// IsTimeZero return true if a time is zero func IsTimeZero(t time.Time) bool { return t.IsZero() || t.Format("2006-01-02 15:04:05") == ZeroTime0 || t.Format("2006-01-02 15:04:05") == ZeroTime1 diff --git a/log/logger.go b/log/logger.go index eeb63693..b8798c3f 100644 --- a/log/logger.go +++ b/log/logger.go @@ -130,65 +130,57 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l LogLevel) *Simpl // Error implement ILogger func (s *SimpleLogger) Error(v ...interface{}) { if s.level <= LOG_ERR { - s.ERR.Output(2, fmt.Sprintln(v...)) + _ = s.ERR.Output(2, fmt.Sprintln(v...)) } - return } // Errorf implement ILogger func (s *SimpleLogger) Errorf(format string, v ...interface{}) { if s.level <= LOG_ERR { - s.ERR.Output(2, fmt.Sprintf(format, v...)) + _ = s.ERR.Output(2, fmt.Sprintf(format, v...)) } - return } // Debug implement ILogger func (s *SimpleLogger) Debug(v ...interface{}) { if s.level <= LOG_DEBUG { - s.DEBUG.Output(2, fmt.Sprintln(v...)) + _ = s.DEBUG.Output(2, fmt.Sprintln(v...)) } - return } // Debugf implement ILogger func (s *SimpleLogger) Debugf(format string, v ...interface{}) { if s.level <= LOG_DEBUG { - s.DEBUG.Output(2, fmt.Sprintf(format, v...)) + _ = s.DEBUG.Output(2, fmt.Sprintf(format, v...)) } - return } // Info implement ILogger func (s *SimpleLogger) Info(v ...interface{}) { if s.level <= LOG_INFO { - s.INFO.Output(2, fmt.Sprintln(v...)) + _ = s.INFO.Output(2, fmt.Sprintln(v...)) } - return } // Infof implement ILogger func (s *SimpleLogger) Infof(format string, v ...interface{}) { if s.level <= LOG_INFO { - s.INFO.Output(2, fmt.Sprintf(format, v...)) + _ = s.INFO.Output(2, fmt.Sprintf(format, v...)) } - return } // Warn implement ILogger func (s *SimpleLogger) Warn(v ...interface{}) { if s.level <= LOG_WARNING { - s.WARN.Output(2, fmt.Sprintln(v...)) + _ = s.WARN.Output(2, fmt.Sprintln(v...)) } - return } // Warnf implement ILogger func (s *SimpleLogger) Warnf(format string, v ...interface{}) { if s.level <= LOG_WARNING { - s.WARN.Output(2, fmt.Sprintf(format, v...)) + _ = s.WARN.Output(2, fmt.Sprintf(format, v...)) } - return } // Level implement ILogger @@ -199,7 +191,6 @@ func (s *SimpleLogger) Level() LogLevel { // SetLevel implement ILogger func (s *SimpleLogger) SetLevel(l LogLevel) { s.level = l - return } // ShowSQL implement ILogger diff --git a/log/syslogger.go b/log/syslogger.go index 0b3e381c..44272586 100644 --- a/log/syslogger.go +++ b/log/syslogger.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build !windows && !nacl && !plan9 // +build !windows,!nacl,!plan9 package log @@ -26,42 +27,42 @@ func NewSyslogLogger(w *syslog.Writer) *SyslogLogger { // Debug log content as Debug func (s *SyslogLogger) Debug(v ...interface{}) { - s.w.Debug(fmt.Sprint(v...)) + _ = s.w.Debug(fmt.Sprint(v...)) } // Debugf log content as Debug and format func (s *SyslogLogger) Debugf(format string, v ...interface{}) { - s.w.Debug(fmt.Sprintf(format, v...)) + _ = s.w.Debug(fmt.Sprintf(format, v...)) } // Error log content as Error func (s *SyslogLogger) Error(v ...interface{}) { - s.w.Err(fmt.Sprint(v...)) + _ = s.w.Err(fmt.Sprint(v...)) } // Errorf log content as Errorf and format func (s *SyslogLogger) Errorf(format string, v ...interface{}) { - s.w.Err(fmt.Sprintf(format, v...)) + _ = s.w.Err(fmt.Sprintf(format, v...)) } // Info log content as Info func (s *SyslogLogger) Info(v ...interface{}) { - s.w.Info(fmt.Sprint(v...)) + _ = s.w.Info(fmt.Sprint(v...)) } // Infof log content as Infof and format func (s *SyslogLogger) Infof(format string, v ...interface{}) { - s.w.Info(fmt.Sprintf(format, v...)) + _ = s.w.Info(fmt.Sprintf(format, v...)) } // Warn log content as Warn func (s *SyslogLogger) Warn(v ...interface{}) { - s.w.Warning(fmt.Sprint(v...)) + _ = s.w.Warning(fmt.Sprint(v...)) } // Warnf log content as Warnf and format func (s *SyslogLogger) Warnf(format string, v ...interface{}) { - s.w.Warning(fmt.Sprintf(format, v...)) + _ = s.w.Warning(fmt.Sprintf(format, v...)) } // Level shows log level diff --git a/migrate/migrate.go b/migrate/migrate.go index 82c58f90..5c259627 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -110,10 +110,7 @@ func (m *Migrate) RollbackLast() error { return err } - if err := m.RollbackMigration(lastRunnedMigration); err != nil { - return err - } - return nil + return m.RollbackMigration(lastRunnedMigration) } func (m *Migrate) getLastRunnedMigration() (*Migration, error) { @@ -206,7 +203,7 @@ func (m *Migrate) migrationDidRun(mig *Migration) (bool, error) { func (m *Migrate) isFirstRun() bool { row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", m.options.TableName)) var count int - row.Scan(&count) + _ = row.Scan(&count) return count == 0 } diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index 19554f7e..750afb28 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -31,7 +31,7 @@ var ( { ID: "201608301400", Migrate: func(tx *xorm.Engine) error { - return tx.Sync2(&Person{}) + return tx.Sync(&Person{}) }, Rollback: func(tx *xorm.Engine) error { return tx.DropTables(&Person{}) @@ -40,7 +40,7 @@ var ( { ID: "201608301430", Migrate: func(tx *xorm.Engine) error { - return tx.Sync2(&Pet{}) + return tx.Sync(&Pet{}) }, Rollback: func(tx *xorm.Engine) error { return tx.DropTables(&Pet{}) @@ -103,13 +103,10 @@ func TestInitSchema(t *testing.T) { m := New(db, DefaultOptions, migrations) m.InitSchema(func(tx *xorm.Engine) error { - if err := tx.Sync2(&Person{}); err != nil { + if err := tx.Sync(&Person{}); err != nil { return err } - if err := tx.Sync2(&Pet{}); err != nil { - return err - } - return nil + return tx.Sync(&Pet{}) }) err = m.Migrate() @@ -145,6 +142,6 @@ func TestMissingID(t *testing.T) { func tableCount(db *xorm.Engine, tableName string) (count int) { row := db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)) - row.Scan(&count) + _ = row.Scan(&count) return } diff --git a/names/mapper.go b/names/mapper.go index 79add76e..69f67171 100644 --- a/names/mapper.go +++ b/names/mapper.go @@ -16,6 +16,7 @@ type Mapper interface { Table2Obj(string) string } +// CacheMapper represents a cache mapper type CacheMapper struct { oriMapper Mapper obj2tableCache map[string]string @@ -24,12 +25,14 @@ type CacheMapper struct { table2objMutex sync.RWMutex } +// NewCacheMapper creates a cache mapper func NewCacheMapper(mapper Mapper) *CacheMapper { return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string), table2objCache: make(map[string]string), } } +// Obj2Table implements Mapper func (m *CacheMapper) Obj2Table(o string) string { m.obj2tableMutex.RLock() t, ok := m.obj2tableCache[o] @@ -45,6 +48,7 @@ func (m *CacheMapper) Obj2Table(o string) string { return t } +// Table2Obj implements Mapper func (m *CacheMapper) Table2Obj(t string) string { m.table2objMutex.RLock() o, ok := m.table2objCache[t] @@ -60,20 +64,22 @@ func (m *CacheMapper) Table2Obj(t string) string { return o } -// SameMapper implements IMapper and provides same name between struct and +// SameMapper implements Mapper and provides same name between struct and // database table type SameMapper struct { } +// Obj2Table implements Mapper func (m SameMapper) Obj2Table(o string) string { return o } +// Table2Obj implements Mapper func (m SameMapper) Table2Obj(t string) string { return t } -// SnakeMapper implements IMapper and provides name transaltion between +// SnakeMapper implements IMapper and provides name translation between // struct and database table type SnakeMapper struct { } @@ -98,6 +104,7 @@ func snakeCasedName(name string) string { return b2s(newstr) } +// Obj2Table implements Mapper func (mapper SnakeMapper) Obj2Table(name string) string { return snakeCasedName(name) } @@ -127,6 +134,7 @@ func titleCasedName(name string) string { return b2s(newstr) } +// Table2Obj implements Mapper func (mapper SnakeMapper) Table2Obj(name string) string { return titleCasedName(name) } @@ -168,10 +176,12 @@ func gonicCasedName(name string) string { return strings.ToLower(string(newstr)) } +// Obj2Table implements Mapper func (mapper GonicMapper) Obj2Table(name string) string { return gonicCasedName(name) } +// Table2Obj implements Mapper func (mapper GonicMapper) Table2Obj(name string) string { newstr := make([]rune, 0) @@ -234,14 +244,17 @@ type PrefixMapper struct { Prefix string } +// Obj2Table implements Mapper func (mapper PrefixMapper) Obj2Table(name string) string { return mapper.Prefix + mapper.Mapper.Obj2Table(name) } +// Table2Obj implements Mapper func (mapper PrefixMapper) Table2Obj(name string) string { return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) } +// NewPrefixMapper creates a prefix mapper func NewPrefixMapper(mapper Mapper, prefix string) PrefixMapper { return PrefixMapper{mapper, prefix} } @@ -252,14 +265,17 @@ type SuffixMapper struct { Suffix string } +// Obj2Table implements Mapper func (mapper SuffixMapper) Obj2Table(name string) string { return mapper.Mapper.Obj2Table(name) + mapper.Suffix } +// Table2Obj implements Mapper func (mapper SuffixMapper) Table2Obj(name string) string { return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)]) } +// NewSuffixMapper creates a suffix mapper func NewSuffixMapper(mapper Mapper, suffix string) SuffixMapper { return SuffixMapper{mapper, suffix} } diff --git a/names/table_name.go b/names/table_name.go index 0afb1ae3..d7d71b51 100644 --- a/names/table_name.go +++ b/names/table_name.go @@ -14,11 +14,18 @@ type TableName interface { TableName() string } +type TableComment interface { + TableComment() string +} + var ( - tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() - tvCache sync.Map + tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() + tpTableComment = reflect.TypeOf((*TableComment)(nil)).Elem() + tvCache sync.Map + tcCache sync.Map ) +// GetTableName returns table name func GetTableName(mapper Mapper, v reflect.Value) string { if v.Type().Implements(tpTableName) { return v.Interface().(TableName).TableName() @@ -54,3 +61,40 @@ func GetTableName(mapper Mapper, v reflect.Value) string { return mapper.Obj2Table(v.Type().Name()) } + +// GetTableComment returns table comment +func GetTableComment(v reflect.Value) string { + if v.Type().Implements(tpTableComment) { + return v.Interface().(TableComment).TableComment() + } + + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableComment) { + return v.Interface().(TableComment).TableComment() + } + } else if v.CanAddr() { + v1 := v.Addr() + if v1.Type().Implements(tpTableComment) { + return v1.Interface().(TableComment).TableComment() + } + } else { + comment, ok := tcCache.Load(v.Type()) + if ok { + if comment.(string) != "" { + return comment.(string) + } + } else { + v2 := reflect.New(v.Type()) + if v2.Type().Implements(tpTableComment) { + tableComment := v2.Interface().(TableComment).TableComment() + tcCache.Store(v.Type(), tableComment) + return tableComment + } + + tcCache.Store(v.Type(), "") + } + } + + return "" +} diff --git a/rows.go b/rows.go index a56ea1c9..4801c300 100644 --- a/rows.go +++ b/rows.go @@ -5,22 +5,19 @@ package xorm import ( - "database/sql" "errors" "fmt" "reflect" "xorm.io/builder" "xorm.io/xorm/core" - "xorm.io/xorm/internal/utils" ) // Rows rows wrapper a rows to type Rows struct { - session *Session - rows *core.Rows - beanType reflect.Type - lastError error + session *Session + rows *core.Rows + beanType reflect.Type } func newRows(session *Session, bean interface{}) (*Rows, error) { @@ -43,7 +40,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { return nil, err } - if len(session.statement.TableName()) <= 0 { + if len(session.statement.TableName()) == 0 { return nil, ErrTableNotFound } @@ -62,15 +59,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://gitea.com/xorm/xorm/issues/179 if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled - var colName = session.engine.Quote(col.Name) - if addedTableName { - var nm = session.statement.TableName() - if len(session.statement.TableAlias) > 0 { - nm = session.statement.TableAlias - } - colName = session.engine.Quote(nm) + "." + colName - } - autoCond = session.statement.CondDeleted(col) } } @@ -86,7 +74,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { rows.rows, err = rows.session.queryRows(sqlStr, args...) if err != nil { - rows.lastError = err rows.Close() return nil, err } @@ -96,48 +83,53 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { // Next move cursor to next record, return false if end has reached func (rows *Rows) Next() bool { - if rows.lastError == nil && rows.rows != nil { - hasNext := rows.rows.Next() - if !hasNext { - rows.lastError = sql.ErrNoRows - } - return hasNext + if rows.rows != nil { + return rows.rows.Next() } return false } // Err returns the error, if any, that was encountered during iteration. Err may be called after an explicit or implicit Close. func (rows *Rows) Err() error { - return rows.lastError + if rows.rows != nil { + return rows.rows.Err() + } + return nil } // Scan row record to bean properties -func (rows *Rows) Scan(bean interface{}) error { - if rows.lastError != nil { - return rows.lastError +func (rows *Rows) Scan(beans ...interface{}) error { + if rows.Err() != nil { + return rows.Err() } - if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { - return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) + var bean = beans[0] + var tp = reflect.TypeOf(bean) + if tp.Kind() == reflect.Ptr { + tp = tp.Elem() } + var beanKind = tp.Kind() - if err := rows.session.statement.SetRefBean(bean); err != nil { - return err + if len(beans) == 1 { + if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { + return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) + } + + if err := rows.session.statement.SetRefBean(bean); err != nil { + return err + } } fields, err := rows.rows.Columns() if err != nil { return err } - - scanResults, err := rows.session.row2Slice(rows.rows, fields, bean) + types, err := rows.rows.ColumnTypes() if err != nil { return err } - dataStruct := utils.ReflectValue(bean) - _, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable) - if err != nil { + if err := rows.session.scan(rows.rows, rows.session.statement.RefTable, beanKind, beans, types, fields); err != nil { return err } @@ -154,5 +146,5 @@ func (rows *Rows) Close() error { return rows.rows.Close() } - return rows.lastError + return nil } diff --git a/scan.go b/scan.go new file mode 100644 index 00000000..00cee4d7 --- /dev/null +++ b/scan.go @@ -0,0 +1,439 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "database/sql" + "fmt" + "math/big" + "reflect" + "time" + + "xorm.io/xorm/convert" + "xorm.io/xorm/core" + "xorm.io/xorm/dialects" + "xorm.io/xorm/schemas" +) + +// genScanResultsByBeanNullabale generates scan result +func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *interface{}: + return t, false, nil + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes, *[]byte: + return t, false, nil + case *time.Time: + return &sql.NullString{}, true, nil + case *sql.NullTime: + return &sql.NullString{}, true, nil + case *string: + return &sql.NullString{}, true, nil + case *int, *int8, *int16, *int32: + return &sql.NullInt32{}, true, nil + case *int64: + return &sql.NullInt64{}, true, nil + case *uint, *uint8, *uint16, *uint32: + return &convert.NullUint32{}, true, nil + case *uint64: + return &convert.NullUint64{}, true, nil + case *float32, *float64: + return &sql.NullFloat64{}, true, nil + case *bool: + return &sql.NullBool{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return &sql.NullString{}, true, nil + case reflect.Int64: + return &sql.NullInt64{}, true, nil + case reflect.Int32, reflect.Int, reflect.Int16, reflect.Int8: + return &sql.NullInt32{}, true, nil + case reflect.Uint64: + return &convert.NullUint64{}, true, nil + case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8: + return &convert.NullUint32{}, true, nil + default: + return nil, false, fmt.Errorf("genScanResultsByBeanNullable: unsupported type: %#v", bean) + } +} + +func genScanResultsByBean(bean interface{}) (interface{}, bool, error) { + switch t := bean.(type) { + case *interface{}: + return t, false, nil + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, + *sql.RawBytes, + *string, + *int, *int8, *int16, *int32, *int64, + *uint, *uint8, *uint16, *uint32, *uint64, + *float32, *float64, + *bool: + return t, false, nil + case *time.Time, *sql.NullTime: + return &sql.NullString{}, true, nil + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString, + time.Time, + string, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + bool: + return nil, false, fmt.Errorf("unsupported scan type: %t", t) + case convert.Conversion: + return &sql.RawBytes{}, true, nil + } + + tp := reflect.TypeOf(bean).Elem() + switch tp.Kind() { + case reflect.String: + return new(string), true, nil + case reflect.Int64: + return new(int64), true, nil + case reflect.Int32: + return new(int32), true, nil + case reflect.Int: + return new(int32), true, nil + case reflect.Int16: + return new(int32), true, nil + case reflect.Int8: + return new(int32), true, nil + case reflect.Uint64: + return new(uint64), true, nil + case reflect.Uint32: + return new(uint32), true, nil + case reflect.Uint: + return new(uint), true, nil + case reflect.Uint16: + return new(uint16), true, nil + case reflect.Uint8: + return new(uint8), true, nil + case reflect.Float32: + return new(float32), true, nil + case reflect.Float64: + return new(float64), true, nil + default: + return nil, false, fmt.Errorf("genScanResultsByBean: unsupported type: %#v", bean) + } +} + +func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { + scanResults := make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { + var s sql.NullString + scanResults[i] = &s + } + + if err := engine.scan(rows, fields, types, scanResults...); err != nil { + return nil, err + } + return scanResults, nil +} + +// scan is a wrap of driver.Scan but will automatically change the input values according requirements +func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error { + scanResults := make([]interface{}, 0, len(types)) + replaces := make([]bool, 0, len(types)) + var err error + for _, v := range vv { + var replaced bool + var scanResult interface{} + switch t := v.(type) { + case *big.Float, *time.Time, *sql.NullTime: + scanResult = &sql.NullString{} + replaced = true + case sql.Scanner: + scanResult = t + case convert.Conversion: + scanResult = &sql.RawBytes{} + replaced = true + default: + nullable, ok := types[0].Nullable() + if !ok || nullable { + scanResult, replaced, err = genScanResultsByBeanNullable(v) + } else { + scanResult, replaced, err = genScanResultsByBean(v) + } + if err != nil { + return err + } + } + + scanResults = append(scanResults, scanResult) + replaces = append(replaces, replaced) + } + + if err = engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResults...); err != nil { + return err + } + + for i, replaced := range replaces { + if replaced { + if err = convert.Assign(vv[i], scanResults[i], engine.DatabaseTZ, engine.TZLocation); err != nil { + return err + } + } + } + + return nil +} + +func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { + scanResultContainers := make([]interface{}, len(types)) + for i := 0; i < len(types); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.scan(rows, fields, types, scanResultContainers...); err != nil { + return nil, err + } + return scanResultContainers, nil +} + +//////////////////// +// row -> map[string]interface{} + +func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) { + resultsMap := make(map[string]interface{}, len(fields)) + scanResultContainers := make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) + if err != nil { + return nil, err + } + scanResultContainers[i] = scanResult + } + if err := engine.scan(rows, fields, types, scanResultContainers...); err != nil { + return nil, err + } + + for ii, key := range fields { + res, err := convert.Interface2Interface(engine.TZLocation, scanResultContainers[ii]) + if err != nil { + return nil, err + } + resultsMap[key] = res + } + return resultsMap, nil +} + +// ScanInterfaceMap scan result from *core.Rows and return a map +func (engine *Engine) ScanInterfaceMap(rows *core.Rows) (map[string]interface{}, error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + return engine.row2mapInterface(rows, types, fields) +} + +// ScanInterfaceMaps scan results from *core.Rows and return a slice of map +func (engine *Engine) ScanInterfaceMaps(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := engine.row2mapInterface(rows, types, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + if rows.Err() != nil { + return nil, rows.Err() + } + + return resultsSlice, nil +} + +//////////////////// +// row -> map[string]string + +func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { + scanResults := make([]interface{}, len(fields)) + for i := 0; i < len(fields); i++ { + var s sql.NullString + scanResults[i] = &s + } + + if err := engine.driver.Scan(&dialects.ScanContext{ + DBLocation: engine.DatabaseTZ, + UserLocation: engine.TZLocation, + }, rows, types, scanResults...); err != nil { + return nil, err + } + + result := make(map[string]string, len(fields)) + for i, key := range fields { + s := scanResults[i].(*sql.NullString) + if s.String == "" { + result[key] = "" + continue + } + + if schemas.TIME_TYPE == engine.dialect.ColumnTypeKind(types[i].DatabaseTypeName()) { + t, err := convert.String2Time(s.String, engine.DatabaseTZ, engine.TZLocation) + if err != nil { + return nil, err + } + result[key] = t.Format("2006-01-02 15:04:05") + } else { + result[key] = s.String + } + } + return result, nil +} + +// ScanStringMap scan results from *core.Rows and return a map +func (engine *Engine) ScanStringMap(rows *core.Rows) (map[string]string, error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + return engine.row2mapStr(rows, types, fields) +} + +// ScanStringMaps scan results from *core.Rows and return a slice of map +func (engine *Engine) ScanStringMaps(rows *core.Rows) (resultsSlice []map[string]string, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + for rows.Next() { + result, err := engine.row2mapStr(rows, types, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + if rows.Err() != nil { + return nil, rows.Err() + } + + return resultsSlice, nil +} + +//////////////////// +// row -> map[string][]byte + +func convertMapStr2Bytes(m map[string]string) map[string][]byte { + r := make(map[string][]byte, len(m)) + for k, v := range m { + r[k] = []byte(v) + } + return r +} + +func (engine *Engine) scanByteMaps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := engine.row2mapStr(rows, types, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, convertMapStr2Bytes(result)) + } + if rows.Err() != nil { + return nil, rows.Err() + } + + return resultsSlice, nil +} + +//////////////////// +// row -> []string + +func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fields []string) ([]string, error) { + scanResults, err := engine.scanStringInterface(rows, fields, types) + if err != nil { + return nil, err + } + + results := make([]string, 0, len(fields)) + for i := 0; i < len(fields); i++ { + results = append(results, scanResults[i].(*sql.NullString).String) + } + return results, nil +} + +// ScanStringSlice scan results from *core.Rows and return a slice of one row +func (engine *Engine) ScanStringSlice(rows *core.Rows) ([]string, error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + return engine.row2sliceStr(rows, types, fields) +} + +// ScanStringSlices scan results from *core.Rows and return a slice of all rows +func (engine *Engine) ScanStringSlices(rows *core.Rows) (resultsSlice [][]string, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + types, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + for rows.Next() { + record, err := engine.row2sliceStr(rows, types, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, record) + } + if rows.Err() != nil { + return nil, rows.Err() + } + + return resultsSlice, nil +} diff --git a/schemas/column.go b/schemas/column.go index 4f32afab..001769cd 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -6,13 +6,12 @@ package schemas import ( "errors" - "fmt" "reflect" "strconv" - "strings" "time" ) +// enumerates all database mapping way const ( TWOSIDES = iota + 1 ONLYTODB @@ -23,11 +22,12 @@ const ( type Column struct { Name string TableName string - FieldName string // Avaiable only when parsed from a struct + FieldName string // Available only when parsed from a struct + FieldIndex []int // Available only when parsed from a struct SQLType SQLType IsJSON bool - Length int - Length2 int + Length int64 + Length2 int64 Nullable bool Default string Indexes map[string]int @@ -48,7 +48,7 @@ type Column struct { } // NewColumn creates a new column -func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column { +func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullable bool) *Column { return &Column{ Name: name, IsJSON: sqlType.IsJson(), @@ -82,41 +82,17 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) { // ValueOfV returns column's filed of struct's value accept reflevt value func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { - var fieldValue reflect.Value - fieldPath := strings.Split(col.FieldName, ".") - - if dataStruct.Type().Kind() == reflect.Map { - keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1]) - fieldValue = dataStruct.MapIndex(keyValue) - return &fieldValue, nil - } else if dataStruct.Type().Kind() == reflect.Interface { - structValue := reflect.ValueOf(dataStruct.Interface()) - dataStruct = &structValue - } - - level := len(fieldPath) - fieldValue = dataStruct.FieldByName(fieldPath[0]) - for i := 0; i < level-1; i++ { - if !fieldValue.IsValid() { - break - } - if fieldValue.Kind() == reflect.Struct { - fieldValue = fieldValue.FieldByName(fieldPath[i+1]) - } else if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + v := *dataStruct + for _, i := range col.FieldIndex { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) } - fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) - } else { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) + v = v.Elem() } + v = v.FieldByIndex([]int{i}) } - - if !fieldValue.IsValid() { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) - } - - return &fieldValue, nil + return &v, nil } // ConvertID converts id content to suitable type according column type diff --git a/schemas/index.go b/schemas/index.go index 9541250f..47027ea4 100644 --- a/schemas/index.go +++ b/schemas/index.go @@ -28,10 +28,11 @@ func NewIndex(name string, indexType int) *Index { return &Index{true, name, indexType, make([]string, 0)} } +// XName returns the special index name for the table func (index *Index) XName(tableName string) string { if !strings.HasPrefix(index.Name, "UQE_") && !strings.HasPrefix(index.Name, "IDX_") { - tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".") + tableParts := strings.Split(strings.ReplaceAll(tableName, `"`, ""), ".") tableName = tableParts[len(tableParts)-1] if index.Type == UniqueType { return fmt.Sprintf("UQE_%v_%v", tableName, index.Name) @@ -43,11 +44,10 @@ func (index *Index) XName(tableName string) string { // AddColumn add columns which will be composite index func (index *Index) AddColumn(cols ...string) { - for _, col := range cols { - index.Cols = append(index.Cols, col) - } + index.Cols = append(index.Cols, cols...) } +// Equal return true if the two Index is equal func (index *Index) Equal(dst *Index) bool { if index.Type != dst.Type { return false diff --git a/schemas/pk.go b/schemas/pk.go index 03916b44..da3c7899 100644 --- a/schemas/pk.go +++ b/schemas/pk.go @@ -11,13 +11,16 @@ import ( "xorm.io/xorm/internal/utils" ) +// PK represents primary key values type PK []interface{} +// NewPK creates primay keys func NewPK(pks ...interface{}) *PK { p := PK(pks) return &p } +// IsZero return true if primay keys are zero func (p *PK) IsZero() bool { for _, k := range *p { if utils.IsZero(k) { @@ -27,6 +30,7 @@ func (p *PK) IsZero() bool { return false } +// ToString convert to SQL string func (p *PK) ToString() (string, error) { buf := new(bytes.Buffer) enc := gob.NewEncoder(buf) @@ -34,6 +38,7 @@ func (p *PK) ToString() (string, error) { return buf.String(), err } +// FromString reads content to load primary keys func (p *PK) FromString(content string) error { dec := gob.NewDecoder(bytes.NewBufferString(content)) err := dec.Decode(p) diff --git a/schemas/quote.go b/schemas/quote.go index a0070048..6df7bf0b 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -16,10 +16,10 @@ type Quoter struct { } var ( - // AlwaysFalseReverse always think it's not a reverse word + // AlwaysNoReserve always think it's not a reverse word AlwaysNoReserve = func(string) bool { return false } - // AlwaysReverse always reverse the word + // AlwaysReserve always reverse the word AlwaysReserve = func(string) bool { return true } // CommanQuoteMark represnets the common quote mark @@ -29,13 +29,15 @@ var ( CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve} ) +// IsEmpty return true if no prefix and suffix func (q Quoter) IsEmpty() bool { return q.Prefix == 0 && q.Suffix == 0 } +// Quote quote a string func (q Quoter) Quote(s string) string { var buf strings.Builder - q.QuoteTo(&buf, s) + _ = q.QuoteTo(&buf, s) return buf.String() } @@ -59,12 +61,14 @@ func (q Quoter) Trim(s string) string { return buf.String() } +// Join joins a slice with quoters func (q Quoter) Join(a []string, sep string) string { var b strings.Builder - q.JoinWrite(&b, a, sep) + _ = q.JoinWrite(&b, a, sep) return b.String() } +// JoinWrite writes quoted content to a builder func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { if len(a) == 0 { return nil @@ -82,7 +86,9 @@ func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { return err } } - q.QuoteTo(b, strings.TrimSpace(s)) + if err := q.QuoteTo(b, strings.TrimSpace(s)); err != nil { + return err + } } return nil } @@ -117,7 +123,7 @@ func findStart(value string, start int) int { } if (value[k] == 'A' || value[k] == 'a') && (value[k+1] == 'S' || value[k+1] == 's') { - k = k + 2 + k += 2 } for j := k; j < len(value); j++ { @@ -157,17 +163,18 @@ func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { } // QuoteTo quotes the table or column names. i.e. if the quotes are [ and ] -// name -> [name] -// `name` -> [name] -// [name] -> [name] -// schema.name -> [schema].[name] -// `schema`.`name` -> [schema].[name] -// `schema`.name -> [schema].[name] -// schema.`name` -> [schema].[name] -// [schema].name -> [schema].[name] -// schema.[name] -> [schema].[name] -// name AS a -> [name] AS a -// schema.name AS a -> [schema].[name] AS a +// +// name -> [name] +// `name` -> [name] +// [name] -> [name] +// schema.name -> [schema].[name] +// `schema`.`name` -> [schema].[name] +// `schema`.name -> [schema].[name] +// schema.`name` -> [schema].[name] +// [schema].name -> [schema].[name] +// schema.[name] -> [schema].[name] +// name AS a -> [name] AS a +// schema.name AS a -> [schema].[name] AS a func (q Quoter) QuoteTo(buf *strings.Builder, value string) error { var i int for i < len(value) { diff --git a/schemas/quote_test.go b/schemas/quote_test.go index 8e351dc0..f84dfb7d 100644 --- a/schemas/quote_test.go +++ b/schemas/quote_test.go @@ -45,7 +45,8 @@ func TestAlwaysQuoteTo(t *testing.T) { for _, v := range kases { t.Run(v.value, func(t *testing.T) { buf := &strings.Builder{} - quoter.QuoteTo(buf, v.value) + err := quoter.QuoteTo(buf, v.value) + assert.NoError(t, err) assert.EqualValues(t, v.expected, buf.String()) }) } @@ -54,10 +55,7 @@ func TestAlwaysQuoteTo(t *testing.T) { func TestReversedQuoteTo(t *testing.T) { var ( quoter = Quoter{'[', ']', func(s string) bool { - if s == "mytable" { - return true - } - return false + return s == "mytable" }} kases = []struct { expected string @@ -118,7 +116,8 @@ func TestNoQuoteTo(t *testing.T) { for _, v := range kases { t.Run(v.value, func(t *testing.T) { buf := &strings.Builder{} - quoter.QuoteTo(buf, v.value) + err := quoter.QuoteTo(buf, v.value) + assert.NoError(t, err) assert.EqualValues(t, v.expected, buf.String()) }) } @@ -176,6 +175,10 @@ func TestReplace(t *testing.T) { "UPDATE table SET `a` = ~ `a`, `b`='abc`'", "UPDATE table SET [a] = ~ [a], [b]='abc`'", }, + { + "INSERT INTO `insert_where` (`height`,`name`,`repo_id`,`width`,`index`) SELECT $1,$2,$3,$4,coalesce(MAX(`index`),0)+1 FROM `insert_where` WHERE (`repo_id`=$5)", + "INSERT INTO [insert_where] ([height],[name],[repo_id],[width],[index]) SELECT $1,$2,$3,$4,coalesce(MAX([index]),0)+1 FROM [insert_where] WHERE ([repo_id]=$5)", + }, } for _, kase := range kases { diff --git a/schemas/table.go b/schemas/table.go index 7ca9531f..91b33e06 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -5,7 +5,6 @@ package schemas import ( - "fmt" "reflect" "strconv" "strings" @@ -90,23 +89,28 @@ func (table *Table) PKColumns() []*Column { return columns } +// ColumnType returns a column's type func (table *Table) ColumnType(name string) reflect.Type { t, _ := table.Type.FieldByName(name) return t.Type } +// AutoIncrColumn returns autoincrement column func (table *Table) AutoIncrColumn() *Column { return table.GetColumn(table.AutoIncrement) } +// VersionColumn returns version column's information func (table *Table) VersionColumn() *Column { return table.GetColumn(table.Version) } +// UpdatedColumn returns updated column's information func (table *Table) UpdatedColumn() *Column { return table.GetColumn(table.Updated) } +// DeletedColumn returns deleted column's information func (table *Table) DeletedColumn() *Column { return table.GetColumn(table.Deleted) } @@ -154,24 +158,8 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) { for i, col := range table.PKColumns() { var err error - fieldName := col.FieldName - for { - parts := strings.SplitN(fieldName, ".", 2) - if len(parts) == 1 { - break - } + pkField := v.FieldByIndex(col.FieldIndex) - v = v.FieldByName(parts[0]) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Struct { - return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName) - } - fieldName = parts[1] - } - - pkField := v.FieldByName(fieldName) switch pkField.Kind() { case reflect.String: pk[i], err = col.ConvertID(pkField.String()) diff --git a/schemas/table_test.go b/schemas/table_test.go index 9bf10e33..f352675b 100644 --- a/schemas/table_test.go +++ b/schemas/table_test.go @@ -27,7 +27,6 @@ var testsGetColumn = []struct { var table *Table func init() { - table = NewEmptyTable() var name string @@ -41,7 +40,6 @@ func init() { } func TestGetColumn(t *testing.T) { - for _, test := range testsGetColumn { if table.GetColumn(test.name) == nil { t.Error("Column not found!") @@ -50,7 +48,6 @@ func TestGetColumn(t *testing.T) { } func TestGetColumnIdx(t *testing.T) { - for _, test := range testsGetColumn { if table.GetColumnIdx(test.name, test.idx) == nil { t.Errorf("Column %s with idx %d not found!", test.name, test.idx) @@ -59,10 +56,8 @@ func TestGetColumnIdx(t *testing.T) { } func BenchmarkGetColumnWithToLower(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { - if _, ok := table.columnsMap[strings.ToLower(test.name)]; !ok { b.Errorf("Column not found:%s", test.name) } @@ -71,10 +66,8 @@ func BenchmarkGetColumnWithToLower(b *testing.B) { } func BenchmarkGetColumnIdxWithToLower(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { - if c, ok := table.columnsMap[strings.ToLower(test.name)]; ok { if test.idx < len(c) { continue @@ -89,7 +82,6 @@ func BenchmarkGetColumnIdxWithToLower(b *testing.B) { } func BenchmarkGetColumn(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { if table.GetColumn(test.name) == nil { @@ -100,7 +92,6 @@ func BenchmarkGetColumn(b *testing.B) { } func BenchmarkGetColumnIdx(b *testing.B) { - for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { if table.GetColumnIdx(test.name, test.idx) == nil { diff --git a/schemas/type.go b/schemas/type.go index 89459a4d..b8b30851 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -5,29 +5,34 @@ package schemas import ( + "database/sql" + "math/big" "reflect" - "sort" "strings" "time" ) +// DBType represents a database type type DBType string +// enumerates all database types const ( POSTGRES DBType = "postgres" SQLITE DBType = "sqlite3" MYSQL DBType = "mysql" MSSQL DBType = "mssql" ORACLE DBType = "oracle" + DAMENG DBType = "dameng" ) // SQLType represents SQL types type SQLType struct { Name string - DefaultLength int - DefaultLength2 int + DefaultLength int64 + DefaultLength2 int64 } +// enumerates all columns types const ( UNKNOW_TYPE = iota TEXT_TYPE @@ -35,8 +40,10 @@ const ( TIME_TYPE NUMERIC_TYPE ARRAY_TYPE + BOOL_TYPE ) +// IsType reutrns ture if the column type is the same as the parameter func (s *SQLType) IsType(st int) bool { if t, ok := SqlTypes[s.Name]; ok && t == st { return true @@ -44,44 +51,69 @@ func (s *SQLType) IsType(st int) bool { return false } +// IsText returns true if column is a text type func (s *SQLType) IsText() bool { return s.IsType(TEXT_TYPE) } +// IsBlob returns true if column is a binary type func (s *SQLType) IsBlob() bool { return s.IsType(BLOB_TYPE) } +// IsTime returns true if column is a time type func (s *SQLType) IsTime() bool { return s.IsType(TIME_TYPE) } +// IsBool returns true if column is a boolean type +func (s *SQLType) IsBool() bool { + return s.IsType(BOOL_TYPE) +} + +// IsNumeric returns true if column is a numeric type func (s *SQLType) IsNumeric() bool { return s.IsType(NUMERIC_TYPE) } +// IsArray returns true if column is an array type func (s *SQLType) IsArray() bool { return s.IsType(ARRAY_TYPE) } +// IsJson returns true if column is an array type func (s *SQLType) IsJson() bool { return s.Name == Json || s.Name == Jsonb } +// IsXML returns true if column is an xml type +func (s *SQLType) IsXML() bool { + return s.Name == XML +} + +// enumerates all the database column types var ( - Bit = "BIT" - TinyInt = "TINYINT" - SmallInt = "SMALLINT" - MediumInt = "MEDIUMINT" - Int = "INT" - Integer = "INTEGER" - BigInt = "BIGINT" + Bit = "BIT" + UnsignedBit = "UNSIGNED BIT" + TinyInt = "TINYINT" + UnsignedTinyInt = "UNSIGNED TINYINT" + SmallInt = "SMALLINT" + UnsignedSmallInt = "UNSIGNED SMALLINT" + MediumInt = "MEDIUMINT" + UnsignedMediumInt = "UNSIGNED MEDIUMINT" + Int = "INT" + UnsignedInt = "UNSIGNED INT" + Integer = "INTEGER" + BigInt = "BIGINT" + UnsignedBigInt = "UNSIGNED BIGINT" + Number = "NUMBER" Enum = "ENUM" Set = "SET" Char = "CHAR" Varchar = "VARCHAR" + VARCHAR2 = "VARCHAR2" NChar = "NCHAR" NVarchar = "NVARCHAR" TinyText = "TINYTEXT" @@ -128,25 +160,36 @@ var ( Json = "JSON" Jsonb = "JSONB" + XML = "XML" Array = "ARRAY" SqlTypes = map[string]int{ - Bit: NUMERIC_TYPE, - TinyInt: NUMERIC_TYPE, - SmallInt: NUMERIC_TYPE, - MediumInt: NUMERIC_TYPE, - Int: NUMERIC_TYPE, - Integer: NUMERIC_TYPE, - BigInt: NUMERIC_TYPE, + Bit: NUMERIC_TYPE, + UnsignedBit: NUMERIC_TYPE, + TinyInt: NUMERIC_TYPE, + UnsignedTinyInt: NUMERIC_TYPE, + SmallInt: NUMERIC_TYPE, + UnsignedSmallInt: NUMERIC_TYPE, + MediumInt: NUMERIC_TYPE, + UnsignedMediumInt: NUMERIC_TYPE, + Int: NUMERIC_TYPE, + UnsignedInt: NUMERIC_TYPE, + Integer: NUMERIC_TYPE, + BigInt: NUMERIC_TYPE, + UnsignedBigInt: NUMERIC_TYPE, + Number: NUMERIC_TYPE, Enum: TEXT_TYPE, Set: TEXT_TYPE, Json: TEXT_TYPE, Jsonb: TEXT_TYPE, + XML: TEXT_TYPE, + Char: TEXT_TYPE, NChar: TEXT_TYPE, Varchar: TEXT_TYPE, + VARCHAR2: TEXT_TYPE, NVarchar: TEXT_TYPE, TinyText: TEXT_TYPE, Text: TEXT_TYPE, @@ -183,100 +226,63 @@ var ( Bytea: BLOB_TYPE, UniqueIdentifier: BLOB_TYPE, - Bool: NUMERIC_TYPE, + Bool: BOOL_TYPE, + Boolean: BOOL_TYPE, Serial: NUMERIC_TYPE, BigSerial: NUMERIC_TYPE, + "INT8": NUMERIC_TYPE, + Array: ARRAY_TYPE, } - - intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} - uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} ) -// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison +// enumerates all types var ( - c_EMPTY_STRING string - c_BOOL_DEFAULT bool - c_BYTE_DEFAULT byte - c_COMPLEX64_DEFAULT complex64 - c_COMPLEX128_DEFAULT complex128 - c_FLOAT32_DEFAULT float32 - c_FLOAT64_DEFAULT float64 - c_INT64_DEFAULT int64 - c_UINT64_DEFAULT uint64 - c_INT32_DEFAULT int32 - c_UINT32_DEFAULT uint32 - c_INT16_DEFAULT int16 - c_UINT16_DEFAULT uint16 - c_INT8_DEFAULT int8 - c_UINT8_DEFAULT uint8 - c_INT_DEFAULT int - c_UINT_DEFAULT uint - c_TIME_DEFAULT time.Time -) + IntType = reflect.TypeOf((*int)(nil)).Elem() + Int8Type = reflect.TypeOf((*int8)(nil)).Elem() + Int16Type = reflect.TypeOf((*int16)(nil)).Elem() + Int32Type = reflect.TypeOf((*int32)(nil)).Elem() + Int64Type = reflect.TypeOf((*int64)(nil)).Elem() -var ( - IntType = reflect.TypeOf(c_INT_DEFAULT) - Int8Type = reflect.TypeOf(c_INT8_DEFAULT) - Int16Type = reflect.TypeOf(c_INT16_DEFAULT) - Int32Type = reflect.TypeOf(c_INT32_DEFAULT) - Int64Type = reflect.TypeOf(c_INT64_DEFAULT) + UintType = reflect.TypeOf((*uint)(nil)).Elem() + Uint8Type = reflect.TypeOf((*uint8)(nil)).Elem() + Uint16Type = reflect.TypeOf((*uint16)(nil)).Elem() + Uint32Type = reflect.TypeOf((*uint32)(nil)).Elem() + Uint64Type = reflect.TypeOf((*uint64)(nil)).Elem() - UintType = reflect.TypeOf(c_UINT_DEFAULT) - Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT) - Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT) - Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT) - Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT) + Float32Type = reflect.TypeOf((*float32)(nil)).Elem() + Float64Type = reflect.TypeOf((*float64)(nil)).Elem() - Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT) - Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT) + Complex64Type = reflect.TypeOf((*complex64)(nil)).Elem() + Complex128Type = reflect.TypeOf((*complex128)(nil)).Elem() - Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT) - Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT) - - StringType = reflect.TypeOf(c_EMPTY_STRING) - BoolType = reflect.TypeOf(c_BOOL_DEFAULT) - ByteType = reflect.TypeOf(c_BYTE_DEFAULT) + StringType = reflect.TypeOf((*string)(nil)).Elem() + BoolType = reflect.TypeOf((*bool)(nil)).Elem() + ByteType = reflect.TypeOf((*byte)(nil)).Elem() BytesType = reflect.SliceOf(ByteType) - TimeType = reflect.TypeOf(c_TIME_DEFAULT) -) - -var ( - PtrIntType = reflect.PtrTo(IntType) - PtrInt8Type = reflect.PtrTo(Int8Type) - PtrInt16Type = reflect.PtrTo(Int16Type) - PtrInt32Type = reflect.PtrTo(Int32Type) - PtrInt64Type = reflect.PtrTo(Int64Type) - - PtrUintType = reflect.PtrTo(UintType) - PtrUint8Type = reflect.PtrTo(Uint8Type) - PtrUint16Type = reflect.PtrTo(Uint16Type) - PtrUint32Type = reflect.PtrTo(Uint32Type) - PtrUint64Type = reflect.PtrTo(Uint64Type) - - PtrFloat32Type = reflect.PtrTo(Float32Type) - PtrFloat64Type = reflect.PtrTo(Float64Type) - - PtrComplex64Type = reflect.PtrTo(Complex64Type) - PtrComplex128Type = reflect.PtrTo(Complex128Type) - - PtrStringType = reflect.PtrTo(StringType) - PtrBoolType = reflect.PtrTo(BoolType) - PtrByteType = reflect.PtrTo(ByteType) - - PtrTimeType = reflect.PtrTo(TimeType) + TimeType = reflect.TypeOf((*time.Time)(nil)).Elem() + BigFloatType = reflect.TypeOf((*big.Float)(nil)).Elem() + NullFloat64Type = reflect.TypeOf((*sql.NullFloat64)(nil)).Elem() + NullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem() + NullInt32Type = reflect.TypeOf((*sql.NullInt32)(nil)).Elem() + NullInt64Type = reflect.TypeOf((*sql.NullInt64)(nil)).Elem() + NullBoolType = reflect.TypeOf((*sql.NullBool)(nil)).Elem() ) // Type2SQLType generate SQLType acorrding Go's type func Type2SQLType(t reflect.Type) (st SQLType) { switch k := t.Kind(); k { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: st = SQLType{Int, 0, 0} - case reflect.Int64, reflect.Uint64: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + st = SQLType{UnsignedInt, 0, 0} + case reflect.Int64: st = SQLType{BigInt, 0, 0} + case reflect.Uint64: + st = SQLType{UnsignedBigInt, 0, 0} case reflect.Float32: st = SQLType{Float, 0, 0} case reflect.Float64: @@ -284,7 +290,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { case reflect.Complex64, reflect.Complex128: st = SQLType{Varchar, 64, 0} case reflect.Array, reflect.Slice, reflect.Map: - if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) { + if t.Elem() == ByteType { st = SQLType{Blob, 0, 0} } else { st = SQLType{Text, 0, 0} @@ -296,6 +302,16 @@ func Type2SQLType(t reflect.Type) (st SQLType) { case reflect.Struct: if t.ConvertibleTo(TimeType) { st = SQLType{DateTime, 0, 0} + } else if t.ConvertibleTo(NullFloat64Type) { + st = SQLType{Double, 0, 0} + } else if t.ConvertibleTo(NullStringType) { + st = SQLType{Varchar, 255, 0} + } else if t.ConvertibleTo(NullInt32Type) { + st = SQLType{Integer, 0, 0} + } else if t.ConvertibleTo(NullInt64Type) { + st = SQLType{BigInt, 0, 0} + } else if t.ConvertibleTo(NullBoolType) { + st = SQLType{Boolean, 0, 0} } else { // TODO need to handle association struct st = SQLType{Text, 0, 0} @@ -308,29 +324,39 @@ func Type2SQLType(t reflect.Type) (st SQLType) { return } -// default sql type change to go types +// SQLType2Type convert default sql type change to go types func SQLType2Type(st SQLType) reflect.Type { name := strings.ToUpper(st.Name) switch name { case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: - return reflect.TypeOf(1) + return IntType case BigInt, BigSerial: - return reflect.TypeOf(int64(1)) + return Int64Type + case UnsignedBit, UnsignedTinyInt, UnsignedSmallInt, UnsignedMediumInt, UnsignedInt: + return UintType + case UnsignedBigInt: + return Uint64Type case Float, Real: - return reflect.TypeOf(float32(1)) + return Float32Type case Double: - return reflect.TypeOf(float64(1)) + return Float64Type case Char, NChar, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName: - return reflect.TypeOf("") + return StringType case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier: - return reflect.TypeOf([]byte{}) + return BytesType case Bool: - return reflect.TypeOf(true) + return BoolType case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year: - return reflect.TypeOf(c_TIME_DEFAULT) + return TimeType case Decimal, Numeric, Money, SmallMoney: - return reflect.TypeOf("") + return StringType default: - return reflect.TypeOf("") + return StringType } } + +// SQLTypeName returns sql type name +func SQLTypeName(tp string) string { + fields := strings.Split(tp, "(") + return fields[0] +} diff --git a/schemas/version.go b/schemas/version.go new file mode 100644 index 00000000..ba789679 --- /dev/null +++ b/schemas/version.go @@ -0,0 +1,12 @@ +// Copyright 2021 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +// Version represents a database version +type Version struct { + Number string // the version number which could be compared + Level string + Edition string +} diff --git a/session.go b/session.go index 17abd453..e1a16e5b 100644 --- a/session.go +++ b/session.go @@ -15,8 +15,8 @@ import ( "hash/crc32" "io" "reflect" + "strconv" "strings" - "time" "xorm.io/xorm/contexts" "xorm.io/xorm/convert" @@ -34,7 +34,7 @@ type ErrFieldIsNotExist struct { } func (e ErrFieldIsNotExist) Error() string { - return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) + return fmt.Sprintf("field %s is not exist on table %s", e.FieldName, e.TableName) } // ErrFieldIsNotValid is not valid @@ -79,7 +79,8 @@ type Session struct { afterClosures []func(interface{}) afterProcessors []executedProcessor - stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + stmtCache map[uint32]*core.Stmt // key: hash.Hash32 of (queryStr, len(queryStr)) + txStmtCache map[uint32]*core.Stmt // for tx statement lastSQL string lastSQLArgs []interface{} @@ -123,13 +124,14 @@ func newSession(engine *Engine) *Session { autoResetStatement: true, prepareStmt: false, - afterInsertBeans: make(map[interface{}]*[]func(interface{}), 0), - afterUpdateBeans: make(map[interface{}]*[]func(interface{}), 0), - afterDeleteBeans: make(map[interface{}]*[]func(interface{}), 0), + afterInsertBeans: make(map[interface{}]*[]func(interface{})), + afterUpdateBeans: make(map[interface{}]*[]func(interface{})), + afterDeleteBeans: make(map[interface{}]*[]func(interface{})), beforeClosures: make([]func(interface{}), 0), afterClosures: make([]func(interface{}), 0), afterProcessors: make([]executedProcessor, 0), stmtCache: make(map[uint32]*core.Stmt), + txStmtCache: make(map[uint32]*core.Stmt), lastSQL: "", lastSQLArgs: make([]interface{}, 0), @@ -150,6 +152,12 @@ func (session *Session) Close() error { } } + for _, v := range session.txStmtCache { + if err := v.Close(); err != nil { + return err + } + } + if !session.isClosed { // When Close be called, if session is a transaction and do not call // Commit or Rollback, then call Rollback. @@ -160,6 +168,7 @@ func (session *Session) Close() error { } session.tx = nil session.stmtCache = nil + session.txStmtCache = nil session.isClosed = true } return nil @@ -169,10 +178,16 @@ func (session *Session) db() *core.DB { return session.engine.db } +// Engine returns session Engine func (session *Session) Engine() *Engine { return session.engine } +// Tx returns session tx +func (session *Session) Tx() *core.Tx { + return session.tx +} + func (session *Session) getQueryer() core.Queryer { if session.tx != nil { return session.tx @@ -194,6 +209,7 @@ func (session *Session) IsClosed() bool { func (session *Session) resetStatement() { if session.autoResetStatement { session.statement.Reset() + session.prepareStmt = false } } @@ -259,8 +275,8 @@ func (session *Session) Limit(limit int, start ...int) *Session { // OrderBy provide order by query condition, the input parameter is the content // after order by on a sql statement. -func (session *Session) OrderBy(order string) *Session { - session.statement.OrderBy(order) +func (session *Session) OrderBy(order interface{}, args ...interface{}) *Session { + session.statement.OrderBy(order, args...) return session } @@ -298,7 +314,7 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { // MustLogSQL means record SQL or not and don't follow engine's setting func (session *Session) MustLogSQL(logs ...bool) *Session { - var showSQL = true + showSQL := true if len(logs) > 0 { showSQL = logs[0] } @@ -314,7 +330,7 @@ func (session *Session) NoCache() *Session { } // Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { +func (session *Session) Join(joinOperator string, tablename interface{}, condition interface{}, args ...interface{}) *Session { session.statement.Join(joinOperator, tablename, condition, args...) return session } @@ -364,37 +380,55 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } -func (session *Session) getField(dataStruct *reflect.Value, key string, table *schemas.Table, idx int) (*reflect.Value, error) { - var col *schemas.Column - if col = table.GetColumnIdx(key, idx); col == nil { - return nil, ErrFieldIsNotExist{key, table.Name} +func (session *Session) doPrepareTx(sqlStr string) (stmt *core.Stmt, err error) { + crc := crc32.ChecksumIEEE([]byte(sqlStr)) + // TODO try hash(sqlStr+len(sqlStr)) + var has bool + stmt, has = session.txStmtCache[crc] + if !has { + stmt, err = session.tx.PrepareContext(session.ctx, sqlStr) + if err != nil { + return nil, err + } + session.txStmtCache[crc] = stmt + } + return +} + +func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { + col := table.GetColumnIdx(colName, idx) + if col == nil { + return nil, nil, ErrFieldIsNotExist{colName, table.Name} } fieldValue, err := col.ValueOfV(dataStruct) if err != nil { - return nil, err + return nil, nil, err + } + if fieldValue == nil { + return nil, nil, ErrFieldIsNotValid{colName, table.Name} } - if !fieldValue.IsValid() || !fieldValue.CanSet() { - return nil, ErrFieldIsNotValid{key, table.Name} + return nil, nil, ErrFieldIsNotValid{colName, table.Name} } - return fieldValue, nil + return col, fieldValue, nil } // Cell cell is a result of one column field type Cell *interface{} -func (session *Session) rows2Beans(rows *core.Rows, fields []string, +func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType, table *schemas.Table, newElemFunc func([]string) reflect.Value, - sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { + sliceValueSetFunc func(*reflect.Value, schemas.PK) error, +) error { for rows.Next() { - var newValue = newElemFunc(fields) + newValue := newElemFunc(fields) bean := newValue.Interface() dataStruct := newValue.Elem() // handle beforeClosures - scanResults, err := session.row2Slice(rows, fields, bean) + scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { return err } @@ -410,10 +444,10 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, bean: bean, }) } - return nil + return rows.Err() } -func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) { +func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) { for _, closure := range session.beforeClosures { closure(bean) } @@ -423,7 +457,7 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa var cell interface{} scanResults[i] = &cell } - if err := rows.Scan(scanResults...); err != nil { + if err := session.engine.scan(rows, fields, types, scanResults...); err != nil { return nil, err } @@ -432,6 +466,245 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa return scanResults, nil } +func setJSON(fieldValue *reflect.Value, fieldType reflect.Type, scanResult interface{}) error { + bs, ok := convert.AsBytes(scanResult) + if !ok { + return fmt.Errorf("unsupported database data type: %#v", scanResult) + } + if len(bs) == 0 { + return nil + } + + if fieldType.Kind() == reflect.String { + fieldValue.SetString(string(bs)) + return nil + } + + if fieldValue.CanAddr() { + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + if err != nil { + return err + } + } else { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } + return nil +} + +func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { + switch tp.Kind() { + case reflect.Ptr: + return asKind(vv.Elem(), tp.Elem()) + case reflect.Int64: + return vv.Int(), nil + case reflect.Int: + return int(vv.Int()), nil + case reflect.Int32: + return int32(vv.Int()), nil + case reflect.Int16: + return int16(vv.Int()), nil + case reflect.Int8: + return int8(vv.Int()), nil + case reflect.Uint64: + return vv.Uint(), nil + case reflect.Uint: + return uint(vv.Uint()), nil + case reflect.Uint32: + return uint32(vv.Uint()), nil + case reflect.Uint16: + return uint16(vv.Uint()), nil + case reflect.Uint8: + return uint8(vv.Uint()), nil + case reflect.String: + return vv.String(), nil + case reflect.Slice: + if tp.Elem().Kind() == reflect.Uint8 { + v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) + if err != nil { + return nil, err + } + return v, nil + } + } + return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) +} + +var uint8ZeroValue = reflect.ValueOf(uint8(0)) + +func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, + scanResult interface{}, table *schemas.Table, +) error { + v, ok := scanResult.(*interface{}) + if ok { + scanResult = *v + } + if scanResult == nil { + return nil + } + + if fieldValue.CanAddr() { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + data, ok := convert.AsBytes(scanResult) + if !ok { + return fmt.Errorf("cannot convert %#v as bytes", scanResult) + } + if data == nil { + return nil + } + return structConvert.FromDB(data) + } + } + + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, ok := convert.AsBytes(scanResult) + if !ok { + return fmt.Errorf("cannot convert %#v as bytes", scanResult) + } + if data == nil { + return nil + } + + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + return fieldValue.Interface().(convert.Conversion).FromDB(data) + } + return structConvert.FromDB(data) + } + + vv := reflect.ValueOf(scanResult) + fieldType := fieldValue.Type() + + if col.IsJSON { + return setJSON(fieldValue, fieldType, scanResult) + } + + switch fieldType.Kind() { + case reflect.Ptr: + var e reflect.Value + if fieldValue.IsNil() { + e = reflect.New(fieldType.Elem()).Elem() + } else { + e = fieldValue.Elem() + } + if err := session.convertBeanField(col, &e, scanResult, table); err != nil { + return err + } + if fieldValue.IsNil() { + fieldValue.Set(e.Addr()) + } + return nil + case reflect.Complex64, reflect.Complex128: + return setJSON(fieldValue, fieldType, scanResult) + case reflect.Slice: + bs, ok := convert.AsBytes(scanResult) + if ok && fieldType.Elem().Kind() == reflect.Uint8 { + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } else { + fieldValue.Set(reflect.ValueOf(bs)) + } + return nil + } + case reflect.Array: + bs, ok := convert.AsBytes(scanResult) + if ok && fieldType.Elem().Kind() == reflect.Uint8 { + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) + if err != nil { + return err + } + fieldValue.Set(x.Elem()) + } else { + if fieldValue.Len() < vv.Len() { + return fmt.Errorf("Set field %s[Array] failed because of data too long", col.Name) + } + for i := 0; i < fieldValue.Len(); i++ { + if i < vv.Len() { + fieldValue.Index(i).Set(vv.Index(i)) + } else { + fieldValue.Index(i).Set(uint8ZeroValue) + } + } + } + return nil + } + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.BigFloatType) { + v, err := convert.AsBigFloat(scanResult) + if err != nil { + return err + } + fieldValue.Set(reflect.ValueOf(v).Elem().Convert(fieldType)) + return nil + } + + if fieldType.ConvertibleTo(schemas.TimeType) { + dbTZ := session.engine.DatabaseTZ + if col.TimeZone != nil { + dbTZ = col.TimeZone + } + + t, err := convert.AsTime(scanResult, dbTZ, session.engine.TZLocation) + if err != nil { + return err + } + + fieldValue.Set(reflect.ValueOf(*t).Convert(fieldType)) + return nil + } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + err := nulVal.Scan(scanResult) + if err == nil { + return nil + } + session.engine.logger.Errorf("sql.Sanner error: %v", err) + } else if session.statement.UseCascade { + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) + if err != nil { + return err + } + + if len(table.PrimaryKeys) != 1 { + return errors.New("unsupported non or composited primary key cascade") + } + pk := make(schemas.PK, len(table.PrimaryKeys)) + pk[0], err = asKind(vv, reflect.TypeOf(scanResult)) + if err != nil { + return err + } + + if !pk.IsZero() { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + structInter := reflect.New(fieldValue.Type()) + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + if err != nil { + return err + } + if has { + fieldValue.Set(structInter.Elem()) + } else { + return errors.New("cascade obj is not exist") + } + } + return nil + } + } // switch fieldType.Kind() + + return convert.AssignValue(fieldValue.Addr(), scanResult) +} + func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { defer func() { executeAfterSet(bean, fields, scanResults) @@ -439,431 +712,36 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b buildAfterProcessors(session, bean) - var tempMap = make(map[string]int) + tempMap := make(map[string]int) var pk schemas.PK - for ii, key := range fields { + for i, colName := range fields { var idx int + lKey := strings.ToLower(colName) var ok bool - var lKey = strings.ToLower(key) + if idx, ok = tempMap[lKey]; !ok { idx = 0 } else { - idx = idx + 1 + idx++ } tempMap[lKey] = idx - fieldValue, err := session.getField(dataStruct, key, table, idx) - if err != nil { - if !strings.Contains(err.Error(), "is not valid") { - session.engine.logger.Warnf("%v", err) - } + col, fieldValue, err := getField(dataStruct, table, colName, idx) + if _, ok := err.(ErrFieldIsNotExist); ok { continue + } else if err != nil { + return nil, err } + if fieldValue == nil { continue } - rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) - // if row is null then ignore - if rawValue.Interface() == nil { - continue + if err := session.convertBeanField(col, fieldValue, scanResults[i], table); err != nil { + return nil, err } - - if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { - if err := structConvert.FromDB(data); err != nil { - return nil, err - } - } else { - return nil, err - } - continue - } - } - - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - if data, err := value2Bytes(&rawValue); err == nil { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type().Elem())) - } - fieldValue.Interface().(convert.Conversion).FromDB(data) - } else { - return nil, err - } - continue - } - - rawValueType := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) - col := table.GetColumnIdx(key, idx) if col.IsPrimaryKey { - pk = append(pk, rawValue.Interface()) - } - fieldType := fieldValue.Type() - hasAssigned := false - - if col.IsJSON { - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(schemas.BytesType) { - bs = vv.Bytes() - } else { - return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) - } - - hasAssigned = true - - if len(bs) > 0 { - if fieldType.Kind() == reflect.String { - fieldValue.SetString(string(bs)) - continue - } - if fieldValue.CanAddr() { - err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return nil, err - } - } else { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } - - continue - } - - switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: - // TODO: reimplement this - var bs []byte - if rawValueType.Kind() == reflect.String { - bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(schemas.BytesType) { - bs = vv.Bytes() - } - - hasAssigned = true - if len(bs) > 0 { - if fieldValue.CanAddr() { - err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) - if err != nil { - return nil, err - } - } else { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } - case reflect.Slice, reflect.Array: - switch rawValueType.Kind() { - case reflect.Slice, reflect.Array: - switch rawValueType.Elem().Kind() { - case reflect.Uint8: - if fieldType.Elem().Kind() == reflect.Uint8 { - hasAssigned = true - if col.SQLType.IsText() { - x := reflect.New(fieldType) - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } else { - if fieldValue.Len() > 0 { - for i := 0; i < fieldValue.Len(); i++ { - if i < vv.Len() { - fieldValue.Index(i).Set(vv.Index(i)) - } - } - } else { - for i := 0; i < vv.Len(); i++ { - fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i))) - } - } - } - } - } - } - case reflect.String: - if rawValueType.Kind() == reflect.String { - hasAssigned = true - fieldValue.SetString(vv.String()) - } - case reflect.Bool: - if rawValueType.Kind() == reflect.Bool { - hasAssigned = true - fieldValue.SetBool(vv.Bool()) - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch rawValueType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true - fieldValue.SetInt(vv.Int()) - } - case reflect.Float32, reflect.Float64: - switch rawValueType.Kind() { - case reflect.Float32, reflect.Float64: - hasAssigned = true - fieldValue.SetFloat(vv.Float()) - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - switch rawValueType.Kind() { - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - hasAssigned = true - fieldValue.SetUint(vv.Uint()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - hasAssigned = true - fieldValue.SetUint(uint64(vv.Int())) - } - case reflect.Struct: - if fieldType.ConvertibleTo(schemas.TimeType) { - dbTZ := session.engine.DatabaseTZ - if col.TimeZone != nil { - dbTZ = col.TimeZone - } - - if rawValueType == schemas.TimeType { - hasAssigned = true - - t := vv.Convert(schemas.TimeType).Interface().(time.Time) - - z, _ := t.Zone() - // set new location if database don't save timezone or give an incorrect timezone - if len(z) == 0 || t.Year() == 0 || t.Location().String() != dbTZ.String() { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.engine.logger.Debugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) - t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), - t.Minute(), t.Second(), t.Nanosecond(), dbTZ) - } - - t = t.In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || - rawValueType == schemas.Int32Type { - hasAssigned = true - - t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else { - if d, ok := vv.Interface().([]uint8); ok { - hasAssigned = true - t, err := session.byte2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - hasAssigned = false - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } - } else if d, ok := vv.Interface().(string); ok { - hasAssigned = true - t, err := session.str2Time(col, d) - if err != nil { - session.engine.logger.Errorf("byte2Time error: %v", err) - hasAssigned = false - } else { - fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } - } else { - return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) - } - } - } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - // !! 增加支持sql.Scanner接口的结构,如sql.NullString - hasAssigned = true - if err := nulVal.Scan(vv.Interface()); err != nil { - session.engine.logger.Errorf("sql.Sanner error: %v", err) - hasAssigned = false - } - } else if col.IsJSON { - if rawValueType.Kind() == reflect.String { - hasAssigned = true - x := reflect.New(fieldType) - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } else if rawValueType.Kind() == reflect.Slice { - hasAssigned = true - x := reflect.New(fieldType) - if len(vv.Bytes()) > 0 { - err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } - } else if session.statement.UseCascade { - table, err := session.engine.tagParser.ParseWithCache(*fieldValue) - if err != nil { - return nil, err - } - - hasAssigned = true - if len(table.PrimaryKeys) != 1 { - return nil, errors.New("unsupported non or composited primary key cascade") - } - var pk = make(schemas.PK, len(table.PrimaryKeys)) - pk[0], err = asKind(vv, rawValueType) - if err != nil { - return nil, err - } - - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return nil, err - } - if has { - fieldValue.Set(structInter.Elem()) - } else { - return nil, errors.New("cascade obj is not exist") - } - } - } - case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - switch fieldType { - // following types case matching ptr's native type, therefore assign ptr directly - case schemas.PtrStringType: - if rawValueType.Kind() == reflect.String { - x := vv.String() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrBoolType: - if rawValueType.Kind() == reflect.Bool { - x := vv.Bool() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrTimeType: - if rawValueType == schemas.PtrTimeType { - hasAssigned = true - var x = rawValue.Interface().(time.Time) - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrFloat64Type: - if rawValueType.Kind() == reflect.Float64 { - x := vv.Float() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrUint64Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint64(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrInt64Type: - if rawValueType.Kind() == reflect.Int64 { - x := vv.Int() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrFloat32Type: - if rawValueType.Kind() == reflect.Float64 { - var x = float32(vv.Float()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrIntType: - if rawValueType.Kind() == reflect.Int64 { - var x = int(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrInt32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrInt8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrInt16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrUintType: - if rawValueType.Kind() == reflect.Int64 { - var x = uint(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.PtrUint32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.Uint8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.Uint16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case schemas.Complex64Type: - var x complex64 - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) - if err != nil { - return nil, err - } - fieldValue.Set(reflect.ValueOf(&x)) - } - hasAssigned = true - case schemas.Complex128Type: - var x complex128 - if len([]byte(vv.String())) > 0 { - err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) - if err != nil { - return nil, err - } - fieldValue.Set(reflect.ValueOf(&x)) - } - hasAssigned = true - } // switch fieldType - } // switch fieldType.Kind() - - // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value - if !hasAssigned { - data, err := value2Bytes(&rawValue) - if err != nil { - return nil, err - } - - if err = session.bytes2Value(col, fieldValue, data); err != nil { - return nil, err - } + pk = append(pk, scanResults[i]) } } return pk, nil @@ -895,8 +773,14 @@ func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) { } } -// ContextHook sets the context on this session +// Context sets the context on this session func (session *Session) Context(ctx context.Context) *Session { + if session.engine.logSessionID && session.ctx != nil { + ctx = context.WithValue(ctx, log.SessionIDKey, session.ctx.Value(log.SessionIDKey)) + ctx = context.WithValue(ctx, log.SessionKey, session.ctx.Value(log.SessionKey)) + ctx = context.WithValue(ctx, log.SessionShowSQLKey, session.ctx.Value(log.SessionShowSQLKey)) + } + session.ctx = ctx return session } @@ -910,3 +794,9 @@ func (session *Session) PingContext(ctx context.Context) error { session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) return session.DB().PingContext(ctx) } + +// disable version check +func (session *Session) NoVersionCheck() *Session { + session.statement.CheckVersion = false + return session +} diff --git a/session_convert.go b/session_convert.go deleted file mode 100644 index a6839947..00000000 --- a/session_convert.go +++ /dev/null @@ -1,529 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "strconv" - "strings" - "time" - - "xorm.io/xorm/convert" - "xorm.io/xorm/internal/json" - "xorm.io/xorm/internal/utils" - "xorm.io/xorm/schemas" -) - -func (session *Session) str2Time(col *schemas.Column, data string) (outTime time.Time, outErr error) { - sdata := strings.TrimSpace(data) - var x time.Time - var err error - - var parseLoc = session.engine.DatabaseTZ - if col.TimeZone != nil { - parseLoc = col.TimeZone - } - - if sdata == utils.ZeroTime0 || sdata == utils.ZeroTime1 { - } else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column - // time stamp - sd, err := strconv.ParseInt(sdata, 10, 64) - if err == nil { - x = time.Unix(sd, 0) - //session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else { - //session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } - } else if len(sdata) > 19 && strings.Contains(sdata, "-") { - x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc) - session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - if err != nil { - x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc) - //session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } - if err != nil { - x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc) - //session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } - } else if len(sdata) == 19 && strings.Contains(sdata, "-") { - x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc) - //session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { - x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) - //session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else if col.SQLType.Name == schemas.Time { - if strings.Contains(sdata, " ") { - ssd := strings.Split(sdata, " ") - sdata = ssd[1] - } - - sdata = strings.TrimSpace(sdata) - if session.engine.dialect.URI().DBType == schemas.MYSQL && len(sdata) > 8 { - sdata = sdata[len(sdata)-8:] - } - - st := fmt.Sprintf("2006-01-02 %v", sdata) - x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc) - //session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else { - outErr = fmt.Errorf("unsupported time format %v", sdata) - return - } - if err != nil { - outErr = fmt.Errorf("unsupported time format %v: %v", sdata, err) - return - } - outTime = x.In(session.engine.TZLocation) - return -} - -func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) { - return session.str2Time(col, string(data)) -} - -// convert a db data([]byte) to a field value -func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Value, data []byte) error { - if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - return structConvert.FromDB(data) - } - - if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { - return structConvert.FromDB(data) - } - - var v interface{} - key := col.Name - fieldType := fieldValue.Type() - - switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: - x := reflect.New(fieldType) - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - case reflect.Slice, reflect.Array, reflect.Map: - v = data - t := fieldType.Elem() - k := t.Kind() - if col.SQLType.IsText() { - x := reflect.New(fieldType) - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } else if col.SQLType.IsBlob() { - if k == reflect.Uint8 { - fieldValue.Set(reflect.ValueOf(v)) - } else { - x := reflect.New(fieldType) - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) - if err != nil { - return err - } - fieldValue.Set(x.Elem()) - } - } - } else { - return ErrUnSupportedType - } - case reflect.String: - fieldValue.SetString(string(data)) - case reflect.Bool: - v, err := asBool(data) - if err != nil { - return fmt.Errorf("arg %v as bool: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(v)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - sdata := string(data) - var x int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - session.engine.dialect.URI().DBType == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API - if len(data) == 1 { - x = int64(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x, err = strconv.ParseInt(sdata, 16, 64) - } else if strings.HasPrefix(sdata, "0") { - x, err = strconv.ParseInt(sdata, 8, 64) - } else if strings.EqualFold(sdata, "true") { - x = 1 - } else if strings.EqualFold(sdata, "false") { - x = 0 - } else { - x, err = strconv.ParseInt(sdata, 10, 64) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.SetInt(x) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return fmt.Errorf("arg %v as float64: %s", key, err.Error()) - } - fieldValue.SetFloat(x) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.SetUint(x) - //Currently only support Time type - case reflect.Struct: - // !! 增加支持sql.Scanner接口的结构,如sql.NullString - if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - if err := nulVal.Scan(data); err != nil { - return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error()) - } - } else { - if fieldType.ConvertibleTo(schemas.TimeType) { - x, err := session.byte2Time(col, data) - if err != nil { - return err - } - v = x - fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if session.statement.UseCascade { - table, err := session.engine.tagParser.ParseWithCache(*fieldValue) - if err != nil { - return err - } - - // TODO: current only support 1 primary key - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(schemas.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) - pk[0], err = str2PK(string(data), rawValueType) - if err != nil { - return err - } - - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Elem().Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } - } - } - } - case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - //typeStr := fieldType.String() - switch fieldType.Elem().Kind() { - // case "*string": - case schemas.StringType.Kind(): - x := string(data) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*bool": - case schemas.BoolType.Kind(): - d := string(data) - v, err := strconv.ParseBool(d) - if err != nil { - return fmt.Errorf("arg %v as bool: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&v).Convert(fieldType)) - // case "*complex64": - case schemas.Complex64Type.Kind(): - var x complex64 - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, &x) - if err != nil { - return err - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - } - // case "*complex128": - case schemas.Complex128Type.Kind(): - var x complex128 - if len(data) > 0 { - err := json.DefaultJSONHandler.Unmarshal(data, &x) - if err != nil { - return err - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - } - // case "*float64": - case schemas.Float64Type.Kind(): - x, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return fmt.Errorf("arg %v as float64: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*float32": - case schemas.Float32Type.Kind(): - var x float32 - x1, err := strconv.ParseFloat(string(data), 32) - if err != nil { - return fmt.Errorf("arg %v as float32: %s", key, err.Error()) - } - x = float32(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint64": - case schemas.Uint64Type.Kind(): - var x uint64 - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint": - case schemas.UintType.Kind(): - var x uint - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint32": - case schemas.Uint32Type.Kind(): - var x uint32 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint32(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint8": - case schemas.Uint8Type.Kind(): - var x uint8 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint8(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*uint16": - case schemas.Uint16Type.Kind(): - var x uint16 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - x = uint16(x1) - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int64": - case schemas.Int64Type.Kind(): - sdata := string(data) - var x int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int64(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x, err = strconv.ParseInt(sdata, 16, 64) - } else if strings.HasPrefix(sdata, "0") { - x, err = strconv.ParseInt(sdata, 8, 64) - } else { - x, err = strconv.ParseInt(sdata, 10, 64) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int": - case schemas.IntType.Kind(): - sdata := string(data) - var x int - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int32": - case schemas.Int32Type.Kind(): - sdata := string(data) - var x int32 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - session.engine.dialect.URI().DBType == schemas.MYSQL { - if len(data) == 1 { - x = int32(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int32(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int32(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int32(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int8": - case schemas.Int8Type.Kind(): - sdata := string(data) - var x int8 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int8(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int8(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int8(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int8(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*int16": - case schemas.Int16Type.Kind(): - sdata := string(data) - var x int16 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == schemas.Bit && - strings.Contains(session.engine.DriverName(), "mysql") { - if len(data) == 1 { - x = int16(data[0]) - } else { - x = 0 - } - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int16(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int16(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int16(x1) - } - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) - // case "*SomeStruct": - case reflect.Struct: - switch fieldType { - // case "*.time.Time": - case schemas.PtrTimeType: - x, err := session.byte2Time(col, data) - if err != nil { - return err - } - v = x - fieldValue.Set(reflect.ValueOf(&x)) - default: - if session.statement.UseCascade { - structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.tagParser.ParseWithCache(structInter.Elem()) - if err != nil { - return err - } - - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(schemas.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) - pk[0], err = str2PK(string(data), rawValueType) - if err != nil { - return err - } - - if !pk.IsZero() { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } - } - } else { - return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) - } - } - default: - return fmt.Errorf("unsupported type in Scan: %s", fieldValue.Type().String()) - } - default: - return fmt.Errorf("unsupported type in Scan: %s", fieldValue.Type().String()) - } - - return nil -} diff --git a/session_delete.go b/session_delete.go index 13bf791f..d36b9e52 100644 --- a/session_delete.go +++ b/session_delete.go @@ -9,7 +9,9 @@ import ( "fmt" "strconv" + "xorm.io/builder" "xorm.io/xorm/caches" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -40,7 +42,13 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri pkColumns := table.PKColumns() ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) if err != nil { - resultsSlice, err := session.queryBytes(newsql, args...) + rows, err := session.queryRows(newsql, args...) + if err != nil { + return err + } + defer rows.Close() + + resultsSlice, err := session.engine.ScanStringMaps(rows) if err != nil { return err } @@ -53,9 +61,9 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri if v, ok := data[col.Name]; !ok { return errors.New("no id") } else if col.SQLType.IsText() { - pk = append(pk, string(v)) + pk = append(pk, v) } else if col.SQLType.IsNumeric() { - id, err = strconv.ParseInt(string(v), 10, 64) + id, err = strconv.ParseInt(v, 10, 64) if err != nil { return err } @@ -83,7 +91,18 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri } // Delete records, bean's non-empty fields are conditions -func (session *Session) Delete(bean interface{}) (int64, error) { +// At least one condition must be set. +func (session *Session) Delete(beans ...interface{}) (int64, error) { + return session.delete(beans, true) +} + +// Truncate records, bean's non-empty fields are conditions +// In contrast to Delete this method allows deletes without conditions. +func (session *Session) Truncate(beans ...interface{}) (int64, error) { + return session.delete(beans, false) +} + +func (session *Session) delete(beans []interface{}, mustHaveConditions bool) (int64, error) { if session.isAutoClose { defer session.Close() } @@ -92,118 +111,114 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return 0, session.statement.LastError } - if err := session.statement.SetRefBean(bean); err != nil { + var ( + condWriter = builder.NewWriter() + err error + bean interface{} + ) + if len(beans) > 0 { + bean = beans[0] + if err = session.statement.SetRefBean(bean); err != nil { + return 0, err + } + + executeBeforeClosures(session, bean) + + if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { + processor.BeforeDelete() + } + + if err = session.statement.MergeConds(bean); err != nil { + return 0, err + } + } + + if err = session.statement.Conds().WriteTo(session.statement.QuoteReplacer(condWriter)); err != nil { return 0, err } - executeBeforeClosures(session, bean) - - if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { - processor.BeforeDelete() - } - - condSQL, condArgs, err := session.statement.GenConds(bean) - if err != nil { - return 0, err - } pLimitN := session.statement.LimitN - if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { + if mustHaveConditions && condWriter.Len() == 0 && (pLimitN == nil || *pLimitN == 0) { return 0, ErrNeedDeletedCond } - var tableNameNoQuote = session.statement.TableName() - var tableName = session.engine.Quote(tableNameNoQuote) - var table = session.statement.RefTable - var deleteSQL string - if len(condSQL) > 0 { - deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) - } else { - deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName) + tableNameNoQuote := session.statement.TableName() + tableName := session.engine.Quote(tableNameNoQuote) + table := session.statement.RefTable + deleteSQLWriter := builder.NewWriter() + fmt.Fprintf(deleteSQLWriter, "DELETE FROM %v", tableName) + if condWriter.Len() > 0 { + fmt.Fprintf(deleteSQLWriter, " WHERE %v", condWriter.String()) + deleteSQLWriter.Append(condWriter.Args()...) } - var orderSQL string - if len(session.statement.OrderStr) > 0 { - orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) + orderSQLWriter := builder.NewWriter() + if err := session.statement.WriteOrderBy(orderSQLWriter); err != nil { + return 0, err } + if pLimitN != nil && *pLimitN > 0 { limitNValue := *pLimitN - orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) + if _, err := fmt.Fprintf(orderSQLWriter, " LIMIT %d", limitNValue); err != nil { + return 0, err + } } - if len(orderSQL) > 0 { + orderCondWriter := builder.NewWriter() + if orderSQLWriter.Len() > 0 { switch session.engine.dialect.URI().DBType { case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - deleteSQL += " AND " + inSQL + if condWriter.Len() > 0 { + fmt.Fprintf(orderCondWriter, " AND ") } else { - deleteSQL += " WHERE " + inSQL + fmt.Fprintf(orderCondWriter, " WHERE ") } + fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()) + orderCondWriter.Append(orderSQLWriter.Args()...) case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - deleteSQL += " AND " + inSQL + if condWriter.Len() > 0 { + fmt.Fprintf(orderCondWriter, " AND ") } else { - deleteSQL += " WHERE " + inSQL + fmt.Fprintf(orderCondWriter, " WHERE ") } + fmt.Fprintf(orderCondWriter, "rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQLWriter.String()) // TODO: how to handle delete limit on mssql? case schemas.MSSQL: return 0, ErrNotImplemented default: - deleteSQL += orderSQL + fmt.Fprint(orderCondWriter, orderSQLWriter.String()) + orderCondWriter.Append(orderSQLWriter.Args()...) } } - var realSQL string - argsForCache := make([]interface{}, 0, len(condArgs)*2) - if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled - realSQL = deleteSQL - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) + realSQLWriter := builder.NewWriter() + argsForCache := make([]interface{}, 0, len(deleteSQLWriter.Args())*2) + copy(argsForCache, deleteSQLWriter.Args()) + argsForCache = append(deleteSQLWriter.Args(), argsForCache...) + if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled + if err := utils.WriteBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil { + return 0, err + } } else { - // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - deletedColumn := table.DeletedColumn() - realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", + if _, err := fmt.Fprintf(realSQLWriter, "UPDATE %v SET %v = ? WHERE %v", session.engine.Quote(session.statement.TableName()), session.engine.Quote(deletedColumn.Name), - condSQL) + condWriter.String()); err != nil { + return 0, err + } + val, t, err := session.engine.nowTime(deletedColumn) + if err != nil { + return 0, err + } + realSQLWriter.Append(val) + realSQLWriter.Append(condWriter.Args()...) - if len(orderSQL) > 0 { - switch session.engine.dialect.URI().DBType { - case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return 0, ErrNotImplemented - default: - realSQL += orderSQL - } + if err := utils.WriteBuilder(realSQLWriter, orderCondWriter); err != nil { + return 0, err } - // !oinume! Insert nowTime to the head of session.statement.Params - condArgs = append(condArgs, "") - paramsLen := len(condArgs) - copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) - - val, t := session.engine.nowTime(deletedColumn) - condArgs[0] = val - - var colName = deletedColumn.Name + colName := deletedColumn.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) @@ -211,36 +226,38 @@ func (session *Session) Delete(bean interface{}) (int64, error) { } if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { - session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) + _ = session.cacheDelete(table, tableNameNoQuote, deleteSQLWriter.String(), argsForCache...) } session.statement.RefTable = table - res, err := session.exec(realSQL, condArgs...) + res, err := session.exec(realSQLWriter.String(), realSQLWriter.Args()...) if err != nil { return 0, err } - // handle after delete processors - if session.isAutoCommit { - for _, closure := range session.afterClosures { - closure(bean) - } - if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { - processor.AfterDelete() - } - } else { - lenAfterClosures := len(session.afterClosures) - if lenAfterClosures > 0 { - if value, has := session.afterDeleteBeans[bean]; has && value != nil { - *value = append(*value, session.afterClosures...) - } else { - afterClosures := make([]func(interface{}), lenAfterClosures) - copy(afterClosures, session.afterClosures) - session.afterDeleteBeans[bean] = &afterClosures + if bean != nil { + // handle after delete processors + if session.isAutoCommit { + for _, closure := range session.afterClosures { + closure(bean) + } + if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { + processor.AfterDelete() } } else { - if _, ok := interface{}(bean).(AfterDeleteProcessor); ok { - session.afterDeleteBeans[bean] = nil + lenAfterClosures := len(session.afterClosures) + if lenAfterClosures > 0 && len(beans) > 0 { + if value, has := session.afterDeleteBeans[beans[0]]; has && value != nil { + *value = append(*value, session.afterClosures...) + } else { + afterClosures := make([]func(interface{}), lenAfterClosures) + copy(afterClosures, session.afterClosures) + session.afterDeleteBeans[bean] = &afterClosures + } + } else { + if _, ok := interface{}(bean).(AfterDeleteProcessor); ok { + session.afterDeleteBeans[bean] = nil + } } } } diff --git a/session_exist.go b/session_exist.go index e52c618e..b5e4a655 100644 --- a/session_exist.go +++ b/session_exist.go @@ -25,5 +25,8 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } diff --git a/session_find.go b/session_find.go index 0daea005..2270454b 100644 --- a/session_find.go +++ b/session_find.go @@ -6,11 +6,11 @@ package xorm import ( "errors" - "fmt" "reflect" "xorm.io/builder" "xorm.io/xorm/caches" + "xorm.io/xorm/convert" "xorm.io/xorm/internal/statements" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" @@ -57,12 +57,10 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte if session.statement.SelectStr != "" { session.statement.SelectStr = "" } - if len(session.statement.ColumnMap) > 0 { + if len(session.statement.ColumnMap) > 0 && !session.statement.IsDistinct { session.statement.ColumnMap = []string{} } - if session.statement.OrderStr != "" { - session.statement.OrderStr = "" - } + session.statement.ResetOrderBy() if session.statement.LimitN != nil { session.statement.LimitN = nil } @@ -71,7 +69,11 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte } // session has stored the conditions so we use `unscoped` to avoid duplicated condition. - return session.Unscoped().Count(reflect.New(sliceElementType).Interface()) + if sliceElementType.Kind() == reflect.Struct { + return session.Unscoped().Count(reflect.New(sliceElementType).Interface()) + } + + return session.Unscoped().Count() } func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { @@ -81,15 +83,15 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - var isSlice = sliceValue.Kind() == reflect.Slice - var isMap = sliceValue.Kind() == reflect.Map + isSlice := sliceValue.Kind() == reflect.Slice + isMap := sliceValue.Kind() == reflect.Map if !isSlice && !isMap { return errors.New("needs a pointer to a slice or a map") } sliceElementType := sliceValue.Type().Elem() - var tp = tpStruct + tp := tpStruct if session.statement.RefTable == nil { if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { @@ -152,7 +154,6 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if err != ErrCacheFailed { return err } - err = nil // !nashtsai! reset err to nil for ErrCacheFailed session.engine.logger.Warnf("Cache Find Failed") } } @@ -161,6 +162,16 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { + elemType := containerValue.Type().Elem() + var isPointer bool + if elemType.Kind() == reflect.Ptr { + isPointer = true + elemType = elemType.Elem() + } + if elemType.Kind() == reflect.Ptr { + return errors.New("pointer to pointer is not supported") + } + rows, err := session.queryRows(sqlStr, args...) if err != nil { return err @@ -172,31 +183,13 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } - var newElemFunc func(fields []string) reflect.Value - elemType := containerValue.Type().Elem() - var isPointer bool - if elemType.Kind() == reflect.Ptr { - isPointer = true - elemType = elemType.Elem() - } - if elemType.Kind() == reflect.Ptr { - return errors.New("pointer to pointer is not supported") + types, err := rows.ColumnTypes() + if err != nil { + return err } - newElemFunc = func(fields []string) reflect.Value { - switch elemType.Kind() { - case reflect.Slice: - slice := reflect.MakeSlice(elemType, len(fields), len(fields)) - x := reflect.New(slice.Type()) - x.Elem().Set(slice) - return x - case reflect.Map: - mp := reflect.MakeMap(elemType) - x := reflect.New(mp.Type()) - x.Elem().Set(mp) - return x - } - return reflect.New(elemType) + newElemFunc := func(fields []string) reflect.Value { + return utils.New(elemType, len(fields), len(fields)) } var containerValueSetFunc func(*reflect.Value, schemas.PK) error @@ -221,10 +214,15 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error { keyValue := reflect.New(keyType) - err := convertPKToValue(table, keyValue.Interface(), pk) - if err != nil { - return err + cols := table.PKColumns() + if len(cols) == 1 { + if err := convert.AssignValue(keyValue, pk[0]); err != nil { + return err + } + } else { + keyValue.Set(reflect.ValueOf(&pk)) } + if isPointer { containerValue.SetMapIndex(keyValue.Elem(), newValue.Elem().Addr()) } else { @@ -235,13 +233,12 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } if elemType.Kind() == reflect.Struct { - var newValue = newElemFunc(fields) - dataStruct := utils.ReflectValue(newValue.Interface()) - tb, err := session.engine.tagParser.ParseWithCache(dataStruct) + newValue := newElemFunc(fields) + tb, err := session.engine.tagParser.ParseWithCache(newValue) if err != nil { return err } - err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc) + err = session.rows2Beans(rows, fields, types, tb, newElemFunc, containerValueSetFunc) rows.Close() if err != nil { return err @@ -250,18 +247,17 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } for rows.Next() { - var newValue = newElemFunc(fields) + newValue := newElemFunc(fields) bean := newValue.Interface() switch elemType.Kind() { case reflect.Slice: - err = rows.ScanSlice(bean) + err = session.getSlice(rows, types, fields, bean) case reflect.Map: - err = rows.ScanMap(bean) + err = session.getMap(rows, types, fields, bean) default: err = rows.Scan(bean) } - if err != nil { return err } @@ -270,17 +266,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } } - return nil -} - -func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { - cols := table.PKColumns() - if len(cols) == 1 { - return convertAssign(dst, pk[0]) - } - - dst = pk - return nil + return rows.Err() } func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { @@ -322,7 +308,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") return ErrCacheFailed } - var res = make([]string, len(table.PrimaryKeys)) + res := make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { return err @@ -337,6 +323,9 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ids = append(ids, pk) } + if rows.Err() != nil { + return rows.Err() + } session.engine.logger.Debugf("[cache] cache sql: %v, %v, %v, %v, %v", ids, tableName, sqlStr, newsql, args) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) @@ -351,7 +340,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ididxes := make(map[string]int) var ides []schemas.PK - var temps = make([]interface{}, len(ids)) + temps := make([]interface{}, len(ids)) for idx, id := range ids { sid, err := id.ToString() @@ -466,14 +455,15 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) } } else if sliceValue.Kind() == reflect.Map { - var key = ids[j] + key := ids[j] keyType := sliceValue.Type().Key() + keyValue := reflect.New(keyType) var ikey interface{} if len(key) == 1 { - ikey, err = str2PK(fmt.Sprintf("%v", key[0]), keyType) - if err != nil { + if err := convert.AssignValue(keyValue, key[0]); err != nil { return err } + ikey = keyValue.Elem().Interface() } else { if keyType.Kind() != reflect.Slice { return errors.New("table have multiple primary keys, key is not schemas.PK or slice") diff --git a/session_get.go b/session_get.go index afedcd1f..9bb92a8b 100644 --- a/session_get.go +++ b/session_get.go @@ -8,39 +8,67 @@ import ( "database/sql" "errors" "fmt" + "math/big" "reflect" "strconv" + "time" "xorm.io/xorm/caches" + "xorm.io/xorm/convert" + "xorm.io/xorm/core" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) +var ( + // ErrObjectIsNil return error of object is nil + ErrObjectIsNil = errors.New("object should not be nil") +) + // Get retrieve one record from database, bean's non-empty fields // will be as conditions -func (session *Session) Get(bean interface{}) (bool, error) { +func (session *Session) Get(beans ...interface{}) (bool, error) { if session.isAutoClose { defer session.Close() } - return session.get(bean) + return session.get(beans...) } -func (session *Session) get(bean interface{}) (bool, error) { +func isPtrOfTime(v interface{}) bool { + if _, ok := v.(*time.Time); ok { + return true + } + + el := reflect.ValueOf(v).Elem() + if el.Kind() != reflect.Struct { + return false + } + + return el.Type().ConvertibleTo(schemas.TimeType) +} + +func (session *Session) get(beans ...interface{}) (bool, error) { defer session.resetStatement() if session.statement.LastError != nil { return false, session.statement.LastError } + if len(beans) == 0 { + return false, errors.New("needs at least one parameter for get") + } - beanValue := reflect.ValueOf(bean) + beanValue := reflect.ValueOf(beans[0]) if beanValue.Kind() != reflect.Ptr { return false, errors.New("needs a pointer to a value") } else if beanValue.Elem().Kind() == reflect.Ptr { return false, errors.New("a pointer to a pointer is not allowed") + } else if beanValue.IsNil() { + return false, ErrObjectIsNil } - if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.SetRefBean(bean); err != nil { + var isStruct = beanValue.Elem().Kind() == reflect.Struct && !isPtrOfTime(beans[0]) + if isStruct { + if err := session.statement.SetRefBean(beans[0]); err != nil { return false, err } } @@ -50,11 +78,11 @@ func (session *Session) get(bean interface{}) (bool, error) { var err error if session.statement.RawSQL == "" { - if len(session.statement.TableName()) <= 0 { + if len(session.statement.TableName()) == 0 { return false, ErrTableNotFound } session.statement.Limit(1) - sqlStr, args, err = session.statement.GenGetSQL(bean) + sqlStr, args, err = session.statement.GenGetSQL(beans[0]) if err != nil { return false, err } @@ -65,10 +93,10 @@ func (session *Session) get(bean interface{}) (bool, error) { table := session.statement.RefTable - if session.statement.ColumnMap.IsEmpty() && session.canCache() && beanValue.Elem().Kind() == reflect.Struct { + if session.statement.ColumnMap.IsEmpty() && session.canCache() && isStruct { if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && !session.statement.GetUnscoped() { - has, err := session.cacheGet(bean, sqlStr, args...) + has, err := session.cacheGet(beans[0], sqlStr, args...) if err != ErrCacheFailed { return has, err } @@ -76,12 +104,12 @@ func (session *Session) get(bean interface{}) (bool, error) { } context := session.statement.Context - if context != nil { + if context != nil && isStruct { res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args)) if res != nil { session.engine.logger.Debugf("hit context cache: %s", sqlStr) - structValue := reflect.Indirect(reflect.ValueOf(bean)) + structValue := reflect.Indirect(reflect.ValueOf(beans[0])) structValue.Set(reflect.Indirect(reflect.ValueOf(res))) session.lastSQL = "" session.lastSQLArgs = nil @@ -89,19 +117,33 @@ func (session *Session) get(bean interface{}) (bool, error) { } } - has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...) + has, err := session.nocacheGet(beanValue.Elem().Kind(), table, beans, sqlStr, args...) if err != nil || !has { return has, err } - if context != nil { - context.Put(fmt.Sprintf("%v-%v", sqlStr, args), bean) + if context != nil && isStruct { + context.Put(fmt.Sprintf("%v-%v", sqlStr, args), beans[0]) } return true, nil } -func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { +func isScannableStruct(bean interface{}, typeLen int) bool { + switch bean.(type) { + case *time.Time: + return false + case sql.Scanner: + return false + case convert.Conversion: + return typeLen > 1 + case *big.Float: + return false + } + return true +} + +func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, beans []interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { return false, err @@ -109,161 +151,124 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, defer rows.Close() if !rows.Next() { - if rows.Err() != nil { - return false, rows.Err() - } - return false, nil + return false, rows.Err() } - switch bean.(type) { - case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString: - return true, rows.Scan(&bean) - case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: - return true, rows.Scan(bean) - case *string: - var res sql.NullString - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*string)) = res.String - } - return true, nil - case *int: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int)) = int(res.Int64) - } - return true, nil - case *int8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int8)) = int8(res.Int64) - } - return true, nil - case *int16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int16)) = int16(res.Int64) - } - return true, nil - case *int32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int32)) = int32(res.Int64) - } - return true, nil - case *int64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*int64)) = int64(res.Int64) - } - return true, nil - case *uint: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint)) = uint(res.Int64) - } - return true, nil - case *uint8: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint8)) = uint8(res.Int64) - } - return true, nil - case *uint16: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint16)) = uint16(res.Int64) - } - return true, nil - case *uint32: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint32)) = uint32(res.Int64) - } - return true, nil - case *uint64: - var res sql.NullInt64 - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*uint64)) = uint64(res.Int64) - } - return true, nil - case *bool: - var res sql.NullBool - if err := rows.Scan(&res); err != nil { - return true, err - } - if res.Valid { - *(bean.(*bool)) = res.Bool - } - return true, nil + // WARN: Alougth rows return true, but we may also return error. + types, err := rows.ColumnTypes() + if err != nil { + return true, err + } + fields, err := rows.Columns() + if err != nil { + return true, err } - switch beanKind { - case reflect.Struct: - fields, err := rows.Columns() + if err := session.scan(rows, table, beanKind, beans, types, fields); err != nil { + return true, err + } + rows.Close() + + return true, session.executeProcessors() +} + +func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []interface{}, types []*sql.ColumnType, fields []string) error { + if len(beans) == 1 { + bean := beans[0] + switch firstBeanKind { + case reflect.Struct: + if !isScannableStruct(bean, len(types)) { + break + } + scanResults, err := session.row2Slice(rows, fields, types, bean) + if err != nil { + return err + } + + dataStruct := utils.ReflectValue(bean) + _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + return err + case reflect.Slice: + return session.getSlice(rows, types, fields, bean) + case reflect.Map: + return session.getMap(rows, types, fields, bean) + } + } + + if len(beans) != len(types) { + return fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) + } + + return session.engine.scan(rows, fields, types, beans...) +} + +func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) error { + switch t := bean.(type) { + case *[]string: + res, err := session.engine.scanStringInterface(rows, fields, types) if err != nil { - // WARN: Alougth rows return true, but get fields failed - return true, err + return err } - scanResults, err := session.row2Slice(rows, fields, bean) - if err != nil { - return false, err + var needAppend = len(*t) == 0 // both support slice is empty or has been initlized + for i, r := range res { + if needAppend { + *t = append(*t, r.(*sql.NullString).String) + } else { + (*t)[i] = r.(*sql.NullString).String + } } - // close it before convert data - rows.Close() - - dataStruct := utils.ReflectValue(bean) - _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) + return nil + case *[]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, fields, types) if err != nil { - return true, err + return err } - - return true, session.executeProcessors() - case reflect.Slice: - err = rows.ScanSlice(bean) - case reflect.Map: - err = rows.ScanMap(bean) - case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - err = rows.Scan(bean) + var needAppend = len(*t) == 0 + for ii := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return err + } + if needAppend { + *t = append(*t, s) + } else { + (*t)[ii] = s + } + } + return nil default: - err = rows.Scan(bean) + return fmt.Errorf("unspoorted slice type: %t", t) } +} - return true, err +func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) error { + switch t := bean.(type) { + case *map[string]string: + scanResults, err := session.engine.scanStringInterface(rows, fields, types) + if err != nil { + return err + } + for ii, key := range fields { + (*t)[key] = scanResults[ii].(*sql.NullString).String + } + return nil + case *map[string]interface{}: + scanResults, err := session.engine.scanInterfaces(rows, fields, types) + if err != nil { + return err + } + for ii, key := range fields { + s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) + if err != nil { + return err + } + (*t)[key] = s + } + return nil + default: + return fmt.Errorf("unspoorted map type: %t", t) + } } func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { @@ -297,9 +302,12 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf if rows.Next() { err = rows.ScanSlice(&res) if err != nil { - return false, err + return true, err } } else { + if rows.Err() != nil { + return false, rows.Err() + } return false, ErrCacheFailed } @@ -339,7 +347,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf cacheBean := cacher.GetBean(tableName, sid) if cacheBean == nil { cacheBean = bean - has, err = session.nocacheGet(reflect.Struct, table, cacheBean, sqlStr, args...) + has, err = session.nocacheGet(reflect.Struct, table, []interface{}{cacheBean}, sqlStr, args...) if err != nil || !has { return has, err } diff --git a/session_insert.go b/session_insert.go index 5f968151..fc025613 100644 --- a/session_insert.go +++ b/session_insert.go @@ -9,15 +9,17 @@ import ( "fmt" "reflect" "sort" - "strconv" "strings" + "time" + "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) // ErrNoElementsOnSlice represents an error there is no element when insert -var ErrNoElementsOnSlice = errors.New("No element on slice when insert") +var ErrNoElementsOnSlice = errors.New("no element on slice when insert") // Insert insert one or more beans func (session *Session) Insert(beans ...interface{}) (int64, error) { @@ -35,71 +37,42 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { }() for _, bean := range beans { - switch bean.(type) { + var cnt int64 + var err error + switch v := bean.(type) { case map[string]interface{}: - cnt, err := session.insertMapInterface(bean.(map[string]interface{})) - if err != nil { - return affected, err - } - affected += cnt + cnt, err = session.insertMapInterface(v) case []map[string]interface{}: - s := bean.([]map[string]interface{}) - for i := 0; i < len(s); i++ { - cnt, err := session.insertMapInterface(s[i]) - if err != nil { - return affected, err - } - affected += cnt - } + cnt, err = session.insertMultipleMapInterface(v) case map[string]string: - cnt, err := session.insertMapString(bean.(map[string]string)) - if err != nil { - return affected, err - } - affected += cnt + cnt, err = session.insertMapString(v) case []map[string]string: - s := bean.([]map[string]string) - for i := 0; i < len(s); i++ { - cnt, err := session.insertMapString(s[i]) - if err != nil { - return affected, err - } - affected += cnt - } + cnt, err = session.insertMultipleMapString(v) default: sliceValue := reflect.Indirect(reflect.ValueOf(bean)) if sliceValue.Kind() == reflect.Slice { - size := sliceValue.Len() - if size <= 0 { - return 0, ErrNoElementsOnSlice - } - - cnt, err := session.innerInsertMulti(bean) - if err != nil { - return affected, err - } - affected += cnt + cnt, err = session.insertMultipleStruct(bean) } else { - cnt, err := session.innerInsert(bean) - if err != nil { - return affected, err - } - affected += cnt + cnt, err = session.insertStruct(bean) } } + if err != nil { + return affected, err + } + affected += cnt } return affected, err } -func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) { +func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, error) { sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice { return 0, errors.New("needs a pointer to a slice") } if sliceValue.Len() <= 0 { - return 0, errors.New("could not insert a empty slice") + return 0, ErrNoElementsOnSlice } if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil { @@ -107,7 +80,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } tableName := session.statement.TableName() - if len(tableName) <= 0 { + if len(tableName) == 0 { return 0, ErrTableNotFound } @@ -117,7 +90,6 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error colNames []string colMultiPlaces []string args []interface{} - cols []*schemas.Column ) for i := 0; i < size; i++ { @@ -150,6 +122,12 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } fieldValue := *ptrFieldValue if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) { + if session.engine.dialect.Features().AutoincrMode == dialects.SequenceAutoincrMode { + if i == 0 { + colNames = append(colNames, col.Name) + } + colPlaces = append(colPlaces, utils.SeqName(tableName)+".nextval") + } continue } if col.MapType == schemas.ONLYFROMDB { @@ -164,8 +142,18 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } + // !satorunooshie! set fieldValue as nil when column is nullable and zero-value + if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok { + if col.Nullable && utils.IsValueZero(fieldValue) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + } + } if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { - val, t := session.engine.nowTime(col) + val, t, err := session.engine.nowTime(col) + if err != nil { + return 0, err + } args = append(args, val) var colName = col.Name @@ -190,7 +178,6 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if i == 0 { colNames = append(colNames, col.Name) - cols = append(cols, col) } colPlaces = append(colPlaces, "?") } @@ -221,7 +208,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, err } - session.cacheInsert(tableName) + _ = session.cacheInsert(tableName) lenAfterClosures := len(session.afterClosures) for i := 0; i < size; i++ { @@ -268,18 +255,14 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { return 0, ErrPtrSliceType } - if sliceValue.Len() <= 0 { - return 0, ErrNoElementsOnSlice - } - - return session.innerInsertMulti(rowsSlicePtr) + return session.insertMultipleStruct(rowsSlicePtr) } -func (session *Session) innerInsert(bean interface{}) (int64, error) { +func (session *Session) insertStruct(bean interface{}) (int64, error) { if err := session.statement.SetRefBean(bean); err != nil { return 0, err } - if len(session.statement.TableName()) <= 0 { + if len(session.statement.TableName()) == 0 { return 0, ErrTableNotFound } @@ -305,6 +288,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if err != nil { return 0, err } + sqlStr = session.engine.dialect.Quoter().Replace(sqlStr) handleAfterInsertProcessorFunc := func(bean interface{}) { if session.isAutoCommit { @@ -324,7 +308,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { copy(afterClosures, session.afterClosures) session.afterInsertBeans[bean] = &afterClosures } - } else { if _, ok := interface{}(bean).(AfterInsertProcessor); ok { session.afterInsertBeans[bean] = nil @@ -334,17 +317,55 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { cleanupProcessorsClosures(&session.afterClosures) // cleanup after used } - // for postgres, many of them didn't implement lastInsertId, so we should - // implemented it ourself. - if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 { - res, err := session.queryBytes("select seq_atable.currval from dual", args...) - if err != nil { - return 0, err + // if there is auto increment column and driver don't support return it + if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID { + var sql string + var newArgs []interface{} + var needCommit bool + var id int64 + if session.engine.dialect.URI().DBType == schemas.ORACLE || session.engine.dialect.URI().DBType == schemas.DAMENG { + if session.isAutoCommit { // if it's not in transaction + if err := session.Begin(); err != nil { + return 0, err + } + needCommit = true + } + _, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } + i := utils.IndexSlice(colNames, table.AutoIncrement) + if i > -1 { + id, err = convert.AsInt64(args[i]) + if err != nil { + return 0, err + } + } else { + sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName)) + } + } else { + sql = sqlStr + newArgs = args + } + + if id == 0 { + err := session.queryRow(sql, newArgs...).Scan(&id) + if err != nil { + return 0, err + } + if needCommit { + if err := session.Commit(); err != nil { + return 0, err + } + } + if id == 0 { + return 0, errors.New("insert successfully but not returned id") + } } defer handleAfterInsertProcessorFunc(bean) - session.cacheInsert(tableName) + _ = session.cacheInsert(tableName) if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) @@ -355,16 +376,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - if len(res) < 1 { - return 0, errors.New("insert no error but not returned id") - } - - idByte := res[0][table.AutoIncrement] - id, err := strconv.ParseInt(string(idByte), 10, 64) - if err != nil || id <= 0 { - return 1, err - } - aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { session.engine.logger.Errorf("%v", err) @@ -374,51 +385,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil - } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || - session.engine.dialect.URI().DBType == schemas.MSSQL) { - res, err := session.queryBytes(sqlStr, args...) - - if err != nil { - return 0, err - } - defer handleAfterInsertProcessorFunc(bean) - - session.cacheInsert(tableName) - - if table.Version != "" && session.statement.CheckVersion { - verValue, err := table.VersionColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("%v", err) - } else if verValue.IsValid() && verValue.CanSet() { - session.incrVersionFieldValue(verValue) - } - } - - if len(res) < 1 { - return 0, errors.New("insert successfully but not returned id") - } - - idByte := res[0][table.AutoIncrement] - id, err := strconv.ParseInt(string(idByte), 10, 64) - if err != nil || id <= 0 { - return 1, err - } - - aiValue, err := table.AutoIncrColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("%v", err) - } - - if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { - return 1, nil - } - - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil + return 1, convert.AssignValue(*aiValue, id) } res, err := session.exec(sqlStr, args...) @@ -428,7 +395,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { defer handleAfterInsertProcessorFunc(bean) - session.cacheInsert(tableName) + _ = session.cacheInsert(tableName) if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) @@ -458,7 +425,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - aiValue.Set(int64ToIntValue(id, aiValue.Type())) + if err := convert.AssignValue(*aiValue, id); err != nil { + return 0, err + } return res.RowsAffected() } @@ -466,12 +435,13 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // InsertOne insert only one struct into database as a record. // The in parameter bean must a struct or a point to struct. The return // parameter is inserted and error +// Deprecated: Please use Insert directly func (session *Session) InsertOne(bean interface{}) (int64, error) { if session.isAutoClose { defer session.Close() } - return session.innerInsert(bean) + return session.insertStruct(bean) } func (session *Session) cacheInsert(table string) error { @@ -497,19 +467,12 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac if col.MapType == schemas.ONLYFROMDB { continue } - - if col.IsDeleted { - continue - } - if session.statement.OmitColumnMap.Contain(col.Name) { continue } - if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } - if session.statement.IncrColumns.IsColExist(col.Name) { continue } else if session.statement.DecrColumns.IsColExist(col.Name) { @@ -518,6 +481,16 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac continue } + if col.IsDeleted { + arg, err := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, time.Time{}) + if err != nil { + return nil, nil, err + } + args = append(args, arg) + colNames = append(colNames, col.Name) + continue + } + fieldValuePtr, err := col.ValueOf(bean) if err != nil { return nil, nil, err @@ -538,7 +511,10 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { // if time is non-empty, then set to auto time - val, t := session.engine.nowTime(col) + val, t, err := session.engine.nowTime(col) + if err != nil { + return nil, nil, err + } args = append(args, val) var colName = col.Name @@ -567,7 +543,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } tableName := session.statement.TableName() - if len(tableName) <= 0 { + if len(tableName) == 0 { return 0, ErrTableNotFound } @@ -588,13 +564,44 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err return session.insertMap(columns, args) } +func (session *Session) insertMultipleMapInterface(maps []map[string]interface{}) (int64, error) { + if len(maps) == 0 { + return 0, ErrNoElementsOnSlice + } + + tableName := session.statement.TableName() + if len(tableName) == 0 { + return 0, ErrTableNotFound + } + + var columns = make([]string, 0, len(maps[0])) + exprs := session.statement.ExprColumns + for k := range maps[0] { + if !exprs.IsColExist(k) { + columns = append(columns, k) + } + } + sort.Strings(columns) + + var argss = make([][]interface{}, 0, len(maps)) + for _, m := range maps { + var args = make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + argss = append(argss, args) + } + + return session.insertMultipleMap(columns, argss) +} + func (session *Session) insertMapString(m map[string]string) (int64, error) { if len(m) == 0 { return 0, ErrParamsType } tableName := session.statement.TableName() - if len(tableName) <= 0 { + if len(tableName) == 0 { return 0, ErrTableNotFound } @@ -616,9 +623,40 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { return session.insertMap(columns, args) } +func (session *Session) insertMultipleMapString(maps []map[string]string) (int64, error) { + if len(maps) == 0 { + return 0, ErrNoElementsOnSlice + } + + tableName := session.statement.TableName() + if len(tableName) == 0 { + return 0, ErrTableNotFound + } + + var columns = make([]string, 0, len(maps[0])) + exprs := session.statement.ExprColumns + for k := range maps[0] { + if !exprs.IsColExist(k) { + columns = append(columns, k) + } + } + sort.Strings(columns) + + var argss = make([][]interface{}, 0, len(maps)) + for _, m := range maps { + var args = make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + argss = append(argss, args) + } + + return session.insertMultipleMap(columns, argss) +} + func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) { tableName := session.statement.TableName() - if len(tableName) <= 0 { + if len(tableName) == 0 { return 0, ErrTableNotFound } @@ -626,6 +664,34 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, if err != nil { return 0, err } + sql = session.engine.dialect.Quoter().Replace(sql) + + if err := session.cacheInsert(tableName); err != nil { + return 0, err + } + + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } + affected, err := res.RowsAffected() + if err != nil { + return 0, err + } + return affected, nil +} + +func (session *Session) insertMultipleMap(columns []string, argss [][]interface{}) (int64, error) { + tableName := session.statement.TableName() + if len(tableName) == 0 { + return 0, ErrTableNotFound + } + + sql, args, err := session.statement.GenInsertMultipleMapSQL(columns, argss) + if err != nil { + return 0, err + } + sql = session.engine.dialect.Quoter().Replace(sql) if err := session.cacheInsert(tableName); err != nil { return 0, err diff --git a/session_iterate.go b/session_iterate.go index 8cab8f48..afb9a7cc 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -54,7 +54,7 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { } i++ } - return err + return rows.Err() } // BufferSize sets the buffersize for iterate @@ -95,7 +95,7 @@ func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error { break } - start = start + slice.Elem().Len() + start += slice.Elem().Len() if pLimitN != nil && start+bufferSize > *pLimitN { bufferSize = *pLimitN - start } diff --git a/session_query.go b/session_query.go deleted file mode 100644 index 12136466..00000000 --- a/session_query.go +++ /dev/null @@ -1,257 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "reflect" - "strconv" - "time" - - "xorm.io/xorm/core" - "xorm.io/xorm/schemas" -) - -// Query runs a raw sql and return records as []map[string][]byte -func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, error) { - if session.isAutoClose { - defer session.Close() - } - - sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) - if err != nil { - return nil, err - } - - return session.queryBytes(sqlStr, args...) -} - -func value2String(rawValue *reflect.Value) (str string, err error) { - aa := reflect.TypeOf((*rawValue).Interface()) - vv := reflect.ValueOf((*rawValue).Interface()) - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - case reflect.String: - str = vv.String() - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - data := rawValue.Interface().([]byte) - str = string(data) - if str == "\x00" { - str = "0" - } - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - // time type - case reflect.Struct: - if aa.ConvertibleTo(schemas.TimeType) { - str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) - } else { - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) - } - return -} - -func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) { - result := make(map[string]string) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - // if row is null then as empty string - if rawValue.Interface() == nil { - result[key] = "" - continue - } - - if data, err := value2String(&rawValue); err == nil { - result[key] = data - } else { - return nil, err - } - } - return result, nil -} - -func row2sliceStr(rows *core.Rows, fields []string) (results []string, err error) { - result := make([]string, 0, len(fields)) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for i := 0; i < len(fields); i++ { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[i])) - // if row is null then as empty string - if rawValue.Interface() == nil { - result = append(result, "") - continue - } - - if data, err := value2String(&rawValue); err == nil { - result = append(result, data) - } else { - return nil, err - } - } - return result, nil -} - -func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - for rows.Next() { - result, err := row2mapStr(rows, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, result) - } - - return resultsSlice, nil -} - -func rows2SliceString(rows *core.Rows) (resultsSlice [][]string, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - for rows.Next() { - record, err := row2sliceStr(rows, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, record) - } - - return resultsSlice, nil -} - -// QueryString runs a raw sql and return records as []map[string]string -func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) { - if session.isAutoClose { - defer session.Close() - } - - sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) - if err != nil { - return nil, err - } - - rows, err := session.queryRows(sqlStr, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - return rows2Strings(rows) -} - -// QuerySliceString runs a raw sql and return records as [][]string -func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, error) { - if session.isAutoClose { - defer session.Close() - } - - sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) - if err != nil { - return nil, err - } - - rows, err := session.queryRows(sqlStr, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - return rows2SliceString(rows) -} - -func row2mapInterface(rows *core.Rows, fields []string) (resultsMap map[string]interface{}, err error) { - resultsMap = make(map[string]interface{}, len(fields)) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - resultsMap[key] = reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])).Interface() - } - return -} - -func rows2Interfaces(rows *core.Rows) (resultsSlice []map[string]interface{}, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - for rows.Next() { - result, err := row2mapInterface(rows, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, result) - } - - return resultsSlice, nil -} - -// QueryInterface runs a raw sql and return records as []map[string]interface{} -func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) { - if session.isAutoClose { - defer session.Close() - } - - sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) - if err != nil { - return nil, err - } - - rows, err := session.queryRows(sqlStr, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - return rows2Interfaces(rows) -} diff --git a/session_raw.go b/session_raw.go index 4cfe297a..add584d0 100644 --- a/session_raw.go +++ b/session_raw.go @@ -6,7 +6,7 @@ package xorm import ( "database/sql" - "reflect" + "strings" "xorm.io/xorm/core" ) @@ -33,7 +33,7 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row if session.isAutoCommit { var db *core.DB - if session.sessionType == groupSession { + if session.sessionType == groupSession && strings.EqualFold(strings.TrimSpace(sqlStr)[:6], "select") && !session.statement.IsForUpdate { db = session.engine.engineGroup.Slave().DB() } else { db = session.DB() @@ -46,91 +46,106 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row return nil, err } - rows, err := stmt.QueryContext(session.ctx, args...) - if err != nil { - return nil, err - } - return rows, nil + return stmt.QueryContext(session.ctx, args...) } - rows, err := db.QueryContext(session.ctx, sqlStr, args...) + return db.QueryContext(session.ctx, sqlStr, args...) + } + + if session.prepareStmt { + stmt, err := session.doPrepareTx(sqlStr) if err != nil { return nil, err } - return rows, nil + + return stmt.QueryContext(session.ctx, args...) } - rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...) - if err != nil { - return nil, err - } - return rows, nil + return session.tx.QueryContext(session.ctx, sqlStr, args...) } func (session *Session) queryRow(sqlStr string, args ...interface{}) *core.Row { return core.NewRow(session.queryRows(sqlStr, args...)) } -func value2Bytes(rawValue *reflect.Value) ([]byte, error) { - str, err := value2String(rawValue) +// Query runs a raw sql and return records as []map[string][]byte +func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, error) { + if session.isAutoClose { + defer session.Close() + } + + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } - return []byte(str), nil -} -func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) { - result := make(map[string][]byte) - scanResultContainers := make([]interface{}, len(fields)) - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - //if row is null then ignore - if rawValue.Interface() == nil { - result[key] = []byte{} - continue - } - - if data, err := value2Bytes(&rawValue); err == nil { - result[key] = data - } else { - return nil, err // !nashtsai! REVIEW, should return err or just error log? - } - } - return result, nil -} - -func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - for rows.Next() { - result, err := row2map(rows, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, result) - } - - return resultsSlice, nil -} - -func (session *Session) queryBytes(sqlStr string, args ...interface{}) ([]map[string][]byte, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { return nil, err } defer rows.Close() - return rows2maps(rows) + return session.engine.scanByteMaps(rows) +} + +// QueryString runs a raw sql and return records as []map[string]string +func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) { + if session.isAutoClose { + defer session.Close() + } + + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) + if err != nil { + return nil, err + } + + rows, err := session.queryRows(sqlStr, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return session.engine.ScanStringMaps(rows) +} + +// QuerySliceString runs a raw sql and return records as [][]string +func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, error) { + if session.isAutoClose { + defer session.Close() + } + + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) + if err != nil { + return nil, err + } + + rows, err := session.queryRows(sqlStr, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return session.engine.ScanStringSlices(rows) +} + +// QueryInterface runs a raw sql and return records as []map[string]interface{} +func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) { + if session.isAutoClose { + defer session.Close() + } + + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) + if err != nil { + return nil, err + } + + rows, err := session.queryRows(sqlStr, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return session.engine.ScanInterfaceMaps(rows) } func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { @@ -142,6 +157,13 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er session.lastSQLArgs = args if !session.isAutoCommit { + if session.prepareStmt { + stmt, err := session.doPrepareTx(sqlStr) + if err != nil { + return nil, err + } + return stmt.ExecContext(session.ctx, args...) + } return session.tx.ExecContext(session.ctx, sqlStr, args...) } @@ -150,12 +172,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er if err != nil { return nil, err } - - res, err := stmt.ExecContext(session.ctx, args...) - if err != nil { - return nil, err - } - return res, nil + return stmt.ExecContext(session.ctx, args...) } return session.DB().ExecContext(session.ctx, sqlStr, args...) diff --git a/session_schema.go b/session_schema.go index 9ccf8abe..e66c3b42 100644 --- a/session_schema.go +++ b/session_schema.go @@ -6,12 +6,14 @@ package xorm import ( "bufio" + "context" "database/sql" "fmt" "io" "os" "strings" + "xorm.io/xorm/dialects" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -40,13 +42,28 @@ func (session *Session) createTable(bean interface{}) error { return err } - sqlStrs := session.statement.GenCreateTableSQL() - for _, s := range sqlStrs { - _, err := session.exec(s) + session.statement.RefTable.StoreEngine = session.statement.StoreEngine + session.statement.RefTable.Charset = session.statement.Charset + tableName := session.statement.TableName() + refTable := session.statement.RefTable + if refTable.AutoIncrement != "" && session.engine.dialect.Features().AutoincrMode == dialects.SequenceAutoincrMode { + sqlStr, err := session.engine.dialect.CreateSequenceSQL(context.Background(), session.engine.db, utils.SeqName(tableName)) if err != nil { return err } + if _, err := session.exec(sqlStr); err != nil { + return err + } } + + sqlStr, _, err := session.engine.dialect.CreateTableSQL(context.Background(), session.engine.db, refTable, tableName) + if err != nil { + return err + } + if _, err := session.exec(sqlStr); err != nil { + return err + } + return nil } @@ -141,11 +158,32 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { checkIfExist = exist } - if checkIfExist { - _, err := session.exec(sqlStr) + if !checkIfExist { + return nil + } + if _, err := session.exec(sqlStr); err != nil { return err } - return nil + + if session.engine.dialect.Features().AutoincrMode == dialects.IncrAutoincrMode { + return nil + } + + var seqName = utils.SeqName(tableName) + exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName) + if err != nil { + return err + } + if !exist { + return nil + } + + sqlStr, err = session.engine.dialect.DropSequenceSQL(seqName) + if err != nil { + return err + } + _, err = session.exec(sqlStr) + return err } // IsTableExist if a table is exist @@ -185,24 +223,6 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) { return total == 0, nil } -// find if index is exist according cols -func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { - indexes, err := session.engine.dialect.GetIndexes(session.getQueryer(), session.ctx, tableName) - if err != nil { - return false, err - } - - for _, index := range indexes { - if utils.SliceEq(index.Cols, cols) { - if unique { - return index.Type == schemas.UniqueType, nil - } - return index.Type == schemas.IndexType, nil - } - } - return false, nil -} - func (session *Session) addColumn(colName string) error { col := session.statement.RefTable.GetColumn(colName) sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col) @@ -225,7 +245,13 @@ func (session *Session) addUnique(tableName, uqeName string) error { } // Sync2 synchronize structs to database tables +// Depricated func (session *Session) Sync2(beans ...interface{}) error { + return session.Sync(beans...) +} + +// Sync synchronize structs to database tables +func (session *Session) Sync(beans ...interface{}) error { engine := session.engine if session.isAutoClose { @@ -336,8 +362,10 @@ func (session *Session) Sync2(beans ...interface{}) error { } } else { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { - engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", - tbNameWithSchema, col.Name, curType, expectedType) + if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) { + engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", + tbNameWithSchema, col.Name, curType, expectedType) + } } } } else if expectedType == schemas.Varchar { @@ -348,6 +376,8 @@ func (session *Session) Sync2(beans ...interface{}) error { _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } } + } else if col.Comment != oriCol.Comment { + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } if col.Default != oriCol.Default { @@ -448,27 +478,43 @@ func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { // Import SQL DDL from io.Reader func (session *Session) Import(r io.Reader) ([]sql.Result, error) { - var results []sql.Result - var lastError error - scanner := bufio.NewScanner(r) + var ( + results []sql.Result + lastError error + inSingleQuote bool + startComment bool + ) - var inSingleQuote bool + scanner := bufio.NewScanner(r) semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } + var oriInSingleQuote = inSingleQuote for i, b := range data { - if b == '\'' { - inSingleQuote = !inSingleQuote - } - if !inSingleQuote && b == ';' { - return i + 1, data[0:i], nil + if startComment { + if b == '\n' { + startComment = false + } + } else { + if !inSingleQuote && i > 0 && data[i-1] == '-' && data[i] == '-' { + startComment = true + continue + } + + if b == '\'' { + inSingleQuote = !inSingleQuote + } + if !inSingleQuote && b == ';' { + return i + 1, data[0:i], nil + } } } // If we're at EOF, we have a final, non-terminated line. Return it. if atEOF { return len(data), data, nil } + inSingleQuote = oriInSingleQuote // Request more data. return 0, nil, nil } @@ -479,10 +525,10 @@ func (session *Session) Import(r io.Reader) ([]sql.Result, error) { query := strings.Trim(scanner.Text(), " \t\n\r") if len(query) > 0 { result, err := session.Exec(query) - results = append(results, result) if err != nil { return nil, err } + results = append(results, result) } } diff --git a/session_stats.go b/session_stats.go index 17d0a675..5d0da5e9 100644 --- a/session_stats.go +++ b/session_stats.go @@ -70,12 +70,12 @@ func (session *Session) SumInt(bean interface{}, columnName string) (res int64, // Sums call sum some columns. bean's non-empty fields are conditions. func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) { - var res = make([]float64, len(columnNames), len(columnNames)) + var res = make([]float64, len(columnNames)) return res, session.sum(&res, bean, columnNames...) } // SumsInt sum specify columns and return as []int64 instead of []float64 func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) { - var res = make([]int64, len(columnNames), len(columnNames)) + var res = make([]int64, len(columnNames)) return res, session.sum(&res, bean, columnNames...) } diff --git a/session_tx.go b/session_tx.go index 57791703..4fa56891 100644 --- a/session_tx.go +++ b/session_tx.go @@ -75,7 +75,7 @@ func (session *Session) Commit() error { } cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) { if len(*slices) > 0 { - *slices = make(map[interface{}]*[]func(interface{}), 0) + *slices = make(map[interface{}]*[]func(interface{})) } } cleanUpFunc(&session.afterInsertBeans) @@ -84,3 +84,8 @@ func (session *Session) Commit() error { } return nil } + +// IsInTx if current session is in a transaction +func (session *Session) IsInTx() bool { + return !session.isAutoCommit +} diff --git a/session_update.go b/session_update.go index 7df8c752..e7104710 100644 --- a/session_update.go +++ b/session_update.go @@ -17,6 +17,12 @@ import ( "xorm.io/xorm/schemas" ) +// enumerated all errors +var ( + ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") +) + +//revive:disable func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if table == nil || session.tx != nil { @@ -34,7 +40,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri var nStart int if len(args) > 0 { - if strings.Index(sqlStr, "?") > -1 { + if strings.Contains(sqlStr, "?") { nStart = strings.Count(oldhead, "?") } else { // only for pq, TODO: if any other databse? @@ -54,7 +60,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = make([]schemas.PK, 0) for rows.Next() { - var res = make([]string, len(table.PrimaryKeys)) + res := make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { return err @@ -76,6 +82,9 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = append(ids, pk) } + if rows.Err() != nil { + return rows.Err() + } session.engine.logger.Debugf("[cache] find updated id: %v", ids) } /*else { session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) @@ -136,14 +145,17 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri // Update records, bean's non-empty fields are updated contents, // condiBean' non-empty filds are conditions // CAUTION: -// 1.bool will defaultly be updated content nor conditions -// You should call UseBool if you have bool to use. -// 2.float32 & float64 may be not inexact as conditions +// +// 1.bool will defaultly be updated content nor conditions +// You should call UseBool if you have bool to use. +// 2.float32 & float64 may be not inexact as conditions func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { if session.isAutoClose { defer session.Close() } + defer session.resetStatement() + if session.statement.LastError != nil { return 0, session.statement.LastError } @@ -165,14 +177,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // -- var err error - var isMap = t.Kind() == reflect.Map - var isStruct = t.Kind() == reflect.Struct + isMap := t.Kind() == reflect.Map + isStruct := t.Kind() == reflect.Struct if isStruct { if err := session.statement.SetRefBean(bean); err != nil { return 0, err } - if len(session.statement.TableName()) <= 0 { + if len(session.statement.TableName()) == 0 { return 0, ErrTableNotFound } @@ -205,14 +217,17 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 !session.statement.OmitColumnMap.Contain(table.Updated) { colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") col := table.UpdatedColumn() - val, t := session.engine.nowTime(col) + val, t, err := session.engine.nowTime(col) + if err != nil { + return 0, err + } if session.engine.dialect.URI().DBType == schemas.ORACLE { args = append(args, t) } else { args = append(args, val) } - var colName = col.Name + colName := col.Name if isStruct { session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) @@ -224,35 +239,36 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // for update action to like "column = column + ?" incColumns := session.statement.IncrColumns - for i, colName := range incColumns.ColNames { - colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") - args = append(args, incColumns.Args[i]) + for _, expr := range incColumns { + colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?") + args = append(args, expr.Arg) } // for update action to like "column = column - ?" decColumns := session.statement.DecrColumns - for i, colName := range decColumns.ColNames { - colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") - args = append(args, decColumns.Args[i]) + for _, expr := range decColumns { + colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?") + args = append(args, expr.Arg) } // for update action to like "column = expression" exprColumns := session.statement.ExprColumns - for i, colName := range exprColumns.ColNames { - switch tp := exprColumns.Args[i].(type) { + for _, expr := range exprColumns { + switch tp := expr.Arg.(type) { case string: if len(tp) == 0 { tp = "''" } - colNames = append(colNames, session.engine.Quote(colName)+"="+tp) + colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp) case *builder.Builder: - subQuery, subArgs, err := session.statement.GenCondSQL(tp) + subQuery, subArgs, err := builder.ToSQL(tp) if err != nil { return 0, err } - colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")") + subQuery = session.statement.ReplaceQuote(subQuery) + colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")") args = append(args, subArgs...) default: - colNames = append(colNames, session.engine.Quote(colName)+"=?") - args = append(args, exprColumns.Args[i]) + colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?") + args = append(args, expr.Arg) } } @@ -265,7 +281,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condBeanIsStruct := false if len(condiBean) > 0 { if c, ok := condiBean[0].(map[string]interface{}); ok { - autoCond = builder.Eq(c) + eq := make(builder.Eq) + for k, v := range c { + eq[session.engine.Quote(k)] = v + } + autoCond = builder.Eq(eq) } else { ct := reflect.TypeOf(condiBean[0]) k := ct.Kind() @@ -273,8 +293,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 k = ct.Elem().Kind() } if k == reflect.Struct { - var err error - autoCond, err = session.statement.BuildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) + condTable, err := session.engine.TableInfo(condiBean[0]) + if err != nil { + return 0, err + } + + autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } @@ -301,11 +325,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 st := session.statement var ( - sqlStr string - condArgs []interface{} - condSQL string cond = session.statement.Conds().And(autoCond) - doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) verValue *reflect.Value ) @@ -321,74 +341,69 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if len(colNames) <= 0 { - return 0, errors.New("No content found to be updated") + if len(colNames) == 0 { + return 0, ErrNoColumnsTobeUpdated } - condSQL, condArgs, err = session.statement.GenCondSQL(cond) - if err != nil { + whereWriter := builder.NewWriter() + if cond.IsValid() { + fmt.Fprint(whereWriter, "WHERE ") + } + if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil { + return 0, err + } + if err := st.WriteOrderBy(whereWriter); err != nil { return 0, err } - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } - - if st.OrderStr != "" { - condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr) - } - - var tableName = session.statement.TableName() + tableName := session.statement.TableName() // TODO: Oracle support needed var top string if st.LimitN != nil { limitValue := *st.LimitN switch session.engine.dialect.URI().DBType { case schemas.MYSQL: - condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) case schemas.SQLITE: - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", - session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = session.statement.GenCondSQL(cond) - if err != nil { + session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil { return 0, err } - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } case schemas.POSTGRES: - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", - session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = session.statement.GenCondSQL(cond) - if err != nil { + session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil { return 0, err } - - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } case schemas.MSSQL: - if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { + if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], - session.engine.Quote(tableName), condSQL), condArgs...) + session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...) - condSQL, condArgs, err = session.statement.GenCondSQL(cond) - if err != nil { + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(whereWriter); err != nil { return 0, err } - if len(condSQL) > 0 { - condSQL = "WHERE " + condSQL - } } else { top = fmt.Sprintf("TOP (%d) ", limitValue) } } } - var tableAlias = session.engine.Quote(tableName) + tableAlias := session.engine.Quote(tableName) var fromSQL string if session.statement.TableAlias != "" { switch session.engine.dialect.URI().DBType { @@ -400,14 +415,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v%v", + updateWriter := builder.NewWriter() + if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", top, tableAlias, strings.Join(colNames, ", "), - fromSQL, - condSQL) + fromSQL); err != nil { + return 0, err + } + if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil { + return 0, err + } - res, err := session.exec(sqlStr, append(args, condArgs...)...) + res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...) if err != nil { return 0, err } else if doIncVer { @@ -443,7 +463,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // FIXME: if bean is a map type, it will panic because map cannot be as map key session.afterUpdateBeans[bean] = &afterClosures } - } else { if _, ok := interface{}(bean).(AfterUpdateProcessor); ok { session.afterUpdateBeans[bean] = nil @@ -508,10 +527,13 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac if col.IsUpdated && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ { // if time is non-empty, then set to auto time - val, t := session.engine.nowTime(col) + val, t, err := session.engine.nowTime(col) + if err != nil { + return nil, nil, err + } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) diff --git a/tags/parser.go b/tags/parser.go index a301d124..028f8d0b 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -12,6 +12,7 @@ import ( "strings" "sync" "time" + "unicode" "xorm.io/xorm/caches" "xorm.io/xorm/convert" @@ -20,10 +21,17 @@ import ( "xorm.io/xorm/schemas" ) -var ( - ErrUnsupportedType = errors.New("Unsupported type") -) +// ErrUnsupportedType represents an unsupported type error +var ErrUnsupportedType = errors.New("unsupported type") +// TableIndices is an interface that describes structs that provide additional index information above that which is automatically parsed +type TableIndices interface { + TableIndices() []*schemas.Index +} + +var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem() + +// Parser represents a parser for xorm tag type Parser struct { identifier string dialect dialects.Dialect @@ -34,6 +42,7 @@ type Parser struct { tableCache sync.Map // map[reflect.Type]*schemas.Table } +// NewParser creates a tag parser func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser { return &Parser{ identifier: identifier, @@ -45,24 +54,35 @@ func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnM } } +// GetTableMapper returns table mapper func (parser *Parser) GetTableMapper() names.Mapper { return parser.tableMapper } +// SetTableMapper sets table mapper func (parser *Parser) SetTableMapper(mapper names.Mapper) { parser.ClearCaches() parser.tableMapper = mapper } +// GetColumnMapper returns column mapper func (parser *Parser) GetColumnMapper() names.Mapper { return parser.columnMapper } +// SetColumnMapper sets column mapper func (parser *Parser) SetColumnMapper(mapper names.Mapper) { parser.ClearCaches() parser.columnMapper = mapper } +// SetIdentifier sets tag identifier +func (parser *Parser) SetIdentifier(identifier string) { + parser.ClearCaches() + parser.identifier = identifier +} + +// ParseWithCache parse a struct with cache func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { t := v.Type() tableI, ok := parser.tableCache.Load(t) @@ -110,6 +130,183 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index } } +// ErrIgnoreField represents an error to ignore field +var ErrIgnoreField = errors.New("field will be ignored") + +func (parser *Parser) getSQLTypeByType(t reflect.Type) (schemas.SQLType, error) { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + v, ok := parser.tableCache.Load(t) + if ok { + pkCols := v.(*schemas.Table).PKColumns() + if len(pkCols) == 1 { + return pkCols[0].SQLType, nil + } + if len(pkCols) > 1 { + return schemas.SQLType{}, fmt.Errorf("unsupported mulitiple primary key on cascade") + } + } + } + return schemas.Type2SQLType(t), nil +} + +func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + var sqlType schemas.SQLType + if fieldValue.CanAddr() { + if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } + } + if _, ok := fieldValue.Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } else { + var err error + sqlType, err = parser.getSQLTypeByType(field.Type) + if err != nil { + return nil, err + } + } + col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name), + field.Name, sqlType, sqlType.DefaultLength, + sqlType.DefaultLength2, true) + col.FieldIndex = []int{fieldIndex} + + if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { + col.IsAutoIncrement = true + col.IsPrimaryKey = true + col.Nullable = false + } + return col, nil +} + +func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) { + col := &schemas.Column{ + FieldName: field.Name, + FieldIndex: []int{fieldIndex}, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: schemas.TWOSIDES, + Indexes: make(map[string]int), + DefaultIsEmpty: true, + } + + ctx := Context{ + table: table, + col: col, + fieldValue: fieldValue, + indexNames: make(map[string]int), + parser: parser, + } + + for j, tag := range tags { + if ctx.ignoreNext { + ctx.ignoreNext = false + continue + } + + ctx.tag = tag + ctx.tagUname = strings.ToUpper(tag.name) + + if j > 0 { + ctx.preTag = strings.ToUpper(tags[j-1].name) + } + if j < len(tags)-1 { + ctx.nextTag = tags[j+1].name + } else { + ctx.nextTag = "" + } + + if h, ok := parser.handlers[ctx.tagUname]; ok { + if err := h(&ctx); err != nil { + return nil, err + } + } else { + if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") { + col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1] + } else { + col.Name = ctx.tag.name + } + } + + if ctx.hasCacheTag { + if parser.cacherMgr.GetDefaultCacher() != nil { + parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) + } else { + parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) + } + } + if ctx.hasNoCacheTag { + parser.cacherMgr.SetCacher(table.Name, nil) + } + } + + if col.SQLType.Name == "" { + var err error + col.SQLType, err = parser.getSQLTypeByType(field.Type) + if err != nil { + return nil, err + } + } + if ctx.isUnsigned && col.SQLType.IsNumeric() && !strings.HasPrefix(col.SQLType.Name, "UNSIGNED") { + col.SQLType.Name = "UNSIGNED " + col.SQLType.Name + } + + parser.dialect.SQLType(col) + if col.Length == 0 { + col.Length = col.SQLType.DefaultLength + } + if col.Length2 == 0 { + col.Length2 = col.SQLType.DefaultLength2 + } + if col.Name == "" { + col.Name = parser.columnMapper.Obj2Table(field.Name) + } + + if ctx.isUnique { + ctx.indexNames[col.Name] = schemas.UniqueType + } else if ctx.isIndex { + ctx.indexNames[col.Name] = schemas.IndexType + } + + for indexName, indexType := range ctx.indexNames { + addIndex(indexName, table, col, indexType) + } + + return col, nil +} + +func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + if isNotTitle(field.Name) { + return nil, ErrIgnoreField + } + + var ( + tag = field.Tag + ormTagStr = strings.TrimSpace(tag.Get(parser.identifier)) + ) + if ormTagStr == "-" { + return nil, ErrIgnoreField + } + if ormTagStr == "" { + return parser.parseFieldWithNoTag(fieldIndex, field, fieldValue) + } + tags, err := splitTag(ormTagStr) + if err != nil { + return nil, err + } + return parser.parseFieldWithTags(table, fieldIndex, field, fieldValue, tags) +} + +func isNotTitle(n string) bool { + for _, c := range n { + return unicode.IsLower(c) + } + return true +} + // Parse parses a struct as a table information func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { t := v.Type() @@ -124,187 +321,59 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table := schemas.NewEmptyTable() table.Type = t table.Name = names.GetTableName(parser.tableMapper, v) - - var idFieldColName string - var hasCacheTag, hasNoCacheTag bool + table.Comment = names.GetTableComment(v) for i := 0; i < t.NumField(); i++ { - tag := t.Field(i).Tag - - ormTagStr := tag.Get(parser.identifier) - var col *schemas.Column - fieldValue := v.Field(i) - fieldType := fieldValue.Type() - - if ormTagStr != "" { - col = &schemas.Column{ - FieldName: t.Field(i).Name, - Nullable: true, - IsPrimaryKey: false, - IsAutoIncrement: false, - MapType: schemas.TWOSIDES, - Indexes: make(map[string]int), - DefaultIsEmpty: true, - } - tags := splitTag(ormTagStr) - - if len(tags) > 0 { - if tags[0] == "-" { - continue - } - - var ctx = Context{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - parser: parser, - } - - if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { - pStart := strings.Index(tags[0], "(") - if pStart > -1 && strings.HasSuffix(tags[0], ")") { - var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { - return r == '\'' || r == '"' - }) - - ctx.params = []string{tagPrefix} - } - - if err := ExtendsTagHandler(&ctx); err != nil { - return nil, err - } - continue - } - - for j, key := range tags { - if ctx.ignoreNext { - ctx.ignoreNext = false - continue - } - - k := strings.ToUpper(key) - ctx.tagName = k - ctx.params = []string{} - - pStart := strings.Index(k, "(") - if pStart == 0 { - return nil, errors.New("( could not be the first character") - } - if pStart > -1 { - if !strings.HasSuffix(k, ")") { - return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key) - } - - ctx.tagName = k[:pStart] - ctx.params = strings.Split(key[pStart+1:len(k)-1], ",") - } - - if j > 0 { - ctx.preTag = strings.ToUpper(tags[j-1]) - } - if j < len(tags)-1 { - ctx.nextTag = tags[j+1] - } else { - ctx.nextTag = "" - } - - if h, ok := parser.handlers[ctx.tagName]; ok { - if err := h(&ctx); err != nil { - return nil, err - } - } else { - if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { - col.Name = key[1 : len(key)-1] - } else { - col.Name = key - } - } - - if ctx.hasCacheTag { - hasCacheTag = true - } - if ctx.hasNoCacheTag { - hasNoCacheTag = true - } - } - - if col.SQLType.Name == "" { - col.SQLType = schemas.Type2SQLType(fieldType) - } - parser.dialect.SQLType(col) - if col.Length == 0 { - col.Length = col.SQLType.DefaultLength - } - if col.Length2 == 0 { - col.Length2 = col.SQLType.DefaultLength2 - } - if col.Name == "" { - col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) - } - - if ctx.isUnique { - ctx.indexNames[col.Name] = schemas.UniqueType - } else if ctx.isIndex { - ctx.indexNames[col.Name] = schemas.IndexType - } - - for indexName, indexType := range ctx.indexNames { - addIndex(indexName, table, col, indexType) - } - } - } else if fieldValue.CanSet() { - var sqlType schemas.SQLType - if fieldValue.CanAddr() { - if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } - } - if _, ok := fieldValue.Interface().(convert.Conversion); ok { - sqlType = schemas.SQLType{Name: schemas.Text} - } else { - sqlType = schemas.Type2SQLType(fieldType) - } - col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) - - if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { - idFieldColName = col.Name - } - } else { + col, err := parser.parseField(table, i, t.Field(i), v.Field(i)) + if err == ErrIgnoreField { continue - } - if col.IsAutoIncrement { - col.Nullable = false + } else if err != nil { + return nil, err } table.AddColumn(col) - } // end for - if idFieldColName != "" && len(table.PrimaryKeys) == 0 { - col := table.GetColumn(idFieldColName) - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - table.PrimaryKeys = append(table.PrimaryKeys, col.Name) - table.AutoIncrement = col.Name - } - - if hasCacheTag { - if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided - //engine.logger.Info("enable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) - } else { - //engine.logger.Info("enable LRU cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) + indices := tableIndices(v) + for _, index := range indices { + // Override old information + if oldIndex, ok := table.Indexes[index.Name]; ok { + for _, colName := range oldIndex.Cols { + col := table.GetColumn(colName) + if col == nil { + return nil, ErrUnsupportedType + } + delete(col.Indexes, index.Name) + } + } + table.AddIndex(index) + for _, colName := range index.Cols { + col := table.GetColumn(colName) + if col == nil { + return nil, ErrUnsupportedType + } + col.Indexes[index.Name] = index.Type } - } - if hasNoCacheTag { - //engine.logger.Info("disable cache on table:", table.Name) - parser.cacherMgr.SetCacher(table.Name, nil) } return table, nil } + +func tableIndices(v reflect.Value) []*schemas.Index { + if v.Type().Implements(tpTableIndices) { + return v.Interface().(TableIndices).TableIndices() + } + + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableIndices) { + return v.Interface().(TableIndices).TableIndices() + } + } else if v.CanAddr() { + v1 := v.Addr() + if v1.Type().Implements(tpTableIndices) { + return v1.Interface().(TableIndices).TableIndices() + } + } + return nil +} diff --git a/tags/parser_test.go b/tags/parser_test.go index ff304a5b..434cfc07 100644 --- a/tags/parser_test.go +++ b/tags/parser_test.go @@ -6,12 +6,16 @@ package tags import ( "reflect" + "strings" "testing" + "time" - "github.com/stretchr/testify/assert" "xorm.io/xorm/caches" "xorm.io/xorm/dialects" "xorm.io/xorm/names" + "xorm.io/xorm/schemas" + + "github.com/stretchr/testify/assert" ) type ParseTableName1 struct{} @@ -22,6 +26,20 @@ func (p ParseTableName2) TableName() string { return "p_parseTableName" } +type ParseTableComment struct{} + +type ParseTableComment1 struct{} + +type ParseTableComment2 struct{} + +func (p ParseTableComment1) TableComment() string { + return "p_parseTableComment1" +} + +func (p *ParseTableComment2) TableComment() string { + return "p_parseTableComment2" +} + func TestParseTableName(t *testing.T) { parser := NewParser( "xorm", @@ -43,6 +61,36 @@ func TestParseTableName(t *testing.T) { assert.EqualValues(t, "p_parseTableName", table.Name) } +func TestParseTableComment(t *testing.T) { + parser := NewParser( + "xorm", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + + table, err := parser.Parse(reflect.ValueOf(new(ParseTableComment))) + assert.NoError(t, err) + assert.EqualValues(t, "", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableComment1))) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment1", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(ParseTableComment1{})) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment1", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableComment2))) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment2", table.Comment) + + table, err = parser.Parse(reflect.ValueOf(ParseTableComment2{})) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableComment2", table.Comment) +} + func TestUnexportField(t *testing.T) { parser := NewParser( "xorm", @@ -53,7 +101,7 @@ func TestUnexportField(t *testing.T) { ) type VanilaStruct struct { - private int + private int // unexported fields will be ignored Public int } table, err := parser.Parse(reflect.ValueOf(new(VanilaStruct))) @@ -67,16 +115,505 @@ func TestUnexportField(t *testing.T) { } type TaggedStruct struct { - private int `xorm:"private"` + private int `xorm:"private"` // unexported fields will be ignored Public int `xorm:"-"` } table, err = parser.Parse(reflect.ValueOf(new(TaggedStruct))) assert.NoError(t, err) assert.EqualValues(t, "tagged_struct", table.Name) + assert.EqualValues(t, 0, len(table.Columns())) +} + +func TestParseWithOtherIdentifier(t *testing.T) { + parser := NewParser( + "xorm", + dialects.QueryDialect("mysql"), + names.SameMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithDBTag struct { + FieldFoo string `db:"foo"` + } + + parser.SetIdentifier("db") + table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag))) + assert.NoError(t, err) + assert.EqualValues(t, "StructWithDBTag", table.Name) assert.EqualValues(t, 1, len(table.Columns())) for _, col := range table.Columns() { - assert.EqualValues(t, "private", col.Name) - assert.NotEqual(t, "public", col.Name) + assert.EqualValues(t, "foo", col.Name) } } + +func TestParseWithIgnore(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SameMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithIgnoreTag struct { + FieldFoo string `db:"-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag))) + assert.NoError(t, err) + assert.EqualValues(t, "StructWithIgnoreTag", table.Name) + assert.EqualValues(t, 0, len(table.Columns())) +} + +func TestParseWithAutoincrement(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement struct { + ID int64 + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) +} + +func TestParseWithAutoincrement2(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithAutoIncrement2 struct { + ID int64 `db:"pk autoincr"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_auto_increment2", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "id", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsAutoIncrement) + assert.True(t, table.Columns()[0].IsPrimaryKey) + assert.False(t, table.Columns()[0].Nullable) +} + +func TestParseWithNullable(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNullable struct { + Name string `db:"notnull"` + FullName string `db:"null comment('column comment,字段注释')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_nullable", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "full_name", table.Columns()[1].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment) +} + +func TestParseWithTimes(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithTimes struct { + Name string `db:"notnull"` + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_times", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.False(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} + +func TestParseWithExtends(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEmbed struct { + Name string + CreatedAt time.Time `db:"created"` + UpdatedAt time.Time `db:"updated"` + DeletedAt time.Time `db:"deleted"` + } + + type StructWithExtends struct { + SW StructWithEmbed `db:"extends"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithExtends))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_extends", table.Name) + assert.EqualValues(t, 4, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "created_at", table.Columns()[1].Name) + assert.EqualValues(t, "updated_at", table.Columns()[2].Name) + assert.EqualValues(t, "deleted_at", table.Columns()[3].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsCreated) + assert.True(t, table.Columns()[2].Nullable) + assert.True(t, table.Columns()[2].IsUpdated) + assert.True(t, table.Columns()[3].Nullable) + assert.True(t, table.Columns()[3].IsDeleted) +} + +func TestParseWithCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithCache struct { + Name string `db:"cache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.NotNil(t, cacher) +} + +func TestParseWithNoCache(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithNoCache struct { + Name string `db:"nocache"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithNoCache))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_no_cache", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + cacher := parser.cacherMgr.GetCacher(table.Name) + assert.Nil(t, cacher) +} + +func TestParseWithEnum(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithEnum struct { + Name string `db:"enum('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithEnum))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_enum", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Enum, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].EnumOptions) +} + +func TestParseWithSet(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithSet struct { + Name string `db:"set('alice', 'bob')"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSet))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_set", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.EqualValues(t, schemas.Set, strings.ToUpper(table.Columns()[0].SQLType.Name)) + assert.EqualValues(t, map[string]int{ + "alice": 0, + "bob": 1, + }, table.Columns()[0].SetOptions) +} + +func TestParseWithIndex(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithIndex struct { + Name string `db:"index"` + Name2 string `db:"index(s)"` + Name3 string `db:"unique"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithIndex))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_index", table.Name) + assert.EqualValues(t, 3, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "name2", table.Columns()[1].Name) + assert.EqualValues(t, "name3", table.Columns()[2].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[2].Nullable) + assert.EqualValues(t, 1, len(table.Columns()[0].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[1].Indexes)) + assert.EqualValues(t, 1, len(table.Columns()[2].Indexes)) +} + +func TestParseWithVersion(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithVersion struct { + Name string + Version int `db:"version"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithVersion))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_version", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "name", table.Columns()[0].Name) + assert.EqualValues(t, "version", table.Columns()[1].Name) + assert.True(t, table.Columns()[0].Nullable) + assert.True(t, table.Columns()[1].Nullable) + assert.True(t, table.Columns()[1].IsVersion) +} + +func TestParseWithLocale(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithLocale struct { + UTCLocale time.Time `db:"utc"` + LocalLocale time.Time `db:"local"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithLocale))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_locale", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "utc_locale", table.Columns()[0].Name) + assert.EqualValues(t, "local_locale", table.Columns()[1].Name) + assert.EqualValues(t, time.UTC, table.Columns()[0].TimeZone) + assert.EqualValues(t, time.Local, table.Columns()[1].TimeZone) +} + +func TestParseWithDefault(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.GonicMapper{}, + caches.NewManager(), + ) + + type StructWithDefault struct { + Default1 time.Time `db:"default '1970-01-01 00:00:00'"` + Default2 time.Time `db:"default(CURRENT_TIMESTAMP)"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithDefault))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_default", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, "'1970-01-01 00:00:00'", table.Columns()[0].Default) + assert.EqualValues(t, "CURRENT_TIMESTAMP", table.Columns()[1].Default) +} + +func TestParseWithOnlyToDB(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "DB": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithOnlyToDB struct { + Default1 time.Time `db:"->"` + Default2 time.Time `db:"<-"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithOnlyToDB))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_only_to_db", table.Name) + assert.EqualValues(t, 2, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.EqualValues(t, "default2", table.Columns()[1].Name) + assert.EqualValues(t, schemas.ONLYTODB, table.Columns()[0].MapType) + assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType) +} + +func TestParseWithJSON(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "JSON": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithJSON struct { + Default1 []string `db:"json"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithJSON))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_json", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsJSON) +} + +func TestParseWithJSONB(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("postgres"), + names.GonicMapper{ + "JSONB": true, + }, + names.SnakeMapper{}, + caches.NewManager(), + ) + + type StructWithJSONB struct { + Default1 []string `db:"jsonb"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithJSONB))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_jsonb", table.Name) + assert.EqualValues(t, 1, len(table.Columns())) + assert.EqualValues(t, "default1", table.Columns()[0].Name) + assert.True(t, table.Columns()[0].IsJSON) +} + +func TestParseWithSQLType(t *testing.T) { + parser := NewParser( + "db", + dialects.QueryDialect("mysql"), + names.GonicMapper{ + "SQL": true, + }, + names.GonicMapper{ + "UUID": true, + }, + caches.NewManager(), + ) + + type StructWithSQLType struct { + Col1 string `db:"varchar(32)"` + Col2 string `db:"char(32)"` + Int int64 `db:"bigint"` + DateTime time.Time `db:"datetime"` + UUID string `db:"uuid"` + } + + table, err := parser.Parse(reflect.ValueOf(new(StructWithSQLType))) + assert.NoError(t, err) + assert.EqualValues(t, "struct_with_sql_type", table.Name) + assert.EqualValues(t, 5, len(table.Columns())) + assert.EqualValues(t, "col1", table.Columns()[0].Name) + assert.EqualValues(t, "col2", table.Columns()[1].Name) + assert.EqualValues(t, "int", table.Columns()[2].Name) + assert.EqualValues(t, "date_time", table.Columns()[3].Name) + assert.EqualValues(t, "uuid", table.Columns()[4].Name) + + assert.EqualValues(t, "VARCHAR", table.Columns()[0].SQLType.Name) + assert.EqualValues(t, "CHAR", table.Columns()[1].SQLType.Name) + assert.EqualValues(t, "BIGINT", table.Columns()[2].SQLType.Name) + assert.EqualValues(t, "DATETIME", table.Columns()[3].SQLType.Name) + assert.EqualValues(t, "UUID", table.Columns()[4].SQLType.Name) +} diff --git a/tags/tag.go b/tags/tag.go index bb5b5838..41d525e1 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -14,30 +14,74 @@ import ( "xorm.io/xorm/schemas" ) -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 +type tag struct { + name string + params []string +} + +func splitTag(tagStr string) ([]tag, error) { + tagStr = strings.TrimSpace(tagStr) + var ( + inQuote bool + inBigQuote bool + lastIdx int + curTag tag + paramStart int + tags []tag + ) + for i, t := range tagStr { + switch t { + case '\'': + inQuote = !inQuote + case ' ': + if !inQuote && !inBigQuote { + if lastIdx < i { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:i] + } + tags = append(tags, curTag) + lastIdx = i + 1 + curTag = tag{} + } else if lastIdx == i { + lastIdx = i + 1 + } + } else if inBigQuote && !inQuote { + paramStart = i + 1 + } + case ',': + if !inQuote && !inBigQuote { + return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr) + } + if !inQuote && inBigQuote { + curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i])) + paramStart = i + 1 + } + case '(': + inBigQuote = true + if !inQuote { + curTag.name = tagStr[lastIdx:i] + paramStart = i + 1 + } + case ')': + inBigQuote = false + if !inQuote { + curTag.params = append(curTag.params, tagStr[paramStart:i]) } } } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + if lastIdx < len(tagStr) { + if curTag.name == "" { + curTag.name = tagStr[lastIdx:] + } + tags = append(tags, curTag) } - return + return tags, nil } // Context represents a context for xorm tag parse. type Context struct { - tagName string - params []string + tag + tagUname string preTag, nextTag string table *schemas.Table col *schemas.Column @@ -49,35 +93,37 @@ type Context struct { hasCacheTag bool hasNoCacheTag bool ignoreNext bool + isUnsigned bool } // Handler describes tag handler for XORM type Handler func(ctx *Context) error -var ( - // defaultTagHandlers enumerates all the default tag handler - defaultTagHandlers = map[string]Handler{ - "<-": OnlyFromDBTagHandler, - "->": OnlyToDBTagHandler, - "PK": PKTagHandler, - "NULL": NULLTagHandler, - "NOT": IgnoreTagHandler, - "AUTOINCR": AutoIncrTagHandler, - "DEFAULT": DefaultTagHandler, - "CREATED": CreatedTagHandler, - "UPDATED": UpdatedTagHandler, - "DELETED": DeletedTagHandler, - "VERSION": VersionTagHandler, - "UTC": UTCTagHandler, - "LOCAL": LocalTagHandler, - "NOTNULL": NotNullTagHandler, - "INDEX": IndexTagHandler, - "UNIQUE": UniqueTagHandler, - "CACHE": CacheTagHandler, - "NOCACHE": NoCacheTagHandler, - "COMMENT": CommentTagHandler, - } -) +// defaultTagHandlers enumerates all the default tag handler +var defaultTagHandlers = map[string]Handler{ + "-": IgnoreHandler, + "<-": OnlyFromDBTagHandler, + "->": OnlyToDBTagHandler, + "PK": PKTagHandler, + "NULL": NULLTagHandler, + "NOT": NotTagHandler, + "AUTOINCR": AutoIncrTagHandler, + "DEFAULT": DefaultTagHandler, + "CREATED": CreatedTagHandler, + "UPDATED": UpdatedTagHandler, + "DELETED": DeletedTagHandler, + "VERSION": VersionTagHandler, + "UTC": UTCTagHandler, + "LOCAL": LocalTagHandler, + "NOTNULL": NotNullTagHandler, + "INDEX": IndexTagHandler, + "UNIQUE": UniqueTagHandler, + "CACHE": CacheTagHandler, + "NOCACHE": NoCacheTagHandler, + "COMMENT": CommentTagHandler, + "EXTENDS": ExtendsTagHandler, + "UNSIGNED": UnsignedTagHandler, +} func init() { for k := range schemas.SqlTypes { @@ -85,11 +131,16 @@ func init() { } } -// IgnoreTagHandler describes ignored tag handler -func IgnoreTagHandler(ctx *Context) error { +// NotTagHandler describes ignored tag handler +func NotTagHandler(ctx *Context) error { return nil } +// IgnoreHandler represetns the field should be ignored +func IgnoreHandler(ctx *Context) error { + return ErrIgnoreField +} + // OnlyFromDBTagHandler describes mapping direction tag handler func OnlyFromDBTagHandler(ctx *Context) error { ctx.col.MapType = schemas.ONLYFROMDB @@ -124,6 +175,7 @@ func NotNullTagHandler(ctx *Context) error { // AutoIncrTagHandler describes autoincr tag handler func AutoIncrTagHandler(ctx *Context) error { ctx.col.IsAutoIncrement = true + ctx.col.Nullable = false /* if len(ctx.params) > 0 { autoStartInt, err := strconv.Atoi(ctx.params[0]) @@ -192,6 +244,7 @@ func UpdatedTagHandler(ctx *Context) error { // DeletedTagHandler describes deleted tag handler func DeletedTagHandler(ctx *Context) error { ctx.col.IsDeleted = true + ctx.col.Nullable = true return nil } @@ -215,6 +268,12 @@ func UniqueTagHandler(ctx *Context) error { return nil } +// UnsignedTagHandler represents the column is unsigned +func UnsignedTagHandler(ctx *Context) error { + ctx.isUnsigned = true + return nil +} + // CommentTagHandler add comment to column func CommentTagHandler(ctx *Context) error { if len(ctx.params) > 0 { @@ -225,41 +284,44 @@ func CommentTagHandler(ctx *Context) error { // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { - ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} - if strings.EqualFold(ctx.tagName, "JSON") { + ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} + if ctx.tagUname == "JSON" || ctx.tagUname == "JSONB" { ctx.col.IsJSON = true } - if len(ctx.params) > 0 { - if ctx.tagName == schemas.Enum { - ctx.col.EnumOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.EnumOptions[v] = k + if len(ctx.params) == 0 { + return nil + } + + switch ctx.tagUname { + case schemas.Enum: + ctx.col.EnumOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.EnumOptions[v] = k + } + case schemas.Set: + ctx.col.SetOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.SetOptions[v] = k + } + default: + var err error + if len(ctx.params) == 2 { + ctx.col.Length, err = strconv.ParseInt(ctx.params[0], 10, 64) + if err != nil { + return err } - } else if ctx.tagName == schemas.Set { - ctx.col.SetOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.SetOptions[v] = k + ctx.col.Length2, err = strconv.ParseInt(ctx.params[1], 10, 64) + if err != nil { + return err } - } else { - var err error - if len(ctx.params) == 2 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } - ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) - if err != nil { - return err - } - } else if len(ctx.params) == 1 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err - } + } else if len(ctx.params) == 1 { + ctx.col.Length, err = strconv.ParseInt(ctx.params[0], 10, 64) + if err != nil { + return err } } } @@ -268,8 +330,8 @@ func SQLTypeTagHandler(ctx *Context) error { // ExtendsTagHandler describes extends tag handler func ExtendsTagHandler(ctx *Context) error { - var fieldValue = ctx.fieldValue - var isPtr = false + fieldValue := ctx.fieldValue + isPtr := false switch fieldValue.Kind() { case reflect.Ptr: f := fieldValue.Type().Elem() @@ -289,11 +351,12 @@ func ExtendsTagHandler(ctx *Context) error { } for _, col := range parentTable.Columns() { col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) + col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...) - var tagPrefix = ctx.col.FieldName + tagPrefix := ctx.col.FieldName if len(ctx.params) > 0 { col.Nullable = isPtr - tagPrefix = ctx.params[0] + tagPrefix = strings.Trim(ctx.params[0], "'") if col.IsPrimaryKey { col.Name = ctx.col.FieldName col.IsPrimaryKey = false @@ -313,9 +376,9 @@ func ExtendsTagHandler(ctx *Context) error { } } default: - //TODO: warning + // TODO: warning } - return nil + return ErrIgnoreField } // CacheTagHandler describes cache tag handler diff --git a/tags/tag_test.go b/tags/tag_test.go index 5775b40a..3ceeefd1 100644 --- a/tags/tag_test.go +++ b/tags/tag_test.go @@ -7,24 +7,83 @@ package tags import ( "testing" - "xorm.io/xorm/internal/utils" + "github.com/stretchr/testify/assert" ) func TestSplitTag(t *testing.T) { var cases = []struct { tag string - tags []string + tags []tag }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ + { + name: "not", + }, + { + name: "null", + }, + { + name: "default", + }, + { + name: "'2000-01-01 00:00:00'", + }, + { + name: "TIMESTAMP", + }, + }, + }, + {"TEXT", []tag{ + { + name: "TEXT", + }, + }, + }, + {"default('2000-01-01 00:00:00')", []tag{ + { + name: "default", + params: []string{ + "'2000-01-01 00:00:00'", + }, + }, + }, + }, + {"json binary", []tag{ + { + name: "json", + }, + { + name: "binary", + }, + }, + }, + {"numeric(10, 2)", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + }, + }, + {"numeric(10, 2) notnull", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + { + name: "notnull", + }, + }, + }, } for _, kase := range cases { - tags := splitTag(kase.tag) - if !utils.SliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } + t.Run(kase.tag, func(t *testing.T) { + tags, err := splitTag(kase.tag) + assert.NoError(t, err) + assert.EqualValues(t, len(tags), len(kase.tags)) + for i := 0; i < len(tags); i++ { + assert.Equal(t, tags[i], kase.tags[i]) + } + }) } }