diff --git a/.changelog.yml b/.changelog.yml new file mode 100644 index 00000000..1303c9cc --- /dev/null +++ b/.changelog.yml @@ -0,0 +1,53 @@ +# The full repository name +repo: xorm/xorm + +# Service type (gitea or github) +service: gitea + +# Base URL for Gitea instance if using gitea service type (optional) +# Default: https://gitea.com +base-url: + +# Changelog groups and which labeled PRs to add to each group +groups: + - + name: BREAKING + labels: + - kind/breaking + - + name: FEATURES + labels: + - kind/feature + - + name: SECURITY + labels: + - kind/security + - + name: BUGFIXES + labels: + - kind/bug + - + name: ENHANCEMENTS + labels: + - kind/enhancement + - kind/refactor + - kind/ui + - + name: TESTING + labels: + - kind/testing + - + name: BUILD + labels: + - kind/build + - kind/lint + - + name: DOCS + labels: + - kind/docs + - + name: MISC + default: true + +# regex indicating which labels to skip for the changelog +skip-labels: skip-changelog|backport\/.+ diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index c8f64282..00000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,61 +0,0 @@ -# Golang CircleCI 2.0 configuration file -# -# Check https://circleci.com/docs/2.0/language-go/ for more details -version: 2 -jobs: - build: - docker: - # specify the version - - image: circleci/golang:1.10 - - - image: circleci/mysql:5.7 - environment: - MYSQL_ALLOW_EMPTY_PASSWORD: true - MYSQL_DATABASE: xorm_test - MYSQL_HOST: 127.0.0.1 - MYSQL_ROOT_HOST: '%' - MYSQL_USER: root - - # CircleCI PostgreSQL images available at: https://hub.docker.com/r/circleci/postgres/ - - image: circleci/postgres:9.6.2-alpine - environment: - POSTGRES_USER: circleci - POSTGRES_DB: xorm_test - - - image: microsoft/mssql-server-linux:latest - environment: - ACCEPT_EULA: Y - SA_PASSWORD: yourStrong(!)Password - MSSQL_PID: Developer - - - image: pingcap/tidb:v2.1.2 - - working_directory: /go/src/github.com/go-xorm/xorm - steps: - - checkout - - - run: go get -t -d -v ./... - - run: go get -u xorm.io/core - - run: go get -u xorm.io/builder - - run: GO111MODULE=off go build -v - - run: GO111MODULE=on go build -v - - - run: go get -u github.com/wadey/gocovmerge - - - run: go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1-1.txt -covermode=atomic - - run: go test -v -race -db="sqlite3" -conn_str="./test.db" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic - - run: go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2-1.txt -covermode=atomic - - run: go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic - - run: go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -coverprofile=coverage3-1.txt -covermode=atomic - - run: go test -v -race -db="mymysql" -conn_str="xorm_test/root/" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic - - run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage4-1.txt -covermode=atomic - - run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic - - run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic - - run: go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic - - run: go test -v -race -db="mssql" -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" -coverprofile=coverage6-1.txt -covermode=atomic - - run: go test -v -race -db="mssql" -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" -cache=true -coverprofile=coverage6-2.txt -covermode=atomic - - run: go test -v -race -db="mysql" -conn_str="root:@tcp(localhost:4000)/xorm_test" -ignore_select_update=true -coverprofile=coverage7-1.txt -covermode=atomic - - run: go test -v -race -db="mysql" -conn_str="root:@tcp(localhost:4000)/xorm_test" -ignore_select_update=true -cache=true -coverprofile=coverage7-2.txt -covermode=atomic - - run: gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt coverage6-1.txt coverage6-2.txt coverage7-1.txt coverage7-2.txt > coverage.txt - - - run: bash <(curl -s https://codecov.io/bash) \ No newline at end of file diff --git a/.drone.yml b/.drone.yml index b2198e38..7a18e0d6 100644 --- a/.drone.yml +++ b/.drone.yml @@ -1,249 +1,88 @@ --- kind: pipeline -name: go1.10-test -workspace: - base: /go - path: src/gitea.com/xorm/xorm - +name: testing steps: -- name: build - pull: default - image: golang:1.10 +- name: test-vet + image: golang:1.11 # The lowest golang requirement + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" commands: - - go get -t -d -v - - go build -v + - make vet + - make test when: event: - push - pull_request - name: test-sqlite - pull: default - image: golang:1.10 - depends_on: - - build + image: golang:1.12 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" commands: - - "go test -v -race -db=\"sqlite3\" -conn_str=\"./test.db\" -coverprofile=coverage1-1.txt -covermode=atomic" - - "go test -v -race -db=\"sqlite3\" -conn_str=\"./test.db\" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic" + - make test-sqlite + - TEST_CACHE_ENABLE=true make test-sqlite + - TEST_QUOTE_POLICY=reserved make test-sqlite when: event: - push - pull_request - name: test-mysql - pull: default - image: golang:1.10 - depends_on: - - build + image: golang:1.12 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + TEST_MYSQL_HOST: mysql + TEST_MYSQL_CHARSET: utf8 + TEST_MYSQL_DBNAME: xorm_test + TEST_MYSQL_USERNAME: root + TEST_MYSQL_PASSWORD: commands: - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test\" -coverprofile=coverage2-1.txt -covermode=atomic" - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test\" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic" + - make test-mysql + - TEST_CACHE_ENABLE=true make test-mysql + - TEST_QUOTE_POLICY=reserved make test-mysql + when: + event: + - push + - pull_request + +- name: test-mysql8 + image: golang:1.12 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + TEST_MYSQL_HOST: mysql8 + TEST_MYSQL_CHARSET: utf8mb4 + TEST_MYSQL_DBNAME: xorm_test + TEST_MYSQL_USERNAME: root + TEST_MYSQL_PASSWORD: + commands: + - make test-mysql + - TEST_CACHE_ENABLE=true make test-mysql + - TEST_QUOTE_POLICY=reserved make test-mysql when: event: - push - pull_request - name: test-mysql-utf8mb4 - pull: default - image: golang:1.10 + image: golang:1.12 depends_on: - - test-mysql - commands: - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test?charset=utf8mb4\" -coverprofile=coverage2.1-1.txt -covermode=atomic" - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test?charset=utf8mb4\" -cache=true -coverprofile=coverage2.1-2.txt -covermode=atomic" - when: - event: - - push - - pull_request - -- name: test-mymysql - pull: default - image: golang:1.10 - depends_on: - - test-mysql-utf8mb4 - commands: - - "go test -v -race -db=\"mymysql\" -conn_str=\"tcp:mysql:3306*xorm_test/root/\" -coverprofile=coverage3-1.txt -covermode=atomic" - - "go test -v -race -db=\"mymysql\" -conn_str=\"tcp:mysql:3306*xorm_test/root/\" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic" - when: - event: - - push - - pull_request - -- name: test-postgres - pull: default - image: golang:1.10 - depends_on: - - build - commands: - - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -coverprofile=coverage4-1.txt -covermode=atomic" - - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic" - when: - event: - - push - - pull_request - -- name: test-postgres-schema - pull: default - image: golang:1.10 - depends_on: - - build - commands: - - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic" - - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic" - when: - event: - - push - - pull_request - -- name: test-mssql - pull: default - image: golang:1.10 - depends_on: - - build - commands: - - "go test -v -race -db=\"mssql\" -conn_str=\"server=mssql;user id=sa;password=yourStrong(!)Password;database=xorm_test\" -coverprofile=coverage6-1.txt -covermode=atomic" - - "go test -v -race -db=\"mssql\" -conn_str=\"server=mssql;user id=sa;password=yourStrong(!)Password;database=xorm_test\" -cache=true -coverprofile=coverage6-2.txt -covermode=atomic" - when: - event: - - push - - pull_request - -- name: test-tidb - pull: default - image: golang:1.10 - depends_on: - - build - commands: - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(tidb:4000)/xorm_test\" -ignore_select_update=true -coverprofile=coverage7-1.txt -covermode=atomic" - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(tidb:4000)/xorm_test\" -ignore_select_update=true -cache=true -coverprofile=coverage7-2.txt -covermode=atomic" - when: - event: - - push - - pull_request - -- name: test-end - pull: default - image: golang:1.10 - depends_on: - - test-sqlite - test-mysql - - test-mysql-utf8mb4 - - test-mymysql - - test-postgres - - test-postgres-schema - - test-mssql - - test-tidb - commands: - - echo "go1.10 build end" - when: - event: - - push - - pull_request - -services: -- name: mysql - pull: default - image: mysql:5.7 - environment: - MYSQL_ALLOW_EMPTY_PASSWORD: yes - MYSQL_DATABASE: xorm_test - when: - event: - - push - - tag - - pull_request - -- name: pgsql - pull: default - image: postgres:9.5 - environment: - POSTGRES_DB: xorm_test - POSTGRES_USER: postgres - when: - event: - - push - - tag - - pull_request - -- name: mssql - pull: default - image: microsoft/mssql-server-linux:latest - environment: - ACCEPT_EULA: Y - SA_PASSWORD: yourStrong(!)Password - MSSQL_PID: Developer - when: - event: - - push - - tag - - pull_request - -- name: tidb - pull: default - image: pingcap/tidb:v3.0.3 - when: - event: - - push - - tag - - pull_request - ---- -kind: pipeline -name: go1.13-test -steps: -- name: build - pull: default - image: golang:1.13 environment: GO111MODULE: "on" GOPROXY: "https://goproxy.cn" + TEST_MYSQL_HOST: mysql + TEST_MYSQL_CHARSET: utf8mb4 + TEST_MYSQL_DBNAME: xorm_test + TEST_MYSQL_USERNAME: root + TEST_MYSQL_PASSWORD: commands: - - go build -v - - go vet - when: - event: - - push - - pull_request - -- name: test-sqlite - pull: default - image: golang:1.13 - environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" - commands: - - "go test -v -race -db=\"sqlite3\" -conn_str=\"./test.db\" -coverprofile=coverage1-1.txt -covermode=atomic" - - "go test -v -race -db=\"sqlite3\" -conn_str=\"./test.db\" -cache=true -coverprofile=coverage1-2.txt -covermode=atomic" - when: - event: - - push - - pull_request - -- name: test-mysql - pull: default - image: golang:1.13 - environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" - commands: - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test\" -coverprofile=coverage2-1.txt -covermode=atomic" - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test\" -cache=true -coverprofile=coverage2-2.txt -covermode=atomic" - when: - event: - - push - - pull_request - -- name: test-mysql-utf8mb4 - pull: default - image: golang:1.13 - depends_on: - - test-mysql - environment: - GO111MODULE: "on" - GOPROXY: "https://goproxy.cn" - commands: - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test?charset=utf8mb4\" -coverprofile=coverage2.1-1.txt -covermode=atomic" - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(mysql)/xorm_test?charset=utf8mb4\" -cache=true -coverprofile=coverage2.1-2.txt -covermode=atomic" + - make test-mysql + - TEST_CACHE_ENABLE=true make test-mysql + - TEST_QUOTE_POLICY=reserved make test-mysql when: event: - push @@ -251,15 +90,20 @@ steps: - name: test-mymysql pull: default - image: golang:1.13 + image: golang:1.12 depends_on: - test-mysql-utf8mb4 environment: GO111MODULE: "on" GOPROXY: "https://goproxy.cn" + TEST_MYSQL_HOST: mysql:3306 + TEST_MYSQL_DBNAME: xorm_test + TEST_MYSQL_USERNAME: root + TEST_MYSQL_PASSWORD: commands: - - "go test -v -race -db=\"mymysql\" -conn_str=\"tcp:mysql:3306*xorm_test/root/\" -coverprofile=coverage3-1.txt -covermode=atomic" - - "go test -v -race -db=\"mymysql\" -conn_str=\"tcp:mysql:3306*xorm_test/root/\" -cache=true -coverprofile=coverage3-2.txt -covermode=atomic" + - make test-mymysql + - TEST_CACHE_ENABLE=true make test-mymysql + - TEST_QUOTE_POLICY=reserved make test-mymysql when: event: - push @@ -267,13 +111,18 @@ steps: - name: test-postgres pull: default - image: golang:1.13 + image: golang:1.12 environment: GO111MODULE: "on" GOPROXY: "https://goproxy.cn" + TEST_PGSQL_HOST: pgsql + TEST_PGSQL_DBNAME: xorm_test + TEST_PGSQL_USERNAME: postgres + TEST_PGSQL_PASSWORD: postgres commands: - - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -coverprofile=coverage4-1.txt -covermode=atomic" - - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -cache=true -coverprofile=coverage4-2.txt -covermode=atomic" + - make test-postgres + - TEST_CACHE_ENABLE=true make test-postgres + - TEST_QUOTE_POLICY=reserved make test-postgres when: event: - push @@ -281,13 +130,21 @@ steps: - name: test-postgres-schema pull: default - image: golang:1.13 + image: golang:1.12 + depends_on: + - test-postgres environment: GO111MODULE: "on" GOPROXY: "https://goproxy.cn" + TEST_PGSQL_HOST: pgsql + TEST_PGSQL_SCHEMA: xorm + TEST_PGSQL_DBNAME: xorm_test + TEST_PGSQL_USERNAME: postgres + TEST_PGSQL_PASSWORD: postgres commands: - - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -schema=xorm -coverprofile=coverage5-1.txt -covermode=atomic" - - "go test -v -race -db=\"postgres\" -conn_str=\"postgres://postgres:@pgsql/xorm_test?sslmode=disable\" -schema=xorm -cache=true -coverprofile=coverage5-2.txt -covermode=atomic" + - make test-postgres + - TEST_CACHE_ENABLE=true make test-postgres + - TEST_QUOTE_POLICY=reserved make test-postgres when: event: - push @@ -295,27 +152,56 @@ steps: - name: test-mssql pull: default - image: golang:1.13 + image: golang:1.12 environment: GO111MODULE: "on" GOPROXY: "https://goproxy.cn" + TEST_MSSQL_HOST: mssql + TEST_MSSQL_DBNAME: xorm_test + TEST_MSSQL_USERNAME: sa + TEST_MSSQL_PASSWORD: "yourStrong(!)Password" commands: - - "go test -v -race -db=\"mssql\" -conn_str=\"server=mssql;user id=sa;password=yourStrong(!)Password;database=xorm_test\" -coverprofile=coverage6-1.txt -covermode=atomic" - - "go test -v -race -db=\"mssql\" -conn_str=\"server=mssql;user id=sa;password=yourStrong(!)Password;database=xorm_test\" -cache=true -coverprofile=coverage6-2.txt -covermode=atomic" + - make test-mssql + - TEST_CACHE_ENABLE=true make test-mssql + - TEST_QUOTE_POLICY=reserved make test-mssql when: event: - push - pull_request - name: test-tidb + pull: default + image: golang:1.12 + environment: + GO111MODULE: "on" + GOPROXY: "https://goproxy.cn" + TEST_TIDB_HOST: "tidb:4000" + TEST_TIDB_DBNAME: xorm_test + TEST_TIDB_USERNAME: root + TEST_TIDB_PASSWORD: + commands: + - make test-tidb + - TEST_CACHE_ENABLE=true make test-tidb + - TEST_QUOTE_POLICY=reserved make test-tidb + when: + event: + - push + - pull_request + +- name: test-cockroach pull: default image: golang:1.13 environment: GO111MODULE: "on" GOPROXY: "https://goproxy.cn" + TEST_COCKROACH_HOST: "cockroach:26257" + TEST_COCKROACH_DBNAME: xorm_test + TEST_COCKROACH_USERNAME: root + TEST_COCKROACH_PASSWORD: commands: - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(tidb:4000)/xorm_test\" -ignore_select_update=true -coverprofile=coverage7-1.txt -covermode=atomic" - - "go test -v -race -db=\"mysql\" -conn_str=\"root:@tcp(tidb:4000)/xorm_test\" -ignore_select_update=true -cache=true -coverprofile=coverage7-2.txt -covermode=atomic" + - sleep 10 + - make test-cockroach + - TEST_CACHE_ENABLE=true make test-cockroach when: event: - push @@ -323,23 +209,23 @@ steps: - name: merge_coverage pull: default - image: golang:1.13 + image: golang:1.12 environment: GO111MODULE: "on" GOPROXY: "https://goproxy.cn" depends_on: - - build + - test-vet - test-sqlite - test-mysql - - test-mysql-utf8mb4 + - test-mysql8 - test-mymysql - test-postgres - test-postgres-schema - test-mssql - test-tidb + - test-cockroach commands: - - go get github.com/wadey/gocovmerge - - gocovmerge coverage1-1.txt coverage1-2.txt coverage2-1.txt coverage2-2.txt coverage2.1-1.txt coverage2.1-2.txt coverage3-1.txt coverage3-2.txt coverage4-1.txt coverage4-2.txt coverage5-1.txt coverage5-2.txt coverage6-1.txt coverage6-2.txt coverage7-1.txt coverage7-2.txt > coverage.txt + - make coverage when: event: - push @@ -359,12 +245,25 @@ services: - tag - pull_request +- name: mysql8 + pull: default + image: mysql:8.0 + environment: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: xorm_test + when: + event: + - push + - tag + - pull_request + - name: pgsql pull: default image: postgres:9.5 environment: POSTGRES_DB: xorm_test POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres when: event: - push @@ -391,4 +290,15 @@ services: event: - push - tag - - pull_request \ No newline at end of file + - pull_request + +- name: cockroach + pull: default + image: cockroachdb/cockroach:v19.2.4 + commands: + - /cockroach/cockroach start --insecure + when: + event: + - push + - tag + - pull_request diff --git a/.gitignore b/.gitignore index f1757b98..617d5da7 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ # Folders _obj _test +vendor/ # Architecture specific extensions/prefixes *.[568vq] @@ -31,3 +32,7 @@ xorm.test test.db.sql .idea/ + +*coverage.out +test.db +integrations/*.sql diff --git a/.revive.toml b/.revive.toml new file mode 100644 index 00000000..64e223bb --- /dev/null +++ b/.revive.toml @@ -0,0 +1,25 @@ +ignoreGeneratedHeader = false +severity = "warning" +confidence = 0.8 +errorCode = 1 +warningCode = 1 + +[rule.blank-imports] +[rule.context-as-argument] +[rule.context-keys-type] +[rule.dot-imports] +[rule.error-return] +[rule.error-strings] +[rule.error-naming] +[rule.exported] +[rule.if-return] +[rule.increment-decrement] +[rule.var-naming] +[rule.var-declaration] +[rule.package-comments] +[rule.range] +[rule.receiver-naming] +[rule.time-naming] +[rule.unexported-return] +[rule.indent-error-flow] +[rule.errorf] \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..22f6157a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,197 @@ +# Changelog + +This changelog goes through all the changes that have been made in each release +without substantial changes to our git log. + +## [1.0.2](https://gitea.com/xorm/xorm/pulls?q=&type=all&state=closed&milestone=1261) - 2020-06-16 + +* FEATURES + * Add Hook (#1644) +* BUGFIXES + * Fix bug when ID used but no reference table given (#1709) + * Fix find and count bug (#1651) +* ENHANCEMENTS + * chore: improve snakeCasedName performance (#1688) + * Fix find with another struct (#1666) + * fix GetColumns missing ordinal position (#1660) +* MISC + * chore: improve titleCasedName performance (#1691) + +## [1.0.1](https://gitea.com/xorm/xorm/pulls?q=&type=all&state=closed&milestone=1253) - 2020-03-25 + +* BUGFIXES + * Oracle : Local Naming Method (#1515) + * Fix find and count bug (#1618) + * Fix duplicated deleted condition on FindAndCount (#1619) + * Fix find and count bug with cache (#1622) + * Fix postgres schema problem (#1624) + * Fix quote with blank (#1626) + +## [1.0.0](https://gitea.com/xorm/xorm/pulls?q=&type=all&state=closed&milestone=1242) - 2020-03-22 + +* BREAKING + * Add context for dialects (#1558) + * Move zero functions to a standalone package (#1548) + * Merge core package back into the main repository and split into serval sub packages. (#1543) +* FEATURES + * Use a new ContextLogger interface to implement logger (#1557) +* BUGFIXES + * Fix setschema (#1606) + * Fix dump/import bug (#1603) + * Fix pk bug (#1602) + * Fix master/slave bug (#1601) + * Fix bug when dump (#1597) + * Ignore schema when dbtype is not postgres (#1593) + * Fix table name (#1590) + * Fix find alias bug (#1581) + * Fix rows bug (#1576) + * Fix map with cols (#1575) + * Fix bug on deleted with join (#1570) + * Improve quote policy (#1567) + * Fix break session sql enable feature (#1566) + * Fix mssql quote (#1535) + * Fix join table name quote bug (#1534) + * Fix mssql issue with duplicate columns. (#1225) + * Fix mysql8.0 sync failed (#808) +* ENHANCEMENTS + * Fix batch insert interface slice be panic (#1598) + * Move some codes to statement sub package (#1574) + * Remove circle file (#1569) + * Move statement as a sub package (#1564) + * Move maptype to tag parser (#1561) + * Move caches to manager (#1553) + * Improve code (#1552) + * Improve some codes (#1551) + * Improve statement (#1549) + * Move tag parser related codes as a standalone sub package (#1547) + * Move reserve words related files into dialects sub package (#1544) + * Fix `Conversion` method `ToDB() ([]byte, error)` return type is nil (#1296) + * Check driver.Valuer response, and skip the column if nil (#1167) + * Add cockroach support and tests (#896) +* TESTING + * Improve tests (#1572) +* BUILD + * Add changelog file and tool configuration (#1546) +* DOCS + * Fix outdate changelog (#1565) + +## old changelog + +* **v0.6.5** + * Postgres schema support + * vgo support + * Add FindAndCount + * Database special params support via NewEngineWithParams + * Some bugs fixed + +* **v0.6.4** + * Automatical Read/Write seperatelly + * Query/QueryString/QueryInterface and action with Where/And + * Get support non-struct variables + * BufferSize on Iterate + * fix some other bugs. + +* **v0.6.3** + * merge tests to main project + * add `Exist` function + * add `SumInt` function + * Mysql now support read and create column comment. + * fix time related bugs. + * fix some other bugs. + +* **v0.6.2** + * refactor tag parse methods + * add Scan features to Get + * add QueryString method + +* **v0.4.5** + * many bugs fixed + * extends support unlimited deep + * Delete Limit support + +* **v0.4.4** + * ql database expriment support + * tidb database expriment support + * sql.NullString and etc. field support + * select ForUpdate support + * many bugs fixed + +* **v0.4.3** + * Json column type support + * oracle expirement support + * bug fixed + +* **v0.4.2** + * Transaction will auto rollback if not Rollback or Commit be called. + * Gonic Mapper support + * bug fixed + +* **v0.4.1** + * deleted tag support for soft delete + * bug fixed + +* **v0.4.0 RC1** + Changes: + * moved xorm cmd to [github.com/go-xorm/cmd](github.com/go-xorm/cmd) + * refactored general DB operation a core lib at [github.com/go-xorm/core](https://github.com/go-xorm/core) + * moved tests to github.com/go-xorm/tests [github.com/go-xorm/tests](github.com/go-xorm/tests) + + Improvements: + * Prepared statement cache + * Add Incr API + * Specify Timezone Location + +* **v0.3.2** + Improvements: + * Add AllCols & MustCols function + * Add TableName for custom table name + + Bug Fixes: + * #46 + * #51 + * #53 + * #89 + * #86 + * #92 + +* **v0.3.1** + + Features: + * Support MSSQL DB via ODBC driver ([github.com/lunny/godbc](https://github.com/lunny/godbc)); + * Composite Key, using multiple pk xorm tag + * Added Row() API as alternative to Iterate() API for traversing result set, provide similar usages to sql.Rows type + * ORM struct allowed declaration of pointer builtin type as members to allow null DB fields + * Before and After Event processors + + Improvements: + * Allowed int/int32/int64/uint/uint32/uint64/string as Primary Key type + * Performance improvement for Get()/Find()/Iterate() + + +* **v0.2.3** : Improved documents; Optimistic Locking support; Timestamp with time zone support; Mapper change to tableMapper and columnMapper & added PrefixMapper & SuffixMapper support custom table or column name's prefix and suffix;Insert now return affected, err instead of id, err; Added UseBool & Distinct; + +* **v0.2.2** : Postgres drivers now support lib/pq; Added method Iterate for record by record to handler;Added SetMaxConns(go1.2+) support; some bugs fixed. + +* **v0.2.1** : Added database reverse tool, now support generate go & c++ codes, see [Xorm Tool README](https://github.com/go-xorm/xorm/blob/master/xorm/README.md); some bug fixed. + +* **v0.2.0** : Added Cache supported, select is speeder up 3~5x; Added SameMapper for same name between struct and table; Added Sync method for auto added tables, columns, indexes; + +* **v0.1.9** : Added postgres and mymysql supported; Added ` and ? supported on Raw SQL even if postgres; Added Cols, StoreEngine, Charset function, Added many column data type supported, please see [Mapping Rules](#mapping). + +* **v0.1.8** : Added union index and union unique supported, please see [Mapping Rules](#mapping). + +* **v0.1.7** : Added IConnectPool interface and NoneConnectPool, SysConnectPool, SimpleConnectPool the three implements. You can choose one of them and the default is SysConnectPool. You can customrize your own connection pool. struct Engine added Close method, It should be invoked before system exit. + +* **v0.1.6** : Added conversion interface support; added struct derive support; added single mapping support + +* **v0.1.5** : Added multi threads support; added Sql() function for struct query; Get function changed return inteface; MakeSession and Create are instead with NewSession and NewEngine. + +* **v0.1.4** : Added simple cascade load support; added more data type supports. + +* **v0.1.3** : Find function now supports both slice and map; Add Table function for multi tables and temperory tables support + +* **v0.1.2** : Insert function now supports both struct and slice pointer parameters, batch inserting and auto transaction + +* **v0.1.1** : Add Id, In functions and improved README + +* **v0.1.0** : Initial release. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 442aa4d3..a6925a5c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -22,6 +22,47 @@ e.g., // !lunny! this is comments made by lunny ``` +### Build xorm and test it locally + +Once you write some codes on your feature branch, you could build and test locally at first. Just + +``` +make build +``` +and +``` +make test +``` + +The `make test` is an alias of `make test-sqlite`, it will run the tests on a sqlite database file. No extra thing needed to do except you need to cgo compile enviroment. + +If you write a new test method, you could run + +``` +make test-sqlite#TestMyNewMethod +``` + +that will only run the special test method. + +If you want to run another datase, you have to prepare a running database at first, and then, you could + +``` +TEST_MYSQL_HOST= TEST_MYSQL_CHARSET= TEST_MYSQL_DBNAME= TEST_MYSQL_USERNAME= TEST_MYSQL_PASSWORD= make test-mysql +``` + +or other databases: +``` +TEST_MSSQL_HOST= TEST_MSSQL_DBNAME= TEST_MSSQL_USERNAME= TEST_MSSQL_PASSWORD= make test-mssql +``` +``` +TEST_PGSQL_HOST= TEST_PGSQL_SCHEMA= TEST_PGSQL_DBNAME= TEST_PGSQL_USERNAME= TEST_PGSQL_PASSWORD= make test-postgres +``` +``` +TEST_TIDB_HOST= TEST_TIDB_DBNAME= TEST_TIDB_USERNAME= TEST_TIDB_PASSWORD= make test-tidb +``` + +And if your branch is related with cache, you could also enable it via `TEST_CACHE_ENABLE=true`. + ### Patch review Help review existing open [pull requests](https://help.github.com/articles/using-pull-requests) by commenting on the code or diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..4cccacd8 --- /dev/null +++ b/Makefile @@ -0,0 +1,220 @@ +IMPORT := xorm.io/xorm +export GO111MODULE=on + +GO ?= go +GOFMT ?= gofmt -s +TAGS ?= +SED_INPLACE := sed -i + +GOFILES := $(shell find . -name "*.go" -type f) +INTEGRATION_PACKAGES := xorm.io/xorm/integrations +PACKAGES ?= $(filter-out $(INTEGRATION_PACKAGES),$(shell $(GO) list ./...)) + +TEST_COCKROACH_HOST ?= cockroach:26257 +TEST_COCKROACH_SCHEMA ?= +TEST_COCKROACH_DBNAME ?= xorm_test +TEST_COCKROACH_USERNAME ?= postgres +TEST_COCKROACH_PASSWORD ?= + +TEST_MSSQL_HOST ?= mssql:1433 +TEST_MSSQL_DBNAME ?= gitea +TEST_MSSQL_USERNAME ?= sa +TEST_MSSQL_PASSWORD ?= MwantsaSecurePassword1 + +TEST_MYSQL_HOST ?= mysql:3306 +TEST_MYSQL_CHARSET ?= utf8 +TEST_MYSQL_DBNAME ?= xorm_test +TEST_MYSQL_USERNAME ?= root +TEST_MYSQL_PASSWORD ?= + +TEST_PGSQL_HOST ?= pgsql:5432 +TEST_PGSQL_SCHEMA ?= +TEST_PGSQL_DBNAME ?= xorm_test +TEST_PGSQL_USERNAME ?= postgres +TEST_PGSQL_PASSWORD ?= mysecretpassword + +TEST_TIDB_HOST ?= tidb:4000 +TEST_TIDB_DBNAME ?= xorm_test +TEST_TIDB_USERNAME ?= root +TEST_TIDB_PASSWORD ?= + +TEST_CACHE_ENABLE ?= false +TEST_QUOTE_POLICY ?= always + +.PHONY: all +all: build + +.PHONY: build +build: go-check $(GO_SOURCES) + $(GO) build $(PACKAGES) + +.PHONY: clean +clean: + $(GO) clean -i ./... + rm -rf *.sql *.log test.db *coverage.out coverage.all integrations/*.sql + +.PHONY: coverage +coverage: + @hash gocovmerge > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + $(GO) get -u github.com/wadey/gocovmerge; \ + fi + gocovmerge $(shell find . -type f -name "coverage.out") > coverage.all;\ + +.PHONY: fmt +fmt: + $(GOFMT) -w $(GOFILES) + +.PHONY: fmt-check +fmt-check: + # get all go files and run go fmt on them + @diff=$$($(GOFMT) -d $(GOFILES)); \ + if [ -n "$$diff" ]; then \ + echo "Please run 'make fmt' and commit the result:"; \ + echo "$${diff}"; \ + exit 1; \ + fi; + +.PHONY: go-check +go-check: + $(eval GO_VERSION := $(shell printf "%03d%03d%03d" $(shell go version | grep -Eo '[0-9]+\.?[0-9]+?\.?[0-9]?\s' | tr '.' ' ');)) + @if [ "$(GO_VERSION)" -lt "001011000" ]; then \ + echo "Gitea requires Go 1.11.0 or greater to build. You can get it at https://golang.org/dl/"; \ + exit 1; \ + fi + +.PHONY: help +help: + @echo "Make Routines:" + @echo " - equivalent to \"build\"" + @echo " - build creates the entire project" + @echo " - clean delete integration files and build files but not css and js files" + @echo " - fmt format the code" + @echo " - lint run code linter revive" + @echo " - misspell check if a word is written wrong" + @echo " - test run default unit test" + @echo " - test-cockroach run integration tests for cockroach" + @echo " - test-mysql run integration tests for mysql" + @echo " - test-mssql run integration tests for mssql" + @echo " - test-postgres run integration tests for postgres" + @echo " - test-sqlite run integration tests for sqlite" + @echo " - test-tidb run integration tests for tidb" + @echo " - vet examines Go source code and reports suspicious constructs" + +.PHONY: lint +lint: revive + +.PHONY: revive +revive: + @hash revive > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + $(GO) get -u github.com/mgechev/revive; \ + fi + revive -config .revive.toml -exclude=./vendor/... ./... || exit 1 + +.PHONY: misspell +misspell: + @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + $(GO) get -u github.com/client9/misspell/cmd/misspell; \ + fi + misspell -w -i unknwon $(GOFILES) + +.PHONY: misspell-check +misspell-check: + @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + $(GO) get -u github.com/client9/misspell/cmd/misspell; \ + fi + misspell -error -i unknwon,destory $(GOFILES) + +.PHONY: test +test: go-check + $(GO) test $(PACKAGES) + +.PNONY: test-cockroach +test-cockroach: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=postgres -schema='$(TEST_COCKROACH_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + -conn_str="postgres://$(TEST_COCKROACH_USERNAME):$(TEST_COCKROACH_PASSWORD)@$(TEST_COCKROACH_HOST)/$(TEST_COCKROACH_DBNAME)?sslmode=disable&experimental_serial_normalization=sql_sequence" \ + -ignore_update_limit=true -coverprofile=cockroach.$(TEST_COCKROACH_SCHEMA).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PHONY: test-cockroach\#% +test-cockroach\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=postgres -schema='$(TEST_COCKROACH_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + -conn_str="postgres://$(TEST_COCKROACH_USERNAME):$(TEST_COCKROACH_PASSWORD)@$(TEST_COCKROACH_HOST)/$(TEST_COCKROACH_DBNAME)?sslmode=disable&experimental_serial_normalization=sql_sequence" \ + -ignore_update_limit=true -coverprofile=cockroach.$(TEST_COCKROACH_SCHEMA).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PNONY: test-mssql +test-mssql: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \ + -coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PNONY: test-mssql\#% +test-mssql\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mssql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="server=$(TEST_MSSQL_HOST);user id=$(TEST_MSSQL_USERNAME);password=$(TEST_MSSQL_PASSWORD);database=$(TEST_MSSQL_DBNAME)" \ + -coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PNONY: test-mymysql +test-mymysql: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ + -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PNONY: test-mymysql\#% +test-mymysql\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ + -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PNONY: test-mysql +test-mysql: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \ + -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PHONY: test-mysql\#% +test-mysql\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ + -conn_str="$(TEST_MYSQL_USERNAME):$(TEST_MYSQL_PASSWORD)@tcp($(TEST_MYSQL_HOST))/$(TEST_MYSQL_DBNAME)?charset=$(TEST_MYSQL_CHARSET)" \ + -coverprofile=mysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PNONY: test-postgres +test-postgres: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PHONY: test-postgres\#% +test-postgres\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=postgres -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)?sslmode=disable" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=postgres.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PHONY: test-sqlite +test-sqlite: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PHONY: test-sqlite-schema +test-sqlite-schema: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -schema=xorm -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PHONY: test-sqlite\#% +test-sqlite\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -cache=$(TEST_CACHE_ENABLE) -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=sqlite.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PNONY: test-tidb +test-tidb: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ + -conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PHONY: test-tidb\#% +test-tidb\#%: go-check + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mysql -cache=$(TEST_CACHE_ENABLE) -ignore_select_update=true \ + -conn_str="$(TEST_TIDB_USERNAME):$(TEST_TIDB_PASSWORD)@tcp($(TEST_TIDB_HOST))/$(TEST_TIDB_DBNAME)" \ + -quote=$(TEST_QUOTE_POLICY) -coverprofile=tidb.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic + +.PHONY: vet +vet: + $(GO) vet $(shell $(GO) list ./...) \ No newline at end of file diff --git a/README.md b/README.md index 17a6ed37..ed866224 100644 --- a/README.md +++ b/README.md @@ -4,55 +4,52 @@ Xorm is a simple and powerful ORM for Go. -[![Build Status](https://drone.gitea.com/api/badges/xorm/xorm/status.svg)](https://drone.gitea.com/xorm/xorm) [![](http://gocover.io/_badge/xorm.io/xorm)](https://gocover.io/xorm.io/xorm) -[![](https://goreportcard.com/badge/xorm.io/xorm)](https://goreportcard.com/report/xorm.io/xorm) -[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3) +[![Build Status](https://drone.gitea.com/api/badges/xorm/xorm/status.svg)](https://drone.gitea.com/xorm/xorm) [![](http://gocover.io/_badge/xorm.io/xorm)](https://gocover.io/xorm.io/xorm) [![](https://goreportcard.com/badge/xorm.io/xorm)](https://goreportcard.com/report/xorm.io/xorm) [![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3) + +## Notice + +v1.0.0 has some break changes from v0.8.2. + +- Removed some non gonic function name `Id`, `Sql`, please use `ID`, `SQL` instead. +- Removed the dependent from `xorm.io/core` and moved the codes to `xorm.io/xorm/core`, `xorm.io/xorm/names`, `xorm.io/xorm/schemas` and others. +- Renamed some interface names. i.e. `core.IMapper` -> `names.Mapper`, `core.ILogger` -> `log.Logger`. ## Features * Struct <-> Table Mapping Support - * Chainable APIs - * Transaction Support - * Both ORM and raw SQL operation Support - * Sync database schema Support - * Query Cache speed up - -* Database Reverse support, See [Xorm Tool README](https://github.com/go-xorm/cmd/blob/master/README.md) - +* Database Reverse support via [xorm.io/reverse](https://xorm.io/reverse) * Simple cascade loading support - * Optimistic Locking support - * SQL Builder support via [xorm.io/builder](https://xorm.io/builder) - * Automatical Read/Write seperatelly - * Postgres schema support - * Context Cache support +* Support log/SQLLog context ## Drivers Support Drivers for Go's sql package which currently support database/sql includes: -* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) +* [Mysql5.*](https://github.com/mysql/mysql-server/tree/5.7) / [Mysql8.*](https://github.com/mysql/mysql-server) / [Mariadb](https://github.com/MariaDB/server) / [Tidb](https://github.com/pingcap/tidb) + - [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) + - [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) -* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/tree/master/godrv) +* [Postgres](https://github.com/postgres/postgres) / [Cockroach](https://github.com/cockroachdb/cockroach) + - [github.com/lib/pq](https://github.com/lib/pq) -* Postgres: [github.com/lib/pq](https://github.com/lib/pq) +* [SQLite](https://sqlite.org) + - [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) -* Tidb: [github.com/pingcap/tidb](https://github.com/pingcap/tidb) +* MsSql + - [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) -* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) - -* MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) - -* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment) +* Oracle + - [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment) ## Installation @@ -62,12 +59,14 @@ Drivers for Go's sql package which currently support database/sql includes: * [Manual](http://xorm.io/docs) -* [GoDoc](http://godoc.org/xorm.io/xorm) +* [GoDoc](http://pkg.go.dev/xorm.io/xorm) ## Quick Start * Create Engine +Firstly, we should new an engine for a database. + ```Go engine, err := xorm.NewEngine(driverName, dataSourceName) ``` @@ -419,7 +418,7 @@ res, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) ## Contributing -If you want to pull request, please see [CONTRIBUTING](https://gitea.com/xorm/xorm/src/branch/master/CONTRIBUTING.md). And we also provide [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) to discuss. +If you want to pull request, please see [CONTRIBUTING](https://gitea.com/xorm/xorm/src/branch/master/CONTRIBUTING.md). And you can also go to [Xorm on discourse](https://xorm.discourse.group) to discuss. ## Credits @@ -440,27 +439,7 @@ Support this project by becoming a sponsor. Your logo will show up here with a l ## Changelog -* **v0.7.0** - * Some bugs fixed - -* **v0.6.6** - * Some bugs fixed - -* **v0.6.5** - * Postgres schema support - * vgo support - * Add FindAndCount - * Database special params support via NewEngineWithParams - * Some bugs fixed - -* **v0.6.4** - * Automatical Read/Write seperatelly - * Query/QueryString/QueryInterface and action with Where/And - * Get support non-struct variables - * BufferSize on Iterate - * fix some other bugs. - -[More changes ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16) +You can find all the changelog [here](CHANGELOG.md) ## Cases diff --git a/README_CN.md b/README_CN.md index 644bdc0b..80245dd3 100644 --- a/README_CN.md +++ b/README_CN.md @@ -2,57 +2,53 @@ [English](https://gitea.com/xorm/xorm/src/branch/master/README.md) -xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 +xorm 是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 -[![Build Status](https://drone.gitea.com/api/badges/xorm/builder/status.svg)](https://drone.gitea.com/xorm/builder) [![](http://gocover.io/_badge/xorm.io/xorm)](https://gocover.io/xorm.io/xorm) -[![](https://goreportcard.com/badge/xorm.io/xorm)](https://goreportcard.com/report/xorm.io/xorm) -[![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3) +[![Build Status](https://drone.gitea.com/api/badges/xorm/xorm/status.svg)](https://drone.gitea.com/xorm/xorm) [![](http://gocover.io/_badge/xorm.io/xorm)](https://gocover.io/xorm.io/xorm) [![](https://goreportcard.com/badge/xorm.io/xorm)](https://goreportcard.com/report/xorm.io/xorm) [![Join the chat at https://img.shields.io/discord/323460943201959939.svg](https://img.shields.io/discord/323460943201959939.svg)](https://discord.gg/HuR2CF3) + +## Notice + +v1.0.0 相对于 v0.8.2 有以下不兼容的变更: + +- 移除了部分不符合Go语言命名的函数,如 `Id`, `Sql`,请使用 `ID`, `SQL` 替代。 +- 删除了对 `xorm.io/core` 的依赖。大部分代码迁移到了 `xorm.io/xorm/core`, `xorm.io/xorm/names`, `xorm.io/xorm/schemas` 等等几个包中. +- 重命名了几个结构体,如: `core.IMapper` -> `names.Mapper`, `core.ILogger` -> `log.Logger`. ## 特性 -* 支持Struct和数据库表之间的灵活映射,并支持自动同步 - +* 支持 Struct 和数据库表之间的灵活映射,并支持自动同步 * 事务支持 - * 同时支持原始SQL语句和ORM操作的混合执行 - * 使用连写来简化调用 - -* 支持使用Id, In, Where, Limit, Join, Having, Table, Sql, Cols等函数和结构体等方式作为条件 - +* 支持使用ID, In, Where, Limit, Join, Having, Table, SQL, Cols等函数和结构体等方式作为条件 * 支持级联加载Struct - * Schema支持(仅Postgres) - * 支持缓存 - -* 支持根据数据库自动生成xorm的结构体 - +* 通过 [xorm.io/reverse](https://xorm.io/reverse) 支持根据数据库自动生成 xorm 结构体 * 支持记录版本(即乐观锁) - -* 内置SQL Builder支持 - +* 通过 [xorm.io/builder](https://xorm.io/builder) 内置 SQL Builder 支持 * 上下文缓存支持 +* 支持日志上下文 ## 驱动支持 目前支持的Go数据库驱动和对应的数据库如下: -* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) +* [Mysql5.*](https://github.com/mysql/mysql-server/tree/5.7) / [Mysql8.*](https://github.com/mysql/mysql-server) / [Mariadb](https://github.com/MariaDB/server) / [Tidb](https://github.com/pingcap/tidb) + - [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) + - [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) -* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) +* [Postgres](https://github.com/postgres/postgres) / [Cockroach](https://github.com/cockroachdb/cockroach) + - [github.com/lib/pq](https://github.com/lib/pq) -* Postgres: [github.com/lib/pq](https://github.com/lib/pq) +* [SQLite](https://sqlite.org) + - [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) -* Tidb: [github.com/pingcap/tidb](https://github.com/pingcap/tidb) +* MsSql + - [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) -* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) - -* MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) - -* MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) - -* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (试验性支持) +* Oracle + - [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (试验性支持) ## 安装 @@ -62,7 +58,7 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * [操作指南](http://xorm.io/docs) -* [Godoc代码文档](http://godoc.org/xorm.io/xorm) +* [Godoc代码文档](http://pkg.go.dev/xorm.io/xorm) # 快速开始 @@ -435,14 +431,14 @@ res, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) # 案例 -* [Go语言中文网](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang) - * [Gitea](http://gitea.io) - [github.com/go-gitea/gitea](http://github.com/go-gitea/gitea) * [Gogs](http://try.gogits.org) - [github.com/gogits/gogs](http://github.com/gogits/gogs) * [grafana](https://grafana.com/) - [github.com/grafana/grafana](http://github.com/grafana/grafana) +* [Go语言中文网](http://studygolang.com/) - [github.com/studygolang/studygolang](https://github.com/studygolang/studygolang) + * [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader) * [Wego](http://github.com/go-tango/wego) @@ -470,27 +466,7 @@ res, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) ## 更新日志 -* **v0.7.0** - * 修正部分Bug - -* **v0.6.6** - * 修正部分Bug - -* **v0.6.5** - * 通过 engine.SetSchema 来支持 schema,当前仅支持Postgres - * vgo 支持 - * 新增 `FindAndCount` 函数 - * 通过 `NewEngineWithParams` 支持数据库特别参数 - * 修正部分Bug - -* **v0.6.4** - * 自动读写分离支持 - * Query/QueryString/QueryInterface 支持与 Where/And 合用 - * `Get` 支持获取非结构体变量 - * `Iterate` 支持 `BufferSize` - * 修正部分Bug - -[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16) +请访问 [CHANGELOG.md](CHANGELOG.md) 获得更新日志。 ## LICENSE diff --git a/caches/cache.go b/caches/cache.go new file mode 100644 index 00000000..7b80eb88 --- /dev/null +++ b/caches/cache.go @@ -0,0 +1,99 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "strings" + "time" + + "xorm.io/xorm/schemas" +) + +const ( + // CacheExpired is default cache expired time + CacheExpired = 60 * time.Minute + // CacheMaxMemory is not use now + CacheMaxMemory = 256 + // CacheGcInterval represents interval time to clear all expired nodes + CacheGcInterval = 10 * time.Minute + // CacheGcMaxRemoved represents max nodes removed when gc + CacheGcMaxRemoved = 20 +) + +// list all the errors +var ( + ErrCacheMiss = errors.New("xorm/cache: key not found") + ErrNotStored = errors.New("xorm/cache: not stored") + // ErrNotExist record does not exist error + ErrNotExist = errors.New("Record does not exist") +) + +// CacheStore is a interface to store cache +type CacheStore interface { + // key is primary key or composite primary key + // value is struct's pointer + // key format : -p--... + Put(key string, value interface{}) error + Get(key string) (interface{}, error) + Del(key string) error +} + +// Cacher is an interface to provide cache +// id format : u--... +type Cacher interface { + GetIds(tableName, sql string) interface{} + GetBean(tableName string, id string) interface{} + PutIds(tableName, sql string, ids interface{}) + PutBean(tableName string, id string, obj interface{}) + DelIds(tableName, sql string) + DelBean(tableName string, id string) + ClearIds(tableName string) + ClearBeans(tableName string) +} + +func encodeIds(ids []schemas.PK) (string, error) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + err := enc.Encode(ids) + + return buf.String(), err +} + +func decodeIds(s string) ([]schemas.PK, error) { + pks := make([]schemas.PK, 0) + + dec := gob.NewDecoder(strings.NewReader(s)) + err := dec.Decode(&pks) + + return pks, err +} + +// GetCacheSql returns cacher PKs via SQL +func GetCacheSql(m Cacher, tableName, sql string, args interface{}) ([]schemas.PK, error) { + bytes := m.GetIds(tableName, GenSqlKey(sql, args)) + if bytes == nil { + return nil, errors.New("Not Exist") + } + return decodeIds(bytes.(string)) +} + +// PutCacheSql puts cacher SQL and PKs +func PutCacheSql(m Cacher, ids []schemas.PK, tableName, sql string, args interface{}) error { + bytes, err := encodeIds(ids) + if err != nil { + return err + } + m.PutIds(tableName, GenSqlKey(sql, args), bytes) + return nil +} + +// GenSqlKey generates cache key +func GenSqlKey(sql string, args interface{}) string { + return fmt.Sprintf("%v-%v", sql, args) +} diff --git a/caches/encode.go b/caches/encode.go new file mode 100644 index 00000000..4ba39924 --- /dev/null +++ b/caches/encode.go @@ -0,0 +1,58 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import ( + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/json" + "fmt" + "io" +) + +// md5 hash string +func Md5(str string) string { + m := md5.New() + io.WriteString(m, str) + return fmt.Sprintf("%x", m.Sum(nil)) +} +func Encode(data interface{}) ([]byte, error) { + //return JsonEncode(data) + return GobEncode(data) +} + +func Decode(data []byte, to interface{}) error { + //return JsonDecode(data, to) + return GobDecode(data, to) +} + +func GobEncode(data interface{}) ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(&data) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func GobDecode(data []byte, to interface{}) error { + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + return dec.Decode(to) +} + +func JsonEncode(data interface{}) ([]byte, error) { + val, err := json.Marshal(data) + if err != nil { + return nil, err + } + return val, nil +} + +func JsonDecode(data []byte, to interface{}) error { + return json.Unmarshal(data, to) +} diff --git a/caches/leveldb.go b/caches/leveldb.go new file mode 100644 index 00000000..d1a177ad --- /dev/null +++ b/caches/leveldb.go @@ -0,0 +1,94 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import ( + "log" + + "github.com/syndtr/goleveldb/leveldb" +) + +// LevelDBStore implements CacheStore provide local machine +type LevelDBStore struct { + store *leveldb.DB + Debug bool + v interface{} +} + +var _ CacheStore = &LevelDBStore{} + +func NewLevelDBStore(dbfile string) (*LevelDBStore, error) { + db := &LevelDBStore{} + h, err := leveldb.OpenFile(dbfile, nil) + if err != nil { + return nil, err + } + db.store = h + return db, nil +} + +func (s *LevelDBStore) Put(key string, value interface{}) error { + val, err := Encode(value) + if err != nil { + if s.Debug { + log.Println("[LevelDB]EncodeErr: ", err, "Key:", key) + } + return err + } + err = s.store.Put([]byte(key), val, nil) + if err != nil { + if s.Debug { + log.Println("[LevelDB]PutErr: ", err, "Key:", key) + } + return err + } + if s.Debug { + log.Println("[LevelDB]Put: ", key) + } + return err +} + +func (s *LevelDBStore) Get(key string) (interface{}, error) { + data, err := s.store.Get([]byte(key), nil) + if err != nil { + if s.Debug { + log.Println("[LevelDB]GetErr: ", err, "Key:", key) + } + if err == leveldb.ErrNotFound { + return nil, ErrNotExist + } + return nil, err + } + + err = Decode(data, &s.v) + if err != nil { + if s.Debug { + log.Println("[LevelDB]DecodeErr: ", err, "Key:", key) + } + return nil, err + } + if s.Debug { + log.Println("[LevelDB]Get: ", key, s.v) + } + return s.v, err +} + +func (s *LevelDBStore) Del(key string) error { + err := s.store.Delete([]byte(key), nil) + if err != nil { + if s.Debug { + log.Println("[LevelDB]DelErr: ", err, "Key:", key) + } + return err + } + if s.Debug { + log.Println("[LevelDB]Del: ", key) + } + return err +} + +func (s *LevelDBStore) Close() { + s.store.Close() +} diff --git a/caches/leveldb_test.go b/caches/leveldb_test.go new file mode 100644 index 00000000..35981db1 --- /dev/null +++ b/caches/leveldb_test.go @@ -0,0 +1,39 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLevelDBStore(t *testing.T) { + store, err := NewLevelDBStore("./level.db") + assert.NoError(t, err) + + var kvs = map[string]interface{}{ + "a": "b", + } + for k, v := range kvs { + assert.NoError(t, store.Put(k, v)) + } + + for k, v := range kvs { + val, err := store.Get(k) + assert.NoError(t, err) + assert.EqualValues(t, v, val) + } + + for k := range kvs { + err := store.Del(k) + assert.NoError(t, err) + } + + for k := range kvs { + _, err := store.Get(k) + assert.EqualValues(t, ErrNotExist, err) + } +} diff --git a/cache_lru.go b/caches/lru.go similarity index 93% rename from cache_lru.go rename to caches/lru.go index ab948bd2..6b45ac94 100644 --- a/cache_lru.go +++ b/caches/lru.go @@ -2,15 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package caches import ( "container/list" "fmt" "sync" "time" - - "xorm.io/core" ) // LRUCacher implments cache object facilities @@ -19,7 +17,7 @@ type LRUCacher struct { sqlList *list.List idIndex map[string]map[string]*list.Element sqlIndex map[string]map[string]*list.Element - store core.CacheStore + store CacheStore mutex sync.Mutex MaxElementSize int Expired time.Duration @@ -27,15 +25,15 @@ type LRUCacher struct { } // NewLRUCacher creates a cacher -func NewLRUCacher(store core.CacheStore, maxElementSize int) *LRUCacher { +func NewLRUCacher(store CacheStore, maxElementSize int) *LRUCacher { return NewLRUCacher2(store, 3600*time.Second, maxElementSize) } // NewLRUCacher2 creates a cache include different params -func NewLRUCacher2(store core.CacheStore, expired time.Duration, maxElementSize int) *LRUCacher { +func NewLRUCacher2(store CacheStore, expired time.Duration, maxElementSize int) *LRUCacher { cacher := &LRUCacher{store: store, idList: list.New(), sqlList: list.New(), Expired: expired, - GcInterval: core.CacheGcInterval, MaxElementSize: maxElementSize, + GcInterval: CacheGcInterval, MaxElementSize: maxElementSize, sqlIndex: make(map[string]map[string]*list.Element), idIndex: make(map[string]map[string]*list.Element), } @@ -57,7 +55,7 @@ func (m *LRUCacher) GC() { defer m.mutex.Unlock() var removedNum int for e := m.idList.Front(); e != nil; { - if removedNum <= core.CacheGcMaxRemoved && + if removedNum <= CacheGcMaxRemoved && time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { removedNum++ next := e.Next() @@ -71,7 +69,7 @@ func (m *LRUCacher) GC() { removedNum = 0 for e := m.sqlList.Front(); e != nil; { - if removedNum <= core.CacheGcMaxRemoved && + if removedNum <= CacheGcMaxRemoved && time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { removedNum++ next := e.Next() @@ -268,11 +266,11 @@ type sqlNode struct { } func genSQLKey(sql string, args interface{}) string { - return fmt.Sprintf("%v-%v", sql, args) + return fmt.Sprintf("%s-%v", sql, args) } func genID(prefix string, id string) string { - return fmt.Sprintf("%v-%v", prefix, id) + return fmt.Sprintf("%s-%s", prefix, id) } func newIDNode(tbName string, id string) *idNode { diff --git a/cache_lru_test.go b/caches/lru_test.go similarity index 94% rename from cache_lru_test.go rename to caches/lru_test.go index 7da36f00..771b924c 100644 --- a/cache_lru_test.go +++ b/caches/lru_test.go @@ -2,13 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package caches import ( "testing" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/xorm/schemas" ) func TestLRUCache(t *testing.T) { @@ -20,7 +20,7 @@ func TestLRUCache(t *testing.T) { cacher := NewLRUCacher(store, 10000) tableName := "cache_object1" - pks := []core.PK{ + pks := []schemas.PK{ {1}, {2}, } diff --git a/caches/manager.go b/caches/manager.go new file mode 100644 index 00000000..05045210 --- /dev/null +++ b/caches/manager.go @@ -0,0 +1,56 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package caches + +import "sync" + +type Manager struct { + cacher Cacher + disableGlobalCache bool + + cachers map[string]Cacher + cacherLock sync.RWMutex +} + +func NewManager() *Manager { + return &Manager{ + cachers: make(map[string]Cacher), + } +} + +// SetDisableGlobalCache disable global cache or not +func (mgr *Manager) SetDisableGlobalCache(disable bool) { + if mgr.disableGlobalCache != disable { + mgr.disableGlobalCache = disable + } +} + +func (mgr *Manager) SetCacher(tableName string, cacher Cacher) { + mgr.cacherLock.Lock() + mgr.cachers[tableName] = cacher + mgr.cacherLock.Unlock() +} + +func (mgr *Manager) GetCacher(tableName string) Cacher { + var cacher Cacher + var ok bool + mgr.cacherLock.RLock() + cacher, ok = mgr.cachers[tableName] + mgr.cacherLock.RUnlock() + if !ok && !mgr.disableGlobalCache { + cacher = mgr.cacher + } + return cacher +} + +// SetDefaultCacher set the default cacher. Xorm's default not enable cacher. +func (mgr *Manager) SetDefaultCacher(cacher Cacher) { + mgr.cacher = cacher +} + +// GetDefaultCacher returns the default cacher +func (mgr *Manager) GetDefaultCacher() Cacher { + return mgr.cacher +} diff --git a/cache_memory_store.go b/caches/memory_store.go similarity index 93% rename from cache_memory_store.go rename to caches/memory_store.go index 0c483f45..f16254d8 100644 --- a/cache_memory_store.go +++ b/caches/memory_store.go @@ -2,15 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package caches import ( "sync" - - "xorm.io/core" ) -var _ core.CacheStore = NewMemoryStore() +var _ CacheStore = NewMemoryStore() // MemoryStore represents in-memory store type MemoryStore struct { diff --git a/cache_memory_store_test.go b/caches/memory_store_test.go similarity index 91% rename from cache_memory_store_test.go rename to caches/memory_store_test.go index fc27ae32..12db4ea7 100644 --- a/cache_memory_store_test.go +++ b/caches/memory_store_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package caches import ( "testing" @@ -25,12 +25,12 @@ func TestMemoryStore(t *testing.T) { assert.EqualValues(t, v, val) } - for k, _ := range kvs { + for k := range kvs { err := store.Del(k) assert.NoError(t, err) } - for k, _ := range kvs { + for k := range kvs { _, err := store.Get(k) assert.EqualValues(t, ErrNotExist, err) } diff --git a/context_cache.go b/contexts/context_cache.go similarity index 97% rename from context_cache.go rename to contexts/context_cache.go index 1bc22884..0d0f0f02 100644 --- a/context_cache.go +++ b/contexts/context_cache.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package contexts // ContextCache is the interface that operates the cache data. type ContextCache interface { diff --git a/contexts/hook.go b/contexts/hook.go new file mode 100644 index 00000000..71ad8e87 --- /dev/null +++ b/contexts/hook.go @@ -0,0 +1,75 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package contexts + +import ( + "context" + "database/sql" + "time" +) + +// ContextHook represents a hook context +type ContextHook struct { + start time.Time + Ctx context.Context + SQL string // log content or SQL + Args []interface{} // if it's a SQL, it's the arguments + Result sql.Result + ExecuteTime time.Duration + Err error // SQL executed error +} + +// NewContextHook return context for hook +func NewContextHook(ctx context.Context, sql string, args []interface{}) *ContextHook { + return &ContextHook{ + start: time.Now(), + Ctx: ctx, + SQL: sql, + Args: args, + } +} + +func (c *ContextHook) End(ctx context.Context, result sql.Result, err error) { + c.Ctx = ctx + c.Result = result + c.Err = err + c.ExecuteTime = time.Now().Sub(c.start) +} + +type Hook interface { + BeforeProcess(c *ContextHook) (context.Context, error) + AfterProcess(c *ContextHook) error +} + +type Hooks struct { + hooks []Hook +} + +func (h *Hooks) AddHook(hooks ...Hook) { + h.hooks = append(h.hooks, hooks...) +} + +func (h *Hooks) BeforeProcess(c *ContextHook) (context.Context, error) { + ctx := c.Ctx + for _, h := range h.hooks { + var err error + ctx, err = h.BeforeProcess(c) + if err != nil { + return nil, err + } + } + return ctx, nil +} + +func (h *Hooks) AfterProcess(c *ContextHook) error { + firstErr := c.Err + for _, h := range h.hooks { + err := h.AfterProcess(c) + if err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} diff --git a/contexts/hook_test.go b/contexts/hook_test.go new file mode 100644 index 00000000..96c54e33 --- /dev/null +++ b/contexts/hook_test.go @@ -0,0 +1,140 @@ +package contexts + +import ( + "context" + "errors" + "testing" +) + +type testHook struct { + before func(c *ContextHook) (context.Context, error) + after func(c *ContextHook) error +} + +func (h *testHook) BeforeProcess(c *ContextHook) (context.Context, error) { + if h.before != nil { + return h.before(c) + } + return c.Ctx, nil +} + +func (h *testHook) AfterProcess(c *ContextHook) error { + if h.after != nil { + return h.after(c) + } + return c.Err +} + +var _ Hook = &testHook{} + +func TestBeforeProcess(t *testing.T) { + expectErr := errors.New("before error") + tests := []struct { + msg string + hooks []Hook + expect error + }{ + { + msg: "first hook return err", + hooks: []Hook{ + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, expectErr + }, + }, + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, nil + }, + }, + }, + expect: expectErr, + }, + { + msg: "second hook return err", + hooks: []Hook{ + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, nil + }, + }, + &testHook{ + before: func(c *ContextHook) (ctx context.Context, err error) { + return c.Ctx, expectErr + }, + }, + }, + expect: expectErr, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + hooks := Hooks{} + hooks.AddHook(tt.hooks...) + _, err := hooks.BeforeProcess(&ContextHook{ + Ctx: context.Background(), + }) + if err != tt.expect { + t.Errorf("got %v, expect %v", err, tt.expect) + } + }) + } +} + +func TestAfterProcess(t *testing.T) { + expectErr := errors.New("expect err") + tests := []struct { + msg string + ctx *ContextHook + hooks []Hook + expect error + }{ + { + msg: "context has err", + ctx: &ContextHook{ + Ctx: context.Background(), + Err: expectErr, + }, + hooks: []Hook{ + &testHook{ + after: func(c *ContextHook) error { + return errors.New("hook err") + }, + }, + }, + expect: expectErr, + }, + { + msg: "last hook has err", + ctx: &ContextHook{ + Ctx: context.Background(), + Err: nil, + }, + hooks: []Hook{ + &testHook{ + after: func(c *ContextHook) error { + return nil + }, + }, + &testHook{ + after: func(c *ContextHook) error { + return expectErr + }, + }, + }, + expect: expectErr, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + hooks := Hooks{} + hooks.AddHook(tt.hooks...) + err := hooks.AfterProcess(tt.ctx) + if err != tt.expect { + t.Errorf("got %v, expect %v", err, tt.expect) + } + }) + } +} diff --git a/convert.go b/convert.go index 2316ca0b..c19d30e0 100644 --- a/convert.go +++ b/convert.go @@ -25,11 +25,10 @@ func strconvErr(err error) error { func cloneBytes(b []byte) []byte { if b == nil { return nil - } else { - c := make([]byte, len(b)) - copy(c, b) - return c } + c := make([]byte, len(b)) + copy(c, b) + return c } func asString(src interface{}) string { @@ -285,56 +284,6 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) } -func convertFloat(v interface{}) (float64, error) { - switch v.(type) { - case float32: - return float64(v.(float32)), nil - case float64: - return v.(float64), nil - case string: - i, err := strconv.ParseFloat(v.(string), 64) - if err != nil { - return 0, err - } - return i, nil - case []byte: - i, err := strconv.ParseFloat(string(v.([]byte)), 64) - if err != nil { - return 0, err - } - return i, nil - } - return 0, fmt.Errorf("unsupported type: %v", v) -} - -func convertInt(v interface{}) (int64, error) { - switch v.(type) { - case int: - return int64(v.(int)), nil - case int8: - return int64(v.(int8)), nil - case int16: - return int64(v.(int16)), nil - case int32: - return int64(v.(int32)), nil - case int64: - return v.(int64), nil - case []byte: - i, err := strconv.ParseInt(string(v.([]byte)), 10, 64) - if err != nil { - return 0, err - } - return i, nil - case string: - i, err := strconv.ParseInt(v.(string), 10, 64) - if err != nil { - return 0, err - } - return i, nil - } - return 0, fmt.Errorf("unsupported type: %v", v) -} - func asBool(bs []byte) (bool, error) { if len(bs) == 0 { return false, nil @@ -346,3 +295,128 @@ func asBool(bs []byte) (bool, error) { } return strconv.ParseBool(string(bs)) } + +// str2PK convert string value to primary key value according to tp +func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { + var err error + var result interface{} + var defReturn = reflect.Zero(tp) + + switch tp.Kind() { + case reflect.Int: + result, err = strconv.Atoi(s) + if err != nil { + return defReturn, fmt.Errorf("convert %s as int: %s", s, err.Error()) + } + case reflect.Int8: + x, err := strconv.Atoi(s) + if err != nil { + return defReturn, fmt.Errorf("convert %s as int8: %s", s, err.Error()) + } + result = int8(x) + case reflect.Int16: + x, err := strconv.Atoi(s) + if err != nil { + return defReturn, fmt.Errorf("convert %s as int16: %s", s, err.Error()) + } + result = int16(x) + case reflect.Int32: + x, err := strconv.Atoi(s) + if err != nil { + return defReturn, fmt.Errorf("convert %s as int32: %s", s, err.Error()) + } + result = int32(x) + case reflect.Int64: + result, err = strconv.ParseInt(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as int64: %s", s, err.Error()) + } + case reflect.Uint: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as uint: %s", s, err.Error()) + } + result = uint(x) + case reflect.Uint8: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as uint8: %s", s, err.Error()) + } + result = uint8(x) + case reflect.Uint16: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as uint16: %s", s, err.Error()) + } + result = uint16(x) + case reflect.Uint32: + x, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as uint32: %s", s, err.Error()) + } + result = uint32(x) + case reflect.Uint64: + result, err = strconv.ParseUint(s, 10, 64) + if err != nil { + return defReturn, fmt.Errorf("convert %s as uint64: %s", s, err.Error()) + } + case reflect.String: + result = s + default: + return defReturn, errors.New("unsupported convert type") + } + return reflect.ValueOf(result).Convert(tp), nil +} + +func str2PK(s string, tp reflect.Type) (interface{}, error) { + v, err := str2PKValue(s, tp) + if err != nil { + return nil, err + } + return v.Interface(), nil +} + +func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { + var v interface{} + kind := tp.Kind() + + if kind == reflect.Ptr { + kind = tp.Elem().Kind() + } + + switch kind { + case reflect.Int16: + temp := int16(id) + v = &temp + case reflect.Int32: + temp := int32(id) + v = &temp + case reflect.Int: + temp := int(id) + v = &temp + case reflect.Int64: + temp := id + v = &temp + case reflect.Uint16: + temp := uint16(id) + v = &temp + case reflect.Uint32: + temp := uint32(id) + v = &temp + case reflect.Uint64: + temp := uint64(id) + v = &temp + case reflect.Uint: + temp := uint(id) + v = &temp + } + + if tp.Kind() == reflect.Ptr { + return reflect.ValueOf(v).Convert(tp) + } + return reflect.ValueOf(v).Elem().Convert(tp) +} + +func int64ToInt(id int64, tp reflect.Type) interface{} { + return int64ToIntValue(id, tp).Interface() +} diff --git a/convert/conversion.go b/convert/conversion.go new file mode 100644 index 00000000..16f1a92a --- /dev/null +++ b/convert/conversion.go @@ -0,0 +1,12 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package convert + +// Conversion is an interface. A type implements Conversion will according +// the custom method to fill into database and retrieve from database. +type Conversion interface { + FromDB([]byte) error + ToDB() ([]byte, error) +} diff --git a/core/db.go b/core/db.go new file mode 100644 index 00000000..50c64c6f --- /dev/null +++ b/core/db.go @@ -0,0 +1,293 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "sync" + + "xorm.io/xorm/contexts" + "xorm.io/xorm/log" + "xorm.io/xorm/names" +) + +var ( + // DefaultCacheSize sets the default cache size + DefaultCacheSize = 200 +) + +func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return "", []interface{}{}, ErrNoMapPointer + } + + args := make([]interface{}, 0, len(vv.Elem().MapKeys())) + var err error + query = re.ReplaceAllStringFunc(query, func(src string) string { + v := vv.Elem().MapIndex(reflect.ValueOf(src[1:])) + if !v.IsValid() { + err = fmt.Errorf("map key %s is missing", src[1:]) + } else { + args = append(args, v.Interface()) + } + return "?" + }) + + return query, args, err +} + +func StructToSlice(query string, st interface{}) (string, []interface{}, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return "", []interface{}{}, ErrNoStructPointer + } + + args := make([]interface{}, 0) + var err error + query = re.ReplaceAllStringFunc(query, func(src string) string { + fv := vv.Elem().FieldByName(src[1:]).Interface() + if v, ok := fv.(driver.Valuer); ok { + var value driver.Value + value, err = v.Value() + if err != nil { + return "?" + } + args = append(args, value) + } else { + args = append(args, fv) + } + return "?" + }) + if err != nil { + return "", []interface{}{}, err + } + return query, args, nil +} + +type cacheStruct struct { + value reflect.Value + idx int +} + +var ( + _ QueryExecuter = &DB{} +) + +// DB is a wrap of sql.DB with extra contents +type DB struct { + *sql.DB + Mapper names.Mapper + reflectCache map[reflect.Type]*cacheStruct + reflectCacheMutex sync.RWMutex + Logger log.ContextLogger + hooks contexts.Hooks +} + +// Open opens a database +func Open(driverName, dataSourceName string) (*DB, error) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + return &DB{ + DB: db, + Mapper: names.NewCacheMapper(&names.SnakeMapper{}), + reflectCache: make(map[reflect.Type]*cacheStruct), + }, nil +} + +// FromDB creates a DB from a sql.DB +func FromDB(db *sql.DB) *DB { + return &DB{ + DB: db, + Mapper: names.NewCacheMapper(&names.SnakeMapper{}), + reflectCache: make(map[reflect.Type]*cacheStruct), + } +} + +// NeedLogSQL returns true if need to log SQL +func (db *DB) NeedLogSQL(ctx context.Context) bool { + if db.Logger == nil { + return false + } + + v := ctx.Value(log.SessionShowSQLKey) + if showSQL, ok := v.(bool); ok { + return showSQL + } + return db.Logger.IsShowSQL() +} + +func (db *DB) reflectNew(typ reflect.Type) reflect.Value { + db.reflectCacheMutex.Lock() + defer db.reflectCacheMutex.Unlock() + cs, ok := db.reflectCache[typ] + if !ok || cs.idx+1 > DefaultCacheSize-1 { + cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0} + db.reflectCache[typ] = cs + } else { + cs.idx = cs.idx + 1 + } + return cs.value.Index(cs.idx).Addr() +} + +// QueryContext overwrites sql.DB.QueryContext +func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := db.beforeProcess(hookCtx) + if err != nil { + return nil, err + } + rows, err := db.DB.QueryContext(ctx, query, args...) + hookCtx.End(ctx, nil, err) + if err := db.afterProcess(hookCtx); err != nil { + if rows != nil { + rows.Close() + } + return nil, err + } + return &Rows{rows, db}, nil +} + +// Query overwrites sql.DB.Query +func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +// QueryMapContext executes query with parameters via map and context +func (db *DB) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return db.QueryContext(ctx, query, args...) +} + +// QueryMap executes query with parameters via map +func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { + return db.QueryMapContext(context.Background(), query, mp) +} + +func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return db.QueryContext(ctx, query, args...) +} + +func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { + return db.QueryStructContext(context.Background(), query, st) +} + +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return &Row{nil, err} + } + return &Row{rows, nil} +} + +func (db *DB) QueryRow(query string, args ...interface{}) *Row { + return db.QueryRowContext(context.Background(), query, args...) +} + +func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { + query, args, err := MapToSlice(query, mp) + if err != nil { + return &Row{nil, err} + } + return db.QueryRowContext(ctx, query, args...) +} + +func (db *DB) QueryRowMap(query string, mp interface{}) *Row { + return db.QueryRowMapContext(context.Background(), query, mp) +} + +func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { + query, args, err := StructToSlice(query, st) + if err != nil { + return &Row{nil, err} + } + return db.QueryRowContext(ctx, query, args...) +} + +func (db *DB) QueryRowStruct(query string, st interface{}) *Row { + return db.QueryRowStructContext(context.Background(), query, st) +} + +var ( + re = regexp.MustCompile(`[?](\w+)`) +) + +// ExecMapContext exec map with context.ContextHook +// insert into (name) values (?) +// insert into (name) values (?name) +func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return db.ExecContext(ctx, query, args...) +} + +func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { + return db.ExecMapContext(context.Background(), query, mp) +} + +func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return db.ExecContext(ctx, query, args...) +} + +func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := db.beforeProcess(hookCtx) + if err != nil { + return nil, err + } + res, err := db.DB.ExecContext(ctx, query, args...) + hookCtx.End(ctx, res, err) + if err := db.afterProcess(hookCtx); err != nil { + return nil, err + } + return res, nil +} + +func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { + return db.ExecStructContext(context.Background(), query, st) +} + +func (db *DB) beforeProcess(c *contexts.ContextHook) (context.Context, error) { + if db.NeedLogSQL(c.Ctx) { + db.Logger.BeforeSQL(log.LogContext(*c)) + } + ctx, err := db.hooks.BeforeProcess(c) + if err != nil { + return nil, err + } + return ctx, nil +} + +func (db *DB) afterProcess(c *contexts.ContextHook) error { + err := db.hooks.AfterProcess(c) + if db.NeedLogSQL(c.Ctx) { + db.Logger.AfterSQL(log.LogContext(*c)) + } + return err +} + +func (db *DB) AddHook(h ...contexts.Hook) { + db.hooks.AddHook(h...) +} diff --git a/core/db_test.go b/core/db_test.go new file mode 100644 index 00000000..777ab0ad --- /dev/null +++ b/core/db_test.go @@ -0,0 +1,685 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "errors" + "flag" + "os" + "testing" + "time" + + "xorm.io/xorm/names" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" +) + +var ( + dbtype = flag.String("dbtype", "sqlite3", "database type") + dbConn = flag.String("dbConn", "./db_test.db", "database connect string") + createTableSql string +) + +func TestMain(m *testing.M) { + flag.Parse() + + switch *dbtype { + case "sqlite3": + createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" + case "mysql": + fallthrough + default: + createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" + } + + exitCode := m.Run() + + os.Exit(exitCode) +} + +func testOpen() (*DB, error) { + switch *dbtype { + case "sqlite3": + os.Remove("./test.db") + return Open("sqlite3", "./test.db") + case "mysql": + return Open("mysql", *dbConn) + default: + panic("no db type") + } +} + +func BenchmarkOriQuery(b *testing.B) { + b.StopTimer() + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + var Id int64 + var Name, Title, Alias, NickName string + var Age float32 + var Created NullTime + err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created) + if err != nil { + b.Error(err) + } + //fmt.Println(Id, Name, Title, Age, Alias, NickName) + } + rows.Close() + } +} + +type User struct { + Id int64 + Name string + Title string + Age float32 + Alias string + NickName string + Created NullTime +} + +func BenchmarkStructQuery(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByIndex(&user) + if err != nil { + b.Error(err) + } + if user.Name != "xlw" { + b.Log(user) + b.Error(errors.New("name should be xlw")) + } + } + rows.Close() + } +} + +func BenchmarkStruct2Query(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + db.Mapper = names.NewCacheMapper(&names.SnakeMapper{}) + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByName(&user) + if err != nil { + b.Error(err) + } + if user.Name != "xlw" { + b.Log(user) + b.Error(errors.New("name should be xlw")) + } + } + rows.Close() + } +} + +func BenchmarkSliceInterfaceQuery(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([]interface{}, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + b.Log(slice) + switch slice[1].(type) { + case *string: + if *slice[1].(*string) != "xlw" { + b.Error(errors.New("name should be xlw")) + } + case []byte: + if string(slice[1].([]byte)) != "xlw" { + b.Error(errors.New("name should be xlw")) + } + } + } + + rows.Close() + } +} + +/*func BenchmarkSliceBytesQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([][]byte, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + if string(slice[1]) != "xlw" { + fmt.Println(slice) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} +*/ + +func BenchmarkSliceStringQuery(b *testing.B) { + b.StopTimer() + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([]*string, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + if (*slice[1]) != "xlw" { + b.Log(slice) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} + +func BenchmarkMapInterfaceQuery(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + m := make(map[string]interface{}) + err = rows.ScanMap(&m) + if err != nil { + b.Error(err) + } + switch m["name"].(type) { + case string: + if m["name"].(string) != "xlw" { + b.Log(m) + b.Error(errors.New("name should be xlw")) + } + case []byte: + if string(m["name"].([]byte)) != "xlw" { + b.Log(m) + b.Error(errors.New("name should be xlw")) + } + } + } + + rows.Close() + } +} + +/*func BenchmarkMapBytesQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + m := make(map[string][]byte) + err = rows.ScanMap(&m) + if err != nil { + b.Error(err) + } + if string(m["name"]) != "xlw" { + fmt.Println(m) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} +*/ +/* +func BenchmarkMapStringQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + m := make(map[string]string) + err = rows.ScanMap(&m) + if err != nil { + b.Error(err) + } + if m["name"] != "xlw" { + fmt.Println(m) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +}*/ + +func BenchmarkExec(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } +} + +func BenchmarkExecMap(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + mp := map[string]interface{}{ + "name": "xlw", + "title": "tester", + "age": 1.2, + "alias": "lunny", + "nick_name": "lunny xiao", + "created": time.Now(), + } + + for i := 0; i < b.N; i++ { + _, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name, created) "+ + "values (?name,?title,?age,?alias,?nick_name,?created)", + &mp) + if err != nil { + b.Error(err) + } + } +} + +func TestExecMap(t *testing.T) { + db, err := testOpen() + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + t.Error(err) + } + + mp := map[string]interface{}{ + "name": "xlw", + "title": "tester", + "age": 1.2, + "alias": "lunny", + "nick_name": "lunny xiao", + "created": time.Now(), + } + + _, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) "+ + "values (?name,?title,?age,?alias,?nick_name,?created)", + &mp) + if err != nil { + t.Error(err) + } + + rows, err := db.Query("select * from user") + if err != nil { + t.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByName(&user) + if err != nil { + t.Error(err) + } + t.Log("--", user) + } +} + +func TestExecStruct(t *testing.T) { + db, err := testOpen() + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + t.Error(err) + } + + user := User{Name: "xlw", + Title: "tester", + Age: 1.2, + Alias: "lunny", + NickName: "lunny xiao", + Created: NullTime(time.Now()), + } + + _, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+ + "values (?Name,?Title,?Age,?Alias,?NickName,?Created)", + &user) + if err != nil { + t.Error(err) + } + + rows, err := db.QueryStruct("select * from user where `name` = ?Name", &user) + if err != nil { + t.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByName(&user) + if err != nil { + t.Error(err) + } + t.Log("1--", user) + } +} + +func BenchmarkExecStruct(b *testing.B) { + b.StopTimer() + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + user := User{Name: "xlw", + Title: "tester", + Age: 1.2, + Alias: "lunny", + NickName: "lunny xiao", + Created: NullTime(time.Now()), + } + + for i := 0; i < b.N; i++ { + _, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+ + "values (?Name,?Title,?Age,?Alias,?NickName,?Created)", + &user) + if err != nil { + b.Error(err) + } + } +} diff --git a/core/error.go b/core/error.go new file mode 100644 index 00000000..1fd18348 --- /dev/null +++ b/core/error.go @@ -0,0 +1,14 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import "errors" + +var ( + // ErrNoMapPointer represents error when no map pointer + ErrNoMapPointer = errors.New("mp should be a map's pointer") + // ErrNoStructPointer represents error when no struct pointer + ErrNoStructPointer = errors.New("mp should be a struct's pointer") +) diff --git a/core/interface.go b/core/interface.go new file mode 100644 index 00000000..a5c8e4e2 --- /dev/null +++ b/core/interface.go @@ -0,0 +1,22 @@ +package core + +import ( + "context" + "database/sql" +) + +// Queryer represents an interface to query a SQL to get data from database +type Queryer interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) +} + +// Executer represents an interface to execute a SQL +type Executer interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// QueryExecuter combines the Queryer and Executer +type QueryExecuter interface { + Queryer + Executer +} diff --git a/core/rows.go b/core/rows.go new file mode 100644 index 00000000..a1e8bfbc --- /dev/null +++ b/core/rows.go @@ -0,0 +1,338 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "database/sql" + "errors" + "reflect" + "sync" +) + +type Rows struct { + *sql.Rows + db *DB +} + +func (rs *Rows) ToMapString() ([]map[string]string, error) { + cols, err := rs.Columns() + if err != nil { + return nil, err + } + + var results = make([]map[string]string, 0, 10) + for rs.Next() { + var record = make(map[string]string, len(cols)) + err = rs.ScanMap(&record) + if err != nil { + return nil, err + } + results = append(results, record) + } + return results, nil +} + +// scan data to a struct's pointer according field index +func (rs *Rows) ScanStructByIndex(dest ...interface{}) error { + if len(dest) == 0 { + return errors.New("at least one struct") + } + + vvvs := make([]reflect.Value, len(dest)) + for i, s := range dest { + vv := reflect.ValueOf(s) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return errors.New("dest should be a struct's pointer") + } + + vvvs[i] = vv.Elem() + } + + cols, err := rs.Columns() + if err != nil { + return err + } + newDest := make([]interface{}, len(cols)) + + var i = 0 + for _, vvv := range vvvs { + for j := 0; j < vvv.NumField(); j++ { + newDest[i] = vvv.Field(j).Addr().Interface() + i = i + 1 + } + } + + return rs.Rows.Scan(newDest...) +} + +var ( + fieldCache = make(map[reflect.Type]map[string]int) + fieldCacheMutex sync.RWMutex +) + +func fieldByName(v reflect.Value, name string) reflect.Value { + t := v.Type() + fieldCacheMutex.RLock() + cache, ok := fieldCache[t] + fieldCacheMutex.RUnlock() + if !ok { + cache = make(map[string]int) + for i := 0; i < v.NumField(); i++ { + cache[t.Field(i).Name] = i + } + fieldCacheMutex.Lock() + fieldCache[t] = cache + fieldCacheMutex.Unlock() + } + + if i, ok := cache[name]; ok { + return v.Field(i) + } + + return reflect.Zero(t) +} + +// scan data to a struct's pointer according field name +func (rs *Rows) ScanStructByName(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return errors.New("dest should be a struct's pointer") + } + + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + var v EmptyScanner + for j, name := range cols { + f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name)) + if f.IsValid() { + newDest[j] = f.Addr().Interface() + } else { + newDest[j] = &v + } + } + + return rs.Rows.Scan(newDest...) +} + +// scan data to a slice's pointer, slice's length should equal to columns' number +func (rs *Rows) ScanSlice(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice { + return errors.New("dest should be a slice's pointer") + } + + vvv := vv.Elem() + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + + for j := 0; j < len(cols); j++ { + if j >= vvv.Len() { + newDest[j] = reflect.New(vvv.Type().Elem()).Interface() + } else { + newDest[j] = vvv.Index(j).Addr().Interface() + } + } + + err = rs.Rows.Scan(newDest...) + if err != nil { + return err + } + + srcLen := vvv.Len() + for i := srcLen; i < len(cols); i++ { + vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem()) + } + return nil +} + +// scan data to a map's pointer +func (rs *Rows) ScanMap(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return errors.New("dest should be a map's pointer") + } + + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + vvv := vv.Elem() + + for i := range cols { + newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface() + } + + err = rs.Rows.Scan(newDest...) + if err != nil { + return err + } + + for i, name := range cols { + vname := reflect.ValueOf(name) + vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem()) + } + + return nil +} + +type Row struct { + rows *Rows + // One of these two will be non-nil: + err error // deferred error for easy chaining +} + +// ErrorRow return an error row +func ErrorRow(err error) *Row { + return &Row{ + err: err, + } +} + +// NewRow from rows +func NewRow(rows *Rows, err error) *Row { + return &Row{rows, err} +} + +func (row *Row) Columns() ([]string, error) { + if row.err != nil { + return nil, row.err + } + return row.rows.Columns() +} + +func (row *Row) Scan(dest ...interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + for _, dp := range dest { + if _, ok := dp.(*sql.RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +func (row *Row) ScanStructByName(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanStructByName(dest) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +func (row *Row) ScanStructByIndex(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanStructByIndex(dest) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +// scan data to a slice's pointer, slice's length should equal to columns' number +func (row *Row) ScanSlice(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanSlice(dest) + if err != nil { + return err + } + + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +// scan data to a map's pointer +func (row *Row) ScanMap(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanMap(dest) + if err != nil { + return err + } + + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +func (row *Row) ToMapString() (map[string]string, error) { + cols, err := row.Columns() + if err != nil { + return nil, err + } + + var record = make(map[string]string, len(cols)) + err = row.ScanMap(&record) + if err != nil { + return nil, err + } + + return record, nil +} diff --git a/core/scan.go b/core/scan.go new file mode 100644 index 00000000..897b5341 --- /dev/null +++ b/core/scan.go @@ -0,0 +1,66 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "database/sql/driver" + "fmt" + "time" +) + +type NullTime time.Time + +var ( + _ driver.Valuer = NullTime{} +) + +func (ns *NullTime) Scan(value interface{}) error { + if value == nil { + return nil + } + return convertTime(ns, value) +} + +// Value implements the driver Valuer interface. +func (ns NullTime) Value() (driver.Value, error) { + if (time.Time)(ns).IsZero() { + return nil, nil + } + return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil +} + +func convertTime(dest *NullTime, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + t, err := time.Parse("2006-01-02 15:04:05", s) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case []uint8: + t, err := time.Parse("2006-01-02 15:04:05", string(s)) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case time.Time: + *dest = NullTime(s) + return nil + case nil: + default: + return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) + } + return nil +} + +type EmptyScanner struct { +} + +func (EmptyScanner) Scan(src interface{}) error { + return nil +} diff --git a/core/stmt.go b/core/stmt.go new file mode 100644 index 00000000..d46ac9c6 --- /dev/null +++ b/core/stmt.go @@ -0,0 +1,194 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "context" + "database/sql" + "errors" + "reflect" + + "xorm.io/xorm/contexts" +) + +// Stmt reprents a stmt objects +type Stmt struct { + *sql.Stmt + db *DB + names map[string]int + query string +} + +func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + names := make(map[string]int) + var i int + query = re.ReplaceAllStringFunc(query, func(src string) string { + names[src[1:]] = i + i++ + return "?" + }) + hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil) + ctx, err := db.beforeProcess(hookCtx) + if err != nil { + return nil, err + } + stmt, err := db.DB.PrepareContext(ctx, query) + hookCtx.End(ctx, nil, err) + if err := db.afterProcess(hookCtx); err != nil { + return nil, err + } + return &Stmt{stmt, db, names, query}, nil +} + +func (db *DB) Prepare(query string) (*Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + return s.ExecContext(ctx, args...) +} + +func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { + return s.ExecMapContext(context.Background(), mp) +} + +func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + return s.ExecContext(ctx, args...) +} + +func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { + return s.ExecStructContext(context.Background(), st) +} + +func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + hookCtx := contexts.NewContextHook(ctx, s.query, args) + ctx, err := s.db.beforeProcess(hookCtx) + if err != nil { + return nil, err + } + res, err := s.Stmt.ExecContext(ctx, args) + hookCtx.End(ctx, res, err) + if err := s.db.afterProcess(hookCtx); err != nil { + return nil, err + } + return res, nil +} + +func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { + hookCtx := contexts.NewContextHook(ctx, s.query, args) + ctx, err := s.db.beforeProcess(hookCtx) + if err != nil { + return nil, err + } + rows, err := s.Stmt.QueryContext(ctx, args...) + hookCtx.End(ctx, nil, err) + if err := s.db.afterProcess(hookCtx); err != nil { + return nil, err + } + return &Rows{rows, s.db}, nil +} + +func (s *Stmt) Query(args ...interface{}) (*Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + + return s.QueryContext(ctx, args...) +} + +func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { + return s.QueryMapContext(context.Background(), mp) +} + +func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + + return s.QueryContext(ctx, args...) +} + +func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { + return s.QueryStructContext(context.Background(), st) +} + +func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { + rows, err := s.QueryContext(ctx, args...) + return &Row{rows, err} +} + +func (s *Stmt) QueryRow(args ...interface{}) *Row { + return s.QueryRowContext(context.Background(), args...) +} + +func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return &Row{nil, errors.New("mp should be a map's pointer")} + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + + return s.QueryRowContext(ctx, args...) +} + +func (s *Stmt) QueryRowMap(mp interface{}) *Row { + return s.QueryRowMapContext(context.Background(), mp) +} + +func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return &Row{nil, errors.New("st should be a struct's pointer")} + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + + return s.QueryRowContext(ctx, args...) +} + +func (s *Stmt) QueryRowStruct(st interface{}) *Row { + return s.QueryRowStructContext(context.Background(), st) +} diff --git a/core/tx.go b/core/tx.go new file mode 100644 index 00000000..9b2988af --- /dev/null +++ b/core/tx.go @@ -0,0 +1,190 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "context" + "database/sql" + + "xorm.io/xorm/contexts" +) + +var ( + _ QueryExecuter = &Tx{} +) + +// Tx represents a transaction +type Tx struct { + *sql.Tx + db *DB +} + +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil) + ctx, err := db.beforeProcess(hookCtx) + if err != nil { + return nil, err + } + tx, err := db.DB.BeginTx(ctx, opts) + hookCtx.End(ctx, nil, err) + if err := db.afterProcess(hookCtx); err != nil { + return nil, err + } + return &Tx{tx, db}, nil +} + +func (db *DB) Begin() (*Tx, error) { + return db.BeginTx(context.Background(), nil) +} + +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + names := make(map[string]int) + var i int + query = re.ReplaceAllStringFunc(query, func(src string) string { + names[src[1:]] = i + i++ + return "?" + }) + hookCtx := contexts.NewContextHook(ctx, "PREPARE", nil) + ctx, err := tx.db.beforeProcess(hookCtx) + if err != nil { + return nil, err + } + stmt, err := tx.Tx.PrepareContext(ctx, query) + hookCtx.End(ctx, nil, err) + if err := tx.db.afterProcess(hookCtx); err != nil { + return nil, err + } + return &Stmt{stmt, tx.db, names, query}, nil +} + +func (tx *Tx) Prepare(query string) (*Stmt, error) { + return tx.PrepareContext(context.Background(), query) +} + +func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { + stmt.Stmt = tx.Tx.StmtContext(ctx, stmt.Stmt) + return stmt +} + +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + return tx.StmtContext(context.Background(), stmt) +} + +func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return tx.ExecContext(ctx, query, args...) +} + +func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) { + return tx.ExecMapContext(context.Background(), query, mp) +} + +func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return tx.ExecContext(ctx, query, args...) +} + +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := tx.db.beforeProcess(hookCtx) + if err != nil { + return nil, err + } + res, err := tx.Tx.ExecContext(ctx, query, args...) + hookCtx.End(ctx, res, err) + if err := tx.db.afterProcess(hookCtx); err != nil { + return nil, err + } + return res, err +} + +func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { + return tx.ExecStructContext(context.Background(), query, st) +} + +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + hookCtx := contexts.NewContextHook(ctx, query, args) + ctx, err := tx.db.beforeProcess(hookCtx) + if err != nil { + return nil, err + } + rows, err := tx.Tx.QueryContext(ctx, query, args...) + hookCtx.End(ctx, nil, err) + if err := tx.db.afterProcess(hookCtx); err != nil { + if rows != nil { + rows.Close() + } + return nil, err + } + return &Rows{rows, tx.db}, nil +} + +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { + return tx.QueryContext(context.Background(), query, args...) +} + +func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return tx.QueryContext(ctx, query, args...) +} + +func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { + return tx.QueryMapContext(context.Background(), query, mp) +} + +func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return tx.QueryContext(ctx, query, args...) +} + +func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) { + return tx.QueryStructContext(context.Background(), query, st) +} + +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := tx.QueryContext(ctx, query, args...) + return &Row{rows, err} +} + +func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { + return tx.QueryRowContext(context.Background(), query, args...) +} + +func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { + query, args, err := MapToSlice(query, mp) + if err != nil { + return &Row{nil, err} + } + return tx.QueryRowContext(ctx, query, args...) +} + +func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { + return tx.QueryRowMapContext(context.Background(), query, mp) +} + +func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { + query, args, err := StructToSlice(query, st) + if err != nil { + return &Row{nil, err} + } + return tx.QueryRowContext(ctx, query, args...) +} + +func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { + return tx.QueryRowStructContext(context.Background(), query, st) +} diff --git a/dialects/dialect.go b/dialects/dialect.go new file mode 100644 index 00000000..dc96f73a --- /dev/null +++ b/dialects/dialect.go @@ -0,0 +1,284 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "context" + "fmt" + "strings" + "time" + + "xorm.io/xorm/core" + "xorm.io/xorm/schemas" +) + +// URI represents an uri to visit database +type URI struct { + DBType schemas.DBType + Proto string + Host string + Port string + DBName string + User string + Passwd string + Charset string + Laddr string + Raddr string + Timeout time.Duration + Schema string +} + +// SetSchema set schema +func (uri *URI) SetSchema(schema string) { + // hack me + if uri.DBType == schemas.POSTGRES { + uri.Schema = strings.TrimSpace(schema) + } +} + +// Dialect represents a kind of database +type Dialect interface { + Init(*URI) error + URI() *URI + SQLType(*schemas.Column) string + FormatBytes(b []byte) string + + IsReserved(string) bool + Quoter() schemas.Quoter + SetQuotePolicy(quotePolicy QuotePolicy) + + AutoIncrStr() string + + GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) + IndexCheckSQL(tableName, idxName string) (string, []interface{}) + CreateIndexSQL(tableName string, index *schemas.Index) string + DropIndexSQL(tableName string, index *schemas.Index) string + + GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) + IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) + CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) + DropTableSQL(tableName string) (string, bool) + + GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) + IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) + AddColumnSQL(tableName string, col *schemas.Column) string + ModifyColumnSQL(tableName string, col *schemas.Column) string + + ForUpdateSQL(query string) string + + Filters() []Filter + SetParams(params map[string]string) +} + +// Base represents a basic dialect and all real dialects could embed this struct +type Base struct { + dialect Dialect + uri *URI + quoter schemas.Quoter +} + +func (b *Base) Quoter() schemas.Quoter { + return b.quoter +} + +func (b *Base) Init(dialect Dialect, uri *URI) error { + b.dialect, b.uri = dialect, uri + return nil +} + +func (b *Base) URI() *URI { + return b.uri +} + +func (b *Base) DBType() schemas.DBType { + return b.uri.DBType +} + +func (b *Base) FormatBytes(bs []byte) string { + return fmt.Sprintf("0x%x", bs) +} + +func (db *Base) DropTableSQL(tableName string) (string, bool) { + quote := db.dialect.Quoter().Quote + return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true +} + +func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) { + rows, err := queryer.QueryContext(ctx, query, args...) + if err != nil { + return false, err + } + defer rows.Close() + + if rows.Next() { + return true, nil + } + return false, nil +} + +func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { + quote := db.dialect.Quoter().Quote + query := fmt.Sprintf( + "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?", + quote("COLUMN_NAME"), + quote("INFORMATION_SCHEMA"), + quote("COLUMNS"), + quote("TABLE_SCHEMA"), + quote("TABLE_NAME"), + quote("COLUMN_NAME"), + ) + return db.HasRecords(queryer, ctx, query, db.uri.DBName, tableName, colName) +} + +func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, true) + return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), s) +} + +func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { + quoter := db.dialect.Quoter() + var unique string + var idxName string + if index.Type == schemas.UniqueType { + unique = " UNIQUE" + } + idxName = index.XName(tableName) + return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, + quoter.Quote(idxName), quoter.Quote(tableName), + quoter.Join(index.Cols, ",")) +} + +func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { + quote := db.dialect.Quoter().Quote + var name string + if index.IsRegular { + name = index.XName(tableName) + } else { + name = index.Name + } + return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName)) +} + +func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, false) + return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, s) +} + +func (b *Base) ForUpdateSQL(query string) string { + return query + " FOR UPDATE" +} + +func (b *Base) SetParams(params map[string]string) { +} + +var ( + dialects = map[string]func() Dialect{} +) + +// RegisterDialect register database dialect +func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { + if dialectFunc == nil { + panic("core: Register dialect is nil") + } + dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect +} + +// QueryDialect query if registered database dialect +func QueryDialect(dbName schemas.DBType) Dialect { + if d, ok := dialects[strings.ToLower(string(dbName))]; ok { + return d() + } + return nil +} + +func regDrvsNDialects() bool { + providedDrvsNDialects := map[string]struct { + dbType schemas.DBType + getDriver func() Driver + getDialect func() Dialect + }{ + "mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, + "odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access + "mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }}, + "mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }}, + "postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }}, + "pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }}, + "sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }}, + "oci8": {"oracle", func() Driver { return &oci8Driver{} }, func() Dialect { return &oracle{} }}, + "goracle": {"oracle", func() Driver { return &goracleDriver{} }, func() Dialect { return &oracle{} }}, + } + + for driverName, v := range providedDrvsNDialects { + if driver := QueryDriver(driverName); driver == nil { + RegisterDriver(driverName, v.getDriver()) + RegisterDialect(v.dbType, v.getDialect) + } + } + return true +} + +func init() { + regDrvsNDialects() +} + +// ColumnString generate column description string according dialect +func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { + bd := strings.Builder{} + + if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { + return "", err + } + + if err := bd.WriteByte(' '); err != nil { + return "", err + } + + if _, err := bd.WriteString(dialect.SQLType(col)); err != nil { + return "", err + } + + if err := bd.WriteByte(' '); err != nil { + return "", err + } + + if includePrimaryKey && col.IsPrimaryKey { + if _, err := bd.WriteString("PRIMARY KEY "); err != nil { + return "", err + } + + if col.IsAutoIncrement { + if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { + return "", err + } + if err := bd.WriteByte(' '); err != nil { + return "", err + } + } + } + + if col.Default != "" { + if _, err := bd.WriteString("DEFAULT "); err != nil { + return "", err + } + if _, err := bd.WriteString(col.Default); err != nil { + return "", err + } + if err := bd.WriteByte(' '); err != nil { + return "", err + } + } + + if col.Nullable { + if _, err := bd.WriteString("NULL "); err != nil { + return "", err + } + } else { + if _, err := bd.WriteString("NOT NULL "); err != nil { + return "", err + } + } + + return bd.String(), nil +} diff --git a/dialects/driver.go b/dialects/driver.go new file mode 100644 index 00000000..ae3afe42 --- /dev/null +++ b/dialects/driver.go @@ -0,0 +1,57 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "fmt" +) + +type Driver interface { + Parse(string, string) (*URI, error) +} + +var ( + drivers = map[string]Driver{} +) + +func RegisterDriver(driverName string, driver Driver) { + if driver == nil { + panic("core: Register driver is nil") + } + if _, dup := drivers[driverName]; dup { + panic("core: Register called twice for driver " + driverName) + } + drivers[driverName] = driver +} + +func QueryDriver(driverName string) Driver { + return drivers[driverName] +} + +func RegisteredDriverSize() int { + return len(drivers) +} + +// OpenDialect opens a dialect via driver name and connection string +func OpenDialect(driverName, connstr string) (Dialect, error) { + driver := QueryDriver(driverName) + if driver == nil { + return nil, fmt.Errorf("Unsupported driver name: %v", driverName) + } + + uri, err := driver.Parse(driverName, connstr) + if err != nil { + return nil, err + } + + dialect := QueryDialect(uri.DBType) + if dialect == nil { + return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) + } + + dialect.Init(uri) + + return dialect, nil +} diff --git a/dialects/filter.go b/dialects/filter.go new file mode 100644 index 00000000..6968b6ce --- /dev/null +++ b/dialects/filter.go @@ -0,0 +1,43 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "fmt" + "strings" +) + +// Filter is an interface to filter SQL +type Filter interface { + Do(sql string) string +} + +// SeqFilter filter SQL replace ?, ? ... to $1, $2 ... +type SeqFilter struct { + Prefix string + Start int +} + +func convertQuestionMark(sql, prefix string, start int) string { + var buf strings.Builder + var beginSingleQuote bool + var index = start + for _, c := range sql { + if !beginSingleQuote && c == '?' { + buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) + index++ + } else { + if c == '\'' { + beginSingleQuote = !beginSingleQuote + } + buf.WriteRune(c) + } + } + return buf.String() +} + +func (s *SeqFilter) Do(sql string) string { + return convertQuestionMark(sql, s.Prefix, s.Start) +} diff --git a/dialects/filter_test.go b/dialects/filter_test.go new file mode 100644 index 00000000..7e2ef0a2 --- /dev/null +++ b/dialects/filter_test.go @@ -0,0 +1,21 @@ +package dialects + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSeqFilter(t *testing.T) { + var kases = map[string]string{ + "SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2", + "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=? AND b=?": "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=$1 AND b=$2", + "select '1''?' from issue": "select '1''?' from issue", + "select '1\\??' from issue": "select '1\\??' from issue", + "select '1\\\\',? from issue": "select '1\\\\',$1 from issue", + "select '1\\''?',? from issue": "select '1\\''?',$1 from issue", + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} diff --git a/gen_reserved.sh b/dialects/gen_reserved.sh similarity index 100% rename from gen_reserved.sh rename to dialects/gen_reserved.sh diff --git a/dialect_mssql.go b/dialects/mssql.go similarity index 77% rename from dialect_mssql.go rename to dialects/mssql.go index 29070da2..8ef924b8 100644 --- a/dialect_mssql.go +++ b/dialects/mssql.go @@ -2,16 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( + "context" "errors" "fmt" "net/url" "strconv" "strings" - "xorm.io/core" + "xorm.io/xorm/core" + "xorm.io/xorm/schemas" ) var ( @@ -202,67 +204,74 @@ var ( "EXIT": true, "PROC": true, } + + mssqlQuoter = schemas.Quoter{ + Prefix: '[', + Suffix: ']', + IsReserved: schemas.AlwaysReserve, + } ) type mssql struct { - core.Base + Base } -func (db *mssql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mssql) Init(uri *URI) error { + db.quoter = mssqlQuoter + return db.Base.Init(db, uri) } -func (db *mssql) SqlType(c *core.Column) string { +func (db *mssql) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case core.Bool: - res = core.Bit + case schemas.Bool: + res = schemas.Bit if strings.EqualFold(c.Default, "true") { c.Default = "1" } else if strings.EqualFold(c.Default, "false") { c.Default = "0" } - case core.Serial: + case schemas.Serial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = core.Int - case core.BigSerial: + res = schemas.Int + case schemas.BigSerial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = core.BigInt - case core.Bytea, core.Blob, core.Binary, core.TinyBlob, core.MediumBlob, core.LongBlob: - res = core.VarBinary + res = schemas.BigInt + case schemas.Bytea, schemas.Blob, schemas.Binary, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob: + res = schemas.VarBinary if c.Length == 0 { c.Length = 50 } - case core.TimeStamp: - res = core.DateTime - case core.TimeStampz: + case schemas.TimeStamp: + res = schemas.DateTime + case schemas.TimeStampz: res = "DATETIMEOFFSET" c.Length = 7 - case core.MediumInt: - res = core.Int - case core.Text, core.MediumText, core.TinyText, core.LongText, core.Json: - res = core.Varchar + "(MAX)" - case core.Double: - res = core.Real - case core.Uuid: - res = core.Varchar + case schemas.MediumInt: + res = schemas.Int + case schemas.Text, schemas.MediumText, schemas.TinyText, schemas.LongText, schemas.Json: + res = schemas.Varchar + "(MAX)" + case schemas.Double: + res = schemas.Real + case schemas.Uuid: + res = schemas.Varchar c.Length = 40 - case core.TinyInt: - res = core.TinyInt + case schemas.TinyInt: + res = schemas.TinyInt c.Length = 0 - case core.BigInt: - res = core.BigInt + case schemas.BigInt: + res = schemas.BigInt c.Length = 0 default: res = t } - if res == core.Int { - return core.Int + if res == schemas.Int { + return schemas.Int } hasLen1 := (c.Length > 0) @@ -276,88 +285,78 @@ func (db *mssql) SqlType(c *core.Column) string { return res } -func (db *mssql) SupportInsertMany() bool { - return true -} - func (db *mssql) IsReserved(name string) bool { - _, ok := mssqlReservedWords[name] + _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok } -func (db *mssql) Quote(name string) string { - return "\"" + name + "\"" -} - -func (db *mssql) SupportEngine() bool { - return false +func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = mssqlQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = mssqlQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = mssqlQuoter + } } func (db *mssql) AutoIncrStr() string { return "IDENTITY" } -func (db *mssql) DropTableSql(tableName string) string { +func (db *mssql) DropTableSQL(tableName string) (string, bool) { return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ - "DROP TABLE \"%s\"", tableName, tableName) + "DROP TABLE \"%s\"", tableName, tableName), true } -func (db *mssql) SupportCharset() bool { - return false -} - -func (db *mssql) IndexOnTable() bool { - return true -} - -func (db *mssql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { +func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{idxName} sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?" return sql, args } -/*func (db *mssql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName, colName} - sql := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` - return sql, args -}*/ - -func (db *mssql) IsColumnExist(tableName, colName string) (bool, error) { +func (db *mssql) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` - return db.HasRecords(query, tableName, colName) + return db.HasRecords(queryer, ctx, query, tableName, colName) } -func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{} +func (db *mssql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1" - return sql, args + return db.HasRecords(queryer, ctx, sql) } -func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { +func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{} s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, - ISNULL(i.is_primary_key, 0), a.is_identity as is_identity + ISNULL(p.is_primary_key, 0), a.is_identity as is_identity from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id left join sys.syscomments c on a.default_object_id=c.id - LEFT OUTER JOIN - sys.index_columns ic ON ic.object_id = a.object_id AND ic.column_id = a.column_id - LEFT OUTER JOIN - sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id + LEFT OUTER JOIN (SELECT i.object_id, ic.column_id, i.is_primary_key + FROM sys.indexes i + LEFT JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id + WHERE i.is_primary_key = 1 + ) as p on p.object_id = a.object_id AND p.column_id = a.column_id where a.object_id=object_id('` + tableName + `')` - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } defer rows.Close() - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { var name, ctype, vdefault string @@ -368,7 +367,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column return nil, nil, err } - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) col.Name = strings.Trim(name, "` ") col.Nullable = nullable @@ -387,14 +386,14 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column } switch ct { case "DATETIMEOFFSET": - col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} case "NVARCHAR": - col.SQLType = core.SQLType{Name: core.NVarchar, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.NVarchar, DefaultLength: 0, DefaultLength2: 0} case "IMAGE": - col.SQLType = core.SQLType{Name: core.VarBinary, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.VarBinary, DefaultLength: 0, DefaultLength2: 0} default: - if _, ok := core.SqlTypes[ct]; ok { - col.SQLType = core.SQLType{Name: ct, DefaultLength: 0, DefaultLength2: 0} + if _, ok := schemas.SqlTypes[ct]; ok { + col.SQLType = schemas.SQLType{Name: ct, DefaultLength: 0, DefaultLength2: 0} } else { return nil, nil, fmt.Errorf("Unknown colType %v for %v - %v", ct, tableName, col.Name) } @@ -406,20 +405,19 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column return colSeq, cols, nil } -func (db *mssql) GetTables() ([]*core.Table, error) { +func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := `select name from sysobjects where xtype ='U'` - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() + table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) if err != nil { @@ -431,7 +429,7 @@ func (db *mssql) GetTables() ([]*core.Table, error) { return tables, nil } -func (db *mssql) GetIndexes(tableName string) (map[string]*core.Index, error) { +func (db *mssql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := `SELECT IXS.NAME AS [INDEX_NAME], @@ -444,15 +442,14 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID AND IXCS.COLUMN_ID=C.COLUMN_ID WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? ` - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var indexType int var indexName, colName, isUnique string @@ -468,9 +465,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } if i { - indexType = core.UniqueType + indexType = schemas.UniqueType } else { - indexType = core.IndexType + indexType = schemas.IndexType } colName = strings.Trim(colName, "` ") @@ -480,10 +477,10 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? isRegular = true } - var index *core.Index + var index *schemas.Index var ok bool if index, ok = indexes[indexName]; !ok { - index = new(core.Index) + index = new(schemas.Index) index.Type = indexType index.Name = indexName index.IsRegular = isRegular @@ -494,7 +491,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? return indexes, nil } -func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { +func (db *mssql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { var sql string if tableName == "" { tableName = table.Name @@ -502,17 +499,14 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + tableName + "' ) CREATE TABLE " - sql += db.Quote(tableName) + " (" + sql += db.Quoter().Quote(tableName) + " (" pkList := table.PrimaryKeys for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - if col.IsPrimaryKey && len(pkList) == 1 { - sql += col.String(db) - } else { - sql += col.StringNoPk(db) - } + s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) + sql += s sql = strings.TrimSpace(sql) sql += ", " } @@ -525,21 +519,21 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars sql = sql[:len(sql)-2] + ")" sql += ";" - return sql + return []string{sql}, true } -func (db *mssql) ForUpdateSql(query string) string { +func (db *mssql) ForUpdateSQL(query string) string { return query } -func (db *mssql) Filters() []core.Filter { - return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}} +func (db *mssql) Filters() []Filter { + return []Filter{} } type odbcDriver struct { } -func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { +func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { var dbName string if strings.HasPrefix(dataSourceName, "sqlserver://") { @@ -563,5 +557,5 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) if dbName == "" { return nil, errors.New("no db name provided") } - return &core.Uri{DbName: dbName, DbType: core.MSSQL}, nil + return &URI{DBName: dbName, DBType: schemas.MSSQL}, nil } diff --git a/dialect_mssql_test.go b/dialects/mssql_test.go similarity index 84% rename from dialect_mssql_test.go rename to dialects/mssql_test.go index acd1d059..168f1777 100644 --- a/dialect_mssql_test.go +++ b/dialects/mssql_test.go @@ -2,13 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( "reflect" "testing" - - "xorm.io/core" ) func TestParseMSSQL(t *testing.T) { @@ -21,15 +19,15 @@ func TestParseMSSQL(t *testing.T) { {"server=localhost;user id=sa;password=yourStrong(!)Password;database=db", "db", true}, } - driver := core.QueryDriver("mssql") + driver := QueryDriver("mssql") for _, test := range tests { uri, err := driver.Parse("mssql", test.in) if err != nil && test.valid { t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } } } diff --git a/dialect_mysql.go b/dialects/mysql.go similarity index 76% rename from dialect_mysql.go rename to dialects/mysql.go index cf1dbb6f..f9a2e943 100644 --- a/dialect_mysql.go +++ b/dialects/mysql.go @@ -2,9 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( + "context" "crypto/tls" "errors" "fmt" @@ -13,7 +14,8 @@ import ( "strings" "time" - "xorm.io/core" + "xorm.io/xorm/core" + "xorm.io/xorm/schemas" ) var ( @@ -159,10 +161,16 @@ var ( "YEAR_MONTH": true, "ZEROFILL": true, } + + mysqlQuoter = schemas.Quoter{ + Prefix: '`', + Suffix: '`', + IsReserved: schemas.AlwaysReserve, + } ) type mysql struct { - core.Base + Base net string addr string params map[string]string @@ -175,8 +183,9 @@ type mysql struct { rowFormat string } -func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *mysql) Init(uri *URI) error { + db.quoter = mysqlQuoter + return db.Base.Init(db, uri) } func (db *mysql) SetParams(params map[string]string) { @@ -199,29 +208,29 @@ func (db *mysql) SetParams(params map[string]string) { } } -func (db *mysql) SqlType(c *core.Column) string { +func (db *mysql) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case core.Bool: - res = core.TinyInt + case schemas.Bool: + res = schemas.TinyInt c.Length = 1 - case core.Serial: + case schemas.Serial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = core.Int - case core.BigSerial: + res = schemas.Int + case schemas.BigSerial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = core.BigInt - case core.Bytea: - res = core.Blob - case core.TimeStampz: - res = core.Char + res = schemas.BigInt + case schemas.Bytea: + res = schemas.Blob + case schemas.TimeStampz: + res = schemas.Char c.Length = 64 - case core.Enum: // mysql enum - res = core.Enum + case schemas.Enum: // mysql enum + res = schemas.Enum res += "(" opts := "" for v := range c.EnumOptions { @@ -229,8 +238,8 @@ func (db *mysql) SqlType(c *core.Column) string { } res += strings.TrimLeft(opts, ",") res += ")" - case core.Set: // mysql set - res = core.Set + case schemas.Set: // mysql set + res = schemas.Set res += "(" opts := "" for v := range c.SetOptions { @@ -238,13 +247,13 @@ func (db *mysql) SqlType(c *core.Column) string { } res += strings.TrimLeft(opts, ",") res += ")" - case core.NVarchar: - res = core.Varchar - case core.Uuid: - res = core.Varchar + case schemas.NVarchar: + res = schemas.Varchar + case schemas.Uuid: + res = schemas.Varchar c.Length = 40 - case core.Json: - res = core.Text + case schemas.Json: + res = schemas.Text default: res = t } @@ -252,7 +261,7 @@ func (db *mysql) SqlType(c *core.Column) string { hasLen1 := (c.Length > 0) hasLen2 := (c.Length2 > 0) - if res == core.BigInt && !hasLen1 && !hasLen2 { + if res == schemas.BigInt && !hasLen1 && !hasLen2 { c.Length = 20 hasLen1 = true } @@ -265,70 +274,53 @@ func (db *mysql) SqlType(c *core.Column) string { return res } -func (db *mysql) SupportInsertMany() bool { - return true -} - func (db *mysql) IsReserved(name string) bool { - _, ok := mysqlReservedWords[name] + _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok } -func (db *mysql) Quote(name string) string { - return "`" + name + "`" -} - -func (db *mysql) SupportEngine() bool { - return true -} - func (db *mysql) AutoIncrStr() string { return "AUTO_INCREMENT" } -func (db *mysql) SupportCharset() bool { - return true -} - -func (db *mysql) IndexOnTable() bool { - return true -} - -func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{db.DbName, tableName, idxName} +func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { + args := []interface{}{db.uri.DBName, tableName, idxName} sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" return sql, args } -/*func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{db.DbName, tableName, colName} - sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - return sql, args -}*/ - -func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{db.DbName, tableName} +func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" - return sql, args + return db.HasRecords(queryer, ctx, sql, db.uri.DBName, tableName) } -func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { - args := []interface{}{db.DbName, tableName} - s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + - " `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - db.LogSQL(s, args) +func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { + quoter := db.dialect.Quoter() + s, _ := ColumnString(db, col, true) + sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), s) + if len(col.Comment) > 0 { + sql += " COMMENT '" + col.Comment + "'" + } + return sql +} - rows, err := db.DB().Query(s, args...) +func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { + args := []interface{}{db.uri.DBName, tableName} + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + + " `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + + " ORDER BY `INFORMATION_SCHEMA`.`COLUMNS`.ORDINAL_POSITION" + + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } defer rows.Close() - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) var columnName, isNullable, colType, colKey, extra, comment string @@ -356,7 +348,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column var len1, len2 int if len(cts) == 2 { idx := strings.Index(cts[1], ")") - if colType == core.Enum && cts[1][0] == '\'' { // enum + if colType == schemas.Enum && cts[1][0] == '\'' { // enum options := strings.Split(cts[1][0:idx], ",") col.EnumOptions = make(map[string]int) for k, v := range options { @@ -364,7 +356,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column v = strings.Trim(v, "'") col.EnumOptions[v] = k } - } else if colType == core.Set && cts[1][0] == '\'' { + } else if colType == schemas.Set && cts[1][0] == '\'' { options := strings.Split(cts[1][0:idx], ",") col.SetOptions = make(map[string]int) for k, v := range options { @@ -394,8 +386,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column } col.Length = len1 col.Length2 = len2 - if _, ok := core.SqlTypes[colType]; ok { - col.SQLType = core.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} + if _, ok := schemas.SqlTypes[colType]; ok { + col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} } else { return nil, nil, fmt.Errorf("Unknown colType %v", colType) } @@ -424,48 +416,65 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column return colSeq, cols, nil } -func (db *mysql) GetTables() ([]*core.Table, error) { - args := []interface{}{db.DbName} - s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + +func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { + args := []interface{}{db.uri.DBName} + s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() - var name, engine, tableRows, comment string - var autoIncr *string - err = rows.Scan(&name, &engine, &tableRows, &autoIncr, &comment) + table := schemas.NewEmptyTable() + var name, engine string + var autoIncr, comment *string + err = rows.Scan(&name, &engine, &autoIncr, &comment) if err != nil { return nil, err } table.Name = name - table.Comment = comment + if comment != nil { + table.Comment = *comment + } table.StoreEngine = engine tables = append(tables, table) } return tables, nil } -func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { - args := []interface{}{db.DbName, tableName} - s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - db.LogSQL(s, args) +func (db *mysql) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = mysqlQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = mysqlQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = mysqlQuoter + } +} - rows, err := db.DB().Query(s, args...) +func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { + args := []interface{}{db.uri.DBName, tableName} + s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var indexType int var indexName, colName, nonUnique string @@ -479,9 +488,9 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { } if "YES" == nonUnique || nonUnique == "1" { - indexType = core.IndexType + indexType = schemas.IndexType } else { - indexType = core.UniqueType + indexType = schemas.UniqueType } colName = strings.Trim(colName, "` ") @@ -491,10 +500,10 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { isRegular = true } - var index *core.Index + var index *schemas.Index var ok bool if index, ok = indexes[indexName]; !ok { - index = new(core.Index) + index = new(schemas.Index) index.IsRegular = isRegular index.Type = indexType index.Name = indexName @@ -505,14 +514,15 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { return indexes, nil } -func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { - var sql string - sql = "CREATE TABLE IF NOT EXISTS " +func (db *mysql) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { + var sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { tableName = table.Name } - sql += db.Quote(tableName) + quoter := db.Quoter() + + sql += quoter.Quote(tableName) sql += " (" if len(table.ColumnsSeq()) > 0 { @@ -520,11 +530,8 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars for _, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - if col.IsPrimaryKey && len(pkList) == 1 { - sql += col.String(db) - } else { - sql += col.StringNoPk(db) - } + s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) + sql += s sql = strings.TrimSpace(sql) if len(col.Comment) > 0 { sql += " COMMENT '" + col.Comment + "'" @@ -534,7 +541,7 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars if len(pkList) > 1 { sql += "PRIMARY KEY ( " - sql += db.Quote(strings.Join(pkList, db.Quote(","))) + sql += quoter.Join(pkList, ",") sql += " ), " } @@ -542,10 +549,11 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars } sql += ")" - if storeEngine != "" { - sql += " ENGINE=" + storeEngine + if table.StoreEngine != "" { + sql += " ENGINE=" + table.StoreEngine } + var charset = table.Charset if len(charset) == 0 { charset = db.URI().Charset } @@ -556,18 +564,18 @@ func (db *mysql) CreateTableSql(table *core.Table, tableName, storeEngine, chars if db.rowFormat != "" { sql += " ROW_FORMAT=" + db.rowFormat } - return sql + return []string{sql}, true } -func (db *mysql) Filters() []core.Filter { - return []core.Filter{&core.IdFilter{}} +func (db *mysql) Filters() []Filter { + return []Filter{} } type mymysqlDriver struct { } -func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.MYSQL} +func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { + uri := &URI{DBType: schemas.MYSQL} pd := strings.SplitN(dataSourceName, "*", 2) if len(pd) == 2 { @@ -576,9 +584,9 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err if len(p) != 2 { return nil, errors.New("Wrong protocol part of URI") } - db.Proto = p[0] + uri.Proto = p[0] options := strings.Split(p[1], ",") - db.Raddr = options[0] + uri.Raddr = options[0] for _, o := range options[1:] { kv := strings.SplitN(o, "=", 2) var k, v string @@ -589,13 +597,13 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err } switch k { case "laddr": - db.Laddr = v + uri.Laddr = v case "timeout": to, err := time.ParseDuration(v) if err != nil { return nil, err } - db.Timeout = to + uri.Timeout = to default: return nil, errors.New("Unknown option: " + k) } @@ -608,17 +616,17 @@ func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, err if len(dup) != 3 { return nil, errors.New("Wrong database part of URI") } - db.DbName = dup[0] - db.User = dup[1] - db.Passwd = dup[2] + uri.DBName = dup[0] + uri.User = dup[1] + uri.Passwd = dup[2] - return db, nil + return uri, nil } type mysqlDriver struct { } -func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { +func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { dsnPattern := regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] @@ -628,12 +636,12 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error // tlsConfigRegister := make(map[string]*tls.Config) names := dsnPattern.SubexpNames() - uri := &core.Uri{DbType: core.MYSQL} + uri := &URI{DBType: schemas.MYSQL} for i, match := range matches { switch names[i] { case "dbname": - uri.DbName = match + uri.DBName = match case "params": if len(match) > 0 { kvs := strings.Split(match, "&") diff --git a/dialect_oracle.go b/dialects/oracle.go similarity index 83% rename from dialect_oracle.go rename to dialects/oracle.go index 15010ca5..91eed251 100644 --- a/dialect_oracle.go +++ b/dialects/oracle.go @@ -2,16 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( + "context" "errors" "fmt" "regexp" "strconv" "strings" - "xorm.io/core" + "xorm.io/xorm/core" + "xorm.io/xorm/schemas" ) var ( @@ -496,32 +498,39 @@ var ( "YEAR": true, "ZONE": true, } + + oracleQuoter = schemas.Quoter{ + Prefix: '"', + Suffix: '"', + IsReserved: schemas.AlwaysReserve, + } ) type oracle struct { - core.Base + Base } -func (db *oracle) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *oracle) Init(uri *URI) error { + db.quoter = oracleQuoter + return db.Base.Init(db, uri) } -func (db *oracle) SqlType(c *core.Column) string { +func (db *oracle) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial: + case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt, schemas.Bool, schemas.Serial, schemas.BigSerial: res = "NUMBER" - case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea: - return core.Blob - case core.Time, core.DateTime, core.TimeStamp: - res = core.TimeStamp - case core.TimeStampz: + case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: + return schemas.Blob + case schemas.Time, schemas.DateTime, schemas.TimeStamp: + res = schemas.TimeStamp + case schemas.TimeStampz: res = "TIMESTAMP WITH TIME ZONE" - case core.Float, core.Double, core.Numeric, core.Decimal: + case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal: res = "NUMBER" - case core.Text, core.MediumText, core.LongText, core.Json: + case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json: res = "CLOB" - case core.Char, core.Varchar, core.TinyText: + case schemas.Char, schemas.Varchar, schemas.TinyText: res = "VARCHAR2" default: res = t @@ -542,47 +551,23 @@ func (db *oracle) AutoIncrStr() string { return "AUTO_INCREMENT" } -func (db *oracle) SupportInsertMany() bool { - return true -} - func (db *oracle) IsReserved(name string) bool { - _, ok := oracleReservedWords[name] + _, ok := oracleReservedWords[strings.ToUpper(name)] return ok } -func (db *oracle) Quote(name string) string { - return "[" + name + "]" +func (db *oracle) DropTableSQL(tableName string) (string, bool) { + return fmt.Sprintf("DROP TABLE `%s`", tableName), false } -func (db *oracle) SupportEngine() bool { - return false -} - -func (db *oracle) SupportCharset() bool { - return false -} - -func (db *oracle) SupportDropIfExists() bool { - return false -} - -func (db *oracle) IndexOnTable() bool { - return false -} - -func (db *oracle) DropTableSql(tableName string) string { - return fmt.Sprintf("DROP TABLE `%s`", tableName) -} - -func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { - var sql string - sql = "CREATE TABLE " +func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { + var sql = "CREATE TABLE " if tableName == "" { tableName = table.Name } - sql += db.Quote(tableName) + " (" + quoter := db.Quoter() + sql += quoter.Quote(tableName) + " (" pkList := table.PrimaryKeys @@ -591,7 +576,8 @@ func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, char /*if col.IsPrimaryKey && len(pkList) == 1 { sql += col.String(b.dialect) } else {*/ - sql += col.StringNoPk(db) + s, _ := ColumnString(db, col, false) + sql += s // } sql = strings.TrimSpace(sql) sql += ", " @@ -599,97 +585,63 @@ func (db *oracle) CreateTableSql(table *core.Table, tableName, storeEngine, char if len(pkList) > 0 { sql += "PRIMARY KEY ( " - sql += db.Quote(strings.Join(pkList, db.Quote(","))) + sql += quoter.Join(pkList, ",") sql += " ), " } sql = sql[:len(sql)-2] + ")" - if db.SupportEngine() && storeEngine != "" { - sql += " ENGINE=" + storeEngine - } - if db.SupportCharset() { - if len(charset) == 0 { - charset = db.URI().Charset - } - if len(charset) > 0 { - sql += " DEFAULT CHARSET " + charset - } - } - return sql + return []string{sql}, false } -func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) { +func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = oracleQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = oracleQuoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = oracleQuoter + } +} + +func (db *oracle) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{tableName, idxName} return `SELECT INDEX_NAME FROM USER_INDEXES ` + `WHERE TABLE_NAME = :1 AND INDEX_NAME = :2`, args } -func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return `SELECT table_name FROM user_tables WHERE table_name = :1`, args +func (db *oracle) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(queryer, ctx, `SELECT table_name FROM user_tables WHERE table_name = :1`, tableName) } -func (db *oracle) MustDropTable(tableName string) error { - sql, args := db.TableCheckSql(tableName) - db.LogSQL(sql, args) - - rows, err := db.DB().Query(sql, args...) - if err != nil { - return err - } - defer rows.Close() - - if !rows.Next() { - return nil - } - - sql = "Drop Table \"" + tableName + "\"" - db.LogSQL(sql, args) - - _, err = db.DB().Exec(sql) - return err -} - -/*func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(colName)} - return "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + - " AND column_name = ?", args -}*/ - -func (db *oracle) IsColumnExist(tableName, colName string) (bool, error) { +func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + " AND column_name = :2" - db.LogSQL(query, args) - - rows, err := db.DB().Query(query, args...) - if err != nil { - return false, err - } - defer rows.Close() - - if rows.Next() { - return true, nil - } - return false, nil + return db.HasRecords(queryer, ctx, query, args...) } -func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { +func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } defer rows.Close() - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string @@ -731,30 +683,30 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum switch dt { case "VARCHAR2": - col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: len1, DefaultLength2: len2} + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: len1, DefaultLength2: len2} case "NVARCHAR2": - col.SQLType = core.SQLType{Name: core.NVarchar, DefaultLength: len1, DefaultLength2: len2} + col.SQLType = schemas.SQLType{Name: schemas.NVarchar, DefaultLength: len1, DefaultLength2: len2} case "TIMESTAMP WITH TIME ZONE": - col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} case "NUMBER": - col.SQLType = core.SQLType{Name: core.Double, DefaultLength: len1, DefaultLength2: len2} + col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: len1, DefaultLength2: len2} case "LONG", "LONG RAW": - col.SQLType = core.SQLType{Name: core.Text, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Text, DefaultLength: 0, DefaultLength2: 0} case "RAW": - col.SQLType = core.SQLType{Name: core.Binary, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Binary, DefaultLength: 0, DefaultLength2: 0} case "ROWID": - col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: 18, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 18, DefaultLength2: 0} case "AQ$_SUBSCRIBERS": ignore = true default: - col.SQLType = core.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2} + col.SQLType = schemas.SQLType{Name: strings.ToUpper(dt), DefaultLength: len1, DefaultLength2: len2} } if ignore { continue } - if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { + if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { return nil, nil, fmt.Errorf("Unknown colType %v %v", *dataType, col.SQLType) } @@ -772,20 +724,19 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum return colSeq, cols, nil } -func (db *oracle) GetTables() ([]*core.Table, error) { +func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT table_name FROM user_tables" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() + table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { return nil, err @@ -796,19 +747,18 @@ func (db *oracle) GetTables() ([]*core.Table, error) { return tables, nil } -func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { +func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var indexType int var indexName, colName, uniqueness string @@ -827,15 +777,15 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { } if uniqueness == "UNIQUE" { - indexType = core.UniqueType + indexType = schemas.UniqueType } else { - indexType = core.IndexType + indexType = schemas.IndexType } - var index *core.Index + var index *schemas.Index var ok bool if index, ok = indexes[indexName]; !ok { - index = new(core.Index) + index = new(schemas.Index) index.Type = indexType index.Name = indexName index.IsRegular = isRegular @@ -846,15 +796,17 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { return indexes, nil } -func (db *oracle) Filters() []core.Filter { - return []core.Filter{&core.QuoteFilter{}, &core.SeqFilter{Prefix: ":", Start: 1}, &core.IdFilter{}} +func (db *oracle) Filters() []Filter { + return []Filter{ + &SeqFilter{Prefix: ":", Start: 1}, + } } type goracleDriver struct { } -func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.ORACLE} +func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*URI, error) { + db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] @@ -867,10 +819,10 @@ func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, e for i, match := range matches { switch names[i] { case "dbname": - db.DbName = match + db.DBName = match } } - if db.DbName == "" { + if db.DBName == "" { return nil, errors.New("dbname is empty") } return db, nil @@ -881,8 +833,8 @@ type oci8Driver struct { // dataSourceName=user/password@ipv4:port/dbname // dataSourceName=user/password@[ipv6]:port/dbname -func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.ORACLE} +func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { + db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?P.*)\/(?P.*)@` + // user:password@ `(?P.*)` + // ip:port @@ -892,10 +844,10 @@ func (p *oci8Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) for i, match := range matches { switch names[i] { case "dbname": - db.DbName = match + db.DBName = match } } - if db.DbName == "" { + if db.DBName == "" && len(matches) != 0 { return nil, errors.New("dbname is empty") } return db, nil diff --git a/dialects/oracle_test.go b/dialects/oracle_test.go new file mode 100644 index 00000000..9c3a93f5 --- /dev/null +++ b/dialects/oracle_test.go @@ -0,0 +1,34 @@ +package dialects + +import ( + "reflect" + "testing" +) + +func TestParseOracleConnStr(t *testing.T) { + tests := []struct { + in string + expected string + valid bool + }{ + {"user/pass@tcp(server:1521)/db", "db", true}, + {"user/pass@server:1521/db", "db", true}, + // test for net service name : https://docs.oracle.com/cd/B13789_01/network.101/b10775/glossary.htm#i998113 + {"user/pass@server:1521", "", true}, + {"user/pass@", "", false}, + {"user/pass", "", false}, + {"", "", false}, + } + driver := QueryDriver("oci8") + for _, test := range tests { + t.Run(test.in, func(t *testing.T) { + driver := driver + uri, err := driver.Parse("oci8", test.in) + if err != nil && test.valid { + t.Errorf("%q got unexpected error: %s", test.in, err) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) + } + }) + } +} diff --git a/pg_reserved.txt b/dialects/pg_reserved.txt similarity index 100% rename from pg_reserved.txt rename to dialects/pg_reserved.txt diff --git a/dialect_postgres.go b/dialects/postgres.go similarity index 83% rename from dialect_postgres.go rename to dialects/postgres.go index ccef3086..1996c49d 100644 --- a/dialect_postgres.go +++ b/dialects/postgres.go @@ -2,16 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( + "context" "errors" "fmt" "net/url" "strconv" "strings" - "xorm.io/core" + "xorm.io/xorm/core" + "xorm.io/xorm/schemas" ) // from http://www.postgresql.org/docs/current/static/sql-keywords-appendix.html @@ -765,71 +767,107 @@ var ( "ZONE": true, } + postgresQuoter = schemas.Quoter{ + Prefix: '"', + Suffix: '"', + IsReserved: schemas.AlwaysReserve, + } +) + +var ( // DefaultPostgresSchema default postgres schema DefaultPostgresSchema = "public" ) -const postgresPublicSchema = "public" - type postgres struct { - core.Base + Base } -func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - err := db.Base.Init(d, db, uri, drivername, dataSourceName) - if err != nil { - return err - } - if db.Schema == "" { - db.Schema = DefaultPostgresSchema - } - return nil +func (db *postgres) Init(uri *URI) error { + db.quoter = postgresQuoter + return db.Base.Init(db, uri) } -func (db *postgres) SqlType(c *core.Column) string { +func (db *postgres) getSchema() string { + if db.uri.Schema != "" { + return db.uri.Schema + } + return DefaultPostgresSchema +} + +func (db *postgres) needQuote(name string) bool { + if db.IsReserved(name) { + return true + } + for _, c := range name { + if c >= 'A' && c <= 'Z' { + return true + } + } + return false +} + +func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = postgresQuoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = postgresQuoter + q.IsReserved = db.needQuote + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = postgresQuoter + } +} + +func (db *postgres) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { - case core.TinyInt: - res = core.SmallInt + case schemas.TinyInt: + res = schemas.SmallInt return res - case core.Bit: - res = core.Boolean + case schemas.Bit: + res = schemas.Boolean return res - case core.MediumInt, core.Int, core.Integer: + case schemas.MediumInt, schemas.Int, schemas.Integer: if c.IsAutoIncrement { - return core.Serial + return schemas.Serial } - return core.Integer - case core.BigInt: + return schemas.Integer + case schemas.BigInt: if c.IsAutoIncrement { - return core.BigSerial + return schemas.BigSerial } - return core.BigInt - case core.Serial, core.BigSerial: + return schemas.BigInt + case schemas.Serial, schemas.BigSerial: c.IsAutoIncrement = true c.Nullable = false res = t - case core.Binary, core.VarBinary: - return core.Bytea - case core.DateTime: - res = core.TimeStamp - case core.TimeStampz: + case schemas.Binary, schemas.VarBinary: + return schemas.Bytea + case schemas.DateTime: + res = schemas.TimeStamp + case schemas.TimeStampz: return "timestamp with time zone" - case core.Float: - res = core.Real - case core.TinyText, core.MediumText, core.LongText: - res = core.Text - case core.NVarchar: - res = core.Varchar - case core.Uuid: - return core.Uuid - case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: - return core.Bytea - case core.Double: + case schemas.Float: + res = schemas.Real + case schemas.TinyText, schemas.MediumText, schemas.LongText: + res = schemas.Text + case schemas.NVarchar: + res = schemas.Varchar + case schemas.Uuid: + return schemas.Uuid + case schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob: + return schemas.Bytea + case schemas.Double: return "DOUBLE PRECISION" default: if c.IsAutoIncrement { - return core.Serial + return schemas.Serial } res = t } @@ -849,99 +887,110 @@ func (db *postgres) SqlType(c *core.Column) string { return res } -func (db *postgres) SupportInsertMany() bool { - return true -} - func (db *postgres) IsReserved(name string) bool { - _, ok := postgresReservedWords[name] + _, ok := postgresReservedWords[strings.ToUpper(name)] return ok } -func (db *postgres) Quote(name string) string { - name = strings.Replace(name, ".", `"."`, -1) - return "\"" + name + "\"" -} - func (db *postgres) AutoIncrStr() string { return "" } -func (db *postgres) SupportEngine() bool { - return false +func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { + var sql string + sql = "CREATE TABLE IF NOT EXISTS " + if tableName == "" { + tableName = table.Name + } + + quoter := db.Quoter() + sql += quoter.Quote(tableName) + sql += " (" + + if len(table.ColumnsSeq()) > 0 { + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) + sql += s + sql = strings.TrimSpace(sql) + sql += ", " + } + + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += quoter.Join(pkList, ",") + sql += " ), " + } + + sql = sql[:len(sql)-2] + } + sql += ")" + + return []string{sql}, true } -func (db *postgres) SupportCharset() bool { - return false -} - -func (db *postgres) IndexOnTable() bool { - return false -} - -func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - if len(db.Schema) == 0 { +func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { + if len(db.getSchema()) == 0 { args := []interface{}{tableName, idxName} return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args } - args := []interface{}{db.Schema, tableName, idxName} + args := []interface{}{db.getSchema(), tableName, idxName} return `SELECT indexname FROM pg_indexes ` + `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } -func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { - if len(db.Schema) == 0 { - args := []interface{}{tableName} - return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args +func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + if len(db.getSchema()) == 0 { + return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) } - args := []interface{}{db.Schema, tableName} - return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args + return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, + db.getSchema(), tableName) } -func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { - if len(db.Schema) == 0 { +func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { + if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") { return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", - tableName, col.Name, db.SqlType(col)) + tableName, col.Name, db.SQLType(col)) } return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", - db.Schema, tableName, col.Name, db.SqlType(col)) + db.getSchema(), tableName, col.Name, db.SQLType(col)) } -func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { - quote := db.Quote +func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string { idxName := index.Name - tableName = strings.Replace(tableName, `"`, "", -1) - tableName = strings.Replace(tableName, `.`, "_", -1) + tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".") + tableName = tableParts[len(tableParts)-1] if !strings.HasPrefix(idxName, "UQE_") && !strings.HasPrefix(idxName, "IDX_") { - if index.Type == core.UniqueType { + if index.Type == schemas.UniqueType { idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) } else { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } - if db.Uri.Schema != "" { - idxName = db.Uri.Schema + "." + idxName + if db.getSchema() != "" { + idxName = db.getSchema() + "." + idxName } - return fmt.Sprintf("DROP INDEX %v", quote(idxName)) + return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { - args := []interface{}{db.Schema, tableName, colName} +func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { + args := []interface{}{db.getSchema(), tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + " AND column_name = $3" - if len(db.Schema) == 0 { + if len(db.getSchema()) == 0 { args = []interface{}{tableName, colName} query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" } - db.LogSQL(query, args) - rows, err := db.DB().Query(query, args...) + rows, err := queryer.QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -950,7 +999,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { return rows.Next(), nil } -func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { +func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, @@ -962,28 +1011,27 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` +WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` - var f string - if len(db.Schema) != 0 { - args = append(args, db.Schema) - f = " AND s.table_schema = $2" + schema := db.getSchema() + if schema != "" { + s = fmt.Sprintf(s, "AND s.table_schema = $2") + args = append(args, schema) + } else { + s = fmt.Sprintf(s, "") } - s = fmt.Sprintf(s, f) - db.LogSQL(s, args) - - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } defer rows.Close() - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) var colName, isNullable, dataType string @@ -994,7 +1042,6 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att return nil, nil, err } - // fmt.Println(args, colName, isNullable, dataType, maxLenStr, colDefault, isPK, isUnique) var maxLen int if maxLenStr != nil { maxLen, err = strconv.Atoi(*maxLenStr) @@ -1006,10 +1053,27 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att col.Name = strings.Trim(colName, `" `) if colDefault != nil { - col.Default = *colDefault + var theDefault = *colDefault + // cockroach has type with the default value with ::: + // and postgres with ::, we should remove them before store them + idx := strings.Index(theDefault, ":::") + if idx == -1 { + idx = strings.Index(theDefault, "::") + } + if idx > -1 { + theDefault = theDefault[:idx] + } + + if strings.HasSuffix(theDefault, "+00:00'") { + theDefault = theDefault[:len(theDefault)-7] + "'" + } + + col.Default = theDefault col.DefaultIsEmpty = false if strings.HasPrefix(col.Default, "nextval(") { col.IsAutoIncrement = true + col.Default = "" + col.DefaultIsEmpty = true } } else { col.DefaultIsEmpty = true @@ -1021,26 +1085,37 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att col.Nullable = (isNullable == "YES") - switch dataType { - case "character varying", "character": - col.SQLType = core.SQLType{Name: core.Varchar, DefaultLength: 0, DefaultLength2: 0} + switch strings.ToLower(dataType) { + case "character varying", "character", "string": + col.SQLType = schemas.SQLType{Name: schemas.Varchar, DefaultLength: 0, DefaultLength2: 0} case "timestamp without time zone": - col.SQLType = core.SQLType{Name: core.DateTime, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.DateTime, DefaultLength: 0, DefaultLength2: 0} case "timestamp with time zone": - col.SQLType = core.SQLType{Name: core.TimeStampz, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} case "double precision": - col.SQLType = core.SQLType{Name: core.Double, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Double, DefaultLength: 0, DefaultLength2: 0} case "boolean": - col.SQLType = core.SQLType{Name: core.Bool, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Bool, DefaultLength: 0, DefaultLength2: 0} case "time without time zone": - col.SQLType = core.SQLType{Name: core.Time, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.Time, DefaultLength: 0, DefaultLength2: 0} + case "bytes": + col.SQLType = schemas.SQLType{Name: schemas.Binary, DefaultLength: 0, DefaultLength2: 0} case "oid": - col.SQLType = core.SQLType{Name: core.BigInt, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: schemas.BigInt, DefaultLength: 0, DefaultLength2: 0} + case "array": + col.SQLType = schemas.SQLType{Name: schemas.Array, DefaultLength: 0, DefaultLength2: 0} default: - col.SQLType = core.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} + startIdx := strings.Index(strings.ToLower(dataType), "string(") + if startIdx != -1 && strings.HasSuffix(dataType, ")") { + length := dataType[startIdx+8 : len(dataType)-1] + l, _ := strconv.Atoi(length) + col.SQLType = schemas.SQLType{Name: "STRING", DefaultLength: l, DefaultLength2: 0} + } else { + col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} + } } - if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { - return nil, nil, fmt.Errorf("Unknown colType: %v", dataType) + if _, ok := schemas.SqlTypes[col.SQLType.Name]; !ok { + return nil, nil, fmt.Errorf("Unknown colType: %s - %s", dataType, col.SQLType.Name) } col.Length = maxLen @@ -1065,25 +1140,24 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att return colSeq, cols, nil } -func (db *postgres) GetTables() ([]*core.Table, error) { +func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT tablename FROM pg_tables" - if len(db.Schema) != 0 { - args = append(args, db.Schema) + schema := db.getSchema() + if schema != "" { + args = append(args, schema) s = s + " WHERE schemaname = $1" } - db.LogSQL(s, args) - - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() + table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) if err != nil { @@ -1106,22 +1180,21 @@ func getIndexColName(indexdef string) []string { return colNames } -func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { +func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") - if len(db.Schema) != 0 { - args = append(args, db.Schema) + if len(db.getSchema()) != 0 { + args = append(args, db.getSchema()) s = s + " AND schemaname=$2" } - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var indexType int var indexName, indexdef string @@ -1130,14 +1203,18 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) if err != nil { return nil, err } + + if indexName == "primary" { + continue + } indexName = strings.Trim(indexName, `" `) if strings.HasSuffix(indexName, "_pkey") { continue } if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { - indexType = core.UniqueType + indexType = schemas.UniqueType } else { - indexType = core.IndexType + indexType = schemas.IndexType } colNames = getIndexColName(indexdef) var isRegular bool @@ -1149,9 +1226,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) } } - index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} + index := &schemas.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} for _, colName := range colNames { - index.Cols = append(index.Cols, strings.Trim(colName, `" `)) + index.Cols = append(index.Cols, strings.TrimSpace(strings.Replace(colName, `"`, "", -1))) } index.IsRegular = isRegular indexes[index.Name] = index @@ -1159,8 +1236,8 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) return indexes, nil } -func (db *postgres) Filters() []core.Filter { - return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}, &core.SeqFilter{Prefix: "$", Start: 1}} +func (db *postgres) Filters() []Filter { + return []Filter{&SeqFilter{Prefix: "$", Start: 1}} } type pqDriver struct { @@ -1214,12 +1291,12 @@ func parseOpts(name string, o values) error { return nil } -func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { - db := &core.Uri{DbType: core.POSTGRES} +func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { + db := &URI{DBType: schemas.POSTGRES} var err error if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { - db.DbName, err = parseURL(dataSourceName) + db.DBName, err = parseURL(dataSourceName) if err != nil { return nil, err } @@ -1230,10 +1307,10 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { return nil, err } - db.DbName = o.Get("dbname") + db.DBName = o.Get("dbname") } - if db.DbName == "" { + if db.DBName == "" { return nil, errors.New("dbname is empty") } @@ -1244,10 +1321,29 @@ type pqDriverPgx struct { pqDriver } -func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) { +func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*URI, error) { // Remove the leading characters for driver to work if len(dataSourceName) >= 9 && dataSourceName[0] == 0 { dataSourceName = dataSourceName[9:] } return pgx.pqDriver.Parse(driverName, dataSourceName) } + +// QueryDefaultPostgresSchema returns the default postgres schema +func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (string, error) { + rows, err := queryer.QueryContext(ctx, "SHOW SEARCH_PATH") + if err != nil { + return "", err + } + defer rows.Close() + if rows.Next() { + var defaultSchema string + if err = rows.Scan(&defaultSchema); err != nil { + return "", err + } + parts := strings.Split(defaultSchema, ",") + return strings.TrimSpace(parts[len(parts)-1]), nil + } + + return "", errors.New("No default schema") +} diff --git a/dialect_postgres_test.go b/dialects/postgres_test.go similarity index 92% rename from dialect_postgres_test.go rename to dialects/postgres_test.go index f2afdefc..c0a8eb6f 100644 --- a/dialect_postgres_test.go +++ b/dialects/postgres_test.go @@ -1,11 +1,10 @@ -package xorm +package dialects import ( "reflect" "testing" "github.com/stretchr/testify/assert" - "xorm.io/core" ) func TestParsePostgres(t *testing.T) { @@ -27,15 +26,15 @@ func TestParsePostgres(t *testing.T) { {"dbname=db =disable", "db", false}, } - driver := core.QueryDriver("postgres") + driver := QueryDriver("postgres") for _, test := range tests { uri, err := driver.Parse("postgres", test.in) if err != nil && test.valid { t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } } } @@ -59,23 +58,23 @@ func TestParsePgx(t *testing.T) { {"dbname=db =disable", "db", false}, } - driver := core.QueryDriver("pgx") + driver := QueryDriver("pgx") for _, test := range tests { uri, err := driver.Parse("pgx", test.in) if err != nil && test.valid { t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } // Register DriverConfig uri, err = driver.Parse("pgx", test.in) if err != nil && test.valid { t.Errorf("%q got unexpected error: %s", test.in, err) - } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { - t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } } diff --git a/dialects/quote.go b/dialects/quote.go new file mode 100644 index 00000000..da4e0dd6 --- /dev/null +++ b/dialects/quote.go @@ -0,0 +1,15 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +// QuotePolicy describes quote handle policy +type QuotePolicy int + +// All QuotePolicies +const ( + QuotePolicyAlways QuotePolicy = iota + QuotePolicyNone + QuotePolicyReserved +) diff --git a/dialect_sqlite3.go b/dialects/sqlite3.go similarity index 67% rename from dialect_sqlite3.go rename to dialects/sqlite3.go index 0a290f3c..0e910934 100644 --- a/dialect_sqlite3.go +++ b/dialects/sqlite3.go @@ -2,16 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( + "context" "database/sql" "errors" "fmt" "regexp" "strings" - "xorm.io/core" + "xorm.io/xorm/core" + "xorm.io/xorm/schemas" ) var ( @@ -141,45 +143,69 @@ var ( "WITH": true, "WITHOUT": true, } + + sqlite3Quoter = schemas.Quoter{ + Prefix: '`', + Suffix: '`', + IsReserved: schemas.AlwaysReserve, + } ) type sqlite3 struct { - core.Base + Base } -func (db *sqlite3) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) +func (db *sqlite3) Init(uri *URI) error { + db.quoter = sqlite3Quoter + return db.Base.Init(db, uri) } -func (db *sqlite3) SqlType(c *core.Column) string { +func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { + switch quotePolicy { + case QuotePolicyNone: + var q = sqlite3Quoter + q.IsReserved = schemas.AlwaysNoReserve + db.quoter = q + case QuotePolicyReserved: + var q = sqlite3Quoter + q.IsReserved = db.IsReserved + db.quoter = q + case QuotePolicyAlways: + fallthrough + default: + db.quoter = sqlite3Quoter + } +} + +func (db *sqlite3) SQLType(c *schemas.Column) string { switch t := c.SQLType.Name; t { - case core.Bool: + case schemas.Bool: if c.Default == "true" { c.Default = "1" } else if c.Default == "false" { c.Default = "0" } - return core.Integer - case core.Date, core.DateTime, core.TimeStamp, core.Time: - return core.DateTime - case core.TimeStampz: - return core.Text - case core.Char, core.Varchar, core.NVarchar, core.TinyText, - core.Text, core.MediumText, core.LongText, core.Json: - return core.Text - case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt: - return core.Integer - case core.Float, core.Double, core.Real: - return core.Real - case core.Decimal, core.Numeric: - return core.Numeric - case core.TinyBlob, core.Blob, core.MediumBlob, core.LongBlob, core.Bytea, core.Binary, core.VarBinary: - return core.Blob - case core.Serial, core.BigSerial: + return schemas.Integer + case schemas.Date, schemas.DateTime, schemas.TimeStamp, schemas.Time: + return schemas.DateTime + case schemas.TimeStampz: + return schemas.Text + case schemas.Char, schemas.Varchar, schemas.NVarchar, schemas.TinyText, + schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json: + return schemas.Text + case schemas.Bit, schemas.TinyInt, schemas.SmallInt, schemas.MediumInt, schemas.Int, schemas.Integer, schemas.BigInt: + return schemas.Integer + case schemas.Float, schemas.Double, schemas.Real: + return schemas.Real + case schemas.Decimal, schemas.Numeric: + return schemas.Numeric + case schemas.TinyBlob, schemas.Blob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea, schemas.Binary, schemas.VarBinary: + return schemas.Blob + case schemas.Serial, schemas.BigSerial: c.IsPrimaryKey = true c.IsAutoIncrement = true c.Nullable = false - return core.Integer + return schemas.Integer default: return t } @@ -189,84 +215,97 @@ func (db *sqlite3) FormatBytes(bs []byte) string { return fmt.Sprintf("X'%x'", bs) } -func (db *sqlite3) SupportInsertMany() bool { - return true -} - func (db *sqlite3) IsReserved(name string) bool { - _, ok := sqlite3ReservedWords[name] + _, ok := sqlite3ReservedWords[strings.ToUpper(name)] return ok } -func (db *sqlite3) Quote(name string) string { - return "`" + name + "`" -} - func (db *sqlite3) AutoIncrStr() string { return "AUTOINCREMENT" } -func (db *sqlite3) SupportEngine() bool { - return false -} - -func (db *sqlite3) SupportCharset() bool { - return false -} - -func (db *sqlite3) IndexOnTable() bool { - return false -} - -func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) { +func (db *sqlite3) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { args := []interface{}{idxName} return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args } -func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args +func (db *sqlite3) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { + return db.HasRecords(queryer, ctx, "SELECT name FROM sqlite_master WHERE type='table' and name = ?", tableName) } -func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string { +func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string { // var unique string - quote := db.Quote idxName := index.Name if !strings.HasPrefix(idxName, "UQE_") && !strings.HasPrefix(idxName, "IDX_") { - if index.Type == core.UniqueType { + if index.Type == schemas.UniqueType { idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) } else { idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) } } - return fmt.Sprintf("DROP INDEX %v", quote(idxName)) + return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) } -func (db *sqlite3) ForUpdateSql(query string) string { +func (db *sqlite3) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { + var sql string + sql = "CREATE TABLE IF NOT EXISTS " + if tableName == "" { + tableName = table.Name + } + + quoter := db.Quoter() + sql += quoter.Quote(tableName) + sql += " (" + + if len(table.ColumnsSeq()) > 0 { + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) + sql += s + sql = strings.TrimSpace(sql) + sql += ", " + } + + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += quoter.Join(pkList, ",") + sql += " ), " + } + + sql = sql[:len(sql)-2] + } + sql += ")" + + return []string{sql}, true +} + +func (db *sqlite3) ForUpdateSQL(query string) string { return query } -/*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName} - sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - return sql, args -}*/ - -func (db *sqlite3) IsColumnExist(tableName, colName string) (bool, error) { - args := []interface{}{tableName} - query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - db.LogSQL(query, args) - rows, err := db.DB().Query(query, args...) +func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { + query := "SELECT * FROM " + tableName + " LIMIT 0" + rows, err := queryer.QueryContext(ctx, query) if err != nil { return false, err } defer rows.Close() - if rows.Next() { - return true, nil + cols, err := rows.Columns() + if err != nil { + return false, err } + + for _, col := range cols { + if strings.EqualFold(col, colName) { + return true, nil + } + } + return false, nil } @@ -298,9 +337,9 @@ func splitColStr(colStr string) []string { return results } -func parseString(colStr string) (*core.Column, error) { +func parseString(colStr string) (*schemas.Column, error) { fields := splitColStr(colStr) - col := new(core.Column) + col := new(schemas.Column) col.Indexes = make(map[string]int) col.Nullable = true col.DefaultIsEmpty = true @@ -310,7 +349,7 @@ func parseString(colStr string) (*core.Column, error) { col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`) continue } else if idx == 1 { - col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0} + col.SQLType = schemas.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0} continue } switch field { @@ -332,11 +371,11 @@ func parseString(colStr string) (*core.Column, error) { return col, nil } -func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { +func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, nil, err } @@ -359,7 +398,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu nEnd := strings.LastIndex(name, ")") reg := regexp.MustCompile(`[^\(,\)]*(\([^\(]*\))?`) colCreates := reg.FindAllString(name[nStart+1:nEnd], -1) - cols := make(map[string]*core.Column) + cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for _, colStr := range colCreates { @@ -389,20 +428,19 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu return colSeq, cols, nil } -func (db *sqlite3) GetTables() ([]*core.Table, error) { +func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT name FROM sqlite_master WHERE type='table'" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - tables := make([]*core.Table, 0) + tables := make([]*schemas.Table, 0) for rows.Next() { - table := core.NewEmptyTable() + table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { return nil, err @@ -415,18 +453,17 @@ func (db *sqlite3) GetTables() ([]*core.Table, error) { return tables, nil } -func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) { +func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" - db.LogSQL(s, args) - rows, err := db.DB().Query(s, args...) + rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { return nil, err } defer rows.Close() - indexes := make(map[string]*core.Index, 0) + indexes := make(map[string]*schemas.Index, 0) for rows.Next() { var tmpSQL sql.NullString err = rows.Scan(&tmpSQL) @@ -439,7 +476,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) } sql := tmpSQL.String - index := new(core.Index) + index := new(schemas.Index) nNStart := strings.Index(sql, "INDEX") nNEnd := strings.Index(sql, "ON") if nNStart == -1 || nNEnd == -1 { @@ -456,9 +493,9 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) } if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { - index.Type = core.UniqueType + index.Type = schemas.UniqueType } else { - index.Type = core.IndexType + index.Type = schemas.IndexType } nStart := strings.Index(sql, "(") @@ -476,17 +513,17 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) return indexes, nil } -func (db *sqlite3) Filters() []core.Filter { - return []core.Filter{&core.IdFilter{}} +func (db *sqlite3) Filters() []Filter { + return []Filter{} } type sqlite3Driver struct { } -func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { +func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { if strings.Contains(dataSourceName, "?") { dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")] } - return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil + return &URI{DBType: schemas.SQLITE, DBName: dataSourceName}, nil } diff --git a/dialect_sqlite3_test.go b/dialects/sqlite3_test.go similarity index 97% rename from dialect_sqlite3_test.go rename to dialects/sqlite3_test.go index a2036159..aa6c3cea 100644 --- a/dialect_sqlite3_test.go +++ b/dialects/sqlite3_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( "testing" diff --git a/dialects/table_name.go b/dialects/table_name.go new file mode 100644 index 00000000..e190cd4b --- /dev/null +++ b/dialects/table_name.go @@ -0,0 +1,89 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "fmt" + "reflect" + "strings" + + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/names" +) + +// TableNameWithSchema will add schema prefix on table name if possible +func TableNameWithSchema(dialect Dialect, tableName string) string { + // Add schema name as prefix of table name. + // Only for postgres database. + if dialect.URI().Schema != "" && + strings.Index(tableName, ".") == -1 { + return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName) + } + return tableName +} + +// TableNameNoSchema returns table name with given tableName +func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface{}) string { + quote := dialect.Quoter().Quote + switch tableName.(type) { + case []string: + t := tableName.([]string) + if len(t) > 1 { + return fmt.Sprintf("%v AS %v", quote(t[0]), quote(t[1])) + } else if len(t) == 1 { + return quote(t[0]) + } + case []interface{}: + t := tableName.([]interface{}) + l := len(t) + var table string + if l > 0 { + f := t[0] + switch f.(type) { + case string: + table = f.(string) + case names.TableName: + table = f.(names.TableName).TableName() + default: + v := utils.ReflectValue(f) + t := v.Type() + if t.Kind() == reflect.Struct { + table = names.GetTableName(mapper, v) + } else { + table = quote(fmt.Sprintf("%v", f)) + } + } + } + if l > 1 { + return fmt.Sprintf("%v AS %v", quote(table), quote(fmt.Sprintf("%v", t[1]))) + } else if l == 1 { + return quote(table) + } + case names.TableName: + return tableName.(names.TableName).TableName() + case string: + return tableName.(string) + case reflect.Value: + v := tableName.(reflect.Value) + return names.GetTableName(mapper, v) + default: + v := utils.ReflectValue(tableName) + t := v.Type() + if t.Kind() == reflect.Struct { + return names.GetTableName(mapper, v) + } + return quote(fmt.Sprintf("%v", tableName)) + } + return "" +} + +// FullTableName returns table name with quote and schema according parameter +func FullTableName(dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string { + tbName := TableNameNoSchema(dialect, mapper, bean) + if len(includeSchema) > 0 && includeSchema[0] && !utils.IsSubQuery(tbName) { + tbName = TableNameWithSchema(dialect, tbName) + } + return tbName +} diff --git a/engine_table_test.go b/dialects/table_name_test.go similarity index 60% rename from engine_table_test.go rename to dialects/table_name_test.go index 8f2300aa..66edc2b4 100644 --- a/engine_table_test.go +++ b/dialects/table_name_test.go @@ -2,11 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( "testing" + "xorm.io/xorm/names" + "github.com/stretchr/testify/assert" ) @@ -20,9 +22,9 @@ func (mcc *MCC) TableName() string { return "mcc" } -func TestTableName1(t *testing.T) { - assert.NoError(t, prepareEngine()) +func TestFullTableName(t *testing.T) { + dialect := QueryDialect("mysql") - assert.EqualValues(t, "mcc", testEngine.TableName(new(MCC))) - assert.EqualValues(t, "mcc", testEngine.TableName("mcc")) + assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, &MCC{})) + assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, "mcc")) } diff --git a/dialects/time.go b/dialects/time.go new file mode 100644 index 00000000..b0394745 --- /dev/null +++ b/dialects/time.go @@ -0,0 +1,49 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "time" + + "xorm.io/xorm/schemas" +) + +// FormatTime format time as column type +func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}) { + switch sqlTypeName { + case schemas.Time: + s := t.Format("2006-01-02 15:04:05") // time.RFC3339 + v = s[11:19] + case schemas.Date: + v = t.Format("2006-01-02") + case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar. + v = t.Format("2006-01-02 15:04:05") + case schemas.TimeStampz: + if dialect.URI().DBType == schemas.MSSQL { + v = t.Format("2006-01-02T15:04:05.9999999Z07:00") + } else { + v = t.Format(time.RFC3339Nano) + } + case schemas.BigInt, schemas.Int: + v = t.Unix() + default: + v = t + } + return +} + +func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) { + if t.IsZero() { + if col.Nullable { + return nil + } + return "" + } + + if col.TimeZone != nil { + return FormatTime(dialect, col.SQLType.Name, t.In(col.TimeZone)) + } + return FormatTime(dialect, col.SQLType.Name, t.In(defaultTimeZone)) +} diff --git a/doc.go b/doc.go index 9620bca1..ea6a2226 100644 --- a/doc.go +++ b/doc.go @@ -8,7 +8,7 @@ Package xorm is a simple and powerful ORM for Go. Installation -Make sure you have installed Go 1.6+ and then: +Make sure you have installed Go 1.11+ and then: go get xorm.io/xorm @@ -126,7 +126,7 @@ Attention: the above 8 methods should be the last chainable method. engine.ID(1).Get(&user) // for single primary key // SELECT * FROM user WHERE id = 1 - engine.ID(core.PK{1, 2}).Get(&user) // for composite primary keys + engine.ID(schemas.PK{1, 2}).Get(&user) // for composite primary keys // SELECT * FROM user WHERE id1 = 1 AND id2 = 2 engine.In("id", 1, 2, 3).Find(&users) // SELECT * FROM user WHERE id IN (1, 2, 3) diff --git a/docs/images/cache_design.graffle b/docs/images/cache_design.graffle deleted file mode 100644 index 5b7c487b..00000000 --- a/docs/images/cache_design.graffle +++ /dev/null @@ -1,2295 +0,0 @@ - - - - - ActiveLayerIndex - 0 - ApplicationVersion - - com.omnigroup.OmniGrafflePro - 139.16.0.171715 - - AutoAdjust - - BackgroundGraphic - - Bounds - {{0, 0}, {771, 554.18930041152259}} - Class - SolidGraphic - ID - 2 - Style - - fill - - Color - - b - 0.989303 - g - 0.907286 - r - 0.795377 - - FillType - 2 - GradientAngle - 78 - GradientColor - - b - 1 - g - 0.854588 - r - 0.623912 - - MiddleColor - - b - 1 - g - 0.856844 - r - 0.43695 - - TrippleBlend - YES - - shadow - - Draws - NO - - stroke - - Draws - NO - - - - BaseZoom - 0 - CanvasOrigin - {0, 0} - CanvasSize - {771, 554.18930041152259} - ColumnAlign - 1 - ColumnSpacing - 36 - CreationDate - 2013-09-29 07:57:57 +0000 - Creator - Lunny Xiao - DisplayScale - 1.000 cm = 1.000 cm - FileType - flat - GraphDocumentVersion - 8 - GraphicsList - - - Bounds - {{409.89504441572683, 415.64570506990464}, {104.42639923095703, 79.447883605957031}} - Class - ShapedGraphic - FontInfo - - Color - - b - 0.8 - g - 0.8 - r - 0.8 - - Font - Verdana - Size - 18 - - ID - 30 - Shape - Rectangle - Style - - fill - - Color - - b - 0.6 - g - 0.6 - r - 0.6 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.4 - g - 0.4 - r - 0.4 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;\red204\green204\blue204;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf2 .\ -.\ -.\ -} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{276.44083898205287, 413.07252538018992}, {112.36092376708984, 79.447883605957031}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - Size - 18 - - ID - 29 - Shape - Rectangle - Style - - fill - - Color - - b - 0.776486 - g - 0.588495 - r - 0.670497 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.618021 - g - 0.412924 - r - 0.50312 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf0 .\ -.\ -.} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{409.89504441572689, 337.17180246145824}, {104.42639923095703, 51}} - Class - ShapedGraphic - FontInfo - - Color - - b - 0.8 - g - 0.8 - r - 0.8 - - Font - Verdana - Size - 18 - - ID - 28 - Shape - Rectangle - Style - - fill - - Color - - b - 0.6 - g - 0.6 - r - 0.6 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.4 - g - 0.4 - r - 0.4 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;\red204\green204\blue204;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf2 user-2:User\{\}} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{274.32251833907753, 322.94787397618234}, {112.36092376708984, 79.447883605957031}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - Size - 18 - - ID - 27 - Shape - Rectangle - Style - - fill - - Color - - b - 0.776486 - g - 0.588495 - r - 0.670497 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.618021 - g - 0.412924 - r - 0.50312 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf0 select id from tb3:[2,5]} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{406.08888702072045, 256.42244420026987}, {104.42639923095703, 51}} - Class - ShapedGraphic - FontInfo - - Color - - b - 0.8 - g - 0.8 - r - 0.8 - - Font - Verdana - Size - 18 - - ID - 25 - Shape - Rectangle - Style - - fill - - Color - - b - 0.6 - g - 0.6 - r - 0.6 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.4 - g - 0.4 - r - 0.4 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;\red204\green204\blue204;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf2 user-2:User\{\}} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{406.08888302813585, 187.47137690331695}, {104.42639923095703, 51}} - Class - ShapedGraphic - FontInfo - - Color - - b - 0.8 - g - 0.8 - r - 0.8 - - Font - Verdana - Size - 18 - - ID - 24 - Shape - Rectangle - Style - - fill - - Color - - b - 0.6 - g - 0.6 - r - 0.6 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.4 - g - 0.4 - r - 0.4 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;\red204\green204\blue204;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf2 table-1:Table\{\}} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{406.08887903555114, 118.52029512620169}, {104.42639923095703, 51}} - Class - ShapedGraphic - FontInfo - - Color - - b - 0.8 - g - 0.8 - r - 0.8 - - Font - Verdana - Size - 18 - - ID - 23 - Shape - Rectangle - Style - - fill - - Color - - b - 0.6 - g - 0.6 - r - 0.6 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.4 - g - 0.4 - r - 0.4 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;\red204\green204\blue204;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf2 user-1:User\{\}} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{556.54055354053583, 325.93280718390133}, {124.41892177446698, 51}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - Size - 18 - - ID - 22 - Shape - Rectangle - Style - - fill - - Color - - b - 0.793851 - g - 0.625208 - r - 0.562982 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.639673 - g - 0.450584 - r - 0.381079 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.511421 - g - 0.637255 - r - 0.120867 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf0 Del(k, v)} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{556.54055354053583, 240.81081643580876}, {124.41892177446698, 51}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - Size - 18 - - ID - 21 - Shape - Rectangle - Style - - fill - - Color - - b - 0.793851 - g - 0.625208 - r - 0.562982 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.639673 - g - 0.450584 - r - 0.381079 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.511421 - g - 0.637255 - r - 0.120867 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf0 Put(k, v)} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{556.54054526129187, 150.52220626433376}, {124.41893005371094, 51}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - Size - 18 - - ID - 20 - Shape - Rectangle - Style - - fill - - Color - - b - 0.793851 - g - 0.625208 - r - 0.562982 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.639673 - g - 0.450584 - r - 0.381079 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.511421 - g - 0.637255 - r - 0.120867 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf0 Get(k, v)} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{276.44083898205287, 214.72973288487913}, {112.36092376708984, 79.447883605957031}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - Size - 18 - - ID - 19 - Shape - Rectangle - Style - - fill - - Color - - b - 0.776486 - g - 0.588495 - r - 0.670497 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.618021 - g - 0.412924 - r - 0.50312 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf0 select id from tb2:[2,5]} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{274.32251963877655, 117.18245433711769}, {112.36092376708984, 68.777008056640625}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - Size - 18 - - ID - 18 - Shape - Rectangle - Style - - fill - - Color - - b - 0.776486 - g - 0.588495 - r - 0.670497 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.618021 - g - 0.412924 - r - 0.50312 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf0 select id from tb1:[1,2,3]} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{103.00000194954862, 30.108108889738791}, {80, 51}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - Size - 18 - - ID - 14 - Shape - Rectangle - Style - - fill - - Color - - b - 0.776486 - g - 0.588495 - r - 0.670497 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.618021 - g - 0.412924 - r - 0.50312 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset134 STHeitiSC-Light;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf0 SQL} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{532.04631600645348, 25.096524339927051}, {166.93052673339844, 79.447883605957031}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - Size - 18 - - ID - 13 - Shape - Rectangle - Style - - fill - - Color - - b - 0.793851 - g - 0.625208 - r - 0.562982 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.639673 - g - 0.450584 - r - 0.381079 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.511421 - g - 0.637255 - r - 0.120867 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;\f1\fnil\fcharset134 STHeitiSC-Light;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf0 Cache\ - -\f1 Store} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{298.9845516069733, 37.412179328792348}, {173.03089904785156, 51}} - Class - ShapedGraphic - FontInfo - - Color - - b - 0.8 - g - 0.8 - r - 0.8 - - Font - Verdana - Size - 18 - - ID - 17 - Shape - Rectangle - Style - - fill - - Color - - b - 0.6 - g - 0.6 - r - 0.6 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.4 - g - 0.4 - r - 0.4 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.590997 - g - 0.18677 - r - 0.567819 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;\red204\green204\blue204;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf2 LRUCacher} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{99.822519035537013, 333.12161552787364}, {88, 51}} - Class - ShapedGraphic - FontInfo - - Color - - archive - - YnBsaXN0MDDUAQIDBAUGBwpZJGFyY2hpdmVy - WCR2ZXJzaW9uVCR0b3BYJG9iamVjdHNfEA9O - U0tleWVkQXJjaGl2ZXISAAGGoNEICVRyb290 - gAGlCwwVGR5VJG51bGzUDQ4PEBESExRfEBJO - U0N1c3RvbUNvbG9yU3BhY2VXTlNXaGl0ZVxO - U0NvbG9yU3BhY2VWJGNsYXNzgAJCMAAQA4AE - 0hYQFxhUTlNJRBACgAPSGhscD1gkY2xhc3Nl - c1okY2xhc3NuYW1log8dWE5TT2JqZWN00hob - HyCiIB1XTlNDb2xvcggRGyQpMkRJTFFTWV9o - fYWSmZueoKKnrK6wtb7JzNXa3QAAAAAAAAEB - AAAAAAAAACEAAAAAAAAAAAAAAAAAAADl - - b - 0 - g - 0 - r - 0 - - Font - Verdana - NSKern - 0.0 - Size - 15 - - ID - 12 - Magnets - - {1, 0} - {-1, 0} - - Shape - Rectangle - Style - - fill - - Color - - b - 0.806569 - g - 0.806569 - r - 0.806569 - - FillType - 2 - GradientAngle - 90 - GradientColor - - w - 0.653285 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.2 - g - 0.2 - r - 0.2 - - Draws - NO - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\qc - -\f0\fs30 \cf0 \expnd0\expndtw0\kerning0 -Delet\ -SQL} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{97.118322659032714, 226.09466078541047}, {88, 51}} - Class - ShapedGraphic - FontInfo - - Color - - b - 0 - g - 0 - r - 0.501961 - - Font - Verdana - NSKern - 0.0 - Size - 15 - - ID - 15 - Magnets - - {1, 0} - {-1, 0} - - Shape - Rectangle - Style - - fill - - Color - - b - 0 - g - 0.389485 - r - 1 - - FillType - 3 - GradientCenter - {-0.34285700000000002, -0.114286} - GradientColor - - b - 0 - g - 0.495748 - r - 1 - - MiddleColor - - b - 0 - g - 0.887657 - r - 1 - - MiddleFraction - 0.6269841194152832 - TrippleBlend - YES - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.2 - g - 0.2 - r - 0.2 - - Draws - NO - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;\red128\green0\blue0;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\qc - -\f0\fs30 \cf2 \expnd0\expndtw0\kerning0 -Update\ -SQL} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - Bounds - {{103, 123.08108395608006}, {80, 51}} - Class - ShapedGraphic - FontInfo - - Color - - b - 0.821332 - g - 0.672602 - r - 0.928374 - - Font - Verdana - Size - 18 - - ID - 16 - Shape - Rectangle - Style - - fill - - Color - - b - 0.436973 - g - 0.155566 - r - 0.758999 - - FillType - 2 - GradientAngle - 90 - GradientColor - - b - 0.25098 - g - 0 - r - 0.501961 - - - shadow - - Beneath - YES - Color - - a - 0.15 - b - 0 - g - 0 - r - 0 - - Fuzziness - 0.0 - ShadowVector - {2, 2} - - stroke - - Color - - b - 0.511421 - g - 0.637255 - r - 0.120867 - - Draws - NO - Width - 2 - - - Text - - Text - {\rtf1\ansi\ansicpg936\cocoartf1187\cocoasubrtf390 -\cocoascreenfonts1{\fonttbl\f0\fnil\fcharset134 STHeitiSC-Light;\f1\fnil\fcharset0 Verdana;} -{\colortbl;\red255\green255\blue255;\red237\green172\blue209;} -\pard\tx560\tx1120\tx1680\tx2240\tx2800\tx3360\tx3920\tx4480\tx5040\tx5600\tx6160\tx6720\pardirnatural\qc - -\f0\fs36 \cf2 select -\f1 -\f0 SQL} - VerticalPad - 0 - - TextRelativeArea - {{0.10000000000000001, 0.14999999999999999}, {0.80000000000000004, 0.69999999999999996}} - - - GridInfo - - GuidesLocked - NO - GuidesVisible - YES - HPages - 2 - ImageCounter - 3 - KeepToScale - - Layers - - - Lock - NO - Name - 图层 1 - Print - YES - View - YES - - - LayoutInfo - - Animate - NO - AutoLayout - 2 - circoMinDist - 18 - circoSeparation - 0.0 - layoutEngine - neato - neatoLineLength - 0.92083334922790527 - neatoSeparation - 0.0 - twopiSeparation - 0.0 - - LinksVisible - NO - MagnetsVisible - NO - MasterSheets - - ModificationDate - 2013-09-29 08:24:57 +0000 - Modifier - Lunny Xiao - NotesVisible - NO - Orientation - 2 - OriginVisible - NO - OutlineStyle - Brainstorming/Clouds - PageBreaks - NO - PrintInfo - - NSBottomMargin - - float - 41 - - NSHorizonalPagination - - coded - BAtzdHJlYW10eXBlZIHoA4QBQISEhAhOU051bWJlcgCEhAdOU1ZhbHVlAISECE5TT2JqZWN0AIWEASqEhAFxlwCG - - NSLeftMargin - - float - 18 - - NSPaperSize - - size - {595, 842} - - NSPrintReverseOrientation - - int - 0 - - NSRightMargin - - float - 18 - - NSTopMargin - - float - 18 - - - PrintOnePage - - ReadOnly - NO - RowAlign - 1 - RowSpacing - 36 - SheetTitle - 版面 1 - SmartAlignmentGuidesActive - YES - SmartDistanceGuidesActive - YES - UniqueID - 1 - UseEntirePage - - VPages - 1 - WindowInfo - - CurrentSheet - 0 - ExpandedCanvases - - FitInWindow - - Frame - {{138, 197}, {869, 617}} - ListView - - OutlineWidth - 142 - RightSidebar - - Sidebar - - SidebarWidth - 138 - VisibleRegion - {{1.0591603214876664, 1.0591603214876664}, {770.00955372153339, 553.94084813804955}} - Zoom - 0.94414412975311279 - ZoomValues - - - 版面 1 - 0.0 - 1 - - - - - diff --git a/docs/images/cache_design.png b/docs/images/cache_design.png deleted file mode 100644 index 11ce8176..00000000 Binary files a/docs/images/cache_design.png and /dev/null differ diff --git a/engine.go b/engine.go index 96cb8ee9..47c25b09 100644 --- a/engine.go +++ b/engine.go @@ -5,82 +5,120 @@ package xorm import ( - "bufio" - "bytes" "context" "database/sql" - "encoding/gob" "errors" "fmt" "io" "os" "reflect" + "runtime" "strconv" "strings" - "sync" "time" - "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/contexts" + "xorm.io/xorm/core" + "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/log" + "xorm.io/xorm/names" + "xorm.io/xorm/schemas" + "xorm.io/xorm/tags" ) // Engine is the major struct of xorm, it means a database manager. // Commonly, an application only need one engine type Engine struct { - db *core.DB - dialect core.Dialect + cacherMgr *caches.Manager + defaultContext context.Context + dialect dialects.Dialect + engineGroup *EngineGroup + logger log.ContextLogger + tagParser *tags.Parser + db *core.DB - ColumnMapper core.IMapper - TableMapper core.IMapper - TagIdentifier string - Tables map[reflect.Type]*core.Table + driverName string + dataSourceName string - mutex *sync.RWMutex - Cacher core.Cacher - - showSQL bool - showExecTime bool - - logger core.ILogger TZLocation *time.Location // The timezone of the application DatabaseTZ *time.Location // The timezone of the database - disableGlobalCache bool - - tagHandlers map[string]tagHandler - - engineGroup *EngineGroup - - cachers map[string]core.Cacher - cacherLock sync.RWMutex - - defaultContext context.Context + logSessionID bool // create session id } -func (engine *Engine) setCacher(tableName string, cacher core.Cacher) { - engine.cacherLock.Lock() - engine.cachers[tableName] = cacher - engine.cacherLock.Unlock() -} - -func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) { - engine.setCacher(tableName, cacher) -} - -func (engine *Engine) getCacher(tableName string) core.Cacher { - var cacher core.Cacher - var ok bool - engine.cacherLock.RLock() - cacher, ok = engine.cachers[tableName] - engine.cacherLock.RUnlock() - if !ok && !engine.disableGlobalCache { - cacher = engine.Cacher +// NewEngine new a db manager according to the parameter. Currently support four +// drivers +func NewEngine(driverName string, dataSourceName string) (*Engine, error) { + dialect, err := dialects.OpenDialect(driverName, dataSourceName) + if err != nil { + return nil, err } - return cacher + + db, err := core.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + + cacherMgr := caches.NewManager() + mapper := names.NewCacheMapper(new(names.SnakeMapper)) + tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) + + engine := &Engine{ + dialect: dialect, + TZLocation: time.Local, + defaultContext: context.Background(), + cacherMgr: cacherMgr, + tagParser: tagParser, + driverName: driverName, + dataSourceName: dataSourceName, + db: db, + logSessionID: false, + } + + if dialect.URI().DBType == schemas.SQLITE { + engine.DatabaseTZ = time.UTC + } else { + engine.DatabaseTZ = time.Local + } + + logger := log.NewSimpleLogger(os.Stdout) + logger.SetLevel(log.LOG_INFO) + engine.SetLogger(log.NewLoggerAdapter(logger)) + + runtime.SetFinalizer(engine, func(engine *Engine) { + engine.Close() + }) + + return engine, nil } -func (engine *Engine) GetCacher(tableName string) core.Cacher { - return engine.getCacher(tableName) +// NewEngineWithParams new a db manager with params. The params will be passed to dialects. +func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) { + engine, err := NewEngine(driverName, dataSourceName) + engine.dialect.SetParams(params) + return engine, err +} + +// EnableSessionID if enable session id +func (engine *Engine) EnableSessionID(enable bool) { + engine.logSessionID = enable +} + +// SetCacher sets cacher for the table +func (engine *Engine) SetCacher(tableName string, cacher caches.Cacher) { + engine.cacherMgr.SetCacher(tableName, cacher) +} + +// GetCacher returns the cachher of the special table +func (engine *Engine) GetCacher(tableName string) caches.Cacher { + return engine.cacherMgr.GetCacher(tableName) +} + +// SetQuotePolicy sets the special quote policy +func (engine *Engine) SetQuotePolicy(quotePolicy dialects.QuotePolicy) { + engine.dialect.SetQuotePolicy(quotePolicy) } // BufferSize sets buffer size for iterate @@ -90,97 +128,64 @@ func (engine *Engine) BufferSize(size int) *Session { return session.BufferSize(size) } -// CondDeleted returns the conditions whether a record is soft deleted. -func (engine *Engine) CondDeleted(colName string) builder.Cond { - if engine.dialect.DBType() == core.MSSQL { - return builder.IsNull{colName} - } - return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1}) -} - // ShowSQL show SQL statement or not on logger if log level is great than INFO func (engine *Engine) ShowSQL(show ...bool) { engine.logger.ShowSQL(show...) - if len(show) == 0 { - engine.showSQL = true - } else { - engine.showSQL = show[0] - } -} - -// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO -func (engine *Engine) ShowExecTime(show ...bool) { - if len(show) == 0 { - engine.showExecTime = true - } else { - engine.showExecTime = show[0] - } + engine.DB().Logger = engine.logger } // Logger return the logger interface -func (engine *Engine) Logger() core.ILogger { +func (engine *Engine) Logger() log.ContextLogger { return engine.logger } // SetLogger set the new logger -func (engine *Engine) SetLogger(logger core.ILogger) { - engine.logger = logger - engine.showSQL = logger.IsShowSQL() - engine.dialect.SetLogger(logger) +func (engine *Engine) SetLogger(logger interface{}) { + var realLogger log.ContextLogger + switch t := logger.(type) { + case log.Logger: + realLogger = log.NewLoggerAdapter(t) + case log.ContextLogger: + realLogger = t + } + engine.logger = realLogger + engine.DB().Logger = realLogger } // SetLogLevel sets the logger level -func (engine *Engine) SetLogLevel(level core.LogLevel) { +func (engine *Engine) SetLogLevel(level log.LogLevel) { engine.logger.SetLevel(level) } // SetDisableGlobalCache disable global cache or not func (engine *Engine) SetDisableGlobalCache(disable bool) { - if engine.disableGlobalCache != disable { - engine.disableGlobalCache = disable - } + engine.cacherMgr.SetDisableGlobalCache(disable) } // DriverName return the current sql driver's name func (engine *Engine) DriverName() string { - return engine.dialect.DriverName() + return engine.driverName } // DataSourceName return the current connection string func (engine *Engine) DataSourceName() string { - return engine.dialect.DataSourceName() + return engine.dataSourceName } // SetMapper set the name mapping rules -func (engine *Engine) SetMapper(mapper core.IMapper) { +func (engine *Engine) SetMapper(mapper names.Mapper) { engine.SetTableMapper(mapper) engine.SetColumnMapper(mapper) } // SetTableMapper set the table name mapping rule -func (engine *Engine) SetTableMapper(mapper core.IMapper) { - engine.TableMapper = mapper +func (engine *Engine) SetTableMapper(mapper names.Mapper) { + engine.tagParser.SetTableMapper(mapper) } // SetColumnMapper set the column name mapping rule -func (engine *Engine) SetColumnMapper(mapper core.IMapper) { - engine.ColumnMapper = mapper -} - -// SupportInsertMany If engine's database support batch insert records like -// "insert into user values (name, age), (name, age)". -// When the return is ture, then engine.Insert(&users) will -// generate batch sql and exeute. -func (engine *Engine) SupportInsertMany() bool { - return engine.dialect.SupportInsertMany() -} - -func (engine *Engine) quoteColumns(columnStr string) string { - columns := strings.Split(columnStr, ",") - for i := 0; i < len(columns); i++ { - columns[i] = engine.Quote(strings.TrimSpace(columns[i])) - } - return strings.Join(columns, ",") +func (engine *Engine) SetColumnMapper(mapper names.Mapper) { + engine.tagParser.SetColumnMapper(mapper) } // Quote Use QuoteStr quote the string sql @@ -206,64 +211,12 @@ func (engine *Engine) QuoteTo(buf *strings.Builder, value string) { if value == "" { return } - - quoteTo(buf, engine.dialect.Quote(""), value) -} - -func quoteTo(buf *strings.Builder, quotePair string, value string) { - if len(quotePair) < 2 { // no quote - _, _ = buf.WriteString(value) - return - } - - prefix, suffix := quotePair[0], quotePair[1] - - i := 0 - for i < len(value) { - // start of a token; might be already quoted - if value[i] == '.' { - _ = buf.WriteByte('.') - i++ - } else if value[i] == prefix || value[i] == '`' { - // Has quotes; skip/normalize `name` to prefix+name+sufix - var ch byte - if value[i] == prefix { - ch = suffix - } else { - ch = '`' - } - i++ - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != ch; i++ { - _ = buf.WriteByte(value[i]) - } - _ = buf.WriteByte(suffix) - i++ - } else { - // Requires quotes - _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != '.'; i++ { - _ = buf.WriteByte(value[i]) - } - _ = buf.WriteByte(suffix) - } - } -} - -func (engine *Engine) quote(sql string) string { - return engine.dialect.Quote(sql) -} - -// SqlType will be deprecated, please use SQLType instead -// -// Deprecated: use SQLType instead -func (engine *Engine) SqlType(c *core.Column) string { - return engine.SQLType(c) + engine.dialect.Quoter().QuoteTo(buf, value) } // SQLType A simple wrapper to dialect's core.SqlType method -func (engine *Engine) SQLType(c *core.Column) string { - return engine.dialect.SqlType(c) +func (engine *Engine) SQLType(c *schemas.Column) string { + return engine.dialect.SQLType(c) } // AutoIncrStr Database's autoincrement statement @@ -273,27 +226,27 @@ func (engine *Engine) AutoIncrStr() string { // SetConnMaxLifetime sets the maximum amount of time a connection may be reused. func (engine *Engine) SetConnMaxLifetime(d time.Duration) { - engine.db.SetConnMaxLifetime(d) + engine.DB().SetConnMaxLifetime(d) } // SetMaxOpenConns is only available for go 1.2+ func (engine *Engine) SetMaxOpenConns(conns int) { - engine.db.SetMaxOpenConns(conns) + engine.DB().SetMaxOpenConns(conns) } // SetMaxIdleConns set the max idle connections on pool, default is 2 func (engine *Engine) SetMaxIdleConns(conns int) { - engine.db.SetMaxIdleConns(conns) + engine.DB().SetMaxIdleConns(conns) } // SetDefaultCacher set the default cacher. Xorm's default not enable cacher. -func (engine *Engine) SetDefaultCacher(cacher core.Cacher) { - engine.Cacher = cacher +func (engine *Engine) SetDefaultCacher(cacher caches.Cacher) { + engine.cacherMgr.SetDefaultCacher(cacher) } // GetDefaultCacher returns the default cacher -func (engine *Engine) GetDefaultCacher() core.Cacher { - return engine.Cacher +func (engine *Engine) GetDefaultCacher() caches.Cacher { + return engine.cacherMgr.GetDefaultCacher() } // NoCache If you has set default cacher, and you want temporilly stop use cache, @@ -312,14 +265,14 @@ func (engine *Engine) NoCascade() *Session { } // MapCacher Set a table use a special cacher -func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error { - engine.setCacher(engine.TableName(bean, true), cacher) +func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error { + engine.SetCacher(dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, true), cacher) return nil } // NewDB provides an interface to operate database directly func (engine *Engine) NewDB() (*core.DB, error) { - return core.OpenDialect(engine.dialect) + return core.Open(engine.driverName, engine.dataSourceName) } // DB return the wrapper of sql.DB @@ -328,20 +281,18 @@ func (engine *Engine) DB() *core.DB { } // Dialect return database dialect -func (engine *Engine) Dialect() core.Dialect { +func (engine *Engine) Dialect() dialects.Dialect { return engine.dialect } // NewSession New a session func (engine *Engine) NewSession() *Session { - session := &Session{engine: engine} - session.Init() - return session + return newSession(engine) } // Close the engine func (engine *Engine) Close() error { - return engine.db.Close() + return engine.DB().Close() } // Ping tests if database is alive @@ -351,25 +302,6 @@ func (engine *Engine) Ping() error { return session.Ping() } -// logSQL save sql -func (engine *Engine) logSQL(sqlStr string, sqlArgs ...interface{}) { - if engine.showSQL && !engine.showExecTime { - if len(sqlArgs) > 0 { - engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs) - } else { - engine.logger.Infof("[SQL] %v", sqlStr) - } - } -} - -// Sql provides raw sql input parameter. When you have a complex SQL statement -// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. -// -// Deprecated: use SQL instead. -func (engine *Engine) Sql(querystring string, args ...interface{}) *Session { - return engine.SQL(querystring, args...) -} - // SQL method let's you manually write raw SQL and operate // For example: // @@ -398,26 +330,33 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session { return session.NoAutoCondition(no...) } -func (engine *Engine) loadTableInfo(table *core.Table) error { - colSeq, cols, err := engine.dialect.GetColumns(table.Name) +func (engine *Engine) loadTableInfo(table *schemas.Table) error { + colSeq, cols, err := engine.dialect.GetColumns(engine.db, engine.defaultContext, table.Name) if err != nil { return err } for _, name := range colSeq { table.AddColumn(cols[name]) } - indexes, err := engine.dialect.GetIndexes(table.Name) + indexes, err := engine.dialect.GetIndexes(engine.db, engine.defaultContext, table.Name) if err != nil { return err } table.Indexes = indexes + var seq int for _, index := range indexes { for _, name := range index.Cols { - if col := table.GetColumn(name); col != nil { + parts := strings.Split(name, " ") + if len(parts) > 1 { + if parts[1] == "DESC" { + seq = 1 + } + } + if col := table.GetColumn(parts[0]); col != nil { col.Indexes[index.Name] = index.Type } else { - return fmt.Errorf("Unknown col %s in index %v of table %v, columns %v", name, index.Name, table.Name, table.ColumnsSeq()) + return fmt.Errorf("Unknown col %s seq %d, in index %v of table %v, columns %v", name, seq, index.Name, table.Name, table.ColumnsSeq()) } } } @@ -425,8 +364,8 @@ func (engine *Engine) loadTableInfo(table *core.Table) error { } // DBMetas Retrieve all tables, columns, indexes' informations from database. -func (engine *Engine) DBMetas() ([]*core.Table, error) { - tables, err := engine.dialect.GetTables() +func (engine *Engine) DBMetas() ([]*schemas.Table, error) { + tables, err := engine.dialect.GetTables(engine.db, engine.defaultContext) if err != nil { return nil, err } @@ -440,7 +379,7 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) { } // DumpAllToFile dump database all table structs and data to a file -func (engine *Engine) DumpAllToFile(fp string, tp ...core.DbType) error { +func (engine *Engine) DumpAllToFile(fp string, tp ...schemas.DBType) error { f, err := os.Create(fp) if err != nil { return err @@ -450,7 +389,7 @@ func (engine *Engine) DumpAllToFile(fp string, tp ...core.DbType) error { } // DumpAll dump database all table structs and data to w -func (engine *Engine) DumpAll(w io.Writer, tp ...core.DbType) error { +func (engine *Engine) DumpAll(w io.Writer, tp ...schemas.DBType) error { tables, err := engine.DBMetas() if err != nil { return err @@ -459,7 +398,7 @@ func (engine *Engine) DumpAll(w io.Writer, tp ...core.DbType) error { } // DumpTablesToFile dump specified tables to SQL file. -func (engine *Engine) DumpTablesToFile(tables []*core.Table, fp string, tp ...core.DbType) error { +func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...schemas.DBType) error { f, err := os.Create(fp) if err != nil { return err @@ -469,55 +408,70 @@ func (engine *Engine) DumpTablesToFile(tables []*core.Table, fp string, tp ...co } // DumpTables dump specify tables to io.Writer -func (engine *Engine) DumpTables(tables []*core.Table, w io.Writer, tp ...core.DbType) error { +func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { return engine.dumpTables(tables, w, tp...) } // dumpTables dump database all table structs and data to w with specify db type -func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.DbType) error { - var dialect core.Dialect - var distDBName string +func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { + var dstDialect dialects.Dialect if len(tp) == 0 { - dialect = engine.dialect - distDBName = string(engine.dialect.DBType()) + dstDialect = engine.dialect } else { - dialect = core.QueryDialect(tp[0]) - if dialect == nil { + dstDialect = dialects.QueryDialect(tp[0]) + if dstDialect == nil { return errors.New("Unsupported database type") } - dialect.Init(nil, engine.dialect.URI(), "", "") - distDBName = string(tp[0]) + + uri := engine.dialect.URI() + destURI := *uri + dstDialect.Init(&destURI) } - _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm v%s %s, from %s to %s*/\n\n", - Version, time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.DBType(), strings.ToUpper(distDBName))) + _, err := io.WriteString(w, fmt.Sprintf("/*Generated by xorm %s, from %s to %s*/\n\n", + time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"), engine.dialect.URI().DBType, dstDialect.URI().DBType)) if err != nil { return err } for i, table := range tables { + tableName := table.Name + if dstDialect.URI().Schema != "" { + tableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, table.Name) + } + originalTableName := table.Name + if engine.dialect.URI().Schema != "" { + originalTableName = fmt.Sprintf("%s.%s", engine.dialect.URI().Schema, table.Name) + } if i > 0 { _, err = io.WriteString(w, "\n") if err != nil { return err } } - _, err = io.WriteString(w, dialect.CreateTableSql(table, "", table.StoreEngine, "")+";\n") - if err != nil { - return err + sqls, _ := dstDialect.CreateTableSQL(table, tableName) + for _, s := range sqls { + _, err = io.WriteString(w, s+";\n") + if err != nil { + return err + } } + if len(table.PKColumns()) > 0 && dstDialect.URI().DBType == schemas.MSSQL { + fmt.Fprintf(w, "SET IDENTITY_INSERT [%s] ON;\n", table.Name) + } + for _, index := range table.Indexes { - _, err = io.WriteString(w, dialect.CreateIndexSql(table.Name, index)+";\n") + _, err = io.WriteString(w, dstDialect.CreateIndexSQL(table.Name, index)+";\n") if err != nil { return err } } cols := table.ColumnsSeq() - colNames := engine.dialect.Quote(strings.Join(cols, engine.dialect.Quote(", "))) - destColNames := dialect.Quote(strings.Join(cols, dialect.Quote(", "))) + colNames := engine.dialect.Quoter().Join(cols, ", ") + destColNames := dstDialect.Quoter().Join(cols, ", ") - rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name)) + rows, err := engine.DB().QueryContext(engine.defaultContext, "SELECT "+colNames+" FROM "+engine.Quote(originalTableName)) if err != nil { return err } @@ -530,7 +484,7 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D return err } - _, err = io.WriteString(w, "INSERT INTO "+dialect.Quote(table.Name)+" ("+destColNames+") VALUES (") + _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(tableName)+" ("+destColNames+") VALUES (") if err != nil { return err } @@ -553,26 +507,26 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } } else if col.SQLType.IsBlob() { if reflect.TypeOf(d).Kind() == reflect.Slice { - temp += fmt.Sprintf(", %s", dialect.FormatBytes(d.([]byte))) + temp += fmt.Sprintf(", %s", dstDialect.FormatBytes(d.([]byte))) } else if reflect.TypeOf(d).Kind() == reflect.String { temp += fmt.Sprintf(", '%s'", d.(string)) } } else if col.SQLType.IsNumeric() { switch reflect.TypeOf(d).Kind() { case reflect.Slice: - if col.SQLType.Name == core.Bool { + if col.SQLType.Name == schemas.Bool { temp += fmt.Sprintf(", %v", strconv.FormatBool(d.([]byte)[0] != byte('0'))) } else { temp += fmt.Sprintf(", %s", string(d.([]byte))) } case reflect.Int16, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Int: - if col.SQLType.Name == core.Bool { + if col.SQLType.Name == schemas.Bool { temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Int() > 0)) } else { temp += fmt.Sprintf(", %v", d) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if col.SQLType.Name == core.Bool { + if col.SQLType.Name == schemas.Bool { temp += fmt.Sprintf(", %v", strconv.FormatBool(reflect.ValueOf(d).Uint() > 0)) } else { temp += fmt.Sprintf(", %v", d) @@ -600,8 +554,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D } // FIXME: Hack for postgres - if string(dialect.DBType()) == core.POSTGRES && table.AutoIncrColumn() != nil { - _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quote(table.Name)+"), 1), false);\n") + if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { + _, err = io.WriteString(w, "SELECT setval('"+tableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(tableName)+"), 1), false);\n") if err != nil { return err } @@ -624,13 +578,6 @@ func (engine *Engine) Where(query interface{}, args ...interface{}) *Session { return session.Where(query, args...) } -// Id will be deprecated, please use ID instead -func (engine *Engine) Id(id interface{}) *Session { - session := engine.NewSession() - session.isAutoClose = true - return session.Id(id) -} - // ID method provoide a condition as (id) = ? func (engine *Engine) ID(id interface{}) *Session { session := engine.NewSession() @@ -838,46 +785,9 @@ func (engine *Engine) Having(conditions string) *Session { return session.Having(conditions) } -// UnMapType removes the datbase mapper of a type -func (engine *Engine) UnMapType(t reflect.Type) { - engine.mutex.Lock() - defer engine.mutex.Unlock() - delete(engine.Tables, t) -} - -func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) { - t := v.Type() - engine.mutex.Lock() - defer engine.mutex.Unlock() - table, ok := engine.Tables[t] - if !ok { - var err error - table, err = engine.mapType(v) - if err != nil { - return nil, err - } - - engine.Tables[t] = table - if engine.Cacher != nil { - if v.CanAddr() { - engine.GobRegister(v.Addr().Interface()) - } else { - engine.GobRegister(v.Interface()) - } - } - } - return table, nil -} - -// GobRegister register one struct to gob for cache use -func (engine *Engine) GobRegister(v interface{}) *Engine { - gob.Register(v) - return engine -} - // Table table struct type Table struct { - *core.Table + *schemas.Table Name string } @@ -887,222 +797,9 @@ func (t *Table) IsValid() bool { } // TableInfo get table info according to bean's content -func (engine *Engine) TableInfo(bean interface{}) *Table { - v := rValue(bean) - tb, err := engine.autoMapType(v) - if err != nil { - engine.logger.Error(err) - } - return &Table{tb, engine.TableName(bean)} -} - -func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = indexType - } else { - index := core.NewIndex(indexName, indexType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = indexType - } -} - -// TableName table name interface to define customerize table name -type TableName interface { - TableName() string -} - -var ( - tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() -) - -func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { - t := v.Type() - table := core.NewEmptyTable() - table.Type = t - table.Name = engine.tbNameForMap(v) - - var idFieldColName string - var hasCacheTag, hasNoCacheTag bool - - for i := 0; i < t.NumField(); i++ { - tag := t.Field(i).Tag - - ormTagStr := tag.Get(engine.TagIdentifier) - var col *core.Column - fieldValue := v.Field(i) - fieldType := fieldValue.Type() - - if ormTagStr != "" { - col = &core.Column{ - FieldName: t.Field(i).Name, - Nullable: true, - IsPrimaryKey: false, - IsAutoIncrement: false, - MapType: core.TWOSIDES, - Indexes: make(map[string]int), - DefaultIsEmpty: true, - } - tags := splitTag(ormTagStr) - - if len(tags) > 0 { - if tags[0] == "-" { - continue - } - - var ctx = tagContext{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - engine: engine, - } - - if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { - pStart := strings.Index(tags[0], "(") - if pStart > -1 && strings.HasSuffix(tags[0], ")") { - var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { - return r == '\'' || r == '"' - }) - - ctx.params = []string{tagPrefix} - } - - if err := ExtendsTagHandler(&ctx); err != nil { - return nil, err - } - continue - } - - for j, key := range tags { - if ctx.ignoreNext { - ctx.ignoreNext = false - continue - } - - k := strings.ToUpper(key) - ctx.tagName = k - ctx.params = []string{} - - pStart := strings.Index(k, "(") - if pStart == 0 { - return nil, errors.New("( could not be the first charactor") - } - if pStart > -1 { - if !strings.HasSuffix(k, ")") { - return nil, fmt.Errorf("field %s tag %s cannot match ) charactor", col.FieldName, key) - } - - ctx.tagName = k[:pStart] - ctx.params = strings.Split(key[pStart+1:len(k)-1], ",") - } - - if j > 0 { - ctx.preTag = strings.ToUpper(tags[j-1]) - } - if j < len(tags)-1 { - ctx.nextTag = tags[j+1] - } else { - ctx.nextTag = "" - } - - if h, ok := engine.tagHandlers[ctx.tagName]; ok { - if err := h(&ctx); err != nil { - return nil, err - } - } else { - if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { - col.Name = key[1 : len(key)-1] - } else { - col.Name = key - } - } - - if ctx.hasCacheTag { - hasCacheTag = true - } - if ctx.hasNoCacheTag { - hasNoCacheTag = true - } - } - - if col.SQLType.Name == "" { - col.SQLType = core.Type2SQLType(fieldType) - } - engine.dialect.SqlType(col) - if col.Length == 0 { - col.Length = col.SQLType.DefaultLength - } - if col.Length2 == 0 { - col.Length2 = col.SQLType.DefaultLength2 - } - if col.Name == "" { - col.Name = engine.ColumnMapper.Obj2Table(t.Field(i).Name) - } - - if ctx.isUnique { - ctx.indexNames[col.Name] = core.UniqueType - } else if ctx.isIndex { - ctx.indexNames[col.Name] = core.IndexType - } - - for indexName, indexType := range ctx.indexNames { - addIndex(indexName, table, col, indexType) - } - } - } else { - var sqlType core.SQLType - if fieldValue.CanAddr() { - if _, ok := fieldValue.Addr().Interface().(core.Conversion); ok { - sqlType = core.SQLType{Name: core.Text} - } - } - if _, ok := fieldValue.Interface().(core.Conversion); ok { - sqlType = core.SQLType{Name: core.Text} - } else { - sqlType = core.Type2SQLType(fieldType) - } - col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) - - if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { - idFieldColName = col.Name - } - } - if col.IsAutoIncrement { - col.Nullable = false - } - - table.AddColumn(col) - - } // end for - - if idFieldColName != "" && len(table.PrimaryKeys) == 0 { - col := table.GetColumn(idFieldColName) - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - table.PrimaryKeys = append(table.PrimaryKeys, col.Name) - table.AutoIncrement = col.Name - } - - if hasCacheTag { - if engine.Cacher != nil { // !nash! use engine's cacher if provided - engine.logger.Info("enable cache on table:", table.Name) - engine.setCacher(table.Name, engine.Cacher) - } else { - engine.logger.Info("enable LRU cache on table:", table.Name) - engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000)) - } - } - if hasNoCacheTag { - engine.logger.Info("disable cache on table:", table.Name) - engine.setCacher(table.Name, nil) - } - - return table, nil +func (engine *Engine) TableInfo(bean interface{}) (*schemas.Table, error) { + v := utils.ReflectValue(bean) + return engine.tagParser.ParseWithCache(v) } // IsTableEmpty if a table has any reocrd @@ -1119,93 +816,9 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) { return session.IsTableExist(beanOrTableName) } -// IdOf get id from one struct -// -// Deprecated: use IDOf instead. -func (engine *Engine) IdOf(bean interface{}) core.PK { - return engine.IDOf(bean) -} - -// IDOf get id from one struct -func (engine *Engine) IDOf(bean interface{}) core.PK { - return engine.IdOfV(reflect.ValueOf(bean)) -} - -// IdOfV get id from one value of struct -// -// Deprecated: use IDOfV instead. -func (engine *Engine) IdOfV(rv reflect.Value) core.PK { - return engine.IDOfV(rv) -} - -// IDOfV get id from one value of struct -func (engine *Engine) IDOfV(rv reflect.Value) core.PK { - pk, err := engine.idOfV(rv) - if err != nil { - engine.logger.Error(err) - return nil - } - return pk -} - -func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) { - v := reflect.Indirect(rv) - table, err := engine.autoMapType(v) - if err != nil { - return nil, err - } - - pk := make([]interface{}, len(table.PrimaryKeys)) - for i, col := range table.PKColumns() { - var err error - - fieldName := col.FieldName - for { - parts := strings.SplitN(fieldName, ".", 2) - if len(parts) == 1 { - break - } - - v = v.FieldByName(parts[0]) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Struct { - return nil, ErrUnSupportedType - } - fieldName = parts[1] - } - - pkField := v.FieldByName(fieldName) - switch pkField.Kind() { - case reflect.String: - pk[i], err = engine.idTypeAssertion(col, pkField.String()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - pk[i], err = engine.idTypeAssertion(col, strconv.FormatInt(pkField.Int(), 10)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - // id of uint will be converted to int64 - pk[i], err = engine.idTypeAssertion(col, strconv.FormatUint(pkField.Uint(), 10)) - } - - if err != nil { - return nil, err - } - } - return core.PK(pk), nil -} - -func (engine *Engine) idTypeAssertion(col *core.Column, sid string) (interface{}, error) { - if col.SQLType.IsNumeric() { - n, err := strconv.ParseInt(sid, 10, 64) - if err != nil { - return nil, err - } - return n, nil - } else if col.SQLType.IsText() { - return sid, nil - } else { - return nil, errors.New("not supported") - } +// TableName returns table name with schema prefix if has +func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { + return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...) } // CreateIndexes create indexes @@ -1224,8 +837,8 @@ func (engine *Engine) CreateUniques(bean interface{}) error { // ClearCacheBean if enabled cache, clear the cache bean func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { - tableName := engine.TableName(bean) - cacher := engine.getCacher(tableName) + tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) + cacher := engine.GetCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) cacher.DelBean(tableName, id) @@ -1236,8 +849,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { // ClearCache if enabled cache, clear some tables' cache func (engine *Engine) ClearCache(beans ...interface{}) error { for _, bean := range beans { - tableName := engine.TableName(bean) - cacher := engine.getCacher(tableName) + tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) + cacher := engine.GetCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) cacher.ClearBeans(tableName) @@ -1246,6 +859,11 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { return nil } +// UnMapType remove table from tables cache +func (engine *Engine) UnMapType(t reflect.Type) { + engine.tagParser.ClearCacheTable(t) +} + // Sync the new struct changes to database, this method will automatically add // table, column, index, unique. but will not delete or change anything. // If you change some field, you should change the database manually. @@ -1254,9 +872,9 @@ func (engine *Engine) Sync(beans ...interface{}) error { defer session.Close() for _, bean := range beans { - v := rValue(bean) - tableNameNoSchema := engine.TableName(bean) - table, err := engine.autoMapType(v) + v := utils.ReflectValue(bean) + tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return err } @@ -1287,12 +905,12 @@ func (engine *Engine) Sync(beans ...interface{}) error { } } else { for _, col := range table.Columns() { - isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name) + isExist, err := engine.dialect.IsColumnExist(engine.db, session.ctx, tableNameNoSchema, col.Name) if err != nil { return err } if !isExist { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } err = session.addColumn(col.Name) @@ -1303,16 +921,16 @@ func (engine *Engine) Sync(beans ...interface{}) error { } for name, index := range table.Indexes { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - if index.Type == core.UniqueType { + if index.Type == schemas.UniqueType { isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) if err != nil { return err } if !isExist { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } @@ -1321,13 +939,13 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } } - } else if index.Type == core.IndexType { + } else if index.Type == schemas.IndexType { isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) if err != nil { return err } if !isExist { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } @@ -1561,108 +1179,36 @@ func (engine *Engine) SumsInt(bean interface{}, colNames ...string) ([]int64, er // ImportFile SQL DDL file func (engine *Engine) ImportFile(ddlPath string) ([]sql.Result, error) { - file, err := os.Open(ddlPath) - if err != nil { - return nil, err - } - defer file.Close() - return engine.Import(file) + session := engine.NewSession() + defer session.Close() + return session.ImportFile(ddlPath) } // Import SQL DDL from io.Reader func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { - var results []sql.Result - var lastError error - scanner := bufio.NewScanner(r) - - semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, ';'); i >= 0 { - return i + 1, data[0:i], nil - } - // If we're at EOF, we have a final, non-terminated line. Return it. - if atEOF { - return len(data), data, nil - } - // Request more data. - return 0, nil, nil - } - - scanner.Split(semiColSpliter) - - for scanner.Scan() { - query := strings.Trim(scanner.Text(), " \t\n\r") - if len(query) > 0 { - engine.logSQL(query) - result, err := engine.DB().Exec(query) - results = append(results, result) - if err != nil { - return nil, err - } - } - } - - return results, lastError + session := engine.NewSession() + defer session.Close() + return session.Import(r) } // nowTime return current time -func (engine *Engine) nowTime(col *core.Column) (interface{}, time.Time) { +func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) { t := time.Now() var tz = engine.DatabaseTZ if !col.DisableTimeZone && col.TimeZone != nil { tz = col.TimeZone } - return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) -} - -func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{}) { - if t.IsZero() { - if col.Nullable { - return nil - } - return "" - } - - if col.TimeZone != nil { - return engine.formatTime(col.SQLType.Name, t.In(col.TimeZone)) - } - return engine.formatTime(col.SQLType.Name, t.In(engine.DatabaseTZ)) -} - -// formatTime format time as column type -func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) { - switch sqlTypeName { - case core.Time: - s := t.Format("2006-01-02 15:04:05") // time.RFC3339 - v = s[11:19] - case core.Date: - v = t.Format("2006-01-02") - case core.DateTime, core.TimeStamp: - v = t.Format("2006-01-02 15:04:05") - case core.TimeStampz: - if engine.dialect.DBType() == core.MSSQL { - v = t.Format("2006-01-02T15:04:05.9999999Z07:00") - } else { - v = t.Format(time.RFC3339Nano) - } - case core.BigInt, core.Int: - v = t.Unix() - default: - v = t - } - return + return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) } // GetColumnMapper returns the column name mapper -func (engine *Engine) GetColumnMapper() core.IMapper { - return engine.ColumnMapper +func (engine *Engine) GetColumnMapper() names.Mapper { + return engine.tagParser.GetColumnMapper() } // GetTableMapper returns the table name mapper -func (engine *Engine) GetTableMapper() core.IMapper { - return engine.TableMapper +func (engine *Engine) GetTableMapper() names.Mapper { + return engine.tagParser.GetTableMapper() } // GetTZLocation returns time zone of the application @@ -1687,7 +1233,11 @@ func (engine *Engine) SetTZDatabase(tz *time.Location) { // SetSchema sets the schema of database func (engine *Engine) SetSchema(schema string) { - engine.dialect.URI().Schema = schema + engine.dialect.URI().SetSchema(schema) +} + +func (engine *Engine) AddHook(hook contexts.Hook) { + engine.db.AddHook(hook) } // Unscoped always disable struct tag "deleted" @@ -1696,3 +1246,47 @@ func (engine *Engine) Unscoped() *Session { session.isAutoClose = true return session.Unscoped() } + +func (engine *Engine) tbNameWithSchema(v string) string { + return dialects.TableNameWithSchema(engine.dialect, v) +} + +// ContextHook creates a session with the context +func (engine *Engine) Context(ctx context.Context) *Session { + session := engine.NewSession() + session.isAutoClose = true + return session.Context(ctx) +} + +// SetDefaultContext set the default context +func (engine *Engine) SetDefaultContext(ctx context.Context) { + engine.defaultContext = ctx +} + +// PingContext tests if database is alive +func (engine *Engine) PingContext(ctx context.Context) error { + session := engine.NewSession() + defer session.Close() + return session.PingContext(ctx) +} + +// Transaction Execute sql wrapped in a transaction(abbr as tx), tx will automatic commit if no errors occurred +func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interface{}, error) { + session := engine.NewSession() + defer session.Close() + + if err := session.Begin(); err != nil { + return nil, err + } + + result, err := f(session) + if err != nil { + return result, err + } + + if err := session.Commit(); err != nil { + return result, err + } + + return result, nil +} diff --git a/engine_cond.go b/engine_cond.go deleted file mode 100644 index 702ac804..00000000 --- a/engine_cond.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "strings" - "time" - - "xorm.io/builder" - "xorm.io/core" -) - -func (engine *Engine) buildConds(table *core.Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, - includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { - var conds []builder.Cond - for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - - if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { - continue - } - if col.SQLType.IsJson() { - continue - } - - var colName string - if addedTableName { - var nm = tableName - if len(aliasName) > 0 { - nm = aliasName - } - colName = engine.Quote(nm) + "." + engine.Quote(col.Name) - } else { - colName = engine.Quote(col.Name) - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - if !strings.Contains(err.Error(), "is not valid") { - engine.logger.Warn(err) - } - continue - } - - if col.IsDeleted && !unscoped { // tag "deleted" is enabled - conds = append(conds, engine.CondDeleted(colName)) - } - - fieldValue := *fieldValuePtr - if fieldValue.Interface() == nil { - continue - } - - fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := useAllCols - - if b, ok := getFlagForColumn(mustColumnMap, col); ok { - if b { - requiredField = true - } else { - continue - } - } - - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - conds = append(conds, builder.Eq{colName: nil}) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } - - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - t := int64(fieldValue.Uint()) - val = reflect.ValueOf(&t).Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = engine.formatColTime(col, t) - } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil { - continue - } - } else { - if col.SQLType.IsJson() { - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !isZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) - } - } else { - val = fieldValue.Interface() - } - } - } - case reflect.Array: - continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - - conds = append(conds, builder.Eq{colName: val}) - } - - return builder.And(conds...), nil -} diff --git a/engine_context.go b/engine_context.go deleted file mode 100644 index c6cbb76c..00000000 --- a/engine_context.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.8 - -package xorm - -import "context" - -// Context creates a session with the context -func (engine *Engine) Context(ctx context.Context) *Session { - session := engine.NewSession() - session.isAutoClose = true - return session.Context(ctx) -} - -// SetDefaultContext set the default context -func (engine *Engine) SetDefaultContext(ctx context.Context) { - engine.defaultContext = ctx -} - -// PingContext tests if database is alive -func (engine *Engine) PingContext(ctx context.Context) error { - session := engine.NewSession() - defer session.Close() - return session.PingContext(ctx) -} diff --git a/engine_context_test.go b/engine_context_test.go deleted file mode 100644 index 1a3276ce..00000000 --- a/engine_context_test.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.8 - -package xorm - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestPingContext(t *testing.T) { - assert.NoError(t, prepareEngine()) - - ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond) - defer canceled() - - time.Sleep(time.Nanosecond) - - err := testEngine.(*Engine).PingContext(ctx) - assert.Error(t, err) - assert.Contains(t, err.Error(), "context deadline exceeded") -} diff --git a/engine_group.go b/engine_group.go index 42d49eca..cdd9dd44 100644 --- a/engine_group.go +++ b/engine_group.go @@ -8,7 +8,11 @@ import ( "context" "time" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/contexts" + "xorm.io/xorm/dialects" + "xorm.io/xorm/log" + "xorm.io/xorm/names" ) // EngineGroup defines an engine group @@ -75,7 +79,7 @@ func (eg *EngineGroup) Close() error { return nil } -// Context returned a group session +// ContextHook returned a group session func (eg *EngineGroup) Context(ctx context.Context) *Session { sess := eg.NewSession() sess.isAutoClose = true @@ -109,10 +113,10 @@ func (eg *EngineGroup) Ping() error { } // SetColumnMapper set the column name mapping rule -func (eg *EngineGroup) SetColumnMapper(mapper core.IMapper) { - eg.Engine.ColumnMapper = mapper +func (eg *EngineGroup) SetColumnMapper(mapper names.Mapper) { + eg.Engine.SetColumnMapper(mapper) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].ColumnMapper = mapper + eg.slaves[i].SetColumnMapper(mapper) } } @@ -125,7 +129,7 @@ func (eg *EngineGroup) SetConnMaxLifetime(d time.Duration) { } // SetDefaultCacher set the default cacher -func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) { +func (eg *EngineGroup) SetDefaultCacher(cacher caches.Cacher) { eg.Engine.SetDefaultCacher(cacher) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].SetDefaultCacher(cacher) @@ -133,15 +137,22 @@ func (eg *EngineGroup) SetDefaultCacher(cacher core.Cacher) { } // SetLogger set the new logger -func (eg *EngineGroup) SetLogger(logger core.ILogger) { +func (eg *EngineGroup) SetLogger(logger interface{}) { eg.Engine.SetLogger(logger) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].SetLogger(logger) } } +func (eg *EngineGroup) AddHook(hook contexts.Hook) { + eg.Engine.AddHook(hook) + for i := 0; i < len(eg.slaves); i++ { + eg.slaves[i].AddHook(hook) + } +} + // SetLogLevel sets the logger level -func (eg *EngineGroup) SetLogLevel(level core.LogLevel) { +func (eg *EngineGroup) SetLogLevel(level log.LogLevel) { eg.Engine.SetLogLevel(level) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].SetLogLevel(level) @@ -149,7 +160,7 @@ func (eg *EngineGroup) SetLogLevel(level core.LogLevel) { } // SetMapper set the name mapping rules -func (eg *EngineGroup) SetMapper(mapper core.IMapper) { +func (eg *EngineGroup) SetMapper(mapper names.Mapper) { eg.Engine.SetMapper(mapper) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].SetMapper(mapper) @@ -158,17 +169,17 @@ func (eg *EngineGroup) SetMapper(mapper core.IMapper) { // SetMaxIdleConns set the max idle connections on pool, default is 2 func (eg *EngineGroup) SetMaxIdleConns(conns int) { - eg.Engine.db.SetMaxIdleConns(conns) + eg.Engine.DB().SetMaxIdleConns(conns) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].db.SetMaxIdleConns(conns) + eg.slaves[i].DB().SetMaxIdleConns(conns) } } // SetMaxOpenConns is only available for go 1.2+ func (eg *EngineGroup) SetMaxOpenConns(conns int) { - eg.Engine.db.SetMaxOpenConns(conns) + eg.Engine.DB().SetMaxOpenConns(conns) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].db.SetMaxOpenConns(conns) + eg.slaves[i].DB().SetMaxOpenConns(conns) } } @@ -178,19 +189,19 @@ func (eg *EngineGroup) SetPolicy(policy GroupPolicy) *EngineGroup { return eg } -// SetTableMapper set the table name mapping rule -func (eg *EngineGroup) SetTableMapper(mapper core.IMapper) { - eg.Engine.TableMapper = mapper +// SetQuotePolicy sets the special quote policy +func (eg *EngineGroup) SetQuotePolicy(quotePolicy dialects.QuotePolicy) { + eg.Engine.SetQuotePolicy(quotePolicy) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].TableMapper = mapper + eg.slaves[i].SetQuotePolicy(quotePolicy) } } -// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO -func (eg *EngineGroup) ShowExecTime(show ...bool) { - eg.Engine.ShowExecTime(show...) +// SetTableMapper set the table name mapping rule +func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) { + eg.Engine.SetTableMapper(mapper) for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].ShowExecTime(show...) + eg.slaves[i].SetTableMapper(mapper) } } diff --git a/engine_group_policy.go b/engine_group_policy.go index 5b56e899..1def8ce4 100644 --- a/engine_group_policy.go +++ b/engine_group_policy.go @@ -51,6 +51,7 @@ func WeightRandomPolicy(weights []int) GroupPolicyHandler { } } +// RoundRobinPolicy returns a group policy handler func RoundRobinPolicy() GroupPolicyHandler { var pos = -1 var lock sync.Mutex @@ -68,6 +69,7 @@ func RoundRobinPolicy() GroupPolicyHandler { } } +// WeightRoundRobinPolicy returns a group policy handler func WeightRoundRobinPolicy(weights []int) GroupPolicyHandler { var rands = make([]int, 0, len(weights)) for i := 0; i < len(weights); i++ { diff --git a/engine_table.go b/engine_table.go deleted file mode 100644 index eb5aa850..00000000 --- a/engine_table.go +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2018 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "reflect" - "strings" - - "xorm.io/core" -) - -// tbNameWithSchema will automatically add schema prefix on table name -func (engine *Engine) tbNameWithSchema(v string) string { - // Add schema name as prefix of table name. - // Only for postgres database. - if engine.dialect.DBType() == core.POSTGRES && - engine.dialect.URI().Schema != "" && - engine.dialect.URI().Schema != postgresPublicSchema && - strings.Index(v, ".") == -1 { - return engine.dialect.URI().Schema + "." + v - } - return v -} - -// TableName returns table name with schema prefix if has -func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { - tbName := engine.tbNameNoSchema(bean) - if len(includeSchema) > 0 && includeSchema[0] { - tbName = engine.tbNameWithSchema(tbName) - } - - return tbName -} - -// tbName get some table's table name -func (session *Session) tbNameNoSchema(table *core.Table) string { - if len(session.statement.AltTableName) > 0 { - return session.statement.AltTableName - } - - return table.Name -} - -func (engine *Engine) tbNameForMap(v reflect.Value) string { - if v.Type().Implements(tpTableName) { - return v.Interface().(TableName).TableName() - } - if v.Kind() == reflect.Ptr { - v = v.Elem() - if v.Type().Implements(tpTableName) { - return v.Interface().(TableName).TableName() - } - } - - return engine.TableMapper.Obj2Table(v.Type().Name()) -} - -func (engine *Engine) tbNameNoSchema(tablename interface{}) string { - switch tablename.(type) { - case []string: - t := tablename.([]string) - if len(t) > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) - } else if len(t) == 1 { - return engine.Quote(t[0]) - } - case []interface{}: - t := tablename.([]interface{}) - l := len(t) - var table string - if l > 0 { - f := t[0] - switch f.(type) { - case string: - table = f.(string) - case TableName: - table = f.(TableName).TableName() - default: - v := rValue(f) - t := v.Type() - if t.Kind() == reflect.Struct { - table = engine.tbNameForMap(v) - } else { - table = engine.Quote(fmt.Sprintf("%v", f)) - } - } - } - if l > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(table), - engine.Quote(fmt.Sprintf("%v", t[1]))) - } else if l == 1 { - return engine.Quote(table) - } - case TableName: - return tablename.(TableName).TableName() - case string: - return tablename.(string) - case reflect.Value: - v := tablename.(reflect.Value) - return engine.tbNameForMap(v) - default: - v := rValue(tablename) - t := v.Type() - if t.Kind() == reflect.Struct { - return engine.tbNameForMap(v) - } - return engine.Quote(fmt.Sprintf("%v", tablename)) - } - return "" -} diff --git a/engine_test.go b/engine_test.go deleted file mode 100644 index 50522f5f..00000000 --- a/engine_test.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestQuoteTo(t *testing.T) { - - test := func(t *testing.T, expected string, value string) { - buf := &strings.Builder{} - quoteTo(buf, "[]", value) - assert.EqualValues(t, expected, buf.String()) - } - - test(t, "[mytable]", "mytable") - test(t, "[mytable]", "`mytable`") - test(t, "[mytable]", `[mytable]`) - - test(t, `["mytable"]`, `"mytable"`) - - test(t, "[myschema].[mytable]", "myschema.mytable") - test(t, "[myschema].[mytable]", "`myschema`.mytable") - test(t, "[myschema].[mytable]", "myschema.`mytable`") - test(t, "[myschema].[mytable]", "`myschema`.`mytable`") - test(t, "[myschema].[mytable]", `[myschema].mytable`) - test(t, "[myschema].[mytable]", `myschema.[mytable]`) - test(t, "[myschema].[mytable]", `[myschema].[mytable]`) - - test(t, `["myschema].[mytable"]`, `"myschema.mytable"`) - - buf := &strings.Builder{} - quoteTo(buf, "", "noquote") - assert.EqualValues(t, "noquote", buf.String()) -} diff --git a/error.go b/error.go index a67527ac..cfa5c819 100644 --- a/error.go +++ b/error.go @@ -6,10 +6,11 @@ package xorm import ( "errors" - "fmt" ) var ( + // ErrPtrSliceType represents a type error + ErrPtrSliceType = errors.New("A point to a slice is needed") // ErrParamsType params error ErrParamsType = errors.New("Params type error") // ErrTableNotFound table not found error @@ -20,32 +21,6 @@ var ( ErrNotExist = errors.New("Record does not exist") // ErrCacheFailed cache failed error ErrCacheFailed = errors.New("Cache failed") - // ErrNeedDeletedCond delete needs less one condition error - ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") - // ErrNotImplemented not implemented - ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") - // ErrUnSupportedSQLType parameter of SQL is not supported - ErrUnSupportedSQLType = errors.New("unsupported sql type") ) - -// ErrFieldIsNotExist columns does not exist -type ErrFieldIsNotExist struct { - FieldName string - TableName string -} - -func (e ErrFieldIsNotExist) Error() string { - return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) -} - -// ErrFieldIsNotValid is not valid -type ErrFieldIsNotValid struct { - FieldName string - TableName string -} - -func (e ErrFieldIsNotValid) Error() string { - return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) -} diff --git a/examples/README.md b/examples/README.md deleted file mode 100644 index 666c6cf9..00000000 --- a/examples/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Xorm Examples - -Notice: all the examples will ask you install extra package `github.com/mattn/go-sqlite3`, since it depends on cgo. You have to compile it after you install a c++ compile. Please see [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3). - -And then, you can run the examples via `go run xxx.go`. Every go file is a standalone example. diff --git a/examples/cache.go b/examples/cache.go deleted file mode 100644 index 5ad1de1b..00000000 --- a/examples/cache.go +++ /dev/null @@ -1,109 +0,0 @@ -package main - -import ( - "fmt" - "os" - - "xorm.io/xorm" - _ "github.com/mattn/go-sqlite3" -) - -// User describes a user -type User struct { - Id int64 - Name string -} - -func main() { - f := "cache.db" - os.Remove(f) - - Orm, err := xorm.NewEngine("sqlite3", f) - if err != nil { - fmt.Println(err) - return - } - Orm.ShowSQL(true) - cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) - Orm.SetDefaultCacher(cacher) - - err = Orm.CreateTables(&User{}) - if err != nil { - fmt.Println(err) - return - } - - _, err = Orm.Insert(&User{Name: "xlw"}) - if err != nil { - fmt.Println(err) - return - } - - var users []User - err = Orm.Find(&users) - if err != nil { - fmt.Println(err) - return - } - - fmt.Println("users:", users) - - var users2 []User - err = Orm.Find(&users2) - if err != nil { - fmt.Println(err) - return - } - - fmt.Println("users2:", users2) - - var users3 []User - err = Orm.Find(&users3) - if err != nil { - fmt.Println(err) - return - } - - fmt.Println("users3:", users3) - - user4 := new(User) - has, err := Orm.ID(1).Get(user4) - if err != nil { - fmt.Println(err) - return - } - - fmt.Println("user4:", has, user4) - - user4.Name = "xiaolunwen" - _, err = Orm.ID(1).Update(user4) - if err != nil { - fmt.Println(err) - return - } - fmt.Println("user4:", user4) - - user5 := new(User) - has, err = Orm.ID(1).Get(user5) - if err != nil { - fmt.Println(err) - return - } - fmt.Println("user5:", has, user5) - - _, err = Orm.ID(1).Delete(new(User)) - if err != nil { - fmt.Println(err) - return - } - - for { - user6 := new(User) - has, err = Orm.ID(1).Get(user6) - if err != nil { - fmt.Println(err) - return - } - fmt.Println("user6:", has, user6) - } -} diff --git a/examples/cachegoroutine.go b/examples/cachegoroutine.go deleted file mode 100644 index c8d8ee69..00000000 --- a/examples/cachegoroutine.go +++ /dev/null @@ -1,108 +0,0 @@ -package main - -import ( - "fmt" - "os" - "time" - - _ "github.com/go-sql-driver/mysql" - "xorm.io/xorm" - _ "github.com/mattn/go-sqlite3" -) - -// User describes a user -type User struct { - Id int64 - Name string -} - -func sqliteEngine() (*xorm.Engine, error) { - os.Remove("./test.db") - return xorm.NewEngine("sqlite3", "./goroutine.db") -} - -func mysqlEngine() (*xorm.Engine, error) { - return xorm.NewEngine("mysql", "root:@/test?charset=utf8") -} - -var u = &User{} - -func test(engine *xorm.Engine) { - err := engine.CreateTables(u) - if err != nil { - fmt.Println(err) - return - } - - size := 500 - queue := make(chan int, size) - - for i := 0; i < size; i++ { - go func(x int) { - //x := i - err := engine.Ping() - if err != nil { - fmt.Println(err) - } else { - for j := 0; j < 10; j++ { - if x+j < 2 { - _, err = engine.Get(u) - } else if x+j < 4 { - users := make([]User, 0) - err = engine.Find(&users) - } else if x+j < 8 { - _, err = engine.Count(u) - } else if x+j < 16 { - _, err = engine.Insert(&User{Name: "xlw"}) - } else if x+j < 32 { - //_, err = engine.ID(1).Delete(u) - _, err = engine.Delete(u) - } - if err != nil { - fmt.Println(err) - queue <- x - return - } - } - fmt.Printf("%v success!\n", x) - } - queue <- x - }(i) - } - - for i := 0; i < size; i++ { - <-queue - } - - //conns := atomic.LoadInt32(&xorm.ConnectionNum) - //fmt.Println("connection number:", conns) - fmt.Println("end") -} - -func main() { - fmt.Println("-----start sqlite go routines-----") - engine, err := sqliteEngine() - if err != nil { - fmt.Println(err) - return - } - engine.ShowSQL(true) - cacher := xorm.NewLRUCacher2(xorm.NewMemoryStore(), time.Hour, 1000) - engine.SetDefaultCacher(cacher) - fmt.Println(engine) - test(engine) - fmt.Println("test end") - engine.Close() - - fmt.Println("-----start mysql go routines-----") - engine, err = mysqlEngine() - engine.ShowSQL(true) - cacher = xorm.NewLRUCacher2(xorm.NewMemoryStore(), time.Hour, 1000) - engine.SetDefaultCacher(cacher) - if err != nil { - fmt.Println(err) - return - } - defer engine.Close() - test(engine) -} diff --git a/examples/conversion.go b/examples/conversion.go deleted file mode 100644 index 62d4a86b..00000000 --- a/examples/conversion.go +++ /dev/null @@ -1,81 +0,0 @@ -package main - -import ( - "errors" - "fmt" - "os" - - "xorm.io/xorm" - _ "github.com/mattn/go-sqlite3" -) - -// Status describes a status -type Status struct { - Name string - Color string -} - -// defines some statuses -var ( - Registered = Status{"Registered", "white"} - Approved = Status{"Approved", "green"} - Removed = Status{"Removed", "red"} - Statuses = map[string]Status{ - Registered.Name: Registered, - Approved.Name: Approved, - Removed.Name: Removed, - } -) - -// FromDB implemented xorm.Conversion convent database data to self -func (s *Status) FromDB(bytes []byte) error { - if r, ok := Statuses[string(bytes)]; ok { - *s = r - return nil - } - return errors.New("no this data") -} - -// ToDB implemented xorm.Conversion convent to database data -func (s *Status) ToDB() ([]byte, error) { - return []byte(s.Name), nil -} - -// User describes a user -type User struct { - Id int64 - Name string - Status Status `xorm:"varchar(40)"` -} - -func main() { - f := "conversion.db" - os.Remove(f) - - Orm, err := xorm.NewEngine("sqlite3", f) - if err != nil { - fmt.Println(err) - return - } - Orm.ShowSQL(true) - err = Orm.CreateTables(&User{}) - if err != nil { - fmt.Println(err) - return - } - - _, err = Orm.Insert(&User{1, "xlw", Registered}) - if err != nil { - fmt.Println(err) - return - } - - users := make([]User, 0) - err = Orm.Find(&users) - if err != nil { - fmt.Println(err) - return - } - - fmt.Println(users) -} diff --git a/examples/derive.go b/examples/derive.go deleted file mode 100644 index 23e7f169..00000000 --- a/examples/derive.go +++ /dev/null @@ -1,70 +0,0 @@ -package main - -import ( - "fmt" - "os" - - "xorm.io/xorm" - _ "github.com/mattn/go-sqlite3" -) - -// User describes a user -type User struct { - Id int64 - Name string -} - -// LoginInfo describes a login information -type LoginInfo struct { - Id int64 - IP string - UserId int64 -} - -// LoginInfo1 describes a login information -type LoginInfo1 struct { - LoginInfo `xorm:"extends"` - UserName string -} - -func main() { - f := "derive.db" - os.Remove(f) - - orm, err := xorm.NewEngine("sqlite3", f) - if err != nil { - fmt.Println(err) - return - } - defer orm.Close() - orm.ShowSQL(true) - err = orm.CreateTables(&User{}, &LoginInfo{}) - if err != nil { - fmt.Println(err) - return - } - - _, err = orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1}) - if err != nil { - fmt.Println(err) - return - } - - info := LoginInfo{} - _, err = orm.ID(1).Get(&info) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(info) - - infos := make([]LoginInfo1, 0) - err = orm.Sql(`select *, (select name from user where id = login_info.user_id) as user_name from - login_info limit 10`).Find(&infos) - if err != nil { - fmt.Println(err) - return - } - - fmt.Println(infos) -} diff --git a/examples/find.go b/examples/find.go deleted file mode 100644 index ae27a797..00000000 --- a/examples/find.go +++ /dev/null @@ -1,51 +0,0 @@ -package main - -import ( - "fmt" - "os" - "time" - - "xorm.io/xorm" - _ "github.com/mattn/go-sqlite3" -) - -// User describes a user -type User struct { - Id int64 - Name string - Created time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` -} - -func main() { - f := "conversion.db" - os.Remove(f) - - orm, err := xorm.NewEngine("sqlite3", f) - if err != nil { - fmt.Println(err) - return - } - orm.ShowSQL(true) - - err = orm.CreateTables(&User{}) - if err != nil { - fmt.Println(err) - return - } - - _, err = orm.Insert(&User{Id: 1, Name: "xlw"}) - if err != nil { - fmt.Println(err) - return - } - - users := make([]User, 0) - err = orm.Find(&users) - if err != nil { - fmt.Println(err) - return - } - - fmt.Println(users) -} diff --git a/examples/goroutine.go b/examples/goroutine.go deleted file mode 100644 index d320714a..00000000 --- a/examples/goroutine.go +++ /dev/null @@ -1,108 +0,0 @@ -package main - -import ( - "fmt" - "os" - "runtime" - - _ "github.com/go-sql-driver/mysql" - "xorm.io/xorm" - _ "github.com/mattn/go-sqlite3" -) - -// User describes a user -type User struct { - Id int64 - Name string -} - -func sqliteEngine() (*xorm.Engine, error) { - os.Remove("./test.db") - return xorm.NewEngine("sqlite3", "./goroutine.db") -} - -func mysqlEngine() (*xorm.Engine, error) { - return xorm.NewEngine("mysql", "root:@/test?charset=utf8") -} - -var u = &User{} - -func test(engine *xorm.Engine) { - err := engine.CreateTables(u) - if err != nil { - fmt.Println(err) - return - } - - size := 100 - queue := make(chan int, size) - - for i := 0; i < size; i++ { - go func(x int) { - //x := i - err := engine.Ping() - if err != nil { - fmt.Println(err) - } else { - /*err = engine.(u) - if err != nil { - fmt.Println("Map user failed") - } else {*/ - for j := 0; j < 10; j++ { - if x+j < 2 { - _, err = engine.Get(u) - } else if x+j < 4 { - users := make([]User, 0) - err = engine.Find(&users) - } else if x+j < 8 { - _, err = engine.Count(u) - } else if x+j < 16 { - _, err = engine.Insert(&User{Name: "xlw"}) - } else if x+j < 32 { - _, err = engine.ID(1).Delete(u) - } - if err != nil { - fmt.Println(err) - queue <- x - return - } - } - fmt.Printf("%v success!\n", x) - //} - } - queue <- x - }(i) - } - - for i := 0; i < size; i++ { - <-queue - } - - //conns := atomic.LoadInt32(&xorm.ConnectionNum) - //fmt.Println("connection number:", conns) - fmt.Println("end") -} - -func main() { - runtime.GOMAXPROCS(2) - fmt.Println("-----start sqlite go routines-----") - engine, err := sqliteEngine() - if err != nil { - fmt.Println(err) - return - } - engine.ShowSQL(true) - fmt.Println(engine) - test(engine) - fmt.Println("test end") - engine.Close() - - fmt.Println("-----start mysql go routines-----") - engine, err = mysqlEngine() - if err != nil { - fmt.Println(err) - return - } - defer engine.Close() - test(engine) -} diff --git a/examples/maxconnect.go b/examples/maxconnect.go deleted file mode 100644 index d8d8b0d8..00000000 --- a/examples/maxconnect.go +++ /dev/null @@ -1,108 +0,0 @@ -package main - -import ( - "fmt" - "os" - "runtime" - - _ "github.com/go-sql-driver/mysql" - "xorm.io/xorm" - _ "github.com/mattn/go-sqlite3" -) - -// User describes a user -type User struct { - Id int64 - Name string -} - -func sqliteEngine() (*xorm.Engine, error) { - os.Remove("./test.db") - return xorm.NewEngine("sqlite3", "./goroutine.db") -} - -func mysqlEngine() (*xorm.Engine, error) { - return xorm.NewEngine("mysql", "root:@/test?charset=utf8") -} - -var u = &User{} - -func test(engine *xorm.Engine) { - err := engine.CreateTables(u) - if err != nil { - fmt.Println(err) - return - } - - engine.ShowSQL(true) - engine.SetMaxOpenConns(5) - - size := 1000 - queue := make(chan int, size) - - for i := 0; i < size; i++ { - go func(x int) { - //x := i - err := engine.Ping() - if err != nil { - fmt.Println(err) - } else { - /*err = engine.Map(u) - if err != nil { - fmt.Println("Map user failed") - } else {*/ - for j := 0; j < 10; j++ { - if x+j < 2 { - _, err = engine.Get(u) - } else if x+j < 4 { - users := make([]User, 0) - err = engine.Find(&users) - } else if x+j < 8 { - _, err = engine.Count(u) - } else if x+j < 16 { - _, err = engine.Insert(&User{Name: "xlw"}) - } else if x+j < 32 { - _, err = engine.ID(1).Delete(u) - } - if err != nil { - fmt.Println(err) - queue <- x - return - } - } - fmt.Printf("%v success!\n", x) - //} - } - queue <- x - }(i) - } - - for i := 0; i < size; i++ { - <-queue - } - - fmt.Println("end") -} - -func main() { - runtime.GOMAXPROCS(2) - fmt.Println("create engine") - engine, err := sqliteEngine() - if err != nil { - fmt.Println(err) - return - } - engine.ShowSQL(true) - fmt.Println(engine) - test(engine) - fmt.Println("------------------------") - engine.Close() - - engine, err = mysqlEngine() - if err != nil { - fmt.Println(err) - return - } - defer engine.Close() - test(engine) -} diff --git a/examples/singlemapping.go b/examples/singlemapping.go deleted file mode 100644 index 5c61448b..00000000 --- a/examples/singlemapping.go +++ /dev/null @@ -1,57 +0,0 @@ -package main - -import ( - "fmt" - "os" - - "xorm.io/xorm" - _ "github.com/mattn/go-sqlite3" -) - -// User describes a user -type User struct { - Id int64 - Name string -} - -// LoginInfo describes a login information -type LoginInfo struct { - Id int64 - IP string - UserId int64 - // timestamp should be updated by database, so only allow get from db - TimeStamp string `xorm:"<-"` - // assume - Nonuse int `xorm:"->"` -} - -func main() { - f := "singleMapping.db" - os.Remove(f) - - orm, err := xorm.NewEngine("sqlite3", f) - if err != nil { - fmt.Println(err) - return - } - orm.ShowSQL(true) - err = orm.CreateTables(&User{}, &LoginInfo{}) - if err != nil { - fmt.Println(err) - return - } - - _, err = orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1, "", 23}) - if err != nil { - fmt.Println(err) - return - } - - info := LoginInfo{} - _, err = orm.ID(1).Get(&info) - if err != nil { - fmt.Println(err) - return - } - fmt.Println(info) -} diff --git a/examples/sync.go b/examples/sync.go deleted file mode 100644 index 92647c0f..00000000 --- a/examples/sync.go +++ /dev/null @@ -1,106 +0,0 @@ -package main - -import ( - "fmt" - - _ "github.com/go-sql-driver/mysql" - "xorm.io/xorm" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" -) - -// SyncUser2 describes a user -type SyncUser2 struct { - Id int64 - Name string `xorm:"unique"` - Age int `xorm:"index"` - Title string - Address string - Genre string - Area string - Date int -} - -// SyncLoginInfo2 describes a login information -type SyncLoginInfo2 struct { - Id int64 - IP string `xorm:"index"` - UserId int64 - AddedCol int - // timestamp should be updated by database, so only allow get from db - TimeStamp string - // assume - Nonuse int `xorm:"unique"` - Newa string `xorm:"index"` -} - -func sync(engine *xorm.Engine) error { - return engine.Sync(&SyncLoginInfo2{}, &SyncUser2{}) -} - -func sqliteEngine() (*xorm.Engine, error) { - f := "sync.db" - //os.Remove(f) - - return xorm.NewEngine("sqlite3", f) -} - -func mysqlEngine() (*xorm.Engine, error) { - return xorm.NewEngine("mysql", "root:@/test?charset=utf8") -} - -func postgresEngine() (*xorm.Engine, error) { - return xorm.NewEngine("postgres", "dbname=xorm_test sslmode=disable") -} - -type engineFunc func() (*xorm.Engine, error) - -func main() { - //engines := []engineFunc{sqliteEngine, mysqlEngine, postgresEngine} - //engines := []engineFunc{sqliteEngine} - //engines := []engineFunc{mysqlEngine} - engines := []engineFunc{postgresEngine} - for _, enginefunc := range engines { - Orm, err := enginefunc() - fmt.Println("--------", Orm.DriverName(), "----------") - if err != nil { - fmt.Println(err) - return - } - Orm.ShowSQL(true) - err = sync(Orm) - if err != nil { - fmt.Println(err) - } - - _, err = Orm.Where("id > 0").Delete(&SyncUser2{}) - if err != nil { - fmt.Println(err) - } - - user := &SyncUser2{ - Name: "testsdf", - Age: 15, - Title: "newsfds", - Address: "fasfdsafdsaf", - Genre: "fsafd", - Area: "fafdsafd", - Date: 1000, - } - _, err = Orm.Insert(user) - if err != nil { - fmt.Println(err) - return - } - - isexist, err := Orm.IsTableExist("sync_user2") - if err != nil { - fmt.Println(err) - return - } - if !isexist { - fmt.Println("sync_user2 is not exist") - return - } - } -} diff --git a/examples/tables.go b/examples/tables.go deleted file mode 100644 index fcf49219..00000000 --- a/examples/tables.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -import ( - "fmt" - "os" - - "xorm.io/xorm" - _ "github.com/mattn/go-sqlite3" -) - -func main() { - if len(os.Args) < 2 { - fmt.Println("need db path") - return - } - - orm, err := xorm.NewEngine("sqlite3", os.Args[1]) - if err != nil { - fmt.Println(err) - return - } - defer orm.Close() - orm.ShowSQL(true) - - tables, err := orm.DBMetas() - if err != nil { - fmt.Println(err) - return - } - - for _, table := range tables { - fmt.Println(table.Name) - } -} diff --git a/go.mod b/go.mod index 6d8b58f4..f6a98156 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,12 @@ module xorm.io/xorm go 1.11 require ( - github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4 - github.com/go-sql-driver/mysql v1.4.1 - github.com/kr/pretty v0.1.0 // indirect - github.com/lib/pq v1.0.0 - github.com/mattn/go-sqlite3 v1.10.0 + github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc + github.com/go-sql-driver/mysql v1.5.0 + github.com/lib/pq v1.7.0 + github.com/mattn/go-sqlite3 v2.0.3+incompatible github.com/stretchr/testify v1.4.0 + github.com/syndtr/goleveldb v1.0.0 github.com/ziutek/mymysql v1.5.4 - xorm.io/builder v0.3.6 - xorm.io/core v0.7.2 + xorm.io/builder v0.3.7 ) diff --git a/go.sum b/go.sum index 2102cc5b..2da01eeb 100644 --- a/go.sum +++ b/go.sum @@ -1,149 +1,61 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.37.4 h1:glPeL3BQJsbF6aIIYfZizMwc5LTYz250bDMjttbBGAU= -cloud.google.com/go v0.37.4/go.mod h1:NHPJ89PdicEuT9hdPXMROBD91xc5uRDxsMtSB16k7hw= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:lSA0F4e9A2NcQSqGqTOXqu2aRi/XEQxDCBwM8yJtE6s= +gitea.com/xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:EXuID2Zs0pAQhH8yz+DNjUbjppKQzKFAn28TMYPB6IU= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4 h1:YcpmyvADGYw5LqMnHqSkyIELsHCGF6PkrmM31V8rF7o= -github.com/denisenkom/go-mssqldb v0.0.0-20190707035753-2be1aa521ff4/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= -github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= -github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= +github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc h1:VRRKCwnzqk8QCaRC4os14xoKDdbHqqlJtJA0oc1ZAjg= +github.com/denisenkom/go-mssqldb v0.0.0-20200428022330-06a60b6afbbc/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= -github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a h1:9wScpmSP5A3Bk8V3XHWUcJmYTh+ZnlHVyc+A4oZYS3Y= -github.com/go-xorm/sqlfiddle v0.0.0-20180821085327-62ce714f951a/go.mod h1:56xuuqnHyryaerycW3BfssRdxQstACi0Epw/yC5E2xM= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= -github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= -github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= -github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/lib/pq v1.7.0 h1:h93mCPfUSkaul3Ka/VG8uZdmW1uMHDGxzu0NWHuJmHY= +github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= +github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= -github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= +github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs= github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= -go.opencensus.io v0.20.1/go.mod h1:6WKK9ahsWS3RSO+PY9ZHZUfv2irvY6gN279GOPZjmmk= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e h1:o3PsSEY8E4eXWkXrIP9YJALUkVZqzHJT5DOasTyn8Vs= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.0 h1:Tfd7cKwKbFRsI8RMAD3oqqw7JPFRrvFlOsfbgVkjOOw= -google.golang.org/appengine v1.6.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190404172233-64821d5d2107/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -xorm.io/builder v0.3.6 h1:ha28mQ2M+TFx96Hxo+iq6tQgnkC9IZkM6D8w9sKHHF8= -xorm.io/builder v0.3.6/go.mod h1:LEFAPISnRzG+zxaxj2vPicRwz67BdhFreKg8yv8/TgU= -xorm.io/core v0.7.2 h1:mEO22A2Z7a3fPaZMk6gKL/jMD80iiyNwRrX5HOv3XLw= -xorm.io/core v0.7.2/go.mod h1:jJfd0UAEzZ4t87nbQYtVjmqpIODugN6PD2D9E+dJvdM= +xorm.io/builder v0.3.7 h1:2pETdKRK+2QG4mLX4oODHEhn5Z8j1m8sXa7jfu+/SZI= +xorm.io/builder v0.3.7/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= diff --git a/helpers.go b/helpers.go deleted file mode 100644 index a31e922c..00000000 --- a/helpers.go +++ /dev/null @@ -1,332 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "fmt" - "reflect" - "sort" - "strconv" - "strings" - - "xorm.io/core" -) - -// str2PK convert string value to primary key value according to tp -func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { - var err error - var result interface{} - var defReturn = reflect.Zero(tp) - - switch tp.Kind() { - case reflect.Int: - result, err = strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int: %s", s, err.Error()) - } - case reflect.Int8: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int8: %s", s, err.Error()) - } - result = int8(x) - case reflect.Int16: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int16: %s", s, err.Error()) - } - result = int16(x) - case reflect.Int32: - x, err := strconv.Atoi(s) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int32: %s", s, err.Error()) - } - result = int32(x) - case reflect.Int64: - result, err = strconv.ParseInt(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as int64: %s", s, err.Error()) - } - case reflect.Uint: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint: %s", s, err.Error()) - } - result = uint(x) - case reflect.Uint8: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint8: %s", s, err.Error()) - } - result = uint8(x) - case reflect.Uint16: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint16: %s", s, err.Error()) - } - result = uint16(x) - case reflect.Uint32: - x, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint32: %s", s, err.Error()) - } - result = uint32(x) - case reflect.Uint64: - result, err = strconv.ParseUint(s, 10, 64) - if err != nil { - return defReturn, fmt.Errorf("convert %s as uint64: %s", s, err.Error()) - } - case reflect.String: - result = s - default: - return defReturn, errors.New("unsupported convert type") - } - return reflect.ValueOf(result).Convert(tp), nil -} - -func str2PK(s string, tp reflect.Type) (interface{}, error) { - v, err := str2PKValue(s, tp) - if err != nil { - return nil, err - } - return v.Interface(), nil -} - -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 - } - } - } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) - } - return -} - -type zeroable interface { - IsZero() bool -} - -func isZero(k interface{}) bool { - switch k.(type) { - case int: - return k.(int) == 0 - case int8: - return k.(int8) == 0 - case int16: - return k.(int16) == 0 - case int32: - return k.(int32) == 0 - case int64: - return k.(int64) == 0 - case uint: - return k.(uint) == 0 - case uint8: - return k.(uint8) == 0 - case uint16: - return k.(uint16) == 0 - case uint32: - return k.(uint32) == 0 - case uint64: - return k.(uint64) == 0 - case float32: - return k.(float32) == 0 - case float64: - return k.(float64) == 0 - case bool: - return k.(bool) == false - case string: - return k.(string) == "" - case zeroable: - return k.(zeroable).IsZero() - } - return false -} - -func isStructZero(v reflect.Value) bool { - if !v.IsValid() { - return true - } - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - switch field.Kind() { - case reflect.Ptr: - field = field.Elem() - fallthrough - case reflect.Struct: - if !isStructZero(field) { - return false - } - default: - if field.CanInterface() && !isZero(field.Interface()) { - return false - } - } - } - return true -} - -func isArrayValueZero(v reflect.Value) bool { - if !v.IsValid() || v.Len() == 0 { - return true - } - - for i := 0; i < v.Len(); i++ { - if !isZero(v.Index(i).Interface()) { - return false - } - } - - return true -} - -func int64ToIntValue(id int64, tp reflect.Type) reflect.Value { - var v interface{} - kind := tp.Kind() - - if kind == reflect.Ptr { - kind = tp.Elem().Kind() - } - - switch kind { - case reflect.Int16: - temp := int16(id) - v = &temp - case reflect.Int32: - temp := int32(id) - v = &temp - case reflect.Int: - temp := int(id) - v = &temp - case reflect.Int64: - temp := id - v = &temp - case reflect.Uint16: - temp := uint16(id) - v = &temp - case reflect.Uint32: - temp := uint32(id) - v = &temp - case reflect.Uint64: - temp := uint64(id) - v = &temp - case reflect.Uint: - temp := uint(id) - v = &temp - } - - if tp.Kind() == reflect.Ptr { - return reflect.ValueOf(v).Convert(tp) - } - return reflect.ValueOf(v).Elem().Convert(tp) -} - -func int64ToInt(id int64, tp reflect.Type) interface{} { - return int64ToIntValue(id, tp).Interface() -} - -func isPKZero(pk core.PK) bool { - for _, k := range pk { - if isZero(k) { - return true - } - } - return false -} - -func indexNoCase(s, sep string) int { - return strings.Index(strings.ToLower(s), strings.ToLower(sep)) -} - -func splitNoCase(s, sep string) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.Split(s, s[idx:idx+len(sep)]) -} - -func splitNNoCase(s, sep string, n int) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.SplitN(s, s[idx:idx+len(sep)], n) -} - -func makeArray(elem string, count int) []string { - res := make([]string, count) - for i := 0; i < count; i++ { - res[i] = elem - } - return res -} - -func rValue(bean interface{}) reflect.Value { - return reflect.Indirect(reflect.ValueOf(bean)) -} - -func rType(bean interface{}) reflect.Type { - sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - // return reflect.TypeOf(sliceValue.Interface()) - return sliceValue.Type() -} - -func structName(v reflect.Type) string { - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - return v.Name() -} - -func sliceEq(left, right []string) bool { - if len(left) != len(right) { - return false - } - sort.Sort(sort.StringSlice(left)) - sort.Sort(sort.StringSlice(right)) - for i := 0; i < len(left); i++ { - if left[i] != right[i] { - return false - } - } - return true -} - -func indexName(tableName, idxName string) string { - return fmt.Sprintf("IDX_%v_%v", tableName, idxName) -} - -func eraseAny(value string, strToErase ...string) string { - if len(strToErase) == 0 { - return value - } - var replaceSeq []string - for _, s := range strToErase { - replaceSeq = append(replaceSeq, s, "") - } - - replacer := strings.NewReplacer(replaceSeq...) - - return replacer.Replace(value) -} - -func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string { - for i := range cols { - cols[i] = quoteFunc(cols[i]) - } - return strings.Join(cols, sep+" ") -} diff --git a/helpers_test.go b/helpers_test.go deleted file mode 100644 index caf7b9f0..00000000 --- a/helpers_test.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestEraseAny(t *testing.T) { - raw := "SELECT * FROM `table`.[table_name]" - assert.EqualValues(t, raw, eraseAny(raw)) - assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`")) - assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]")) -} - -func TestQuoteColumns(t *testing.T) { - cols := []string{"f1", "f2", "f3"} - quoteFunc := func(value string) string { - return "[" + value + "]" - } - - assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ",")) -} diff --git a/helpler_time.go b/helpler_time.go deleted file mode 100644 index f4013e27..00000000 --- a/helpler_time.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import "time" - -const ( - zeroTime0 = "0000-00-00 00:00:00" - zeroTime1 = "0001-01-01 00:00:00" -) - -func formatTime(t time.Time) string { - return t.Format("2006-01-02 15:04:05") -} - -func isTimeZero(t time.Time) bool { - return t.IsZero() || formatTime(t) == zeroTime0 || - formatTime(t) == zeroTime1 -} diff --git a/cache_test.go b/integrations/cache_test.go similarity index 91% rename from cache_test.go rename to integrations/cache_test.go index 26d7ac68..44e817b1 100644 --- a/cache_test.go +++ b/integrations/cache_test.go @@ -2,17 +2,19 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" "time" + "xorm.io/xorm/caches" + "github.com/stretchr/testify/assert" ) func TestCacheFind(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type MailBox struct { Id int64 `xorm:"pk"` @@ -21,7 +23,7 @@ func TestCacheFind(t *testing.T) { } oldCacher := testEngine.GetDefaultCacher() - cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) + cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000) testEngine.SetDefaultCacher(cacher) assert.NoError(t, testEngine.Sync2(new(MailBox))) @@ -87,7 +89,7 @@ func TestCacheFind(t *testing.T) { } func TestCacheFind2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type MailBox2 struct { Id uint64 `xorm:"pk"` @@ -96,7 +98,7 @@ func TestCacheFind2(t *testing.T) { } oldCacher := testEngine.GetDefaultCacher() - cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) + cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000) testEngine.SetDefaultCacher(cacher) assert.NoError(t, testEngine.Sync2(new(MailBox2))) @@ -138,7 +140,7 @@ func TestCacheFind2(t *testing.T) { } func TestCacheGet(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type MailBox3 struct { Id uint64 @@ -147,7 +149,7 @@ func TestCacheGet(t *testing.T) { } oldCacher := testEngine.GetDefaultCacher() - cacher := NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) + cacher := caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000) testEngine.SetDefaultCacher(cacher) assert.NoError(t, testEngine.Sync2(new(MailBox3))) diff --git a/integrations/engine_group_test.go b/integrations/engine_group_test.go new file mode 100644 index 00000000..635f73a6 --- /dev/null +++ b/integrations/engine_group_test.go @@ -0,0 +1,35 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "testing" + + "xorm.io/xorm" + "xorm.io/xorm/log" + "xorm.io/xorm/schemas" + + "github.com/stretchr/testify/assert" +) + +func TestEngineGroup(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + master := testEngine.(*xorm.Engine) + if master.Dialect().URI().DBType == schemas.SQLITE { + t.Skip() + return + } + + eg, err := xorm.NewEngineGroup(master, []*xorm.Engine{master}) + assert.NoError(t, err) + + eg.SetMaxIdleConns(10) + eg.SetMaxOpenConns(100) + eg.SetTableMapper(master.GetTableMapper()) + eg.SetColumnMapper(master.GetColumnMapper()) + eg.SetLogLevel(log.LOG_INFO) + eg.ShowSQL(true) +} diff --git a/integrations/engine_test.go b/integrations/engine_test.go new file mode 100644 index 00000000..19c5285d --- /dev/null +++ b/integrations/engine_test.go @@ -0,0 +1,141 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "xorm.io/xorm" + "xorm.io/xorm/schemas" + + _ "github.com/denisenkom/go-mssqldb" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + _ "github.com/ziutek/mymysql/godrv" +) + +func TestPing(t *testing.T) { + if err := testEngine.Ping(); err != nil { + t.Fatal(err) + } +} + +func TestPingContext(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond) + defer canceled() + + time.Sleep(time.Nanosecond) + + err := testEngine.(*xorm.Engine).PingContext(ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") +} + +func TestAutoTransaction(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestTx struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Created time.Time `xorm:"created"` + } + + assert.NoError(t, testEngine.Sync2(new(TestTx))) + + engine := testEngine.(*xorm.Engine) + + // will success + engine.Transaction(func(session *xorm.Session) (interface{}, error) { + _, err := session.Insert(TestTx{Msg: "hi"}) + assert.NoError(t, err) + + return nil, nil + }) + + has, err := engine.Exist(&TestTx{Msg: "hi"}) + assert.NoError(t, err) + assert.EqualValues(t, true, has) + + // will rollback + _, err = engine.Transaction(func(session *xorm.Session) (interface{}, error) { + _, err := session.Insert(TestTx{Msg: "hello"}) + assert.NoError(t, err) + + return nil, fmt.Errorf("rollback") + }) + assert.Error(t, err) + + has, err = engine.Exist(&TestTx{Msg: "hello"}) + assert.NoError(t, err) + assert.EqualValues(t, false, has) +} + +func assertSync(t *testing.T, beans ...interface{}) { + for _, bean := range beans { + t.Run(testEngine.TableName(bean, true), func(t *testing.T) { + assert.NoError(t, testEngine.DropTables(bean)) + assert.NoError(t, testEngine.Sync2(bean)) + }) + } +} + +func TestDump(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestDumpStruct struct { + Id int64 + Name string + } + + assertSync(t, new(TestDumpStruct)) + + testEngine.Insert([]TestDumpStruct{ + {Name: "1"}, + {Name: "2\n"}, + {Name: "3;"}, + {Name: "4\n;\n''"}, + {Name: "5'\n"}, + }) + + fp := fmt.Sprintf("%v.sql", testEngine.Dialect().URI().DBType) + os.Remove(fp) + assert.NoError(t, testEngine.DumpAllToFile(fp)) + + assert.NoError(t, PrepareEngine()) + + sess := testEngine.NewSession() + defer sess.Close() + assert.NoError(t, sess.Begin()) + _, err := sess.ImportFile(fp) + assert.NoError(t, err) + assert.NoError(t, sess.Commit()) + + for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} { + name := fmt.Sprintf("dump_%v.sql", tp) + t.Run(name, func(t *testing.T) { + assert.NoError(t, testEngine.DumpAllToFile(name, tp)) + }) + } +} + +func TestSetSchema(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + if testEngine.Dialect().URI().DBType == schemas.POSTGRES { + oldSchema := testEngine.Dialect().URI().Schema + testEngine.SetSchema("my_schema") + assert.EqualValues(t, "my_schema", testEngine.Dialect().URI().Schema) + testEngine.SetSchema(oldSchema) + assert.EqualValues(t, oldSchema, testEngine.Dialect().URI().Schema) + } +} diff --git a/types.go b/integrations/main_test.go similarity index 57% rename from types.go rename to integrations/main_test.go index c76a5460..225ae45a 100644 --- a/types.go +++ b/integrations/main_test.go @@ -2,15 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( - "reflect" - - "xorm.io/core" + "testing" ) -var ( - ptrPkType = reflect.TypeOf(&core.PK{}) - pkType = reflect.TypeOf(core.PK{}) -) +func TestMain(m *testing.M) { + MainTest(m) +} diff --git a/processors_test.go b/integrations/processors_test.go similarity index 90% rename from processors_test.go rename to integrations/processors_test.go index d1efc047..e349988d 100644 --- a/processors_test.go +++ b/integrations/processors_test.go @@ -2,18 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "errors" "fmt" "testing" + "xorm.io/xorm" + "github.com/stretchr/testify/assert" ) func TestBefore_Get(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type BeforeTable struct { Id int64 @@ -40,7 +42,7 @@ func TestBefore_Get(t *testing.T) { } func TestBefore_Find(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type BeforeTable2 struct { Id int64 @@ -101,7 +103,7 @@ func (p *ProcessorsStruct) BeforeDelete() { p.B4DeleteFlag = 1 } -func (p *ProcessorsStruct) BeforeSet(col string, cell Cell) { +func (p *ProcessorsStruct) BeforeSet(col string, cell xorm.Cell) { p.BeforeSetFlag = p.BeforeSetFlag + 1 } @@ -117,25 +119,19 @@ func (p *ProcessorsStruct) AfterDelete() { p.AfterDeletedFlag = 1 } -func (p *ProcessorsStruct) AfterSet(col string, cell Cell) { +func (p *ProcessorsStruct) AfterSet(col string, cell xorm.Cell) { p.AfterSetFlag = p.AfterSetFlag + 1 } func TestProcessors(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) p := &ProcessorsStruct{} err = testEngine.CreateTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) b4InsertFunc := func(bean interface{}) { if v, ok := (bean).(*ProcessorsStruct); ok { @@ -259,42 +255,22 @@ func TestProcessors(t *testing.T) { _, err = testEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) assert.NoError(t, err) - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag == 0 { - t.Error(errors.New("AfterUpdatedFlag not set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt == 0 { - t.Error(errors.New("AfterUpdatedViaExt not set")) - } + assert.False(t, p.B4UpdateFlag == 0, "B4UpdateFlag not set") + assert.False(t, p.AfterUpdatedFlag == 0, "AfterUpdatedFlag not set") + assert.False(t, p.B4UpdateViaExt == 0, "B4UpdateViaExt not set") + assert.False(t, p.AfterUpdatedViaExt == 0, "AfterUpdatedViaExt not set") p2 = &ProcessorsStruct{} has, err = testEngine.ID(p.Id).Get(p2) assert.NoError(t, err) assert.True(t, has) - if p2.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p2.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set: " + string(p.AfterUpdatedFlag))) - } - if p2.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p2.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set: " + string(p.AfterUpdatedViaExt))) - } - if p2.BeforeSetFlag != 9 { - t.Error(fmt.Errorf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) - } - if p2.AfterSetFlag != 9 { - t.Error(fmt.Errorf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) - } + assert.False(t, p2.B4UpdateFlag == 0, "B4UpdateFlag not set") + assert.False(t, p2.AfterUpdatedFlag != 0, fmt.Sprintf("AfterUpdatedFlag is set: %d", p.AfterUpdatedFlag)) + assert.False(t, p2.B4UpdateViaExt == 0, "B4UpdateViaExt not set") + assert.False(t, p2.AfterUpdatedViaExt != 0, fmt.Sprintf("AfterUpdatedViaExt is set: %d", p.AfterUpdatedViaExt)) + assert.False(t, p2.BeforeSetFlag != 9, fmt.Sprintf("BeforeSetFlag is %d not 9", p2.BeforeSetFlag)) + assert.False(t, p2.AfterSetFlag != 9, fmt.Sprintf("AfterSetFlag is %d not 9", p2.BeforeSetFlag)) // -- // test delete processors @@ -382,7 +358,7 @@ func TestProcessors(t *testing.T) { } func TestProcessorsTx(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables(&ProcessorsStruct{}) assert.NoError(t, err) @@ -450,12 +426,7 @@ func TestProcessorsTx(t *testing.T) { p2 := &ProcessorsStruct{} _, err = testEngine.ID(p.Id).Get(p2) assert.NoError(t, err) - - if p2.Id > 0 { - err = errors.New("tx got committed upon insert!?") - t.Error(err) - panic(err) - } + assert.False(t, p2.Id > 0, "tx got committed upon insert!?") // -- // test insert processors with tx commit @@ -516,7 +487,7 @@ func TestProcessorsTx(t *testing.T) { t.Error(errors.New("AfterInsertedViaExt is set")) } - insertedId := p2.Id + insertedID := p2.Id // -- // test update processors with tx rollback @@ -544,7 +515,7 @@ func TestProcessorsTx(t *testing.T) { p = p2 // reset - _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + _, err = session.ID(insertedID).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) assert.NoError(t, err) if p.B4UpdateFlag == 0 { @@ -579,7 +550,7 @@ func TestProcessorsTx(t *testing.T) { session.Close() p2 = &ProcessorsStruct{} - _, err = testEngine.ID(insertedId).Get(p2) + _, err = testEngine.ID(insertedID).Get(p2) assert.NoError(t, err) if p2.B4UpdateFlag != 0 { @@ -603,7 +574,7 @@ func TestProcessorsTx(t *testing.T) { err = session.Begin() assert.NoError(t, err) - p = &ProcessorsStruct{Id: insertedId} + p = &ProcessorsStruct{Id: insertedID} _, err = session.Update(p) assert.NoError(t, err) @@ -642,7 +613,7 @@ func TestProcessorsTx(t *testing.T) { p = &ProcessorsStruct{} - _, err = session.ID(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + _, err = session.ID(insertedID).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) assert.NoError(t, err) if p.B4UpdateFlag == 0 { @@ -676,7 +647,7 @@ func TestProcessorsTx(t *testing.T) { session.Close() p2 = &ProcessorsStruct{} - _, err = testEngine.ID(insertedId).Get(p2) + _, err = testEngine.ID(insertedID).Get(p2) assert.NoError(t, err) if p.B4UpdateFlag == 0 { @@ -718,7 +689,7 @@ func TestProcessorsTx(t *testing.T) { p = &ProcessorsStruct{} // reset - _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + _, err = session.ID(insertedID).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) assert.NoError(t, err) if p.B4DeleteFlag == 0 { @@ -752,7 +723,7 @@ func TestProcessorsTx(t *testing.T) { session.Close() p2 = &ProcessorsStruct{} - _, err = testEngine.ID(insertedId).Get(p2) + _, err = testEngine.ID(insertedID).Get(p2) assert.NoError(t, err) if p2.B4DeleteFlag != 0 { @@ -778,7 +749,7 @@ func TestProcessorsTx(t *testing.T) { p = &ProcessorsStruct{} - _, err = session.ID(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + _, err = session.ID(insertedID).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) assert.NoError(t, err) if p.B4DeleteFlag == 0 { @@ -819,7 +790,7 @@ func TestProcessorsTx(t *testing.T) { err = session.Begin() assert.NoError(t, err) - p = &ProcessorsStruct{Id: insertedId} + p = &ProcessorsStruct{Id: insertedID} _, err = session.Delete(p) assert.NoError(t, err) @@ -846,7 +817,6 @@ func TestProcessorsTx(t *testing.T) { t.Error(errors.New("AfterUpdatedFlag set")) } session.Close() - // -- } type AfterLoadStructA struct { @@ -862,19 +832,19 @@ type AfterLoadStructB struct { Err error `xorm:"-"` } -func (s *AfterLoadStructB) AfterLoad(session *Session) { +func (s *AfterLoadStructB) AfterLoad(session *xorm.Session) { has, err := session.ID(s.AId).NoAutoCondition().Get(&s.A) if err != nil { s.Err = err return } if !has { - s.Err = ErrNotExist + s.Err = xorm.ErrNotExist } } func TestAfterLoadProcessor(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(AfterLoadStructA), new(AfterLoadStructB)) @@ -925,7 +895,7 @@ func (a *AfterInsertStruct) AfterInsert() { } func TestAfterInsert(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(AfterInsertStruct)) diff --git a/rows_test.go b/integrations/rows_test.go similarity index 87% rename from rows_test.go rename to integrations/rows_test.go index af333861..f68030a4 100644 --- a/rows_test.go +++ b/integrations/rows_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" @@ -11,7 +11,7 @@ import ( ) func TestRows(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserRows struct { Id int64 @@ -85,7 +85,7 @@ func TestRows(t *testing.T) { } func TestRowsMyTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserRowsMyTable struct { Id int64 @@ -104,7 +104,6 @@ func TestRowsMyTableName(t *testing.T) { rows, err := testEngine.Table(tableName).Rows(new(UserRowsMyTable)) assert.NoError(t, err) - defer rows.Close() cnt = 0 user := new(UserRowsMyTable) @@ -114,6 +113,21 @@ func TestRowsMyTableName(t *testing.T) { cnt++ } assert.EqualValues(t, 1, cnt) + + rows.Close() + + rows, err = testEngine.Table(tableName).Rows(&UserRowsMyTable{ + Id: 2, + }) + assert.NoError(t, err) + cnt = 0 + user = new(UserRowsMyTable) + for rows.Next() { + err = rows.Scan(user) + assert.NoError(t, err) + cnt++ + } + assert.EqualValues(t, 0, cnt) } type UserRowsSpecTable struct { @@ -126,7 +140,7 @@ func (UserRowsSpecTable) TableName() string { } func TestRowsSpecTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(UserRowsSpecTable))) cnt, err := testEngine.Insert(&UserRowsSpecTable{ diff --git a/session_cols_test.go b/integrations/session_cols_test.go similarity index 93% rename from session_cols_test.go rename to integrations/session_cols_test.go index 96cb1620..b74c6f8a 100644 --- a/session_cols_test.go +++ b/integrations/session_cols_test.go @@ -2,18 +2,18 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" "github.com/stretchr/testify/assert" "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/schemas" ) func TestSetExpr(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserExprIssue struct { Id int64 @@ -45,7 +45,7 @@ func TestSetExpr(t *testing.T) { assert.EqualValues(t, 1, cnt) var not = "NOT" - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { not = "~" } cnt, err = testEngine.SetExpr("show", not+" `show`").ID(1).Update(new(UserExpr)) @@ -64,7 +64,7 @@ func TestSetExpr(t *testing.T) { } func TestCols(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type ColsTable struct { Id int64 @@ -96,7 +96,7 @@ func TestCols(t *testing.T) { } func TestMustCol(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type CustomerUpdate struct { Id int64 `form:"id" json:"id"` diff --git a/session_cond_test.go b/integrations/session_cond_test.go similarity index 88% rename from session_cond_test.go rename to integrations/session_cond_test.go index 10650484..a0a91cad 100644 --- a/session_cond_test.go +++ b/integrations/session_cond_test.go @@ -2,19 +2,19 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "errors" "fmt" "testing" - "xorm.io/builder" "github.com/stretchr/testify/assert" + "xorm.io/builder" ) func TestBuilder(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) const ( OpEqual int = iota @@ -102,7 +102,7 @@ func TestBuilder(t *testing.T) { } func TestIn(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(Userinfo))) cnt, err := testEngine.Insert([]Userinfo{ @@ -137,15 +137,13 @@ func TestIn(t *testing.T) { idsStr = idsStr[:len(idsStr)-1] users := make([]Userinfo, 0) - err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users) + err = testEngine.In("id", ids[0], ids[1], ids[2]).Find(&users) assert.NoError(t, err) - fmt.Println(users) assert.EqualValues(t, 3, len(users)) users = make([]Userinfo, 0) - err = testEngine.In("(id)", ids).Find(&users) + err = testEngine.In("id", ids).Find(&users) assert.NoError(t, err) - fmt.Println(users) assert.EqualValues(t, 3, len(users)) for _, user := range users { @@ -161,9 +159,8 @@ func TestIn(t *testing.T) { idsInterface = append(idsInterface, id) } - err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users) + err = testEngine.Where(department+" = ?", "dev").In("id", idsInterface...).Find(&users) assert.NoError(t, err) - fmt.Println(users) assert.EqualValues(t, 3, len(users)) for _, user := range users { @@ -175,11 +172,10 @@ func TestIn(t *testing.T) { dev := testEngine.GetColumnMapper().Obj2Table("Dev") - err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users) + err = testEngine.In("id", 1).In("id", 2).In(department, dev).Find(&users) assert.NoError(t, err) - fmt.Println(users) - cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"}) + cnt, err = testEngine.In("id", ids[0]).Update(&Userinfo{Departname: "dev-"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) @@ -189,17 +185,17 @@ func TestIn(t *testing.T) { assert.True(t, has) assert.EqualValues(t, "dev-", user.Departname) - cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"}) + cnt, err = testEngine.In("id", ids[0]).Update(&Userinfo{Departname: "dev"}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{}) + cnt, err = testEngine.In("id", ids[1]).Delete(&Userinfo{}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) } func TestFindAndCount(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type FindAndCount struct { Id int64 diff --git a/session_delete_test.go b/integrations/session_delete_test.go similarity index 93% rename from session_delete_test.go rename to integrations/session_delete_test.go index 5edb0718..f3565963 100644 --- a/session_delete_test.go +++ b/integrations/session_delete_test.go @@ -2,18 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" "time" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/schemas" + "github.com/stretchr/testify/assert" ) func TestDelete(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoDelete struct { Uid int64 `xorm:"id pk not null autoincr"` @@ -26,7 +28,7 @@ func TestDelete(t *testing.T) { defer session.Close() var err error - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("SET IDENTITY_INSERT userinfo_delete ON") @@ -38,7 +40,7 @@ func TestDelete(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } @@ -69,7 +71,7 @@ func TestDelete(t *testing.T) { } func TestDeleted(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type Deleted struct { Id int64 `xorm:"pk"` @@ -156,10 +158,10 @@ func TestDeleted(t *testing.T) { } func TestCacheDelete(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) oldCacher := testEngine.GetDefaultCacher() - cacher := NewLRUCacher(NewMemoryStore(), 1000) + cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 1000) testEngine.SetDefaultCacher(cacher) type CacheDeleteStruct struct { @@ -188,7 +190,7 @@ func TestCacheDelete(t *testing.T) { } func TestUnscopeDelete(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UnscopeDeleteStruct struct { Id int64 diff --git a/integrations/session_exist_test.go b/integrations/session_exist_test.go new file mode 100644 index 00000000..6247c91a --- /dev/null +++ b/integrations/session_exist_test.go @@ -0,0 +1,208 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestExistStruct(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type RecordExist struct { + Id int64 + Name string + } + + assertSync(t, new(RecordExist)) + + has, err := testEngine.Exist(new(RecordExist)) + assert.NoError(t, err) + assert.False(t, has) + + cnt, err := testEngine.Insert(&RecordExist{ + Name: "test1", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + has, err = testEngine.Exist(new(RecordExist)) + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Exist(&RecordExist{ + Name: "test1", + }) + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Exist(&RecordExist{ + Name: "test2", + }) + assert.NoError(t, err) + assert.False(t, has) + + has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{}) + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Where("name = ?", "test2").Exist(&RecordExist{}) + assert.NoError(t, err) + assert.False(t, has) + + has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist() + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist() + assert.NoError(t, err) + assert.False(t, has) + + has, err = testEngine.Table("record_exist").Exist() + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() + assert.NoError(t, err) + assert.True(t, has) + + has, err = testEngine.Table("record_exist").Where("name = ?", "test2").Exist() + assert.NoError(t, err) + assert.False(t, has) +} + +func TestExistStructForJoin(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type Number struct { + Id int64 + Lid int64 + } + + type OrderList struct { + Id int64 + Eid int64 + } + + type Player struct { + Id int64 + Name string + } + + assert.NoError(t, testEngine.Sync2(new(Number), new(OrderList), new(Player))) + + var ply Player + cnt, err := testEngine.Insert(&ply) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var orderlist = OrderList{ + Eid: ply.Id, + } + cnt, err = testEngine.Insert(&orderlist) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var um = Number{ + Lid: orderlist.Id, + } + cnt, err = testEngine.Insert(&um) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + session := testEngine.NewSession() + defer session.Close() + + session.Table("number"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid"). + Where("number.lid = ?", 1) + has, err := session.Exist() + assert.NoError(t, err) + assert.True(t, has) + + session.Table("number"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid"). + Where("number.lid = ?", 2) + has, err = session.Exist() + assert.NoError(t, err) + assert.False(t, has) + + session.Table("number"). + Select("order_list.id"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid"). + Where("order_list.id = ?", 1) + has, err = session.Exist() + assert.NoError(t, err) + assert.True(t, has) + + session.Table("number"). + Select("player.id"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid"). + Where("player.id = ?", 2) + has, err = session.Exist() + assert.NoError(t, err) + assert.False(t, has) + + session.Table("number"). + Select("player.id"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid") + has, err = session.Exist() + assert.NoError(t, err) + assert.True(t, has) + + err = session.DropTable("order_list") + assert.NoError(t, err) + + exist, err := session.IsTableExist("order_list") + assert.NoError(t, err) + assert.False(t, exist) + + session.Table("number"). + Select("player.id"). + Join("INNER", "order_list", "order_list.id = number.lid"). + Join("LEFT", "player", "player.id = order_list.eid") + has, err = session.Exist() + assert.Error(t, err) + assert.False(t, has) + + session.Table("number"). + Select("player.id"). + Join("LEFT", "player", "player.id = number.lid") + has, err = session.Exist() + assert.NoError(t, err) + assert.True(t, has) +} + +func TestExistContext(t *testing.T) { + type ContextQueryStruct struct { + Id int64 + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(ContextQueryStruct)) + + _, err := testEngine.Insert(&ContextQueryStruct{Name: "1"}) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + defer cancel() + + time.Sleep(time.Nanosecond) + + has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context deadline exceeded") + assert.False(t, has) +} diff --git a/session_find_test.go b/integrations/session_find_test.go similarity index 60% rename from session_find_test.go rename to integrations/session_find_test.go index f805f06e..95cf9384 100644 --- a/session_find_test.go +++ b/integrations/session_find_test.go @@ -2,20 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( - "errors" - "fmt" "testing" "time" - "xorm.io/core" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/names" + "github.com/stretchr/testify/assert" ) func TestJoinLimit(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type Salary struct { Id int64 @@ -62,45 +62,27 @@ func TestJoinLimit(t *testing.T) { assert.NoError(t, err) } -func assertSync(t *testing.T, beans ...interface{}) { - for _, bean := range beans { - assert.NoError(t, testEngine.DropTables(bean)) - assert.NoError(t, testEngine.Sync2(bean)) - } -} - func TestWhere(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.Where("(id) > ?", 2).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + err := testEngine.Where("id > ?", 2).Find(&users) + assert.NoError(t, err) - err = testEngine.Where("(id) > ?", 2).And("(id) < ?", 10).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + err = testEngine.Where("id > ?", 2).And("id < ?", 10).Find(&users) + assert.NoError(t, err) } func TestFind(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) err := testEngine.Find(&users) assert.NoError(t, err) - for _, user := range users { - fmt.Println(user) - } users2 := make([]Userinfo, 0) var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) @@ -109,17 +91,13 @@ func TestFind(t *testing.T) { } func TestFind2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) users := make([]*Userinfo, 0) assertSync(t, new(Userinfo)) err := testEngine.Find(&users) assert.NoError(t, err) - - for _, user := range users { - fmt.Println(user) - } } type Team struct { @@ -138,7 +116,7 @@ func (TeamUser) TableName() string { func TestFind3(t *testing.T) { var teamUser = new(TeamUser) - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.Sync2(new(Team), teamUser) assert.NoError(t, err) @@ -192,37 +170,45 @@ func TestFind3(t *testing.T) { } func TestFindMap(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) + cnt, err := testEngine.Insert(&Userinfo{ + Username: "lunny", + Departname: "depart1", + IsMan: true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + users := make(map[int64]Userinfo) - err := testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for _, user := range users { - fmt.Println(user) - } + err = testEngine.Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + assert.EqualValues(t, "lunny", users[1].Username) + assert.EqualValues(t, "depart1", users[1].Departname) + assert.True(t, users[1].IsMan) + + users = make(map[int64]Userinfo) + err = testEngine.Cols("username, departname").Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) + assert.EqualValues(t, "lunny", users[1].Username) + assert.EqualValues(t, "depart1", users[1].Departname) + assert.False(t, users[1].IsMan) } func TestFindMap2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make(map[int64]*Userinfo) err := testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for id, user := range users { - fmt.Println(id, user) - } + assert.NoError(t, err) } func TestDistinct(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) _, err := testEngine.Insert(&Userinfo{ @@ -236,8 +222,6 @@ func TestDistinct(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(users)) - fmt.Println(users) - type Depart struct { Departname string } @@ -245,31 +229,24 @@ func TestDistinct(t *testing.T) { users2 := make([]Depart, 0) err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2) assert.NoError(t, err) - if len(users2) != 1 { - fmt.Println(len(users2)) - t.Error(err) - panic(errors.New("should be one record")) - } - fmt.Println(users2) + assert.EqualValues(t, 1, len(users2)) } func TestOrder(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) err := testEngine.OrderBy("id desc").Find(&users) assert.NoError(t, err) - fmt.Println(users) users2 := make([]Userinfo, 0) err = testEngine.Asc("id", "username").Desc("height").Find(&users2) assert.NoError(t, err) - fmt.Println(users2) } func TestGroupBy(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) @@ -278,207 +255,151 @@ func TestGroupBy(t *testing.T) { } func TestHaving(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users) assert.NoError(t, err) - fmt.Println(users) - - /*users = make([]Userinfo, 0) - err = testEngine.Cols("id, username").GroupBy("username").Having("username='xlw'").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users)*/ } func TestOrderSameMapper(t *testing.T) { - assert.NoError(t, prepareEngine()) - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + assert.NoError(t, PrepareEngine()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) mapper := testEngine.GetTableMapper() - testEngine.SetMapper(core.SameMapper{}) + testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(mapper) }() assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) - err := testEngine.OrderBy("(id) desc").Find(&users) + err := testEngine.OrderBy("id desc").Find(&users) assert.NoError(t, err) - fmt.Println(users) users2 := make([]Userinfo, 0) - err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2) + err = testEngine.Asc("id", "Username").Desc("Height").Find(&users2) assert.NoError(t, err) - fmt.Println(users2) } func TestHavingSameMapper(t *testing.T) { - assert.NoError(t, prepareEngine()) - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + assert.NoError(t, PrepareEngine()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) mapper := testEngine.GetTableMapper() - testEngine.SetMapper(core.SameMapper{}) + testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(mapper) }() assertSync(t, new(Userinfo)) users := make([]Userinfo, 0) err := testEngine.GroupBy("`Username`").Having("`Username`='xlw'").Find(&users) - if err != nil { - t.Fatal(err) - } - fmt.Println(users) + assert.NoError(t, err) } func TestFindInts(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var idsInt64 []int64 err := testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsInt64) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsInt64) + assert.NoError(t, err) var idsInt32 []int32 err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsInt32) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsInt32) + assert.NoError(t, err) var idsInt []int err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsInt) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsInt) + assert.NoError(t, err) var idsUint []uint err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsUint) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsUint) + assert.NoError(t, err) type MyInt int var idsMyInt []MyInt err = testEngine.Table(userinfo).Cols("id").Desc("id").Find(&idsMyInt) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsMyInt) + assert.NoError(t, err) } func TestFindStrings(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") username := testEngine.GetColumnMapper().Obj2Table("Username") var idsString []string err := testEngine.Table(userinfo).Cols(username).Desc("id").Find(&idsString) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsString) + assert.NoError(t, err) } func TestFindMyString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") username := testEngine.GetColumnMapper().Obj2Table("Username") var idsMyString []MyString err := testEngine.Table(userinfo).Cols(username).Desc("id").Find(&idsMyString) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsMyString) + assert.NoError(t, err) } func TestFindInterface(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") username := testEngine.GetColumnMapper().Obj2Table("Username") var idsInterface []interface{} err := testEngine.Table(userinfo).Cols(username).Desc("id").Find(&idsInterface) - if err != nil { - t.Fatal(err) - } - fmt.Println(idsInterface) + assert.NoError(t, err) } func TestFindSliceBytes(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var ids [][][]byte err := testEngine.Table(userinfo).Desc("id").Find(&ids) - if err != nil { - t.Fatal(err) - } - for _, record := range ids { - fmt.Println(record) - } + assert.NoError(t, err) } func TestFindSlicePtrString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var ids [][]*string err := testEngine.Table(userinfo).Desc("id").Find(&ids) - if err != nil { - t.Fatal(err) - } - for _, record := range ids { - fmt.Println(record) - } + assert.NoError(t, err) } func TestFindMapBytes(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var ids []map[string][]byte err := testEngine.Table(userinfo).Desc("id").Find(&ids) - if err != nil { - t.Fatal(err) - } - for _, record := range ids { - fmt.Println(record) - } + assert.NoError(t, err) } func TestFindMapPtrString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") var ids []map[string]*string err := testEngine.Table(userinfo).Desc("id").Find(&ids) assert.NoError(t, err) - for _, record := range ids { - fmt.Println(record) - } } func TestFindBit(t *testing.T) { @@ -487,7 +408,7 @@ func TestFindBit(t *testing.T) { Msg bool `xorm:"bit"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindBitStruct)) cnt, err := testEngine.Insert([]FindBitStruct{ @@ -515,7 +436,7 @@ func TestFindMark(t *testing.T) { MarkA string `xorm:"VARCHAR(1)"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Mark)) cnt, err := testEngine.Insert([]Mark{ @@ -546,7 +467,7 @@ func TestFindAndCountOneFunc(t *testing.T) { Msg bool `xorm:"bit"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindAndCountStruct)) cnt, err := testEngine.Insert([]FindAndCountStruct{ @@ -563,6 +484,12 @@ func TestFindAndCountOneFunc(t *testing.T) { assert.EqualValues(t, 2, cnt) var results = make([]FindAndCountStruct, 0, 2) + cnt, err = testEngine.Limit(1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 2, cnt) + + results = make([]FindAndCountStruct, 0, 2) cnt, err = testEngine.FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 2, len(results)) @@ -575,10 +502,48 @@ func TestFindAndCountOneFunc(t *testing.T) { assert.EqualValues(t, 1, cnt) results = make([]FindAndCountStruct, 0, 1) - cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results) + cnt, err = testEngine.Where("1=1").Limit(1).FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) - assert.EqualValues(t, 1, cnt) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, FindAndCountStruct{ + Id: 1, + Content: "111", + Msg: false, + }, results[0]) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("1=1").Limit(1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, FindAndCountStruct{ + Id: 1, + Content: "111", + Msg: false, + }, results[0]) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("1=1").Limit(1, 1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, FindAndCountStruct{ + Id: 2, + Content: "222", + Msg: true, + }, results[0]) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("1=1").Limit(1, 1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 2, cnt) + assert.EqualValues(t, FindAndCountStruct{ + Id: 2, + Content: "222", + Msg: true, + }, results[0]) results = make([]FindAndCountStruct, 0, 1) cnt, err = testEngine.Where("msg = ?", true).Select("id, content, msg"). @@ -589,10 +554,96 @@ func TestFindAndCountOneFunc(t *testing.T) { results = make([]FindAndCountStruct, 0, 1) cnt, err = testEngine.Where("msg = ?", true).Desc("id"). - Limit(1).FindAndCount(&results) + Limit(1).Cols("content").FindAndCount(&results) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, cnt) + + ids := make([]int64, 0, 2) + tableName := testEngine.GetTableMapper().Obj2Table("FindAndCountStruct") + cnt, err = testEngine.Table(tableName).Limit(1).Cols("id").FindAndCount(&ids) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(ids)) + assert.EqualValues(t, 2, cnt) +} + +func TestFindAndCountOneFuncWithDeleted(t *testing.T) { + type CommentWithDeleted struct { + Id int `xorm:"pk autoincr"` + DeletedAt int64 `xorm:"deleted notnull default(0) index"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(CommentWithDeleted)) + + var comments []CommentWithDeleted + cnt, err := testEngine.FindAndCount(&comments) + assert.NoError(t, err) + assert.EqualValues(t, 0, cnt) +} + +func TestFindAndCount2(t *testing.T) { + // User + type TestFindAndCountUser struct { + Id int64 `xorm:"bigint(11) pk autoincr"` + Name string `xorm:"'name'"` + } + + // Hotel + type TestFindAndCountHotel struct { + Id int64 `xorm:"bigint(11) pk autoincr"` + Name string `xorm:"'name'"` + Code string `xorm:"'code'"` + Region string `xorm:"'region'"` + CreateBy *TestFindAndCountUser `xorm:"'create_by'"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestFindAndCountUser), new(TestFindAndCountHotel)) + + var u = TestFindAndCountUser{ + Name: "myname", + } + cnt, err := testEngine.Insert(&u) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var hotel = TestFindAndCountHotel{ + Name: "myhotel", + Code: "111", + Region: "222", + CreateBy: &u, + } + cnt, err = testEngine.Insert(&hotel) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + hotels := make([]*TestFindAndCountHotel, 0) + cnt, err = testEngine. + Alias("t"). + Limit(10, 0). + FindAndCount(&hotels) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + hotels = make([]*TestFindAndCountHotel, 0) + cnt, err = testEngine. + Table(new(TestFindAndCountHotel)). + Alias("t"). + Limit(10, 0). + FindAndCount(&hotels) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + hotels = make([]*TestFindAndCountHotel, 0) + cnt, err = testEngine. + Table(new(TestFindAndCountHotel)). + Alias("t"). + Where("t.region like '6501%'"). + Limit(10, 0). + FindAndCount(&hotels) + assert.NoError(t, err) + assert.EqualValues(t, 0, cnt) } type FindMapDevice struct { @@ -605,7 +656,7 @@ func (device *FindMapDevice) TableName() string { } func TestFindMapStringId(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindMapDevice)) cnt, err := testEngine.Insert(&FindMapDevice{ @@ -676,7 +727,7 @@ func TestFindExtends(t *testing.T) { FindExtendsB `xorm:"extends"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindExtendsA)) cnt, err := testEngine.Insert(&FindExtendsA{ @@ -695,6 +746,13 @@ func TestFindExtends(t *testing.T) { err = testEngine.Find(&results) assert.NoError(t, err) assert.EqualValues(t, 2, len(results)) + + results = make([]FindExtendsA, 0, 2) + err = testEngine.Find(&results, &FindExtendsB{ + ID: 1, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) } func TestFindExtends3(t *testing.T) { @@ -711,7 +769,7 @@ func TestFindExtends3(t *testing.T) { FindExtendsBB `xorm:"extends"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(FindExtendsAA)) cnt, err := testEngine.Insert(&FindExtendsAA{ @@ -747,7 +805,7 @@ func TestFindCacheLimit(t *testing.T) { Created time.Time `xorm:"created"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(InviteCode)) cnt, err := testEngine.Insert(&InviteCode{ @@ -788,8 +846,12 @@ func TestFindJoin(t *testing.T) { DeviceId int64 } - assert.NoError(t, prepareEngine()) - assertSync(t, new(SceneItem), new(DeviceUserPrivrels)) + type Order struct { + Id int64 + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(SceneItem), new(DeviceUserPrivrels), new(Order)) var scenes []SceneItem err := testEngine.Join("LEFT OUTER", "device_user_privrels", "device_user_privrels.device_id=scene_item.device_id"). @@ -800,4 +862,96 @@ func TestFindJoin(t *testing.T) { err = testEngine.Join("LEFT OUTER", new(DeviceUserPrivrels), "device_user_privrels.device_id=scene_item.device_id"). Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) assert.NoError(t, err) + + scenes = make([]SceneItem, 0) + err = testEngine.Join("INNER", "order", "`scene_item`.device_id=`order`.id").Find(&scenes) + assert.NoError(t, err) +} + +func TestJoinFindLimit(t *testing.T) { + type JoinFindLimit1 struct { + Id int64 + Name string + } + + type JoinFindLimit2 struct { + Id int64 + Eid int64 `xorm:"index"` + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(JoinFindLimit1), new(JoinFindLimit2)) + + var finds []JoinFindLimit1 + err := testEngine.Join("INNER", new(JoinFindLimit2), "join_find_limit2.eid=join_find_limit1.id"). + Limit(10, 10).Find(&finds) + assert.NoError(t, err) +} + +func TestMoreExtends(t *testing.T) { + type MoreExtendsUsers struct { + ID int64 `xorm:"id autoincr pk" json:"id"` + Name string `xorm:"name not null" json:"name"` + CreatedAt time.Time `xorm:"created not null" json:"created_at"` + UpdatedAt time.Time `xorm:"updated not null" json:"updated_at"` + DeletedAt time.Time `xorm:"deleted" json:"deleted_at"` + } + + type MoreExtendsBooks struct { + ID int64 `xorm:"id autoincr pk" json:"id"` + Name string `xorm:"name not null" json:"name"` + UserID int64 `xorm:"user_id not null" json:"user_id"` + CreatedAt time.Time `xorm:"created not null" json:"created_at"` + UpdatedAt time.Time `xorm:"updated not null" json:"updated_at"` + DeletedAt time.Time `xorm:"deleted" json:"deleted_at"` + } + + type MoreExtendsBooksExtend struct { + MoreExtendsBooks `xorm:"extends"` + Users MoreExtendsUsers `xorm:"extends" json:"users"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(MoreExtendsUsers), new(MoreExtendsBooks)) + + var books []MoreExtendsBooksExtend + err := testEngine.Table("more_extends_books").Select("more_extends_books.*, more_extends_users.*"). + Join("INNER", "more_extends_users", "more_extends_books.user_id = more_extends_users.id"). + Where("more_extends_books.name LIKE ?", "abc"). + Limit(10, 10). + Find(&books) + assert.NoError(t, err) + + books = make([]MoreExtendsBooksExtend, 0, len(books)) + err = testEngine.Table("more_extends_books"). + Alias("m"). + Select("m.*, more_extends_users.*"). + Join("INNER", "more_extends_users", "m.user_id = more_extends_users.id"). + Where("m.name LIKE ?", "abc"). + Limit(10, 10). + Find(&books) + assert.NoError(t, err) +} + +func TestDistinctAndCols(t *testing.T) { + type DistinctAndCols struct { + Id int64 + Name string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(DistinctAndCols)) + + cnt, err := testEngine.Insert(&DistinctAndCols{ + Name: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var names []string + err = testEngine.Table("distinct_and_cols").Cols("name").Distinct("name").Find(&names) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(names)) + assert.EqualValues(t, "test", names[0]) } diff --git a/session_get_test.go b/integrations/session_get_test.go similarity index 88% rename from session_get_test.go rename to integrations/session_get_test.go index fcef992e..4e50f9ab 100644 --- a/session_get_test.go +++ b/integrations/session_get_test.go @@ -2,20 +2,73 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "database/sql" "fmt" + "strconv" "testing" "time" + "xorm.io/xorm/contexts" + "xorm.io/xorm/schemas" + "github.com/stretchr/testify/assert" - "xorm.io/core" ) +func convertInt(v interface{}) (int64, error) { + switch v.(type) { + case int: + return int64(v.(int)), nil + case int8: + return int64(v.(int8)), nil + case int16: + return int64(v.(int16)), nil + case int32: + return int64(v.(int32)), nil + case int64: + return v.(int64), nil + case []byte: + i, err := strconv.ParseInt(string(v.([]byte)), 10, 64) + if err != nil { + return 0, err + } + return i, nil + case string: + i, err := strconv.ParseInt(v.(string), 10, 64) + if err != nil { + return 0, err + } + return i, nil + } + return 0, fmt.Errorf("unsupported type: %v", v) +} + +func convertFloat(v interface{}) (float64, error) { + switch v.(type) { + case float32: + return float64(v.(float32)), nil + case float64: + return v.(float64), nil + case string: + i, err := strconv.ParseFloat(v.(string), 64) + if err != nil { + return 0, err + } + return i, nil + case []byte: + i, err := strconv.ParseFloat(string(v.([]byte)), 64) + if err != nil { + return 0, err + } + return i, nil + } + return 0, fmt.Errorf("unsupported type: %v", v) +} + func TestGetVar(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar struct { Id int64 `xorm:"autoincr pk"` @@ -153,7 +206,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money)) var money2 float64 - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { has, err = testEngine.SQL("SELECT TOP 1 money FROM " + testEngine.TableName("get_var", true)).Get(&money2) } else { has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2) @@ -178,7 +231,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "1.5", valuesString["money"]) // for mymysql driver, interface{} will be []byte, so ignore it currently - if testEngine.Dialect().DriverName() != "mymysql" { + if testEngine.DriverName() != "mymysql" { var valuesInter = make(map[string]interface{}) has, err = testEngine.Table("get_var").Where("id = ?", 1).Select("*").Get(&valuesInter) assert.NoError(t, err) @@ -220,7 +273,7 @@ func TestGetVar(t *testing.T) { } func TestGetStruct(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoGet struct { Uid int `xorm:"pk autoincr"` @@ -233,7 +286,7 @@ func TestGetStruct(t *testing.T) { defer session.Close() var err error - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("SET IDENTITY_INSERT userinfo_get ON") @@ -242,7 +295,7 @@ func TestGetStruct(t *testing.T) { cnt, err := session.Insert(&UserinfoGet{Uid: 2}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } @@ -275,7 +328,7 @@ func TestGetStruct(t *testing.T) { } func TestGetSlice(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoSlice struct { Uid int `xorm:"pk autoincr"` @@ -291,7 +344,7 @@ func TestGetSlice(t *testing.T) { } func TestGetError(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetError struct { Uid int `xorm:"pk autoincr"` @@ -311,7 +364,7 @@ func TestGetError(t *testing.T) { } func TestJSONString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type JsonString struct { Id int64 @@ -334,17 +387,17 @@ func TestJSONString(t *testing.T) { assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, 1, js.Id) - assert.EqualValues(t, `["1","2"]`, js.Content) + assert.True(t, `["1","2"]` == js.Content || `["1", "2"]` == js.Content) var jss []JsonString err = testEngine.Table("json_json").Find(&jss) assert.NoError(t, err) assert.EqualValues(t, 1, len(jss)) - assert.EqualValues(t, `["1","2"]`, jss[0].Content) + assert.True(t, `["1","2"]` == jss[0].Content || `["1", "2"]` == jss[0].Content) } func TestGetActionMapping(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type ActionMapping struct { ActionId string `xorm:"pk"` @@ -381,7 +434,7 @@ func TestGetStructId(t *testing.T) { Id int64 } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(TestGetStruct)) _, err := testEngine.Insert(&TestGetStruct{}) @@ -408,7 +461,7 @@ func TestContextGet(t *testing.T) { Name string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(ContextGetStruct)) _, err := testEngine.Insert(&ContextGetStruct{Name: "1"}) @@ -417,7 +470,7 @@ func TestContextGet(t *testing.T) { sess := testEngine.NewSession() defer sess.Close() - context := NewMemoryContextCache() + context := contexts.NewMemoryContextCache() var c2 ContextGetStruct has, err := sess.ID(1).NoCache().ContextCache(context).Get(&c2) @@ -446,13 +499,13 @@ func TestContextGet2(t *testing.T) { Name string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(ContextGetStruct2)) _, err := testEngine.Insert(&ContextGetStruct2{Name: "1"}) assert.NoError(t, err) - context := NewMemoryContextCache() + context := contexts.NewMemoryContextCache() var c2 ContextGetStruct2 has, err := testEngine.ID(1).NoCache().ContextCache(context).Get(&c2) @@ -480,12 +533,12 @@ type MyGetCustomTableImpletation struct { const getCustomTableName = "GetCustomTableInterface" -func (m *MyGetCustomTableImpletation) TableName() string { +func (MyGetCustomTableImpletation) TableName() string { return getCustomTableName } func TestGetCustomTableInterface(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Table(getCustomTableName).Sync2(new(MyGetCustomTableImpletation))) exist, err := testEngine.IsTableExist(getCustomTableName) @@ -510,7 +563,7 @@ func TestGetNullVar(t *testing.T) { Age int } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(TestGetNullVarStruct)) affected, err := testEngine.Exec("insert into " + testEngine.TableName(new(TestGetNullVarStruct), true) + " (name,age) values (null,null)") @@ -595,7 +648,7 @@ func TestCustomTypes(t *testing.T) { Age MyInt } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(TestCustomizeStruct)) var s = TestCustomizeStruct{ @@ -626,7 +679,7 @@ func TestGetViaMapCond(t *testing.T) { Index int } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(GetViaMapCond)) var ( diff --git a/session_insert_test.go b/integrations/session_insert_test.go similarity index 71% rename from session_insert_test.go rename to integrations/session_insert_test.go index e6100fdc..47789b8a 100644 --- a/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -2,20 +2,21 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( - "errors" "fmt" "reflect" "testing" "time" + "xorm.io/xorm" + "github.com/stretchr/testify/assert" ) func TestInsertOne(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type Test struct { Id int64 `xorm:"autoincr pk"` @@ -32,7 +33,7 @@ func TestInsertOne(t *testing.T) { func TestInsertMulti(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TestMulti struct { Id int64 `xorm:"int(11) pk"` Name string `xorm:"varchar(255)"` @@ -107,7 +108,7 @@ func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) er } func TestInsertOneIfPkIsPoint(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TestPoint struct { Id *int64 `xorm:"autoincr pk notnull 'id'"` @@ -123,7 +124,7 @@ func TestInsertOneIfPkIsPoint(t *testing.T) { } func TestInsertOneIfPkIsPointRename(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type ID *int64 type TestPoint2 struct { Id ID `xorm:"autoincr pk notnull 'id'"` @@ -139,7 +140,7 @@ func TestInsertOneIfPkIsPointRename(t *testing.T) { } func TestInsert(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), @@ -154,32 +155,19 @@ func TestInsert(t *testing.T) { // Username is unique, so this should return error assert.Error(t, err, "insert should fail but no error returned") assert.EqualValues(t, 0, cnt, "insert not returned 1") - if err == nil { - panic("should return err") - } } func TestInsertAutoIncr(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) // auto increment insert user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} cnt, err := testEngine.Insert(&user) - fmt.Println(user.Uid) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - } - if user.Uid <= 0 { - t.Error(errors.New("not return id error")) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + assert.Greater(t, user.Uid, int64(0)) } type DefaultInsert struct { @@ -191,7 +179,7 @@ type DefaultInsert struct { } func TestInsertDefault(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) di := new(DefaultInsert) err := testEngine.Sync2(di) @@ -201,28 +189,12 @@ func TestInsertDefault(t *testing.T) { _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("Status")).Insert(&di2) assert.NoError(t, err) - has, err := testEngine.Desc("(id)").Get(di) + has, err := testEngine.Desc("id").Get(di) assert.NoError(t, err) - if !has { - err = errors.New("error with no data") - t.Error(err) - panic(err) - } - if di.Status != -1 { - err = errors.New("inserted error data") - t.Error(err) - panic(err) - } - if di2.Updated.Unix() != di.Updated.Unix() { - err = errors.New("updated should equal") - t.Error(err, di.Updated, di2.Updated) - panic(err) - } - if di2.Created.Unix() != di.Created.Unix() { - err = errors.New("created should equal") - t.Error(err, di.Created, di2.Created) - panic(err) - } + assert.True(t, has) + assert.EqualValues(t, -1, di.Status) + assert.EqualValues(t, di2.Updated.Unix(), di.Updated.Unix()) + assert.EqualValues(t, di2.Created.Unix(), di.Created.Unix()) } type DefaultInsert2 struct { @@ -233,57 +205,24 @@ type DefaultInsert2 struct { } func TestInsertDefault2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) di := new(DefaultInsert2) err := testEngine.Sync2(di) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) var di2 = DefaultInsert2{Name: "test"} _, err = testEngine.Omit(testEngine.GetColumnMapper().Obj2Table("CheckTime")).Insert(&di2) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) - has, err := testEngine.Desc("(id)").Get(di) - if err != nil { - t.Error(err) - } - if !has { - err = errors.New("error with no data") - t.Error(err) - panic(err) - } + has, err := testEngine.Desc("id").Get(di) + assert.NoError(t, err) + assert.True(t, has) - has, err = testEngine.NoAutoCondition().Desc("(id)").Get(&di2) - if err != nil { - t.Error(err) - } - - if !has { - err = errors.New("error with no data") - t.Error(err) - panic(err) - } - - if *di != di2 { - err = fmt.Errorf("%v is not equal to %v", di, di2) - t.Error(err) - panic(err) - } - - /*if di2.Updated.Unix() != di.Updated.Unix() { - err = errors.New("updated should equal") - t.Error(err, di.Updated, di2.Updated) - panic(err) - } - if di2.Created.Unix() != di.Created.Unix() { - err = errors.New("created should equal") - t.Error(err, di.Created, di2.Created) - panic(err) - }*/ + has, err = testEngine.NoAutoCondition().Desc("id").Get(&di2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, *di, di2) } type CreatedInsert struct { @@ -317,147 +256,91 @@ type CreatedInsert6 struct { } func TestInsertCreated(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) di := new(CreatedInsert) err := testEngine.Sync2(di) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci := &CreatedInsert{} _, err = testEngine.Insert(ci) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - has, err := testEngine.Desc("(id)").Get(di) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci.Created.Unix() != di.Created.Unix() { - t.Fatal("should equal:", ci, di) - } - fmt.Println("ci:", ci, "di:", di) + has, err := testEngine.Desc("id").Get(di) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci.Created.Unix(), di.Created.Unix()) di2 := new(CreatedInsert2) err = testEngine.Sync2(di2) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci2 := &CreatedInsert2{} _, err = testEngine.Insert(ci2) - if err != nil { - t.Fatal(err) - } - has, err = testEngine.Desc("(id)").Get(di2) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci2.Created != di2.Created { - t.Fatal("should equal:", ci2, di2) - } - fmt.Println("ci2:", ci2, "di2:", di2) + assert.NoError(t, err) + + has, err = testEngine.Desc("id").Get(di2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci2.Created, di2.Created) di3 := new(CreatedInsert3) err = testEngine.Sync2(di3) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci3 := &CreatedInsert3{} _, err = testEngine.Insert(ci3) - if err != nil { - t.Fatal(err) - } - has, err = testEngine.Desc("(id)").Get(di3) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci3.Created != di3.Created { - t.Fatal("should equal:", ci3, di3) - } - fmt.Println("ci3:", ci3, "di3:", di3) + assert.NoError(t, err) + + has, err = testEngine.Desc("id").Get(di3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci3.Created, di3.Created) di4 := new(CreatedInsert4) err = testEngine.Sync2(di4) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci4 := &CreatedInsert4{} _, err = testEngine.Insert(ci4) - if err != nil { - t.Fatal(err) - } - has, err = testEngine.Desc("(id)").Get(di4) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci4.Created != di4.Created { - t.Fatal("should equal:", ci4, di4) - } - fmt.Println("ci4:", ci4, "di4:", di4) + assert.NoError(t, err) + + has, err = testEngine.Desc("id").Get(di4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci4.Created, di4.Created) di5 := new(CreatedInsert5) err = testEngine.Sync2(di5) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci5 := &CreatedInsert5{} _, err = testEngine.Insert(ci5) - if err != nil { - t.Fatal(err) - } - has, err = testEngine.Desc("(id)").Get(di5) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci5.Created.Unix() != di5.Created.Unix() { - t.Fatal("should equal:", ci5, di5) - } - fmt.Println("ci5:", ci5, "di5:", di5) + assert.NoError(t, err) + + has, err = testEngine.Desc("id").Get(di5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci5.Created.Unix(), di5.Created.Unix()) di6 := new(CreatedInsert6) err = testEngine.Sync2(di6) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + oldTime := time.Now().Add(-time.Hour) ci6 := &CreatedInsert6{Created: oldTime} _, err = testEngine.Insert(ci6) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) - has, err = testEngine.Desc("(id)").Get(di6) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci6.Created.Unix() != di6.Created.Unix() { - t.Fatal("should equal:", ci6, di6) - } - fmt.Println("ci6:", ci6, "di6:", di6) + has, err = testEngine.Desc("id").Get(di6) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci6.Created.Unix(), di6.Created.Unix()) } -type JsonTime time.Time +type JSONTime time.Time -func (j JsonTime) format() string { +func (j JSONTime) format() string { t := time.Time(j) if t.IsZero() { return "" @@ -466,11 +349,11 @@ func (j JsonTime) format() string { return t.Format("2006-01-02") } -func (j JsonTime) MarshalText() ([]byte, error) { +func (j JSONTime) MarshalText() ([]byte, error) { return []byte(j.format()), nil } -func (j JsonTime) MarshalJSON() ([]byte, error) { +func (j JSONTime) MarshalJSON() ([]byte, error) { return []byte(`"` + j.format() + `"`), nil } @@ -478,64 +361,55 @@ func TestDefaultTime3(t *testing.T) { type PrepareTask struct { Id int `xorm:"not null pk autoincr INT(11)" json:"id"` // ... - StartTime JsonTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP index" json:"start_time"` - EndTime JsonTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP" json:"end_time"` + StartTime JSONTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP index" json:"start_time"` + EndTime JSONTime `xorm:"not null default '2006-01-02 15:04:05' TIMESTAMP" json:"end_time"` Cuser string `xorm:"not null default '' VARCHAR(64) index" json:"cuser"` Muser string `xorm:"not null default '' VARCHAR(64)" json:"muser"` - Ctime JsonTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP created" json:"ctime"` - Mtime JsonTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP updated" json:"mtime"` + Ctime JSONTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP created" json:"ctime"` + Mtime JSONTime `xorm:"not null default CURRENT_TIMESTAMP TIMESTAMP updated" json:"mtime"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(PrepareTask)) prepareTask := &PrepareTask{ - StartTime: JsonTime(time.Now()), + StartTime: JSONTime(time.Now()), Cuser: "userId", Muser: "userId", } - cnt, err := testEngine.Omit("end_time").InsertOne(prepareTask) + cnt, err := testEngine.Omit("end_time").Insert(prepareTask) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) } -type MyJsonTime struct { +type MyJSONTime struct { Id int64 `json:"id"` - Created JsonTime `xorm:"created" json:"created_at"` + Created JSONTime `xorm:"created" json:"created_at"` } func TestCreatedJsonTime(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) - di5 := new(MyJsonTime) + di5 := new(MyJSONTime) err := testEngine.Sync2(di5) - if err != nil { - t.Fatal(err) - } - ci5 := &MyJsonTime{} - _, err = testEngine.Insert(ci5) - if err != nil { - t.Fatal(err) - } - has, err := testEngine.Desc("(id)").Get(di5) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if time.Time(ci5.Created).Unix() != time.Time(di5.Created).Unix() { - t.Fatal("should equal:", time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix()) - } - fmt.Println("ci5:", ci5, "di5:", di5) + assert.NoError(t, err) - var dis = make([]MyJsonTime, 0) + ci5 := &MyJSONTime{} + _, err = testEngine.Insert(ci5) + assert.NoError(t, err) + + has, err := testEngine.Desc("id").Get(di5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, time.Time(ci5.Created).Unix(), time.Time(di5.Created).Unix()) + + var dis = make([]MyJSONTime, 0) err = testEngine.Find(&dis) assert.NoError(t, err) } func TestInsertMulti2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) @@ -545,6 +419,34 @@ func TestInsertMulti2(t *testing.T) { {Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, {Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, } + cnt, err := testEngine.Insert(&users) + assert.NoError(t, err) + assert.EqualValues(t, len(users), cnt) + + users2 := []*Userinfo{ + {Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + + cnt, err = testEngine.Insert(&users2) + assert.NoError(t, err) + assert.EqualValues(t, len(users2), cnt) +} + +func TestInsertMulti2Interface(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(Userinfo)) + + users := []interface{}{ + Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + Userinfo{Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + Userinfo{Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + cnt, err := testEngine.Insert(&users) if err != nil { t.Error(err) @@ -552,7 +454,7 @@ func TestInsertMulti2(t *testing.T) { } assert.EqualValues(t, len(users), cnt) - users2 := []*Userinfo{ + users2 := []interface{}{ &Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, &Userinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, &Userinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, @@ -565,7 +467,7 @@ func TestInsertMulti2(t *testing.T) { } func TestInsertTwoTable(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo), new(Userdetail)) @@ -573,32 +475,14 @@ func TestInsertTwoTable(t *testing.T) { userinfo := Userinfo{Username: "xlw3", Departname: "dev", Alias: "lunny4", Created: time.Now(), Detail: userdetail} cnt, err := testEngine.Insert(&userinfo, &userdetail) - if err != nil { - t.Error(err) - panic(err) - } - - if userinfo.Uid <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } - - if userdetail.Id <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } - - if cnt != 2 { - err = errors.New("insert not returned 2") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.Greater(t, userinfo.Uid, int64(0)) + assert.Greater(t, userdetail.Id, int64(0)) + assert.EqualValues(t, 2, cnt) } func TestInsertCreatedInt64(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TestCreatedInt64 struct { Id int64 `xorm:"autoincr pk"` @@ -630,7 +514,7 @@ func (MyUserinfo) TableName() string { } func TestInsertMulti3(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) testEngine.ShowSQL(true) assertSync(t, new(MyUserinfo)) @@ -646,10 +530,10 @@ func TestInsertMulti3(t *testing.T) { assert.EqualValues(t, len(users), cnt) users2 := []*MyUserinfo{ - &MyUserinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &MyUserinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, - &MyUserinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &MyUserinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, } cnt, err = testEngine.Insert(&users2) @@ -674,7 +558,7 @@ func (MyUserinfo2) TableName() string { } func TestInsertMulti4(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) testEngine.ShowSQL(false) assertSync(t, new(MyUserinfo2)) @@ -691,10 +575,10 @@ func TestInsertMulti4(t *testing.T) { assert.EqualValues(t, len(users), cnt) users2 := []*MyUserinfo2{ - &MyUserinfo2{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &MyUserinfo2{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, - &MyUserinfo2{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &MyUserinfo2{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, } cnt, err = testEngine.Insert(&users2) @@ -720,7 +604,7 @@ func TestAnonymousStruct(t *testing.T) { } `json:"ext" xorm:"'EXT' json notnull"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(PlainFoo)) _, err := testEngine.Insert(&PlainFoo{ @@ -749,7 +633,7 @@ func TestInsertMap(t *testing.T) { Name string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(InsertMap)) cnt, err := testEngine.Table(new(InsertMap)).Insert(map[string]interface{}{ @@ -834,7 +718,7 @@ func TestInsertWhere(t *testing.T) { IsTrue bool } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(InsertWhere)) var i = InsertWhere{ @@ -928,6 +812,64 @@ func TestInsertWhere(t *testing.T) { assert.EqualValues(t, 5, j5.Index) } +func TestInsertExpr2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type InsertExprsRelease struct { + Id int64 + RepoId int + IsTag bool + IsDraft bool + NumCommits int + Sha1 string + } + + assertSync(t, new(InsertExprsRelease)) + + var ie = InsertExprsRelease{ + RepoId: 1, + IsTag: true, + } + inserted, err := testEngine. + SetExpr("is_draft", true). + SetExpr("num_commits", 0). + SetExpr("sha1", ""). + Insert(&ie) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var ie2 InsertExprsRelease + has, err := testEngine.ID(ie.Id).Get(&ie2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, true, ie2.IsDraft) + assert.EqualValues(t, "", ie2.Sha1) + assert.EqualValues(t, 0, ie2.NumCommits) + assert.EqualValues(t, 1, ie2.RepoId) + assert.EqualValues(t, true, ie2.IsTag) + + inserted, err = testEngine.Table(new(InsertExprsRelease)). + SetExpr("is_draft", true). + SetExpr("num_commits", 0). + SetExpr("sha1", ""). + Insert(map[string]interface{}{ + "repo_id": 1, + "is_tag": true, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + var ie3 InsertExprsRelease + has, err = testEngine.ID(ie.Id + 1).Get(&ie3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, true, ie3.IsDraft) + assert.EqualValues(t, "", ie3.Sha1) + assert.EqualValues(t, 0, ie3.NumCommits) + assert.EqualValues(t, 1, ie3.RepoId) + assert.EqualValues(t, true, ie3.IsTag) +} + type NightlyRate struct { ID int64 `xorm:"'id' not null pk BIGINT(20)" json:"id"` } @@ -937,7 +879,7 @@ func (NightlyRate) TableName() string { } func TestMultipleInsertTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) tableName := `prd_nightly_rate_16` assert.NoError(t, testEngine.Table(tableName).Sync2(new(NightlyRate))) @@ -968,7 +910,7 @@ func TestMultipleInsertTableName(t *testing.T) { } func TestInsertMultiWithOmit(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TestMultiOmit struct { Id int64 `xorm:"int(11) pk"` @@ -1009,3 +951,38 @@ func TestInsertMultiWithOmit(t *testing.T) { assert.EqualValues(t, 3, num) check() } + +func TestInsertTwice(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type InsertStructA struct { + FieldA int + } + + type InsertStructB struct { + FieldB int + } + + assert.NoError(t, testEngine.Sync2(new(InsertStructA), new(InsertStructB))) + + var sliceA []InsertStructA // sliceA is empty + sliceB := []InsertStructB{ + { + FieldB: 1, + }, + } + + ssn := testEngine.NewSession() + defer ssn.Close() + + err := ssn.Begin() + assert.NoError(t, err) + + _, err = ssn.Insert(sliceA) + assert.EqualValues(t, xorm.ErrNoElementsOnSlice, err) + + _, err = ssn.Insert(sliceB) + assert.NoError(t, err) + + assert.NoError(t, ssn.Commit()) +} diff --git a/session_iterate_test.go b/integrations/session_iterate_test.go similarity index 83% rename from session_iterate_test.go rename to integrations/session_iterate_test.go index 9a7ec25f..564f457b 100644 --- a/session_iterate_test.go +++ b/integrations/session_iterate_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "testing" @@ -11,7 +11,7 @@ import ( ) func TestIterate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserIterate struct { Id int64 @@ -39,7 +39,7 @@ func TestIterate(t *testing.T) { } func TestBufferIterate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserBufferIterate struct { Id int64 @@ -89,4 +89,15 @@ func TestBufferIterate(t *testing.T) { }) assert.NoError(t, err) assert.EqualValues(t, 7, cnt) + + cnt = 0 + err = testEngine.Where("id <= 10").BufferSize(2).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error { + user := bean.(*UserBufferIterate) + assert.EqualValues(t, cnt+1, user.Id) + assert.EqualValues(t, true, user.IsMan) + cnt++ + return nil + }) + assert.NoError(t, err) + assert.EqualValues(t, 10, cnt) } diff --git a/integrations/session_pk_test.go b/integrations/session_pk_test.go new file mode 100644 index 00000000..d5f23491 --- /dev/null +++ b/integrations/session_pk_test.go @@ -0,0 +1,673 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "sort" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "xorm.io/xorm/schemas" +) + +type IntId struct { + Id int `xorm:"pk autoincr"` + Name string +} + +type Int16Id struct { + Id int16 `xorm:"pk autoincr"` + Name string +} + +type Int32Id struct { + Id int32 `xorm:"pk autoincr"` + Name string +} + +type UintId struct { + Id uint `xorm:"pk autoincr"` + Name string +} + +type Uint16Id struct { + Id uint16 `xorm:"pk autoincr"` + Name string +} + +type Uint32Id struct { + Id uint32 `xorm:"pk autoincr"` + Name string +} + +type Uint64Id struct { + Id uint64 `xorm:"pk autoincr"` + Name string +} + +type StringPK struct { + Id string `xorm:"pk notnull"` + Name string +} + +type ID int64 +type MyIntPK struct { + ID ID `xorm:"pk autoincr"` + Name string +} + +type StrID string +type MyStringPK struct { + ID StrID `xorm:"pk notnull"` + Name string +} + +func TestIntId(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&IntId{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&IntId{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&IntId{Name: "test"}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + bean := new(IntId) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]IntId, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[int]IntId) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&IntId{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestInt16Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Int16Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Int16Id{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&Int16Id{Name: "test"}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + bean := new(Int16Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]Int16Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[int16]Int16Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&Int16Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestInt32Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Int32Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Int32Id{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&Int32Id{Name: "test"}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + bean := new(Int32Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]Int32Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[int32]Int32Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&Int32Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestUintId(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&UintId{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&UintId{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&UintId{Name: "test"}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var inserts = []UintId{ + {Name: "test1"}, + {Name: "test2"}, + } + cnt, err = testEngine.Insert(&inserts) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + bean := new(UintId) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]UintId, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 3, len(beans)) + + beans2 := make(map[uint]UintId, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 3, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&UintId{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestUint16Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Uint16Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Uint16Id{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&Uint16Id{Name: "test"}) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(Uint16Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]Uint16Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[uint16]Uint16Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&Uint16Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestUint32Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Uint32Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Uint32Id{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&Uint32Id{Name: "test"}) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(Uint32Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]Uint32Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[uint32]Uint32Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&Uint32Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestUint64Id(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Uint64Id{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Uint64Id{}) + assert.NoError(t, err) + + idbean := &Uint64Id{Name: "test"} + cnt, err := testEngine.Insert(idbean) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(Uint64Id) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, bean.Id, idbean.Id) + + beans := make([]Uint64Id, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + assert.EqualValues(t, *bean, beans[0]) + + beans2 := make(map[uint64]Uint64Id, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + assert.EqualValues(t, *bean, beans2[bean.Id]) + + cnt, err = testEngine.ID(bean.Id).Delete(&Uint64Id{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestStringPK(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&StringPK{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&StringPK{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&StringPK{Id: "1-1-2", Name: "test"}) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(StringPK) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + + beans := make([]StringPK, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + + beans2 := make(map[string]StringPK) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + + cnt, err = testEngine.ID(bean.Id).Delete(&StringPK{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type CompositeKey struct { + Id1 int64 `xorm:"id1 pk"` + Id2 int64 `xorm:"id2 pk"` + UpdateStr string +} + +func TestCompositeKey(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&CompositeKey{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&CompositeKey{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&CompositeKey{11, 22, ""}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Insert(&CompositeKey{11, 22, ""}) + assert.Error(t, err) + assert.NotEqual(t, int64(1), cnt) + + var compositeKeyVal CompositeKey + has, err := testEngine.ID(schemas.PK{11, 22}).Get(&compositeKeyVal) + assert.NoError(t, err) + assert.True(t, has) + + var compositeKeyVal2 CompositeKey + // test passing PK ptr, this test seem failed withCache + has, err = testEngine.ID(&schemas.PK{11, 22}).Get(&compositeKeyVal2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, compositeKeyVal, compositeKeyVal2) + + var cps = make([]CompositeKey, 0) + err = testEngine.Find(&cps) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(cps)) + assert.EqualValues(t, cps[0], compositeKeyVal) + + cnt, err = testEngine.Insert(&CompositeKey{22, 22, ""}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cps = make([]CompositeKey, 0) + err = testEngine.Find(&cps) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(cps), "should has two record") + assert.EqualValues(t, compositeKeyVal, cps[0], "should be equeal") + + compositeKeyVal = CompositeKey{UpdateStr: "test1"} + cnt, err = testEngine.ID(schemas.PK{11, 22}).Update(&compositeKeyVal) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.ID(schemas.PK{11, 22}).Delete(&CompositeKey{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestCompositeKey2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type User struct { + UserId string `xorm:"varchar(19) not null pk"` + NickName string `xorm:"varchar(19) not null"` + GameId uint32 `xorm:"integer pk"` + Score int32 `xorm:"integer"` + } + + err := testEngine.DropTables(&User{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&User{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&User{"11", "nick", 22, 5}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Insert(&User{"11", "nick", 22, 6}) + assert.Error(t, err) + assert.NotEqual(t, 1, cnt) + + var user User + has, err := testEngine.ID(schemas.PK{"11", 22}).Get(&user) + assert.NoError(t, err) + assert.True(t, has) + + // test passing PK ptr, this test seem failed withCache + has, err = testEngine.ID(&schemas.PK{"11", 22}).Get(&user) + assert.NoError(t, err) + assert.True(t, has) + + user = User{NickName: "test1"} + cnt, err = testEngine.ID(schemas.PK{"11", 22}).Update(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.ID(schemas.PK{"11", 22}).Delete(&User{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type MyString string +type UserPK2 struct { + UserId MyString `xorm:"varchar(19) not null pk"` + NickName string `xorm:"varchar(19) not null"` + GameId uint32 `xorm:"integer pk"` + Score int32 `xorm:"integer"` +} + +func TestCompositeKey3(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&UserPK2{}) + + assert.NoError(t, err) + + err = testEngine.CreateTables(&UserPK2{}) + assert.NoError(t, err) + + cnt, err := testEngine.Insert(&UserPK2{"11", "nick", 22, 5}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Insert(&UserPK2{"11", "nick", 22, 6}) + assert.Error(t, err) + assert.NotEqual(t, 1, cnt) + + var user UserPK2 + has, err := testEngine.ID(schemas.PK{"11", 22}).Get(&user) + assert.NoError(t, err) + assert.True(t, has) + + // test passing PK ptr, this test seem failed withCache + has, err = testEngine.ID(&schemas.PK{"11", 22}).Get(&user) + assert.NoError(t, err) + assert.True(t, has) + + user = UserPK2{NickName: "test1"} + cnt, err = testEngine.ID(schemas.PK{"11", 22}).Update(&user) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.ID(schemas.PK{"11", 22}).Delete(&UserPK2{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestMyIntId(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&MyIntPK{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&MyIntPK{}) + assert.NoError(t, err) + + idbean := &MyIntPK{Name: "test"} + cnt, err := testEngine.Insert(idbean) + assert.NoError(t, err) + + assert.EqualValues(t, 1, cnt) + + bean := new(MyIntPK) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, bean.ID, idbean.ID) + + var beans []MyIntPK + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + assert.EqualValues(t, *bean, beans[0]) + + beans2 := make(map[ID]MyIntPK, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + assert.EqualValues(t, *bean, beans2[bean.ID]) + + cnt, err = testEngine.ID(bean.ID).Delete(&MyIntPK{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestMyStringId(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&MyStringPK{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&MyStringPK{}) + assert.NoError(t, err) + + idbean := &MyStringPK{ID: "1111", Name: "test"} + cnt, err := testEngine.Insert(idbean) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + bean := new(MyStringPK) + has, err := testEngine.Get(bean) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, bean.ID, idbean.ID) + + var beans []MyStringPK + err = testEngine.Find(&beans) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans)) + assert.EqualValues(t, *bean, beans[0]) + + beans2 := make(map[StrID]MyStringPK, 0) + err = testEngine.Find(&beans2) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(beans2)) + assert.EqualValues(t, *bean, beans2[bean.ID]) + + cnt, err = testEngine.ID(bean.ID).Delete(&MyStringPK{}) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +func TestSingleAutoIncrColumn(t *testing.T) { + type Account struct { + Id int64 `xorm:"pk autoincr"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(Account)) + + _, err := testEngine.Insert(&Account{}) + assert.NoError(t, err) +} + +func TestCompositePK(t *testing.T) { + type TaskSolution struct { + UID string `xorm:"notnull pk UUID 'uid'"` + TID string `xorm:"notnull pk UUID 'tid'"` + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` + } + + assert.NoError(t, PrepareEngine()) + + tables1, err := testEngine.DBMetas() + assert.NoError(t, err) + + assertSync(t, new(TaskSolution)) + assert.NoError(t, testEngine.Sync2(new(TaskSolution))) + + tables2, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1+len(tables1), len(tables2)) + + var table *schemas.Table + for _, t := range tables2 { + if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") { + table = t + break + } + } + + assert.NotEqual(t, nil, table) + + pkCols := table.PKColumns() + assert.EqualValues(t, 2, len(pkCols)) + + names := []string{pkCols[0].Name, pkCols[1].Name} + sort.Strings(names) + assert.EqualValues(t, []string{"tid", "uid"}, names) +} + +func TestNoPKIdQueryUpdate(t *testing.T) { + type NoPKTable struct { + Username string + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(NoPKTable)) + + cnt, err := testEngine.Insert(&NoPKTable{ + Username: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var res NoPKTable + has, err := testEngine.ID("test").Get(&res) + assert.Error(t, err) + assert.False(t, has) + + cnt, err = testEngine.ID("test").Update(&NoPKTable{ + Username: "test1", + }) + assert.Error(t, err) + assert.EqualValues(t, 0, cnt) + + type UnvalidPKTable struct { + ID int `xorm:"id"` + Username string + } + + assertSync(t, new(UnvalidPKTable)) + + cnt, err = testEngine.Insert(&UnvalidPKTable{ + ID: 1, + Username: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var res2 UnvalidPKTable + has, err = testEngine.ID(1).Get(&res2) + assert.Error(t, err) + assert.False(t, has) + + cnt, err = testEngine.ID(1).Update(&UnvalidPKTable{ + Username: "test1", + }) + assert.Error(t, err) + assert.EqualValues(t, 0, cnt) +} diff --git a/session_query_test.go b/integrations/session_query_test.go similarity index 87% rename from session_query_test.go rename to integrations/session_query_test.go index 772206a8..30f2e6ab 100644 --- a/session_query_test.go +++ b/integrations/session_query_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" @@ -11,13 +11,13 @@ import ( "time" "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/schemas" "github.com/stretchr/testify/assert" ) func TestQueryString(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar2 struct { Id int64 `xorm:"autoincr pk"` @@ -48,7 +48,7 @@ func TestQueryString(t *testing.T) { } func TestQueryString2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar3 struct { Id int64 `xorm:"autoincr pk"` @@ -108,7 +108,7 @@ func toFloat64(i interface{}) float64 { } func TestQueryInterface(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVarInterface struct { Id int64 `xorm:"autoincr pk"` @@ -139,7 +139,7 @@ func TestQueryInterface(t *testing.T) { } func TestQueryNoParams(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type QueryNoParams struct { Id int64 `xorm:"autoincr pk"` @@ -188,7 +188,7 @@ func TestQueryNoParams(t *testing.T) { } func TestQueryStringNoParam(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar4 struct { Id int64 `xorm:"autoincr pk"` @@ -207,7 +207,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) - if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0]["msg"]) } else { assert.EqualValues(t, "0", records[0]["msg"]) @@ -217,7 +217,7 @@ func TestQueryStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0]["id"]) - if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0]["msg"]) } else { assert.EqualValues(t, "0", records[0]["msg"]) @@ -225,7 +225,7 @@ func TestQueryStringNoParam(t *testing.T) { } func TestQuerySliceStringNoParam(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar6 struct { Id int64 `xorm:"autoincr pk"` @@ -244,7 +244,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0][0]) - if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0][1]) } else { assert.EqualValues(t, "0", records[0][1]) @@ -254,7 +254,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(records)) assert.EqualValues(t, "1", records[0][0]) - if testEngine.Dialect().DBType() == core.POSTGRES || testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.POSTGRES || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.EqualValues(t, "false", records[0][1]) } else { assert.EqualValues(t, "0", records[0][1]) @@ -262,7 +262,7 @@ func TestQuerySliceStringNoParam(t *testing.T) { } func TestQueryInterfaceNoParam(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type GetVar5 struct { Id int64 `xorm:"autoincr pk"` @@ -291,7 +291,7 @@ func TestQueryInterfaceNoParam(t *testing.T) { } func TestQueryWithBuilder(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type QueryWithBuilder struct { Id int64 `xorm:"autoincr pk"` @@ -336,7 +336,7 @@ func TestQueryWithBuilder(t *testing.T) { } func TestJoinWithSubQuery(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type JoinWithSubQuery1 struct { Id int64 `xorm:"autoincr pk"` @@ -371,10 +371,18 @@ func TestJoinWithSubQuery(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + tbName := testEngine.Quote(testEngine.TableName("join_with_sub_query_depart", true)) var querys []JoinWithSubQuery1 - err = testEngine.Join("INNER", builder.Select("id").From(testEngine.Quote(testEngine.TableName("join_with_sub_query_depart", true))), + err = testEngine.Join("INNER", builder.Select("id").From(tbName), "join_with_sub_query_depart.id = join_with_sub_query1.depart_id").Find(&querys) assert.NoError(t, err) assert.EqualValues(t, 1, len(querys)) assert.EqualValues(t, q, querys[0]) + + querys = make([]JoinWithSubQuery1, 0, 1) + err = testEngine.Join("INNER", "(SELECT id FROM "+tbName+") join_with_sub_query_depart", "join_with_sub_query_depart.id = join_with_sub_query1.depart_id"). + Find(&querys) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(querys)) + assert.EqualValues(t, q, querys[0]) } diff --git a/session_raw_test.go b/integrations/session_raw_test.go similarity index 94% rename from session_raw_test.go rename to integrations/session_raw_test.go index 766206a4..8b9d6766 100644 --- a/session_raw_test.go +++ b/integrations/session_raw_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "strconv" @@ -12,7 +12,7 @@ import ( ) func TestExecAndQuery(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoQuery struct { Uid int diff --git a/session_schema_test.go b/integrations/session_schema_test.go similarity index 90% rename from session_schema_test.go rename to integrations/session_schema_test.go index 141f4d5b..c17d9a1d 100644 --- a/session_schema_test.go +++ b/integrations/session_schema_test.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" - "os" "testing" "time" @@ -14,7 +13,7 @@ import ( ) func TestStoreEngine(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("user_store_engine")) @@ -27,7 +26,7 @@ func TestStoreEngine(t *testing.T) { } func TestCreateTable(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("user_user")) @@ -40,7 +39,7 @@ func TestCreateTable(t *testing.T) { } func TestCreateMultiTables(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) session := testEngine.NewSession() defer session.Close() @@ -95,7 +94,7 @@ func (s *SyncTable3) TableName() string { } func TestSyncTable(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(SyncTable1))) @@ -120,7 +119,7 @@ func TestSyncTable(t *testing.T) { } func TestSyncTable2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Table("sync_tablex").Sync2(new(SyncTable1))) @@ -145,7 +144,7 @@ func TestSyncTable2(t *testing.T) { } func TestIsTableExist(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) exist, err := testEngine.IsTableExist(new(CustomTableName)) assert.NoError(t, err) @@ -159,7 +158,7 @@ func TestIsTableExist(t *testing.T) { } func TestIsTableEmpty(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type NumericEmpty struct { Numeric float64 `xorm:"numeric(26,2)"` @@ -202,7 +201,7 @@ func (c *CustomTableName) TableName() string { } func TestCustomTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) c := new(CustomTableName) assert.NoError(t, testEngine.DropTables(c)) @@ -210,14 +209,6 @@ func TestCustomTableName(t *testing.T) { assert.NoError(t, testEngine.CreateTables(c)) } -func TestDump(t *testing.T) { - assert.NoError(t, prepareEngine()) - - fp := testEngine.Dialect().URI().DbName + ".sql" - os.Remove(fp) - assert.NoError(t, testEngine.DumpAllToFile(fp)) -} - type IndexOrUnique struct { Id int64 Index int `xorm:"index"` @@ -229,7 +220,7 @@ type IndexOrUnique struct { } func TestIndexAndUnique(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.CreateTables(&IndexOrUnique{})) @@ -245,7 +236,7 @@ func TestIndexAndUnique(t *testing.T) { } func TestMetaInfo(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(CustomTableName), new(IndexOrUnique))) tables, err := testEngine.DBMetas() @@ -257,19 +248,13 @@ func TestMetaInfo(t *testing.T) { } func TestCharst(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables("user_charset") - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) err = testEngine.Charset("utf8").Table("user_charset").CreateTable(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestSync2_1(t *testing.T) { @@ -279,7 +264,7 @@ func TestSync2_1(t *testing.T) { Id_delete int8 `xorm:"null int default 1"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("wx_test")) assert.NoError(t, testEngine.Sync2(new(WxTest))) @@ -296,7 +281,7 @@ func TestUnique_1(t *testing.T) { UpdatedAt time.Time `xorm:"updated"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.DropTables("user_unique")) assert.NoError(t, testEngine.Sync2(new(UserUnique))) @@ -312,7 +297,7 @@ func TestSync2_2(t *testing.T) { UserId int64 `xorm:"index"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) var tableNames = make(map[string]bool) for i := 0; i < 10; i++ { @@ -341,7 +326,7 @@ func TestSync2_Default(t *testing.T) { Name string `xorm:"default('my_name')"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(TestSync2Default)) assert.NoError(t, testEngine.Sync2(new(TestSync2Default))) } diff --git a/session_stats_test.go b/integrations/session_stats_test.go similarity index 87% rename from session_stats_test.go rename to integrations/session_stats_test.go index 01c76ba5..47a64076 100644 --- a/session_stats_test.go +++ b/integrations/session_stats_test.go @@ -2,15 +2,15 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" "strconv" "testing" - "xorm.io/builder" "github.com/stretchr/testify/assert" + "xorm.io/builder" ) func isFloatEq(i, j float64, precision int) bool { @@ -23,7 +23,7 @@ func TestSum(t *testing.T) { Float float32 } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(SumStruct))) var ( @@ -82,7 +82,7 @@ func (s SumStructWithTableName) TableName() string { } func TestSumWithTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(SumStructWithTableName))) var ( @@ -132,7 +132,7 @@ func TestSumWithTableName(t *testing.T) { } func TestSumCustomColumn(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type SumStruct2 struct { Int int @@ -160,7 +160,7 @@ func TestSumCustomColumn(t *testing.T) { } func TestCount(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoCount struct { Departname string @@ -196,7 +196,7 @@ func TestCount(t *testing.T) { } func TestSQLCount(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserinfoCount2 struct { Id int64 @@ -218,7 +218,7 @@ func TestSQLCount(t *testing.T) { } func TestCountWithOthers(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type CountWithOthers struct { Id int64 @@ -252,7 +252,7 @@ func (CountWithTableName) TableName() string { } func TestWithTableName(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(CountWithTableName)) @@ -274,3 +274,27 @@ func TestWithTableName(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, total) } + +func TestCountWithSelectCols(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + assertSync(t, new(CountWithTableName)) + + _, err := testEngine.Insert(&CountWithTableName{ + Name: "orderby", + }) + assert.NoError(t, err) + + _, err = testEngine.Insert(CountWithTableName{ + Name: "limit", + }) + assert.NoError(t, err) + + total, err := testEngine.Cols("id").Count(new(CountWithTableName)) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) + + total, err = testEngine.Select("count(id)").Count(CountWithTableName{}) + assert.NoError(t, err) + assert.EqualValues(t, 2, total) +} diff --git a/session_test.go b/integrations/session_test.go similarity index 70% rename from session_test.go rename to integrations/session_test.go index 343f9baa..bdf3278d 100644 --- a/session_test.go +++ b/integrations/session_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "database/sql" @@ -12,7 +12,7 @@ import ( ) func TestClose(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) sess1 := testEngine.NewSession() sess1.Close() @@ -31,7 +31,7 @@ func TestNullFloatStruct(t *testing.T) { Amount MyNullFloat64 } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync2(new(MyNullFloatStruct))) _, err := testEngine.Insert(&MyNullFloatStruct{ @@ -43,3 +43,14 @@ func TestNullFloatStruct(t *testing.T) { }) assert.NoError(t, err) } + +func TestMustLogSQL(t *testing.T) { + assert.NoError(t, PrepareEngine()) + testEngine.ShowSQL(false) + defer testEngine.ShowSQL(true) + + assertSync(t, new(Userinfo)) + + _, err := testEngine.Table("userinfo").MustLogSQL(true).Get(new(Userinfo)) + assert.NoError(t, err) +} diff --git a/session_tx_test.go b/integrations/session_tx_test.go similarity index 84% rename from session_tx_test.go rename to integrations/session_tx_test.go index 23e1bf28..4cff5610 100644 --- a/session_tx_test.go +++ b/integrations/session_tx_test.go @@ -2,30 +2,28 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/names" ) func TestTransaction(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) - counter := func() { - total, err := testEngine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } - fmt.Printf("----now total %v records\n", total) + counter := func(t *testing.T) { + _, err := testEngine.Count(&Userinfo{}) + assert.NoError(t, err) } - counter() + counter(t) //defer counter() session := testEngine.NewSession() @@ -39,7 +37,7 @@ func TestTransaction(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "yyy"} - _, err = session.Where("(id) = ?", 0).Update(&user2) + _, err = session.Where("id = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Delete(&user2) @@ -50,14 +48,12 @@ func TestTransaction(t *testing.T) { } func TestCombineTransaction(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) counter := func() { total, err := testEngine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) fmt.Printf("----now total %v records\n", total) } @@ -85,13 +81,13 @@ func TestCombineTransaction(t *testing.T) { } func TestCombineTransactionSameMapper(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) oldMapper := testEngine.GetColumnMapper() - testEngine.UnMapType(rValue(new(Userinfo)).Type()) - testEngine.SetMapper(core.SameMapper{}) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) + testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(oldMapper) }() @@ -99,9 +95,7 @@ func TestCombineTransactionSameMapper(t *testing.T) { counter := func() { total, err := testEngine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) fmt.Printf("----now total %v records\n", total) } @@ -119,7 +113,7 @@ func TestCombineTransactionSameMapper(t *testing.T) { assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} - _, err = session.Where("(id) = ?", 0).Update(&user2) + _, err = session.Where("id = ?", 0).Update(&user2) assert.NoError(t, err) _, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username) @@ -130,7 +124,7 @@ func TestCombineTransactionSameMapper(t *testing.T) { } func TestMultipleTransaction(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type MultipleTransaction struct { Id int64 diff --git a/session_update_test.go b/integrations/session_update_test.go similarity index 74% rename from session_update_test.go rename to integrations/session_update_test.go index 386a68d1..1bc1f32a 100644 --- a/session_update_test.go +++ b/integrations/session_update_test.go @@ -2,21 +2,23 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( - "errors" "fmt" "sync" "testing" "time" "github.com/stretchr/testify/assert" - "xorm.io/core" + "xorm.io/xorm" + "xorm.io/xorm/internal/statements" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/names" ) func TestUpdateMap(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateTable struct { Id int64 @@ -38,10 +40,23 @@ func TestUpdateMap(t *testing.T) { }) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) + + cnt, err = testEngine.Table("update_table").ID(tb.Id).Update(map[string]interface{}{ + "name": "test2", + "age": 36, + }) + assert.Error(t, err) + assert.True(t, statements.IsIDConditionWithNoTableErr(err)) + assert.EqualValues(t, 0, cnt) } func TestUpdateLimit(t *testing.T) { - assert.NoError(t, prepareEngine()) + if *ingoreUpdateLimit { + t.Skip() + return + } + + assert.NoError(t, PrepareEngine()) type UpdateTable2 struct { Id int64 @@ -83,7 +98,7 @@ type ForUpdate struct { Name string } -func setupForUpdate(engine EngineInterface) error { +func setupForUpdate(engine xorm.EngineInterface) error { v := new(ForUpdate) err := testEngine.DropTables(v) if err != nil { @@ -137,7 +152,7 @@ func TestForUpdate(t *testing.T) { // use lock fList := make([]ForUpdate, 0) session1.ForUpdate() - session1.Where("(id) = ?", 1) + session1.Where("id = ?", 1) err = session1.Find(&fList) switch { case err != nil: @@ -158,7 +173,7 @@ func TestForUpdate(t *testing.T) { wg.Add(1) go func() { f2 := new(ForUpdate) - session2.Where("(id) = ?", 1).ForUpdate() + session2.Where("id = ?", 1).ForUpdate() has, err := session2.Get(f2) // wait release lock switch { case err != nil: @@ -175,7 +190,7 @@ func TestForUpdate(t *testing.T) { wg.Add(1) go func() { f3 := new(ForUpdate) - session3.Where("(id) = ?", 1) + session3.Where("id = ?", 1) has, err := session3.Get(f3) // wait release lock switch { case err != nil: @@ -193,7 +208,7 @@ func TestForUpdate(t *testing.T) { f := new(ForUpdate) f.Name = "updated by session1" - session1.Where("(id) = ?", 1) + session1.Where("id = ?", 1) session1.Update(f) // release lock @@ -213,7 +228,7 @@ func TestWithIn(t *testing.T) { Test bool `xorm:"Test"` } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assert.NoError(t, testEngine.Sync(new(temp3))) testEngine.Insert(&[]temp3{ @@ -265,20 +280,17 @@ type Article struct { } func TestUpdateMap2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(UpdateMustCols)) _, err := testEngine.Table("update_must_cols").Where("id =?", 1).Update(map[string]interface{}{ "bool": true, }) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestUpdate1(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) _, err := testEngine.Insert(&Userinfo{ @@ -287,14 +299,8 @@ func TestUpdate1(t *testing.T) { var ori Userinfo has, err := testEngine.Get(&ori) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - t.Error(errors.New("not exist")) - panic(errors.New("not exist")) - } + assert.NoError(t, err) + assert.True(t, has) // update by id user := Userinfo{Username: "xxx", Height: 1.2} @@ -318,10 +324,7 @@ func TestUpdate1(t *testing.T) { { user := &Userinfo{Username: "not null data", Height: 180.5} _, err := testEngine.Insert(user) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) userID := user.Uid has, err := testEngine.ID(userID). @@ -331,29 +334,15 @@ func TestUpdate1(t *testing.T) { And("detail_id = ?", 0). And("is_man = ?", 0). Get(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("cannot insert properly") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has, "cannot insert properly") updatedUser := &Userinfo{Username: "null data"} cnt, err = testEngine.ID(userID). Nullable("height", "departname", "is_man", "created"). Update(updatedUser) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, "update not returned 1") has, err = testEngine.ID(userID). And("username = ?", updatedUser.Username). @@ -363,60 +352,27 @@ func TestUpdate1(t *testing.T) { And("created IS NULL"). And("detail_id = ?", 0). Get(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("cannot update with null properly") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has, "cannot update with null properly") cnt, err = testEngine.ID(userID).Delete(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("delete not returned 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, "delete not returned 1") } err = testEngine.StoreEngine("Innodb").Sync2(&Article{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) defer func() { err = testEngine.DropTables(&Article{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) }() a := &Article{0, "1", "2", "3", "4", "5", 2} cnt, err = testEngine.Insert(a) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) - t.Error(err) - panic(err) - } - - if a.Id == 0 { - err = errors.New("insert returned id is 0") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt, fmt.Sprintf("insert not returned 1 but %d", cnt)) + assert.Greater(t, a.Id, int32(0), "insert returned id is 0") cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) assert.NoError(t, err) @@ -442,28 +398,18 @@ func TestUpdate1(t *testing.T) { assert.EqualValues(t, *col2, *col3) { - col1 := &UpdateMustCols{} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col2 := &UpdateMustCols{col1.Id, true, ""} boolStr := testEngine.GetColumnMapper().Obj2Table("Bool") stringStr := testEngine.GetColumnMapper().Obj2Table("String") _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col3 := &UpdateMustCols{} has, err := testEngine.ID(col2.Id).Get(col3) @@ -474,7 +420,7 @@ func TestUpdate1(t *testing.T) { } func TestUpdateIncrDecr(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) col1 := &UpdateIncr{ Name: "test", @@ -537,36 +483,23 @@ type UpdatedUpdate5 struct { } func TestUpdateUpdated(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) di := new(UpdatedUpdate) err := testEngine.Sync2(di) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate{}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ci := &UpdatedUpdate{} _, err = testEngine.ID(1).Update(ci) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) has, err := testEngine.ID(1).Get(di) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci.Updated.Unix() != di.Updated.Unix() { - t.Fatal("should equal:", ci, di) - } - fmt.Println("ci:", ci, "di:", di) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci.Updated.Unix(), di.Updated.Unix()) di2 := new(UpdatedUpdate2) err = testEngine.Sync2(di2) @@ -597,108 +530,71 @@ func TestUpdateUpdated(t *testing.T) { di3 := new(UpdatedUpdate3) err = testEngine.Sync2(di3) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate3{}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci3 := &UpdatedUpdate3{} _, err = testEngine.ID(1).Update(ci3) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) has, err = testEngine.ID(1).Get(di3) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci3.Updated != di3.Updated { - t.Fatal("should equal:", ci3, di3) - } - fmt.Println("ci3:", ci3, "di3:", di3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci3.Updated, di3.Updated) di4 := new(UpdatedUpdate4) err = testEngine.Sync2(di4) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate4{}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ci4 := &UpdatedUpdate4{} _, err = testEngine.ID(1).Update(ci4) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) has, err = testEngine.ID(1).Get(di4) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci4.Updated != di4.Updated { - t.Fatal("should equal:", ci4, di4) - } - fmt.Println("ci4:", ci4, "di4:", di4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci4.Updated, di4.Updated) di5 := new(UpdatedUpdate5) err = testEngine.Sync2(di5) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(&UpdatedUpdate5{}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + ci5 := &UpdatedUpdate5{} _, err = testEngine.ID(1).Update(ci5) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) has, err = testEngine.ID(1).Get(di5) - if err != nil { - t.Fatal(err) - } - if !has { - t.Fatal(ErrNotExist) - } - if ci5.Updated.Unix() != di5.Updated.Unix() { - t.Fatal("should equal:", ci5, di5) - } - fmt.Println("ci5:", ci5, "di5:", di5) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, ci5.Updated.Unix(), di5.Updated.Unix()) } func TestUpdateSameMapper(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) oldMapper := testEngine.GetTableMapper() - testEngine.UnMapType(rValue(new(Userinfo)).Type()) - testEngine.UnMapType(rValue(new(Condi)).Type()) - testEngine.UnMapType(rValue(new(Article)).Type()) - testEngine.UnMapType(rValue(new(UpdateAllCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateMustCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateIncr)).Type()) - testEngine.SetMapper(core.SameMapper{}) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Condi)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Article)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateAllCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateMustCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateIncr)).Type()) + testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) - testEngine.UnMapType(rValue(new(Condi)).Type()) - testEngine.UnMapType(rValue(new(Article)).Type()) - testEngine.UnMapType(rValue(new(UpdateAllCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateMustCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateIncr)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Condi)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Article)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateAllCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateMustCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateIncr)).Type()) testEngine.SetMapper(oldMapper) }() @@ -743,18 +639,8 @@ func TestUpdateSameMapper(t *testing.T) { a := &Article{0, "1", "2", "3", "4", "5", 2} cnt, err = testEngine.Insert(a) assert.NoError(t, err) - - if cnt != 1 { - err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) - t.Error(err) - panic(err) - } - - if a.Id == 0 { - err = errors.New("insert returned id is 0") - t.Error(err) - panic(err) - } + assert.EqualValues(t, 1, cnt) + assert.Greater(t, a.Id, int32(0)) cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) assert.NoError(t, err) @@ -801,75 +687,42 @@ func TestUpdateSameMapper(t *testing.T) { { col1 := &UpdateIncr{} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) cnt, err := testEngine.ID(col1.Id).Incr("`Cnt`").Update(col1) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update incr failed") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) newCol := new(UpdateIncr) has, err := testEngine.ID(col1.Id).Get(newCol) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("has incr failed") - t.Error(err) - panic(err) - } - if 1 != newCol.Cnt { - err = errors.New("incr failed") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, newCol.Cnt) } } func TestUseBool(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) cnt1, err := testEngine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) users := make([]Userinfo, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) var fNumber int64 for _, u := range users { if u.IsMan == false { - fNumber += 1 + fNumber++ } } cnt2, err := testEngine.UseBool().Update(&Userinfo{IsMan: true}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if fNumber != cnt2 { fmt.Println("cnt1", cnt1, "fNumber", fNumber, "cnt2", cnt2) /*err = errors.New("Updated number is not corrected.") @@ -878,58 +731,34 @@ func TestUseBool(t *testing.T) { } _, err = testEngine.Update(&Userinfo{IsMan: true}) - if err == nil { - err = errors.New("error condition") - t.Error(err) - panic(err) - } + assert.Error(t, err) } func TestBool(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(Userinfo)) _, err := testEngine.UseBool().Update(&Userinfo{IsMan: true}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) users := make([]Userinfo, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) for _, user := range users { - if !user.IsMan { - err = errors.New("update bool or find bool error") - t.Error(err) - panic(err) - } + assert.True(t, user.IsMan) } _, err = testEngine.UseBool().Update(&Userinfo{IsMan: false}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) users = make([]Userinfo, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) for _, user := range users { - if user.IsMan { - err = errors.New("update bool or find bool error") - t.Error(err) - panic(err) - } + assert.True(t, user.IsMan) } } func TestNoUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type NoUpdate struct { Id int64 @@ -950,7 +779,7 @@ func TestNoUpdate(t *testing.T) { } func TestNewUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TbUserInfo struct { Id int64 `xorm:"pk autoincr unique BIGINT" json:"id"` @@ -980,7 +809,7 @@ func TestNewUpdate(t *testing.T) { } func TestUpdateUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type PublicKeyUpdate struct { Id int64 @@ -997,7 +826,7 @@ func TestUpdateUpdate(t *testing.T) { } func TestCreatedUpdated2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type CreatedUpdatedStruct struct { Id int64 @@ -1041,7 +870,7 @@ func TestCreatedUpdated2(t *testing.T) { } func TestDeletedUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type DeletedUpdatedStruct struct { Id int64 @@ -1089,7 +918,7 @@ func TestDeletedUpdate(t *testing.T) { } func TestUpdateMapCondition(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateMapCondition struct { Id int64 @@ -1120,7 +949,7 @@ func TestUpdateMapCondition(t *testing.T) { } func TestUpdateMapContent(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateMapContent struct { Id int64 @@ -1168,7 +997,7 @@ func TestUpdateMapContent(t *testing.T) { assert.EqualValues(t, false, c2.IsMan) assert.EqualValues(t, 2, c2.Gender) - cnt, err = testEngine.Table(testEngine.TableName(new(UpdateMapContent))).ID(c.Id).Update(map[string]interface{}{ + cnt, err = testEngine.Table(new(UpdateMapContent)).ID(c.Id).Update(map[string]interface{}{ "age": 15, "is_man": true, "gender": 1, @@ -1195,7 +1024,7 @@ func TestUpdateCondiBean(t *testing.T) { Name string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NeedUpdateBean)) cnt, err := testEngine.Insert(&NeedUpdateBean{ @@ -1245,7 +1074,7 @@ func TestWhereCondErrorWhenUpdate(t *testing.T) { RequestToken string } - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(AuthRequestError)) _, err := testEngine.Cols("challenge_token", "request_token", "challenge_agent", "status"). @@ -1254,11 +1083,11 @@ func TestWhereCondErrorWhenUpdate(t *testing.T) { ChallengeToken: "2", }) assert.Error(t, err) - assert.EqualValues(t, ErrConditionType, err) + assert.EqualValues(t, xorm.ErrConditionType, err) } func TestUpdateDeleted(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateDeletedStruct struct { Id int64 @@ -1299,7 +1128,7 @@ func TestUpdateDeleted(t *testing.T) { } func TestUpdateExprs(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateExprs struct { Id int64 @@ -1330,7 +1159,7 @@ func TestUpdateExprs(t *testing.T) { } func TestUpdateAlias(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UpdateAlias struct { Id int64 @@ -1359,3 +1188,165 @@ func TestUpdateAlias(t *testing.T) { assert.EqualValues(t, 2, ue.NumIssues) assert.EqualValues(t, "lunny xiao", ue.Name) } + +func TestUpdateExprs2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UpdateExprsRelease struct { + Id int64 + RepoId int + IsTag bool + IsDraft bool + NumCommits int + Sha1 string + } + + assertSync(t, new(UpdateExprsRelease)) + + var uer = UpdateExprsRelease{ + RepoId: 1, + IsTag: false, + IsDraft: false, + NumCommits: 1, + Sha1: "sha1", + } + inserted, err := testEngine.Insert(&uer) + assert.NoError(t, err) + assert.EqualValues(t, 1, inserted) + + updated, err := testEngine. + Where("repo_id = ? AND is_tag = ?", 1, false). + SetExpr("is_draft", true). + SetExpr("num_commits", 0). + SetExpr("sha1", ""). + Update(new(UpdateExprsRelease)) + assert.NoError(t, err) + assert.EqualValues(t, 1, updated) + + var uer2 UpdateExprsRelease + has, err := testEngine.ID(uer.Id).Get(&uer2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 1, uer2.RepoId) + assert.EqualValues(t, false, uer2.IsTag) + assert.EqualValues(t, true, uer2.IsDraft) + assert.EqualValues(t, 0, uer2.NumCommits) + assert.EqualValues(t, "", uer2.Sha1) +} + +func TestUpdateMap3(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UpdateMapUser struct { + Id uint64 `xorm:"PK autoincr"` + Name string `xorm:""` + Ver uint64 `xorm:"version"` + } + + oldMapper := testEngine.GetColumnMapper() + defer func() { + testEngine.SetColumnMapper(oldMapper) + }() + + mapper := names.NewPrefixMapper(names.SnakeMapper{}, "F") + testEngine.SetColumnMapper(mapper) + + assertSync(t, new(UpdateMapUser)) + + _, err := testEngine.Table(new(UpdateMapUser)).Insert(map[string]interface{}{ + "Fname": "first user name", + "Fver": 1, + }) + assert.NoError(t, err) + + update := map[string]interface{}{ + "Fname": "user name", + "Fver": 1, + } + rows, err := testEngine.Table(new(UpdateMapUser)).ID(1).Update(update) + assert.NoError(t, err) + assert.EqualValues(t, 1, rows) + + update = map[string]interface{}{ + "Name": "user name", + "Ver": 1, + } + rows, err = testEngine.Table(new(UpdateMapUser)).ID(1).Update(update) + assert.Error(t, err) + assert.EqualValues(t, 0, rows) +} + +func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { + type TestOnlyFromDBField struct { + Id int64 `xorm:"PK"` + OnlyFromDBField string `xorm:"<-"` + OnlyToDBField string `xorm:"->"` + IngoreField string `xorm:"-"` + } + + assertGetRecord := func() *TestOnlyFromDBField { + var record TestOnlyFromDBField + has, err := testEngine.Where("id = ?", 1).Get(&record) + assert.NoError(t, err) + assert.EqualValues(t, true, has) + assert.EqualValues(t, "", record.OnlyFromDBField) + return &record + + } + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestOnlyFromDBField)) + + _, err := testEngine.Insert(&TestOnlyFromDBField{ + Id: 1, + OnlyFromDBField: "a", + OnlyToDBField: "b", + IngoreField: "c", + }) + assert.NoError(t, err) + + assertGetRecord() + + _, err = testEngine.ID(1).Update(&TestOnlyFromDBField{ + OnlyToDBField: "b", + OnlyFromDBField: "test", + }) + assert.NoError(t, err) + assertGetRecord() +} + +func TestUpdateMultiplePK(t *testing.T) { + type TestUpdateMultiplePKStruct struct { + Id string `xorm:"notnull pk" description:"唯一ID号"` + Name string `xorm:"notnull pk" description:"名称"` + Value string `xorm:"notnull varchar(4000)" description:"值"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestUpdateMultiplePKStruct)) + + test := &TestUpdateMultiplePKStruct{ + Id: "ID1", + Name: "Name1", + Value: "1", + } + _, err := testEngine.Insert(test) + assert.NoError(t, err) + + test.Value = "2" + _, err = testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Cols("Value").Update(test) + assert.NoError(t, err) + + test.Value = "3" + num, err := testEngine.Where("`id` = ? And `name` = ?", test.Id, test.Name).Update(test) + assert.NoError(t, err) + assert.EqualValues(t, 1, num) + + test.Value = "4" + _, err = testEngine.ID([]interface{}{test.Id, test.Name}).Update(test) + assert.NoError(t, err) + + type MySlice []interface{} + test.Value = "5" + _, err = testEngine.ID(&MySlice{test.Id, test.Name}).Update(test) + assert.NoError(t, err) +} diff --git a/integrations/tags_test.go b/integrations/tags_test.go new file mode 100644 index 00000000..f787fffe --- /dev/null +++ b/integrations/tags_test.go @@ -0,0 +1,1329 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package integrations + +import ( + "fmt" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/names" + "xorm.io/xorm/schemas" +) + +type tempUser struct { + Id int64 + Username string +} + +type tempUser2 struct { + TempUser tempUser `xorm:"extends"` + Departname string +} + +type tempUser3 struct { + Temp *tempUser `xorm:"extends"` + Departname string +} + +type tempUser4 struct { + TempUser2 tempUser2 `xorm:"extends"` +} + +type Userinfo struct { + Uid int64 `xorm:"id pk not null autoincr"` + Username string `xorm:"unique"` + Departname string + Alias string `xorm:"-"` + Created time.Time + Detail Userdetail `xorm:"detail_id int(11)"` + Height float64 + Avatar []byte + IsMan bool +} + +type Userdetail struct { + Id int64 + Intro string `xorm:"text"` + Profile string `xorm:"varchar(2000)"` +} + +type UserAndDetail struct { + Userinfo `xorm:"extends"` + Userdetail `xorm:"extends"` +} + +func TestExtends(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&tempUser2{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&tempUser2{}) + assert.NoError(t, err) + + tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} + _, err = testEngine.Insert(tu) + assert.NoError(t, err) + + tu2 := &tempUser2{} + _, err = testEngine.Get(tu2) + assert.NoError(t, err) + + tu3 := &tempUser2{tempUser{0, "extends update"}, ""} + _, err = testEngine.ID(tu2.TempUser.Id).Update(tu3) + assert.NoError(t, err) + + err = testEngine.DropTables(&tempUser4{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&tempUser4{}) + assert.NoError(t, err) + + tu8 := &tempUser4{tempUser2{tempUser{0, "extends"}, "dev depart"}} + _, err = testEngine.Insert(tu8) + assert.NoError(t, err) + + tu9 := &tempUser4{} + _, err = testEngine.Get(tu9) + assert.NoError(t, err) + assert.EqualValues(t, tu8.TempUser2.TempUser.Username, tu9.TempUser2.TempUser.Username) + assert.EqualValues(t, tu8.TempUser2.Departname, tu9.TempUser2.Departname) + + tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} + _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10) + assert.NoError(t, err) + + err = testEngine.DropTables(&tempUser3{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&tempUser3{}) + assert.NoError(t, err) + + tu4 := &tempUser3{&tempUser{0, "extends"}, "dev depart"} + _, err = testEngine.Insert(tu4) + assert.NoError(t, err) + + tu5 := &tempUser3{} + _, err = testEngine.Get(tu5) + assert.NoError(t, err) + + assert.NotNil(t, tu5.Temp) + assert.EqualValues(t, 1, tu5.Temp.Id) + assert.EqualValues(t, "extends", tu5.Temp.Username) + assert.EqualValues(t, "dev depart", tu5.Departname) + + tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} + _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) + assert.NoError(t, err) + + users := make([]tempUser3, 0) + err = testEngine.Find(&users) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users), "error get data not 1") + + assertSync(t, new(Userinfo), new(Userdetail)) + + detail := Userdetail{ + Intro: "I'm in China", + } + _, err = testEngine.Insert(&detail) + assert.NoError(t, err) + + _, err = testEngine.Insert(&Userinfo{ + Username: "lunny", + Detail: detail, + }) + assert.NoError(t, err) + + var info UserAndDetail + qt := testEngine.Quote + ui := testEngine.TableName(new(Userinfo), true) + ud := testEngine.TableName(&detail, true) + uiid := testEngine.GetColumnMapper().Obj2Table("Id") + udid := "detail_id" + sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", + qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) + b, err := testEngine.SQL(sql).NoCascade().Get(&info) + assert.NoError(t, err) + assert.True(t, b, "should has lest one record") + assert.True(t, info.Userinfo.Uid > 0, "all of the id should has value") + assert.True(t, info.Userdetail.Id > 0, "all of the id should has value") + + var info2 UserAndDetail + b, err = testEngine.Table(&Userinfo{}). + Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). + NoCascade().Get(&info2) + assert.NoError(t, err) + assert.True(t, b) + assert.True(t, info2.Userinfo.Uid > 0, "all of the id should has value") + assert.True(t, info2.Userdetail.Id > 0, "all of the id should has value") + + var infos2 = make([]UserAndDetail, 0) + err = testEngine.Table(&Userinfo{}). + Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). + NoCascade(). + Find(&infos2) + assert.NoError(t, err) +} + +type MessageBase struct { + Id int64 `xorm:"int(11) pk autoincr"` + TypeId int64 `xorm:"int(11) notnull"` +} + +type Message struct { + MessageBase `xorm:"extends"` + Title string `xorm:"varchar(100) notnull"` + Content string `xorm:"text notnull"` + Uid int64 `xorm:"int(11) notnull"` + ToUid int64 `xorm:"int(11) notnull"` + CreateTime time.Time `xorm:"datetime notnull created"` +} + +type MessageUser struct { + Id int64 + Name string +} + +type MessageType struct { + Id int64 + Name string +} + +type MessageExtend3 struct { + Message `xorm:"extends"` + Sender MessageUser `xorm:"extends"` + Receiver MessageUser `xorm:"extends"` + Type MessageType `xorm:"extends"` +} + +type MessageExtend4 struct { + Message `xorm:"extends"` + MessageUser `xorm:"extends"` + MessageType `xorm:"extends"` +} + +func TestExtends2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + var sender = MessageUser{Name: "sender"} + var receiver = MessageUser{Name: "receiver"} + var msgtype = MessageType{Name: "type"} + _, err = testEngine.Insert(&sender, &receiver, &msgtype) + assert.NoError(t, err) + + msg := Message{ + MessageBase: MessageBase{ + Id: msgtype.Id, + }, + Title: "test", + Content: "test", + Uid: sender.Id, + ToUid: receiver.Id, + } + + session := testEngine.NewSession() + defer session.Close() + + // MSSQL deny insert identity column excep declare as below + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + cnt, err := session.Insert(&msg) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + + var mapper = testEngine.GetTableMapper().Obj2Table + var quote = testEngine.Quote + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) + + list := make([]Message, 0) + err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). + Find(&list) + assert.NoError(t, err) + + assert.EqualValues(t, 1, len(list), fmt.Sprintln("should have 1 message, got", len(list))) + assert.EqualValues(t, msg.Id, list[0].Id, fmt.Sprintln("should message equal", list[0], msg)) +} + +func TestExtends3(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + var sender = MessageUser{Name: "sender"} + var receiver = MessageUser{Name: "receiver"} + var msgtype = MessageType{Name: "type"} + _, err = testEngine.Insert(&sender, &receiver, &msgtype) + assert.NoError(t, err) + + msg := Message{ + MessageBase: MessageBase{ + Id: msgtype.Id, + }, + Title: "test", + Content: "test", + Uid: sender.Id, + ToUid: receiver.Id, + } + + session := testEngine.NewSession() + defer session.Close() + + // MSSQL deny insert identity column excep declare as below + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + _, err = session.Insert(&msg) + assert.NoError(t, err) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + + var mapper = testEngine.GetTableMapper().Obj2Table + var quote = testEngine.Quote + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) + + list := make([]MessageExtend3, 0) + err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). + Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). + Find(&list) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(list)) + assert.EqualValues(t, list[0].Message.Id, msg.Id) + assert.EqualValues(t, list[0].Sender.Id, sender.Id) + assert.EqualValues(t, list[0].Sender.Name, sender.Name) + assert.EqualValues(t, list[0].Receiver.Id, receiver.Id) + assert.EqualValues(t, list[0].Receiver.Name, receiver.Name) + assert.EqualValues(t, list[0].Type.Id, msgtype.Id) + assert.EqualValues(t, list[0].Type.Name, msgtype.Name) +} + +func TestExtends4(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) + assert.NoError(t, err) + + var sender = MessageUser{Name: "sender"} + var msgtype = MessageType{Name: "type"} + _, err = testEngine.Insert(&sender, &msgtype) + assert.NoError(t, err) + + msg := Message{ + MessageBase: MessageBase{ + Id: msgtype.Id, + }, + Title: "test", + Content: "test", + Uid: sender.Id, + } + + session := testEngine.NewSession() + defer session.Close() + + // MSSQL deny insert identity column excep declare as below + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Begin() + assert.NoError(t, err) + _, err = session.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + _, err = session.Insert(&msg) + assert.NoError(t, err) + + if testEngine.Dialect().URI().DBType == schemas.MSSQL { + err = session.Commit() + assert.NoError(t, err) + } + + var mapper = testEngine.GetTableMapper().Obj2Table + var quote = testEngine.Quote + userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) + typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) + msgTableName := quote(testEngine.TableName(mapper("Message"), true)) + + list := make([]MessageExtend4, 0) + err = session.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). + Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). + Find(&list) + assert.NoError(t, err) + assert.EqualValues(t, len(list), 1) + assert.EqualValues(t, list[0].Message.Id, msg.Id) + assert.EqualValues(t, list[0].MessageUser.Id, sender.Id) + assert.EqualValues(t, list[0].MessageUser.Name, sender.Name) + assert.EqualValues(t, list[0].MessageType.Id, msgtype.Id) + assert.EqualValues(t, list[0].MessageType.Name, msgtype.Name) +} + +type Size struct { + ID int64 `xorm:"int(4) 'id' pk autoincr"` + Width float32 `json:"width" xorm:"float 'Width'"` + Height float32 `json:"height" xorm:"float 'Height'"` +} + +type Book struct { + ID int64 `xorm:"int(4) 'id' pk autoincr"` + SizeOpen *Size `xorm:"extends('Open')"` + SizeClosed *Size `xorm:"extends('Closed')"` + Size *Size `xorm:"extends('')"` +} + +func TestExtends5(t *testing.T) { + assert.NoError(t, PrepareEngine()) + err := testEngine.DropTables(&Book{}, &Size{}) + assert.NoError(t, err) + + err = testEngine.CreateTables(&Size{}, &Book{}) + assert.NoError(t, err) + + var sc = Size{Width: 0.2, Height: 0.4} + var so = Size{Width: 0.2, Height: 0.8} + var s = Size{Width: 0.15, Height: 1.5} + var bk1 = Book{ + SizeOpen: &so, + SizeClosed: &sc, + Size: &s, + } + var bk2 = Book{ + SizeOpen: &so, + } + var bk3 = Book{ + SizeClosed: &sc, + Size: &s, + } + var bk4 = Book{} + var bk5 = Book{Size: &s} + _, err = testEngine.Insert(&sc, &so, &s, &bk1, &bk2, &bk3, &bk4, &bk5) + if err != nil { + t.Fatal(err) + } + + var books = map[int64]Book{ + bk1.ID: bk1, + bk2.ID: bk2, + bk3.ID: bk3, + bk4.ID: bk4, + bk5.ID: bk5, + } + + session := testEngine.NewSession() + defer session.Close() + + var mapper = testEngine.GetTableMapper().Obj2Table + var quote = testEngine.Quote + bookTableName := quote(testEngine.TableName(mapper("Book"), true)) + sizeTableName := quote(testEngine.TableName(mapper("Size"), true)) + + list := make([]Book, 0) + err = session. + Select(fmt.Sprintf( + "%s.%s, sc.%s AS %s, sc.%s AS %s, s.%s, s.%s", + quote(bookTableName), + quote("id"), + quote("Width"), + quote("ClosedWidth"), + quote("Height"), + quote("ClosedHeight"), + quote("Width"), + quote("Height"), + )). + Table(bookTableName). + Join( + "LEFT", + sizeTableName+" AS `sc`", + bookTableName+".`SizeClosed`=sc.`id`", + ). + Join( + "LEFT", + sizeTableName+" AS `s`", + bookTableName+".`Size`=s.`id`", + ). + Find(&list) + assert.NoError(t, err) + + for _, book := range list { + if ok := assert.Equal(t, books[book.ID].SizeClosed.Width, book.SizeClosed.Width); !ok { + t.Error("Not bounded size closed") + panic("Not bounded size closed") + } + + if ok := assert.Equal(t, books[book.ID].SizeClosed.Height, book.SizeClosed.Height); !ok { + t.Error("Not bounded size closed") + panic("Not bounded size closed") + } + + if books[book.ID].Size != nil || book.Size != nil { + if ok := assert.Equal(t, books[book.ID].Size.Width, book.Size.Width); !ok { + t.Error("Not bounded size") + panic("Not bounded size") + } + + if ok := assert.Equal(t, books[book.ID].Size.Height, book.Size.Height); !ok { + t.Error("Not bounded size") + panic("Not bounded size") + } + } + } +} + +func TestCacheTag(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type CacheDomain struct { + Id int64 `xorm:"pk cache"` + Name string + } + + assert.NoError(t, testEngine.CreateTables(&CacheDomain{})) + assert.True(t, testEngine.GetCacher(testEngine.TableName(&CacheDomain{})) != nil) +} + +func TestNoCacheTag(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type NoCacheDomain struct { + Id int64 `xorm:"pk nocache"` + Name string + } + + assert.NoError(t, testEngine.CreateTables(&NoCacheDomain{})) + assert.True(t, testEngine.GetCacher(testEngine.TableName(&NoCacheDomain{})) == nil) +} + +type IDGonicMapper struct { + ID int64 +} + +func TestGonicMapperID(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + oldMapper := testEngine.GetColumnMapper() + testEngine.UnMapType(utils.ReflectValue(new(IDGonicMapper)).Type()) + testEngine.SetMapper(names.LintGonicMapper) + defer func() { + testEngine.UnMapType(utils.ReflectValue(new(IDGonicMapper)).Type()) + testEngine.SetMapper(oldMapper) + }() + + err := testEngine.CreateTables(new(IDGonicMapper)) + if err != nil { + t.Fatal(err) + } + + tables, err := testEngine.DBMetas() + if err != nil { + t.Fatal(err) + } + + for _, tb := range tables { + if tb.Name == "id_gonic_mapper" { + if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "id" { + t.Fatal(tb) + } + return + } + } + + t.Fatal("not table id_gonic_mapper") +} + +type IDSameMapper struct { + ID int64 +} + +func TestSameMapperID(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + oldMapper := testEngine.GetColumnMapper() + testEngine.UnMapType(utils.ReflectValue(new(IDSameMapper)).Type()) + testEngine.SetMapper(names.SameMapper{}) + defer func() { + testEngine.UnMapType(utils.ReflectValue(new(IDSameMapper)).Type()) + testEngine.SetMapper(oldMapper) + }() + + err := testEngine.CreateTables(new(IDSameMapper)) + if err != nil { + t.Fatal(err) + } + + tables, err := testEngine.DBMetas() + if err != nil { + t.Fatal(err) + } + + for _, tb := range tables { + if tb.Name == "IDSameMapper" { + if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "ID" { + t.Fatalf("tb %s tb.PKColumns() is %d not 1, tb.PKColumns()[0].Name is %s not ID", tb.Name, len(tb.PKColumns()), tb.PKColumns()[0].Name) + } + return + } + } + t.Fatal("not table IDSameMapper") +} + +type UserCU struct { + Id int64 + Name string + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` +} + +func TestCreatedAndUpdated(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + u := new(UserCU) + err := testEngine.DropTables(u) + assert.NoError(t, err) + + err = testEngine.CreateTables(u) + assert.NoError(t, err) + + u.Name = "sss" + cnt, err := testEngine.Insert(u) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + u.Name = "xxx" + cnt, err = testEngine.ID(u.Id).Update(u) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + u.Id = 0 + u.Created = time.Now().Add(-time.Hour * 24 * 365) + u.Updated = u.Created + cnt, err = testEngine.NoAutoTime().Insert(u) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) +} + +type StrangeName struct { + Id_t int64 `xorm:"pk autoincr"` + Name string +} + +func TestStrangeName(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(StrangeName)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(StrangeName)) + assert.NoError(t, err) + + _, err = testEngine.Insert(&StrangeName{Name: "sfsfdsfds"}) + assert.NoError(t, err) + + beans := make([]StrangeName, 0) + err = testEngine.Find(&beans) + assert.NoError(t, err) +} + +func TestCreatedUpdated(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type CreatedUpdated struct { + Id int64 + Name string + Value float64 `xorm:"numeric"` + Created time.Time `xorm:"created"` + Created2 time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` + } + + err := testEngine.Sync2(&CreatedUpdated{}) + assert.NoError(t, err) + + c := &CreatedUpdated{Name: "test"} + _, err = testEngine.Insert(c) + assert.NoError(t, err) + + c2 := new(CreatedUpdated) + has, err := testEngine.ID(c.Id).Get(c2) + assert.NoError(t, err) + + assert.True(t, has) + + c2.Value-- + _, err = testEngine.ID(c2.Id).Update(c2) + assert.NoError(t, err) +} + +func TestCreatedUpdatedInt64(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type CreatedUpdatedInt64 struct { + Id int64 + Name string + Value float64 `xorm:"numeric"` + Created int64 `xorm:"created"` + Created2 int64 `xorm:"created"` + Updated int64 `xorm:"updated"` + } + + assertSync(t, &CreatedUpdatedInt64{}) + + c := &CreatedUpdatedInt64{Name: "test"} + _, err := testEngine.Insert(c) + assert.NoError(t, err) + + c2 := new(CreatedUpdatedInt64) + has, err := testEngine.ID(c.Id).Get(c2) + assert.NoError(t, err) + assert.True(t, has) + + c2.Value-- + _, err = testEngine.ID(c2.Id).Update(c2) + assert.NoError(t, err) +} + +type Lowercase struct { + Id int64 + Name string + ended int64 `xorm:"-"` +} + +func TestLowerCase(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.Sync2(&Lowercase{}) + assert.NoError(t, err) + _, err = testEngine.Where("id > 0").Delete(&Lowercase{}) + assert.NoError(t, err) + + _, err = testEngine.Insert(&Lowercase{ended: 1}) + assert.NoError(t, err) + + ls := make([]Lowercase, 0) + err = testEngine.Find(&ls) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(ls)) +} + +func TestAutoIncrTag(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestAutoIncr1 struct { + Id int64 + } + + tb, err := testEngine.TableInfo(new(TestAutoIncr1)) + assert.NoError(t, err) + + cols := tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.True(t, cols[0].IsAutoIncrement) + assert.True(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + + type TestAutoIncr2 struct { + Id int64 `xorm:"id"` + } + + tb, err = testEngine.TableInfo(new(TestAutoIncr2)) + assert.NoError(t, err) + + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.False(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) + + type TestAutoIncr3 struct { + Id int64 `xorm:"'ID'"` + } + + tb, err = testEngine.TableInfo(new(TestAutoIncr3)) + assert.NoError(t, err) + + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.False(t, cols[0].IsPrimaryKey) + assert.Equal(t, "ID", cols[0].Name) + + type TestAutoIncr4 struct { + Id int64 `xorm:"pk"` + } + + tb, err = testEngine.TableInfo(new(TestAutoIncr4)) + assert.NoError(t, err) + + cols = tb.Columns() + assert.EqualValues(t, 1, len(cols)) + assert.False(t, cols[0].IsAutoIncrement) + assert.True(t, cols[0].IsPrimaryKey) + assert.Equal(t, "id", cols[0].Name) +} + +func TestTagComment(t *testing.T) { + assert.NoError(t, PrepareEngine()) + // FIXME: only support mysql + if testEngine.Dialect().URI().DBType != schemas.MYSQL { + return + } + + type TestComment1 struct { + Id int64 `xorm:"comment(主键)"` + } + + assert.NoError(t, testEngine.Sync2(new(TestComment1))) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) + + assert.NoError(t, testEngine.DropTables(new(TestComment1))) + + type TestComment2 struct { + Id int64 `xorm:"comment('主键')"` + } + + assert.NoError(t, testEngine.Sync2(new(TestComment2))) + + tables, err = testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) +} + +func TestTagDefault(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct struct { + Id int64 + Name string + Age int `xorm:"default(10)"` + } + + assertSync(t, new(DefaultStruct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("age") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.EqualValues(t, "10", defaultVal) + + cnt, err := testEngine.Omit("age").Insert(&DefaultStruct{ + Name: "test", + Age: 20, + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s DefaultStruct + has, err := testEngine.ID(1).Get(&s) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 10, s.Age) + assert.EqualValues(t, "test", s.Name) +} + +func TestTagDefault2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct2 struct { + Id int64 + Name string + } + + assertSync(t, new(DefaultStruct2)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct2") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("name") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.False(t, isDefaultExist, fmt.Sprintf("default value is --%v--", defaultVal)) + assert.EqualValues(t, "", defaultVal) +} + +func TestTagDefault3(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct3 struct { + Id int64 + Name string `xorm:"default('myname')"` + } + + assertSync(t, new(DefaultStruct3)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct3") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("name") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.EqualValues(t, "'myname'", defaultVal) +} + +func TestTagDefault4(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct4 struct { + Id int64 + Created time.Time `xorm:"default(CURRENT_TIMESTAMP)"` + } + + assertSync(t, new(DefaultStruct4)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct4") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("created") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.True(t, "CURRENT_TIMESTAMP" == defaultVal || + "current_timestamp()" == defaultVal || // for cockroach + "now()" == defaultVal || + "getdate" == defaultVal, defaultVal) +} + +func TestTagDefault5(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct5 struct { + Id int64 + Created time.Time `xorm:"default('2006-01-02 15:04:05')"` + } + + assertSync(t, new(DefaultStruct5)) + table, err := testEngine.TableInfo(new(DefaultStruct5)) + assert.NoError(t, err) + + createdCol := table.GetColumn("created") + assert.NotNil(t, createdCol) + assert.EqualValues(t, "'2006-01-02 15:04:05'", createdCol.Default) + assert.False(t, createdCol.DefaultIsEmpty) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct5") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("created") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + assert.EqualValues(t, "'2006-01-02 15:04:05'", defaultVal) +} + +func TestTagDefault6(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DefaultStruct6 struct { + Id int64 + IsMan bool `xorm:"default(true)"` + } + + assertSync(t, new(DefaultStruct6)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct6") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("is_man") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + if defaultVal == "1" { + defaultVal = "true" + } else if defaultVal == "0" { + defaultVal = "false" + } + assert.EqualValues(t, "true", defaultVal) +} + +func TestTagsDirection(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type OnlyFromDBStruct struct { + Id int64 + Name string + Uuid string `xorm:"<- default '1'"` + } + + assertSync(t, new(OnlyFromDBStruct)) + + cnt, err := testEngine.Insert(&OnlyFromDBStruct{ + Name: "test", + Uuid: "2", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s OnlyFromDBStruct + has, err := testEngine.ID(1).Get(&s) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "1", s.Uuid) + assert.EqualValues(t, "test", s.Name) + + cnt, err = testEngine.ID(1).Update(&OnlyFromDBStruct{ + Uuid: "3", + Name: "test1", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s3 OnlyFromDBStruct + has, err = testEngine.ID(1).Get(&s3) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "1", s3.Uuid) + assert.EqualValues(t, "test1", s3.Name) + + type OnlyToDBStruct struct { + Id int64 + Name string + Uuid string `xorm:"->"` + } + + assertSync(t, new(OnlyToDBStruct)) + + cnt, err = testEngine.Insert(&OnlyToDBStruct{ + Name: "test", + Uuid: "2", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var s2 OnlyToDBStruct + has, err = testEngine.ID(1).Get(&s2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "", s2.Uuid) + assert.EqualValues(t, "test", s2.Name) +} + +func TestTagTime(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TagUTCStruct struct { + Id int64 + Name string + Created time.Time `xorm:"created utc"` + } + + assertSync(t, new(TagUTCStruct)) + + assert.EqualValues(t, time.Local.String(), testEngine.GetTZLocation().String()) + + s := TagUTCStruct{ + Name: "utc", + } + cnt, err := testEngine.Insert(&s) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var u TagUTCStruct + has, err := testEngine.ID(1).Get(&u) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, s.Created.Format("2006-01-02 15:04:05"), u.Created.Format("2006-01-02 15:04:05")) + + var tm string + has, err = testEngine.Table("tag_u_t_c_struct").Cols("created").Get(&tm) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), + strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) +} + +func TestTagAutoIncr(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TagAutoIncr struct { + Id int64 + Name string + } + + assertSync(t, new(TagAutoIncr)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, tableMapper.Obj2Table("TagAutoIncr"), tables[0].Name) + col := tables[0].GetColumn(colMapper.Obj2Table("Id")) + assert.NotNil(t, col) + assert.True(t, col.IsPrimaryKey) + assert.True(t, col.IsAutoIncrement) + + col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) + assert.NotNil(t, col2) + assert.False(t, col2.IsPrimaryKey) + assert.False(t, col2.IsAutoIncrement) +} + +func TestTagPrimarykey(t *testing.T) { + assert.NoError(t, PrepareEngine()) + type TagPrimaryKey struct { + Id int64 `xorm:"pk"` + Name string `xorm:"VARCHAR(20) pk"` + } + + assertSync(t, new(TagPrimaryKey)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, tableMapper.Obj2Table("TagPrimaryKey"), tables[0].Name) + col := tables[0].GetColumn(colMapper.Obj2Table("Id")) + assert.NotNil(t, col) + assert.True(t, col.IsPrimaryKey) + assert.False(t, col.IsAutoIncrement) + + col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) + assert.NotNil(t, col2) + assert.True(t, col2.IsPrimaryKey) + assert.False(t, col2.IsAutoIncrement) +} + +type VersionS struct { + Id int64 + Name string + Ver int `xorm:"version"` + Created time.Time `xorm:"created"` +} + +func TestVersion1(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(VersionS)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(VersionS)) + assert.NoError(t, err) + + ver := &VersionS{Name: "sfsfdsfds"} + _, err = testEngine.Insert(ver) + assert.NoError(t, err) + assert.EqualValues(t, ver.Ver, 1) + + newVer := new(VersionS) + has, err := testEngine.ID(ver.Id).Get(newVer) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, newVer.Ver, 1) + + newVer.Name = "-------" + _, err = testEngine.ID(ver.Id).Update(newVer) + assert.NoError(t, err) + assert.EqualValues(t, newVer.Ver, 2) + + newVer = new(VersionS) + has, err = testEngine.ID(ver.Id).Get(newVer) + assert.NoError(t, err) + assert.EqualValues(t, newVer.Ver, 2) +} + +func TestVersion2(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(VersionS)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(VersionS)) + assert.NoError(t, err) + + var vers = []VersionS{ + {Name: "sfsfdsfds"}, + {Name: "xxxxx"}, + } + _, err = testEngine.Insert(vers) + assert.NoError(t, err) + for _, v := range vers { + assert.EqualValues(t, v.Ver, 1) + } +} + +type VersionUintS struct { + Id int64 + Name string + Ver uint `xorm:"version"` + Created time.Time `xorm:"created"` +} + +func TestVersion3(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(VersionUintS)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(VersionUintS)) + assert.NoError(t, err) + + ver := &VersionUintS{Name: "sfsfdsfds"} + _, err = testEngine.Insert(ver) + assert.NoError(t, err) + assert.EqualValues(t, ver.Ver, 1) + + newVer := new(VersionUintS) + has, err := testEngine.ID(ver.Id).Get(newVer) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, newVer.Ver, 1) + + newVer.Name = "-------" + _, err = testEngine.ID(ver.Id).Update(newVer) + assert.NoError(t, err) + assert.EqualValues(t, newVer.Ver, 2) + + newVer = new(VersionUintS) + has, err = testEngine.ID(ver.Id).Get(newVer) + assert.NoError(t, err) + assert.EqualValues(t, newVer.Ver, 2) +} + +func TestVersion4(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + err := testEngine.DropTables(new(VersionUintS)) + assert.NoError(t, err) + + err = testEngine.CreateTables(new(VersionUintS)) + assert.NoError(t, err) + + var vers = []VersionUintS{ + {Name: "sfsfdsfds"}, + {Name: "xxxxx"}, + } + _, err = testEngine.Insert(vers) + assert.NoError(t, err) + for _, v := range vers { + assert.EqualValues(t, v.Ver, 1) + } +} + +func TestIndexes(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type TestIndexesStruct struct { + Id int64 + Name string `xorm:"index unique(s)"` + Email string `xorm:"index unique(s)"` + } + + assertSync(t, new(TestIndexesStruct)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 3, len(tables[0].Columns())) + slice1 := []string{ + testEngine.GetColumnMapper().Obj2Table("Id"), + testEngine.GetColumnMapper().Obj2Table("Name"), + testEngine.GetColumnMapper().Obj2Table("Email"), + } + slice2 := []string{ + tables[0].Columns()[0].Name, + tables[0].Columns()[1].Name, + tables[0].Columns()[2].Name, + } + sort.Strings(slice1) + sort.Strings(slice2) + assert.EqualValues(t, slice1, slice2) + assert.EqualValues(t, 3, len(tables[0].Indexes)) +} diff --git a/xorm_test.go b/integrations/tests.go similarity index 65% rename from xorm_test.go rename to integrations/tests.go index 21715256..c8219935 100644 --- a/xorm_test.go +++ b/integrations/tests.go @@ -1,28 +1,27 @@ -// Copyright 2018 The Xorm Authors. All rights reserved. +// Copyright 2017 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "database/sql" "flag" "fmt" - "log" "os" "strings" "testing" - _ "github.com/denisenkom/go-mssqldb" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - _ "github.com/ziutek/mymysql/godrv" - "xorm.io/core" + "xorm.io/xorm" + "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" + "xorm.io/xorm/log" + "xorm.io/xorm/names" + "xorm.io/xorm/schemas" ) var ( - testEngine EngineInterface + testEngine xorm.EngineInterface dbType string connString string @@ -30,14 +29,15 @@ var ( showSQL = flag.Bool("show_sql", true, "show generated SQLs") ptrConnStr = flag.String("conn_str", "./test.db?cache=shared&mode=rwc", "test database connection string") mapType = flag.String("map_type", "snake", "indicate the name mapping") - cache = flag.Bool("cache", false, "if enable cache") + cacheFlag = flag.Bool("cache", false, "if enable cache") cluster = flag.Bool("cluster", false, "if this is a cluster") splitter = flag.String("splitter", ";", "the splitter on connstr for cluster") schema = flag.String("schema", "", "specify the schema") ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb") - - tableMapper core.IMapper - colMapper core.IMapper + ingoreUpdateLimit = flag.Bool("ignore_update_limit", false, "ignore update limit if implementation difference, only for cockroach") + quotePolicyStr = flag.String("quote", "always", "quote could be always, none, reversed") + tableMapper names.Mapper + colMapper names.Mapper ) func createEngine(dbType, connStr string) error { @@ -45,8 +45,8 @@ func createEngine(dbType, connStr string) error { var err error if !*cluster { - switch strings.ToLower(dbType) { - case core.MSSQL: + switch schemas.DBType(strings.ToLower(dbType)) { + case schemas.MSSQL: db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) if err != nil { return err @@ -56,12 +56,12 @@ func createEngine(dbType, connStr string) error { } db.Close() *ignoreSelectUpdate = true - case core.POSTGRES: - db, err := sql.Open(dbType, connStr) + case schemas.POSTGRES: + db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "postgres", -1)) if err != nil { return err } - rows, err := db.Query(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = 'xorm_test'")) + rows, err := db.Query("SELECT 1 FROM pg_database WHERE datname = 'xorm_test'") if err != nil { return fmt.Errorf("db.Query: %v", err) } @@ -73,13 +73,19 @@ func createEngine(dbType, connStr string) error { } } if *schema != "" { + db.Close() + db, err = sql.Open(dbType, connStr) + if err != nil { + return err + } + defer db.Close() if _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS " + *schema); err != nil { return fmt.Errorf("CREATE SCHEMA: %v", err) } } db.Close() *ignoreSelectUpdate = true - case core.MYSQL: + case schemas.MYSQL: db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "mysql", -1)) if err != nil { return err @@ -92,9 +98,9 @@ func createEngine(dbType, connStr string) error { *ignoreSelectUpdate = true } - testEngine, err = NewEngine(dbType, connStr) + testEngine, err = xorm.NewEngine(dbType, connStr) } else { - testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter)) + testEngine, err = xorm.NewEngineGroup(dbType, strings.Split(connStr, *splitter)) if dbType != "mysql" && dbType != "mymysql" { *ignoreSelectUpdate = true } @@ -107,22 +113,30 @@ func createEngine(dbType, connStr string) error { testEngine.SetSchema(*schema) } testEngine.ShowSQL(*showSQL) - testEngine.SetLogLevel(core.LOG_DEBUG) - if *cache { - cacher := NewLRUCacher(NewMemoryStore(), 100000) + testEngine.SetLogLevel(log.LOG_DEBUG) + if *cacheFlag { + cacher := caches.NewLRUCacher(caches.NewMemoryStore(), 100000) testEngine.SetDefaultCacher(cacher) } if len(*mapType) > 0 { switch *mapType { case "snake": - testEngine.SetMapper(core.SnakeMapper{}) + testEngine.SetMapper(names.SnakeMapper{}) case "same": - testEngine.SetMapper(core.SameMapper{}) + testEngine.SetMapper(names.SameMapper{}) case "gonic": - testEngine.SetMapper(core.LintGonicMapper) + testEngine.SetMapper(names.LintGonicMapper) } } + + if *quotePolicyStr == "none" { + testEngine.SetQuotePolicy(dialects.QuotePolicyNone) + } else if *quotePolicyStr == "reserved" { + testEngine.SetQuotePolicy(dialects.QuotePolicyReserved) + } else { + testEngine.SetQuotePolicy(dialects.QuotePolicyAlways) + } } tableMapper = testEngine.GetTableMapper() @@ -142,11 +156,11 @@ func createEngine(dbType, connStr string) error { return nil } -func prepareEngine() error { +func PrepareEngine() error { return createEngine(dbType, connString) } -func TestMain(m *testing.M) { +func MainTest(m *testing.M) { flag.Parse() dbType = *db @@ -158,7 +172,7 @@ func TestMain(m *testing.M) { } } else { if ptrConnStr == nil { - log.Fatal("you should indicate conn string") + fmt.Println("you should indicate conn string") return } connString = *ptrConnStr @@ -174,8 +188,9 @@ func TestMain(m *testing.M) { testEngine = nil fmt.Println("testing", dbType, connString) - if err := prepareEngine(); err != nil { - log.Fatal(err) + if err := PrepareEngine(); err != nil { + fmt.Println(err) + os.Exit(1) return } @@ -187,9 +202,3 @@ func TestMain(m *testing.M) { os.Exit(res) } - -func TestPing(t *testing.T) { - if err := testEngine.Ping(); err != nil { - t.Fatal(err) - } -} diff --git a/time_test.go b/integrations/time_test.go similarity index 83% rename from time_test.go rename to integrations/time_test.go index b7e4d12b..6d8d812c 100644 --- a/time_test.go +++ b/integrations/time_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "fmt" @@ -10,11 +10,17 @@ import ( "testing" "time" + "xorm.io/xorm/internal/utils" + "github.com/stretchr/testify/assert" ) +func formatTime(t time.Time) string { + return t.Format("2006-01-02 15:04:05") +} + func TestTimeUserTime(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type TimeUser struct { Id string @@ -44,7 +50,7 @@ func TestTimeUserTime(t *testing.T) { } func TestTimeUserTimeDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -80,7 +86,7 @@ func TestTimeUserTimeDiffLoc(t *testing.T) { } func TestTimeUserCreated(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserCreated struct { Id string @@ -109,7 +115,7 @@ func TestTimeUserCreated(t *testing.T) { } func TestTimeUserCreatedDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -144,7 +150,7 @@ func TestTimeUserCreatedDiffLoc(t *testing.T) { } func TestTimeUserUpdated(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserUpdated struct { Id string @@ -195,7 +201,7 @@ func TestTimeUserUpdated(t *testing.T) { } func TestTimeUserUpdatedDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -252,13 +258,15 @@ func TestTimeUserUpdatedDiffLoc(t *testing.T) { } func TestTimeUserDeleted(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserDeleted struct { - Id string - CreatedAt time.Time `xorm:"created"` - UpdatedAt time.Time `xorm:"updated"` - DeletedAt time.Time `xorm:"deleted"` + Id string + CreatedAt time.Time `xorm:"created"` + UpdatedAt time.Time `xorm:"updated"` + DeletedAt time.Time `xorm:"deleted"` + CreatedAtStr string `xorm:"datetime created"` + UpdatedAtStr string `xorm:"datetime updated"` } assertSync(t, new(UserDeleted)) @@ -280,14 +288,15 @@ func TestTimeUserDeleted(t *testing.T) { assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) - assert.True(t, isTimeZero(user2.DeletedAt)) + assert.True(t, utils.IsTimeZero(user2.DeletedAt)) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) + fmt.Println("user2 str", user2.CreatedAtStr, user2.UpdatedAtStr) var user3 UserDeleted cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - assert.True(t, !isTimeZero(user3.DeletedAt)) + assert.True(t, !utils.IsTimeZero(user3.DeletedAt)) var user4 UserDeleted has, err = testEngine.Unscoped().Get(&user4) @@ -299,7 +308,7 @@ func TestTimeUserDeleted(t *testing.T) { } func TestTimeUserDeletedDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -333,14 +342,14 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { assert.EqualValues(t, formatTime(user.CreatedAt), formatTime(user2.CreatedAt)) assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) assert.EqualValues(t, formatTime(user.UpdatedAt), formatTime(user2.UpdatedAt)) - assert.True(t, isTimeZero(user2.DeletedAt)) + assert.True(t, utils.IsTimeZero(user2.DeletedAt)) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted2 cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - assert.True(t, !isTimeZero(user3.DeletedAt)) + assert.True(t, !utils.IsTimeZero(user3.DeletedAt)) var user4 UserDeleted2 has, err = testEngine.Unscoped().Get(&user4) @@ -351,38 +360,38 @@ func TestTimeUserDeletedDiffLoc(t *testing.T) { fmt.Println("user3", user3.DeletedAt, user4.DeletedAt) } -type JsonDate time.Time +type JSONDate time.Time -func (j JsonDate) MarshalJSON() ([]byte, error) { +func (j JSONDate) MarshalJSON() ([]byte, error) { if time.Time(j).IsZero() { return []byte(`""`), nil } return []byte(`"` + time.Time(j).Format("2006-01-02 15:04:05") + `"`), nil } -func (j *JsonDate) UnmarshalJSON(value []byte) error { +func (j *JSONDate) UnmarshalJSON(value []byte) error { var v = strings.TrimSpace(strings.Trim(string(value), "\"")) t, err := time.ParseInLocation("2006-01-02 15:04:05", v, time.Local) if err != nil { return err } - *j = JsonDate(t) + *j = JSONDate(t) return nil } -func (j *JsonDate) Unix() int64 { +func (j *JSONDate) Unix() int64 { return (*time.Time)(j).Unix() } func TestCustomTimeUserDeleted(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type UserDeleted3 struct { Id string - CreatedAt JsonDate `xorm:"created"` - UpdatedAt JsonDate `xorm:"updated"` - DeletedAt JsonDate `xorm:"deleted"` + CreatedAt JSONDate `xorm:"created"` + UpdatedAt JSONDate `xorm:"updated"` + DeletedAt JSONDate `xorm:"deleted"` } assertSync(t, new(UserDeleted3)) @@ -404,14 +413,14 @@ func TestCustomTimeUserDeleted(t *testing.T) { assert.EqualValues(t, formatTime(time.Time(user.CreatedAt)), formatTime(time.Time(user2.CreatedAt))) assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) assert.EqualValues(t, formatTime(time.Time(user.UpdatedAt)), formatTime(time.Time(user2.UpdatedAt))) - assert.True(t, isTimeZero(time.Time(user2.DeletedAt))) + assert.True(t, utils.IsTimeZero(time.Time(user2.DeletedAt))) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted3 cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - assert.True(t, !isTimeZero(time.Time(user3.DeletedAt))) + assert.True(t, !utils.IsTimeZero(time.Time(user3.DeletedAt))) var user4 UserDeleted3 has, err = testEngine.Unscoped().Get(&user4) @@ -423,7 +432,7 @@ func TestCustomTimeUserDeleted(t *testing.T) { } func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) loc, err := time.LoadLocation("Asia/Shanghai") assert.NoError(t, err) testEngine.SetTZLocation(loc) @@ -433,9 +442,9 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { type UserDeleted4 struct { Id string - CreatedAt JsonDate `xorm:"created"` - UpdatedAt JsonDate `xorm:"updated"` - DeletedAt JsonDate `xorm:"deleted"` + CreatedAt JSONDate `xorm:"created"` + UpdatedAt JSONDate `xorm:"updated"` + DeletedAt JSONDate `xorm:"deleted"` } assertSync(t, new(UserDeleted4)) @@ -457,14 +466,14 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { assert.EqualValues(t, formatTime(time.Time(user.CreatedAt)), formatTime(time.Time(user2.CreatedAt))) assert.EqualValues(t, user.UpdatedAt.Unix(), user2.UpdatedAt.Unix()) assert.EqualValues(t, formatTime(time.Time(user.UpdatedAt)), formatTime(time.Time(user2.UpdatedAt))) - assert.True(t, isTimeZero(time.Time(user2.DeletedAt))) + assert.True(t, utils.IsTimeZero(time.Time(user2.DeletedAt))) fmt.Println("user2", user2.CreatedAt, user2.UpdatedAt, user2.DeletedAt) var user3 UserDeleted4 cnt, err = testEngine.Where("id = ?", "lunny").Delete(&user3) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - assert.True(t, !isTimeZero(time.Time(user3.DeletedAt))) + assert.True(t, !utils.IsTimeZero(time.Time(user3.DeletedAt))) var user4 UserDeleted4 has, err = testEngine.Unscoped().Get(&user4) @@ -474,3 +483,40 @@ func TestCustomTimeUserDeletedDiffLoc(t *testing.T) { assert.EqualValues(t, formatTime(time.Time(user3.DeletedAt)), formatTime(time.Time(user4.DeletedAt))) fmt.Println("user3", user3.DeletedAt, user4.DeletedAt) } + +func TestDeletedInt64(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type DeletedInt64Struct struct { + Id int64 + Deleted int64 `xorm:"deleted default(0) notnull"` // timestamp + } + + assertSync(t, new(DeletedInt64Struct)) + + var d1 DeletedInt64Struct + cnt, err := testEngine.Insert(&d1) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var d2 DeletedInt64Struct + has, err := testEngine.ID(d1.Id).Get(&d2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, d1, d2) + + cnt, err = testEngine.ID(d1.Id).NoAutoCondition().Delete(&d1) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var d3 DeletedInt64Struct + has, err = testEngine.ID(d1.Id).Get(&d3) + assert.NoError(t, err) + assert.False(t, has) + + var d4 DeletedInt64Struct + has, err = testEngine.ID(d1.Id).Unscoped().Get(&d4) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, d1, d4) +} diff --git a/types_null_test.go b/integrations/types_null_test.go similarity index 59% rename from types_null_test.go rename to integrations/types_null_test.go index 7a13837e..98bd86b9 100644 --- a/types_null_test.go +++ b/integrations/types_null_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "database/sql" @@ -22,7 +22,8 @@ type NullType struct { Age sql.NullInt64 Height sql.NullFloat64 IsMan sql.NullBool `xorm:"null"` - CustomStruct CustomStruct `xorm:"valchar(64) null"` + Nil driver.Valuer + CustomStruct CustomStruct `xorm:"varchar(64) null"` } type CustomStruct struct { @@ -57,90 +58,61 @@ func (m CustomStruct) Value() (driver.Value, error) { } func TestCreateNullStructTable(t *testing.T) { - assert.NoError(t, prepareEngine()) - + assert.NoError(t, PrepareEngine()) err := testEngine.CreateTables(new(NullType)) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestDropNullStructTable(t *testing.T) { - assert.NoError(t, prepareEngine()) - + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables(new(NullType)) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestNullStructInsert(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) - if true { - item := new(NullType) - _, err := testEngine.Insert(item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(item) - if item.Id != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } - } + item1 := new(NullType) + _, err := testEngine.Insert(item1) + assert.NoError(t, err) + assert.EqualValues(t, 1, item1.Id) - if true { + item := NullType{ + Name: sql.NullString{String: "haolei", Valid: true}, + Age: sql.NullInt64{Int64: 34, Valid: true}, + Height: sql.NullFloat64{Float64: 1.72, Valid: true}, + IsMan: sql.NullBool{Bool: true, Valid: true}, + Nil: nil, + } + _, err = testEngine.Insert(&item) + assert.NoError(t, err) + assert.EqualValues(t, 2, item.Id) + + items := []NullType{} + for i := 0; i < 5; i++ { item := NullType{ - Name: sql.NullString{String: "haolei", Valid: true}, - Age: sql.NullInt64{Int64: 34, Valid: true}, - Height: sql.NullFloat64{Float64: 1.72, Valid: true}, - IsMan: sql.NullBool{Bool: true, Valid: true}, - } - _, err := testEngine.Insert(&item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(item) - if item.Id != 2 { - err = errors.New("insert error") - t.Error(err) - panic(err) + Name: sql.NullString{String: "haolei_" + fmt.Sprint(i+1), Valid: true}, + Age: sql.NullInt64{Int64: 30 + int64(i), Valid: true}, + Height: sql.NullFloat64{Float64: 1.5 + 1.1*float64(i), Valid: true}, + IsMan: sql.NullBool{Bool: true, Valid: true}, + CustomStruct: CustomStruct{i, i + 1, i + 2}, + Nil: nil, } + items = append(items, item) } - if true { - items := []NullType{} + _, err = testEngine.Insert(&items) + assert.NoError(t, err) - for i := 0; i < 5; i++ { - item := NullType{ - Name: sql.NullString{String: "haolei_" + fmt.Sprint(i+1), Valid: true}, - Age: sql.NullInt64{Int64: 30 + int64(i), Valid: true}, - Height: sql.NullFloat64{Float64: 1.5 + 1.1*float64(i), Valid: true}, - IsMan: sql.NullBool{Bool: true, Valid: true}, - CustomStruct: CustomStruct{i, i + 1, i + 2}, - } - - items = append(items, item) - } - - _, err := testEngine.Insert(&items) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(items) - } + items = make([]NullType, 0, 7) + err = testEngine.Find(&items) + assert.NoError(t, err) + assert.EqualValues(t, 7, len(items)) } func TestNullStructUpdate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) _, err := testEngine.Insert([]NullType{ @@ -177,30 +149,16 @@ func TestNullStructUpdate(t *testing.T) { item.Height = sql.NullFloat64{Float64: 0, Valid: false} // update to NULL affected, err := testEngine.ID(2).Cols("age", "height", "is_man").Update(item) - if err != nil { - t.Error(err) - panic(err) - } - if affected != 1 { - err := errors.New("update failed") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) } if true { // 测试In update item := new(NullType) item.Age = sql.NullInt64{Int64: 23, Valid: true} affected, err := testEngine.In("id", 3, 4).Cols("age", "height", "is_man").Update(item) - if err != nil { - t.Error(err) - panic(err) - } - if affected != 2 { - err := errors.New("update failed") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 2, affected) } if true { // 测试where @@ -210,10 +168,7 @@ func TestNullStructUpdate(t *testing.T) { item.Age = sql.NullInt64{Int64: 34, Valid: true} _, err := testEngine.Where("age > ?", 34).Update(item) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } if true { // 修改全部时,插入空值 @@ -225,17 +180,12 @@ func TestNullStructUpdate(t *testing.T) { } _, err := testEngine.AllCols().ID(6).Update(item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(item) + assert.NoError(t, err) } - } func TestNullStructFind(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) _, err := testEngine.Insert([]NullType{ @@ -269,68 +219,38 @@ func TestNullStructFind(t *testing.T) { if true { item := new(NullType) has, err := testEngine.ID(1).Get(item) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - t.Error(errors.New("no find id 1")) - panic(err) - } - fmt.Println(item) - if item.Id != 1 || item.Name.Valid || item.Age.Valid || item.Height.Valid || - item.IsMan.Valid { - err = errors.New("insert error") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, item.Id, 1) + assert.False(t, item.Name.Valid) + assert.False(t, item.Age.Valid) + assert.False(t, item.Height.Valid) + assert.False(t, item.IsMan.Valid) } if true { item := new(NullType) item.Id = 2 - has, err := testEngine.Get(item) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - t.Error(errors.New("no find id 2")) - panic(err) - } - fmt.Println(item) + assert.NoError(t, err) + assert.True(t, has) } if true { item := make([]NullType, 0) - err := testEngine.ID(2).Find(&item) - if err != nil { - t.Error(err) - panic(err) - } - - fmt.Println(item) + assert.NoError(t, err) } if true { item := make([]NullType, 0) - err := testEngine.Asc("age").Find(&item) - if err != nil { - t.Error(err) - panic(err) - } - - for k, v := range item { - fmt.Println(k, v) - } + assert.NoError(t, err) } } func TestNullStructIterate(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) if true { @@ -340,65 +260,45 @@ func TestNullStructIterate(t *testing.T) { fmt.Println(i, nultype) return nil }) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } } func TestNullStructCount(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) if true { item := new(NullType) - total, err := testEngine.Where("age IS NOT NULL").Count(item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(total) + _, err := testEngine.Where("age IS NOT NULL").Count(item) + assert.NoError(t, err) } } func TestNullStructRows(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) item := new(NullType) rows, err := testEngine.Where("id > ?", 1).Rows(item) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) defer rows.Close() for rows.Next() { err = rows.Scan(item) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(item) + assert.NoError(t, err) } } func TestNullStructDelete(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) assertSync(t, new(NullType)) item := new(NullType) _, err := testEngine.ID(1).Delete(item) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Where("id > ?", 1).Delete(item) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } diff --git a/types_test.go b/integrations/types_test.go similarity index 76% rename from types_test.go rename to integrations/types_test.go index 274609b2..112308f3 100644 --- a/types_test.go +++ b/integrations/types_test.go @@ -2,19 +2,23 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package integrations import ( "errors" "fmt" "testing" - "xorm.io/core" + "xorm.io/xorm" + "xorm.io/xorm/convert" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/schemas" + "github.com/stretchr/testify/assert" ) func TestArrayField(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type ArrayStruct struct { Id int64 @@ -77,7 +81,7 @@ func TestArrayField(t *testing.T) { } func TestGetBytes(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) type Varbinary struct { Data []byte `xorm:"VARBINARY(250)"` @@ -116,40 +120,47 @@ type ConvConfig struct { } func (s *ConvConfig) FromDB(data []byte) error { - return DefaultJSONHandler.Unmarshal(data, s) + if data == nil { + s = nil + return nil + } + return json.DefaultJSONHandler.Unmarshal(data, s) } func (s *ConvConfig) ToDB() ([]byte, error) { - return DefaultJSONHandler.Marshal(s) + if s == nil { + return nil, nil + } + return json.DefaultJSONHandler.Marshal(s) } type SliceType []*ConvConfig func (s *SliceType) FromDB(data []byte) error { - return DefaultJSONHandler.Unmarshal(data, s) + return json.DefaultJSONHandler.Unmarshal(data, s) } func (s *SliceType) ToDB() ([]byte, error) { - return DefaultJSONHandler.Marshal(s) + return json.DefaultJSONHandler.Marshal(s) } type ConvStruct struct { Conv ConvString Conv2 *ConvString Cfg1 ConvConfig - Cfg2 *ConvConfig `xorm:"TEXT"` - Cfg3 core.Conversion `xorm:"BLOB"` + Cfg2 *ConvConfig `xorm:"TEXT"` + Cfg3 convert.Conversion `xorm:"BLOB"` Slice SliceType } -func (c *ConvStruct) BeforeSet(name string, cell Cell) { +func (c *ConvStruct) BeforeSet(name string, cell xorm.Cell) { if name == "cfg3" || name == "Cfg3" { c.Cfg3 = new(ConvConfig) } } func TestConversion(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) c := new(ConvStruct) assert.NoError(t, testEngine.DropTables(c)) @@ -181,6 +192,30 @@ func TestConversion(t *testing.T) { assert.EqualValues(t, 2, len(c1.Slice)) assert.EqualValues(t, *c.Slice[0], *c1.Slice[0]) assert.EqualValues(t, *c.Slice[1], *c1.Slice[1]) + + cnt, err := testEngine.Where("1=1").Delete(new(ConvStruct)) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + c.Cfg2 = nil + + _, err = testEngine.Insert(c) + assert.NoError(t, err) + + c2 := new(ConvStruct) + has, err = testEngine.Get(c2) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "prefix---tttt", string(c2.Conv)) + assert.NotNil(t, c2.Conv2) + assert.EqualValues(t, "prefix---"+s, *c2.Conv2) + assert.EqualValues(t, c.Cfg1, c2.Cfg1) + assert.Nil(t, c2.Cfg2) + assert.NotNil(t, c2.Cfg3) + assert.EqualValues(t, *c.Cfg3.(*ConvConfig), *c2.Cfg3.(*ConvConfig)) + assert.EqualValues(t, 2, len(c2.Slice)) + assert.EqualValues(t, *c.Slice[0], *c2.Slice[0]) + assert.EqualValues(t, *c.Slice[1], *c2.Slice[1]) } type MyInt int @@ -209,7 +244,7 @@ type MyStruct struct { } func TestCustomType1(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) err := testEngine.DropTables(&MyStruct{}) assert.NoError(t, err) @@ -267,14 +302,14 @@ type Status struct { } var ( - _ core.Conversion = &Status{} - Registed Status = Status{"Registed", "white"} - Approved Status = Status{"Approved", "green"} - Removed Status = Status{"Removed", "red"} - Statuses map[string]Status = map[string]Status{ - Registed.Name: Registed, - Approved.Name: Approved, - Removed.Name: Removed, + _ convert.Conversion = &Status{} + Registered = Status{"Registered", "white"} + Approved = Status{"Approved", "green"} + Removed = Status{"Removed", "red"} + Statuses = map[string]Status{ + Registered.Name: Registered, + Approved.Name: Approved, + Removed.Name: Removed, } ) @@ -282,9 +317,8 @@ func (s *Status) FromDB(bytes []byte) error { if r, ok := Statuses[string(bytes)]; ok { *s = r return nil - } else { - return errors.New("no this data") } + return errors.New("no this data") } func (s *Status) ToDB() ([]byte, error) { @@ -298,7 +332,7 @@ type UserCus struct { } func TestCustomType2(t *testing.T) { - assert.NoError(t, prepareEngine()) + assert.NoError(t, PrepareEngine()) var uc UserCus err := testEngine.CreateTables(&uc) @@ -311,18 +345,18 @@ func TestCustomType2(t *testing.T) { session := testEngine.NewSession() defer session.Close() - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Begin() assert.NoError(t, err) _, err = session.Exec("set IDENTITY_INSERT " + tableName + " on") assert.NoError(t, err) } - cnt, err := session.Insert(&UserCus{1, "xlw", Registed}) + cnt, err := session.Insert(&UserCus{1, "xlw", Registered}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - if testEngine.Dialect().DBType() == core.MSSQL { + if testEngine.Dialect().URI().DBType == schemas.MSSQL { err = session.Commit() assert.NoError(t, err) } @@ -335,7 +369,7 @@ func TestCustomType2(t *testing.T) { fmt.Println(user) users := make([]UserCus, 0) - err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Status")+"` = ?", "Registed").Find(&users) + err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Status")+"` = ?", "Registered").Find(&users) assert.NoError(t, err) assert.EqualValues(t, 1, len(users)) diff --git a/interface.go b/interface.go index 0c42f60d..ab150fe6 100644 --- a/interface.go +++ b/interface.go @@ -10,7 +10,12 @@ import ( "reflect" "time" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/contexts" + "xorm.io/xorm/dialects" + "xorm.io/xorm/log" + "xorm.io/xorm/names" + "xorm.io/xorm/schemas" ) // Interface defines the interface which Engine, EngineGroup and Session will implementate. @@ -56,6 +61,7 @@ type Interface interface { QueryString(sqlOrArgs ...interface{}) ([]map[string]string, error) Rows(bean interface{}) (*Rows, error) SetExpr(string, interface{}) *Session + Select(string) *Session SQL(interface{}, ...interface{}) *Session Sum(bean interface{}, colName string) (float64, error) SumInt(bean interface{}, colName string) (int64, error) @@ -77,37 +83,42 @@ type EngineInterface interface { ClearCache(...interface{}) error Context(context.Context) *Session CreateTables(...interface{}) error - DBMetas() ([]*core.Table, error) - Dialect() core.Dialect + DBMetas() ([]*schemas.Table, error) + Dialect() dialects.Dialect + DriverName() string DropTables(...interface{}) error - DumpAllToFile(fp string, tp ...core.DbType) error - GetCacher(string) core.Cacher - GetColumnMapper() core.IMapper - GetDefaultCacher() core.Cacher - GetTableMapper() core.IMapper + DumpAllToFile(fp string, tp ...schemas.DBType) error + GetCacher(string) caches.Cacher + GetColumnMapper() names.Mapper + GetDefaultCacher() caches.Cacher + GetTableMapper() names.Mapper GetTZDatabase() *time.Location GetTZLocation() *time.Location - MapCacher(interface{}, core.Cacher) error + ImportFile(fp string) ([]sql.Result, error) + MapCacher(interface{}, caches.Cacher) error NewSession() *Session NoAutoTime() *Session Quote(string) string - SetCacher(string, core.Cacher) + SetCacher(string, caches.Cacher) SetConnMaxLifetime(time.Duration) - SetDefaultCacher(core.Cacher) - SetLogger(logger core.ILogger) - SetLogLevel(core.LogLevel) - SetMapper(core.IMapper) + SetColumnMapper(names.Mapper) + SetDefaultCacher(caches.Cacher) + SetLogger(logger interface{}) + SetLogLevel(log.LogLevel) + SetMapper(names.Mapper) SetMaxOpenConns(int) SetMaxIdleConns(int) + SetQuotePolicy(dialects.QuotePolicy) SetSchema(string) + SetTableMapper(names.Mapper) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) - ShowExecTime(...bool) + AddHook(hook contexts.Hook) ShowSQL(show ...bool) Sync(...interface{}) error Sync2(...interface{}) error StoreEngine(storeEngine string) *Session - TableInfo(bean interface{}) *Table + TableInfo(bean interface{}) (*schemas.Table, error) TableName(interface{}, ...bool) string UnMapType(reflect.Type) } diff --git a/json.go b/internal/json/json.go similarity index 98% rename from json.go rename to internal/json/json.go index fdb6ce56..c9a2eb4e 100644 --- a/json.go +++ b/internal/json/json.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package json import "encoding/json" diff --git a/internal/statements/cache.go b/internal/statements/cache.go new file mode 100644 index 00000000..cb33df08 --- /dev/null +++ b/internal/statements/cache.go @@ -0,0 +1,79 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +func (statement *Statement) ConvertIDSQL(sqlStr string) string { + if statement.RefTable != nil { + cols := statement.RefTable.PKColumns() + if len(cols) == 0 { + return "" + } + + colstrs := statement.joinColumns(cols, false) + sqls := utils.SplitNNoCase(sqlStr, " from ", 2) + if len(sqls) != 2 { + return "" + } + + var top string + pLimitN := statement.LimitN + if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL { + top = fmt.Sprintf("TOP %d ", *pLimitN) + } + + newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) + return newsql + } + return "" +} + +func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { + if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { + return "", "" + } + + colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true) + sqls := utils.SplitNNoCase(sqlStr, "where", 2) + if len(sqls) != 2 { + if len(sqls) == 1 { + return sqls[0], fmt.Sprintf("SELECT %v FROM %v", + colstrs, statement.quote(statement.TableName())) + } + return "", "" + } + + var whereStr = sqls[1] + + // TODO: for postgres only, if any other database? + var paraStr string + if statement.dialect.URI().DBType == schemas.POSTGRES { + paraStr = "$" + } else if statement.dialect.URI().DBType == schemas.MSSQL { + paraStr = ":" + } + + if paraStr != "" { + if strings.Contains(sqls[1], paraStr) { + dollers := strings.Split(sqls[1], paraStr) + whereStr = dollers[0] + for i, c := range dollers[1:] { + ccs := strings.SplitN(c, " ", 2) + whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) + } + } + } + + return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", + colstrs, statement.quote(statement.TableName()), + whereStr) +} diff --git a/internal/statements/column_map.go b/internal/statements/column_map.go new file mode 100644 index 00000000..bb764b4e --- /dev/null +++ b/internal/statements/column_map.go @@ -0,0 +1,66 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "strings" + + "xorm.io/xorm/schemas" +) + +type columnMap []string + +func (m columnMap) Contain(colName string) bool { + if len(m) == 0 { + return false + } + + n := len(colName) + for _, mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, colName) { + return true + } + } + + return false +} + +func (m columnMap) Len() int { + return len(m) +} + +func (m columnMap) IsEmpty() bool { + return len(m) == 0 +} + +func (m *columnMap) Add(colName string) bool { + if m.Contain(colName) { + return false + } + *m = append(*m, colName) + return true +} + +func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) { + if len(m) == 0 { + return false, false + } + + n := len(col.Name) + + for mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, col.Name) { + return m[mk], true + } + } + + return false, false +} diff --git a/statement_exprparam.go b/internal/statements/expr_param.go similarity index 66% rename from statement_exprparam.go rename to internal/statements/expr_param.go index 4da4f1ea..6657408e 100644 --- a/statement_exprparam.go +++ b/internal/statements/expr_param.go @@ -2,13 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "fmt" "strings" "xorm.io/builder" + "xorm.io/xorm/schemas" ) type ErrUnsupportedExprType struct { @@ -25,22 +26,22 @@ type exprParam struct { } type exprParams struct { - colNames []string - args []interface{} + ColNames []string + Args []interface{} } func (exprs *exprParams) Len() int { - return len(exprs.colNames) + return len(exprs.ColNames) } func (exprs *exprParams) addParam(colName string, arg interface{}) { - exprs.colNames = append(exprs.colNames, colName) - exprs.args = append(exprs.args, arg) + exprs.ColNames = append(exprs.ColNames, colName) + exprs.Args = append(exprs.Args, arg) } -func (exprs *exprParams) isColExist(colName string) bool { - for _, name := range exprs.colNames { - if strings.EqualFold(trimQuote(name), trimQuote(colName)) { +func (exprs *exprParams) IsColExist(colName string) bool { + for _, name := range exprs.ColNames { + if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) { return true } } @@ -48,16 +49,16 @@ func (exprs *exprParams) isColExist(colName string) bool { } func (exprs *exprParams) getByName(colName string) (exprParam, bool) { - for i, name := range exprs.colNames { + for i, name := range exprs.ColNames { if strings.EqualFold(name, colName) { - return exprParam{name, exprs.args[i]}, true + return exprParam{name, exprs.Args[i]}, true } } return exprParam{}, false } -func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { - for i, expr := range exprs.args { +func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error { + for i, expr := range exprs.Args { switch arg := expr.(type) { case *builder.Builder: if _, err := w.WriteString("("); err != nil { @@ -69,12 +70,20 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { if _, err := w.WriteString(")"); err != nil { return err } - default: + case string: + if arg == "" { + arg = "''" + } if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil { return err } + default: + if _, err := w.WriteString("?"); err != nil { + return err + } + w.Append(arg) } - if i != len(exprs.args)-1 { + if i != len(exprs.Args)-1 { if _, err := w.WriteString(","); err != nil { return err } @@ -84,7 +93,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { } func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { - for i, colName := range exprs.colNames { + for i, colName := range exprs.ColNames { if _, err := w.WriteString(colName); err != nil { return err } @@ -92,7 +101,7 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { return err } - switch arg := exprs.args[i].(type) { + switch arg := exprs.Args[i].(type) { case *builder.Builder: if _, err := w.WriteString("("); err != nil { return err @@ -104,10 +113,10 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { return err } default: - w.Append(exprs.args[i]) + w.Append(exprs.Args[i]) } - if i+1 != len(exprs.colNames) { + if i+1 != len(exprs.ColNames) { if _, err := w.WriteString(","); err != nil { return err } diff --git a/internal/statements/insert.go b/internal/statements/insert.go new file mode 100644 index 00000000..6cbbbeda --- /dev/null +++ b/internal/statements/insert.go @@ -0,0 +1,207 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schemas.Table) error { + if statement.dialect.URI().DBType == schemas.MSSQL && len(table.AutoIncrement) > 0 { + if _, err := buf.WriteString(" OUTPUT Inserted."); err != nil { + return err + } + if _, err := buf.WriteString(table.AutoIncrement); err != nil { + return err + } + } + return nil +} + +// GenInsertSQL generates insert beans SQL +func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) { + var ( + buf = builder.NewWriter() + exprs = statement.ExprColumns + table = statement.RefTable + tableName = statement.TableName() + ) + + if _, err := buf.WriteString("INSERT INTO "); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { + return "", nil, err + } + + if len(colNames) <= 0 { + if statement.dialect.URI().DBType == schemas.MYSQL { + if _, err := buf.WriteString(" VALUES ()"); err != nil { + return "", nil, err + } + } else { + if err := statement.writeInsertOutput(buf.Builder, table); err != nil { + return "", nil, err + } + if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil { + return "", nil, err + } + } + } else { + if _, err := buf.WriteString(" ("); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil { + return "", nil, err + } + + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + if err := statement.writeInsertOutput(buf.Builder, table); err != nil { + return "", nil, err + } + + if statement.Conds().IsValid() { + if _, err := buf.WriteString(" SELECT "); err != nil { + return "", nil, err + } + + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + + if _, err := buf.WriteString(" FROM "); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { + return "", nil, err + } + + if _, err := buf.WriteString(" WHERE "); err != nil { + return "", nil, err + } + + if err := statement.Conds().WriteTo(buf); err != nil { + return "", nil, err + } + } else { + if _, err := buf.WriteString(" VALUES ("); err != nil { + return "", nil, err + } + + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + } + + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + } + } + + if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { + if _, err := buf.WriteString(" RETURNING "); err != nil { + return "", nil, err + } + if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { + return "", nil, err + } + } + + return buf.String(), buf.Args(), nil +} + +// GenInsertMapSQL generates insert map SQL +func (statement *Statement) GenInsertMapSQL(columns []string, args []interface{}) (string, []interface{}, error) { + var ( + buf = builder.NewWriter() + exprs = statement.ExprColumns + tableName = statement.TableName() + ) + + if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s (", statement.quote(tableName))); err != nil { + return "", nil, err + } + + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames...), ","); err != nil { + return "", nil, err + } + + // if insert where + if statement.Conds().IsValid() { + if _, err := buf.WriteString(") SELECT "); err != nil { + return "", nil, err + } + + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + } + + if _, err := buf.WriteString(fmt.Sprintf(" FROM %s WHERE ", statement.quote(tableName))); err != nil { + return "", nil, err + } + + if err := statement.Conds().WriteTo(buf); err != nil { + return "", nil, err + } + } else { + if _, err := buf.WriteString(") VALUES ("); err != nil { + return "", nil, err + } + if err := statement.WriteArgs(buf, args); err != nil { + return "", nil, err + } + + if len(exprs.Args) > 0 { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if err := exprs.WriteArgs(buf); err != nil { + return "", nil, err + } + } + if _, err := buf.WriteString(")"); err != nil { + return "", nil, err + } + } + + return buf.String(), buf.Args(), nil +} diff --git a/internal/statements/pk.go b/internal/statements/pk.go new file mode 100644 index 00000000..59da89c0 --- /dev/null +++ b/internal/statements/pk.go @@ -0,0 +1,98 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "reflect" + + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +var ( + ptrPkType = reflect.TypeOf(&schemas.PK{}) + pkType = reflect.TypeOf(schemas.PK{}) + stringType = reflect.TypeOf("") + intType = reflect.TypeOf(int64(0)) + uintType = reflect.TypeOf(uint64(0)) +) + +// ErrIDConditionWithNoTable represents an error there is no reference table with an ID condition +type ErrIDConditionWithNoTable struct { + ID schemas.PK +} + +func (err ErrIDConditionWithNoTable) Error() string { + return fmt.Sprintf("ID condition %#v need reference table", err.ID) +} + +// IsIDConditionWithNoTableErr return true if the err is ErrIDConditionWithNoTable +func IsIDConditionWithNoTableErr(err error) bool { + _, ok := err.(ErrIDConditionWithNoTable) + return ok +} + +// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?" +func (statement *Statement) ID(id interface{}) *Statement { + switch t := id.(type) { + case *schemas.PK: + statement.idParam = *t + case schemas.PK: + statement.idParam = t + case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + statement.idParam = schemas.PK{id} + default: + idValue := reflect.ValueOf(id) + idType := idValue.Type() + + switch idType.Kind() { + case reflect.String: + statement.idParam = schemas.PK{idValue.Convert(stringType).Interface()} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + statement.idParam = schemas.PK{idValue.Convert(intType).Interface()} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + statement.idParam = schemas.PK{idValue.Convert(uintType).Interface()} + case reflect.Slice: + if idType.ConvertibleTo(pkType) { + statement.idParam = idValue.Convert(pkType).Interface().(schemas.PK) + } + case reflect.Ptr: + if idType.ConvertibleTo(ptrPkType) { + statement.idParam = idValue.Convert(ptrPkType).Elem().Interface().(schemas.PK) + } + } + } + + if statement.idParam == nil { + statement.LastError = fmt.Errorf("ID param %#v is not supported", id) + } + + return statement +} + +// ProcessIDParam handles the process of id condition +func (statement *Statement) ProcessIDParam() error { + if statement.idParam == nil { + return nil + } + + if statement.RefTable == nil { + return ErrIDConditionWithNoTable{statement.idParam} + } + + if len(statement.RefTable.PrimaryKeys) != len(statement.idParam) { + return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", + len(statement.RefTable.PrimaryKeys), + len(statement.idParam), + ) + } + + for i, col := range statement.RefTable.PKColumns() { + var colName = statement.colName(col, statement.TableName()) + statement.cond = statement.cond.And(builder.Eq{colName: statement.idParam[i]}) + } + return nil +} diff --git a/internal/statements/query.go b/internal/statements/query.go new file mode 100644 index 00000000..ab3021bf --- /dev/null +++ b/internal/statements/query.go @@ -0,0 +1,441 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { + if len(sqlOrArgs) > 0 { + return statement.ConvertSQLOrArgs(sqlOrArgs...) + } + + if statement.RawSQL != "" { + return statement.GenRawSQL(), statement.RawParams, nil + } + + if len(statement.TableName()) <= 0 { + return "", nil, ErrTableNotFound + } + + var columnStr = statement.ColumnStr() + if len(statement.SelectStr) > 0 { + columnStr = statement.SelectStr + } else { + if statement.JoinStr == "" { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = "*" + } + } + } + if columnStr == "" { + columnStr = "*" + } + } + + if err := statement.ProcessIDParam(); err != nil { + return "", nil, err + } + + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) + if err != nil { + return "", nil, err + } + args := append(statement.joinArgs, condArgs...) + + // for mssql and use limit + qs := strings.Count(sqlStr, "?") + if len(args)*2 == qs { + args = append(args, args...) + } + + return sqlStr, args, nil +} + +func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.GenRawSQL(), statement.RawParams, nil + } + + statement.SetRefBean(bean) + + var sumStrs = make([]string, 0, len(columns)) + for _, colName := range columns { + if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { + colName = statement.quote(colName) + } else { + colName = statement.ReplaceQuote(colName) + } + sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) + } + sumSelect := strings.Join(sumStrs, ", ") + + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } + + sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil +} + +func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) { + v := rValue(bean) + isStruct := v.Kind() == reflect.Struct + if isStruct { + statement.SetRefBean(bean) + } + + var columnStr = statement.ColumnStr() + if len(statement.SelectStr) > 0 { + columnStr = statement.SelectStr + } else { + // TODO: always generate column names, not use * even if join + if len(statement.JoinStr) == 0 { + if len(columnStr) == 0 { + if len(statement.GroupByStr) > 0 { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if len(columnStr) == 0 { + if len(statement.GroupByStr) > 0 { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } + } + } + } + + if len(columnStr) == 0 { + columnStr = "*" + } + + if isStruct { + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } + } else { + if err := statement.ProcessIDParam(); err != nil { + return "", nil, err + } + } + + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil +} + +// GenCountSQL generates the SQL for counting +func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.GenRawSQL(), statement.RawParams, nil + } + + var condArgs []interface{} + var err error + if len(beans) > 0 { + statement.SetRefBean(beans[0]) + if err := statement.mergeConds(beans[0]); err != nil { + return "", nil, err + } + } + + var selectSQL = statement.SelectStr + if len(selectSQL) <= 0 { + if statement.IsDistinct { + selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) + } else if statement.ColumnStr() != "" { + selectSQL = fmt.Sprintf("count(%s)", statement.ColumnStr()) + } else { + selectSQL = "count(*)" + } + } + sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil +} + +func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { + var ( + distinct string + dialect = statement.dialect + quote = statement.quote + fromStr = " FROM " + top, mssqlCondi, whereStr string + ) + if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { + distinct = "DISTINCT " + } + + condSQL, condArgs, err := statement.GenCondSQL(statement.cond) + if err != nil { + return "", nil, err + } + if len(condSQL) > 0 { + whereStr = " WHERE " + condSQL + } + + if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { + fromStr += statement.TableName() + } else { + fromStr += quote(statement.TableName()) + } + + if statement.TableAlias != "" { + if dialect.URI().DBType == schemas.ORACLE { + fromStr += " " + quote(statement.TableAlias) + } else { + fromStr += " AS " + quote(statement.TableAlias) + } + } + if statement.JoinStr != "" { + fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) + } + + pLimitN := statement.LimitN + if dialect.URI().DBType == schemas.MSSQL { + if pLimitN != nil { + LimitNValue := *pLimitN + top = fmt.Sprintf("TOP %d ", LimitNValue) + } + if statement.Start > 0 { + var column string + if len(statement.RefTable.PKColumns()) == 0 { + for _, index := range statement.RefTable.Indexes { + if len(index.Cols) == 1 { + column = index.Cols[0] + break + } + } + if len(column) == 0 { + column = statement.RefTable.ColumnsSeq()[0] + } + } else { + column = statement.RefTable.PKColumns()[0].Name + } + if statement.needTableName() { + if len(statement.TableAlias) > 0 { + column = statement.TableAlias + "." + column + } else { + column = statement.TableName() + "." + column + } + } + + var orderStr string + if needOrderBy && len(statement.OrderStr) > 0 { + orderStr = " ORDER BY " + statement.OrderStr + } + + var groupStr string + if len(statement.GroupByStr) > 0 { + groupStr = " GROUP BY " + statement.GroupByStr + } + mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))", + column, statement.Start, column, fromStr, whereStr, orderStr, groupStr) + } + } + + var buf strings.Builder + fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) + if len(mssqlCondi) > 0 { + if len(whereStr) > 0 { + fmt.Fprint(&buf, " AND ", mssqlCondi) + } else { + fmt.Fprint(&buf, " WHERE ", mssqlCondi) + } + } + + if statement.GroupByStr != "" { + fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) + } + if statement.HavingStr != "" { + fmt.Fprint(&buf, " ", statement.HavingStr) + } + if needOrderBy && statement.OrderStr != "" { + fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) + } + if needLimit { + if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { + if statement.Start > 0 { + if pLimitN != nil { + fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) + } else { + fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start) + } + } else if pLimitN != nil { + fmt.Fprint(&buf, " LIMIT ", *pLimitN) + } + } else if dialect.URI().DBType == schemas.ORACLE { + if statement.Start != 0 || pLimitN != nil { + oldString := buf.String() + buf.Reset() + rawColStr := columnStr + if rawColStr == "*" { + rawColStr = "at.*" + } + fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) + } + } + } + if statement.IsForUpdate { + return dialect.ForUpdateSQL(buf.String()), condArgs, nil + } + + return buf.String(), condArgs, nil +} + +func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.GenRawSQL(), statement.RawParams, nil + } + + var sqlStr string + var args []interface{} + var joinStr string + var err error + if len(bean) == 0 { + tableName := statement.TableName() + if len(tableName) <= 0 { + return "", nil, ErrTableNotFound + } + + tableName = statement.quote(tableName) + if len(statement.JoinStr) > 0 { + joinStr = statement.JoinStr + } + + if statement.Conds().IsValid() { + condSQL, condArgs, err := statement.GenCondSQL(statement.Conds()) + if err != nil { + return "", nil, err + } + + if statement.dialect.URI().DBType == schemas.MSSQL { + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) + } else if statement.dialect.URI().DBType == schemas.ORACLE { + sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) + } + args = condArgs + } else { + if statement.dialect.URI().DBType == schemas.MSSQL { + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) + } else if statement.dialect.URI().DBType == schemas.ORACLE { + sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) + } + args = []interface{}{} + } + } else { + beanValue := reflect.ValueOf(bean[0]) + if beanValue.Kind() != reflect.Ptr { + return "", nil, errors.New("needs a pointer") + } + + if beanValue.Elem().Kind() == reflect.Struct { + if err := statement.SetRefBean(bean[0]); err != nil { + return "", nil, err + } + } + + if len(statement.TableName()) <= 0 { + return "", nil, ErrTableNotFound + } + statement.Limit(1) + sqlStr, args, err = statement.GenGetSQL(bean[0]) + if err != nil { + return "", nil, err + } + } + + return sqlStr, args, nil +} + +func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.GenRawSQL(), statement.RawParams, nil + } + + var sqlStr string + var args []interface{} + var err error + + if len(statement.TableName()) <= 0 { + return "", nil, ErrTableNotFound + } + + var columnStr = statement.ColumnStr() + if len(statement.SelectStr) > 0 { + columnStr = statement.SelectStr + } else { + if statement.JoinStr == "" { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = "*" + } + } + } + if columnStr == "" { + columnStr = "*" + } + } + + statement.cond = statement.cond.And(autoCond) + + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) + if err != nil { + return "", nil, err + } + args = append(statement.joinArgs, condArgs...) + // for mssql and use limit + qs := strings.Count(sqlStr, "?") + if len(args)*2 == qs { + args = append(args, args...) + } + + return sqlStr, args, nil +} diff --git a/internal/statements/statement.go b/internal/statements/statement.go new file mode 100644 index 00000000..ed7bdaeb --- /dev/null +++ b/internal/statements/statement.go @@ -0,0 +1,995 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strings" + "time" + + "xorm.io/builder" + "xorm.io/xorm/contexts" + "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" + "xorm.io/xorm/tags" +) + +var ( + // ErrConditionType condition type unsupported + ErrConditionType = errors.New("Unsupported condition type") + // ErrUnSupportedSQLType parameter of SQL is not supported + ErrUnSupportedSQLType = errors.New("Unsupported sql type") + // ErrUnSupportedType unsupported error + ErrUnSupportedType = errors.New("Unsupported type error") + // ErrTableNotFound table not found error + ErrTableNotFound = errors.New("Table not found") +) + +// Statement save all the sql info for executing SQL +type Statement struct { + RefTable *schemas.Table + dialect dialects.Dialect + defaultTimeZone *time.Location + tagParser *tags.Parser + Start int + LimitN *int + idParam schemas.PK + OrderStr string + JoinStr string + joinArgs []interface{} + GroupByStr string + HavingStr string + SelectStr string + useAllCols bool + AltTableName string + tableName string + RawSQL string + RawParams []interface{} + UseCascade bool + UseAutoJoin bool + StoreEngine string + Charset string + UseCache bool + UseAutoTime bool + NoAutoCondition bool + IsDistinct bool + IsForUpdate bool + TableAlias string + allUseBool bool + CheckVersion bool + unscoped bool + ColumnMap columnMap + OmitColumnMap columnMap + MustColumnMap map[string]bool + NullableMap map[string]bool + IncrColumns exprParams + DecrColumns exprParams + ExprColumns exprParams + cond builder.Cond + BufferSize int + Context contexts.ContextCache + LastError error +} + +// NewStatement creates a new statement +func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZone *time.Location) *Statement { + statement := &Statement{ + dialect: dialect, + tagParser: tagParser, + defaultTimeZone: defaultTimeZone, + } + statement.Reset() + return statement +} + +func (statement *Statement) SetTableName(tableName string) { + statement.tableName = tableName +} + +func (statement *Statement) omitStr() string { + return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,") +} + +// GenRawSQL generates correct raw sql +func (statement *Statement) GenRawSQL() string { + return statement.ReplaceQuote(statement.RawSQL) +} + +func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) { + condSQL, condArgs, err := builder.ToSQL(condOrBuilder) + if err != nil { + return "", nil, err + } + return statement.ReplaceQuote(condSQL), condArgs, nil +} + +func (statement *Statement) ReplaceQuote(sql string) string { + if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || + statement.dialect.URI().DBType == schemas.SQLITE { + return sql + } + return statement.dialect.Quoter().Replace(sql) +} + +func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { + statement.Context = ctxCache +} + +// Init reset all the statement's fields +func (statement *Statement) Reset() { + statement.RefTable = nil + statement.Start = 0 + statement.LimitN = nil + statement.OrderStr = "" + statement.UseCascade = true + statement.JoinStr = "" + statement.joinArgs = make([]interface{}, 0) + statement.GroupByStr = "" + statement.HavingStr = "" + statement.ColumnMap = columnMap{} + statement.OmitColumnMap = columnMap{} + statement.AltTableName = "" + statement.tableName = "" + statement.idParam = nil + statement.RawSQL = "" + statement.RawParams = make([]interface{}, 0) + statement.UseCache = true + statement.UseAutoTime = true + statement.NoAutoCondition = false + statement.IsDistinct = false + statement.IsForUpdate = false + statement.TableAlias = "" + statement.SelectStr = "" + statement.allUseBool = false + statement.useAllCols = false + statement.MustColumnMap = make(map[string]bool) + statement.NullableMap = make(map[string]bool) + statement.CheckVersion = true + statement.unscoped = false + statement.IncrColumns = exprParams{} + statement.DecrColumns = exprParams{} + statement.ExprColumns = exprParams{} + statement.cond = builder.NewCond() + statement.BufferSize = 0 + statement.Context = nil + statement.LastError = nil +} + +// NoAutoCondition if you do not want convert bean's field as query condition, then use this function +func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement { + statement.NoAutoCondition = true + if len(no) > 0 { + statement.NoAutoCondition = no[0] + } + return statement +} + +// Alias set the table alias +func (statement *Statement) Alias(alias string) *Statement { + statement.TableAlias = alias + return statement +} + +// SQL adds raw sql statement +func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { + switch query.(type) { + case (*builder.Builder): + var err error + statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() + if err != nil { + statement.LastError = err + } + case string: + statement.RawSQL = query.(string) + statement.RawParams = args + default: + statement.LastError = ErrUnSupportedSQLType + } + + return statement +} + +// Where add Where statement +func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement { + return statement.And(query, args...) +} + +func (statement *Statement) quote(s string) string { + return statement.dialect.Quoter().Quote(s) +} + +// And add Where & and statement +func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { + switch query.(type) { + case string: + cond := builder.Expr(query.(string), args...) + statement.cond = statement.cond.And(cond) + case map[string]interface{}: + queryMap := query.(map[string]interface{}) + newMap := make(map[string]interface{}) + for k, v := range queryMap { + newMap[statement.quote(k)] = v + } + statement.cond = statement.cond.And(builder.Eq(newMap)) + case builder.Cond: + cond := query.(builder.Cond) + statement.cond = statement.cond.And(cond) + for _, v := range args { + if vv, ok := v.(builder.Cond); ok { + statement.cond = statement.cond.And(vv) + } + } + default: + statement.LastError = ErrConditionType + } + + return statement +} + +// Or add Where & Or statement +func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { + switch query.(type) { + case string: + cond := builder.Expr(query.(string), args...) + statement.cond = statement.cond.Or(cond) + case map[string]interface{}: + cond := builder.Eq(query.(map[string]interface{})) + statement.cond = statement.cond.Or(cond) + case builder.Cond: + cond := query.(builder.Cond) + statement.cond = statement.cond.Or(cond) + for _, v := range args { + if vv, ok := v.(builder.Cond); ok { + statement.cond = statement.cond.Or(vv) + } + } + default: + // TODO: not support condition type + } + return statement +} + +// In generate "Where column IN (?) " statement +func (statement *Statement) In(column string, args ...interface{}) *Statement { + in := builder.In(statement.quote(column), args...) + statement.cond = statement.cond.And(in) + return statement +} + +// NotIn generate "Where column NOT IN (?) " statement +func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { + notIn := builder.NotIn(statement.quote(column), args...) + statement.cond = statement.cond.And(notIn) + return statement +} + +func (statement *Statement) SetRefValue(v reflect.Value) error { + var err error + statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v)) + if err != nil { + return err + } + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), v, true) + return nil +} + +func rValue(bean interface{}) reflect.Value { + return reflect.Indirect(reflect.ValueOf(bean)) +} + +func (statement *Statement) SetRefBean(bean interface{}) error { + var err error + statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean)) + if err != nil { + return err + } + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), bean, true) + return nil +} + +func (statement *Statement) needTableName() bool { + return len(statement.JoinStr) > 0 +} + +func (statement *Statement) colName(col *schemas.Column, tableName string) string { + if statement.needTableName() { + var nm = tableName + if len(statement.TableAlias) > 0 { + nm = statement.TableAlias + } + return statement.quote(nm) + "." + statement.quote(col.Name) + } + return statement.quote(col.Name) +} + +// TableName return current tableName +func (statement *Statement) TableName() string { + if statement.AltTableName != "" { + return statement.AltTableName + } + + return statement.tableName +} + +// Incr Generate "Update ... Set column = column + arg" statement +func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { + if len(arg) > 0 { + statement.IncrColumns.addParam(column, arg[0]) + } else { + statement.IncrColumns.addParam(column, 1) + } + return statement +} + +// Decr Generate "Update ... Set column = column - arg" statement +func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { + if len(arg) > 0 { + statement.DecrColumns.addParam(column, arg[0]) + } else { + statement.DecrColumns.addParam(column, 1) + } + return statement +} + +// SetExpr Generate "Update ... Set column = {expression}" statement +func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { + if e, ok := expression.(string); ok { + statement.ExprColumns.addParam(column, statement.dialect.Quoter().Replace(e)) + } else { + statement.ExprColumns.addParam(column, expression) + } + return statement +} + +// Distinct generates "DISTINCT col1, col2 " statement +func (statement *Statement) Distinct(columns ...string) *Statement { + statement.IsDistinct = true + statement.Cols(columns...) + return statement +} + +// ForUpdate generates "SELECT ... FOR UPDATE" statement +func (statement *Statement) ForUpdate() *Statement { + statement.IsForUpdate = true + return statement +} + +// Select replace select +func (statement *Statement) Select(str string) *Statement { + statement.SelectStr = statement.ReplaceQuote(str) + return statement +} + +func col2NewCols(columns ...string) []string { + newColumns := make([]string, 0, len(columns)) + for _, col := range columns { + col = strings.Replace(col, "`", "", -1) + col = strings.Replace(col, `"`, "", -1) + ccols := strings.Split(col, ",") + for _, c := range ccols { + newColumns = append(newColumns, strings.TrimSpace(c)) + } + } + return newColumns +} + +// Cols generate "col1, col2" statement +func (statement *Statement) Cols(columns ...string) *Statement { + cols := col2NewCols(columns...) + for _, nc := range cols { + statement.ColumnMap.Add(nc) + } + return statement +} + +func (statement *Statement) ColumnStr() string { + return statement.dialect.Quoter().Join(statement.ColumnMap, ", ") +} + +// AllCols update use only: update all columns +func (statement *Statement) AllCols() *Statement { + statement.useAllCols = true + return statement +} + +// MustCols update use only: must update columns +func (statement *Statement) MustCols(columns ...string) *Statement { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.MustColumnMap[strings.ToLower(nc)] = true + } + return statement +} + +// UseBool indicates that use bool fields as update contents and query contiditions +func (statement *Statement) UseBool(columns ...string) *Statement { + if len(columns) > 0 { + statement.MustCols(columns...) + } else { + statement.allUseBool = true + } + return statement +} + +// Omit do not use the columns +func (statement *Statement) Omit(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.OmitColumnMap = append(statement.OmitColumnMap, nc) + } +} + +// Nullable Update use only: update columns to null when value is nullable and zero-value +func (statement *Statement) Nullable(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.NullableMap[strings.ToLower(nc)] = true + } +} + +// Top generate LIMIT limit statement +func (statement *Statement) Top(limit int) *Statement { + statement.Limit(limit) + return statement +} + +// Limit generate LIMIT start, limit statement +func (statement *Statement) Limit(limit int, start ...int) *Statement { + statement.LimitN = &limit + if len(start) > 0 { + statement.Start = start[0] + } + return statement +} + +// OrderBy generate "Order By order" statement +func (statement *Statement) OrderBy(order string) *Statement { + if len(statement.OrderStr) > 0 { + statement.OrderStr += ", " + } + statement.OrderStr += statement.ReplaceQuote(order) + return statement +} + +// Desc generate `ORDER BY xx DESC` +func (statement *Statement) Desc(colNames ...string) *Statement { + var buf strings.Builder + if len(statement.OrderStr) > 0 { + fmt.Fprint(&buf, statement.OrderStr, ", ") + } + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " DESC") + } + statement.OrderStr = buf.String() + return statement +} + +// Asc provide asc order by query condition, the input parameters are columns. +func (statement *Statement) Asc(colNames ...string) *Statement { + var buf strings.Builder + if len(statement.OrderStr) > 0 { + fmt.Fprint(&buf, statement.OrderStr, ", ") + } + for i, col := range colNames { + if i > 0 { + fmt.Fprint(&buf, ", ") + } + statement.dialect.Quoter().QuoteTo(&buf, col) + fmt.Fprint(&buf, " ASC") + } + statement.OrderStr = buf.String() + return statement +} + +func (statement *Statement) Conds() builder.Cond { + return statement.cond +} + +// Table tempororily set table name, the parameter could be a string or a pointer of struct +func (statement *Statement) SetTable(tableNameOrBean interface{}) error { + v := rValue(tableNameOrBean) + t := v.Type() + if t.Kind() == reflect.Struct { + var err error + statement.RefTable, err = statement.tagParser.ParseWithCache(v) + if err != nil { + return err + } + } + + statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tableNameOrBean, true) + return nil +} + +// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN +func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { + var buf strings.Builder + if len(statement.JoinStr) > 0 { + fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) + } else { + fmt.Fprintf(&buf, "%v JOIN ", joinOP) + } + + switch tp := tablename.(type) { + case builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.LastError = err + return statement + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) + statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + case *builder.Builder: + subSQL, subQueryArgs, err := tp.ToSQL() + if err != nil { + statement.LastError = err + return statement + } + + fields := strings.Split(tp.TableName(), ".") + aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1]) + aliasName = schemas.CommonQuoter.Trim(aliasName) + + fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), aliasName, statement.ReplaceQuote(condition)) + statement.joinArgs = append(statement.joinArgs, subQueryArgs...) + default: + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) + if !utils.IsSubQuery(tbName) { + var buf strings.Builder + statement.dialect.Quoter().QuoteTo(&buf, tbName) + tbName = buf.String() + } + fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condition)) + } + + statement.JoinStr = buf.String() + statement.joinArgs = append(statement.joinArgs, args...) + return statement +} + +// tbName get some table's table name +func (statement *Statement) tbNameNoSchema(table *schemas.Table) string { + if len(statement.AltTableName) > 0 { + return statement.AltTableName + } + + return table.Name +} + +// GroupBy generate "Group By keys" statement +func (statement *Statement) GroupBy(keys string) *Statement { + statement.GroupByStr = statement.ReplaceQuote(keys) + return statement +} + +// Having generate "Having conditions" statement +func (statement *Statement) Having(conditions string) *Statement { + statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions)) + return statement +} + +// Unscoped always disable struct tag "deleted" +func (statement *Statement) SetUnscoped() *Statement { + statement.unscoped = true + return statement +} + +func (statement *Statement) GetUnscoped() bool { + return statement.unscoped +} + +func (statement *Statement) genColumnStr() string { + if statement.RefTable == nil { + return "" + } + + var buf strings.Builder + columns := statement.RefTable.Columns() + + for _, col := range columns { + if statement.OmitColumnMap.Contain(col.Name) { + continue + } + + if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) { + continue + } + + if col.MapType == schemas.ONLYTODB { + continue + } + + if buf.Len() != 0 { + buf.WriteString(", ") + } + + if statement.JoinStr != "" { + if statement.TableAlias != "" { + buf.WriteString(statement.TableAlias) + } else { + buf.WriteString(statement.TableName()) + } + + buf.WriteString(".") + } + + statement.dialect.Quoter().QuoteTo(&buf, col.Name) + } + + return buf.String() +} + +func (statement *Statement) GenCreateTableSQL() []string { + statement.RefTable.StoreEngine = statement.StoreEngine + statement.RefTable.Charset = statement.Charset + s, _ := statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName()) + return s +} + +func (statement *Statement) GenIndexSQL() []string { + var sqls []string + tbName := statement.TableName() + for _, index := range statement.RefTable.Indexes { + if index.Type == schemas.IndexType { + sql := statement.dialect.CreateIndexSQL(tbName, index) + sqls = append(sqls, sql) + } + } + return sqls +} + +func uniqueName(tableName, uqeName string) string { + return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) +} + +func (statement *Statement) GenUniqueSQL() []string { + var sqls []string + tbName := statement.TableName() + for _, index := range statement.RefTable.Indexes { + if index.Type == schemas.UniqueType { + sql := statement.dialect.CreateIndexSQL(tbName, index) + sqls = append(sqls, sql) + } + } + return sqls +} + +func (statement *Statement) GenDelIndexSQL() []string { + var sqls []string + tbName := statement.TableName() + idx := strings.Index(tbName, ".") + if idx > -1 { + tbName = tbName[idx+1:] + } + for _, index := range statement.RefTable.Indexes { + sqls = append(sqls, statement.dialect.DropIndexSQL(tbName, index)) + } + return sqls +} + +func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, + includeVersion bool, includeUpdated bool, includeNil bool, + includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, + mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { + var conds []builder.Cond + for _, col := range table.Columns() { + if !includeVersion && col.IsVersion { + continue + } + if !includeUpdated && col.IsUpdated { + continue + } + if !includeAutoIncr && col.IsAutoIncrement { + continue + } + + if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text || + col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { + continue + } + if col.SQLType.IsJson() { + continue + } + + var colName string + if addedTableName { + var nm = tableName + if len(aliasName) > 0 { + nm = aliasName + } + colName = statement.quote(nm) + "." + statement.quote(col.Name) + } else { + colName = statement.quote(col.Name) + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + if !strings.Contains(err.Error(), "is not valid") { + //engine.logger.Warn(err) + } + continue + } + + if col.IsDeleted && !unscoped { // tag "deleted" is enabled + conds = append(conds, statement.CondDeleted(col)) + } + + fieldValue := *fieldValuePtr + if fieldValue.Interface() == nil { + continue + } + + fieldType := reflect.TypeOf(fieldValue.Interface()) + requiredField := useAllCols + + if b, ok := getFlagForColumn(mustColumnMap, col); ok { + if b { + requiredField = true + } else { + continue + } + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + conds = append(conds, builder.Eq{colName: nil}) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + var val interface{} + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { + continue + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = valNul.Value() + if val == nil && !requiredField { + continue + } + } else { + if col.SQLType.IsJson() { + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = bytes + } + } else { + table, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + val = fieldValue.Interface() + } else { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { + val = pkField.Interface() + } else { + continue + } + } else { + //TODO: how to handler? + return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } + } + } + case reflect.Array: + continue + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + conds = append(conds, builder.Eq{colName: val}) + } + + return builder.And(conds...), nil +} + +func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { + return statement.buildConds2(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, + statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) +} + +func (statement *Statement) mergeConds(bean interface{}) error { + if !statement.NoAutoCondition && statement.RefTable != nil { + var addedTableName = (len(statement.JoinStr) > 0) + autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) + if err != nil { + return err + } + statement.cond = statement.cond.And(autoCond) + } + + if err := statement.ProcessIDParam(); err != nil { + return err + } + return nil +} + +func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) { + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } + + return statement.GenCondSQL(statement.cond) +} + +func (statement *Statement) quoteColumnStr(columnStr string) string { + columns := strings.Split(columnStr, ",") + return statement.dialect.Quoter().Join(columns, ",") +} + +func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { + sql, args, err := convertSQLOrArgs(sqlOrArgs...) + if err != nil { + return "", nil, err + } + return statement.ReplaceQuote(sql), args, nil +} + +func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { + switch sqlOrArgs[0].(type) { + case string: + return sqlOrArgs[0].(string), sqlOrArgs[1:], nil + case *builder.Builder: + return sqlOrArgs[0].(*builder.Builder).ToSQL() + case builder.Builder: + bd := sqlOrArgs[0].(builder.Builder) + return bd.ToSQL() + } + + return "", nil, ErrUnSupportedType +} + +func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string { + var colnames = make([]string, len(cols)) + for i, col := range cols { + if includeTableName { + colnames[i] = statement.quote(statement.TableName()) + + "." + statement.quote(col.Name) + } else { + colnames[i] = statement.quote(col.Name) + } + } + return strings.Join(colnames, ", ") +} + +// CondDeleted returns the conditions whether a record is soft deleted. +func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { + var colName = col.Name + if statement.JoinStr != "" { + var prefix string + if statement.TableAlias != "" { + prefix = statement.TableAlias + } else { + prefix = statement.TableName() + } + colName = statement.quote(prefix) + "." + statement.quote(col.Name) + } + var cond = builder.NewCond() + if col.SQLType.IsNumeric() { + cond = builder.Eq{colName: 0} + } else { + // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. + if statement.dialect.URI().DBType != schemas.MSSQL { + cond = builder.Eq{colName: utils.ZeroTime1} + } + } + + if col.Nullable { + cond = cond.Or(builder.IsNull{colName}) + } + + return cond +} diff --git a/statement_args.go b/internal/statements/statement_args.go similarity index 64% rename from statement_args.go rename to internal/statements/statement_args.go index 310f24d6..dc14467d 100644 --- a/statement_args.go +++ b/internal/statements/statement_args.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "fmt" @@ -11,7 +11,7 @@ import ( "time" "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/schemas" ) func quoteNeeded(a interface{}) bool { @@ -77,30 +77,8 @@ func convertArg(arg interface{}, convertFunc func(string) string) string { const insertSelectPlaceHolder = true -func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { +func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { - case bool: - if statement.Engine.dialect.DBType() == core.MSSQL { - if argv { - if _, err := w.WriteString("1"); err != nil { - return err - } - } else { - if _, err := w.WriteString("0"); err != nil { - return err - } - } - } else { - if argv { - if _, err := w.WriteString("true"); err != nil { - return err - } - } else { - if _, err := w.WriteString("false"); err != nil { - return err - } - } - } case *builder.Builder: if _, err := w.WriteString("("); err != nil { return err @@ -116,10 +94,18 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er if err := w.WriteByte('?'); err != nil { return err } - w.Append(arg) + if v, ok := arg.(bool); ok && statement.dialect.URI().DBType == schemas.MSSQL { + if v { + w.Append(1) + } else { + w.Append(0) + } + } else { + w.Append(arg) + } } else { var convertFunc = convertStringSingleQuote - if statement.Engine.dialect.DBType() == core.MYSQL { + if statement.dialect.URI().DBType == schemas.MYSQL { convertFunc = convertString } if _, err := w.WriteString(convertArg(arg, convertFunc)); err != nil { @@ -130,9 +116,9 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er return nil } -func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{}) error { +func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error { for i, arg := range args { - if err := statement.writeArg(w, arg); err != nil { + if err := statement.WriteArg(w, arg); err != nil { return err } @@ -144,27 +130,3 @@ func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{} } return nil } - -func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error { - for i, colName := range cols { - if len(leftQuote) > 0 && colName[0] != '`' { - if _, err := w.WriteString(leftQuote); err != nil { - return err - } - } - if _, err := w.WriteString(colName); err != nil { - return err - } - if len(rightQuote) > 0 && colName[len(colName)-1] != '`' { - if _, err := w.WriteString(rightQuote); err != nil { - return err - } - } - if i+1 != len(cols) { - if _, err := w.WriteString(","); err != nil { - return err - } - } - } - return nil -} diff --git a/statement_test.go b/internal/statements/statement_test.go similarity index 57% rename from statement_test.go rename to internal/statements/statement_test.go index acc542ab..15f446f4 100644 --- a/statement_test.go +++ b/internal/statements/statement_test.go @@ -2,17 +2,43 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "reflect" "strings" "testing" + "time" "github.com/stretchr/testify/assert" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" + "xorm.io/xorm/names" + "xorm.io/xorm/schemas" + "xorm.io/xorm/tags" + + _ "github.com/mattn/go-sqlite3" ) +var ( + dialect dialects.Dialect + tagParser *tags.Parser +) + +func TestMain(m *testing.M) { + var err error + dialect, err = dialects.OpenDialect("sqlite3", "./test.db") + if err != nil { + panic("unknow dialect") + } + + tagParser = tags.NewParser("xorm", dialect, names.SnakeMapper{}, names.SnakeMapper{}, caches.NewManager()) + if tagParser == nil { + panic("tags parser is nil") + } + m.Run() +} + var colStrTests = []struct { omitColumn string onlyToDBColumnNdx int @@ -27,14 +53,9 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { - if dbType == "postgres" || dbType == "mssql" { - return - } - - var statement *Statement - for ndx, testCase := range colStrTests { - statement = createTestStatement() + statement, err := createTestStatement() + assert.NoError(t, err) if testCase.omitColumn != "" { statement.Omit(testCase.omitColumn) @@ -42,7 +63,7 @@ func TestColumnsStringGeneration(t *testing.T) { columns := statement.RefTable.Columns() if testCase.onlyToDBColumnNdx >= 0 { - columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB + columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB } actual := statement.genColumnStr() @@ -51,34 +72,7 @@ func TestColumnsStringGeneration(t *testing.T) { t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual) } if testCase.onlyToDBColumnNdx >= 0 { - columns[testCase.onlyToDBColumnNdx].MapType = core.TWOSIDES - } - } -} - -func BenchmarkColumnsStringGeneration(b *testing.B) { - b.StopTimer() - - statement := createTestStatement() - - testCase := colStrTests[0] - - if testCase.omitColumn != "" { - statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped - } - - if testCase.onlyToDBColumnNdx >= 0 { - columns := statement.RefTable.Columns() - columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB // !nemec784! Column must be skipped - } - - b.StartTimer() - - for i := 0; i < b.N; i++ { - actual := statement.genColumnStr() - - if actual != testCase.expected { - b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) + columns[testCase.onlyToDBColumnNdx].MapType = schemas.TWOSIDES } } } @@ -88,7 +82,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { b.StopTimer() mapCols := make(map[string]bool) - cols := []*core.Column{ + cols := []*schemas.Column{ {Name: `ID`}, {Name: `IsDeleted`}, {Name: `Caption`}, @@ -122,7 +116,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { b.StopTimer() mapCols := make(map[string]bool) - cols := []*core.Column{ + cols := []*schemas.Column{ {Name: `ID`}, {Name: `IsDeleted`}, {Name: `Caption`}, @@ -163,86 +157,40 @@ func (TestType) TableName() string { return "TestTable" } -func createTestStatement() *Statement { - if engine, ok := testEngine.(*Engine); ok { - statement := &Statement{} - statement.Init() - statement.Engine = engine - statement.setRefValue(reflect.ValueOf(TestType{})) - - return statement - } else if eg, ok := testEngine.(*EngineGroup); ok { - statement := &Statement{} - statement.Init() - statement.Engine = eg.Engine - statement.setRefValue(reflect.ValueOf(TestType{})) - - return statement +func createTestStatement() (*Statement, error) { + statement := NewStatement(dialect, tagParser, time.Local) + if err := statement.SetRefValue(reflect.ValueOf(TestType{})); err != nil { + return nil, err } - return nil + return statement, nil } -func TestDistinctAndCols(t *testing.T) { - type DistinctAndCols struct { - Id int64 - Name string +func BenchmarkColumnsStringGeneration(b *testing.B) { + b.StopTimer() + + statement, err := createTestStatement() + if err != nil { + panic(err) } - assert.NoError(t, prepareEngine()) - assertSync(t, new(DistinctAndCols)) + testCase := colStrTests[0] - cnt, err := testEngine.Insert(&DistinctAndCols{ - Name: "test", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var names []string - err = testEngine.Table("distinct_and_cols").Cols("name").Distinct("name").Find(&names) - assert.NoError(t, err) - assert.EqualValues(t, 1, len(names)) - assert.EqualValues(t, "test", names[0]) -} - -func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { - type TestOnlyFromDBField struct { - Id int64 `xorm:"PK"` - OnlyFromDBField string `xorm:"<-"` - OnlyToDBField string `xorm:"->"` - IngoreField string `xorm:"-"` + if testCase.omitColumn != "" { + statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped } - assertGetRecord := func() *TestOnlyFromDBField { - var record TestOnlyFromDBField - has, err := testEngine.Where("id = ?", 1).Get(&record) - assert.NoError(t, err) - assert.EqualValues(t, true, has) - assert.EqualValues(t, "", record.OnlyFromDBField) - return &record - + if testCase.onlyToDBColumnNdx >= 0 { + columns := statement.RefTable.Columns() + columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped } - assert.NoError(t, prepareEngine()) - assertSync(t, new(TestOnlyFromDBField)) - _, err := testEngine.Insert(&TestOnlyFromDBField{ - Id: 1, - OnlyFromDBField: "a", - OnlyToDBField: "b", - IngoreField: "c", - }) - assert.NoError(t, err) + b.StartTimer() - record := assertGetRecord() - record.OnlyFromDBField = "test" - testEngine.Update(record) - assertGetRecord() -} - -func TestCol2NewColsWithQuote(t *testing.T) { - cols := []string{"f1", "f2", "t3.f3"} - - statement := createTestStatement() - - quotedCols := statement.col2NewColsWithQuote(cols...) - assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols) + for i := 0; i < b.N; i++ { + actual := statement.genColumnStr() + + if actual != testCase.expected { + b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) + } + } } diff --git a/internal/statements/update.go b/internal/statements/update.go new file mode 100644 index 00000000..b6ae118e --- /dev/null +++ b/internal/statements/update.go @@ -0,0 +1,294 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "database/sql/driver" + "errors" + "fmt" + "reflect" + "time" + + "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil, + includeAutoIncr, update bool) (bool, error) { + columnMap := statement.ColumnMap + omitColumnMap := statement.OmitColumnMap + unscoped := statement.unscoped + + if !includeVersion && col.IsVersion { + return false, nil + } + if col.IsCreated && !columnMap.Contain(col.Name) { + return false, nil + } + if !includeUpdated && col.IsUpdated { + return false, nil + } + if !includeAutoIncr && col.IsAutoIncrement { + return false, nil + } + if col.IsDeleted && !unscoped { + return false, nil + } + if omitColumnMap.Contain(col.Name) { + return false, nil + } + if len(columnMap) > 0 && !columnMap.Contain(col.Name) { + return false, nil + } + + if col.MapType == schemas.ONLYFROMDB { + return false, nil + } + + if statement.IncrColumns.IsColExist(col.Name) { + return false, nil + } else if statement.DecrColumns.IsColExist(col.Name) { + return false, nil + } else if statement.ExprColumns.IsColExist(col.Name) { + return false, nil + } + + return true, nil +} + +// BuildUpdates auto generating update columnes and values according a struct +func (statement *Statement) BuildUpdates(tableValue reflect.Value, + includeVersion, includeUpdated, includeNil, + includeAutoIncr, update bool) ([]string, []interface{}, error) { + table := statement.RefTable + allUseBool := statement.allUseBool + useAllCols := statement.useAllCols + mustColumnMap := statement.MustColumnMap + nullableMap := statement.NullableMap + + var colNames = make([]string, 0) + var args = make([]interface{}, 0) + + for _, col := range table.Columns() { + ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil, + includeAutoIncr, update) + if err != nil { + return nil, nil, err + } + if !ok { + continue + } + + fieldValuePtr, err := col.ValueOfV(&tableValue) + if err != nil { + return nil, nil, err + } + + fieldValue := *fieldValuePtr + fieldType := reflect.TypeOf(fieldValue.Interface()) + if fieldType == nil { + continue + } + + requiredField := useAllCols + includeNil := useAllCols + + if b, ok := getFlagForColumn(mustColumnMap, col); ok { + if b { + requiredField = true + } else { + continue + } + } + + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if b, ok := getFlagForColumn(nullableMap, col); ok { + if b && col.Nullable && utils.IsZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + fieldType = reflect.TypeOf(fieldValue.Interface()) + includeNil = true + } + } + + var val interface{} + + if fieldValue.CanAddr() { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + data, err := structConvert.ToDB() + if err != nil { + return nil, nil, err + } + + val = data + goto APPEND + } + } + + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, err := structConvert.ToDB() + if err != nil { + return nil, nil, err + } + + val = data + goto APPEND + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + args = append(args, nil) + colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name))) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = nulType.Value() + if val == nil && !requiredField { + continue + } + } else { + if !col.SQLType.IsJson() { + table, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + val = fieldValue.Interface() + } else { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + if pkField.IsValid() && (!requiredField && !utils.IsZero(pkField.Interface())) { + val = pkField.Interface() + } else { + continue + } + } else { + return nil, nil, errors.New("Not supported multiple primary keys") + } + } + } else { + // Blank struct could not be as update data + if requiredField || !utils.IsStructZero(fieldValue) { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, nil, fmt.Errorf("mashal %v failed", fieldValue.Interface()) + } + if col.SQLType.IsText() { + val = string(bytes) + } else if col.SQLType.IsBlob() { + val = bytes + } + } else { + continue + } + } + } + case reflect.Array, reflect.Slice, reflect.Map: + if !requiredField { + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldType.Kind() == reflect.Array { + if utils.IsArrayZero(fieldValue) { + continue + } + } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, nil, err + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if fieldType.Kind() == reflect.Slice && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else if fieldType.Kind() == reflect.Array && + fieldType.Elem().Kind() == reflect.Uint8 { + val = fieldValue.Slice(0, 0).Interface() + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, nil, err + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + APPEND: + args = append(args, val) + colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name))) + } + + return colNames, args, nil +} diff --git a/internal/statements/values.go b/internal/statements/values.go new file mode 100644 index 00000000..a1102c54 --- /dev/null +++ b/internal/statements/values.go @@ -0,0 +1,154 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "time" + + "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/schemas" +) + +var ( + nullFloatType = reflect.TypeOf(sql.NullFloat64{}) +) + +// Value2Interface convert a field value of a struct to interface for puting into database +func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) { + if fieldValue.CanAddr() { + if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return nil, err + } + if col.SQLType.IsBlob() { + return data, nil + } + return string(data), nil + } + } + + if fieldConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return nil, err + } + if col.SQLType.IsBlob() { + return data, nil + } + if nil == data { + return nil, nil + } + return string(data), nil + } + + fieldType := fieldValue.Type() + k := fieldType.Kind() + if k == reflect.Ptr { + if fieldValue.IsNil() { + return nil, nil + } else if !fieldValue.IsValid() { + return nil, nil + } else { + // !nashtsai! deference pointer type to instance type + fieldValue = fieldValue.Elem() + fieldType = fieldValue.Type() + k = fieldType.Kind() + } + } + + switch k { + case reflect.Bool: + return fieldValue.Bool(), nil + case reflect.String: + return fieldValue.String(), nil + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + tf := dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + return tf, nil + } else if fieldType.ConvertibleTo(nullFloatType) { + t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) + if !t.Valid { + return nil, nil + } + return t.Float64, nil + } + + if !col.SQLType.IsJson() { + // !! 增加支持driver.Valuer接口的结构,如sql.NullString + if v, ok := fieldValue.Interface().(driver.Valuer); ok { + return v.Value() + } + + fieldTable, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + return nil, err + } + if len(fieldTable.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) + return pkField.Interface(), nil + } + return nil, fmt.Errorf("no primary key for col %v", col.Name) + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return bytes, nil + } + return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) + case reflect.Complex64, reflect.Complex128: + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return string(bytes), nil + case reflect.Array, reflect.Slice, reflect.Map: + if !fieldValue.IsValid() { + return fieldValue.Interface(), nil + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (k == reflect.Slice) && + (fieldValue.Type().Elem().Kind() == reflect.Uint8) { + bytes = fieldValue.Bytes() + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + } + return bytes, nil + } + return nil, ErrUnSupportedType + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return fieldValue.Uint(), nil + default: + return fieldValue.Interface(), nil + } +} diff --git a/internal/utils/name.go b/internal/utils/name.go new file mode 100644 index 00000000..f5fc3ff7 --- /dev/null +++ b/internal/utils/name.go @@ -0,0 +1,13 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "fmt" +) + +func IndexName(tableName, idxName string) string { + return fmt.Sprintf("IDX_%v_%v", tableName, idxName) +} diff --git a/internal/utils/reflect.go b/internal/utils/reflect.go new file mode 100644 index 00000000..3dad6bfe --- /dev/null +++ b/internal/utils/reflect.go @@ -0,0 +1,13 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "reflect" +) + +func ReflectValue(bean interface{}) reflect.Value { + return reflect.Indirect(reflect.ValueOf(bean)) +} diff --git a/internal/utils/slice.go b/internal/utils/slice.go new file mode 100644 index 00000000..89685706 --- /dev/null +++ b/internal/utils/slice.go @@ -0,0 +1,22 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import "sort" + +// SliceEq return true if two slice have the same elements even if different sort. +func SliceEq(left, right []string) bool { + if len(left) != len(right) { + return false + } + sort.Sort(sort.StringSlice(left)) + sort.Sort(sort.StringSlice(right)) + for i := 0; i < len(left); i++ { + if left[i] != right[i] { + return false + } + } + return true +} diff --git a/internal/utils/sql.go b/internal/utils/sql.go new file mode 100644 index 00000000..5e68c4a4 --- /dev/null +++ b/internal/utils/sql.go @@ -0,0 +1,19 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "strings" +) + +func IsSubQuery(tbName string) bool { + const selStr = "select" + if len(tbName) <= len(selStr)+1 { + return false + } + + return strings.EqualFold(tbName[:len(selStr)], selStr) || + strings.EqualFold(tbName[:len(selStr)+1], "("+selStr) +} diff --git a/internal/utils/strings.go b/internal/utils/strings.go new file mode 100644 index 00000000..b5dc37b7 --- /dev/null +++ b/internal/utils/strings.go @@ -0,0 +1,30 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "strings" +) + +func IndexNoCase(s, sep string) int { + return strings.Index(strings.ToLower(s), strings.ToLower(sep)) +} + +func SplitNoCase(s, sep string) []string { + idx := IndexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.Split(s, s[idx:idx+len(sep)]) +} + +func SplitNNoCase(s, sep string, n int) []string { + idx := IndexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.SplitN(s, s[idx:idx+len(sep)], n) +} + diff --git a/internal/utils/zero.go b/internal/utils/zero.go new file mode 100644 index 00000000..8f033c60 --- /dev/null +++ b/internal/utils/zero.go @@ -0,0 +1,145 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "reflect" + "time" +) + +type Zeroable interface { + IsZero() bool +} + +var nilTime *time.Time + +// IsZero returns false if k is nil or has a zero value +func IsZero(k interface{}) bool { + if k == nil { + return true + } + + switch k.(type) { + case int: + return k.(int) == 0 + case int8: + return k.(int8) == 0 + case int16: + return k.(int16) == 0 + case int32: + return k.(int32) == 0 + case int64: + return k.(int64) == 0 + case uint: + return k.(uint) == 0 + case uint8: + return k.(uint8) == 0 + case uint16: + return k.(uint16) == 0 + case uint32: + return k.(uint32) == 0 + case uint64: + return k.(uint64) == 0 + case float32: + return k.(float32) == 0 + case float64: + return k.(float64) == 0 + case bool: + return k.(bool) == false + case string: + return k.(string) == "" + case *time.Time: + return k.(*time.Time) == nilTime || IsTimeZero(*k.(*time.Time)) + case time.Time: + return IsTimeZero(k.(time.Time)) + case Zeroable: + return k.(Zeroable) == nil || k.(Zeroable).IsZero() + case reflect.Value: // for go version less than 1.13 because reflect.Value has no method IsZero + return IsValueZero(k.(reflect.Value)) + } + + return IsValueZero(reflect.ValueOf(k)) +} + +var zeroType = reflect.TypeOf((*Zeroable)(nil)).Elem() + +func IsValueZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Slice: + return v.IsNil() + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: + return v.Int() == 0 + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: + return v.Uint() == 0 + case reflect.String: + return v.Len() == 0 + case reflect.Ptr: + if v.IsNil() { + return true + } + return IsValueZero(v.Elem()) + case reflect.Struct: + return IsStructZero(v) + case reflect.Array: + return IsArrayZero(v) + } + return false +} + +func IsStructZero(v reflect.Value) bool { + if !v.IsValid() || v.NumField() == 0 { + return true + } + + if v.Type().Implements(zeroType) { + f := v.MethodByName("IsZero") + if f.IsValid() { + res := f.Call(nil) + return len(res) == 1 && res[0].Bool() + } + } + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + switch field.Kind() { + case reflect.Ptr: + field = field.Elem() + fallthrough + case reflect.Struct: + if !IsStructZero(field) { + return false + } + default: + if field.CanInterface() && !IsZero(field.Interface()) { + return false + } + } + } + return true +} + +func IsArrayZero(v reflect.Value) bool { + if !v.IsValid() || v.Len() == 0 { + return true + } + + for i := 0; i < v.Len(); i++ { + if !IsZero(v.Index(i).Interface()) { + return false + } + } + + return true +} + +const ( + ZeroTime0 = "0000-00-00 00:00:00" + ZeroTime1 = "0001-01-01 00:00:00" +) + +func IsTimeZero(t time.Time) bool { + return t.IsZero() || t.Format("2006-01-02 15:04:05") == ZeroTime0 || + t.Format("2006-01-02 15:04:05") == ZeroTime1 +} diff --git a/internal/utils/zero_test.go b/internal/utils/zero_test.go new file mode 100644 index 00000000..a5f4912a --- /dev/null +++ b/internal/utils/zero_test.go @@ -0,0 +1,73 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type MyInt int +type ZeroStruct struct{} + +func TestZero(t *testing.T) { + var zeroValues = []interface{}{ + int8(0), + int16(0), + int(0), + int32(0), + int64(0), + uint8(0), + uint16(0), + uint(0), + uint32(0), + uint64(0), + MyInt(0), + reflect.ValueOf(0), + nil, + time.Time{}, + &time.Time{}, + nilTime, + ZeroStruct{}, + &ZeroStruct{}, + } + + for _, v := range zeroValues { + t.Run(fmt.Sprintf("%#v", v), func(t *testing.T) { + assert.True(t, IsZero(v)) + }) + } +} + +func TestIsValueZero(t *testing.T) { + var zeroReflectValues = []reflect.Value{ + reflect.ValueOf(int8(0)), + reflect.ValueOf(int16(0)), + reflect.ValueOf(int(0)), + reflect.ValueOf(int32(0)), + reflect.ValueOf(int64(0)), + reflect.ValueOf(uint8(0)), + reflect.ValueOf(uint16(0)), + reflect.ValueOf(uint(0)), + reflect.ValueOf(uint32(0)), + reflect.ValueOf(uint64(0)), + reflect.ValueOf(MyInt(0)), + reflect.ValueOf(time.Time{}), + reflect.ValueOf(&time.Time{}), + reflect.ValueOf(nilTime), + reflect.ValueOf(ZeroStruct{}), + reflect.ValueOf(&ZeroStruct{}), + } + + for _, v := range zeroReflectValues { + t.Run(fmt.Sprintf("%#v", v), func(t *testing.T) { + assert.True(t, IsValueZero(v)) + }) + } +} diff --git a/logger.go b/log/logger.go similarity index 64% rename from logger.go rename to log/logger.go index 7b26e77f..eeb63693 100644 --- a/logger.go +++ b/log/logger.go @@ -2,26 +2,56 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package log import ( "fmt" "io" "log" +) - "xorm.io/core" +// LogLevel defines a log level +type LogLevel int + +// enumerate all LogLevels +const ( + // !nashtsai! following level also match syslog.Priority value + LOG_DEBUG LogLevel = iota + LOG_INFO + LOG_WARNING + LOG_ERR + LOG_OFF + LOG_UNKNOWN ) // default log options const ( DEFAULT_LOG_PREFIX = "[xorm]" DEFAULT_LOG_FLAG = log.Ldate | log.Lmicroseconds - DEFAULT_LOG_LEVEL = core.LOG_DEBUG + DEFAULT_LOG_LEVEL = LOG_DEBUG ) -var _ core.ILogger = DiscardLogger{} +// Logger is a logger interface +type Logger interface { + Debug(v ...interface{}) + Debugf(format string, v ...interface{}) + Error(v ...interface{}) + Errorf(format string, v ...interface{}) + Info(v ...interface{}) + Infof(format string, v ...interface{}) + Warn(v ...interface{}) + Warnf(format string, v ...interface{}) -// DiscardLogger don't log implementation for core.ILogger + Level() LogLevel + SetLevel(l LogLevel) + + ShowSQL(show ...bool) + IsShowSQL() bool +} + +var _ Logger = DiscardLogger{} + +// DiscardLogger don't log implementation for ILogger type DiscardLogger struct{} // Debug empty implementation @@ -49,12 +79,12 @@ func (DiscardLogger) Warn(v ...interface{}) {} func (DiscardLogger) Warnf(format string, v ...interface{}) {} // Level empty implementation -func (DiscardLogger) Level() core.LogLevel { - return core.LOG_UNKNOWN +func (DiscardLogger) Level() LogLevel { + return LOG_UNKNOWN } // SetLevel empty implementation -func (DiscardLogger) SetLevel(l core.LogLevel) {} +func (DiscardLogger) SetLevel(l LogLevel) {} // ShowSQL empty implementation func (DiscardLogger) ShowSQL(show ...bool) {} @@ -64,17 +94,17 @@ func (DiscardLogger) IsShowSQL() bool { return false } -// SimpleLogger is the default implment of core.ILogger +// SimpleLogger is the default implment of ILogger type SimpleLogger struct { DEBUG *log.Logger ERR *log.Logger INFO *log.Logger WARN *log.Logger - level core.LogLevel + level LogLevel showSQL bool } -var _ core.ILogger = &SimpleLogger{} +var _ Logger = &SimpleLogger{} // NewSimpleLogger use a special io.Writer as logger output func NewSimpleLogger(out io.Writer) *SimpleLogger { @@ -87,7 +117,7 @@ func NewSimpleLogger2(out io.Writer, prefix string, flag int) *SimpleLogger { } // NewSimpleLogger3 let you customrize your logger prefix and flag and logLevel -func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) *SimpleLogger { +func NewSimpleLogger3(out io.Writer, prefix string, flag int, l LogLevel) *SimpleLogger { return &SimpleLogger{ DEBUG: log.New(out, fmt.Sprintf("%s [debug] ", prefix), flag), ERR: log.New(out, fmt.Sprintf("%s [error] ", prefix), flag), @@ -97,82 +127,82 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) * } } -// Error implement core.ILogger +// Error implement ILogger func (s *SimpleLogger) Error(v ...interface{}) { - if s.level <= core.LOG_ERR { - s.ERR.Output(2, fmt.Sprint(v...)) + if s.level <= LOG_ERR { + s.ERR.Output(2, fmt.Sprintln(v...)) } return } -// Errorf implement core.ILogger +// Errorf implement ILogger func (s *SimpleLogger) Errorf(format string, v ...interface{}) { - if s.level <= core.LOG_ERR { + if s.level <= LOG_ERR { s.ERR.Output(2, fmt.Sprintf(format, v...)) } return } -// Debug implement core.ILogger +// Debug implement ILogger func (s *SimpleLogger) Debug(v ...interface{}) { - if s.level <= core.LOG_DEBUG { - s.DEBUG.Output(2, fmt.Sprint(v...)) + if s.level <= LOG_DEBUG { + s.DEBUG.Output(2, fmt.Sprintln(v...)) } return } -// Debugf implement core.ILogger +// Debugf implement ILogger func (s *SimpleLogger) Debugf(format string, v ...interface{}) { - if s.level <= core.LOG_DEBUG { + if s.level <= LOG_DEBUG { s.DEBUG.Output(2, fmt.Sprintf(format, v...)) } return } -// Info implement core.ILogger +// Info implement ILogger func (s *SimpleLogger) Info(v ...interface{}) { - if s.level <= core.LOG_INFO { - s.INFO.Output(2, fmt.Sprint(v...)) + if s.level <= LOG_INFO { + s.INFO.Output(2, fmt.Sprintln(v...)) } return } -// Infof implement core.ILogger +// Infof implement ILogger func (s *SimpleLogger) Infof(format string, v ...interface{}) { - if s.level <= core.LOG_INFO { + if s.level <= LOG_INFO { s.INFO.Output(2, fmt.Sprintf(format, v...)) } return } -// Warn implement core.ILogger +// Warn implement ILogger func (s *SimpleLogger) Warn(v ...interface{}) { - if s.level <= core.LOG_WARNING { - s.WARN.Output(2, fmt.Sprint(v...)) + if s.level <= LOG_WARNING { + s.WARN.Output(2, fmt.Sprintln(v...)) } return } -// Warnf implement core.ILogger +// Warnf implement ILogger func (s *SimpleLogger) Warnf(format string, v ...interface{}) { - if s.level <= core.LOG_WARNING { + if s.level <= LOG_WARNING { s.WARN.Output(2, fmt.Sprintf(format, v...)) } return } -// Level implement core.ILogger -func (s *SimpleLogger) Level() core.LogLevel { +// Level implement ILogger +func (s *SimpleLogger) Level() LogLevel { return s.level } -// SetLevel implement core.ILogger -func (s *SimpleLogger) SetLevel(l core.LogLevel) { +// SetLevel implement ILogger +func (s *SimpleLogger) SetLevel(l LogLevel) { s.level = l return } -// ShowSQL implement core.ILogger +// ShowSQL implement ILogger func (s *SimpleLogger) ShowSQL(show ...bool) { if len(show) == 0 { s.showSQL = true @@ -181,7 +211,7 @@ func (s *SimpleLogger) ShowSQL(show ...bool) { s.showSQL = show[0] } -// IsShowSQL implement core.ILogger +// IsShowSQL implement ILogger func (s *SimpleLogger) IsShowSQL() bool { return s.showSQL } diff --git a/log/logger_context.go b/log/logger_context.go new file mode 100644 index 00000000..6b7252ef --- /dev/null +++ b/log/logger_context.go @@ -0,0 +1,115 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package log + +import ( + "fmt" + + "xorm.io/xorm/contexts" +) + +// LogContext represents a log context +type LogContext contexts.ContextHook + +// SQLLogger represents an interface to log SQL +type SQLLogger interface { + BeforeSQL(context LogContext) // only invoked when IsShowSQL is true + AfterSQL(context LogContext) // only invoked when IsShowSQL is true +} + +// ContextLogger represents a logger interface with context +type ContextLogger interface { + SQLLogger + + Debugf(format string, v ...interface{}) + Errorf(format string, v ...interface{}) + Infof(format string, v ...interface{}) + Warnf(format string, v ...interface{}) + + Level() LogLevel + SetLevel(l LogLevel) + + ShowSQL(show ...bool) + IsShowSQL() bool +} + +var ( + _ ContextLogger = &LoggerAdapter{} +) + +// enumerate all the context keys +var ( + SessionIDKey = "__xorm_session_id" + SessionShowSQLKey = "__xorm_show_sql" +) + +// LoggerAdapter wraps a Logger interface as LoggerContext interface +type LoggerAdapter struct { + logger Logger +} + +// NewLoggerAdapter creates an adapter for old xorm logger interface +func NewLoggerAdapter(logger Logger) ContextLogger { + return &LoggerAdapter{ + logger: logger, + } +} + +// BeforeSQL implements ContextLogger +func (l *LoggerAdapter) BeforeSQL(ctx LogContext) {} + +// AfterSQL implements ContextLogger +func (l *LoggerAdapter) AfterSQL(ctx LogContext) { + var sessionPart string + v := ctx.Ctx.Value(SessionIDKey) + if key, ok := v.(string); ok { + sessionPart = fmt.Sprintf(" [%s]", key) + } + if ctx.ExecuteTime > 0 { + l.logger.Infof("[SQL]%s %s %v - %v", sessionPart, ctx.SQL, ctx.Args, ctx.ExecuteTime) + } else { + l.logger.Infof("[SQL]%s %s %v", sessionPart, ctx.SQL, ctx.Args) + } +} + +// Debugf implements ContextLogger +func (l *LoggerAdapter) Debugf(format string, v ...interface{}) { + l.logger.Debugf(format, v...) +} + +// Errorf implements ContextLogger +func (l *LoggerAdapter) Errorf(format string, v ...interface{}) { + l.logger.Errorf(format, v...) +} + +// Infof implements ContextLogger +func (l *LoggerAdapter) Infof(format string, v ...interface{}) { + l.logger.Infof(format, v...) +} + +// Warnf implements ContextLogger +func (l *LoggerAdapter) Warnf(format string, v ...interface{}) { + l.logger.Warnf(format, v...) +} + +// Level implements ContextLogger +func (l *LoggerAdapter) Level() LogLevel { + return l.logger.Level() +} + +// SetLevel implements ContextLogger +func (l *LoggerAdapter) SetLevel(lv LogLevel) { + l.logger.SetLevel(lv) +} + +// ShowSQL implements ContextLogger +func (l *LoggerAdapter) ShowSQL(show ...bool) { + l.logger.ShowSQL(show...) +} + +// IsShowSQL implements ContextLogger +func (l *LoggerAdapter) IsShowSQL() bool { + return l.logger.IsShowSQL() +} diff --git a/syslogger.go b/log/syslogger.go similarity index 88% rename from syslogger.go rename to log/syslogger.go index 11ba01e7..0b3e381c 100644 --- a/syslogger.go +++ b/log/syslogger.go @@ -4,16 +4,14 @@ // +build !windows,!nacl,!plan9 -package xorm +package log import ( "fmt" "log/syslog" - - "xorm.io/core" ) -var _ core.ILogger = &SyslogLogger{} +var _ Logger = &SyslogLogger{} // SyslogLogger will be depricated type SyslogLogger struct { @@ -21,7 +19,7 @@ type SyslogLogger struct { showSQL bool } -// NewSyslogLogger implements core.ILogger +// NewSyslogLogger implements Logger func NewSyslogLogger(w *syslog.Writer) *SyslogLogger { return &SyslogLogger{w: w} } @@ -67,12 +65,12 @@ func (s *SyslogLogger) Warnf(format string, v ...interface{}) { } // Level shows log level -func (s *SyslogLogger) Level() core.LogLevel { - return core.LOG_UNKNOWN +func (s *SyslogLogger) Level() LogLevel { + return LOG_UNKNOWN } // SetLevel always return error, as current log/syslog package doesn't allow to set priority level after syslog.Writer created -func (s *SyslogLogger) SetLevel(l core.LogLevel) {} +func (s *SyslogLogger) SetLevel(l LogLevel) {} // ShowSQL set if logging SQL func (s *SyslogLogger) ShowSQL(show ...bool) { diff --git a/migrate/migrate.go b/migrate/migrate.go index ed7b401c..82c58f90 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -13,7 +13,7 @@ type MigrateFunc func(*xorm.Engine) error // RollbackFunc is the func signature for rollbacking. type RollbackFunc func(*xorm.Engine) error -// InitSchemaFunc is the func signature for initializing the schema. +// InitSchemaFunc is the func signature for initializing the schemas. type InitSchemaFunc func(*xorm.Engine) error // Options define options for all migrations. @@ -34,7 +34,7 @@ type Migration struct { Rollback RollbackFunc } -// Migrate represents a collection of all migrations of a database schema. +// Migrate represents a collection of all migrations of a database schemas. type Migrate struct { db *xorm.Engine options *Options diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index 3a52787c..19554f7e 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -6,9 +6,9 @@ import ( "os" "testing" - "xorm.io/xorm" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" + "xorm.io/xorm" ) type Person struct { diff --git a/names/mapper.go b/names/mapper.go new file mode 100644 index 00000000..79add76e --- /dev/null +++ b/names/mapper.go @@ -0,0 +1,265 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package names + +import ( + "strings" + "sync" + "unsafe" +) + +// Mapper represents a name convertation between struct's fields name and table's column name +type Mapper interface { + Obj2Table(string) string + Table2Obj(string) string +} + +type CacheMapper struct { + oriMapper Mapper + obj2tableCache map[string]string + obj2tableMutex sync.RWMutex + table2objCache map[string]string + table2objMutex sync.RWMutex +} + +func NewCacheMapper(mapper Mapper) *CacheMapper { + return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string), + table2objCache: make(map[string]string), + } +} + +func (m *CacheMapper) Obj2Table(o string) string { + m.obj2tableMutex.RLock() + t, ok := m.obj2tableCache[o] + m.obj2tableMutex.RUnlock() + if ok { + return t + } + + t = m.oriMapper.Obj2Table(o) + m.obj2tableMutex.Lock() + m.obj2tableCache[o] = t + m.obj2tableMutex.Unlock() + return t +} + +func (m *CacheMapper) Table2Obj(t string) string { + m.table2objMutex.RLock() + o, ok := m.table2objCache[t] + m.table2objMutex.RUnlock() + if ok { + return o + } + + o = m.oriMapper.Table2Obj(t) + m.table2objMutex.Lock() + m.table2objCache[t] = o + m.table2objMutex.Unlock() + return o +} + +// SameMapper implements IMapper and provides same name between struct and +// database table +type SameMapper struct { +} + +func (m SameMapper) Obj2Table(o string) string { + return o +} + +func (m SameMapper) Table2Obj(t string) string { + return t +} + +// SnakeMapper implements IMapper and provides name transaltion between +// struct and database table +type SnakeMapper struct { +} + +func b2s(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +func snakeCasedName(name string) string { + newstr := make([]byte, 0, len(name)+1) + for i := 0; i < len(name); i++ { + c := name[i] + if isUpper := 'A' <= c && c <= 'Z'; isUpper { + if i > 0 { + newstr = append(newstr, '_') + } + c += 'a' - 'A' + } + newstr = append(newstr, c) + } + + return b2s(newstr) +} + +func (mapper SnakeMapper) Obj2Table(name string) string { + return snakeCasedName(name) +} + +func titleCasedName(name string) string { + newstr := make([]byte, 0, len(name)) + upNextChar := true + + name = strings.ToLower(name) + + for i := 0; i < len(name); i++ { + c := name[i] + switch { + case upNextChar: + upNextChar = false + if 'a' <= c && c <= 'z' { + c -= 'a' - 'A' + } + case c == '_': + upNextChar = true + continue + } + + newstr = append(newstr, c) + } + + return b2s(newstr) +} + +func (mapper SnakeMapper) Table2Obj(name string) string { + return titleCasedName(name) +} + +// GonicMapper implements IMapper. It will consider initialisms when mapping names. +// E.g. id -> ID, user -> User and to table names: UserID -> user_id, MyUID -> my_uid +type GonicMapper map[string]bool + +func isASCIIUpper(r rune) bool { + return 'A' <= r && r <= 'Z' +} + +func toASCIIUpper(r rune) rune { + if 'a' <= r && r <= 'z' { + r -= ('a' - 'A') + } + return r +} + +func gonicCasedName(name string) string { + newstr := make([]rune, 0, len(name)+3) + for idx, chr := range name { + if isASCIIUpper(chr) && idx > 0 { + if !isASCIIUpper(newstr[len(newstr)-1]) { + newstr = append(newstr, '_') + } + } + + if !isASCIIUpper(chr) && idx > 1 { + l := len(newstr) + if isASCIIUpper(newstr[l-1]) && isASCIIUpper(newstr[l-2]) { + newstr = append(newstr, newstr[l-1]) + newstr[l-1] = '_' + } + } + + newstr = append(newstr, chr) + } + return strings.ToLower(string(newstr)) +} + +func (mapper GonicMapper) Obj2Table(name string) string { + return gonicCasedName(name) +} + +func (mapper GonicMapper) Table2Obj(name string) string { + newstr := make([]rune, 0) + + name = strings.ToLower(name) + parts := strings.Split(name, "_") + + for _, p := range parts { + _, isInitialism := mapper[strings.ToUpper(p)] + for i, r := range p { + if i == 0 || isInitialism { + r = toASCIIUpper(r) + } + newstr = append(newstr, r) + } + } + + return string(newstr) +} + +// LintGonicMapper is A GonicMapper that contains a list of common initialisms taken from golang/lint +var LintGonicMapper = GonicMapper{ + "API": true, + "ASCII": true, + "CPU": true, + "CSS": true, + "DNS": true, + "EOF": true, + "GUID": true, + "HTML": true, + "HTTP": true, + "HTTPS": true, + "ID": true, + "IP": true, + "JSON": true, + "LHS": true, + "QPS": true, + "RAM": true, + "RHS": true, + "RPC": true, + "SLA": true, + "SMTP": true, + "SSH": true, + "TLS": true, + "TTL": true, + "UI": true, + "UID": true, + "UUID": true, + "URI": true, + "URL": true, + "UTF8": true, + "VM": true, + "XML": true, + "XSRF": true, + "XSS": true, +} + +// PrefixMapper provides prefix table name support +type PrefixMapper struct { + Mapper Mapper + Prefix string +} + +func (mapper PrefixMapper) Obj2Table(name string) string { + return mapper.Prefix + mapper.Mapper.Obj2Table(name) +} + +func (mapper PrefixMapper) Table2Obj(name string) string { + return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) +} + +func NewPrefixMapper(mapper Mapper, prefix string) PrefixMapper { + return PrefixMapper{mapper, prefix} +} + +// SuffixMapper provides suffix table name support +type SuffixMapper struct { + Mapper Mapper + Suffix string +} + +func (mapper SuffixMapper) Obj2Table(name string) string { + return mapper.Mapper.Obj2Table(name) + mapper.Suffix +} + +func (mapper SuffixMapper) Table2Obj(name string) string { + return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)]) +} + +func NewSuffixMapper(mapper Mapper, suffix string) SuffixMapper { + return SuffixMapper{mapper, suffix} +} diff --git a/names/mapper_test.go b/names/mapper_test.go new file mode 100644 index 00000000..a39cb569 --- /dev/null +++ b/names/mapper_test.go @@ -0,0 +1,70 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package names + +import ( + "strings" + "testing" +) + +func TestGonicMapperFromObj(t *testing.T) { + testCases := map[string]string{ + "HTTPLib": "http_lib", + "id": "id", + "ID": "id", + "IDa": "i_da", + "iDa": "i_da", + "IDAa": "id_aa", + "aID": "a_id", + "aaID": "aa_id", + "aaaID": "aaa_id", + "MyREalFunkYLONgNAME": "my_r_eal_funk_ylo_ng_name", + } + + for in, expected := range testCases { + out := gonicCasedName(in) + if out != expected { + t.Errorf("Given %s, expected %s but got %s", in, expected, out) + } + } +} + +func TestGonicMapperToObj(t *testing.T) { + testCases := map[string]string{ + "http_lib": "HTTPLib", + "id": "ID", + "ida": "Ida", + "id_aa": "IDAa", + "aa_id": "AaID", + "my_r_eal_funk_ylo_ng_name": "MyREalFunkYloNgName", + } + + for in, expected := range testCases { + out := LintGonicMapper.Table2Obj(in) + if out != expected { + t.Errorf("Given %s, expected %s but got %s", in, expected, out) + } + } +} + +func BenchmarkSnakeCasedName(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + s := strings.Repeat("FooBar", 32) + for i := 0; i < b.N; i++ { + _ = snakeCasedName(s) + } +} + +func BenchmarkTitleCasedName(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + s := strings.Repeat("foo_bar", 32) + for i := 0; i < b.N; i++ { + _ = titleCasedName(s) + } +} diff --git a/names/table_name.go b/names/table_name.go new file mode 100644 index 00000000..0afb1ae3 --- /dev/null +++ b/names/table_name.go @@ -0,0 +1,56 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package names + +import ( + "reflect" + "sync" +) + +// TableName table name interface to define customerize table name +type TableName interface { + TableName() string +} + +var ( + tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() + tvCache sync.Map +) + +func GetTableName(mapper Mapper, v reflect.Value) string { + if v.Type().Implements(tpTableName) { + return v.Interface().(TableName).TableName() + } + + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableName) { + return v.Interface().(TableName).TableName() + } + } else if v.CanAddr() { + v1 := v.Addr() + if v1.Type().Implements(tpTableName) { + return v1.Interface().(TableName).TableName() + } + } else { + name, ok := tvCache.Load(v.Type()) + if ok { + if name.(string) != "" { + return name.(string) + } + } else { + v2 := reflect.New(v.Type()) + if v2.Type().Implements(tpTableName) { + tableName := v2.Interface().(TableName).TableName() + tvCache.Store(v.Type(), tableName) + return tableName + } + + tvCache.Store(v.Type(), "") + } + } + + return mapper.Obj2Table(v.Type().Name()) +} diff --git a/names/table_name_test.go b/names/table_name_test.go new file mode 100644 index 00000000..76da4135 --- /dev/null +++ b/names/table_name_test.go @@ -0,0 +1,140 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package names + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type Userinfo struct { + Uid int64 `xorm:"id pk not null autoincr"` + Username string `xorm:"unique"` + Departname string + Alias string `xorm:"-"` + Created time.Time + Detail Userdetail `xorm:"detail_id int(11)"` + Height float64 + Avatar []byte + IsMan bool +} + +type Userdetail struct { + Id int64 + Intro string `xorm:"text"` + Profile string `xorm:"varchar(2000)"` +} + +type MyGetCustomTableImpletation struct { + Id int64 `json:"id"` + Name string `json:"name"` +} + +const getCustomTableName = "GetCustomTableInterface" + +func (MyGetCustomTableImpletation) TableName() string { + return getCustomTableName +} + +type TestTableNameStruct struct{} + +const getTestTableName = "my_test_table_name_struct" + +func (t *TestTableNameStruct) TableName() string { + return getTestTableName +} + +func TestGetTableName(t *testing.T) { + var kases = []struct { + mapper Mapper + v reflect.Value + expectedTableName string + }{ + { + SnakeMapper{}, + reflect.ValueOf(new(Userinfo)), + "userinfo", + }, + { + SnakeMapper{}, + reflect.ValueOf(Userinfo{}), + "userinfo", + }, + { + SameMapper{}, + reflect.ValueOf(new(Userinfo)), + "Userinfo", + }, + { + SameMapper{}, + reflect.ValueOf(Userinfo{}), + "Userinfo", + }, + { + SnakeMapper{}, + reflect.ValueOf(new(MyGetCustomTableImpletation)), + getCustomTableName, + }, + { + SnakeMapper{}, + reflect.ValueOf(MyGetCustomTableImpletation{}), + getCustomTableName, + }, + { + SnakeMapper{}, + reflect.ValueOf(new(TestTableNameStruct)), + new(TestTableNameStruct).TableName(), + }, + { + SnakeMapper{}, + reflect.ValueOf(new(TestTableNameStruct)), + getTestTableName, + }, + { + SnakeMapper{}, + reflect.ValueOf(TestTableNameStruct{}), + getTestTableName, + }, + } + + for _, kase := range kases { + assert.EqualValues(t, kase.expectedTableName, GetTableName(kase.mapper, kase.v)) + } +} + +type OAuth2Application struct { +} + +// TableName sets the table name to `oauth2_application` +func (app *OAuth2Application) TableName() string { + return "oauth2_application" +} + +func TestGonicMapperCustomTable(t *testing.T) { + assert.EqualValues(t, "oauth2_application", + GetTableName(LintGonicMapper, reflect.ValueOf(new(OAuth2Application)))) + assert.EqualValues(t, "oauth2_application", + GetTableName(LintGonicMapper, reflect.ValueOf(OAuth2Application{}))) +} + +type MyTable struct { + Idx int +} + +func (t *MyTable) TableName() string { + return fmt.Sprintf("mytable_%d", t.Idx) +} + +func TestMyTable(t *testing.T) { + var table MyTable + for i := 0; i < 10; i++ { + table.Idx = i + assert.EqualValues(t, fmt.Sprintf("mytable_%d", i), GetTableName(SameMapper{}, reflect.ValueOf(&table))) + } +} diff --git a/processors.go b/processors.go index dcd9c6ac..8697e302 100644 --- a/processors.go +++ b/processors.go @@ -76,3 +76,69 @@ func (session *Session) executeProcessors() error { } return nil } + +func cleanupProcessorsClosures(slices *[]func(interface{})) { + if len(*slices) > 0 { + *slices = make([]func(interface{}), 0) + } +} + +func executeBeforeClosures(session *Session, bean interface{}) { + // handle before delete processors + for _, closure := range session.beforeClosures { + closure(bean) + } + cleanupProcessorsClosures(&session.beforeClosures) +} + +func executeBeforeSet(bean interface{}, fields []string, scanResults []interface{}) { + if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { + for ii, key := range fields { + b.BeforeSet(key, Cell(scanResults[ii].(*interface{}))) + } + } +} + +func executeAfterSet(bean interface{}, fields []string, scanResults []interface{}) { + if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { + for ii, key := range fields { + b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) + } + } +} + +func buildAfterProcessors(session *Session, bean interface{}) { + // handle afterClosures + for _, closure := range session.afterClosures { + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(sess *Session, bean interface{}) error { + closure(bean) + return nil + }, + session: session, + bean: bean, + }) + } + + if a, has := bean.(AfterLoadProcessor); has { + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(sess *Session, bean interface{}) error { + a.AfterLoad() + return nil + }, + session: session, + bean: bean, + }) + } + + if a, has := bean.(AfterLoadSessionProcessor); has { + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(sess *Session, bean interface{}) error { + a.AfterLoad(sess) + return nil + }, + session: session, + bean: bean, + }) + } +} diff --git a/rows.go b/rows.go index bdd44589..a56ea1c9 100644 --- a/rows.go +++ b/rows.go @@ -6,10 +6,13 @@ package xorm import ( "database/sql" + "errors" "fmt" "reflect" - "xorm.io/core" + "xorm.io/builder" + "xorm.io/xorm/core" + "xorm.io/xorm/internal/utils" ) // Rows rows wrapper a rows to @@ -29,7 +32,14 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { var args []interface{} var err error - if err = rows.session.statement.setRefBean(bean); err != nil { + beanValue := reflect.ValueOf(bean) + if beanValue.Kind() != reflect.Ptr { + return nil, errors.New("needs a pointer to a value") + } else if beanValue.Elem().Kind() == reflect.Ptr { + return nil, errors.New("a pointer to a pointer is not allowed") + } + + if err = rows.session.statement.SetRefBean(bean); err != nil { return nil, err } @@ -38,12 +48,39 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { } if rows.session.statement.RawSQL == "" { - sqlStr, args, err = rows.session.statement.genGetSQL(bean) + var autoCond builder.Cond + var addedTableName = (len(session.statement.JoinStr) > 0) + var table = rows.session.statement.RefTable + + if !session.statement.NoAutoCondition { + var err error + autoCond, err = session.statement.BuildConds(table, bean, true, true, false, true, addedTableName) + if err != nil { + return nil, err + } + } else { + // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. + // See https://gitea.com/xorm/xorm/issues/179 + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + var colName = session.engine.Quote(col.Name) + if addedTableName { + var nm = session.statement.TableName() + if len(session.statement.TableAlias) > 0 { + nm = session.statement.TableAlias + } + colName = session.engine.Quote(nm) + "." + colName + } + + autoCond = session.statement.CondDeleted(col) + } + } + + sqlStr, args, err = rows.session.statement.GenFindSQL(autoCond) if err != nil { return nil, err } } else { - sqlStr = rows.session.statement.RawSQL + sqlStr = rows.session.statement.GenRawSQL() args = rows.session.statement.RawParams } @@ -84,7 +121,7 @@ func (rows *Rows) Scan(bean interface{}) error { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } - if err := rows.session.statement.setRefBean(bean); err != nil { + if err := rows.session.statement.SetRefBean(bean); err != nil { return err } @@ -98,7 +135,7 @@ func (rows *Rows) Scan(bean interface{}) error { return err } - dataStruct := rValue(bean) + dataStruct := utils.ReflectValue(bean) _, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable) if err != nil { return err diff --git a/schemas/column.go b/schemas/column.go new file mode 100644 index 00000000..db66a3a6 --- /dev/null +++ b/schemas/column.go @@ -0,0 +1,133 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +const ( + TWOSIDES = iota + 1 + ONLYTODB + ONLYFROMDB +) + +// Column defines database column +type Column struct { + Name string + TableName string + FieldName string // Avaiable only when parsed from a struct + SQLType SQLType + IsJSON bool + Length int + Length2 int + Nullable bool + Default string + Indexes map[string]int + IsPrimaryKey bool + IsAutoIncrement bool + MapType int + IsCreated bool + IsUpdated bool + IsDeleted bool + IsCascade bool + IsVersion bool + DefaultIsEmpty bool // false means column has no default set, but not default value is empty + EnumOptions map[string]int + SetOptions map[string]int + DisableTimeZone bool + TimeZone *time.Location // column specified time zone + Comment string +} + +// NewColumn creates a new column +func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column { + return &Column{ + Name: name, + TableName: "", + FieldName: fieldName, + SQLType: sqlType, + Length: len1, + Length2: len2, + Nullable: nullable, + Default: "", + Indexes: make(map[string]int), + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: TWOSIDES, + IsCreated: false, + IsUpdated: false, + IsDeleted: false, + IsCascade: false, + IsVersion: false, + DefaultIsEmpty: true, // default should be no default + EnumOptions: make(map[string]int), + Comment: "", + } +} + +// ValueOf returns column's filed of struct's value +func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) { + dataStruct := reflect.Indirect(reflect.ValueOf(bean)) + return col.ValueOfV(&dataStruct) +} + +// ValueOfV returns column's filed of struct's value accept reflevt value +func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { + var fieldValue reflect.Value + fieldPath := strings.Split(col.FieldName, ".") + + if dataStruct.Type().Kind() == reflect.Map { + keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1]) + fieldValue = dataStruct.MapIndex(keyValue) + return &fieldValue, nil + } else if dataStruct.Type().Kind() == reflect.Interface { + structValue := reflect.ValueOf(dataStruct.Interface()) + dataStruct = &structValue + } + + level := len(fieldPath) + fieldValue = dataStruct.FieldByName(fieldPath[0]) + for i := 0; i < level-1; i++ { + if !fieldValue.IsValid() { + break + } + if fieldValue.Kind() == reflect.Struct { + fieldValue = fieldValue.FieldByName(fieldPath[i+1]) + } else if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) + } else { + return nil, fmt.Errorf("field %v is not valid", col.FieldName) + } + } + + if !fieldValue.IsValid() { + return nil, fmt.Errorf("field %v is not valid", col.FieldName) + } + + return &fieldValue, nil +} + +// ConvertID converts id content to suitable type according column type +func (col *Column) ConvertID(sid string) (interface{}, error) { + if col.SQLType.IsNumeric() { + n, err := strconv.ParseInt(sid, 10, 64) + if err != nil { + return nil, err + } + return n, nil + } else if col.SQLType.IsText() { + return sid, nil + } + return nil, errors.New("not supported") +} diff --git a/schemas/index.go b/schemas/index.go new file mode 100644 index 00000000..9541250f --- /dev/null +++ b/schemas/index.go @@ -0,0 +1,72 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "fmt" + "strings" +) + +// enumerate all index types +const ( + IndexType = iota + 1 + UniqueType +) + +// Index represents a database index +type Index struct { + IsRegular bool + Name string + Type int + Cols []string +} + +// NewIndex new an index object +func NewIndex(name string, indexType int) *Index { + return &Index{true, name, indexType, make([]string, 0)} +} + +func (index *Index) XName(tableName string) string { + if !strings.HasPrefix(index.Name, "UQE_") && + !strings.HasPrefix(index.Name, "IDX_") { + tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".") + tableName = tableParts[len(tableParts)-1] + if index.Type == UniqueType { + return fmt.Sprintf("UQE_%v_%v", tableName, index.Name) + } + return fmt.Sprintf("IDX_%v_%v", tableName, index.Name) + } + return index.Name +} + +// AddColumn add columns which will be composite index +func (index *Index) AddColumn(cols ...string) { + for _, col := range cols { + index.Cols = append(index.Cols, col) + } +} + +func (index *Index) Equal(dst *Index) bool { + if index.Type != dst.Type { + return false + } + if len(index.Cols) != len(dst.Cols) { + return false + } + + for i := 0; i < len(index.Cols); i++ { + var found bool + for j := 0; j < len(dst.Cols); j++ { + if index.Cols[i] == dst.Cols[j] { + found = true + break + } + } + if !found { + return false + } + } + return true +} diff --git a/schemas/pk.go b/schemas/pk.go new file mode 100644 index 00000000..03916b44 --- /dev/null +++ b/schemas/pk.go @@ -0,0 +1,41 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "bytes" + "encoding/gob" + + "xorm.io/xorm/internal/utils" +) + +type PK []interface{} + +func NewPK(pks ...interface{}) *PK { + p := PK(pks) + return &p +} + +func (p *PK) IsZero() bool { + for _, k := range *p { + if utils.IsZero(k) { + return true + } + } + return false +} + +func (p *PK) ToString() (string, error) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + err := enc.Encode(*p) + return buf.String(), err +} + +func (p *PK) FromString(content string) error { + dec := gob.NewDecoder(bytes.NewBufferString(content)) + err := dec.Decode(p) + return err +} diff --git a/schemas/pk_test.go b/schemas/pk_test.go new file mode 100644 index 00000000..a88b70da --- /dev/null +++ b/schemas/pk_test.go @@ -0,0 +1,36 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "reflect" + "testing" +) + +func TestPK(t *testing.T) { + p := NewPK(1, 3, "string") + str, err := p.ToString() + if err != nil { + t.Error(err) + } + t.Log(str) + + s := &PK{} + err = s.FromString(str) + if err != nil { + t.Error(err) + } + t.Log(s) + + if len(*p) != len(*s) { + t.Fatal("p", *p, "should be equal", *s) + } + + for i, ori := range *p { + if ori != (*s)[i] { + t.Fatal("ori", ori, reflect.ValueOf(ori), "should be equal", (*s)[i], reflect.ValueOf((*s)[i])) + } + } +} diff --git a/schemas/quote.go b/schemas/quote.go new file mode 100644 index 00000000..c44abe25 --- /dev/null +++ b/schemas/quote.go @@ -0,0 +1,240 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "strings" +) + +// Quoter represents a quoter to the SQL table name and column name +type Quoter struct { + Prefix byte + Suffix byte + IsReserved func(string) bool +} + +var ( + // AlwaysFalseReverse always think it's not a reverse word + AlwaysNoReserve = func(string) bool { return false } + + // AlwaysReverse always reverse the word + AlwaysReserve = func(string) bool { return true } + + // CommanQuoteMark represnets the common quote mark + CommanQuoteMark byte = '`' + + // CommonQuoter represetns a common quoter + CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve} +) + +func (q Quoter) IsEmpty() bool { + return q.Prefix == 0 && q.Suffix == 0 +} + +func (q Quoter) Quote(s string) string { + var buf strings.Builder + q.QuoteTo(&buf, s) + return buf.String() +} + +// Trim removes quotes from s +func (q Quoter) Trim(s string) string { + if len(s) < 2 { + return s + } + + var buf strings.Builder + for i := 0; i < len(s); i++ { + switch { + case i == 0 && s[i] == q.Prefix: + case i == len(s)-1 && s[i] == q.Suffix: + case s[i] == q.Suffix && s[i+1] == '.': + case s[i] == q.Prefix && s[i-1] == '.': + default: + buf.WriteByte(s[i]) + } + } + return buf.String() +} + +func (q Quoter) Join(a []string, sep string) string { + var b strings.Builder + q.JoinWrite(&b, a, sep) + return b.String() +} + +func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { + if len(a) == 0 { + return nil + } + + n := len(sep) * (len(a) - 1) + for i := 0; i < len(a); i++ { + n += len(a[i]) + } + + b.Grow(n) + for i, s := range a { + if i > 0 { + if _, err := b.WriteString(sep); err != nil { + return err + } + } + if s != "*" { + q.QuoteTo(b, strings.TrimSpace(s)) + } + } + return nil +} + +func findWord(v string, start int) int { + for j := start; j < len(v); j++ { + switch v[j] { + case '.', ' ': + return j + } + } + return len(v) +} + +func findStart(value string, start int) int { + if value[start] == '.' { + return start + 1 + } + if value[start] != ' ' { + return start + } + + var k = -1 + for j := start; j < len(value); j++ { + if value[j] != ' ' { + k = j + break + } + } + if k == -1 { + return len(value) + } + + if (value[k] == 'A' || value[k] == 'a') && (value[k+1] == 'S' || value[k+1] == 's') { + k = k + 2 + } + + for j := k; j < len(value); j++ { + if value[j] != ' ' { + return j + } + } + return len(value) +} + +func (q Quoter) quoteWordTo(buf *strings.Builder, word string) error { + var realWord = word + if (word[0] == CommanQuoteMark && word[len(word)-1] == CommanQuoteMark) || + (word[0] == q.Prefix && word[len(word)-1] == q.Suffix) { + realWord = word[1 : len(word)-1] + } + + if q.IsEmpty() { + _, err := buf.WriteString(realWord) + return err + } + + isReserved := q.IsReserved(realWord) + if isReserved { + if err := buf.WriteByte(q.Prefix); err != nil { + return err + } + } + if _, err := buf.WriteString(realWord); err != nil { + return err + } + if isReserved { + return buf.WriteByte(q.Suffix) + } + + return nil +} + +// QuoteTo quotes the table or column names. i.e. if the quotes are [ and ] +// name -> [name] +// `name` -> [name] +// [name] -> [name] +// schema.name -> [schema].[name] +// `schema`.`name` -> [schema].[name] +// `schema`.name -> [schema].[name] +// schema.`name` -> [schema].[name] +// [schema].name -> [schema].[name] +// schema.[name] -> [schema].[name] +// name AS a -> [name] AS a +// schema.name AS a -> [schema].[name] AS a +func (q Quoter) QuoteTo(buf *strings.Builder, value string) error { + var i int + for i < len(value) { + start := findStart(value, i) + if start > i { + if _, err := buf.WriteString(value[i:start]); err != nil { + return err + } + } + if start == len(value) { + return nil + } + + var nextEnd = findWord(value, start) + if err := q.quoteWordTo(buf, value[start:nextEnd]); err != nil { + return err + } + i = nextEnd + } + return nil +} + +// Strings quotes a slice of string +func (q Quoter) Strings(s []string) []string { + var res = make([]string, 0, len(s)) + for _, a := range s { + res = append(res, q.Quote(a)) + } + return res +} + +// Replace replaces common quote(`) as the quotes on the sql +func (q Quoter) Replace(sql string) string { + if q.IsEmpty() { + return sql + } + + var buf strings.Builder + buf.Grow(len(sql)) + + var beginSingleQuote bool + for i := 0; i < len(sql); i++ { + if !beginSingleQuote && sql[i] == CommanQuoteMark { + var j = i + 1 + for ; j < len(sql); j++ { + if sql[j] == CommanQuoteMark { + break + } + } + word := sql[i+1 : j] + isReserved := q.IsReserved(word) + if isReserved { + buf.WriteByte(q.Prefix) + } + buf.WriteString(word) + if isReserved { + buf.WriteByte(q.Suffix) + } + i = j + } else { + if sql[i] == '\'' { + beginSingleQuote = !beginSingleQuote + } + buf.WriteByte(sql[i]) + } + } + return buf.String() +} diff --git a/schemas/quote_test.go b/schemas/quote_test.go new file mode 100644 index 00000000..708b450e --- /dev/null +++ b/schemas/quote_test.go @@ -0,0 +1,181 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAlwaysQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', AlwaysReserve} + kases = []struct { + expected string + value string + }{ + {"[mytable]", "mytable"}, + {"[mytable]", "`mytable`"}, + {"[mytable]", `[mytable]`}, + {`["mytable"]`, `"mytable"`}, + {"[myschema].[mytable]", "myschema.mytable"}, + {"[myschema].[mytable]", "`myschema`.mytable"}, + {"[myschema].[mytable]", "myschema.`mytable`"}, + {"[myschema].[mytable]", "`myschema`.`mytable`"}, + {"[myschema].[mytable]", `[myschema].mytable`}, + {"[myschema].[mytable]", `myschema.[mytable]`}, + {"[myschema].[mytable]", `[myschema].[mytable]`}, + {`["myschema].[mytable"]`, `"myschema.mytable"`}, + {"[message_user] AS [sender]", "`message_user` AS `sender`"}, + {"[myschema].[mytable] AS [table]", "myschema.mytable AS table"}, + {" [mytable]", " mytable"}, + {" [mytable]", " mytable"}, + {"[mytable] ", "mytable "}, + {"[mytable] ", "mytable "}, + {" [mytable] ", " mytable "}, + {" [mytable] ", " mytable "}, + } + ) + + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } +} + +func TestReversedQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', func(s string) bool { + if s == "mytable" { + return true + } + return false + }} + kases = []struct { + expected string + value string + }{ + {"[mytable]", "mytable"}, + {"[mytable]", "`mytable`"}, + {"[mytable]", `[mytable]`}, + {`"mytable"`, `"mytable"`}, + {"myschema.[mytable]", "myschema.mytable"}, + {"myschema.[mytable]", "`myschema`.mytable"}, + {"myschema.[mytable]", "myschema.`mytable`"}, + {"myschema.[mytable]", "`myschema`.`mytable`"}, + {"myschema.[mytable]", `[myschema].mytable`}, + {"myschema.[mytable]", `myschema.[mytable]`}, + {"myschema.[mytable]", `[myschema].[mytable]`}, + {`"myschema.mytable"`, `"myschema.mytable"`}, + {"message_user AS sender", "`message_user` AS `sender`"}, + {"myschema.[mytable] AS table", "myschema.mytable AS table"}, + } + ) + + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } +} + +func TestNoQuoteTo(t *testing.T) { + var ( + quoter = Quoter{'[', ']', AlwaysNoReserve} + kases = []struct { + expected string + value string + }{ + {"mytable", "mytable"}, + {"mytable", "`mytable`"}, + {"mytable", `[mytable]`}, + {`"mytable"`, `"mytable"`}, + {"myschema.mytable", "myschema.mytable"}, + {"myschema.mytable", "`myschema`.mytable"}, + {"myschema.mytable", "myschema.`mytable`"}, + {"myschema.mytable", "`myschema`.`mytable`"}, + {"myschema.mytable", `[myschema].mytable`}, + {"myschema.mytable", `myschema.[mytable]`}, + {"myschema.mytable", `[myschema].[mytable]`}, + {`"myschema.mytable"`, `"myschema.mytable"`}, + {"message_user AS sender", "`message_user` AS `sender`"}, + {"myschema.mytable AS table", "myschema.mytable AS table"}, + } + ) + + for _, v := range kases { + t.Run(v.value, func(t *testing.T) { + buf := &strings.Builder{} + quoter.QuoteTo(buf, v.value) + assert.EqualValues(t, v.expected, buf.String()) + }) + } +} + +func TestJoin(t *testing.T) { + cols := []string{"f1", "f2", "f3"} + quoter := Quoter{'[', ']', AlwaysReserve} + + assert.EqualValues(t, "[a],[b]", quoter.Join([]string{"a", " b"}, ",")) + + assert.EqualValues(t, "[f1], [f2], [f3]", quoter.Join(cols, ", ")) + + quoter.IsReserved = AlwaysNoReserve + assert.EqualValues(t, "f1, f2, f3", quoter.Join(cols, ", ")) +} + +func TestStrings(t *testing.T) { + cols := []string{"f1", "f2", "t3.f3"} + quoter := Quoter{'[', ']', AlwaysReserve} + + quotedCols := quoter.Strings(cols) + assert.EqualValues(t, []string{"[f1]", "[f2]", "[t3].[f3]"}, quotedCols) +} + +func TestTrim(t *testing.T) { + var kases = map[string]string{ + "[table_name]": "table_name", + "[schema].[table_name]": "schema.table_name", + } + + for src, dst := range kases { + assert.EqualValues(t, src, CommonQuoter.Trim(src)) + assert.EqualValues(t, dst, Quoter{'[', ']', AlwaysReserve}.Trim(src)) + } +} + +func TestReplace(t *testing.T) { + q := Quoter{'[', ']', AlwaysReserve} + var kases = []struct { + source string + expected string + }{ + { + "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?", + "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", + }, + { + "SELECT 'abc```test```''', `a` FROM b", + "SELECT 'abc```test```''', [a] FROM b", + }, + { + "UPDATE table SET `a` = ~ `a`, `b`='abc`'", + "UPDATE table SET [a] = ~ [a], [b]='abc`'", + }, + } + + for _, kase := range kases { + t.Run(kase.source, func(t *testing.T) { + assert.EqualValues(t, kase.expected, q.Replace(kase.source)) + }) + } +} diff --git a/schemas/table.go b/schemas/table.go new file mode 100644 index 00000000..6c57a7e3 --- /dev/null +++ b/schemas/table.go @@ -0,0 +1,195 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +// Table represents a database table +type Table struct { + Name string + Type reflect.Type + columnsSeq []string + columnsMap map[string][]*Column + columns []*Column + Indexes map[string]*Index + PrimaryKeys []string + AutoIncrement string + Created map[string]bool + Updated string + Deleted string + Version string + StoreEngine string + Charset string + Comment string +} + +// NewEmptyTable creates an empty table +func NewEmptyTable() *Table { + return NewTable("", nil) +} + +// NewTable creates a new Table object +func NewTable(name string, t reflect.Type) *Table { + return &Table{Name: name, Type: t, + columnsSeq: make([]string, 0), + columns: make([]*Column, 0), + columnsMap: make(map[string][]*Column), + Indexes: make(map[string]*Index), + Created: make(map[string]bool), + PrimaryKeys: make([]string, 0), + } +} + +// Columns returns table's columns +func (table *Table) Columns() []*Column { + return table.columns +} + +// ColumnsSeq returns table's column names according sequence +func (table *Table) ColumnsSeq() []string { + return table.columnsSeq +} + +func (table *Table) columnsByName(name string) []*Column { + for k, cols := range table.columnsMap { + if strings.EqualFold(k, name) { + return cols + } + } + return nil +} + +// GetColumn returns column according column name, if column not found, return nil +func (table *Table) GetColumn(name string) *Column { + cols := table.columnsByName(name) + if cols != nil { + return cols[0] + } + + return nil +} + +// GetColumnIdx returns column according name and idx +func (table *Table) GetColumnIdx(name string, idx int) *Column { + cols := table.columnsByName(name) + if cols != nil && idx < len(cols) { + return cols[idx] + } + + return nil +} + +// PKColumns reprents all primary key columns +func (table *Table) PKColumns() []*Column { + columns := make([]*Column, len(table.PrimaryKeys)) + for i, name := range table.PrimaryKeys { + columns[i] = table.GetColumn(name) + } + return columns +} + +func (table *Table) ColumnType(name string) reflect.Type { + t, _ := table.Type.FieldByName(name) + return t.Type +} + +func (table *Table) AutoIncrColumn() *Column { + return table.GetColumn(table.AutoIncrement) +} + +func (table *Table) VersionColumn() *Column { + return table.GetColumn(table.Version) +} + +func (table *Table) UpdatedColumn() *Column { + return table.GetColumn(table.Updated) +} + +func (table *Table) DeletedColumn() *Column { + return table.GetColumn(table.Deleted) +} + +// AddColumn adds a column to table +func (table *Table) AddColumn(col *Column) { + table.columnsSeq = append(table.columnsSeq, col.Name) + table.columns = append(table.columns, col) + colName := strings.ToLower(col.Name) + if c, ok := table.columnsMap[colName]; ok { + table.columnsMap[colName] = append(c, col) + } else { + table.columnsMap[colName] = []*Column{col} + } + + if col.IsPrimaryKey { + table.PrimaryKeys = append(table.PrimaryKeys, col.Name) + } + if col.IsAutoIncrement { + table.AutoIncrement = col.Name + } + if col.IsCreated { + table.Created[col.Name] = true + } + if col.IsUpdated { + table.Updated = col.Name + } + if col.IsDeleted { + table.Deleted = col.Name + } + if col.IsVersion { + table.Version = col.Name + } +} + +// AddIndex adds an index or an unique to table +func (table *Table) AddIndex(index *Index) { + table.Indexes[index.Name] = index +} + +// IDOfV get id from one value of struct +func (table *Table) IDOfV(rv reflect.Value) (PK, error) { + v := reflect.Indirect(rv) + pk := make([]interface{}, len(table.PrimaryKeys)) + for i, col := range table.PKColumns() { + var err error + + fieldName := col.FieldName + for { + parts := strings.SplitN(fieldName, ".", 2) + if len(parts) == 1 { + break + } + + v = v.FieldByName(parts[0]) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName) + } + fieldName = parts[1] + } + + pkField := v.FieldByName(fieldName) + switch pkField.Kind() { + case reflect.String: + pk[i], err = col.ConvertID(pkField.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + pk[i], err = col.ConvertID(strconv.FormatInt(pkField.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // id of uint will be converted to int64 + pk[i], err = col.ConvertID(strconv.FormatUint(pkField.Uint(), 10)) + } + + if err != nil { + return nil, err + } + } + return PK(pk), nil +} diff --git a/schemas/table_test.go b/schemas/table_test.go new file mode 100644 index 00000000..9bf10e33 --- /dev/null +++ b/schemas/table_test.go @@ -0,0 +1,111 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "strings" + "testing" +) + +var testsGetColumn = []struct { + name string + idx int +}{ + {"Id", 0}, + {"Deleted", 0}, + {"Caption", 0}, + {"Code_1", 0}, + {"Code_2", 0}, + {"Code_3", 0}, + {"Parent_Id", 0}, + {"Latitude", 0}, + {"Longitude", 0}, +} + +var table *Table + +func init() { + + table = NewEmptyTable() + + var name string + + for i := 0; i < len(testsGetColumn); i++ { + // as in Table.AddColumn func + name = strings.ToLower(testsGetColumn[i].name) + + table.columnsMap[name] = append(table.columnsMap[name], &Column{}) + } +} + +func TestGetColumn(t *testing.T) { + + for _, test := range testsGetColumn { + if table.GetColumn(test.name) == nil { + t.Error("Column not found!") + } + } +} + +func TestGetColumnIdx(t *testing.T) { + + for _, test := range testsGetColumn { + if table.GetColumnIdx(test.name, test.idx) == nil { + t.Errorf("Column %s with idx %d not found!", test.name, test.idx) + } + } +} + +func BenchmarkGetColumnWithToLower(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + + if _, ok := table.columnsMap[strings.ToLower(test.name)]; !ok { + b.Errorf("Column not found:%s", test.name) + } + } + } +} + +func BenchmarkGetColumnIdxWithToLower(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + + if c, ok := table.columnsMap[strings.ToLower(test.name)]; ok { + if test.idx < len(c) { + continue + } else { + b.Errorf("Bad idx in: %s, %d", test.name, test.idx) + } + } else { + b.Errorf("Column not found: %s, %d", test.name, test.idx) + } + } + } +} + +func BenchmarkGetColumn(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + if table.GetColumn(test.name) == nil { + b.Errorf("Column not found:%s", test.name) + } + } + } +} + +func BenchmarkGetColumnIdx(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + if table.GetColumnIdx(test.name, test.idx) == nil { + b.Errorf("Column not found:%s, %d", test.name, test.idx) + } + } + } +} diff --git a/schemas/type.go b/schemas/type.go new file mode 100644 index 00000000..89459a4d --- /dev/null +++ b/schemas/type.go @@ -0,0 +1,336 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +import ( + "reflect" + "sort" + "strings" + "time" +) + +type DBType string + +const ( + POSTGRES DBType = "postgres" + SQLITE DBType = "sqlite3" + MYSQL DBType = "mysql" + MSSQL DBType = "mssql" + ORACLE DBType = "oracle" +) + +// SQLType represents SQL types +type SQLType struct { + Name string + DefaultLength int + DefaultLength2 int +} + +const ( + UNKNOW_TYPE = iota + TEXT_TYPE + BLOB_TYPE + TIME_TYPE + NUMERIC_TYPE + ARRAY_TYPE +) + +func (s *SQLType) IsType(st int) bool { + if t, ok := SqlTypes[s.Name]; ok && t == st { + return true + } + return false +} + +func (s *SQLType) IsText() bool { + return s.IsType(TEXT_TYPE) +} + +func (s *SQLType) IsBlob() bool { + return s.IsType(BLOB_TYPE) +} + +func (s *SQLType) IsTime() bool { + return s.IsType(TIME_TYPE) +} + +func (s *SQLType) IsNumeric() bool { + return s.IsType(NUMERIC_TYPE) +} + +func (s *SQLType) IsArray() bool { + return s.IsType(ARRAY_TYPE) +} + +func (s *SQLType) IsJson() bool { + return s.Name == Json || s.Name == Jsonb +} + +var ( + Bit = "BIT" + TinyInt = "TINYINT" + SmallInt = "SMALLINT" + MediumInt = "MEDIUMINT" + Int = "INT" + Integer = "INTEGER" + BigInt = "BIGINT" + + Enum = "ENUM" + Set = "SET" + + Char = "CHAR" + Varchar = "VARCHAR" + NChar = "NCHAR" + NVarchar = "NVARCHAR" + TinyText = "TINYTEXT" + Text = "TEXT" + NText = "NTEXT" + Clob = "CLOB" + MediumText = "MEDIUMTEXT" + LongText = "LONGTEXT" + Uuid = "UUID" + UniqueIdentifier = "UNIQUEIDENTIFIER" + SysName = "SYSNAME" + + Date = "DATE" + DateTime = "DATETIME" + SmallDateTime = "SMALLDATETIME" + Time = "TIME" + TimeStamp = "TIMESTAMP" + TimeStampz = "TIMESTAMPZ" + Year = "YEAR" + + Decimal = "DECIMAL" + Numeric = "NUMERIC" + Money = "MONEY" + SmallMoney = "SMALLMONEY" + + Real = "REAL" + Float = "FLOAT" + Double = "DOUBLE" + + Binary = "BINARY" + VarBinary = "VARBINARY" + TinyBlob = "TINYBLOB" + Blob = "BLOB" + MediumBlob = "MEDIUMBLOB" + LongBlob = "LONGBLOB" + Bytea = "BYTEA" + + Bool = "BOOL" + Boolean = "BOOLEAN" + + Serial = "SERIAL" + BigSerial = "BIGSERIAL" + + Json = "JSON" + Jsonb = "JSONB" + + Array = "ARRAY" + + SqlTypes = map[string]int{ + Bit: NUMERIC_TYPE, + TinyInt: NUMERIC_TYPE, + SmallInt: NUMERIC_TYPE, + MediumInt: NUMERIC_TYPE, + Int: NUMERIC_TYPE, + Integer: NUMERIC_TYPE, + BigInt: NUMERIC_TYPE, + + Enum: TEXT_TYPE, + Set: TEXT_TYPE, + Json: TEXT_TYPE, + Jsonb: TEXT_TYPE, + + Char: TEXT_TYPE, + NChar: TEXT_TYPE, + Varchar: TEXT_TYPE, + NVarchar: TEXT_TYPE, + TinyText: TEXT_TYPE, + Text: TEXT_TYPE, + NText: TEXT_TYPE, + MediumText: TEXT_TYPE, + LongText: TEXT_TYPE, + Uuid: TEXT_TYPE, + Clob: TEXT_TYPE, + SysName: TEXT_TYPE, + + Date: TIME_TYPE, + DateTime: TIME_TYPE, + Time: TIME_TYPE, + TimeStamp: TIME_TYPE, + TimeStampz: TIME_TYPE, + SmallDateTime: TIME_TYPE, + Year: TIME_TYPE, + + Decimal: NUMERIC_TYPE, + Numeric: NUMERIC_TYPE, + Real: NUMERIC_TYPE, + Float: NUMERIC_TYPE, + Double: NUMERIC_TYPE, + Money: NUMERIC_TYPE, + SmallMoney: NUMERIC_TYPE, + + Binary: BLOB_TYPE, + VarBinary: BLOB_TYPE, + + TinyBlob: BLOB_TYPE, + Blob: BLOB_TYPE, + MediumBlob: BLOB_TYPE, + LongBlob: BLOB_TYPE, + Bytea: BLOB_TYPE, + UniqueIdentifier: BLOB_TYPE, + + Bool: NUMERIC_TYPE, + + Serial: NUMERIC_TYPE, + BigSerial: NUMERIC_TYPE, + + Array: ARRAY_TYPE, + } + + intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} + uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} +) + +// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison +var ( + c_EMPTY_STRING string + c_BOOL_DEFAULT bool + c_BYTE_DEFAULT byte + c_COMPLEX64_DEFAULT complex64 + c_COMPLEX128_DEFAULT complex128 + c_FLOAT32_DEFAULT float32 + c_FLOAT64_DEFAULT float64 + c_INT64_DEFAULT int64 + c_UINT64_DEFAULT uint64 + c_INT32_DEFAULT int32 + c_UINT32_DEFAULT uint32 + c_INT16_DEFAULT int16 + c_UINT16_DEFAULT uint16 + c_INT8_DEFAULT int8 + c_UINT8_DEFAULT uint8 + c_INT_DEFAULT int + c_UINT_DEFAULT uint + c_TIME_DEFAULT time.Time +) + +var ( + IntType = reflect.TypeOf(c_INT_DEFAULT) + Int8Type = reflect.TypeOf(c_INT8_DEFAULT) + Int16Type = reflect.TypeOf(c_INT16_DEFAULT) + Int32Type = reflect.TypeOf(c_INT32_DEFAULT) + Int64Type = reflect.TypeOf(c_INT64_DEFAULT) + + UintType = reflect.TypeOf(c_UINT_DEFAULT) + Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT) + Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT) + Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT) + Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT) + + Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT) + Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT) + + Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT) + Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT) + + StringType = reflect.TypeOf(c_EMPTY_STRING) + BoolType = reflect.TypeOf(c_BOOL_DEFAULT) + ByteType = reflect.TypeOf(c_BYTE_DEFAULT) + BytesType = reflect.SliceOf(ByteType) + + TimeType = reflect.TypeOf(c_TIME_DEFAULT) +) + +var ( + PtrIntType = reflect.PtrTo(IntType) + PtrInt8Type = reflect.PtrTo(Int8Type) + PtrInt16Type = reflect.PtrTo(Int16Type) + PtrInt32Type = reflect.PtrTo(Int32Type) + PtrInt64Type = reflect.PtrTo(Int64Type) + + PtrUintType = reflect.PtrTo(UintType) + PtrUint8Type = reflect.PtrTo(Uint8Type) + PtrUint16Type = reflect.PtrTo(Uint16Type) + PtrUint32Type = reflect.PtrTo(Uint32Type) + PtrUint64Type = reflect.PtrTo(Uint64Type) + + PtrFloat32Type = reflect.PtrTo(Float32Type) + PtrFloat64Type = reflect.PtrTo(Float64Type) + + PtrComplex64Type = reflect.PtrTo(Complex64Type) + PtrComplex128Type = reflect.PtrTo(Complex128Type) + + PtrStringType = reflect.PtrTo(StringType) + PtrBoolType = reflect.PtrTo(BoolType) + PtrByteType = reflect.PtrTo(ByteType) + + PtrTimeType = reflect.PtrTo(TimeType) +) + +// Type2SQLType generate SQLType acorrding Go's type +func Type2SQLType(t reflect.Type) (st SQLType) { + switch k := t.Kind(); k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + st = SQLType{Int, 0, 0} + case reflect.Int64, reflect.Uint64: + st = SQLType{BigInt, 0, 0} + case reflect.Float32: + st = SQLType{Float, 0, 0} + case reflect.Float64: + st = SQLType{Double, 0, 0} + case reflect.Complex64, reflect.Complex128: + st = SQLType{Varchar, 64, 0} + case reflect.Array, reflect.Slice, reflect.Map: + if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) { + st = SQLType{Blob, 0, 0} + } else { + st = SQLType{Text, 0, 0} + } + case reflect.Bool: + st = SQLType{Bool, 0, 0} + case reflect.String: + st = SQLType{Varchar, 255, 0} + case reflect.Struct: + if t.ConvertibleTo(TimeType) { + st = SQLType{DateTime, 0, 0} + } else { + // TODO need to handle association struct + st = SQLType{Text, 0, 0} + } + case reflect.Ptr: + st = Type2SQLType(t.Elem()) + default: + st = SQLType{Text, 0, 0} + } + return +} + +// default sql type change to go types +func SQLType2Type(st SQLType) reflect.Type { + name := strings.ToUpper(st.Name) + switch name { + case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: + return reflect.TypeOf(1) + case BigInt, BigSerial: + return reflect.TypeOf(int64(1)) + case Float, Real: + return reflect.TypeOf(float32(1)) + case Double: + return reflect.TypeOf(float64(1)) + case Char, NChar, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName: + return reflect.TypeOf("") + case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier: + return reflect.TypeOf([]byte{}) + case Bool: + return reflect.TypeOf(true) + case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year: + return reflect.TypeOf(c_TIME_DEFAULT) + case Decimal, Numeric, Money, SmallMoney: + return reflect.TypeOf("") + default: + return reflect.TypeOf("") + } +} diff --git a/session.go b/session.go index 83071935..761b1415 100644 --- a/session.go +++ b/session.go @@ -6,35 +6,65 @@ package xorm import ( "context" + "crypto/rand" + "crypto/sha256" "database/sql" + "encoding/hex" "errors" "fmt" "hash/crc32" + "io" "reflect" "strings" "time" - "xorm.io/core" + "xorm.io/xorm/contexts" + "xorm.io/xorm/convert" + "xorm.io/xorm/core" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/internal/statements" + "xorm.io/xorm/log" + "xorm.io/xorm/schemas" ) -type sessionType int +// ErrFieldIsNotExist columns does not exist +type ErrFieldIsNotExist struct { + FieldName string + TableName string +} + +func (e ErrFieldIsNotExist) Error() string { + return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) +} + +// ErrFieldIsNotValid is not valid +type ErrFieldIsNotValid struct { + FieldName string + TableName string +} + +func (e ErrFieldIsNotValid) Error() string { + return fmt.Sprintf("field %s is not valid on table %s", e.FieldName, e.TableName) +} + +type sessionType bool const ( - engineSession sessionType = iota - groupSession + engineSession sessionType = false + groupSession sessionType = true ) // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { - db *core.DB engine *Engine tx *core.Tx - statement Statement + statement *statements.Statement isAutoCommit bool isCommitedOrRollbacked bool isAutoClose bool - + isClosed bool + prepareStmt bool // Automatically reset the statement after operations that execute a SQL // query such as Count(), Find(), Get(), ... autoResetStatement bool @@ -45,89 +75,117 @@ type Session struct { afterDeleteBeans map[interface{}]*[]func(interface{}) // -- - beforeClosures []func(interface{}) - afterClosures []func(interface{}) - + beforeClosures []func(interface{}) + afterClosures []func(interface{}) afterProcessors []executedProcessor - prepareStmt bool - stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) - // !evalphobia! stored the last executed query on this session - //beforeSQLExec func(string, ...interface{}) lastSQL string lastSQLArgs []interface{} - showSQL bool ctx context.Context sessionType sessionType } -// Clone copy all the session's content and return a new session -func (session *Session) Clone() *Session { - var sess = *session - return &sess +func newSessionID() string { + hash := sha256.New() + _, err := io.CopyN(hash, rand.Reader, 50) + if err != nil { + return "????????????????????" + } + md := hash.Sum(nil) + mdStr := hex.EncodeToString(md) + return mdStr[0:20] } -// Init reset the session as the init status. -func (session *Session) Init() { - session.statement.Init() - session.statement.Engine = session.engine - session.showSQL = session.engine.showSQL - session.isAutoCommit = true - session.isCommitedOrRollbacked = false - session.isAutoClose = false - session.autoResetStatement = true - session.prepareStmt = false +func newSession(engine *Engine) *Session { + var ctx context.Context + if engine.logSessionID { + ctx = context.WithValue(engine.defaultContext, log.SessionIDKey, newSessionID()) + } else { + ctx = engine.defaultContext + } - // !nashtsai! is lazy init better? - session.afterInsertBeans = make(map[interface{}]*[]func(interface{}), 0) - session.afterUpdateBeans = make(map[interface{}]*[]func(interface{}), 0) - session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) - session.beforeClosures = make([]func(interface{}), 0) - session.afterClosures = make([]func(interface{}), 0) - session.stmtCache = make(map[uint32]*core.Stmt) + return &Session{ + ctx: ctx, + engine: engine, + tx: nil, + statement: statements.NewStatement( + engine.dialect, + engine.tagParser, + engine.DatabaseTZ, + ), + isClosed: false, + isAutoCommit: true, + isCommitedOrRollbacked: false, + isAutoClose: false, + autoResetStatement: true, + prepareStmt: false, - session.afterProcessors = make([]executedProcessor, 0) + afterInsertBeans: make(map[interface{}]*[]func(interface{}), 0), + afterUpdateBeans: make(map[interface{}]*[]func(interface{}), 0), + afterDeleteBeans: make(map[interface{}]*[]func(interface{}), 0), + beforeClosures: make([]func(interface{}), 0), + afterClosures: make([]func(interface{}), 0), + afterProcessors: make([]executedProcessor, 0), + stmtCache: make(map[uint32]*core.Stmt), - session.lastSQL = "" - session.lastSQLArgs = []interface{}{} + lastSQL: "", + lastSQLArgs: make([]interface{}, 0), - session.ctx = session.engine.defaultContext + sessionType: engineSession, + } } // Close release the connection from pool -func (session *Session) Close() { +func (session *Session) Close() error { for _, v := range session.stmtCache { - v.Close() + if err := v.Close(); err != nil { + return err + } } - if session.db != nil { + if !session.isClosed { // When Close be called, if session is a transaction and do not call // Commit or Rollback, then call Rollback. if session.tx != nil && !session.isCommitedOrRollbacked { - session.Rollback() + if err := session.Rollback(); err != nil { + return err + } } session.tx = nil session.stmtCache = nil - session.db = nil + session.isClosed = true } + return nil +} + +func (session *Session) db() *core.DB { + return session.engine.db +} + +func (session *Session) getQueryer() core.Queryer { + if session.tx != nil { + return session.tx + } + return session.db() } // ContextCache enable context cache or not -func (session *Session) ContextCache(context ContextCache) *Session { - session.statement.context = context +func (session *Session) ContextCache(context contexts.ContextCache) *Session { + session.statement.SetContextCache(context) return session } // IsClosed returns if session is closed func (session *Session) IsClosed() bool { - return session.db == nil + return session.isClosed } func (session *Session) resetStatement() { if session.autoResetStatement { - session.statement.Init() + session.statement.Reset() } } @@ -155,7 +213,9 @@ func (session *Session) After(closures func(interface{})) *Session { // Table can input a string or pointer to struct for special a table to operate. func (session *Session) Table(tableNameOrBean interface{}) *Session { - session.statement.Table(tableNameOrBean) + if err := session.statement.SetTable(tableNameOrBean); err != nil { + session.statement.LastError = err + } return session } @@ -179,7 +239,7 @@ func (session *Session) ForUpdate() *Session { // NoAutoCondition disable generate SQL condition from beans func (session *Session) NoAutoCondition(no ...bool) *Session { - session.statement.NoAutoCondition(no...) + session.statement.SetNoAutoCondition(no...) return session } @@ -229,12 +289,12 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { } // MustLogSQL means record SQL or not and don't follow engine's setting -func (session *Session) MustLogSQL(log ...bool) *Session { - if len(log) > 0 { - session.showSQL = log[0] - } else { - session.showSQL = true +func (session *Session) MustLogSQL(logs ...bool) *Session { + var showSQL = true + if len(logs) > 0 { + showSQL = logs[0] } + session.ctx = context.WithValue(session.ctx, log.SessionShowSQLKey, showSQL) return session } @@ -265,17 +325,7 @@ func (session *Session) Having(conditions string) *Session { // DB db return the wrapper of sql.DB func (session *Session) DB() *core.DB { - if session.db == nil { - session.db = session.engine.db - session.stmtCache = make(map[uint32]*core.Stmt, 0) - } - return session.db -} - -func cleanupProcessorsClosures(slices *[]func(interface{})) { - if len(*slices) > 0 { - *slices = make([]func(interface{}), 0) - } + return session.db() } func (session *Session) canCache() bool { @@ -285,7 +335,7 @@ func (session *Session) canCache() bool { !session.statement.UseCache || session.statement.IsForUpdate || session.tx != nil || - len(session.statement.selectStr) > 0 { + len(session.statement.SelectStr) > 0 { return false } return true @@ -306,8 +356,8 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt, return } -func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) (*reflect.Value, error) { - var col *core.Column +func (session *Session) getField(dataStruct *reflect.Value, key string, table *schemas.Table, idx int) (*reflect.Value, error) { + var col *schemas.Column if col = table.GetColumnIdx(key, idx); col == nil { return nil, ErrFieldIsNotExist{key, table.Name} } @@ -328,8 +378,8 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c type Cell *interface{} func (session *Session) rows2Beans(rows *core.Rows, fields []string, - table *core.Table, newElemFunc func([]string) reflect.Value, - sliceValueSetFunc func(*reflect.Value, core.PK) error) error { + table *schemas.Table, newElemFunc func([]string) reflect.Value, + sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { var newValue = newElemFunc(fields) bean := newValue.Interface() @@ -369,59 +419,20 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interfa return nil, err } - if b, hasBeforeSet := bean.(BeforeSetProcessor); hasBeforeSet { - for ii, key := range fields { - b.BeforeSet(key, Cell(scanResults[ii].(*interface{}))) - } - } + executeBeforeSet(bean, fields, scanResults) + return scanResults, nil } -func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) { +func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *schemas.Table) (schemas.PK, error) { defer func() { - if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { - for ii, key := range fields { - b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) - } - } + executeAfterSet(bean, fields, scanResults) }() - // handle afterClosures - for _, closure := range session.afterClosures { - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(sess *Session, bean interface{}) error { - closure(bean) - return nil - }, - session: session, - bean: bean, - }) - } - - if a, has := bean.(AfterLoadProcessor); has { - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(sess *Session, bean interface{}) error { - a.AfterLoad() - return nil - }, - session: session, - bean: bean, - }) - } - - if a, has := bean.(AfterLoadSessionProcessor); has { - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(sess *Session, bean interface{}) error { - a.AfterLoad(sess) - return nil - }, - session: session, - bean: bean, - }) - } + buildAfterProcessors(session, bean) var tempMap = make(map[string]int) - var pk core.PK + var pk schemas.PK for ii, key := range fields { var idx int var ok bool @@ -436,7 +447,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue, err := session.getField(dataStruct, key, table, idx) if err != nil { if !strings.Contains(err.Error(), "is not valid") { - session.engine.logger.Warn(err) + session.engine.logger.Warnf("%v", err) } continue } @@ -451,7 +462,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { if data, err := value2Bytes(&rawValue); err == nil { if err := structConvert.FromDB(data); err != nil { return nil, err @@ -463,12 +474,12 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } - if _, ok := fieldValue.Interface().(core.Conversion); ok { + if _, ok := fieldValue.Interface().(convert.Conversion); ok { if data, err := value2Bytes(&rawValue); err == nil { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldValue.Type().Elem())) } - fieldValue.Interface().(core.Conversion).FromDB(data) + fieldValue.Interface().(convert.Conversion).FromDB(data) } else { return nil, err } @@ -488,7 +499,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b var bs []byte if rawValueType.Kind() == reflect.String { bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { + } else if rawValueType.ConvertibleTo(schemas.BytesType) { bs = vv.Bytes() } else { return nil, fmt.Errorf("unsupported database data type: %s %v", key, rawValueType.Kind()) @@ -502,13 +513,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b continue } if fieldValue.CanAddr() { - err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { return nil, err } } else { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) if err != nil { return nil, err } @@ -525,20 +536,20 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b var bs []byte if rawValueType.Kind() == reflect.String { bs = []byte(vv.String()) - } else if rawValueType.ConvertibleTo(core.BytesType) { + } else if rawValueType.ConvertibleTo(schemas.BytesType) { bs = vv.Bytes() } hasAssigned = true if len(bs) > 0 { if fieldValue.CanAddr() { - err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { return nil, err } } else { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) if err != nil { return nil, err } @@ -554,7 +565,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true if col.SQLType.IsText() { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } @@ -607,16 +618,16 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue.SetUint(uint64(vv.Int())) } case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { + if fieldType.ConvertibleTo(schemas.TimeType) { dbTZ := session.engine.DatabaseTZ if col.TimeZone != nil { dbTZ = col.TimeZone } - if rawValueType == core.TimeType { + if rawValueType == schemas.TimeType { hasAssigned = true - t := vv.Convert(core.TimeType).Interface().(time.Time) + t := vv.Convert(schemas.TimeType).Interface().(time.Time) z, _ := t.Zone() // set new location if database don't save timezone or give an incorrect timezone @@ -628,8 +639,8 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b t = t.In(session.engine.TZLocation) fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) - } else if rawValueType == core.IntType || rawValueType == core.Int64Type || - rawValueType == core.Int32Type { + } else if rawValueType == schemas.IntType || rawValueType == schemas.Int64Type || + rawValueType == schemas.Int32Type { hasAssigned = true t := time.Unix(vv.Int(), 0).In(session.engine.TZLocation) @@ -639,7 +650,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true t, err := session.byte2Time(col, d) if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Errorf("byte2Time error: %v", err) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) @@ -648,7 +659,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true t, err := session.str2Time(col, d) if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Errorf("byte2Time error: %v", err) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) @@ -661,7 +672,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b // !! 增加支持sql.Scanner接口的结构,如sql.NullString hasAssigned = true if err := nulVal.Scan(vv.Interface()); err != nil { - session.engine.logger.Error("sql.Sanner error:", err.Error()) + session.engine.logger.Errorf("sql.Sanner error: %v", err) hasAssigned = false } } else if col.SQLType.IsJson() { @@ -669,7 +680,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true x := reflect.New(fieldType) if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) if err != nil { return nil, err } @@ -679,7 +690,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true x := reflect.New(fieldType) if len(vv.Bytes()) > 0 { - err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } @@ -687,7 +698,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { return nil, err } @@ -696,13 +707,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b if len(table.PrimaryKeys) != 1 { return nil, errors.New("unsupported non or composited primary key cascade") } - var pk = make(core.PK, len(table.PrimaryKeys)) + var pk = make(schemas.PK, len(table.PrimaryKeys)) pk[0], err = asKind(vv, rawValueType) if err != nil { return nil, err } - if !isPKZero(pk) { + if !pk.IsZero() { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily @@ -722,110 +733,110 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b // !nashtsai! TODO merge duplicated codes above switch fieldType { // following types case matching ptr's native type, therefore assign ptr directly - case core.PtrStringType: + case schemas.PtrStringType: if rawValueType.Kind() == reflect.String { x := vv.String() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrBoolType: + case schemas.PtrBoolType: if rawValueType.Kind() == reflect.Bool { x := vv.Bool() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrTimeType: - if rawValueType == core.PtrTimeType { + case schemas.PtrTimeType: + if rawValueType == schemas.PtrTimeType { hasAssigned = true var x = rawValue.Interface().(time.Time) fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrFloat64Type: + case schemas.PtrFloat64Type: if rawValueType.Kind() == reflect.Float64 { x := vv.Float() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrUint64Type: + case schemas.PtrUint64Type: if rawValueType.Kind() == reflect.Int64 { var x = uint64(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt64Type: + case schemas.PtrInt64Type: if rawValueType.Kind() == reflect.Int64 { x := vv.Int() hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrFloat32Type: + case schemas.PtrFloat32Type: if rawValueType.Kind() == reflect.Float64 { var x = float32(vv.Float()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrIntType: + case schemas.PtrIntType: if rawValueType.Kind() == reflect.Int64 { var x = int(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt32Type: + case schemas.PtrInt32Type: if rawValueType.Kind() == reflect.Int64 { var x = int32(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt8Type: + case schemas.PtrInt8Type: if rawValueType.Kind() == reflect.Int64 { var x = int8(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrInt16Type: + case schemas.PtrInt16Type: if rawValueType.Kind() == reflect.Int64 { var x = int16(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrUintType: + case schemas.PtrUintType: if rawValueType.Kind() == reflect.Int64 { var x = uint(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.PtrUint32Type: + case schemas.PtrUint32Type: if rawValueType.Kind() == reflect.Int64 { var x = uint32(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.Uint8Type: + case schemas.Uint8Type: if rawValueType.Kind() == reflect.Int64 { var x = uint8(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.Uint16Type: + case schemas.Uint16Type: if rawValueType.Kind() == reflect.Int64 { var x = uint16(vv.Int()) hasAssigned = true fieldValue.Set(reflect.ValueOf(&x)) } - case core.Complex64Type: + case schemas.Complex64Type: var x complex64 if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) if err != nil { return nil, err } fieldValue.Set(reflect.ValueOf(&x)) } hasAssigned = true - case core.Complex128Type: + case schemas.Complex128Type: var x complex128 if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) if err != nil { return nil, err } @@ -854,17 +865,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b func (session *Session) saveLastSQL(sql string, args ...interface{}) { session.lastSQL = sql session.lastSQLArgs = args - session.logSQL(sql, args...) -} - -func (session *Session) logSQL(sqlStr string, sqlArgs ...interface{}) { - if session.showSQL && !session.engine.showExecTime { - if len(sqlArgs) > 0 { - session.engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs) - } else { - session.engine.logger.Infof("[SQL] %v", sqlStr) - } - } } // LastSQL returns last query information @@ -874,7 +874,7 @@ func (session *Session) LastSQL() (string, []interface{}) { // Unscoped always disable struct tag "deleted" func (session *Session) Unscoped() *Session { - session.statement.Unscoped() + session.statement.SetUnscoped() return session } @@ -886,3 +886,19 @@ func (session *Session) incrVersionFieldValue(fieldValue *reflect.Value) { fieldValue.SetUint(fieldValue.Uint() + 1) } } + +// ContextHook sets the context on this session +func (session *Session) Context(ctx context.Context) *Session { + session.ctx = ctx + return session +} + +// PingContext test if database is ok +func (session *Session) PingContext(ctx context.Context) error { + if session.isAutoClose { + defer session.Close() + } + + session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) + return session.DB().PingContext(ctx) +} diff --git a/session_cols.go b/session_cols.go index 1558074f..ca3589ab 100644 --- a/session_cols.go +++ b/session_cols.go @@ -9,10 +9,10 @@ import ( "strings" "time" - "xorm.io/core" + "xorm.io/xorm/schemas" ) -func setColumnInt(bean interface{}, col *core.Column, t int64) { +func setColumnInt(bean interface{}, col *schemas.Column, t int64) { v, err := col.ValueOf(bean) if err != nil { return @@ -27,7 +27,7 @@ func setColumnInt(bean interface{}, col *core.Column, t int64) { } } -func setColumnTime(bean interface{}, col *core.Column, t time.Time) { +func setColumnTime(bean interface{}, col *schemas.Column, t time.Time) { v, err := col.ValueOf(bean) if err != nil { return @@ -44,7 +44,7 @@ func setColumnTime(bean interface{}, col *core.Column, t time.Time) { } } -func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) { +func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) { if len(m) == 0 { return false, false } @@ -63,19 +63,6 @@ func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) return false, false } -func col2NewCols(columns ...string) []string { - newColumns := make([]string, 0, len(columns)) - for _, col := range columns { - col = strings.Replace(col, "`", "", -1) - col = strings.Replace(col, `"`, "", -1) - ccols := strings.Split(col, ",") - for _, c := range ccols { - newColumns = append(newColumns, strings.TrimSpace(c)) - } - } - return newColumns -} - // Incr provides a query string like "count = count + 1" func (session *Session) Incr(column string, arg ...interface{}) *Session { session.statement.Incr(column, arg...) diff --git a/session_cond.go b/session_cond.go index b16bdea8..25d17148 100644 --- a/session_cond.go +++ b/session_cond.go @@ -6,14 +6,6 @@ package xorm import "xorm.io/builder" -// Sql provides raw sql input parameter. When you have a complex SQL statement -// and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. -// -// Deprecated: use SQL instead. -func (session *Session) Sql(query string, args ...interface{}) *Session { - return session.SQL(query, args...) -} - // SQL provides raw sql input parameter. When you have a complex SQL statement // and cannot use Where, Id, In and etc. Methods to describe, you can use SQL. func (session *Session) SQL(query interface{}, args ...interface{}) *Session { @@ -39,13 +31,6 @@ func (session *Session) Or(query interface{}, args ...interface{}) *Session { return session } -// Id provides converting id as a query condition -// -// Deprecated: use ID instead -func (session *Session) Id(id interface{}) *Session { - return session.ID(id) -} - // ID provides converting id as a query condition func (session *Session) ID(id interface{}) *Session { session.statement.ID(id) @@ -66,5 +51,5 @@ func (session *Session) NotIn(column string, args ...interface{}) *Session { // Conds returns session query conditions except auto bean conditions func (session *Session) Conds() builder.Cond { - return session.statement.cond + return session.statement.Conds() } diff --git a/session_context.go b/session_context.go deleted file mode 100644 index 915f0568..00000000 --- a/session_context.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import "context" - -// Context sets the context on this session -func (session *Session) Context(ctx context.Context) *Session { - session.ctx = ctx - return session -} - -// PingContext test if database is ok -func (session *Session) PingContext(ctx context.Context) error { - if session.isAutoClose { - defer session.Close() - } - - session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) - return session.DB().PingContext(ctx) -} diff --git a/session_context_test.go b/session_context_test.go deleted file mode 100644 index 2784468d..00000000 --- a/session_context_test.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestQueryContext(t *testing.T) { - type ContextQueryStruct struct { - Id int64 - Name string - } - - assert.NoError(t, prepareEngine()) - assertSync(t, new(ContextQueryStruct)) - - _, err := testEngine.Insert(&ContextQueryStruct{Name: "1"}) - assert.NoError(t, err) - - ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) - defer cancel() - - time.Sleep(time.Nanosecond) - - has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "context deadline exceeded") - assert.False(t, has) -} diff --git a/session_convert.go b/session_convert.go index 7f11354d..a6839947 100644 --- a/session_convert.go +++ b/session_convert.go @@ -6,7 +6,6 @@ package xorm import ( "database/sql" - "database/sql/driver" "errors" "fmt" "reflect" @@ -14,10 +13,13 @@ import ( "strings" "time" - "xorm.io/core" + "xorm.io/xorm/convert" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) -func (session *Session) str2Time(col *core.Column, data string) (outTime time.Time, outErr error) { +func (session *Session) str2Time(col *schemas.Column, data string) (outTime time.Time, outErr error) { sdata := strings.TrimSpace(data) var x time.Time var err error @@ -27,7 +29,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti parseLoc = col.TimeZone } - if sdata == zeroTime0 || sdata == zeroTime1 { + if sdata == utils.ZeroTime0 || sdata == utils.ZeroTime1 { } else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column // time stamp sd, err := strconv.ParseInt(sdata, 10, 64) @@ -54,14 +56,14 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc) //session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) - } else if col.SQLType.Name == core.Time { + } else if col.SQLType.Name == schemas.Time { if strings.Contains(sdata, " ") { ssd := strings.Split(sdata, " ") sdata = ssd[1] } sdata = strings.TrimSpace(sdata) - if session.engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 { + if session.engine.dialect.URI().DBType == schemas.MYSQL && len(sdata) > 8 { sdata = sdata[len(sdata)-8:] } @@ -80,21 +82,17 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti return } -func (session *Session) byte2Time(col *core.Column, data []byte) (outTime time.Time, outErr error) { +func (session *Session) byte2Time(col *schemas.Column, data []byte) (outTime time.Time, outErr error) { return session.str2Time(col, string(data)) } -var ( - nullFloatType = reflect.TypeOf(sql.NullFloat64{}) -) - // convert a db data([]byte) to a field value -func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, data []byte) error { - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { +func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Value, data []byte) error { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { return structConvert.FromDB(data) } - if structConvert, ok := fieldValue.Interface().(core.Conversion); ok { + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { return structConvert.FromDB(data) } @@ -106,9 +104,8 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, case reflect.Complex64, reflect.Complex128: x := reflect.New(fieldType) if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(x.Elem()) @@ -120,9 +117,8 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, if col.SQLType.IsText() { x := reflect.New(fieldType) if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(x.Elem()) @@ -133,9 +129,8 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } else { x := reflect.New(fieldType) if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(x.Elem()) @@ -157,8 +152,8 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var x int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && - session.engine.dialect.DBType() == core.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API + if col.SQLType.Name == schemas.Bit && + session.engine.dialect.URI().DBType == schemas.MYSQL { // !nashtsai! TODO dialect needs to provide conversion interface API if len(data) == 1 { x = int64(data[0]) } else { @@ -199,7 +194,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error()) } } else { - if fieldType.ConvertibleTo(core.TimeType) { + if fieldType.ConvertibleTo(schemas.TimeType) { x, err := session.byte2Time(col, data) if err != nil { return err @@ -207,7 +202,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { return err } @@ -217,14 +212,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return errors.New("unsupported composited primary key cascade") } - var pk = make(core.PK, len(table.PrimaryKeys)) + var pk = make(schemas.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err } - if !isPKZero(pk) { + if !pk.IsZero() { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily @@ -247,11 +242,11 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, //typeStr := fieldType.String() switch fieldType.Elem().Kind() { // case "*string": - case core.StringType.Kind(): + case schemas.StringType.Kind(): x := string(data) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*bool": - case core.BoolType.Kind(): + case schemas.BoolType.Kind(): d := string(data) v, err := strconv.ParseBool(d) if err != nil { @@ -259,36 +254,34 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&v).Convert(fieldType)) // case "*complex64": - case core.Complex64Type.Kind(): + case schemas.Complex64Type.Kind(): var x complex64 if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, &x) + err := json.DefaultJSONHandler.Unmarshal(data, &x) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) } // case "*complex128": - case core.Complex128Type.Kind(): + case schemas.Complex128Type.Kind(): var x complex128 if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, &x) + err := json.DefaultJSONHandler.Unmarshal(data, &x) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) } // case "*float64": - case core.Float64Type.Kind(): + case schemas.Float64Type.Kind(): x, err := strconv.ParseFloat(string(data), 64) if err != nil { return fmt.Errorf("arg %v as float64: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*float32": - case core.Float32Type.Kind(): + case schemas.Float32Type.Kind(): var x float32 x1, err := strconv.ParseFloat(string(data), 32) if err != nil { @@ -297,7 +290,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = float32(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint64": - case core.Uint64Type.Kind(): + case schemas.Uint64Type.Kind(): var x uint64 x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -305,7 +298,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint": - case core.UintType.Kind(): + case schemas.UintType.Kind(): var x uint x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -314,7 +307,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = uint(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint32": - case core.Uint32Type.Kind(): + case schemas.Uint32Type.Kind(): var x uint32 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -323,7 +316,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = uint32(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint8": - case core.Uint8Type.Kind(): + case schemas.Uint8Type.Kind(): var x uint8 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -332,7 +325,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = uint8(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*uint16": - case core.Uint16Type.Kind(): + case schemas.Uint16Type.Kind(): var x uint16 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -341,12 +334,12 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, x = uint16(x1) fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int64": - case core.Int64Type.Kind(): + case schemas.Int64Type.Kind(): sdata := string(data) var x int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && + if col.SQLType.Name == schemas.Bit && strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int64(data[0]) @@ -365,13 +358,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int": - case core.IntType.Kind(): + case schemas.IntType.Kind(): sdata := string(data) var x int var x1 int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && + if col.SQLType.Name == schemas.Bit && strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int(data[0]) @@ -393,14 +386,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int32": - case core.Int32Type.Kind(): + case schemas.Int32Type.Kind(): sdata := string(data) var x int32 var x1 int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && - session.engine.dialect.DBType() == core.MYSQL { + if col.SQLType.Name == schemas.Bit && + session.engine.dialect.URI().DBType == schemas.MYSQL { if len(data) == 1 { x = int32(data[0]) } else { @@ -421,13 +414,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int8": - case core.Int8Type.Kind(): + case schemas.Int8Type.Kind(): sdata := string(data) var x int8 var x1 int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && + if col.SQLType.Name == schemas.Bit && strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int8(data[0]) @@ -449,13 +442,13 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) // case "*int16": - case core.Int16Type.Kind(): + case schemas.Int16Type.Kind(): sdata := string(data) var x int16 var x1 int64 var err error // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == core.Bit && + if col.SQLType.Name == schemas.Bit && strings.Contains(session.engine.DriverName(), "mysql") { if len(data) == 1 { x = int16(data[0]) @@ -480,7 +473,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, case reflect.Struct: switch fieldType { // case "*.time.Time": - case core.PtrTimeType: + case schemas.PtrTimeType: x, err := session.byte2Time(col, data) if err != nil { return err @@ -490,7 +483,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, default: if session.statement.UseCascade { structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.autoMapType(structInter.Elem()) + table, err := session.engine.tagParser.ParseWithCache(structInter.Elem()) if err != nil { return err } @@ -499,14 +492,14 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return errors.New("unsupported composited primary key cascade") } - var pk = make(core.PK, len(table.PrimaryKeys)) + var pk = make(schemas.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err } - if !isPKZero(pk) { + if !pk.IsZero() { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily @@ -534,138 +527,3 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return nil } - -// convert a field value of a struct to interface for put into db -func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Value) (interface{}, error) { - if fieldValue.CanAddr() { - if fieldConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { - data, err := fieldConvert.ToDB() - if err != nil { - return 0, err - } - if col.SQLType.IsBlob() { - return data, nil - } - return string(data), nil - } - } - - if fieldConvert, ok := fieldValue.Interface().(core.Conversion); ok { - data, err := fieldConvert.ToDB() - if err != nil { - return 0, err - } - if col.SQLType.IsBlob() { - return data, nil - } - return string(data), nil - } - - fieldType := fieldValue.Type() - k := fieldType.Kind() - if k == reflect.Ptr { - if fieldValue.IsNil() { - return nil, nil - } else if !fieldValue.IsValid() { - session.engine.logger.Warn("the field[", col.FieldName, "] is invalid") - return nil, nil - } else { - // !nashtsai! deference pointer type to instance type - fieldValue = fieldValue.Elem() - fieldType = fieldValue.Type() - k = fieldType.Kind() - } - } - - switch k { - case reflect.Bool: - return fieldValue.Bool(), nil - case reflect.String: - return fieldValue.String(), nil - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - tf := session.engine.formatColTime(col, t) - return tf, nil - } else if fieldType.ConvertibleTo(nullFloatType) { - t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) - if !t.Valid { - return nil, nil - } - return t.Float64, nil - } - - if !col.SQLType.IsJson() { - // !! 增加支持driver.Valuer接口的结构,如sql.NullString - if v, ok := fieldValue.Interface().(driver.Valuer); ok { - return v.Value() - } - - fieldTable, err := session.engine.autoMapType(fieldValue) - if err != nil { - return nil, err - } - if len(fieldTable.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) - return pkField.Interface(), nil - } - return 0, fmt.Errorf("no primary key for col %v", col.Name) - } - - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - session.engine.logger.Error(err) - return 0, err - } - return string(bytes), nil - } else if col.SQLType.IsBlob() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - session.engine.logger.Error(err) - return 0, err - } - return bytes, nil - } - return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) - case reflect.Complex64, reflect.Complex128: - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - session.engine.logger.Error(err) - return 0, err - } - return string(bytes), nil - case reflect.Array, reflect.Slice, reflect.Map: - if !fieldValue.IsValid() { - return fieldValue.Interface(), nil - } - - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - session.engine.logger.Error(err) - return 0, err - } - return string(bytes), nil - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (k == reflect.Slice) && - (fieldValue.Type().Elem().Kind() == reflect.Uint8) { - bytes = fieldValue.Bytes() - } else { - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - session.engine.logger.Error(err) - return 0, err - } - } - return bytes, nil - } - return nil, ErrUnSupportedType - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - return int64(fieldValue.Uint()), nil - default: - return fieldValue.Interface(), nil - } -} diff --git a/session_delete.go b/session_delete.go index 675d4d8c..13bf791f 100644 --- a/session_delete.go +++ b/session_delete.go @@ -9,37 +9,46 @@ import ( "fmt" "strconv" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/schemas" ) -func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, args ...interface{}) error { +var ( + // ErrNeedDeletedCond delete needs less one condition error + ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") + + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("Not implemented") +) + +func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if table == nil || session.tx != nil { return ErrCacheFailed } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.engine.dialect, table) + sqlStr = filter.Do(sqlStr) } - newsql := session.statement.convertIDSQL(sqlStr) + newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } - cacher := session.engine.getCacher(tableName) + cacher := session.engine.cacherMgr.GetCacher(tableName) pkColumns := table.PKColumns() - ids, err := core.GetCacheSql(cacher, tableName, newsql, args) + ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) if err != nil { resultsSlice, err := session.queryBytes(newsql, args...) if err != nil { return err } - ids = make([]core.PK, 0) + ids = make([]schemas.PK, 0) if len(resultsSlice) > 0 { for _, data := range resultsSlice { var id int64 - var pk core.PK = make([]interface{}, 0) + var pk schemas.PK = make([]interface{}, 0) for _, col := range pkColumns { if v, ok := data[col.Name]; !ok { return errors.New("no id") @@ -61,14 +70,14 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string, } for _, id := range ids { - session.engine.logger.Debug("[cacheDelete] delete cache obj:", tableName, id) + session.engine.logger.Debugf("[cache] delete cache obj: %v, %v", tableName, id) sid, err := id.ToString() if err != nil { return err } cacher.DelBean(tableName, sid) } - session.engine.logger.Debug("[cacheDelete] clear cache table:", tableName) + session.engine.logger.Debugf("[cache] clear cache table: %v", tableName) cacher.ClearIds(tableName) return nil } @@ -79,29 +88,26 @@ func (session *Session) Delete(bean interface{}) (int64, error) { defer session.Close() } - if session.statement.lastError != nil { - return 0, session.statement.lastError + if session.statement.LastError != nil { + return 0, session.statement.LastError } - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } - // handle before delete processors - for _, closure := range session.beforeClosures { - closure(bean) - } - cleanupProcessorsClosures(&session.beforeClosures) + executeBeforeClosures(session, bean) if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { processor.BeforeDelete() } - condSQL, condArgs, err := session.statement.genConds(bean) + condSQL, condArgs, err := session.statement.GenConds(bean) if err != nil { return 0, err } - if len(condSQL) == 0 && session.statement.LimitN == 0 { + pLimitN := session.statement.LimitN + if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { return 0, ErrNeedDeletedCond } @@ -119,28 +125,29 @@ func (session *Session) Delete(bean interface{}) (int64, error) { if len(session.statement.OrderStr) > 0 { orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) } - if session.statement.LimitN > 0 { - orderSQL += fmt.Sprintf(" LIMIT %d", session.statement.LimitN) + if pLimitN != nil && *pLimitN > 0 { + limitNValue := *pLimitN + orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) } if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { - case core.POSTGRES: + switch session.engine.dialect.URI().DBType { + case schemas.POSTGRES: inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { deleteSQL += " AND " + inSQL } else { deleteSQL += " WHERE " + inSQL } - case core.SQLITE: + case schemas.SQLITE: inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { deleteSQL += " AND " + inSQL } else { deleteSQL += " WHERE " + inSQL } - // TODO: how to handle delete limit on mssql? - case core.MSSQL: + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: return 0, ErrNotImplemented default: deleteSQL += orderSQL @@ -149,12 +156,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) { var realSQL string argsForCache := make([]interface{}, 0, len(condArgs)*2) - if session.statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled + if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled realSQL = deleteSQL copy(argsForCache, condArgs) argsForCache = append(condArgs, argsForCache...) } else { - // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for cache. + // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. copy(argsForCache, condArgs) argsForCache = append(condArgs, argsForCache...) @@ -165,23 +172,23 @@ func (session *Session) Delete(bean interface{}) (int64, error) { condSQL) if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { - case core.POSTGRES: + switch session.engine.dialect.URI().DBType { + case schemas.POSTGRES: inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { realSQL += " AND " + inSQL } else { realSQL += " WHERE " + inSQL } - case core.SQLITE: + case schemas.SQLITE: inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) if len(condSQL) > 0 { realSQL += " AND " + inSQL } else { realSQL += " WHERE " + inSQL } - // TODO: how to handle delete limit on mssql? - case core.MSSQL: + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: return 0, ErrNotImplemented default: realSQL += orderSQL @@ -203,7 +210,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { }) } - if cacher := session.engine.getCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { + if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) } diff --git a/session_exist.go b/session_exist.go index 660cc47e..e52c618e 100644 --- a/session_exist.go +++ b/session_exist.go @@ -4,86 +4,19 @@ package xorm -import ( - "errors" - "fmt" - "reflect" - - "xorm.io/builder" - "xorm.io/core" -) - // Exist returns true if the record exist otherwise return false func (session *Session) Exist(bean ...interface{}) (bool, error) { if session.isAutoClose { defer session.Close() } - if session.statement.lastError != nil { - return false, session.statement.lastError + if session.statement.LastError != nil { + return false, session.statement.LastError } - var sqlStr string - var args []interface{} - var err error - - if session.statement.RawSQL == "" { - if len(bean) == 0 { - tableName := session.statement.TableName() - if len(tableName) <= 0 { - return false, ErrTableNotFound - } - - tableName = session.statement.Engine.Quote(tableName) - - if session.statement.cond.IsValid() { - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { - return false, err - } - - if session.engine.dialect.DBType() == core.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s WHERE %s", tableName, condSQL) - } else if session.engine.dialect.DBType() == core.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) AND ROWNUM=1", tableName, condSQL) - } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) - } - args = condArgs - } else { - if session.engine.dialect.DBType() == core.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s", tableName) - } else if session.engine.dialect.DBType() == core.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE ROWNUM=1", tableName) - } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) - } - args = []interface{}{} - } - } else { - beanValue := reflect.ValueOf(bean[0]) - if beanValue.Kind() != reflect.Ptr { - return false, errors.New("needs a pointer") - } - - if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefBean(bean[0]); err != nil { - return false, err - } - } - - if len(session.statement.TableName()) <= 0 { - return false, ErrTableNotFound - } - session.statement.Limit(1) - sqlStr, args, err = session.statement.genGetSQL(bean[0]) - if err != nil { - return false, err - } - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenExistSQL(bean...) + if err != nil { + return false, err } rows, err := session.queryRows(sqlStr, args...) diff --git a/session_exist_test.go b/session_exist_test.go deleted file mode 100644 index 9d985771..00000000 --- a/session_exist_test.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestExistStruct(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type RecordExist struct { - Id int64 - Name string - } - - assertSync(t, new(RecordExist)) - - has, err := testEngine.Exist(new(RecordExist)) - assert.NoError(t, err) - assert.False(t, has) - - cnt, err := testEngine.Insert(&RecordExist{ - Name: "test1", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - has, err = testEngine.Exist(new(RecordExist)) - assert.NoError(t, err) - assert.True(t, has) - - has, err = testEngine.Exist(&RecordExist{ - Name: "test1", - }) - assert.NoError(t, err) - assert.True(t, has) - - has, err = testEngine.Exist(&RecordExist{ - Name: "test2", - }) - assert.NoError(t, err) - assert.False(t, has) - - has, err = testEngine.Where("name = ?", "test1").Exist(&RecordExist{}) - assert.NoError(t, err) - assert.True(t, has) - - has, err = testEngine.Where("name = ?", "test2").Exist(&RecordExist{}) - assert.NoError(t, err) - assert.False(t, has) - - has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist() - assert.NoError(t, err) - assert.True(t, has) - - has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist() - assert.NoError(t, err) - assert.False(t, has) - - has, err = testEngine.Table("record_exist").Exist() - assert.NoError(t, err) - assert.True(t, has) - - has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() - assert.NoError(t, err) - assert.True(t, has) - - has, err = testEngine.Table("record_exist").Where("name = ?", "test2").Exist() - assert.NoError(t, err) - assert.False(t, has) -} diff --git a/session_find.go b/session_find.go index e16ae54c..642093f2 100644 --- a/session_find.go +++ b/session_find.go @@ -8,10 +8,12 @@ import ( "errors" "fmt" "reflect" - "strings" "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/internal/statements" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) const ( @@ -52,25 +54,33 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte } session.autoResetStatement = true - if session.statement.selectStr != "" { - session.statement.selectStr = "" + if session.statement.SelectStr != "" { + session.statement.SelectStr = "" } if session.statement.OrderStr != "" { session.statement.OrderStr = "" } + if session.statement.LimitN != nil { + session.statement.LimitN = nil + } + if session.statement.Start > 0 { + session.statement.Start = 0 + } - return session.Count(reflect.New(sliceElementType).Interface()) + // session has stored the conditions so we use `unscoped` to avoid duplicated condition. + return session.Unscoped().Count(reflect.New(sliceElementType).Interface()) } func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { defer session.resetStatement() - - if session.statement.lastError != nil { - return session.statement.lastError + if session.statement.LastError != nil { + return session.statement.LastError } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { + var isSlice = sliceValue.Kind() == reflect.Slice + var isMap = sliceValue.Kind() == reflect.Map + if !isSlice && !isMap { return errors.New("needs a pointer to a slice or a map") } @@ -81,7 +91,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { pv := reflect.New(sliceElementType.Elem()) - if err := session.statement.setRefValue(pv); err != nil { + if err := session.statement.SetRefValue(pv); err != nil { return err } } else { @@ -89,7 +99,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } else if sliceElementType.Kind() == reflect.Struct { pv := reflect.New(sliceElementType) - if err := session.statement.setRefValue(pv); err != nil { + if err := session.statement.SetRefValue(pv); err != nil { return err } } else { @@ -97,107 +107,57 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } - var table = session.statement.RefTable - - var addedTableName = (len(session.statement.JoinStr) > 0) - var autoCond builder.Cond + var ( + table = session.statement.RefTable + addedTableName = (len(session.statement.JoinStr) > 0) + autoCond builder.Cond + ) if tp == tpStruct { - if !session.statement.noAutoCondition && len(condiBean) > 0 { - var err error - autoCond, err = session.statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName) + if !session.statement.NoAutoCondition && len(condiBean) > 0 { + condTable, err := session.engine.tagParser.Parse(reflect.ValueOf(condiBean[0])) + if err != nil { + return err + } + autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, addedTableName) if err != nil { return err } } else { - // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. - // See https://gitea.com/xorm/xorm/issues/179 - if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled - var colName = session.engine.Quote(col.Name) - if addedTableName { - var nm = session.statement.TableName() - if len(session.statement.TableAlias) > 0 { - nm = session.statement.TableAlias - } - colName = session.engine.Quote(nm) + "." + colName - } - - autoCond = session.engine.CondDeleted(colName) + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + autoCond = session.statement.CondDeleted(col) } } } - var sqlStr string - var args []interface{} - var err error - if session.statement.RawSQL == "" { - if len(session.statement.TableName()) <= 0 { - return ErrTableNotFound + // if it's a map with Cols but primary key not in column list, we still need the primary key + if isMap && !session.statement.ColumnMap.IsEmpty() { + for _, k := range session.statement.RefTable.PrimaryKeys { + session.statement.ColumnMap.Add(k) } - - var columnStr = session.statement.ColumnStr - if len(session.statement.selectStr) > 0 { - columnStr = session.statement.selectStr - } else { - if session.statement.JoinStr == "" { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) - } else { - columnStr = session.statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) - } else { - columnStr = "*" - } - } - } - if columnStr == "" { - columnStr = "*" - } - } - - session.statement.cond = session.statement.cond.And(autoCond) - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { - return err - } - - args = append(session.statement.joinArgs, condArgs...) - sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return err - } - // for mssql and use limit - qs := strings.Count(sqlStr, "?") - if len(args)*2 == qs { - args = append(args, args...) - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams } - if session.canCache() { - if cacher := session.engine.getCacher(session.statement.TableName()); cacher != nil && + sqlStr, args, err := session.statement.GenFindSQL(autoCond) + if err != nil { + return err + } + + if session.statement.ColumnMap.IsEmpty() && session.canCache() { + if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && !session.statement.IsDistinct && - !session.statement.unscoped { + !session.statement.GetUnscoped() { err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) if err != ErrCacheFailed { return err } err = nil // !nashtsai! reset err to nil for ErrCacheFailed - session.engine.logger.Warn("Cache Find Failed") + session.engine.logger.Warnf("Cache Find Failed") } } return session.noCacheFind(table, sliceValue, sqlStr, args...) } -func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { +func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error { rows, err := session.queryRows(sqlStr, args...) if err != nil { return err @@ -236,10 +196,10 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va return reflect.New(elemType) } - var containerValueSetFunc func(*reflect.Value, core.PK) error + var containerValueSetFunc func(*reflect.Value, schemas.PK) error if containerValue.Kind() == reflect.Slice { - containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { + containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error { if isPointer { containerValue.Set(reflect.Append(containerValue, newValue.Elem().Addr())) } else { @@ -256,7 +216,7 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va return errors.New("don't support multiple primary key's map has non-slice key type") } - containerValueSetFunc = func(newValue *reflect.Value, pk core.PK) error { + containerValueSetFunc = func(newValue *reflect.Value, pk schemas.PK) error { keyValue := reflect.New(keyType) err := convertPKToValue(table, keyValue.Interface(), pk) if err != nil { @@ -273,8 +233,8 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va if elemType.Kind() == reflect.Struct { var newValue = newElemFunc(fields) - dataStruct := rValue(newValue.Interface()) - tb, err := session.engine.autoMapType(dataStruct) + dataStruct := utils.ReflectValue(newValue.Interface()) + tb, err := session.engine.tagParser.ParseWithCache(dataStruct) if err != nil { return err } @@ -310,7 +270,7 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va return nil } -func convertPKToValue(table *core.Table, dst interface{}, pk core.PK) error { +func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { cols := table.PKColumns() if len(cols) == 1 { return convertAssign(dst, pk[0]) @@ -322,28 +282,28 @@ func convertPKToValue(table *core.Table, dst interface{}, pk core.PK) error { func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { if !session.canCache() || - indexNoCase(sqlStr, "having") != -1 || - indexNoCase(sqlStr, "group by") != -1 { + utils.IndexNoCase(sqlStr, "having") != -1 || + utils.IndexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed } tableName := session.statement.TableName() - cacher := session.engine.getCacher(tableName) + cacher := session.engine.cacherMgr.GetCacher(tableName) if cacher == nil { return nil } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) + sqlStr = filter.Do(sqlStr) } - newsql := session.statement.convertIDSQL(sqlStr) + newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } table := session.statement.RefTable - ids, err := core.GetCacheSql(cacher, tableName, newsql, args) + ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) if err != nil { rows, err := session.queryRows(newsql, args...) if err != nil { @@ -352,11 +312,11 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in defer rows.Close() var i int - ids = make([]core.PK, 0) + ids = make([]schemas.PK, 0) for rows.Next() { i++ if i > 500 { - session.engine.logger.Debug("[cacheFind] ids length > 500, no cache") + session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") return ErrCacheFailed } var res = make([]string, len(table.PrimaryKeys)) @@ -364,9 +324,9 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if err != nil { return err } - var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) + var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { - pk[i], err = session.engine.idTypeAssertion(col, res[i]) + pk[i], err = col.ConvertID(res[i]) if err != nil { return err } @@ -375,19 +335,19 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ids = append(ids, pk) } - session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, sqlStr, newsql, args) - err = core.PutCacheSql(cacher, ids, tableName, newsql, args) + session.engine.logger.Debugf("[cache] cache sql: %v, %v, %v, %v, %v", ids, tableName, sqlStr, newsql, args) + err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return err } } else { - session.engine.logger.Debug("[cacheFind] cache hit sql:", tableName, sqlStr, newsql, args) + session.engine.logger.Debugf("[cache] cache hit sql: %v, %v, %v, %v", tableName, sqlStr, newsql, args) } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) ididxes := make(map[string]int) - var ides []core.PK + var ides []schemas.PK var temps = make([]interface{}, len(ids)) for idx, id := range ids { @@ -396,20 +356,38 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in return err } bean := cacher.GetBean(tableName, sid) - if bean == nil || reflect.ValueOf(bean).Elem().Type() != t { + + // fix issue #894 + isHit := func() (ht bool) { + if bean == nil { + ht = false + return + } + ckb := reflect.ValueOf(bean).Elem().Type() + ht = ckb == t + if !ht && t.Kind() == reflect.Ptr { + ht = t.Elem() == ckb + } + return + } + if !isHit() { ides = append(ides, id) ididxes[sid] = idx } else { - session.engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean) + session.engine.logger.Debugf("[cache] cache hit bean: %v, %v, %v", tableName, id, bean) + + pk, err := table.IDOfV(reflect.ValueOf(bean)) + if err != nil { + return err + } - pk := session.engine.IdOf(bean) xid, err := pk.ToString() if err != nil { return err } if sid != xid { - session.engine.logger.Error("[cacheFind] error cache", xid, sid, bean) + session.engine.logger.Errorf("[cache] error cache: %v, %v, %v", xid, sid, bean) return ErrCacheFailed } temps[idx] = bean @@ -420,6 +398,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in slices := reflect.New(reflect.SliceOf(t)) beans := slices.Interface() + statement := session.statement + session.statement = statements.NewStatement( + session.engine.dialect, + session.engine.tagParser, + session.engine.DatabaseTZ, + ) if len(table.PrimaryKeys) == 1 { ff := make([]interface{}, 0, len(ides)) for _, ie := range ides { @@ -441,6 +425,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if err != nil { return err } + session.statement = statement vs := reflect.Indirect(reflect.ValueOf(beans)) for i := 0; i < vs.Len(); i++ { @@ -448,7 +433,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if rv.Kind() != reflect.Ptr { rv = rv.Addr() } - id, err := session.engine.idOfV(rv) + id, err := table.IDOfV(rv) if err != nil { return err } @@ -459,7 +444,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in bean := rv.Interface() temps[ididxes[sid]] = bean - session.engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps) + session.engine.logger.Debugf("[cache] cache bean: %v, %v, %v, %v", tableName, id, bean, temps) cacher.PutBean(tableName, sid, bean) } } @@ -467,7 +452,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in for j := 0; j < len(temps); j++ { bean := temps[j] if bean == nil { - session.engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps) + session.engine.logger.Warnf("[cache] cache no hit: %v, %v, %v", tableName, ids[j], temps) // return errors.New("cache error") // !nashtsai! no need to return error, but continue instead continue } @@ -488,7 +473,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } } else { if keyType.Kind() != reflect.Slice { - return errors.New("table have multiple primary keys, key is not core.PK or slice") + return errors.New("table have multiple primary keys, key is not schemas.PK or slice") } ikey = key } diff --git a/session_get.go b/session_get.go index cc0a2019..afedcd1f 100644 --- a/session_get.go +++ b/session_get.go @@ -11,7 +11,9 @@ import ( "reflect" "strconv" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) // Get retrieve one record from database, bean's non-empty fields @@ -26,8 +28,8 @@ func (session *Session) Get(bean interface{}) (bool, error) { func (session *Session) get(bean interface{}) (bool, error) { defer session.resetStatement() - if session.statement.lastError != nil { - return false, session.statement.lastError + if session.statement.LastError != nil { + return false, session.statement.LastError } beanValue := reflect.ValueOf(bean) @@ -38,7 +40,7 @@ func (session *Session) get(bean interface{}) (bool, error) { } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return false, err } } @@ -52,20 +54,20 @@ func (session *Session) get(bean interface{}) (bool, error) { return false, ErrTableNotFound } session.statement.Limit(1) - sqlStr, args, err = session.statement.genGetSQL(bean) + sqlStr, args, err = session.statement.GenGetSQL(bean) if err != nil { return false, err } } else { - sqlStr = session.statement.RawSQL + sqlStr = session.statement.GenRawSQL() args = session.statement.RawParams } table := session.statement.RefTable - if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { - if cacher := session.engine.getCacher(session.statement.TableName()); cacher != nil && - !session.statement.unscoped { + if session.statement.ColumnMap.IsEmpty() && session.canCache() && beanValue.Elem().Kind() == reflect.Struct { + if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && + !session.statement.GetUnscoped() { has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { return has, err @@ -73,11 +75,11 @@ func (session *Session) get(bean interface{}) (bool, error) { } } - context := session.statement.context + context := session.statement.Context if context != nil { res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args)) if res != nil { - session.engine.logger.Debug("hit context cache", sqlStr) + session.engine.logger.Debugf("hit context cache: %s", sqlStr) structValue := reflect.Indirect(reflect.ValueOf(bean)) structValue.Set(reflect.Indirect(reflect.ValueOf(res))) @@ -99,7 +101,7 @@ func (session *Session) get(bean interface{}) (bool, error) { return true, nil } -func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { +func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { return false, err @@ -240,10 +242,10 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea if err != nil { return false, err } - // close it before covert data + // close it before convert data rows.Close() - dataStruct := rValue(bean) + dataStruct := utils.ReflectValue(bean) _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) if err != nil { return true, err @@ -271,19 +273,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } for _, filter := range session.engine.dialect.Filters() { - sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable) + sqlStr = filter.Do(sqlStr) } - newsql := session.statement.convertIDSQL(sqlStr) + newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { return false, ErrCacheFailed } tableName := session.statement.TableName() - cacher := session.engine.getCacher(tableName) + cacher := session.engine.cacherMgr.GetCacher(tableName) - session.engine.logger.Debug("[cacheGet] find sql:", newsql, args) + session.engine.logger.Debugf("[cache] Get SQL: %s, %v", newsql, args) table := session.statement.RefTable - ids, err := core.GetCacheSql(cacher, tableName, newsql, args) + ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) if err != nil { var res = make([]string, len(table.PrimaryKeys)) rows, err := session.NoCache().queryRows(newsql, args...) @@ -301,7 +303,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return false, ErrCacheFailed } - var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) + var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { if col.SQLType.IsText() { pk[i] = res[i] @@ -316,20 +318,20 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } } - ids = []core.PK{pk} - session.engine.logger.Debug("[cacheGet] cache ids:", newsql, ids) - err = core.PutCacheSql(cacher, ids, tableName, newsql, args) + ids = []schemas.PK{pk} + session.engine.logger.Debugf("[cache] cache ids: %s, %v", newsql, ids) + err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return false, err } } else { - session.engine.logger.Debug("[cacheGet] cache hit sql:", newsql, ids) + session.engine.logger.Debugf("[cache] cache hit: %s, %v", newsql, ids) } if len(ids) > 0 { structValue := reflect.Indirect(reflect.ValueOf(bean)) id := ids[0] - session.engine.logger.Debug("[cacheGet] get bean:", tableName, id) + session.engine.logger.Debugf("[cache] get bean: %s, %v", tableName, id) sid, err := id.ToString() if err != nil { return false, err @@ -342,10 +344,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return has, err } - session.engine.logger.Debug("[cacheGet] cache bean:", tableName, id, cacheBean) + session.engine.logger.Debugf("[cache] cache bean: %s, %v, %v", tableName, id, cacheBean) cacher.PutBean(tableName, sid, cacheBean) } else { - session.engine.logger.Debug("[cacheGet] cache hit bean:", tableName, id, cacheBean) + session.engine.logger.Debugf("[cache] cache hit: %s, %v, %v", tableName, id, cacheBean) has = true } structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean))) diff --git a/session_insert.go b/session_insert.go index 1e19ce7a..5f968151 100644 --- a/session_insert.go +++ b/session_insert.go @@ -12,10 +12,13 @@ import ( "strconv" "strings" - "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) +// ErrNoElementsOnSlice represents an error there is no element when insert +var ErrNoElementsOnSlice = errors.New("No element on slice when insert") + // Insert insert one or more beans func (session *Session) Insert(beans ...interface{}) (int64, error) { var affected int64 @@ -67,23 +70,15 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { sliceValue := reflect.Indirect(reflect.ValueOf(bean)) if sliceValue.Kind() == reflect.Slice { size := sliceValue.Len() - if size > 0 { - if session.engine.SupportInsertMany() { - cnt, err := session.innerInsertMulti(bean) - if err != nil { - return affected, err - } - affected += cnt - } else { - for i := 0; i < size; i++ { - cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) - if err != nil { - return affected, err - } - affected += cnt - } - } + if size <= 0 { + return 0, ErrNoElementsOnSlice } + + cnt, err := session.innerInsertMulti(bean) + if err != nil { + return affected, err + } + affected += cnt } else { cnt, err := session.innerInsert(bean) if err != nil { @@ -107,7 +102,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, errors.New("could not insert a empty slice") } - if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { + if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil { return 0, err } @@ -116,17 +111,24 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, ErrTableNotFound } - table := session.statement.RefTable - size := sliceValue.Len() - - var colNames []string - var colMultiPlaces []string - var args []interface{} - var cols []*core.Column + var ( + table = session.statement.RefTable + size = sliceValue.Len() + colNames []string + colMultiPlaces []string + args []interface{} + cols []*schemas.Column + ) for i := 0; i < size; i++ { v := sliceValue.Index(i) - vv := reflect.Indirect(v) + var vv reflect.Value + switch v.Kind() { + case reflect.Interface: + vv = reflect.Indirect(v.Elem()) + default: + vv = reflect.Indirect(v) + } elemValue := v.Interface() var colPlaces []string @@ -141,123 +143,77 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } // -- - if i == 0 { - for _, col := range table.Columns() { - ptrFieldValue, err := col.ValueOfV(&vv) + for _, col := range table.Columns() { + ptrFieldValue, err := col.ValueOfV(&vv) + if err != nil { + return 0, err + } + fieldValue := *ptrFieldValue + if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) { + continue + } + if col.MapType == schemas.ONLYFROMDB { + continue + } + if col.IsDeleted { + continue + } + if session.statement.OmitColumnMap.Contain(col.Name) { + continue + } + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { + continue + } + if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { + val, t := session.engine.nowTime(col) + args = append(args, val) + + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } else if col.IsVersion && session.statement.CheckVersion { + args = append(args, 1) + var colName = col.Name + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnInt(bean, col, 1) + }) + } else { + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return 0, err } - fieldValue := *ptrFieldValue - if col.IsAutoIncrement && isZero(fieldValue.Interface()) { - continue - } - if col.MapType == core.ONLYFROMDB { - continue - } - if col.IsDeleted { - continue - } - if session.statement.omitColumnMap.contain(col.Name) { - continue - } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { - continue - } - if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { - val, t := session.engine.nowTime(col) - args = append(args, val) - - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.statement.checkVersion { - args = append(args, 1) - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnInt(bean, col, 1) - }) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return 0, err - } - args = append(args, arg) - } + args = append(args, arg) + } + if i == 0 { colNames = append(colNames, col.Name) cols = append(cols, col) - colPlaces = append(colPlaces, "?") - } - } else { - for _, col := range cols { - ptrFieldValue, err := col.ValueOfV(&vv) - if err != nil { - return 0, err - } - fieldValue := *ptrFieldValue - - if col.IsAutoIncrement && isZero(fieldValue.Interface()) { - continue - } - if col.MapType == core.ONLYFROMDB { - continue - } - if col.IsDeleted { - continue - } - if session.statement.omitColumnMap.contain(col.Name) { - continue - } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { - continue - } - if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { - val, t := session.engine.nowTime(col) - args = append(args, val) - - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } else if col.IsVersion && session.statement.checkVersion { - args = append(args, 1) - var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnInt(bean, col, 1) - }) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return 0, err - } - args = append(args, arg) - } - - colPlaces = append(colPlaces, "?") } + colPlaces = append(colPlaces, "?") } + colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", ")) } cleanupProcessorsClosures(&session.beforeClosures) + quoter := session.engine.dialect.Quoter() var sql string - if session.engine.dialect.DBType() == core.ORACLE { + colStr := quoter.Join(colNames, ",") + if session.engine.dialect.URI().DBType == schemas.ORACLE { temp := fmt.Sprintf(") INTO %s (%v) VALUES (", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ",")) + quoter.Quote(tableName), + colStr) sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + quoter.Quote(tableName), + colStr, strings.Join(colMultiPlaces, temp)) } else { sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", - session.engine.Quote(tableName), - quoteColumns(colNames, session.engine.Quote, ","), + quoter.Quote(tableName), + colStr, strings.Join(colMultiPlaces, "),(")) } res, err := session.exec(sql, args...) @@ -277,7 +233,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error for _, closure := range session.afterClosures { closure(elemValue) } - if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok { + if processor, ok := elemValue.(AfterInsertProcessor); ok { processor.AfterInsert() } } else { @@ -290,7 +246,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error session.afterInsertBeans[elemValue] = &afterClosures } } else { - if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok { + if _, ok := elemValue.(AfterInsertProcessor); ok { session.afterInsertBeans[elemValue] = nil } } @@ -309,27 +265,24 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice { - return 0, ErrParamsType - + return 0, ErrPtrSliceType } if sliceValue.Len() <= 0 { - return 0, nil + return 0, ErrNoElementsOnSlice } return session.innerInsertMulti(rowsSlicePtr) } func (session *Session) innerInsert(bean interface{}) (int64, error) { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } if len(session.statement.TableName()) <= 0 { return 0, ErrTableNotFound } - table := session.statement.RefTable - // handle BeforeInsertProcessor for _, closure := range session.beforeClosures { closure(bean) @@ -340,100 +293,19 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { processor.BeforeInsert() } + var tableName = session.statement.TableName() + table := session.statement.RefTable + colNames, args, err := session.genInsertColumns(bean) if err != nil { return 0, err } - exprs := session.statement.exprColumns - colPlaces := strings.Repeat("?, ", len(colNames)) - if exprs.Len() <= 0 && len(colPlaces) > 0 { - colPlaces = colPlaces[0 : len(colPlaces)-2] - } - - var tableName = session.statement.TableName() - var output string - if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { - output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) - } - - var buf = builder.NewWriter() - if _, err := buf.WriteString(fmt.Sprintf("INSERT INTO %s", session.engine.Quote(tableName))); err != nil { + sqlStr, args, err := session.statement.GenInsertSQL(colNames, args) + if err != nil { return 0, err } - if len(colPlaces) <= 0 { - if session.engine.dialect.DBType() == core.MYSQL { - if _, err := buf.WriteString(" VALUES ()"); err != nil { - return 0, err - } - } else { - if _, err := buf.WriteString(fmt.Sprintf("%s DEFAULT VALUES", output)); err != nil { - return 0, err - } - } - } else { - if _, err := buf.WriteString(" ("); err != nil { - return 0, err - } - - if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil { - return 0, err - } - - if session.statement.cond.IsValid() { - if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil { - return 0, err - } - - if err := session.statement.writeArgs(buf, args); err != nil { - return 0, err - } - - if len(exprs.args) > 0 { - if _, err := buf.WriteString(","); err != nil { - return 0, err - } - } - if err := exprs.writeArgs(buf); err != nil { - return 0, err - } - - if _, err := buf.WriteString(fmt.Sprintf(" FROM %v WHERE ", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.statement.cond.WriteTo(buf); err != nil { - return 0, err - } - } else { - buf.Append(args...) - - if _, err := buf.WriteString(fmt.Sprintf(")%s VALUES (%v", - output, - colPlaces)); err != nil { - return 0, err - } - - if err := exprs.writeArgs(buf); err != nil { - return 0, err - } - - if _, err := buf.WriteString(")"); err != nil { - return 0, err - } - } - } - - if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES { - if _, err := buf.WriteString(" RETURNING " + session.engine.Quote(table.AutoIncrement)); err != nil { - return 0, err - } - } - - sqlStr := buf.String() - args = buf.Args() - handleAfterInsertProcessorFunc := func(bean interface{}) { if session.isAutoCommit { for _, closure := range session.afterClosures { @@ -464,7 +336,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. - if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { + if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 { res, err := session.queryBytes("select seq_atable.currval from dual", args...) if err != nil { return 0, err @@ -474,10 +346,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(tableName) - if table.Version != "" && session.statement.checkVersion { + if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else if verValue.IsValid() && verValue.CanSet() { session.incrVersionFieldValue(verValue) } @@ -495,7 +367,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -505,7 +377,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, nil - } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == core.POSTGRES || session.engine.dialect.DBType() == core.MSSQL) { + } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || + session.engine.dialect.URI().DBType == schemas.MSSQL) { res, err := session.queryBytes(sqlStr, args...) if err != nil { @@ -515,10 +388,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(tableName) - if table.Version != "" && session.statement.checkVersion { + if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else if verValue.IsValid() && verValue.CanSet() { session.incrVersionFieldValue(verValue) } @@ -536,7 +409,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -546,48 +419,48 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(int64ToIntValue(id, aiValue.Type())) return 1, nil - } else { - res, err := session.exec(sqlStr, args...) + } + + res, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } + + defer handleAfterInsertProcessorFunc(bean) + + session.cacheInsert(tableName) + + if table.Version != "" && session.statement.CheckVersion { + verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - return 0, err + session.engine.logger.Errorf("%v", err) + } else if verValue.IsValid() && verValue.CanSet() { + session.incrVersionFieldValue(verValue) } + } - defer handleAfterInsertProcessorFunc(bean) - - session.cacheInsert(tableName) - - if table.Version != "" && session.statement.checkVersion { - verValue, err := table.VersionColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Error(err) - } else if verValue.IsValid() && verValue.CanSet() { - session.incrVersionFieldValue(verValue) - } - } - - if table.AutoIncrement == "" { - return res.RowsAffected() - } - - var id int64 - id, err = res.LastInsertId() - if err != nil || id <= 0 { - return res.RowsAffected() - } - - aiValue, err := table.AutoIncrColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Error(err) - } - - if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { - return res.RowsAffected() - } - - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - + if table.AutoIncrement == "" { return res.RowsAffected() } + + var id int64 + id, err = res.LastInsertId() + if err != nil || id <= 0 { + return res.RowsAffected() + } + + aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if err != nil { + session.engine.logger.Errorf("%v", err) + } + + if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { + return res.RowsAffected() + } + + aiValue.Set(int64ToIntValue(id, aiValue.Type())) + + return res.RowsAffected() } // InsertOne insert only one struct into database as a record. @@ -605,11 +478,11 @@ func (session *Session) cacheInsert(table string) error { if !session.statement.UseCache { return nil } - cacher := session.engine.getCacher(table) + cacher := session.engine.cacherMgr.GetCacher(table) if cacher == nil { return nil } - session.engine.logger.Debug("[cache] clear sql:", table) + session.engine.logger.Debugf("[cache] clear SQL: %v", table) cacher.ClearIds(table) return nil } @@ -621,7 +494,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac args := make([]interface{}, 0, len(table.ColumnsSeq())) for _, col := range table.Columns() { - if col.MapType == core.ONLYFROMDB { + if col.MapType == schemas.ONLYFROMDB { continue } @@ -629,19 +502,19 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac continue } - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } - if session.statement.incrColumns.isColExist(col.Name) { + if session.statement.IncrColumns.IsColExist(col.Name) { continue - } else if session.statement.decrColumns.isColExist(col.Name) { + } else if session.statement.DecrColumns.IsColExist(col.Name) { continue - } else if session.statement.exprColumns.isColExist(col.Name) { + } else if session.statement.ExprColumns.IsColExist(col.Name) { continue } @@ -651,30 +524,13 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } fieldValue := *fieldValuePtr - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } + if col.IsAutoIncrement && utils.IsValueZero(fieldValue) { + continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { - if col.Nullable && isZero(fieldValue.Interface()) { + if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok { + if col.Nullable && utils.IsValueZero(fieldValue) { var nilValue *int fieldValue = reflect.ValueOf(nilValue) } @@ -690,10 +546,10 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return colNames, args, err } @@ -716,9 +572,9 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } var columns = make([]string, 0, len(m)) - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns for k := range m { - if !exprs.isColExist(k) { + if !exprs.IsColExist(k) { columns = append(columns, k) } } @@ -729,66 +585,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err args = append(args, m[colName]) } - w := builder.NewWriter() - if session.statement.cond.IsValid() { - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { - return 0, err - } - - if _, err := w.WriteString(") SELECT "); err != nil { - return 0, err - } - - if err := session.statement.writeArgs(w, args); err != nil { - return 0, err - } - - if len(exprs.args) > 0 { - if _, err := w.WriteString(","); err != nil { - return 0, err - } - if err := exprs.writeArgs(w); err != nil { - return 0, err - } - } - - if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.statement.cond.WriteTo(w); err != nil { - return 0, err - } - } else { - qm := strings.Repeat("?,", len(columns)) - qm = qm[:len(qm)-1] - - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { - return 0, err - } - w.Append(args...) - } - - sql := w.String() - args = w.Args() - - if err := session.cacheInsert(tableName); err != nil { - return 0, err - } - - res, err := session.exec(sql, args...) - if err != nil { - return 0, err - } - affected, err := res.RowsAffected() - if err != nil { - return 0, err - } - return affected, nil + return session.insertMap(columns, args) } func (session *Session) insertMapString(m map[string]string) (int64, error) { @@ -802,12 +599,13 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { } var columns = make([]string, 0, len(m)) - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns for k := range m { - if !exprs.isColExist(k) { + if !exprs.IsColExist(k) { columns = append(columns, k) } } + sort.Strings(columns) var args = make([]interface{}, 0, len(m)) @@ -815,52 +613,19 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { args = append(args, m[colName]) } - w := builder.NewWriter() - if session.statement.cond.IsValid() { - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { - return 0, err - } + return session.insertMap(columns, args) +} - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { - return 0, err - } - - if _, err := w.WriteString(") SELECT "); err != nil { - return 0, err - } - - if err := session.statement.writeArgs(w, args); err != nil { - return 0, err - } - - if len(exprs.args) > 0 { - if _, err := w.WriteString(","); err != nil { - return 0, err - } - if err := exprs.writeArgs(w); err != nil { - return 0, err - } - } - - if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil { - return 0, err - } - - if err := session.statement.cond.WriteTo(w); err != nil { - return 0, err - } - } else { - qm := strings.Repeat("?,", len(columns)) - qm = qm[:len(qm)-1] - - if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil { - return 0, err - } - w.Append(args...) +func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) { + tableName := session.statement.TableName() + if len(tableName) <= 0 { + return 0, ErrTableNotFound } - sql := w.String() - args = w.Args() + sql, args, err := session.statement.GenInsertMapSQL(columns, args) + if err != nil { + return 0, err + } if err := session.cacheInsert(tableName); err != nil { return 0, err diff --git a/session_iterate.go b/session_iterate.go index ca996c28..8cab8f48 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -4,7 +4,11 @@ package xorm -import "reflect" +import ( + "reflect" + + "xorm.io/xorm/internal/utils" +) // IterFunc only use by Iterate type IterFunc func(idx int, bean interface{}) error @@ -23,11 +27,11 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { defer session.Close() } - if session.statement.lastError != nil { - return session.statement.lastError + if session.statement.LastError != nil { + return session.statement.LastError } - if session.statement.bufferSize > 0 { + if session.statement.BufferSize > 0 { return session.bufferIterate(bean, fun) } @@ -55,27 +59,28 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { // BufferSize sets the buffersize for iterate func (session *Session) BufferSize(size int) *Session { - session.statement.bufferSize = size + session.statement.BufferSize = size return session } func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error { - if session.isAutoClose { - defer session.Close() - } - - var bufferSize = session.statement.bufferSize - var limit = session.statement.LimitN - if limit > 0 && bufferSize > limit { - bufferSize = limit + var bufferSize = session.statement.BufferSize + var pLimitN = session.statement.LimitN + if pLimitN != nil && bufferSize > *pLimitN { + bufferSize = *pLimitN } var start = session.statement.Start - v := rValue(bean) + v := utils.ReflectValue(bean) sliceType := reflect.SliceOf(v.Type()) var idx = 0 - for { + session.autoResetStatement = false + defer func() { + session.autoResetStatement = true + }() + + for bufferSize > 0 { slice := reflect.New(sliceType) - if err := session.Limit(bufferSize, start).find(slice.Interface(), bean); err != nil { + if err := session.NoCache().Limit(bufferSize, start).find(slice.Interface(), bean); err != nil { return err } @@ -86,13 +91,13 @@ func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error { idx++ } - start = start + slice.Elem().Len() - if limit > 0 && idx+bufferSize > limit { - bufferSize = limit - idx + if bufferSize > slice.Elem().Len() { + break } - if bufferSize <= 0 || slice.Elem().Len() < bufferSize || idx == limit { - break + start = start + slice.Elem().Len() + if pLimitN != nil && start+bufferSize > *pLimitN { + bufferSize = *pLimitN - start } } diff --git a/session_pk_test.go b/session_pk_test.go deleted file mode 100644 index 4c066634..00000000 --- a/session_pk_test.go +++ /dev/null @@ -1,1197 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "testing" - "time" - - "xorm.io/core" - "github.com/stretchr/testify/assert" -) - -type IntId struct { - Id int `xorm:"pk autoincr"` - Name string -} - -type Int16Id struct { - Id int16 `xorm:"pk autoincr"` - Name string -} - -type Int32Id struct { - Id int32 `xorm:"pk autoincr"` - Name string -} - -type UintId struct { - Id uint `xorm:"pk autoincr"` - Name string -} - -type Uint16Id struct { - Id uint16 `xorm:"pk autoincr"` - Name string -} - -type Uint32Id struct { - Id uint32 `xorm:"pk autoincr"` - Name string -} - -type Uint64Id struct { - Id uint64 `xorm:"pk autoincr"` - Name string -} - -type StringPK struct { - Id string `xorm:"pk notnull"` - Name string -} - -type ID int64 -type MyIntPK struct { - ID ID `xorm:"pk autoincr"` - Name string -} - -type StrID string -type MyStringPK struct { - ID StrID `xorm:"pk notnull"` - Name string -} - -func TestIntId(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&IntId{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&IntId{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&IntId{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(IntId) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]IntId, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[int]IntId) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&IntId{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestInt16Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Int16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Int16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&Int16Id{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Int16Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]Int16Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[int16]Int16Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Int16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestInt32Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Int32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Int32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&Int32Id{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Int32Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]Int32Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[int32]Int32Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Int32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestUintId(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&UintId{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&UintId{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&UintId{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - var inserts = []UintId{ - {Name: "test1"}, - {Name: "test2"}, - } - cnt, err = testEngine.Insert(&inserts) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 2 { - err = errors.New("insert count should be two") - t.Error(err) - panic(err) - } - - bean := new(UintId) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]UintId, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 3 { - err = errors.New("get count should be three") - t.Error(err) - panic(err) - } - - beans2 := make(map[uint]UintId, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 3 { - err = errors.New("get count should be three") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&UintId{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestUint16Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Uint16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Uint16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&Uint16Id{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Uint16Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]Uint16Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[uint16]Uint16Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Uint16Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestUint32Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Uint32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Uint32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&Uint32Id{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Uint32Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]Uint32Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[uint32]Uint32Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Uint32Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestUint64Id(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Uint64Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Uint64Id{}) - if err != nil { - t.Error(err) - panic(err) - } - - idbean := &Uint64Id{Name: "test"} - cnt, err := testEngine.Insert(idbean) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(Uint64Id) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if bean.Id != idbean.Id { - panic(errors.New("should be equal")) - } - - beans := make([]Uint64Id, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans[0] { - panic(errors.New("should be equal")) - } - - beans2 := make(map[uint64]Uint64Id, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans2[bean.Id] { - panic(errors.New("should be equal")) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&Uint64Id{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestStringPK(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&StringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&StringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&StringPK{Id: "1-1-2", Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(StringPK) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans := make([]StringPK, 0) - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - beans2 := make(map[string]StringPK) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - cnt, err = testEngine.ID(bean.Id).Delete(&StringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -type CompositeKey struct { - Id1 int64 `xorm:"id1 pk"` - Id2 int64 `xorm:"id2 pk"` - UpdateStr string -} - -func TestCompositeKey(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&CompositeKey{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&CompositeKey{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&CompositeKey{11, 22, ""}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("failed to insert CompositeKey{11, 22}")) - } - - cnt, err = testEngine.Insert(&CompositeKey{11, 22, ""}) - if err == nil || cnt == 1 { - t.Error(errors.New("inserted CompositeKey{11, 22}")) - } - - var compositeKeyVal CompositeKey - has, err := testEngine.ID(core.PK{11, 22}).Get(&compositeKeyVal) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get CompositeKey{11, 22}")) - } - - var compositeKeyVal2 CompositeKey - // test passing PK ptr, this test seem failed withCache - has, err = testEngine.ID(&core.PK{11, 22}).Get(&compositeKeyVal2) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get CompositeKey{11, 22}")) - } - - if compositeKeyVal != compositeKeyVal2 { - t.Error(errors.New("should be equal")) - } - - var cps = make([]CompositeKey, 0) - err = testEngine.Find(&cps) - if err != nil { - t.Error(err) - } - if len(cps) != 1 { - t.Error(errors.New("should has one record")) - } - if cps[0] != compositeKeyVal { - t.Error(errors.New("should be equal")) - } - - cnt, err = testEngine.Insert(&CompositeKey{22, 22, ""}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("failed to insert CompositeKey{22, 22}")) - } - - cps = make([]CompositeKey, 0) - err = testEngine.Find(&cps) - assert.NoError(t, err) - assert.EqualValues(t, 2, len(cps), "should has two record") - assert.EqualValues(t, compositeKeyVal, cps[0], "should be equeal") - - compositeKeyVal = CompositeKey{UpdateStr: "test1"} - cnt, err = testEngine.ID(core.PK{11, 22}).Update(&compositeKeyVal) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't update CompositeKey{11, 22}")) - } - - cnt, err = testEngine.ID(core.PK{11, 22}).Delete(&CompositeKey{}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't delete CompositeKey{11, 22}")) - } -} - -func TestCompositeKey2(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type User struct { - UserId string `xorm:"varchar(19) not null pk"` - NickName string `xorm:"varchar(19) not null"` - GameId uint32 `xorm:"integer pk"` - Score int32 `xorm:"integer"` - } - - err := testEngine.DropTables(&User{}) - - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&User{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&User{"11", "nick", 22, 5}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("failed to insert User{11, 22}")) - } - - cnt, err = testEngine.Insert(&User{"11", "nick", 22, 6}) - if err == nil || cnt == 1 { - t.Error(errors.New("inserted User{11, 22}")) - } - - var user User - has, err := testEngine.ID(core.PK{"11", 22}).Get(&user) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get User{11, 22}")) - } - - // test passing PK ptr, this test seem failed withCache - has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get User{11, 22}")) - } - - user = User{NickName: "test1"} - cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't update User{11, 22}")) - } - - cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&User{}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't delete CompositeKey{11, 22}")) - } -} - -type MyString string -type UserPK2 struct { - UserId MyString `xorm:"varchar(19) not null pk"` - NickName string `xorm:"varchar(19) not null"` - GameId uint32 `xorm:"integer pk"` - Score int32 `xorm:"integer"` -} - -func TestCompositeKey3(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&UserPK2{}) - - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&UserPK2{}) - if err != nil { - t.Error(err) - panic(err) - } - - cnt, err := testEngine.Insert(&UserPK2{"11", "nick", 22, 5}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("failed to insert User{11, 22}")) - } - - cnt, err = testEngine.Insert(&UserPK2{"11", "nick", 22, 6}) - if err == nil || cnt == 1 { - t.Error(errors.New("inserted User{11, 22}")) - } - - var user UserPK2 - has, err := testEngine.ID(core.PK{"11", 22}).Get(&user) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get User{11, 22}")) - } - - // test passing PK ptr, this test seem failed withCache - has, err = testEngine.ID(&core.PK{"11", 22}).Get(&user) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get User{11, 22}")) - } - - user = UserPK2{NickName: "test1"} - cnt, err = testEngine.ID(core.PK{"11", 22}).Update(&user) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't update User{11, 22}")) - } - - cnt, err = testEngine.ID(core.PK{"11", 22}).Delete(&UserPK2{}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't delete CompositeKey{11, 22}")) - } -} - -func TestMyIntId(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&MyIntPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&MyIntPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - idbean := &MyIntPK{Name: "test"} - cnt, err := testEngine.Insert(idbean) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(MyIntPK) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if bean.ID != idbean.ID { - panic(errors.New("should be equal")) - } - - var beans []MyIntPK - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans[0] { - panic(errors.New("should be equal")) - } - - beans2 := make(map[ID]MyIntPK, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans2[bean.ID] { - panic(errors.New("should be equal")) - } - - cnt, err = testEngine.ID(bean.ID).Delete(&MyIntPK{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestMyStringId(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&MyStringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&MyStringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - - idbean := &MyStringPK{ID: "1111", Name: "test"} - cnt, err := testEngine.Insert(idbean) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } - - bean := new(MyStringPK) - has, err := testEngine.Get(bean) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if bean.ID != idbean.ID { - panic(errors.New("should be equal")) - } - - var beans []MyStringPK - err = testEngine.Find(&beans) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans[0] { - panic(errors.New("should be equal")) - } - - beans2 := make(map[StrID]MyStringPK, 0) - err = testEngine.Find(&beans2) - if err != nil { - t.Error(err) - panic(err) - } - if len(beans2) != 1 { - err = errors.New("get count should be one") - t.Error(err) - panic(err) - } - - if *bean != beans2[bean.ID] { - panic(errors.New("should be equal")) - } - - cnt, err = testEngine.ID(bean.ID).Delete(&MyStringPK{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert count should be one") - t.Error(err) - panic(err) - } -} - -func TestSingleAutoIncrColumn(t *testing.T) { - type Account struct { - Id int64 `xorm:"pk autoincr"` - } - - assert.NoError(t, prepareEngine()) - assertSync(t, new(Account)) - - _, err := testEngine.Insert(&Account{}) - assert.NoError(t, err) -} - -func TestCompositePK(t *testing.T) { - type TaskSolution struct { - UID string `xorm:"notnull pk UUID 'uid'"` - TID string `xorm:"notnull pk UUID 'tid'"` - Created time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` - } - - assert.NoError(t, prepareEngine()) - - tables1, err := testEngine.DBMetas() - assert.NoError(t, err) - - assertSync(t, new(TaskSolution)) - assert.NoError(t, testEngine.Sync2(new(TaskSolution))) - - tables2, err := testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1+len(tables1), len(tables2)) - - var table *core.Table - for _, t := range tables2 { - if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") { - table = t - break - } - } - - assert.NotEqual(t, nil, table) - - pkCols := table.PKColumns() - assert.EqualValues(t, 2, len(pkCols)) - assert.EqualValues(t, "uid", pkCols[0].Name) - assert.EqualValues(t, "tid", pkCols[1].Name) -} - -func TestNoPKIdQueryUpdate(t *testing.T) { - type NoPKTable struct { - Username string - } - - assert.NoError(t, prepareEngine()) - assertSync(t, new(NoPKTable)) - - cnt, err := testEngine.Insert(&NoPKTable{ - Username: "test", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var res NoPKTable - has, err := testEngine.ID("test").Get(&res) - assert.Error(t, err) - assert.False(t, has) - - cnt, err = testEngine.ID("test").Update(&NoPKTable{ - Username: "test1", - }) - assert.Error(t, err) - assert.EqualValues(t, 0, cnt) - - type UnvalidPKTable struct { - ID int `xorm:"id"` - Username string - } - - assertSync(t, new(UnvalidPKTable)) - - cnt, err = testEngine.Insert(&UnvalidPKTable{ - ID: 1, - Username: "test", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var res2 UnvalidPKTable - has, err = testEngine.ID(1).Get(&res2) - assert.Error(t, err) - assert.False(t, has) - - cnt, err = testEngine.ID(1).Update(&UnvalidPKTable{ - Username: "test1", - }) - assert.Error(t, err) - assert.EqualValues(t, 0, cnt) -} diff --git a/session_query.go b/session_query.go index 21c00b8d..12136466 100644 --- a/session_query.go +++ b/session_query.go @@ -8,82 +8,19 @@ import ( "fmt" "reflect" "strconv" - "strings" "time" - "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/core" + "xorm.io/xorm/schemas" ) -func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { - if len(sqlOrArgs) > 0 { - return convertSQLOrArgs(sqlOrArgs...) - } - - if session.statement.RawSQL != "" { - return session.statement.RawSQL, session.statement.RawParams, nil - } - - if len(session.statement.TableName()) <= 0 { - return "", nil, ErrTableNotFound - } - - var columnStr = session.statement.ColumnStr - if len(session.statement.selectStr) > 0 { - columnStr = session.statement.selectStr - } else { - if session.statement.JoinStr == "" { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) - } else { - columnStr = session.statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.engine.quoteColumns(session.statement.GroupByStr) - } else { - columnStr = "*" - } - } - } - if columnStr == "" { - columnStr = "*" - } - } - - if err := session.statement.processIDParam(); err != nil { - return "", nil, err - } - - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { - return "", nil, err - } - - args := append(session.statement.joinArgs, condArgs...) - sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return "", nil, err - } - // for mssql and use limit - qs := strings.Count(sqlStr, "?") - if len(args)*2 == qs { - args = append(args, args...) - } - - return sqlStr, args, nil -} - // Query runs a raw sql and return records as []map[string][]byte func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, error) { if session.isAutoClose { defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -116,8 +53,8 @@ func value2String(rawValue *reflect.Value) (str string, err error) { } // time type case reflect.Struct: - if aa.ConvertibleTo(core.TimeType) { - str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) + if aa.ConvertibleTo(schemas.TimeType) { + str = vv.Convert(schemas.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) } else { err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) } @@ -232,7 +169,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -252,7 +189,7 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -305,7 +242,7 @@ func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]i defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } diff --git a/session_raw.go b/session_raw.go index a9298296..4cfe297a 100644 --- a/session_raw.go +++ b/session_raw.go @@ -7,15 +7,13 @@ package xorm import ( "database/sql" "reflect" - "time" - "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/core" ) func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { for _, filter := range session.engine.dialect.Filters() { - *sqlStr = filter.Do(*sqlStr, session.engine.dialect, session.statement.RefTable) + *sqlStr = filter.Do(*sqlStr) } session.lastSQL = *sqlStr @@ -24,30 +22,14 @@ func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Rows, error) { defer session.resetStatement() + if session.statement.LastError != nil { + return nil, session.statement.LastError + } session.queryPreprocess(&sqlStr, args...) - if session.showSQL { - session.lastSQL = sqlStr - session.lastSQLArgs = args - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(args) > 0 { - session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration) - } else { - session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration) - } - }() - } else { - if len(args) > 0 { - session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args) - } else { - session.engine.logger.Infof("[SQL] %v", sqlStr) - } - } - } + session.lastSQL = sqlStr + session.lastSQLArgs = args if session.isAutoCommit { var db *core.DB @@ -156,25 +138,8 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er session.queryPreprocess(&sqlStr, args...) - if session.engine.showSQL { - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(args) > 0 { - session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration) - } else { - session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration) - } - }() - } else { - if len(args) > 0 { - session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args) - } else { - session.engine.logger.Infof("[SQL] %v", sqlStr) - } - } - } + session.lastSQL = sqlStr + session.lastSQLArgs = args if !session.isAutoCommit { return session.tx.ExecContext(session.ctx, sqlStr, args...) @@ -196,20 +161,6 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er return session.DB().ExecContext(session.ctx, sqlStr, args...) } -func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { - switch sqlOrArgs[0].(type) { - case string: - return sqlOrArgs[0].(string), sqlOrArgs[1:], nil - case *builder.Builder: - return sqlOrArgs[0].(*builder.Builder).ToSQL() - case builder.Builder: - bd := sqlOrArgs[0].(builder.Builder) - return bd.ToSQL() - } - - return "", nil, ErrUnSupportedType -} - // Exec raw sql func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) { if session.isAutoClose { @@ -220,7 +171,7 @@ func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) { return nil, ErrUnSupportedType } - sqlStr, args, err := convertSQLOrArgs(sqlOrArgs...) + sqlStr, args, err := session.statement.ConvertSQLOrArgs(sqlOrArgs...) if err != nil { return nil, err } diff --git a/session_schema.go b/session_schema.go index bcfd2aa1..8520d00b 100644 --- a/session_schema.go +++ b/session_schema.go @@ -5,12 +5,16 @@ package xorm import ( + "bufio" "database/sql" "fmt" "regexp" + "io" + "os" "strings" - "xorm.io/core" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) // Ping test if database is ok @@ -33,13 +37,18 @@ func (session *Session) CreateTable(bean interface{}) error { } func (session *Session) createTable(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqlStr := session.statement.genCreateTableSQL() - _, err := session.exec(sqlStr) - return err + sqlStrs := session.statement.GenCreateTableSQL() + for _, s := range sqlStrs { + _, err := session.exec(s) + if err != nil { + return err + } + } + return nil } // CreateIndexes create indexes @@ -52,11 +61,11 @@ func (session *Session) CreateIndexes(bean interface{}) error { } func (session *Session) createIndexes(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqls := session.statement.genIndexSQL() + sqls := session.statement.GenIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -75,11 +84,11 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createUniques(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqls := session.statement.genUniqueSQL() + sqls := session.statement.GenUniqueSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -99,11 +108,11 @@ func (session *Session) DropIndexes(bean interface{}) error { } func (session *Session) dropIndexes(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqls := session.statement.genDelIndexSQL() + sqls := session.statement.GenDelIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -293,18 +302,16 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error { tableName := session.engine.TableName(beanOrTableName) - var needDrop = true - if !session.engine.dialect.SupportDropIfExists() { - sqlStr, args := session.engine.dialect.TableCheckSql(tableName) - results, err := session.queryBytes(sqlStr, args...) + sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) + if !checkIfExist { + exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) if err != nil { return err } - needDrop = len(results) > 0 + checkIfExist = exist } - if needDrop { - sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true)) + if checkIfExist { _, err := session.exec(sqlStr) return err } @@ -323,9 +330,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) } func (session *Session) isTableExist(tableName string) (bool, error) { - sqlStr, args := session.engine.dialect.TableCheckSql(tableName) - results, err := session.queryBytes(sqlStr, args...) - return len(results) > 0, err + return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) } // IsTableEmpty if table have any records @@ -352,17 +357,17 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) { // find if index is exist according cols func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { - indexes, err := session.engine.dialect.GetIndexes(tableName) + indexes, err := session.engine.dialect.GetIndexes(session.getQueryer(), session.ctx, tableName) if err != nil { return false, err } for _, index := range indexes { - if sliceEq(index.Cols, cols) { + if utils.SliceEq(index.Cols, cols) { if unique { - return index.Type == core.UniqueType, nil + return index.Type == schemas.UniqueType, nil } - return index.Type == core.IndexType, nil + return index.Type == schemas.IndexType, nil } } return false, nil @@ -370,21 +375,21 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo func (session *Session) addColumn(colName string) error { col := session.statement.RefTable.GetColumn(colName) - sql, args := session.statement.genAddColumnStr(col) - _, err := session.exec(sql, args...) + sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col) + _, err := session.exec(sql) return err } func (session *Session) addIndex(tableName, idxName string) error { index := session.statement.RefTable.Indexes[idxName] - sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) + sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index) _, err := session.exec(sqlStr) return err } func (session *Session) addUnique(tableName, uqeName string) error { index := session.statement.RefTable.Indexes[uqeName] - sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) + sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index) _, err := session.exec(sqlStr) return err } @@ -398,7 +403,7 @@ func (session *Session) Sync2(beans ...interface{}) error { defer session.Close() } - tables, err := engine.dialect.GetTables() + tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) if err != nil { return err } @@ -410,8 +415,8 @@ func (session *Session) Sync2(beans ...interface{}) error { }() for _, bean := range beans { - v := rValue(bean) - table, err := engine.mapType(v) + v := utils.ReflectValue(bean) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return err } @@ -423,7 +428,7 @@ func (session *Session) Sync2(beans ...interface{}) error { } tbNameWithSchema := engine.tbNameWithSchema(tbName) - var oriTable *core.Table + var oriTable *schemas.Table for _, tb := range tables { if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { oriTable = tb @@ -457,7 +462,7 @@ func (session *Session) Sync2(beans ...interface{}) error { // check columns for _, col := range table.Columns() { - var oriCol *core.Column + var oriCol *schemas.Column for _, col2 := range oriTable.Columns() { if strings.EqualFold(col.Name, col2.Name) { oriCol = col2 @@ -468,7 +473,7 @@ func (session *Session) Sync2(beans ...interface{}) error { // column is not exist on table if oriCol == nil { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.SetTableName(tbNameWithSchema) if err = session.addColumn(col.Name); err != nil { return err } @@ -476,27 +481,27 @@ func (session *Session) Sync2(beans ...interface{}) error { } err = nil - expectedType := engine.dialect.SqlType(col) - curType := engine.dialect.SqlType(oriCol) + expectedType := engine.dialect.SQLType(col) + curType := engine.dialect.SQLType(oriCol) if expectedType != curType { - if expectedType == core.Text && - strings.HasPrefix(curType, core.Varchar) { + if expectedType == schemas.Text && + strings.HasPrefix(curType, schemas.Varchar) { // currently only support mysql & postgres - if engine.dialect.DBType() == core.MYSQL || - engine.dialect.DBType() == core.POSTGRES { + if engine.dialect.URI().DBType == schemas.MYSQL || + engine.dialect.URI().DBType == schemas.POSTGRES { engine.logger.Infof("Table %s column %s change type from %s to %s\n", tbNameWithSchema, col.Name, curType, expectedType) - _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } else { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", tbNameWithSchema, col.Name, curType, expectedType) } - } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { - if engine.dialect.DBType() == core.MYSQL { + } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { + if engine.dialect.URI().DBType == schemas.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbNameWithSchema, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } } } else { @@ -505,21 +510,23 @@ func (session *Session) Sync2(beans ...interface{}) error { tbNameWithSchema, col.Name, curType, expectedType) } } - } else if expectedType == core.Varchar { - if engine.dialect.DBType() == core.MYSQL { + } else if expectedType == schemas.Varchar { + if engine.dialect.URI().DBType == schemas.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbNameWithSchema, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } } } if col.Default != oriCol.Default { - if (col.SQLType.Name == core.Bool || col.SQLType.Name == core.Boolean) && + switch { + case col.IsAutoIncrement: // For autoincrement column, don't check default + case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) && ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") || - (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")) { - } else { + (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")): + default: engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s", tbName, col.Name, oriCol.Default, col.Default) } @@ -535,10 +542,10 @@ func (session *Session) Sync2(beans ...interface{}) error { } var foundIndexNames = make(map[string]bool) - var addedNames = make(map[string]*core.Index) + var addedNames = make(map[string]*schemas.Index) for name, index := range table.Indexes { - var oriIndex *core.Index + var oriIndex *schemas.Index for name2, index2 := range oriTable.Indexes { if index.Equal(index2) { oriIndex = index2 @@ -549,7 +556,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriIndex != nil { if oriIndex.Type != index.Type { - sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex) + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) _, err = session.exec(sql) if err != nil { return err @@ -565,7 +572,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for name2, index2 := range oriTable.Indexes { if _, ok := foundIndexNames[name2]; !ok { - sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2) + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2) _, err = session.exec(sql) if err != nil { return err @@ -574,13 +581,13 @@ func (session *Session) Sync2(beans ...interface{}) error { } for name, index := range addedNames { - if index.Type == core.UniqueType { + if index.Type == schemas.UniqueType { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.SetTableName(tbNameWithSchema) err = session.addUnique(tbNameWithSchema, name) - } else if index.Type == core.IndexType { + } else if index.Type == schemas.IndexType { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.SetTableName(tbNameWithSchema) err = session.addIndex(tbNameWithSchema, name) } if err != nil { @@ -598,3 +605,56 @@ func (session *Session) Sync2(beans ...interface{}) error { return nil } + +// ImportFile SQL DDL file +func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { + file, err := os.Open(ddlPath) + if err != nil { + return nil, err + } + defer file.Close() + return session.Import(file) +} + +// Import SQL DDL from io.Reader +func (session *Session) Import(r io.Reader) ([]sql.Result, error) { + var results []sql.Result + var lastError error + scanner := bufio.NewScanner(r) + + var inSingleQuote bool + semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + for i, b := range data { + if b == '\'' { + inSingleQuote = !inSingleQuote + } + if !inSingleQuote && b == ';' { + return i + 1, data[0:i], nil + } + } + // If we're at EOF, we have a final, non-terminated line. Return it. + if atEOF { + return len(data), data, nil + } + // Request more data. + return 0, nil, nil + } + + scanner.Split(semiColSpliter) + + for scanner.Scan() { + query := strings.Trim(scanner.Text(), " \t\n\r") + if len(query) > 0 { + result, err := session.Exec(query) + results = append(results, result) + if err != nil { + return nil, err + } + } + } + + return results, lastError +} diff --git a/session_stats.go b/session_stats.go index c2cac830..17d0a675 100644 --- a/session_stats.go +++ b/session_stats.go @@ -17,17 +17,9 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { defer session.Close() } - var sqlStr string - var args []interface{} - var err error - if session.statement.RawSQL == "" { - sqlStr, args, err = session.statement.genCountSQL(bean...) - if err != nil { - return 0, err - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenCountSQL(bean...) + if err != nil { + return 0, err } var total int64 @@ -50,21 +42,12 @@ func (session *Session) sum(res interface{}, bean interface{}, columnNames ...st return errors.New("need a pointer to a variable") } - var isSlice = v.Elem().Kind() == reflect.Slice - var sqlStr string - var args []interface{} - var err error - if len(session.statement.RawSQL) == 0 { - sqlStr, args, err = session.statement.genSumSQL(bean, columnNames...) - if err != nil { - return err - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenSumSQL(bean, columnNames...) + if err != nil { + return err } - if isSlice { + if v.Elem().Kind() == reflect.Slice { err = session.queryRow(sqlStr, args...).ScanSlice(res) } else { err = session.queryRow(sqlStr, args...).Scan(res) diff --git a/session_tx.go b/session_tx.go index ee3d473f..cd23cf89 100644 --- a/session_tx.go +++ b/session_tx.go @@ -4,6 +4,12 @@ package xorm +import ( + "time" + + "xorm.io/xorm/log" +) + // Begin a transaction func (session *Session) Begin() error { if session.isAutoCommit { @@ -14,6 +20,7 @@ func (session *Session) Begin() error { session.isAutoCommit = false session.isCommitedOrRollbacked = false session.tx = tx + session.saveLastSQL("BEGIN TRANSACTION") } return nil @@ -22,10 +29,28 @@ func (session *Session) Begin() error { // Rollback When using transaction, you can rollback if any error func (session *Session) Rollback() error { if !session.isAutoCommit && !session.isCommitedOrRollbacked { - session.saveLastSQL(session.engine.dialect.RollBackStr()) + session.saveLastSQL("ROLL BACK") session.isCommitedOrRollbacked = true session.isAutoCommit = true - return session.tx.Rollback() + + start := time.Now() + needSQL := session.DB().NeedLogSQL(session.ctx) + if needSQL { + session.engine.logger.BeforeSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "ROLL BACK", + }) + } + err := session.tx.Rollback() + if needSQL { + session.engine.logger.AfterSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "ROLL BACK", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + return err } return nil } @@ -36,48 +61,67 @@ func (session *Session) Commit() error { session.saveLastSQL("COMMIT") session.isCommitedOrRollbacked = true session.isAutoCommit = true - var err error - if err = session.tx.Commit(); err == nil { - // handle processors after tx committed - closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { - if closuresPtr != nil { - for _, closure := range *closuresPtr { - closure(bean) - } - } - } - for bean, closuresPtr := range session.afterInsertBeans { - closureCallFunc(closuresPtr, bean) - - if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { - processor.AfterInsert() - } - } - for bean, closuresPtr := range session.afterUpdateBeans { - closureCallFunc(closuresPtr, bean) - - if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - processor.AfterUpdate() - } - } - for bean, closuresPtr := range session.afterDeleteBeans { - closureCallFunc(closuresPtr, bean) - - if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { - processor.AfterDelete() - } - } - cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) { - if len(*slices) > 0 { - *slices = make(map[interface{}]*[]func(interface{}), 0) - } - } - cleanUpFunc(&session.afterInsertBeans) - cleanUpFunc(&session.afterUpdateBeans) - cleanUpFunc(&session.afterDeleteBeans) + start := time.Now() + needSQL := session.DB().NeedLogSQL(session.ctx) + if needSQL { + session.engine.logger.BeforeSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "COMMIT", + }) } - return err + err := session.tx.Commit() + if needSQL { + session.engine.logger.AfterSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "COMMIT", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + + if err != nil { + return err + } + + // handle processors after tx committed + closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { + if closuresPtr != nil { + for _, closure := range *closuresPtr { + closure(bean) + } + } + } + + for bean, closuresPtr := range session.afterInsertBeans { + closureCallFunc(closuresPtr, bean) + + if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { + processor.AfterInsert() + } + } + for bean, closuresPtr := range session.afterUpdateBeans { + closureCallFunc(closuresPtr, bean) + + if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { + processor.AfterUpdate() + } + } + for bean, closuresPtr := range session.afterDeleteBeans { + closureCallFunc(closuresPtr, bean) + + if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { + processor.AfterDelete() + } + } + cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) { + if len(*slices) > 0 { + *slices = make(map[interface{}]*[]func(interface{}), 0) + } + } + cleanUpFunc(&session.afterInsertBeans) + cleanUpFunc(&session.afterUpdateBeans) + cleanUpFunc(&session.afterDeleteBeans) } return nil } diff --git a/session_update.go b/session_update.go index 231163e0..62116c47 100644 --- a/session_update.go +++ b/session_update.go @@ -12,23 +12,25 @@ import ( "strings" "xorm.io/builder" - "xorm.io/core" + "xorm.io/xorm/caches" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" ) -func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, args ...interface{}) error { +func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if table == nil || session.tx != nil { return ErrCacheFailed } - oldhead, newsql := session.statement.convertUpdateSQL(sqlStr) + oldhead, newsql := session.statement.ConvertUpdateSQL(sqlStr) if newsql == "" { return ErrCacheFailed } for _, filter := range session.engine.dialect.Filters() { - newsql = filter.Do(newsql, session.engine.dialect, table) + newsql = filter.Do(newsql) } - session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql) + session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql) var nStart int if len(args) > 0 { @@ -40,9 +42,9 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, } } - cacher := session.engine.getCacher(tableName) - session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) - ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:]) + cacher := session.engine.GetCacher(tableName) + session.engine.logger.Debugf("[cache] get cache sql: %v, %v", newsql, args[nStart:]) + ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:]) if err != nil { rows, err := session.NoCache().queryRows(newsql, args[nStart:]...) if err != nil { @@ -50,14 +52,14 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, } defer rows.Close() - ids = make([]core.PK, 0) + ids = make([]schemas.PK, 0) for rows.Next() { var res = make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { return err } - var pk core.PK = make([]interface{}, len(table.PrimaryKeys)) + var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) for i, col := range table.PKColumns() { if col.SQLType.IsNumeric() { n, err := strconv.ParseInt(res[i], 10, 64) @@ -74,7 +76,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, ids = append(ids, pk) } - session.engine.logger.Debug("[cacheUpdate] find updated id", ids) + session.engine.logger.Debugf("[cache] find updated id: %v", ids) } /*else { session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) cacher.DelIds(tableName, genSqlKey(newsql, args)) @@ -86,12 +88,12 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, return err } if bean := cacher.GetBean(tableName, sid); bean != nil { - sqls := splitNNoCase(sqlStr, "where", 2) + sqls := utils.SplitNNoCase(sqlStr, "where", 2) if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed } - sqls = splitNNoCase(sqls[0], "set", 2) + sqls = utils.SplitNNoCase(sqls[0], "set", 2) if len(sqls) != 2 { return ErrCacheFailed } @@ -101,38 +103,32 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, sps := strings.SplitN(kv, "=", 2) sps2 := strings.Split(sps[0], ".") colName := sps2[len(sps2)-1] - // treat quote prefix, suffix and '`' as quotes - quotes := append(strings.Split(session.engine.Quote(""), ""), "`") - if strings.ContainsAny(colName, strings.Join(quotes, "")) { - colName = strings.TrimSpace(eraseAny(colName, quotes...)) - } else { - session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) - return ErrCacheFailed - } + colName = session.engine.dialect.Quoter().Trim(colName) + colName = schemas.CommonQuoter.Trim(colName) if col := table.GetColumn(colName); col != nil { fieldValue, err := col.ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else { - session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) - if col.IsVersion && session.statement.checkVersion { + session.engine.logger.Debugf("[cache] set bean field: %v, %v, %v", bean, colName, fieldValue.Interface()) + if col.IsVersion && session.statement.CheckVersion { session.incrVersionFieldValue(fieldValue) } else { fieldValue.Set(reflect.ValueOf(args[idx])) } } } else { - session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's", + session.engine.logger.Errorf("[cache] ERROR: column %v is not table %v's", colName, table.Name) } } - session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean) + session.engine.logger.Debugf("[cache] update cache: %v, %v, %v", tableName, id, bean) cacher.PutBean(tableName, sid, bean) } } - session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName) + session.engine.logger.Debugf("[cache] clear cached table sql: %v", tableName) cacher.ClearIds(tableName) return nil } @@ -148,11 +144,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 defer session.Close() } - if session.statement.lastError != nil { - return 0, session.statement.lastError + if session.statement.LastError != nil { + return 0, session.statement.LastError } - v := rValue(bean) + v := utils.ReflectValue(bean) t := v.Type() var colNames []string @@ -172,7 +168,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var isMap = t.Kind() == reflect.Map var isStruct = t.Kind() == reflect.Struct if isStruct { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } @@ -180,14 +176,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, ErrTableNotFound } - if session.statement.ColumnStr == "" { - colNames, args = session.statement.buildUpdates(bean, false, false, + if session.statement.ColumnStr() == "" { + colNames, args, err = session.statement.BuildUpdates(v, false, false, false, false, true) } else { colNames, args, err = session.genUpdateColumns(bean) - if err != nil { - return 0, err - } + } + if err != nil { + return 0, err } } else if isMap { colNames = make([]string, 0) @@ -205,8 +201,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 table := session.statement.RefTable if session.statement.UseAutoTime && table != nil && table.Updated != "" { - if !session.statement.columnMap.contain(table.Updated) && - !session.statement.omitColumnMap.contain(table.Updated) { + if !session.statement.ColumnMap.Contain(table.Updated) && + !session.statement.OmitColumnMap.Contain(table.Updated) { colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") col := table.UpdatedColumn() val, t := session.engine.nowTime(col) @@ -223,39 +219,45 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } // for update action to like "column = column + ?" - incColumns := session.statement.incrColumns - for i, colName := range incColumns.colNames { + incColumns := session.statement.IncrColumns + for i, colName := range incColumns.ColNames { colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") - args = append(args, incColumns.args[i]) + args = append(args, incColumns.Args[i]) } // for update action to like "column = column - ?" - decColumns := session.statement.decrColumns - for i, colName := range decColumns.colNames { + decColumns := session.statement.DecrColumns + for i, colName := range decColumns.ColNames { colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") - args = append(args, decColumns.args[i]) + args = append(args, decColumns.Args[i]) } // for update action to like "column = expression" - exprColumns := session.statement.exprColumns - for i, colName := range exprColumns.colNames { - switch tp := exprColumns.args[i].(type) { + exprColumns := session.statement.ExprColumns + for i, colName := range exprColumns.ColNames { + switch tp := exprColumns.Args[i].(type) { case string: - colNames = append(colNames, session.engine.Quote(colName)+" = "+tp) + if len(tp) == 0 { + tp = "''" + } + colNames = append(colNames, session.engine.Quote(colName)+"="+tp) case *builder.Builder: - subQuery, subArgs, err := builder.ToSQL(tp) + subQuery, subArgs, err := session.statement.GenCondSQL(tp) if err != nil { return 0, err } - colNames = append(colNames, session.engine.Quote(colName)+" = ("+subQuery+")") + colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")") args = append(args, subArgs...) + default: + colNames = append(colNames, session.engine.Quote(colName)+"=?") + args = append(args, exprColumns.Args[i]) } } - if err = session.statement.processIDParam(); err != nil { + if err = session.statement.ProcessIDParam(); err != nil { return 0, err } var autoCond builder.Cond - if !session.statement.noAutoCondition { + if !session.statement.NoAutoCondition { condBeanIsStruct := false if len(condiBean) > 0 { if c, ok := condiBean[0].(map[string]interface{}); ok { @@ -268,7 +270,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if k == reflect.Struct { var err error - autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) + autoCond, err = session.statement.BuildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } @@ -280,8 +282,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if !condBeanIsStruct && table != nil { - if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled - autoCond1 := session.engine.CondDeleted(session.engine.Quote(col.Name)) + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + autoCond1 := session.statement.CondDeleted(col) if autoCond == nil { autoCond = autoCond1 @@ -292,26 +294,34 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - st := &session.statement + st := session.statement - var sqlStr string - var condArgs []interface{} - var condSQL string - cond := session.statement.cond.And(autoCond) + var ( + sqlStr string + condArgs []interface{} + condSQL string + cond = session.statement.Conds().And(autoCond) - var doIncVer = (table != nil && table.Version != "" && session.statement.checkVersion) - var verValue *reflect.Value + doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) + verValue *reflect.Value + ) if doIncVer { verValue, err = table.VersionColumn().ValueOf(bean) if err != nil { return 0, err } - cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) - colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1") + if verValue != nil { + cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) + colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1") + } } - condSQL, condArgs, err = builder.ToSQL(cond) + if len(colNames) <= 0 { + return 0, errors.New("No content found to be updated") + } + + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -327,25 +337,27 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var tableName = session.statement.TableName() // TODO: Oracle support needed var top string - if st.LimitN > 0 { - if st.Engine.dialect.DBType() == core.MYSQL { - condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) - } else if st.Engine.dialect.DBType() == core.SQLITE { - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + if st.LimitN != nil { + limitValue := *st.LimitN + switch session.engine.dialect.URI().DBType { + case schemas.MYSQL: + condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) + case schemas.SQLITE: + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.Engine.dialect.DBType() == core.POSTGRES { - tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) + case schemas.POSTGRES: + tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -353,14 +365,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.Engine.dialect.DBType() == core.MSSQL { - if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL && - table != nil && len(table.PrimaryKeys) == 1 { + case schemas.MSSQL: + if st.OrderStr != "" && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", - table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0], + table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], session.engine.Quote(tableName), condSQL), condArgs...) - condSQL, condArgs, err = builder.ToSQL(cond) + condSQL, condArgs, err = session.statement.GenCondSQL(cond) if err != nil { return 0, err } @@ -368,20 +379,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condSQL = "WHERE " + condSQL } } else { - top = fmt.Sprintf("TOP (%d) ", st.LimitN) + top = fmt.Sprintf("TOP (%d) ", limitValue) } } } - if len(colNames) <= 0 { - return 0, errors.New("No content found to be updated") - } - var tableAlias = session.engine.Quote(tableName) var fromSQL string if session.statement.TableAlias != "" { - switch session.engine.dialect.DBType() { - case core.MSSQL: + switch session.engine.dialect.URI().DBType { + case schemas.MSSQL: fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) tableAlias = session.statement.TableAlias default: @@ -405,9 +412,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { + if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache { // session.cacheUpdate(table, tableName, sqlStr, args...) - session.engine.logger.Debug("[cacheUpdate] clear table ", tableName) + session.engine.logger.Debugf("[cache] clear table: %v", tableName) cacher.ClearIds(tableName) cacher.ClearBeans(tableName) } @@ -418,7 +425,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 closure(bean) } if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - session.engine.logger.Debug("[event]", tableName, " has after update processor") + session.engine.logger.Debugf("[event] %v has after update processor", tableName) processor.AfterUpdate() } } else { @@ -452,11 +459,11 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac for _, col := range table.Columns() { if !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } } - if col.MapType == core.ONLYFROMDB { + if col.MapType == schemas.ONLYFROMDB { continue } @@ -466,47 +473,30 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac } fieldValue := *fieldValuePtr - if col.IsAutoIncrement { - switch fieldValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - case reflect.String: - if len(fieldValue.String()) == 0 { - continue - } - case reflect.Ptr: - if fieldValue.Pointer() == 0 { - continue - } - } + if col.IsAutoIncrement && utils.IsValueZero(fieldValue) { + continue } - if (col.IsDeleted && !session.statement.unscoped) || col.IsCreated { + if (col.IsDeleted && !session.statement.GetUnscoped()) || col.IsCreated { continue } // if only update specify columns - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } - if session.statement.incrColumns.isColExist(col.Name) { + if session.statement.IncrColumns.IsColExist(col.Name) { continue - } else if session.statement.decrColumns.isColExist(col.Name) { + } else if session.statement.DecrColumns.IsColExist(col.Name) { continue - } else if session.statement.exprColumns.isColExist(col.Name) { + } else if session.statement.ExprColumns.IsColExist(col.Name) { continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { - if col.Nullable && isZero(fieldValue.Interface()) { + if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok { + if col.Nullable && utils.IsValueZero(fieldValue) { var nilValue *int fieldValue = reflect.ValueOf(nilValue) } @@ -522,10 +512,10 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { - arg, err := session.value2Interface(col, fieldValue) + arg, err := session.statement.Value2Interface(col, fieldValue) if err != nil { return colNames, args, err } diff --git a/statement.go b/statement.go deleted file mode 100644 index 67e35213..00000000 --- a/statement.go +++ /dev/null @@ -1,1256 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "strings" - "time" - - "xorm.io/builder" - "xorm.io/core" -) - -// Statement save all the sql info for executing SQL -type Statement struct { - RefTable *core.Table - Engine *Engine - Start int - LimitN int - idParam *core.PK - OrderStr string - JoinStr string - joinArgs []interface{} - GroupByStr string - HavingStr string - ColumnStr string - selectStr string - useAllCols bool - OmitStr string - AltTableName string - tableName string - RawSQL string - RawParams []interface{} - UseCascade bool - UseAutoJoin bool - StoreEngine string - Charset string - UseCache bool - UseAutoTime bool - noAutoCondition bool - IsDistinct bool - IsForUpdate bool - TableAlias string - allUseBool bool - checkVersion bool - unscoped bool - columnMap columnMap - omitColumnMap columnMap - mustColumnMap map[string]bool - nullableMap map[string]bool - incrColumns exprParams - decrColumns exprParams - exprColumns exprParams - cond builder.Cond - bufferSize int - context ContextCache - lastError error -} - -// Init reset all the statement's fields -func (statement *Statement) Init() { - statement.RefTable = nil - statement.Start = 0 - statement.LimitN = 0 - statement.OrderStr = "" - statement.UseCascade = true - statement.JoinStr = "" - statement.joinArgs = make([]interface{}, 0) - statement.GroupByStr = "" - statement.HavingStr = "" - statement.ColumnStr = "" - statement.OmitStr = "" - statement.columnMap = columnMap{} - statement.omitColumnMap = columnMap{} - statement.AltTableName = "" - statement.tableName = "" - statement.idParam = nil - statement.RawSQL = "" - statement.RawParams = make([]interface{}, 0) - statement.UseCache = true - statement.UseAutoTime = true - statement.noAutoCondition = false - statement.IsDistinct = false - statement.IsForUpdate = false - statement.TableAlias = "" - statement.selectStr = "" - statement.allUseBool = false - statement.useAllCols = false - statement.mustColumnMap = make(map[string]bool) - statement.nullableMap = make(map[string]bool) - statement.checkVersion = true - statement.unscoped = false - statement.incrColumns = exprParams{} - statement.decrColumns = exprParams{} - statement.exprColumns = exprParams{} - statement.cond = builder.NewCond() - statement.bufferSize = 0 - statement.context = nil - statement.lastError = nil -} - -// NoAutoCondition if you do not want convert bean's field as query condition, then use this function -func (statement *Statement) NoAutoCondition(no ...bool) *Statement { - statement.noAutoCondition = true - if len(no) > 0 { - statement.noAutoCondition = no[0] - } - return statement -} - -// Alias set the table alias -func (statement *Statement) Alias(alias string) *Statement { - statement.TableAlias = alias - return statement -} - -// SQL adds raw sql statement -func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { - switch query.(type) { - case (*builder.Builder): - var err error - statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() - if err != nil { - statement.lastError = err - } - case string: - statement.RawSQL = query.(string) - statement.RawParams = args - default: - statement.lastError = ErrUnSupportedSQLType - } - - return statement -} - -// Where add Where statement -func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement { - return statement.And(query, args...) -} - -// And add Where & and statement -func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { - switch query.(type) { - case string: - cond := builder.Expr(query.(string), args...) - statement.cond = statement.cond.And(cond) - case map[string]interface{}: - queryMap := query.(map[string]interface{}) - newMap := make(map[string]interface{}) - for k, v := range queryMap { - newMap[statement.Engine.Quote(k)] = v - } - statement.cond = statement.cond.And(builder.Eq(newMap)) - case builder.Cond: - cond := query.(builder.Cond) - statement.cond = statement.cond.And(cond) - for _, v := range args { - if vv, ok := v.(builder.Cond); ok { - statement.cond = statement.cond.And(vv) - } - } - default: - statement.lastError = ErrConditionType - } - - return statement -} - -// Or add Where & Or statement -func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { - switch query.(type) { - case string: - cond := builder.Expr(query.(string), args...) - statement.cond = statement.cond.Or(cond) - case map[string]interface{}: - cond := builder.Eq(query.(map[string]interface{})) - statement.cond = statement.cond.Or(cond) - case builder.Cond: - cond := query.(builder.Cond) - statement.cond = statement.cond.Or(cond) - for _, v := range args { - if vv, ok := v.(builder.Cond); ok { - statement.cond = statement.cond.Or(vv) - } - } - default: - // TODO: not support condition type - } - return statement -} - -// In generate "Where column IN (?) " statement -func (statement *Statement) In(column string, args ...interface{}) *Statement { - in := builder.In(statement.Engine.Quote(column), args...) - statement.cond = statement.cond.And(in) - return statement -} - -// NotIn generate "Where column NOT IN (?) " statement -func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { - notIn := builder.NotIn(statement.Engine.Quote(column), args...) - statement.cond = statement.cond.And(notIn) - return statement -} - -func (statement *Statement) setRefValue(v reflect.Value) error { - var err error - statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v)) - if err != nil { - return err - } - statement.tableName = statement.Engine.TableName(v, true) - return nil -} - -func (statement *Statement) setRefBean(bean interface{}) error { - var err error - statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) - if err != nil { - return err - } - statement.tableName = statement.Engine.TableName(bean, true) - return nil -} - -// Auto generating update columnes and values according a struct -func (statement *Statement) buildUpdates(bean interface{}, - includeVersion, includeUpdated, includeNil, - includeAutoIncr, update bool) ([]string, []interface{}) { - engine := statement.Engine - table := statement.RefTable - allUseBool := statement.allUseBool - useAllCols := statement.useAllCols - mustColumnMap := statement.mustColumnMap - nullableMap := statement.nullableMap - columnMap := statement.columnMap - omitColumnMap := statement.omitColumnMap - unscoped := statement.unscoped - - var colNames = make([]string, 0) - var args = make([]interface{}, 0) - for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue - } - if col.IsCreated { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - if col.IsDeleted && !unscoped { - continue - } - if omitColumnMap.contain(col.Name) { - continue - } - if len(columnMap) > 0 && !columnMap.contain(col.Name) { - continue - } - - if col.MapType == core.ONLYFROMDB { - continue - } - - if statement.incrColumns.isColExist(col.Name) { - continue - } else if statement.decrColumns.isColExist(col.Name) { - continue - } else if statement.exprColumns.isColExist(col.Name) { - continue - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - engine.logger.Error(err) - continue - } - - fieldValue := *fieldValuePtr - fieldType := reflect.TypeOf(fieldValue.Interface()) - if fieldType == nil { - continue - } - - requiredField := useAllCols - includeNil := useAllCols - - if b, ok := getFlagForColumn(mustColumnMap, col); ok { - if b { - requiredField = true - } else { - continue - } - } - - // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if b, ok := getFlagForColumn(nullableMap, col); ok { - if b && col.Nullable && isZero(fieldValue.Interface()) { - var nilValue *int - fieldValue = reflect.ValueOf(nilValue) - fieldType = reflect.TypeOf(fieldValue.Interface()) - includeNil = true - } - } - - var val interface{} - - if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { - data, err := structConvert.ToDB() - if err != nil { - engine.logger.Error(err) - } else { - val = data - } - goto APPEND - } - } - - if structConvert, ok := fieldValue.Interface().(core.Conversion); ok { - data, err := structConvert.ToDB() - if err != nil { - engine.logger.Error(err) - } else { - val = data - } - goto APPEND - } - - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - args = append(args, nil) - colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } - - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - t := int64(fieldValue.Uint()) - val = reflect.ValueOf(&t).Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = engine.formatColTime(col, t) - } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = nulType.Value() - } else { - if !col.SQLType.IsJson() { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) { - val = pkField.Interface() - } else { - continue - } - } else { - // TODO: how to handler? - panic("not supported") - } - } else { - val = fieldValue.Interface() - } - } else { - // Blank struct could not be as update data - if requiredField || !isStructZero(fieldValue) { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface())) - } - if col.SQLType.IsText() { - val = string(bytes) - } else if col.SQLType.IsBlob() { - val = bytes - } - } else { - continue - } - } - } - case reflect.Array, reflect.Slice, reflect.Map: - if !requiredField { - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldType.Kind() == reflect.Array { - if isArrayValueZero(fieldValue) { - continue - } - } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - } - - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if fieldType.Kind() == reflect.Slice && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else if fieldType.Kind() == reflect.Array && - fieldType.Elem().Kind() == reflect.Uint8 { - val = fieldValue.Slice(0, 0).Interface() - } else { - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - - APPEND: - args = append(args, val) - if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { - continue - } - colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) - } - - return colNames, args -} - -func (statement *Statement) needTableName() bool { - return len(statement.JoinStr) > 0 -} - -func (statement *Statement) colName(col *core.Column, tableName string) string { - if statement.needTableName() { - var nm = tableName - if len(statement.TableAlias) > 0 { - nm = statement.TableAlias - } - return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) - } - return statement.Engine.Quote(col.Name) -} - -// TableName return current tableName -func (statement *Statement) TableName() string { - if statement.AltTableName != "" { - return statement.AltTableName - } - - return statement.tableName -} - -// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?" -func (statement *Statement) ID(id interface{}) *Statement { - idValue := reflect.ValueOf(id) - idType := reflect.TypeOf(idValue.Interface()) - - switch idType { - case ptrPkType: - if pkPtr, ok := (id).(*core.PK); ok { - statement.idParam = pkPtr - return statement - } - case pkType: - if pk, ok := (id).(core.PK); ok { - statement.idParam = &pk - return statement - } - } - - switch idType.Kind() { - case reflect.String: - statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()} - return statement - } - - statement.idParam = &core.PK{id} - return statement -} - -// Incr Generate "Update ... Set column = column + arg" statement -func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { - if len(arg) > 0 { - statement.incrColumns.addParam(column, arg[0]) - } else { - statement.incrColumns.addParam(column, 1) - } - return statement -} - -// Decr Generate "Update ... Set column = column - arg" statement -func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { - if len(arg) > 0 { - statement.decrColumns.addParam(column, arg[0]) - } else { - statement.decrColumns.addParam(column, 1) - } - return statement -} - -// SetExpr Generate "Update ... Set column = {expression}" statement -func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { - statement.exprColumns.addParam(column, expression) - return statement -} - -func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { - newColumns := make([]string, 0) - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - for _, col := range columns { - newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...))) - } - return newColumns -} - -func (statement *Statement) colmap2NewColsWithQuote() []string { - newColumns := make([]string, len(statement.columnMap), len(statement.columnMap)) - copy(newColumns, statement.columnMap) - for i := 0; i < len(statement.columnMap); i++ { - newColumns[i] = statement.Engine.Quote(newColumns[i]) - } - return newColumns -} - -// Distinct generates "DISTINCT col1, col2 " statement -func (statement *Statement) Distinct(columns ...string) *Statement { - statement.IsDistinct = true - statement.Cols(columns...) - return statement -} - -// ForUpdate generates "SELECT ... FOR UPDATE" statement -func (statement *Statement) ForUpdate() *Statement { - statement.IsForUpdate = true - return statement -} - -// Select replace select -func (statement *Statement) Select(str string) *Statement { - statement.selectStr = str - return statement -} - -// Cols generate "col1, col2" statement -func (statement *Statement) Cols(columns ...string) *Statement { - cols := col2NewCols(columns...) - for _, nc := range cols { - statement.columnMap.add(nc) - } - - newColumns := statement.colmap2NewColsWithQuote() - - statement.ColumnStr = strings.Join(newColumns, ", ") - statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) - return statement -} - -// AllCols update use only: update all columns -func (statement *Statement) AllCols() *Statement { - statement.useAllCols = true - return statement -} - -// MustCols update use only: must update columns -func (statement *Statement) MustCols(columns ...string) *Statement { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.mustColumnMap[strings.ToLower(nc)] = true - } - return statement -} - -// UseBool indicates that use bool fields as update contents and query contiditions -func (statement *Statement) UseBool(columns ...string) *Statement { - if len(columns) > 0 { - statement.MustCols(columns...) - } else { - statement.allUseBool = true - } - return statement -} - -// Omit do not use the columns -func (statement *Statement) Omit(columns ...string) { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.omitColumnMap = append(statement.omitColumnMap, nc) - } - statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) -} - -// Nullable Update use only: update columns to null when value is nullable and zero-value -func (statement *Statement) Nullable(columns ...string) { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.nullableMap[strings.ToLower(nc)] = true - } -} - -// Top generate LIMIT limit statement -func (statement *Statement) Top(limit int) *Statement { - statement.Limit(limit) - return statement -} - -// Limit generate LIMIT start, limit statement -func (statement *Statement) Limit(limit int, start ...int) *Statement { - statement.LimitN = limit - if len(start) > 0 { - statement.Start = start[0] - } - return statement -} - -// OrderBy generate "Order By order" statement -func (statement *Statement) OrderBy(order string) *Statement { - if len(statement.OrderStr) > 0 { - statement.OrderStr += ", " - } - statement.OrderStr += order - return statement -} - -// Desc generate `ORDER BY xx DESC` -func (statement *Statement) Desc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, statement.OrderStr, ", ") - } - newColNames := statement.col2NewColsWithQuote(colNames...) - fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) - statement.OrderStr = buf.String() - return statement -} - -// Asc provide asc order by query condition, the input parameters are columns. -func (statement *Statement) Asc(colNames ...string) *Statement { - var buf strings.Builder - if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, statement.OrderStr, ", ") - } - newColNames := statement.col2NewColsWithQuote(colNames...) - fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) - statement.OrderStr = buf.String() - return statement -} - -// Table tempororily set table name, the parameter could be a string or a pointer of struct -func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { - v := rValue(tableNameOrBean) - t := v.Type() - if t.Kind() == reflect.Struct { - var err error - statement.RefTable, err = statement.Engine.autoMapType(v) - if err != nil { - statement.Engine.logger.Error(err) - return statement - } - } - - statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) - return statement -} - -// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { - var buf strings.Builder - if len(statement.JoinStr) > 0 { - fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) - } else { - fmt.Fprintf(&buf, "%v JOIN ", joinOP) - } - - switch tp := tablename.(type) { - case builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.lastError = err - return statement - } - tbs := strings.Split(tp.TableName(), ".") - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - - var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) - fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) - case *builder.Builder: - subSQL, subQueryArgs, err := tp.ToSQL() - if err != nil { - statement.lastError = err - return statement - } - tbs := strings.Split(tp.TableName(), ".") - quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") - - var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) - fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) - statement.joinArgs = append(statement.joinArgs, subQueryArgs...) - default: - tbName := statement.Engine.TableName(tablename, true) - fmt.Fprintf(&buf, "%s ON %v", tbName, condition) - } - - statement.JoinStr = buf.String() - statement.joinArgs = append(statement.joinArgs, args...) - return statement -} - -// GroupBy generate "Group By keys" statement -func (statement *Statement) GroupBy(keys string) *Statement { - statement.GroupByStr = keys - return statement -} - -// Having generate "Having conditions" statement -func (statement *Statement) Having(conditions string) *Statement { - statement.HavingStr = fmt.Sprintf("HAVING %v", conditions) - return statement -} - -// Unscoped always disable struct tag "deleted" -func (statement *Statement) Unscoped() *Statement { - statement.unscoped = true - return statement -} - -func (statement *Statement) genColumnStr() string { - if statement.RefTable == nil { - return "" - } - - var buf strings.Builder - columns := statement.RefTable.Columns() - - for _, col := range columns { - if statement.omitColumnMap.contain(col.Name) { - continue - } - - if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) { - continue - } - - if col.MapType == core.ONLYTODB { - continue - } - - if buf.Len() != 0 { - buf.WriteString(", ") - } - - if statement.JoinStr != "" { - if statement.TableAlias != "" { - buf.WriteString(statement.TableAlias) - } else { - buf.WriteString(statement.TableName()) - } - - buf.WriteString(".") - } - - statement.Engine.QuoteTo(&buf, col.Name) - } - - return buf.String() -} - -func (statement *Statement) genCreateTableSQL() string { - return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(), - statement.StoreEngine, statement.Charset) -} - -func (statement *Statement) genIndexSQL() []string { - var sqls []string - tbName := statement.TableName() - for _, index := range statement.RefTable.Indexes { - if index.Type == core.IndexType { - sql := statement.Engine.dialect.CreateIndexSql(tbName, index) - /*idxTBName := strings.Replace(tbName, ".", "_", -1) - idxTBName = strings.Replace(idxTBName, `"`, "", -1) - sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)), - quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/ - sqls = append(sqls, sql) - } - } - return sqls -} - -func uniqueName(tableName, uqeName string) string { - return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) -} - -func (statement *Statement) genUniqueSQL() []string { - var sqls []string - tbName := statement.TableName() - for _, index := range statement.RefTable.Indexes { - if index.Type == core.UniqueType { - sql := statement.Engine.dialect.CreateIndexSql(tbName, index) - sqls = append(sqls, sql) - } - } - return sqls -} - -func (statement *Statement) genDelIndexSQL() []string { - var sqls []string - tbName := statement.TableName() - idxPrefixName := strings.Replace(tbName, `"`, "", -1) - idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1) - for idxName, index := range statement.RefTable.Indexes { - var rIdxName string - if index.Type == core.UniqueType { - rIdxName = uniqueName(idxPrefixName, idxName) - } else if index.Type == core.IndexType { - rIdxName = indexName(idxPrefixName, idxName) - } - sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true))) - if statement.Engine.dialect.IndexOnTable() { - sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) - } - sqls = append(sqls, sql) - } - return sqls -} - -func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { - quote := statement.Engine.Quote - sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()), - col.String(statement.Engine.dialect)) - if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" - } - sql += ";" - return sql, []interface{}{} -} - -func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { - return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) -} - -func (statement *Statement) mergeConds(bean interface{}) error { - if !statement.noAutoCondition { - var addedTableName = (len(statement.JoinStr) > 0) - autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) - if err != nil { - return err - } - statement.cond = statement.cond.And(autoCond) - } - - if err := statement.processIDParam(); err != nil { - return err - } - return nil -} - -func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { - if err := statement.mergeConds(bean); err != nil { - return "", nil, err - } - - return builder.ToSQL(statement.cond) -} - -func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) { - v := rValue(bean) - isStruct := v.Kind() == reflect.Struct - if isStruct { - statement.setRefBean(bean) - } - - var columnStr = statement.ColumnStr - if len(statement.selectStr) > 0 { - columnStr = statement.selectStr - } else { - // TODO: always generate column names, not use * even if join - if len(statement.JoinStr) == 0 { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.quoteColumns(statement.GroupByStr) - } else { - columnStr = statement.genColumnStr() - } - } - } else { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.quoteColumns(statement.GroupByStr) - } - } - } - } - - if len(columnStr) == 0 { - columnStr = "*" - } - - if isStruct { - if err := statement.mergeConds(bean); err != nil { - return "", nil, err - } - } else { - if err := statement.processIDParam(); err != nil { - return "", nil, err - } - } - condSQL, condArgs, err := builder.ToSQL(statement.cond) - if err != nil { - return "", nil, err - } - - sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil -} - -func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) { - var condSQL string - var condArgs []interface{} - var err error - if len(beans) > 0 { - statement.setRefBean(beans[0]) - condSQL, condArgs, err = statement.genConds(beans[0]) - } else { - condSQL, condArgs, err = builder.ToSQL(statement.cond) - } - if err != nil { - return "", nil, err - } - - var selectSQL = statement.selectStr - if len(selectSQL) <= 0 { - if statement.IsDistinct { - selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr) - } else { - selectSQL = "count(*)" - } - } - sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil -} - -func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { - statement.setRefBean(bean) - - var sumStrs = make([]string, 0, len(columns)) - for _, colName := range columns { - if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { - colName = statement.Engine.Quote(colName) - } - sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) - } - sumSelect := strings.Join(sumStrs, ", ") - - condSQL, condArgs, err := statement.genConds(bean) - if err != nil { - return "", nil, err - } - - sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil -} - -func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { - var ( - distinct string - dialect = statement.Engine.Dialect() - quote = statement.Engine.Quote - fromStr = " FROM " - top, mssqlCondi, whereStr string - ) - if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { - distinct = "DISTINCT " - } - if len(condSQL) > 0 { - whereStr = " WHERE " + condSQL - } - - if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { - fromStr += statement.TableName() - } else { - fromStr += quote(statement.TableName()) - } - - if statement.TableAlias != "" { - if dialect.DBType() == core.ORACLE { - fromStr += " " + quote(statement.TableAlias) - } else { - fromStr += " AS " + quote(statement.TableAlias) - } - } - if statement.JoinStr != "" { - fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) - } - - if dialect.DBType() == core.MSSQL { - if statement.LimitN > 0 { - top = fmt.Sprintf("TOP %d ", statement.LimitN) - } - if statement.Start > 0 { - var column string - if len(statement.RefTable.PKColumns()) == 0 { - for _, index := range statement.RefTable.Indexes { - if len(index.Cols) == 1 { - column = index.Cols[0] - break - } - } - if len(column) == 0 { - column = statement.RefTable.ColumnsSeq()[0] - } - } else { - column = statement.RefTable.PKColumns()[0].Name - } - if statement.needTableName() { - if len(statement.TableAlias) > 0 { - column = statement.TableAlias + "." + column - } else { - column = statement.TableName() + "." + column - } - } - - var orderStr string - if needOrderBy && len(statement.OrderStr) > 0 { - orderStr = " ORDER BY " + statement.OrderStr - } - - var groupStr string - if len(statement.GroupByStr) > 0 { - groupStr = " GROUP BY " + statement.GroupByStr - } - mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))", - column, statement.Start, column, fromStr, whereStr, orderStr, groupStr) - } - } - - var buf strings.Builder - fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) - if len(mssqlCondi) > 0 { - if len(whereStr) > 0 { - fmt.Fprint(&buf, " AND ", mssqlCondi) - } else { - fmt.Fprint(&buf, " WHERE ", mssqlCondi) - } - } - - if statement.GroupByStr != "" { - fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) - } - if statement.HavingStr != "" { - fmt.Fprint(&buf, " ", statement.HavingStr) - } - if needOrderBy && statement.OrderStr != "" { - fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) - } - if needLimit { - if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { - if statement.Start > 0 { - fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start) - } else if statement.LimitN > 0 { - fmt.Fprint(&buf, " LIMIT ", statement.LimitN) - } - } else if dialect.DBType() == core.ORACLE { - if statement.Start != 0 || statement.LimitN != 0 { - oldString := buf.String() - buf.Reset() - rawColStr := columnStr - if rawColStr == "*" { - rawColStr = "at.*" - } - fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", - columnStr, rawColStr, oldString, statement.Start+statement.LimitN, statement.Start) - } - } - } - if statement.IsForUpdate { - return dialect.ForUpdateSql(buf.String()), nil - } - - return buf.String(), nil -} - -func (statement *Statement) processIDParam() error { - if statement.idParam == nil || statement.RefTable == nil { - return nil - } - - if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) { - return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", - len(statement.RefTable.PrimaryKeys), - len(*statement.idParam), - ) - } - - for i, col := range statement.RefTable.PKColumns() { - var colName = statement.colName(col, statement.TableName()) - statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) - } - return nil -} - -func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string { - var colnames = make([]string, len(cols)) - for i, col := range cols { - if includeTableName { - colnames[i] = statement.Engine.Quote(statement.TableName()) + - "." + statement.Engine.Quote(col.Name) - } else { - colnames[i] = statement.Engine.Quote(col.Name) - } - } - return strings.Join(colnames, ", ") -} - -func (statement *Statement) convertIDSQL(sqlStr string) string { - if statement.RefTable != nil { - cols := statement.RefTable.PKColumns() - if len(cols) == 0 { - return "" - } - - colstrs := statement.joinColumns(cols, false) - sqls := splitNNoCase(sqlStr, " from ", 2) - if len(sqls) != 2 { - return "" - } - - var top string - if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL { - top = fmt.Sprintf("TOP %d ", statement.LimitN) - } - - newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) - return newsql - } - return "" -} - -func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { - if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { - return "", "" - } - - colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true) - sqls := splitNNoCase(sqlStr, "where", 2) - if len(sqls) != 2 { - if len(sqls) == 1 { - return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - colstrs, statement.Engine.Quote(statement.TableName())) - } - return "", "" - } - - var whereStr = sqls[1] - - // TODO: for postgres only, if any other database? - var paraStr string - if statement.Engine.dialect.DBType() == core.POSTGRES { - paraStr = "$" - } else if statement.Engine.dialect.DBType() == core.MSSQL { - paraStr = ":" - } - - if paraStr != "" { - if strings.Contains(sqls[1], paraStr) { - dollers := strings.Split(sqls[1], paraStr) - whereStr = dollers[0] - for i, c := range dollers[1:] { - ccs := strings.SplitN(c, " ", 2) - whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) - } - } - } - - return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - colstrs, statement.Engine.Quote(statement.TableName()), - whereStr) -} diff --git a/statement_columnmap.go b/statement_columnmap.go deleted file mode 100644 index b6523b1e..00000000 --- a/statement_columnmap.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import "strings" - -type columnMap []string - -func (m columnMap) contain(colName string) bool { - if len(m) == 0 { - return false - } - - n := len(colName) - for _, mk := range m { - if len(mk) != n { - continue - } - if strings.EqualFold(mk, colName) { - return true - } - } - - return false -} - -func (m *columnMap) add(colName string) bool { - if m.contain(colName) { - return false - } - *m = append(*m, colName) - return true -} diff --git a/statement_quote.go b/statement_quote.go deleted file mode 100644 index e22e0d14..00000000 --- a/statement_quote.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2019 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -func trimQuote(s string) string { - if len(s) == 0 { - return s - } - - if s[0] == '`' { - s = s[1:] - } - if len(s) > 0 && s[len(s)-1] == '`' { - return s[:len(s)-1] - } - return s -} diff --git a/tag_cache_test.go b/tag_cache_test.go deleted file mode 100644 index 30e2c51a..00000000 --- a/tag_cache_test.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCacheTag(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type CacheDomain struct { - Id int64 `xorm:"pk cache"` - Name string - } - - assert.NoError(t, testEngine.CreateTables(&CacheDomain{})) - assert.True(t, testEngine.GetCacher(testEngine.TableName(&CacheDomain{})) != nil) -} - -func TestNoCacheTag(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type NoCacheDomain struct { - Id int64 `xorm:"pk nocache"` - Name string - } - - assert.NoError(t, testEngine.CreateTables(&NoCacheDomain{})) - assert.True(t, testEngine.GetCacher(testEngine.TableName(&NoCacheDomain{})) == nil) -} diff --git a/tag_extends_test.go b/tag_extends_test.go deleted file mode 100644 index 5a8031f0..00000000 --- a/tag_extends_test.go +++ /dev/null @@ -1,608 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "fmt" - "testing" - "time" - - "xorm.io/core" - "github.com/stretchr/testify/assert" -) - -type tempUser struct { - Id int64 - Username string -} - -type tempUser2 struct { - TempUser tempUser `xorm:"extends"` - Departname string -} - -type tempUser3 struct { - Temp *tempUser `xorm:"extends"` - Departname string -} - -type tempUser4 struct { - TempUser2 tempUser2 `xorm:"extends"` -} - -type Userinfo struct { - Uid int64 `xorm:"id pk not null autoincr"` - Username string `xorm:"unique"` - Departname string - Alias string `xorm:"-"` - Created time.Time - Detail Userdetail `xorm:"detail_id int(11)"` - Height float64 - Avatar []byte - IsMan bool -} - -type Userdetail struct { - Id int64 - Intro string `xorm:"text"` - Profile string `xorm:"varchar(2000)"` -} - -type UserAndDetail struct { - Userinfo `xorm:"extends"` - Userdetail `xorm:"extends"` -} - -func TestExtends(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&tempUser2{}) - assert.NoError(t, err) - - err = testEngine.CreateTables(&tempUser2{}) - assert.NoError(t, err) - - tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} - _, err = testEngine.Insert(tu) - assert.NoError(t, err) - - tu2 := &tempUser2{} - _, err = testEngine.Get(tu2) - assert.NoError(t, err) - - tu3 := &tempUser2{tempUser{0, "extends update"}, ""} - _, err = testEngine.ID(tu2.TempUser.Id).Update(tu3) - assert.NoError(t, err) - - err = testEngine.DropTables(&tempUser4{}) - assert.NoError(t, err) - - err = testEngine.CreateTables(&tempUser4{}) - assert.NoError(t, err) - - tu8 := &tempUser4{tempUser2{tempUser{0, "extends"}, "dev depart"}} - _, err = testEngine.Insert(tu8) - assert.NoError(t, err) - - tu9 := &tempUser4{} - _, err = testEngine.Get(tu9) - assert.NoError(t, err) - - if tu9.TempUser2.TempUser.Username != tu8.TempUser2.TempUser.Username || tu9.TempUser2.Departname != tu8.TempUser2.Departname { - err = errors.New(fmt.Sprintln("not equal for", tu8, tu9)) - t.Error(err) - panic(err) - } - - tu10 := &tempUser4{tempUser2{tempUser{0, "extends update"}, ""}} - _, err = testEngine.ID(tu9.TempUser2.TempUser.Id).Update(tu10) - assert.NoError(t, err) - - err = testEngine.DropTables(&tempUser3{}) - assert.NoError(t, err) - - err = testEngine.CreateTables(&tempUser3{}) - assert.NoError(t, err) - - tu4 := &tempUser3{&tempUser{0, "extends"}, "dev depart"} - _, err = testEngine.Insert(tu4) - assert.NoError(t, err) - - tu5 := &tempUser3{} - _, err = testEngine.Get(tu5) - assert.NoError(t, err) - - if tu5.Temp == nil { - err = errors.New("error get data extends") - t.Error(err) - panic(err) - } - if tu5.Temp.Id != 1 || tu5.Temp.Username != "extends" || - tu5.Departname != "dev depart" { - err = errors.New("error get data extends") - t.Error(err) - panic(err) - } - - tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} - _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) - assert.NoError(t, err) - - users := make([]tempUser3, 0) - err = testEngine.Find(&users) - assert.NoError(t, err) - assert.EqualValues(t, 1, len(users), "error get data not 1") - - assertSync(t, new(Userinfo), new(Userdetail)) - - detail := Userdetail{ - Intro: "I'm in China", - } - _, err = testEngine.Insert(&detail) - assert.NoError(t, err) - - _, err = testEngine.Insert(&Userinfo{ - Username: "lunny", - Detail: detail, - }) - assert.NoError(t, err) - - var info UserAndDetail - qt := testEngine.Quote - ui := testEngine.TableName(new(Userinfo), true) - ud := testEngine.TableName(&detail, true) - uiid := testEngine.GetColumnMapper().Obj2Table("Id") - udid := "detail_id" - sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", - qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) - b, err := testEngine.SQL(sql).NoCascade().Get(&info) - assert.NoError(t, err) - if !b { - err = errors.New("should has lest one record") - t.Error(err) - panic(err) - } - fmt.Println(info) - if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 { - err = errors.New("all of the id should has value") - t.Error(err) - panic(err) - } - - fmt.Println("----join--info2") - var info2 UserAndDetail - b, err = testEngine.Table(&Userinfo{}). - Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). - NoCascade().Get(&info2) - if err != nil { - t.Error(err) - panic(err) - } - if !b { - err = errors.New("should has lest one record") - t.Error(err) - panic(err) - } - if info2.Userinfo.Uid == 0 || info2.Userdetail.Id == 0 { - err = errors.New("all of the id should has value") - t.Error(err) - panic(err) - } - fmt.Println(info2) - - fmt.Println("----join--infos2") - var infos2 = make([]UserAndDetail, 0) - err = testEngine.Table(&Userinfo{}). - Join("LEFT", qt(ud), qt(ui)+"."+qt("detail_id")+" = "+qt(ud)+"."+qt(uiid)). - NoCascade(). - Find(&infos2) - assert.NoError(t, err) - fmt.Println(infos2) -} - -type MessageBase struct { - Id int64 `xorm:"int(11) pk autoincr"` - TypeId int64 `xorm:"int(11) notnull"` -} - -type Message struct { - MessageBase `xorm:"extends"` - Title string `xorm:"varchar(100) notnull"` - Content string `xorm:"text notnull"` - Uid int64 `xorm:"int(11) notnull"` - ToUid int64 `xorm:"int(11) notnull"` - CreateTime time.Time `xorm:"datetime notnull created"` -} - -type MessageUser struct { - Id int64 - Name string -} - -type MessageType struct { - Id int64 - Name string -} - -type MessageExtend3 struct { - Message `xorm:"extends"` - Sender MessageUser `xorm:"extends"` - Receiver MessageUser `xorm:"extends"` - Type MessageType `xorm:"extends"` -} - -type MessageExtend4 struct { - Message `xorm:"extends"` - MessageUser `xorm:"extends"` - MessageType `xorm:"extends"` -} - -func TestExtends2(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) - assert.NoError(t, err) - - err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) - assert.NoError(t, err) - - var sender = MessageUser{Name: "sender"} - var receiver = MessageUser{Name: "receiver"} - var msgtype = MessageType{Name: "type"} - _, err = testEngine.Insert(&sender, &receiver, &msgtype) - assert.NoError(t, err) - - msg := Message{ - MessageBase: MessageBase{ - Id: msgtype.Id, - }, - Title: "test", - Content: "test", - Uid: sender.Id, - ToUid: receiver.Id, - } - - session := testEngine.NewSession() - defer session.Close() - - // MSSQL deny insert identity column excep declare as below - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Begin() - assert.NoError(t, err) - _, err = session.Exec("SET IDENTITY_INSERT message ON") - assert.NoError(t, err) - } - cnt, err := session.Insert(&msg) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Commit() - assert.NoError(t, err) - } - - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote - userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) - typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) - msgTableName := quote(testEngine.TableName(mapper("Message"), true)) - - list := make([]Message, 0) - err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). - Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). - Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). - Find(&list) - assert.NoError(t, err) - - assert.EqualValues(t, 1, len(list), fmt.Sprintln("should have 1 message, got", len(list))) - assert.EqualValues(t, msg.Id, list[0].Id, fmt.Sprintln("should message equal", list[0], msg)) -} - -func TestExtends3(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } - - var sender = MessageUser{Name: "sender"} - var receiver = MessageUser{Name: "receiver"} - var msgtype = MessageType{Name: "type"} - _, err = testEngine.Insert(&sender, &receiver, &msgtype) - if err != nil { - t.Error(err) - panic(err) - } - - msg := Message{ - MessageBase: MessageBase{ - Id: msgtype.Id, - }, - Title: "test", - Content: "test", - Uid: sender.Id, - ToUid: receiver.Id, - } - - session := testEngine.NewSession() - defer session.Close() - - // MSSQL deny insert identity column excep declare as below - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Begin() - assert.NoError(t, err) - _, err = session.Exec("SET IDENTITY_INSERT message ON") - assert.NoError(t, err) - } - _, err = session.Insert(&msg) - assert.NoError(t, err) - - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Commit() - assert.NoError(t, err) - } - - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote - userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) - typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) - msgTableName := quote(testEngine.TableName(mapper("Message"), true)) - - list := make([]MessageExtend3, 0) - err = session.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). - Join("LEFT", []string{userTableName, "receiver"}, "`receiver`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("ToUid")+"`"). - Join("LEFT", []string{typeTableName, "type"}, "`type`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). - Find(&list) - assert.NoError(t, err) - - if len(list) != 1 { - err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) - t.Error(err) - panic(err) - } - - if list[0].Message.Id != msg.Id { - err = errors.New(fmt.Sprintln("should message equal", list[0].Message, msg)) - t.Error(err) - panic(err) - } - - if list[0].Sender.Id != sender.Id || list[0].Sender.Name != sender.Name { - err = errors.New(fmt.Sprintln("should sender equal", list[0].Sender, sender)) - t.Error(err) - panic(err) - } - - if list[0].Receiver.Id != receiver.Id || list[0].Receiver.Name != receiver.Name { - err = errors.New(fmt.Sprintln("should receiver equal", list[0].Receiver, receiver)) - t.Error(err) - panic(err) - } - - if list[0].Type.Id != msgtype.Id || list[0].Type.Name != msgtype.Name { - err = errors.New(fmt.Sprintln("should msgtype equal", list[0].Type, msgtype)) - t.Error(err) - panic(err) - } -} - -func TestExtends4(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Message{}, &MessageUser{}, &MessageType{}) - if err != nil { - t.Error(err) - panic(err) - } - - var sender = MessageUser{Name: "sender"} - var msgtype = MessageType{Name: "type"} - _, err = testEngine.Insert(&sender, &msgtype) - if err != nil { - t.Error(err) - panic(err) - } - - msg := Message{ - MessageBase: MessageBase{ - Id: msgtype.Id, - }, - Title: "test", - Content: "test", - Uid: sender.Id, - } - - session := testEngine.NewSession() - defer session.Close() - - // MSSQL deny insert identity column excep declare as below - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Begin() - assert.NoError(t, err) - _, err = session.Exec("SET IDENTITY_INSERT message ON") - assert.NoError(t, err) - } - _, err = session.Insert(&msg) - assert.NoError(t, err) - - if testEngine.Dialect().DBType() == core.MSSQL { - err = session.Commit() - assert.NoError(t, err) - } - - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote - userTableName := quote(testEngine.TableName(mapper("MessageUser"), true)) - typeTableName := quote(testEngine.TableName(mapper("MessageType"), true)) - msgTableName := quote(testEngine.TableName(mapper("Message"), true)) - - list := make([]MessageExtend4, 0) - err = session.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`"). - Join("LEFT", typeTableName, typeTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Id")+"`"). - Find(&list) - if err != nil { - t.Error(err) - panic(err) - } - - if len(list) != 1 { - err = errors.New(fmt.Sprintln("should have 1 message, got", len(list))) - t.Error(err) - panic(err) - } - - if list[0].Message.Id != msg.Id { - err = errors.New(fmt.Sprintln("should message equal", list[0].Message, msg)) - t.Error(err) - panic(err) - } - - if list[0].MessageUser.Id != sender.Id || list[0].MessageUser.Name != sender.Name { - err = errors.New(fmt.Sprintln("should sender equal", list[0].MessageUser, sender)) - t.Error(err) - panic(err) - } - - if list[0].MessageType.Id != msgtype.Id || list[0].MessageType.Name != msgtype.Name { - err = errors.New(fmt.Sprintln("should msgtype equal", list[0].MessageType, msgtype)) - t.Error(err) - panic(err) - } -} - -type Size struct { - ID int64 `xorm:"int(4) 'id' pk autoincr"` - Width float32 `json:"width" xorm:"float 'Width'"` - Height float32 `json:"height" xorm:"float 'Height'"` -} - -type Book struct { - ID int64 `xorm:"int(4) 'id' pk autoincr"` - SizeOpen *Size `xorm:"extends('Open')"` - SizeClosed *Size `xorm:"extends('Closed')"` - Size *Size `xorm:"extends('')"` -} - -func TestExtends5(t *testing.T) { - assert.NoError(t, prepareEngine()) - err := testEngine.DropTables(&Book{}, &Size{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(&Size{}, &Book{}) - if err != nil { - t.Error(err) - panic(err) - } - - var sc = Size{Width: 0.2, Height: 0.4} - var so = Size{Width: 0.2, Height: 0.8} - var s = Size{Width: 0.15, Height: 1.5} - var bk1 = Book{ - SizeOpen: &so, - SizeClosed: &sc, - Size: &s, - } - var bk2 = Book{ - SizeOpen: &so, - } - var bk3 = Book{ - SizeClosed: &sc, - Size: &s, - } - var bk4 = Book{} - var bk5 = Book{Size: &s} - _, err = testEngine.Insert(&sc, &so, &s, &bk1, &bk2, &bk3, &bk4, &bk5) - if err != nil { - t.Fatal(err) - } - - var books = map[int64]Book{ - bk1.ID: bk1, - bk2.ID: bk2, - bk3.ID: bk3, - bk4.ID: bk4, - bk5.ID: bk5, - } - - session := testEngine.NewSession() - defer session.Close() - - var mapper = testEngine.GetTableMapper().Obj2Table - var quote = testEngine.Quote - bookTableName := quote(testEngine.TableName(mapper("Book"), true)) - sizeTableName := quote(testEngine.TableName(mapper("Size"), true)) - - list := make([]Book, 0) - err = session. - Select(fmt.Sprintf( - "%s.%s, sc.%s AS %s, sc.%s AS %s, s.%s, s.%s", - quote(bookTableName), - quote("id"), - quote("Width"), - quote("ClosedWidth"), - quote("Height"), - quote("ClosedHeight"), - quote("Width"), - quote("Height"), - )). - Table(bookTableName). - Join( - "LEFT", - sizeTableName+" AS `sc`", - bookTableName+".`SizeClosed`=sc.`id`", - ). - Join( - "LEFT", - sizeTableName+" AS `s`", - bookTableName+".`Size`=s.`id`", - ). - Find(&list) - if err != nil { - t.Error(err) - panic(err) - } - - for _, book := range list { - if ok := assert.Equal(t, books[book.ID].SizeClosed.Width, book.SizeClosed.Width); !ok { - t.Error("Not bounded size closed") - panic("Not bounded size closed") - } - - if ok := assert.Equal(t, books[book.ID].SizeClosed.Height, book.SizeClosed.Height); !ok { - t.Error("Not bounded size closed") - panic("Not bounded size closed") - } - - if books[book.ID].Size != nil || book.Size != nil { - if ok := assert.Equal(t, books[book.ID].Size.Width, book.Size.Width); !ok { - t.Error("Not bounded size") - panic("Not bounded size") - } - - if ok := assert.Equal(t, books[book.ID].Size.Height, book.Size.Height); !ok { - t.Error("Not bounded size") - panic("Not bounded size") - } - } - } -} diff --git a/tag_id_test.go b/tag_id_test.go deleted file mode 100644 index f1c5a6bc..00000000 --- a/tag_id_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "testing" - - "xorm.io/core" - "github.com/stretchr/testify/assert" -) - -type IDGonicMapper struct { - ID int64 -} - -func TestGonicMapperID(t *testing.T) { - assert.NoError(t, prepareEngine()) - - oldMapper := testEngine.GetColumnMapper() - testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) - testEngine.SetMapper(core.LintGonicMapper) - defer func() { - testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) - testEngine.SetMapper(oldMapper) - }() - - err := testEngine.CreateTables(new(IDGonicMapper)) - if err != nil { - t.Fatal(err) - } - - tables, err := testEngine.DBMetas() - if err != nil { - t.Fatal(err) - } - - for _, tb := range tables { - if tb.Name == "id_gonic_mapper" { - if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "id" { - t.Fatal(tb) - } - return - } - } - - t.Fatal("not table id_gonic_mapper") -} - -type IDSameMapper struct { - ID int64 -} - -func TestSameMapperID(t *testing.T) { - assert.NoError(t, prepareEngine()) - - oldMapper := testEngine.GetColumnMapper() - testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) - testEngine.SetMapper(core.SameMapper{}) - defer func() { - testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) - testEngine.SetMapper(oldMapper) - }() - - err := testEngine.CreateTables(new(IDSameMapper)) - if err != nil { - t.Fatal(err) - } - - tables, err := testEngine.DBMetas() - if err != nil { - t.Fatal(err) - } - - for _, tb := range tables { - if tb.Name == "IDSameMapper" { - if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "ID" { - t.Fatal(tb) - } - return - } - } - t.Fatal("not table IDSameMapper") -} diff --git a/tag_test.go b/tag_test.go deleted file mode 100644 index 979ba929..00000000 --- a/tag_test.go +++ /dev/null @@ -1,600 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "xorm.io/core" -) - -type UserCU struct { - Id int64 - Name string - Created time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` -} - -func TestCreatedAndUpdated(t *testing.T) { - assert.NoError(t, prepareEngine()) - - u := new(UserCU) - err := testEngine.DropTables(u) - assert.NoError(t, err) - - err = testEngine.CreateTables(u) - assert.NoError(t, err) - - u.Name = "sss" - cnt, err := testEngine.Insert(u) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - u.Name = "xxx" - cnt, err = testEngine.ID(u.Id).Update(u) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - u.Id = 0 - u.Created = time.Now().Add(-time.Hour * 24 * 365) - u.Updated = u.Created - cnt, err = testEngine.NoAutoTime().Insert(u) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) -} - -type StrangeName struct { - Id_t int64 `xorm:"pk autoincr"` - Name string -} - -func TestStrangeName(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(StrangeName)) - assert.NoError(t, err) - - err = testEngine.CreateTables(new(StrangeName)) - assert.NoError(t, err) - - _, err = testEngine.Insert(&StrangeName{Name: "sfsfdsfds"}) - assert.NoError(t, err) - - beans := make([]StrangeName, 0) - err = testEngine.Find(&beans) - assert.NoError(t, err) -} - -func TestCreatedUpdated(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type CreatedUpdated struct { - Id int64 - Name string - Value float64 `xorm:"numeric"` - Created time.Time `xorm:"created"` - Created2 time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` - } - - err := testEngine.Sync2(&CreatedUpdated{}) - assert.NoError(t, err) - - c := &CreatedUpdated{Name: "test"} - _, err = testEngine.Insert(c) - assert.NoError(t, err) - - c2 := new(CreatedUpdated) - has, err := testEngine.ID(c.Id).Get(c2) - assert.NoError(t, err) - - assert.True(t, has) - - c2.Value -= 1 - _, err = testEngine.ID(c2.Id).Update(c2) - assert.NoError(t, err) -} - -func TestCreatedUpdatedInt64(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type CreatedUpdatedInt64 struct { - Id int64 - Name string - Value float64 `xorm:"numeric"` - Created int64 `xorm:"created"` - Created2 int64 `xorm:"created"` - Updated int64 `xorm:"updated"` - } - - assertSync(t, &CreatedUpdatedInt64{}) - - c := &CreatedUpdatedInt64{Name: "test"} - _, err := testEngine.Insert(c) - assert.NoError(t, err) - - c2 := new(CreatedUpdatedInt64) - has, err := testEngine.ID(c.Id).Get(c2) - assert.NoError(t, err) - assert.True(t, has) - - c2.Value -= 1 - _, err = testEngine.ID(c2.Id).Update(c2) - assert.NoError(t, err) -} - -type Lowercase struct { - Id int64 - Name string - ended int64 `xorm:"-"` -} - -func TestLowerCase(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.Sync2(&Lowercase{}) - assert.NoError(t, err) - _, err = testEngine.Where("id > 0").Delete(&Lowercase{}) - assert.NoError(t, err) - - _, err = testEngine.Insert(&Lowercase{ended: 1}) - assert.NoError(t, err) - - ls := make([]Lowercase, 0) - err = testEngine.Find(&ls) - assert.NoError(t, err) - assert.EqualValues(t, 1, len(ls)) -} - -func TestAutoIncrTag(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type TestAutoIncr1 struct { - Id int64 - } - - tb := testEngine.TableInfo(new(TestAutoIncr1)) - cols := tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.True(t, cols[0].IsAutoIncrement) - assert.True(t, cols[0].IsPrimaryKey) - assert.Equal(t, "id", cols[0].Name) - - type TestAutoIncr2 struct { - Id int64 `xorm:"id"` - } - - tb = testEngine.TableInfo(new(TestAutoIncr2)) - cols = tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.False(t, cols[0].IsAutoIncrement) - assert.False(t, cols[0].IsPrimaryKey) - assert.Equal(t, "id", cols[0].Name) - - type TestAutoIncr3 struct { - Id int64 `xorm:"'ID'"` - } - - tb = testEngine.TableInfo(new(TestAutoIncr3)) - cols = tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.False(t, cols[0].IsAutoIncrement) - assert.False(t, cols[0].IsPrimaryKey) - assert.Equal(t, "ID", cols[0].Name) - - type TestAutoIncr4 struct { - Id int64 `xorm:"pk"` - } - - tb = testEngine.TableInfo(new(TestAutoIncr4)) - cols = tb.Columns() - assert.EqualValues(t, 1, len(cols)) - assert.False(t, cols[0].IsAutoIncrement) - assert.True(t, cols[0].IsPrimaryKey) - assert.Equal(t, "id", cols[0].Name) -} - -func TestTagComment(t *testing.T) { - assert.NoError(t, prepareEngine()) - // FIXME: only support mysql - if testEngine.Dialect().DriverName() != core.MYSQL { - return - } - - type TestComment1 struct { - Id int64 `xorm:"comment(主键)"` - } - - assert.NoError(t, testEngine.Sync2(new(TestComment1))) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - assert.EqualValues(t, 1, len(tables[0].Columns())) - assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) - - assert.NoError(t, testEngine.DropTables(new(TestComment1))) - - type TestComment2 struct { - Id int64 `xorm:"comment('主键')"` - } - - assert.NoError(t, testEngine.Sync2(new(TestComment2))) - - tables, err = testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - assert.EqualValues(t, 1, len(tables[0].Columns())) - assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) -} - -func TestTagDefault(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct struct { - Id int64 - Name string - Age int `xorm:"default(10)"` - } - - assertSync(t, new(DefaultStruct)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("age") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - assert.EqualValues(t, "10", defaultVal) - - cnt, err := testEngine.Omit("age").Insert(&DefaultStruct{ - Name: "test", - Age: 20, - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var s DefaultStruct - has, err := testEngine.ID(1).Get(&s) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, 10, s.Age) - assert.EqualValues(t, "test", s.Name) -} - -func TestTagDefault2(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct2 struct { - Id int64 - Name string - } - - assertSync(t, new(DefaultStruct2)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct2") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("name") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.False(t, isDefaultExist, fmt.Sprintf("default value is --%v--", defaultVal)) - assert.EqualValues(t, "", defaultVal) -} - -func TestTagDefault3(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct3 struct { - Id int64 - Name string `xorm:"default('myname')"` - } - - assertSync(t, new(DefaultStruct3)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct3") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("name") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - assert.EqualValues(t, "'myname'", defaultVal) -} - -func TestTagDefault4(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct4 struct { - Id int64 - Created time.Time `xorm:"default(CURRENT_TIMESTAMP)"` - } - - assertSync(t, new(DefaultStruct4)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct4") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("created") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - assert.True(t, "CURRENT_TIMESTAMP" == defaultVal || - "now()" == defaultVal || - "getdate" == defaultVal, defaultVal) -} - -func TestTagDefault5(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct5 struct { - Id int64 - Created time.Time `xorm:"default('2006-01-02 15:04:05')"` - } - - assertSync(t, new(DefaultStruct5)) - table := testEngine.TableInfo(new(DefaultStruct5)) - createdCol := table.GetColumn("created") - assert.NotNil(t, createdCol) - assert.EqualValues(t, "'2006-01-02 15:04:05'", createdCol.Default) - assert.False(t, createdCol.DefaultIsEmpty) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct5") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("created") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - assert.EqualValues(t, "'2006-01-02 15:04:05'", defaultVal) -} - -func TestTagDefault6(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type DefaultStruct6 struct { - Id int64 - IsMan bool `xorm:"default(true)"` - } - - assertSync(t, new(DefaultStruct6)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - - var defaultVal string - var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct6") - for _, table := range tables { - if table.Name == tableName { - col := table.GetColumn("is_man") - assert.NotNil(t, col) - defaultVal = col.Default - isDefaultExist = !col.DefaultIsEmpty - break - } - } - assert.True(t, isDefaultExist) - if defaultVal == "1" { - defaultVal = "true" - } else if defaultVal == "0" { - defaultVal = "false" - } - assert.EqualValues(t, "true", defaultVal) -} - -func TestTagsDirection(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type OnlyFromDBStruct struct { - Id int64 - Name string - Uuid string `xorm:"<- default '1'"` - } - - assertSync(t, new(OnlyFromDBStruct)) - - cnt, err := testEngine.Insert(&OnlyFromDBStruct{ - Name: "test", - Uuid: "2", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var s OnlyFromDBStruct - has, err := testEngine.ID(1).Get(&s) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, "1", s.Uuid) - assert.EqualValues(t, "test", s.Name) - - cnt, err = testEngine.ID(1).Update(&OnlyFromDBStruct{ - Uuid: "3", - Name: "test1", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var s3 OnlyFromDBStruct - has, err = testEngine.ID(1).Get(&s3) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, "1", s3.Uuid) - assert.EqualValues(t, "test1", s3.Name) - - type OnlyToDBStruct struct { - Id int64 - Name string - Uuid string `xorm:"->"` - } - - assertSync(t, new(OnlyToDBStruct)) - - cnt, err = testEngine.Insert(&OnlyToDBStruct{ - Name: "test", - Uuid: "2", - }) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var s2 OnlyToDBStruct - has, err = testEngine.ID(1).Get(&s2) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, "", s2.Uuid) - assert.EqualValues(t, "test", s2.Name) -} - -func TestTagTime(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type TagUTCStruct struct { - Id int64 - Name string - Created time.Time `xorm:"created utc"` - } - - assertSync(t, new(TagUTCStruct)) - - assert.EqualValues(t, time.Local.String(), testEngine.GetTZLocation().String()) - - s := TagUTCStruct{ - Name: "utc", - } - cnt, err := testEngine.Insert(&s) - assert.NoError(t, err) - assert.EqualValues(t, 1, cnt) - - var u TagUTCStruct - has, err := testEngine.ID(1).Get(&u) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, s.Created.Format("2006-01-02 15:04:05"), u.Created.Format("2006-01-02 15:04:05")) - - var tm string - has, err = testEngine.Table("tag_u_t_c_struct").Cols("created").Get(&tm) - assert.NoError(t, err) - assert.True(t, has) - assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), - strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) -} - -func TestSplitTag(t *testing.T) { - var cases = []struct { - tag string - tags []string - }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, - } - - for _, kase := range cases { - tags := splitTag(kase.tag) - if !sliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } - } -} - -func TestTagAutoIncr(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type TagAutoIncr struct { - Id int64 - Name string - } - - assertSync(t, new(TagAutoIncr)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - assert.EqualValues(t, tableMapper.Obj2Table("TagAutoIncr"), tables[0].Name) - col := tables[0].GetColumn(colMapper.Obj2Table("Id")) - assert.NotNil(t, col) - assert.True(t, col.IsPrimaryKey) - assert.True(t, col.IsAutoIncrement) - - col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) - assert.NotNil(t, col2) - assert.False(t, col2.IsPrimaryKey) - assert.False(t, col2.IsAutoIncrement) -} - -func TestTagPrimarykey(t *testing.T) { - assert.NoError(t, prepareEngine()) - type TagPrimaryKey struct { - Id int64 `xorm:"pk"` - Name string `xorm:"VARCHAR(20) pk"` - } - - assertSync(t, new(TagPrimaryKey)) - - tables, err := testEngine.DBMetas() - assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - assert.EqualValues(t, tableMapper.Obj2Table("TagPrimaryKey"), tables[0].Name) - col := tables[0].GetColumn(colMapper.Obj2Table("Id")) - assert.NotNil(t, col) - assert.True(t, col.IsPrimaryKey) - assert.False(t, col.IsAutoIncrement) - - col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) - assert.NotNil(t, col2) - assert.True(t, col2.IsPrimaryKey) - assert.False(t, col2.IsAutoIncrement) -} diff --git a/tag_version_test.go b/tag_version_test.go deleted file mode 100644 index cd6dc935..00000000 --- a/tag_version_test.go +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "errors" - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -type VersionS struct { - Id int64 - Name string - Ver int `xorm:"version"` - Created time.Time `xorm:"created"` -} - -func TestVersion1(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(VersionS)) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(new(VersionS)) - if err != nil { - t.Error(err) - panic(err) - } - - ver := &VersionS{Name: "sfsfdsfds"} - _, err = testEngine.Insert(ver) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(ver) - if ver.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } - - newVer := new(VersionS) - has, err := testEngine.ID(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } - - if !has { - t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id))) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } - - newVer.Name = "-------" - _, err = testEngine.ID(ver.Id).Update(newVer) - if err != nil { - t.Error(err) - panic(err) - } - if newVer.Ver != 2 { - err = errors.New("update should set version back to struct") - t.Error(err) - } - - newVer = new(VersionS) - has, err = testEngine.ID(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 2 { - err = errors.New("update error") - t.Error(err) - panic(err) - } -} - -func TestVersion2(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(VersionS)) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(new(VersionS)) - if err != nil { - t.Error(err) - panic(err) - } - - var vers = []VersionS{ - {Name: "sfsfdsfds"}, - {Name: "xxxxx"}, - } - _, err = testEngine.Insert(vers) - if err != nil { - t.Error(err) - panic(err) - } - - fmt.Println(vers) - - for _, v := range vers { - if v.Ver != 1 { - err := errors.New("version should be 1") - t.Error(err) - panic(err) - } - } -} - -type VersionUintS struct { - Id int64 - Name string - Ver uint `xorm:"version"` - Created time.Time `xorm:"created"` -} - -func TestVersion3(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(VersionUintS)) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(new(VersionUintS)) - if err != nil { - t.Error(err) - panic(err) - } - - ver := &VersionUintS{Name: "sfsfdsfds"} - _, err = testEngine.Insert(ver) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(ver) - if ver.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } - - newVer := new(VersionUintS) - has, err := testEngine.ID(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } - - if !has { - t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id))) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } - - newVer.Name = "-------" - _, err = testEngine.ID(ver.Id).Update(newVer) - if err != nil { - t.Error(err) - panic(err) - } - if newVer.Ver != 2 { - err = errors.New("update should set version back to struct") - t.Error(err) - } - - newVer = new(VersionUintS) - has, err = testEngine.ID(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 2 { - err = errors.New("update error") - t.Error(err) - panic(err) - } -} - -func TestVersion4(t *testing.T) { - assert.NoError(t, prepareEngine()) - - err := testEngine.DropTables(new(VersionUintS)) - if err != nil { - t.Error(err) - panic(err) - } - - err = testEngine.CreateTables(new(VersionUintS)) - if err != nil { - t.Error(err) - panic(err) - } - - var vers = []VersionUintS{ - {Name: "sfsfdsfds"}, - {Name: "xxxxx"}, - } - _, err = testEngine.Insert(vers) - if err != nil { - t.Error(err) - panic(err) - } - - fmt.Println(vers) - - for _, v := range vers { - if v.Ver != 1 { - err := errors.New("version should be 1") - t.Error(err) - panic(err) - } - } -} diff --git a/tags/parser.go b/tags/parser.go new file mode 100644 index 00000000..add30a13 --- /dev/null +++ b/tags/parser.go @@ -0,0 +1,308 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tags + +import ( + "encoding/gob" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "time" + + "xorm.io/xorm/caches" + "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" + "xorm.io/xorm/names" + "xorm.io/xorm/schemas" +) + +var ( + ErrUnsupportedType = errors.New("Unsupported type") +) + +type Parser struct { + identifier string + dialect dialects.Dialect + columnMapper names.Mapper + tableMapper names.Mapper + handlers map[string]Handler + cacherMgr *caches.Manager + tableCache sync.Map // map[reflect.Type]*schemas.Table +} + +func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser { + return &Parser{ + identifier: identifier, + dialect: dialect, + tableMapper: tableMapper, + columnMapper: columnMapper, + handlers: defaultTagHandlers, + cacherMgr: cacherMgr, + } +} + +func (parser *Parser) GetTableMapper() names.Mapper { + return parser.tableMapper +} + +func (parser *Parser) SetTableMapper(mapper names.Mapper) { + parser.ClearCaches() + parser.tableMapper = mapper +} + +func (parser *Parser) GetColumnMapper() names.Mapper { + return parser.columnMapper +} + +func (parser *Parser) SetColumnMapper(mapper names.Mapper) { + parser.ClearCaches() + parser.columnMapper = mapper +} + +func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { + t := v.Type() + tableI, ok := parser.tableCache.Load(t) + if ok { + return tableI.(*schemas.Table), nil + } + + table, err := parser.Parse(v) + if err != nil { + return nil, err + } + + parser.tableCache.Store(t, table) + + if parser.cacherMgr.GetDefaultCacher() != nil { + if v.CanAddr() { + gob.Register(v.Addr().Interface()) + } else { + gob.Register(v.Interface()) + } + } + + return table, nil +} + +// ClearCacheTable removes the database mapper of a type from the cache +func (parser *Parser) ClearCacheTable(t reflect.Type) { + parser.tableCache.Delete(t) +} + +// ClearCaches removes all the cached table information parsed by structs +func (parser *Parser) ClearCaches() { + parser.tableCache = sync.Map{} +} + +func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) { + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = indexType + } else { + index := schemas.NewIndex(indexName, indexType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = indexType + } +} + +// Parse parses a struct as a table information +func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { + t := v.Type() + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + if t.Kind() != reflect.Struct { + return nil, ErrUnsupportedType + } + + table := schemas.NewEmptyTable() + table.Type = t + table.Name = names.GetTableName(parser.tableMapper, v) + + var idFieldColName string + var hasCacheTag, hasNoCacheTag bool + + for i := 0; i < t.NumField(); i++ { + tag := t.Field(i).Tag + + ormTagStr := tag.Get(parser.identifier) + var col *schemas.Column + fieldValue := v.Field(i) + fieldType := fieldValue.Type() + + if ormTagStr != "" { + col = &schemas.Column{ + FieldName: t.Field(i).Name, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: schemas.TWOSIDES, + Indexes: make(map[string]int), + DefaultIsEmpty: true, + } + tags := splitTag(ormTagStr) + + if len(tags) > 0 { + if tags[0] == "-" { + continue + } + + var ctx = Context{ + table: table, + col: col, + fieldValue: fieldValue, + indexNames: make(map[string]int), + parser: parser, + } + + if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { + pStart := strings.Index(tags[0], "(") + if pStart > -1 && strings.HasSuffix(tags[0], ")") { + var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { + return r == '\'' || r == '"' + }) + + ctx.params = []string{tagPrefix} + } + + if err := ExtendsTagHandler(&ctx); err != nil { + return nil, err + } + continue + } + + for j, key := range tags { + if ctx.ignoreNext { + ctx.ignoreNext = false + continue + } + + k := strings.ToUpper(key) + ctx.tagName = k + ctx.params = []string{} + + pStart := strings.Index(k, "(") + if pStart == 0 { + return nil, errors.New("( could not be the first character") + } + if pStart > -1 { + if !strings.HasSuffix(k, ")") { + return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key) + } + + ctx.tagName = k[:pStart] + ctx.params = strings.Split(key[pStart+1:len(k)-1], ",") + } + + if j > 0 { + ctx.preTag = strings.ToUpper(tags[j-1]) + } + if j < len(tags)-1 { + ctx.nextTag = tags[j+1] + } else { + ctx.nextTag = "" + } + + if h, ok := parser.handlers[ctx.tagName]; ok { + if err := h(&ctx); err != nil { + return nil, err + } + } else { + if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") { + col.Name = key[1 : len(key)-1] + } else { + col.Name = key + } + } + + if ctx.hasCacheTag { + hasCacheTag = true + } + if ctx.hasNoCacheTag { + hasNoCacheTag = true + } + } + + if col.SQLType.Name == "" { + col.SQLType = schemas.Type2SQLType(fieldType) + } + parser.dialect.SQLType(col) + if col.Length == 0 { + col.Length = col.SQLType.DefaultLength + } + if col.Length2 == 0 { + col.Length2 = col.SQLType.DefaultLength2 + } + if col.Name == "" { + col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) + } + + if ctx.isUnique { + ctx.indexNames[col.Name] = schemas.UniqueType + } else if ctx.isIndex { + ctx.indexNames[col.Name] = schemas.IndexType + } + + for indexName, indexType := range ctx.indexNames { + addIndex(indexName, table, col, indexType) + } + } + } else { + var sqlType schemas.SQLType + if fieldValue.CanAddr() { + if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } + } + if _, ok := fieldValue.Interface().(convert.Conversion); ok { + sqlType = schemas.SQLType{Name: schemas.Text} + } else { + sqlType = schemas.Type2SQLType(fieldType) + } + col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), + t.Field(i).Name, sqlType, sqlType.DefaultLength, + sqlType.DefaultLength2, true) + + if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { + idFieldColName = col.Name + } + } + if col.IsAutoIncrement { + col.Nullable = false + } + + table.AddColumn(col) + + } // end for + + if idFieldColName != "" && len(table.PrimaryKeys) == 0 { + col := table.GetColumn(idFieldColName) + col.IsPrimaryKey = true + col.IsAutoIncrement = true + col.Nullable = false + table.PrimaryKeys = append(table.PrimaryKeys, col.Name) + table.AutoIncrement = col.Name + } + + if hasCacheTag { + if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided + //engine.logger.Info("enable cache on table:", table.Name) + parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher()) + } else { + //engine.logger.Info("enable LRU cache on table:", table.Name) + parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000)) + } + } + if hasNoCacheTag { + //engine.logger.Info("disable cache on table:", table.Name) + parser.cacherMgr.SetCacher(table.Name, nil) + } + + return table, nil +} diff --git a/tags/parser_test.go b/tags/parser_test.go new file mode 100644 index 00000000..6065bf2e --- /dev/null +++ b/tags/parser_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tags + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" + "xorm.io/xorm/names" +) + +type ParseTableName1 struct{} + +type ParseTableName2 struct{} + +func (p ParseTableName2) TableName() string { + return "p_parseTableName" +} + +func TestParseTableName(t *testing.T) { + parser := NewParser( + "xorm", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + table, err := parser.Parse(reflect.ValueOf(new(ParseTableName1))) + assert.NoError(t, err) + assert.EqualValues(t, "parse_table_name1", table.Name) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableName2))) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableName", table.Name) + + table, err = parser.Parse(reflect.ValueOf(ParseTableName2{})) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableName", table.Name) +} diff --git a/tag.go b/tags/tag.go similarity index 72% rename from tag.go rename to tags/tag.go index ec8d5cf0..ee3f1e82 100644 --- a/tag.go +++ b/tags/tag.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package tags import ( "fmt" @@ -11,31 +11,52 @@ import ( "strings" "time" - "xorm.io/core" + "xorm.io/xorm/schemas" ) -type tagContext struct { +func splitTag(tag string) (tags []string) { + tag = strings.TrimSpace(tag) + var hasQuote = false + var lastIdx = 0 + for i, t := range tag { + if t == '\'' { + hasQuote = !hasQuote + } else if t == ' ' { + if lastIdx < i && !hasQuote { + tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) + lastIdx = i + 1 + } + } + } + if lastIdx < len(tag) { + tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + } + return +} + +// Context represents a context for xorm tag parse. +type Context struct { tagName string params []string preTag, nextTag string - table *core.Table - col *core.Column + table *schemas.Table + col *schemas.Column fieldValue reflect.Value isIndex bool isUnique bool indexNames map[string]int - engine *Engine + parser *Parser hasCacheTag bool hasNoCacheTag bool ignoreNext bool } -// tagHandler describes tag handler for XORM -type tagHandler func(ctx *tagContext) error +// Handler describes tag handler for XORM +type Handler func(ctx *Context) error var ( // defaultTagHandlers enumerates all the default tag handler - defaultTagHandlers = map[string]tagHandler{ + defaultTagHandlers = map[string]Handler{ "<-": OnlyFromDBTagHandler, "->": OnlyToDBTagHandler, "PK": PKTagHandler, @@ -59,49 +80,49 @@ var ( ) func init() { - for k := range core.SqlTypes { + for k := range schemas.SqlTypes { defaultTagHandlers[k] = SQLTypeTagHandler } } // IgnoreTagHandler describes ignored tag handler -func IgnoreTagHandler(ctx *tagContext) error { +func IgnoreTagHandler(ctx *Context) error { return nil } // OnlyFromDBTagHandler describes mapping direction tag handler -func OnlyFromDBTagHandler(ctx *tagContext) error { - ctx.col.MapType = core.ONLYFROMDB +func OnlyFromDBTagHandler(ctx *Context) error { + ctx.col.MapType = schemas.ONLYFROMDB return nil } // OnlyToDBTagHandler describes mapping direction tag handler -func OnlyToDBTagHandler(ctx *tagContext) error { - ctx.col.MapType = core.ONLYTODB +func OnlyToDBTagHandler(ctx *Context) error { + ctx.col.MapType = schemas.ONLYTODB return nil } -// PKTagHandler decribes primary key tag handler -func PKTagHandler(ctx *tagContext) error { +// PKTagHandler describes primary key tag handler +func PKTagHandler(ctx *Context) error { ctx.col.IsPrimaryKey = true ctx.col.Nullable = false return nil } // NULLTagHandler describes null tag handler -func NULLTagHandler(ctx *tagContext) error { +func NULLTagHandler(ctx *Context) error { ctx.col.Nullable = (strings.ToUpper(ctx.preTag) != "NOT") return nil } // NotNullTagHandler describes notnull tag handler -func NotNullTagHandler(ctx *tagContext) error { +func NotNullTagHandler(ctx *Context) error { ctx.col.Nullable = false return nil } // AutoIncrTagHandler describes autoincr tag handler -func AutoIncrTagHandler(ctx *tagContext) error { +func AutoIncrTagHandler(ctx *Context) error { ctx.col.IsAutoIncrement = true /* if len(ctx.params) > 0 { @@ -118,7 +139,7 @@ func AutoIncrTagHandler(ctx *tagContext) error { } // DefaultTagHandler describes default tag handler -func DefaultTagHandler(ctx *tagContext) error { +func DefaultTagHandler(ctx *Context) error { if len(ctx.params) > 0 { ctx.col.Default = ctx.params[0] } else { @@ -130,26 +151,26 @@ func DefaultTagHandler(ctx *tagContext) error { } // CreatedTagHandler describes created tag handler -func CreatedTagHandler(ctx *tagContext) error { +func CreatedTagHandler(ctx *Context) error { ctx.col.IsCreated = true return nil } // VersionTagHandler describes version tag handler -func VersionTagHandler(ctx *tagContext) error { +func VersionTagHandler(ctx *Context) error { ctx.col.IsVersion = true ctx.col.Default = "1" return nil } // UTCTagHandler describes utc tag handler -func UTCTagHandler(ctx *tagContext) error { +func UTCTagHandler(ctx *Context) error { ctx.col.TimeZone = time.UTC return nil } // LocalTagHandler describes local tag handler -func LocalTagHandler(ctx *tagContext) error { +func LocalTagHandler(ctx *Context) error { if len(ctx.params) == 0 { ctx.col.TimeZone = time.Local } else { @@ -163,21 +184,21 @@ func LocalTagHandler(ctx *tagContext) error { } // UpdatedTagHandler describes updated tag handler -func UpdatedTagHandler(ctx *tagContext) error { +func UpdatedTagHandler(ctx *Context) error { ctx.col.IsUpdated = true return nil } // DeletedTagHandler describes deleted tag handler -func DeletedTagHandler(ctx *tagContext) error { +func DeletedTagHandler(ctx *Context) error { ctx.col.IsDeleted = true return nil } // IndexTagHandler describes index tag handler -func IndexTagHandler(ctx *tagContext) error { +func IndexTagHandler(ctx *Context) error { if len(ctx.params) > 0 { - ctx.indexNames[ctx.params[0]] = core.IndexType + ctx.indexNames[ctx.params[0]] = schemas.IndexType } else { ctx.isIndex = true } @@ -185,9 +206,9 @@ func IndexTagHandler(ctx *tagContext) error { } // UniqueTagHandler describes unique tag handler -func UniqueTagHandler(ctx *tagContext) error { +func UniqueTagHandler(ctx *Context) error { if len(ctx.params) > 0 { - ctx.indexNames[ctx.params[0]] = core.UniqueType + ctx.indexNames[ctx.params[0]] = schemas.UniqueType } else { ctx.isUnique = true } @@ -195,7 +216,7 @@ func UniqueTagHandler(ctx *tagContext) error { } // CommentTagHandler add comment to column -func CommentTagHandler(ctx *tagContext) error { +func CommentTagHandler(ctx *Context) error { if len(ctx.params) > 0 { ctx.col.Comment = strings.Trim(ctx.params[0], "' ") } @@ -203,17 +224,17 @@ func CommentTagHandler(ctx *tagContext) error { } // SQLTypeTagHandler describes SQL Type tag handler -func SQLTypeTagHandler(ctx *tagContext) error { - ctx.col.SQLType = core.SQLType{Name: ctx.tagName} +func SQLTypeTagHandler(ctx *Context) error { + ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} if len(ctx.params) > 0 { - if ctx.tagName == core.Enum { + if ctx.tagName == schemas.Enum { ctx.col.EnumOptions = make(map[string]int) for k, v := range ctx.params { v = strings.TrimSpace(v) v = strings.Trim(v, "'") ctx.col.EnumOptions[v] = k } - } else if ctx.tagName == core.Set { + } else if ctx.tagName == schemas.Set { ctx.col.SetOptions = make(map[string]int) for k, v := range ctx.params { v = strings.TrimSpace(v) @@ -243,7 +264,7 @@ func SQLTypeTagHandler(ctx *tagContext) error { } // ExtendsTagHandler describes extends tag handler -func ExtendsTagHandler(ctx *tagContext) error { +func ExtendsTagHandler(ctx *Context) error { var fieldValue = ctx.fieldValue var isPtr = false switch fieldValue.Kind() { @@ -259,7 +280,7 @@ func ExtendsTagHandler(ctx *tagContext) error { isPtr = true fallthrough case reflect.Struct: - parentTable, err := ctx.engine.mapType(fieldValue) + parentTable, err := ctx.parser.Parse(fieldValue) if err != nil { return err } @@ -295,7 +316,7 @@ func ExtendsTagHandler(ctx *tagContext) error { } // CacheTagHandler describes cache tag handler -func CacheTagHandler(ctx *tagContext) error { +func CacheTagHandler(ctx *Context) error { if !ctx.hasCacheTag { ctx.hasCacheTag = true } @@ -303,7 +324,7 @@ func CacheTagHandler(ctx *tagContext) error { } // NoCacheTagHandler describes nocache tag handler -func NoCacheTagHandler(ctx *tagContext) error { +func NoCacheTagHandler(ctx *Context) error { if !ctx.hasNoCacheTag { ctx.hasNoCacheTag = true } diff --git a/tags/tag_test.go b/tags/tag_test.go new file mode 100644 index 00000000..5775b40a --- /dev/null +++ b/tags/tag_test.go @@ -0,0 +1,30 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tags + +import ( + "testing" + + "xorm.io/xorm/internal/utils" +) + +func TestSplitTag(t *testing.T) { + var cases = []struct { + tag string + tags []string + }{ + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, + {"TEXT", []string{"TEXT"}}, + {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, + {"json binary", []string{"json", "binary"}}, + } + + for _, kase := range cases { + tags := splitTag(kase.tag) + if !utils.SliceEq(tags, kase.tags) { + t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) + } + } +} diff --git a/test_mssql.sh b/test_mssql.sh deleted file mode 100755 index 7f060cff..00000000 --- a/test_mssql.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=mssql -conn_str="server=localhost;user id=sa;password=yourStrong(!)Password;database=xorm_test" \ No newline at end of file diff --git a/test_mssql_cache.sh b/test_mssql_cache.sh deleted file mode 100755 index 76efd6ca..00000000 --- a/test_mssql_cache.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=mssql -conn_str="server=192.168.1.58;user id=sa;password=123456;database=xorm_test" -cache=true \ No newline at end of file diff --git a/test_mymysql.sh b/test_mymysql.sh deleted file mode 100755 index f7780d14..00000000 --- a/test_mymysql.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=mymysql -conn_str="xorm_test/root/" \ No newline at end of file diff --git a/test_mymysql_cache.sh b/test_mymysql_cache.sh deleted file mode 100755 index 0100286d..00000000 --- a/test_mymysql_cache.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=mymysql -conn_str="xorm_test/root/" -cache=true \ No newline at end of file diff --git a/test_mysql.sh b/test_mysql.sh deleted file mode 100755 index 650e4ee1..00000000 --- a/test_mysql.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=mysql -conn_str="root:@/xorm_test" \ No newline at end of file diff --git a/test_mysql_cache.sh b/test_mysql_cache.sh deleted file mode 100755 index c542e735..00000000 --- a/test_mysql_cache.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=mysql -conn_str="root:@/xorm_test" -cache=true \ No newline at end of file diff --git a/test_postgres.sh b/test_postgres.sh deleted file mode 100755 index dc1152e0..00000000 --- a/test_postgres.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=postgres -conn_str="dbname=xorm_test sslmode=disable" \ No newline at end of file diff --git a/test_postgres_cache.sh b/test_postgres_cache.sh deleted file mode 100755 index 462fc948..00000000 --- a/test_postgres_cache.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=postgres -conn_str="dbname=xorm_test sslmode=disable" -cache=true \ No newline at end of file diff --git a/test_sqlite.sh b/test_sqlite.sh deleted file mode 100755 index 6352b5cb..00000000 --- a/test_sqlite.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" \ No newline at end of file diff --git a/test_sqlite_cache.sh b/test_sqlite_cache.sh deleted file mode 100755 index 75a054c3..00000000 --- a/test_sqlite_cache.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=sqlite3 -conn_str="./test.db?cache=shared&mode=rwc" -cache=true \ No newline at end of file diff --git a/test_tidb.sh b/test_tidb.sh deleted file mode 100755 index 03d2d6cd..00000000 --- a/test_tidb.sh +++ /dev/null @@ -1 +0,0 @@ -go test -db=mysql -conn_str="root:@tcp(localhost:4000)/xorm_test" -ignore_select_update=true \ No newline at end of file diff --git a/transaction.go b/transaction.go deleted file mode 100644 index 4104103f..00000000 --- a/transaction.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2018 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -// Transaction Execute sql wrapped in a transaction(abbr as tx), tx will automatic commit if no errors occurred -func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interface{}, error) { - session := engine.NewSession() - defer session.Close() - - if err := session.Begin(); err != nil { - return nil, err - } - - result, err := f(session) - if err != nil { - return nil, err - } - - if err := session.Commit(); err != nil { - return nil, err - } - - return result, nil -} diff --git a/transancation_test.go b/transancation_test.go deleted file mode 100644 index b9a89878..00000000 --- a/transancation_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestAutoTransaction(t *testing.T) { - assert.NoError(t, prepareEngine()) - - type TestTx struct { - Id int64 `xorm:"autoincr pk"` - Msg string `xorm:"varchar(255)"` - Created time.Time `xorm:"created"` - } - - assert.NoError(t, testEngine.Sync2(new(TestTx))) - - engine := testEngine.(*Engine) - - // will success - engine.Transaction(func(session *Session) (interface{}, error) { - _, err := session.Insert(TestTx{Msg: "hi"}) - assert.NoError(t, err) - - return nil, nil - }) - - has, err := engine.Exist(&TestTx{Msg: "hi"}) - assert.NoError(t, err) - assert.EqualValues(t, true, has) - - // will rollback - _, err = engine.Transaction(func(session *Session) (interface{}, error) { - _, err := session.Insert(TestTx{Msg: "hello"}) - assert.NoError(t, err) - - return nil, fmt.Errorf("rollback") - }) - assert.Error(t, err) - - has, err = engine.Exist(&TestTx{Msg: "hello"}) - assert.NoError(t, err) - assert.EqualValues(t, false, has) -} diff --git a/xorm.go b/xorm.go deleted file mode 100644 index e1c83b56..00000000 --- a/xorm.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2015 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.8 - -package xorm - -import ( - "context" - "fmt" - "os" - "reflect" - "runtime" - "sync" - "time" - - "xorm.io/core" -) - -const ( - // Version show the xorm's version - Version string = "0.8.0.1015" -) - -func regDrvsNDialects() bool { - providedDrvsNDialects := map[string]struct { - dbType core.DbType - getDriver func() core.Driver - getDialect func() core.Dialect - }{ - "mssql": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, - "odbc": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access - "mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, - "mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, - "postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, - "pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }}, - "sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, - "oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, - "goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, - } - - for driverName, v := range providedDrvsNDialects { - if driver := core.QueryDriver(driverName); driver == nil { - core.RegisterDriver(driverName, v.getDriver()) - core.RegisterDialect(v.dbType, v.getDialect) - } - } - return true -} - -func close(engine *Engine) { - engine.Close() -} - -func init() { - regDrvsNDialects() -} - -// NewEngine new a db manager according to the parameter. Currently support four -// drivers -func NewEngine(driverName string, dataSourceName string) (*Engine, error) { - driver := core.QueryDriver(driverName) - if driver == nil { - return nil, fmt.Errorf("Unsupported driver name: %v", driverName) - } - - uri, err := driver.Parse(driverName, dataSourceName) - if err != nil { - return nil, err - } - - dialect := core.QueryDialect(uri.DbType) - if dialect == nil { - return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DbType) - } - - db, err := core.Open(driverName, dataSourceName) - if err != nil { - return nil, err - } - - err = dialect.Init(db, uri, driverName, dataSourceName) - if err != nil { - return nil, err - } - - engine := &Engine{ - db: db, - dialect: dialect, - Tables: make(map[reflect.Type]*core.Table), - mutex: &sync.RWMutex{}, - TagIdentifier: "xorm", - TZLocation: time.Local, - tagHandlers: defaultTagHandlers, - cachers: make(map[string]core.Cacher), - defaultContext: context.Background(), - } - - if uri.DbType == core.SQLITE { - engine.DatabaseTZ = time.UTC - } else { - engine.DatabaseTZ = time.Local - } - - logger := NewSimpleLogger(os.Stdout) - logger.SetLevel(core.LOG_INFO) - engine.SetLogger(logger) - engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper))) - - runtime.SetFinalizer(engine, close) - - return engine, nil -} - -// NewEngineWithParams new a db manager with params. The params will be passed to dialect. -func NewEngineWithParams(driverName string, dataSourceName string, params map[string]string) (*Engine, error) { - engine, err := NewEngine(driverName, dataSourceName) - engine.dialect.SetParams(params) - return engine, err -} - -// Clone clone an engine -func (engine *Engine) Clone() (*Engine, error) { - return NewEngine(engine.DriverName(), engine.DataSourceName()) -}