pax_global_header00006660000000000000000000000064137753353720014531gustar00rootroot0000000000000052 comment=2aad872da50499ce0f2c4a380ee6ca73ebd47285 ent-0.5.4/000077500000000000000000000000001377533537200123255ustar00rootroot00000000000000ent-0.5.4/.circleci/000077500000000000000000000000001377533537200141605ustar00rootroot00000000000000ent-0.5.4/.circleci/config.yml000066400000000000000000000146031377533537200161540ustar00rootroot00000000000000version: 2.1 aliases: - &mktestdir run: name: Create results directory command: mkdir -p ~/test-results - &storetestdir store_test_results: path: ~/test-results orbs: aws-cli: circleci/aws-cli@1.0.0 go: circleci/go@1.1.1 commands: getmods: steps: - go/load-cache - go/mod-download - go/save-cache jobs: lint: docker: - image: golangci/golangci-lint:v1.28-alpine steps: - checkout - *mktestdir - run: name: Run linters command: golangci-lint run --out-format junit-xml > ~/test-results/lint.xml - *storetestdir unit: executor: name: go/default tag: '1.15' steps: - checkout - *mktestdir - getmods - run: name: Dialect tests command: gotestsum -f short-verbose --junitfile ~/test-results/dialect.xml working_directory: dialect - run: name: Schema tests command: gotestsum -f short-verbose --junitfile ~/test-results/schema.xml working_directory: schema - run: name: Loader tests command: gotestsum -f short-verbose --junitfile ~/test-results/load.xml working_directory: entc/load - run: name: Codegen tests command: gotestsum -f short-verbose --junitfile ~/test-results/gen.xml working_directory: entc/gen - *storetestdir integration: docker: &integration-docker - image: circleci/golang - image: circleci/mysql:5.6.35 environment: &mysql_env MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass - image: circleci/mysql:5.7.26 environment: <<: *mysql_env MYSQL_TCP_PORT: 3307 - image: circleci/mysql environment: <<: *mysql_env MYSQL_TCP_PORT: 3308 - image: circleci/mariadb environment: <<: *mysql_env MYSQL_TCP_PORT: 4306 - image: docker.io/bitnami/mariadb:10.2-debian-10 environment: MARIADB_DATABASE: test MARIADB_ROOT_PASSWORD: pass MARIADB_PORT_NUMBER: 4307 - image: circleci/postgres:10.0 environment: POSTGRES_DB: test POSTGRES_PASSWORD: pass command: -p 5430 - image: circleci/postgres:11.0 environment: POSTGRES_DB: test POSTGRES_PASSWORD: pass command: -p 5431 - image: circleci/postgres:12.3 environment: POSTGRES_DB: test POSTGRES_PASSWORD: pass command: -p 5433 - image: circleci/postgres:13.1 environment: POSTGRES_DB: test POSTGRES_PASSWORD: pass command: -p 5434 - image: entgo/gremlin-server entrypoint: /opt/gremlin-server/bin/gremlin-server.sh command: conf/gremlin-server.yaml steps: - checkout - run: &integration-wait name: Wait for databases command: >- dockerize -timeout 2m -wait tcp://localhost:3306 -wait tcp://localhost:3307 -wait tcp://localhost:3308 -wait tcp://localhost:4306 -wait tcp://localhost:4307 -wait tcp://localhost:5430 -wait tcp://localhost:5431 -wait tcp://localhost:5433 -wait tcp://localhost:8182 - *mktestdir - getmods - run: name: Run codegen for examples working_directory: examples command: go generate ./... - run: name: Run example tests working_directory: examples command: gotestsum -f short-verbose --junitfile ~/test-results/examples.xml -- -race ./... - run: name: Run codegen for entc/load working_directory: entc/load command: go generate - run: name: Run codegen for entc/gen working_directory: entc/gen command: go generate - run: name: Run codegen for entc/integration working_directory: entc/integration command: go generate - run: name: Check untracked files command: | if [[ `git status --porcelain` ]] then echo "Running 'go generate ./...' introduced untracked files" git status --porcelain exit 1 fi - run: name: Run integration tests working_directory: entc/integration command: gotestsum -f short-verbose --junitfile ~/test-results/integration.xml -- -race -count=2 -tags='json1' ./... - *storetestdir migration: docker: *integration-docker steps: - checkout - run: *integration-wait - *mktestdir - getmods - run: name: Checkout master command: git checkout origin/master - run: name: Run integration on master working_directory: entc/integration command: gotestsum -f short-verbose --junitfile ~/test-results/master-integration.xml -- -race -count=2 . - run: name: Checkout PR branch command: git checkout "$CIRCLE_BRANCH" - run: name: Run integration on PR branch working_directory: entc/integration command: gotestsum -f short-verbose --junitfile ~/test-results/pr-integration.xml -- -race -count=2 . - *storetestdir docs: docker: - image: circleci/node steps: - checkout - run: name: Checking Docs Modified command: | if [[ ! $(git diff master^ --name-only doc/) ]]; then echo "docs not modified; no need to deploy" circleci step halt fi - run: name: Install Dependencies working_directory: ~/project/doc/website command: yarn - run: name: Build Docs working_directory: ~/project/doc/website command: yarn build - aws-cli/setup - run: name: Deploy Docs working_directory: ~/project/doc/website/build/ent command: aws s3 sync . s3://entgo.io --delete --exclude "assets/*" - run: name: Invalidate Cache command: aws cloudfront create-invalidation --distribution-id $CDN_DISTRIBUTION_ID --paths "/*" | jq -M "del(.Location)" workflows: version: 2.1 all: jobs: - lint - unit - integration - migration: filters: branches: ignore: master - docs: filters: branches: only: master ent-0.5.4/.github/000077500000000000000000000000001377533537200136655ustar00rootroot00000000000000ent-0.5.4/.github/ISSUE_TEMPLATE/000077500000000000000000000000001377533537200160505ustar00rootroot00000000000000ent-0.5.4/.github/ISSUE_TEMPLATE/1.bug.md000066400000000000000000000021411377533537200173040ustar00rootroot00000000000000--- name: Bug report 🐛 about: Create a bug report. labels: 'status: needs triage' --- - [ ] The issue is present in the latest release. - [ ] I have searched the [issues](https://github.com/facebook/ent/issues) of this repository and believe that this is not a duplicate. ## Current Behavior 😯 ## Expected Behavior 🤔 ## Steps to Reproduce 🕹 Steps: 1. 2. 3. 4. ## Your Environment 🌎 | Tech | Version | | ----------- | ------- | | Go | 1.15.? | | Ent | 0.5.? | | Database | Mysql | | Driver | https://github.com/go-sql-driver/mysql |ent-0.5.4/.github/ISSUE_TEMPLATE/2.feature.md000066400000000000000000000014241377533537200201660ustar00rootroot00000000000000--- name: Feature request 🎉 about: Suggest a new idea for the project. labels: 'status: needs triage' --- - [ ] I have searched the [issues](https://github.com/facebook/ent/issues) of this repository and believe that this is not a duplicate. ## Summary 💡 ## Motivation 🔦 ent-0.5.4/.github/ISSUE_TEMPLATE/3.support.md000066400000000000000000000001141377533537200202430ustar00rootroot00000000000000--- name: Question about: General support labels: 'status: needs triage' ---ent-0.5.4/.github/ISSUE_TEMPLATE/config.yml000066400000000000000000000003171377533537200200410ustar00rootroot00000000000000blank_issues_enabled: false # force the usage of a template contact_links: - name: Something Else ❔ url: https://gophers.slack.com/archives/C01FMSQDT53 about: Come chat to us in the gophers slackent-0.5.4/.github/dependabot.yml000066400000000000000000000002611377533537200165140ustar00rootroot00000000000000version: 2 updates: - package-ecosystem: github-actions directory: / schedule: interval: daily - package-ecosystem: gomod directory: / schedule: interval: daily ent-0.5.4/.github/workflows/000077500000000000000000000000001377533537200157225ustar00rootroot00000000000000ent-0.5.4/.github/workflows/cd.yml000066400000000000000000000021351377533537200170340ustar00rootroot00000000000000name: Continuous Deployment on: push: branches: - master paths: - 'doc/**' jobs: docs: name: docs runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions/setup-node@v2.1.4 with: node-version: 14 - name: Install Dependencies working-directory: doc/website run: yarn - name: Build Docs working-directory: doc/website run: yarn build - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@v1 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: eu-central-1 - name: Deploy Docs working-directory: doc/website/build/ent run: aws s3 sync . s3://entgo.io --delete --exclude "assets/*" - name: Invalidate Cache env: CDN_DISTRIBUTION_ID: ${{ secrets.CDN_DISTRIBUTION_ID }} run: aws cloudfront create-invalidation --distribution-id $CDN_DISTRIBUTION_ID --paths "/*" | jq -M "del(.Location)" ent-0.5.4/.github/workflows/ci.yml000066400000000000000000000230011377533537200170340ustar00rootroot00000000000000name: Continuous Integration on: push: paths-ignore: - 'doc/**' tags-ignore: - '*.*' pull_request: paths-ignore: - 'doc/**' jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Run linters uses: golangci/golangci-lint-action@v2.3.0 with: version: v1.28 unit: runs-on: ubuntu-latest strategy: matrix: go: ['1.15', '1.14'] steps: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 with: go-version: ${{ matrix.go }} - uses: actions/cache@v2 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Run dialect tests run: go test -race ./... working-directory: dialect - name: Run schema tests run: go test -race ./... working-directory: schema - name: Run loader tests run: go test -race ./... working-directory: entc/load - name: Run codegen tests run: go test -race ./... working-directory: entc/gen - name: Run example tests working-directory: examples run: go test -race ./... generate: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 - uses: actions/cache@v2 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Run go generate run: go generate ./... - name: Check generated files run: | status=$(git status --porcelain) if [ -n "$status" ]; then echo "you need to run 'go generate ./...' and commit the changes" echo "$status" exit 1 fi integration: runs-on: ubuntu-latest services: mysql56: image: mysql:5.6.35 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 mysql57: image: mysql:5.7.26 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3307:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 mysql8: image: mysql:8 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3308:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria: image: mariadb env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria102: image: mariadb:10.2.32 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4307:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 postgres10: image: postgres:10 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5430:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres11: image: postgres:11 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5431:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres12: image: postgres:12.3 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5433:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres13: image: postgres:13.1 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5434:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 gremlin-server: image: entgo/gremlin-server ports: - 8182:8182 options: >- --health-cmd "netstat -an | grep -q 8182" --health-interval 10s --health-timeout 5s --health-retries 5 steps: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 - uses: actions/cache@v2 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Run integration tests working-directory: entc/integration run: go test -race -count=2 -tags='json1' ./... migration: runs-on: ubuntu-latest if: ${{ github.ref != 'refs/heads/master' }} services: mysql56: image: mysql:5.6.35 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 mysql57: image: mysql:5.7.26 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3307:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 mysql8: image: mysql:8 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3308:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria: image: mariadb env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria102: image: mariadb:10.2.32 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4307:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 postgres10: image: postgres:10 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5430:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres11: image: postgres:11 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5431:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres12: image: postgres:12.3 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5433:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres13: image: postgres:13.1 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5434:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 gremlin-server: image: entgo/gremlin-server ports: - 8182:8182 options: >- --health-cmd "netstat -an | grep -q 8182" --health-interval 10s --health-timeout 5s --health-retries 5 steps: - uses: actions/checkout@v2 with: fetch-depth: 0 - uses: actions/setup-go@v2 - uses: actions/cache@v2 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Checkout origin/master run: git checkout origin/master - name: Run integration on origin/master working-directory: entc/integration run: go test -race -count=2 -tags='json1' ./... - name: Checkout previous HEAD run: git checkout - - name: Run integration on HEAD working-directory: entc/integration run: go test -race -count=2 -tags='json1' ./... ent-0.5.4/.golangci.yml000066400000000000000000000027311377533537200147140ustar00rootroot00000000000000run: timeout: 3m linters-settings: errcheck: ignore: fmt:.*,Read|Write|Close|Exec,io:Copy dupl: threshold: 100 funlen: lines: 130 statements: 80 goheader: template: |- Copyright 2019-present Facebook Inc. All rights reserved. This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory of this source tree. linters: disable-all: true enable: - bodyclose - deadcode - depguard - dogsled - dupl - errcheck - funlen - gocritic - gofmt - goheader - gosec - gosimple - govet - ineffassign - interfacer - misspell - staticcheck - structcheck - stylecheck - typecheck - unconvert - unused - varcheck - whitespace issues: exclude-rules: - path: _test\.go linters: - dupl - funlen - gosec - linters: - unused source: ent.Schema - path: entc/integration/ent/schema/card.go text: "`internal` is unused" - path: dialect/sql/builder.go text: "can be `Querier`" linters: - interfacer - path: dialect/sql/builder.go text: "SQL string concatenation" linters: - gosec - path: dialect/sql/schema linters: - dupl - gosec - text: "Expect WriteFile permissions to be 0600 or less" linters: - gosec - path: privacy/privacy.go linters: - stylecheck ent-0.5.4/CODE_OF_CONDUCT.md000066400000000000000000000064341377533537200151330ustar00rootroot00000000000000# Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ent-0.5.4/CONTRIBUTING.md000066400000000000000000000050551377533537200145630ustar00rootroot00000000000000# Contributing to ent We want to make contributing to this project as easy and transparent as possible. # Project structure - `dialect` - Contains SQL and Gremlin code used by the generated code. - `dialect/sql/schema` - Auto migration logic resides there. - `schema` - User schema API. - `schema/{field, edge, index, mixin}` - provides schema builders API. - `schema/field/gen` - Templates and codegen for numeric builders. - `entc` - Codegen of `ent`. - `entc/load` - `entc` loader API for loading user schemas into a Go objects at runtime. - `entc/gen` - The actual code generation logic resides in this package (and its `templates` package). - `integration` - Integration tests for `entc`. - `privacy` - Runtime code for [privacy layer](https://entgo.io/docs/privacy/). - `doc` - Documentation code for `entgo.io` (uses [Docusaurus](https://docusaurus.io)). - `doc/md` - Markdown files for documentation. - `doc/website` - Website code and assets. In order to test your documentation changes, run `npm start` from the `doc/website` directory, and open [localhost:3000](http://localhost:3000/). # Run integration tests If you touch any file in `entc`, run the following command in `entc/integration`: ``` go generate ./... ``` Then, run `docker-compose` in order to spin-up all database containers: ``` docker-compose -f compose/docker-compose.yaml up -d --scale test=0 ``` Then, run `go test ./...` to run all integration tests. ## Pull Requests We actively welcome your pull requests. 1. Fork the repo and create your branch from `master`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. 5. Make sure your code lints. 6. If you haven't already, complete the Contributor License Agreement ("CLA"). ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Facebook's open source projects. Complete your CLA here: ## Issues We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. ## License By contributing to ent, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. ent-0.5.4/LICENSE000066400000000000000000000261351377533537200133410ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.ent-0.5.4/README.md000066400000000000000000000043361377533537200136120ustar00rootroot00000000000000## ent - An Entity Framework For Go Simple, yet powerful entity framework for Go, that makes it easy to build and maintain applications with large data-models. - **Schema As Code** - model any database schema as Go objects. - **Easily Traverse Any Graph** - run queries, aggregations and traverse any graph structure easily. - **Statically Typed And Explicit API** - 100% statically typed and explicit API using code generation. - **Multi Storage Driver** - supports MySQL, PostgreSQL, SQLite and Gremlin. - **Extendable** - simple to extend and customize using Go templates. ## Quick Installation ```console go get github.com/facebook/ent/cmd/ent ``` For proper installation using [Go modules], visit [entgo.io website][entgo instal]. ## Docs and Support The documentation for developing and using ent is available at: https://entgo.io For discussion and support, [open an issue](https://github.com/facebook/ent/issues/new/choose) or join our [channel](https://gophers.slack.com/archives/C01FMSQDT53) in the gophers Slack. ## Join the ent Community In order to contribute to `ent`, see the [CONTRIBUTING](CONTRIBUTING.md) file for how to go get started. If your company or your product is using `ent`, please let us know by adding yourself to the [ent users page](https://github.com/facebook/ent/wiki/ent-users). ## About the Project The `ent` project was inspired by Ent, an entity framework we use internally. It is developed and maintained by [a8m](https://github.com/a8m) and [alexsn](https://github.com/alexsn) from the [Facebook Connectivity][fbc] team. It is used by multiple teams and projects in production, and the roadmap for its v1 release is described [here](https://github.com/facebook/ent/issues/46). Read more about the motivation of the project [here](https://entgo.io/blog/2019/10/03/introducing-ent). ## License ent is licensed under Apache 2.0 as found in the [LICENSE file](LICENSE). [entgo instal]: https://entgo.io/docs/code-gen/#version-compatibility-between-entc-and-ent [Go modules]: https://github.com/golang/go/wiki/Modules#quick-start [fbc]: https://connectivity.fb.com ent-0.5.4/cmd/000077500000000000000000000000001377533537200130705ustar00rootroot00000000000000ent-0.5.4/cmd/ent/000077500000000000000000000000001377533537200136565ustar00rootroot00000000000000ent-0.5.4/cmd/ent/ent.go000066400000000000000000000007261377533537200150000ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package main import ( "log" "github.com/facebook/ent/cmd/internal/base" "github.com/spf13/cobra" ) func main() { log.SetFlags(0) cmd := &cobra.Command{Use: "ent"} cmd.AddCommand( base.InitCmd(), base.DescribeCmd(), base.GenerateCmd(), ) _ = cmd.Execute() } ent-0.5.4/cmd/ent/ent_test.go000066400000000000000000000016371377533537200160410ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package main import ( "bytes" "os" "os/exec" "testing" "github.com/stretchr/testify/require" ) func TestCmd(t *testing.T) { defer os.RemoveAll("ent") cmd := exec.Command("go", "run", "github.com/facebook/ent/cmd/ent", "init", "User") stderr := bytes.NewBuffer(nil) cmd.Stderr = stderr require.NoError(t, cmd.Run(), stderr.String()) _, err := os.Stat("ent/generate.go") require.NoError(t, err) _, err = os.Stat("ent/schema/user.go") require.NoError(t, err) cmd = exec.Command("go", "run", "github.com/facebook/ent/cmd/ent", "generate", "./ent/schema") stderr = bytes.NewBuffer(nil) cmd.Stderr = stderr require.NoError(t, cmd.Run(), stderr.String()) _, err = os.Stat("ent/user.go") require.NoError(t, err) } ent-0.5.4/cmd/entc/000077500000000000000000000000001377533537200140215ustar00rootroot00000000000000ent-0.5.4/cmd/entc/entc.go000066400000000000000000000016141377533537200153030ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package main import ( "bytes" "io/ioutil" "log" "path/filepath" "github.com/facebook/ent/cmd/internal/base" "github.com/facebook/ent/entc/gen" "github.com/spf13/cobra" ) func main() { log.SetFlags(0) cmd := &cobra.Command{Use: "entc"} cmd.AddCommand( base.InitCmd(), base.DescribeCmd(), base.GenerateCmd(migrate), ) _ = cmd.Execute() } func migrate(c *gen.Config) { var ( target = filepath.Join(c.Target, "generate.go") oldCmd = []byte("github.com/facebook/ent/cmd/entc") ) buf, err := ioutil.ReadFile(target) if err != nil || !bytes.Contains(buf, oldCmd) { return } _ = ioutil.WriteFile(target, bytes.ReplaceAll(buf, oldCmd, []byte("github.com/facebook/ent/cmd/ent")), 0644) } ent-0.5.4/cmd/entc/entc_test.go000066400000000000000000000016411377533537200163420ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package main import ( "bytes" "os" "os/exec" "testing" "github.com/stretchr/testify/require" ) func TestCmd(t *testing.T) { defer os.RemoveAll("ent") cmd := exec.Command("go", "run", "github.com/facebook/ent/cmd/entc", "init", "User") stderr := bytes.NewBuffer(nil) cmd.Stderr = stderr require.NoError(t, cmd.Run(), stderr.String()) _, err := os.Stat("ent/generate.go") require.NoError(t, err) _, err = os.Stat("ent/schema/user.go") require.NoError(t, err) cmd = exec.Command("go", "run", "github.com/facebook/ent/cmd/entc", "generate", "./ent/schema") stderr = bytes.NewBuffer(nil) cmd.Stderr = stderr require.NoError(t, cmd.Run(), stderr.String()) _, err = os.Stat("ent/user.go") require.NoError(t, err) } ent-0.5.4/cmd/internal/000077500000000000000000000000001377533537200147045ustar00rootroot00000000000000ent-0.5.4/cmd/internal/base/000077500000000000000000000000001377533537200156165ustar00rootroot00000000000000ent-0.5.4/cmd/internal/base/base.go000066400000000000000000000155261377533537200170700ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package base defines shared basic pieces of the ent command. package base import ( "bytes" "errors" "fmt" "io/ioutil" "log" "os" "path/filepath" "strings" "text/template" "unicode" "github.com/facebook/ent/cmd/internal/printer" "github.com/facebook/ent/entc" "github.com/facebook/ent/entc/gen" "github.com/facebook/ent/schema/field" "github.com/spf13/cobra" ) // IDType is a custom ID implementation for pflag. type IDType field.Type // Set implements the Set method of the flag.Value interface. func (t *IDType) Set(s string) error { switch s { case field.TypeInt.String(): *t = IDType(field.TypeInt) case field.TypeInt64.String(): *t = IDType(field.TypeInt64) case field.TypeUint.String(): *t = IDType(field.TypeUint) case field.TypeUint64.String(): *t = IDType(field.TypeUint64) case field.TypeString.String(): *t = IDType(field.TypeString) default: return fmt.Errorf("invalid type %q", s) } return nil } // Type returns the type representation of the id option for help command. func (IDType) Type() string { return fmt.Sprintf("%v", []field.Type{ field.TypeInt, field.TypeInt64, field.TypeUint, field.TypeUint64, field.TypeString, }) } // String returns the default value for the help command. func (IDType) String() string { return field.TypeInt.String() } // InitCmd returns the init command for ent/c packages. func InitCmd() *cobra.Command { var target string cmd := &cobra.Command{ Use: "init [flags] [schemas]", Short: "initialize an environment with zero or more schemas", Example: examples( "ent init Example", "ent init --target entv1/schema User Group", ), Args: func(_ *cobra.Command, names []string) error { for _, name := range names { if !unicode.IsUpper(rune(name[0])) { return errors.New("schema names must begin with uppercase") } } return nil }, Run: func(cmd *cobra.Command, names []string) { if err := initEnv(target, names); err != nil { log.Fatalln(fmt.Errorf("ent/init: %w", err)) } }, } cmd.Flags().StringVar(&target, "target", defaultSchema, "target directory for schemas") return cmd } // DescribeCmd returns the describe command for ent/c packages. func DescribeCmd() *cobra.Command { return &cobra.Command{ Use: "describe [flags] path", Short: "printer a description of the graph schema", Example: examples( "ent describe ./ent/schema", "ent describe github.com/a8m/x", ), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, path []string) { graph, err := entc.LoadGraph(path[0], &gen.Config{}) if err != nil { log.Fatalln(err) } printer.Fprint(os.Stdout, graph) }, } } // GenerateCmd returns the generate command for ent/c packages. func GenerateCmd(postRun ...func(*gen.Config)) *cobra.Command { var ( cfg gen.Config storage string features []string templates []string idtype = IDType(field.TypeInt) cmd = &cobra.Command{ Use: "generate [flags] path", Short: "generate go code for the schema directory", Example: examples( "ent generate ./ent/schema", "ent generate github.com/a8m/x", ), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, path []string) { opts := []entc.Option{ entc.Storage(storage), entc.FeatureNames(features...), } for _, tmpl := range templates { typ := "dir" if parts := strings.SplitN(tmpl, "=", 2); len(parts) > 1 { typ, tmpl = parts[0], parts[1] } switch typ { case "dir": opts = append(opts, entc.TemplateDir(tmpl)) case "file": opts = append(opts, entc.TemplateFiles(tmpl)) case "glob": opts = append(opts, entc.TemplateGlob(tmpl)) default: log.Fatalln("unsupported template type", typ) } } // If the target directory is not inferred from // the schema path, resolve its package path. if cfg.Target != "" { pkgPath, err := PkgPath(DefaultConfig, cfg.Target) if err != nil { log.Fatalln(err) } cfg.Package = pkgPath } cfg.IDType = &field.TypeInfo{Type: field.Type(idtype)} if err := entc.Generate(path[0], &cfg, opts...); err != nil { log.Fatalln(err) } for _, fn := range postRun { fn(&cfg) } }, } ) cmd.Flags().Var(&idtype, "idtype", "type of the id field") cmd.Flags().StringVar(&storage, "storage", "sql", "storage driver to support in codegen") cmd.Flags().StringVar(&cfg.Header, "header", "", "override codegen header") cmd.Flags().StringVar(&cfg.Target, "target", "", "target directory for codegen") cmd.Flags().StringSliceVarP(&features, "feature", "", nil, "extend codegen with additional features") cmd.Flags().StringSliceVarP(&templates, "template", "", nil, "external templates to execute") return cmd } // initEnv initialize an environment for ent codegen. func initEnv(target string, names []string) error { if err := createDir(target); err != nil { return fmt.Errorf("create dir %s: %w", target, err) } for _, name := range names { if err := gen.ValidSchemaName(name); err != nil { return fmt.Errorf("init schema %s: %w", name, err) } b := bytes.NewBuffer(nil) if err := tmpl.Execute(b, name); err != nil { return fmt.Errorf("executing template %s: %w", name, err) } newFileTarget := filepath.Join(target, strings.ToLower(name+".go")) if err := ioutil.WriteFile(newFileTarget, b.Bytes(), 0644); err != nil { return fmt.Errorf("writing file %s: %w", newFileTarget, err) } } return nil } func createDir(target string) error { _, err := os.Stat(target) if err == nil || !os.IsNotExist(err) { return err } if err := os.MkdirAll(target, os.ModePerm); err != nil { return fmt.Errorf("creating schema directory: %w", err) } if target != defaultSchema { return nil } if err := ioutil.WriteFile("ent/generate.go", []byte(genFile), 0644); err != nil { return fmt.Errorf("creating generate.go file: %w", err) } return nil } // schema template for the "init" command. var tmpl = template.Must(template.New("schema"). Parse(`package schema import "github.com/facebook/ent" // {{ . }} holds the schema definition for the {{ . }} entity. type {{ . }} struct { ent.Schema } // Fields of the {{ . }}. func ({{ . }}) Fields() []ent.Field { return nil } // Edges of the {{ . }}. func ({{ . }}) Edges() []ent.Edge { return nil } `)) const ( // default schema package path. defaultSchema = "ent/schema" // ent/generate.go file used for "go generate" command. genFile = "package ent\n\n//go:generate go run github.com/facebook/ent/cmd/ent generate ./schema\n" ) // examples formats the given examples to the cli. func examples(ex ...string) string { for i := range ex { ex[i] = " " + ex[i] // indent each row with 2 spaces. } return strings.Join(ex, "\n") } ent-0.5.4/cmd/internal/base/packages.go000066400000000000000000000030501377533537200177210ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package base import ( "fmt" "os" "path" "path/filepath" "golang.org/x/tools/go/packages" ) // DefaultConfig for loading Go base. var DefaultConfig = &packages.Config{Mode: packages.NeedName} // PkgPath returns the Go package name for given target path. // Even if the existing path is not exist yet in the filesystem. // // If base.Config is nil, DefaultConfig will be used to load base. func PkgPath(config *packages.Config, target string) (string, error) { if config == nil { config = DefaultConfig } pathCheck, err := filepath.Abs(target) if err != nil { return "", err } var parts []string if _, err := os.Stat(pathCheck); os.IsNotExist(err) { parts = append(parts, filepath.Base(pathCheck)) pathCheck = filepath.Dir(pathCheck) } // Try maximum 2 directories above the given // target to find the root package or module. for i := 0; i < 2; i++ { pkgs, err := packages.Load(config, pathCheck) if err != nil { return "", fmt.Errorf("load package info: %v", err) } if len(pkgs) == 0 || len(pkgs[0].Errors) != 0 { parts = append(parts, filepath.Base(pathCheck)) pathCheck = filepath.Dir(pathCheck) continue } pkgPath := pkgs[0].PkgPath for j := len(parts) - 1; j >= 0; j-- { pkgPath = path.Join(pkgPath, parts[j]) } return pkgPath, nil } return "", fmt.Errorf("root package or module was not found for: %s", target) } ent-0.5.4/cmd/internal/base/packages_test.go000066400000000000000000000026261377533537200207700ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package base import ( "path/filepath" "testing" "github.com/stretchr/testify/require" "golang.org/x/tools/go/packages/packagestest" ) func TestPkgPath(t *testing.T) { packagestest.TestAll(t, testPkgPath) } func testPkgPath(t *testing.T, x packagestest.Exporter) { e := packagestest.Export(t, x, []packagestest.Module{ { Name: "golang.org/x", Files: map[string]interface{}{ "x.go": "package x", "y/y.go": "package y", }, }, }) defer e.Cleanup() e.Config.Dir = filepath.Dir(e.File("golang.org/x", "x.go")) target := filepath.Join(e.Config.Dir, "ent") pkgPath, err := PkgPath(e.Config, target) require.NoError(t, err) require.Equal(t, "golang.org/x/ent", pkgPath) e.Config.Dir = filepath.Dir(e.File("golang.org/x", "y/y.go")) target = filepath.Join(e.Config.Dir, "ent") pkgPath, err = PkgPath(e.Config, target) require.NoError(t, err) require.Equal(t, "golang.org/x/y/ent", pkgPath) target = filepath.Join(e.Config.Dir, "z/ent") pkgPath, err = PkgPath(e.Config, target) require.NoError(t, err) require.Equal(t, "golang.org/x/y/z/ent", pkgPath) target = filepath.Join(e.Config.Dir, "z/e/n/t") pkgPath, err = PkgPath(e.Config, target) require.Error(t, err) require.Empty(t, pkgPath) } ent-0.5.4/cmd/internal/printer/000077500000000000000000000000001377533537200163675ustar00rootroot00000000000000ent-0.5.4/cmd/internal/printer/printer.go000066400000000000000000000040701377533537200204020ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package printer import ( "fmt" "io" "reflect" "strconv" "strings" "github.com/facebook/ent/entc/gen" "github.com/olekukonko/tablewriter" ) // A Config controls the output of Fprint. type Config struct { io.Writer } // Print prints a table description of the graph to the given writer. func (p Config) Print(g *gen.Graph) { for _, n := range g.Nodes { p.node(n) } } // Fprint executes "pretty-printer" on the given writer. func Fprint(w io.Writer, g *gen.Graph) { Config{Writer: w}.Print(g) } // node returns description of a type. The format of the description is: // // Type: // // // // func (p Config) node(t *gen.Type) { var ( b strings.Builder table = tablewriter.NewWriter(&b) header = []string{"Field", "Type", "Unique", "Optional", "Nillable", "Default", "UpdateDefault", "Immutable", "StructTag", "Validators"} ) b.WriteString(t.Name + ":\n") table.SetAutoFormatHeaders(false) table.SetHeader(header) for _, f := range append([]*gen.Field{t.ID}, t.Fields...) { v := reflect.ValueOf(*f) row := make([]string, len(header)) for i := range row { field := v.FieldByNameFunc(func(name string) bool { // The first field is mapped from "Name" to "Field". return name == "Name" && i == 0 || name == header[i] }) row[i] = fmt.Sprint(field.Interface()) } table.Append(row) } table.Render() table = tablewriter.NewWriter(&b) table.SetAutoFormatHeaders(false) table.SetHeader([]string{"Edge", "Type", "Inverse", "BackRef", "Relation", "Unique", "Optional"}) for _, e := range t.Edges { table.Append([]string{ e.Name, e.Type.Name, strconv.FormatBool(e.IsInverse()), e.Inverse, e.Rel.Type.String(), strconv.FormatBool(e.Unique), strconv.FormatBool(e.Optional), }) } if table.NumLines() > 0 { table.Render() } io.WriteString(p, strings.ReplaceAll(b.String(), "\n", "\n\t")+"\n") } ent-0.5.4/cmd/internal/printer/printer_test.go000066400000000000000000000207371377533537200214510ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package printer import ( "strings" "testing" "github.com/facebook/ent/entc/gen" "github.com/facebook/ent/schema/field" "github.com/stretchr/testify/assert" ) func TestPrinter_Print(t *testing.T) { tests := []struct { input *gen.Graph out string }{ { input: &gen.Graph{ Nodes: []*gen.Type{ { Name: "User", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Fields: []*gen.Field{ {Name: "name", Type: &field.TypeInfo{Type: field.TypeString}, Validators: 1}, {Name: "age", Type: &field.TypeInfo{Type: field.TypeInt}, Nillable: true}, {Name: "created_at", Type: &field.TypeInfo{Type: field.TypeTime}, Nillable: true, Immutable: true}, }, }, }, }, out: ` User: +------------+-----------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +------------+-----------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | | name | string | false | false | false | false | false | false | | 1 | | age | int | false | false | true | false | false | false | | 0 | | created_at | time.Time | false | false | true | false | false | true | | 0 | +------------+-----------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ `, }, { input: &gen.Graph{ Nodes: []*gen.Type{ { Name: "User", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Edges: []*gen.Edge{ {Name: "groups", Type: &gen.Type{Name: "Group"}, Rel: gen.Relation{Type: gen.M2M}, Optional: true}, {Name: "spouse", Type: &gen.Type{Name: "User"}, Unique: true, Rel: gen.Relation{Type: gen.O2O}}, }, }, }, }, out: ` User: +-------+------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | +-------+------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ +--------+-------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +--------+-------+---------+---------+----------+--------+----------+ | groups | Group | false | | M2M | false | true | | spouse | User | false | | O2O | true | false | +--------+-------+---------+---------+----------+--------+----------+ `, }, { input: &gen.Graph{ Nodes: []*gen.Type{ { Name: "User", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Fields: []*gen.Field{ {Name: "name", Type: &field.TypeInfo{Type: field.TypeString}, Validators: 1}, {Name: "age", Type: &field.TypeInfo{Type: field.TypeInt}, Nillable: true}, }, Edges: []*gen.Edge{ {Name: "groups", Type: &gen.Type{Name: "Group"}, Rel: gen.Relation{Type: gen.M2M}, Optional: true}, {Name: "spouse", Type: &gen.Type{Name: "User"}, Unique: true, Rel: gen.Relation{Type: gen.O2O}}, }, }, }, }, out: ` User: +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | | name | string | false | false | false | false | false | false | | 1 | | age | int | false | false | true | false | false | false | | 0 | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ +--------+-------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +--------+-------+---------+---------+----------+--------+----------+ | groups | Group | false | | M2M | false | true | | spouse | User | false | | O2O | true | false | +--------+-------+---------+---------+----------+--------+----------+ `, }, { input: &gen.Graph{ Nodes: []*gen.Type{ { Name: "User", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Fields: []*gen.Field{ {Name: "name", Type: &field.TypeInfo{Type: field.TypeString}, Validators: 1}, {Name: "age", Type: &field.TypeInfo{Type: field.TypeInt}, Nillable: true}, }, Edges: []*gen.Edge{ {Name: "groups", Type: &gen.Type{Name: "Group"}, Rel: gen.Relation{Type: gen.M2M}, Optional: true}, {Name: "spouse", Type: &gen.Type{Name: "User"}, Unique: true, Rel: gen.Relation{Type: gen.O2O}}, }, }, { Name: "Group", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Fields: []*gen.Field{ {Name: "name", Type: &field.TypeInfo{Type: field.TypeString}}, }, Edges: []*gen.Edge{ {Name: "users", Type: &gen.Type{Name: "User"}, Rel: gen.Relation{Type: gen.M2M}, Optional: true}, }, }, }, }, out: ` User: +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | | name | string | false | false | false | false | false | false | | 1 | | age | int | false | false | true | false | false | false | | 0 | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ +--------+-------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +--------+-------+---------+---------+----------+--------+----------+ | groups | Group | false | | M2M | false | true | | spouse | User | false | | O2O | true | false | +--------+-------+---------+---------+----------+--------+----------+ Group: +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | | name | string | false | false | false | false | false | false | | 0 | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ +-------+------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +-------+------+---------+---------+----------+--------+----------+ | users | User | false | | M2M | false | true | +-------+------+---------+---------+----------+--------+----------+ `, }, } for _, tt := range tests { b := &strings.Builder{} Fprint(b, tt.input) assert.Equal(t, tt.out, "\n"+b.String()) } } ent-0.5.4/dialect/000077500000000000000000000000001377533537200137325ustar00rootroot00000000000000ent-0.5.4/dialect/dialect.go000066400000000000000000000114421377533537200156700ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package dialect import ( "context" "database/sql/driver" "fmt" "log" "github.com/google/uuid" ) // Dialect names for external usage. const ( MySQL = "mysql" SQLite = "sqlite3" Postgres = "postgres" Gremlin = "gremlin" ) // ExecQuerier wraps the 2 database operations. type ExecQuerier interface { // Exec executes a query that doesn't return rows. For example, in SQL, INSERT or UPDATE. // It scans the result into the pointer v. In SQL, you it's usually sql.Result. Exec(ctx context.Context, query string, args, v interface{}) error // Query executes a query that returns rows, typically a SELECT in SQL. // It scans the result into the pointer v. In SQL, you it's usually *sql.Rows. Query(ctx context.Context, query string, args, v interface{}) error } // Driver is the interface that wraps all necessary operations for ent clients. type Driver interface { ExecQuerier // Tx starts and returns a new transaction. // The provided context is used until the transaction is committed or rolled back. Tx(context.Context) (Tx, error) // Close closes the underlying connection. Close() error // Dialect returns the dialect name of the driver. Dialect() string } // Tx wraps the Exec and Query operations in transaction. type Tx interface { ExecQuerier driver.Tx } type nopTx struct { Driver } func (nopTx) Commit() error { return nil } func (nopTx) Rollback() error { return nil } // NopTx returns a Tx with a no-op Commit / Rollback methods wrapping // the provided Driver d. func NopTx(d Driver) Tx { return nopTx{d} } // DebugDriver is a driver that logs all driver operations. type DebugDriver struct { Driver // underlying driver. log func(context.Context, ...interface{}) // log function. defaults to log.Println. } // Debug gets a driver and an optional logging function, and returns // a new debugged-driver that prints all outgoing operations. func Debug(d Driver, logger ...func(...interface{})) Driver { logf := log.Println if len(logger) == 1 { logf = logger[0] } drv := &DebugDriver{d, func(_ context.Context, v ...interface{}) { logf(v...) }} return drv } // DebugWithContext gets a driver and a logging function, and returns // a new debugged-driver that prints all outgoing operations with context. func DebugWithContext(d Driver, logger func(context.Context, ...interface{})) Driver { drv := &DebugDriver{d, logger} return drv } // Exec logs its params and calls the underlying driver Exec method. func (d *DebugDriver) Exec(ctx context.Context, query string, args, v interface{}) error { d.log(ctx, fmt.Sprintf("driver.Exec: query=%v args=%v", query, args)) return d.Driver.Exec(ctx, query, args, v) } // Query logs its params and calls the underlying driver Query method. func (d *DebugDriver) Query(ctx context.Context, query string, args, v interface{}) error { d.log(ctx, fmt.Sprintf("driver.Query: query=%v args=%v", query, args)) return d.Driver.Query(ctx, query, args, v) } // Tx adds an log-id for the transaction and calls the underlying driver Tx command. func (d *DebugDriver) Tx(ctx context.Context) (Tx, error) { tx, err := d.Driver.Tx(ctx) if err != nil { return nil, err } id := uuid.New().String() d.log(ctx, fmt.Sprintf("driver.Tx(%s): started", id)) return &DebugTx{tx, id, d.log, ctx}, nil } // DebugTx is a transaction implementation that logs all transaction operations. type DebugTx struct { Tx // underlying transaction. id string // transaction logging id. log func(context.Context, ...interface{}) // log function. defaults to fmt.Println. ctx context.Context // underlying transaction context. } // Exec logs its params and calls the underlying transaction Exec method. func (d *DebugTx) Exec(ctx context.Context, query string, args, v interface{}) error { d.log(ctx, fmt.Sprintf("Tx(%s).Exec: query=%v args=%v", d.id, query, args)) return d.Tx.Exec(ctx, query, args, v) } // Query logs its params and calls the underlying transaction Query method. func (d *DebugTx) Query(ctx context.Context, query string, args, v interface{}) error { d.log(ctx, fmt.Sprintf("Tx(%s).Query: query=%v args=%v", d.id, query, args)) return d.Tx.Query(ctx, query, args, v) } // Commit logs this step and calls the underlying transaction Commit method. func (d *DebugTx) Commit() error { d.log(d.ctx, fmt.Sprintf("Tx(%s): committed", d.id)) return d.Tx.Commit() } // Rollback logs this step and calls the underlying transaction Rollback method. func (d *DebugTx) Rollback() error { d.log(d.ctx, fmt.Sprintf("Tx(%s): rollbacked", d.id)) return d.Tx.Rollback() } ent-0.5.4/dialect/entsql/000077500000000000000000000000001377533537200152405ustar00rootroot00000000000000ent-0.5.4/dialect/entsql/annotation.go000066400000000000000000000046611377533537200177500ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package entsql import "github.com/facebook/ent/schema" // Annotation is a builtin schema annotation for attaching // SQL metadata to schema objects for both codegen and runtime. type Annotation struct { // The Table option allows overriding the default table // name that is generated by ent. For example: // // entsql.Annotation{ // Table: "Users", // } // Table string `json:"table,omitempty"` // Charset defines the character-set of the table. For example: // // entsql.Annotation{ // Charset: "utf8mb4", // } // Charset string `json:"charset,omitempty"` // Collation defines the collation of the table (a set of rules for comparing // characters in a character set). For example: // // entsql.Annotation{ // Collation: "utf8mb4_bin", // } // Collation string `json:"collation,omitempty"` // Options defines the additional table options. For example: // // entsql.Annotation{ // Options: "ENGINE = INNODB", // } // Options string `json:"options,omitempty"` // Size defines the column size in the generated schema. For example: // // entsql.Annotation{ // Size: 128, // } // Size int64 `json:"size,omitempty"` // Incremental defines the autoincremental behavior of a column. For example: // // incrementalEnabled := true // entsql.Annotation{ // Incremental: &incrementalEnabled, // } // // By default, this value is nil defaulting to whatever best fits each scenario. // Incremental *bool `json:"incremental,omitempty"` } // Name describes the annotation name. func (Annotation) Name() string { return "EntSQL" } // Merge implements the schema.Merger interface. func (a Annotation) Merge(other schema.Annotation) schema.Annotation { var ant Annotation switch other := other.(type) { case Annotation: ant = other case *Annotation: if other != nil { ant = *other } default: return a } if t := ant.Table; t != "" { a.Table = t } if c := ant.Charset; c != "" { a.Charset = c } if c := ant.Collation; c != "" { a.Collation = c } if o := ant.Options; o != "" { a.Options = o } if s := ant.Size; s != 0 { a.Size = s } if s := ant.Incremental; s != nil { a.Incremental = s } return a } var ( _ schema.Annotation = (*Annotation)(nil) _ schema.Merger = (*Annotation)(nil) ) ent-0.5.4/dialect/gremlin/000077500000000000000000000000001377533537200153675ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/client.go000066400000000000000000000045331377533537200172010ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "fmt" "net/http" ) // RoundTripper is an interface representing the ability to execute a // single gremlin transaction, obtaining the Response for a given Request. type RoundTripper interface { RoundTrip(context.Context, *Request) (*Response, error) } // The RoundTripperFunc type is an adapter to allow the use of ordinary functions as Gremlin RoundTripper. type RoundTripperFunc func(context.Context, *Request) (*Response, error) // RoundTrip calls f(ctx, r). func (f RoundTripperFunc) RoundTrip(ctx context.Context, r *Request) (*Response, error) { return f(ctx, r) } // Interceptor provides a hook to intercept the execution of a Gremlin Request. type Interceptor func(RoundTripper) RoundTripper // A Client is a gremlin client. type Client struct { // Transport specifies the mechanism by which individual // Gremlin requests are made. Transport RoundTripper } // MaxResponseSize defines the maximum response size allowed. const MaxResponseSize = 2 << 20 // NewClient creates a gremlin client from config and options. func NewClient(cfg Config, opt ...Option) (*Client, error) { return cfg.Build(opt...) } // NewHTTPClient creates an http based gremlin client. func NewHTTPClient(url string, client *http.Client) (*Client, error) { transport, err := NewHTTPTransport(url, client) if err != nil { return nil, err } return &Client{transport}, nil } // Do sends a gremlin request and returns a gremlin response. func (c Client) Do(ctx context.Context, req *Request) (*Response, error) { rsp, err := c.Transport.RoundTrip(ctx, req) if err == nil { err = rsp.Err() } // If we got an error, and the context has been canceled, // the context's error is probably more useful. if err != nil && ctx.Err() != nil { err = ctx.Err() } return rsp, err } // Query issues an eval request via the Do function. func (c Client) Query(ctx context.Context, query string) (*Response, error) { return c.Do(ctx, NewEvalRequest(query)) } // Queryf formats a query string and invokes Query. func (c Client) Queryf(ctx context.Context, format string, args ...interface{}) (*Response, error) { return c.Query(ctx, fmt.Sprintf(format, args...)) } ent-0.5.4/dialect/gremlin/client_test.go000066400000000000000000000044721377533537200202420ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "io" "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestNewClient(t *testing.T) { var cfg Config cfg.Endpoint.URL, _ = url.Parse("http://gremlin-server/gremlin") c, err := NewClient(cfg) assert.NotNil(t, c) assert.NoError(t, err) } type mockRoundTripper struct{ mock.Mock } func (m *mockRoundTripper) RoundTrip(ctx context.Context, req *Request) (*Response, error) { args := m.Called(ctx, req) return args.Get(0).(*Response), args.Error(1) } func TestClientRequest(t *testing.T) { ctx := context.Background() req, rsp := &Request{}, &Response{} var m mockRoundTripper m.On("RoundTrip", ctx, req). Run(func(mock.Arguments) { rsp.Status.Code = StatusSuccess }). Return(rsp, nil). Once() defer m.AssertExpectations(t) response, err := Client{&m}.Do(context.Background(), req) assert.NoError(t, err) assert.Equal(t, rsp, response) } func TestClientResponseError(t *testing.T) { rsp := &Response{} var m mockRoundTripper m.On("RoundTrip", mock.Anything, mock.Anything). Run(func(mock.Arguments) { rsp.Status.Code = StatusServerError }). Return(rsp, nil). Once() defer m.AssertExpectations(t) _, err := Client{&m}.Do(context.Background(), nil) assert.Error(t, err) } func TestClientCanceledContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var m mockRoundTripper m.On("RoundTrip", ctx, mock.Anything). Run(func(mock.Arguments) { cancel() }). Return(&Response{}, io.ErrUnexpectedEOF). Once() defer m.AssertExpectations(t) _, err := Client{&m}.Query(ctx, "g.E()") assert.EqualError(t, err, context.Canceled.Error()) } func TestClientQuery(t *testing.T) { rsp := &Response{} rsp.Status.Code = StatusNoContent var m mockRoundTripper m.On("RoundTrip", mock.Anything, mock.Anything). Run(func(args mock.Arguments) { req := args.Get(1).(*Request) assert.Equal(t, "g.V(1)", req.Arguments[ArgsGremlin]) }). Return(rsp, nil). Once() defer m.AssertExpectations(t) rsp, err := Client{&m}.Queryf(context.Background(), "g.V(%d)", 1) assert.NotNil(t, rsp) assert.NoError(t, err) } ent-0.5.4/dialect/gremlin/config.go000066400000000000000000000040361377533537200171660ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "fmt" "net/http" "net/url" ) type ( // Config offers a declarative way to construct a client. Config struct { Endpoint Endpoint `env:"ENDPOINT" long:"endpoint" default:"" description:"gremlin endpoint to connect to"` DisableExpansion bool `env:"DISABLE_EXPANSION" long:"disable-expansion" description:"disable bindings expansion"` } // An Option configured client. Option func(*options) options struct { interceptors []Interceptor httpClient *http.Client } // Endpoint wraps a url to add flag unmarshaling. Endpoint struct { *url.URL } ) // WithInterceptor adds interceptors to the client's transport. func WithInterceptor(interceptors ...Interceptor) Option { return func(opts *options) { opts.interceptors = append(opts.interceptors, interceptors...) } } // WithHTTPClient assigns underlying http client to be used by http transport. func WithHTTPClient(client *http.Client) Option { return func(opts *options) { opts.httpClient = client } } // Build constructs a client from Config. func (cfg Config) Build(opt ...Option) (c *Client, err error) { opts := cfg.buildOptions(opt) switch cfg.Endpoint.Scheme { case "http", "https": c, err = NewHTTPClient(cfg.Endpoint.String(), opts.httpClient) default: err = fmt.Errorf("unsupported endpoint scheme: %s", cfg.Endpoint.Scheme) } if err != nil { return nil, err } for i := len(opts.interceptors) - 1; i >= 0; i-- { c.Transport = opts.interceptors[i](c.Transport) } if !cfg.DisableExpansion { c.Transport = ExpandBindings(c.Transport) } return c, nil } func (Config) buildOptions(opts []Option) options { var o options for _, opt := range opts { opt(&o) } return o } // UnmarshalFlag implements flag.Unmarshaler interface. func (ep *Endpoint) UnmarshalFlag(value string) (err error) { ep.URL, err = url.Parse(value) return } ent-0.5.4/dialect/gremlin/config_test.go000066400000000000000000000067321377533537200202320ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "errors" "net/http" "net/url" "testing" "github.com/jessevdk/go-flags" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestConfigParsing(t *testing.T) { var cfg Config _, err := flags.ParseArgs(&cfg, []string{ "--disable-expansion", "--endpoint", "http://localhost:8182/gremlin", }) assert.NoError(t, err) assert.True(t, cfg.DisableExpansion) assert.Equal(t, "http", cfg.Endpoint.Scheme) assert.Equal(t, "http://localhost:8182/gremlin", cfg.Endpoint.String()) cfg = Config{} _, err = flags.ParseArgs(&cfg, nil) assert.NoError(t, err) assert.NotNil(t, cfg.Endpoint.URL) } func TestConfigBuild(t *testing.T) { tests := []struct { name string cfg Config opts []Option wantErr bool }{ { name: "HTTP", cfg: Config{ Endpoint: Endpoint{ URL: func() *url.URL { u, _ := url.Parse("http://gremlin-server/gremlin") return u }(), }, }, }, { name: "NoScheme", cfg: Config{ Endpoint: Endpoint{ URL: &url.URL{}, }, }, wantErr: true, }, { name: "BadScheme", cfg: Config{ Endpoint: Endpoint{ URL: &url.URL{ Scheme: "bad", }, }, }, wantErr: true, }, { name: "WithOptions", cfg: Config{ Endpoint: Endpoint{ URL: func() *url.URL { u, _ := url.Parse("http://gremlin-server/gremlin") return u }(), }, DisableExpansion: true, }, opts: []Option{WithHTTPClient(&http.Client{})}, }, { name: "NoExpansion", cfg: Config{ Endpoint: Endpoint{ URL: &url.URL{ Scheme: "bad", }, }, DisableExpansion: true, }, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { client, err := tc.cfg.Build(tc.opts...) if !tc.wantErr { assert.NotNil(t, client) assert.NoError(t, err) } else { assert.Error(t, err) } }) } } type testRoundTripper struct{ mock.Mock } func (rt *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { args := rt.Called(req) rsp, _ := args.Get(0).(*http.Response) return rsp, args.Error(1) } func TestBuildWithHTTPClient(t *testing.T) { var transport testRoundTripper transport.On("RoundTrip", mock.Anything). Return(nil, errors.New("noop")). Once() defer transport.AssertExpectations(t) u, err := url.Parse("http://gremlin-server:8182/gremlin") require.NoError(t, err) client, err := Config{Endpoint: Endpoint{u}}. Build(WithHTTPClient(&http.Client{Transport: &transport})) require.NoError(t, err) _, _ = client.Do(context.Background(), NewEvalRequest("g.V()")) } func TestExpandOrdering(t *testing.T) { var cfg Config cfg.Endpoint.URL, _ = url.Parse("http://gremlin-server/gremlin") interceptor := func(RoundTripper) RoundTripper { return RoundTripperFunc(func(ctx context.Context, req *Request) (*Response, error) { assert.Equal(t, `g.V().hasLabel("user")`, req.Arguments[ArgsGremlin]) assert.Nil(t, req.Arguments[ArgsBindings]) return nil, errors.New("noop") }) } c, err := cfg.Build(WithInterceptor(interceptor)) require.NoError(t, err) req := NewEvalRequest("g.V().hasLabel($1)", WithBindings(map[string]interface{}{"$1": "user"})) _, _ = c.Do(context.Background(), req) } ent-0.5.4/dialect/gremlin/driver.go000066400000000000000000000032601377533537200172120ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "fmt" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/gremlin/graph/dsl" ) // Driver is a dialect.Driver implementation for TinkerPop gremlin. type Driver struct { *Client } // NewDriver returns a new dialect.Driver implementation for gremlin. func NewDriver(c *Client) *Driver { c.Transport = ExpandBindings(c.Transport) return &Driver{c} } // Dialect implements the dialect.Dialect method. func (Driver) Dialect() string { return dialect.Gremlin } // Exec implements the dialect.Exec method. func (c *Driver) Exec(ctx context.Context, query string, args, v interface{}) error { vr, ok := v.(*Response) if !ok { return fmt.Errorf("dialect/gremlin: invalid type %T. expect *gremlin.Response", v) } bindings, ok := args.(dsl.Bindings) if !ok { return fmt.Errorf("dialect/gremlin: invalid type %T. expect map[string]interface{} for bindings", args) } res, err := c.Do(ctx, NewEvalRequest(query, WithBindings(bindings))) if err != nil { return err } *vr = *res return nil } // Query implements the dialect.Query method. func (c *Driver) Query(ctx context.Context, query string, args, v interface{}) error { return c.Exec(ctx, query, args, v) } // Close is a nop close call. It should close the connection in case of WS client. func (Driver) Close() error { return nil } // Tx returns a nop transaction. func (c *Driver) Tx(context.Context) (dialect.Tx, error) { return dialect.NopTx(c), nil } var _ dialect.Driver = (*Driver)(nil) ent-0.5.4/dialect/gremlin/encoding/000077500000000000000000000000001377533537200171555ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/encoding/graphson/000077500000000000000000000000001377533537200207765ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/encoding/graphson/bench_test.go000066400000000000000000000034001377533537200234400ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" jsoniter "github.com/json-iterator/go" ) type book struct { ID string `json:"id" graphson:"g:UUID"` Title string `json:"title"` Author string `json:"author"` Pages int `json:"num_pages"` Chapters []string `json:"chapters"` } func generateObject() *book { return &book{ ID: "21d5dcbf-1fd4-493e-9b74-d6c429f9e4a5", Title: "The Art of Computer Programming, Vol. 2", Author: "Donald E. Knuth", Pages: 784, Chapters: []string{"Random numbers", "Arithmetic"}, } } func BenchmarkMarshalObject(b *testing.B) { obj := generateObject() b.ResetTimer() for n := 0; n < b.N; n++ { _, err := Marshal(obj) if err != nil { b.Fatal(err) } } } func BenchmarkUnmarshalObject(b *testing.B) { out, err := Marshal(generateObject()) if err != nil { b.Fatal(err) } obj := &book{} b.ResetTimer() for n := 0; n < b.N; n++ { err = Unmarshal(out, obj) if err != nil { b.Fatal(err) } } } func BenchmarkMarshalInterface(b *testing.B) { data, err := jsoniter.Marshal(generateObject()) if err != nil { b.Fatal(err) } var obj interface{} if err = jsoniter.Unmarshal(data, &obj); err != nil { b.Fatal(err) } b.ResetTimer() for n := 0; n < b.N; n++ { _, err = Marshal(obj) if err != nil { b.Fatal(err) } } } func BenchmarkUnmarshalInterface(b *testing.B) { data, err := Marshal(generateObject()) if err != nil { b.Fatal(err) } var obj interface{} b.ResetTimer() for n := 0; n < b.N; n++ { err = Unmarshal(data, &obj) if err != nil { b.Fatal(err) } } } ent-0.5.4/dialect/gremlin/encoding/graphson/common_test.go000066400000000000000000000024601377533537200236560ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "unsafe" jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/mock" ) type mocker struct { mock.Mock } // Encode belongs to jsoniter.ValEncoder interface. func (m *mocker) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { m.Called(ptr, stream) } // IsEmpty belongs to jsoniter.ValEncoder interface. func (m *mocker) IsEmpty(ptr unsafe.Pointer) bool { args := m.Called(ptr) return args.Bool(0) } // Decode implements jsoniter.ValDecoder interface. func (m *mocker) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { m.Called(ptr, iter) } // CheckType implements typeChecker interface. func (m *mocker) CheckType(typ Type) error { args := m.Called(typ) return args.Error(0) } // MarshalGraphson implements Marshaler interface. func (m *mocker) MarshalGraphson() ([]byte, error) { args := m.Called() data, err := args.Get(0), args.Error(1) if data == nil { return nil, err } return data.([]byte), err } // UnmarshalGraphson implements Unmarshaler interface. func (m *mocker) UnmarshalGraphson(data []byte) error { args := m.Called(data) return args.Error(0) } ent-0.5.4/dialect/gremlin/encoding/graphson/decode.go000066400000000000000000000056331377533537200225570ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "io" "reflect" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) type decodeExtension struct { jsoniter.DummyExtension } // Unmarshal parses the graphson encoded data and stores the result // in the value pointed to by v. func Unmarshal(data []byte, v interface{}) error { return config.Unmarshal(data, v) } // UnmarshalFromString parses the graphson encoded str and stores the result // in the value pointed to by v. func UnmarshalFromString(str string, v interface{}) error { return config.UnmarshalFromString(str, v) } // Decoder defines a graphson decoder. type Decoder interface { Decode(interface{}) error } // NewDecoder create a graphson decoder. func NewDecoder(r io.Reader) Decoder { return config.NewDecoder(r) } // Unmarshaler is the interface implemented by types // that can unmarshal a graphson description of themselves. type Unmarshaler interface { UnmarshalGraphson([]byte) error } // UpdateStructDescriptor decorates struct field encoders for graphson tags. func (ext decodeExtension) UpdateStructDescriptor(desc *jsoniter.StructDescriptor) { for _, binding := range desc.Fields { if tag, ok := binding.Field.Tag().Lookup("graphson"); ok && tag != "-" { if dec := ext.DecoratorOfStructField(binding.Decoder, tag); dec != nil { binding.Decoder = dec } } } } // CreateDecoder returns a value decoder for type. func (ext decodeExtension) CreateDecoder(typ reflect2.Type) jsoniter.ValDecoder { if dec := ext.DecoderOfRegistered(typ); dec != nil { return dec } if dec := ext.DecoderOfUnmarshaler(typ); dec != nil { return dec } if dec := ext.DecoderOfNative(typ); dec != nil { return dec } switch typ.Kind() { case reflect.Array: return ext.DecoderOfArray(typ) case reflect.Slice: return ext.DecoderOfSlice(typ) case reflect.Map: return ext.DecoderOfMap(typ) default: return nil } } // DecorateDecoder decorates an passed in value decoder for type. func (ext decodeExtension) DecorateDecoder(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { if dec := ext.DecoratorOfRegistered(dec); dec != nil { return dec } if dec := ext.DecoratorOfUnmarshaler(typ, dec); dec != nil { return dec } if dec := ext.DecoratorOfTyper(typ, dec); dec != nil { return dec } if dec := ext.DecoratorOfNative(typ, dec); dec != nil { return dec } switch typ.Kind() { case reflect.Ptr, reflect.Struct: return dec case reflect.Interface: return ext.DecoratorOfInterface(typ, dec) case reflect.Slice: return ext.DecoratorOfSlice(typ, dec) case reflect.Array: return ext.DecoratorOfArray(dec) case reflect.Map: return ext.DecoratorOfMap(dec) default: return ext.DecoderOfError("graphson: unsupported type: " + typ.String()) } } ent-0.5.4/dialect/gremlin/encoding/graphson/decode_test.go000066400000000000000000000003241377533537200236060ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson ent-0.5.4/dialect/gremlin/encoding/graphson/encode.go000066400000000000000000000050431377533537200225640ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "io" "reflect" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) type encodeExtension struct { jsoniter.DummyExtension } // Marshal returns the graphson encoding of v. func Marshal(v interface{}) ([]byte, error) { return config.Marshal(v) } // MarshalToString returns the graphson encoding of v as string. func MarshalToString(v interface{}) (string, error) { return config.MarshalToString(v) } // Encoder defines a graphson encoder. type Encoder interface { Encode(interface{}) error } // NewEncoder create a graphson encoder. func NewEncoder(w io.Writer) Encoder { return config.NewEncoder(w) } // Marshaler is the interface implemented by types that // can marshal themselves as graphson. type Marshaler interface { MarshalGraphson() ([]byte, error) } // UpdateStructDescriptor decorates struct field encoders for graphson tags. func (ext encodeExtension) UpdateStructDescriptor(desc *jsoniter.StructDescriptor) { for _, binding := range desc.Fields { if tag, ok := binding.Field.Tag().Lookup("graphson"); ok && tag != "-" { if enc := ext.DecoratorOfStructField(binding.Encoder, tag); enc != nil { binding.Encoder = enc } } } } // CreateEncoder returns a value encoder for type. func (ext encodeExtension) CreateEncoder(typ reflect2.Type) jsoniter.ValEncoder { if enc := ext.EncoderOfRegistered(typ); enc != nil { return enc } if enc := ext.EncoderOfNative(typ); enc != nil { return enc } switch typ.Kind() { case reflect.Map: return ext.EncoderOfMap(typ) default: return nil } } // DecorateEncoder decorates an passed in value encoder for type. func (ext encodeExtension) DecorateEncoder(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { if enc := ext.DecoratorOfRegistered(enc); enc != nil { return enc } if enc := ext.DecoratorOfMarshaler(typ, enc); enc != nil { return enc } if enc := ext.DecoratorOfTyper(typ, enc); enc != nil { return enc } if enc := ext.DecoratorOfNative(typ, enc); enc != nil { return enc } switch typ.Kind() { case reflect.Ptr, reflect.Interface, reflect.Struct: return enc case reflect.Array: return ext.DecoratorOfArray(enc) case reflect.Slice: return ext.DecoratorOfSlice(typ, enc) case reflect.Map: return ext.DecoratorOfMap(enc) default: return ext.EncoderOfError("graphson: unsupported type: " + typ.String()) } } ent-0.5.4/dialect/gremlin/encoding/graphson/encode_test.go000066400000000000000000000005701377533537200236230ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/stretchr/testify/assert" ) func TestEncodeUnsupportedType(t *testing.T) { _, err := Marshal(func() {}) assert.Error(t, err) } ent-0.5.4/dialect/gremlin/encoding/graphson/error.go000066400000000000000000000023041377533537200224550ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "unsafe" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" ) // EncoderOfError returns a value encoder which always fails to encode. func (encodeExtension) EncoderOfError(format string, args ...interface{}) jsoniter.ValEncoder { return decoratorOfError(format, args...) } // DecoderOfError returns a value decoder which always fails to decode. func (decodeExtension) DecoderOfError(format string, args ...interface{}) jsoniter.ValDecoder { return decoratorOfError(format, args...) } func decoratorOfError(format string, args ...interface{}) errorCodec { err := errors.Errorf(format, args...) return errorCodec{err} } type errorCodec struct{ error } func (ec errorCodec) Encode(_ unsafe.Pointer, stream *jsoniter.Stream) { if stream.Error == nil { stream.Error = ec.error } } func (errorCodec) IsEmpty(unsafe.Pointer) bool { return false } func (ec errorCodec) Decode(_ unsafe.Pointer, iter *jsoniter.Iterator) { if iter.Error == nil { iter.Error = ec.error } } ent-0.5.4/dialect/gremlin/encoding/graphson/error_test.go000066400000000000000000000014031377533537200235130ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "bytes" "errors" "testing" "github.com/stretchr/testify/assert" ) func TestErrorCodec(t *testing.T) { codec := errorCodec{errors.New("codec error")} assert.False(t, codec.IsEmpty(nil)) var buf bytes.Buffer stream := config.BorrowStream(&buf) defer config.ReturnStream(stream) codec.Encode(nil, stream) assert.Empty(t, buf.Bytes()) assert.EqualError(t, stream.Error, codec.Error()) iter := config.BorrowIterator([]byte{}) defer config.ReturnIterator(iter) codec.Decode(nil, iter) assert.EqualError(t, iter.Error, codec.Error()) } ent-0.5.4/dialect/gremlin/encoding/graphson/extension.go000066400000000000000000000044311377533537200233430ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "reflect" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) var ( typeEncoders = map[string]jsoniter.ValEncoder{} typeDecoders = map[string]jsoniter.ValDecoder{} ) // RegisterTypeEncoder register type encoder for typ. func RegisterTypeEncoder(typ string, enc jsoniter.ValEncoder) { typeEncoders[typ] = enc } // RegisterTypeDecoder register type decoder for typ. func RegisterTypeDecoder(typ string, dec jsoniter.ValDecoder) { typeDecoders[typ] = dec } type registeredEncoder struct{ jsoniter.ValEncoder } // EncoderOfNative returns a value encoder of a registered type. func (encodeExtension) EncoderOfRegistered(typ reflect2.Type) jsoniter.ValEncoder { enc := typeEncoders[typ.String()] if enc != nil { return registeredEncoder{enc} } if typ.Kind() == reflect.Ptr { ptrType := typ.(reflect2.PtrType) enc := typeEncoders[ptrType.Elem().String()] if enc != nil { return registeredEncoder{ ValEncoder: &jsoniter.OptionalEncoder{ ValueEncoder: enc, }, } } } return nil } // DecoratorOfRegistered decorates a value encoder of a registered type. func (encodeExtension) DecoratorOfRegistered(enc jsoniter.ValEncoder) jsoniter.ValEncoder { if _, ok := enc.(registeredEncoder); ok { return enc } return nil } type registeredDecoder struct{ jsoniter.ValDecoder } // DecoratorOfRegistered returns a value decoder of a registered type. func (decodeExtension) DecoderOfRegistered(typ reflect2.Type) jsoniter.ValDecoder { dec := typeDecoders[typ.String()] if dec != nil { return registeredDecoder{dec} } if typ.Kind() == reflect.Ptr { ptrType := typ.(reflect2.PtrType) dec := typeDecoders[ptrType.Elem().String()] if dec != nil { return registeredDecoder{ ValDecoder: &jsoniter.OptionalDecoder{ ValueType: ptrType.Elem(), ValueDecoder: dec, }, } } } return nil } // DecoratorOfNative decorates a value decoder of a registered type. func (decodeExtension) DecoratorOfRegistered(dec jsoniter.ValDecoder) jsoniter.ValDecoder { if _, ok := dec.(registeredDecoder); ok { return dec } return nil } ent-0.5.4/dialect/gremlin/encoding/graphson/init.go000066400000000000000000000006351377533537200222740ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( jsoniter "github.com/json-iterator/go" ) var config = jsoniter.Config{}.Froze() func init() { config.RegisterExtension(&encodeExtension{}) config.RegisterExtension(&decodeExtension{}) } ent-0.5.4/dialect/gremlin/encoding/graphson/interface.go000066400000000000000000000073121377533537200232700ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "bytes" "fmt" "io" "reflect" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/pkg/errors" ) // DecoratorOfInterface decorates a value decoder of an interface type. func (decodeExtension) DecoratorOfInterface(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { if _, ok := typ.(*reflect2.UnsafeEFaceType); ok { return efaceDecoder{typ, dec} } return dec } type efaceDecoder struct { typ reflect2.Type jsoniter.ValDecoder } func (dec efaceDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { switch next := iter.WhatIsNext(); next { case jsoniter.StringValue, jsoniter.BoolValue, jsoniter.NilValue: dec.ValDecoder.Decode(ptr, iter) case jsoniter.ObjectValue: dec.decode(ptr, iter) default: iter.ReportError("decode empty interface", fmt.Sprintf("unexpected value type: %d", next)) } } func (dec efaceDecoder) decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { data := iter.SkipAndReturnBytes() if iter.Error != nil && iter.Error != io.EOF { return } rtype, err := dec.reflectBytes(data) if err != nil { iter.ReportError("decode empty interface", err.Error()) return } it := config.BorrowIterator(data) defer config.ReturnIterator(it) var val interface{} if rtype != nil { val = rtype.New() it.ReadVal(val) val = rtype.Indirect(val) } else { if jsoniter.Get(data, TypeKey).LastError() == nil { vk := jsoniter.Get(data, ValueKey) if vk.LastError() == nil { val = vk.GetInterface() } } if val == nil { val = it.Read() } } if it.Error != nil && it.Error != io.EOF { iter.ReportError("decode empty interface", it.Error.Error()) return } // nolint: gas dec.typ.UnsafeSet(ptr, unsafe.Pointer(&val)) } func (dec efaceDecoder) reflectBytes(data []byte) (reflect2.Type, error) { typ := Type(jsoniter.Get(data, TypeKey).ToString()) rtype := dec.reflectType(typ) if rtype != nil { return rtype, nil } switch typ { case listType: return dec.reflectSlice(data) case mapType: return dec.reflectMap(data) default: return nil, nil } } func (efaceDecoder) reflectType(typ Type) reflect2.Type { switch typ { case doubleType: return reflect2.TypeOf(float64(0)) case floatType: return reflect2.TypeOf(float32(0)) case byteType: return reflect2.TypeOf(uint8(0)) case int16Type: return reflect2.TypeOf(int16(0)) case int32Type: return reflect2.TypeOf(int32(0)) case int64Type, bigIntegerType: return reflect2.TypeOf(int64(0)) case byteBufferType: return reflect2.TypeOf([]byte{}) default: return nil } } func (efaceDecoder) reflectSlice(data []byte) (reflect2.Type, error) { var elem interface{} if err := Unmarshal(data, &[...]*interface{}{&elem}); err != nil { return nil, errors.Wrap(err, "cannot read first list element") } if elem == nil { return reflect2.TypeOf([]interface{}{}), nil } sliceType := reflect.SliceOf(reflect.TypeOf(elem)) return reflect2.Type2(sliceType), nil } func (efaceDecoder) reflectMap(data []byte) (reflect2.Type, error) { var key, elem interface{} if err := Unmarshal( bytes.Replace(data, []byte(mapType), []byte(listType), 1), &[...]*interface{}{&key, &elem}, ); err != nil { return nil, errors.Wrap(err, "cannot unmarshal first map item") } if key == nil { return reflect2.TypeOf(map[interface{}]interface{}{}), nil } else if elem == nil { return nil, errors.New("expect map element, but found only key") } mapType := reflect.MapOf(reflect.TypeOf(key), reflect.TypeOf(elem)) return reflect2.Type2(mapType), nil } ent-0.5.4/dialect/gremlin/encoding/graphson/interface_test.go000066400000000000000000000133321377533537200243260ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDecodeInterface(t *testing.T) { tests := []struct { name string in string want interface{} wantErr bool }{ { name: "Boolean", in: "false", want: false, }, { name: "String", in: `"str"`, want: "str", }, { name: "Double", in: `{ "@type": "g:Double", "@value": 3.14 }`, want: float64(3.14), }, { name: "Float", in: `{ "@type": "g:Float", "@value": -22.567 }`, want: float32(-22.567), }, { name: "Int32", in: `{ "@type": "g:Int32", "@value": 9000 }`, want: int32(9000), }, { name: "Int64", in: `{ "@type": "g:Int64", "@value": 188786 }`, want: int64(188786), }, { name: "BigInteger", in: `{ "@type": "gx:BigInteger", "@value": 352353463712 }`, want: int64(352353463712), }, { name: "Byte", in: `{ "@type": "gx:Byte", "@value": 100 }`, want: uint8(100), }, { name: "Int16", in: `{ "@type": "gx:Int16", "@value": 2000 }`, want: int16(2000), }, { name: "UnknownType", in: `{ "@type": "g:T", "@value": "label" }`, want: "label", }, { name: "UntypedArray", in: "[]", wantErr: true, }, { name: "NoType", in: `{ "@typ": "g:Int32", "@value": 345 }`, wantErr: true, }, { name: "BadObject", in: `{ "@type": "g:Int32", "@value": 345 `, wantErr: true, }, { name: "BadList", in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int64", "@val": 123457990 } ] }`, wantErr: true, }, { name: "BadMap", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@val": 123457990 }, "First" ] }`, wantErr: true, }, { name: "KeyOnlyMap", in: `{ "@type": "g:Map", "@value": ["Key"] }`, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var got interface{} err := UnmarshalFromString(tc.in, &got) if !tc.wantErr { require.NoError(t, err) assert.Equal(t, tc.want, got) } else { assert.Error(t, err) } }) } } func TestDecodeInterfaceSlice(t *testing.T) { tests := []struct { in string want interface{} }{ { in: `{ "@type": "g:List", "@value": [] }`, want: []interface{}{}, }, { in: `{ "@type": "g:List", "@value": ["x", "y", "z"] }`, want: []string{"x", "y", "z"}, }, { in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int64", "@value": 123457990 }, { "@type": "g:Int64", "@value": 23456111 }, { "@type": "g:Int64", "@value": -687450 } ] }`, want: []int64{123457990, 23456111, -687450}, }, { in: `{ "@type": "gx:ByteBuffer", "@value": "AQIDBAU=" }`, want: []byte{1, 2, 3, 4, 5}, }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.want), func(t *testing.T) { t.Parallel() var got interface{} err := UnmarshalFromString(tc.in, &got) require.NoError(t, err) assert.Equal(t, tc.want, got) }) } } func TestDecodeInterfaceMap(t *testing.T) { tests := []struct { in string want interface{} }{ { in: `{ "@type": "g:Map", "@value": [] }`, want: map[interface{}]interface{}{}, }, { in: `{ "@type": "g:Map", "@value": [ "Sep", { "@type": "g:Int32", "@value": 9 }, "Oct", { "@type": "g:Int32", "@value": 10 }, "Nov", { "@type": "g:Int32", "@value": 11 } ] }`, want: map[string]int32{ "Sep": int32(9), "Oct": int32(10), "Nov": int32(11), }, }, { in: `{ "@type": "g:Map", "@value": [ "One", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 1 } ] }, "Two", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 2 } ] }, "Three", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 3 } ] } ] }`, want: map[string][]int32{ "One": {1}, "Two": {2}, "Three": {3}, }, }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.want), func(t *testing.T) { t.Parallel() var got interface{} err := UnmarshalFromString(tc.in, &got) require.NoError(t, err) assert.Equal(t, tc.want, got) }) } } func TestDecodeInterfaceObject(t *testing.T) { book := struct { ID string `json:"id" graphson:"g:UUID"` Title string `json:"title"` Author string `json:"author"` Pages int `json:"num_pages"` Chapters []string `json:"chapters"` }{ ID: "21d5dcbf-1fd4-493e-9b74-d6c429f9e4a5", Title: "The Art of Computer Programming, Vol. 2", Author: "Donald E. Knuth", Pages: 784, Chapters: []string{"Random numbers", "Arithmetic"}, } data, err := Marshal(book) require.NoError(t, err) var v interface{} err = Unmarshal(data, &v) require.NoError(t, err) obj := v.(map[string]interface{}) assert.Equal(t, book.ID, obj["id"]) assert.Equal(t, book.Title, obj["title"]) assert.Equal(t, book.Author, obj["author"]) assert.EqualValues(t, book.Pages, obj["num_pages"]) assert.ElementsMatch(t, book.Chapters, obj["chapters"]) } ent-0.5.4/dialect/gremlin/encoding/graphson/lazy.go000066400000000000000000000036601377533537200223110ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "sync" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/pkg/errors" ) // LazyEncoderOf returns a lazy encoder for type. func (encodeExtension) LazyEncoderOf(typ reflect2.Type) jsoniter.ValEncoder { return &lazyEncoder{resolve: func() jsoniter.ValEncoder { return config.EncoderOf(typ) }} } // LazyDecoderOf returns a lazy unique decoder for type. func (decodeExtension) LazyDecoderOf(typ reflect2.Type) jsoniter.ValDecoder { return &lazyDecoder{resolve: func() jsoniter.ValDecoder { dec := config.DecoderOf(reflect2.PtrTo(typ)) if td, ok := dec.(typeDecoder); ok { td.typeChecker = &uniqueType{elemChecker: td.typeChecker} dec = td } return dec }} } type lazyEncoder struct { jsoniter.ValEncoder resolve func() jsoniter.ValEncoder once sync.Once } func (enc *lazyEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { enc.once.Do(func() { enc.ValEncoder = enc.resolve() }) enc.ValEncoder.Encode(ptr, stream) } func (enc *lazyEncoder) IsEmpty(ptr unsafe.Pointer) bool { enc.once.Do(func() { enc.ValEncoder = enc.resolve() }) return enc.ValEncoder.IsEmpty(ptr) } type lazyDecoder struct { jsoniter.ValDecoder resolve func() jsoniter.ValDecoder once sync.Once } func (dec *lazyDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { dec.once.Do(func() { dec.ValDecoder = dec.resolve() }) dec.ValDecoder.Decode(ptr, iter) } type uniqueType struct { typ Type once sync.Once elemChecker typeChecker } func (u *uniqueType) CheckType(other Type) error { u.once.Do(func() { u.typ = other }) if u.typ != other { return errors.Errorf("expect type %s, but found %s", u.typ, other) } return u.elemChecker.CheckType(u.typ) } ent-0.5.4/dialect/gremlin/encoding/graphson/lazy_test.go000066400000000000000000000021521377533537200233430ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "sync/atomic" "testing" jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestLazyEncode(t *testing.T) { var m mocker m.On("IsEmpty", mock.Anything).Return(false).Once() m.On("Encode", mock.Anything, mock.Anything).Once() defer m.AssertExpectations(t) var cnt uint32 var enc jsoniter.ValEncoder = &lazyEncoder{resolve: func() jsoniter.ValEncoder { assert.Equal(t, uint32(1), atomic.AddUint32(&cnt, 1)) return &m }} enc.IsEmpty(nil) enc.Encode(nil, nil) } func TestLazyDecode(t *testing.T) { var m mocker m.On("Decode", mock.Anything, mock.Anything).Times(3) defer m.AssertExpectations(t) var cnt uint32 var dec jsoniter.ValDecoder = &lazyDecoder{resolve: func() jsoniter.ValDecoder { assert.Equal(t, uint32(1), atomic.AddUint32(&cnt, 1)) return &m }} for i := 0; i < 3; i++ { dec.Decode(nil, nil) } } ent-0.5.4/dialect/gremlin/encoding/graphson/map.go000066400000000000000000000054421377533537200221070ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) // EncoderOfMap returns a value encoder of a map type. func (ext encodeExtension) EncoderOfMap(typ reflect2.Type) jsoniter.ValEncoder { mapType := typ.(reflect2.MapType) return &mapEncoder{ mapType: mapType, keyEnc: ext.LazyEncoderOf(mapType.Key()), elemEnc: ext.LazyEncoderOf(mapType.Elem()), } } // DecoratorOfMap decorates a value encoder of a map type. func (encodeExtension) DecoratorOfMap(enc jsoniter.ValEncoder) jsoniter.ValEncoder { return typeEncoder{enc, mapType} } type mapEncoder struct { mapType reflect2.MapType keyEnc jsoniter.ValEncoder elemEnc jsoniter.ValEncoder } func (enc *mapEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { iter := enc.mapType.UnsafeIterate(ptr) if !iter.HasNext() { stream.WriteEmptyArray() return } stream.WriteArrayStart() for { key, elem := iter.UnsafeNext() enc.keyEnc.Encode(key, stream) stream.WriteMore() enc.elemEnc.Encode(elem, stream) if !iter.HasNext() { break } stream.WriteMore() } stream.WriteArrayEnd() } func (enc *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool { return !enc.mapType.UnsafeIterate(ptr).HasNext() } // DecoderOfMap returns a value decoder of a map type. func (ext decodeExtension) DecoderOfMap(typ reflect2.Type) jsoniter.ValDecoder { mapType := typ.(reflect2.MapType) keyType, elemType := mapType.Key(), mapType.Elem() return &mapDecoder{ mapType: mapType, keyType: keyType, elemType: elemType, keyDec: ext.LazyDecoderOf(keyType), elemDec: ext.LazyDecoderOf(elemType), } } // DecoratorOfMap decorates a value decoder of a map type. func (decodeExtension) DecoratorOfMap(dec jsoniter.ValDecoder) jsoniter.ValDecoder { return typeDecoder{dec, mapType} } type mapDecoder struct { mapType reflect2.MapType keyType reflect2.Type elemType reflect2.Type keyDec jsoniter.ValDecoder elemDec jsoniter.ValDecoder } func (dec *mapDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { mapType := dec.mapType if mapType.UnsafeIsNil(ptr) { mapType.UnsafeSet(ptr, mapType.UnsafeMakeMap(0)) } var key unsafe.Pointer if !iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { if key == nil { key = dec.keyType.UnsafeNew() dec.keyDec.Decode(key, iter) return iter.Error == nil } elem := dec.elemType.UnsafeNew() dec.elemDec.Decode(elem, iter) if iter.Error != nil { return false } mapType.UnsafeSetIndex(ptr, key, elem) key = nil return true }) { return } if key != nil { iter.ReportError("decode map", "odd number of map items") } } ent-0.5.4/dialect/gremlin/encoding/graphson/map_test.go000066400000000000000000000127041377533537200231450ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "strings" "testing" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncodeMap(t *testing.T) { tests := []struct { name string in interface{} want string }{ { name: "simple", in: map[int32]string{ 3: "Mar", 1: "Jan", 2: "Feb", }, want: `[ { "@type": "g:Int32", "@value": 1 }, "Jan", { "@type": "g:Int32", "@value": 2 }, "Feb", { "@type": "g:Int32", "@value": 3 }, "Mar" ]`, }, { name: "mixed", in: map[string]interface{}{ "byte": byte('a'), "string": "str", "slice": []int{1, 2, 3}, "map": map[string]int{}, }, want: `[ "byte", { "@type": "gx:Byte", "@value": 97 }, "string", "str", "slice", { "@type": "g:List", "@value": [ { "@type": "g:Int64", "@value": 1 }, { "@type": "g:Int64", "@value": 2 }, { "@type": "g:Int64", "@value": 3 } ] }, "map", { "@type": "g:Map", "@value": [] } ]`, }, { name: "struct-key", in: map[struct { K string `json:"key"` }]int32{ {"result"}: 42, }, want: `[ { "key": "result" }, { "@type": "g:Int32", "@value": 42 } ]`, }, { name: "nil", in: map[string]uint8(nil), want: "null", }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() data, err := Marshal(tc.in) require.NoError(t, err) assert.Equal(t, "g:Map", jsoniter.Get(data, "@type").ToString()) var want []interface{} err = jsoniter.UnmarshalFromString(tc.want, &want) require.NoError(t, err) got, ok := jsoniter.Get(data, "@value").GetInterface().([]interface{}) require.True(t, ok) assert.ElementsMatch(t, want, got) }) } } func TestDecodeMap(t *testing.T) { tests := []struct { name string in string want interface{} }{ { name: "empty", in: `{ "@type": "g:Map", "@value": [] }`, want: map[int]int{}, }, { name: "simple", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int32", "@value": 6 }, "Jun", { "@type": "g:Int32", "@value": 7 }, "Jul", { "@type": "g:Int32", "@value": 8 }, "Aug" ] }`, want: map[int]string{ 6: "Jun", 7: "Jul", 8: "Aug", }, }, { name: "duplicate", in: `{ "@type": "g:Map", "@value": [ "Sep", { "@type": "g:Int32", "@value": 9 }, "Oct", { "@type": "g:Int32", "@value": 65 }, "Oct", { "@type": "g:Int32", "@value": 10 }, "Nov", null ] }`, want: map[string]*int{ "Sep": func() *int { v := 9; return &v }(), "Oct": func() *int { v := 10; return &v }(), "Nov": nil, }, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() typ := reflect2.TypeOf(tc.want).(reflect2.MapType) got := typ.MakeMap(0) err := UnmarshalFromString(tc.in, got) require.NoError(t, err) assert.Equal(t, tc.want, typ.Indirect(got)) }) } } func TestDecodeMapIntoNil(t *testing.T) { var got map[int64]int32 err := UnmarshalFromString(`{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": -9 }, { "@type": "g:Int64", "@value": 99 }, { "@type": "g:Int32", "@value": -99 }, { "@type": "g:Int64", "@value": 999 }, { "@type": "g:Int32", "@value": -999 } ] }`, &got) require.NoError(t, err) assert.Equal(t, map[int64]int32{9: -9, 99: -99, 999: -999}, got) } func TestDecodeBadMap(t *testing.T) { tests := []struct { name string in string }{ { name: "BadValue", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": "55" } ] }`, }, { name: "NoValue", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": 9 }, { "@type": "g:Int64", "@value": 42 } ] }`, }, { name: "AlterKeyType", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": 9 }, { "@type": "g:Int32", "@value": 42 }, { "@type": "g:Int32", "@value": 42 } ] }`, }, { name: "AlterValType", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": 9 }, { "@type": "g:Int64", "@value": 42 }, { "@type": "g:Int64", "@value": 42 } ] }`, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var v map[int]int err := NewDecoder(strings.NewReader(tc.in)).Decode(&v) assert.Error(t, err) }) } } ent-0.5.4/dialect/gremlin/encoding/graphson/marshaler.go000066400000000000000000000070051377533537200233050ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "io" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/pkg/errors" ) // DecoratorOfMarshaler decorates a value encoder of a Marshaler interface. func (ext encodeExtension) DecoratorOfMarshaler(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { if typ == marshalerType { enc := marshalerEncoder{enc, typ} return directMarshalerEncoder{enc} } if typ.Implements(marshalerType) { return marshalerEncoder{enc, typ} } ptrType := reflect2.PtrTo(typ) if ptrType.Implements(marshalerType) { ptrEnc := ext.LazyEncoderOf(ptrType) enc := marshalerEncoder{ptrEnc, ptrType} return referenceEncoder{enc} } return nil } // DecoderOfUnmarshaler returns a value decoder of an Unmarshaler interface. func (decodeExtension) DecoderOfUnmarshaler(typ reflect2.Type) jsoniter.ValDecoder { ptrType := reflect2.PtrTo(typ) if ptrType.Implements(unmarshalerType) { return referenceDecoder{ unmarshalerDecoder{ptrType}, } } return nil } // DecoratorOfUnmarshaler decorates a value encoder of an Unmarshaler interface. func (decodeExtension) DecoratorOfUnmarshaler(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { if reflect2.PtrTo(typ).Implements(unmarshalerType) { return dec } return nil } var ( marshalerType = reflect2.TypeOfPtr((*Marshaler)(nil)).Elem() unmarshalerType = reflect2.TypeOfPtr((*Unmarshaler)(nil)).Elem() ) type marshalerEncoder struct { jsoniter.ValEncoder reflect2.Type } func (enc marshalerEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { marshaler := enc.Type.UnsafeIndirect(ptr).(Marshaler) enc.encode(marshaler, stream) } func (enc marshalerEncoder) encode(marshaler Marshaler, stream *jsoniter.Stream) { data, err := marshaler.MarshalGraphson() if err != nil { stream.Error = errors.Wrapf(err, "graphson: error calling MarshalGraphson for type %s", enc.Type) return } if !config.Valid(data) { stream.Error = errors.Errorf("graphson: syntax error when marshaling type %s", enc.Type) return } _, stream.Error = stream.Write(data) } type directMarshalerEncoder struct { marshalerEncoder } func (enc directMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { marshaler := *(*Marshaler)(ptr) enc.encode(marshaler, stream) } type referenceEncoder struct { jsoniter.ValEncoder } func (enc referenceEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { // nolint: gas enc.ValEncoder.Encode(unsafe.Pointer(&ptr), stream) } func (enc referenceEncoder) IsEmpty(ptr unsafe.Pointer) bool { // nolint: gas return enc.ValEncoder.IsEmpty(unsafe.Pointer(&ptr)) } type unmarshalerDecoder struct { reflect2.Type } func (dec unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { bytes := iter.SkipAndReturnBytes() if iter.Error != nil && iter.Error != io.EOF { return } unmarshaler := dec.UnsafeIndirect(ptr).(Unmarshaler) if err := unmarshaler.UnmarshalGraphson(bytes); err != nil { iter.ReportError( "unmarshal graphson", fmt.Sprintf( "graphson: error calling UnmarshalGraphson for type %s: %s", dec.Type, err, ), ) } } type referenceDecoder struct { jsoniter.ValDecoder } func (dec referenceDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { // nolint: gas dec.ValDecoder.Decode(unsafe.Pointer(&ptr), iter) } ent-0.5.4/dialect/gremlin/encoding/graphson/marshaler_test.go000066400000000000000000000050761377533537200243520ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "errors" "fmt" "reflect" "testing" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestMarshalerEncode(t *testing.T) { want := []byte(`{"@type": "g:Int32", "@value": 42}`) m := &mocker{} call := m.On("MarshalGraphson").Return(want, nil) defer m.AssertExpectations(t) tests := []interface{}{m, &m, func() *Marshaler { marshaler := Marshaler(m); return &marshaler }(), Marshaler(nil)} call.Times(len(tests) - 1) for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc), func(t *testing.T) { got, err := Marshal(tc) assert.NoError(t, err) if !reflect2.IsNil(tc) { assert.Equal(t, want, got) } else { assert.Equal(t, []byte("null"), got) } }) } } func TestMarshalerError(t *testing.T) { errStr := "marshaler error" m := &mocker{} m.On("MarshalGraphson").Return(nil, errors.New(errStr)).Once() defer m.AssertExpectations(t) _, err := Marshal(m) assert.Error(t, err) assert.Contains(t, err.Error(), errStr) } func TestBadMarshaler(t *testing.T) { m := &mocker{} m.On("MarshalGraphson").Return([]byte(`{"@type": "g:Int32", "@value":`), nil).Once() defer m.AssertExpectations(t) _, err := Marshal(m) assert.Error(t, err) } func TestUnmarshalerDecode(t *testing.T) { data := `{"@type": "g:UUID", "@value": "cb682578-9d92-4499-9ebc-5c6aa73c5397"}` var value string m := &mocker{} m.On("UnmarshalGraphson", mock.Anything). Run(func(args mock.Arguments) { data := args.Get(0).([]byte) value = jsoniter.Get(data, "@value").ToString() }). Return(nil). Once() defer m.AssertExpectations(t) err := UnmarshalFromString(data, m) require.NoError(t, err) assert.Equal(t, "cb682578-9d92-4499-9ebc-5c6aa73c5397", value) } func TestUnmarshalerError(t *testing.T) { errStr := "unmarshaler error" m := &mocker{} m.On("UnmarshalGraphson", mock.Anything).Return(errors.New(errStr)).Once() defer m.AssertExpectations(t) err := Unmarshal([]byte(`{}`), m) require.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("graphson: error calling UnmarshalGraphson for type %s: %s", reflect.TypeOf(m), errStr, ), ) } func TestUnmarshalBadInput(t *testing.T) { m := &mocker{} defer m.AssertExpectations(t) err := UnmarshalFromString(`{"@type"}`, m) assert.Error(t, err) } ent-0.5.4/dialect/gremlin/encoding/graphson/native.go000066400000000000000000000071201377533537200226130ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "io" "math" "reflect" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) // EncoderOfNative returns a value encoder of a native type. func (encodeExtension) EncoderOfNative(typ reflect2.Type) jsoniter.ValEncoder { switch typ.Kind() { case reflect.Float64: return float64Encoder{typ} default: return nil } } // DecoratorOfNative decorates a value encoder of a native type. func (encodeExtension) DecoratorOfNative(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { switch typ.Kind() { case reflect.Bool, reflect.String: return enc case reflect.Int64, reflect.Int, reflect.Uint32: return typeEncoder{enc, int64Type} case reflect.Int32, reflect.Int8, reflect.Uint16: return typeEncoder{enc, int32Type} case reflect.Int16: return typeEncoder{enc, int16Type} case reflect.Uint64, reflect.Uint: return typeEncoder{enc, bigIntegerType} case reflect.Uint8: return typeEncoder{enc, byteType} case reflect.Float32: return typeEncoder{enc, floatType} case reflect.Float64: return typeEncoder{enc, doubleType} default: return nil } } // DecoderOfNative returns a value decoder of a native type. func (decodeExtension) DecoderOfNative(typ reflect2.Type) jsoniter.ValDecoder { switch typ.Kind() { case reflect.Float64: return float64Decoder{typ} default: return nil } } // DecoratorOfNative decorates a value decoder of a native type. func (decodeExtension) DecoratorOfNative(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { switch typ.Kind() { case reflect.Bool: return dec case reflect.String: return typeDecoder{dec, typeCheckerFunc(func(Type) error { return nil })} case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return typeDecoder{dec, integerTypes} case reflect.Float32: return typeDecoder{dec, floatTypes} case reflect.Float64: return typeDecoder{dec, doubleTypes} default: return nil } } type float64Encoder struct { reflect2.Type } func (enc float64Encoder) IsEmpty(ptr unsafe.Pointer) bool { return enc.UnsafeIndirect(ptr).(float64) == 0 } func (enc float64Encoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { f := enc.UnsafeIndirect(ptr).(float64) switch { case math.IsNaN(f): stream.WriteString("NaN") case math.IsInf(f, 1): stream.WriteString("Infinity") case math.IsInf(f, -1): stream.WriteString("-Infinity") default: stream.WriteFloat64(f) } } type float64Decoder struct { reflect2.Type } var ( integerTypes = Types{byteType, int16Type, int32Type, int64Type, bigIntegerType} floatTypes = append(integerTypes, floatType, bigDecimal) doubleTypes = append(floatTypes, doubleType) ) func (dec float64Decoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { var val float64 switch next := iter.WhatIsNext(); next { case jsoniter.NumberValue: val = iter.ReadFloat64() case jsoniter.StringValue: switch str := iter.ReadString(); str { case "NaN": val = math.NaN() case "Infinity": val = math.Inf(1) case "-Infinity": val = math.Inf(-1) default: iter.ReportError("decode float64", "invalid value "+str) } default: iter.ReportError("decode float64", fmt.Sprintf("unexpected value type: %d", next)) } if iter.Error == nil || iter.Error == io.EOF { // nolint: gas dec.UnsafeSet(ptr, unsafe.Pointer(&val)) } } ent-0.5.4/dialect/gremlin/encoding/graphson/native_test.go000066400000000000000000000130411377533537200236510ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "math" "testing" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncodeNative(t *testing.T) { tests := []struct { in interface{} want string wantErr bool }{ { in: true, want: "true", }, { in: "hello", want: `"hello"`, }, { in: int8(120), want: `{ "@type": "g:Int32", "@value": 120 }`, }, { in: int16(-16), want: `{ "@type": "gx:Int16", "@value": -16 }`, }, { in: int32(3232), want: `{ "@type": "g:Int32", "@value": 3232 }`, }, { in: int64(646464), want: `{ "@type": "g:Int64", "@value": 646464 }`, }, { in: int(127001), want: `{ "@type": "g:Int64", "@value": 127001 }`, }, { in: uint8(81), want: `{ "@type": "gx:Byte", "@value": 81 }`, }, { in: uint16(12345), want: `{ "@type": "g:Int32", "@value": 12345 }`, }, { in: uint32(123454321), want: `{ "@type": "g:Int64", "@value": 123454321 }`, }, { in: uint64(1234567890), want: `{ "@type": "gx:BigInteger", "@value": 1234567890 }`, }, { in: uint(9876543210), want: `{ "@type" :"gx:BigInteger", "@value": 9876543210 }`, }, { in: float32(math.Pi), want: `{ "@type": "g:Float", "@value": 3.1415927 }`, }, { in: float64(math.E), want: `{ "@type": "g:Double", "@value": 2.718281828459045 }`, }, { in: math.NaN(), want: `{ "@type": "g:Double", "@value": "NaN" }`, }, { in: math.Inf(1), want: `{ "@type": "g:Double", "@value": "Infinity" }`, }, { in: math.Inf(-1), want: `{ "@type": "g:Double", "@value": "-Infinity" }`, }, { in: func() *int { v := 7142; return &v }(), want: `{ "@type": "g:Int64", "@value": 7142 }`, }, { in: func() interface{} { v := int16(6116); return &v }(), want: `{ "@type": "gx:Int16", "@value": 6116 }`, }, { in: nil, want: "null", }, { in: make(chan int), wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.in), func(t *testing.T) { t.Parallel() got, err := MarshalToString(tc.in) if !tc.wantErr { assert.NoError(t, err) assert.JSONEq(t, tc.want, got) } else { assert.Error(t, err) assert.Empty(t, got) } }) } } func TestDecodeNative(t *testing.T) { tests := []struct { in string want interface{} }{ { in: `{"@type": "g:Float", "@value": 3.14}`, want: float32(3.14), }, { in: `{"@type": "g:Float", "@value": "Float"}`, }, { in: `{"@type": "g:Double", "@value": 2.71}`, want: float64(2.71), }, { in: `{"@type": "gx:BigDecimal", "@value": 3.142}`, want: float32(3.142), }, { in: `{"@type": "gx:BigDecimal", "@value": 55512.5176}`, want: float64(55512.5176), }, { in: `{"@type": "g:T", "@value": "world"}`, want: "world", }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.want), func(t *testing.T) { t.Parallel() if tc.want != nil { typ := reflect2.TypeOf(tc.want) got := typ.New() err := UnmarshalFromString(tc.in, got) require.NoError(t, err) assert.Equal(t, tc.want, typ.Indirect(got)) } else { var msg jsoniter.RawMessage err := UnmarshalFromString(tc.in, &msg) assert.Error(t, err) } }) } } func TestDecodeTypeMismatch(t *testing.T) { t.Run("FloatToInt", func(t *testing.T) { var v int err := UnmarshalFromString(`{"@type": "g:Float", "@value": 3.14}`, &v) assert.Error(t, err) }) t.Run("DoubleToFloat", func(t *testing.T) { var v float32 err := UnmarshalFromString(`{"@type": "g:Double", "@value": 5.51}`, &v) assert.Error(t, err) }) t.Run("BigDecimalToUint64", func(t *testing.T) { var v uint64 err := UnmarshalFromString(`{"@type": "gx:BigDecimal", "@value": 5645.51834}`, &v) assert.Error(t, err) }) } func TestDecodeNaNInfinity(t *testing.T) { tests := []struct { data []byte expect func(*testing.T, float64, error) }{ { data: []byte(`{"@type": "g:Double", "@value": "NaN"}`), expect: func(t *testing.T, f float64, err error) { assert.NoError(t, err) assert.True(t, math.IsNaN(f)) }, }, { data: []byte(`{"@type": "g:Double", "@value": "Infinity"}`), expect: func(t *testing.T, f float64, err error) { assert.NoError(t, err) assert.True(t, math.IsInf(f, 1)) }, }, { data: []byte(`{"@type": "g:Double", "@value": "-Infinity"}`), expect: func(t *testing.T, f float64, err error) { assert.NoError(t, err) assert.True(t, math.IsInf(f, -1)) }, }, { data: []byte(`{"@type": "g:Double", "@value": "Junk"}`), expect: func(t *testing.T, _ float64, err error) { assert.Error(t, err) }, }, { data: []byte(`{"@type": "g:Double", "@value": [42]}`), expect: func(t *testing.T, _ float64, err error) { assert.Error(t, err) }, }, } for _, tc := range tests { var f float64 err := Unmarshal(tc.data, &f) tc.expect(t, f, err) } } func TestDecodeTypeDefinition(t *testing.T) { type Status int const StatusOk Status = 42 var status Status err := UnmarshalFromString(`{"@type": "g:Int64", "@value": 42}`, &status) assert.NoError(t, err) assert.Equal(t, StatusOk, status) } ent-0.5.4/dialect/gremlin/encoding/graphson/raw.go000066400000000000000000000016001377533537200221130ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "github.com/pkg/errors" ) // RawMessage is a raw encoded graphson value. type RawMessage []byte // RawMessage must implement Marshaler/Unmarshaler interfaces. var ( _ Marshaler = (*RawMessage)(nil) _ Unmarshaler = (*RawMessage)(nil) ) // MarshalGraphson returns m as the graphson encoding of m. func (m RawMessage) MarshalGraphson() ([]byte, error) { if m == nil { return []byte("null"), nil } return m, nil } // UnmarshalGraphson sets *m to a copy of data. func (m *RawMessage) UnmarshalGraphson(data []byte) error { if m == nil { return errors.New("graphson.RawMessage: UnmarshalGraphson on nil pointer") } *m = append((*m)[0:0], data...) return nil } ent-0.5.4/dialect/gremlin/encoding/graphson/raw_test.go000066400000000000000000000014251377533537200231570ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRawMessageEncoding(t *testing.T) { var s struct{ M RawMessage } got, err := MarshalToString(s) require.NoError(t, err) assert.Equal(t, `{"M":null}`, got) s.M = []byte(`"155a"`) got, err = MarshalToString(s) require.NoError(t, err) assert.JSONEq(t, `{"M": "155a"}`, got) err = (*RawMessage)(nil).UnmarshalGraphson(s.M) assert.Error(t, err) s.M = nil err = UnmarshalFromString(got, &s) require.NoError(t, err) assert.Equal(t, `"155a"`, string(s.M)) } ent-0.5.4/dialect/gremlin/encoding/graphson/slice.go000066400000000000000000000074561377533537200224400ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "io" "reflect" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/pkg/errors" ) // DecoratorOfSlice decorates a value encoder of a slice type. func (encodeExtension) DecoratorOfSlice(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { encoder := typeEncoder{ValEncoder: enc} sliceType := typ.(reflect2.SliceType) if sliceType.Elem().Kind() == reflect.Uint8 { encoder.Type = byteBufferType } else { encoder.Type = listType } return sliceEncoder{sliceType, encoder} } // DecoratorOfArray decorates a value encoder of an array type. func (encodeExtension) DecoratorOfArray(enc jsoniter.ValEncoder) jsoniter.ValEncoder { return typeEncoder{enc, listType} } // DecoderOfSlice returns a value decoder of a slice type. func (ext decodeExtension) DecoderOfSlice(typ reflect2.Type) jsoniter.ValDecoder { sliceType := typ.(reflect2.SliceType) elemType := sliceType.Elem() if elemType.Kind() == reflect.Uint8 { return nil } return sliceDecoder{ sliceType: sliceType, elemDec: ext.LazyDecoderOf(elemType), } } // DecoderOfArray returns a value decoder of an array type. func (ext decodeExtension) DecoderOfArray(typ reflect2.Type) jsoniter.ValDecoder { arrayType := typ.(reflect2.ArrayType) return arrayDecoder{ arrayType: arrayType, elemDec: ext.LazyDecoderOf(arrayType.Elem()), } } // DecoratorOfSlice decorates a value decoder of a slice type. func (ext decodeExtension) DecoratorOfSlice(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { if typ.(reflect2.SliceType).Elem().Kind() == reflect.Uint8 { return typeDecoder{dec, byteBufferType} } return typeDecoder{dec, listType} } // DecoratorOfArray decorates a value decoder of an array type. func (ext decodeExtension) DecoratorOfArray(dec jsoniter.ValDecoder) jsoniter.ValDecoder { return typeDecoder{dec, listType} } type sliceEncoder struct { sliceType reflect2.SliceType jsoniter.ValEncoder } func (enc sliceEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { if enc.sliceType.UnsafeIsNil(ptr) { stream.WriteNil() } else { enc.ValEncoder.Encode(ptr, stream) } } type sliceDecoder struct { sliceType reflect2.SliceType elemDec jsoniter.ValDecoder } func (dec sliceDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { dec.decode(ptr, iter) if iter.Error != nil && iter.Error != io.EOF { iter.Error = errors.Wrapf(iter.Error, "decoding slice %s", dec.sliceType) } } func (dec sliceDecoder) decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { sliceType := dec.sliceType if iter.ReadNil() { sliceType.UnsafeSetNil(ptr) return } sliceType.UnsafeSet(ptr, sliceType.UnsafeMakeSlice(0, 0)) var length int iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { idx := length length++ sliceType.UnsafeGrow(ptr, length) elem := sliceType.UnsafeGetIndex(ptr, idx) dec.elemDec.Decode(elem, iter) return iter.Error == nil }) } type arrayDecoder struct { arrayType reflect2.ArrayType elemDec jsoniter.ValDecoder } func (dec arrayDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { dec.decode(ptr, iter) if iter.Error != nil && iter.Error != io.EOF { iter.Error = errors.Wrapf(iter.Error, "decoding array %s", dec.arrayType) } } func (dec arrayDecoder) decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { var ( arrayType = dec.arrayType length int ) iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { if length < arrayType.Len() { idx := length length++ elem := arrayType.UnsafeGetIndex(ptr, idx) dec.elemDec.Decode(elem, iter) } else { iter.Skip() } return iter.Error == nil }) } ent-0.5.4/dialect/gremlin/encoding/graphson/slice_test.go000066400000000000000000000073021377533537200234650ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "bytes" "fmt" "strings" "testing" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncodeArray(t *testing.T) { t.Parallel() got, err := MarshalToString([...]string{"a", "b", "c"}) require.NoError(t, err) want := `{ "@type": "g:List", "@value": ["a", "b", "c"]}` assert.JSONEq(t, want, got) } func TestEncodeSlice(t *testing.T) { tests := []struct { in interface{} want string }{ { in: []int32{5, 6, 7, 8}, want: `{ "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 5 }, { "@type": "g:Int32", "@value": 6 }, { "@type": "g:Int32", "@value": 7 }, { "@type": "g:Int32", "@value": 8 } ] }`, }, { in: []byte{1, 2, 3, 4, 5}, want: `{ "@type": "gx:ByteBuffer", "@value": "AQIDBAU=" }`, }, { in: [...]byte{4, 5}, want: `{ "@type": "g:List", "@value": [ { "@type": "gx:Byte", "@value": 4 }, { "@type": "gx:Byte", "@value": 5 } ] }`, }, { in: []uint64(nil), want: "null", }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.in), func(t *testing.T) { t.Parallel() var got bytes.Buffer err := NewEncoder(&got).Encode(tc.in) assert.NoError(t, err) assert.JSONEq(t, tc.want, got.String()) }) } } func TestDecodeSlice(t *testing.T) { tests := []struct { in string want interface{} }{ { in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 3 }, { "@type": "g:Int32", "@value": -2 }, { "@type": "g:Int32", "@value": 1 } ] }`, want: []int32{3, -2, 1}, }, { in: `{ "@type": "g:List", "@value": ["a", "b", "c"] }`, want: []string{"a", "b", "c"}, }, { in: `{ "@type": "gx:ByteBuffer", "@value": "AQIDBAU=" }`, want: []byte{1, 2, 3, 4, 5}, }, { in: `{ "@type": "g:List", "@value": [ { "@type": "gx:Byte", "@value": 42 }, { "@type": "gx:Byte", "@value": 55 }, { "@type": "gx:Byte", "@value": 94 } ] }`, want: [...]byte{42, 55}, }, { in: `{ "@type": "g:List", "@value": null }`, want: []int(nil), }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.want), func(t *testing.T) { t.Parallel() typ := reflect2.TypeOf(tc.want) got := typ.New() err := NewDecoder(strings.NewReader(tc.in)).Decode(got) require.NoError(t, err) assert.Equal(t, tc.want, typ.Indirect(got)) }) } } func TestDecodeBadSlice(t *testing.T) { tests := []struct { name string in string new func() interface{} }{ { name: "TypeMismatch", in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 3 }, { "@type": "g:Int64", "@value": 2 } ] }`, new: func() interface{} { return &[]int{} }, }, { name: "BadValue", in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 3 }, { "@type": "g:Int32", "@value": "2" } ] }`, new: func() interface{} { return &[2]int{} }, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := NewDecoder(strings.NewReader(tc.in)).Decode(tc.new()) assert.Error(t, err) }) } } ent-0.5.4/dialect/gremlin/encoding/graphson/struct.go000066400000000000000000000017401377533537200226530ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import jsoniter "github.com/json-iterator/go" // DecoratorOfStructField decorates a struct field value encoder. func (encodeExtension) DecoratorOfStructField(enc jsoniter.ValEncoder, tag string) jsoniter.ValEncoder { typ, _ := parseTag(tag) if typ == "" { return nil } encoder, ok := enc.(typeEncoder) if !ok { encoder = typeEncoder{ValEncoder: enc} } encoder.Type = Type(typ) return encoder } // DecoratorOfStructField decorates a struct field value decoder. func (decodeExtension) DecoratorOfStructField(dec jsoniter.ValDecoder, tag string) jsoniter.ValDecoder { typ, _ := parseTag(tag) if typ == "" { return nil } decoder, ok := dec.(typeDecoder) if !ok { decoder = typeDecoder{ValDecoder: dec} } decoder.typeChecker = Type(typ) return decoder } ent-0.5.4/dialect/gremlin/encoding/graphson/struct_test.go000066400000000000000000000074631377533537200237220ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncodeStruct(t *testing.T) { tests := []struct { name string in interface{} want string }{ { name: "Simple", in: struct { S string I int }{ S: "string", I: 1000, }, want: `{ "S":"string", "I": { "@type": "g:Int64", "@value": 1000 } }`, }, { name: "Tagged", in: struct { ID string `json:"requestId" graphson:"g:UUID"` Seq int `json:"seq" graphson:"g:Int32"` Op string `json:"op" graphson:","` Args map[string]string `json:"args"` }{ ID: "cb682578-9d92-4499-9ebc-5c6aa73c5397", Seq: 42, Op: "authentication", Args: map[string]string{ "sasl": "AHN0ZXBocGhlbgBwYXNzd29yZA==", }, }, want: `{ "requestId": { "@type": "g:UUID", "@value": "cb682578-9d92-4499-9ebc-5c6aa73c5397" }, "seq": { "@type": "g:Int32", "@value": 42 }, "op": "authentication", "args": { "@type": "g:Map", "@value": ["sasl", "AHN0ZXBocGhlbgBwYXNzd29yZA=="] } }`, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() data, err := MarshalToString(tc.in) require.NoError(t, err) assert.JSONEq(t, tc.want, data) }) } } func TestEncodeNestedStruct(t *testing.T) { type S struct { Parent *S `json:"parent,omitempty"` ID int `json:"id" graphson:"g:Int32"` } v := S{Parent: &S{ID: 1}, ID: 2} want := `{ "id": { "@type": "g:Int32", "@value": 2 }, "parent": { "id": { "@type": "g:Int32", "@value": 1 } } }` got, err := MarshalToString(&v) require.NoError(t, err) assert.JSONEq(t, want, got) } func TestDecodeStruct(t *testing.T) { tests := []struct { name string in string want interface{} }{ { name: "Simple", in: `{ "S":"str", "I": { "@type": "g:Int32", "@value": 9999 } }`, want: struct { S string I int32 }{ S: "str", I: 9999, }, }, { name: "Tagged", in: `{ "requestId": { "@type": "g:UUID", "@value": "cb682578-9d92-4499-9ebc-5c6aa73c5397" }, "seq": { "@type": "g:Int32", "@value": 42 }, "op": "authentication", "args": { "@type": "g:Map", "@value": ["sasl", "AHN0ZXBocGhlbgBwYXNzd29yZA=="] } }`, want: struct { ID string `json:"requestId" graphson:"g:UUID"` Seq int `json:"seq" graphson:"g:Int32"` Op string `json:"op" graphson:","` Args map[string]string `json:"args"` }{ ID: "cb682578-9d92-4499-9ebc-5c6aa73c5397", Seq: 42, Op: "authentication", Args: map[string]string{ "sasl": "AHN0ZXBocGhlbgBwYXNzd29yZA==", }, }, }, { name: "Empty", in: `{}`, want: struct{}{}, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() typ := reflect2.TypeOf(tc.want) got := typ.New() err := UnmarshalFromString(tc.in, got) require.NoError(t, err) assert.Equal(t, tc.want, typ.Indirect(got)) }) } } func TestDecodeNestedStruct(t *testing.T) { type S struct { Parent *S `json:"parent,omitempty"` ID int `json:"id" graphson:"g:Int32"` } in := `{ "id": { "@type": "g:Int32", "@value": 37 }, "parent": { "id": { "@type": "g:Int32", "@value": 65 } } }` var got S err := UnmarshalFromString(in, &got) require.NoError(t, err) want := S{Parent: &S{ID: 65}, ID: 37} assert.Equal(t, want, got) } ent-0.5.4/dialect/gremlin/encoding/graphson/tags.go000066400000000000000000000016321377533537200222650ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "strings" ) type tagOptions string // parseTag splits a struct field's graphson tag into its type and // comma-separated options. func parseTag(tag string) (string, tagOptions) { if idx := strings.Index(tag, ","); idx != -1 { return tag[:idx], tagOptions(tag[idx+1:]) } return tag, "" } // Contains reports whether a comma-separated list of options // contains a particular substr flag. substr must be surrounded by a // string boundary or commas. func (opts tagOptions) Contains(opt string) bool { s := string(opts) for s != "" { var next string i := strings.Index(s, ",") if i >= 0 { s, next = s[:i], s[i+1:] } if s == opt { return true } s = next } return false } ent-0.5.4/dialect/gremlin/encoding/graphson/tags_test.go000066400000000000000000000022541377533537200233250ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/stretchr/testify/assert" ) func TestParseTag(t *testing.T) { tests := []struct { name string tag string typ string opts tagOptions }{ { name: "Empty", }, { name: "TypeOnly", tag: "g:Int32", typ: "g:Int32", }, { name: "OptsOnly", tag: ",opt1,opt2", opts: "opt1,opt2", }, { name: "TypeAndOpts", tag: "g:UUID,opt3,opt4", typ: "g:UUID", opts: "opt3,opt4", }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() typ, opts := parseTag(tc.tag) assert.Equal(t, tc.typ, typ) assert.Equal(t, tc.opts, opts) }) } } func TestTagOptionsContains(t *testing.T) { _, opts := parseTag(",opt1,opt2,opt3") assert.True(t, opts.Contains("opt1")) assert.True(t, opts.Contains("opt2")) assert.True(t, opts.Contains("opt3")) assert.False(t, opts.Contains("opt4")) assert.False(t, opts.Contains("opt11")) assert.False(t, opts.Contains("")) } ent-0.5.4/dialect/gremlin/encoding/graphson/time.go000066400000000000000000000016341377533537200222670ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "time" "unsafe" jsoniter "github.com/json-iterator/go" ) func init() { RegisterTypeEncoder("time.Time", typeEncoder{timeCodec{}, Timestamp}) RegisterTypeDecoder("time.Time", typeDecoder{timeCodec{}, Types{Timestamp, Date}}) } type timeCodec struct{} func (timeCodec) IsEmpty(ptr unsafe.Pointer) bool { ts := *((*time.Time)(ptr)) return ts.IsZero() } func (timeCodec) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { ts := *((*time.Time)(ptr)) stream.WriteInt64(ts.UnixNano() / time.Millisecond.Nanoseconds()) } func (timeCodec) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { ns := iter.ReadInt64() * time.Millisecond.Nanoseconds() *((*time.Time)(ptr)) = time.Unix(0, ns) } ent-0.5.4/dialect/gremlin/encoding/graphson/time_test.go000066400000000000000000000016101377533537200233200ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestTimeEncoding(t *testing.T) { const ms = 1481750076295 ts := time.Unix(0, ms*time.Millisecond.Nanoseconds()) for _, v := range []interface{}{ts, &ts} { got, err := MarshalToString(v) require.NoError(t, err) assert.JSONEq(t, `{ "@type": "g:Timestamp", "@value": 1481750076295 }`, got) } strs := []string{ `{ "@type": "g:Timestamp", "@value": 1481750076295 }`, `{ "@type": "g:Date", "@value": 1481750076295 }`, } for _, str := range strs { var v time.Time err := UnmarshalFromString(str, &v) assert.NoError(t, err) assert.True(t, ts.Equal(v)) } } ent-0.5.4/dialect/gremlin/encoding/graphson/type.go000066400000000000000000000074041377533537200223130ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "reflect" "strings" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/pkg/errors" ) // A Type is a graphson type. type Type string // graphson typed value types. const ( // core doubleType Type = "g:Double" floatType Type = "g:Float" int32Type Type = "g:Int32" int64Type Type = "g:Int64" listType Type = "g:List" mapType Type = "g:Map" Timestamp Type = "g:Timestamp" Date Type = "g:Date" // extended bigIntegerType Type = "gx:BigInteger" bigDecimal Type = "gx:BigDecimal" byteType Type = "gx:Byte" byteBufferType Type = "gx:ByteBuffer" int16Type Type = "gx:Int16" ) // String implements fmt.Stringer interface. func (typ Type) String() string { return string(typ) } // CheckType implements typeChecker interface. func (typ Type) CheckType(other Type) error { if typ != other { return errors.Errorf("expect type %s, but found %s", typ, other) } return nil } // Types is a slice of Type. type Types []Type // Contains reports whether a slice of types contains a particular type. func (types Types) Contains(typ Type) bool { for i := range types { if types[i] == typ { return true } } return false } // String implements fmt.Stringer interface. func (types Types) String() string { var builder strings.Builder builder.WriteByte('[') for i := range types { if i > 0 { builder.WriteByte(',') } builder.WriteString(types[i].String()) } builder.WriteByte(']') return builder.String() } // CheckType implements typeChecker interface. func (types Types) CheckType(typ Type) error { if !types.Contains(typ) { return errors.Errorf("expect any of %s, but found %s", types, typ) } return nil } // Typer is the interface implemented by types that // define an underlying graphson type. type Typer interface { GraphsonType() Type } var typerType = reflect2.TypeOfPtr((*Typer)(nil)).Elem() // DecoratorOfTyper decorates a value encoder of a Typer interface. func (ext encodeExtension) DecoratorOfTyper(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { if typ.Kind() != reflect.Struct { return nil } if typ.Implements(typerType) { return typerEncoder{ typeEncoder: typeEncoder{ValEncoder: enc}, typerOf: func(ptr unsafe.Pointer) Typer { return typ.UnsafeIndirect(ptr).(Typer) }, } } ptrType := reflect2.PtrTo(typ) if ptrType.Implements(typerType) { return typerEncoder{ typeEncoder: typeEncoder{ValEncoder: enc}, typerOf: func(ptr unsafe.Pointer) Typer { // nolint: gas return ptrType.UnsafeIndirect(unsafe.Pointer(&ptr)).(Typer) }, } } return nil } // DecoratorOfTyper decorates a value decoder of a Typer interface. func (ext decodeExtension) DecoratorOfTyper(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { ptrType := reflect2.PtrTo(typ) if ptrType.Implements(typerType) { return typerDecoder{ typeDecoder: typeDecoder{ValDecoder: dec}, typerOf: func(ptr unsafe.Pointer) Typer { // nolint: gas return ptrType.UnsafeIndirect(unsafe.Pointer(&ptr)).(Typer) }, } } return nil } type typerEncoder struct { typeEncoder typerOf func(unsafe.Pointer) Typer } func (enc typerEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { enc.typeEncoder.Type = enc.typerOf(ptr).GraphsonType() enc.typeEncoder.Encode(ptr, stream) } type typerDecoder struct { typeDecoder typerOf func(unsafe.Pointer) Typer } func (dec typerDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { dec.typeDecoder.typeChecker = dec.typerOf(ptr).GraphsonType() dec.typeDecoder.Decode(ptr, iter) } ent-0.5.4/dialect/gremlin/encoding/graphson/type_test.go000066400000000000000000000041561377533537200233530ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestTypeCheckType(t *testing.T) { assert.NoError(t, int32Type.CheckType(int32Type)) assert.Error(t, int32Type.CheckType(int64Type)) } func TestTypesCheckType(t *testing.T) { assert.NoError(t, Types{int16Type, int32Type, int64Type}.CheckType(int32Type)) assert.Error(t, Types{floatType, doubleType}.CheckType(bigIntegerType)) } func TestTypesString(t *testing.T) { assert.Equal(t, "[]", Types{}.String()) assert.Equal(t, "[gx:Byte]", Types{byteType}.String()) assert.Equal(t, "[gx:Int16,g:Int32,g:Int64]", Types{int16Type, int32Type, int64Type}.String()) } type vertex struct { ID int `json:"id"` Label string `json:"label"` } func (vertex) GraphsonType() Type { return Type("g:Vertex") } type mockVertex struct { mock.Mock `json:"-"` ID int `json:"id"` Label string `json:"label"` } func (m *mockVertex) GraphsonType() Type { return m.Called().Get(0).(Type) } func TestEncodeTyper(t *testing.T) { m := &mockVertex{ID: 42, Label: "person"} m.On("GraphsonType").Return(Type("g:Vertex")).Twice() defer m.AssertExpectations(t) v := vertex{ID: m.ID, Label: m.Label} var vv Typer = v want := `{ "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 42 }, "label": "person" } }` for _, tc := range []interface{}{m, &m, v, vv, &vv} { got, err := MarshalToString(tc) assert.NoError(t, err) assert.JSONEq(t, want, got) } } func TestDecodeTyper(t *testing.T) { var m mockVertex m.On("GraphsonType").Return(Type("g:Vertex")).Once() defer m.AssertExpectations(t) in := `{ "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 55 }, "label": "user" } }` err := UnmarshalFromString(in, &m) assert.NoError(t, err) assert.Equal(t, 55, m.ID) assert.Equal(t, "user", m.Label) } ent-0.5.4/dialect/gremlin/encoding/graphson/util.go000066400000000000000000000052001377533537200222770ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "io" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" ) // graphson encoding type / value keys const ( TypeKey = "@type" ValueKey = "@value" ) // typeEncoder adds graphson type information to a value encoder. type typeEncoder struct { jsoniter.ValEncoder Type Type } // Encode belongs to jsoniter.ValEncoder interface. func (enc typeEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { stream.WriteObjectStart() stream.WriteObjectField(TypeKey) stream.WriteString(enc.Type.String()) stream.WriteMore() stream.WriteObjectField(ValueKey) enc.ValEncoder.Encode(ptr, stream) stream.WriteObjectEnd() } type ( // typeDecoder decorates a value decoder and adds graphson type verification. typeDecoder struct { jsoniter.ValDecoder typeChecker } // typeChecker defines an interface for graphson type verification. typeChecker interface { CheckType(Type) error } // typeCheckerFunc allows the use of functions as type checkers. typeCheckerFunc func(Type) error // typeValue defines a graphson type / value pair. typeValue struct { Type Type Value jsoniter.RawMessage } ) // Decode belongs to jsoniter.ValDecoder interface. func (dec typeDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { if iter.WhatIsNext() != jsoniter.ObjectValue { dec.ValDecoder.Decode(ptr, iter) return } data := iter.SkipAndReturnBytes() if iter.Error != nil && iter.Error != io.EOF { return } var tv typeValue if err := jsoniter.Unmarshal(data, &tv); err != nil { iter.ReportError("unmarshal type value", err.Error()) return } if err := dec.CheckType(tv.Type); err != nil { iter.ReportError("check type", err.Error()) return } it := config.BorrowIterator(tv.Value) defer config.ReturnIterator(it) dec.ValDecoder.Decode(ptr, it) if it.Error != nil && it.Error != io.EOF { iter.ReportError("decode value", it.Error.Error()) } } // UnmarshalJSON implements json.Unmarshaler interface. func (tv *typeValue) UnmarshalJSON(data []byte) error { var v struct { Type *Type `json:"@type"` Value jsoniter.RawMessage `json:"@value"` } if err := jsoniter.Unmarshal(data, &v); err != nil { return err } if v.Type == nil || v.Value == nil { return errors.New("missing type or value") } tv.Type = *v.Type tv.Value = v.Value return nil } // CheckType implements typeChecker interface. func (f typeCheckerFunc) CheckType(typ Type) error { return f(typ) } ent-0.5.4/dialect/gremlin/encoding/graphson/util_test.go000066400000000000000000000063351377533537200233500ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "bytes" "errors" "fmt" "testing" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestTypeEncode(t *testing.T) { var got bytes.Buffer stream := config.BorrowStream(&got) defer config.ReturnStream(stream) typ, val := int32Type, 42 ptr := unsafe.Pointer(&val) var m mocker m.On("Encode", ptr, stream). Run(func(args mock.Arguments) { stream := args.Get(1).(*jsoniter.Stream) stream.WriteInt(val) }). Once() defer m.AssertExpectations(t) typeEncoder{&m, typ}.Encode(ptr, stream) require.NoError(t, stream.Flush()) want := fmt.Sprintf(`{"@type": "%s", "@value": %d}`, typ, val) assert.JSONEq(t, want, got.String()) } func TestTypeDecode(t *testing.T) { typ, val := int64Type, 84 ptr := unsafe.Pointer(&val) data := fmt.Sprintf(`{"@value": %d, "@type": "%s"}`, val, typ) iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) m := &mocker{} m.On("CheckType", typ). Return(nil). Once() m.On("Decode", ptr, mock.Anything). Run(func(args mock.Arguments) { iter := args.Get(1).(*jsoniter.Iterator) assert.Equal(t, val, iter.ReadInt()) }). Once() defer m.AssertExpectations(t) typeDecoder{m, m}.Decode(ptr, iter) assert.NoError(t, iter.Error) } func TestTypeDecodeBadType(t *testing.T) { typ, val := int64Type, 55 ptr := unsafe.Pointer(&val) m := &mocker{} m.On("CheckType", typ).Return(errors.New("bad type")).Once() defer m.AssertExpectations(t) data := fmt.Sprintf(`{"@type": "%s", "@value": %d}`, typ, val) iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) typeDecoder{m, m}.Decode(ptr, iter) require.Error(t, iter.Error) assert.Contains(t, iter.Error.Error(), "bad type") } func TestTypeDecodeDuplicateField(t *testing.T) { data := `{"@type": "gx:Byte", "@value": 33, "@type": "g:Int32"}` iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) var ptr unsafe.Pointer m := &mocker{} m.On("CheckType", mock.MatchedBy(func(typ Type) bool { return typ == int32Type })). Return(nil). Once() m.On("Decode", ptr, mock.Anything). Run(func(args mock.Arguments) { args.Get(1).(*jsoniter.Iterator).Skip() require.NoError(t, iter.Error) }). Once() defer m.AssertExpectations(t) typeDecoder{m, m}.Decode(ptr, iter) assert.NoError(t, iter.Error) } func TestTypeDecodeMissingField(t *testing.T) { data := `{"@type": "g:Int32"}` iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) m := &mocker{} defer m.AssertExpectations(t) typeDecoder{m, m}.Decode(nil, iter) require.Error(t, iter.Error) assert.Contains(t, iter.Error.Error(), "missing type or value") } func TestTypeDecodeSyntaxError(t *testing.T) { data := `{"@type": "gx:Int16", "@value", 65000}` iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) m := &mocker{} defer m.AssertExpectations(t) typeDecoder{m, m}.Decode(nil, iter) assert.Error(t, iter.Error) } ent-0.5.4/dialect/gremlin/encoding/mime.go000066400000000000000000000012211377533537200204270ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package encoding import ( "bytes" ) // Mime defines a gremlin mime type. type Mime []byte // Graphson mime headers. var ( GraphSON3Mime = NewMime("application/vnd.gremlin-v3.0+json") ) // NewMime creates a wire format mime header. func NewMime(s string) Mime { var buf bytes.Buffer buf.WriteByte(byte(len(s))) buf.WriteString(s) return buf.Bytes() } // String implements fmt.Stringer interface. func (m Mime) String() string { return string(m[1:]) } ent-0.5.4/dialect/gremlin/encoding/mime_test.go000066400000000000000000000010661377533537200214750ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package encoding import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewMime(t *testing.T) { str := "application/vnd.gremlin-v2.0+json" mime := NewMime(str) require.Len(t, mime, len(str)+1) assert.EqualValues(t, len(str), mime[0]) assert.EqualValues(t, str, mime[1:]) assert.Equal(t, str, mime.String()) } ent-0.5.4/dialect/gremlin/example_test.go000066400000000000000000000016311377533537200204110ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "flag" "log" "os" "time" ) func ExampleClient_Query() { addr := flag.String("gremlin-server", os.Getenv("GREMLIN_SERVER"), "gremlin server address") flag.Parse() if *addr == "" { log.Fatal("missing gremlin server address") } client, err := NewHTTPClient(*addr, nil) if err != nil { log.Fatalf("creating client: %v", err) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() rsp, err := client.Query(ctx, "g.E()") if err != nil { log.Fatalf("executing query: %v", err) } edges, err := rsp.ReadEdges() if err != nil { log.Fatalf("unmashal edges") } for _, e := range edges { log.Println(e.String()) } // - Output: } ent-0.5.4/dialect/gremlin/expand.go000066400000000000000000000025401377533537200171760ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "sort" "strings" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" ) // ExpandBindings expands the given RoundTripper and expands the request bindings into the Gremlin traversal. func ExpandBindings(rt RoundTripper) RoundTripper { return RoundTripperFunc(func(ctx context.Context, r *Request) (*Response, error) { bindings, ok := r.Arguments[ArgsBindings] if !ok { return rt.RoundTrip(ctx, r) } query, ok := r.Arguments[ArgsGremlin] if !ok { return rt.RoundTrip(ctx, r) } { query, bindings := query.(string), bindings.(map[string]interface{}) keys := make(sort.StringSlice, 0, len(bindings)) for k := range bindings { keys = append(keys, k) } sort.Sort(sort.Reverse(keys)) kv := make([]string, 0, len(bindings)*2) for _, k := range keys { s, err := jsoniter.MarshalToString(bindings[k]) if err != nil { return nil, errors.WithMessagef(err, "marshal bindings value for key %s", k) } kv = append(kv, k, s) } delete(r.Arguments, ArgsBindings) r.Arguments[ArgsGremlin] = strings.NewReplacer(kv...).Replace(query) } return rt.RoundTrip(ctx, r) }) } ent-0.5.4/dialect/gremlin/expand_test.go000066400000000000000000000040751377533537200202420ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "strconv" "testing" "github.com/stretchr/testify/assert" ) func TestExpandBindings(t *testing.T) { tests := []struct { req *Request wantErr bool wantQuery string }{ { req: NewEvalRequest("no bindings"), wantQuery: "no bindings", }, { req: NewEvalRequest("g.V($0)", WithBindings(map[string]interface{}{"$0": 1})), wantQuery: "g.V(1)", }, { req: NewEvalRequest("g.V().has($1, $2)", WithBindings(map[string]interface{}{"$1": "name", "$2": "a8m"})), wantQuery: "g.V().has(\"name\", \"a8m\")", }, { req: NewEvalRequest("g.V().limit(n)", WithBindings(map[string]interface{}{"n": 10})), wantQuery: "g.V().limit(10)", }, { req: NewEvalRequest("g.V()", WithBindings(map[string]interface{}{"$0": func() {}})), wantErr: true, }, { req: NewEvalRequest("g.V().has($0, $1)", WithBindings(map[string]interface{}{"$0": "active", "$1": true})), wantQuery: "g.V().has(\"active\", true)", }, { req: NewEvalRequest("g.V().has($1, $11)", WithBindings(map[string]interface{}{"$1": "active", "$11": true})), wantQuery: "g.V().has(\"active\", true)", }, } for i, tt := range tests { tt := tt t.Run(strconv.Itoa(i), func(t *testing.T) { rt := ExpandBindings(RoundTripperFunc(func(ctx context.Context, r *Request) (*Response, error) { assert.Equal(t, tt.wantQuery, r.Arguments[ArgsGremlin]) return nil, nil })) _, err := rt.RoundTrip(context.Background(), tt.req) assert.Equal(t, tt.wantErr, err != nil) }) } } func TestExpandBindingsNoQuery(t *testing.T) { rt := ExpandBindings(RoundTripperFunc(func(ctx context.Context, r *Request) (*Response, error) { return nil, nil })) _, err := rt.RoundTrip(context.Background(), &Request{Arguments: map[string]interface{}{ ArgsBindings: map[string]interface{}{}, }}) assert.NoError(t, err) } ent-0.5.4/dialect/gremlin/graph/000077500000000000000000000000001377533537200164705ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/graph/dsl/000077500000000000000000000000001377533537200172525ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/graph/dsl/__/000077500000000000000000000000001377533537200176275ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/graph/dsl/__/dsl.go000066400000000000000000000054771377533537200207550ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package __ import "github.com/facebook/ent/dialect/gremlin/graph/dsl" // As is the api for calling __.As(). func As(args ...interface{}) *dsl.Traversal { return New().As(args...) } // Is is the api for calling __.Is(). func Is(args ...interface{}) *dsl.Traversal { return New().Is(args...) } // Not is the api for calling __.Not(). func Not(args ...interface{}) *dsl.Traversal { return New().Not(args...) } // Has is the api for calling __.Has(). func Has(args ...interface{}) *dsl.Traversal { return New().Has(args...) } // HasNot is the api for calling __.HasNot(). func HasNot(args ...interface{}) *dsl.Traversal { return New().HasNot(args...) } // Or is the api for calling __.Or(). func Or(args ...interface{}) *dsl.Traversal { return New().Or(args...) } // And is the api for calling __.And(). func And(args ...interface{}) *dsl.Traversal { return New().And(args...) } // In is the api for calling __.In(). func In(args ...interface{}) *dsl.Traversal { return New().In(args...) } // Out is the api for calling __.Out(). func Out(args ...interface{}) *dsl.Traversal { return New().Out(args...) } // OutE is the api for calling __.OutE(). func OutE(args ...interface{}) *dsl.Traversal { return New().OutE(args...) } // InE is the api for calling __.InE(). func InE(args ...interface{}) *dsl.Traversal { return New().InE(args...) } // InV is the api for calling __.InV(). func InV(args ...interface{}) *dsl.Traversal { return New().InV(args...) } // V is the api for calling __.V(). func V(args ...interface{}) *dsl.Traversal { return New().V(args...) } // OutV is the api for calling __.OutV(). func OutV(args ...interface{}) *dsl.Traversal { return New().OutV(args...) } // Values is the api for calling __.Values(). func Values(args ...string) *dsl.Traversal { return New().Values(args...) } // Union is the api for calling __.Union(). func Union(args ...interface{}) *dsl.Traversal { return New().Union(args...) } // Constant is the api for calling __.Constant(). func Constant(args ...interface{}) *dsl.Traversal { return New().Constant(args...) } // Properties is the api for calling __.Properties(). func Properties(args ...interface{}) *dsl.Traversal { return New().Properties(args...) } // OtherV is the api for calling __.OtherV(). func OtherV() *dsl.Traversal { return New().OtherV() } // Count is the api for calling __.Count(). func Count() *dsl.Traversal { return New().Count() } // Drop is the api for calling __.Drop(). func Drop() *dsl.Traversal { return New().Drop() } // Fold is the api for calling __.Fold(). func Fold() *dsl.Traversal { return New().Fold() } func New() *dsl.Traversal { return new(dsl.Traversal).Add(dsl.Token("__")) } ent-0.5.4/dialect/gremlin/graph/dsl/dsl.go000066400000000000000000000114011377533537200203600ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package dsl provide an API for writing gremlin dsl queries almost as-is // in Go without using strings in the code. // // Note that, the API is not type-safe and assume the provided query and // its arguments are valid. package dsl import ( "fmt" "strings" "time" ) // Node represents a DSL step in the traversal. type Node interface { // Code returns the code representation of the element and its bindings (if any). Code() (string, []interface{}) } type ( // Token holds a simple token, like assignment. Token string // List represents a list of elements. List struct { Elements []interface{} } // Func represents a function call. Func struct { Name string Args []interface{} } // Block represents a block/group of nodes. Block struct { Nodes []interface{} } // Var represents a variable assignment and usage. Var struct { Name string Elem interface{} } ) // Code stringified the token. func (t Token) Code() (string, []interface{}) { return string(t), nil } // Code returns the code representation of a list. func (l List) Code() (string, []interface{}) { c, args := codeList(", ", l.Elements...) return fmt.Sprintf("[%s]", c), args } // Code returns the code representation of a function call. func (f Func) Code() (string, []interface{}) { c, args := codeList(", ", f.Args...) return fmt.Sprintf("%s(%s)", f.Name, c), args } // Code returns the code representation of group/block of nodes. func (b Block) Code() (string, []interface{}) { return codeList("; ", b.Nodes...) } // Code returns the code representation of variable declaration or its identifier. func (v Var) Code() (string, []interface{}) { c, args := code(v.Elem) if v.Name == "" { return c, args } return fmt.Sprintf("%s = %s", v.Name, c), args } // predefined nodes. var ( G = Token("g") Dot = Token(".") ) // NewFunc returns a new function node. func NewFunc(name string, args ...interface{}) *Func { return &Func{Name: name, Args: args} } // NewList returns a new list node. func NewList(args ...interface{}) *List { return &List{Elements: args} } // Querier is the interface that wraps the Query method. type Querier interface { // Query returns the query-string (similar to the Gremlin byte-code) and its bindings. Query() (string, Bindings) } // Bindings are used to associate a variable with a value. type Bindings map[string]interface{} // Add adds new value to the bindings map, formats it if needed, and returns its generated name. func (b Bindings) Add(v interface{}) string { k := fmt.Sprintf("$%x", len(b)) switch v := v.(type) { case time.Time: b[k] = v.UnixNano() default: b[k] = v } return k } // Cardinality of vertex properties. type Cardinality string // Cardinality options. const ( Set Cardinality = "set" Single Cardinality = "single" ) // Code implements the Node interface. func (c Cardinality) Code() (string, []interface{}) { return string(c), nil } // Order of vertex properties. type Order string // Order options. const ( Incr Order = "incr" Decr Order = "decr" Shuffle Order = "shuffle" ) // Code implements the Node interface. func (o Order) Code() (string, []interface{}) { return string(o), nil } // Column references a particular type of column in a complex data structure such as a Map, a Map.Entry, or a Path. type Column string // Column options. const ( Keys Column = "keys" Values Column = "values" ) // Code implements the Node interface. func (o Column) Code() (string, []interface{}) { return string(o), nil } // Scope used for steps that have a variable scope which alter the manner in which the step will behave in relation to how the traverses are processed. type Scope string // Scope options. const ( Local Scope = "local" Global Scope = "global" ) // Code implements the Node interface. func (s Scope) Code() (string, []interface{}) { return string(s), nil } func codeList(sep string, vs ...interface{}) (string, []interface{}) { var ( br strings.Builder args []interface{} ) for i, node := range vs { if i > 0 { br.WriteString(sep) } c, nargs := code(node) br.WriteString(c) args = append(args, nargs...) } return br.String(), args } func code(v interface{}) (string, []interface{}) { switch n := v.(type) { case Node: return n.Code() case *Traversal: var ( b strings.Builder args []interface{} ) for i := range n.nodes { code, nargs := n.nodes[i].Code() b.WriteString(code) args = append(args, nargs...) } return b.String(), args default: return "%s", []interface{}{v} } } func sface(args []string) (v []interface{}) { for _, s := range args { v = append(v, s) } return } ent-0.5.4/dialect/gremlin/graph/dsl/dsl_test.go000066400000000000000000000201671377533537200214300ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package dsl_test import ( "strconv" "testing" "github.com/facebook/ent/dialect/gremlin/graph/dsl" "github.com/facebook/ent/dialect/gremlin/graph/dsl/__" "github.com/facebook/ent/dialect/gremlin/graph/dsl/g" "github.com/facebook/ent/dialect/gremlin/graph/dsl/p" "github.com/stretchr/testify/require" ) func TestTraverse(t *testing.T) { tests := []struct { input dsl.Querier wantQuery string wantBinds dsl.Bindings }{ { input: g.V(5), wantQuery: "g.V($0)", wantBinds: dsl.Bindings{"$0": 5}, }, { input: g.V(2).Both("knows"), wantQuery: "g.V($0).both($1)", wantBinds: dsl.Bindings{"$0": 2, "$1": "knows"}, }, { input: g.V(49).BothE("knows").OtherV().ValueMap(), wantQuery: "g.V($0).bothE($1).otherV().valueMap()", wantBinds: dsl.Bindings{"$0": 49, "$1": "knows"}, }, { input: g.AddV("person").Property("name", "a8m").Next(), wantQuery: "g.addV($0).property($1, $2).next()", wantBinds: dsl.Bindings{"$0": "person", "$1": "name", "$2": "a8m"}, }, { input: dsl.Each([]interface{}{1, 2, 3}, func(it *dsl.Traversal) *dsl.Traversal { return g.V(it) }), wantQuery: "[$0, $1, $2].each { g.V(it) }", wantBinds: dsl.Bindings{"$0": 1, "$1": 2, "$2": 3}, }, { input: dsl.Each([]interface{}{g.V(1).Next()}, func(it *dsl.Traversal) *dsl.Traversal { return it.ID() }), wantQuery: "[g.V($0).next()].each { it.id() }", wantBinds: dsl.Bindings{"$0": 1}, }, { input: g.AddV("person").AddE("knows").To(g.V(2)), wantQuery: "g.addV($0).addE($1).to(g.V($2))", wantBinds: dsl.Bindings{"$0": "person", "$1": "knows", "$2": 2}, }, { input: func() *dsl.Traversal { v1 := g.V(2).Next() v2 := g.AddV("person").Property("name", "a8m") e1 := g.V(v1).AddE("knows").To(v2) return dsl.Group(v1, v2, e1) }(), wantQuery: "t0 = g.V($0).next(); t1 = g.addV($1).property($2, $3); t2 = g.V(t0).addE($4).to(t1); t2", wantBinds: dsl.Bindings{"$0": 2, "$1": "person", "$2": "name", "$3": "a8m", "$4": "knows"}, }, { input: func() *dsl.Traversal { v1 := g.AddV("person") each := dsl.Each([]interface{}{1, 2, 3}, func(it *dsl.Traversal) *dsl.Traversal { return g.V(v1).AddE("knows").To(g.V(it)).Next() }) return dsl.Group(v1, each) }(), wantQuery: "t0 = g.addV($0); t1 = [$1, $2, $3].each { g.V(t0).addE($4).to(g.V(it)).next() }; t1", wantBinds: dsl.Bindings{"$0": "person", "$1": 1, "$2": 2, "$3": 3, "$4": "knows"}, }, { input: g.V().HasLabel("person"). Choose(__.Values("age").Is(p.LTE(20))), wantQuery: "g.V().hasLabel($0).choose(__.values($1).is(lte($2)))", wantBinds: dsl.Bindings{"$0": "person", "$1": "age", "$2": 20}, }, { input: g.AddV("person").Property("name", "a8m").Properties(), wantQuery: "g.addV($0).property($1, $2).properties()", wantBinds: dsl.Bindings{"$0": "person", "$1": "name", "$2": "a8m"}, }, { input: func() *dsl.Traversal { v1 := g.AddV("person").Next() e1 := g.V(v1).AddE("knows").To(g.V(2).Next()) return dsl.Group(v1, e1, g.V(v1).ValueMap(true)) }(), wantQuery: "t0 = g.addV($0).next(); t1 = g.V(t0).addE($1).to(g.V($2).next()); t2 = g.V(t0).valueMap($3); t2", wantBinds: dsl.Bindings{"$0": "person", "$1": "knows", "$2": 2, "$3": true}, }, { input: func() *dsl.Traversal { vs := g.V().HasLabel("person").ToList() edge := g.V(vs).AddE("assoc").To(g.V(1)).Iterate() each := dsl.Each(vs, func(it *dsl.Traversal) *dsl.Traversal { return g.V(1).AddE("inverse").To(it).Next() }) return dsl.Group(vs, edge, each) }(), wantQuery: "t0 = g.V().hasLabel($0).toList(); t1 = g.V(t0).addE($1).to(g.V($2)).iterate(); t2 = t0.each { g.V($3).addE($4).to(it).next() }; t2", wantBinds: dsl.Bindings{"$0": "person", "$1": "assoc", "$2": 1, "$3": 1, "$4": "inverse"}, }, { input: g.V().Where(__.Or(__.Has("age", 29), __.Has("age", 30))), wantQuery: "g.V().where(__.or(__.has($0, $1), __.has($2, $3)))", wantBinds: dsl.Bindings{"$0": "age", "$1": 29, "$2": "age", "$3": 30}, }, { input: g.V().Has("name", p.Containing("le")).Has("name", p.StartingWith("A")), wantQuery: `g.V().has($0, containing($1)).has($2, startingWith($3))`, wantBinds: dsl.Bindings{"$0": "name", "$1": "le", "$2": "name", "$3": "A"}, }, { input: g.AddV().Property(dsl.Single, "age", 32).ValueMap(), wantQuery: "g.addV().property(single, $0, $1).valueMap()", wantBinds: dsl.Bindings{"$0": "age", "$1": 32}, }, { input: g.V().Count(), wantQuery: "g.V().count()", wantBinds: dsl.Bindings{}, }, { input: g.V().HasNot("age"), wantQuery: "g.V().hasNot($0)", wantBinds: dsl.Bindings{"$0": "age"}, }, { input: func() *dsl.Traversal { v := g.V().HasID(1) u := v.Clone().InE().Drop() return dsl.Join(v, u) }(), wantQuery: "g.V().hasId($0); g.V().hasId($1).inE().drop()", wantBinds: dsl.Bindings{"$0": 1, "$1": 1}, }, { input: func() *dsl.Traversal { v := g.V().HasID(1) u := v.Clone().InE().Drop() w := u.Clone() return dsl.Join(v, u, w) }(), wantQuery: "g.V().hasId($0); g.V().hasId($1).inE().drop(); g.V().hasId($2).inE().drop()", wantBinds: dsl.Bindings{"$0": 1, "$1": 1, "$2": 1}, }, { input: g.V().OutE("knows").Where(__.InV().Has("name", "a8m")).OutV(), wantQuery: "g.V().outE($0).where(__.inV().has($1, $2)).outV()", wantBinds: dsl.Bindings{"$0": "knows", "$1": "name", "$2": "a8m"}, }, { input: g.V().Has("name", p.Within("a8m", "alex")), wantQuery: "g.V().has($0, within($1, $2))", wantBinds: dsl.Bindings{"$0": "name", "$1": "a8m", "$2": "alex"}, }, { input: g.V().HasID(p.Within(1, 2)), wantQuery: "g.V().hasId(within($0, $1))", wantBinds: dsl.Bindings{"$0": 1, "$1": 2}, }, { input: g.V().HasID(p.Without(1, 2)), wantQuery: "g.V().hasId(without($0, $1))", wantBinds: dsl.Bindings{"$0": 1, "$1": 2}, }, { input: g.V().Order().By("name"), wantQuery: "g.V().order().by($0)", wantBinds: dsl.Bindings{"$0": "name"}, }, { input: g.V().Order().By("name", dsl.Incr), wantQuery: "g.V().order().by($0, incr)", wantBinds: dsl.Bindings{"$0": "name"}, }, { input: g.V().Order().By("name", dsl.Incr).Undo(), wantQuery: "g.V().order()", wantBinds: dsl.Bindings{}, }, { input: g.V().OutE("knows").Where(__.InV().Has("name", "a8m")).Undo(), wantQuery: "g.V().outE($0)", wantBinds: dsl.Bindings{"$0": "knows"}, }, { input: g.V().Has("name").Group().By("name").By("age").Select(dsl.Values), wantQuery: "g.V().has($0).group().by($1).by($2).select(values)", wantBinds: dsl.Bindings{"$0": "name", "$1": "name", "$2": "age"}, }, { input: g.V().Fold().Unfold(), wantQuery: "g.V().fold().unfold()", wantBinds: dsl.Bindings{}, }, { input: g.V().Has("person", "name", "a8m").Count().Coalesce( __.Is(p.NEQ(0)).Constant("unique constraint failed"), g.AddV("person").Property("name", "a8m").ValueMap(true), ), wantQuery: "g.V().has($0, $1, $2).count().coalesce(__.is(neq($3)).constant($4), g.addV($5).property($6, $7).valueMap($8))", wantBinds: dsl.Bindings{"$0": "person", "$1": "name", "$2": "a8m", "$3": 0, "$4": "unique constraint failed", "$5": "person", "$6": "name", "$7": "a8m", "$8": true}, }, { input: g.V().Has("age").Property("age", __.Union(__.Values("age"), __.Constant(10)).Sum()).ValueMap(), wantQuery: "g.V().has($0).property($1, __.union(__.values($2), __.constant($3)).sum()).valueMap()", wantBinds: dsl.Bindings{"$0": "age", "$1": "age", "$2": "age", "$3": 10}, }, { input: g.V().Has("age").SideEffect(__.Properties("name").Drop()).ValueMap(), wantQuery: "g.V().has($0).sideEffect(__.properties($1).drop()).valueMap()", wantBinds: dsl.Bindings{"$0": "age", "$1": "name"}, }, } for i, tt := range tests { tt := tt t.Run(strconv.Itoa(i), func(t *testing.T) { query, bindings := tt.input.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantBinds, bindings) }) } } ent-0.5.4/dialect/gremlin/graph/dsl/g/000077500000000000000000000000001377533537200175005ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/graph/dsl/g/g.go000066400000000000000000000014011377533537200202510ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package g import "github.com/facebook/ent/dialect/gremlin/graph/dsl" // V is the api for calling g.V(). func V(args ...interface{}) *dsl.Traversal { return dsl.NewTraversal().V(args...) } // E is the api for calling g.E(). func E(args ...interface{}) *dsl.Traversal { return dsl.NewTraversal().E(args...) } // AddV is the api for calling g.AddV(). func AddV(args ...interface{}) *dsl.Traversal { return dsl.NewTraversal().AddV(args...) } // AddE is the api for calling g.AddE(). func AddE(args ...interface{}) *dsl.Traversal { return dsl.NewTraversal().AddE(args...) } ent-0.5.4/dialect/gremlin/graph/dsl/p/000077500000000000000000000000001377533537200175115ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/graph/dsl/p/p.go000066400000000000000000000043321377533537200203010ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package p import ( "github.com/facebook/ent/dialect/gremlin/graph/dsl" ) // EQ is the equal predicate. func EQ(v interface{}) *dsl.Traversal { return op("eq", v) } // NEQ is the not-equal predicate. func NEQ(v interface{}) *dsl.Traversal { return op("neq", v) } // GT is the greater than predicate. func GT(v interface{}) *dsl.Traversal { return op("gt", v) } // GTE is the greater than or equal predicate. func GTE(v interface{}) *dsl.Traversal { return op("gte", v) } // LT is the less than predicate. func LT(v interface{}) *dsl.Traversal { return op("lt", v) } // LTE is the less than or equal predicate. func LTE(v interface{}) *dsl.Traversal { return op("lte", v) } // Between is the between/contains predicate. func Between(v, u interface{}) *dsl.Traversal { return op("between", v, u) } // StartingWith is the prefix test predicate. func StartingWith(prefix string) *dsl.Traversal { return op("startingWith", prefix) } // EndingWith is the suffix test predicate. func EndingWith(suffix string) *dsl.Traversal { return op("endingWith", suffix) } // Containing is the sub string test predicate. func Containing(substr string) *dsl.Traversal { return op("containing", substr) } // NotStartingWith is the negation of StartingWith. func NotStartingWith(prefix string) *dsl.Traversal { return op("notStartingWith", prefix) } // NotEndingWith is the negation of EndingWith. func NotEndingWith(suffix string) *dsl.Traversal { return op("notEndingWith", suffix) } // NotContaining is the negation of Containing. func NotContaining(substr string) *dsl.Traversal { return op("notContaining", substr) } // Within Determines if a value is within the specified list of values. func Within(args ...interface{}) *dsl.Traversal { return op("within", args...) } // Without determines if a value is not within the specified list of values. func Without(args ...interface{}) *dsl.Traversal { return op("without", args...) } func op(name string, args ...interface{}) *dsl.Traversal { t := &dsl.Traversal{} return t.Add(dsl.NewFunc(name, args...)) } ent-0.5.4/dialect/gremlin/graph/dsl/traversal.go000066400000000000000000000326371377533537200216170ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package dsl import ( "fmt" "strings" ) // Traversal mimics the TinkerPop graph traversal. type Traversal struct { // nodes holds the dsl nodes. first element is the reference name // of the TinkerGraph. defaults to "g". nodes []Node } // NewTraversal returns a new default traversal with "g" as a reference name to the Graph. func NewTraversal() *Traversal { return &Traversal{[]Node{G}} } // Group groups a list of traversals into one. all traversals are assigned into a temporary // variables named by their index. The last variable functions as a return value of the query. // Note that, this "temporary hack" is not perfect and may not work in some cases because of // the limitation of evaluation order. func Group(trs ...*Traversal) *Traversal { var ( b = Block{} names = make(map[*Traversal]Token) ) for i, tr := range trs { if _, ok := names[tr]; ok { continue } v := &Var{Name: fmt.Sprintf("t%d", i), Elem: &Traversal{nodes: tr.nodes}} b.Nodes = append(b.Nodes, v) names[tr] = Token(v.Name) } for _, tr := range trs { tr.nodes = []Node{names[tr]} } b.Nodes = append(b.Nodes, names[trs[len(trs)-1]]) return &Traversal{[]Node{b}} } // Join joins a list of traversals with a semicolon separator. func Join(trs ...*Traversal) *Traversal { b := Block{} for _, tr := range trs { b.Nodes = append(b.Nodes, &Traversal{nodes: tr.nodes}) } return &Traversal{[]Node{b}} } // V step is usually used to start a traversal but it may also be used mid-traversal. func (t *Traversal) V(args ...interface{}) *Traversal { t.Add(Dot, NewFunc("V", args...)) return t } // OtherV maps the Edge to the incident vertex that was not just traversed from in the path history. func (t *Traversal) OtherV() *Traversal { t.Add(Dot, NewFunc("otherV")) return t } // E step is usually used to start a traversal but it may also be used mid-traversal. func (t *Traversal) E(args ...interface{}) *Traversal { t.Add(Dot, NewFunc("E", args...)) return t } // AddV adds a vertex. func (t *Traversal) AddV(args ...interface{}) *Traversal { t.Add(Dot, NewFunc("addV", args...)) return t } // AddE adds an edge. func (t *Traversal) AddE(args ...interface{}) *Traversal { t.Add(Dot, NewFunc("addE", args...)) return t } // Next gets the next n-number of results from the traversal. func (t *Traversal) Next() *Traversal { return t.Add(Dot, NewFunc("next")) } // Drop removes elements and properties from the graph. func (t *Traversal) Drop() *Traversal { return t.Add(Dot, NewFunc("drop")) } // Property sets a Property value and related meta properties if supplied, // if supported by the Graph and if the Element is a VertexProperty. func (t *Traversal) Property(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("property", args...)) } // Both maps the Vertex to its adjacent vertices given the edge labels. func (t *Traversal) Both(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("both", args...)) } // BothE maps the Vertex to its incident edges given the edge labels. func (t *Traversal) BothE(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("bothE", args...)) } // Has filters vertices, edges and vertex properties based on their properties. // See: http://tinkerpop.apache.org/docs/current/reference/#has-step. func (t *Traversal) Has(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("has", args...)) } // HasNot filters vertices, edges and vertex properties based on the non-existence of properties. // See: http://tinkerpop.apache.org/docs/current/reference/#has-step. func (t *Traversal) HasNot(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("hasNot", args...)) } // HasID filters vertices, edges and vertex properties based on their identifier. func (t *Traversal) HasID(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("hasId", args...)) } // HasLabel filters vertices, edges and vertex properties based on their label. func (t *Traversal) HasLabel(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("hasLabel", args...)) } // HasNext returns true if the iteration has more elements. func (t *Traversal) HasNext() *Traversal { return t.Add(Dot, NewFunc("hasNext")) } // Match maps the Traverser to a Map of bindings as specified by the provided match traversals. func (t *Traversal) Match(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("match", args...)) } // Choose routes the current traverser to a particular traversal branch option which allows the creation of if-then-else like semantics within a traversal. func (t *Traversal) Choose(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("choose", args...)) } // Select arbitrary values from the traversal. func (t *Traversal) Select(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("select", args...)) } // Group organizes objects in the stream into a Map.Calls to group() are typically accompanied with by() modulators which help specify how the grouping should occur. func (t *Traversal) Group() *Traversal { return t.Add(Dot, NewFunc("group")) } // Values maps the Element to the values of the associated properties given the provide property keys. func (t *Traversal) Values(args ...string) *Traversal { return t.Add(Dot, NewFunc("values", sface(args)...)) } // ValueMap maps the Element to a Map of the property values key'd according to their Property.key(). func (t *Traversal) ValueMap(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("valueMap", args...)) } // Properties maps the Element to its associated properties given the provide property keys. func (t *Traversal) Properties(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("properties", args...)) } // Range filters the objects in the traversal by the number of them to pass through the stream. func (t *Traversal) Range(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("range", args...)) } // Limit filters the objects in the traversal by the number of them to pass through the stream, where only the first n objects are allowed as defined by the limit argument. func (t *Traversal) Limit(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("limit", args...)) } // ID maps the Element to its Element.id(). func (t *Traversal) ID() *Traversal { return t.Add(Dot, NewFunc("id")) } // Label maps the Element to its Element.label(). func (t *Traversal) Label() *Traversal { return t.Add(Dot, NewFunc("label")) } // From provides from()-modulation to respective steps. func (t *Traversal) From(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("from", args...)) } // To used as a modifier to addE(String) this method specifies the traversal to use for selecting the incoming vertex of the newly added Edge. func (t *Traversal) To(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("to", args...)) } // As provides a label to the step that can be accessed later in the traversal by other steps. func (t *Traversal) As(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("as", args...)) } // Or ensures that at least one of the provided traversals yield a result. func (t *Traversal) Or(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("or", args...)) } // And ensures that all of the provided traversals yield a result. func (t *Traversal) And(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("and", args...)) } // Is filters the E object if it is not P.eq(V) to the provided value. func (t *Traversal) Is(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("is", args...)) } // Not removes objects from the traversal stream when the traversal provided as an argument does not return any objects. func (t *Traversal) Not(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("not", args...)) } // In maps the Vertex to its incoming adjacent vertices given the edge labels. func (t *Traversal) In(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("in", args...)) } // Where filters the current object based on the object itself or the path history. func (t *Traversal) Where(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("where", args...)) } // Out maps the Vertex to its outgoing adjacent vertices given the edge labels. func (t *Traversal) Out(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("out", args...)) } // OutE maps the Vertex to its outgoing incident edges given the edge labels. func (t *Traversal) OutE(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("outE", args...)) } // InE maps the Vertex to its incoming incident edges given the edge labels. func (t *Traversal) InE(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("inE", args...)) } // OutV maps the Edge to its outgoing/tail incident Vertex. func (t *Traversal) OutV(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("outV", args...)) } // InV maps the Edge to its incoming/head incident Vertex. func (t *Traversal) InV(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("inV", args...)) } // ToList puts all the results into a Groovy list. func (t *Traversal) ToList() *Traversal { return t.Add(Dot, NewFunc("toList")) } // Iterate iterates the traversal presumably for the generation of side-effects. func (t *Traversal) Iterate() *Traversal { return t.Add(Dot, NewFunc("iterate")) } // Count maps the traversal stream to its reduction as a sum of the Traverser.bulk() values // (i.e. count the number of traversers up to this point). func (t *Traversal) Count(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("count", args...)) } // Order all the objects in the traversal up to this point and then emit them one-by-one in their ordered sequence. func (t *Traversal) Order(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("order", args...)) } // By can be applied to a number of different step to alter their behaviors. // This form is essentially an identity() modulation. func (t *Traversal) By(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("by", args...)) } // Fold rolls up objects in the stream into an aggregate list.. func (t *Traversal) Fold() *Traversal { return t.Add(Dot, NewFunc("fold")) } // Unfold unrolls a Iterator, Iterable or Map into a linear form or simply emits the object if it is not one of those types. func (t *Traversal) Unfold() *Traversal { return t.Add(Dot, NewFunc("unfold")) } // Sum maps the traversal stream to its reduction as a sum of the Traverser.get() values multiplied by their Traverser.bulk(). func (t *Traversal) Sum(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("sum", args...)) } // Mean determines the mean value in the stream. func (t *Traversal) Mean(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("mean", args...)) } // Min determines the smallest value in the stream. func (t *Traversal) Min(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("min", args...)) } // Max determines the greatest value in the stream. func (t *Traversal) Max(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("max", args...)) } // Coalesce evaluates the provided traversals and returns the result of the first traversal to emit at least one object. func (t *Traversal) Coalesce(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("coalesce", args...)) } // Dedup removes all duplicates in the traversal stream up to this point. func (t *Traversal) Dedup(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("dedup", args...)) } // Constant maps any object to a fixed E value. func (t *Traversal) Constant(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("constant", args...)) } // Union merges the results of an arbitrary number of traversals. func (t *Traversal) Union(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("union", args...)) } // SideEffect allows the traverser to proceed unchanged, but yield some computational // sideEffect in the process. func (t *Traversal) SideEffect(args ...interface{}) *Traversal { return t.Add(Dot, NewFunc("sideEffect", args...)) } // Each is a Groovy each-loop function. func Each(v interface{}, cb func(it *Traversal) *Traversal) *Traversal { t := &Traversal{} switch v := v.(type) { case *Traversal: t.Add(&Var{Elem: v}) case []interface{}: t.Add(NewList(v...)) default: t.Add(Token("undefined")) } t.Add(Dot, Token("each"), Token(" { ")) t.Add(cb(&Traversal{[]Node{Token("it")}}).nodes...) t.Add(Token(" }")) return t } // Add is the public API for adding new nodes to the traversal by its sub packages. func (t *Traversal) Add(n ...Node) *Traversal { t.nodes = append(t.nodes, n...) return t } // Query returns the query-representation and its binding of this traversal object. func (t *Traversal) Query() (string, Bindings) { var ( names []interface{} query strings.Builder bindings = Bindings{} ) for _, n := range t.nodes { code, args := n.Code() query.WriteString(code) for _, arg := range args { names = append(names, bindings.Add(arg)) } } return fmt.Sprintf(query.String(), names...), bindings } // Clone creates a deep copy of an existing traversal. func (t *Traversal) Clone() *Traversal { if t == nil { return nil } return &Traversal{append(make([]Node, 0, len(t.nodes)), t.nodes...)} } // Undo reverts the last-step of the traversal. func (t *Traversal) Undo() *Traversal { if n := len(t.nodes); n > 2 { t.nodes = t.nodes[:n-2] } return t } ent-0.5.4/dialect/gremlin/graph/edge.go000066400000000000000000000043511377533537200177260ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "fmt" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" "github.com/pkg/errors" ) type ( // An Edge between two vertices. Edge struct { Element OutV, InV Vertex } // graphson edge repr. edge struct { Element OutV interface{} `json:"outV"` OutVLabel string `json:"outVLabel"` InV interface{} `json:"inV"` InVLabel string `json:"inVLabel"` } ) // NewEdge create a new graph edge. func NewEdge(id interface{}, label string, outV, inV Vertex) Edge { return Edge{ Element: NewElement(id, label), OutV: outV, InV: inV, } } // String implements fmt.Stringer interface. func (e Edge) String() string { return fmt.Sprintf("e[%v][%v-%s->%v]", e.ID, e.OutV.ID, e.Label, e.InV.ID) } // MarshalGraphson implements graphson.Marshaler interface. func (e Edge) MarshalGraphson() ([]byte, error) { return graphson.Marshal(edge{ Element: e.Element, OutV: e.OutV.ID, OutVLabel: e.OutV.Label, InV: e.InV.ID, InVLabel: e.InV.Label, }) } // UnmarshalGraphson implements graphson.Unmarshaler interface. func (e *Edge) UnmarshalGraphson(data []byte) error { var edge edge if err := graphson.Unmarshal(data, &edge); err != nil { return errors.Wrap(err, "unmarshaling edge") } *e = NewEdge( edge.ID, edge.Label, NewVertex(edge.OutV, edge.OutVLabel), NewVertex(edge.InV, edge.InVLabel), ) return nil } // GraphsonType implements graphson.Typer interface. func (edge) GraphsonType() graphson.Type { return "g:Edge" } // Property denotes a key/value pair associated with an edge. type Property struct { Key string `json:"key"` Value interface{} `json:"value"` } // NewProperty create a new graph edge property. func NewProperty(key string, value interface{}) Property { return Property{key, value} } // GraphsonType implements graphson.Typer interface. func (Property) GraphsonType() graphson.Type { return "g:Property" } // String implements fmt.Stringer interface. func (p Property) String() string { return fmt.Sprintf("p[%s->%v]", p.Key, p.Value) } ent-0.5.4/dialect/gremlin/graph/edge_test.go000066400000000000000000000041171377533537200207650ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "fmt" "testing" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEdgeString(t *testing.T) { e := NewEdge( 13, "develops", NewVertex(1, ""), NewVertex(10, ""), ) assert.Equal(t, "e[13][1-develops->10]", fmt.Sprint(e)) } func TestEdgeEncoding(t *testing.T) { t.Parallel() e := NewEdge(13, "develops", NewVertex(1, "person"), NewVertex(10, "software"), ) got, err := graphson.MarshalToString(e) require.NoError(t, err) want := `{ "@type" : "g:Edge", "@value" : { "id" : { "@type" : "g:Int64", "@value" : 13 }, "label" : "develops", "inVLabel" : "software", "outVLabel" : "person", "inV" : { "@type" : "g:Int64", "@value" : 10 }, "outV" : { "@type" : "g:Int64", "@value" : 1 } } }` assert.JSONEq(t, want, got) e = Edge{} err = graphson.UnmarshalFromString(got, &e) require.NoError(t, err) assert.Equal(t, NewElement(int64(13), "develops"), e.Element) assert.Equal(t, NewVertex(int64(1), "person"), e.OutV) assert.Equal(t, NewVertex(int64(10), "software"), e.InV) } func TestPropertyEncoding(t *testing.T) { t.Parallel() props := []Property{ NewProperty("from", int32(2017)), NewProperty("to", int32(2019)), } got, err := graphson.MarshalToString(props) require.NoError(t, err) want := `{ "@type" : "g:List", "@value" : [ { "@type" : "g:Property", "@value" : { "key" : "from", "value" : { "@type" : "g:Int32", "@value" : 2017 } } }, { "@type" : "g:Property", "@value" : { "key" : "to", "value" : { "@type" : "g:Int32", "@value" : 2019 } } } ] }` assert.JSONEq(t, want, got) } func TestPropertyString(t *testing.T) { p := NewProperty("since", 2019) assert.Equal(t, "p[since->2019]", fmt.Sprint(p)) } ent-0.5.4/dialect/gremlin/graph/element.go000066400000000000000000000007401377533537200204510ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph // Element defines a base struct for graph elements. type Element struct { ID interface{} `json:"id"` Label string `json:"label"` } // NewElement create a new graph element. func NewElement(id interface{}, label string) Element { return Element{id, label} } ent-0.5.4/dialect/gremlin/graph/valuemap.go000066400000000000000000000025211377533537200206310ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "reflect" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" ) // ValueMap models a .valueMap() gremlin response. type ValueMap []map[string]interface{} // Decode decodes a value map into v. func (m ValueMap) Decode(v interface{}) error { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr { return errors.New("cannot unmarshal into a non pointer") } if rv.IsNil() { return errors.New("cannot unmarshal into a nil pointer") } if rv.Elem().Kind() != reflect.Slice { v = &[]interface{}{v} } return m.decode(v) } func (m ValueMap) decode(v interface{}) error { cfg := mapstructure.DecoderConfig{ DecodeHook: func(f, t reflect.Kind, data interface{}) (interface{}, error) { if f == reflect.Slice && t != reflect.Slice { rv := reflect.ValueOf(data) if rv.Len() == 1 { data = rv.Index(0).Interface() } } return data, nil }, Result: v, TagName: "json", } dec, err := mapstructure.NewDecoder(&cfg) if err != nil { return errors.Wrap(err, "creating structure decoder") } if err := dec.Decode(m); err != nil { return errors.Wrap(err, "decoding value map") } return nil } ent-0.5.4/dialect/gremlin/graph/valuemap_test.go000066400000000000000000000033161377533537200216730ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestValueMapDecodeOne(t *testing.T) { vm := ValueMap{map[string]interface{}{ "id": int64(1), "label": "person", "name": []interface{}{"marko"}, "age": []interface{}{int32(29)}, }} var ent struct { ID uint64 `json:"id"` Label string `json:"label"` Name string `json:"name"` Age uint8 `json:"age"` } err := vm.Decode(&ent) require.NoError(t, err) assert.Equal(t, uint64(1), ent.ID) assert.Equal(t, "person", ent.Label) assert.Equal(t, "marko", ent.Name) assert.Equal(t, uint8(29), ent.Age) } func TestValueMapDecodeMany(t *testing.T) { vm := ValueMap{ map[string]interface{}{ "id": int64(1), "label": "person", "name": []interface{}{"chico"}, }, map[string]interface{}{ "id": int64(2), "label": "person", "name": []interface{}{"dico"}, }, } ents := []struct { ID int `json:"id"` Label string `json:"label"` Name string `json:"name"` }{} err := vm.Decode(&ents) require.NoError(t, err) require.Len(t, ents, 2) assert.Equal(t, 1, ents[0].ID) assert.Equal(t, "person", ents[0].Label) assert.Equal(t, "chico", ents[0].Name) assert.Equal(t, 2, ents[1].ID) assert.Equal(t, "person", ents[1].Label) assert.Equal(t, "dico", ents[1].Name) } func TestValueMapDecodeBadInput(t *testing.T) { type s struct{ Name string } err := ValueMap{}.Decode(s{}) assert.Error(t, err) err = ValueMap{}.Decode((*s)(nil)) assert.Error(t, err) } ent-0.5.4/dialect/gremlin/graph/vertex.go000066400000000000000000000027131377533537200203370ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "fmt" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" ) // Vertex represents a graph vertex. type Vertex struct { Element } // NewVertex create a new graph vertex. func NewVertex(id interface{}, label string) Vertex { if label == "" { label = "vertex" } return Vertex{ Element: NewElement(id, label), } } // GraphsonType implements graphson.Typer interface. func (Vertex) GraphsonType() graphson.Type { return "g:Vertex" } // String implements fmt.Stringer interface. func (v Vertex) String() string { return fmt.Sprintf("v[%v]", v.ID) } // VertexProperty denotes a key/value pair associated with a vertex. type VertexProperty struct { ID interface{} `json:"id"` Key string `json:"label"` Value interface{} `json:"value"` } // NewVertexProperty create a new graph vertex property. func NewVertexProperty(id interface{}, key string, value interface{}) VertexProperty { return VertexProperty{ ID: id, Key: key, Value: value, } } // GraphsonType implements graphson.Typer interface. func (VertexProperty) GraphsonType() graphson.Type { return "g:VertexProperty" } // String implements fmt.Stringer interface. func (vp VertexProperty) String() string { return fmt.Sprintf("vp[%s->%v]", vp.Key, vp.Value) } ent-0.5.4/dialect/gremlin/graph/vertex_test.go000066400000000000000000000036201377533537200213740ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "fmt" "testing" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestVertexCreation(t *testing.T) { v := NewVertex(45, "person") assert.Equal(t, 45, v.ID) assert.Equal(t, "person", v.Label) v = NewVertex(46, "") assert.Equal(t, "vertex", v.Label) } func TestVertexString(t *testing.T) { v := NewVertex(42, "") assert.Equal(t, "v[42]", fmt.Sprint(v)) } func TestVertexEncoding(t *testing.T) { t.Parallel() v := NewVertex(1, "user") got, err := graphson.MarshalToString(v) require.NoError(t, err) want := `{ "@type" : "g:Vertex", "@value" : { "id" : { "@type" : "g:Int64", "@value" : 1 }, "label" : "user" } }` assert.JSONEq(t, want, got) v = Vertex{} err = graphson.UnmarshalFromString(got, &v) require.NoError(t, err) assert.Equal(t, int64(1), v.ID) assert.Equal(t, "user", v.Label) } func TestVertexPropertyEncoding(t *testing.T) { t.Parallel() vp := NewVertexProperty("46ab60c2-918c-4cc4-a13b-350510e8908a", "name", "alex") got, err := graphson.MarshalToString(vp) require.NoError(t, err) want := `{ "@type" : "g:VertexProperty", "@value" : { "id" : "46ab60c2-918c-4cc4-a13b-350510e8908a", "label": "name", "value": "alex" } }` assert.JSONEq(t, want, got) vp = VertexProperty{} err = graphson.UnmarshalFromString(got, &vp) require.NoError(t, err) assert.Equal(t, "46ab60c2-918c-4cc4-a13b-350510e8908a", vp.ID) assert.Equal(t, "name", vp.Key) assert.Equal(t, "alex", vp.Value) } func TestVertexPropertyString(t *testing.T) { vp := NewVertexProperty(55, "country", "israel") assert.Equal(t, "vp[country->israel]", fmt.Sprint(vp)) } ent-0.5.4/dialect/gremlin/http.go000066400000000000000000000044261377533537200167030ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "io" "io/ioutil" "net/http" "net/url" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" jsoniter "github.com/json-iterator/go" "github.com/pkg/errors" ) type httpTransport struct { client *http.Client url string } // NewHTTPTransport returns a new http transport. func NewHTTPTransport(urlStr string, client *http.Client) (RoundTripper, error) { u, err := url.Parse(urlStr) if err != nil { return nil, errors.Wrap(err, "gremlin/http: parsing url") } if client == nil { client = http.DefaultClient } return &httpTransport{client, u.String()}, nil } // RoundTrip implements RouterTripper interface. func (t *httpTransport) RoundTrip(ctx context.Context, req *Request) (*Response, error) { if req.Operation != OpsEval { return nil, errors.Errorf("gremlin/http: unsupported operation: %q", req.Operation) } if _, ok := req.Arguments[ArgsGremlin]; !ok { return nil, errors.New("gremlin/http: missing query expression") } pr, pw := io.Pipe() defer pr.Close() go func() { err := jsoniter.NewEncoder(pw).Encode(req.Arguments) _ = pw.CloseWithError(errors.Wrap(err, "gremlin/http: encoding request")) }() var br io.Reader { req, err := http.NewRequest(http.MethodPost, t.url, pr) if err != nil { return nil, errors.Wrap(err, "gremlin/http: creating http request") } req.Header.Set("Content-Type", "application/json") rsp, err := t.client.Do(req.WithContext(ctx)) if err != nil { return nil, errors.Wrap(err, "gremlin/http: posting http request") } defer rsp.Body.Close() if rsp.StatusCode < http.StatusOK || rsp.StatusCode > http.StatusPartialContent { body, _ := ioutil.ReadAll(rsp.Body) return nil, errors.Errorf("gremlin/http: status=%q, body=%q", rsp.Status, body) } if rsp.ContentLength > MaxResponseSize { return nil, errors.New("gremlin/http: context length exceeds limit") } br = rsp.Body } var rsp Response if err := graphson.NewDecoder(io.LimitReader(br, MaxResponseSize)).Decode(&rsp); err != nil { return nil, errors.Wrap(err, "gremlin/http: decoding response") } return &rsp, nil } ent-0.5.4/dialect/gremlin/http_test.go000066400000000000000000000070671377533537200177460ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "io" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHTTPTransportRoundTripper(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, r.Header.Get("Content-Type"), "application/json") got, err := ioutil.ReadAll(r.Body) require.NoError(t, err) assert.JSONEq(t, `{"gremlin": "g.V(1)", "language": "gremlin-groovy"}`, string(got)) _, err = io.WriteString(w, `{ "requestId": "f679127f-8701-425c-af55-049a44720db6", "result": { "data": { "@type": "g:List", "@value": [ { "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 1 }, "label": "person" } } ] }, "meta": { "@type": "g:Map", "@value": [] } }, "status": { "attributes": { "@type": "g:Map", "@value": [] }, "code": 200, "message": "" } }`) require.NoError(t, err) })) defer srv.Close() transport, err := NewHTTPTransport(srv.URL, nil) require.NoError(t, err) rsp, err := transport.RoundTrip(context.Background(), NewEvalRequest("g.V(1)")) require.NoError(t, err) assert.Equal(t, "f679127f-8701-425c-af55-049a44720db6", rsp.RequestID) assert.Equal(t, 200, rsp.Status.Code) assert.Empty(t, rsp.Status.Message) v := jsoniter.Get(rsp.Result.Data, graphson.ValueKey, 0, graphson.ValueKey) require.NoError(t, v.LastError()) assert.Equal(t, 1, v.Get("id", graphson.ValueKey).ToInt()) assert.Equal(t, "person", v.Get("label").ToString()) } func TestNewHTTPTransportBadURL(t *testing.T) { transport, err := NewHTTPTransport(":", nil) assert.Nil(t, transport) assert.Error(t, err) } func TestHTTPTransportBadRequest(t *testing.T) { transport, err := NewHTTPTransport("example.com", nil) require.NoError(t, err) req := NewEvalRequest("g.V()") req.Operation = "" rsp, err := transport.RoundTrip(context.Background(), req) assert.EqualError(t, err, `gremlin/http: unsupported operation: ""`) assert.Nil(t, rsp) req = NewEvalRequest("g.V()") delete(req.Arguments, ArgsGremlin) rsp, err = transport.RoundTrip(context.Background(), req) assert.EqualError(t, err, "gremlin/http: missing query expression") assert.Nil(t, rsp) } func TestHTTPTransportBadResponseStatus(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer srv.Close() transport, err := NewHTTPTransport(srv.URL, nil) require.NoError(t, err) _, err = transport.RoundTrip(context.Background(), NewEvalRequest("g.E().")) require.Error(t, err) assert.Contains(t, err.Error(), http.StatusText(http.StatusInternalServerError)) } func TestHTTPTransportBadResponseBody(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := io.WriteString(w, "{{{") require.NoError(t, err) })) defer srv.Close() transport, err := NewHTTPTransport(srv.URL, nil) require.NoError(t, err) _, err = transport.RoundTrip(context.Background(), NewEvalRequest("g.E().")) require.Error(t, err) assert.Contains(t, err.Error(), "decoding response") } ent-0.5.4/dialect/gremlin/internal/000077500000000000000000000000001377533537200172035ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/internal/ws/000077500000000000000000000000001377533537200176345ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/internal/ws/conn.go000066400000000000000000000172651377533537200211330ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ws import ( "bytes" "context" "io" "net/http" "sync" "time" "github.com/facebook/ent/dialect/gremlin" "github.com/facebook/ent/dialect/gremlin/encoding" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" "github.com/gorilla/websocket" "github.com/pkg/errors" "golang.org/x/sync/errgroup" ) const ( // Time allowed to write a message to the peer. writeWait = 5 * time.Second // Time allowed to read the next pong message from the peer. pongWait = 10 * time.Second // Send pings to peer with this period. Must be less than pongWait. pingPeriod = (pongWait * 9) / 10 ) type ( // A Dialer contains options for connecting to Gremlin server. Dialer struct { // Underlying websocket dialer. websocket.Dialer // Gremlin server basic auth credentials. user, pass string } // Conn performs operations on a gremlin server. Conn struct { // Underlying websocket connection. conn *websocket.Conn // Credentials for basic authentication. user, pass string // Goroutine tracking. ctx context.Context grp *errgroup.Group // Channel of outbound requests. send chan io.Reader // Map of in flight requests. inflight sync.Map } // inflight tracks request state. inflight struct { // partially received data frags []graphson.RawMessage // response channel result chan<- result } // represents an execution result. result struct { rsp *gremlin.Response err error } ) var ( // DefaultDialer is a dialer with all fields set to the default values. DefaultDialer = &Dialer{ Dialer: websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: 5 * time.Second, WriteBufferSize: 8192, ReadBufferSize: 8192, }, } // ErrConnClosed is returned by the Conns Execute method when // the underlying gremlin server connection is closed. ErrConnClosed = errors.New("gremlin: server connection closed") // ErrDuplicateRequest is returned by the Conns Execute method on // request identifier key collision. ErrDuplicateRequest = errors.New("gremlin: duplicate request") ) // Dial creates a new connection by calling DialContext with a background context. func (d *Dialer) Dial(uri string) (*Conn, error) { return d.DialContext(context.Background(), uri) } // DialContext creates a new Gremlin connection. func (d *Dialer) DialContext(ctx context.Context, uri string) (*Conn, error) { c, rsp, err := d.Dialer.DialContext(ctx, uri, nil) if err != nil { return nil, errors.Wrapf(err, "gremlin: dialing uri %s", uri) } defer rsp.Body.Close() conn := &Conn{ conn: c, user: d.user, pass: d.pass, send: make(chan io.Reader), } conn.grp, conn.ctx = errgroup.WithContext(context.Background()) conn.grp.Go(conn.sender) conn.grp.Go(conn.receiver) return conn, nil } // Execute executes a request against a Gremlin server. func (c *Conn) Execute(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { // buffered result channel prevents receiver block on context cancellation result := make(chan result, 1) // request id must be unique across inflight request if _, loaded := c.inflight.LoadOrStore(req.RequestID, &inflight{result: result}); loaded { return nil, ErrDuplicateRequest } pr, pw := io.Pipe() defer pr.Close() // stream graphson encoding into request c.grp.Go(func() error { err := graphson.NewEncoder(pw).Encode(req) if err != nil { err = errors.Wrap(err, "encoding request") } pw.CloseWithError(err) return err }) // local copy for single write send := c.send for { select { case <-c.ctx.Done(): c.inflight.Delete(req.RequestID) return nil, ErrConnClosed case <-ctx.Done(): c.inflight.Delete(req.RequestID) return nil, ctx.Err() case send <- pr: send = nil case result := <-result: return result.rsp, result.err } } } // Close connection with a Gremlin server. func (c *Conn) Close() error { c.grp.Go(func() error { return ErrConnClosed }) _ = c.grp.Wait() return nil } func (c *Conn) sender() error { pinger := time.NewTicker(pingPeriod) defer pinger.Stop() // closing connection terminates receiver defer c.conn.Close() for { select { case r := <-c.send: // ensure write completes within a window c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // fetch next message writer w, err := c.conn.NextWriter(websocket.BinaryMessage) if err != nil { return errors.Wrap(err, "getting message writer") } // write mime header if _, err := w.Write(encoding.GraphSON3Mime); err != nil { return errors.Wrap(err, "writing mime header") } // write request body if _, err := io.Copy(w, r); err != nil { return errors.Wrap(err, "writing request") } // finish message write if err := w.Close(); err != nil { return errors.Wrap(err, "closing message writer") } case <-c.ctx.Done(): // connection closing return c.conn.WriteControl( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}, ) case <-pinger.C: // periodic connection keepalive if err := c.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(writeWait)); err != nil { return errors.Wrap(err, "writing ping message") } } } } func (c *Conn) receiver() error { // handle keepalive responses c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { return c.conn.SetReadDeadline(time.Now().Add(pongWait)) }) // complete all in flight requests on termination defer c.inflight.Range(func(id, ifr interface{}) bool { ifr.(*inflight).result <- result{err: ErrConnClosed} c.inflight.Delete(id) return true }) for { // rely on sender connection close during termination _, r, err := c.conn.NextReader() if err != nil { return errors.Wrap(err, "getting next reader") } // decode received response var rsp gremlin.Response if err := graphson.NewDecoder(r).Decode(&rsp); err != nil { return errors.Wrap(err, "reading response") } ifr, ok := c.inflight.Load(rsp.RequestID) if !ok { // context cancellation aborts inflight requests continue } // handle incoming response if done := c.receive(ifr.(*inflight), &rsp); done { // stop tracking finished requests c.inflight.Delete(rsp.RequestID) } } } func (c *Conn) receive(ifr *inflight, rsp *gremlin.Response) bool { result := result{rsp: rsp} switch rsp.Status.Code { case gremlin.StatusSuccess: // quickly handle non fragmented responses if ifr.frags == nil { break } // handle fragment fallthrough case gremlin.StatusPartialContent: // append received fragment var frag []graphson.RawMessage if err := graphson.Unmarshal(rsp.Result.Data, &frag); err != nil { result.err = errors.Wrap(err, "decoding response fragment") break } ifr.frags = append(ifr.frags, frag...) // partial response requires additional fragments if rsp.Status.Code == gremlin.StatusPartialContent { return false } // reassemble fragmented response if rsp.Result.Data, result.err = graphson.Marshal(ifr.frags); result.err != nil { result.err = errors.Wrap(result.err, "assembling fragmented response") } case gremlin.StatusAuthenticate: // receiver should never block c.grp.Go(func() error { var buf bytes.Buffer if err := graphson.NewEncoder(&buf).Encode( gremlin.NewAuthRequest(rsp.RequestID, c.user, c.pass), ); err != nil { return errors.Wrap(err, "encoding auth request") } select { case c.send <- &buf: case <-c.ctx.Done(): } return c.ctx.Err() }) return false } ifr.result <- result return true } ent-0.5.4/dialect/gremlin/internal/ws/conn_test.go000066400000000000000000000226421377533537200221650ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ws import ( "context" "net/http" "net/http/httptest" "strconv" "sync" "testing" "github.com/facebook/ent/dialect/gremlin" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type conn struct{ *websocket.Conn } func (c conn) ReadRequest() (*gremlin.Request, error) { _, data, err := c.ReadMessage() if err != nil { return nil, err } var req gremlin.Request if err := graphson.Unmarshal(data[data[0]+1:], &req); err != nil { return nil, err } return &req, nil } func (c conn) WriteResponse(rsp *gremlin.Response) error { data, err := graphson.Marshal(rsp) if err != nil { return err } return c.WriteMessage(websocket.BinaryMessage, data) } func serve(handler func(conn)) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} c, _ := upgrader.Upgrade(w, r, nil) defer c.Close() handler(conn{c}) for { _, _, err := c.ReadMessage() if err != nil { break } } })) } func TestConnectClosure(t *testing.T) { var wg sync.WaitGroup wg.Add(1) defer wg.Wait() srv := serve(func(conn conn) { defer wg.Done() _, _, err := conn.ReadMessage() assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure)) }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) err = conn.Close() assert.NoError(t, err) _, err = conn.Execute(context.Background(), gremlin.NewEvalRequest("g.V()")) assert.EqualError(t, err, ErrConnClosed.Error()) } func TestSimpleQuery(t *testing.T) { srv := serve(func(conn conn) { typ, data, err := conn.ReadMessage() require.NoError(t, err) assert.Equal(t, websocket.BinaryMessage, typ) var req gremlin.Request err = graphson.Unmarshal(data[data[0]+1:], &req) require.NoError(t, err) assert.Equal(t, "g.V()", req.Arguments["gremlin"]) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusNoContent err = conn.WriteResponse(&rsp) require.NoError(t, err) }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer assert.Condition(t, func() bool { return assert.NoError(t, conn.Close()) }) rsp, err := conn.Execute(context.Background(), gremlin.NewEvalRequest("g.V()")) assert.NoError(t, err) require.NotNil(t, rsp) assert.Equal(t, gremlin.StatusNoContent, rsp.Status.Code) } func TestDuplicateRequest(t *testing.T) { // skip until flakiness will be fixed. t.SkipNow() srv := serve(func(conn conn) { req, err := conn.ReadRequest() require.NoError(t, err) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusNoContent err = conn.WriteResponse(&rsp) require.NoError(t, err) }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() var errors [2]error req := gremlin.NewEvalRequest("g.V()") var wg sync.WaitGroup wg.Add(len(errors)) for i := range errors { go func(i int) { _, errors[i] = conn.Execute(context.Background(), req) wg.Done() }(i) } wg.Wait() err = errors[0] if err == nil { err = errors[1] } assert.EqualError(t, err, ErrDuplicateRequest.Error()) } func TestConnectCancellation(t *testing.T) { srv := serve(func(conn) {}) defer srv.Close() ctx, cancel := context.WithCancel(context.Background()) cancel() conn, err := DefaultDialer.DialContext(ctx, "ws://"+srv.Listener.Addr().String()) assert.Error(t, err) assert.Nil(t, conn) } func TestQueryCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) srv := serve(func(conn conn) { if _, _, err := conn.ReadMessage(); err == nil { cancel() } }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() _, err = conn.Execute(ctx, gremlin.NewEvalRequest("g.E()")) assert.EqualError(t, err, context.Canceled.Error()) } func TestBadResponse(t *testing.T) { tests := []struct { name string mangle func(*gremlin.Response) *gremlin.Response }{ { name: "NoStatus", mangle: func(rsp *gremlin.Response) *gremlin.Response { return rsp }, }, { name: "Malformed", mangle: func(rsp *gremlin.Response) *gremlin.Response { rsp.Status.Code = gremlin.StatusMalformedRequest rsp.Status.Message = "bad request" return rsp }, }, { name: "Unknown", mangle: func(rsp *gremlin.Response) *gremlin.Response { rsp.Status.Code = 424242 return rsp }, }, } srv := serve(func(conn conn) { for { req, err := conn.ReadRequest() if err != nil { break } idx, err := strconv.ParseInt(req.Arguments["gremlin"].(string), 10, 0) require.NoError(t, err) err = conn.WriteResponse(tests[idx].mangle(&gremlin.Response{RequestID: req.RequestID})) require.NoError(t, err) } }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() var wg sync.WaitGroup wg.Add(len(tests)) ctx := context.Background() for i, tc := range tests { i, tc := i, tc t.Run(tc.name, func(t *testing.T) { defer wg.Done() rsp, err := conn.Execute(ctx, gremlin.NewEvalRequest(strconv.FormatInt(int64(i), 10))) assert.NoError(t, err) assert.True(t, rsp.IsErr()) }) } wg.Wait() } func TestServerHangup(t *testing.T) { // skip until flakiness will be fixed. t.SkipNow() srv := serve(func(conn conn) { _ = conn.Close() }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() _, err = conn.Execute(context.Background(), gremlin.NewEvalRequest("g.V()")) assert.EqualError(t, err, ErrConnClosed.Error()) assert.Error(t, conn.ctx.Err()) } func TestCanceledLongRequest(t *testing.T) { // skip until flakiness will be fixed. t.SkipNow() ctx, cancel := context.WithCancel(context.Background()) srv := serve(func(conn conn) { var responses [3]*gremlin.Response for i := 0; i < len(responses); i++ { req, err := conn.ReadRequest() require.NoError(t, err) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusSuccess rsp.Result.Data = graphson.RawMessage(`"ok"`) responses[i] = &rsp } cancel() responses[0], responses[2] = responses[2], responses[0] for i := 0; i < len(responses); i++ { err := conn.WriteResponse(responses[i]) require.NoError(t, err) } }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() var wg sync.WaitGroup wg.Add(3) defer wg.Wait() for i := 0; i < 3; i++ { go func(ctx context.Context, idx int) { defer wg.Done() rsp, err := conn.Execute(ctx, gremlin.NewEvalRequest("g.V()")) if idx > 0 { assert.NoError(t, err) assert.EqualValues(t, []byte(`"ok"`), rsp.Result.Data) } else { assert.EqualError(t, err, context.Canceled.Error()) } }(ctx, i) ctx = context.Background() } } func TestPartialResponse(t *testing.T) { type kv struct { Key string Value int } kvs := []kv{ {"one", 1}, {"two", 2}, {"three", 3}, } srv := serve(func(conn conn) { req, err := conn.ReadRequest() require.NoError(t, err) for i := range kvs { data, err := graphson.Marshal([]kv{kvs[i]}) require.NoError(t, err) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Result.Data = graphson.RawMessage(data) if i != len(kvs)-1 { rsp.Status.Code = gremlin.StatusPartialContent } else { rsp.Status.Code = gremlin.StatusSuccess } err = conn.WriteResponse(&rsp) require.NoError(t, err) } }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() rsp, err := conn.Execute(context.Background(), gremlin.NewEvalRequest("g.E()")) assert.NoError(t, err) var result []kv err = graphson.Unmarshal(rsp.Result.Data, &result) require.NoError(t, err) assert.Equal(t, kvs, result) } func TestAuthentication(t *testing.T) { user, pass := "username", "password" srv := serve(func(conn conn) { req, err := conn.ReadRequest() require.NoError(t, err) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusAuthenticate err = conn.WriteResponse(&rsp) require.NoError(t, err) areq, err := conn.ReadRequest() require.NoError(t, err) var acreds gremlin.Credentials err = acreds.UnmarshalText([]byte(areq.Arguments["sasl"].(string))) assert.NoError(t, err) areq.Arguments["sasl"] = acreds assert.Equal(t, gremlin.NewAuthRequest(req.RequestID, user, pass), areq) rsp = gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusNoContent err = conn.WriteResponse(&rsp) require.NoError(t, err) }) defer srv.Close() dialer := *DefaultDialer dialer.user = user dialer.pass = pass client, err := dialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer client.Close() _, err = client.Execute(context.Background(), gremlin.NewEvalRequest("g.E().drop()")) assert.NoError(t, err) } ent-0.5.4/dialect/gremlin/ocgremlin/000077500000000000000000000000001377533537200173465ustar00rootroot00000000000000ent-0.5.4/dialect/gremlin/ocgremlin/client.go000066400000000000000000000041431377533537200211550ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "github.com/facebook/ent/dialect/gremlin" "go.opencensus.io/trace" ) // Transport is an gremlin.RoundTripper that instruments all outgoing requests with // OpenCensus stats and tracing. type Transport struct { // Base is a wrapped gremlin.RoundTripper that does the actual requests. Base gremlin.RoundTripper // StartOptions are applied to the span started by this Transport around each // request. // // StartOptions.SpanKind will always be set to trace.SpanKindClient // for spans started by this transport. StartOptions trace.StartOptions // GetStartOptions allows to set start options per request. If set, // StartOptions is going to be ignored. GetStartOptions func(context.Context, *gremlin.Request) trace.StartOptions // NameFromRequest holds the function to use for generating the span name // from the information found in the outgoing Gremlin Request. By default the // name equals the URL Path. FormatSpanName func(context.Context, *gremlin.Request) string // WithQuery, if set to true, will enable recording of gremlin queries in spans. // Only allow this if it is safe to have queries recorded with respect to // security. WithQuery bool } // RoundTrip implements gremlin.RoundTripper, delegating to Base and recording stats and traces for the request. func (t *Transport) RoundTrip(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { spanNameFormatter := t.FormatSpanName if spanNameFormatter == nil { spanNameFormatter = func(context.Context, *gremlin.Request) string { return "gremlin:traversal" } } startOpts := t.StartOptions if t.GetStartOptions != nil { startOpts = t.GetStartOptions(ctx, req) } var rt gremlin.RoundTripper = &traceTransport{ base: t.Base, formatSpanName: spanNameFormatter, startOptions: startOpts, withQuery: t.WithQuery, } rt = statsTransport{rt} return rt.RoundTrip(ctx, req) } ent-0.5.4/dialect/gremlin/ocgremlin/client_test.go000066400000000000000000000032771377533537200222230ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "errors" "testing" "github.com/facebook/ent/dialect/gremlin" "github.com/stretchr/testify/mock" "go.opencensus.io/trace" ) type mockExporter struct { mock.Mock } func (e *mockExporter) ExportSpan(s *trace.SpanData) { e.Called(s) } func TestTransportOptions(t *testing.T) { tests := []struct { name string spanName string wantName string }{ { name: "Default formatter", wantName: "gremlin:traversal", }, { name: "Custom formatter", spanName: "tester", wantName: "tester", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { var exporter mockExporter exporter.On( "ExportSpan", mock.MatchedBy(func(s *trace.SpanData) bool { return s.Name == tt.wantName })). Once() defer exporter.AssertExpectations(t) trace.RegisterExporter(&exporter) defer trace.UnregisterExporter(&exporter) transport := &mockTransport{} transport.On("RoundTrip", mock.Anything, mock.Anything). Return(nil, errors.New("noop")). Once() defer transport.AssertExpectations(t) rt := &Transport{ Base: transport, GetStartOptions: func(context.Context, *gremlin.Request) trace.StartOptions { return trace.StartOptions{Sampler: trace.AlwaysSample()} }, } if tt.spanName != "" { rt.FormatSpanName = func(context.Context, *gremlin.Request) string { return tt.spanName } } _, _ = rt.RoundTrip(context.Background(), gremlin.NewEvalRequest("g.E()")) }) } } ent-0.5.4/dialect/gremlin/ocgremlin/stats.go000066400000000000000000000067211377533537200210410ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "strconv" "time" "github.com/facebook/ent/dialect/gremlin" "go.opencensus.io/stats" "go.opencensus.io/stats/view" "go.opencensus.io/tag" ) // The following measures are supported for use in custom views. var ( RequestCount = stats.Int64( "gremlin/request_count", "Number of Gremlin requests started", stats.UnitDimensionless, ) ResponseBytes = stats.Int64( "gremlin/response_bytes", "Total number of bytes in response data", stats.UnitBytes, ) RoundTripLatency = stats.Float64( "gremlin/roundtrip_latency", "End-to-end latency", stats.UnitMilliseconds, ) ) // The following tags are applied to stats recorded by this package. var ( // StatusCode is the numeric Gremlin response status code, // or "error" if a transport error occurred and no status code was read. StatusCode, _ = tag.NewKey("gremlin_status_code") ) // Default distributions used by views in this package. var ( DefaultSizeDistribution = view.Distribution(32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576) DefaultLatencyDistribution = view.Distribution(1, 2, 3, 4, 5, 6, 8, 10, 13, 16, 20, 25, 30, 40, 50, 65, 80, 100, 130, 160, 200, 250, 300, 400, 500, 650, 800, 1000, 2000, 5000, 10000, 20000, 50000, 100000) ) // Package ocgremlin provides some convenience views for measures. // You still need to register these views for data to actually be collected. var ( RequestCountView = &view.View{ Name: "gremlin/request_count", Measure: RequestCount, Aggregation: view.Count(), Description: "Count of Gremlin requests started", } ResponseCountView = &view.View{ Name: "gremlin/response_count", Measure: RoundTripLatency, Aggregation: view.Count(), Description: "Count of responses received, by response status", TagKeys: []tag.Key{StatusCode}, } ResponseBytesView = &view.View{ Name: "gremlin/response_bytes", Measure: ResponseBytes, Aggregation: DefaultSizeDistribution, Description: "Total number of bytes in response data", } RoundTripLatencyView = &view.View{ Name: "gremlin/roundtrip_latency", Measure: RoundTripLatency, Aggregation: DefaultLatencyDistribution, Description: "End-to-end latency, by response code", TagKeys: []tag.Key{StatusCode}, } ) // Views are the default views provided by this package. func Views() []*view.View { return []*view.View{ RequestCountView, ResponseCountView, ResponseBytesView, RoundTripLatencyView, } } // statsTransport is an gremlin.RoundTripper that collects stats for the outgoing requests. type statsTransport struct { base gremlin.RoundTripper } func (t statsTransport) RoundTrip(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { stats.Record(ctx, RequestCount.M(1)) start := time.Now() rsp, err := t.base.RoundTrip(ctx, req) latency := float64(time.Since(start)) / float64(time.Millisecond) var ( tags = make([]tag.Mutator, 1) ms = []stats.Measurement{RoundTripLatency.M(latency)} ) if err == nil { tags[0] = tag.Upsert(StatusCode, strconv.Itoa(rsp.Status.Code)) ms = append(ms, ResponseBytes.M(int64(len(rsp.Result.Data)))) } else { tags[0] = tag.Upsert(StatusCode, "error") } _ = stats.RecordWithTags(ctx, tags, ms...) return rsp, err } ent-0.5.4/dialect/gremlin/ocgremlin/stats_test.go000066400000000000000000000045451377533537200221020ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "strings" "testing" "github.com/facebook/ent/dialect/gremlin" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opencensus.io/stats/view" ) func TestStatsCollection(t *testing.T) { err := view.Register( RequestCountView, ResponseCountView, ResponseBytesView, RoundTripLatencyView, ) require.NoError(t, err) req := gremlin.NewEvalRequest("g.E()") rsp := &gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusSuccess rsp.Result.Data = graphson.RawMessage( `{"@type": "g:List", "@value": [{"@type": "g:Int32", "@value": 42}]}`, ) transport := &mockTransport{} transport.On("RoundTrip", mock.Anything, mock.Anything). Return(rsp, nil). Once() defer transport.AssertExpectations(t) rt := &statsTransport{transport} _, _ = rt.RoundTrip(context.Background(), req) tests := []struct { name string expect func(*testing.T, *view.Row) }{ { name: "gremlin/request_count", expect: func(t *testing.T, row *view.Row) { count, ok := row.Data.(*view.CountData) require.True(t, ok) assert.Equal(t, int64(1), count.Value) }, }, { name: "gremlin/response_count", expect: func(t *testing.T, row *view.Row) { count, ok := row.Data.(*view.CountData) require.True(t, ok) assert.Equal(t, int64(1), count.Value) }, }, { name: "gremlin/response_bytes", expect: func(t *testing.T, row *view.Row) { data, ok := row.Data.(*view.DistributionData) require.True(t, ok) assert.EqualValues(t, len(rsp.Result.Data), data.Sum()) }, }, { name: "gremlin/roundtrip_latency", expect: func(t *testing.T, row *view.Row) { data, ok := row.Data.(*view.DistributionData) require.True(t, ok) assert.NotZero(t, data.Sum()) }, }, } for _, tt := range tests { tt := tt t.Run(tt.name[strings.Index(tt.name, "/")+1:], func(t *testing.T) { v := view.Find(tt.name) assert.NotNil(t, v) rows, err := view.RetrieveData(tt.name) require.NoError(t, err) require.Len(t, rows, 1) tt.expect(t, rows[0]) }) } } ent-0.5.4/dialect/gremlin/ocgremlin/trace.go000066400000000000000000000074411377533537200210010ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "fmt" "github.com/facebook/ent/dialect/gremlin" "go.opencensus.io/trace" ) // Attributes recorded on the span for the requests. const ( RequestIDAttribute = "gremlin.request_id" OperationAttribute = "gremlin.operation" QueryAttribute = "gremlin.query" BindingAttribute = "gremlin.binding" CodeAttribute = "gremlin.code" MessageAttribute = "gremlin.message" ) type traceTransport struct { base gremlin.RoundTripper startOptions trace.StartOptions formatSpanName func(context.Context, *gremlin.Request) string withQuery bool } func (t *traceTransport) RoundTrip(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { ctx, span := trace.StartSpan(ctx, t.formatSpanName(ctx, req), trace.WithSampler(t.startOptions.Sampler), trace.WithSpanKind(trace.SpanKindClient), ) defer span.End() span.AddAttributes(requestAttrs(req, t.withQuery)...) rsp, err := t.base.RoundTrip(ctx, req) if err != nil { span.SetStatus(trace.Status{Code: trace.StatusCodeUnknown, Message: err.Error()}) return rsp, err } span.AddAttributes(responseAttrs(rsp)...) span.SetStatus(TraceStatus(rsp.Status.Code)) return rsp, err } func requestAttrs(req *gremlin.Request, withQuery bool) []trace.Attribute { attrs := []trace.Attribute{ trace.StringAttribute(RequestIDAttribute, req.RequestID), trace.StringAttribute(OperationAttribute, req.Operation), } if withQuery { query, _ := req.Arguments[gremlin.ArgsGremlin].(string) attrs = append(attrs, trace.StringAttribute(QueryAttribute, query)) if bindings, ok := req.Arguments[gremlin.ArgsBindings].(map[string]interface{}); ok { attrs = append(attrs, bindingsAttrs(bindings)...) } } return attrs } func bindingsAttrs(bindings map[string]interface{}) []trace.Attribute { attrs := make([]trace.Attribute, 0, len(bindings)) for key, val := range bindings { key = BindingAttribute + "." + key attrs = append(attrs, bindingToAttr(key, val)) } return attrs } func bindingToAttr(key string, val interface{}) trace.Attribute { switch v := val.(type) { case nil: return trace.StringAttribute(key, "") case int64: return trace.Int64Attribute(key, v) case float64: return trace.Float64Attribute(key, v) case string: return trace.StringAttribute(key, v) case bool: return trace.BoolAttribute(key, v) default: s := fmt.Sprintf("%v", v) if len(s) > 256 { s = s[:256] } return trace.StringAttribute(key, s) } } func responseAttrs(rsp *gremlin.Response) []trace.Attribute { attrs := []trace.Attribute{ trace.Int64Attribute(CodeAttribute, int64(rsp.Status.Code)), } if rsp.Status.Message != "" { attrs = append(attrs, trace.StringAttribute(MessageAttribute, rsp.Status.Message)) } return attrs } // TraceStatus is a utility to convert the gremlin status code to a trace.Status. func TraceStatus(status int) trace.Status { var code int32 switch status { case gremlin.StatusSuccess, gremlin.StatusNoContent, gremlin.StatusPartialContent: code = trace.StatusCodeOK case gremlin.StatusUnauthorized: code = trace.StatusCodePermissionDenied case gremlin.StatusAuthenticate: code = trace.StatusCodeUnauthenticated case gremlin.StatusMalformedRequest, gremlin.StatusInvalidRequestArguments, gremlin.StatusScriptEvaluationError: code = trace.StatusCodeInvalidArgument case gremlin.StatusServerError, gremlin.StatusServerSerializationError: code = trace.StatusCodeInternal case gremlin.StatusServerTimeout: code = trace.StatusCodeDeadlineExceeded default: code = trace.StatusCodeUnknown } return trace.Status{Code: code, Message: gremlin.StatusText(status)} } ent-0.5.4/dialect/gremlin/ocgremlin/trace_test.go000066400000000000000000000166031377533537200220400ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "bytes" "context" "errors" "fmt" "testing" "github.com/facebook/ent/dialect/gremlin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opencensus.io/trace" ) type mockTransport struct { mock.Mock } func (t *mockTransport) RoundTrip(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { args := t.Called(ctx, req) rsp, _ := args.Get(0).(*gremlin.Response) return rsp, args.Error(1) } func TestTraceTransportRoundTrip(t *testing.T) { _, parent := trace.StartSpan(context.Background(), "parent") tests := []struct { name string parent *trace.Span }{ { name: "no parent", parent: nil, }, { name: "parent", parent: parent, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { transport := &mockTransport{} transport.On("RoundTrip", mock.Anything, mock.Anything). Run(func(args mock.Arguments) { span := trace.FromContext(args.Get(0).(context.Context)) require.NotNil(t, span) if tt.parent != nil { assert.Equal(t, tt.parent.SpanContext().TraceID, span.SpanContext().TraceID) } }). Return(nil, errors.New("noop")). Once() defer transport.AssertExpectations(t) ctx, req := context.Background(), gremlin.NewEvalRequest("g.V()") if tt.parent != nil { ctx = trace.NewContext(ctx, tt.parent) } rt := &Transport{Base: transport} _, _ = rt.RoundTrip(ctx, req) }) } } type testExporter struct { spans []*trace.SpanData } func (t *testExporter) ExportSpan(s *trace.SpanData) { t.spans = append(t.spans, s) } func TestEndToEnd(t *testing.T) { trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()}) var exporter testExporter trace.RegisterExporter(&exporter) defer trace.UnregisterExporter(&exporter) req := gremlin.NewEvalRequest("g.V()") rsp := &gremlin.Response{ RequestID: req.RequestID, } rsp.Status.Code = 200 rsp.Status.Message = "OK" var transport mockTransport transport.On("RoundTrip", mock.Anything, mock.Anything). Return(rsp, nil). Once() defer transport.AssertExpectations(t) rt := &Transport{Base: &transport, WithQuery: true} _, err := rt.RoundTrip(context.Background(), req) require.NoError(t, err) require.Len(t, exporter.spans, 1) attrs := exporter.spans[0].Attributes assert.Len(t, attrs, 5) assert.Equal(t, req.RequestID, attrs["gremlin.request_id"]) assert.Equal(t, req.Operation, attrs["gremlin.operation"]) assert.Equal(t, req.Arguments[gremlin.ArgsGremlin], attrs["gremlin.query"]) assert.Equal(t, int64(200), attrs["gremlin.code"]) assert.Equal(t, "OK", attrs["gremlin.message"]) } func TestRequestAttributes(t *testing.T) { tests := []struct { name string makeReq func() *gremlin.Request wantAttrs []trace.Attribute }{ { name: "Query without bindings", makeReq: func() *gremlin.Request { req := gremlin.NewEvalRequest("g.E().count()") req.RequestID = "a8b5c664-03ca-4175-a9e7-569b46f3551c" return req }, wantAttrs: []trace.Attribute{ trace.StringAttribute("gremlin.request_id", "a8b5c664-03ca-4175-a9e7-569b46f3551c"), trace.StringAttribute("gremlin.operation", "eval"), trace.StringAttribute("gremlin.query", "g.E().count()"), }, }, { name: "Query with bindings", makeReq: func() *gremlin.Request { bindings := map[string]interface{}{ "$1": "user", "$2": int64(42), "$3": 3.14, "$4": bytes.Repeat([]byte{0xff}, 257), "$5": true, "$6": nil, } req := gremlin.NewEvalRequest( `g.V().hasLabel($1).has("age",$2).has("v",$3).limit($4).valueMap($5)`, gremlin.WithBindings(bindings), ) req.RequestID = "d3d986fa-bd22-41bd-b2f7-ef2f1f639260" return req }, wantAttrs: []trace.Attribute{ trace.StringAttribute("gremlin.request_id", "d3d986fa-bd22-41bd-b2f7-ef2f1f639260"), trace.StringAttribute("gremlin.operation", "eval"), trace.StringAttribute("gremlin.query", `g.V().hasLabel($1).has("age",$2).has("v",$3).limit($4).valueMap($5)`), trace.StringAttribute("gremlin.binding.$1", "user"), trace.Int64Attribute("gremlin.binding.$2", 42), trace.Float64Attribute("gremlin.binding.$3", 3.14), trace.StringAttribute("gremlin.binding.$4", func() string { str := fmt.Sprintf("%v", bytes.Repeat([]byte{0xff}, 256)) return str[:256] }()), trace.BoolAttribute("gremlin.binding.$5", true), trace.StringAttribute("gremlin.binding.$6", ""), }, }, { name: "Authentication", makeReq: func() *gremlin.Request { return gremlin.NewAuthRequest( "d239d950-59a1-41a7-a103-908f976ebd89", "user", "pass", ) }, wantAttrs: []trace.Attribute{ trace.StringAttribute("gremlin.request_id", "d239d950-59a1-41a7-a103-908f976ebd89"), trace.StringAttribute("gremlin.operation", "authentication"), trace.StringAttribute("gremlin.query", ""), }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { req := tt.makeReq() attrs := requestAttrs(req, true) for _, attr := range attrs { assert.Contains(t, tt.wantAttrs, attr) } assert.Len(t, attrs, len(tt.wantAttrs)) }) } } func TestResponseAttributes(t *testing.T) { tests := []struct { name string makeRsp func() *gremlin.Response wantAttrs []trace.Attribute }{ { name: "Success no message", makeRsp: func() *gremlin.Response { var rsp gremlin.Response rsp.Status.Code = 204 return &rsp }, wantAttrs: []trace.Attribute{ trace.Int64Attribute("gremlin.code", 204), }, }, { name: "Authenticate with message", makeRsp: func() *gremlin.Response { var rsp gremlin.Response rsp.Status.Code = 407 rsp.Status.Message = "login required" return &rsp }, wantAttrs: []trace.Attribute{ trace.Int64Attribute("gremlin.code", 407), trace.StringAttribute("gremlin.message", "login required"), }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { rsp := tt.makeRsp() attrs := responseAttrs(rsp) assert.Equal(t, tt.wantAttrs, attrs) }) } } func TestTraceStatus(t *testing.T) { tests := []struct { in int want trace.Status }{ {200, trace.Status{Code: trace.StatusCodeOK, Message: "Success"}}, {204, trace.Status{Code: trace.StatusCodeOK, Message: "No Content"}}, {206, trace.Status{Code: trace.StatusCodeOK, Message: "Partial Content"}}, {401, trace.Status{Code: trace.StatusCodePermissionDenied, Message: "Unauthorized"}}, {407, trace.Status{Code: trace.StatusCodeUnauthenticated, Message: "Authenticate"}}, {498, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: "Malformed Request"}}, {499, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: "Invalid Request Arguments"}}, {500, trace.Status{Code: trace.StatusCodeInternal, Message: "Server Error"}}, {597, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: "Script Evaluation Error"}}, {598, trace.Status{Code: trace.StatusCodeDeadlineExceeded, Message: "Server Timeout"}}, {599, trace.Status{Code: trace.StatusCodeInternal, Message: "Server Serialization Error"}}, {600, trace.Status{Code: trace.StatusCodeUnknown, Message: ""}}, } for _, tt := range tests { assert.Equal(t, tt.want, TraceStatus(tt.in)) } } ent-0.5.4/dialect/gremlin/request.go000066400000000000000000000052531377533537200174130ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "bytes" "encoding/base64" "time" "github.com/google/uuid" "github.com/pkg/errors" ) type ( // A Request models a request message sent to the server. Request struct { RequestID string `json:"requestId" graphson:"g:UUID"` Operation string `json:"op"` Processor string `json:"processor"` Arguments map[string]interface{} `json:"args"` } // RequestOption enables request customization. RequestOption func(*Request) // Credentials holds request plain auth credentials. Credentials struct{ Username, Password string } ) // NewEvalRequest returns a new evaluation request request. func NewEvalRequest(query string, opts ...RequestOption) *Request { r := &Request{ RequestID: uuid.New().String(), Operation: OpsEval, Arguments: map[string]interface{}{ ArgsGremlin: query, ArgsLanguage: "gremlin-groovy", }, } for i := range opts { opts[i](r) } return r } // NewAuthRequest returns a new auth request. func NewAuthRequest(requestID, username, password string) *Request { return &Request{ RequestID: requestID, Operation: OpsAuthentication, Arguments: map[string]interface{}{ ArgsSasl: Credentials{ Username: username, Password: password, }, ArgsSaslMechanism: "PLAIN", }, } } // WithBindings sets request bindings. func WithBindings(bindings map[string]interface{}) RequestOption { return func(r *Request) { r.Arguments[ArgsBindings] = bindings } } // WithEvalTimeout sets script evaluation timeout. func WithEvalTimeout(timeout time.Duration) RequestOption { return func(r *Request) { r.Arguments[ArgsEvalTimeout] = int64(timeout / time.Millisecond) } } // MarshalText implements encoding.TextMarshaler interface. func (c Credentials) MarshalText() ([]byte, error) { var buf bytes.Buffer buf.WriteByte(0) buf.WriteString(c.Username) buf.WriteByte(0) buf.WriteString(c.Password) enc := base64.StdEncoding text := make([]byte, enc.EncodedLen(buf.Len())) enc.Encode(text, buf.Bytes()) return text, nil } // UnmarshalText implements encoding.TextUnmarshaler interface. func (c *Credentials) UnmarshalText(text []byte) error { enc := base64.StdEncoding data := make([]byte, enc.DecodedLen(len(text))) n, err := enc.Decode(data, text) if err != nil { return err } data = data[:n] parts := bytes.SplitN(data, []byte{0}, 3) if len(parts) != 3 { return errors.New("bad credentials data") } c.Username = string(parts[1]) c.Password = string(parts[2]) return nil } ent-0.5.4/dialect/gremlin/request_test.go000066400000000000000000000061611377533537200204510ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "encoding/json" "testing" "time" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEvaluateRequestEncode(t *testing.T) { req := NewEvalRequest("g.V(x)", WithBindings(map[string]interface{}{"x": 1}), WithEvalTimeout(time.Second), ) data, err := graphson.Marshal(req) require.NoError(t, err) var got map[string]interface{} err = json.Unmarshal(data, &got) require.NoError(t, err) assert.Equal(t, map[string]interface{}{ "@type": "g:UUID", "@value": req.RequestID, }, got["requestId"]) assert.Equal(t, req.Operation, got["op"]) assert.Equal(t, req.Processor, got["processor"]) args := got["args"].(map[string]interface{}) assert.Equal(t, "g:Map", args["@type"]) assert.ElementsMatch(t, args["@value"], []interface{}{ "gremlin", "g.V(x)", "language", "gremlin-groovy", "scriptEvaluationTimeout", map[string]interface{}{ "@type": "g:Int64", "@value": float64(1000), }, "bindings", map[string]interface{}{ "@type": "g:Map", "@value": []interface{}{ "x", map[string]interface{}{ "@type": "g:Int64", "@value": float64(1), }, }, }, }) } func TestEvaluateRequestWithoutBindingsEncode(t *testing.T) { req := NewEvalRequest("g.E()") got, err := graphson.MarshalToString(req) require.NoError(t, err) assert.NotContains(t, got, "bindings") } func TestAuthenticateRequestEncode(t *testing.T) { req := NewAuthRequest("41d2e28a-20a4-4ab0-b379-d810dede3786", "user", "pass") data, err := graphson.Marshal(req) require.NoError(t, err) var got map[string]interface{} err = json.Unmarshal(data, &got) require.NoError(t, err) assert.Equal(t, map[string]interface{}{ "@type": "g:UUID", "@value": req.RequestID, }, got["requestId"]) assert.Equal(t, req.Operation, got["op"]) assert.Equal(t, req.Processor, got["processor"]) args := got["args"].(map[string]interface{}) assert.Equal(t, "g:Map", args["@type"]) assert.ElementsMatch(t, args["@value"], []interface{}{ "sasl", "AHVzZXIAcGFzcw==", "saslMechanism", "PLAIN", }) } func TestCredentialsMarshaling(t *testing.T) { want := Credentials{ Username: "username", Password: "password", } text, err := want.MarshalText() assert.NoError(t, err) assert.Equal(t, "AHVzZXJuYW1lAHBhc3N3b3Jk", string(text)) var got Credentials err = got.UnmarshalText(text) assert.NoError(t, err) assert.Equal(t, want, got) } func TestCredentialsBadEncodingMarshaling(t *testing.T) { tests := []struct { name string text []byte }{ { name: "BadBase64", text: []byte{0x12}, }, { name: "Empty", text: []byte{}, }, { name: "BadPrefix", text: []byte("Kg=="), }, { name: "NoSeperator", text: []byte("AHVzZXI="), }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { var creds Credentials err := creds.UnmarshalText(tc.text) assert.Error(t, err) }) } } ent-0.5.4/dialect/gremlin/response.go000066400000000000000000000064211377533537200175570ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "github.com/facebook/ent/dialect/gremlin/encoding/graphson" "github.com/facebook/ent/dialect/gremlin/graph" "github.com/pkg/errors" ) // A Response models a response message received from the server. type Response struct { RequestID string `json:"requestId" graphson:"g:UUID"` Status struct { Code int `json:"code"` Attributes map[string]interface{} `json:"attributes"` Message string `json:"message"` } `json:"status"` Result struct { Data graphson.RawMessage `json:"data"` Meta map[string]interface{} `json:"meta"` } `json:"result"` } // IsErr returns whether response indicates an error. func (rsp *Response) IsErr() bool { switch rsp.Status.Code { case StatusSuccess, StatusNoContent, StatusPartialContent: return false default: return true } } // Err returns an error representing response status. func (rsp *Response) Err() error { if rsp.IsErr() { return errors.Errorf("gremlin: code=%d, message=%q", rsp.Status.Code, rsp.Status.Message) } return nil } // ReadVal reads gremlin response data into v. func (rsp *Response) ReadVal(v interface{}) error { if err := rsp.Err(); err != nil { return err } if err := graphson.Unmarshal(rsp.Result.Data, v); err != nil { return errors.Wrapf(err, "gremlin: unmarshal response data: type=%T", v) } return nil } // ReadVertices returns response data as slice of vertices. func (rsp *Response) ReadVertices() ([]graph.Vertex, error) { var v []graph.Vertex err := rsp.ReadVal(&v) return v, err } // ReadVertexProperties returns response data as slice of vertex properties. func (rsp *Response) ReadVertexProperties() ([]graph.VertexProperty, error) { var vp []graph.VertexProperty err := rsp.ReadVal(&vp) return vp, err } // ReadEdges returns response data as slice of edges. func (rsp *Response) ReadEdges() ([]graph.Edge, error) { var e []graph.Edge err := rsp.ReadVal(&e) return e, err } // ReadProperties returns response data as slice of properties. func (rsp *Response) ReadProperties() ([]graph.Property, error) { var p []graph.Property err := rsp.ReadVal(&p) return p, err } // ReadValueMap returns response data as a value map. func (rsp *Response) ReadValueMap() (graph.ValueMap, error) { var m graph.ValueMap err := rsp.ReadVal(&m) return m, err } // ReadBool returns response data as a bool. func (rsp *Response) ReadBool() (bool, error) { var b [1]*bool if err := rsp.ReadVal(&b); err != nil { return false, err } if b[0] == nil { return false, errors.New("gremlin: no boolean value") } return *b[0], nil } // ReadInt returns response data as an int. func (rsp *Response) ReadInt() (int, error) { var v [1]*int if err := rsp.ReadVal(&v); err != nil { return 0, err } if v[0] == nil { return 0, errors.New("gremlin: no integer value") } return *v[0], nil } // ReadString returns response data as a string. func (rsp *Response) ReadString() (string, error) { var v [1]*string if err := rsp.ReadVal(&v); err != nil { return "", err } if v[0] == nil { return "", errors.New("gremlin: no string value") } return *v[0], nil } ent-0.5.4/dialect/gremlin/response_test.go000066400000000000000000000231121377533537200206120ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "reflect" "testing" "github.com/facebook/ent/dialect/gremlin/encoding/graphson" "github.com/facebook/ent/dialect/gremlin/graph" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDecodeResponse(t *testing.T) { in := `{ "requestId": "a65f2d39-1efa-45d2-a06a-c736476500fc", "result": { "data": { "@type": "g:List", "@value": [ { "@type": "g:Map", "@value": [ { "@type": "g:T", "@value": "id" }, { "@type": "g:Int64", "@value": 1 }, { "@type": "g:T", "@value": "label" }, "person", "name", { "@type": "g:List", "@value": [ "marko" ] }, "age", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 29 } ] } ] }, { "@type": "g:Map", "@value": [ { "@type": "g:T", "@value": "id" }, { "@type": "g:Int64", "@value": 6 }, { "@type": "g:T", "@value": "label" }, "person", "name", { "@type": "g:List", "@value": [ "peter" ] }, "age", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 35 } ] } ] } ] }, "meta": { "@type": "g:Map", "@value": [] } }, "status": { "attributes": { "@type": "g:Map", "@value": [] }, "code": 200, "message": "" } }` var rsp Response err := graphson.UnmarshalFromString(in, &rsp) require.NoError(t, err) assert.Equal(t, "a65f2d39-1efa-45d2-a06a-c736476500fc", rsp.RequestID) assert.Equal(t, 200, rsp.Status.Code) assert.Empty(t, rsp.Status.Message) assert.Empty(t, rsp.Status.Attributes) assert.Empty(t, rsp.Result.Meta) var vm graph.ValueMap err = graphson.Unmarshal(rsp.Result.Data, &vm) require.NoError(t, err) require.Len(t, vm, 2) type person struct { ID int64 `json:"id"` Name string `json:"name"` Age int `json:"age"` } var people []person err = vm.Decode(&people) require.NoError(t, err) assert.Equal(t, []person{ {1, "marko", 29}, {6, "peter", 35}, }, people) } func TestDecodeResponseWithError(t *testing.T) { in := `{ "requestId": "41d2e28a-20a4-4ab0-b379-d810dede3786", "result": { "data": null, "meta": { "@type": "g:Map", "@value": [] } }, "status": { "attributes": { "@type": "g:Map", "@value": [] }, "code": 500, "message": "Database Down" } }` var rsp Response err := graphson.UnmarshalFromString(in, &rsp) require.NoError(t, err) err = rsp.Err() require.Error(t, err) assert.Contains(t, err.Error(), "Database Down") rsp = Response{} err = graphson.UnmarshalFromString(`{"status": null}`, &rsp) require.NoError(t, err) assert.Error(t, rsp.Err()) } func TestResponseReadVal(t *testing.T) { var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(`{"@type": "g:Int32", "@value": 15}`) var v int32 err := rsp.ReadVal(&v) assert.NoError(t, err) assert.Equal(t, int32(15), v) var s string err = rsp.ReadVal(&s) assert.Error(t, err) rsp.Status.Code = StatusServerError err = rsp.ReadVal(&v) assert.Error(t, err) } func TestResponseReadGraphElements(t *testing.T) { tests := []struct { method string data string want interface{} }{ { method: "ReadVertices", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 1 }, "label": "person" } }, { "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 6 }, "label": "person" } } ] }`, want: []graph.Vertex{ graph.NewVertex(int64(1), "person"), graph.NewVertex(int64(6), "person"), }, }, { method: "ReadVertexProperties", data: `{ "@type": "g:List", "@value": [ { "@type": "g:VertexProperty", "@value": { "id": { "@type": "g:Int64", "@value": 0 }, "label": "name", "value": "marko" } }, { "@type": "g:VertexProperty", "@value": { "id": { "@type": "g:Int64", "@value": 2 }, "label": "age", "value": { "@type": "g:Int32", "@value": 29 } } } ] }`, want: []graph.VertexProperty{ graph.NewVertexProperty(int64(0), "name", "marko"), graph.NewVertexProperty(int64(2), "age", int32(29)), }, }, { method: "ReadEdges", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Edge", "@value": { "id": { "@type": "g:Int32", "@value": 12 }, "inV": { "@type": "g:Int64", "@value": 3 }, "inVLabel": "software", "label": "created", "outV": { "@type": "g:Int64", "@value": 6 }, "outVLabel": "person" } } ] }`, want: []graph.Edge{ graph.NewEdge(int32(12), "created", graph.NewVertex(int64(6), "person"), graph.NewVertex(int64(3), "software"), ), }, }, { method: "ReadProperties", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Property", "@value": { "key": "weight", "value": { "@type": "g:Double", "@value": 0.2 } } } ] }`, want: []graph.Property{ graph.NewProperty("weight", float64(0.2)), }, }, } for _, tc := range tests { tc := tc t.Run(tc.method, func(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(tc.data) vals := reflect.ValueOf(&rsp).MethodByName(tc.method).Call(nil) require.Len(t, vals, 2) require.True(t, vals[1].IsNil()) assert.Equal(t, tc.want, vals[0].Interface()) }) } } func TestResponseReadValueMap(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(`{ "@type": "g:List", "@value": [ { "@type": "g:Map", "@value": [ "name", { "@type": "g:List", "@value": [ "alex" ] } ] } ] }`) m, err := rsp.ReadValueMap() require.NoError(t, err) var name string err = m.Decode(&struct { Name *string `json:"name"` }{&name}) require.NoError(t, err) assert.Equal(t, "alex", name) } func TestResponseReadBool(t *testing.T) { tests := []struct { name string data string want bool wantErr bool }{ { name: "Simple", data: `{ "@type": "g:List", "@value": [ true ] }`, want: true, }, { name: "Multi", data: `{ "@type": "g:List", "@value": [ false, true ] }`, want: false, }, { name: "Empty", data: `{ "@type": "g:List", "@value": [] }`, wantErr: true, }, { name: "BadType", data: `{ "@type": "g:List", "@value": [ "user" ] }`, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(tc.data) got, err := rsp.ReadBool() if tc.wantErr { assert.Error(t, err) } else { require.NoError(t, err) assert.Equal(t, tc.want, got) } }) } } func TestResponseReadInt(t *testing.T) { tests := []struct { name string data string want int wantErr bool }{ { name: "Simple", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Int64", "@value": 42 } ] }`, want: 42, }, { name: "Multi", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Int64", "@value": 55 }, { "@type": "g:Int64", "@value": 13 } ] }`, want: 55, }, { name: "Empty", data: `{ "@type": "g:List", "@value": [] }`, wantErr: true, }, { name: "BadType", data: `{ "@type": "g:List", "@value": [ true ] }`, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(tc.data) got, err := rsp.ReadInt() if tc.wantErr { assert.Error(t, err) } else { require.NoError(t, err) assert.Equal(t, tc.want, got) } }) } } func TestResponseReadString(t *testing.T) { tests := []struct { name string data string want string wantErr bool }{ { name: "Simple", data: `{ "@type": "g:List", "@value": ["foo"] }`, want: "foo", }, { name: "Empty", data: `{ "@type": "g:List", "@value": [] }`, wantErr: true, }, { name: "BadType", data: `{ "@type": "g:List", "@value": [ true ] }`, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(tc.data) got, err := rsp.ReadString() if tc.wantErr { assert.Error(t, err) } else { require.NoError(t, err) assert.Equal(t, tc.want, got) } }) } } ent-0.5.4/dialect/gremlin/status.go000066400000000000000000000060571377533537200172510ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin const ( // StatusSuccess is returned on success. StatusSuccess = 200 // StatusNoContent means the server processed the request but there is no result to return. StatusNoContent = 204 // StatusPartialContent indicates the server successfully returned some content, but there // is more in the stream to arrive wait for a success code to signify the end. StatusPartialContent = 206 // StatusUnauthorized means the request attempted to access resources that // the requesting user did not have access to. StatusUnauthorized = 401 // StatusAuthenticate denotes a challenge from the server for the client to authenticate its request. StatusAuthenticate = 407 // StatusMalformedRequest means the request message was not properly formatted which means it could not be parsed at // all or the "op" code was not recognized such that Gremlin Server could properly route it for processing. // Check the message format and retry the request. StatusMalformedRequest = 498 // StatusInvalidRequestArguments means the request message was parsable, but the arguments supplied in the message // were in conflict or incomplete. Check the message format and retry the request. StatusInvalidRequestArguments = 499 // StatusServerError indicates a general server error occurred that prevented the request from being processed. StatusServerError = 500 // StatusScriptEvaluationError is returned when the script submitted for processing evaluated in the ScriptEngine // with errors and could not be processed. Check the script submitted for syntax errors or other problems // and then resubmit. StatusScriptEvaluationError = 597 // StatusServerTimeout means the server exceeded one of the timeout settings for the request and could therefore // only partially responded or did not respond at all. StatusServerTimeout = 598 // StatusServerSerializationError means the server was not capable of serializing an object that was returned from the // script supplied on the request. Either transform the object into something Gremlin Server can process within // the script or install mapper serialization classes to Gremlin Server. StatusServerSerializationError = 599 ) var statusText = map[int]string{ StatusSuccess: "Success", StatusNoContent: "No Content", StatusPartialContent: "Partial Content", StatusUnauthorized: "Unauthorized", StatusAuthenticate: "Authenticate", StatusMalformedRequest: "Malformed Request", StatusInvalidRequestArguments: "Invalid Request Arguments", StatusServerError: "Server Error", StatusScriptEvaluationError: "Script Evaluation Error", StatusServerTimeout: "Server Timeout", StatusServerSerializationError: "Server Serialization Error", } // StatusText returns status text of code. func StatusText(code int) string { return statusText[code] } ent-0.5.4/dialect/gremlin/status_test.go000066400000000000000000000006121377533537200202770ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "testing" "github.com/stretchr/testify/assert" ) func TestStatusText(t *testing.T) { assert.NotEmpty(t, StatusText(StatusSuccess)) assert.Empty(t, StatusText(4242)) } ent-0.5.4/dialect/gremlin/tokens.go000066400000000000000000000045451377533537200172310ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin // Gremlin server operations. const ( // OpsAuthentication used by the client to authenticate itself. OpsAuthentication = "authentication" // OpsBytecode used for a request that contains the Bytecode representation of a Traversal. OpsBytecode = "bytecode" // OpsEval used to evaluate a Gremlin script provided as a string. OpsEval = "eval" // OpsGather used to get a particular side-effect as produced by a previously executed Traversal. OpsGather = "gather" // OpsKeys used to get all the keys of all side-effects as produced by a previously executed Traversal. OpsKeys = "keys" // OpsClose used to get all the keys of all side-effects as produced by a previously executed Traversal. OpsClose = "close" ) // Gremlin server operation processors. const ( // ProcessorTraversal is the default operation processor. ProcessorTraversal = "traversal" ) const ( // ArgsBatchSize allows to defines the number of iterations each ResponseMessage should contain ArgsBatchSize = "batchSize" // ArgsBindings allows to provide a map of key/value pairs to apply // as variables in the context of the Gremlin script. ArgsBindings = "bindings" // ArgsAliases allows to define aliases that represent globally bound Graph and TraversalSource objects. ArgsAliases = "aliases" // ArgsGremlin corresponds to the Traversal to evaluate. ArgsGremlin = "gremlin" // ArgsSideEffect allows to specify the unique identifier for the request. ArgsSideEffect = "sideEffect" // ArgsSideEffectKey allows to specify the key for a specific side-effect. ArgsSideEffectKey = "sideEffectKey" // ArgsAggregateTo describes how side-effect data should be treated. ArgsAggregateTo = "aggregateTo" // ArgsLanguage allows to change the flavor of Gremlin used (e.g. gremlin-groovy). ArgsLanguage = "language" // ArgsEvalTimeout allows to override the server setting that determines // the maximum time to wait for a script to execute on the server. ArgsEvalTimeout = "scriptEvaluationTimeout" // ArgsSasl defines the response to the server authentication challenge. ArgsSasl = "sasl" // ArgsSaslMechanism defines the SASL mechanism (e.g. PLAIN). ArgsSaslMechanism = "saslMechanism" ) ent-0.5.4/dialect/sql/000077500000000000000000000000001377533537200145315ustar00rootroot00000000000000ent-0.5.4/dialect/sql/builder.go000066400000000000000000001665051377533537200165230ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package sql provides wrappers around the standard database/sql package // to allow the generated code to interact with a statically-typed API. // // Users that are interacting with this package should be aware that the // following builders don't check the given SQL syntax nor validate or escape // user-inputs. ~All validations are expected to be happened in the generated // ent package. package sql import ( "bytes" "database/sql/driver" "fmt" "strconv" "strings" "github.com/facebook/ent/dialect" ) // Querier wraps the basic Query method that is implemented // by the different builders in this file. type Querier interface { // Query returns the query representation of the element // and its arguments (if any). Query() (string, []interface{}) } // ColumnBuilder is a builder for column definition in table creation. type ColumnBuilder struct { Builder typ string // column type. name string // column name. attr string // extra attributes. modify bool // modify existing. fk *ForeignKeyBuilder // foreign-key constraint. check func(*Builder) // column checks. } // Column returns a new ColumnBuilder with the given name. // // sql.Column("group_id").Type("int").Attr("UNIQUE") // func Column(name string) *ColumnBuilder { return &ColumnBuilder{name: name} } // Type sets the column type. func (c *ColumnBuilder) Type(t string) *ColumnBuilder { c.typ = t return c } // Attr sets an extra attribute for the column, like UNIQUE or AUTO_INCREMENT. func (c *ColumnBuilder) Attr(attr string) *ColumnBuilder { if c.attr != "" && attr != "" { c.attr += " " } c.attr += attr return c } // Constraint adds the CONSTRAINT clause to the ADD COLUMN statement in SQLite. func (c *ColumnBuilder) Constraint(fk *ForeignKeyBuilder) *ColumnBuilder { c.fk = fk return c } // Check adds a CHECK clause to the ADD COLUMN statement. func (c *ColumnBuilder) Check(check func(*Builder)) *ColumnBuilder { c.check = check return c } // Query returns query representation of a Column. func (c *ColumnBuilder) Query() (string, []interface{}) { c.Ident(c.name) if c.typ != "" { if c.postgres() && c.modify { c.WriteString(" TYPE") } c.Pad().WriteString(c.typ) } if c.attr != "" { c.Pad().WriteString(c.attr) } if c.fk != nil { c.WriteString(" CONSTRAINT " + c.fk.symbol) c.Pad().Join(c.fk.ref) for _, action := range c.fk.actions { c.Pad().WriteString(action) } } if c.check != nil { c.WriteString(" CHECK ") c.Nested(c.check) } return c.String(), c.args } // TableBuilder is a query builder for `CREATE TABLE` statement. type TableBuilder struct { Builder name string // table name. exists bool // check existence. charset string // table charset. collation string // table collation. options string // table options. columns []Querier // table columns. primary []string // primary key. constraints []Querier // foreign keys and indices. } // CreateTable returns a query builder for the `CREATE TABLE` statement. // // CreateTable("users"). // Columns( // Column("id").Type("int").Attr("auto_increment"), // Column("name").Type("varchar(255)"), // ). // PrimaryKey("id") // func CreateTable(name string) *TableBuilder { return &TableBuilder{name: name} } // IfNotExists appends the `IF NOT EXISTS` clause to the `CREATE TABLE` statement. func (t *TableBuilder) IfNotExists() *TableBuilder { t.exists = true return t } // Column appends the given column to the `CREATE TABLE` statement. func (t *TableBuilder) Column(c *ColumnBuilder) *TableBuilder { t.columns = append(t.columns, c) return t } // Columns appends the a list of columns to the builder. func (t *TableBuilder) Columns(columns ...*ColumnBuilder) *TableBuilder { t.columns = make([]Querier, 0, len(columns)) for i := range columns { t.columns = append(t.columns, columns[i]) } return t } // PrimaryKey adds a column to the primary-key constraint in the statement. func (t *TableBuilder) PrimaryKey(column ...string) *TableBuilder { t.primary = append(t.primary, column...) return t } // ForeignKeys adds a list of foreign-keys to the statement (without constraints). func (t *TableBuilder) ForeignKeys(fks ...*ForeignKeyBuilder) *TableBuilder { queries := make([]Querier, len(fks)) for i := range fks { // erase the constraint symbol/name. fks[i].symbol = "" queries[i] = fks[i] } t.constraints = append(t.constraints, queries...) return t } // Constraints adds a list of foreign-key constraints to the statement. func (t *TableBuilder) Constraints(fks ...*ForeignKeyBuilder) *TableBuilder { queries := make([]Querier, len(fks)) for i := range fks { queries[i] = &Wrapper{"CONSTRAINT %s", fks[i]} } t.constraints = append(t.constraints, queries...) return t } // Charset appends the `CHARACTER SET` clause to the statement. MySQL only. func (t *TableBuilder) Charset(s string) *TableBuilder { t.charset = s return t } // Collate appends the `COLLATE` clause to the statement. MySQL only. func (t *TableBuilder) Collate(s string) *TableBuilder { t.collation = s return t } // Options appends additional options to to the statement (MySQL only). func (t *TableBuilder) Options(s string) *TableBuilder { t.options = s return t } // Query returns query representation of a `CREATE TABLE` statement. // // CREATE TABLE [IF NOT EXISTS] name // (table definition) // [charset and collation] // func (t *TableBuilder) Query() (string, []interface{}) { t.WriteString("CREATE TABLE ") if t.exists { t.WriteString("IF NOT EXISTS ") } t.Ident(t.name) t.Nested(func(b *Builder) { b.JoinComma(t.columns...) if len(t.primary) > 0 { b.Comma().WriteString("PRIMARY KEY") b.Nested(func(b *Builder) { b.IdentComma(t.primary...) }) } if len(t.constraints) > 0 { b.Comma().JoinComma(t.constraints...) } }) if t.charset != "" { t.WriteString(" CHARACTER SET " + t.charset) } if t.collation != "" { t.WriteString(" COLLATE " + t.collation) } if t.options != "" { t.WriteString(" " + t.options) } return t.String(), t.args } // DescribeBuilder is a query builder for `DESCRIBE` statement. type DescribeBuilder struct { Builder name string // table name. } // Describe returns a query builder for the `DESCRIBE` statement. // // Describe("users") // func Describe(name string) *DescribeBuilder { return &DescribeBuilder{name: name} } // Query returns query representation of a `DESCRIBE` statement. func (t *DescribeBuilder) Query() (string, []interface{}) { t.WriteString("DESCRIBE ") t.Ident(t.name) return t.String(), nil } // TableAlter is a query builder for `ALTER TABLE` statement. type TableAlter struct { Builder name string // table to alter. Queries []Querier // columns and foreign-keys to add. } // AlterTable returns a query builder for the `ALTER TABLE` statement. // // AlterTable("users"). // AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). // AddForeignKey(ForeignKey().Columns("group_id"). // Reference(Reference().Table("groups").Columns("id")).OnDelete("CASCADE")), // ) // func AlterTable(name string) *TableAlter { return &TableAlter{name: name} } // AddColumn appends the `ADD COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) AddColumn(c *ColumnBuilder) *TableAlter { t.Queries = append(t.Queries, &Wrapper{"ADD COLUMN %s", c}) return t } // ModifyColumn appends the `MODIFY/ALTER COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter { switch { case t.postgres(): c.modify = true t.Queries = append(t.Queries, &Wrapper{"ALTER COLUMN %s", c}) default: t.Queries = append(t.Queries, &Wrapper{"MODIFY COLUMN %s", c}) } return t } // RenameColumn appends the `RENAME COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) RenameColumn(old, new string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("RENAME COLUMN %s TO %s", t.Quote(old), t.Quote(new)))) return t } // ModifyColumns calls ModifyColumn with each of the given builders. func (t *TableAlter) ModifyColumns(cs ...*ColumnBuilder) *TableAlter { for _, c := range cs { t.ModifyColumn(c) } return t } // DropColumn appends the `DROP COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) DropColumn(c *ColumnBuilder) *TableAlter { t.Queries = append(t.Queries, &Wrapper{"DROP COLUMN %s", c}) return t } // ChangeColumn appends the `CHANGE COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) ChangeColumn(name string, c *ColumnBuilder) *TableAlter { prefix := fmt.Sprintf("CHANGE COLUMN %s", t.Quote(name)) t.Queries = append(t.Queries, &Wrapper{prefix + " %s", c}) return t } // RenameIndex appends the `RENAME INDEX` clause to the given `ALTER TABLE` statement. func (t *TableAlter) RenameIndex(curr, new string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("RENAME INDEX %s TO %s", t.Quote(curr), t.Quote(new)))) return t } // DropIndex appends the `DROP INDEX` clause to the given `ALTER TABLE` statement. func (t *TableAlter) DropIndex(name string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("DROP INDEX %s", t.Quote(name)))) return t } // AddIndex appends the `ADD INDEX` clause to the given `ALTER TABLE` statement. func (t *TableAlter) AddIndex(idx *IndexBuilder) *TableAlter { b := &Builder{dialect: t.dialect} b.WriteString("ADD ") if idx.unique { b.WriteString("UNIQUE ") } b.WriteString("INDEX ") b.Ident(idx.name) b.Nested(func(b *Builder) { b.IdentComma(idx.columns...) }) t.Queries = append(t.Queries, b) return t } // AddForeignKey adds a foreign key constraint to the `ALTER TABLE` statement. func (t *TableAlter) AddForeignKey(fk *ForeignKeyBuilder) *TableAlter { t.Queries = append(t.Queries, &Wrapper{"ADD CONSTRAINT %s", fk}) return t } // DropConstraint appends the `DROP CONSTRAINT` clause to the given `ALTER TABLE` statement. func (t *TableAlter) DropConstraint(ident string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("DROP CONSTRAINT %s", t.Quote(ident)))) return t } // DropForeignKey appends the `DROP FOREIGN KEY` clause to the given `ALTER TABLE` statement. func (t *TableAlter) DropForeignKey(ident string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("DROP FOREIGN KEY %s", t.Quote(ident)))) return t } // Query returns query representation of the `ALTER TABLE` statement. // // ALTER TABLE name // [alter_specification] // func (t *TableAlter) Query() (string, []interface{}) { t.WriteString("ALTER TABLE ") t.Ident(t.name) t.Pad() t.JoinComma(t.Queries...) return t.String(), t.args } // IndexAlter is a query builder for `ALTER INDEX` statement. type IndexAlter struct { Builder name string // index to alter. Queries []Querier // alter options. } // AlterIndex returns a query builder for the `ALTER INDEX` statement. // // AlterIndex("old_key"). // Rename("new_key") // func AlterIndex(name string) *IndexAlter { return &IndexAlter{name: name} } // Rename appends the `RENAME TO` clause to the `ALTER INDEX` statement. func (i *IndexAlter) Rename(name string) *IndexAlter { i.Queries = append(i.Queries, Raw(fmt.Sprintf("RENAME TO %s", i.Quote(name)))) return i } // Query returns query representation of the `ALTER INDEX` statement. // // ALTER INDEX name // [alter_specification] // func (i *IndexAlter) Query() (string, []interface{}) { i.WriteString("ALTER INDEX ") i.Ident(i.name) i.Pad() i.JoinComma(i.Queries...) return i.String(), i.args } // ForeignKeyBuilder is the builder for the foreign-key constraint clause. type ForeignKeyBuilder struct { Builder symbol string columns []string actions []string ref *ReferenceBuilder } // ForeignKey returns a builder for the foreign-key constraint clause in create/alter table statements. // // ForeignKey(). // Columns("group_id"). // Reference(Reference().Table("groups").Columns("id")). // OnDelete("CASCADE") // func ForeignKey(symbol ...string) *ForeignKeyBuilder { fk := &ForeignKeyBuilder{} if len(symbol) != 0 { fk.symbol = symbol[0] } return fk } // Symbol sets the symbol of the foreign key. func (fk *ForeignKeyBuilder) Symbol(s string) *ForeignKeyBuilder { fk.symbol = s return fk } // Columns sets the columns of the foreign key in the source table. func (fk *ForeignKeyBuilder) Columns(s ...string) *ForeignKeyBuilder { fk.columns = append(fk.columns, s...) return fk } // Reference sets the reference clause. func (fk *ForeignKeyBuilder) Reference(r *ReferenceBuilder) *ForeignKeyBuilder { fk.ref = r return fk } // OnDelete sets the on delete action for this constraint. func (fk *ForeignKeyBuilder) OnDelete(action string) *ForeignKeyBuilder { fk.actions = append(fk.actions, "ON DELETE "+action) return fk } // OnUpdate sets the on delete action for this constraint. func (fk *ForeignKeyBuilder) OnUpdate(action string) *ForeignKeyBuilder { fk.actions = append(fk.actions, "ON UPDATE "+action) return fk } // Query returns query representation of a foreign key constraint. func (fk *ForeignKeyBuilder) Query() (string, []interface{}) { if fk.symbol != "" { fk.Ident(fk.symbol).Pad() } fk.WriteString("FOREIGN KEY") fk.Nested(func(b *Builder) { b.IdentComma(fk.columns...) }) fk.Pad().Join(fk.ref) for _, action := range fk.actions { fk.Pad().WriteString(action) } return fk.String(), fk.args } // ReferenceBuilder is a builder for the reference clause in constraints. For example, in foreign key creation. type ReferenceBuilder struct { Builder table string // referenced table. columns []string // referenced columns. } // Reference create a reference builder for the reference_option clause. // // Reference().Table("groups").Columns("id") // func Reference() *ReferenceBuilder { return &ReferenceBuilder{} } // Table sets the referenced table. func (r *ReferenceBuilder) Table(s string) *ReferenceBuilder { r.table = s return r } // Columns sets the columns of the referenced table. func (r *ReferenceBuilder) Columns(s ...string) *ReferenceBuilder { r.columns = append(r.columns, s...) return r } // Query returns query representation of a reference clause. func (r *ReferenceBuilder) Query() (string, []interface{}) { r.WriteString("REFERENCES ") r.Ident(r.table) r.Nested(func(b *Builder) { b.IdentComma(r.columns...) }) return r.String(), r.args } // IndexBuilder is a builder for `CREATE INDEX` statement. type IndexBuilder struct { Builder name string unique bool table string columns []string } // CreateIndex creates a builder for the `CREATE INDEX` statement. // // CreateIndex("index_name"). // Unique(). // Table("users"). // Column("name") // // Or: // // CreateIndex("index_name"). // Unique(). // Table("users"). // Columns("name", "age") // func CreateIndex(name string) *IndexBuilder { return &IndexBuilder{name: name} } // Unique sets the index to be a unique index. func (i *IndexBuilder) Unique() *IndexBuilder { i.unique = true return i } // Table defines the table for the index. func (i *IndexBuilder) Table(table string) *IndexBuilder { i.table = table return i } // Column appends a column to the column list for the index. func (i *IndexBuilder) Column(column string) *IndexBuilder { i.columns = append(i.columns, column) return i } // Columns appends the given columns to the column list for the index. func (i *IndexBuilder) Columns(columns ...string) *IndexBuilder { i.columns = append(i.columns, columns...) return i } // Query returns query representation of a reference clause. func (i *IndexBuilder) Query() (string, []interface{}) { i.WriteString("CREATE ") if i.unique { i.WriteString("UNIQUE ") } i.WriteString("INDEX ") i.Ident(i.name) i.WriteString(" ON ") i.Ident(i.table).Nested(func(b *Builder) { b.IdentComma(i.columns...) }) return i.String(), nil } // DropIndexBuilder is a builder for `DROP INDEX` statement. type DropIndexBuilder struct { Builder name string table string } // DropIndex creates a builder for the `DROP INDEX` statement. // // MySQL: // // DropIndex("index_name"). // Table("users"). // // SQLite/PostgreSQL: // // DropIndex("index_name") // func DropIndex(name string) *DropIndexBuilder { return &DropIndexBuilder{name: name} } // Table defines the table for the index. func (d *DropIndexBuilder) Table(table string) *DropIndexBuilder { d.table = table return d } // Query returns query representation of a reference clause. // // DROP INDEX index_name [ON table_name] // func (d *DropIndexBuilder) Query() (string, []interface{}) { d.WriteString("DROP INDEX ") d.Ident(d.name) if d.table != "" { d.WriteString(" ON ") d.Ident(d.table) } return d.String(), nil } // InsertBuilder is a builder for `INSERT INTO` statement. type InsertBuilder struct { Builder table string schema string columns []string defaults string returning []string values [][]interface{} } // Insert creates a builder for the `INSERT INTO` statement. // // Insert("users"). // Columns("name", "age"). // Values("a8m", 10). // Values("foo", 20) // // Note: Insert inserts all values in one batch. func Insert(table string) *InsertBuilder { return &InsertBuilder{table: table} } // Schema sets the database name for the insert table. func (i *InsertBuilder) Schema(name string) *InsertBuilder { i.schema = name return i } // Set is a syntactic sugar API for inserting only one row. func (i *InsertBuilder) Set(column string, v interface{}) *InsertBuilder { i.columns = append(i.columns, column) if len(i.values) == 0 { i.values = append(i.values, []interface{}{v}) } else { i.values[0] = append(i.values[0], v) } return i } // Columns sets the columns of the insert statement. func (i *InsertBuilder) Columns(columns ...string) *InsertBuilder { i.columns = append(i.columns, columns...) return i } // Values append a value tuple for the insert statement. func (i *InsertBuilder) Values(values ...interface{}) *InsertBuilder { i.values = append(i.values, values) return i } // Default sets the default values clause based on the dialect type. func (i *InsertBuilder) Default() *InsertBuilder { switch i.Dialect() { case dialect.MySQL: i.defaults = "VALUES ()" case dialect.SQLite, dialect.Postgres: i.defaults = "DEFAULT VALUES" } return i } // Returning adds the `RETURNING` clause to the insert statement. PostgreSQL only. func (i *InsertBuilder) Returning(columns ...string) *InsertBuilder { i.returning = columns return i } // Query returns query representation of an `INSERT INTO` statement. func (i *InsertBuilder) Query() (string, []interface{}) { i.WriteString("INSERT INTO ") i.writeSchema(i.schema) i.Ident(i.table).Pad() if i.defaults != "" && len(i.columns) == 0 { i.WriteString(i.defaults) } else { i.Nested(func(b *Builder) { b.IdentComma(i.columns...) }) i.WriteString(" VALUES ") for j, v := range i.values { if j > 0 { i.Comma() } i.Nested(func(b *Builder) { b.Args(v...) }) } } if len(i.returning) > 0 && i.postgres() { i.WriteString(" RETURNING ") i.IdentComma(i.returning...) } return i.String(), i.args } // UpdateBuilder is a builder for `UPDATE` statement. type UpdateBuilder struct { Builder table string schema string where *Predicate nulls []string columns []string values []interface{} } // Update creates a builder for the `UPDATE` statement. // // Update("users").Set("name", "foo").Set("age", 10) // func Update(table string) *UpdateBuilder { return &UpdateBuilder{table: table} } // Schema sets the database name for the updated table. func (u *UpdateBuilder) Schema(name string) *UpdateBuilder { u.schema = name return u } // Set sets a column and a its value. func (u *UpdateBuilder) Set(column string, v interface{}) *UpdateBuilder { u.columns = append(u.columns, column) u.values = append(u.values, v) return u } // Add adds a numeric value to the given column. func (u *UpdateBuilder) Add(column string, v interface{}) *UpdateBuilder { u.columns = append(u.columns, column) u.values = append(u.values, P().Append(func(b *Builder) { b.WriteString("COALESCE") b.Nested(func(b *Builder) { b.Ident(column).Comma().Arg(0) }) b.WriteString(" + ") b.Arg(v) })) return u } // SetNull sets a column as null value. func (u *UpdateBuilder) SetNull(column string) *UpdateBuilder { u.nulls = append(u.nulls, column) return u } // Where adds a where predicate for update statement. func (u *UpdateBuilder) Where(p *Predicate) *UpdateBuilder { if u.where != nil { u.where = And(u.where, p) } else { u.where = p } return u } // FromSelect makes it possible to update entities that match the sub-query. func (u *UpdateBuilder) FromSelect(s *Selector) *UpdateBuilder { u.Where(s.where) if table, _ := s.from.(*SelectTable); table != nil { u.table = table.name } return u } // Empty reports whether this builder does not contain update changes. func (u *UpdateBuilder) Empty() bool { return len(u.columns) == 0 && len(u.nulls) == 0 } // Query returns query representation of an `UPDATE` statement. func (u *UpdateBuilder) Query() (string, []interface{}) { u.WriteString("UPDATE ") u.writeSchema(u.schema) u.Ident(u.table).WriteString(" SET ") for i, c := range u.nulls { if i > 0 { u.Comma() } u.Ident(c).WriteString(" = NULL") } if len(u.nulls) > 0 && len(u.columns) > 0 { u.Comma() } for i, c := range u.columns { if i > 0 { u.Comma() } u.Ident(c).WriteString(" = ") switch v := u.values[i].(type) { case Querier: u.Join(v) default: u.Arg(v) } } if u.where != nil { u.WriteString(" WHERE ") u.Join(u.where) } return u.String(), u.args } // DeleteBuilder is a builder for `DELETE` statement. type DeleteBuilder struct { Builder table string schema string where *Predicate } // Delete creates a builder for the `DELETE` statement. // // Delete("users"). // Where( // Or( // EQ("name", "foo").And().EQ("age", 10), // EQ("name", "bar").And().EQ("age", 20), // And( // EQ("name", "qux"), // EQ("age", 1).Or().EQ("age", 2), // ), // ), // ) // func Delete(table string) *DeleteBuilder { return &DeleteBuilder{table: table} } // Schema sets the database name for the table whose row will be deleted. func (d *DeleteBuilder) Schema(name string) *DeleteBuilder { d.schema = name return d } // Where appends a where predicate to the `DELETE` statement. func (d *DeleteBuilder) Where(p *Predicate) *DeleteBuilder { if d.where != nil { d.where = And(d.where, p) } else { d.where = p } return d } // FromSelect makes it possible to delete a sub query. func (d *DeleteBuilder) FromSelect(s *Selector) *DeleteBuilder { d.Where(s.where) if table, _ := s.from.(*SelectTable); table != nil { d.table = table.name } return d } // Query returns query representation of a `DELETE` statement. func (d *DeleteBuilder) Query() (string, []interface{}) { d.WriteString("DELETE FROM ") d.writeSchema(d.schema) d.Ident(d.table) if d.where != nil { d.WriteString(" WHERE ") d.Join(d.where) } return d.String(), d.args } // Predicate is a where predicate. type Predicate struct { Builder depth int fns []func(*Builder) } // P creates a new predicate. // // P().EQ("name", "a8m").And().EQ("age", 30) // func P(fns ...func(*Builder)) *Predicate { return &Predicate{fns: fns} } // Or combines all given predicates with OR between them. // // Or(EQ("name", "foo"), EQ("name", "bar")) // func Or(preds ...*Predicate) *Predicate { p := P() return p.Append(func(b *Builder) { p.mayWrap(preds, b, "OR") }) } // False appends the FALSE keyword to the predicate. // // Delete().From("users").Where(False()) // func False() *Predicate { return P().False() } // False appends FALSE to the predicate. func (p *Predicate) False() *Predicate { return p.Append(func(b *Builder) { b.WriteString("FALSE") }) } // Not wraps the given predicate with the not predicate. // // Not(Or(EQ("name", "foo"), EQ("name", "bar"))) // func Not(pred *Predicate) *Predicate { return P().Not().Append(func(b *Builder) { b.Nested(func(b *Builder) { b.Join(pred) }) }) } // Not appends NOT to the predicate. func (p *Predicate) Not() *Predicate { return p.Append(func(b *Builder) { b.WriteString("NOT ") }) } // And combines all given predicates with AND between them. func And(preds ...*Predicate) *Predicate { p := P() return p.Append(func(b *Builder) { p.mayWrap(preds, b, "AND") }) } // EQ returns a "=" predicate. func EQ(col string, value interface{}) *Predicate { return P().EQ(col, value) } // EQ appends a "=" predicate. func (p *Predicate) EQ(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) b.WriteOp(OpEQ) b.Arg(arg) }) } // NEQ returns a "<>" predicate. func NEQ(col string, value interface{}) *Predicate { return P().NEQ(col, value) } // NEQ appends a "<>" predicate. func (p *Predicate) NEQ(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) b.WriteOp(OpNEQ) b.Arg(arg) }) } // LT returns a "<" predicate. func LT(col string, value interface{}) *Predicate { return P().LT(col, value) } // LT appends a "<" predicate. func (p *Predicate) LT(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpLT) b.Arg(arg) }) } // LTE returns a "<=" predicate. func LTE(col string, value interface{}) *Predicate { return P().LTE(col, value) } // LTE appends a "<=" predicate. func (p *Predicate) LTE(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpLTE) b.Arg(arg) }) } // GT returns a ">" predicate. func GT(col string, value interface{}) *Predicate { return P().GT(col, value) } // GT appends a ">" predicate. func (p *Predicate) GT(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpGT) b.Arg(arg) }) } // GTE returns a ">=" predicate. func GTE(col string, value interface{}) *Predicate { return P().GTE(col, value) } // GTE appends a ">=" predicate. func (p *Predicate) GTE(col string, arg interface{}) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpGTE) b.Arg(arg) }) } // NotNull returns the `IS NOT NULL` predicate. func NotNull(col string) *Predicate { return P().NotNull(col) } // NotNull appends the `IS NOT NULL` predicate. func (p *Predicate) NotNull(col string) *Predicate { return p.Append(func(b *Builder) { b.Ident(col).WriteString(" IS NOT NULL") }) } // IsNull returns the `IS NULL` predicate. func IsNull(col string) *Predicate { return P().IsNull(col) } // IsNull appends the `IS NULL` predicate. func (p *Predicate) IsNull(col string) *Predicate { return p.Append(func(b *Builder) { b.Ident(col).WriteString(" IS NULL") }) } // In returns the `IN` predicate. func In(col string, args ...interface{}) *Predicate { return P().In(col, args...) } // In appends the `IN` predicate. func (p *Predicate) In(col string, args ...interface{}) *Predicate { if len(args) == 0 { return p } return p.Append(func(b *Builder) { b.Ident(col).WriteOp(OpIn) b.Nested(func(b *Builder) { if s, ok := args[0].(*Selector); ok { b.Join(s) } else { b.Args(args...) } }) }) } // InInts returns the `IN` predicate for ints. func InInts(col string, args ...int) *Predicate { return P().InInts(col, args...) } // InValues adds the `IN` predicate for slice of driver.Value. func InValues(col string, args ...driver.Value) *Predicate { return P().InValues(col, args...) } // InInts adds the `IN` predicate for ints. func (p *Predicate) InInts(col string, args ...int) *Predicate { iface := make([]interface{}, len(args)) for i := range args { iface[i] = args[i] } return p.In(col, iface...) } // InValues adds the `IN` predicate for slice of driver.Value. func (p *Predicate) InValues(col string, args ...driver.Value) *Predicate { iface := make([]interface{}, len(args)) for i := range args { iface[i] = args[i] } return p.In(col, iface...) } // NotIn returns the `Not IN` predicate. func NotIn(col string, args ...interface{}) *Predicate { return P().NotIn(col, args...) } // NotIn appends the `Not IN` predicate. func (p *Predicate) NotIn(col string, args ...interface{}) *Predicate { if len(args) == 0 { return p } return p.Append(func(b *Builder) { b.Ident(col).WriteOp(OpNotIn) b.Nested(func(b *Builder) { if s, ok := args[0].(*Selector); ok { b.Join(s) } else { b.Args(args...) } }) }) } // Like returns the `LIKE` predicate. func Like(col, pattern string) *Predicate { return P().Like(col, pattern) } // Like appends the `LIKE` predicate. func (p *Predicate) Like(col, pattern string) *Predicate { return p.Append(func(b *Builder) { b.Ident(col).WriteOp(OpLike) b.Arg(pattern) }) } // HasPrefix is a helper predicate that checks prefix using the LIKE predicate. func HasPrefix(col, prefix string) *Predicate { return P().HasPrefix(col, prefix) } // HasPrefix is a helper predicate that checks prefix using the LIKE predicate. func (p *Predicate) HasPrefix(col, prefix string) *Predicate { return p.Like(col, prefix+"%") } // HasSuffix is a helper predicate that checks suffix using the LIKE predicate. func HasSuffix(col, suffix string) *Predicate { return P().HasSuffix(col, suffix) } // HasSuffix is a helper predicate that checks suffix using the LIKE predicate. func (p *Predicate) HasSuffix(col, suffix string) *Predicate { return p.Like(col, "%"+suffix) } // EqualFold is a helper predicate that applies the "=" predicate with case-folding. func EqualFold(col, sub string) *Predicate { return P().EqualFold(col, sub) } // EqualFold is a helper predicate that applies the "=" predicate with case-folding. func (p *Predicate) EqualFold(col, sub string) *Predicate { return p.Append(func(b *Builder) { f := &Func{} f.SetDialect(b.dialect) f.Lower(col) b.WriteString(f.String()) b.WriteOp(OpEQ) b.Arg(strings.ToLower(sub)) }) } // Contains is a helper predicate that checks substring using the LIKE predicate. func Contains(col, sub string) *Predicate { return P().Contains(col, sub) } // Contains is a helper predicate that checks substring using the LIKE predicate. func (p *Predicate) Contains(col, sub string) *Predicate { return p.Like(col, "%"+sub+"%") } // ContainsFold is a helper predicate that checks substring using the LIKE predicate. func ContainsFold(col, sub string) *Predicate { return P().ContainsFold(col, sub) } // ContainsFold is a helper predicate that applies the LIKE predicate with case-folding. func (p *Predicate) ContainsFold(col, sub string) *Predicate { return p.Append(func(b *Builder) { f := &Func{} f.SetDialect(b.dialect) switch b.dialect { case dialect.MySQL: // We assume the CHARACTER SET is configured to utf8mb4, // because this how it is defined in dialect/sql/schema. b.Ident(col).WriteString(" COLLATE utf8mb4_general_ci LIKE ") case dialect.Postgres: b.Ident(col).WriteString(" ILIKE ") default: // SQLite. f.Lower(col) b.WriteString(f.String()).WriteString(" LIKE ") } b.Arg("%" + strings.ToLower(sub) + "%") }) } // CompositeGT returns a comiposite ">" predicate func CompositeGT(columns []string, args ...interface{}) *Predicate { return P().CompositeGT(columns, args...) } // CompositeLT returns a comiposite "<" predicate func CompositeLT(columns []string, args ...interface{}) *Predicate { return P().CompositeLT(columns, args...) } func (p *Predicate) compositeP(operator string, columns []string, args ...interface{}) *Predicate { return p.Append(func(b *Builder) { b.Nested(func(nb *Builder) { nb.IdentComma(columns...) }) b.WriteString(operator) b.WriteString("(") b.Args(args...) b.WriteString(")") }) } // CompositeGT returns a composite ">" predicate. func (p *Predicate) CompositeGT(columns []string, args ...interface{}) *Predicate { const operator = " > " return p.compositeP(operator, columns, args...) } // CompositeLT appends a composite "<" predicate. func (p *Predicate) CompositeLT(columns []string, args ...interface{}) *Predicate { const operator = " < " return p.compositeP(operator, columns, args...) } // Append appends a new function to the predicate callbacks. // The callback list are executed on call to Query. func (p *Predicate) Append(f func(*Builder)) *Predicate { p.fns = append(p.fns, f) return p } // Query returns query representation of a predicate. func (p *Predicate) Query() (string, []interface{}) { if p.Len() > 0 || len(p.args) > 0 { p.Reset() p.args = nil } for _, f := range p.fns { f(&p.Builder) } return p.String(), p.args } // clone returns a shallow clone of p. func (p *Predicate) clone() *Predicate { if p == nil { return p } return &Predicate{fns: append([]func(*Builder){}, p.fns...)} } func (p *Predicate) mayWrap(preds []*Predicate, b *Builder, op string) { switch n := len(preds); { case n == 1: b.Join(preds[0]) return case n > 1 && p.depth != 0: b.WriteByte('(') defer b.WriteByte(')') } for i := range preds { preds[i].depth = p.depth + 1 if i > 0 { b.WriteByte(' ') b.WriteString(op) b.WriteByte(' ') } if len(preds[i].fns) > 1 { b.Nested(func(b *Builder) { b.Join(preds[i]) }) } else { b.Join(preds[i]) } } } // Func represents an SQL function. type Func struct { Builder fns []func(*Builder) } // Lower wraps the given column with the LOWER function. // // P().EQ(sql.Lower("name"), "a8m") // func Lower(ident string) string { f := &Func{} f.Lower(ident) return f.String() } // Lower wraps the given ident with the LOWER function. func (f *Func) Lower(ident string) { f.byName("LOWER", ident) } // Count wraps the ident with the COUNT aggregation function. func Count(ident string) string { f := &Func{} f.Count(ident) return f.String() } // Count wraps the ident with the COUNT aggregation function. func (f *Func) Count(ident string) { f.byName("COUNT", ident) } // Max wraps the ident with the MAX aggregation function. func Max(ident string) string { f := &Func{} f.Max(ident) return f.String() } // Max wraps the ident with the MAX aggregation function. func (f *Func) Max(ident string) { f.byName("MAX", ident) } // Min wraps the ident with the MIN aggregation function. func Min(ident string) string { f := &Func{} f.Min(ident) return f.String() } // Min wraps the ident with the MIN aggregation function. func (f *Func) Min(ident string) { f.byName("MIN", ident) } // Sum wraps the ident with the SUM aggregation function. func Sum(ident string) string { f := &Func{} f.Sum(ident) return f.String() } // Sum wraps the ident with the SUM aggregation function. func (f *Func) Sum(ident string) { f.byName("SUM", ident) } // Avg wraps the ident with the AVG aggregation function. func Avg(ident string) string { f := &Func{} f.Avg(ident) return f.String() } // Avg wraps the ident with the AVG aggregation function. func (f *Func) Avg(ident string) { f.byName("AVG", ident) } // byName wraps an identifier with a function name. func (f *Func) byName(fn, ident string) { f.Append(func(b *Builder) { f.WriteString(fn) f.Nested(func(b *Builder) { b.Ident(ident) }) }) } // Append appends a new function to the function callbacks. // The callback list are executed on call to String. func (f *Func) Append(fn func(*Builder)) *Func { f.fns = append(f.fns, fn) return f } // String implements the fmt.Stringer. func (f *Func) String() string { for _, fn := range f.fns { fn(&f.Builder) } return f.Builder.String() } // As suffixed the given column with an alias (`a` AS `b`). func As(ident string, as string) string { b := &Builder{} b.fromIdent(ident) b.Ident(ident).Pad().WriteString("AS") b.Pad().Ident(as) return b.String() } // Distinct prefixed the given columns with the `DISTINCT` keyword (DISTINCT `id`). func Distinct(idents ...string) string { b := &Builder{} if len(idents) > 0 { b.fromIdent(idents[0]) } b.WriteString("DISTINCT") b.Pad().IdentComma(idents...) return b.String() } // TableView is a view that returns a table view. Can ne a Table, Selector or a View (WITH statement). type TableView interface { view() } // SelectTable is a table selector. type SelectTable struct { Builder as string name string schema string quote bool } // Table returns a new table selector. // // t1 := Table("users").As("u") // return Select(t1.C("name")) // func Table(name string) *SelectTable { return &SelectTable{quote: true, name: name} } // Schema sets the schema name of the table. func (s *SelectTable) Schema(name string) *SelectTable { s.schema = name return s } // As adds the AS clause to the table selector. func (s *SelectTable) As(alias string) *SelectTable { s.as = alias return s } // C returns a formatted string for the table column. func (s *SelectTable) C(column string) string { name := s.name if s.as != "" { name = s.as } b := &Builder{dialect: s.dialect} if s.as == "" { b.writeSchema(s.schema) } b.Ident(name).WriteByte('.').Ident(column) return b.String() } // Columns returns a list of formatted strings for the table columns. func (s *SelectTable) Columns(columns ...string) []string { names := make([]string, 0, len(columns)) for _, c := range columns { names = append(names, s.C(c)) } return names } // Unquote makes the table name to be formatted as raw string (unquoted). // It is useful whe you don't want to query tables under the current database. // For example: "INFORMATION_SCHEMA.TABLE_CONSTRAINTS" in MySQL. func (s *SelectTable) Unquote() *SelectTable { s.quote = false return s } // ref returns the table reference. func (s *SelectTable) ref() string { if !s.quote { return s.name } b := &Builder{dialect: s.dialect} b.writeSchema(s.schema) b.Ident(s.name) if s.as != "" { b.WriteString(" AS ") b.Ident(s.as) } return b.String() } // implement the table view. func (*SelectTable) view() {} // join table option. type join struct { on *Predicate kind string table TableView } // clone a joiner. func (j join) clone() join { if sel, ok := j.table.(*Selector); ok { j.table = sel.Clone() } j.on = j.on.clone() return j } // Selector is a builder for the `SELECT` statement. type Selector struct { Builder as string columns []string from TableView joins []join where *Predicate or bool not bool order []string group []string having *Predicate limit *int offset *int distinct bool } // Select returns a new selector for the `SELECT` statement. // // t1 := Table("users").As("u") // t2 := Select().From(Table("groups")).Where(EQ("user_id", 10)).As("g") // return Select(t1.C("id"), t2.C("name")). // From(t1). // Join(t2). // On(t1.C("id"), t2.C("user_id")) // func Select(columns ...string) *Selector { return (&Selector{}).Select(columns...) } // Select changes the columns selection of the SELECT statement. // Empty selection means all columns *. func (s *Selector) Select(columns ...string) *Selector { s.columns = columns return s } // From sets the source of `FROM` clause. func (s *Selector) From(t TableView) *Selector { s.from = t if st, ok := t.(state); ok { st.SetDialect(s.dialect) } return s } // Distinct adds the DISTINCT keyword to the `SELECT` statement. func (s *Selector) Distinct() *Selector { s.distinct = true return s } // SetDistinct sets explicitly if the returned rows are distinct or indistinct. func (s *Selector) SetDistinct(v bool) *Selector { s.distinct = v return s } // Limit adds the `LIMIT` clause to the `SELECT` statement. func (s *Selector) Limit(limit int) *Selector { s.limit = &limit return s } // Offset adds the `OFFSET` clause to the `SELECT` statement. func (s *Selector) Offset(offset int) *Selector { s.offset = &offset return s } // Where sets or appends the given predicate to the statement. func (s *Selector) Where(p *Predicate) *Selector { if s.not { p = Not(p) s.not = false } switch { case s.where == nil: s.where = p case s.where != nil && s.or: s.where = Or(s.where, p) s.or = false default: s.where = And(s.where, p) } return s } // P returns the predicate of a selector. func (s *Selector) P() *Predicate { return s.where } // SetP sets explicitly the predicate function for the selector and clear its previous state. func (s *Selector) SetP(p *Predicate) *Selector { s.where = p s.or = false s.not = false return s } // FromSelect copies the predicate from a selector. func (s *Selector) FromSelect(s2 *Selector) *Selector { s.where = s2.where return s } // Not sets the next coming predicate with not. func (s *Selector) Not() *Selector { s.not = true return s } // Or sets the next coming predicate with OR operator (disjunction). func (s *Selector) Or() *Selector { s.or = true return s } // Table returns the selected table. func (s *Selector) Table() *SelectTable { return s.from.(*SelectTable) } // Join appends a `JOIN` clause to the statement. func (s *Selector) Join(t TableView) *Selector { return s.join("JOIN", t) } // LeftJoin appends a `LEFT JOIN` clause to the statement. func (s *Selector) LeftJoin(t TableView) *Selector { return s.join("LEFT JOIN", t) } // RightJoin appends a `RIGHT JOIN` clause to the statement. func (s *Selector) RightJoin(t TableView) *Selector { return s.join("RIGHT JOIN", t) } // join adds a join table to the selector with the given kind. func (s *Selector) join(kind string, t TableView) *Selector { s.joins = append(s.joins, join{ kind: kind, table: t, }) switch view := t.(type) { case *SelectTable: if view.as == "" { view.as = "t0" } case *Selector: if view.as == "" { view.as = "t" + strconv.Itoa(len(s.joins)) } } if st, ok := t.(state); ok { st.SetDialect(s.dialect) } return s } // C returns a formatted string for a selected column from this statement. func (s *Selector) C(column string) string { if s.as != "" { b := &Builder{dialect: s.dialect} b.Ident(s.as) b.WriteByte('.') b.Ident(column) return b.String() } return s.Table().C(column) } // Columns returns a list of formatted strings for a selected columns from this statement. func (s *Selector) Columns(columns ...string) []string { names := make([]string, 0, len(columns)) for _, c := range columns { names = append(names, s.C(c)) } return names } // OnP sets or appends the given predicate for the `ON` clause of the statement. func (s *Selector) OnP(p *Predicate) *Selector { if len(s.joins) > 0 { join := &s.joins[len(s.joins)-1] switch { case join.on == nil: join.on = p default: join.on = And(join.on, p) } } return s } // On sets the `ON` clause for the `JOIN` operation. func (s *Selector) On(c1, c2 string) *Selector { s.OnP(P(func(builder *Builder) { builder.Ident(c1).WriteOp(OpEQ).Ident(c2) })) return s } // As give this selection an alias. func (s *Selector) As(alias string) *Selector { s.as = alias return s } // Count sets the Select statement to be a `SELECT COUNT(*)`. func (s *Selector) Count(columns ...string) *Selector { column := "*" if len(columns) > 0 { b := &Builder{} b.IdentComma(columns...) column = b.String() } s.columns = []string{Count(column)} return s } // Clone returns a duplicate of the selector, including all associated steps. It can be // used to prepare common SELECT statements and use them differently after the clone is made. func (s *Selector) Clone() *Selector { if s == nil { return nil } joins := make([]join, len(s.joins)) for i := range s.joins { joins[i] = s.joins[i].clone() } return &Selector{ Builder: s.Builder.clone(), as: s.as, or: s.or, not: s.not, from: s.from, limit: s.limit, offset: s.offset, distinct: s.distinct, where: s.where.clone(), having: s.having.clone(), joins: append([]join{}, joins...), group: append([]string{}, s.group...), order: append([]string{}, s.order...), columns: append([]string{}, s.columns...), } } // Asc adds the ASC suffix for the given column. func Asc(column string) string { b := &Builder{} b.Ident(column).WriteString(" ASC") return b.String() } // Desc adds the DESC suffix for the given column. func Desc(column string) string { b := &Builder{} b.Ident(column).WriteString(" DESC") return b.String() } // OrderBy appends the `ORDER BY` clause to the `SELECT` statement. func (s *Selector) OrderBy(columns ...string) *Selector { s.order = append(s.order, columns...) return s } // GroupBy appends the `GROUP BY` clause to the `SELECT` statement. func (s *Selector) GroupBy(columns ...string) *Selector { s.group = append(s.group, columns...) return s } // Having appends a predicate for the `HAVING` clause. func (s *Selector) Having(p *Predicate) *Selector { s.having = p return s } // Query returns query representation of a `SELECT` statement. func (s *Selector) Query() (string, []interface{}) { b := s.Builder.clone() b.WriteString("SELECT ") if s.distinct { b.WriteString("DISTINCT ") } if len(s.columns) > 0 { b.IdentComma(s.columns...) } else { b.WriteString("*") } b.WriteString(" FROM ") switch t := s.from.(type) { case *SelectTable: t.SetDialect(s.dialect) b.WriteString(t.ref()) case *Selector: t.SetDialect(s.dialect) b.Nested(func(b *Builder) { b.Join(t) }) b.WriteString(" AS ") b.Ident(t.as) } for _, join := range s.joins { b.WriteString(" " + join.kind + " ") switch view := join.table.(type) { case *SelectTable: view.SetDialect(s.dialect) b.WriteString(view.ref()) case *Selector: view.SetDialect(s.dialect) b.Nested(func(b *Builder) { b.Join(view) }) b.WriteString(" AS ") b.Ident(view.as) } if join.on != nil { b.WriteString(" ON ") b.Join(join.on) } } if s.where != nil { b.WriteString(" WHERE ") b.Join(s.where) } if len(s.group) > 0 { b.WriteString(" GROUP BY ") b.IdentComma(s.group...) } if s.having != nil { b.WriteString(" HAVING ") b.Join(s.having) } if len(s.order) > 0 { b.WriteString(" ORDER BY ") b.IdentComma(s.order...) } if s.limit != nil { b.WriteString(" LIMIT ") b.WriteString(strconv.Itoa(*s.limit)) } if s.offset != nil { b.WriteString(" OFFSET ") b.WriteString(strconv.Itoa(*s.offset)) } s.total = b.total return b.String(), b.args } // implement the table view interface. func (*Selector) view() {} // WithBuilder is the builder for the `WITH` statement. type WithBuilder struct { Builder name string s *Selector } // With returns a new builder for the `WITH` statement. // // n := Queries{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))} // return n.Query() // func With(name string) *WithBuilder { return &WithBuilder{name: name} } // Name returns the name of the view. func (w *WithBuilder) Name() string { return w.name } // As sets the view sub query. func (w *WithBuilder) As(s *Selector) *WithBuilder { w.s = s return w } // Query returns query representation of a `WITH` clause. func (w *WithBuilder) Query() (string, []interface{}) { w.WriteString(fmt.Sprintf("WITH %s AS ", w.name)) w.Nested(func(b *Builder) { b.Join(w.s) }) return w.String(), w.args } // implement the table view interface. func (*WithBuilder) view() {} // Wrapper wraps a given Querier with different format. // Used to prefix/suffix other queries. type Wrapper struct { format string wrapped Querier } // Query returns query representation of a wrapped Querier. func (w *Wrapper) Query() (string, []interface{}) { query, args := w.wrapped.Query() return fmt.Sprintf(w.format, query), args } // SetDialect calls SetDialect on the wrapped query. func (w *Wrapper) SetDialect(name string) { if s, ok := w.wrapped.(state); ok { s.SetDialect(name) } } // Dialect calls Dialect on the wrapped query. func (w *Wrapper) Dialect() string { if s, ok := w.wrapped.(state); ok { return s.Dialect() } return "" } // Total returns the total number of arguments so far. func (w *Wrapper) Total() int { if s, ok := w.wrapped.(state); ok { return s.Total() } return 0 } // SetTotal sets the value of the total arguments. // Used to pass this information between sub queries/expressions. func (w *Wrapper) SetTotal(total int) { if s, ok := w.wrapped.(state); ok { s.SetTotal(total) } } // Raw returns a raw sql Querier that is placed as-is in the query. func Raw(s string) Querier { return &raw{s} } type raw struct{ s string } func (r *raw) Query() (string, []interface{}) { return r.s, nil } // Queries are list of queries join with space between them. type Queries []Querier // Query returns query representation of Queriers. func (n Queries) Query() (string, []interface{}) { b := &Builder{} for i := range n { if i > 0 { b.Pad() } query, args := n[i].Query() b.WriteString(query) b.args = append(b.args, args...) } return b.String(), b.args } // Builder is the base query builder for the sql dsl. type Builder struct { bytes.Buffer // underlying buffer. dialect string // configured dialect. args []interface{} // query parameters. total int // total number of parameters in query tree. errs []error // errors that added during the query construction. } // Quote quotes the given identifier with the characters based // on the configured dialect. It defaults to "`". func (b *Builder) Quote(ident string) string { switch { case b.postgres(): // If it was quoted with the wrong // identifier character. if strings.Contains(ident, "`") { return strings.ReplaceAll(ident, "`", `"`) } return strconv.Quote(ident) // An identifier for unknown dialect. case b.dialect == "" && strings.ContainsAny(ident, "`\""): return ident default: return fmt.Sprintf("`%s`", ident) } } // Ident appends the given string as an identifier. func (b *Builder) Ident(s string) *Builder { switch { case len(s) == 0: case s != "*" && !b.isIdent(s) && !isFunc(s) && !isModifier(s): b.WriteString(b.Quote(s)) case (isFunc(s) || isModifier(s)) && b.postgres(): // Modifiers and aggregation functions that // were called without dialect information. b.WriteString(strings.ReplaceAll(s, "`", `"`)) default: b.WriteString(s) } return b } // IdentComma calls Ident on all arguments and adds a comma between them. func (b *Builder) IdentComma(s ...string) *Builder { for i := range s { if i > 0 { b.Comma() } b.Ident(s[i]) } return b } // WriteByte wraps the Buffer.WriteByte to make it chainable with other methods. func (b *Builder) WriteByte(c byte) *Builder { b.Buffer.WriteByte(c) return b } // WriteString wraps the Buffer.WriteString to make it chainable with other methods. func (b *Builder) WriteString(s string) *Builder { b.Buffer.WriteString(s) return b } // AddError appends an error to the builder errors. func (b *Builder) AddError(err error) *Builder { b.errs = append(b.errs, err) return b } func (b *Builder) writeSchema(schema string) { if schema != "" && b.dialect != dialect.SQLite { b.Ident(schema).WriteByte('.') } } // Err returns a concatenated error of all errors encountered during // the query-building, or were added manually by calling AddError. func (b *Builder) Err() error { if len(b.errs) == 0 { return nil } br := strings.Builder{} for i := range b.errs { if i > 0 { br.WriteString("; ") } br.WriteString(b.errs[i].Error()) } return fmt.Errorf(br.String()) } // An Op represents a predicate operator. type Op int // Predicate operators const ( OpEQ Op = iota // logical and. OpNEQ // <> OpGT // > OpGTE // >= OpLT // < OpLTE // <= OpIn // IN OpNotIn // NOT IN OpLike // LIKE OpIsNull // IS NULL OpNotNull // IS NOT NULL ) var ops = [...]string{ OpEQ: "=", OpNEQ: "<>", OpGT: ">", OpGTE: ">=", OpLT: "<", OpLTE: "<=", OpIn: "IN", OpNotIn: "NOT IN", OpLike: "LIKE", OpIsNull: "IS NULL", OpNotNull: "IS NOT NULL", } // WriteOp writes an operator to the builder. func (b *Builder) WriteOp(op Op) *Builder { switch { case op >= OpEQ && op <= OpLike: b.Pad().WriteString(ops[op]).Pad() case op == OpIsNull || op == OpNotNull: b.Pad().WriteString(ops[op]) default: panic(fmt.Sprintf("invalid op %d", op)) } return b } // Arg appends an input argument to the builder. func (b *Builder) Arg(a interface{}) *Builder { if r, ok := a.(*raw); ok { b.WriteString(r.s) return b } b.total++ b.args = append(b.args, a) switch { case b.postgres(): // PostgreSQL arguments are referenced using the syntax $n. // $1 refers to the 1st argument, $2 to the 2nd, and so on. b.WriteString("$" + strconv.Itoa(b.total)) default: b.WriteString("?") } return b } // Args appends a list of arguments to the builder. func (b *Builder) Args(a ...interface{}) *Builder { for i := range a { if i > 0 { b.Comma() } b.Arg(a[i]) } return b } // Comma adds a comma to the query. func (b *Builder) Comma() *Builder { b.WriteString(", ") return b } // Pad adds a space to the query. func (b *Builder) Pad() *Builder { b.WriteString(" ") return b } // Join joins a list of Queries to the builder. func (b *Builder) Join(qs ...Querier) *Builder { return b.join(qs, "") } // JoinComma joins a list of Queries and adds comma between them. func (b *Builder) JoinComma(qs ...Querier) *Builder { return b.join(qs, ", ") } // join joins a list of Queries to the builder with a given separator. func (b *Builder) join(qs []Querier, sep string) *Builder { for i, q := range qs { if i > 0 { b.WriteString(sep) } st, ok := q.(state) if ok { st.SetDialect(b.dialect) st.SetTotal(b.total) } query, args := q.Query() b.WriteString(query) b.args = append(b.args, args...) b.total = len(b.args) if ok { b.total = st.Total() } } return b } // Nested gets a callback, and wraps its result with parentheses. func (b *Builder) Nested(f func(*Builder)) *Builder { nb := &Builder{dialect: b.dialect, total: b.total} nb.WriteByte('(') f(nb) nb.WriteByte(')') nb.WriteTo(b) b.args = append(b.args, nb.args...) b.total = nb.total return b } // SetDialect sets the builder dialect. It's used for garnering dialect specific queries. func (b *Builder) SetDialect(dialect string) { b.dialect = dialect } // Dialect returns the dialect of the builder. func (b Builder) Dialect() string { return b.dialect } // Total returns the total number of arguments so far. func (b Builder) Total() int { return b.total } // SetTotal sets the value of the total arguments. // Used to pass this information between sub queries/expressions. func (b *Builder) SetTotal(total int) { b.total = total } // Query implements the Querier interface. func (b Builder) Query() (string, []interface{}) { return b.String(), b.args } // clone returns a shallow clone of a builder. func (b Builder) clone() Builder { c := Builder{dialect: b.dialect, total: b.total} if len(b.args) > 0 { c.args = append(c.args, b.args...) } c.Buffer.Write(b.Bytes()) return c } // postgres reports if the builder dialect is PostgreSQL. func (b Builder) postgres() bool { return b.Dialect() == dialect.Postgres } // fromIdent sets the builder dialect from the identifier format. func (b *Builder) fromIdent(ident string) { if strings.Contains(ident, `"`) { b.SetDialect(dialect.Postgres) } // otherwise, use the default. } // isIdent reports if the given string is a dialect identifier. func (b *Builder) isIdent(s string) bool { switch { case b.postgres(): return strings.Contains(s, `"`) default: return strings.Contains(s, "`") } } // state wraps the all methods for setting and getting // update state between all queries in the query tree. type state interface { Dialect() string SetDialect(string) Total() int SetTotal(int) } // DialectBuilder prefixes all root builders with the `Dialect` constructor. type DialectBuilder struct { dialect string } // Dialect creates a new DialectBuilder with the given dialect name. func Dialect(name string) *DialectBuilder { return &DialectBuilder{name} } // Describe creates a DescribeBuilder for the configured dialect. // // Dialect(dialect.Postgres). // Describe("users") // func (d *DialectBuilder) Describe(name string) *DescribeBuilder { b := Describe(name) b.SetDialect(d.dialect) return b } // CreateTable creates a TableBuilder for the configured dialect. // // Dialect(dialect.Postgres). // CreateTable("users"). // Columns( // Column("id").Type("int").Attr("auto_increment"), // Column("name").Type("varchar(255)"), // ). // PrimaryKey("id") // func (d *DialectBuilder) CreateTable(name string) *TableBuilder { b := CreateTable(name) b.SetDialect(d.dialect) return b } // AlterTable creates a TableAlter for the configured dialect. // // Dialect(dialect.Postgres). // AlterTable("users"). // AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). // AddForeignKey(ForeignKey().Columns("group_id"). // Reference(Reference().Table("groups").Columns("id")). // OnDelete("CASCADE"), // ) // func (d *DialectBuilder) AlterTable(name string) *TableAlter { b := AlterTable(name) b.SetDialect(d.dialect) return b } // AlterIndex creates an IndexAlter for the configured dialect. // // Dialect(dialect.Postgres). // AlterIndex("old"). // Rename("new") // func (d *DialectBuilder) AlterIndex(name string) *IndexAlter { b := AlterIndex(name) b.SetDialect(d.dialect) return b } // Column creates a ColumnBuilder for the configured dialect. // // Dialect(dialect.Postgres).. // Column("group_id").Type("int").Attr("UNIQUE") // func (d *DialectBuilder) Column(name string) *ColumnBuilder { b := Column(name) b.SetDialect(d.dialect) return b } // Insert creates a InsertBuilder for the configured dialect. // // Dialect(dialect.Postgres). // Insert("users").Columns("age").Values(1) // func (d *DialectBuilder) Insert(table string) *InsertBuilder { b := Insert(table) b.SetDialect(d.dialect) return b } // Update creates a UpdateBuilder for the configured dialect. // // Dialect(dialect.Postgres). // Update("users").Set("name", "foo") // func (d *DialectBuilder) Update(table string) *UpdateBuilder { b := Update(table) b.SetDialect(d.dialect) return b } // Delete creates a DeleteBuilder for the configured dialect. // // Dialect(dialect.Postgres). // Delete().From("users") // func (d *DialectBuilder) Delete(table string) *DeleteBuilder { b := Delete(table) b.SetDialect(d.dialect) return b } // Select creates a Selector for the configured dialect. // // Dialect(dialect.Postgres). // Select().From(Table("users")) // func (d *DialectBuilder) Select(columns ...string) *Selector { b := Select(columns...) b.SetDialect(d.dialect) return b } // Table creates a SelectTable for the configured dialect. // // Dialect(dialect.Postgres). // Table("users").As("u") // func (d *DialectBuilder) Table(name string) *SelectTable { b := Table(name) b.SetDialect(d.dialect) return b } // With creates a WithBuilder for the configured dialect. // // Dialect(dialect.Postgres). // With("users_view"). // As(Select().From(Table("users"))) // func (d *DialectBuilder) With(name string) *WithBuilder { b := With(name) b.SetDialect(d.dialect) return b } // CreateIndex creates a IndexBuilder for the configured dialect. // // Dialect(dialect.Postgres). // CreateIndex("unique_name"). // Unique(). // Table("users"). // Columns("first", "last") // func (d *DialectBuilder) CreateIndex(name string) *IndexBuilder { b := CreateIndex(name) b.SetDialect(d.dialect) return b } // DropIndex creates a DropIndexBuilder for the configured dialect. // // Dialect(dialect.Postgres). // DropIndex("name") // func (d *DialectBuilder) DropIndex(name string) *DropIndexBuilder { b := DropIndex(name) b.SetDialect(d.dialect) return b } func isFunc(s string) bool { return strings.Contains(s, "(") && strings.Contains(s, ")") } func isModifier(s string) bool { for _, m := range [...]string{"DISTINCT", "ALL", "WITH ROLLUP"} { if strings.HasPrefix(s, m) { return true } } return false } ent-0.5.4/dialect/sql/builder_test.go000066400000000000000000001346071377533537200175600ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sql import ( "fmt" "strconv" "strings" "testing" "github.com/facebook/ent/dialect" "github.com/stretchr/testify/require" ) func TestBuilder(t *testing.T) { tests := []struct { input Querier wantQuery string wantArgs []interface{} }{ { input: Describe("users"), wantQuery: "DESCRIBE `users`", }, { input: CreateTable("users"). Columns( Column("id").Type("int").Attr("auto_increment"), Column("name").Type("varchar(255)"), ). PrimaryKey("id"), wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`))", }, { input: Dialect(dialect.Postgres).CreateTable("users"). Columns( Column("id").Type("serial").Attr("PRIMARY KEY"), Column("name").Type("varchar"), ), wantQuery: `CREATE TABLE "users"("id" serial PRIMARY KEY, "name" varchar)`, }, { input: CreateTable("users"). Columns( Column("id").Type("int").Attr("auto_increment"), Column("name").Type("varchar(255)"), ). PrimaryKey("id"). Charset("utf8mb4"), wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4", }, { input: CreateTable("users"). Columns( Column("id").Type("int").Attr("auto_increment"), Column("name").Type("varchar(255)"), ). PrimaryKey("id"). Charset("utf8mb4"). Collate("utf8mb4_general_ci"). Options("ENGINE=InnoDB"), wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci ENGINE=InnoDB", }, { input: CreateTable("users"). IfNotExists(). Columns( Column("id").Type("int").Attr("auto_increment"), ). PrimaryKey("id", "name"), wantQuery: "CREATE TABLE IF NOT EXISTS `users`(`id` int auto_increment, PRIMARY KEY(`id`, `name`))", }, { input: CreateTable("users"). IfNotExists(). Columns( Column("id").Type("int").Attr("auto_increment"), Column("card_id").Type("int"), Column("doc").Type("longtext").Check(func(b *Builder) { b.WriteString("JSON_VALID(").Ident("doc").WriteByte(')') }), ). PrimaryKey("id", "name"). ForeignKeys(ForeignKey().Columns("card_id"). Reference(Reference().Table("cards").Columns("id")).OnDelete("SET NULL")), wantQuery: "CREATE TABLE IF NOT EXISTS `users`(`id` int auto_increment, `card_id` int, `doc` longtext CHECK (JSON_VALID(`doc`)), PRIMARY KEY(`id`, `name`), FOREIGN KEY(`card_id`) REFERENCES `cards`(`id`) ON DELETE SET NULL)", }, { input: Dialect(dialect.Postgres).CreateTable("users"). IfNotExists(). Columns( Column("id").Type("serial"), Column("card_id").Type("int"), ). PrimaryKey("id", "name"). ForeignKeys(ForeignKey().Columns("card_id"). Reference(Reference().Table("cards").Columns("id")).OnDelete("SET NULL")), wantQuery: `CREATE TABLE IF NOT EXISTS "users"("id" serial, "card_id" int, PRIMARY KEY("id", "name"), FOREIGN KEY("card_id") REFERENCES "cards"("id") ON DELETE SET NULL)`, }, { input: AlterTable("users"). AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). AddForeignKey(ForeignKey().Columns("group_id"). Reference(Reference().Table("groups").Columns("id")). OnDelete("CASCADE"), ), wantQuery: "ALTER TABLE `users` ADD COLUMN `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`) ON DELETE CASCADE", }, { input: Dialect(dialect.Postgres).AlterTable("users"). AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). AddForeignKey(ForeignKey("constraint").Columns("group_id"). Reference(Reference().Table("groups").Columns("id")). OnDelete("CASCADE"), ), wantQuery: `ALTER TABLE "users" ADD COLUMN "group_id" int UNIQUE, ADD CONSTRAINT "constraint" FOREIGN KEY("group_id") REFERENCES "groups"("id") ON DELETE CASCADE`, }, { input: AlterTable("users"). AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). AddForeignKey(ForeignKey().Columns("group_id"). Reference(Reference().Table("groups").Columns("id")), ), wantQuery: "ALTER TABLE `users` ADD COLUMN `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`)", }, { input: Dialect(dialect.Postgres).AlterTable("users"). AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). AddForeignKey(ForeignKey().Columns("group_id"). Reference(Reference().Table("groups").Columns("id")), ), wantQuery: `ALTER TABLE "users" ADD COLUMN "group_id" int UNIQUE, ADD CONSTRAINT FOREIGN KEY("group_id") REFERENCES "groups"("id")`, }, { input: AlterTable("users"). AddColumn(Column("age").Type("int")). AddColumn(Column("name").Type("varchar(255)")), wantQuery: "ALTER TABLE `users` ADD COLUMN `age` int, ADD COLUMN `name` varchar(255)", }, { input: AlterTable("users"). DropForeignKey("users_parent_id"), wantQuery: "ALTER TABLE `users` DROP FOREIGN KEY `users_parent_id`", }, { input: Dialect(dialect.Postgres).AlterTable("users"). AddColumn(Column("age").Type("int")). AddColumn(Column("name").Type("varchar(255)")). DropConstraint("users_nickname_key"), wantQuery: `ALTER TABLE "users" ADD COLUMN "age" int, ADD COLUMN "name" varchar(255), DROP CONSTRAINT "users_nickname_key"`, }, { input: AlterTable("users"). AddForeignKey(ForeignKey().Columns("group_id"). Reference(Reference().Table("groups").Columns("id")), ). AddForeignKey(ForeignKey().Columns("location_id"). Reference(Reference().Table("locations").Columns("id")), ), wantQuery: "ALTER TABLE `users` ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`), ADD CONSTRAINT FOREIGN KEY(`location_id`) REFERENCES `locations`(`id`)", }, { input: AlterTable("users"). ModifyColumn(Column("age").Type("int")), wantQuery: "ALTER TABLE `users` MODIFY COLUMN `age` int", }, { input: Dialect(dialect.Postgres).AlterTable("users"). ModifyColumn(Column("age").Type("int")), wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int`, }, { input: AlterTable("users"). ModifyColumn(Column("age").Type("int")). DropColumn(Column("name")), wantQuery: "ALTER TABLE `users` MODIFY COLUMN `age` int, DROP COLUMN `name`", }, { input: Dialect(dialect.Postgres).AlterTable("users"). ModifyColumn(Column("age").Type("int")). DropColumn(Column("name")), wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int, DROP COLUMN "name"`, }, { input: Dialect(dialect.Postgres).AlterTable("users"). ModifyColumn(Column("age").Type("int")). ModifyColumn(Column("age").Attr("SET NOT NULL")). ModifyColumn(Column("name").Attr("DROP NOT NULL")), wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int, ALTER COLUMN "age" SET NOT NULL, ALTER COLUMN "name" DROP NOT NULL`, }, { input: AlterTable("users"). ChangeColumn("old_age", Column("age").Type("int")), wantQuery: "ALTER TABLE `users` CHANGE COLUMN `old_age` `age` int", }, { input: Dialect(dialect.Postgres).AlterTable("users"). AddColumn(Column("boring").Type("varchar")). ModifyColumn(Column("age").Type("int")). DropColumn(Column("name")), wantQuery: `ALTER TABLE "users" ADD COLUMN "boring" varchar, ALTER COLUMN "age" TYPE int, DROP COLUMN "name"`, }, { input: AlterTable("users").RenameIndex("old", "new"), wantQuery: "ALTER TABLE `users` RENAME INDEX `old` TO `new`", }, { input: AlterTable("users"). DropIndex("old"). AddIndex(CreateIndex("new1").Columns("c1", "c2")). AddIndex(CreateIndex("new2").Columns("c1", "c2").Unique()), wantQuery: "ALTER TABLE `users` DROP INDEX `old`, ADD INDEX `new1`(`c1`, `c2`), ADD UNIQUE INDEX `new2`(`c1`, `c2`)", }, { input: Dialect(dialect.Postgres).AlterIndex("old"). Rename("new"), wantQuery: `ALTER INDEX "old" RENAME TO "new"`, }, { input: Insert("users").Columns("age").Values(1), wantQuery: "INSERT INTO `users` (`age`) VALUES (?)", wantArgs: []interface{}{1}, }, { input: Insert("users").Columns("age").Values(1).Schema("mydb"), wantQuery: "INSERT INTO `mydb`.`users` (`age`) VALUES (?)", wantArgs: []interface{}{1}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("age").Values(1), wantQuery: `INSERT INTO "users" ("age") VALUES ($1)`, wantArgs: []interface{}{1}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("age").Values(1).Schema("mydb"), wantQuery: `INSERT INTO "mydb"."users" ("age") VALUES ($1)`, wantArgs: []interface{}{1}, }, { input: Dialect(dialect.SQLite).Insert("users").Columns("age").Values(1).Schema("mydb"), wantQuery: "INSERT INTO `users` (`age`) VALUES (?)", wantArgs: []interface{}{1}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("age").Values(1).Returning("id"), wantQuery: `INSERT INTO "users" ("age") VALUES ($1) RETURNING "id"`, wantArgs: []interface{}{1}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("age").Values(1).Returning("id").Returning("name"), wantQuery: `INSERT INTO "users" ("age") VALUES ($1) RETURNING "name"`, wantArgs: []interface{}{1}, }, { input: Insert("users").Columns("name", "age").Values("a8m", 10), wantQuery: "INSERT INTO `users` (`name`, `age`) VALUES (?, ?)", wantArgs: []interface{}{"a8m", 10}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("name", "age").Values("a8m", 10), wantQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2)`, wantArgs: []interface{}{"a8m", 10}, }, { input: Insert("users").Columns("name", "age").Values("a8m", 10).Values("foo", 20), wantQuery: "INSERT INTO `users` (`name`, `age`) VALUES (?, ?), (?, ?)", wantArgs: []interface{}{"a8m", 10, "foo", 20}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("name", "age").Values("a8m", 10).Values("foo", 20), wantQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2), ($3, $4)`, wantArgs: []interface{}{"a8m", 10, "foo", 20}, }, { input: Dialect(dialect.Postgres).Insert("users"). Columns("name", "age"). Values("a8m", 10). Values("foo", 20). Values("bar", 30), wantQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2), ($3, $4), ($5, $6)`, wantArgs: []interface{}{"a8m", 10, "foo", 20, "bar", 30}, }, { input: Update("users").Set("name", "foo"), wantQuery: "UPDATE `users` SET `name` = ?", wantArgs: []interface{}{"foo"}, }, { input: Update("users").Set("name", "foo").Schema("mydb"), wantQuery: "UPDATE `mydb`.`users` SET `name` = ?", wantArgs: []interface{}{"foo"}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo"), wantQuery: `UPDATE "users" SET "name" = $1`, wantArgs: []interface{}{"foo"}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Schema("mydb"), wantQuery: `UPDATE "mydb"."users" SET "name" = $1`, wantArgs: []interface{}{"foo"}, }, { input: Dialect(dialect.SQLite).Update("users").Set("name", "foo").Schema("mydb"), wantQuery: "UPDATE `users` SET `name` = ?", wantArgs: []interface{}{"foo"}, }, { input: Update("users").Set("name", "foo").Set("age", 10), wantQuery: "UPDATE `users` SET `name` = ?, `age` = ?", wantArgs: []interface{}{"foo", 10}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Set("age", 10), wantQuery: `UPDATE "users" SET "name" = $1, "age" = $2`, wantArgs: []interface{}{"foo", 10}, }, { input: Update("users").Set("name", "foo").Where(EQ("name", "bar")), wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` = ?", wantArgs: []interface{}{"foo", "bar"}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Where(EQ("name", "bar")), wantQuery: `UPDATE "users" SET "name" = $1 WHERE "name" = $2`, wantArgs: []interface{}{"foo", "bar"}, }, { input: func() Querier { p1, p2 := EQ("name", "bar"), Or(EQ("age", 10), EQ("age", 20)) return Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Where(p1). Where(p2). Where(p1). Where(p2) }(), wantQuery: `UPDATE "users" SET "name" = $1 WHERE (("name" = $2 AND ("age" = $3 OR "age" = $4)) AND "name" = $5) AND ("age" = $6 OR "age" = $7)`, wantArgs: []interface{}{"foo", "bar", 10, 20, "bar", 10, 20}, }, { input: Update("users").Set("name", "foo").SetNull("spouse_id"), wantQuery: "UPDATE `users` SET `spouse_id` = NULL, `name` = ?", wantArgs: []interface{}{"foo"}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").SetNull("spouse_id"), wantQuery: `UPDATE "users" SET "spouse_id" = NULL, "name" = $1`, wantArgs: []interface{}{"foo"}, }, { input: Update("users").Set("name", "foo"). Where(EQ("name", "bar")). Where(EQ("age", 20)), wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` = ? AND `age` = ?", wantArgs: []interface{}{"foo", "bar", 20}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Where(EQ("name", "bar")). Where(EQ("age", 20)), wantQuery: `UPDATE "users" SET "name" = $1 WHERE "name" = $2 AND "age" = $3`, wantArgs: []interface{}{"foo", "bar", 20}, }, { input: Update("users"). Set("name", "foo"). Set("age", 10). Where(Or(EQ("name", "bar"), EQ("name", "baz"))), wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ? OR `name` = ?", wantArgs: []interface{}{"foo", 10, "bar", "baz"}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Set("age", 10). Where(Or(EQ("name", "bar"), EQ("name", "baz"))), wantQuery: `UPDATE "users" SET "name" = $1, "age" = $2 WHERE "name" = $3 OR "name" = $4`, wantArgs: []interface{}{"foo", 10, "bar", "baz"}, }, { input: Update("users"). Set("name", "foo"). Set("age", 10). Where(P().EQ("name", "foo")), wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ?", wantArgs: []interface{}{"foo", 10, "foo"}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Set("age", 10). Where(P().EQ("name", "foo")), wantQuery: `UPDATE "users" SET "name" = $1, "age" = $2 WHERE "name" = $3`, wantArgs: []interface{}{"foo", 10, "foo"}, }, { input: Update("users"). Set("name", "foo"). Where(And(In("name", "bar", "baz"), NotIn("age", 1, 2))), wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` IN (?, ?) AND `age` NOT IN (?, ?)", wantArgs: []interface{}{"foo", "bar", "baz", 1, 2}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Where(And(In("name", "bar", "baz"), NotIn("age", 1, 2))), wantQuery: `UPDATE "users" SET "name" = $1 WHERE "name" IN ($2, $3) AND "age" NOT IN ($4, $5)`, wantArgs: []interface{}{"foo", "bar", "baz", 1, 2}, }, { input: Update("users"). Set("name", "foo"). Where(And(HasPrefix("nickname", "a8m"), Contains("lastname", "mash"))), wantQuery: "UPDATE `users` SET `name` = ? WHERE `nickname` LIKE ? AND `lastname` LIKE ?", wantArgs: []interface{}{"foo", "a8m%", "%mash%"}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Where(And(HasPrefix("nickname", "a8m"), Contains("lastname", "mash"))), wantQuery: `UPDATE "users" SET "name" = $1 WHERE "nickname" LIKE $2 AND "lastname" LIKE $3`, wantArgs: []interface{}{"foo", "a8m%", "%mash%"}, }, { input: Update("users"). Add("age", 1). Where(HasPrefix("nickname", "a8m")), wantQuery: "UPDATE `users` SET `age` = COALESCE(`age`, ?) + ? WHERE `nickname` LIKE ?", wantArgs: []interface{}{0, 1, "a8m%"}, }, { input: Dialect(dialect.Postgres). Update("users"). Add("age", 1). Where(HasPrefix("nickname", "a8m")), wantQuery: `UPDATE "users" SET "age" = COALESCE("age", $1) + $2 WHERE "nickname" LIKE $3`, wantArgs: []interface{}{0, 1, "a8m%"}, }, { input: Update("users"). Add("age", 1). Set("nickname", "a8m"). Add("version", 10). Set("name", "mashraki"), wantQuery: "UPDATE `users` SET `age` = COALESCE(`age`, ?) + ?, `nickname` = ?, `version` = COALESCE(`version`, ?) + ?, `name` = ?", wantArgs: []interface{}{0, 1, "a8m", 0, 10, "mashraki"}, }, { input: Dialect(dialect.Postgres). Update("users"). Add("age", 1). Set("nickname", "a8m"). Add("version", 10). Set("name", "mashraki"), wantQuery: `UPDATE "users" SET "age" = COALESCE("age", $1) + $2, "nickname" = $3, "version" = COALESCE("version", $4) + $5, "name" = $6`, wantArgs: []interface{}{0, 1, "a8m", 0, 10, "mashraki"}, }, { input: Dialect(dialect.Postgres). Update("users"). Add("age", 1). Set("nickname", "a8m"). Add("version", 10). Set("name", "mashraki"). Set("first", "ariel"). Add("score", 1e5). Where(Or(EQ("age", 1), EQ("age", 2))), wantQuery: `UPDATE "users" SET "age" = COALESCE("age", $1) + $2, "nickname" = $3, "version" = COALESCE("version", $4) + $5, "name" = $6, "first" = $7, "score" = COALESCE("score", $8) + $9 WHERE "age" = $10 OR "age" = $11`, wantArgs: []interface{}{0, 1, "a8m", 0, 10, "mashraki", "ariel", 0, 1e5, 1, 2}, }, { input: Select(). From(Table("users")). Where(EQ("name", "Alex")), wantQuery: "SELECT * FROM `users` WHERE `name` = ?", wantArgs: []interface{}{"Alex"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")), wantQuery: `SELECT * FROM "users"`, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(EQ("name", "Ariel")), wantQuery: `SELECT * FROM "users" WHERE "name" = $1`, wantArgs: []interface{}{"Ariel"}, }, { input: Select(). From(Table("users")). Where(Or(EQ("name", "BAR"), EQ("name", "BAZ"))), wantQuery: "SELECT * FROM `users` WHERE `name` = ? OR `name` = ?", wantArgs: []interface{}{"BAR", "BAZ"}, }, { input: Update("users"). Set("name", "foo"). Set("age", 10). Where(And(EQ("name", "foo"), EQ("age", 20))), wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ? AND `age` = ?", wantArgs: []interface{}{"foo", 10, "foo", 20}, }, { input: Delete("users"). Where(NotNull("parent_id")), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NOT NULL", }, { input: Delete("users"). Where(NotNull("parent_id")). Schema("mydb"), wantQuery: "DELETE FROM `mydb`.`users` WHERE `parent_id` IS NOT NULL", }, { input: Dialect(dialect.SQLite). Delete("users"). Where(NotNull("parent_id")). Schema("mydb"), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NOT NULL", }, { input: Dialect(dialect.Postgres). Delete("users"). Where(IsNull("parent_id")), wantQuery: `DELETE FROM "users" WHERE "parent_id" IS NULL`, }, { input: Dialect(dialect.Postgres). Delete("users"). Where(IsNull("parent_id")). Schema("mydb"), wantQuery: `DELETE FROM "mydb"."users" WHERE "parent_id" IS NULL`, }, { input: Delete("users"). Where(And(IsNull("parent_id"), NotIn("name", "foo", "bar"))), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NULL AND `name` NOT IN (?, ?)", wantArgs: []interface{}{"foo", "bar"}, }, { input: Dialect(dialect.Postgres). Delete("users"). Where(And(IsNull("parent_id"), NotIn("name", "foo", "bar"))), wantQuery: `DELETE FROM "users" WHERE "parent_id" IS NULL AND "name" NOT IN ($1, $2)`, wantArgs: []interface{}{"foo", "bar"}, }, { input: Delete("users"). Where(And(False(), False())), wantQuery: "DELETE FROM `users` WHERE FALSE AND FALSE", }, { input: Dialect(dialect.Postgres). Delete("users"). Where(And(False(), False())), wantQuery: `DELETE FROM "users" WHERE FALSE AND FALSE`, }, { input: Delete("users"). Where(Or(NotNull("parent_id"), EQ("parent_id", 10))), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NOT NULL OR `parent_id` = ?", wantArgs: []interface{}{10}, }, { input: Dialect(dialect.Postgres). Delete("users"). Where(Or(NotNull("parent_id"), EQ("parent_id", 10))), wantQuery: `DELETE FROM "users" WHERE "parent_id" IS NOT NULL OR "parent_id" = $1`, wantArgs: []interface{}{10}, }, { input: Delete("users"). Where( Or( And(EQ("name", "foo"), EQ("age", 10)), And(EQ("name", "bar"), EQ("age", 20)), And( EQ("name", "qux"), Or(EQ("age", 1), EQ("age", 2)), ), ), ), wantQuery: "DELETE FROM `users` WHERE (`name` = ? AND `age` = ?) OR (`name` = ? AND `age` = ?) OR (`name` = ? AND (`age` = ? OR `age` = ?))", wantArgs: []interface{}{"foo", 10, "bar", 20, "qux", 1, 2}, }, { input: Dialect(dialect.Postgres). Delete("users"). Where( Or( And(EQ("name", "foo"), EQ("age", 10)), And(EQ("name", "bar"), EQ("age", 20)), And( EQ("name", "qux"), Or(EQ("age", 1), EQ("age", 2)), ), ), ), wantQuery: `DELETE FROM "users" WHERE ("name" = $1 AND "age" = $2) OR ("name" = $3 AND "age" = $4) OR ("name" = $5 AND ("age" = $6 OR "age" = $7))`, wantArgs: []interface{}{"foo", 10, "bar", 20, "qux", 1, 2}, }, { input: Delete("users"). Where( Or( And(EQ("name", "foo"), EQ("age", 10)), And(EQ("name", "bar"), EQ("age", 20)), ), ). Where(EQ("role", "admin")), wantQuery: "DELETE FROM `users` WHERE ((`name` = ? AND `age` = ?) OR (`name` = ? AND `age` = ?)) AND `role` = ?", wantArgs: []interface{}{"foo", 10, "bar", 20, "admin"}, }, { input: Dialect(dialect.Postgres). Delete("users"). Where( Or( And(EQ("name", "foo"), EQ("age", 10)), And(EQ("name", "bar"), EQ("age", 20)), ), ). Where(EQ("role", "admin")), wantQuery: `DELETE FROM "users" WHERE (("name" = $1 AND "age" = $2) OR ("name" = $3 AND "age" = $4)) AND "role" = $5`, wantArgs: []interface{}{"foo", 10, "bar", 20, "admin"}, }, { input: Select().From(Table("users")), wantQuery: "SELECT * FROM `users`", }, { input: Dialect(dialect.Postgres).Select().From(Table("users")), wantQuery: `SELECT * FROM "users"`, }, { input: Select().From(Table("users").Unquote()), wantQuery: "SELECT * FROM users", }, { input: Dialect(dialect.Postgres).Select().From(Table("users").Unquote()), wantQuery: "SELECT * FROM users", }, { input: Select().From(Table("users").As("u")), wantQuery: "SELECT * FROM `users` AS `u`", }, { input: Dialect(dialect.Postgres).Select().From(Table("users").As("u")), wantQuery: `SELECT * FROM "users" AS "u"`, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Select(t1.C("id"), t2.C("name")).From(t1).Join(t2) }(), wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g`", }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Dialect(dialect.Postgres).Select(t1.C("id"), t2.C("name")).From(t1).Join(t2) }(), wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN "groups" AS "g"`, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")) }(), wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g` ON `u`.`id` = `g`.`user_id`", }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Dialect(dialect.Postgres). Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")) }(), wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN "groups" AS "g" ON "u"."id" = "g"."user_id"`, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")). Where(And(EQ(t1.C("name"), "bar"), NotNull(t2.C("name")))) }(), wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g` ON `u`.`id` = `g`.`user_id` WHERE `u`.`name` = ? AND `g`.`name` IS NOT NULL", wantArgs: []interface{}{"bar"}, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Dialect(dialect.Postgres). Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")). Where(And(EQ(t1.C("name"), "bar"), NotNull(t2.C("name")))) }(), wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN "groups" AS "g" ON "u"."id" = "g"."user_id" WHERE "u"."name" = $1 AND "g"."name" IS NOT NULL`, wantArgs: []interface{}{"bar"}, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("user_groups").As("ug") return Select(t1.C("id"), As(Count("`*`"), "group_count")). From(t1). LeftJoin(t2). On(t1.C("id"), t2.C("user_id")). GroupBy(t1.C("id")) }(), wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`", }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("user_groups").As("ug") return Select(t1.C("id"), As(Count("`*`"), "group_count")). From(t1). LeftJoin(t2). OnP(P(func(b *Builder) { b.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("user_id")) })). GroupBy(t1.C("id")).Clone() }(), wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`", }, { input: func() Querier { t1 := Table("groups").As("g") t2 := Table("user_groups").As("ug") return Select(t1.C("id"), As(Count("`*`"), "user_count")). From(t1). RightJoin(t2). On(t1.C("id"), t2.C("group_id")). GroupBy(t1.C("id")) }(), wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` RIGHT JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`", }, { input: func() Querier { t1 := Table("users").As("u") return Select(t1.Columns("name", "age")...).From(t1) }(), wantQuery: "SELECT `u`.`name`, `u`.`age` FROM `users` AS `u`", }, { input: func() Querier { t1 := Table("users").As("u") return Dialect(dialect.Postgres). Select(t1.Columns("name", "age")...).From(t1) }(), wantQuery: `SELECT "u"."name", "u"."age" FROM "users" AS "u"`, }, { input: func() Querier { t1 := Dialect(dialect.Postgres). Table("users").As("u") return Dialect(dialect.Postgres). Select(t1.Columns("name", "age")...).From(t1) }(), wantQuery: `SELECT "u"."name", "u"."age" FROM "users" AS "u"`, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Select().From(Table("groups")).Where(EQ("user_id", 10)).As("g") return Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")) }(), wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN (SELECT * FROM `groups` WHERE `user_id` = ?) AS `g` ON `u`.`id` = `g`.`user_id`", wantArgs: []interface{}{10}, }, { input: func() Querier { d := Dialect(dialect.Postgres) t1 := d.Table("users").As("u") t2 := d.Select().From(Table("groups")).Where(EQ("user_id", 10)).As("g") return d.Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")) }(), wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN (SELECT * FROM "groups" WHERE "user_id" = $1) AS "g" ON "u"."id" = "g"."user_id"`, wantArgs: []interface{}{10}, }, { input: func() Querier { selector := Select().Where(Or(EQ("name", "foo"), EQ("name", "bar"))) return Delete("users").FromSelect(selector) }(), wantQuery: "DELETE FROM `users` WHERE `name` = ? OR `name` = ?", wantArgs: []interface{}{"foo", "bar"}, }, { input: func() Querier { d := Dialect(dialect.Postgres) selector := d.Select().Where(Or(EQ("name", "foo"), EQ("name", "bar"))) return d.Delete("users").FromSelect(selector) }(), wantQuery: `DELETE FROM "users" WHERE "name" = $1 OR "name" = $2`, wantArgs: []interface{}{"foo", "bar"}, }, { input: func() Querier { selector := Select().From(Table("users")).As("t") return selector.Select(selector.C("name")) }(), wantQuery: "SELECT `t`.`name` FROM `users`", }, { input: func() Querier { selector := Dialect(dialect.Postgres). Select().From(Table("users")).As("t") return selector.Select(selector.C("name")) }(), wantQuery: `SELECT "t"."name" FROM "users"`, }, { input: func() Querier { selector := Select().From(Table("groups")).Where(EQ("name", "foo")) return Delete("users").FromSelect(selector) }(), wantQuery: "DELETE FROM `groups` WHERE `name` = ?", wantArgs: []interface{}{"foo"}, }, { input: func() Querier { d := Dialect(dialect.Postgres) selector := d.Select().From(Table("groups")).Where(EQ("name", "foo")) return d.Delete("users").FromSelect(selector) }(), wantQuery: `DELETE FROM "groups" WHERE "name" = $1`, wantArgs: []interface{}{"foo"}, }, { input: func() Querier { selector := Select() return Delete("users").FromSelect(selector) }(), wantQuery: "DELETE FROM `users`", }, { input: func() Querier { d := Dialect(dialect.Postgres) selector := d.Select() return d.Delete("users").FromSelect(selector) }(), wantQuery: `DELETE FROM "users"`, }, { input: Select(). From(Table("users")). Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))), wantQuery: "SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)", wantArgs: []interface{}{"foo", "bar"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))), wantQuery: `SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)`, wantArgs: []interface{}{"foo", "bar"}, }, { input: Select(). From(Table("users")). Where(Or(EqualFold("name", "BAR"), EqualFold("name", "BAZ"))), wantQuery: "SELECT * FROM `users` WHERE LOWER(`name`) = ? OR LOWER(`name`) = ?", wantArgs: []interface{}{"bar", "baz"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(Or(EqualFold("name", "BAR"), EqualFold("name", "BAZ"))), wantQuery: `SELECT * FROM "users" WHERE LOWER("name") = $1 OR LOWER("name") = $2`, wantArgs: []interface{}{"bar", "baz"}, }, { input: Dialect(dialect.SQLite). Select(). From(Table("users")). Where(And(ContainsFold("name", "Ariel"), ContainsFold("nick", "Bar"))), wantQuery: "SELECT * FROM `users` WHERE LOWER(`name`) LIKE ? AND LOWER(`nick`) LIKE ?", wantArgs: []interface{}{"%ariel%", "%bar%"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(And(ContainsFold("name", "Ariel"), ContainsFold("nick", "Bar"))), wantQuery: `SELECT * FROM "users" WHERE "name" ILIKE $1 AND "nick" ILIKE $2`, wantArgs: []interface{}{"%ariel%", "%bar%"}, }, { input: Dialect(dialect.MySQL). Select(). From(Table("users")). Where(And(ContainsFold("name", "Ariel"), ContainsFold("nick", "Bar"))), wantQuery: "SELECT * FROM `users` WHERE `name` COLLATE utf8mb4_general_ci LIKE ? AND `nick` COLLATE utf8mb4_general_ci LIKE ?", wantArgs: []interface{}{"%ariel%", "%bar%"}, }, { input: func() Querier { s1 := Select(). From(Table("users")). Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))) return Queries{With("users_view").As(s1), Select("name").From(Table("users_view"))} }(), wantQuery: "WITH users_view AS (SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)) SELECT `name` FROM `users_view`", wantArgs: []interface{}{"foo", "bar"}, }, { input: func() Querier { d := Dialect(dialect.Postgres) s1 := d.Select(). From(Table("users")). Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))) return Queries{d.With("users_view").As(s1), d.Select("name").From(Table("users_view"))} }(), wantQuery: `WITH users_view AS (SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)) SELECT "name" FROM "users_view"`, wantArgs: []interface{}{"foo", "bar"}, }, { input: func() Querier { s1 := Select().From(Table("users")).Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))).As("users_view") return Select("name").From(s1) }(), wantQuery: "SELECT `name` FROM (SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)) AS `users_view`", wantArgs: []interface{}{"foo", "bar"}, }, { input: func() Querier { d := Dialect(dialect.Postgres) s1 := d.Select().From(Table("users")).Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))).As("users_view") return d.Select("name").From(s1) }(), wantQuery: `SELECT "name" FROM (SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)) AS "users_view"`, wantArgs: []interface{}{"foo", "bar"}, }, { input: func() Querier { t1 := Table("users") return Select(). From(t1). Where(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro")))) }(), wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `owner_id` FROM `pets` WHERE `name` = ?)", wantArgs: []interface{}{"pedro"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro")))) }(), wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $1)`, wantArgs: []interface{}{"pedro"}, }, { input: func() Querier { t1 := Table("users") return Select(). From(t1). Where(Not(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro"))))) }(), wantQuery: "SELECT * FROM `users` WHERE NOT (`users`.`id` IN (SELECT `owner_id` FROM `pets` WHERE `name` = ?))", wantArgs: []interface{}{"pedro"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(Not(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro"))))) }(), wantQuery: `SELECT * FROM "users" WHERE NOT ("users"."id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $1))`, wantArgs: []interface{}{"pedro"}, }, { input: Select().Count().From(Table("users")), wantQuery: "SELECT COUNT(*) FROM `users`", }, { input: Dialect(dialect.Postgres). Select().Count().From(Table("users")), wantQuery: `SELECT COUNT(*) FROM "users"`, }, { input: Select().Count(Distinct("id")).From(Table("users")), wantQuery: "SELECT COUNT(DISTINCT `id`) FROM `users`", }, { input: Dialect(dialect.Postgres). Select().Count(Distinct("id")).From(Table("users")), wantQuery: `SELECT COUNT(DISTINCT "id") FROM "users"`, }, { input: func() Querier { t1 := Table("users") t2 := Select().From(Table("groups")) t3 := Select().Count().From(t1).Join(t1).On(t2.C("id"), t1.C("blocked_id")) return t3.Count(Distinct(t3.Columns("id", "name")...)) }(), wantQuery: "SELECT COUNT(DISTINCT `t0`.`id`, `t0`.`name`) FROM `users` AS `t0` JOIN `users` AS `t0` ON `groups`.`id` = `t0`.`blocked_id`", }, { input: func() Querier { d := Dialect(dialect.Postgres) t1 := d.Table("users") t2 := d.Select().From(Table("groups")) t3 := d.Select().Count().From(t1).Join(t1).On(t2.C("id"), t1.C("blocked_id")) return t3.Count(Distinct(t3.Columns("id", "name")...)) }(), wantQuery: `SELECT COUNT(DISTINCT "t0"."id", "t0"."name") FROM "users" AS "t0" JOIN "users" AS "t0" ON "groups"."id" = "t0"."blocked_id"`, }, { input: Select(Sum("age"), Min("age")).From(Table("users")), wantQuery: "SELECT SUM(`age`), MIN(`age`) FROM `users`", }, { input: Dialect(dialect.Postgres). Select(Sum("age"), Min("age")). From(Table("users")), wantQuery: `SELECT SUM("age"), MIN("age") FROM "users"`, }, { input: func() Querier { t1 := Table("users").As("u") return Select(As(Max(t1.C("age")), "max_age")).From(t1) }(), wantQuery: "SELECT MAX(`u`.`age`) AS `max_age` FROM `users` AS `u`", }, { input: func() Querier { t1 := Table("users").As("u") return Dialect(dialect.Postgres). Select(As(Max(t1.C("age")), "max_age")). From(t1) }(), wantQuery: `SELECT MAX("u"."age") AS "max_age" FROM "users" AS "u"`, }, { input: Select("name", Count("*")). From(Table("users")). GroupBy("name"), wantQuery: "SELECT `name`, COUNT(*) FROM `users` GROUP BY `name`", }, { input: Dialect(dialect.Postgres). Select("name", Count("*")). From(Table("users")). GroupBy("name"), wantQuery: `SELECT "name", COUNT(*) FROM "users" GROUP BY "name"`, }, { input: Select("name", Count("*")). From(Table("users")). GroupBy("name"). OrderBy("name"), wantQuery: "SELECT `name`, COUNT(*) FROM `users` GROUP BY `name` ORDER BY `name`", }, { input: Dialect(dialect.Postgres). Select("name", Count("*")). From(Table("users")). GroupBy("name"). OrderBy("name"), wantQuery: `SELECT "name", COUNT(*) FROM "users" GROUP BY "name" ORDER BY "name"`, }, { input: Select("name", "age", Count("*")). From(Table("users")). GroupBy("name", "age"). OrderBy(Desc("name"), "age"), wantQuery: "SELECT `name`, `age`, COUNT(*) FROM `users` GROUP BY `name`, `age` ORDER BY `name` DESC, `age`", }, { input: Dialect(dialect.Postgres). Select("name", "age", Count("*")). From(Table("users")). GroupBy("name", "age"). OrderBy(Desc("name"), "age"), wantQuery: `SELECT "name", "age", COUNT(*) FROM "users" GROUP BY "name", "age" ORDER BY "name" DESC, "age"`, }, { input: Select("*"). From(Table("users")). Limit(1), wantQuery: "SELECT * FROM `users` LIMIT 1", }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("users")). Limit(1), wantQuery: `SELECT * FROM "users" LIMIT 1`, }, { input: Select("age").Distinct().From(Table("users")), wantQuery: "SELECT DISTINCT `age` FROM `users`", }, { input: Dialect(dialect.Postgres). Select("age"). Distinct(). From(Table("users")), wantQuery: `SELECT DISTINCT "age" FROM "users"`, }, { input: Select("age", "name").From(Table("users")).Distinct().OrderBy("name"), wantQuery: "SELECT DISTINCT `age`, `name` FROM `users` ORDER BY `name`", }, { input: Dialect(dialect.Postgres). Select("age", "name"). From(Table("users")). Distinct(). OrderBy("name"), wantQuery: `SELECT DISTINCT "age", "name" FROM "users" ORDER BY "name"`, }, { input: Select("age").From(Table("users")).Where(EQ("name", "foo")).Or().Where(EQ("name", "bar")), wantQuery: "SELECT `age` FROM `users` WHERE `name` = ? OR `name` = ?", wantArgs: []interface{}{"foo", "bar"}, }, { input: Dialect(dialect.Postgres). Select("age"). From(Table("users")). Where(EQ("name", "foo")).Or().Where(EQ("name", "bar")), wantQuery: `SELECT "age" FROM "users" WHERE "name" = $1 OR "name" = $2`, wantArgs: []interface{}{"foo", "bar"}, }, { input: Queries{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))}, wantQuery: "WITH users_view AS (SELECT * FROM `users`) SELECT * FROM `users_view`", }, { input: func() Querier { base := Select("*").From(Table("groups")) return Queries{With("groups").As(base.Clone().Where(EQ("name", "bar"))), base.Select("age")} }(), wantQuery: "WITH groups AS (SELECT * FROM `groups` WHERE `name` = ?) SELECT `age` FROM `groups`", wantArgs: []interface{}{"bar"}, }, { input: func() Querier { builder := Dialect(dialect.Postgres) t1 := builder.Table("groups") t2 := builder.Table("users") t3 := builder.Table("user_groups") t4 := builder.Select(t3.C("id")). From(t3). Join(t2). On(t3.C("id"), t2.C("id2")). Where(EQ(t2.C("id"), "baz")) return builder.Select(). From(t1). Join(t4). On(t1.C("id"), t4.C("id")).Limit(1) }(), wantQuery: `SELECT * FROM "groups" JOIN (SELECT "user_groups"."id" FROM "user_groups" JOIN "users" AS "t0" ON "user_groups"."id" = "t0"."id2" WHERE "t0"."id" = $1) AS "t1" ON "groups"."id" = "t1"."id" LIMIT 1`, wantArgs: []interface{}{"baz"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(CompositeGT(t1.Columns("id", "name"), 1, "Ariel")) }(), wantQuery: `SELECT * FROM "users" WHERE ("users"."id", "users"."name") > ($1, $2)`, wantArgs: []interface{}{1, "Ariel"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(And(EQ("name", "Ariel"), CompositeGT(t1.Columns("id", "name"), 1, "Ariel"))) }(), wantQuery: `SELECT * FROM "users" WHERE "name" = $1 AND ("users"."id", "users"."name") > ($2, $3)`, wantArgs: []interface{}{"Ariel", 1, "Ariel"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(And(EQ("name", "Ariel"), Or(EQ("surname", "Doe"), CompositeGT(t1.Columns("id", "name"), 1, "Ariel")))) }(), wantQuery: `SELECT * FROM "users" WHERE "name" = $1 AND ("surname" = $2 OR ("users"."id", "users"."name") > ($3, $4))`, wantArgs: []interface{}{"Ariel", "Doe", 1, "Ariel"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(Table("users")). Where(And(EQ("name", "Ariel"), CompositeLT(t1.Columns("id", "name"), 1, "Ariel"))) }(), wantQuery: `SELECT * FROM "users" WHERE "name" = $1 AND ("users"."id", "users"."name") < ($2, $3)`, wantArgs: []interface{}{"Ariel", 1, "Ariel"}, }, { input: CreateIndex("name_index").Table("users").Column("name"), wantQuery: "CREATE INDEX `name_index` ON `users`(`name`)", }, { input: Dialect(dialect.Postgres). CreateIndex("name_index"). Table("users"). Column("name"), wantQuery: `CREATE INDEX "name_index" ON "users"("name")`, }, { input: CreateIndex("unique_name").Unique().Table("users").Columns("first", "last"), wantQuery: "CREATE UNIQUE INDEX `unique_name` ON `users`(`first`, `last`)", }, { input: Dialect(dialect.Postgres). CreateIndex("unique_name"). Unique(). Table("users"). Columns("first", "last"), wantQuery: `CREATE UNIQUE INDEX "unique_name" ON "users"("first", "last")`, }, { input: DropIndex("name_index"), wantQuery: "DROP INDEX `name_index`", }, { input: Dialect(dialect.Postgres). DropIndex("name_index"), wantQuery: `DROP INDEX "name_index"`, }, { input: DropIndex("name_index").Table("users"), wantQuery: "DROP INDEX `name_index` ON `users`", }, { input: Select(). From(Table("pragma_table_info('t1')").Unquote()). OrderBy("pk"), wantQuery: "SELECT * FROM pragma_table_info('t1') ORDER BY `pk`", }, { input: AlterTable("users"). AddColumn(Column("spouse").Type("integer"). Constraint(ForeignKey("user_spouse"). Reference(Reference().Table("users").Columns("id")). OnDelete("SET NULL"))), wantQuery: "ALTER TABLE `users` ADD COLUMN `spouse` integer CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE SET NULL", }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("users")). Where(Or( And(EQ("id", 1), InInts("group_id", 2, 3)), And(EQ("id", 2), InValues("group_id", 4, 5)), )). Where(And( Or(EQ("a", "a"), And(EQ("b", "b"), EQ("c", "c"))), Not(Or(IsNull("d"), NotNull("e"))), )). Or(). Where(And(NEQ("f", "f"), NEQ("g", "g"))), wantQuery: strings.NewReplacer("\n", "", "\t", "").Replace(` SELECT * FROM "users" WHERE ( (("id" = $1 AND "group_id" IN ($2, $3)) OR ("id" = $4 AND "group_id" IN ($5, $6))) AND (("a" = $7 OR ("b" = $8 AND "c" = $9)) AND (NOT ("d" IS NULL OR "e" IS NOT NULL))) ) OR ("f" <> $10 AND "g" <> $11)`), wantArgs: []interface{}{1, 2, 3, 2, 4, 5, "a", "b", "c", "f", "g"}, }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("test")). Where(P(func(b *Builder) { b.WriteString("nlevel(").Ident("path").WriteByte(')').WriteOp(OpGT).Arg(1) })), wantQuery: `SELECT * FROM "test" WHERE nlevel("path") > $1`, wantArgs: []interface{}{1}, }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("test")). Where(P(func(b *Builder) { b.WriteString("nlevel(").Ident("path").WriteByte(')').WriteOp(OpGT).Arg(1) })), wantQuery: `SELECT * FROM "test" WHERE nlevel("path") > $1`, wantArgs: []interface{}{1}, }, { input: func() Querier { t1, t2 := Table("users").Schema("s1"), Table("pets").Schema("s2") return Select("*"). From(t1).Join(t2). OnP(P(func(b *Builder) { b.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("owner_id")) })). Where(EQ(t2.C("name"), "pedro")) }(), wantQuery: "SELECT * FROM `s1`.`users` JOIN `s2`.`pets` AS `t0` ON `s1`.`users`.`id` = `t0`.`owner_id` WHERE `t0`.`name` = ?", wantArgs: []interface{}{"pedro"}, }, { input: func() Querier { t1, t2 := Table("users").Schema("s1"), Table("pets").Schema("s2") sel := Select("*"). From(t1).Join(t2). OnP(P(func(b *Builder) { b.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("owner_id")) })). Where(EQ(t2.C("name"), "pedro")) sel.SetDialect(dialect.SQLite) return sel }(), wantQuery: "SELECT * FROM `users` JOIN `pets` AS `t0` ON `users`.`id` = `t0`.`owner_id` WHERE `t0`.`name` = ?", wantArgs: []interface{}{"pedro"}, }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { query, args := tt.input.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } func TestBuilder_Err(t *testing.T) { b := Select("i-") require.NoError(t, b.Err()) b.AddError(fmt.Errorf("invalid")) require.EqualError(t, b.Err(), "invalid") b.AddError(fmt.Errorf("unexpected")) require.EqualError(t, b.Err(), "invalid; unexpected") } ent-0.5.4/dialect/sql/driver.go000066400000000000000000000076511377533537200163640ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sql import ( "context" "database/sql" "database/sql/driver" "fmt" "strings" "github.com/facebook/ent/dialect" ) // Driver is a dialect.Driver implementation for SQL based databases. type Driver struct { Conn dialect string } // Open wraps the database/sql.Open method and returns a dialect.Driver that implements the an ent/dialect.Driver interface. func Open(driver, source string) (*Driver, error) { db, err := sql.Open(driver, source) if err != nil { return nil, err } return &Driver{Conn{db}, driver}, nil } // OpenDB wraps the given database/sql.DB method with a Driver. func OpenDB(driver string, db *sql.DB) *Driver { return &Driver{Conn{db}, driver} } // DB returns the underlying *sql.DB instance. func (d Driver) DB() *sql.DB { return d.ExecQuerier.(*sql.DB) } // Dialect implements the dialect.Dialect method. func (d Driver) Dialect() string { // If the underlying driver is wrapped with opencensus driver. for _, name := range []string{dialect.MySQL, dialect.SQLite, dialect.Postgres} { if strings.HasPrefix(d.dialect, name) { return name } } return d.dialect } // Tx starts and returns a transaction. func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) { return d.BeginTx(ctx, nil) } // BeginTx starts a transaction with options. func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, error) { tx, err := d.DB().BeginTx(ctx, opts) if err != nil { return nil, err } return &Tx{ ExecQuerier: Conn{tx}, Tx: tx, }, nil } // Close closes the underlying connection. func (d *Driver) Close() error { return d.DB().Close() } // Tx implements dialect.Tx interface. type Tx struct { dialect.ExecQuerier driver.Tx } // ExecQuerier wraps the standard Exec and Query methods. type ExecQuerier interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } // Conn implements dialect.ExecQuerier given ExecQuerier. type Conn struct { ExecQuerier } // Exec implements the dialect.Exec method. func (c Conn) Exec(ctx context.Context, query string, args, v interface{}) error { argv, ok := args.([]interface{}) if !ok { return fmt.Errorf("dialect/sql: invalid type %T. expect []interface{} for args", v) } switch v := v.(type) { case nil: if _, err := c.ExecContext(ctx, query, argv...); err != nil { return err } case *sql.Result: res, err := c.ExecContext(ctx, query, argv...) if err != nil { return err } *v = res default: return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Result", v) } return nil } // Query implements the dialect.Query method. func (c Conn) Query(ctx context.Context, query string, args, v interface{}) error { vr, ok := v.(*Rows) if !ok { return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v) } argv, ok := args.([]interface{}) if !ok { return fmt.Errorf("dialect/sql: invalid type %T. expect []interface{} for args", args) } rows, err := c.QueryContext(ctx, query, argv...) if err != nil { return err } *vr = Rows{rows} return nil } var _ dialect.Driver = (*Driver)(nil) type ( // Rows wraps the sql.Rows to avoid locks copy. Rows struct{ *sql.Rows } // Result is an alias to sql.Result. Result = sql.Result // NullBool is an alias to sql.NullBool. NullBool = sql.NullBool // NullInt64 is an alias to sql.NullInt64. NullInt64 = sql.NullInt64 // NullString is an alias to sql.NullString. NullString = sql.NullString // NullFloat64 is an alias to sql.NullFloat64. NullFloat64 = sql.NullFloat64 // NullTime represents a time.Time that may be null. NullTime = sql.NullTime // TxOptions holds the transaction options to be used in DB.BeginTx. TxOptions = sql.TxOptions ) ent-0.5.4/dialect/sql/scan.go000066400000000000000000000135061377533537200160110ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sql import ( "database/sql" "database/sql/driver" "fmt" "reflect" "strings" ) // ColumnScanner is the interface that wraps the // four sql.Rows methods used for scanning. type ColumnScanner interface { Next() bool Scan(...interface{}) error Columns() ([]string, error) Err() error } // ScanOne scans one row to the given value. It fails if the rows holds more than 1 row. func ScanOne(rows ColumnScanner, v interface{}) error { columns, err := rows.Columns() if err != nil { return fmt.Errorf("sql/scan: failed getting column names: %v", err) } if n := len(columns); n != 1 { return fmt.Errorf("sql/scan: unexpected number of columns: %d", n) } if !rows.Next() { if err := rows.Err(); err != nil { return err } return sql.ErrNoRows } if err := rows.Scan(v); err != nil { return err } if rows.Next() { return fmt.Errorf("sql/scan: expect exactly one row in result set") } return rows.Err() } // ScanInt64 scans and returns an int64 from the rows columns. func ScanInt64(rows ColumnScanner) (int64, error) { var n int64 if err := ScanOne(rows, &n); err != nil { return 0, err } return n, nil } // ScanInt scans and returns an int from the rows columns. func ScanInt(rows ColumnScanner) (int, error) { n, err := ScanInt64(rows) if err != nil { return 0, err } return int(n), nil } // ScanString scans and returns a string from the rows columns. func ScanString(rows ColumnScanner) (string, error) { var s string if err := ScanOne(rows, &s); err != nil { return "", err } return s, nil } // ScanValue scans and returns a driver.Value from the rows columns. func ScanValue(rows ColumnScanner) (driver.Value, error) { var v driver.Value if err := ScanOne(rows, &v); err != nil { return "", err } return v, nil } // ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice. func ScanSlice(rows ColumnScanner, v interface{}) error { columns, err := rows.Columns() if err != nil { return fmt.Errorf("sql/scan: failed getting column names: %v", err) } rv := reflect.Indirect(reflect.ValueOf(v)) if k := rv.Kind(); k != reflect.Slice { return fmt.Errorf("sql/scan: invalid type %s. expected slice as an argument", k) } scan, err := scanType(rv.Type().Elem(), columns) if err != nil { return err } if n, m := len(columns), len(scan.columns); n > m { return fmt.Errorf("sql/scan: columns do not match (%d > %d)", n, m) } for rows.Next() { values := scan.values() if err := rows.Scan(values...); err != nil { return fmt.Errorf("sql/scan: failed scanning rows: %v", err) } vv := reflect.Append(rv, scan.value(values...)) rv.Set(vv) } return rows.Err() } // rowScan is the configuration for scanning one sql.Row. type rowScan struct { // column types of a row. columns []reflect.Type // value functions that converts the row columns (result) to a reflect.Value. value func(v ...interface{}) reflect.Value } // values returns a []interface{} from the configured column types. func (r *rowScan) values() []interface{} { values := make([]interface{}, len(r.columns)) for i := range r.columns { values[i] = reflect.New(r.columns[i]).Interface() } return values } // scanType returns rowScan for the given reflect.Type. func scanType(typ reflect.Type, columns []string) (*rowScan, error) { switch k := typ.Kind(); { case assignable(typ): return &rowScan{ columns: []reflect.Type{typ}, value: func(v ...interface{}) reflect.Value { return reflect.Indirect(reflect.ValueOf(v[0])) }, }, nil case k == reflect.Ptr: return scanPtr(typ, columns) case k == reflect.Struct: return scanStruct(typ, columns) default: return nil, fmt.Errorf("sql/scan: unsupported type ([]%s)", k) } } var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // assignable reports if the given type can be assigned directly by `Rows.Scan`. func assignable(typ reflect.Type) bool { switch k := typ.Kind(); { case typ.Implements(scannerType): case k == reflect.Interface && typ.NumMethod() == 0: case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64: case (k == reflect.Slice || k == reflect.Array) && typ.Elem().Kind() == reflect.Uint8: default: return false } return true } // scanStruct returns the a configuration for scanning an sql.Row into a struct. func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { var ( scan = &rowScan{} names = make(map[string]int) idx = make([]int, 0, typ.NumField()) ) for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) name := strings.ToLower(f.Name) if tag, ok := f.Tag.Lookup("sql"); ok { name = tag } else if tag, ok := f.Tag.Lookup("json"); ok { name = strings.Split(tag, ",")[0] } names[name] = i } for _, c := range columns { // normalize columns if necessary, for example: COUNT(*) => count. name := strings.ToLower(strings.Split(c, "(")[0]) i, ok := names[name] if !ok { return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name) } idx = append(idx, i) scan.columns = append(scan.columns, typ.Field(i).Type) } scan.value = func(vs ...interface{}) reflect.Value { st := reflect.New(typ).Elem() for i, v := range vs { st.Field(idx[i]).Set(reflect.Indirect(reflect.ValueOf(v))) } return st } return scan, nil } // scanPtr wraps the underlying type with rowScan. func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) { typ = typ.Elem() scan, err := scanType(typ, columns) if err != nil { return nil, err } wrap := scan.value scan.value = func(vs ...interface{}) reflect.Value { v := wrap(vs...) pt := reflect.PtrTo(v.Type()) pv := reflect.New(pt.Elem()) pv.Elem().Set(v) return pv } return scan, nil } ent-0.5.4/dialect/sql/scan_test.go000066400000000000000000000133311377533537200170440ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sql import ( "database/sql" "database/sql/driver" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/google/uuid" "github.com/stretchr/testify/require" ) func TestScanSlice(t *testing.T) { mock := sqlmock.NewRows([]string{"name"}). AddRow("foo"). AddRow("bar") var v0 []string require.NoError(t, ScanSlice(toRows(mock), &v0)) require.Equal(t, []string{"foo", "bar"}, v0) mock = sqlmock.NewRows([]string{"age"}). AddRow(1). AddRow(2) var v1 []int require.NoError(t, ScanSlice(toRows(mock), &v1)) require.Equal(t, []int{1, 2}, v1) mock = sqlmock.NewRows([]string{"name", "COUNT(*)"}). AddRow("foo", 1). AddRow("bar", 2) var v2 []struct { Name string Count int } require.NoError(t, ScanSlice(toRows(mock), &v2)) require.Equal(t, "foo", v2[0].Name) require.Equal(t, "bar", v2[1].Name) require.Equal(t, 1, v2[0].Count) require.Equal(t, 2, v2[1].Count) mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}). AddRow("foo", 1). AddRow("bar", 2) var v3 []struct { Count int Name string `json:"nick_name"` } require.NoError(t, ScanSlice(toRows(mock), &v3)) require.Equal(t, "foo", v3[0].Name) require.Equal(t, "bar", v3[1].Name) require.Equal(t, 1, v3[0].Count) require.Equal(t, 2, v3[1].Count) mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}). AddRow("foo", 1). AddRow("bar", 2) var v4 []*struct { Count int Name string `json:"nick_name"` Ignored string `json:"string"` } require.NoError(t, ScanSlice(toRows(mock), &v4)) require.Equal(t, "foo", v4[0].Name) require.Equal(t, "bar", v4[1].Name) require.Equal(t, 1, v4[0].Count) require.Equal(t, 2, v4[1].Count) mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}). AddRow("foo", 1). AddRow("bar", 2) var v5 []*struct { Count int Name string `json:"name" sql:"nick_name"` } require.NoError(t, ScanSlice(toRows(mock), &v5)) require.Equal(t, "foo", v5[0].Name) require.Equal(t, "bar", v5[1].Name) require.Equal(t, 1, v5[0].Count) require.Equal(t, 2, v5[1].Count) mock = sqlmock.NewRows([]string{"age", "name"}). AddRow(1, nil). AddRow(nil, "a8m") var v6 []struct { Age NullInt64 Name NullString } require.NoError(t, ScanSlice(toRows(mock), &v6)) require.EqualValues(t, 1, v6[0].Age.Int64) require.False(t, v6[0].Name.Valid) require.False(t, v6[1].Age.Valid) require.Equal(t, "a8m", v6[1].Name.String) u1, u2 := uuid.New().String(), uuid.New().String() mock = sqlmock.NewRows([]string{"ids"}). AddRow([]byte(u1)). AddRow([]byte(u2)) var ids []uuid.UUID require.NoError(t, ScanSlice(toRows(mock), &ids)) require.Equal(t, u1, ids[0].String()) require.Equal(t, u2, ids[1].String()) mock = sqlmock.NewRows([]string{"pids"}). AddRow([]byte(u1)). AddRow([]byte(u2)) var pids []*uuid.UUID require.NoError(t, ScanSlice(toRows(mock), &pids)) require.Equal(t, u1, pids[0].String()) require.Equal(t, u2, pids[1].String()) } func TestScanSlicePtr(t *testing.T) { mock := sqlmock.NewRows([]string{"name"}). AddRow("foo"). AddRow("bar") var v0 []*string require.NoError(t, ScanSlice(toRows(mock), &v0)) require.Equal(t, "foo", *v0[0]) require.Equal(t, "bar", *v0[1]) mock = sqlmock.NewRows([]string{"age"}). AddRow(1). AddRow(2) var v1 []**int require.NoError(t, ScanSlice(toRows(mock), &v1)) require.Equal(t, 1, **v1[0]) require.Equal(t, 2, **v1[1]) mock = sqlmock.NewRows([]string{"age", "name"}). AddRow(1, "a8m"). AddRow(2, "nati") var v2 []*struct { Age *int Name **string } require.NoError(t, ScanSlice(toRows(mock), &v2)) require.Equal(t, 1, *v2[0].Age) require.Equal(t, "a8m", **v2[0].Name) require.Equal(t, 2, *v2[1].Age) require.Equal(t, "nati", **v2[1].Name) } func TestScanInt64(t *testing.T) { mock := sqlmock.NewRows([]string{"age"}). AddRow("10"). AddRow("20") n, err := ScanInt64(toRows(mock)) require.Error(t, err) require.Zero(t, n) mock = sqlmock.NewRows([]string{"age", "count"}). AddRow("10", "1") n, err = ScanInt64(toRows(mock)) require.Error(t, err) require.Zero(t, n) mock = sqlmock.NewRows([]string{"count"}). AddRow(10) n, err = ScanInt64(toRows(mock)) require.NoError(t, err) require.EqualValues(t, 10, n) mock = sqlmock.NewRows([]string{"count"}). AddRow("10") n, err = ScanInt64(toRows(mock)) require.NoError(t, err) require.EqualValues(t, 10, n) } func TestScanValue(t *testing.T) { mock := sqlmock.NewRows([]string{"count"}). AddRow(10) n, err := ScanValue(toRows(mock)) require.NoError(t, err) require.EqualValues(t, 10, n) } func TestScanOne(t *testing.T) { mock := sqlmock.NewRows([]string{"name"}). AddRow("10"). AddRow("20") err := ScanOne(toRows(mock), new(string)) require.Error(t, err, "multiple lines") mock = sqlmock.NewRows([]string{"name"}). AddRow("10") err = ScanOne(toRows(mock), "") require.Error(t, err, "not a pointer") mock = sqlmock.NewRows([]string{"name"}). AddRow("10") var s string err = ScanOne(toRows(mock), &s) require.NoError(t, err) require.Equal(t, "10", s) } func TestInterface(t *testing.T) { mock := sqlmock.NewRows([]string{"age"}). AddRow("10"). AddRow("20") var values []driver.Value err := ScanSlice(toRows(mock), &values) require.NoError(t, err) require.Equal(t, []driver.Value{"10", "20"}, values) mock = sqlmock.NewRows([]string{"age"}). AddRow(10). AddRow(20) values = values[:0:0] err = ScanSlice(toRows(mock), &values) require.NoError(t, err) require.Equal(t, []driver.Value{int64(10), int64(20)}, values) } func toRows(mrows *sqlmock.Rows) *sql.Rows { db, mock, _ := sqlmock.New() mock.ExpectQuery("").WillReturnRows(mrows) rows, _ := db.Query("") return rows } ent-0.5.4/dialect/sql/schema/000077500000000000000000000000001377533537200157715ustar00rootroot00000000000000ent-0.5.4/dialect/sql/schema/migrate.go000066400000000000000000000447071377533537200177640ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "crypto/md5" "fmt" "math" "sort" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" ) const ( // TypeTable defines the table name holding the type information. TypeTable = "ent_types" // MaxTypes defines the max number of types can be created when // defining universal ids. The left 16-bits are reserved. MaxTypes = math.MaxUint16 ) // MigrateOption allows for managing schema configuration using functional options. type MigrateOption func(*Migrate) // WithGlobalUniqueID sets the universal ids options to the migration. // Defaults to false. func WithGlobalUniqueID(b bool) MigrateOption { return func(m *Migrate) { m.universalID = b } } // WithDropColumn sets the columns dropping option to the migration. // Defaults to false. func WithDropColumn(b bool) MigrateOption { return func(m *Migrate) { m.dropColumns = b } } // WithDropIndex sets the indexes dropping option to the migration. // Defaults to false. func WithDropIndex(b bool) MigrateOption { return func(m *Migrate) { m.dropIndexes = b } } // WithFixture sets the foreign-key renaming option to the migration when upgrading // ent from v0.1.0 (issue-#285). Defaults to false. func WithFixture(b bool) MigrateOption { return func(m *Migrate) { m.withFixture = b } } // WithForeignKeys enables creating foreign-key in ddl. Defaults to true. func WithForeignKeys(b bool) MigrateOption { return func(m *Migrate) { m.withForeignKeys = b } } // Migrate runs the migrations logic for the SQL dialects. type Migrate struct { sqlDialect universalID bool // global unique ids. dropColumns bool // drop deleted columns. dropIndexes bool // drop deleted indexes. withFixture bool // with fks rename fixture. withForeignKeys bool // with foreign keys typeRanges []string // types order by their range. } // NewMigrate create a migration structure for the given SQL driver. func NewMigrate(d dialect.Driver, opts ...MigrateOption) (*Migrate, error) { m := &Migrate{withForeignKeys: true} for _, opt := range opts { opt(m) } switch d.Dialect() { case dialect.MySQL: m.sqlDialect = &MySQL{Driver: d} case dialect.SQLite: m.sqlDialect = &SQLite{Driver: d, WithForeignKeys: m.withForeignKeys} case dialect.Postgres: m.sqlDialect = &Postgres{Driver: d} default: return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect()) } return m, nil } // Create creates all schema resources in the database. It works in an "append-only" // mode, which means, it only create tables, append column to tables or modifying column type. // // Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not // resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but // changing varchar(120) to varchar(255) is valid. For more info, see the convert function below. // // Note that SQLite dialect does not support (this moment) the "append-only" mode describe above, // since it's used only for testing. func (m *Migrate) Create(ctx context.Context, tables ...*Table) error { tx, err := m.Tx(ctx) if err != nil { return err } if err := m.init(ctx, tx); err != nil { return rollback(tx, err) } if m.universalID { if err := m.types(ctx, tx); err != nil { return rollback(tx, err) } } if err := m.create(ctx, tx, tables...); err != nil { return rollback(tx, err) } return tx.Commit() } func (m *Migrate) create(ctx context.Context, tx dialect.Tx, tables ...*Table) error { for _, t := range tables { m.setupTable(t) switch exist, err := m.tableExist(ctx, tx, t.Name); { case err != nil: return err case exist: curr, err := m.table(ctx, tx, t.Name) if err != nil { return err } if err := m.verify(ctx, tx, curr); err != nil { return err } if err := m.fixture(ctx, tx, curr, t); err != nil { return err } change, err := m.changeSet(curr, t) if err != nil { return err } if err := m.apply(ctx, tx, t.Name, change); err != nil { return err } default: // !exist query, args := m.tBuilder(t).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create table %q: %v", t.Name, err) } // If global unique identifier is enabled and it's not // a relation table, allocate a range for the table pk. if m.universalID && len(t.PrimaryKey) == 1 { if err := m.allocPKRange(ctx, tx, t); err != nil { return err } } // indexes. for _, idx := range t.Indexes { query, args := m.addIndex(idx, t.Name).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create index %q: %v", idx.Name, err) } } } } if !m.withForeignKeys { return nil } // Create foreign keys after tables were created/altered, // because circular foreign-key constraints are possible. for _, t := range tables { if len(t.ForeignKeys) == 0 { continue } fks := make([]*ForeignKey, 0, len(t.ForeignKeys)) for _, fk := range t.ForeignKeys { exist, err := m.fkExist(ctx, tx, fk.Symbol) if err != nil { return err } if !exist { fks = append(fks, fk) } } if len(fks) == 0 { continue } b := sql.Dialect(m.Dialect()).AlterTable(t.Name) for _, fk := range fks { b.AddForeignKey(fk.DSL()) } query, args := b.Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create foreign keys for %q: %v", t.Name, err) } } return nil } // apply applies changes on the given table. func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change *changes) error { // Constraints should be dropped before dropping columns, because if a column // is a part of multi-column constraints (like, unique index), ALTER TABLE // might fail if the intermediate state violates the constraints. if m.dropIndexes { if pr, ok := m.sqlDialect.(preparer); ok { if err := pr.prepare(ctx, tx, change, table); err != nil { return err } } for _, idx := range change.index.drop { if err := m.dropIndex(ctx, tx, idx, table); err != nil { return fmt.Errorf("drop index of table %q: %v", table, err) } } } var drop []*Column if m.dropColumns { drop = change.column.drop } queries := m.alterColumns(table, change.column.add, change.column.modify, drop) // If there's actual action to execute on ALTER TABLE. for i := range queries { query, args := queries[i].Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("alter table %q: %v", table, err) } } for _, idx := range change.index.add { query, args := m.addIndex(idx, table).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create index %q: %v", table, err) } } return nil } // changes to apply on existing table. type changes struct { // column changes. column struct { add []*Column drop []*Column modify []*Column } // index changes. index struct { add Indexes drop Indexes } } // dropColumn returns the dropped column by name (if any). func (c *changes) dropColumn(name string) (*Column, bool) { for _, col := range c.column.drop { if col.Name == name { return col, true } } return nil, false } // changeSet returns a changes object to be applied on existing table. // It fails if one of the changes is invalid. func (m *Migrate) changeSet(curr, new *Table) (*changes, error) { change := &changes{} // pks. if len(curr.PrimaryKey) != len(new.PrimaryKey) { return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) } sort.Slice(new.PrimaryKey, func(i, j int) bool { return new.PrimaryKey[i].Name < new.PrimaryKey[j].Name }) sort.Slice(curr.PrimaryKey, func(i, j int) bool { return curr.PrimaryKey[i].Name < curr.PrimaryKey[j].Name }) for i := range curr.PrimaryKey { if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name { return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) } } // Add or modify columns. for _, c1 := range new.Columns { // Ignore primary keys. if c1.PrimaryKey() { continue } switch c2, ok := curr.column(c1.Name); { case !ok: change.column.add = append(change.column.add, c1) case !c2.Type.Valid(): return nil, fmt.Errorf("invalid type %q for column %q", c2.typ, c2.Name) // Modify a non-unique column to unique. case c1.Unique && !c2.Unique: change.index.add.append(&Index{ Name: c1.Name, Unique: true, Columns: []*Column{c1}, columns: []string{c1.Name}, }) // Modify a unique column to non-unique. case !c1.Unique && c2.Unique: idx, ok := curr.index(c2.Name) if !ok { return nil, fmt.Errorf("missing index to drop for column %q", c2.Name) } change.index.drop.append(idx) // Extending column types. case m.cType(c1) != m.cType(c2): if !c2.ConvertibleTo(c1) { return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, m.cType(c1), m.cType(c2)) } fallthrough // Change nullability of a column. case c1.Nullable != c2.Nullable: change.column.modify = append(change.column.modify, c1) } } // Drop columns. for _, c1 := range curr.Columns { // If a column was dropped, multi-columns indexes that are associated with this column will // no longer behave the same. Therefore, these indexes should be dropped too. There's no need // to do it explicitly (here), because entc will remove them from the schema specification, // and they will be dropped in the block below. if _, ok := new.column(c1.Name); !ok { change.column.drop = append(change.column.drop, c1) } } // Add or modify indexes. for _, idx1 := range new.Indexes { switch idx2, ok := curr.index(idx1.Name); { case !ok: change.index.add.append(idx1) // Changing index cardinality require drop and create. case idx1.Unique != idx2.Unique: change.index.drop.append(idx2) change.index.add.append(idx1) } } // Drop indexes. for _, idx := range curr.Indexes { _, ok1 := new.fk(idx.Name) _, ok2 := new.index(idx.Name) if !ok1 && !ok2 { change.index.drop.append(idx) } } return change, nil } // fixture is a special migration code for renaming foreign-key columns (issue-#285). func (m *Migrate) fixture(ctx context.Context, tx dialect.Tx, curr, new *Table) error { d, ok := m.sqlDialect.(fkRenamer) if !m.withFixture || !m.withForeignKeys || !ok { return nil } rename := make(map[string]*Index) for _, fk := range new.ForeignKeys { ok, err := m.fkExist(ctx, tx, fk.Symbol) if err != nil { return fmt.Errorf("checking foreign-key existence %q: %v", fk.Symbol, err) } if !ok { continue } column, err := m.fkColumn(ctx, tx, fk) if err != nil { return err } newcol := fk.Columns[0] if column == newcol.Name { continue } query, args := d.renameColumn(curr, &Column{Name: column}, newcol).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("rename column %q: %v", column, err) } prev, ok := curr.column(column) if !ok { continue } // Find all indexes that ~maybe need to be renamed. for _, idx := range prev.indexes { switch _, ok := new.index(idx.Name); { // Ignore indexes that exist in the schema, PKs. case ok || idx.primary: // Index that was created implicitly for a unique // column needs to be renamed to the column name. case d.isImplicitIndex(idx, prev): idx2 := &Index{Name: newcol.Name, Unique: true, Columns: []*Column{newcol}} query, args := d.renameIndex(curr, idx, idx2).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("rename index %q: %v", prev.Name, err) } idx.Name = idx2.Name default: rename[idx.Name] = idx } } // Update the name of the loaded column, so `changeSet` won't create it. prev.Name = newcol.Name } // Go over the indexes that need to be renamed // and find their ~identical in the new schema. for _, idx := range rename { Find: // Find its ~identical in the new schema, and rename it // if it doesn't exist. for _, idx2 := range new.Indexes { if _, ok := curr.index(idx2.Name); ok { continue } if idx.sameAs(idx2) { query, args := d.renameIndex(curr, idx, idx2).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("rename index %q: %v", idx.Name, err) } idx.Name = idx2.Name break Find } } } return nil } // verify verifies that the auto-increment counter is correct for table with universal-id support. func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error { vr, ok := m.sqlDialect.(verifyRanger) if !ok || !m.universalID { return nil } id := indexOf(m.typeRanges, t.Name) if id == -1 { return nil } return vr.verifyRange(ctx, tx, t, id<<32) } // types loads the type list from the database. // If the table does not create, it will create one. func (m *Migrate) types(ctx context.Context, tx dialect.Tx) error { exists, err := m.tableExist(ctx, tx, TypeTable) if err != nil { return err } if !exists { t := NewTable(TypeTable). AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}). AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}) query, args := m.tBuilder(t).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create types table: %v", err) } return nil } rows := &sql.Rows{} query, args := sql.Dialect(m.Dialect()). Select("type").From(sql.Table(TypeTable)).OrderBy(sql.Asc("id")).Query() if err := tx.Query(ctx, query, args, rows); err != nil { return fmt.Errorf("query types table: %v", err) } defer rows.Close() return sql.ScanSlice(rows, &m.typeRanges) } func (m *Migrate) allocPKRange(ctx context.Context, tx dialect.Tx, t *Table) error { id := indexOf(m.typeRanges, t.Name) // If the table re-created, re-use its range from // the past. otherwise, allocate a new id-range. if id == -1 { if len(m.typeRanges) > MaxTypes { return fmt.Errorf("max number of types exceeded: %d", MaxTypes) } query, args := sql.Dialect(m.Dialect()). Insert(TypeTable).Columns("type").Values(t.Name).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("insert into type: %v", err) } id = len(m.typeRanges) m.typeRanges = append(m.typeRanges, t.Name) } // Set the id offset for table. return m.setRange(ctx, tx, t, id<<32) } // fkColumn returns the column name of a foreign-key. func (m *Migrate) fkColumn(ctx context.Context, tx dialect.Tx, fk *ForeignKey) (string, error) { t1 := sql.Table("INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS t1").Unquote().As("t1") t2 := sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t2").Unquote().As("t2") query, args := sql.Dialect(m.Dialect()). Select("column_name"). From(t1). Join(t2). On(t1.C("constraint_name"), t2.C("constraint_name")). Where(sql.And( sql.EQ(t2.C("constraint_type"), sql.Raw("'FOREIGN KEY'")), sql.EQ(t2.C("table_schema"), m.sqlDialect.(fkRenamer).tableSchema()), sql.EQ(t1.C("table_schema"), m.sqlDialect.(fkRenamer).tableSchema()), sql.EQ(t2.C("constraint_name"), fk.Symbol), )). Query() rows := &sql.Rows{} if err := tx.Query(ctx, query, args, rows); err != nil { return "", fmt.Errorf("reading foreign-key %q column: %v", fk.Symbol, err) } defer rows.Close() column, err := sql.ScanString(rows) if err != nil { return "", fmt.Errorf("scanning foreign-key %q column: %v", fk.Symbol, err) } return column, nil } // setup ensures the table is configured properly, like table columns // are linked to their indexes, and PKs columns are defined. func (m *Migrate) setupTable(t *Table) { if t.columns == nil { t.columns = make(map[string]*Column, len(t.Columns)) } for _, c := range t.Columns { t.columns[c.Name] = c } for _, idx := range t.Indexes { idx.Name = m.symbol(idx.Name) for _, c := range idx.Columns { c.indexes.append(idx) } } for _, pk := range t.PrimaryKey { c := t.columns[pk.Name] c.Key = PrimaryKey pk.Key = PrimaryKey } for _, fk := range t.ForeignKeys { fk.Symbol = m.symbol(fk.Symbol) for i := range fk.Columns { fk.Columns[i].foreign = fk } } } // symbol makes sure the symbol length is not longer than the maxlength in the dialect. func (m *Migrate) symbol(name string) string { size := 64 if m.Dialect() == dialect.Postgres { size = 63 } if len(name) <= size { return name } return fmt.Sprintf("%s_%x", name[:size-33], md5.Sum([]byte(name))) } // rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred. func rollback(tx dialect.Tx, err error) error { err = fmt.Errorf("sql/schema: %v", err) if rerr := tx.Rollback(); rerr != nil { err = fmt.Errorf("%s: %v", err.Error(), rerr) } return err } // exist checks if the given COUNT query returns a value >= 1. func exist(ctx context.Context, tx dialect.Tx, query string, args ...interface{}) (bool, error) { rows := &sql.Rows{} if err := tx.Query(ctx, query, args, rows); err != nil { return false, fmt.Errorf("reading schema information %v", err) } defer rows.Close() n, err := sql.ScanInt(rows) if err != nil { return false, err } return n > 0, nil } func indexOf(a []string, s string) int { for i := range a { if a[i] == s { return i } } return -1 } type sqlDialect interface { dialect.Driver init(context.Context, dialect.Tx) error table(context.Context, dialect.Tx, string) (*Table, error) tableExist(context.Context, dialect.Tx, string) (bool, error) fkExist(context.Context, dialect.Tx, string) (bool, error) setRange(context.Context, dialect.Tx, *Table, int) error dropIndex(context.Context, dialect.Tx, *Index, string) error // table, column and index builder per dialect. cType(*Column) string tBuilder(*Table) *sql.TableBuilder addIndex(*Index, string) *sql.IndexBuilder alterColumns(table string, add, modify, drop []*Column) sql.Queries } type preparer interface { prepare(context.Context, dialect.Tx, *changes, string) error } // fkRenamer is used by the fixture migration (to solve #285), // and it's implemented by the different dialects for renaming FKs. type fkRenamer interface { tableSchema() sql.Querier isImplicitIndex(*Index, *Column) bool renameIndex(*Table, *Index, *Index) sql.Querier renameColumn(*Table, *Column, *Column) sql.Querier } // verifyRanger wraps the method for verifying global-id range correctness. type verifyRanger interface { verifyRange(context.Context, dialect.Tx, *Table, int) error } ent-0.5.4/dialect/sql/schema/mysql.go000066400000000000000000000447111377533537200174740ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "database/sql/driver" "fmt" "math" "strconv" "strings" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" ) // MySQL is a MySQL migration driver. type MySQL struct { dialect.Driver version string } // init loads the MySQL version from the database for later use in the migration process. func (d *MySQL) init(ctx context.Context, tx dialect.Tx) error { rows := &sql.Rows{} if err := tx.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil { return fmt.Errorf("mysql: querying mysql version %v", err) } defer rows.Close() if !rows.Next() { if err := rows.Err(); err != nil { return err } return fmt.Errorf("mysql: version variable was not found") } version := make([]string, 2) if err := rows.Scan(&version[0], &version[1]); err != nil { return fmt.Errorf("mysql: scanning mysql version: %v", err) } d.version = version[1] return nil } func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) { query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")). Where(sql.And( sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")), sql.EQ("TABLE_NAME", name), )).Query() return exist(ctx, tx, query, args...) } func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) { query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLE_CONSTRAINTS").Schema("INFORMATION_SCHEMA")). Where(sql.And( sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")), sql.EQ("CONSTRAINT_TYPE", "FOREIGN KEY"), sql.EQ("CONSTRAINT_NAME", name), )).Query() return exist(ctx, tx, query, args...) } // table loads the current table description from the database. func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { rows := &sql.Rows{} query, args := sql.Select("column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"). From(sql.Table("COLUMNS").Schema("INFORMATION_SCHEMA")). Where(sql.And( sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")), sql.EQ("TABLE_NAME", name)), ).Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("mysql: reading table description %v", err) } // Call Close in cases of failures (Close is idempotent). defer rows.Close() t := NewTable(name) for rows.Next() { c := &Column{} if err := d.scanColumn(c, rows); err != nil { return nil, fmt.Errorf("mysql: %v", err) } if c.PrimaryKey() { t.PrimaryKey = append(t.PrimaryKey, c) } t.AddColumn(c) } if err := rows.Err(); err != nil { return nil, err } if err := rows.Close(); err != nil { return nil, fmt.Errorf("mysql: closing rows %v", err) } indexes, err := d.indexes(ctx, tx, name) if err != nil { return nil, err } // Add and link indexes to table columns. for _, idx := range indexes { t.AddIndex(idx.Name, idx.Unique, idx.columns) } if _, ok := d.mariadb(); ok { if err := d.normalizeJSON(ctx, tx, t); err != nil { return nil, err } } return t, nil } // table loads the table indexes from the database. func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Index, error) { rows := &sql.Rows{} query, args := sql.Select("index_name", "column_name", "non_unique", "seq_in_index"). From(sql.Table("STATISTICS").Schema("INFORMATION_SCHEMA")). Where(sql.And( sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")), sql.EQ("TABLE_NAME", name), )). OrderBy("index_name", "seq_in_index"). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("mysql: reading index description %v", err) } defer rows.Close() idx, err := d.scanIndexes(rows) if err != nil { return nil, fmt.Errorf("mysql: %v", err) } return idx, nil } func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error { return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []interface{}{}, nil) } func (d *MySQL) verifyRange(ctx context.Context, tx dialect.Tx, t *Table, expected int) error { if expected == 0 { return nil } rows := &sql.Rows{} query, args := sql.Select("AUTO_INCREMENT"). From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")). Where(sql.And( sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")), sql.EQ("TABLE_NAME", t.Name), )). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return fmt.Errorf("mysql: query auto_increment %v", err) } // Call Close in cases of failures (Close is idempotent). defer rows.Close() actual := &sql.NullInt64{} if err := sql.ScanOne(rows, actual); err != nil { return fmt.Errorf("mysql: scan auto_increment %v", err) } if err := rows.Close(); err != nil { return err } // Table is empty and auto-increment is not configured. This can happen // because MySQL (< 8.0) stores the auto-increment counter in main memory // (not persistent), and the value is reset on restart (if table is empty). if actual.Int64 <= 1 { return d.setRange(ctx, tx, t, expected) } return nil } // tBuilder returns the MySQL DSL query for table creation. func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { b := sql.CreateTable(t.Name).IfNotExists() for _, c := range t.Columns { b.Column(d.addColumn(c)) } for _, pk := range t.PrimaryKey { b.PrimaryKey(pk.Name) } // Charset and collation config on MySQL table. // These options can be overridden by the entsql annotation. b.Charset("utf8mb4").Collate("utf8mb4_bin") if t.Annotation != nil { if charset := t.Annotation.Charset; charset != "" { b.Charset(charset) } if collate := t.Annotation.Collation; collate != "" { b.Collate(collate) } if opts := t.Annotation.Options; opts != "" { b.Options(opts) } } return b } // cType returns the MySQL string type for the given column. func (d *MySQL) cType(c *Column) (t string) { if c.SchemaType != nil && c.SchemaType[dialect.MySQL] != "" { // MySQL returns the column type lower cased. return strings.ToLower(c.SchemaType[dialect.MySQL]) } switch c.Type { case field.TypeBool: t = "boolean" case field.TypeInt8: t = "tinyint" case field.TypeUint8: t = "tinyint unsigned" case field.TypeInt16: t = "smallint" case field.TypeUint16: t = "smallint unsigned" case field.TypeInt32: t = "int" case field.TypeUint32: t = "int unsigned" case field.TypeInt, field.TypeInt64: t = "bigint" case field.TypeUint, field.TypeUint64: t = "bigint unsigned" case field.TypeBytes: size := int64(math.MaxUint16) if c.Size > 0 { size = c.Size } switch { case size <= math.MaxUint8: t = "tinyblob" case size <= math.MaxUint16: t = "blob" case size < 1<<24: t = "mediumblob" case size <= math.MaxUint32: t = "longblob" } case field.TypeJSON: t = "json" if compareVersions(d.version, "5.7.8") == -1 { t = "longblob" } case field.TypeString: size := c.Size if size == 0 { size = c.defaultSize(d.version) } if size <= math.MaxUint16 { t = fmt.Sprintf("varchar(%d)", size) } else { t = "longtext" } case field.TypeFloat32, field.TypeFloat64: t = c.scanTypeOr("double") case field.TypeTime: t = c.scanTypeOr("timestamp") // In MySQL, timestamp columns are `NOT NULL` by default, and assigning NULL // assigns the current_timestamp(). We avoid this if not set otherwise. c.Nullable = c.Attr == "" case field.TypeEnum: values := make([]string, len(c.Enums)) for i, e := range c.Enums { values[i] = fmt.Sprintf("'%s'", e) } t = fmt.Sprintf("enum(%s)", strings.Join(values, ", ")) case field.TypeUUID: t = "char(36) binary" default: panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name)) } return t } // addColumn returns the DSL query for adding the given column to a table. // The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable]. func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder { b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr) c.unique(b) if c.Increment { b.Attr("AUTO_INCREMENT") } c.nullable(b) c.defaultValue(b) if c.Type == field.TypeJSON { // Manually add a `CHECK` clause for older versions of MariaDB for validating the // JSON documents. This constraint is automatically included from version 10.4.3. if version, ok := d.mariadb(); ok && compareVersions(version, "10.4.3") == -1 { b.Check(func(b *sql.Builder) { b.WriteString("JSON_VALID(").Ident(c.Name).WriteByte(')') }) } } return b } // addIndex returns the querying for adding an index to MySQL. func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder { return i.Builder(table) } // dropIndex drops a MySQL index. func (d *MySQL) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error { query, args := idx.DropBuilder(table).Query() return tx.Exec(ctx, query, args, nil) } // prepare runs preparation work that needs to be done to apply the change-set. func (d *MySQL) prepare(ctx context.Context, tx dialect.Tx, change *changes, table string) error { for _, idx := range change.index.drop { switch n := len(idx.columns); { case n == 0: return fmt.Errorf("index %q has no columns", idx.Name) case n > 1: continue // not a foreign-key index. } var qr sql.Querier Switch: switch col, ok := change.dropColumn(idx.columns[0]); { // If both the index and the column need to be dropped, the foreign-key // constraint that is associated with them need to be dropped as well. case ok: names, err := fkNames(ctx, tx, table, col.Name) if err != nil { return err } if len(names) == 1 { qr = sql.AlterTable(table).DropForeignKey(names[0]) } // If the uniqueness was dropped from a foreign-key column, // create a "simple index" if no other index exist for it. case !ok && idx.Unique && len(idx.Columns) > 0: col := idx.Columns[0] for _, idx2 := range col.indexes { if idx2 != idx && len(idx2.columns) == 1 { break Switch } } names, err := fkNames(ctx, tx, table, col.Name) if err != nil { return err } if len(names) == 1 { qr = sql.CreateIndex(names[0]).Table(table).Columns(col.Name) } } if qr != nil { query, args := qr.Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return err } } } return nil } // scanColumn scans the column information from MySQL column description. func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error { var ( nullable sql.NullString defaults sql.NullString ) if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}); err != nil { return fmt.Errorf("scanning column description: %v", err) } c.Unique = c.UniqueKey() if nullable.Valid { c.Nullable = nullable.String == "YES" } parts, size, unsigned, err := parseColumn(c.typ) if err != nil { return err } switch parts[0] { case "mediumint", "int": c.Type = field.TypeInt32 if unsigned { c.Type = field.TypeUint32 } case "smallint": c.Type = field.TypeInt16 if unsigned { c.Type = field.TypeUint16 } case "bigint": c.Type = field.TypeInt64 if unsigned { c.Type = field.TypeUint64 } case "tinyint": switch { case size == 1: c.Type = field.TypeBool case unsigned: c.Type = field.TypeUint8 default: c.Type = field.TypeInt8 } case "numeric", "decimal", "double": c.Type = field.TypeFloat64 case "time", "timestamp", "date", "datetime": c.Type = field.TypeTime // The mapping from schema defaults to database // defaults is not supported for TypeTime fields. defaults = sql.NullString{} case "tinyblob": c.Size = math.MaxUint8 c.Type = field.TypeBytes case "blob": c.Size = math.MaxUint16 c.Type = field.TypeBytes case "mediumblob": c.Size = 1<<24 - 1 c.Type = field.TypeBytes case "longblob": c.Size = math.MaxUint32 c.Type = field.TypeBytes case "binary", "varbinary": c.Type = field.TypeBytes c.Size = size case "varchar": c.Type = field.TypeString c.Size = size case "longtext": c.Size = math.MaxInt32 c.Type = field.TypeString case "json": c.Type = field.TypeJSON case "enum": c.Type = field.TypeEnum c.Enums = make([]string, len(parts)-1) for i, e := range parts[1:] { c.Enums[i] = strings.Trim(e, "'") } case "char": // UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens). if size != 36 { return fmt.Errorf("unknown char(%d) type (not a uuid)", size) } c.Type = field.TypeUUID default: return fmt.Errorf("unknown column type %q for version %q", parts[0], d.version) } if defaults.Valid { return c.ScanDefault(defaults.String) } return nil } // scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows, // should return the following 4 columns: INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX. // SEQ_IN_INDEX specifies the position of the column in the index columns. func (d *MySQL) scanIndexes(rows *sql.Rows) (Indexes, error) { var ( i Indexes names = make(map[string]*Index) ) for rows.Next() { var ( name string column string nonuniq bool seqindex int ) if err := rows.Scan(&name, &column, &nonuniq, &seqindex); err != nil { return nil, fmt.Errorf("scanning index description: %v", err) } // Ignore primary keys. if name == "PRIMARY" { continue } idx, ok := names[name] if !ok { idx = &Index{Name: name, Unique: !nonuniq} i = append(i, idx) names[name] = idx } idx.columns = append(idx.columns, column) } if err := rows.Err(); err != nil { return nil, err } return i, nil } // isImplicitIndex reports if the index was created implicitly for the unique column. func (d *MySQL) isImplicitIndex(idx *Index, col *Column) bool { // We execute `CHANGE COLUMN` on older versions of MySQL (<8.0), which // auto create the new index. The old one, will be dropped in `changeSet`. if compareVersions(d.version, "8.0.0") >= 0 { return idx.Name == col.Name && col.Unique } return false } // renameColumn returns the statement for renaming a column in // MySQL based on its version. func (d *MySQL) renameColumn(t *Table, old, new *Column) sql.Querier { q := sql.AlterTable(t.Name) if compareVersions(d.version, "8.0.0") >= 0 { return q.RenameColumn(old.Name, new.Name) } return q.ChangeColumn(old.Name, d.addColumn(new)) } // renameIndex returns the statement for renaming an index. func (d *MySQL) renameIndex(t *Table, old, new *Index) sql.Querier { q := sql.AlterTable(t.Name) if compareVersions(d.version, "5.7.0") >= 0 { return q.RenameIndex(old.Name, new.Name) } return q.DropIndex(old.Name).AddIndex(new.Builder(t.Name)) } // tableSchema returns the query for getting the table schema. func (d *MySQL) tableSchema() sql.Querier { return sql.Raw("(SELECT DATABASE())") } // alterColumns returns the queries for applying the columns change-set. func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries { b := sql.Dialect(dialect.MySQL).AlterTable(table) for _, c := range add { b.AddColumn(d.addColumn(c)) } for _, c := range modify { b.ModifyColumn(d.addColumn(c)) } for _, c := range drop { b.DropColumn(sql.Dialect(dialect.MySQL).Column(c.Name)) } if len(b.Queries) == 0 { return nil } return sql.Queries{b} } // normalizeJSON normalize MariaDB longtext columns to type JSON. func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) error { var ( names []driver.Value columns = make(map[string]*Column) ) for _, c := range t.Columns { if c.typ == "longtext" { columns[c.Name] = c names = append(names, c.Name) } } if len(names) == 0 { return nil } rows := &sql.Rows{} query, args := sql.Select("CONSTRAINT_NAME", "CHECK_CLAUSE"). From(sql.Table("CHECK_CONSTRAINTS").Schema("INFORMATION_SCHEMA")). Where(sql.And( sql.EQ("CONSTRAINT_SCHEMA", sql.Raw("(SELECT DATABASE())")), sql.EQ("TABLE_NAME", t.Name), sql.InValues("CONSTRAINT_NAME", names...), )). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return fmt.Errorf("mysql: query table constraints %v", err) } // Call Close in cases of failures (Close is idempotent). defer rows.Close() for rows.Next() { var name, check string if err := rows.Scan(&name, &check); err != nil { return fmt.Errorf("mysql: scan table constraints") } c, ok := columns[name] if !ok || !strings.HasPrefix(check, "json_valid") { continue } c.Type = field.TypeJSON } if err := rows.Err(); err != nil { return err } return rows.Close() } // mariadb reports if the migration runs on MariaDB and returns the semver string. func (d *MySQL) mariadb() (string, bool) { idx := strings.Index(d.version, "MariaDB") if idx == -1 { return "", false } return d.version[:idx-1], true } // parseColumn returns column parts, size and signed-info from a MySQL type. func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) { switch parts = strings.FieldsFunc(typ, func(r rune) bool { return r == '(' || r == ')' || r == ' ' || r == ',' }); parts[0] { case "tinyint", "smallint", "mediumint", "int", "bigint": switch { case len(parts) == 2 && parts[1] == "unsigned": // int unsigned unsigned = true case len(parts) == 3: // int(10) unsigned unsigned = true fallthrough case len(parts) == 2: // int(10) size, err = strconv.ParseInt(parts[1], 10, 0) } case "varbinary", "varchar", "char", "binary": size, err = strconv.ParseInt(parts[1], 10, 64) } if err != nil { return parts, size, unsigned, fmt.Errorf("converting %s size to int: %v", parts[0], err) } return parts, size, unsigned, nil } // fkNames returns the foreign-key names of a column. func fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) { query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("KEY_COLUMN_USAGE").Schema("INFORMATION_SCHEMA")). Where(sql.And( sql.EQ("TABLE_NAME", table), sql.EQ("COLUMN_NAME", column), // NULL for unique and primary-key constraints. sql.NotNull("POSITION_IN_UNIQUE_CONSTRAINT"), sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")), )). Query() var ( names []string rows = &sql.Rows{} ) if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("mysql: reading constraint names %v", err) } defer rows.Close() if err := sql.ScanSlice(rows, &names); err != nil { return nil, err } return names, nil } ent-0.5.4/dialect/sql/schema/mysql_test.go000066400000000000000000001643311377533537200205340ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "math" "regexp" "strings" "testing" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/entsql" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestMySQL_Create(t *testing.T) { tests := []struct { name string tables []*Table options []MigrateOption before func(mysqlMock) wantErr bool }{ { name: "tx failed", before: func(mock mysqlMock) { mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled) }, wantErr: true, }, { name: "no tables", before: func(mock mysqlMock) { mock.start("5.7.23") mock.ExpectCommit() }, }, { name: "create new table", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeInt}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, {Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "datetime", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Nullable: true}, {Name: "decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2)"}}, }, Annotation: &entsql.Annotation{ Charset: "utf8", Collation: "utf8_general_ci", Options: "ENGINE = INNODB", }, }, }, before: func(mock mysqlMock) { mock.start("5.7.8") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `age` bigint NOT NULL, `doc` json NULL, `enums` enum('a', 'b') NOT NULL, `uuid` char(36) binary NULL, `datetime` datetime NULL, `decimal` decimal(6,2) NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8 COLLATE utf8_general_ci ENGINE = INNODB")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table 5.6", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, {Name: "name", Type: field.TypeString, Unique: true}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.6.35") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `age` bigint NOT NULL, `name` varchar(191) UNIQUE NOT NULL, `doc` longblob NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` bigint NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.fkExists("pets_owner", false) mock.ExpectExec(escape("ALTER TABLE `pets` ADD CONSTRAINT `pets_owner` FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key disabled", options: []MigrateOption{ WithForeignKeys(false), }, tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` bigint NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add column to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "date", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{dialect.MySQL: "date"}}, {Name: "age", Type: field.TypeInt}, {Name: "tiny", Type: field.TypeInt8}, {Name: "tiny_unsigned", Type: field.TypeUint8}, {Name: "small", Type: field.TypeInt16}, {Name: "small_unsigned", Type: field.TypeUint16}, {Name: "big", Type: field.TypeInt64}, {Name: "big_unsigned", Type: field.TypeUint64}, {Name: "decimal", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2)"}}, {Name: "timestamp", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "TIMESTAMP"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("8.0.19") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", ""). AddRow("text", "longtext", "YES", "YES", "NULL", "", "", ""). AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin"). AddRow("date", "date", "YES", "YES", "NULL", "", "", ""). // 8.0.19: new int column type formats AddRow("tiny", "tinyint", "NO", "YES", "NULL", "", "", ""). AddRow("tiny_unsigned", "tinyint unsigned", "NO", "YES", "NULL", "", "", ""). AddRow("small", "smallint", "NO", "YES", "NULL", "", "", ""). AddRow("small_unsigned", "smallint unsigned", "NO", "YES", "NULL", "", "", ""). AddRow("big", "bigint", "NO", "YES", "NULL", "", "", ""). AddRow("big_unsigned", "bigint unsigned", "NO", "YES", "NULL", "", "", ""). AddRow("decimal", "decimal(6,2)", "NO", "YES", "NULL", "", "", ""). AddRow("timestamp", "timestamp", "NO", "NO", "CURRENT_TIMESTAMP", "DEFAULT_GENERATED on update CURRENT_TIMESTAMP", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` bigint NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "enums", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "enums1", Type: field.TypeEnum, Enums: []string{"a", "b"}}, // add enum. {Name: "enums2", Type: field.TypeEnum, Enums: []string{"a"}}, // remove enum. }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", ""). AddRow("enums1", "enum('a')", "YES", "NO", "NULL", "", "", ""). AddRow("enums2", "enum('b', 'a')", "NO", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `enums1` enum('a', 'b') NOT NULL, MODIFY COLUMN `enums2` enum('a') NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "datetime and timestamp", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Nullable: true}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Nullable: true}, {Name: "deleted_at", Type: field.TypeTime, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("created_at", "datetime", "NO", "YES", "NULL", "", "", ""). AddRow("updated_at", "timestamp", "NO", "YES", "NULL", "", "", ""). AddRow("deleted_at", "datetime", "NO", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `updated_at` datetime NULL, MODIFY COLUMN `deleted_at` timestamp NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add int column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeInt, Default: 10}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.6.0") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", ""). AddRow("doc", "longblob", "YES", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` bigint NOT NULL DEFAULT 10")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add blob columns", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "tiny", Type: field.TypeBytes, Size: 100}, {Name: "blob", Type: field.TypeBytes, Size: 1e3}, {Name: "medium", Type: field.TypeBytes, Size: 1e5}, {Name: "long", Type: field.TypeBytes, Size: 1e8}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `tiny` tinyblob NOT NULL, ADD COLUMN `blob` blob NOT NULL, ADD COLUMN `medium` mediumblob NOT NULL, ADD COLUMN `long` longblob NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add binary column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "binary", Type: field.TypeBytes, Size: 20, SchemaType: map[string]string{dialect.MySQL: "binary(20)"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("8.0.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `binary` binary(20) NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "accept varbinary columns", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "tiny", Type: field.TypeBytes, Size: 100}, {Name: "medium", Type: field.TypeBytes, Size: math.MaxUint32}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("tiny", "varbinary(255)", "NO", "YES", "NULL", "", "", ""). AddRow("medium", "varbinary(255)", "NO", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `medium` longblob NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add float column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeFloat64, Default: 10.1}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec("ALTER TABLE `users` ADD COLUMN `age` double NOT NULL DEFAULT 10.1"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add bool column with default value", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeBool, Default: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec("ALTER TABLE `users` ADD COLUMN `age` boolean NOT NULL DEFAULT true"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add string column with default value", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "nick", Type: field.TypeString, Default: "unknown"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `nick` varchar(255) NOT NULL DEFAULT 'unknown'")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add column with unsupported default value", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "nick", Type: field.TypeString, Size: 1 << 17, Default: "unknown"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `nick` longtext NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "drop columns", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropColumn(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` DROP COLUMN `name`")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "modify column to nullable", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, {Name: "name", Type: field.TypeString, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", ""). AddRow("age", "bigint(20)", "NO", "NO", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `name` varchar(255) NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "apply uniqueness on column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt, Unique: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("age", "bigint(20)", "NO", "", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) // create the unique index. mock.ExpectExec(escape("CREATE UNIQUE INDEX `age` ON `users`(`age`)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "remove uniqueness from column without option", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("age", "bigint(20)", "NO", "UNI", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1"). AddRow("age", "age", "0", "1")) mock.ExpectCommit() }, }, { name: "remove uniqueness from column with option", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropIndex(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("age", "bigint(20)", "NO", "UNI", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1"). AddRow("age", "age", "0", "1")) // check if a foreign-key needs to be dropped. mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). WithArgs("users", "age"). WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"})) // drop the unique index. mock.ExpectExec(escape("DROP INDEX `age` ON `users`")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "ignore foreign keys on index dropping", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "parent_id", Type: field.TypeInt, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, ForeignKeys: []*ForeignKey{ { Symbol: "parent_id", Columns: []*Column{ {Name: "parent_id", Type: field.TypeInt, Nullable: true}, }, }, }, }, }, options: []MigrateOption{WithDropIndex(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1"). AddRow("old_index", "old", "0", "1"). AddRow("parent_id", "parent_id", "0", "1")) // drop the unique index. mock.ExpectExec(escape("DROP INDEX `old_index` ON `users`")). WillReturnResult(sqlmock.NewResult(0, 1)) // foreign key already exist. mock.fkExists("parent_id", true) mock.ExpectCommit() }, }, { name: "drop foreign key with column and index", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropIndex(true), WithDropColumn(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1"). AddRow("parent_id", "parent_id", "0", "1")) // check if a foreign-key needs to be dropped. mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). WithArgs("users", "parent_id"). WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"}).AddRow("users_parent_id")) mock.ExpectExec(escape("ALTER TABLE `users` DROP FOREIGN KEY `users_parent_id`")). WillReturnResult(sqlmock.NewResult(0, 1)) // drop the unique index. mock.ExpectExec(escape("DROP INDEX `parent_id` ON `users`")). WillReturnResult(sqlmock.NewResult(0, 1)) // drop the unique index. mock.ExpectExec(escape("ALTER TABLE `users` DROP COLUMN `parent_id`")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create a new simple-index for the foreign-key", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "parent_id", Type: field.TypeInt, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropIndex(true), WithDropColumn(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1"). AddRow("parent_id", "parent_id", "0", "1")) // check if there's a foreign-key that is associated with this index. mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). WithArgs("users", "parent_id"). WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"}).AddRow("users_parent_id")) // create a new index, to replace the old one (that needs to be dropped). mock.ExpectExec(escape("CREATE INDEX `users_parent_id` ON `users`(`parent_id`)")). WillReturnResult(sqlmock.NewResult(0, 1)) // drop the unique index. mock.ExpectExec(escape("DROP INDEX `parent_id` ON `users`")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add edge to table", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "spouse_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "user_spouse" + strings.Repeat("_", 64), // super long fk. Columns: c1[2:], RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) t1.ForeignKeys[0].RefTable = t1 return []*Table{t1} }(), before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` bigint NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.fkExists("user_spouse_____________________390ed76f91d3c57cd3516e7690f621dc", false) mock.ExpectExec("ALTER TABLE `users` ADD CONSTRAINT `.{64}` FOREIGN KEY\\(`spouse_id`\\) REFERENCES `users`\\(`id`\\) ON DELETE CASCADE"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for all tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("ent_types", false) // create ent_types table. mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `ent_types`(`id` bigint unsigned AUTO_INCREMENT NOT NULL, `type` varchar(255) UNIQUE NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("users"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("groups"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for new tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) mock.tableExists("users", true) // users table has no changes. mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) // query groups table. mock.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("groups"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for restored tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range (without inserting to ent_types). mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("groups", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id mismatch with ent_types", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). WillReturnRows(sqlmock.NewRows([]string{"type"}). AddRow("deleted"). AddRow("users")) mock.tableExists("users", true) // users table has no changes. mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) // query the auto-increment value. mock.ExpectQuery(escape("SELECT `AUTO_INCREMENT` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"AUTO_INCREMENT"}). AddRow(1)) // restore the auto-increment counter. mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 4294967296")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, // MariaDB specific tests. { name: "mariadb/10.2.32/create table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "json", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("10.2.32-MariaDB") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL CHECK (JSON_VALID(`json`)), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "mariadb/10.5.8/create table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "json", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("10.5.8-MariaDB") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "mariadb/10.5.8/table exists", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "json", Type: field.TypeJSON, Nullable: true}, {Name: "longtext", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("10.5.8-MariaDB-1:10.5.8+maria~focal") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", ""). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", ""). AddRow("json", "longtext", "YES", "YES", "NULL", "", "utf8mb4", "utf8mb4_bin"). AddRow("longtext", "longtext", "YES", "YES", "NULL", "", "utf8mb4", "utf8mb4_bin")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", "0", "1")) mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME`, `CHECK_CLAUSE` FROM `INFORMATION_SCHEMA`.`CHECK_CONSTRAINTS` WHERE `CONSTRAINT_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? AND `CONSTRAINT_NAME` IN (?, ?)")). WithArgs("users", "json", "longtext"). WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME", "CHECK_CLAUSE"}). AddRow("json", "json_valid(`json`)")) mock.ExpectCommit() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.before(mysqlMock{mock}) migrate, err := NewMigrate(sql.OpenDB("mysql", db), tt.options...) require.NoError(t, err) err = migrate.Create(context.Background(), tt.tables...) require.Equal(t, tt.wantErr, err != nil, err) }) } } type mysqlMock struct { sqlmock.Sqlmock } func (m mysqlMock) start(version string) { m.ExpectBegin() m.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")). WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", version)) } func (m mysqlMock) tableExists(table string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs(table). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } func (m mysqlMock) fkExists(fk string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLE_CONSTRAINTS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")). WithArgs("FOREIGN KEY", fk). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } func escape(query string) string { rows := strings.Split(query, "\n") for i := range rows { rows[i] = strings.TrimPrefix(rows[i], " ") } query = strings.Join(rows, " ") return strings.TrimSpace(regexp.QuoteMeta(query)) + "$" } ent-0.5.4/dialect/sql/schema/postgres.go000066400000000000000000000344371377533537200202010ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "strings" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" ) // Postgres is a postgres migration driver. type Postgres struct { dialect.Driver version string } // init loads the Postgres version from the database for later use in the migration process. // It returns an error if the server version is lower than v10. func (d *Postgres) init(ctx context.Context, tx dialect.Tx) error { rows := &sql.Rows{} if err := tx.Query(ctx, "SHOW server_version_num", []interface{}{}, rows); err != nil { return fmt.Errorf("querying server version %v", err) } defer rows.Close() if !rows.Next() { if err := rows.Err(); err != nil { return err } return fmt.Errorf("server_version_num variable was not found") } var version string if err := rows.Scan(&version); err != nil { return fmt.Errorf("scanning version: %v", err) } if len(version) < 6 { return fmt.Errorf("malformed version: %s", version) } d.version = fmt.Sprintf("%s.%s.%s", version[:2], version[2:4], version[4:]) if compareVersions(d.version, "10.0.0") == -1 { return fmt.Errorf("unsupported postgres version: %s", d.version) } return nil } // tableExist checks if a table exists in the database and current schema. func (d *Postgres) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) { query, args := sql.Dialect(dialect.Postgres). Select(sql.Count("*")).From(sql.Table("tables").Schema("information_schema")). Where(sql.And( sql.EQ("table_schema", sql.Raw("CURRENT_SCHEMA()")), sql.EQ("table_name", name), )).Query() return exist(ctx, tx, query, args...) } // tableExist checks if a foreign-key exists in the current schema. func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) { query, args := sql.Dialect(dialect.Postgres). Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")). Where(sql.And( sql.EQ("table_schema", sql.Raw("CURRENT_SCHEMA()")), sql.EQ("constraint_type", "FOREIGN KEY"), sql.EQ("constraint_name", name), )).Query() return exist(ctx, tx, query, args...) } // setRange sets restart the identity column to the given offset. Used by the universal-id option. func (d *Postgres) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error { if value == 0 { value = 1 // RESTART value cannot be < 1. } pk := "id" if len(t.PrimaryKey) == 1 { pk = t.PrimaryKey[0].Name } return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s RESTART WITH %d", t.Name, pk, value), []interface{}{}, nil) } // table loads the current table description from the database. func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { rows := &sql.Rows{} query, args := sql.Dialect(dialect.Postgres). Select("column_name", "data_type", "is_nullable", "column_default", "udt_name"). From(sql.Table("columns").Schema("information_schema")). Where(sql.And( sql.EQ("table_schema", sql.Raw("CURRENT_SCHEMA()")), sql.EQ("table_name", name), )).Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("postgres: reading table description %v", err) } // Call `Close` in cases of failures (`Close` is idempotent). defer rows.Close() t := NewTable(name) for rows.Next() { c := &Column{} if err := d.scanColumn(c, rows); err != nil { return nil, err } t.AddColumn(c) } if err := rows.Err(); err != nil { return nil, err } if err := rows.Close(); err != nil { return nil, fmt.Errorf("closing rows %v", err) } idxs, err := d.indexes(ctx, tx, name) if err != nil { return nil, err } // Populate the index information to the table and its columns. // We do it manually, because PK and uniqueness information does // not exist when querying the information_schema.COLUMNS above. for _, idx := range idxs { switch { case idx.primary: for _, name := range idx.columns { c, ok := t.column(name) if !ok { return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) } c.Key = PrimaryKey t.PrimaryKey = append(t.PrimaryKey, c) } case idx.Unique && len(idx.columns) == 1: name := idx.columns[0] c, ok := t.column(name) if !ok { return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) } c.Key = UniqueKey c.Unique = true fallthrough default: t.addIndex(idx) } } return t, nil } // indexesQuery holds a query format for retrieving // table indexes of the current schema. const indexesQuery = ` SELECT i.relname AS index_name, a.attname AS column_name, idx.indisprimary AS primary, idx.indisunique AS unique, array_position(idx.indkey, a.attnum) as seq_in_index FROM pg_class t, pg_class i, pg_index idx, pg_attribute a, pg_namespace n WHERE t.oid = idx.indrelid AND i.oid = idx.indexrelid AND n.oid = t.relnamespace AND a.attrelid = t.oid AND a.attnum = ANY(idx.indkey) AND t.relkind = 'r' AND n.nspname = CURRENT_SCHEMA() AND t.relname = '%s' ORDER BY index_name, seq_in_index; ` func (d *Postgres) indexes(ctx context.Context, tx dialect.Tx, table string) (Indexes, error) { rows := &sql.Rows{} if err := tx.Query(ctx, fmt.Sprintf(indexesQuery, table), []interface{}{}, rows); err != nil { return nil, fmt.Errorf("querying indexes for table %s: %v", table, err) } defer rows.Close() var ( idxs Indexes names = make(map[string]*Index) ) for rows.Next() { var ( seqindex int name, column string unique, primary bool ) if err := rows.Scan(&name, &column, &primary, &unique, &seqindex); err != nil { return nil, fmt.Errorf("scanning index description: %v", err) } // If the index is prefixed with the table, it may was added by // `addIndex` and it should be trimmed. But, since entc prefixes // all indexes with schema-type, for uncountable types (like, media // or equipment) this isn't correct, and we fallback for the real-name. short := strings.TrimPrefix(name, table+"_") idx, ok := names[short] if !ok { idx = &Index{Name: short, Unique: unique, primary: primary, realname: name} idxs = append(idxs, idx) names[short] = idx } idx.columns = append(idx.columns, column) } if err := rows.Err(); err != nil { return nil, err } return idxs, nil } // maxCharSize defines the maximum size of limited character types in Postgres (10 MB). const maxCharSize = 10 << 20 // scanColumn scans the information a column from column description. func (d *Postgres) scanColumn(c *Column, rows *sql.Rows) error { var ( nullable sql.NullString defaults sql.NullString udt sql.NullString ) if err := rows.Scan(&c.Name, &c.typ, &nullable, &defaults, &udt); err != nil { return fmt.Errorf("scanning column description: %v", err) } if nullable.Valid { c.Nullable = nullable.String == "YES" } switch c.typ { case "boolean": c.Type = field.TypeBool case "smallint": c.Type = field.TypeInt16 case "integer": c.Type = field.TypeInt32 case "bigint": c.Type = field.TypeInt64 case "real": c.Type = field.TypeFloat32 case "numeric", "decimal", "double precision": c.Type = field.TypeFloat64 case "text": c.Type = field.TypeString c.Size = maxCharSize + 1 case "character", "character varying": c.Type = field.TypeString case "date", "time", "timestamp", "timestamp with time zone", "timestamp without time zone": c.Type = field.TypeTime case "bytea": c.Type = field.TypeBytes case "jsonb": c.Type = field.TypeJSON case "uuid": c.Type = field.TypeUUID case "cidr", "inet", "macaddr", "macaddr8": c.Type = field.TypeOther case "USER-DEFINED": c.Type = field.TypeOther if !udt.Valid { return fmt.Errorf("missing user defined type for column %q", c.Name) } c.SchemaType = map[string]string{dialect.Postgres: udt.String} } switch { case !defaults.Valid || c.Type == field.TypeTime || seqfunc(defaults.String): return nil case strings.Contains(defaults.String, "::"): parts := strings.Split(defaults.String, "::") defaults.String = strings.Trim(parts[0], "'") fallthrough default: return c.ScanDefault(defaults.String) } } // tBuilder returns the TableBuilder for the given table. func (d *Postgres) tBuilder(t *Table) *sql.TableBuilder { b := sql.Dialect(dialect.Postgres). CreateTable(t.Name).IfNotExists() for _, c := range t.Columns { b.Column(d.addColumn(c)) } for _, pk := range t.PrimaryKey { b.PrimaryKey(pk.Name) } return b } // cType returns the PostgreSQL string type for this column. func (d *Postgres) cType(c *Column) (t string) { if c.SchemaType != nil && c.SchemaType[dialect.Postgres] != "" { return c.SchemaType[dialect.Postgres] } switch c.Type { case field.TypeBool: t = "boolean" case field.TypeUint8, field.TypeInt8, field.TypeInt16, field.TypeUint16: t = "smallint" case field.TypeInt32, field.TypeUint32: t = "int" case field.TypeInt, field.TypeUint, field.TypeInt64, field.TypeUint64: t = "bigint" case field.TypeFloat32: t = c.scanTypeOr("real") case field.TypeFloat64: t = c.scanTypeOr("double precision") case field.TypeBytes: t = "bytea" case field.TypeJSON: t = "jsonb" case field.TypeUUID: t = "uuid" case field.TypeString: t = "varchar" if c.Size > maxCharSize { t = "text" } case field.TypeTime: t = c.scanTypeOr("timestamp with time zone") case field.TypeEnum: // Currently, the support for enums is weak (application level only. // like SQLite). Dialect needs to create and maintain its enum type. t = "varchar" case field.TypeOther: t = c.typ default: panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name)) } return t } // addColumn returns the ColumnBuilder for adding the given column to a table. func (d *Postgres) addColumn(c *Column) *sql.ColumnBuilder { b := sql.Dialect(dialect.Postgres). Column(c.Name).Type(d.cType(c)).Attr(c.Attr) c.unique(b) if c.Increment { b.Attr("GENERATED BY DEFAULT AS IDENTITY") } c.nullable(b) c.defaultValue(b) return b } // alterColumn returns list of ColumnBuilder for applying in order to alter a column. func (d *Postgres) alterColumn(c *Column) (ops []*sql.ColumnBuilder) { b := sql.Dialect(dialect.Postgres) ops = append(ops, b.Column(c.Name).Type(d.cType(c))) if c.Nullable { ops = append(ops, b.Column(c.Name).Attr("DROP NOT NULL")) } else { ops = append(ops, b.Column(c.Name).Attr("SET NOT NULL")) } return ops } // hasUniqueName reports if the index has a unique name in the schema. func hasUniqueName(i *Index) bool { name := i.Name // The "_key" suffix is added by Postgres for implicit indexes. if strings.HasSuffix(name, "_key") { name = strings.TrimSuffix(name, "_key") } suffix := strings.Join(i.columnNames(), "_") if !strings.HasSuffix(name, suffix) { return true // Assume it has a custom storage-key. } // The codegen prefixes by default indexes with the type name. // For example, an index "users"("name"), will named as "user_name". return name != suffix } // addIndex returns the querying for adding an index to PostgreSQL. func (d *Postgres) addIndex(i *Index, table string) *sql.IndexBuilder { name := i.Name if !hasUniqueName(i) { // Since index name should be unique in pg_class for schema, // we prefix it with the table name and remove on read. name = fmt.Sprintf("%s_%s", table, i.Name) } idx := sql.Dialect(dialect.Postgres). CreateIndex(name).Table(table) if i.Unique { idx.Unique() } for _, c := range i.Columns { idx.Column(c.Name) } return idx } // dropIndex drops a Postgres index. func (d *Postgres) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error { name := idx.Name build := sql.Dialect(dialect.Postgres) if prefix := table + "_"; !strings.HasPrefix(name, prefix) && !hasUniqueName(idx) { name = prefix + name } query, args := sql.Dialect(dialect.Postgres). Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")). Where(sql.And( sql.EQ("table_schema", sql.Raw("CURRENT_SCHEMA()")), sql.EQ("constraint_type", "UNIQUE"), sql.EQ("constraint_name", name), )). Query() exists, err := exist(ctx, tx, query, args...) if err != nil { return err } query, args = build.DropIndex(name).Query() if exists { query, args = build.AlterTable(table).DropConstraint(name).Query() } return tx.Exec(ctx, query, args, nil) } // isImplicitIndex reports if the index was created implicitly for the unique column. func (d *Postgres) isImplicitIndex(idx *Index, col *Column) bool { return strings.TrimSuffix(idx.Name, "_key") == col.Name && col.Unique } // renameColumn returns the statement for renaming a column. func (d *Postgres) renameColumn(t *Table, old, new *Column) sql.Querier { return sql.Dialect(dialect.Postgres). AlterTable(t.Name). RenameColumn(old.Name, new.Name) } // renameIndex returns the statement for renaming an index. func (d *Postgres) renameIndex(t *Table, old, new *Index) sql.Querier { if sfx := "_key"; strings.HasSuffix(old.Name, sfx) && !strings.HasSuffix(new.Name, sfx) { new.Name += sfx } if pfx := t.Name + "_"; strings.HasPrefix(old.realname, pfx) && !strings.HasPrefix(new.Name, pfx) { new.Name = pfx + new.Name } return sql.Dialect(dialect.Postgres).AlterIndex(old.realname).Rename(new.Name) } // tableSchema returns the query for getting the table schema. func (d *Postgres) tableSchema() sql.Querier { return sql.Raw("(CURRENT_SCHEMA())") } // alterColumns returns the queries for applying the columns change-set. func (d *Postgres) alterColumns(table string, add, modify, drop []*Column) sql.Queries { b := sql.Dialect(dialect.Postgres).AlterTable(table) for _, c := range add { b.AddColumn(d.addColumn(c)) } for _, c := range modify { b.ModifyColumns(d.alterColumn(c)...) } for _, c := range drop { b.DropColumn(sql.Dialect(dialect.Postgres).Column(c.Name)) } if len(b.Queries) == 0 { return nil } return sql.Queries{b} } // seqfunc reports if the given string is a sequence function. func seqfunc(defaults string) bool { for _, fn := range [...]string{"currval", "lastval", "setval", "nextval"} { if strings.HasPrefix(defaults, fn+"(") && strings.HasSuffix(defaults, ")") { return true } } return false } ent-0.5.4/dialect/sql/schema/postgres_test.go000066400000000000000000001101501377533537200212230ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "math" "strings" "testing" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestPostgres_Create(t *testing.T) { tests := []struct { name string tables []*Table options []MigrateOption before func(pgMock) wantErr bool }{ { name: "tx failed", before: func(mock pgMock) { mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled) }, wantErr: true, }, { name: "unsupported version", before: func(mock pgMock) { mock.start("90000") }, wantErr: true, }, { name: "no tables", before: func(mock pgMock) { mock.start("120000") mock.ExpectCommit() }, }, { name: "create new table", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeInt}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, {Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}, Default: "a"}, {Name: "uuid", Type: field.TypeUUID}, {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.Postgres: "numeric(5,2)"}}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NULL, "age" bigint NOT NULL, "doc" jsonb NULL, "enums" varchar NOT NULL DEFAULT 'a', "uuid" uuid NOT NULL, "price" numeric(5,2) NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, {Name: "inet", Type: field.TypeString, Unique: true, SchemaType: map[string]string{dialect.Postgres: "inet"}}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NULL, "created_at" timestamp with time zone NOT NULL, "inet" inet UNIQUE NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "pets"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NOT NULL, "owner_id" bigint NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.fkExists("pets_owner", false) mock.ExpectExec(escape(`ALTER TABLE "pets" ADD CONSTRAINT "pets_owner" FOREIGN KEY("owner_id") REFERENCES "users"("id") ON DELETE CASCADE`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key disabled", options: []MigrateOption{ WithForeignKeys(false), }, tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NULL, "created_at" timestamp with time zone NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "pets"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NOT NULL, "owner_id" bigint NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "scan table with default set to serial", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "nextval('users_colname_seq'::regclass)", "int4")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectCommit() }, }, { name: "scan table with custom type", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "custom", Type: field.TypeOther, SchemaType: map[string]string{dialect.Postgres: "customtype"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "nextval('users_colname_seq'::regclass)", "NULL"). AddRow("custom", "USER-DEFINED", "NO", "NULL", "customtype")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectCommit() }, }, { name: "add column to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, {Name: "age", Type: field.TypeInt}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.Postgres: "date"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "date"}, Nullable: true}, {Name: "deleted_at", Type: field.TypeTime, Nullable: true}, {Name: "cidr", Type: field.TypeString, SchemaType: map[string]string{dialect.Postgres: "cidr"}}, {Name: "inet", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "inet"}}, {Name: "macaddr", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "macaddr"}}, {Name: "macaddr8", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "macaddr8"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("name", "character varying", "YES", "NULL", "varchar"). AddRow("uuid", "uuid", "YES", "NULL", "uuid"). AddRow("created_at", "date", "NO", "CURRENT_DATE", "date"). AddRow("updated_at", "timestamp", "YES", "NULL", "timestamptz"). AddRow("deleted_at", "date", "YES", "NULL", "date"). AddRow("text", "text", "YES", "NULL", "text"). AddRow("cidr", "cidr", "NO", "NULL", "cidr"). AddRow("inet", "inet", "YES", "NULL", "inet"). AddRow("macaddr", "macaddr", "YES", "NULL", "macaddr"). AddRow("macaddr8", "macaddr8", "YES", "NULL", "macaddr8")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" bigint NOT NULL, ALTER COLUMN "updated_at" TYPE timestamp with time zone, ALTER COLUMN "updated_at" DROP NOT NULL, ALTER COLUMN "deleted_at" TYPE timestamp with time zone, ALTER COLUMN "deleted_at" DROP NOT NULL`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add int column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeInt, Default: 10}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("name", "character", "YES", "NULL", "bpchar"). AddRow("doc", "jsonb", "YES", "NULL", "jsonb")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" bigint NOT NULL DEFAULT 10`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add blob columns", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "blob", Type: field.TypeBytes, Size: 1e3}, {Name: "longblob", Type: field.TypeBytes, Size: 1e6}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("name", "character", "YES", "NULL", "bpchar"). AddRow("doc", "jsonb", "YES", "NULL", "jsonb")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "blob" bytea NOT NULL, ADD COLUMN "longblob" bytea NOT NULL`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add float column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeFloat64, Default: 10.1}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("name", "character", "YES", "NULL", "bpchar")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" double precision NOT NULL DEFAULT 10.1`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add bool column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeBool, Default: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("name", "character", "YES", "NULL", "bpchar")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" boolean NOT NULL DEFAULT true`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add string column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "nick", Type: field.TypeString, Default: "unknown"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("name", "character", "YES", "NULL", "bpchar")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "nick" varchar NOT NULL DEFAULT 'unknown'`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "drop column to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropColumn(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("name", "character", "YES", "NULL", "bpchar")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" DROP COLUMN "name"`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "modify column to nullable", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("name", "character", "NO", "NULL", "bpchar")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "name" TYPE varchar, ALTER COLUMN "name" DROP NOT NULL`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "apply uniqueness on column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt, Unique: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("age", "bigint", "NO", "NULL", "int8")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`CREATE UNIQUE INDEX "users_age" ON "users"("age")`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "remove uniqueness from column without option", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("age", "bigint", "NO", "NULL", "int8")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0). AddRow("users_age_key", "age", "f", "t", 0)) mock.ExpectCommit() }, }, { name: "remove uniqueness from column with option", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropIndex(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("age", "bigint", "NO", "NULL", "int8")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0). AddRow("users_age_key", "age", "f", "t", 0)) mock.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."table_constraints" WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)). WithArgs("UNIQUE", "users_age_key"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) mock.ExpectExec(escape(`ALTER TABLE "users" DROP CONSTRAINT "users_age_key"`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add and remove indexes", tables: func() []*Table { c1 := []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, // Add implicit index. {Name: "age", Type: field.TypeInt, Unique: true}, {Name: "score", Type: field.TypeInt}, } c2 := []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "score", Type: field.TypeInt}, } return []*Table{ { Name: "users", Columns: c1, PrimaryKey: c1[0:1], Indexes: Indexes{ // Change non-unique index to unique. {Name: "user_score", Columns: c1[2:3], Unique: true}, }, }, { Name: "equipment", Columns: c2, PrimaryKey: c2[0:1], Indexes: Indexes{ {Name: "equipment_score", Columns: c2[1:]}, }, }, } }(), options: []MigrateOption{WithDropIndex(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("age", "bigint", "NO", "NULL", "int8"). AddRow("score", "bigint", "NO", "NULL", "int8")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0). AddRow("user_score", "score", "f", "f", 0)) mock.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."table_constraints" WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)). WithArgs("UNIQUE", "user_score"). WillReturnRows(sqlmock.NewRows([]string{"count"}). AddRow(0)) mock.ExpectExec(escape(`DROP INDEX "user_score"`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape(`CREATE UNIQUE INDEX "users_age" ON "users"("age")`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape(`CREATE UNIQUE INDEX "user_score" ON "users"("score")`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("equipment", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("equipment"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "NO", "NULL", "int8"). AddRow("score", "bigint", "NO", "NULL", "int8")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "equipment"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0). AddRow("equipment_score", "score", "f", "f", 0)) mock.ExpectCommit() }, }, { name: "add edge to table", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "spouse_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "user_spouse" + strings.Repeat("_", 64), // super long fk. Columns: c1[2:], RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) t1.ForeignKeys[0].RefTable = t1 return []*Table{t1} }(), before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "YES", "NULL", "int8"). AddRow("name", "character", "YES", "NULL", "bpchar")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "spouse_id" bigint NULL`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.fkExists("user_spouse____________________390ed76f91d3c57cd3516e7690f621dc", false) mock.ExpectExec(`ALTER TABLE "users" ADD CONSTRAINT ".{63}" FOREIGN KEY\("spouse_id"\) REFERENCES "users"\("id"\) ON DELETE CASCADE`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for all tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("ent_types", false) // create ent_types table. mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "ent_types"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "type" varchar UNIQUE NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range. mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). WithArgs("users"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec("ALTER TABLE users ALTER COLUMN id RESTART WITH 1"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("groups", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "groups"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec("ALTER TABLE groups ALTER COLUMN id RESTART WITH 4294967296"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for new tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(`SELECT "type" FROM "ent_types" ORDER BY "id" ASC`). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) // query users table. mock.tableExists("users", true) // users table has no changes. mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name"}). AddRow("id", "bigint", "YES", "NULL", "int8")) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) // query groups table. mock.tableExists("groups", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "groups"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec("ALTER TABLE groups ALTER COLUMN id RESTART WITH 4294967296"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for restored tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(`SELECT "type" FROM "ent_types" ORDER BY "id" ASC`). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) // query and create users (restored table). mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range (without inserting to ent_types). mock.ExpectExec("ALTER TABLE users ALTER COLUMN id RESTART WITH 1"). WillReturnResult(sqlmock.NewResult(0, 1)) // query groups table. mock.tableExists("groups", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "groups"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec("ALTER TABLE groups ALTER COLUMN id RESTART WITH 4294967296"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.before(pgMock{mock}) migrate, err := NewMigrate(sql.OpenDB("postgres", db), tt.options...) require.NoError(t, err) err = migrate.Create(context.Background(), tt.tables...) require.Equal(t, tt.wantErr, err != nil, err) }) } } type pgMock struct { sqlmock.Sqlmock } func (m pgMock) start(version string) { m.ExpectBegin() m.ExpectQuery(escape("SHOW server_version_num")). WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow(version)) } func (m pgMock) tableExists(table string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs(table). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } func (m pgMock) fkExists(fk string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."table_constraints" WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)). WithArgs("FOREIGN KEY", fk). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } ent-0.5.4/dialect/sql/schema/schema.go000066400000000000000000000344051377533537200175660ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package schema contains all schema migration logic for SQL dialects. package schema import ( "fmt" "strconv" "strings" "github.com/facebook/ent/dialect/entsql" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" ) const ( // DefaultStringLen describes the default length for string/varchar types. DefaultStringLen int64 = 255 // Null is the string representation of NULL in SQL. Null = "NULL" // PrimaryKey is the string representation of PKs in SQL. PrimaryKey = "PRI" // UniqueKey is the string representation of PKs in SQL. UniqueKey = "UNI" ) // Table schema definition for SQL dialects. type Table struct { Name string Columns []*Column columns map[string]*Column Indexes []*Index PrimaryKey []*Column ForeignKeys []*ForeignKey Annotation *entsql.Annotation } // NewTable returns a new table with the given name. func NewTable(name string) *Table { return &Table{ Name: name, columns: make(map[string]*Column), } } // AddPrimary adds a new primary key to the table. func (t *Table) AddPrimary(c *Column) *Table { c.Key = PrimaryKey t.AddColumn(c) t.PrimaryKey = append(t.PrimaryKey, c) return t } // AddForeignKey adds a foreign key to the table. func (t *Table) AddForeignKey(fk *ForeignKey) *Table { t.ForeignKeys = append(t.ForeignKeys, fk) return t } // AddColumn adds a new column to the table. func (t *Table) AddColumn(c *Column) *Table { t.columns[c.Name] = c t.Columns = append(t.Columns, c) return t } // SetAnnotation the entsql.Annotation on the table. func (t *Table) SetAnnotation(ant *entsql.Annotation) *Table { t.Annotation = ant return t } // AddIndex creates and adds a new index to the table from the given options. func (t *Table) AddIndex(name string, unique bool, columns []string) *Table { return t.addIndex(&Index{ Name: name, Unique: unique, columns: columns, Columns: make([]*Column, 0, len(columns)), }) } // AddIndex creates and adds a new index to the table from the given options. func (t *Table) addIndex(idx *Index) *Table { for _, name := range idx.columns { c, ok := t.columns[name] if ok { c.indexes.append(idx) idx.Columns = append(idx.Columns, c) } } t.Indexes = append(t.Indexes, idx) return t } // column returns a table column by its name. // faster than map lookup for most cases. func (t *Table) column(name string) (*Column, bool) { for _, c := range t.Columns { if c.Name == name { return c, true } } return nil, false } // index returns a table index by its name. func (t *Table) index(name string) (*Index, bool) { for _, idx := range t.Indexes { if name == idx.Name || name == idx.realname { return idx, true } // Same as below, there are cases where the index name // is unknown (created automatically on column constraint). if len(idx.Columns) == 1 && idx.Columns[0].Name == name { return idx, true } } // If it is an "implicit index" (unique constraint on // table creation) and it wasn't loaded in table scanning. c, ok := t.column(name) if !ok { // Postgres naming convention for unique constraint (__key). name = strings.TrimPrefix(name, t.Name+"_") name = strings.TrimSuffix(name, "_key") c, ok = t.column(name) } if ok && c.Unique { return &Index{Name: name, Unique: c.Unique, Columns: []*Column{c}, columns: []string{c.Name}}, true } return nil, false } // fk returns a table foreign-key by its symbol. // faster than map lookup for most cases. func (t *Table) fk(symbol string) (*ForeignKey, bool) { for _, fk := range t.ForeignKeys { if fk.Symbol == symbol { return fk, true } } return nil, false } // Column schema definition for SQL dialects. type Column struct { Name string // column name. Type field.Type // column type. SchemaType map[string]string // optional schema type per dialect. Attr string // extra attributes. Size int64 // max size parameter for string, blob, etc. Key string // key definition (PRI, UNI or MUL). Unique bool // column with unique constraint. Increment bool // auto increment attribute. Nullable bool // null or not null attribute. Default interface{} // default value. Enums []string // enum values. typ string // row column type (used for Rows.Scan). indexes Indexes // linked indexes. foreign *ForeignKey // linked foreign-key. } // UniqueKey returns boolean indicates if this column is a unique key. // Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects. func (c *Column) UniqueKey() bool { return c.Key == UniqueKey } // PrimaryKey returns boolean indicates if this column is on of the primary key columns. // Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects. func (c *Column) PrimaryKey() bool { return c.Key == PrimaryKey } // ConvertibleTo reports whether a column can be converted to the new column without altering its data. func (c *Column) ConvertibleTo(d *Column) bool { switch { case c.Type == d.Type: if c.Size != 0 && d.Size != 0 { // Types match and have a size constraint. return c.Size <= d.Size } return true case c.IntType() && d.IntType() || c.UintType() && d.UintType(): return c.Type <= d.Type case c.UintType() && d.IntType(): // uintX can not be converted to intY, when X > Y. return c.Type-field.TypeUint8 <= d.Type-field.TypeInt8 case c.Type == field.TypeString && d.Type == field.TypeEnum || c.Type == field.TypeEnum && d.Type == field.TypeString: return true case c.Type.Integer() && d.Type == field.TypeString: return true } return c.FloatType() && d.FloatType() } // IntType reports whether the column is an int type (int8 ... int64). func (c Column) IntType() bool { return c.Type >= field.TypeInt8 && c.Type <= field.TypeInt64 } // UintType reports of the given type is a uint type (int8 ... int64). func (c Column) UintType() bool { return c.Type >= field.TypeUint8 && c.Type <= field.TypeUint64 } // FloatType reports of the given type is a float type (float32, float64). func (c Column) FloatType() bool { return c.Type == field.TypeFloat32 || c.Type == field.TypeFloat64 } // ScanDefault scans the default value string to its interface type. func (c *Column) ScanDefault(value string) error { switch { case strings.ToUpper(value) == Null: // ignore. case c.IntType(): v := &sql.NullInt64{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning int value for column %q: %v", c.Name, err) } c.Default = v.Int64 case c.UintType(): v := &sql.NullInt64{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning uint value for column %q: %v", c.Name, err) } c.Default = uint64(v.Int64) case c.FloatType(): v := &sql.NullFloat64{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning float value for column %q: %v", c.Name, err) } c.Default = v.Float64 case c.Type == field.TypeBool: v := &sql.NullBool{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning bool value for column %q: %v", c.Name, err) } c.Default = v.Bool case c.Type == field.TypeString || c.Type == field.TypeEnum: v := &sql.NullString{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning string value for column %q: %v", c.Name, err) } c.Default = v.String case c.Type == field.TypeJSON: v := &sql.NullString{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning json value for column %q: %v", c.Name, err) } c.Default = v.String default: return fmt.Errorf("unsupported default type: %v", c.Type) } return nil } // defaultValue adds tge `DEFAULT` attribute the the column. // Note that, in SQLite if a NOT NULL constraint is specified, // then the column must have a default value which not NULL. func (c *Column) defaultValue(b *sql.ColumnBuilder) { // has default, and it's supported in the database level. if c.Default != nil && c.supportDefault() { attr := "DEFAULT " switch v := c.Default.(type) { case bool: attr += strconv.FormatBool(v) case string: // Escape single quote by replacing each with 2. attr += fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) default: attr += fmt.Sprint(v) } b.Attr(attr) } } // supportDefault reports if the column type supports default value. func (c Column) supportDefault() bool { switch { case c.Type == field.TypeString || c.Type == field.TypeEnum: return c.Size < 1<<16 // not a text. case c.Type.Numeric(), c.Type == field.TypeBool: return true default: return false } } // unique adds the `UNIQUE` attribute if the column is a unique type. // it is exist in a different function to share the common declaration // between the two dialects. func (c *Column) unique(b *sql.ColumnBuilder) { if c.Unique { b.Attr("UNIQUE") } } // nullable adds the `NULL`/`NOT NULL` attribute to the column. it is exist in // a different function to share the common declaration between the two dialects. func (c *Column) nullable(b *sql.ColumnBuilder) { attr := Null if !c.Nullable { attr = "NOT " + attr } b.Attr(attr) } // defaultSize returns the default size for MySQL varchar type based // on column size, charset and table indexes, in order to avoid index // prefix key limit (767). func (c *Column) defaultSize(version string) int64 { size := DefaultStringLen switch { // version is >= 5.7. case compareVersions(version, "5.7.0") != -1: // non-unique, or not part of any index (reaching the error 1071). case !c.Unique && len(c.indexes) == 0: default: size = 191 } return size } // scanTypeOr returns the scanning type or the given value. func (c *Column) scanTypeOr(t string) string { if c.typ != "" { return strings.ToLower(c.typ) } return t } // ForeignKey definition for creation. type ForeignKey struct { Symbol string // foreign-key name. Generated if empty. Columns []*Column // table column RefTable *Table // referenced table. RefColumns []*Column // referenced columns. OnUpdate ReferenceOption // action on update. OnDelete ReferenceOption // action on delete. } // DSL returns a default DSL query for a foreign-key. func (fk ForeignKey) DSL() *sql.ForeignKeyBuilder { cols := make([]string, len(fk.Columns)) refs := make([]string, len(fk.RefColumns)) for i, c := range fk.Columns { cols[i] = c.Name } for i, c := range fk.RefColumns { refs[i] = c.Name } dsl := sql.ForeignKey().Symbol(fk.Symbol). Columns(cols...). Reference(sql.Reference().Table(fk.RefTable.Name).Columns(refs...)) if action := string(fk.OnDelete); action != "" { dsl.OnDelete(action) } if action := string(fk.OnUpdate); action != "" { dsl.OnUpdate(action) } return dsl } // ReferenceOption for constraint actions. type ReferenceOption string // Reference options. const ( NoAction ReferenceOption = "NO ACTION" Restrict ReferenceOption = "RESTRICT" Cascade ReferenceOption = "CASCADE" SetNull ReferenceOption = "SET NULL" SetDefault ReferenceOption = "SET DEFAULT" ) // ConstName returns the constant name of a reference option. It's used by entc for printing the constant name in templates. func (r ReferenceOption) ConstName() string { if r == NoAction { return "" } return strings.ReplaceAll(strings.Title(strings.ToLower(string(r))), " ", "") } // Index definition for table index. type Index struct { Name string // index name. Unique bool // uniqueness. Columns []*Column // actual table columns. columns []string // columns loaded from query scan. primary bool // primary key index. realname string // real name in the database (Postgres only). } // Builder returns the query builder for index creation. The DSL is identical in all dialects. func (i *Index) Builder(table string) *sql.IndexBuilder { idx := sql.CreateIndex(i.Name).Table(table) if i.Unique { idx.Unique() } for _, c := range i.Columns { idx.Column(c.Name) } return idx } // DropBuilder returns the query builder for the drop index. func (i *Index) DropBuilder(table string) *sql.DropIndexBuilder { idx := sql.DropIndex(i.Name).Table(table) return idx } // sameAs reports if the index has the same properties // as the given index (except the name). func (i *Index) sameAs(idx *Index) bool { if i.Unique != idx.Unique || len(i.Columns) != len(idx.Columns) { return false } for j, c := range i.Columns { if c.Name != idx.Columns[j].Name { return false } } return true } // columnNames returns the names of the columns of the index. func (i *Index) columnNames() []string { if len(i.columns) > 0 { return i.columns } columns := make([]string, 0, len(i.Columns)) for _, c := range i.Columns { columns = append(columns, c.Name) } return columns } // Indexes used for scanning all sql.Rows into a list of indexes, because // multiple sql rows can represent the same index (multi-columns indexes). type Indexes []*Index // append wraps the basic `append` function by filtering duplicates indexes. func (i *Indexes) append(idx1 *Index) { for _, idx2 := range *i { if idx2.Name == idx1.Name { return } } *i = append(*i, idx1) } // compareVersions returns an integer comparing the 2 versions. func compareVersions(v1, v2 string) int { pv1, ok1 := parseVersion(v1) pv2, ok2 := parseVersion(v2) if !ok1 && !ok2 { return 0 } if !ok1 { return -1 } if !ok2 { return 1 } if v := compare(pv1.major, pv2.major); v != 0 { return v } if v := compare(pv1.minor, pv2.minor); v != 0 { return v } return compare(pv1.patch, pv2.patch) } // version represents a parsed MySQL version. type version struct { major int minor int patch int } // parseVersion returns an integer comparing the 2 versions. func parseVersion(v string) (*version, bool) { parts := strings.Split(v, ".") if len(parts) == 0 { return nil, false } var ( err error ver = &version{} ) for i, e := range []*int{&ver.major, &ver.minor, &ver.patch} { if i == len(parts) { break } if *e, err = strconv.Atoi(strings.Split(parts[i], "-")[0]); err != nil { return nil, false } } return ver, true } func compare(v1, v2 int) int { if v1 == v2 { return 0 } if v1 < v2 { return -1 } return 1 } ent-0.5.4/dialect/sql/schema/schema_test.go000066400000000000000000000101351377533537200206170ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "testing" "github.com/facebook/ent/schema/field" "github.com/stretchr/testify/require" ) func TestColumn_ConvertibleTo(t *testing.T) { c1 := &Column{Type: field.TypeString, Size: 10} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 10})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 255})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 9})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat32})) c1 = &Column{Type: field.TypeFloat32} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat32})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat64})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint})) c1 = &Column{Type: field.TypeFloat64} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat32})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat64})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint})) c1 = &Column{Type: field.TypeUint} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeUint})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt64})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeUint64})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeInt8})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint8})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint16})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint32})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 1})) c1 = &Column{Type: field.TypeInt} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt64})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeInt8})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeInt32})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint8})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint16})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint32})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 1})) } func TestColumn_ScanDefault(t *testing.T) { c1 := &Column{Type: field.TypeString, Size: 10} require.NoError(t, c1.ScanDefault("Hello World")) require.Equal(t, "Hello World", c1.Default) require.NoError(t, c1.ScanDefault("1")) require.Equal(t, "1", c1.Default) c1 = &Column{Type: field.TypeInt64} require.NoError(t, c1.ScanDefault("128")) require.Equal(t, int64(128), c1.Default) require.NoError(t, c1.ScanDefault("1")) require.Equal(t, int64(1), c1.Default) require.Error(t, c1.ScanDefault("foo")) c1 = &Column{Type: field.TypeUint64} require.NoError(t, c1.ScanDefault("128")) require.Equal(t, uint64(128), c1.Default) require.NoError(t, c1.ScanDefault("1")) require.Equal(t, uint64(1), c1.Default) require.Error(t, c1.ScanDefault("foo")) c1 = &Column{Type: field.TypeFloat64} require.NoError(t, c1.ScanDefault("128.1")) require.Equal(t, 128.1, c1.Default) require.NoError(t, c1.ScanDefault("1")) require.Equal(t, float64(1), c1.Default) require.Error(t, c1.ScanDefault("foo")) c1 = &Column{Type: field.TypeBool} require.NoError(t, c1.ScanDefault("1")) require.Equal(t, true, c1.Default) require.NoError(t, c1.ScanDefault("true")) require.Equal(t, true, c1.Default) require.NoError(t, c1.ScanDefault("0")) require.Equal(t, false, c1.Default) require.NoError(t, c1.ScanDefault("false")) require.Equal(t, false, c1.Default) require.Error(t, c1.ScanDefault("foo")) } ent-0.5.4/dialect/sql/schema/sqlite.go000066400000000000000000000235701377533537200176300ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "strings" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" ) // SQLite is an SQLite migration driver. type SQLite struct { dialect.Driver WithForeignKeys bool } // init makes sure that foreign_keys support is enabled. func (d *SQLite) init(ctx context.Context, tx dialect.Tx) error { on, err := exist(ctx, tx, "PRAGMA foreign_keys") if err != nil { return fmt.Errorf("sqlite: check foreign_keys pragma: %v", err) } if !on { // foreign_keys pragma is off, either enable it by execute "PRAGMA foreign_keys=ON" // or add the following parameter in the connection string "_fk=1". return fmt.Errorf("sqlite: foreign_keys pragma is off: missing %q is the connection string", "_fk=1") } return nil } func (d *SQLite) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) { query, args := sql.Select().Count(). From(sql.Table("sqlite_master")). Where(sql.And( sql.EQ("type", "table"), sql.EQ("name", name), )). Query() return exist(ctx, tx, query, args...) } // setRange sets the start value of table PK. // SQLite tracks the AUTOINCREMENT in the "sqlite_sequence" table that is created and initialized automatically // whenever a table that contains an AUTOINCREMENT column is created. However, it populates to it a rows (for tables) // only after the first insertion. Therefore, we check. If a record (for the given table) already exists in the "sqlite_sequence" // table, we updated it. Otherwise, we insert a new value. func (d *SQLite) setRange(ctx context.Context, tx dialect.Tx, t *Table, value int) error { query, args := sql.Select().Count(). From(sql.Table("sqlite_sequence")). Where(sql.EQ("name", t.Name)). Query() exists, err := exist(ctx, tx, query, args...) switch { case err != nil: return err case exists: query, args = sql.Update("sqlite_sequence").Set("seq", value).Where(sql.EQ("name", t.Name)).Query() default: // !exists query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(t.Name, value).Query() } return tx.Exec(ctx, query, args, nil) } func (d *SQLite) tBuilder(t *Table) *sql.TableBuilder { b := sql.CreateTable(t.Name) for _, c := range t.Columns { b.Column(d.addColumn(c)) } // Unlike in MySQL, we're not able to add foreign-key constraints to table // after it was created, and adding them to the `CREATE TABLE` statement is // not always valid (because circular foreign-keys situation is possible). // We stay consistent by not using constraints at all, and just defining the // foreign keys in the `CREATE TABLE` statement. if d.WithForeignKeys { for _, fk := range t.ForeignKeys { b.ForeignKeys(fk.DSL()) } } // If it's an ID based primary key with autoincrement, we add // the `PRIMARY KEY` clause to the column declaration. Otherwise, // we append it to the constraint clause. if len(t.PrimaryKey) == 1 && t.PrimaryKey[0].Increment { return b } for _, pk := range t.PrimaryKey { b.PrimaryKey(pk.Name) } return b } // cType returns the SQLite string type for the given column. func (*SQLite) cType(c *Column) (t string) { if c.SchemaType != nil && c.SchemaType[dialect.SQLite] != "" { return c.SchemaType[dialect.SQLite] } switch c.Type { case field.TypeBool: t = "bool" case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32, field.TypeUint32, field.TypeUint, field.TypeInt, field.TypeInt64, field.TypeUint64: t = "integer" case field.TypeBytes: t = "blob" case field.TypeString, field.TypeEnum: // SQLite does not impose any length restrictions on // the length of strings, BLOBs or numeric values. t = fmt.Sprintf("varchar(%d)", DefaultStringLen) case field.TypeFloat32, field.TypeFloat64: t = "real" case field.TypeTime: t = "datetime" case field.TypeJSON: t = "json" case field.TypeUUID: t = "uuid" default: panic("unsupported type " + c.Type.String()) } return t } // addColumn returns the DSL query for adding the given column to a table. func (d *SQLite) addColumn(c *Column) *sql.ColumnBuilder { b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr) c.unique(b) if c.PrimaryKey() && c.Increment { b.Attr("PRIMARY KEY AUTOINCREMENT") } c.nullable(b) c.defaultValue(b) return b } // addIndex returns the querying for adding an index to SQLite. func (d *SQLite) addIndex(i *Index, table string) *sql.IndexBuilder { return i.Builder(table) } // dropIndex drops a SQLite index. func (d *SQLite) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error { query, args := idx.DropBuilder("").Query() return tx.Exec(ctx, query, args, nil) } // fkExist returns always true to disable foreign-keys creation after the table was created. func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil } // table returns always error to indicate that SQLite dialect doesn't support incremental migration. func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { rows := &sql.Rows{} query, args := sql.Select("name", "type", "notnull", "dflt_value", "pk"). From(sql.Table(fmt.Sprintf("pragma_table_info('%s')", name)).Unquote()). OrderBy("pk"). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("sqlite: reading table description %v", err) } // Call Close in cases of failures (Close is idempotent). defer rows.Close() t := NewTable(name) for rows.Next() { c := &Column{} if err := d.scanColumn(c, rows); err != nil { return nil, fmt.Errorf("sqlite: %v", err) } if c.PrimaryKey() { t.PrimaryKey = append(t.PrimaryKey, c) } t.AddColumn(c) } if err := rows.Err(); err != nil { return nil, err } if err := rows.Close(); err != nil { return nil, fmt.Errorf("sqlite: closing rows %v", err) } indexes, err := d.indexes(ctx, tx, name) if err != nil { return nil, err } // Add and link indexes to table columns. for _, idx := range indexes { switch { case idx.primary: case idx.Unique && len(idx.columns) == 1: name := idx.columns[0] c, ok := t.column(name) if !ok { return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) } c.Key = UniqueKey c.Unique = true fallthrough default: t.addIndex(idx) } } return t, nil } // table loads the table indexes from the database. func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Indexes, error) { rows := &sql.Rows{} query, args := sql.Select("name", "unique", "origin"). From(sql.Table(fmt.Sprintf("pragma_index_list('%s')", name)).Unquote()). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("reading table indexes %v", err) } defer rows.Close() var idx Indexes for rows.Next() { i := &Index{} origin := sql.NullString{} if err := rows.Scan(&i.Name, &i.Unique, &origin); err != nil { return nil, fmt.Errorf("scanning index description %v", err) } i.primary = origin.String == "pk" idx = append(idx, i) } if err := rows.Err(); err != nil { return nil, err } if err := rows.Close(); err != nil { return nil, fmt.Errorf("closing rows %v", err) } for i := range idx { columns, err := d.indexColumns(ctx, tx, idx[i].Name) if err != nil { return nil, err } idx[i].columns = columns // Normalize implicit index names to ent naming convention. See: // https://github.com/sqlite/sqlite/blob/e937df8/src/build.c#L3583 if len(columns) == 1 && strings.HasPrefix(idx[i].Name, "sqlite_autoindex_"+name) { idx[i].Name = columns[0] } } return idx, nil } // indexColumns loads index columns from index info. func (d *SQLite) indexColumns(ctx context.Context, tx dialect.Tx, name string) ([]string, error) { rows := &sql.Rows{} query, args := sql.Select("name"). From(sql.Table(fmt.Sprintf("pragma_index_info('%s')", name)).Unquote()). OrderBy("seqno"). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("reading table indexes %v", err) } defer rows.Close() var names []string if err := sql.ScanSlice(rows, &names); err != nil { return nil, err } return names, nil } // scanColumn scans the column information from SQLite column description. func (d *SQLite) scanColumn(c *Column, rows *sql.Rows) error { var ( pk sql.NullInt64 notnull sql.NullInt64 defaults sql.NullString ) if err := rows.Scan(&c.Name, &c.typ, ¬null, &defaults, &pk); err != nil { return fmt.Errorf("scanning column description: %v", err) } c.Nullable = notnull.Int64 == 0 if pk.Int64 > 0 { c.Key = PrimaryKey } parts, _, _, err := parseColumn(c.typ) if err != nil { return err } switch parts[0] { case "bool", "boolean": c.Type = field.TypeBool case "blob": c.Type = field.TypeBytes case "integer": // All integer types have the same "type affinity". c.Type = field.TypeInt case "real", "float", "double": c.Type = field.TypeFloat64 case "datetime": c.Type = field.TypeTime case "json": c.Type = field.TypeJSON case "uuid": c.Type = field.TypeUUID case "varchar", "text": c.Size = DefaultStringLen c.Type = field.TypeString } if defaults.Valid { return c.ScanDefault(defaults.String) } return nil } // alterColumns returns the queries for applying the columns change-set. func (d *SQLite) alterColumns(table string, add, _, _ []*Column) sql.Queries { queries := make(sql.Queries, 0, len(add)) for i := range add { c := d.addColumn(add[i]) if fk := add[i].foreign; fk != nil { c.Constraint(fk.DSL()) } queries = append(queries, sql.Dialect(dialect.SQLite).AlterTable(table).AddColumn(c)) } // Modifying and dropping columns is not supported and disabled until we // will support https://www.sqlite.org/lang_altertable.html#otheralter return queries } ent-0.5.4/dialect/sql/schema/sqlite_test.go000066400000000000000000000417721377533537200206730ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "math" "testing" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestSQLite_Create(t *testing.T) { tests := []struct { name string tables []*Table options []MigrateOption before func(sqliteMock) wantErr bool }{ { name: "tx failed", before: func(mock sqliteMock) { mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled) }, wantErr: true, }, { name: "fk disabled", before: func(mock sqliteMock) { mock.ExpectBegin() mock.ExpectQuery("PRAGMA foreign_keys"). WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(0)) mock.ExpectRollback() }, wantErr: true, }, { name: "no tables", before: func(mock sqliteMock) { mock.start() mock.ExpectCommit() }, }, { name: "create new table", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeInt}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.SQLite: "decimal(6,2)"}}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock sqliteMock) { mock.start() mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` datetime NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key disabled", options: []MigrateOption{ WithForeignKeys(false), }, tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock sqliteMock) { mock.start() mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` datetime NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add column to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "age", Type: field.TypeInt, Default: 0}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("name", "varchar(255)", 0, nil, 0). AddRow("text", "text", 0, "NULL", 0). AddRow("uuid", "uuid", 0, "Null", 0). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "datetime and timestamp", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "created_at", Type: field.TypeTime, Nullable: true}, {Name: "updated_at", Type: field.TypeTime, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("created_at", "datetime", 0, nil, 0). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add blob columns", tables: []*Table{ { Name: "blobs", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "old_tiny", Type: field.TypeBytes, Size: 100}, {Name: "old_blob", Type: field.TypeBytes, Size: 1e3}, {Name: "old_medium", Type: field.TypeBytes, Size: 1e5}, {Name: "old_long", Type: field.TypeBytes, Size: 1e8}, {Name: "new_tiny", Type: field.TypeBytes, Size: 100}, {Name: "new_blob", Type: field.TypeBytes, Size: 1e3}, {Name: "new_medium", Type: field.TypeBytes, Size: 1e5}, {Name: "new_long", Type: field.TypeBytes, Size: 1e8}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("blobs", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('blobs') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("old_tiny", "blob", 1, nil, 0). AddRow("old_blob", "blob", 1, nil, 0). AddRow("old_medium", "blob", 1, nil, 0). AddRow("old_long", "blob", 1, nil, 0). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('blobs')")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) for _, c := range []string{"tiny", "blob", "medium", "long"} { mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))). WillReturnResult(sqlmock.NewResult(0, 1)) } mock.ExpectCommit() }, }, { name: "add columns with default values", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Default: "unknown"}, {Name: "active", Type: field.TypeBool, Default: false}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `name` varchar(255) NOT NULL DEFAULT 'unknown'")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add edge to table", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "spouse_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "user_spouse", Columns: c1[2:], RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) t1.ForeignKeys[0].RefTable = t1 return []*Table{t1} }(), before: func(mock sqliteMock) { mock.start() mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("name", "varchar(255)", 1, "NULL", 0). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for all tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock sqliteMock) { mock.start() // creating ent_types table. mock.tableExists("ent_types", false) mock.ExpectExec(escape("CREATE TABLE `ent_types`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `type` varchar(255) UNIQUE NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("users"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). WithArgs("users", 0). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("groups", false) mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). WithArgs("groups"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). WithArgs("groups", 1<<32). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for restored tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock sqliteMock) { mock.start() // query ent_types table. mock.tableExists("ent_types", true) mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range (without inserting to ent_types). mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) mock.ExpectExec(escape("UPDATE `sqlite_sequence` SET `seq` = ? WHERE `name` = ?")). WithArgs(0, "users"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("groups", false) mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). WithArgs("groups"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). WithArgs("groups", 1<<32). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.before(sqliteMock{mock}) migrate, err := NewMigrate(sql.OpenDB("sqlite3", db), tt.options...) require.NoError(t, err) err = migrate.Create(context.Background(), tt.tables...) require.Equal(t, tt.wantErr, err != nil, err) }) } } type sqliteMock struct { sqlmock.Sqlmock } func (m sqliteMock) start() { m.ExpectBegin() m.ExpectQuery("PRAGMA foreign_keys"). WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1)) } func (m sqliteMock) tableExists(table string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")). WithArgs("table", table). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } ent-0.5.4/dialect/sql/schema/writer.go000066400000000000000000000023131377533537200176330ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "io" "strings" "github.com/facebook/ent/dialect" ) // WriteDriver is a driver that writes all driver exec operations to its writer. type WriteDriver struct { dialect.Driver // underlying driver. io.Writer // target for exec statements. } // Exec writes its query and calls the underlying driver Exec method. func (w *WriteDriver) Exec(_ context.Context, query string, _, _ interface{}) error { if !strings.HasSuffix(query, ";") { query += ";" } _, err := io.WriteString(w, query+"\n") return err } // Tx writes the transaction start. func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) { if _, err := io.WriteString(w, "BEGIN;\n"); err != nil { return nil, err } return w, nil } // Commit writes the transaction commit. func (w *WriteDriver) Commit() error { _, err := io.WriteString(w, "COMMIT;\n") return err } // Rollback writes the transaction rollback. func (w *WriteDriver) Rollback() error { _, err := io.WriteString(w, "ROLLBACK;\n") return err } ent-0.5.4/dialect/sql/schema/writer_test.go000066400000000000000000000030741377533537200206770ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "bytes" "context" "strings" "testing" "github.com/facebook/ent/dialect" "github.com/stretchr/testify/require" ) func TestWriteDriver(t *testing.T) { b := &bytes.Buffer{} w := WriteDriver{Driver: nopDriver{}, Writer: b} ctx := context.Background() tx, err := w.Tx(ctx) require.NoError(t, err) err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil) require.NoError(t, err) err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil) require.NoError(t, err) err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `age` int", nil, nil) require.NoError(t, err) err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", nil, nil) require.NoError(t, err) err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil) require.NoError(t, err) require.NoError(t, tx.Commit()) lines := strings.Split(b.String(), "\n") require.Equal(t, "BEGIN;", lines[0]) require.Equal(t, "ALTER TABLE `users` ADD COLUMN `age` int;", lines[1]) require.Equal(t, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", lines[2]) require.Equal(t, "COMMIT;", lines[3]) require.Empty(t, lines[4], "file ends with blank line") } type nopDriver struct { dialect.Driver } func (nopDriver) Exec(context.Context, string, interface{}, interface{}) error { return nil } func (nopDriver) Query(context.Context, string, interface{}, interface{}) error { return nil } ent-0.5.4/dialect/sql/sqlgraph/000077500000000000000000000000001377533537200163525ustar00rootroot00000000000000ent-0.5.4/dialect/sql/sqlgraph/entql.go000066400000000000000000000202531377533537200200260ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqlgraph import ( "fmt" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/entql" ) type ( // A Schema holds a representation of ent/schema at runtime. Each Node // represents a single schema-type and its relations in the graph (storage). // // It is used for translating common graph traversal operations to the // underlying SQL storage. For example, an operation like `has_edge(E)`, // will be translated to an SQL lookup based on the relation type and the // FK configuration. Schema struct { Nodes []*Node } // A Node in the graph holds the SQL information for an ent/schema. Node struct { NodeSpec // Type holds the node type (schema name). Type string // Fields maps from field names to their spec. Fields map[string]*FieldSpec // Edges maps from edge names to their spec. Edges map[string]struct { To *Node Spec *EdgeSpec } } ) // AddE adds an edge to the graph. It fails, if one of the node // types is missing. // // g.AddE("pets", spec, "user", "pet") // g.AddE("friends", spec, "user", "user") // func (g *Schema) AddE(name string, spec *EdgeSpec, from, to string) error { var fromT, toT *Node for i := range g.Nodes { t := g.Nodes[i].Type if t == from { fromT = g.Nodes[i] } if t == to { toT = g.Nodes[i] } } if fromT == nil || toT == nil { return fmt.Errorf("from/to type was not found") } if fromT.Edges == nil { fromT.Edges = make(map[string]struct { To *Node Spec *EdgeSpec }) } fromT.Edges[name] = struct { To *Node Spec *EdgeSpec }{ To: toT, Spec: spec, } return nil } // MustAddE is like AddE but panics if the edge can be added to the graph. func (g *Schema) MustAddE(name string, spec *EdgeSpec, from, to string) { if err := g.AddE(name, spec, from, to); err != nil { panic(err) } } // EvalP evaluates the entql predicate on the given selector (query builder). func (g *Schema) EvalP(nodeType string, p entql.P, selector *sql.Selector) error { var node *Node for i := range g.Nodes { if g.Nodes[i].Type == nodeType { node = g.Nodes[i] break } } if node == nil { return fmt.Errorf("node %s was not found in the graph schema", nodeType) } pr, err := evalExpr(node, selector, p) if err != nil { return err } selector.Where(pr) return nil } // FuncSelector represents a selector function to be used as an entql foreign-function. const FuncSelector entql.Func = "func_selector" // wrappedFunc wraps the selector-function to an ent-expression. type wrappedFunc struct { entql.Expr Func func(*sql.Selector) } // WrapFunc wraps a selector-func with an entql call expression. func WrapFunc(s func(*sql.Selector)) *entql.CallExpr { return &entql.CallExpr{ Func: FuncSelector, Args: []entql.Expr{wrappedFunc{Func: s}}, } } var ( binary = [...]sql.Op{ entql.OpEQ: sql.OpEQ, entql.OpNEQ: sql.OpNEQ, entql.OpGT: sql.OpGT, entql.OpGTE: sql.OpGTE, entql.OpLT: sql.OpLT, entql.OpLTE: sql.OpLTE, entql.OpIn: sql.OpIn, entql.OpNotIn: sql.OpNotIn, } nary = [...]func(...*sql.Predicate) *sql.Predicate{ entql.OpAnd: sql.And, entql.OpOr: sql.Or, } strFunc = map[entql.Func]func(string, string) *sql.Predicate{ entql.FuncContains: sql.Contains, entql.FuncContainsFold: sql.ContainsFold, entql.FuncEqualFold: sql.EqualFold, entql.FuncHasPrefix: sql.HasPrefix, entql.FuncHasSuffix: sql.HasSuffix, } nullFunc = [...]func(string) *sql.Predicate{ entql.OpEQ: sql.IsNull, entql.OpNEQ: sql.NotNull, } ) // state represents the state of a predicate evaluation. // Note that, the evaluation output is a predicate to be // applied on the database. type state struct { sql.Builder context *Node selector *sql.Selector } // evalExpr evaluates the entql expression and returns a new SQL predicate to be applied on the database. func evalExpr(context *Node, selector *sql.Selector, expr entql.Expr) (p *sql.Predicate, err error) { ex := &state{ context: context, selector: selector, } defer catch(&err) p = ex.evalExpr(expr) return } // evalExpr evaluates any expression. func (e *state) evalExpr(expr entql.Expr) *sql.Predicate { switch expr := expr.(type) { case *entql.BinaryExpr: return e.evalBinary(expr) case *entql.UnaryExpr: return sql.Not(e.evalExpr(expr.X)) case *entql.NaryExpr: ps := make([]*sql.Predicate, len(expr.Xs)) for i, x := range expr.Xs { ps[i] = e.evalExpr(x) } return nary[expr.Op](ps...) case *entql.CallExpr: switch expr.Func { case entql.FuncHasPrefix, entql.FuncHasSuffix, entql.FuncContains, entql.FuncEqualFold, entql.FuncContainsFold: expect(len(expr.Args) == 2, "invalid number of arguments for %s", expr.Func) f, ok := expr.Args[0].(*entql.Field) expect(ok, "*entql.Field, got %T", expr.Args[0]) v, ok := expr.Args[1].(*entql.Value) expect(ok, "*entql.Value, got %T", expr.Args[1]) s, ok := v.V.(string) expect(ok, "string value, got %T", v.V) return strFunc[expr.Func](e.field(f), s) case entql.FuncHasEdge: expect(len(expr.Args) > 0, "invalid number of arguments for %s", expr.Func) edge, ok := expr.Args[0].(*entql.Edge) expect(ok, "*entql.Edge, got %T", expr.Args[0]) return e.evalEdge(edge.Name, expr.Args[1:]...) } } panic("invalid") } // evalBinary evaluates binary expressions. func (e *state) evalBinary(expr *entql.BinaryExpr) *sql.Predicate { switch expr.Op { case entql.OpOr: return sql.Or(e.evalExpr(expr.X), e.evalExpr(expr.Y)) case entql.OpAnd: return sql.And(e.evalExpr(expr.X), e.evalExpr(expr.Y)) case entql.OpEQ, entql.OpNEQ: if expr.Y == (*entql.Value)(nil) { f, ok := expr.X.(*entql.Field) expect(ok, "*entql.Field, got %T", expr.Y) return nullFunc[expr.Op](e.field(f)) } fallthrough default: field, ok := expr.X.(*entql.Field) expect(ok, "expr.X to be *entql.Field (got %T)", expr.X) _, ok = expr.Y.(*entql.Field) if !ok { _, ok = expr.Y.(*entql.Value) } expect(ok, "expr.Y to be *entql.Field or *entql.Value (got %T)", expr.X) return sql.P(func(b *sql.Builder) { b.Ident(e.field(field)) b.WriteOp(binary[expr.Op]) switch x := expr.Y.(type) { case *entql.Field: b.Ident(e.field(x)) case *entql.Value: args(b, x) } }) } } // evalEdge evaluates has-edge and has-edge-with calls. func (e *state) evalEdge(name string, exprs ...entql.Expr) *sql.Predicate { edge, ok := e.context.Edges[name] expect(ok, "edge %q was not found for node %q", name, e.context.Type) step := NewStep( From(e.context.Table, e.context.ID.Column), To(edge.To.Table, edge.To.ID.Column), Edge(edge.Spec.Rel, edge.Spec.Inverse, edge.Spec.Table, edge.Spec.Columns...), ) selector := e.selector.Clone().SetP(nil) selector.SetTotal(e.Total()) if len(exprs) == 0 { HasNeighbors(selector, step) return selector.P() } HasNeighborsWith(selector, step, func(s *sql.Selector) { for _, expr := range exprs { if cx, ok := expr.(*entql.CallExpr); ok && cx.Func == FuncSelector { expect(len(cx.Args) == 1, "invalid number of arguments for %s", FuncSelector) wrapped, ok := cx.Args[0].(wrappedFunc) expect(ok, "invalid argument for %s: %T", FuncSelector, cx.Args[0]) wrapped.Func(s) } else { p, err := evalExpr(edge.To, s, expr) expect(err == nil, "edge evaluation failed for %s->%s: %s", e.context.Type, name, err) s.Where(p) } } }) return selector.P() } func (e *state) field(f *entql.Field) string { _, ok := e.context.Fields[f.Name] expect(ok || e.context.ID.Column == f.Name, "field %q was not found for node %q", f.Name, e.context.Type) return f.Name } func args(b *sql.Builder, v *entql.Value) { vs, ok := v.V.([]interface{}) if !ok { b.Arg(v.V) return } b.Args(vs...) } // expect panics if the condition is false. func expect(cond bool, msg string, args ...interface{}) { if !cond { panic(evalError{fmt.Sprintf("expect "+msg, args...)}) } } type evalError struct { msg string } func (p evalError) Error() string { return fmt.Sprintf("sqlgraph: %s", p.msg) } func catch(err *error) { if e := recover(); e != nil { xerr, ok := e.(evalError) if !ok { panic(e) } *err = xerr } } ent-0.5.4/dialect/sql/sqlgraph/entql_test.go000066400000000000000000000151641377533537200210720ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqlgraph import ( "strconv" "testing" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/entql" "github.com/facebook/ent/schema/field" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGraph_AddE(t *testing.T) { g := &Schema{ Nodes: []*Node{{Type: "user"}, {Type: "pet"}}, } err := g.AddE("pets", &EdgeSpec{Rel: O2M}, "user", "pet") assert.NoError(t, err) err = g.AddE("owner", &EdgeSpec{Rel: O2M}, "pet", "user") assert.NoError(t, err) err = g.AddE("groups", &EdgeSpec{Rel: M2M}, "pet", "groups") assert.Error(t, err) } func TestGraph_EvalP(t *testing.T) { g := &Schema{ Nodes: []*Node{ { Type: "user", NodeSpec: NodeSpec{ Table: "users", ID: &FieldSpec{Column: "uid"}, }, Fields: map[string]*FieldSpec{ "name": {Column: "name", Type: field.TypeString}, "last": {Column: "last", Type: field.TypeString}, }, }, { Type: "pet", NodeSpec: NodeSpec{ Table: "pets", ID: &FieldSpec{Column: "pid"}, }, Fields: map[string]*FieldSpec{ "name": {Column: "name", Type: field.TypeString}, }, }, { Type: "group", NodeSpec: NodeSpec{ Table: "groups", ID: &FieldSpec{Column: "gid"}, }, Fields: map[string]*FieldSpec{ "name": {Column: "name", Type: field.TypeString}, }, }, }, } err := g.AddE("pets", &EdgeSpec{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}}, "user", "pet") require.NoError(t, err) err = g.AddE("owner", &EdgeSpec{Rel: M2O, Inverse: true, Table: "pets", Columns: []string{"owner_id"}}, "pet", "user") require.NoError(t, err) err = g.AddE("groups", &EdgeSpec{Rel: M2M, Table: "user_groups", Columns: []string{"user_id", "group_id"}}, "user", "group") require.NoError(t, err) err = g.AddE("users", &EdgeSpec{Rel: M2M, Inverse: true, Table: "user_groups", Columns: []string{"user_id", "group_id"}}, "group", "user") require.NoError(t, err) tests := []struct { s *sql.Selector p entql.P wantQuery string wantArgs []interface{} wantErr bool }{ { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.FieldHasPrefix("name", "a"), wantQuery: `SELECT * FROM "users" WHERE "name" LIKE $1`, wantArgs: []interface{}{"a%"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")). Where(sql.EQ("age", 1)), p: entql.FieldHasPrefix("name", "a"), wantQuery: `SELECT * FROM "users" WHERE "age" = $1 AND "name" LIKE $2`, wantArgs: []interface{}{1, "a%"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")). Where(sql.EQ("age", 1)), p: entql.FieldHasPrefix("name", "a"), wantQuery: `SELECT * FROM "users" WHERE "age" = $1 AND "name" LIKE $2`, wantArgs: []interface{}{1, "a%"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.EQ(entql.F("name"), entql.F("last")), wantQuery: `SELECT * FROM "users" WHERE "name" = "last"`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.EQ(entql.F("name"), entql.F("last")), wantQuery: `SELECT * FROM "users" WHERE "name" = "last"`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.And(entql.FieldNil("name"), entql.FieldNotNil("last")), wantQuery: `SELECT * FROM "users" WHERE "name" IS NULL AND "last" IS NOT NULL`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")). Where(sql.EQ("foo", "bar")), p: entql.Or(entql.FieldEQ("name", "foo"), entql.FieldEQ("name", "baz")), wantQuery: `SELECT * FROM "users" WHERE "foo" = $1 AND ("name" = $2 OR "name" = $3)`, wantArgs: []interface{}{"bar", "foo", "baz"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.HasEdge("pets"), wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL)`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.HasEdge("groups"), wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups")`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.HasEdgeWith("pets", entql.Or(entql.FieldEQ("name", "pedro"), entql.FieldEQ("name", "xabi"))), wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $1 OR "name" = $2)`, wantArgs: []interface{}{"pedro", "xabi"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)), p: entql.HasEdgeWith("groups", entql.Or(entql.FieldEQ("name", "GitHub"), entql.FieldEQ("name", "GitLab"))), wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t0" ON "user_groups"."group_id" = "t0"."gid" WHERE "name" = $2 OR "name" = $3)`, wantArgs: []interface{}{true, "GitHub", "GitLab"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)), p: entql.And(entql.HasEdge("pets"), entql.HasEdge("groups"), entql.EQ(entql.F("name"), entql.F("uid"))), wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "name" = "uid")`, wantArgs: []interface{}{true}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)), p: entql.HasEdgeWith("pets", entql.FieldEQ("name", "pedro"), WrapFunc(func(s *sql.Selector) { s.Where(sql.EQ("owner_id", 10)) })), wantQuery: `SELECT * FROM "users" WHERE "active" = $1 AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2 AND "owner_id" = $3)`, wantArgs: []interface{}{true, "pedro", 10}, }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { err = g.EvalP("user", tt.p, tt.s) require.Equal(t, tt.wantErr, err != nil, err) query, args := tt.s.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } ent-0.5.4/dialect/sql/sqlgraph/graph.go000066400000000000000000001060721377533537200200100ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package sqlgraph provides graph abstraction capabilities on top // of sql-based databases for ent codegen. package sqlgraph import ( "context" "database/sql/driver" "encoding/json" "fmt" "math" "sort" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" ) // Rel is a relation type of an edge. type Rel int // Relation types. const ( _ Rel = iota // Unknown. O2O // One to one / has one. O2M // One to many / has many. M2O // Many to one (inverse perspective for O2M). M2M // Many to many. ) // String returns the relation name. func (r Rel) String() (s string) { switch r { case O2O: s = "O2O" case O2M: s = "O2M" case M2O: s = "M2O" case M2M: s = "M2M" default: s = "Unknown" } return s } // A ConstraintError represents an error from mutation that violates a specific constraint. type ConstraintError struct { msg string } func (e ConstraintError) Error() string { return e.msg } // A Step provides a path-step information to the traversal functions. type Step struct { // From is the source of the step. From struct { // V can be either one vertex or set of vertices. // It can be a pre-processed step (sql.Query) or a simple Go type (integer or string). V interface{} // Table holds the table name of V (from). Table string // Column to join with. Usually the "id" column. Column string } // Edge holds the edge information for getting the neighbors. Edge struct { // Rel of the edge. Rel Rel // Table name of where this edge columns reside. Table string // Columns of the edge. // In O2O and M2O, it holds the foreign-key column. Hence, len == 1. // In M2M, it holds the primary-key columns of the join table. Hence, len == 2. Columns []string // Inverse indicates if the edge is an inverse edge. Inverse bool } // To is the dest of the path (the neighbors). To struct { // Table holds the table name of the neighbors (to). Table string // Column to join with. Usually the "id" column. Column string } } // StepOption allows configuring Steps using functional options. type StepOption func(*Step) // From sets the source of the step. func From(table, column string, v ...interface{}) StepOption { return func(s *Step) { s.From.Table = table s.From.Column = column if len(v) > 0 { s.From.V = v[0] } } } // To sets the destination of the step. func To(table, column string) StepOption { return func(s *Step) { s.To.Table = table s.To.Column = column } } // Edge sets the edge info for getting the neighbors. func Edge(rel Rel, inverse bool, table string, columns ...string) StepOption { return func(s *Step) { s.Edge.Rel = rel s.Edge.Table = table s.Edge.Columns = columns s.Edge.Inverse = inverse } } // NewStep gets list of options and returns a configured step. // // NewStep( // From("table", "pk", V), // To("table", "pk"), // Edge("name", O2M, "fk"), // ) // func NewStep(opts ...StepOption) *Step { s := &Step{} for _, opt := range opts { opt(s) } return s } // Neighbors returns a Selector for evaluating the path-step // and getting the neighbors of one vertex. func Neighbors(dialect string, s *Step) (q *sql.Selector) { builder := sql.Dialect(dialect) switch r := s.Edge.Rel; { case r == M2M: pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] if s.Edge.Inverse { pk1, pk2 = pk2, pk1 } to := builder.Table(s.To.Table) join := builder.Table(s.Edge.Table) match := builder.Select(join.C(pk1)). From(join). Where(sql.EQ(join.C(pk2), s.From.V)) q = builder.Select(). From(to). Join(match). On(to.C(s.To.Column), match.C(pk1)) case r == M2O || (r == O2O && s.Edge.Inverse): t1 := builder.Table(s.To.Table) t2 := builder.Select(s.Edge.Columns[0]). From(builder.Table(s.Edge.Table)). Where(sql.EQ(s.From.Column, s.From.V)) q = builder.Select(). From(t1). Join(t2). On(t1.C(s.To.Column), t2.C(s.Edge.Columns[0])) case r == O2M || (r == O2O && !s.Edge.Inverse): q = builder.Select(). From(builder.Table(s.To.Table)). Where(sql.EQ(s.Edge.Columns[0], s.From.V)) } return q } // SetNeighbors returns a Selector for evaluating the path-step // and getting the neighbors of set of vertices. func SetNeighbors(dialect string, s *Step) (q *sql.Selector) { set := s.From.V.(*sql.Selector) builder := sql.Dialect(dialect) switch r := s.Edge.Rel; { case r == M2M: pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] if s.Edge.Inverse { pk1, pk2 = pk2, pk1 } to := builder.Table(s.To.Table) set.Select(set.C(s.From.Column)) join := builder.Table(s.Edge.Table) match := builder.Select(join.C(pk1)). From(join). Join(set). On(join.C(pk2), set.C(s.From.Column)) q = builder.Select(). From(to). Join(match). On(to.C(s.To.Column), match.C(pk1)) case r == M2O || (r == O2O && s.Edge.Inverse): t1 := builder.Table(s.To.Table) set.Select(set.C(s.Edge.Columns[0])) q = builder.Select(). From(t1). Join(set). On(t1.C(s.To.Column), set.C(s.Edge.Columns[0])) case r == O2M || (r == O2O && !s.Edge.Inverse): t1 := builder.Table(s.To.Table) set.Select(set.C(s.From.Column)) q = builder.Select(). From(t1). Join(set). On(t1.C(s.Edge.Columns[0]), set.C(s.From.Column)) } return q } // HasNeighbors applies on the given Selector a neighbors check. func HasNeighbors(q *sql.Selector, s *Step) { builder := sql.Dialect(q.Dialect()) switch r := s.Edge.Rel; { case r == M2M: pk1 := s.Edge.Columns[0] if s.Edge.Inverse { pk1 = s.Edge.Columns[1] } from := q.Table() join := builder.Table(s.Edge.Table) q.Where( sql.In( from.C(s.From.Column), builder.Select(join.C(pk1)).From(join), ), ) case r == M2O || (r == O2O && s.Edge.Inverse): from := q.Table() q.Where(sql.NotNull(from.C(s.Edge.Columns[0]))) case r == O2M || (r == O2O && !s.Edge.Inverse): from := q.Table() to := builder.Table(s.Edge.Table) q.Where( sql.In( from.C(s.From.Column), builder.Select(to.C(s.Edge.Columns[0])). From(to). Where(sql.NotNull(to.C(s.Edge.Columns[0]))), ), ) } } // HasNeighborsWith applies on the given Selector a neighbors check. // The given predicate applies its filtering on the selector. func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) { builder := sql.Dialect(q.Dialect()) switch r := s.Edge.Rel; { case r == M2M: pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] if s.Edge.Inverse { pk1, pk2 = pk2, pk1 } from := q.Table() to := builder.Table(s.To.Table) edge := builder.Table(s.Edge.Table) join := builder.Select(edge.C(pk2)). From(edge). Join(to). On(edge.C(pk1), to.C(s.To.Column)) matches := builder.Select().From(to) pred(matches) join.FromSelect(matches) q.Where(sql.In(from.C(s.From.Column), join)) case r == M2O || (r == O2O && s.Edge.Inverse): from := q.Table() to := builder.Table(s.To.Table) matches := builder.Select(to.C(s.To.Column)). From(to) pred(matches) q.Where(sql.In(from.C(s.Edge.Columns[0]), matches)) case r == O2M || (r == O2O && !s.Edge.Inverse): from := q.Table() to := builder.Table(s.Edge.Table) matches := builder.Select(to.C(s.Edge.Columns[0])). From(to) pred(matches) q.Where(sql.In(from.C(s.From.Column), matches)) } } type ( // FieldSpec holds the information for updating a field // column in the database. FieldSpec struct { Column string Type field.Type Value driver.Value // value to be stored. } // EdgeTarget holds the information for the target nodes // of an edge. EdgeTarget struct { Nodes []driver.Value IDSpec *FieldSpec } // EdgeSpec holds the information for updating a field // column in the database. EdgeSpec struct { Rel Rel Inverse bool Table string Columns []string Bidi bool // bidirectional edge. Target *EdgeTarget // target nodes. } // EdgeSpecs used for perform common operations on list of edges. EdgeSpecs []*EdgeSpec // NodeSpec defines the information for querying and // decoding nodes in the graph. NodeSpec struct { Table string Columns []string ID *FieldSpec } ) type ( // CreateSpec holds the information for creating // a node in the graph. CreateSpec struct { Table string ID *FieldSpec Fields []*FieldSpec Edges []*EdgeSpec } // BatchCreateSpec holds the information for creating // multiple nodes in the graph. BatchCreateSpec struct { Nodes []*CreateSpec } ) // CreateNode applies the CreateSpec on the graph. func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error { tx, err := drv.Tx(ctx) if err != nil { return err } gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())} cr := &creator{CreateSpec: spec, graph: gr} if err := cr.node(ctx, tx); err != nil { return rollback(tx, err) } return tx.Commit() } // BatchCreate applies the BatchCreateSpec on the graph. func BatchCreate(ctx context.Context, drv dialect.Driver, spec *BatchCreateSpec) error { tx, err := drv.Tx(ctx) if err != nil { return err } gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())} cr := &creator{BatchCreateSpec: spec, graph: gr} if err := cr.nodes(ctx, tx); err != nil { return rollback(tx, err) } return tx.Commit() } type ( // EdgeMut defines edge mutations. EdgeMut struct { Add []*EdgeSpec Clear []*EdgeSpec } // FieldMut defines field mutations. FieldMut struct { Set []*FieldSpec // field = ? Add []*FieldSpec // field = field + ? Clear []*FieldSpec // field = NULL } // UpdateSpec holds the information for updating one // or more nodes in the graph. UpdateSpec struct { Node *NodeSpec Edges EdgeMut Fields FieldMut Predicate func(*sql.Selector) ScanValues func(columns []string) ([]interface{}, error) Assign func(columns []string, values []interface{}) error } ) // UpdateNode applies the UpdateSpec on one node in the graph. func UpdateNode(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) error { tx, err := drv.Tx(ctx) if err != nil { return err } gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())} cr := &updater{UpdateSpec: spec, graph: gr} if err := cr.node(ctx, tx); err != nil { return rollback(tx, err) } return tx.Commit() } // UpdateNodes applies the UpdateSpec on a set of nodes in the graph. func UpdateNodes(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) (int, error) { tx, err := drv.Tx(ctx) if err != nil { return 0, err } gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())} cr := &updater{UpdateSpec: spec, graph: gr} affected, err := cr.nodes(ctx, tx) if err != nil { return 0, rollback(tx, err) } return affected, tx.Commit() } // NotFoundError returns when trying to update an // entity and it was not found in the database. type NotFoundError struct { table string id driver.Value } func (e *NotFoundError) Error() string { return fmt.Sprintf("record with id %v not found in table %s", e.id, e.table) } // DeleteSpec holds the information for delete one // or more nodes in the graph. type DeleteSpec struct { Node *NodeSpec Predicate func(*sql.Selector) } // DeleteNodes applies the DeleteSpec on the graph. func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int, error) { tx, err := drv.Tx(ctx) if err != nil { return 0, err } var ( res sql.Result builder = sql.Dialect(drv.Dialect()) ) selector := builder.Select(). From(builder.Table(spec.Node.Table)) if pred := spec.Predicate; pred != nil { pred(selector) } query, args := builder.Delete(spec.Node.Table).FromSelect(selector).Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return 0, rollback(tx, err) } affected, err := res.RowsAffected() if err != nil { return 0, rollback(tx, err) } return int(affected), tx.Commit() } // QuerySpec holds the information for querying // nodes in the graph. type QuerySpec struct { Node *NodeSpec // Nodes info. From *sql.Selector // Optional query source (from path). Limit int Offset int Unique bool Order func(*sql.Selector) Predicate func(*sql.Selector) ScanValues func(columns []string) ([]interface{}, error) Assign func(columns []string, values []interface{}) error } // QueryNodes queries the nodes in the graph query and scans them to the given values. func QueryNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) error { builder := sql.Dialect(drv.Dialect()) qr := &query{graph: graph{builder: builder}, QuerySpec: spec} return qr.nodes(ctx, drv) } // CountNodes counts the nodes in the given graph query. func CountNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) (int, error) { builder := sql.Dialect(drv.Dialect()) qr := &query{graph: graph{builder: builder}, QuerySpec: spec} return qr.count(ctx, drv) } // EdgeQuerySpec holds the information for querying // edges in the graph. type EdgeQuerySpec struct { Edge *EdgeSpec Predicate func(*sql.Selector) ScanValues func() [2]interface{} Assign func(out, in interface{}) error } // QueryEdges queries the edges in the graph and scans the result with the given dest function. func QueryEdges(ctx context.Context, drv dialect.Driver, spec *EdgeQuerySpec) error { if len(spec.Edge.Columns) != 2 { return fmt.Errorf("sqlgraph: edge query requires 2 columns (out, in)") } out, in := spec.Edge.Columns[0], spec.Edge.Columns[1] if spec.Edge.Inverse { out, in = in, out } selector := sql.Dialect(drv.Dialect()). Select(out, in). From(sql.Table(spec.Edge.Table)) if p := spec.Predicate; p != nil { p(selector) } rows := &sql.Rows{} query, args := selector.Query() if err := drv.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() for rows.Next() { values := spec.ScanValues() if err := rows.Scan(values[0], values[1]); err != nil { return err } if err := spec.Assign(values[0], values[1]); err != nil { return err } } return rows.Err() } type query struct { graph *QuerySpec } func (q *query) nodes(ctx context.Context, drv dialect.Driver) error { rows := &sql.Rows{} selector, err := q.selector() if err != nil { return err } query, args := selector.Query() if err := drv.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() columns, err := rows.Columns() if err != nil { return err } for rows.Next() { values, err := q.ScanValues(columns) if err != nil { return err } if err := rows.Scan(values...); err != nil { return err } if err := q.Assign(columns, values); err != nil { return err } } return rows.Err() } func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) { rows := &sql.Rows{} selector, err := q.selector() if err != nil { return 0, err } selector.Count(selector.C(q.Node.ID.Column)) if q.Unique { selector.SetDistinct(false) selector.Count(sql.Distinct(selector.C(q.Node.ID.Column))) } query, args := selector.Query() if err := drv.Query(ctx, query, args, rows); err != nil { return 0, err } defer rows.Close() return sql.ScanInt(rows) } func (q *query) selector() (*sql.Selector, error) { selector := q.builder.Select().From(q.builder.Table(q.Node.Table)) if q.From != nil { selector = q.From } selector.Select(selector.Columns(q.Node.Columns...)...) if pred := q.Predicate; pred != nil { pred(selector) } if order := q.Order; order != nil { order(selector) } if q.Offset != 0 { // Limit is mandatory for the offset clause. We start // with default value, and override it below if needed. selector.Offset(q.Offset).Limit(math.MaxInt32) } if q.Limit != 0 { selector.Limit(q.Limit) } if q.Unique { selector.Distinct() } if err := selector.Err(); err != nil { return nil, err } return selector, nil } type updater struct { graph *UpdateSpec } func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error { var ( // id holds the PK of the node used for linking // it with the other nodes. id = u.Node.ID.Value addEdges = EdgeSpecs(u.Edges.Add).GroupRel() clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel() ) update := u.builder.Update(u.Node.Table).Where(sql.EQ(u.Node.ID.Column, id)) if err := u.setTableColumns(update, addEdges, clearEdges); err != nil { return err } if !update.Empty() { var res sql.Result query, args := update.Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return err } } if err := u.setExternalEdges(ctx, []driver.Value{id}, addEdges, clearEdges); err != nil { return err } selector := u.builder.Select(u.Node.Columns...). From(u.builder.Table(u.Node.Table)). Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value)) rows := &sql.Rows{} query, args := selector.Query() if err := tx.Query(ctx, query, args, rows); err != nil { return err } return u.scan(rows) } func (u *updater) nodes(ctx context.Context, tx dialect.ExecQuerier) (int, error) { var ( ids []driver.Value addEdges = EdgeSpecs(u.Edges.Add).GroupRel() clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel() multiple = u.hasExternalEdges(addEdges, clearEdges) update = u.builder.Update(u.Node.Table) selector = u.builder.Select(u.Node.ID.Column). From(u.builder.Table(u.Node.Table)) ) if pred := u.Predicate; pred != nil { pred(selector) } // If this change-set contains multiple table updates. if multiple { query, args := selector.Query() rows := &sql.Rows{} if err := u.tx.Query(ctx, query, args, rows); err != nil { return 0, fmt.Errorf("querying table %s: %v", u.Node.Table, err) } defer rows.Close() if err := sql.ScanSlice(rows, &ids); err != nil { return 0, fmt.Errorf("scan node ids: %v", err) } if err := rows.Close(); err != nil { return 0, err } if len(ids) == 0 { return 0, nil } update.Where(matchID(u.Node.ID.Column, ids)) } else { update.FromSelect(selector) } if err := u.setTableColumns(update, addEdges, clearEdges); err != nil { return 0, err } if !update.Empty() { var res sql.Result query, args := update.Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return 0, err } if !multiple { affected, err := res.RowsAffected() if err != nil { return 0, err } return int(affected), nil } } if len(ids) > 0 { if err := u.setExternalEdges(ctx, ids, addEdges, clearEdges); err != nil { return 0, err } } return len(ids), nil } func (u *updater) setExternalEdges(ctx context.Context, ids []driver.Value, addEdges, clearEdges map[Rel][]*EdgeSpec) error { if err := u.graph.clearM2MEdges(ctx, ids, clearEdges[M2M]); err != nil { return err } if err := u.graph.addM2MEdges(ctx, ids, addEdges[M2M]); err != nil { return err } if err := u.graph.clearFKEdges(ctx, ids, append(clearEdges[O2M], clearEdges[O2O]...)); err != nil { return err } if err := u.graph.addFKEdges(ctx, ids, append(addEdges[O2M], addEdges[O2O]...)); err != nil { return err } return nil } func (*updater) hasExternalEdges(addEdges, clearEdges map[Rel][]*EdgeSpec) bool { // M2M edges reside in a join-table, and O2M edges reside // in the M2O table (the entity that holds the FK). if len(clearEdges[M2M]) > 0 || len(addEdges[M2M]) > 0 || len(clearEdges[O2M]) > 0 || len(addEdges[O2M]) > 0 { return true } for _, edges := range [][]*EdgeSpec{clearEdges[O2O], addEdges[O2O]} { for _, e := range edges { if !e.Inverse { return true } } } return false } // setTableColumns sets the table columns and foreign_keys used in insert. func (u *updater) setTableColumns(update *sql.UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error { // Avoid multiple assignments to the same column. setEdges := make(map[string]bool) for _, e := range addEdges[M2O] { setEdges[e.Columns[0]] = true } for _, e := range addEdges[O2O] { if e.Inverse || e.Bidi { setEdges[e.Columns[0]] = true } } for _, fi := range u.Fields.Clear { update.SetNull(fi.Column) } for _, e := range clearEdges[M2O] { if col := e.Columns[0]; !setEdges[col] { update.SetNull(col) } } for _, e := range clearEdges[O2O] { col := e.Columns[0] if (e.Inverse || e.Bidi) && !setEdges[col] { update.SetNull(col) } } err := setTableColumns(u.Fields.Set, addEdges, func(column string, value driver.Value) { update.Set(column, value) }) if err != nil { return err } for _, fi := range u.Fields.Add { update.Add(fi.Column, fi.Value) } return nil } func (u *updater) scan(rows *sql.Rows) error { defer rows.Close() columns, err := rows.Columns() if err != nil { return err } if !rows.Next() { if err := rows.Err(); err != nil { return err } return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value} } values, err := u.ScanValues(columns) if err != nil { return err } if err := rows.Scan(values...); err != nil { return fmt.Errorf("failed scanning rows: %v", err) } if err := u.Assign(columns, values); err != nil { return err } return nil } type creator struct { graph *CreateSpec *BatchCreateSpec } func (c *creator) node(ctx context.Context, tx dialect.ExecQuerier) error { var ( edges = EdgeSpecs(c.Edges).GroupRel() insert = c.builder.Insert(c.Table).Default() ) // Set and create the node. if err := c.setTableColumns(insert, edges); err != nil { return err } if err := c.insert(ctx, tx, insert); err != nil { return fmt.Errorf("insert node to table %q: %v", c.Table, err) } if err := c.graph.addM2MEdges(ctx, []driver.Value{c.ID.Value}, edges[M2M]); err != nil { return err } if err := c.graph.addFKEdges(ctx, []driver.Value{c.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil { return err } return nil } func (c *creator) nodes(ctx context.Context, tx dialect.ExecQuerier) error { if len(c.Nodes) == 0 { return nil } columns := make(map[string]struct{}) values := make([]map[string]driver.Value, len(c.Nodes)) for i, node := range c.Nodes { if i > 0 && node.Table != c.Nodes[i-1].Table { return fmt.Errorf("more than 1 table for batch insert: %q != %q", node.Table, c.Nodes[i-1].Table) } values[i] = make(map[string]driver.Value) if node.ID.Value != nil { columns[node.ID.Column] = struct{}{} values[i][node.ID.Column] = node.ID.Value } edges := EdgeSpecs(node.Edges).GroupRel() err := setTableColumns(node.Fields, edges, func(column string, value driver.Value) { columns[column] = struct{}{} values[i][column] = value }) if err != nil { return err } } for column := range columns { for i := range values { switch _, exists := values[i][column]; { case column == c.Nodes[i].ID.Column && !exists: // If the ID value was provided to one of the nodes, it should be // provided to all others because this affects the way we calculate // their values in MySQL and SQLite dialects. return fmt.Errorf("incosistent id values for batch insert") case !exists: // Assign NULL values for empty placeholders. values[i][column] = nil } } } sorted := keys(columns) insert := c.builder.Insert(c.Nodes[0].Table).Default().Columns(sorted...) for i := range values { vs := make([]interface{}, len(sorted)) for j, c := range sorted { vs[j] = values[i][c] } insert.Values(vs...) } if err := c.batchInsert(ctx, tx, insert); err != nil { return fmt.Errorf("insert nodes to table %q: %v", c.Nodes[0].Table, err) } if err := c.batchAddM2M(ctx, c.BatchCreateSpec); err != nil { return err } // FKs that exist in different tables can't be updated in batch (using the CASE // statement), because we rely on RowsAffected to check if the FK column is NULL. for _, node := range c.Nodes { edges := EdgeSpecs(node.Edges).GroupRel() if err := c.graph.addFKEdges(ctx, []driver.Value{node.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil { return err } } return nil } // setTableColumns sets the table columns and foreign_keys used in insert. func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error { err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) { insert.Set(column, value) }) return err } // insert inserts the node to its table and sets its ID if it wasn't provided by the user. func (c *creator) insert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error { var res sql.Result // If the id field was provided by the user. if c.ID.Value != nil { insert.Set(c.ID.Column, c.ID.Value) query, args := insert.Query() return tx.Exec(ctx, query, args, &res) } id, err := insertLastID(ctx, tx, insert.Returning(c.ID.Column)) if err != nil { return err } c.ID.Value = id return nil } // batchInsert inserts a batch of nodes to their table and sets their ID if it wasn't provided by the user. func (c *creator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error { ids, err := insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column)) if err != nil { return err } for i, node := range c.Nodes { node.ID.Value = ids[i] } return nil } // GroupRel groups edges by their relation type. func (es EdgeSpecs) GroupRel() map[Rel][]*EdgeSpec { edges := make(map[Rel][]*EdgeSpec) for _, edge := range es { edges[edge.Rel] = append(edges[edge.Rel], edge) } return edges } // GroupTable groups edges by their table name. func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec { edges := make(map[string][]*EdgeSpec) for _, edge := range es { edges[edge.Table] = append(edges[edge.Table], edge) } return edges } // FilterRel returns edges for the given relation type. func (es EdgeSpecs) FilterRel(r Rel) EdgeSpecs { edges := make([]*EdgeSpec, 0, len(es)) for _, edge := range es { if edge.Rel == r { edges = append(edges, edge) } } return edges } // The common operations shared between the different builders. // // M2M edges reside in join tables and require INSERT and DELETE // queries for adding or removing edges respectively. // // O2M and non-inverse O2O edges also reside in external tables, // but use UPDATE queries (fk = ?, fk = NULL). type graph struct { tx dialect.ExecQuerier builder *sql.DialectBuilder } func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error { var ( res sql.Result // Remove all M2M edges from the same type at once. // The EdgeSpec is the same for all members in a group. tables = edges.GroupTable() ) for _, table := range edgeKeys(tables) { edges := tables[table] preds := make([]*sql.Predicate, 0, len(edges)) for _, edge := range edges { fromC, toC := edge.Columns[0], edge.Columns[1] if edge.Inverse { fromC, toC = toC, fromC } // If there are no specific edges (to target-nodes) to remove, // clear all edges that go out (or come in) from the nodes. if len(edge.Target.Nodes) == 0 { preds = append(preds, matchID(fromC, ids)) if edge.Bidi { preds = append(preds, matchID(toC, ids)) } } else { pk1, pk2 := ids, edge.Target.Nodes preds = append(preds, matchIDs(fromC, pk1, toC, pk2)) if edge.Bidi { preds = append(preds, matchIDs(toC, pk1, fromC, pk2)) } } } query, args := g.builder.Delete(table).Where(sql.Or(preds...)).Query() if err := g.tx.Exec(ctx, query, args, &res); err != nil { return fmt.Errorf("remove m2m edge for table %s: %v", table, err) } } return nil } func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error { var ( res sql.Result // Insert all M2M edges from the same type at once. // The EdgeSpec is the same for all members in a group. tables = edges.GroupTable() ) for _, table := range edgeKeys(tables) { edges := tables[table] insert := g.builder.Insert(table).Columns(edges[0].Columns...) for _, edge := range edges { pk1, pk2 := ids, edge.Target.Nodes if edge.Inverse { pk1, pk2 = pk2, pk1 } for _, pair := range product(pk1, pk2) { insert.Values(pair[0], pair[1]) if edge.Bidi { insert.Values(pair[1], pair[0]) } } } query, args := insert.Query() if err := g.tx.Exec(ctx, query, args, &res); err != nil { return fmt.Errorf("add m2m edge for table %s: %v", table, err) } } return nil } func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error { tables := make(map[string]*sql.InsertBuilder) for _, node := range spec.Nodes { edges := EdgeSpecs(node.Edges).FilterRel(M2M) for t, edges := range edges.GroupTable() { insert, ok := tables[t] if !ok { insert = g.builder.Insert(t).Columns(edges[0].Columns...) } tables[t] = insert if len(edges) != 1 { return fmt.Errorf("expect exactly 1 edge-spec per table, but got %d", len(edges)) } edge := edges[0] pk1, pk2 := []driver.Value{node.ID.Value}, edge.Target.Nodes if edge.Inverse { pk1, pk2 = pk2, pk1 } for _, pair := range product(pk1, pk2) { insert.Values(pair[0], pair[1]) if edge.Bidi { insert.Values(pair[1], pair[0]) } } } } for _, table := range insertKeys(tables) { var ( res sql.Result query, args = tables[table].Query() ) if err := g.tx.Exec(ctx, query, args, &res); err != nil { return fmt.Errorf("add m2m edge for table %s: %v", table, err) } } return nil } func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error { for _, edge := range edges { if edge.Rel == O2O && edge.Inverse { continue } // O2O relations can be cleared without // passing the target ids. pred := matchID(edge.Columns[0], ids) if nodes := edge.Target.Nodes; len(nodes) > 0 { pred = matchIDs(edge.Target.IDSpec.Column, edge.Target.Nodes, edge.Columns[0], ids) } query, args := g.builder.Update(edge.Table). SetNull(edge.Columns[0]). Where(pred). Query() var res sql.Result if err := g.tx.Exec(ctx, query, args, &res); err != nil { return fmt.Errorf("add %s edge for table %s: %v", edge.Rel, edge.Table, err) } } return nil } func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error { id := ids[0] if len(ids) > 1 && len(edges) != 0 { // O2M and O2O edges are defined by a FK in the "other" table. // Therefore, ids[i+1] will override ids[i] which is invalid. return fmt.Errorf("unable to link FK edge to more than 1 node: %v", ids) } for _, edge := range edges { if edge.Rel == O2O && edge.Inverse { continue } p := sql.EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0]) // Use "IN" predicate instead of list of "OR" // in case of more than on nodes to connect. if len(edge.Target.Nodes) > 1 { p = sql.InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...) } query, args := g.builder.Update(edge.Table). Set(edge.Columns[0], id). Where(sql.And(p, sql.IsNull(edge.Columns[0]))). Query() var res sql.Result if err := g.tx.Exec(ctx, query, args, &res); err != nil { return fmt.Errorf("add %s edge for table %s: %v", edge.Rel, edge.Table, err) } affected, err := res.RowsAffected() if err != nil { return err } // Setting the FK value of the "other" table // without clearing it before, is not allowed. if ids := edge.Target.Nodes; int(affected) < len(ids) { return &ConstraintError{msg: fmt.Sprintf("one of %v is already connected to a different %s", ids, edge.Columns[0])} } } return nil } // setTableColumns is shared between updater and creator. func setTableColumns(fields []*FieldSpec, edges map[Rel][]*EdgeSpec, set func(string, driver.Value)) (err error) { for _, fi := range fields { value := fi.Value if fi.Type == field.TypeJSON { buf, err := json.Marshal(value) if err != nil { return fmt.Errorf("marshal value for column %s: %v", fi.Column, err) } // If the underlying driver does not support JSON types, // driver.DefaultParameterConverter will convert it to uint8. value = json.RawMessage(buf) } set(fi.Column, value) } for _, e := range edges[M2O] { set(e.Columns[0], e.Target.Nodes[0]) } for _, e := range edges[O2O] { if e.Inverse || e.Bidi { set(e.Columns[0], e.Target.Nodes[0]) } } return nil } // insertLastID invokes the insert query on the transaction and returns the LastInsertID. func insertLastID(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (driver.Value, error) { query, args := insert.Query() // PostgreSQL does not support the LastInsertId() method of sql.Result // on Exec, and should be extracted manually using the `RETURNING` clause. if insert.Dialect() == dialect.Postgres { rows := &sql.Rows{} if err := tx.Query(ctx, query, args, rows); err != nil { return 0, err } defer rows.Close() return sql.ScanValue(rows) } // MySQL, SQLite, etc. var res sql.Result if err := tx.Exec(ctx, query, args, &res); err != nil { return 0, err } return res.LastInsertId() } // insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities. func insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) (ids []driver.Value, err error) { query, args := insert.Query() // PostgreSQL does not support the LastInsertId() method of sql.Result // on Exec, and should be extracted manually using the `RETURNING` clause. if insert.Dialect() == dialect.Postgres { rows := &sql.Rows{} if err := tx.Query(ctx, query, args, rows); err != nil { return nil, err } defer rows.Close() return ids, sql.ScanSlice(rows, &ids) } // MySQL, SQLite, etc. var res sql.Result if err := tx.Exec(ctx, query, args, &res); err != nil { return nil, err } id, err := res.LastInsertId() if err != nil { return nil, err } affected, err := res.RowsAffected() if err != nil { return nil, err } ids = make([]driver.Value, 0, affected) switch insert.Dialect() { case dialect.SQLite: id -= affected - 1 fallthrough case dialect.MySQL: for i := int64(0); i < affected; i++ { ids = append(ids, id+i) } } return ids, nil } // rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred. func rollback(tx dialect.Tx, err error) error { if rerr := tx.Rollback(); rerr != nil { err = fmt.Errorf("%s: %v", err.Error(), rerr) } return err } func edgeKeys(m map[string][]*EdgeSpec) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Strings(keys) return keys } func insertKeys(m map[string]*sql.InsertBuilder) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Strings(keys) return keys } func keys(m map[string]struct{}) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Strings(keys) return keys } func matchID(column string, pk []driver.Value) *sql.Predicate { if len(pk) > 1 { return sql.InValues(column, pk...) } return sql.EQ(column, pk[0]) } func matchIDs(column1 string, pk1 []driver.Value, column2 string, pk2 []driver.Value) *sql.Predicate { p := matchID(column1, pk1) if len(pk2) > 1 { // Use "IN" predicate instead of list of "OR" // in case of more than on nodes to connect. return sql.And(p, sql.InValues(column2, pk2...)) } return sql.And(p, sql.EQ(column2, pk2[0])) } // cartesian product of 2 id sets. func product(a, b []driver.Value) [][2]driver.Value { c := make([][2]driver.Value, 0, len(a)*len(b)) for i := range a { for j := range b { c = append(c, [2]driver.Value{a[i], b[j]}) } } return c } ent-0.5.4/dialect/sql/sqlgraph/graph_test.go000066400000000000000000001401011377533537200210360ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqlgraph import ( "context" "database/sql/driver" "fmt" "regexp" "strings" "testing" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestNeighbors(t *testing.T) { tests := []struct { name string input *Step wantQuery string wantArgs []interface{} }{ { name: "O2O/1type", // Since the relation is on the same sql.Table, // V used as a reference value. input: NewStep( From("users", "id", 1), To("users", "id"), Edge(O2O, false, "users", "spouse_id"), ), wantQuery: "SELECT * FROM `users` WHERE `spouse_id` = ?", wantArgs: []interface{}{1}, }, { name: "O2O/1type/inverse", input: NewStep( From("nodes", "id", 1), To("nodes", "id"), Edge(O2O, true, "nodes", "prev_id"), ), wantQuery: "SELECT * FROM `nodes` JOIN (SELECT `prev_id` FROM `nodes` WHERE `id` = ?) AS `t1` ON `nodes`.`id` = `t1`.`prev_id`", wantArgs: []interface{}{1}, }, { name: "O2M/1type", input: NewStep( From("users", "id", 1), To("users", "id"), Edge(O2M, false, "users", "parent_id"), ), wantQuery: "SELECT * FROM `users` WHERE `parent_id` = ?", wantArgs: []interface{}{1}, }, { name: "O2O/2types", input: NewStep( From("users", "id", 2), To("card", "id"), Edge(O2O, false, "cards", "owner_id"), ), wantQuery: "SELECT * FROM `card` WHERE `owner_id` = ?", wantArgs: []interface{}{2}, }, { name: "O2O/2types/inverse", input: NewStep( From("cards", "id", 2), To("users", "id"), Edge(O2O, true, "cards", "owner_id"), ), wantQuery: "SELECT * FROM `users` JOIN (SELECT `owner_id` FROM `cards` WHERE `id` = ?) AS `t1` ON `users`.`id` = `t1`.`owner_id`", wantArgs: []interface{}{2}, }, { name: "O2M/2types", input: NewStep( From("users", "id", 1), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ), wantQuery: "SELECT * FROM `pets` WHERE `owner_id` = ?", wantArgs: []interface{}{1}, }, { name: "M2O/2types/inverse", input: NewStep( From("pets", "id", 2), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), wantQuery: "SELECT * FROM `users` JOIN (SELECT `owner_id` FROM `pets` WHERE `id` = ?) AS `t1` ON `users`.`id` = `t1`.`owner_id`", wantArgs: []interface{}{2}, }, { name: "M2O/1type/inverse", input: NewStep( From("users", "id", 2), To("users", "id"), Edge(M2O, true, "users", "parent_id"), ), wantQuery: "SELECT * FROM `users` JOIN (SELECT `parent_id` FROM `users` WHERE `id` = ?) AS `t1` ON `users`.`id` = `t1`.`parent_id`", wantArgs: []interface{}{2}, }, { name: "M2M/2type", input: NewStep( From("groups", "id", 2), To("users", "id"), Edge(M2M, false, "user_groups", "group_id", "user_id"), ), wantQuery: "SELECT * FROM `users` JOIN (SELECT `user_groups`.`user_id` FROM `user_groups` WHERE `user_groups`.`group_id` = ?) AS `t1` ON `users`.`id` = `t1`.`user_id`", wantArgs: []interface{}{2}, }, { name: "M2M/2type/inverse", input: NewStep( From("users", "id", 2), To("groups", "id"), Edge(M2M, true, "user_groups", "group_id", "user_id"), ), wantQuery: "SELECT * FROM `groups` JOIN (SELECT `user_groups`.`group_id` FROM `user_groups` WHERE `user_groups`.`user_id` = ?) AS `t1` ON `groups`.`id` = `t1`.`group_id`", wantArgs: []interface{}{2}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { selector := Neighbors("", tt.input) query, args := selector.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } func TestSetNeighbors(t *testing.T) { tests := []struct { name string input *Step wantQuery string wantArgs []interface{} }{ { name: "O2M/2types", input: NewStep( From("users", "id", sql.Select().From(sql.Table("users")).Where(sql.EQ("name", "a8m"))), To("pets", "id"), Edge(O2M, false, "users", "owner_id"), ), wantQuery: `SELECT * FROM "pets" JOIN (SELECT "users"."id" FROM "users" WHERE "name" = $1) AS "t1" ON "pets"."owner_id" = "t1"."id"`, wantArgs: []interface{}{"a8m"}, }, { name: "M2O/2types", input: NewStep( From("pets", "id", sql.Select().From(sql.Table("pets")).Where(sql.EQ("name", "pedro"))), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), wantQuery: `SELECT * FROM "users" JOIN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $1) AS "t1" ON "users"."id" = "t1"."owner_id"`, wantArgs: []interface{}{"pedro"}, }, { name: "M2M/2types", input: NewStep( From("users", "id", sql.Select().From(sql.Table("users")).Where(sql.EQ("name", "a8m"))), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ), wantQuery: ` SELECT * FROM "groups" JOIN (SELECT "user_groups"."group_id" FROM "user_groups" JOIN (SELECT "users"."id" FROM "users" WHERE "name" = $1) AS "t1" ON "user_groups"."user_id" = "t1"."id") AS "t1" ON "groups"."id" = "t1"."group_id"`, wantArgs: []interface{}{"a8m"}, }, { name: "M2M/2types/inverse", input: NewStep( From("groups", "id", sql.Select().From(sql.Table("groups")).Where(sql.EQ("name", "GitHub"))), To("users", "id"), Edge(M2M, true, "user_groups", "user_id", "group_id"), ), wantQuery: ` SELECT * FROM "users" JOIN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN (SELECT "groups"."id" FROM "groups" WHERE "name" = $1) AS "t1" ON "user_groups"."group_id" = "t1"."id") AS "t1" ON "users"."id" = "t1"."user_id"`, wantArgs: []interface{}{"GitHub"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { selector := SetNeighbors("postgres", tt.input) query, args := selector.Query() tt.wantQuery = strings.Join(strings.Fields(tt.wantQuery), " ") require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } func TestHasNeighbors(t *testing.T) { tests := []struct { name string step *Step selector *sql.Selector wantQuery string }{ { name: "O2O/1type", // A nodes sql.Table; linked-list (next->prev). The "prev" // node holds association pointer. The neighbors query // here checks if a node "has-next". step: NewStep( From("nodes", "id"), To("nodes", "id"), Edge(O2O, false, "nodes", "prev_id"), ), selector: sql.Select("*").From(sql.Table("nodes")), wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`id` IN (SELECT `nodes`.`prev_id` FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL)", }, { name: "O2O/1type/inverse", // Same example as above, but the neighbors // query checks if a node "has-previous". step: NewStep( From("nodes", "id"), To("nodes", "id"), Edge(O2O, true, "nodes", "prev_id"), ), selector: sql.Select("*").From(sql.Table("nodes")), wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL", }, { name: "O2M/2type2", step: NewStep( From("users", "id"), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ), selector: sql.Select("*").From(sql.Table("users")), wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `pets`.`owner_id` FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL)", }, { name: "M2O/2type2", step: NewStep( From("pets", "id"), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), selector: sql.Select("*").From(sql.Table("pets")), wantQuery: "SELECT * FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL", }, { name: "M2M/2types", step: NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ), selector: sql.Select("*").From(sql.Table("users")), wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `user_groups`.`user_id` FROM `user_groups`)", }, { name: "M2M/2types/inverse", step: NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, true, "group_users", "group_id", "user_id"), ), selector: sql.Select("*").From(sql.Table("users")), wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `group_users`.`user_id` FROM `group_users`)", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { HasNeighbors(tt.selector, tt.step) query, args := tt.selector.Query() require.Equal(t, tt.wantQuery, query) require.Empty(t, args) }) } } func TestHasNeighborsWith(t *testing.T) { tests := []struct { name string step *Step selector *sql.Selector predicate func(*sql.Selector) wantQuery string wantArgs []interface{} }{ { name: "O2O", step: NewStep( From("users", "id"), To("cards", "id"), Edge(O2O, false, "cards", "owner_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("users")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("expired", false)) }, wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE "expired" = $1)`, wantArgs: []interface{}{false}, }, { name: "O2O/inverse", step: NewStep( From("cards", "id"), To("users", "id"), Edge(O2O, true, "cards", "owner_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("cards")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "a8m")) }, wantQuery: `SELECT * FROM "cards" WHERE "cards"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "name" = $1)`, wantArgs: []interface{}{"a8m"}, }, { name: "O2M", step: NewStep( From("users", "id"), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ), selector: sql.Dialect("postgres").Select("*"). From(sql.Table("users")). Where(sql.EQ("last_name", "mashraki")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "pedro")) }, wantQuery: `SELECT * FROM "users" WHERE "last_name" = $1 AND "users"."id" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2)`, wantArgs: []interface{}{"mashraki", "pedro"}, }, { name: "M2O", step: NewStep( From("pets", "id"), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), selector: sql.Dialect("postgres").Select("*"). From(sql.Table("pets")). Where(sql.EQ("name", "pedro")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("last_name", "mashraki")) }, wantQuery: `SELECT * FROM "pets" WHERE "name" = $1 AND "pets"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "last_name" = $2)`, wantArgs: []interface{}{"pedro", "mashraki"}, }, { name: "M2M", step: NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("users")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "GitHub")) }, wantQuery: ` SELECT * FROM "users" WHERE "users"."id" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t0" ON "user_groups"."group_id" = "t0"."id" WHERE "name" = $1)`, wantArgs: []interface{}{"GitHub"}, }, { name: "M2M/inverse", step: NewStep( From("groups", "id"), To("users", "id"), Edge(M2M, true, "user_groups", "user_id", "group_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("groups")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "a8m")) }, wantQuery: ` SELECT * FROM "groups" WHERE "groups"."id" IN (SELECT "user_groups"."group_id" FROM "user_groups" JOIN "users" AS "t0" ON "user_groups"."user_id" = "t0"."id" WHERE "name" = $1)`, wantArgs: []interface{}{"a8m"}, }, { name: "M2M/inverse", step: NewStep( From("groups", "id"), To("users", "id"), Edge(M2M, true, "user_groups", "user_id", "group_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("groups")), predicate: func(s *sql.Selector) { s.Where(sql.And(sql.NotNull("name"), sql.EQ("name", "a8m"))) }, wantQuery: ` SELECT * FROM "groups" WHERE "groups"."id" IN (SELECT "user_groups"."group_id" FROM "user_groups" JOIN "users" AS "t0" ON "user_groups"."user_id" = "t0"."id" WHERE "name" IS NOT NULL AND "name" = $1)`, wantArgs: []interface{}{"a8m"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { HasNeighborsWith(tt.selector, tt.step, tt.predicate) query, args := tt.selector.Query() tt.wantQuery = strings.Join(strings.Fields(tt.wantQuery), " ") require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } func TestCreateNode(t *testing.T) { tests := []struct { name string spec *CreateSpec expect func(sqlmock.Sqlmock) wantErr bool }{ { name: "fields", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "a8m"}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`) VALUES (?, ?)")). WithArgs(30, "a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "fields/user-defined-id", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Value: 1}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "a8m"}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`, `id`) VALUES (?, ?, ?)")). WithArgs(30, "a8m", 1). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "fields/json", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "json", Type: field.TypeJSON, Value: struct{}{}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`json`) VALUES (?)")). WithArgs([]byte("{}")). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2o", spec: &CreateSpec{ Table: "pets", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "pedro"}, }, Edges: []*EdgeSpec{ {Rel: M2O, Columns: []string{"owner_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `pets` (`name`, `owner_id`) VALUES (?, ?)")). WithArgs("pedro", 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/o2o/inverse", spec: &CreateSpec{ Table: "cards", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "number", Type: field.TypeString, Value: "0001"}, }, Edges: []*EdgeSpec{ {Rel: O2O, Columns: []string{"owner_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `cards` (`number`, `owner_id`) VALUES (?, ?)")). WithArgs("0001", 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/o2m", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "a8m"}, }, Edges: []*EdgeSpec{ {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/o2m", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "a8m"}, }, Edges: []*EdgeSpec{ {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2, 3, 4}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` IN (?, ?, ?) AND `owner_id` IS NULL")). WithArgs(1, 2, 3, 4). WillReturnResult(sqlmock.NewResult(1, 3)) m.ExpectCommit() }, }, { name: "edges/o2o", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "a8m"}, }, Edges: []*EdgeSpec{ {Rel: O2O, Table: "cards", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("UPDATE `cards` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/o2o/bidi", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "a8m"}, }, Edges: []*EdgeSpec{ {Rel: O2O, Bidi: true, Table: "users", Columns: []string{"spouse_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`, `spouse_id`) VALUES (?, ?)")). WithArgs("a8m", 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("UPDATE `users` SET `spouse_id` = ? WHERE `id` = ? AND `spouse_id` IS NULL")). WithArgs(1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m", spec: &CreateSpec{ Table: "groups", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "GitHub"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `groups` (`name`) VALUES (?)")). WithArgs("GitHub"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?)")). WithArgs(1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m/inverse", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "mashraki"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("mashraki"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?)")). WithArgs(2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m/bidi", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "mashraki"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("mashraki"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")). WithArgs(1, 2, 2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m/bidi/batch", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "mashraki"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{4}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{5}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("mashraki"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). WithArgs(4, 1, 5, 1). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). WithArgs(1, 2, 2, 1, 1, 3, 3, 1). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.expect(mock) err = CreateNode(context.Background(), sql.OpenDB("", db), tt.spec) require.Equal(t, tt.wantErr, err != nil, err) }) } } func TestBatchCreate(t *testing.T) { tests := []struct { name string nodes []*CreateSpec expect func(sqlmock.Sqlmock) wantErr bool }{ { name: "empty", expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectCommit() }, }, { name: "multiple", nodes: []*CreateSpec{ { Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 32}, {Column: "name", Type: field.TypeString, Value: "a8m"}, {Column: "active", Type: field.TypeBool, Value: false}, }, Edges: []*EdgeSpec{ {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Table: "user_products", Columns: []string{"user_id", "product_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}}, {Rel: M2O, Table: "company", Columns: []string{"workplace_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, { Table: "users", ID: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "nati"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Table: "user_products", Columns: []string{"user_id", "product_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}}, {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() // Insert nodes with FKs. m.ExpectExec(escape("INSERT INTO `users` (`active`, `age`, `name`, `workplace_id`) VALUES (?, ?, ?, ?), (?, ?, ?, ?)")). WithArgs(false, 32, "a8m", 2, nil, 30, "nati", nil). WillReturnResult(sqlmock.NewResult(10, 2)) // Insert M2M inverse-edges. m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). WithArgs(2, 10, 2, 11). WillReturnResult(sqlmock.NewResult(2, 2)) // Insert M2M bidirectional edges. m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). WithArgs(10, 2, 2, 10, 11, 2, 2, 11). WillReturnResult(sqlmock.NewResult(2, 2)) // Insert M2M edges. m.ExpectExec(escape("INSERT INTO `user_products` (`user_id`, `product_id`) VALUES (?, ?), (?, ?)")). WithArgs(10, 2, 11, 2). WillReturnResult(sqlmock.NewResult(2, 2)) // Update FKs exist in different tables. m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(10 /* id of the 1st new node */, 2 /* pet id */). WillReturnResult(sqlmock.NewResult(2, 2)) m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(11 /* id of the 2nd new node */, 3 /* pet id */). WillReturnResult(sqlmock.NewResult(2, 2)) m.ExpectCommit() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.expect(mock) err = BatchCreate(context.Background(), sql.OpenDB("mysql", db), &BatchCreateSpec{Nodes: tt.nodes}) require.Equal(t, tt.wantErr, err != nil, err) }) } } type user struct { id int age int name string edges struct { fk1 int fk2 int } } func (*user) values(columns []string) ([]interface{}, error) { values := make([]interface{}, len(columns)) for i := range columns { switch c := columns[i]; c { case "id", "age", "fk1", "fk2": values[i] = &sql.NullInt64{} case "name": values[i] = &sql.NullString{} default: return nil, fmt.Errorf("unexpected column %q", c) } } return values, nil } func (u *user) assign(columns []string, values []interface{}) error { if len(columns) != len(values) { return fmt.Errorf("mismatch number of values") } for i, c := range columns { switch c { case "id": u.id = int(values[i].(*sql.NullInt64).Int64) case "age": u.age = int(values[i].(*sql.NullInt64).Int64) case "name": u.name = values[i].(*sql.NullString).String case "fk1": u.edges.fk1 = int(values[i].(*sql.NullInt64).Int64) case "fk2": u.edges.fk2 = int(values[i].(*sql.NullInt64).Int64) default: return fmt.Errorf("unknown column %q", c) } } return nil } func TestUpdateNode(t *testing.T) { tests := []struct { name string spec *UpdateSpec prepare func(sqlmock.Sqlmock) wantErr bool wantUser *user }{ { name: "fields/set", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Fields: FieldMut{ Set: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "Ariel"}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ? WHERE `id` = ?")). WithArgs(30, "Ariel", 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 30, "Ariel")) mock.ExpectCommit() }, wantUser: &user{name: "Ariel", age: 30, id: 1}, }, { name: "fields/add_clear", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Fields: FieldMut{ Add: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 1}, }, Clear: []*FieldSpec{ {Column: "name", Type: field.TypeString}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `age` = COALESCE(`age`, ?) + ? WHERE `id` = ?")). WithArgs(0, 1, 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 31, nil)) mock.ExpectCommit() }, wantUser: &user{age: 31, id: 1}, }, { name: "edges/o2o_non_inverse and m2o", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: O2O, Columns: []string{"car_id"}, Inverse: true}, {Rel: M2O, Columns: []string{"workplace_id"}, Inverse: true}, }, Add: []*EdgeSpec{ {Rel: O2O, Columns: []string{"card_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, {Rel: M2O, Columns: []string{"parent_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `users` SET `workplace_id` = NULL, `car_id` = NULL, `parent_id` = ?, `card_id` = ? WHERE `id` = ?")). WithArgs(2, 2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 31, nil)) mock.ExpectCommit() }, wantUser: &user{age: 31, id: 1}, }, { name: "edges/o2o_bidi", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"partner_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}}, {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}}, }, Add: []*EdgeSpec{ {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Clear the "partner" from 1's column, and set "spouse 3". // "spouse 2" is implicitly removed when setting a different foreign-key. mock.ExpectExec(escape("UPDATE `users` SET `partner_id` = NULL, `spouse_id` = ? WHERE `id` = ?")). WithArgs(3, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Clear the "partner_id" column from previous 1's partner. mock.ExpectExec(escape("UPDATE `users` SET `partner_id` = NULL WHERE `partner_id` = ?")). WithArgs(1). WillReturnResult(sqlmock.NewResult(1, 1)) // Clear "spouse 1" from 3's column. mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL WHERE `id` = ? AND `spouse_id` = ?")). WithArgs(2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Set 3's column to point "spouse 1". mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = ? WHERE `id` = ? AND `spouse_id` IS NULL")). WithArgs(1, 3). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 31, nil)) mock.ExpectCommit() }, wantUser: &user{age: 31, id: 1}, }, { name: "edges/clear_add_m2m", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{2}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{3, 7}}}, // Clear all "following" edges (and their inverse). {Rel: M2M, Table: "user_following", Bidi: true, Columns: []string{"following_id", "follower_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}}, // Clear all "user_blocked" edges. {Rel: M2M, Table: "user_blocked", Columns: []string{"user_id", "blocked_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}}, // Clear all "comments" edges. {Rel: M2M, Inverse: true, Table: "comment_responders", Columns: []string{"comment_id", "responder_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}}, }, Add: []*EdgeSpec{ {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{4}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{5}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}, Nodes: []driver.Value{6, 8}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Clear comment responders. mock.ExpectExec(escape("DELETE FROM `comment_responders` WHERE `responder_id` = ?")). WithArgs(1). WillReturnResult(sqlmock.NewResult(1, 1)) // Remove user groups. mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `user_id` = ? AND `group_id` IN (?, ?)")). WithArgs(1, 3, 7). WillReturnResult(sqlmock.NewResult(1, 1)) // Clear all blocked users. mock.ExpectExec(escape("DELETE FROM `user_blocked` WHERE `user_id` = ?")). WithArgs(1). WillReturnResult(sqlmock.NewResult(1, 1)) // Clear all user following. mock.ExpectExec(escape("DELETE FROM `user_following` WHERE `following_id` = ? OR `follower_id` = ?")). WithArgs(1, 1). WillReturnResult(sqlmock.NewResult(1, 2)) // Clear user friends. mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` = ? AND `friend_id` = ?) OR (`friend_id` = ? AND `user_id` = ?)")). WithArgs(1, 2, 1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) // Add new groups. mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?), (?, ?)")). WithArgs(5, 1, 6, 1, 8, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Add new friends. mock.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")). WithArgs(1, 4, 4, 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 31, nil)) mock.ExpectCommit() }, wantUser: &user{age: 31, id: 1}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.prepare(mock) usr := &user{} tt.spec.Assign = usr.assign tt.spec.ScanValues = usr.values err = UpdateNode(context.Background(), sql.OpenDB("", db), tt.spec) require.Equal(t, tt.wantErr, err != nil, err) require.Equal(t, tt.wantUser, usr) }) } } func TestUpdateNodes(t *testing.T) { tests := []struct { name string spec *UpdateSpec prepare func(sqlmock.Sqlmock) wantErr bool wantAffected int }{ { name: "without predicate", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Fields: FieldMut{ Set: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "Ariel"}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Apply field changes. mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ?")). WithArgs(30, "Ariel"). WillReturnResult(sqlmock.NewResult(0, 2)) mock.ExpectCommit() }, wantAffected: 2, }, { name: "with predicate", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Fields: FieldMut{ Clear: []*FieldSpec{ {Column: "age", Type: field.TypeInt}, {Column: "name", Type: field.TypeString}, }, }, Predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "a8m")) }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Clear fields. mock.ExpectExec(escape("UPDATE `users` SET `age` = NULL, `name` = NULL WHERE `name` = ?")). WithArgs("a8m"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, wantAffected: 1, }, { name: "own_fks/m2o_o2o_inverse", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: O2O, Columns: []string{"car_id"}, Inverse: true}, {Rel: M2O, Columns: []string{"workplace_id"}, Inverse: true}, }, Add: []*EdgeSpec{ {Rel: O2O, Columns: []string{"card_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{3}}}, {Rel: M2O, Columns: []string{"parent_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{4}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Clear "car" and "workplace" foreign_keys and add "card" and a "parent". mock.ExpectExec(escape("UPDATE `users` SET `workplace_id` = NULL, `car_id` = NULL, `parent_id` = ?, `card_id` = ?")). WithArgs(4, 3). WillReturnResult(sqlmock.NewResult(0, 3)) mock.ExpectCommit() }, wantAffected: 3, }, { name: "m2m_one", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2, 3}}}, {Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{5, 6}}}, {Rel: M2M, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{4}}}, }, Add: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{7, 8}}}, {Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{9}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Get all node ids first. mock.ExpectQuery(escape("SELECT `id` FROM `users`")). WillReturnRows(sqlmock.NewRows([]string{"id"}). AddRow(1)) // Clear user's groups. mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `user_id` = ? AND `group_id` IN (?, ?)")). WithArgs(1, 2, 3). WillReturnResult(sqlmock.NewResult(0, 2)) // Clear user's followers. mock.ExpectExec(escape("DELETE FROM `user_followers` WHERE (`user_id` = ? AND `follower_id` IN (?, ?)) OR (`follower_id` = ? AND `user_id` IN (?, ?))")). WithArgs(1, 5, 6, 1, 5, 6). WillReturnResult(sqlmock.NewResult(0, 2)) // Clear user's friends. mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` = ? AND `friend_id` = ?) OR (`friend_id` = ? AND `user_id` = ?)")). WithArgs(1, 4, 1, 4). WillReturnResult(sqlmock.NewResult(0, 2)) // Attach new groups to user. mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). WithArgs(7, 1, 8, 1). WillReturnResult(sqlmock.NewResult(0, 2)) // Attach new friends to user. mock.ExpectExec(escape("INSERT INTO `user_followers` (`user_id`, `follower_id`) VALUES (?, ?), (?, ?)")). WithArgs(1, 9, 9, 1). WillReturnResult(sqlmock.NewResult(0, 2)) mock.ExpectCommit() }, wantAffected: 1, }, { name: "m2m_many", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2, 3}}}, {Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{5, 6}}}, {Rel: M2M, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{4}}}, }, Add: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{7, 8}}}, {Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{9}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Get all node ids first. mock.ExpectQuery(escape("SELECT `id` FROM `users`")). WillReturnRows(sqlmock.NewRows([]string{"id"}). AddRow(10). AddRow(20)) // Clear user's groups. mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `user_id` IN (?, ?) AND `group_id` IN (?, ?)")). WithArgs(10, 20, 2, 3). WillReturnResult(sqlmock.NewResult(0, 2)) // Clear user's followers. mock.ExpectExec(escape("DELETE FROM `user_followers` WHERE (`user_id` IN (?, ?) AND `follower_id` IN (?, ?)) OR (`follower_id` IN (?, ?) AND `user_id` IN (?, ?))")). WithArgs(10, 20, 5, 6, 10, 20, 5, 6). WillReturnResult(sqlmock.NewResult(0, 2)) // Clear user's friends. mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` IN (?, ?) AND `friend_id` = ?) OR (`friend_id` IN (?, ?) AND `user_id` = ?)")). WithArgs(10, 20, 4, 10, 20, 4). WillReturnResult(sqlmock.NewResult(0, 2)) // Attach new groups to user. mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). WithArgs(7, 10, 7, 20, 8, 10, 8, 20). WillReturnResult(sqlmock.NewResult(0, 4)) // Attach new friends to user. mock.ExpectExec(escape("INSERT INTO `user_followers` (`user_id`, `follower_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). WithArgs(10, 9, 9, 10, 20, 9, 9, 20). WillReturnResult(sqlmock.NewResult(0, 4)) mock.ExpectCommit() }, wantAffected: 2, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.prepare(mock) affected, err := UpdateNodes(context.Background(), sql.OpenDB("", db), tt.spec) require.Equal(t, tt.wantErr, err != nil, err) require.Equal(t, tt.wantAffected, affected) }) } } func TestDeleteNodes(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectBegin() mock.ExpectExec(escape("DELETE FROM `users`")). WillReturnResult(sqlmock.NewResult(0, 2)) mock.ExpectCommit() affected, err := DeleteNodes(context.Background(), sql.OpenDB("", db), &DeleteSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, }) require.NoError(t, err) require.Equal(t, 2, affected) } func TestQueryNodes(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectQuery(escape("SELECT DISTINCT `users`.`id`, `users`.`age`, `users`.`name`, `users`.`fk1`, `users`.`fk2` FROM `users` WHERE `age` < ? ORDER BY `id` LIMIT 3 OFFSET 4")). WithArgs(40). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name", "fk1", "fk2"}). AddRow(1, 10, nil, nil, nil). AddRow(2, 20, "", 0, 0). AddRow(3, 30, "a8m", 1, 1)) mock.ExpectQuery(escape("SELECT COUNT(DISTINCT `users`.`id`) FROM `users` WHERE `age` < ? ORDER BY `id` LIMIT 3 OFFSET 4")). WithArgs(40). WillReturnRows(sqlmock.NewRows([]string{"COUNT"}). AddRow(3)) var ( users []*user spec = &QuerySpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "age", "name", "fk1", "fk2"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Limit: 3, Offset: 4, Unique: true, Order: func(s *sql.Selector) { s.OrderBy("id") }, Predicate: func(s *sql.Selector) { s.Where(sql.LT("age", 40)) }, ScanValues: func(columns []string) ([]interface{}, error) { u := &user{} users = append(users, u) return u.values(columns) }, Assign: func(columns []string, values []interface{}) error { return users[len(users)-1].assign(columns, values) }, } ) // Query and scan. err = QueryNodes(context.Background(), sql.OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, &user{id: 1, age: 10, name: ""}, users[0]) require.Equal(t, &user{id: 2, age: 20, name: ""}, users[1]) require.Equal(t, &user{id: 3, age: 30, name: "a8m", edges: struct{ fk1, fk2 int }{1, 1}}, users[2]) // Count nodes. n, err := CountNodes(context.Background(), sql.OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, 3, n) } func TestQueryEdges(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectQuery(escape("SELECT `group_id`, `user_id` FROM `user_groups` WHERE `user_id` IN (?, ?, ?)")). WithArgs(1, 2, 3). WillReturnRows(sqlmock.NewRows([]string{"group_id", "user_id"}). AddRow(4, 5). AddRow(4, 6)) var ( edges [][]int64 spec = &EdgeQuerySpec{ Edge: &EdgeSpec{ Inverse: true, Table: "user_groups", Columns: []string{"user_id", "group_id"}, }, Predicate: func(s *sql.Selector) { s.Where(sql.InValues("user_id", 1, 2, 3)) }, ScanValues: func() [2]interface{} { return [2]interface{}{&sql.NullInt64{}, &sql.NullInt64{}} }, Assign: func(out, in interface{}) error { o, i := out.(*sql.NullInt64), in.(*sql.NullInt64) edges = append(edges, []int64{o.Int64, i.Int64}) return nil }, } ) // Query and scan. err = QueryEdges(context.Background(), sql.OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, [][]int64{{4, 5}, {4, 6}}, edges) } func escape(query string) string { rows := strings.Split(query, "\n") for i := range rows { rows[i] = strings.TrimPrefix(rows[i], " ") } query = strings.Join(rows, " ") return strings.TrimSpace(regexp.QuoteMeta(query)) + "$" } ent-0.5.4/dialect/sql/sqljson/000077500000000000000000000000001377533537200162225ustar00rootroot00000000000000ent-0.5.4/dialect/sql/sqljson/sqljson.go000066400000000000000000000302651377533537200202500ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqljson import ( "encoding/json" "fmt" "strings" "unicode" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" ) // HasKey return a predicate for checking that a JSON key // exists and not NULL. // // sqljson.HasKey("column", sql.DotPath("a.b[2].c")) // func HasKey(column string, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { ValuePath(b, column, opts...) b.WriteOp(sql.OpNotNull) }) } // ValueEQ return a predicate for checking that a JSON value // (returned by the path) is equal to the given argument. // // sqljson.ValueEQ("a", 1, sqljson.Path("b")) // func ValueEQ(column string, arg interface{}, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts, arg = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpEQ).Arg(arg) }) } // ValueNEQ return a predicate for checking that a JSON value // (returned by the path) is not equal to the given argument. // // sqljson.ValueNEQ("a", 1, sqljson.Path("b")) // func ValueNEQ(column string, arg interface{}, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts, arg = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpNEQ).Arg(arg) }) } // ValueGT return a predicate for checking that a JSON value // (returned by the path) is greater than the given argument. // // sqljson.ValueGT("a", 1, sqljson.Path("b")) // func ValueGT(column string, arg interface{}, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts, arg = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpGT).Arg(arg) }) } // ValueGTE return a predicate for checking that a JSON value // (returned by the path) is greater than or equal to the given // argument. // // sqljson.ValueGTE("a", 1, sqljson.Path("b")) // func ValueGTE(column string, arg interface{}, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts, arg = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpGTE).Arg(arg) }) } // ValueLT return a predicate for checking that a JSON value // (returned by the path) is less than the given argument. // // sqljson.ValueLT("a", 1, sqljson.Path("b")) // func ValueLT(column string, arg interface{}, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts, arg = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpLT).Arg(arg) }) } // ValueLTE return a predicate for checking that a JSON value // (returned by the path) is less than or equal to the given // argument. // // sqljson.ValueLTE("a", 1, sqljson.Path("b")) // func ValueLTE(column string, arg interface{}, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts, arg = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpLTE).Arg(arg) }) } // ValueContains return a predicate for checking that a JSON // value (returned by the path) contains the given argument. // // sqljson.ValueContains("a", 1, sqljson.Path("b")) // func ValueContains(column string, arg interface{}, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { path := &PathOptions{Ident: column} for i := range opts { opts[i](path) } switch b.Dialect() { case dialect.MySQL: b.WriteString("JSON_CONTAINS").Nested(func(b *sql.Builder) { b.Ident(column).Comma() b.Arg(marshal(arg)).Comma() path.mysqlPath(b) }) b.WriteOp(sql.OpEQ).Arg(1) case dialect.SQLite: b.WriteString("EXISTS").Nested(func(b *sql.Builder) { b.WriteString("SELECT * FROM JSON_EACH").Nested(func(b *sql.Builder) { b.Ident(column).Comma() path.mysqlPath(b) }) b.WriteString(" WHERE ").Ident("value").WriteOp(sql.OpEQ).Arg(arg) }) case dialect.Postgres: opts, arg = normalizePG(b, arg, opts) path.Cast = "jsonb" path.value(b) b.WriteString(" @> ").Arg(marshal(arg)) } }) } // LenEQ return a predicate for checking that an array length // of a JSON (returned by the path) is equal to the given argument. // // sqljson.LenEQ("a", 1, sqljson.Path("b")) // func LenEQ(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpEQ).Arg(size) }) } // LenNEQ return a predicate for checking that an array length // of a JSON (returned by the path) is not equal to the given argument. // // sqljson.LenEQ("a", 1, sqljson.Path("b")) // func LenNEQ(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpNEQ).Arg(size) }) } // LenGT return a predicate for checking that an array length // of a JSON (returned by the path) is greater than the given // argument. // // sqljson.LenGT("a", 1, sqljson.Path("b")) // func LenGT(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpGT).Arg(size) }) } // LenGTE return a predicate for checking that an array length // of a JSON (returned by the path) is greater than or equal to // the given argument. // // sqljson.LenGTE("a", 1, sqljson.Path("b")) // func LenGTE(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpGTE).Arg(size) }) } // LenLT return a predicate for checking that an array length // of a JSON (returned by the path) is less than the given // argument. // // sqljson.LenLT("a", 1, sqljson.Path("b")) // func LenLT(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpLT).Arg(size) }) } // LenLTE return a predicate for checking that an array length // of a JSON (returned by the path) is less than or equal to // the given argument. // // sqljson.LenLTE("a", 1, sqljson.Path("b")) // func LenLTE(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpLTE).Arg(size) }) } // ValuePath writes to the given SQL builder the JSON path for // getting the value of a given JSON path. // // sqljson.ValuePath(b, Path("a", "b", "[1]", "c"), Cast("int")) // func ValuePath(b *sql.Builder, column string, opts ...Option) { path := &PathOptions{Ident: column} for i := range opts { opts[i](path) } path.value(b) } // LenPath writes to the given SQL builder the JSON path for // getting the length of a given JSON path. // // sqljson.LenPath(b, Path("a", "b", "[1]", "c")) // func LenPath(b *sql.Builder, column string, opts ...Option) { path := &PathOptions{Ident: column} for i := range opts { opts[i](path) } path.length(b) } // Option allows for calling database JSON paths with functional options. type Option func(*PathOptions) // Path sets the path to the JSON value of a column. // // ValuePath(b, "column", Path("a", "b", "[1]", "c")) // func Path(path ...string) Option { return func(p *PathOptions) { p.Path = path } } // DotPath is similar to Path, but accepts string with dot format. // // ValuePath(b, "column", DotPath("a.b.c")) // ValuePath(b, "column", DotPath("a.b[2].c")) // // Note that DotPath is ignored if the input is invalid. func DotPath(dotpath string) Option { path, _ := ParsePath(dotpath) return func(p *PathOptions) { p.Path = path } } // Unquote indicates that the result value should be unquoted. // // ValuePath(b, "column", Path("a", "b", "[1]", "c"), Unquote(true)) // func Unquote(unquote bool) Option { return func(p *PathOptions) { p.Unquote = unquote } } // Cast indicates that the result value should be casted to the given type. // // ValuePath(b, "column", Path("a", "b", "[1]", "c"), Cast("int")) // func Cast(typ string) Option { return func(p *PathOptions) { p.Cast = typ } } // PathOptions holds the options for accessing a JSON value from an identifier. type PathOptions struct { Ident string Path []string Cast string Unquote bool } // value writes the path for getting the JSON value. func (p *PathOptions) value(b *sql.Builder) { switch { case len(p.Path) == 0: b.Ident(p.Ident) case b.Dialect() == dialect.Postgres: if p.Cast != "" { b.WriteByte('(') defer b.WriteString(")::" + p.Cast) } p.pgPath(b) default: if p.Unquote && b.Dialect() == dialect.MySQL { b.WriteString("JSON_UNQUOTE(") defer b.WriteByte(')') } p.mysqlFunc("JSON_EXTRACT", b) } } // value writes the path for getting the length of a JSON value. func (p *PathOptions) length(b *sql.Builder) { switch { case b.Dialect() == dialect.Postgres: b.WriteString("JSONB_ARRAY_LENGTH(") p.pgPath(b) b.WriteByte(')') case b.Dialect() == dialect.MySQL: p.mysqlFunc("JSON_LENGTH", b) default: p.mysqlFunc("JSON_ARRAY_LENGTH", b) } } // mysqlFunc writes the JSON path in MySQL format for the // the given function. `JSON_EXTRACT("a", '$.b.c')`. func (p *PathOptions) mysqlFunc(fn string, b *sql.Builder) { b.WriteString(fn).WriteByte('(') b.Ident(p.Ident).Comma() p.mysqlPath(b) b.WriteByte(')') } // mysqlPath writes the JSON path in MySQL (or SQLite) format. func (p *PathOptions) mysqlPath(b *sql.Builder) { b.WriteString(`"$`) for _, p := range p.Path { if _, ok := isJSONIdx(p); ok { b.WriteString(p) } else { b.WriteString("." + p) } } b.WriteByte('"') } // pgPath writes the JSON path in Postgres format `"a"->'b'->>'c'`. func (p *PathOptions) pgPath(b *sql.Builder) { b.Ident(p.Ident) for i, s := range p.Path { b.WriteString("->") if p.Unquote && i == len(p.Path)-1 { b.WriteString(">") } if idx, ok := isJSONIdx(s); ok { b.WriteString(idx) } else { b.WriteString("'" + s + "'") } } } // ParsePath parses the "dotpath" for the DotPath option. // // "a.b" => ["a", "b"] // "a[1][2]" => ["a", "[1]", "[2]"] // "a.\"b.c\" => ["a", "\"b.c\""] // func ParsePath(dotpath string) ([]string, error) { var ( i, p int path []string ) for i < len(dotpath) { switch r := dotpath[i]; { case r == '"': if i == len(dotpath)-1 { return nil, fmt.Errorf("unexpected quote") } idx := strings.IndexRune(dotpath[i+1:], '"') if idx == -1 || idx == 0 { return nil, fmt.Errorf("unbalanced quote") } i += idx + 2 case r == '[': if p != i { path = append(path, dotpath[p:i]) } p = i if i == len(dotpath)-1 { return nil, fmt.Errorf("unexpected bracket") } idx := strings.IndexRune(dotpath[i:], ']') if idx == -1 || idx == 1 { return nil, fmt.Errorf("unbalanced bracket") } if !isNumber(dotpath[i+1 : i+idx]) { return nil, fmt.Errorf("invalid index %q", dotpath[i:i+idx+1]) } i += idx + 1 case r == '.' || r == ']': if p != i { path = append(path, dotpath[p:i]) } i++ p = i default: i++ } } if p != i { path = append(path, dotpath[p:i]) } return path, nil } // normalizePG adds cast option to the JSON path is the argument type is // not string, in order to avoid "missing type casts" error in Postgres. func normalizePG(b *sql.Builder, arg interface{}, opts []Option) ([]Option, interface{}) { if b.Dialect() != dialect.Postgres { return opts, arg } base := []Option{Unquote(true)} switch arg.(type) { case string: case bool: base = append(base, Cast("bool")) case float32, float64: base = append(base, Cast("float")) case int8, int16, int32, int64, int, uint8, uint16, uint32, uint64: base = append(base, Cast("int")) default: // convert unknown types to text. arg = marshal(arg) } return append(base, opts...), arg } // isJSONIdx reports whether the string represents a JSON index. func isJSONIdx(s string) (string, bool) { if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) { return s[1 : len(s)-1], true } return "", false } // isNumber reports whether the string is a number (category N). func isNumber(s string) bool { for _, r := range s { if !unicode.IsNumber(r) { return false } } return true } // marshal stringifies the given argument to a valid JSON document. func marshal(arg interface{}) interface{} { if buf, err := json.Marshal(arg); err == nil { arg = string(buf) } return arg } ent-0.5.4/dialect/sql/sqljson/sqljson_test.go000066400000000000000000000171631377533537200213110ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqljson_test import ( "strconv" "testing" "github.com/facebook/ent/dialect" "github.com/facebook/ent/dialect/sql" "github.com/facebook/ent/dialect/sql/sqljson" "github.com/stretchr/testify/require" ) func TestWritePath(t *testing.T) { tests := []struct { input sql.Querier wantQuery string wantArgs []interface{} }{ { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", 1, sqljson.Path("b", "c", "[1]", "d"), sqljson.Cast("int"))), wantQuery: `SELECT * FROM "users" WHERE ("a"->'b'->'c'->1->>'d')::int = $1`, wantArgs: []interface{}{1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", "a", sqljson.DotPath("b.c[1].d"))), wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.c[1].d\") = ?", wantArgs: []interface{}{"a"}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", "a", sqljson.DotPath("b.\"c[1]\".d[1][2].e"))), wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, \"$.b.\"c[1]\".d[1][2].e\") = ?", wantArgs: []interface{}{"a"}, }, { input: sql.Select("*"). From(sql.Table("test")). Where(sqljson.HasKey("j", sqljson.DotPath("a.*.c"))), wantQuery: "SELECT * FROM `test` WHERE JSON_EXTRACT(`j`, \"$.a.*.c\") IS NOT NULL", }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("test")). Where(sqljson.HasKey("j", sqljson.DotPath("a.b.c"))), wantQuery: `SELECT * FROM "test" WHERE "j"->'a'->'b'->'c' IS NOT NULL`, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("test")). Where(sql.And( sql.EQ("e", 10), sqljson.ValueEQ("a", 1, sqljson.DotPath("b.c")), )), wantQuery: `SELECT * FROM "test" WHERE "e" = $1 AND ("a"->'b'->>'c')::int = $2`, wantArgs: []interface{}{10, 1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", "a", sqljson.Path("b", "c", "[1]", "d"), sqljson.Unquote(true))), wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, \"$.b.c[1].d\")) = ?", wantArgs: []interface{}{"a"}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", "a", sqljson.Path("b", "c", "[1]", "d"), sqljson.Unquote(true))), wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' = $1`, wantArgs: []interface{}{"a"}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", 1, sqljson.Path("b", "c", "[1]", "d"), sqljson.Cast("int"))), wantQuery: `SELECT * FROM "users" WHERE ("a"->'b'->'c'->1->>'d')::int = $1`, wantArgs: []interface{}{1}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where( sql.Or( sqljson.ValueNEQ("a", 1, sqljson.Path("b")), sqljson.ValueGT("a", 1, sqljson.Path("c")), sqljson.ValueGTE("a", 1.1, sqljson.Path("d")), sqljson.ValueLT("a", 1, sqljson.Path("e")), sqljson.ValueLTE("a", 1, sqljson.Path("f")), ), ), wantQuery: `SELECT * FROM "users" WHERE ("a"->>'b')::int <> $1 OR ("a"->>'c')::int > $2 OR ("a"->>'d')::float >= $3 OR ("a"->>'e')::int < $4 OR ("a"->>'f')::int <= $5`, wantArgs: []interface{}{1, 1, 1.1, 1, 1}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.LenEQ("a", 1)), wantQuery: `SELECT * FROM "users" WHERE JSONB_ARRAY_LENGTH("a") = $1`, wantArgs: []interface{}{1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.LenEQ("a", 1)), wantQuery: "SELECT * FROM `users` WHERE JSON_LENGTH(`a`, \"$\") = ?", wantArgs: []interface{}{1}, }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("users")). Where(sqljson.LenEQ("a", 1)), wantQuery: "SELECT * FROM `users` WHERE JSON_ARRAY_LENGTH(`a`, \"$\") = ?", wantArgs: []interface{}{1}, }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("users")). Where( sql.Or( sqljson.LenGT("a", 1, sqljson.Path("b")), sqljson.LenGTE("a", 1, sqljson.Path("c")), sqljson.LenLT("a", 1, sqljson.Path("d")), sqljson.LenLTE("a", 1, sqljson.Path("e")), ), ), wantQuery: "SELECT * FROM `users` WHERE JSON_ARRAY_LENGTH(`a`, \"$.b\") > ? OR JSON_ARRAY_LENGTH(`a`, \"$.c\") >= ? OR JSON_ARRAY_LENGTH(`a`, \"$.d\") < ? OR JSON_ARRAY_LENGTH(`a`, \"$.e\") <= ?", wantArgs: []interface{}{1, 1, 1, 1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", "foo")), wantQuery: "SELECT * FROM `users` WHERE JSON_CONTAINS(`tags`, ?, \"$\") = ?", wantArgs: []interface{}{"\"foo\"", 1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", 1, sqljson.Path("a"))), wantQuery: "SELECT * FROM `users` WHERE JSON_CONTAINS(`tags`, ?, \"$.a\") = ?", wantArgs: []interface{}{"1", 1}, }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", "foo")), wantQuery: "SELECT * FROM `users` WHERE EXISTS(SELECT * FROM JSON_EACH(`tags`, \"$\") WHERE `value` = ?)", wantArgs: []interface{}{"foo"}, }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", 1, sqljson.Path("a"))), wantQuery: "SELECT * FROM `users` WHERE EXISTS(SELECT * FROM JSON_EACH(`tags`, \"$.a\") WHERE `value` = ?)", wantArgs: []interface{}{1}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", "foo")), wantQuery: "SELECT * FROM \"users\" WHERE \"tags\" @> $1", wantArgs: []interface{}{"\"foo\""}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", 1, sqljson.Path("a"))), wantQuery: "SELECT * FROM \"users\" WHERE (\"tags\"->'a')::jsonb @> $1", wantArgs: []interface{}{"1"}, }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { query, args := tt.input.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } func TestParsePath(t *testing.T) { tests := []struct { input string wantPath []string wantErr bool }{ { input: "a.b.c", wantPath: []string{"a", "b", "c"}, }, { input: "a[1][2]", wantPath: []string{"a", "[1]", "[2]"}, }, { input: "a[1][2].b", wantPath: []string{"a", "[1]", "[2]", "b"}, }, { input: `a."b.c[0]"`, wantPath: []string{"a", `"b.c[0]"`}, }, { input: `a."b.c[0]".d`, wantPath: []string{"a", `"b.c[0]"`, "d"}, }, { input: `...`, }, { input: `.a.b.`, wantPath: []string{"a", "b"}, }, { input: `a."`, wantErr: true, }, { input: `a[`, wantErr: true, }, { input: `a[a]`, wantErr: true, }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { path, err := sqljson.ParsePath(tt.input) require.Equal(t, tt.wantPath, path) require.Equal(t, tt.wantErr, err != nil) }) } } ent-0.5.4/doc/000077500000000000000000000000001377533537200130725ustar00rootroot00000000000000ent-0.5.4/doc/.gitignore000077500000000000000000000002771377533537200150730ustar00rootroot00000000000000.DS_Store node_modules lib/core/metadata.js lib/core/MetadataBlog.js website/translated_docs website/build/ website/yarn.lock website/node_modules website/i18n/* website/package-lock.json ent-0.5.4/doc/md/000077500000000000000000000000001377533537200134725ustar00rootroot00000000000000ent-0.5.4/doc/md/aggregate.md000077500000000000000000000014221377533537200157440ustar00rootroot00000000000000--- id: aggregate title: Aggregation --- ## Group By Group by `name` and `age` fields of all users, and sum their total age. ```go package main import ( "context" "/ent" "/ent/user" ) func Do(ctx context.Context, client *ent.Client) { var v []struct { Name string `json:"name"` Age int `json:"age"` Sum int `json:"sum"` Count int `json:"count"` } err := client.User.Query(). GroupBy(user.FieldName, user.FieldAge). Aggregate(ent.Count(), ent.Sum(user.FieldAge)). Scan(ctx, &v) } ``` Group by one field. ```go package main import ( "context" "/ent" "/ent/user" ) func Do(ctx context.Context, client *ent.Client) { names, err := client.User. Query(). GroupBy(user.FieldName). Strings(ctx) } ``` ent-0.5.4/doc/md/code-gen.md000077500000000000000000000176361377533537200155150ustar00rootroot00000000000000--- id: code-gen title: Introduction --- ## Installation The project comes with a codegen tool called `ent`. In order to install `ent` run the following command: ```bash go get github.com/facebook/ent/cmd/ent ``` ## Initialize A New Schema In order to generate one or more schema templates, run `ent init` as follows: ```bash go run github.com/facebook/ent/cmd/ent init User Pet ``` `init` will create the 2 schemas (`user.go` and `pet.go`) under the `ent/schema` directory. If the `ent` directory does not exist, it will create it as well. The convention is to have an `ent` directory under the root directory of the project. ## Generate Assets After adding a few [fields](schema-fields.md) and [edges](schema-edges.md), you want to generate the assets for working with your entities. Run `ent generate` from the root directory of the project, or use `go generate`: ```bash go generate ./ent ``` The `generate` command generates the following assets for the schemas: - `Client` and `Tx` objects used for interacting with the graph. - CRUD builders for each schema type. See [CRUD](crud.md) for more info. - Entity object (Go struct) for each of the schema types. - Package containing constants and predicates used for interacting with the builders. - A `migrate` package for SQL dialects. See [Migration](migrate.md) for more info. ## Version Compatibility Between `entc` And `ent` When working with `ent` CLI in a project, you want to make sure the version being used by the CLI is **identical** to the `ent` version used by your project. One of the options for achieving this is asking `go generate` to use the version mentioned in the `go.mod` file when running `ent`. If your project does not use [Go modules](https://github.com/golang/go/wiki/Modules#quick-start), setup one as follows: ```console go mod init ``` And then, re-run the following command in order to add `ent` to your `go.mod` file: ```console go get github.com/facebook/ent/cmd/ent ``` Add a `generate.go` file to your project under `/ent`: ```go package ent //go:generate go run github.com/facebook/ent/cmd/ent generate ./schema ``` Finally, you can run `go generate ./ent` from the root directory of your project in order to run `ent` code generation on your project schemas. ## Code Generation Options For more info about codegen options, run `ent generate -h`: ```console generate go code for the schema directory Usage: ent generate [flags] path Examples: ent generate ./ent/schema ent generate github.com/a8m/x Flags: --feature strings extend codegen with additional features --header string override codegen header -h, --help help for generate --idtype [int string] type of the id field (default int) --storage string storage driver to support in codegen (default "sql") --target string target directory for codegen --template strings external templates to execute ``` ## Storage Options `ent` can generate assets for both SQL and Gremlin dialect. The default dialect is SQL. ## External Templates `ent` accepts external Go templates to execute. If the template name already defined by `ent`, it will override the existing one. Otherwise, it will write the execution output to a file with the same name as the template. The flag format supports `file`, `dir` and `glob` as follows: ```console go run github.com/facebook/ent/cmd/ent generate --template --template glob="path/to/*.tmpl" ./ent/schema ``` More information and examples can be found in the [external templates doc](templates.md). ## Use `entc` As A Package Another option for running `ent` CLI is to use it as a package as follows: ```go package main import ( "log" "github.com/facebook/ent/entc" "github.com/facebook/ent/entc/gen" "github.com/facebook/ent/schema/field" ) func main() { err := entc.Generate("./schema", &gen.Config{ Header: "// Your Custom Header", IDType: &field.TypeInfo{Type: field.TypeInt}, }) if err != nil { log.Fatal("running ent codegen:", err) } } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/entcpkg). ## Schema Description In order to get a description of your graph schema, run: ```bash go run github.com/facebook/ent/cmd/ent describe ./ent/schema ``` An example for the output is as follows: ```console Pet: +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ | id | int | false | false | false | false | false | false | json:"id,omitempty" | 0 | | name | string | false | false | false | false | false | false | json:"name,omitempty" | 0 | +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ +-------+------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +-------+------+---------+---------+----------+--------+----------+ | owner | User | true | pets | M2O | true | true | +-------+------+---------+---------+----------+--------+----------+ User: +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ | id | int | false | false | false | false | false | false | json:"id,omitempty" | 0 | | age | int | false | false | false | false | false | false | json:"age,omitempty" | 0 | | name | string | false | false | false | false | false | false | json:"name,omitempty" | 0 | +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ +------+------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +------+------+---------+---------+----------+--------+----------+ | pets | Pet | false | | O2M | false | true | +------+------+---------+---------+----------+--------+----------+ ``` ## Code Generation Hooks The `entc` package provides an option to add a list of hooks (middlewares) to the code-generation phase. This option is ideal for adding custom validators for the schema, or for generating additional assets using the graph schema. ```go // +build ignore package main import ( "fmt" "log" "reflect" "github.com/facebook/ent/entc" "github.com/facebook/ent/entc/gen" ) func main() { err := entc.Generate("./schema", &gen.Config{ Hooks: []gen.Hook{ EnsureStructTag("json"), }, }) if err != nil { log.Fatalf("running ent codegen: %v", err) } } // EnsureStructTag ensures all fields in the graph have a specific tag name. func EnsureStructTag(name string) gen.Hook { return func(next gen.Generator) gen.Generator { return gen.GenerateFunc(func(g *gen.Graph) error { for _, node := range g.Nodes { for _, field := range node.Fields { tag := reflect.StructTag(field.StructTag) if _, ok := tag.Lookup(name); !ok { return fmt.Errorf("struct tag %q is missing for field %s.%s", name, node.Name, f.Name) } } } return next.Generate(g) }) } } ``` ## Feature Flags The `entc` package provides a collection of code-generation features that be added or removed using flags. For more information, please see the [features-flags page](features.md). ent-0.5.4/doc/md/crud.md000077500000000000000000000150751377533537200147640ustar00rootroot00000000000000--- id: crud title: CRUD API --- As mentioned in the [introduction](code-gen.md) section, running `ent` on the schemas, will generate the following assets: - `Client` and `Tx` objects used for interacting with the graph. - CRUD builders for each schema type. See [CRUD](crud.md) for more info. - Entity object (Go struct) for each of the schema type. - Package containing constants and predicates used for interacting with the builders. - A `migrate` package for SQL dialects. See [Migration](migrate.md) for more info. ## Create A New Client **MySQL** ```go package main import ( "log" "/ent" _ "github.com/go-sql-driver/mysql" ) func main() { client, err := ent.Open("mysql", ":@tcp(:)/?parseTime=True") if err != nil { log.Fatal(err) } defer client.Close() } ``` **PostgreSQL** ```go package main import ( "log" "/ent" _ "github.com/lib/pq" ) func main() { client, err := ent.Open("postgres","host= port= user= dbname= password=") if err != nil { log.Fatal(err) } defer client.Close() } ``` **SQLite** ```go package main import ( "log" "/ent" _ "github.com/mattn/go-sqlite3" ) func main() { client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") if err != nil { log.Fatal(err) } defer client.Close() } ``` **Gremlin (AWS Neptune)** ```go package main import ( "log" "/ent" ) func main() { client, err := ent.Open("gremlin", "http://localhost:8182") if err != nil { log.Fatal(err) } } ``` ## Create An Entity **Save** a user. ```go a8m, err := client.User. // UserClient. Create(). // User create builder. SetName("a8m"). // Set field value. SetNillableAge(age). // Avoid nil checks. AddGroups(g1, g2). // Add many edges. SetSpouse(nati). // Set unique edge. Save(ctx) // Create and return. ``` **SaveX** a pet; Unlike **Save**, **SaveX** panics if an error occurs. ```go pedro := client.Pet. // PetClient. Create(). // Pet create builder. SetName("pedro"). // Set field value. SetOwner(a8m). // Set owner (unique edge). SaveX(ctx) // Create and return. ``` ## Create Many **Save** a bulk of pets. ```go names := []string{"pedro", "xabi", "layla"} bulk := make([]*ent.PetCreate, len(names)) for i, name := range names { bulk[i] = client.Pet.Create().SetName(name).SetOwner(a8m) } pets, err := client.Pet.CreateBulk(bulk...).Save(ctx) ``` ## Update One Update an entity that was returned from the database. ```go a8m, err = a8m.Update(). // User update builder. RemoveGroup(g2). // Remove specific edge. ClearCard(). // Clear unique edge. SetAge(30). // Set field value Save(ctx) // Save and return. ``` ## Update By ID ```go pedro, err := client.Pet. // PetClient. UpdateOneID(id). // Pet update builder. SetName("pedro"). // Set field name. SetOwnerID(owner). // Set unique edge, using id. Save(ctx) // Save and return. ``` ## Update Many Filter using predicates. ```go n, err := client.User. // UserClient. Update(). // Pet update builder. Where( // user.Or( // (age >= 30 OR name = "bar") user.AgeEQ(30), // user.Name("bar"), // AND ), // user.HasFollowers(), // UserHasFollowers() ). // SetName("foo"). // Set field name. Save(ctx) // exec and return. ``` Query edge-predicates. ```go n, err := client.User. // UserClient. Update(). // Pet update builder. Where( // user.HasFriendsWith( // UserHasFriendsWith ( user.Or( // age = 20 user.Age(20), // OR user.Age(30), // age = 30 ) // ) ), // ). // SetName("a8m"). // Set field name. Save(ctx) // exec and return. ``` ## Query The Graph Get all users with followers. ```go users, err := client.User. // UserClient. Query(). // User query builder. Where(user.HasFollowers()). // filter only users with followers. All(ctx) // query and return. ``` Get all followers of a specific user; Start the traversal from a node in the graph. ```go users, err := a8m. QueryFollowers(). All(ctx) ``` Get all pets of the followers of a user. ```go users, err := a8m. QueryFollowers(). QueryPets(). All(ctx) ``` More advance traversals can be found in the [next section](traversals.md). ## Field Selection Get all pet names. ```go names, err := client.Pet. Query(). Select(pet.FieldName). Strings(ctx) ``` Select partial objects and partial associations.gs Get all pets and their owners, but select and fill only the `ID` and `Name` fields. ```go pets, err := client.Pet. Query(). Select(pet.FieldName). WithOwner(func (q *ent.UserQuery) { q.Select(user.FieldName) }). All(ctx) ``` Scan all pet names and ages to custom struct. ```go var v []struct { Age int `json:"age"` Name string `json:"name"` } err := client.Pet. Query(). Select(pet.FieldAge, pet.FieldName). Scan(ctx, &v) if err != nil { log.Fatal(err) } ``` ## Delete One Delete an entity. ```go err := client.User. DeleteOne(a8m). Exec(ctx) ``` Delete by ID. ```go err := client.User. DeleteOneID(id). Exec(ctx) ``` ## Delete Many Delete using predicates. ```go _, err := client.File. Delete(). Where(file.UpdatedAtLT(date)). Exec(ctx) ``` ## Mutation Each generated node type has its own type of mutation. For example, all [`User` builders](crud.md#create-an-entity), share the same generated `UserMutation` object. However, all builder types implement the generic `ent.Mutation` interface. For example, in order to write a generic code that apply a set of methods on both `ent.UserCreate` and `ent.UserUpdate`, use the `UserMutation` object: ```go func Do() { creator := client.User.Create() SetAgeName(creator.Mutation()) updater := client.User.UpdateOneID(id) SetAgeName(updater.Mutation()) } // SetAgeName sets the age and the name for any mutation. func SetAgeName(m *ent.UserMutation) { m.SetAge(32) m.SetName("Ariel") } ``` In some cases, you want to apply a set of methods on multiple types. For cases like this, either use the generic `ent.Mutation` interface, or create your own interface. ```go func Do() { creator1 := client.User.Create() SetName(creator1.Mutation(), "a8m") creator2 := client.Pet.Create() SetName(creator2.Mutation(), "pedro") } // SetNamer wraps the 2 methods for getting // and setting the "name" field in mutations. type SetNamer interface { SetName(string) Name() (string, bool) } func SetName(m SetNamer, name string) { if _, exist := m.Name(); !exist { m.SetName(name) } } ``` ent-0.5.4/doc/md/dialects.md000077500000000000000000000016011377533537200156050ustar00rootroot00000000000000--- id: dialects title: Supported Dialects --- ## MySQL MySQL supports all the features that are mentioned in the [Migration](migrate.md) section, and it's being tested constantly on the following 3 versions: `5.6.35`, `5.7.26` and `8`. ## PostgreSQL PostgreSQL supports all the features that are mentioned in the [Migration](migrate.md) section, and it's being tested constantly on the following 3 versions: `10`, `11` and `12`. ## SQLite SQLite supports all _"append-only"_ features mentioned in the [Migration](migrate.md) section. However, dropping or modifying resources, like [drop-index](migrate.md#drop-resources) are not supported by default by SQLite, and will be added in the future using a [temporary table](https://www.sqlite.org/lang_altertable.html#otheralter). ## Gremlin Gremlin does not support migration nor indexes, and **it's considered experimental**. ent-0.5.4/doc/md/eager-load.md000066400000000000000000000052231377533537200160160ustar00rootroot00000000000000--- id: eager-load title: Eager Loading --- ## Overview `ent` supports querying entities with their associations (through their edges). The associated entities are populated to the `Edges` field in the returned object. Let's give an example hows does the API look like for the following schema: ![er-group-users](https://entgo.io/assets/er_user_pets_groups.png) **Query all users with their pets:** ```go users, err := client.User. Query(). WithPets(). All(ctx) if err != nil { return err } // The returned users look as follows: // // [ // User { // ID: 1, // Name: "a8m", // Edges: { // Pets: [Pet(...), ...] // ... // } // }, // ... // ] // for _, u := range users { for _, p := range u.Edges.Pets { fmt.Printf("User(%v) -> Pet(%v)\n", u.ID, p.ID) // Output: // User(...) -> Pet(...) } } ``` Eager loading allows to query more than one association (including nested), and also filter, sort or limit their result. For example: ```go admins, err := client.User. Query(). Where(user.Admin(true)). // Populate the `pets` that associated with the `admins`. WithPets(). // Populate the first 5 `groups` that associated with the `admins`. WithGroups(func(q *ent.GroupQuery) { q.Limit(5) // Limit to 5. q.WithUsers().Limit(5) // Populate the `users` of each `groups`. }). All(ctx) if err != nil { return err } // The returned users look as follows: // // [ // User { // ID: 1, // Name: "admin1", // Edges: { // Pets: [Pet(...), ...] // Groups: [ // Group { // ID: 7, // Name: "GitHub", // Edges: { // Users: [User(...), ...] // ... // } // } // ] // } // }, // ... // ] // for _, admin := range admins { for _, p := range admin.Edges.Pets { fmt.Printf("Admin(%v) -> Pet(%v)\n", u.ID, p.ID) // Output: // Admin(...) -> Pet(...) } for _, g := range admin.Edges.Groups { for _, u := range g.Edges.Users { fmt.Printf("Admin(%v) -> Group(%v) -> User(%v)\n", u.ID, g.ID, u.ID) // Output: // Admin(...) -> Group(...) -> User(...) } } } ``` ## API Each query-builder has a list of methods in the form of `With(...func(Query))` for each of its edges. `` stands for the edge name (like, `WithGroups`) and `` for the edge type (like, `GroupQuery`). Note that, only SQL dialects support this feature. ## Implementation Since a query-builder can load more than one association, it's not possible to load them using one `JOIN` operation. Therefore, `ent` executes additional queries for loading associations. One query for `M2O/O2M` and `O2O` edges, and 2 queries for loading `M2M` edges. Note that, we expect to improve this in the next versions of `ent`. ent-0.5.4/doc/md/faq.md000066400000000000000000000241671377533537200145750ustar00rootroot00000000000000--- id: faq title: Frequently Asked Questions (FAQ) sidebar_label: FAQ --- ## Questions [How to create an entity from a struct `T`?](#how-to-create-an-entity-from-a-struct-t) [How to create a struct (or a mutation) level validator?](#how-to-create-a-mutation-level-validator) [How to write an audit-log extension?](#how-to-write-an-audit-log-extension) [How to write custom predicates?](#how-to-write-custom-predicates) [How to add custom predicates to the codegen assets?](#how-to-add-custom-predicates-to-the-codegen-assets) [How to define a network address field in PostgreSQL?](#how-to-define-a-network-address-field-in-postgresql) [How to customize time fields to type `DATETIME` in MySQL?](#how-to-customize-time-fields-to-type-datetime-in-mysql) ## Answers #### How to create an entity from a struct `T`? The different builders don't support the option of setting the entity fields (or edges) from a given struct `T`. The reason is that there's no way to distinguish between zero/real values when updating the database (for example, `&ent.T{Age: 0, Name: ""}`). Setting these values, may set incorrect values in the database or update unnecessary columns. However, the [external template](templates.md) option lets you extend the default code-generation assets by adding custom logic. For example, in order to generate a method for each of the create-builders, that accepts a struct as an input and configure the builder, use the following template: ```gotemplate {{ range $n := $.Nodes }} {{ $builder := $n.CreateName }} {{ $receiver := receiver $builder }} func ({{ $receiver }} *{{ $builder }}) Set{{ $n.Name }}(input *{{ $n.Name }}) *{{ $builder }} { {{- range $f := $n.Fields }} {{- $setter := print "Set" $f.StructField }} {{ $receiver }}.{{ $setter }}(input.{{ $f.StructField }}) {{- end }} return {{ $receiver }} } {{ end }} ``` #### How to create a mutation level validator? In order to implement a mutation-level validator, you can either use [schema hooks](hooks.md#schema-hooks) for validating changes applied on one entity type, or use [transaction hooks](transactions.md#hooks) for validating mutations that being applied on multiple entity types (e.g. a GraphQL mutation). For example: ```go // A VersionHook is a dummy example for a hook that validates the "version" field // is incremented by 1 on each update. Note that this is just a dummy example, and // it doesn't promise consistency in the database. func VersionHook() ent.Hook { type OldSetVersion interface { SetVersion(int) Version() (int, bool) OldVersion(context.Context) (int, error) } return func(next ent.Mutator) ent.Mutator { return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { ver, ok := m.(OldSetVersion) if !ok { return next.Mutate(ctx, m) } oldV, err := ver.OldVersion(ctx) if err != nil { return nil, err } curV, exists := ver.Version() if !exists { return nil, fmt.Errorf("version field is required in update mutation") } if curV != oldV+1 { return nil, fmt.Errorf("version field must be incremented by 1") } // Add an SQL predicate that validates the "version" column is equal // to "oldV" (ensure it wasn't changed during the mutation by others). return next.Mutate(ctx, m) }) } } ``` #### How to write an audit-log extension? The preferred way for writing such an extension is to use [ent.Mixin](schema-mixin.md). Use the `Fields` option for setting the fields that are shared between all schemas that import the mixed-schema, and use the `Hooks` option for attaching a mutation-hook for all mutations that are being applied on these schemas. Here's an example, based on a discussion in the [repository issue-tracker](https://github.com/facebook/ent/issues/830): ```go // AuditMixin implements the ent.Mixin for sharing // audit-log capabilities with package schemas. type AuditMixin struct{ mixin.Schema } // Fields of the AuditMixin. func (AuditMixin) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Immutable(). Default(time.Now), field.Int("created_by"). Optional(), field.Time("updated_at"). Default(time.Now). UpdateDefault(time.Now), field.Int("updated_by"). Optional(), } } // Hooks of the AuditMixin. func (AuditMixin) Hooks() []ent.Hook { return []ent.Hook{ hooks.AuditHook, } } // A AuditHook is an example for audit-log hook. func AuditHook(next ent.Mutator) ent.Mutator { // AuditLogger wraps the methods that are shared between all mutations of // schemas that embed the AuditLog mixin. The variable "exists" is true, if // the field already exists in the mutation (e.g. was set by a different hook). type AuditLogger interface { SetCreatedAt(time.Time) CreatedAt() (value time.Time, exists bool) SetCreatedBy(int) CreatedBy() (id int, exists bool) SetUpdatedAt(time.Time) UpdatedAt() (value time.Time, exists bool) SetUpdatedBy(int) UpdatedBy() (id int, exists bool) } return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { ml, ok := m.(AuditLogger) if !ok { return nil, fmt.Errorf("unexpected audit-log call from mutation type %T", m) } usr, err := viewer.UserFromContext(ctx) if err != nil { return nil, err } switch op := m.Op(); { case op.Is(ent.OpCreate): ml.SetCreatedAt(time.Now()) if _, exists := ml.CreatedBy(); !exists { ml.SetCreatedBy(usr.ID) } case op.Is(ent.OpUpdateOne | ent.OpUpdate): ml.SetUpdatedAt(time.Now()) if _, exists := ml.UpdatedBy(); !exists { ml.SetUpdatedBy(usr.ID) } } return next.Mutate(ctx, m) }) } ``` #### How to write custom predicates? Users can provide custom predicates to apply on the query before it's executed. For example: ```go pets := client.Pet. Query(). Where(predicate.Pet(func(s *sql.Selector) { s.Where(sql.InInts(pet.OwnerColumn, 1, 2, 3)) })). AllX(ctx) users := client.User. Query(). Where(predicate.User(func(s *sql.Selector) { s.Where(sqljson.ValueContains(user.FieldTags, "tag")) })). AllX(ctx) ``` For more examples, go to the [predicates](predicates.md#custom-predicates) page, or search in the repository issue-tracker for more advance examples like [issue-842](https://github.com/facebook/ent/issues/842#issuecomment-707896368). #### How to add custom predicates to the codegen assets? The [template](templates.md) option enables the capability for extending or overriding the default codegen assets. In order to generate a type-safe predicate for the [example above](#how-to-write-custom-predicates), use the template option for doing it as follows: ```gotemplate {{/* A template that adds the "Glob" predicate for all string fields. */}} {{ define "where/additional/strings" }} {{ range $f := $.Fields }} {{ if $f.IsString }} {{ $func := print $f.StructField "Glob" }} // {{ $func }} applies the Glob predicate on the {{ quote $f.Name }} field. func {{ $func }}(pattern string) predicate.{{ $.Name }} { return predicate.{{ $.Name }}(func(s *sql.Selector) { s.Where(sql.P(func(b *sql.Builder) { b.Ident(s.C({{ $f.Constant }})).WriteString(" glob" ).Arg(pattern) })) }) } {{ end }} {{ end }} {{ end }} ``` #### How to define a network address field in PostgreSQL? The [GoType](http://localhost:3000/docs/schema-fields#go-type) and the [SchemaType](http://localhost:3000/docs/schema-fields#database-type) options allow users to define database-specific fields. For example, in order to define a [`macaddr`](https://www.postgresql.org/docs/13/datatype-net-types.html#DATATYPE-MACADDR) field, use the following configuration: ```go func (T) Fields() []ent.Field { return []ent.Field{ field.String("mac"). GoType(&MAC{}). SchemaType(map[string]string{ dialect.Postgres: "macaddr", }). Validate(func(s string) error { _, err := net.ParseMAC(s) return err }), } } // MAC represents a physical hardware address. type MAC struct { net.HardwareAddr } // Scan implements the Scanner interface. func (m *MAC) Scan(value interface{}) (err error) { switch v := value.(type) { case nil: case []byte: m.HardwareAddr, err = net.ParseMAC(string(v)) case string: m.HardwareAddr, err = net.ParseMAC(v) default: err = fmt.Errorf("unexpected type %T", v) } return } // Value implements the driver Valuer interface. func (m MAC) Value() (driver.Value, error) { return m.HardwareAddr.String(), nil } ``` Note that, if the database doesn't support the `macaddr` type (e.g. SQLite on testing), the field fallback to its native type (i.e. `string`). `inet` example: ```go func (T) Fields() []ent.Field { return []ent.Field{ field.String("ip"). GoType(&Inet{}). SchemaType(map[string]string{ dialect.Postgres: "inet", }). Validate(func(s string) error { if net.ParseIP(s) == nil { return fmt.Errorf("invalid value for ip %q", s) } return nil }), } } // Inet represents a single IP address type Inet struct { net.IP } // Scan implements the Scanner interface func (i *Inet) Scan(value interface{}) (err error) { switch v := value.(type) { case nil: case []byte: if i.IP = net.ParseIP(string(v)); i.IP == nil { err = fmt.Errorf("invalid value for ip %q", s) } case string: if i.IP = net.ParseIP(v); i.IP == nil { err = fmt.Errorf("invalid value for ip %q", s) } default: err = fmt.Errorf("unexpected type %T", v) } return } // Value implements the driver Valuer interface func (i Inet) Value() (driver.Value, error) { return i.IP.String(), nil } ``` #### How to customize time fields to type `DATETIME` in MySQL? `Time` fields use the MySQL `TIMESTAMP` type in the schema creation by default, and this type has a range of '1970-01-01 00:00:01' UTC to '2038-01-19 03:14:07' UTC (see, [MySQL docs](https://dev.mysql.com/doc/refman/5.6/en/datetime.html)). In order to customize time fields for a wider range, use the MySQL `DATETIME` as follows: ```go field.Time("birth_date"). Optional(). SchemaType(map[string]string{ dialect.MySQL: "datetime", }), ```ent-0.5.4/doc/md/features.md000066400000000000000000000036441377533537200156410ustar00rootroot00000000000000--- id: feature-flags title: Feature Flags sidebar_label: Feature Flags --- The framework provides a collection of code-generation features that be added or removed using flags. ## Usage Feature flags can be provided either by CLI flags or as arguments to the `gen` package. #### CLI ```console go run github.com/facebook/ent/cmd/ent generate --feature privacy,entql ./ent/schema ``` #### Go ```go // +build ignore package main import ( "log" "text/template" "github.com/facebook/ent/entc" "github.com/facebook/ent/entc/gen" ) func main() { err := entc.Generate("./schema", &gen.Config{ Features: []*gen.Feature{ gen.FeaturePrivacy, gen.FeatureEntQL, }, Templates: []*gen.Template{ gen.MustParse(gen.NewTemplate("static"). Funcs(template.FuncMap{"title": strings.ToTitle}). ParseFiles("template/static.tmpl")), }, }) if err != nil { log.Fatalf("running ent codegen: %v", err) } } ``` ## List of Features #### Privacy Layer The privacy layer allows configuring privacy policy for queries and mutations of entities in the database. This option can be added to projects using the `--feature privacy` flag, and its full documentation exists in the [privacy page](privacy.md). #### EntQL Filtering The `entql` option provides a generic and dynamic filtering capability at runtime for the different query builders. This option can be added to projects using the `--feature entql` flag, and more information about it exists in the [privacy page](privacy.md#multi-tenancy). #### Auto-Solve Merge Conflicts The `schema/snapshot` option tells `entc` (ent codegen) to store a snapshot of the latest schema in an internal package, and use it to automatically solve merge conflicts when user's schema can't be built. This option can be added to projects using the `--feature schema/snapshot` flag, but please see [facebook/ent/issues/852](https://github.com/facebook/ent/issues/852) to get more context about it.ent-0.5.4/doc/md/getting-started.md000077500000000000000000000340311377533537200171250ustar00rootroot00000000000000--- id: getting-started title: Quick Introduction sidebar_label: Quick Introduction --- **ent** is a simple, yet powerful entity framework for Go, that makes it easy to build and maintain applications with large data-models and sticks with the following principles: - Easily model database schema as a graph structure. - Define schema as a programmatic Go code. - Static typing based on code generation. - Database queries and graph traversals are easy to write. - Simple to extend and customize using Go templates.
![gopher-schema-as-code](https://entgo.io/assets/gopher-schema-as-code.png) ## Installation ```console go get github.com/facebook/ent/cmd/ent ``` After installing `ent` codegen tool, you should have it in your `PATH`. If you don't find it your path, you can also run: `go run github.com/facebook/ent/cmd/ent ` ## Setup A Go Environment If your project directory is outside [GOPATH](https://github.com/golang/go/wiki/GOPATH) or you are not familiar with GOPATH, setup a [Go module](https://github.com/golang/go/wiki/Modules#quick-start) project as follows: ```console go mod init ``` ## Create Your First Schema Go to the root directory of your project, and run: ```console ent init User ``` The command above will generate the schema for `User` under `/ent/schema/` directory: ```go // /ent/schema/user.go package schema import "github.com/facebook/ent" // User holds the schema definition for the User entity. type User struct { ent.Schema } // Fields of the User. func (User) Fields() []ent.Field { return nil } // Edges of the User. func (User) Edges() []ent.Edge { return nil } ``` Add 2 fields to the `User` schema: ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/schema/field" ) // Fields of the User. func (User) Fields() []ent.Field { return []ent.Field{ field.Int("age"). Positive(), field.String("name"). Default("unknown"), } } ``` Run `go generate` from the root directory of the project as follows: ```go go generate ./ent ``` This produces the following files: ``` ent ├── client.go ├── config.go ├── context.go ├── ent.go ├── migrate │ ├── migrate.go │ └── schema.go ├── predicate │ └── predicate.go ├── schema │ └── user.go ├── tx.go ├── user │ ├── user.go │ └── where.go ├── user.go ├── user_create.go ├── user_delete.go ├── user_query.go └── user_update.go ``` ## Create Your First Entity To get started, create a new `ent.Client`. For this example, we will use SQLite3. ```go package main import ( "context" "log" "/ent" _ "github.com/mattn/go-sqlite3" ) func main() { client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") if err != nil { log.Fatalf("failed opening connection to sqlite: %v", err) } defer client.Close() // Run the auto migration tool. if err := client.Schema.Create(context.Background()); err != nil { log.Fatalf("failed creating schema resources: %v", err) } } ``` Now, we're ready to create our user. Let's call this function `CreateUser` for the sake of example: ```go func CreateUser(ctx context.Context, client *ent.Client) (*ent.User, error) { u, err := client.User. Create(). SetAge(30). SetName("a8m"). Save(ctx) if err != nil { return nil, fmt.Errorf("failed creating user: %v", err) } log.Println("user was created: ", u) return u, nil } ``` ## Query Your Entities `ent` generates a package for each entity schema that contains its predicates, default values, validators and additional information about storage elements (column names, primary keys, etc). ```go package main import ( "log" "/ent" "/ent/user" ) func QueryUser(ctx context.Context, client *ent.Client) (*ent.User, error) { u, err := client.User. Query(). Where(user.NameEQ("a8m")). // `Only` fails if no user found, // or more than 1 user returned. Only(ctx) if err != nil { return nil, fmt.Errorf("failed querying user: %v", err) } log.Println("user returned: ", u) return u, nil } ``` ## Add Your First Edge (Relation) In this part of the tutorial, we want to declare an edge (relation) to another entity in the schema. Let's create 2 additional entities named `Car` and `Group` with a few fields. We use `ent` CLI to generate the initial schemas: ```console go run github.com/facebook/ent/cmd/ent init Car Group ``` And then we add the rest of the fields manually: ```go import ( "regexp" "github.com/facebook/ent" "github.com/facebook/ent/schema/field" ) // Fields of the Car. func (Car) Fields() []ent.Field { return []ent.Field{ field.String("model"), field.Time("registered_at"), } } // Fields of the Group. func (Group) Fields() []ent.Field { return []ent.Field{ field.String("name"). // Regexp validation for group name. Match(regexp.MustCompile("[a-zA-Z_]+$")), } } ``` Let's define our first relation. An edge from `User` to `Car` defining that a user can **have 1 or more** cars, but a car **has only one** owner (one-to-many relation). ![er-user-cars](https://entgo.io/assets/re_user_cars.png) Let's add the `"cars"` edge to the `User` schema, and run `go generate ./ent`: ```go import ( "log" "github.com/facebook/ent" "github.com/facebook/ent/schema/edge" ) // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("cars", Car.Type), } } ``` We continue our example by creating 2 cars and adding them to a user. ```go func CreateCars(ctx context.Context, client *ent.Client) (*ent.User, error) { // Create a new car with model "Tesla". tesla, err := client.Car. Create(). SetModel("Tesla"). SetRegisteredAt(time.Now()). Save(ctx) if err != nil { return nil, fmt.Errorf("failed creating car: %v", err) } // Create a new car with model "Ford". ford, err := client.Car. Create(). SetModel("Ford"). SetRegisteredAt(time.Now()). Save(ctx) if err != nil { return nil, fmt.Errorf("failed creating car: %v", err) } log.Println("car was created: ", ford) // Create a new user, and add it the 2 cars. a8m, err := client.User. Create(). SetAge(30). SetName("a8m"). AddCars(tesla, ford). Save(ctx) if err != nil { return nil, fmt.Errorf("failed creating user: %v", err) } log.Println("user was created: ", a8m) return a8m, nil } ``` But what about querying the `cars` edge (relation)? Here's how we do it: ```go import ( "log" "/ent" "/ent/car" ) func QueryCars(ctx context.Context, a8m *ent.User) error { cars, err := a8m.QueryCars().All(ctx) if err != nil { return fmt.Errorf("failed querying user cars: %v", err) } log.Println("returned cars:", cars) // What about filtering specific cars. ford, err := a8m.QueryCars(). Where(car.ModelEQ("Ford")). Only(ctx) if err != nil { return fmt.Errorf("failed querying user cars: %v", err) } log.Println(ford) return nil } ``` ## Add Your First Inverse Edge (BackRef) Assume we have a `Car` object and we want to get its owner; the user that this car belongs to. For this, we have another type of edge called "inverse edge" that is defined using the `edge.From` function. ![er-cars-owner](https://entgo.io/assets/re_cars_owner.png) The new edge created in the diagram above is translucent, to emphasize that we don't create another edge in the database. It's just a back-reference to the real edge (relation). Let's add an inverse edge named `owner` to the `Car` schema, reference it to the `cars` edge in the `User` schema, and run `go generate ./ent`. ```go import ( "log" "github.com/facebook/ent" "github.com/facebook/ent/schema/edge" ) // Edges of the Car. func (Car) Edges() []ent.Edge { return []ent.Edge{ // Create an inverse-edge called "owner" of type `User` // and reference it to the "cars" edge (in User schema) // explicitly using the `Ref` method. edge.From("owner", User.Type). Ref("cars"). // setting the edge to unique, ensure // that a car can have only one owner. Unique(), } } ``` We'll continue the user/cars example above by querying the inverse edge. ```go import ( "log" "/ent" ) func QueryCarUsers(ctx context.Context, a8m *ent.User) error { cars, err := a8m.QueryCars().All(ctx) if err != nil { return fmt.Errorf("failed querying user cars: %v", err) } // Query the inverse edge. for _, ca := range cars { owner, err := ca.QueryOwner().Only(ctx) if err != nil { return fmt.Errorf("failed querying car %q owner: %v", ca.Model, err) } log.Printf("car %q owner: %q\n", ca.Model, owner.Name) } return nil } ``` ## Create Your Second Edge We'll continue our example by creating a M2M (many-to-many) relationship between users and groups. ![er-group-users](https://entgo.io/assets/re_group_users.png) As you can see, each group entity can **have many** users, and a user can **be connected to many** groups; a simple "many-to-many" relationship. In the above illustration, the `Group` schema is the owner of the `users` edge (relation), and the `User` entity has a back-reference/inverse edge to this relationship named `groups`. Let's define this relationship in our schemas: - `/ent/schema/group.go`: ```go import ( "log" "github.com/facebook/ent" "github.com/facebook/ent/schema/edge" ) // Edges of the Group. func (Group) Edges() []ent.Edge { return []ent.Edge{ edge.To("users", User.Type), } } ``` - `/ent/schema/user.go`: ```go import ( "log" "github.com/facebook/ent" "github.com/facebook/ent/schema/edge" ) // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("cars", Car.Type), // Create an inverse-edge called "groups" of type `Group` // and reference it to the "users" edge (in Group schema) // explicitly using the `Ref` method. edge.From("groups", Group.Type). Ref("users"), } } ``` We run `ent` on the schema directory to re-generate the assets. ```console go generate ./ent ``` ## Run Your First Graph Traversal In order to run our first graph traversal, we need to generate some data (nodes and edges, or in other words, entities and relations). Let's create the following graph using the framework: ![re-graph](https://entgo.io/assets/re_graph_getting_started.png) ```go func CreateGraph(ctx context.Context, client *ent.Client) error { // First, create the users. a8m, err := client.User. Create(). SetAge(30). SetName("Ariel"). Save(ctx) if err != nil { return err } neta, err := client.User. Create(). SetAge(28). SetName("Neta"). Save(ctx) if err != nil { return err } // Then, create the cars, and attach them to the users in the creation. _, err = client.Car. Create(). SetModel("Tesla"). SetRegisteredAt(time.Now()). // ignore the time in the graph. SetOwner(a8m). // attach this graph to Ariel. Save(ctx) if err != nil { return err } _, err = client.Car. Create(). SetModel("Mazda"). SetRegisteredAt(time.Now()). // ignore the time in the graph. SetOwner(a8m). // attach this graph to Ariel. Save(ctx) if err != nil { return err } _, err = client.Car. Create(). SetModel("Ford"). SetRegisteredAt(time.Now()). // ignore the time in the graph. SetOwner(neta). // attach this graph to Neta. Save(ctx) if err != nil { return err } // Create the groups, and add their users in the creation. _, err = client.Group. Create(). SetName("GitLab"). AddUsers(neta, a8m). Save(ctx) if err != nil { return err } _, err = client.Group. Create(). SetName("GitHub"). AddUsers(a8m). Save(ctx) if err != nil { return err } log.Println("The graph was created successfully") return nil } ``` Now when we have a graph with data, we can run a few queries on it: 1. Get all user's cars within the group named "GitHub": ```go import ( "log" "/ent" "/ent/group" ) func QueryGithub(ctx context.Context, client *ent.Client) error { cars, err := client.Group. Query(). Where(group.Name("GitHub")). // (Group(Name=GitHub),) QueryUsers(). // (User(Name=Ariel, Age=30),) QueryCars(). // (Car(Model=Tesla, RegisteredAt=
`ent.Mutation` interface. ## Hooks Hooks are functions that get an `ent.Mutator` and return a mutator back. They function as middleware between mutators. It's similar to the popular HTTP middleware pattern. ```go type ( // Mutator is the interface that wraps the Mutate method. Mutator interface { // Mutate apply the given mutation on the graph. Mutate(context.Context, Mutation) (Value, error) } // Hook defines the "mutation middleware". A function that gets a Mutator // and returns a Mutator. For example: // // hook := func(next ent.Mutator) ent.Mutator { // return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { // fmt.Printf("Type: %s, Operation: %s, ConcreteType: %T\n", m.Type(), m.Op(), m) // return next.Mutate(ctx, m) // }) // } // Hook func(Mutator) Mutator ) ``` There are 2 types of mutation hooks - **schema hooks** and **runtime hooks**. **Schema hooks** are mainly used for defining custom mutation logic in the schema, and **runtime hooks** are used for adding things like logging, metrics, tracing, etc. Let's go over the 2 versions: ## Runtime hooks Let's start with a short example that logs all mutation operations of all types: ```go func main() { client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") if err != nil { log.Fatalf("failed opening connection to sqlite: %v", err) } defer client.Close() ctx := context.Background() // Run the auto migration tool. if err := client.Schema.Create(ctx); err != nil { log.Fatalf("failed creating schema resources: %v", err) } // Add a global hook that runs on all types and all operations. client.Use(func(next ent.Mutator) ent.Mutator { return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { start := time.Now() defer func() { log.Printf("Op=%s\tType=%s\tTime=%s\tConcreteType=%T\n", m.Op(), m.Type(), time.Since(start), m) }() return next.Mutate(ctx, m) }) }) client.User.Create().SetName("a8m").SaveX(ctx) // Output: // 2020/03/21 10:59:10 Op=Create Type=Card Time=46.23µs ConcreteType=*ent.UserMutation } ``` Global hooks are useful for adding traces, metrics, logs and more. But sometimes, users want more granularity: ```go func main() { // // Add a hook only on user mutations. client.User.Use(func(next ent.Mutator) ent.Mutator { // Use the "/ent/hook" to get the concrete type of the mutation. return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) { return next.Mutate(ctx, m) }) }) // Add a hook only on update operations. client.Use(hook.On(Logger(), ent.OpUpdate|ent.OpUpdateOne)) // Reject delete operations. client.Use(hook.Reject(ent.OpDelete|ent.OpDeleteOne)) } ``` Assume you want to share a hook that mutate a field between multiple types (e.g. `Group` and `User`). There are ~2 ways to do this: ```go // Option 1: use type assertion. client.Use(func(next ent.Mutator) ent.Mutator { type NameSetter interface { SetName(value string) } return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { // A schema with a "name" field must implement the NameSetter interface. if ns, ok := m.(NameSetter); ok { ns.SetName("Ariel Mashraki") } return next.Mutate(ctx, m) }) }) // Option 2: use the generic ent.Mutation interface. client.Use(func(next ent.Mutator) ent.Mutator { return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { if err := m.SetField("name", "Ariel Mashraki"); err != nil { // An error is returned, if the field is not defined in // the schema, or if the type mismatch the field type. } return next.Mutate(ctx, m) }) }) ``` ## Schema hooks Schema hooks are defined in the type schema and applied only on mutations that match the schema type. The motivation for defining hooks in the schema is to gather all logic regarding the node type in one place, which is the schema. ```go package schema import ( "context" "fmt" gen "/ent" "/ent/hook" "github.com/facebook/ent" ) // Card holds the schema definition for the CreditCard entity. type Card struct { ent.Schema } // Hooks of the Card. func (Card) Hooks() []ent.Hook { return []ent.Hook{ // First hook. hook.On( func(next ent.Mutator) ent.Mutator { return hook.CardFunc(func(ctx context.Context, m *gen.CardMutation) (ent.Value, error) { if num, ok := m.Number(); ok && len(num) < 10 { return nil, fmt.Errorf("card number is too short") } return next.Mutate(ctx, m) }) }, // Limit the hook only for these operations. ent.OpCreate|ent.OpUpdate|ent.OpUpdateOne, ), // Second hook. func(next ent.Mutator) ent.Mutator { return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { if s, ok := m.(interface{ SetName(string) }); ok { s.SetName("Boring") } return next.Mutate(ctx, m) }) }, } } ``` > **Note that** if you use **schema hooks**, you **MUST** add the following import in the > main package, because a circular import is possible. > > ```go > import _ "/ent/runtime" > ``` ## Evaluation order Hooks are called in the order they were registered to the client. Thus, `client.Use(f, g, h)` executes `f(g(h(...)))` on mutations. Also note, that **runtime hooks** are called before **schema hooks**. That is, if `g`, and `h` were defined in the schema, and `f` was registered using `client.Use(...)`, they will be executed as follows: `f(g(h(...)))`. ## Hook helpers The generated hooks package provides several helpers that can help you control when a hook will be executed. ```go package schema import ( "context" "fmt" "/ent/hook" "github.com/facebook/ent" "github.com/facebook/ent/schema/mixin" ) type SomeMixin struct { mixin.Schema } func (SomeMixin) Hooks() []ent.Hook { return []ent.Hook{ // Execute "HookA" only for the UpdateOne and DeleteOne operations. hook.On(HookA(), ent.OpUpdateOne|ent.OpDeleteOne), // Don't execute "HookB" on Create operation. hook.Unless(HookB(), ent.OpCreate), // Execute "HookC" only if the ent.Mutation is changing the "status" field, // and clearing the "dirty" field. hook.If(HookC(), hook.And(hook.HasFields("status"), hook.HasClearedFields("dirty"))), } } ``` ## Transaction Hooks Hooks can also be registered on active transactions, and will be executed on `Tx.Commit` or `Tx.Rollback`. For more information, read about it in the [transactions page](transactions.md#hooks). ## Codegen Hooks The `entc` package provides an option to add a list of hooks (middlewares) to the code-generation phase. For more information, read about it in the [codegen page](code-gen.md#code-generation-hooks). ent-0.5.4/doc/md/migrate.md000077500000000000000000000125621377533537200154550ustar00rootroot00000000000000--- id: migrate title: Database Migration --- The migration support for `ent` provides the option for keeping the database schema aligned with the schema objects defined in `ent/migrate/schema.go` under the root of your project. ## Auto Migration Run the auto-migration logic in the initialization of the application: ```go if err := client.Schema.Create(ctx); err != nil { log.Fatalf("failed creating schema resources: %v", err) } ``` `Create` creates all database resources needed for your `ent` project. By default, `Create` works in an *"append-only"* mode; which means, it only creates new tables and indexes, appends columns to tables or extends column types. For example, changing `int` to `bigint`. What about dropping columns or indexes? ## Drop Resources `WithDropIndex` and `WithDropColumn` are 2 options for dropping table columns and indexes. ```go package main import ( "context" "log" "/ent" "/ent/migrate" ) func main() { client, err := ent.Open("mysql", "root:pass@tcp(localhost:3306)/test") if err != nil { log.Fatalf("failed connecting to mysql: %v", err) } defer client.Close() ctx := context.Background() // Run migration. err = client.Schema.Create( ctx, migrate.WithDropIndex(true), migrate.WithDropColumn(true), ) if err != nil { log.Fatalf("failed creating schema resources: %v", err) } } ``` In order to run the migration in debug mode (printing all SQL queries), run: ```go err := client.Debug().Schema.Create( ctx, migrate.WithDropIndex(true), migrate.WithDropColumn(true), ) if err != nil { log.Fatalf("failed creating schema resources: %v", err) } ``` ## Universal IDs By default, SQL primary-keys start from 1 for each table; which means that multiple entities of different types can share the same ID. Unlike AWS Neptune, where node IDs are UUIDs. This does not work well if you work with [GraphQL](https://graphql.org/learn/schema/#scalar-types), which requires the object ID to be unique. To enable the Universal-IDs support for your project, pass the `WithGlobalUniqueID` option to the migration. ```go package main import ( "context" "log" "/ent" "/ent/migrate" ) func main() { client, err := ent.Open("mysql", "root:pass@tcp(localhost:3306)/test") if err != nil { log.Fatalf("failed connecting to mysql: %v", err) } defer client.Close() ctx := context.Background() // Run migration. if err := client.Schema.Create(ctx, migrate.WithGlobalUniqueID(true)); err != nil { log.Fatalf("failed creating schema resources: %v", err) } } ``` **How does it work?** `ent` migration allocates a 1<<32 range for the IDs of each entity (table), and store this information in a table named `ent_types`. For example, type `A` will have the range of `[1,4294967296)` for its IDs, and type `B` will have the range of `[4294967296,8589934592)`, etc. Note that if this option is enabled, the maximum number of possible tables is **65535**. ## Offline Mode Offline mode allows you to write the schema changes to an `io.Writer` before executing them on the database. It's useful for verifying the SQL commands before they're executed on the database, or to get an SQL script to run manually. **Print changes** ```go package main import ( "context" "log" "os" "/ent" "/ent/migrate" ) func main() { client, err := ent.Open("mysql", "root:pass@tcp(localhost:3306)/test") if err != nil { log.Fatalf("failed connecting to mysql: %v", err) } defer client.Close() ctx := context.Background() // Dump migration changes to stdout. if err := client.Schema.WriteTo(ctx, os.Stdout); err != nil { log.Fatalf("failed printing schema changes: %v", err) } } ``` **Write changes to a file** ```go package main import ( "context" "log" "os" "/ent" "/ent/migrate" ) func main() { client, err := ent.Open("mysql", "root:pass@tcp(localhost:3306)/test") if err != nil { log.Fatalf("failed connecting to mysql: %v", err) } defer client.Close() ctx := context.Background() // Dump migration changes to an SQL script. f, err := os.Create("migrate.sql") if err != nil { log.Fatalf("create migrate file: %v", err) } defer f.Close() if err := client.Schema.WriteTo(ctx, f); err != nil { log.Fatalf("failed printing schema changes: %v", err) } } ``` ## Foreign Keys By default, `ent` uses foreign-keys when defining relationships (edges) to enforce correctness and consistency on the database side. However, `ent` also provide an option to disable this functionality using the `WithForeignKeys` option. You should note that setting this option to `false`, will tell the migration to not create foreign-keys in the schema DDL and the edges validation and clearing must be handled manually by the developer. We expect to provide a set of hooks for implementing the foreign-key constraints in the application level in the near future. ```go package main import ( "context" "log" "/ent" "/ent/migrate" ) func main() { client, err := ent.Open("mysql", "root:pass@tcp(localhost:3306)/test") if err != nil { log.Fatalf("failed connecting to mysql: %v", err) } defer client.Close() ctx := context.Background() // Run migration. err = client.Schema.Create( ctx, migrate.WithForeignKeys(false), // Disable foreign keys. ) if err != nil { log.Fatalf("failed creating schema resources: %v", err) } } ```ent-0.5.4/doc/md/paging.md000077500000000000000000000011061377533537200152620ustar00rootroot00000000000000--- id: paging title: Paging And Ordering --- ## Limit `Limit` limits the query result to `n` entities. ```go users, err := client.User. Query(). Limit(n). All(ctx) ``` ## Offset `Offset` sets the first node to return from the query. ```go users, err := client.User. Query(). Offset(10). All(ctx) ``` ## Ordering `Order` returns the entities sorted by the values of one or more fields. Note that, an error is returned if the given fields are not valid columns or foreign-keys. ```go users, err := client.User.Query(). Order(ent.Asc(user.FieldName)). All(ctx) ``` ent-0.5.4/doc/md/predicates.md000077500000000000000000000031501377533537200161410ustar00rootroot00000000000000--- id: predicates title: Predicates --- ## Field Predicates - **Bool**: - =, != - **Numeric**: - =, !=, >, <, >=, <=, - IN, NOT IN - **Time**: - =, !=, >, <, >=, <= - IN, NOT IN - **String**: - =, !=, >, <, >=, <= - IN, NOT IN - Contains, HasPrefix, HasSuffix - ContainsFold, EqualFold (**SQL** specific) - **JSON** - =, != - =, !=, >, <, >=, <= on nested values (JSON path). - Contains on nested values (JSON path). - HasKey, Len

- **Optional** fields: - IsNil, NotNil ## Edge Predicates - **HasEdge**. For example, for edge named `owner` of type `Pet`, use: ```go client.Pet. Query(). Where(pet.HasOwner()). All(ctx) ``` - **HasEdgeWith**. Add list of predicates for edge predicate. ```go client.Pet. Query(). Where(pet.HasOwnerWith(user.Name("a8m"))). All(ctx) ``` ## Negation (NOT) ```go client.Pet. Query(). Where(pet.Not(pet.NameHasPrefix("Ari"))). All(ctx) ``` ## Disjunction (OR) ```go client.Pet. Query(). Where( pet.Or( pet.HasOwner(), pet.Not(pet.HasFriends()), ) ). All(ctx) ``` ## Conjunction (AND) ```go client.Pet. Query(). Where( pet.And( pet.HasOwner(), pet.Not(pet.HasFriends()), ) ). All(ctx) ``` ## Custom Predicates Custom predicates can be useful if you want to write your own dialect-specific logic. ```go pets := client.Pet. Query(). Where(predicate.Pet(func(s *sql.Selector) { s.Where(sql.InInts(pet.OwnerColumn, 1, 2, 3)) })). AllX(ctx) users := client.User. Query(). Where(predicate.User(func(s *sql.Selector) { s.Where(sqljson.HasKey(user.FieldURL, sqljson.Path("Scheme"))) })). AllX(ctx) ``` ent-0.5.4/doc/md/privacy.md000066400000000000000000000435771377533537200155110ustar00rootroot00000000000000--- id: privacy title: Privacy --- The `Policy` option in the schema allows configuring privacy policy for queries and mutations of entities in the database. ![gopher-privacy](https://entgo.io/assets/gopher-privacy-opacity.png) The main advantage of the privacy layer is that, you write the privacy policy **once** (in the schema), and it is **always** evaluated. No matter where queries and mutations are performed in your codebase, it will always go through the privacy layer. In this tutorial, we will start by going over the basic terms we use in the framework, continue with a section for configuring the policy feature to your project, and finish with a few examples. ## Basic Terms ### Policy The `ent.Policy` interface contains two methods: `EvalQuery` and `EvalMutation`. The first defines the read-policy, and the second defines the write-policy. A policy contains zero or more privacy rules (see below). These rules are evaluated in the same order they are declared in the schema. If all rules are evaluated without returning an error, the evaluation finishes successfully, and the executed operation gets access to the target nodes. ![privacy-rules](https://entgo.io/assets/permission_1.png) However, if one of the evaluated rules returns an error or a `privacy.Deny` decision (see below), the executed operation returns an error, and it is cancelled. ![privacy-deny](https://entgo.io/assets/permission_2.png) ### Privacy Rules Each policy (mutation or query) includes one or more privacy rules. The function signature for these rules is as follows: ```go // EvalQuery defines the a read-policy rule. func(Policy) EvalQuery(context.Context, Query) error // EvalMutation defines the a write-policy rule. func(Policy) EvalMutation(context.Context, Mutation) error ``` ### Privacy Decisions There are three types of decision that can help you control the privacy rules evaluation. - `privacy.Allow` - If returned from a privacy rule, the evaluation stops (next rules will be skipped), and the executed operation (query or mutation) gets access to the target nodes. - `privacy.Deny` - If returned from a privacy rule, the evaluation stops (next rules will be skipped), and the executed operation is cancelled. This equivalent to returning any error. - `privacy.Skip` - Skip the current rule, and jump to the next privacy rule. This equivalent to returning a `nil` error. ![privacy-allow](https://entgo.io/assets/permission_3.png) Now, that we’ve covered the basic terms, let’s start writing some code. ## Configuration In order to enable the privacy option in your code generation, enable the `privacy` feature with one of two options: 1\. If you are using the default go generate config, add `--feature privacy` option to the `ent/generate.go` file as follows: ```go package ent //go:generate go run github.com/facebook/ent/cmd/ent generate --feature privacy ./schema ``` It is recommended to add the [`schema/snapshot`](features.md#auto-solve-merge-conflicts) feature-flag along with the `privacy` to enhance the development experience (e.g. `--feature privacy,schema/snapshot`) 2\. If you are using the configuration from the GraphQL documentation, add the feature flag as follows: ```go // Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // +build ignore package main import ( "log" "github.com/facebook/ent/entc" "github.com/facebook/ent/entc/gen" "github.com/facebookincubator/ent-contrib/entgql" ) func main() { opts := []entc.Option{ entc.FeatureNames("privacy"), } err := entc.Generate("./schema", &gen.Config{ Templates: entgql.AllTemplates, }, opts...) if err != nil { log.Fatalf("running ent codegen: %v", err) } } ``` > You should notice that, similar to schema hooks, if you use the **`Policy`** option in your schema, > you **MUST** add the following import in the main package, because a circular import is possible: > > ```go > import _ "/ent/runtime" > ``` ## Examples ### Admin Only We start with a simple example of an application that lets any user read any data, and accepts mutations only from users with admin role. We will create 2 additional packages for the purpose of the examples: - `rule` - for holding the different privacy rules in our schema. - `viewer` - for getting and setting the user/viewer who's executing the operation. In this simple example, it can be either a normal user or an admin.
After running the code-generation (with the feature-flag for privacy), we add the `Policy` method with 2 generated policy rules. ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/examples/privacyadmin/ent/privacy" ) // User holds the schema definition for the User entity. type User struct { ent.Schema } // Policy defines the privacy policy of the User. func (User) Policy() ent.Policy { return privacy.Policy{ Mutation: privacy.MutationPolicy{ // Deny if not set otherwise. privacy.AlwaysDenyRule(), }, Query: privacy.QueryPolicy{ // Allow any viewer to read anything. privacy.AlwaysAllowRule(), }, } } ``` We defined a policy that rejects any mutation and accepts any query. However, as mentioned above, in this example, we accept mutations only from viewers with admin role. Let's create 2 privacy rules to enforce this: ```go package rule import ( "context" "github.com/facebook/ent/examples/privacyadmin/ent/privacy" "github.com/facebook/ent/examples/privacyadmin/viewer" ) // DenyIfNoViewer is a rule that returns Deny decision if the viewer is // missing in the context. func DenyIfNoViewer() privacy.QueryMutationRule { return privacy.ContextQueryMutationRule(func(ctx context.Context) error { view := viewer.FromContext(ctx) if view == nil { return privacy.Denyf("viewer-context is missing") } // Skip to the next privacy rule (equivalent to returning nil). return privacy.Skip }) } // AllowIfAdmin is a rule that returns Allow decision if the viewer is admin. func AllowIfAdmin() privacy.QueryMutationRule { return privacy.ContextQueryMutationRule(func(ctx context.Context) error { view := viewer.FromContext(ctx) if view.Admin() { return privacy.Allow } // Skip to the next privacy rule (equivalent to returning nil). return privacy.Skip }) } ``` As you can see, the first rule `DenyIfNoViewer`, makes sure every operation has a viewer in its context, otherwise, the operation rejected. The second rule `AllowIfAdmin`, accepts any operation from viewer with admin role. Let's add them to the schema, and run the code-generation: ```go // Policy defines the privacy policy of the User. func (User) Policy() ent.Policy { return privacy.Policy{ Mutation: privacy.MutationPolicy{ rule.DenyIfNoViewer(), rule.AllowIfAdmin(), privacy.AlwaysDenyRule(), }, Query: privacy.QueryPolicy{ privacy.AlwaysAllowRule(), }, } } ``` Since we define the `DenyIfNoViewer` first, it will be executed before all other rules, and accessing the `viewer.Viewer` object is safe in the `AllowIfAdmin` rule. After adding the rules above and running the code-generation, we expect the privacy-layer logic to be applied on `ent.Client` operations. ```go func Do(ctx context.Context, client *ent.Client) error { // Expect operation to fail, because viewer-context // is missing (first mutation rule check). if _, err := client.User.Create().Save(ctx); !errors.Is(err, privacy.Deny) { return fmt.Errorf("expect operation to fail, but got %v", err) } // Apply the same operation with "Admin" role. admin := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.Admin}) if _, err := client.User.Create().Save(admin); err != nil { return fmt.Errorf("expect operation to pass, but got %v", err) } // Apply the same operation with "ViewOnly" role. viewOnly := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.View}) if _, err := client.User.Create().Save(viewOnly); !errors.Is(err, privacy.Deny) { return fmt.Errorf("expect operation to fail, but got %v", err) } // Allow all viewers to query users. for _, ctx := range []context.Context{ctx, viewOnly, admin} { // Operation should pass for all viewers. count := client.User.Query().CountX(ctx) fmt.Println(count) } return nil } ``` ### Decision Context Sometimes, we want to bind a specific privacy decision to the `context.Context`. In cases like this, we can use the `privacy.DecisionContext` function to create a new context with a privacy decision attached to it. ```go func Do(ctx context.Context, client *ent.Client) error { // Bind a privacy decision to the context (bypass all other rules). allow := privacy.DecisionContext(ctx, privacy.Allow) if _, err := client.User.Create().Save(allow); err != nil { return fmt.Errorf("expect operation to pass, but got %v", err) } return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/privacyadmin). ### Multi Tenancy In this example, we're going to create a schema with 3 entity types - `Tenant`, `User` and `Group`. The helper packages `viewer` and `rule` (as mentioned above) also exist in this example to help us structure the application. ![tenant-example](https://entgo.io/assets/tenant_medium.png) Let's start building this application piece by piece. We begin by creating 3 different schemas (see the full code [here](https://github.com/facebook/ent/tree/master/examples/privacytenant/ent/schema)), and since we want to share some logic between them, we create another [mixed-in schema](schema-mixin.md) and add it to all other schemas as follows: ```go // BaseMixin for all schemas in the graph. type BaseMixin struct { mixin.Schema } // Policy defines the privacy policy of the BaseMixin. func (BaseMixin) Policy() ent.Policy { return privacy.Policy{ Mutation: privacy.MutationPolicy{ rule.DenyIfNoViewer(), }, Query: privacy.QueryPolicy{ rule.DenyIfNoViewer(), }, } } // Mixin of the Tenant schema. func (Tenant) Mixin() []ent.Mixin { return []ent.Mixin{ BaseMixin{}, } } ``` As explained in the first example, the `DenyIfNoViewer` privacy rule, denies the operation if the `context.Context` does not contain the `viewer.Viewer` information. Similar to the previous example, we want add a constraint that only admin users can create tenants (and deny otherwise). We do it by copying the `AllowIfAdmin` rule from above, and adding it to the `Policy` of the `Tenant` schema: ```go // Policy defines the privacy policy of the User. func (Tenant) Policy() ent.Policy { return privacy.Policy{ Mutation: privacy.MutationPolicy{ // For Tenant type, we only allow admin users to mutate // the tenant information and deny otherwise. rule.AllowIfAdmin(), privacy.AlwaysDenyRule(), }, } } ``` Then, we expect the following code to run successfully: ```go func Do(ctx context.Context, client *ent.Client) error { // Expect operation to fail, because viewer-context // is missing (first mutation rule check). if _, err := client.Tenant.Create().Save(ctx); !errors.Is(err, privacy.Deny) { return fmt.Errorf("expect operation to fail, but got %v", err) } // Deny tenant creation if the viewer is not admin. viewOnly := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.View}) if _, err := client.Tenant.Create().Save(viewOnly); !errors.Is(err, privacy.Deny) { return fmt.Errorf("expect operation to fail, but got %v", err) } // Apply the same operation with "Admin" role. admin := viewer.NewContext(ctx, viewer.UserViewer{Role: viewer.Admin}) hub, err := client.Tenant.Create().SetName("GitHub").Save(admin) if err != nil { return fmt.Errorf("expect operation to pass, but got %v", err) } fmt.Println(hub) lab, err := client.Tenant.Create().SetName("GitLab").Save(admin) if err != nil { return fmt.Errorf("expect operation to pass, but got %v", err) } fmt.Println(lab) return nil } ``` We continue by adding the rest of the edges in our data-model (see image above), and since both `User` and `Group` have an edge to the `Tenant` schema, we create a shared [mixed-in schema](schema-mixin.md) named `TenantMixin` for this: ```go // TenantMixin for embedding the tenant info in different schemas. type TenantMixin struct { mixin.Schema } // Edges for all schemas that embed TenantMixin. func (TenantMixin) Edges() []ent.Edge { return []ent.Edge{ edge.To("tenant", Tenant.Type). Unique(). Required(), } } ``` Now, we want to enforce that viewers can see only groups and users that are connected to the tenant they belong to. In this case, there's another type of privacy rule named `FilterRule`. This rule can help us to filters out entities that are not connected to the same tenant. > Note, the filtering option for privacy needs to be enabled using the `entql` feature-flag (see instructions [above](#configuration)). ```go // FilterTenantRule is a query rule that filters out entities that are not in the tenant. func FilterTenantRule() privacy.QueryRule { type TeamsFilter interface { WhereHasTenantWith(...predicate.Tenant) } return privacy.FilterFunc(func(ctx context.Context, f privacy.Filter) error { view := viewer.FromContext(ctx) if view.Tenant() == "" { return privacy.Denyf("missing tenant information in viewer") } tf, ok := f.(TeamsFilter) if !ok { return privacy.Denyf("unexpected filter type %T", f) } // Make sure that a tenant reads only entities that has an edge to it. tf.WhereHasTenantWith(tenant.Name(view.Tenant())) // Skip to the next privacy rule (equivalent to return nil). return privacy.Skip }) } ``` After creating the `FilterTenantRule` privacy rule, we add it to the `TenantMixin` to make sure **all schemas** that use this mixin, will also have this privacy rule. ```go // Policy for all schemas that embed TenantMixin. func (TenantMixin) Policy() ent.Policy { return privacy.Policy{ Query: privacy.QueryPolicy{ rule.AllowIfAdmin(), // Filter out entities that are not connected to the tenant. // If the viewer is admin, this policy rule is skipped above. rule.FilterTenantRule(), }, } } ``` Then, after running the code-generation, we expect the privacy-rules to take effect on the client operations. ```go func Do(ctx context.Context, client *ent.Client) error { // A continuation of the code-block above. // Create 2 users connected to the 2 tenants we created above (a8m->GitHub, nati->GitLab). a8m := client.User.Create().SetName("a8m").SetTenant(hub).SaveX(admin) nati := client.User.Create().SetName("nati").SetTenant(lab).SaveX(admin) hubView := viewer.NewContext(ctx, viewer.UserViewer{T: hub}) out := client.User.Query().OnlyX(hubView) // Expect that "GitHub" tenant to read only its users (i.e. a8m). if out.ID != a8m.ID { return fmt.Errorf("expect result for user query, got %v", out) } fmt.Println(out) labView := viewer.NewContext(ctx, viewer.UserViewer{T: lab}) out = client.User.Query().OnlyX(labView) // Expect that "GitLab" tenant to read only its users (i.e. nati). if out.ID != nati.ID { return fmt.Errorf("expect result for user query, got %v", out) } fmt.Println(out) return nil } ``` We finish our example with another privacy-rule named `DenyMismatchedTenants` on the `Group` schema. The `DenyMismatchedTenants` rule rejects the group creation if the associated users don't belong to the same tenant as the group. ```go // DenyMismatchedTenants is a rule runs only on create operations, and returns a deny decision // if the operation tries to add users to groups that are not in the same tenant. func DenyMismatchedTenants() privacy.MutationRule { // Create a rule, and limit it to create operations below. rule := privacy.GroupMutationRuleFunc(func(ctx context.Context, m *ent.GroupMutation) error { tid, exists := m.TenantID() if !exists { return privacy.Denyf("missing tenant information in mutation") } users := m.UsersIDs() // If there are no users in the mutation, skip this rule-check. if len(users) == 0 { return privacy.Skip } // Query the tenant-id of all users. Expect to have exact 1 result, // and it matches the tenant-id of the group above. uid, err := m.Client().User.Query().Where(user.IDIn(users...)).QueryTenant().OnlyID(ctx) if err != nil { return privacy.Denyf("querying the tenant-id %v", err) } if uid != tid { return privacy.Denyf("mismatch tenant-ids for group/users %d != %d", tid, uid) } // Skip to the next privacy rule (equivalent to return nil). return privacy.Skip }) // Evaluate the mutation rule only on group creation. return privacy.OnMutationOperation(rule, ent.OpCreate) } ``` We add this rule to the `Group` schema and run code-generation. ```go // Policy defines the privacy policy of the Group. func (Group) Policy() ent.Policy { return privacy.Policy{ Mutation: privacy.MutationPolicy{ rule.DenyMismatchedTenants(), }, } } ``` Again, we expect the privacy-rules to take effect on the client operations. ```go func Do(ctx context.Context, client *ent.Client) error { // A continuation of the code-block above. // We expect operation to fail, because the DenyMismatchedTenants rule // makes sure the group and the users are connected to the same tenant. _, err = client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(nati).Save(admin) if !errors.Is(err, privacy.Deny) { return fmt.Errorf("expect operatio to fail, since user (nati) is not connected to the same tenant") } _, err = client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(nati, a8m).Save(admin) if !errors.Is(err, privacy.Deny) { return fmt.Errorf("expect operatio to fail, since some users (nati) are not connected to the same tenant") } entgo, err := client.Group.Create().SetName("entgo.io").SetTenant(hub).AddUsers(a8m).Save(admin) if err != nil { return fmt.Errorf("expect operation to pass, but got %v", err) } fmt.Println(entgo) return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/privacytenant). Please note that this documentation is under active development.ent-0.5.4/doc/md/schema-annotations.md000077500000000000000000000022321377533537200176110ustar00rootroot00000000000000--- id: schema-annotations title: Annotations --- Schema annotations allow attaching metadata to schema objects like fields and edges and inject them to external templates. An annotation must be a Go type that is serializable to JSON raw value (e.g. struct, map or slice) and implement the [Annotation](https://pkg.go.dev/github.com/facebook/ent/schema/field?tab=doc#Annotation) interface. The builtin annotations allow configuring the different storage drivers (like SQL) and control the code generation output. ## Custom Table Name A custom table name can be provided for types using the `entsql` annotation as follows: ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/dialect/entsql" "github.com/facebook/ent/schema" "github.com/facebook/ent/schema/field" ) // User holds the schema definition for the User entity. type User struct { ent.Schema } // Annotations of the User. func (User) Annotations() []schema.Annotation { return []schema.Annotation{ entsql.Annotation{Table: "Users"}, } } // Fields of the User. func (User) Fields() []ent.Field { return []ent.Field{ field.Int("age"), field.String("name"), } } ``` ent-0.5.4/doc/md/schema-def.md000077500000000000000000000030351377533537200160140ustar00rootroot00000000000000--- id: schema-def title: Introduction --- ## Quick Summary Schema describes the definition of one entity type in the graph, like `User` or `Group`, and can contain the following configurations: - Entity fields (or properties), like: name or age of a `User`. - Entity edges (or relations), like: `User`'s groups, or `User`'s friends. - Database specific options, like: indexes or unique indexes.
Here's an example of a schema: ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/schema/field" "github.com/facebook/ent/schema/edge" "github.com/facebook/ent/schema/index" ) type User struct { ent.Schema } func (User) Fields() []ent.Field { return []ent.Field{ field.Int("age"), field.String("name"), field.String("nickname"). Unique(), } } func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("groups", Group.Type), edge.To("friends", User.Type), } } func (User) Index() []ent.Index { return []ent.Index{ index.Fields("age", "name"). Unique(), } } ``` Entity schemas are usually stored inside `ent/schema` directory under the root directory of your project, and can be generated by `entc` as follows: ```console go run github.com/facebook/ent/cmd/ent init User Group ``` ## It's Just Another ORM If you are used to the definition of relations over edges, that's fine. The modeling is the same. You can model with `ent` whatever you can model with other traditional ORMs. There are many examples in this website that can help you get started in the [Edges](schema-edges.md) section. ent-0.5.4/doc/md/schema-edges.md000077500000000000000000000517361377533537200163600ustar00rootroot00000000000000--- id: schema-edges title: Edges --- ## Quick Summary Edges are the relations (or associations) of entities. For example, user's pets, or group's users. ![er-group-users](https://entgo.io/assets/er_user_pets_groups.png) In the example above, you can see 2 relations declared using edges. Let's go over them. 1\. `pets` / `owner` edges; user's pets and pet's owner - `ent/schema/user.go` ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/schema/edge" ) // User schema. type User struct { ent.Schema } // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ // ... } } // Edges of the user. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("pets", Pet.Type), } } ``` `ent/schema/pet.go` ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/schema/edge" ) // Pet holds the schema definition for the Pet entity. type Pet struct { ent.Schema } // Fields of the Pet. func (Pet) Fields() []ent.Field { return []ent.Field{ // ... } } // Edges of the Pet. func (Pet) Edges() []ent.Edge { return []ent.Edge{ edge.From("owner", User.Type). Ref("pets"). Unique(), } } ``` As you can see, a `User` entity can **have many** pets, but a `Pet` entity can **have only one** owner. In relationship definition, the `pets` edge is a *O2M* (one-to-many) relationship, and the `owner` edge is a *M2O* (many-to-one) relationship. The `User` schema **owns** the `pets/owner` relationship because it uses `edge.To`, and the `Pet` schema just has a back-reference to it, declared using `edge.From` with the `Ref` method. The `Ref` method describes which edge of the `User` schema we're referencing because there can be multiple references from one schema to other. The cardinality of the edge/relationship can be controlled using the `Unique` method, and it's explained more widely below. 2\. `users` / `groups` edges; group's users and user's groups - `ent/schema/group.go` ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/schema/edge" ) // Group schema. type Group struct { ent.Schema } // Fields of the group. func (Group) Fields() []ent.Field { return []ent.Field{ // ... } } // Edges of the group. func (Group) Edges() []ent.Edge { return []ent.Edge{ edge.To("users", User.Type), } } ``` `ent/schema/user.go` ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/schema/edge" ) // User schema. type User struct { ent.Schema } // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ // ... } } // Edges of the user. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.From("groups", Group.Type). Ref("users"), // "pets" declared in the example above. edge.To("pets", Pet.Type), } } ``` As you can see, a Group entity can **have many** users, and a User entity can have **have many** groups. In relationship definition, the `users` edge is a *M2M* (many-to-many) relationship, and the `groups` edge is also a *M2M* (many-to-many) relationship. ## To and From `edge.To` and `edge.From` are the 2 builders for creating edges/relations. A schema that defines an edge using the `edge.To` builder owns the relation, unlike using the `edge.From` builder that gives only a back-reference for the relation (with a different name). Let's go over a few examples that show how to define different relation types using edges. ## Relationship - [O2O Two Types](#o2o-two-types) - [O2O Same Type](#o2o-same-type) - [O2O Bidirectional](#o2o-bidirectional) - [O2M Two Types](#o2m-two-types) - [O2M Same Type](#o2m-same-type) - [M2M Two Types](#m2m-two-types) - [M2M Same Type](#m2m-same-type) - [M2M Bidirectional](#m2m-bidirectional) ## O2O Two Types ![er-user-card](https://entgo.io/assets/er_user_card.png) In this example, a user **has only one** credit-card, and a card **has only one** owner. The `User` schema defines an `edge.To` card named `card`, and the `Card` schema defines a back-reference to this edge using `edge.From` named `owner`. `ent/schema/user.go` ```go // Edges of the user. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("card", Card.Type). Unique(), } } ``` `ent/schema/card.go` ```go // Edges of the Card. func (Card) Edges() []ent.Edge { return []ent.Edge{ edge.From("owner", User.Type). Ref("card"). Unique(). // We add the "Required" method to the builder // to make this edge required on entity creation. // i.e. Card cannot be created without its owner. Required(), } } ``` The API for interacting with these edges is as follows: ```go func Do(ctx context.Context, client *ent.Client) error { a8m, err := client.User. Create(). SetAge(30). SetName("Mashraki"). Save(ctx) if err != nil { return fmt.Errorf("creating user: %v", err) } log.Println("user:", a8m) card1, err := client.Card. Create(). SetOwner(a8m). SetNumber("1020"). SetExpired(time.Now().Add(time.Minute)). Save(ctx) if err != nil { return fmt.Errorf("creating card: %v", err) } log.Println("card:", card1) // Only returns the card of the user, // and expects that there's only one. card2, err := a8m.QueryCard().Only(ctx) if err != nil { return fmt.Errorf("querying card: %v", err) } log.Println("card:", card2) // The Card entity is able to query its owner using // its back-reference. owner, err := card2.QueryOwner().Only(ctx) if err != nil { return fmt.Errorf("querying owner: %v", err) } log.Println("owner:", owner) return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/o2o2types). ## O2O Same Type ![er-linked-list](https://entgo.io/assets/er_linked_list.png) In this linked-list example, we have a **recursive relation** named `next`/`prev`. Each node in the list can **have only one** `next` node. If a node A points (using `next`) to node B, B can get its pointer using `prev` (the back-reference edge). `ent/schema/node.go` ```go // Edges of the Node. func (Node) Edges() []ent.Edge { return []ent.Edge{ edge.To("next", Node.Type). Unique(). From("prev"). Unique(), } } ``` As you can see, in cases of relations of the same type, you can declare the edge and its reference in the same builder. ```diff func (Node) Edges() []ent.Edge { return []ent.Edge{ + edge.To("next", Node.Type). + Unique(). + From("prev"). + Unique(), - edge.To("next", Node.Type). - Unique(), - edge.From("prev", Node.Type). - Ref("next). - Unique(), } } ``` The API for interacting with these edges is as follows: ```go func Do(ctx context.Context, client *ent.Client) error { head, err := client.Node. Create(). SetValue(1). Save(ctx) if err != nil { return fmt.Errorf("creating the head: %v", err) } curr := head // Generate the following linked-list: 1<->2<->3<->4<->5. for i := 0; i < 4; i++ { curr, err = client.Node. Create(). SetValue(curr.Value + 1). SetPrev(curr). Save(ctx) if err != nil { return err } } // Loop over the list and print it. `FirstX` panics if an error occur. for curr = head; curr != nil; curr = curr.QueryNext().FirstX(ctx) { fmt.Printf("%d ", curr.Value) } // Output: 1 2 3 4 5 // Make the linked-list circular: // The tail of the list, has no "next". tail, err := client.Node. Query(). Where(node.Not(node.HasNext())). Only(ctx) if err != nil { return fmt.Errorf("getting the tail of the list: %v", tail) } tail, err = tail.Update().SetNext(head).Save(ctx) if err != nil { return err } // Check that the change actually applied: prev, err := head.QueryPrev().Only(ctx) if err != nil { return fmt.Errorf("getting head's prev: %v", err) } fmt.Printf("\n%v", prev.Value == tail.Value) // Output: true return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/o2orecur). ## O2O Bidirectional ![er-user-spouse](https://entgo.io/assets/er_user_spouse.png) In this user-spouse example, we have a **symmetric O2O relation** named `spouse`. Each user can **have only one** spouse. If user A sets its spouse (using `spouse`) to B, B can get its spouse using the `spouse` edge. Note that there are no owner/inverse terms in cases of bidirectional edges. `ent/schema/user.go` ```go // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("spouse", User.Type). Unique(), } } ``` The API for interacting with this edge is as follows: ```go func Do(ctx context.Context, client *ent.Client) error { a8m, err := client.User. Create(). SetAge(30). SetName("a8m"). Save(ctx) if err != nil { return fmt.Errorf("creating user: %v", err) } nati, err := client.User. Create(). SetAge(28). SetName("nati"). SetSpouse(a8m). Save(ctx) if err != nil { return fmt.Errorf("creating user: %v", err) } // Query the spouse edge. // Unlike `Only`, `OnlyX` panics if an error occurs. spouse := nati.QuerySpouse().OnlyX(ctx) fmt.Println(spouse.Name) // Output: a8m spouse = a8m.QuerySpouse().OnlyX(ctx) fmt.Println(spouse.Name) // Output: nati // Query how many users have a spouse. // Unlike `Count`, `CountX` panics if an error occurs. count := client.User. Query(). Where(user.HasSpouse()). CountX(ctx) fmt.Println(count) // Output: 2 // Get the user, that has a spouse with name="a8m". spouse = client.User. Query(). Where(user.HasSpouseWith(user.Name("a8m"))). OnlyX(ctx) fmt.Println(spouse.Name) // Output: nati return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/o2obidi). ## O2M Two Types ![er-user-pets](https://entgo.io/assets/er_user_pets.png) In this user-pets example, we have a O2M relation between user and its pets. Each user **has many** pets, and a pet **has one** owner. If user A adds a pet B using the `pets` edge, B can get its owner using the `owner` edge (the back-reference edge). Note that this relation is also a M2O (many-to-one) from the point of view of the `Pet` schema. `ent/schema/user.go` ```go // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("pets", Pet.Type), } } ``` `ent/schema/pet.go` ```go // Edges of the Pet. func (Pet) Edges() []ent.Edge { return []ent.Edge{ edge.From("owner", User.Type). Ref("pets"). Unique(), } } ``` The API for interacting with these edges is as follows: ```go func Do(ctx context.Context, client *ent.Client) error { // Create the 2 pets. pedro, err := client.Pet. Create(). SetName("pedro"). Save(ctx) if err != nil { return fmt.Errorf("creating pet: %v", err) } lola, err := client.Pet. Create(). SetName("lola"). Save(ctx) if err != nil { return fmt.Errorf("creating pet: %v", err) } // Create the user, and add its pets on the creation. a8m, err := client.User. Create(). SetAge(30). SetName("a8m"). AddPets(pedro, lola). Save(ctx) if err != nil { return fmt.Errorf("creating user: %v", err) } fmt.Println("User created:", a8m) // Output: User(id=1, age=30, name=a8m) // Query the owner. Unlike `Only`, `OnlyX` panics if an error occurs. owner := pedro.QueryOwner().OnlyX(ctx) fmt.Println(owner.Name) // Output: a8m // Traverse the sub-graph. Unlike `Count`, `CountX` panics if an error occurs. count := pedro. QueryOwner(). // a8m QueryPets(). // pedro, lola CountX(ctx) // count fmt.Println(count) // Output: 2 return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/o2m2types). ## O2M Same Type ![er-tree](https://entgo.io/assets/er_tree.png) In this example, we have a recursive O2M relation between tree's nodes and their children (or their parent). Each node in the tree **has many** children, and **has one** parent. If node A adds B to its children, B can get its owner using the `owner` edge. `ent/schema/node.go` ```go // Edges of the Node. func (Node) Edges() []ent.Edge { return []ent.Edge{ edge.To("children", Node.Type). From("parent"). Unique(), } } ``` As you can see, in cases of relations of the same type, you can declare the edge and its reference in the same builder. ```diff func (Node) Edges() []ent.Edge { return []ent.Edge{ + edge.To("children", Node.Type). + From("parent"). + Unique(), - edge.To("children", Node.Type), - edge.From("parent", Node.Type). - Ref("children"). - Unique(), } } ``` The API for interacting with these edges is as follows: ```go func Do(ctx context.Context, client *ent.Client) error { root, err := client.Node. Create(). SetValue(2). Save(ctx) if err != nil { return fmt.Errorf("creating the root: %v", err) } // Add additional nodes to the tree: // // 2 // / \ // 1 4 // / \ // 3 5 // // Unlike `Save`, `SaveX` panics if an error occurs. n1 := client.Node. Create(). SetValue(1). SetParent(root). SaveX(ctx) n4 := client.Node. Create(). SetValue(4). SetParent(root). SaveX(ctx) n3 := client.Node. Create(). SetValue(3). SetParent(n4). SaveX(ctx) n5 := client.Node. Create(). SetValue(5). SetParent(n4). SaveX(ctx) fmt.Println("Tree leafs", []int{n1.Value, n3.Value, n5.Value}) // Output: Tree leafs [1 3 5] // Get all leafs (nodes without children). // Unlike `Int`, `IntX` panics if an error occurs. ints := client.Node. Query(). // All nodes. Where(node.Not(node.HasChildren())). // Only leafs. Order(ent.Asc(node.FieldValue)). // Order by their `value` field. GroupBy(node.FieldValue). // Extract only the `value` field. IntsX(ctx) fmt.Println(ints) // Output: [1 3 5] // Get orphan nodes (nodes without parent). // Unlike `Only`, `OnlyX` panics if an error occurs. orphan := client.Node. Query(). Where(node.Not(node.HasParent())). OnlyX(ctx) fmt.Println(orphan) // Output: Node(id=1, value=2) return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/o2mrecur). ## M2M Two Types ![er-user-groups](https://entgo.io/assets/er_user_groups.png) In this groups-users example, we have a M2M relation between groups and their users. Each group **has many** users, and each user can be joined to **many** groups. `ent/schema/group.go` ```go // Edges of the Group. func (Group) Edges() []ent.Edge { return []ent.Edge{ edge.To("users", User.Type), } } ``` `ent/schema/user.go` ```go // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.From("groups", Group.Type). Ref("users"), } } ``` The API for interacting with these edges is as follows: ```go func Do(ctx context.Context, client *ent.Client) error { // Unlike `Save`, `SaveX` panics if an error occurs. hub := client.Group. Create(). SetName("GitHub"). SaveX(ctx) lab := client.Group. Create(). SetName("GitLab"). SaveX(ctx) a8m := client.User. Create(). SetAge(30). SetName("a8m"). AddGroups(hub, lab). SaveX(ctx) nati := client.User. Create(). SetAge(28). SetName("nati"). AddGroups(hub). SaveX(ctx) // Query the edges. groups, err := a8m. QueryGroups(). All(ctx) if err != nil { return fmt.Errorf("querying a8m groups: %v", err) } fmt.Println(groups) // Output: [Group(id=1, name=GitHub) Group(id=2, name=GitLab)] groups, err = nati. QueryGroups(). All(ctx) if err != nil { return fmt.Errorf("querying nati groups: %v", err) } fmt.Println(groups) // Output: [Group(id=1, name=GitHub)] // Traverse the graph. users, err := a8m. QueryGroups(). // [hub, lab] Where(group.Not(group.HasUsersWith(user.Name("nati")))). // [lab] QueryUsers(). // [a8m] QueryGroups(). // [hub, lab] QueryUsers(). // [a8m, nati] All(ctx) if err != nil { return fmt.Errorf("traversing the graph: %v", err) } fmt.Println(users) // Output: [User(id=1, age=30, name=a8m) User(id=2, age=28, name=nati)] return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/m2m2types). ## M2M Same Type ![er-following-followers](https://entgo.io/assets/er_following_followers.png) In this following-followers example, we have a M2M relation between users to their followers. Each user can follow **many** users, and can have **many** followers. `ent/schema/user.go` ```go // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("following", User.Type). From("followers"), } } ``` As you can see, in cases of relations of the same type, you can declare the edge and its reference in the same builder. ```diff func (User) Edges() []ent.Edge { return []ent.Edge{ + edge.To("following", User.Type). + From("followers"), - edge.To("following", User.Type), - edge.From("followers", User.Type). - Ref("following"), } } ``` The API for interacting with these edges is as follows: ```go func Do(ctx context.Context, client *ent.Client) error { // Unlike `Save`, `SaveX` panics if an error occurs. a8m := client.User. Create(). SetAge(30). SetName("a8m"). SaveX(ctx) nati := client.User. Create(). SetAge(28). SetName("nati"). AddFollowers(a8m). SaveX(ctx) // Query following/followers: flw := a8m.QueryFollowing().AllX(ctx) fmt.Println(flw) // Output: [User(id=2, age=28, name=nati)] flr := a8m.QueryFollowers().AllX(ctx) fmt.Println(flr) // Output: [] flw = nati.QueryFollowing().AllX(ctx) fmt.Println(flw) // Output: [] flr = nati.QueryFollowers().AllX(ctx) fmt.Println(flr) // Output: [User(id=1, age=30, name=a8m)] // Traverse the graph: ages := nati. QueryFollowers(). // [a8m] QueryFollowing(). // [nati] GroupBy(user.FieldAge). // [28] IntsX(ctx) fmt.Println(ages) // Output: [28] names := client.User. Query(). Where(user.Not(user.HasFollowers())). GroupBy(user.FieldName). StringsX(ctx) fmt.Println(names) // Output: [a8m] return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/m2mrecur). ## M2M Bidirectional ![er-user-friends](https://entgo.io/assets/er_user_friends.png) In this user-friends example, we have a **symmetric M2M relation** named `friends`. Each user can **have many** friends. If user A becomes a friend of B, B is also a friend of A. Note that there are no owner/inverse terms in cases of bidirectional edges. `ent/schema/user.go` ```go // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("friends", User.Type), } } ``` The API for interacting with these edges is as follows: ```go func Do(ctx context.Context, client *ent.Client) error { // Unlike `Save`, `SaveX` panics if an error occurs. a8m := client.User. Create(). SetAge(30). SetName("a8m"). SaveX(ctx) nati := client.User. Create(). SetAge(28). SetName("nati"). AddFriends(a8m). SaveX(ctx) // Query friends. Unlike `All`, `AllX` panics if an error occurs. friends := nati. QueryFriends(). AllX(ctx) fmt.Println(friends) // Output: [User(id=1, age=30, name=a8m)] friends = a8m. QueryFriends(). AllX(ctx) fmt.Println(friends) // Output: [User(id=2, age=28, name=nati)] // Query the graph: friends = client.User. Query(). Where(user.HasFriends()). AllX(ctx) fmt.Println(friends) // Output: [User(id=1, age=30, name=a8m) User(id=2, age=28, name=nati)] return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/m2mbidi). ## Required Edges can be defined as required in the entity creation using the `Required` method on the builder. ```go // Edges of the Card. func (Card) Edges() []ent.Edge { return []ent.Edge{ edge.From("owner", User.Type). Ref("card"). Unique(). Required(), } } ``` If the example above, a card entity cannot be created without its owner. ## StorageKey Custom storage configuration can be provided for edges using the `StorageKey` method. ```go // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("pets", Pet.Type). // Set the column name in the "pets" table for O2M relationship. StorageKey(edge.Column("owner_id")), edge.To("friends", User.Type). // Set the join-table and the column names for M2M relationship. StorageKey(edge.Table("friends"), edge.Columns("user_id", "friend_id")), } } ``` ## Indexes Indexes can be defined on multi fields and some types of edges as well. However, you should note, that this is currently an SQL-only feature. Read more about this in the [Indexes](schema-indexes.md) section. ## Annotations `Annotations` is used to attach arbitrary metadata to the edge object in code generation. Template extensions can retrieve this metadata and use it inside their templates. Note that the metadata object must be serializable to a JSON raw value (e.g. struct, map or slice). ```go // Pet schema. type Pet struct { ent.Schema } // Edges of the Pet. func (Pet) Edges() []ent.Edge { return []ent.Field{ edge.To("owner", User.Type). Ref("pets"). Unique(). Annotations(entgql.Annotation{ OrderField: "OWNER", }), } } ``` Read more about annotations and their usage in templates in the [template doc](templates.md#annotations). ent-0.5.4/doc/md/schema-fields.md000077500000000000000000000304401377533537200165240ustar00rootroot00000000000000--- id: schema-fields title: Fields --- ## Quick Summary Fields (or properties) in the schema are the attributes of the node. For example, a `User` with 4 fields: `age`, `name`, `username` and `created_at`: ![re-fields-properties](https://entgo.io/assets/er_fields_properties.png) Fields are returned from the schema using the `Fields` method. For example: ```go package schema import ( "time" "github.com/facebook/ent" "github.com/facebook/ent/schema/field" ) // User schema. type User struct { ent.Schema } // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.Int("age"), field.String("name"), field.String("username"). Unique(), field.Time("created_at"). Default(time.Now), } } ``` All fields are required by default, and can be set to optional using the `Optional` method. ## Types The following types are currently supported by the framework: - All Go numeric types. Like `int`, `uint8`, `float64`, etc. - `bool` - `string` - `time.Time` - `[]byte` (SQL only). - `JSON` (SQL only). - `Enum` (SQL only). - `UUID` (SQL only).
```go package schema import ( "time" "net/url" "github.com/google/uuid" "github.com/facebook/ent" "github.com/facebook/ent/schema/field" ) // User schema. type User struct { ent.Schema } // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.Int("age"). Positive(), field.Float("rank"). Optional(), field.Bool("active"). Default(false), field.String("name"). Unique(), field.Time("created_at"). Default(time.Now), field.JSON("url", &url.URL{}). Optional(), field.JSON("strings", []string{}). Optional(), field.Enum("state"). Values("on", "off"). Optional(), field.UUID("uuid", uuid.UUID{}). Default(uuid.New), } } ``` To read more about how each type is mapped to its database-type, go to the [Migration](migrate.md) section. ## ID Field The `id` field is builtin in the schema and does not need declaration. In SQL-based databases, its type defaults to `int` (but can be changed with a [codegen option](code-gen.md#code-generation-options)) and auto-incremented in the database. In order to configure the `id` field to be unique across all tables, use the [WithGlobalUniqueID](migrate.md#universal-ids) option when running schema migration. If a different configuration for the `id` field is needed, or the `id` value should be provided on entity creation by the application (e.g. UUID), override the builtin `id` configuration. For example: ```go // Fields of the Group. func (Group) Fields() []ent.Field { return []ent.Field{ field.Int("id"). StructTag(`json:"oid,omitempty"`), } } // Fields of the Blob. func (Blob) Fields() []ent.Field { return []ent.Field{ field.UUID("id", uuid.UUID{}). Default(uuid.New), } } // Fields of the Pet. func (Pet) Fields() []ent.Field { return []ent.Field{ field.String("id"). MaxLen(25). NotEmpty(). Unique(). Immutable(), } } ``` ## Database Type Each database dialect has its own mapping from Go type to database type. For example, the MySQL dialect creates `float64` fields as `double` columns in the database. However, there is an option to override the default behavior using the `SchemaType` method. ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/dialect" "github.com/facebook/ent/schema/field" ) // Card schema. type Card struct { ent.Schema } // Fields of the Card. func (Card) Fields() []ent.Field { return []ent.Field{ field.Float("amount"). SchemaType(map[string]string{ dialect.MySQL: "decimal(6,2)", // Override MySQL. dialect.Postgres: "numeric", // Override Postgres. }), } } ``` ## Go Type The default type for fields are the basic Go types. For example, for string fields, the type is `string`, and for time fields, the type is `time.Time`. The `GoType` method provides an option to override the default ent type with a custom one. The custom type must be either a type that is convertible to the Go basic type, or a type that implements the [ValueScanner](https://pkg.go.dev/github.com/facebook/ent/schema/field?tab=doc#ValueScanner) interface. ```go package schema import ( "database/sql" "github.com/facebook/ent" "github.com/facebook/ent/dialect" "github.com/facebook/ent/schema/field" ) // Amount is a custom Go type that's convertible to the basic float64 type. type Amount float64 // Card schema. type Card struct { ent.Schema } // Fields of the Card. func (Card) Fields() []ent.Field { return []ent.Field{ field.Float("amount"). GoType(Amount(0)), field.String("name"). Optional(). // A ValueScanner type. GoType(&sql.NullString{}), field.Enum("role"). // A convertible type to string. GoType(role.Unknown), } } ``` ## Default Values **Non-unique** fields support default values using the `Default` and `UpdateDefault` methods. ```go // Fields of the User. func (User) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Default(time.Now), field.Time("updated_at"). Default(time.Now). UpdateDefault(time.Now), field.String("name"). Default("unknown"), field.String("cuid"). DefaultFunc(cuid.New), } } ``` ## Validators A field validator is a function from type `func(T) error` that is defined in the schema using the `Validate` method, and applied on the field value before creating or updating the entity. The supported types of field validators are `string` and all numeric types. ```go package schema import ( "errors" "regexp" "strings" "time" "github.com/facebook/ent" "github.com/facebook/ent/schema/field" ) // Group schema. type Group struct { ent.Schema } // Fields of the group. func (Group) Fields() []ent.Field { return []ent.Field{ field.String("name"). Match(regexp.MustCompile("[a-zA-Z_]+$")). Validate(func(s string) error { if strings.ToLower(s) == s { return errors.New("group name must begin with uppercase") } return nil }), } } ``` Here is another example for writing a reusable validator: ```go // MaxRuneCount validates the rune length of a string by using the unicode/utf8 package. func MaxRuneCount(maxLen int) func(s string) error { return func(s string) error { if utf8.RuneCountInString(s) > maxLen { return errors.New("value is more than the max length") } return nil } } field.String("name"). Validate(MaxRuneCount(10)) field.String("nickname"). Validate(MaxRuneCount(20)) ``` ## Built-in Validators The framework provides a few built-in validators for each type: - Numeric types: - `Positive()` - `Negative()` - `NonNegative()` - `Min(i)` - Validate that the given value is > i. - `Max(i)` - Validate that the given value is < i. - `Range(i, j)` - Validate that the given value is within the range [i, j]. - `string` - `MinLen(i)` - `MaxLen(i)` - `Match(regexp.Regexp)` - `NotEmpty` ## Optional Optional fields are fields that are not required in the entity creation, and will be set to nullable fields in the database. Unlike edges, **fields are required by default**, and setting them to optional should be done explicitly using the `Optional` method. ```go // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.String("required_name"), field.String("optional_name"). Optional(), } } ``` ## Nillable Sometimes you want to be able to distinguish between the zero value of fields and `nil`; for example if the database column contains `0` or `NULL`. The `Nillable` option exists exactly for this. If you have an `Optional` field of type `T`, setting it to `Nillable` will generate a struct field with type `*T`. Hence, if the database returns `NULL` for this field, the struct field will be `nil`. Otherwise, it will contains a pointer to the actual data. For example, given this schema: ```go // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.String("required_name"), field.String("optional_name"). Optional(), field.String("nillable_name"). Optional(). Nillable(), } } ``` The generated struct for the `User` entity will be as follows: ```go // ent/user.go package ent // User entity. type User struct { RequiredName string `json:"required_name,omitempty"` OptionalName string `json:"optional_name,omitempty"` NillableName *string `json:"nillable_name,omitempty"` } ``` ## Immutable Immutable fields are fields that can be set only in the creation of the entity. i.e., no setters will be generated for the entity updater. ```go // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.String("name"), field.Time("created_at"). Default(time.Now). Immutable(), } } ``` ## Uniqueness Fields can be defined as unique using the `Unique` method. Note that unique fields cannot have default values. ```go // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.String("name"), field.String("nickname"). Unique(), } } ``` ## Storage Key Custom storage name can be configured using the `StorageKey` method. It's mapped to a column name in SQL dialects and to property name in Gremlin. ```go // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.String("name"). StorageKey("old_name"), } } ``` ## Indexes Indexes can be defined on multi fields and some types of edges as well. However, you should note, that this is currently an SQL-only feature. Read more about this in the [Indexes](schema-indexes.md) section. ## Struct Tags Custom struct tags can be added to the generated entities using the `StructTag` method. Note that if this option was not provided, or provided and did not contain the `json` tag, the default `json` tag will be added with the field name. ```go // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.String("name"). StructTag(`gqlgen:"gql_name"`), } } ``` ## Additional Struct Fields By default, `ent` generates the entity model with fields that are configured in the `schema.Fields` method. For example, given this schema configuration: ```go // User schema. type User struct { ent.Schema } // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.Int("age"). Optional(). Nillable(), field.String("name"). StructTag(`gqlgen:"gql_name"`), } } ``` The generated model will be as follows: ```go // User is the model entity for the User schema. type User struct { // Age holds the value of the "age" field. Age *int `json:"age,omitempty"` // Name holds the value of the "name" field. Name string `json:"name,omitempty" gqlgen:"gql_name"` } ``` In order to add additional fields to the generated struct **that are not stored in the database**, use [external templates](code-gen.md/#external-templates). For example: ```gotemplate {{ define "model/fields/additional" }} {{- if eq $.Name "User" }} // StaticField defined by template. StaticField string `json:"static,omitempty"` {{- end }} {{ end }} ``` The generated model will be as follows: ```go // User is the model entity for the User schema. type User struct { // Age holds the value of the "age" field. Age *int `json:"age,omitempty"` // Name holds the value of the "name" field. Name string `json:"name,omitempty" gqlgen:"gql_name"` // StaticField defined by template. StaticField string `json:"static,omitempty"` } ``` ## Sensitive Fields String fields can be defined as sensitive using the `Sensitive` method. Sensitive fields won't be printed and they will be omitted when encoding. Note that sensitive fields cannot have struct tags. ```go // User schema. type User struct { ent.Schema } // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.String("password"). Sensitive(), } } ``` ## Annotations `Annotations` is used to attach arbitrary metadata to the field object in code generation. Template extensions can retrieve this metadata and use it inside their templates. Note that the metadata object must be serializable to a JSON raw value (e.g. struct, map or slice). ```go // User schema. type User struct { ent.Schema } // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.Time("creation_date"). Annotations(entgql.Annotation{ OrderField: "CREATED_AT", }), } } ``` Read more about annotations and their usage in templates in the [template doc](templates.md#annotations). ent-0.5.4/doc/md/schema-indexes.md000077500000000000000000000056661377533537200167310ustar00rootroot00000000000000--- id: schema-indexes title: Indexes --- ## Multiple Fields Indexes can be configured on one or more fields in order to improve speed of data retrieval, or defining uniqueness. ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/schema/index" ) // User holds the schema definition for the User entity. type User struct { ent.Schema } func (User) Indexes() []ent.Index { return []ent.Index{ // non-unique index. index.Fields("field1", "field2"), // unique index. index.Fields("first_name", "last_name"). Unique(), } } ``` Note that for setting a single field as unique, use the `Unique` method on the field builder as follows: ```go func (User) Fields() []ent.Field { return []ent.Field{ field.String("phone"). Unique(), } } ``` ## Index On Edges Indexes can be configured on composition of fields and edges. The main use-case is setting uniqueness on fields under a specific relation. Let's take an example: ![er-city-streets](https://entgo.io/assets/er_city_streets.png) In the example above, we have a `City` with many `Street`s, and we want to set the street name to be unique under each city. `ent/schema/city.go` ```go // City holds the schema definition for the City entity. type City struct { ent.Schema } // Fields of the City. func (City) Fields() []ent.Field { return []ent.Field{ field.String("name"), } } // Edges of the City. func (City) Edges() []ent.Edge { return []ent.Edge{ edge.To("streets", Street.Type), } } ``` `ent/schema/street.go` ```go // Street holds the schema definition for the Street entity. type Street struct { ent.Schema } // Fields of the Street. func (Street) Fields() []ent.Field { return []ent.Field{ field.String("name"), } } // Edges of the Street. func (Street) Edges() []ent.Edge { return []ent.Edge{ edge.From("city", City.Type). Ref("streets"). Unique(), } } // Indexes of the Street. func (Street) Indexes() []ent.Index { return []ent.Index{ index.Fields("name"). Edges("city"). Unique(), } } ``` `example.go` ```go func Do(ctx context.Context, client *ent.Client) error { // Unlike `Save`, `SaveX` panics if an error occurs. tlv := client.City. Create(). SetName("TLV"). SaveX(ctx) nyc := client.City. Create(). SetName("NYC"). SaveX(ctx) // Add a street "ST" to "TLV". client.Street. Create(). SetName("ST"). SetCity(tlv). SaveX(ctx) // This operation will fail because "ST" // is already created under "TLV". _, err := client.Street. Create(). SetName("ST"). SetCity(tlv). Save(ctx) if err == nil { return fmt.Errorf("expecting creation to fail") } // Add a street "ST" to "NYC". client.Street. Create(). SetName("ST"). SetCity(nyc). SaveX(ctx) return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/edgeindex). ## Dialect Support Indexes currently support only SQL dialects, and do not support Gremlin. ent-0.5.4/doc/md/schema-mixin.md000077500000000000000000000062301377533537200164020ustar00rootroot00000000000000--- id: schema-mixin title: Mixin --- A `Mixin` allows you to create reusable pieces of `ent.Schema` code. The `ent.Mixin` interface is as follows: ```go type Mixin interface { // Fields returns a slice of fields to add to the schema. Fields() []Field // Edges returns a slice of edges to add to the schema. Edges() []Edge // Indexes returns a slice of indexes to add to the schema. Indexes() []Index // Hooks returns a slice of hooks to add to the schema. // Note that mixin hooks are executed before schema hooks. Hooks() []Hook // Policy returns a privacy policy to add to the schema. // Note that mixin policy are executed before schema policy. Policy() Policy // Annotations returns a list of schema annotations to add // to the schema annotations. Annotations() []schema.Annotation } ``` ## Example A common use case for `Mixin` is to mix-in a list of common fields to your schema. ```go package schema import ( "time" "github.com/facebook/ent" "github.com/facebook/ent/schema/field" "github.com/facebook/ent/schema/mixin" ) // ------------------------------------------------- // Mixin definition // TimeMixin implements the ent.Mixin for sharing // time fields with package schemas. type TimeMixin struct{ // We embed the `mixin.Schema` to avoid // implementing the rest of the methods. mixin.Schema } func (TimeMixin) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Immutable(). Default(time.Now), field.Time("updated_at"). Default(time.Now). UpdateDefault(time.Now), } } // DetailsMixin implements the ent.Mixin for sharing // entity details fields with package schemas. type DetailsMixin struct{ // We embed the `mixin.Schema` to avoid // implementing the rest of the methods. mixin.Schema } func (DetailsMixin) Fields() []ent.Field { return []ent.Field{ field.Int("age"). Positive(), field.String("name"). NotEmpty(), } } // ------------------------------------------------- // Schema definition // User schema mixed-in the TimeMixin and DetailsMixin fields and therefore // has 5 fields: `created_at`, `updated_at`, `age`, `name` and `nickname`. type User struct { ent.Schema } func (User) Mixin() []ent.Mixin { return []ent.Mixin{ TimeMixin{}, DetailsMixin{}, } } func (User) Fields() []ent.Field { return []ent.Field{ field.String("nickname"). Unique(), } } // Pet schema mixed-in the DetailsMixin fields and therefore // has 3 fields: `age`, `name` and `weight`. type Pet struct { ent.Schema } func (Pet) Mixin() []ent.Mixin { return []ent.Mixin{ DetailsMixin{}, } } func (Pet) Fields() []ent.Field { return []ent.Field{ field.Float("weight"), } } ``` ## Builtin Mixin Package `mixin` provides a few builtin mixins that can be used for adding the `create_time` and `update_time` fields to the schema. In order to use them, add the `mixin.Time` mixin to your schema as follows: ```go package schema import ( "github.com/facebook/ent" "github.com/facebook/ent/schema/mixin" ) type Pet struct { ent.Schema } func (Pet) Mixin() []ent.Mixin { return []ent.Mixin{ mixin.Time{}, // Or, mixin.CreateTime only for create_time // and mixin.UpdateTime only for update_time. } } ``` ent-0.5.4/doc/md/sql-integration.md000066400000000000000000000053751377533537200171460ustar00rootroot00000000000000--- id: sql-integration title: sql.DB Integration --- The following examples show how to pass a custom `sql.DB` object to `ent.Client`. ## Configure `sql.DB` First option: ```go package main import ( "time" "/ent" "github.com/facebook/ent/dialect/sql" ) func Open() (*ent.Client, error) { drv, err := sql.Open("mysql", "") if err != nil { return nil, err } // Get the underlying sql.DB object of the driver. db := drv.DB() db.SetMaxIdleConns(10) db.SetMaxOpenConns(100) db.SetConnMaxLifetime(time.Hour) return ent.NewClient(ent.Driver(drv)), nil } ``` Second option: ```go package main import ( "database/sql" "time" "/ent" entsql "github.com/facebook/ent/dialect/sql" ) func Open() (*ent.Client, error) { db, err := sql.Open("mysql", "") if err != nil { return nil, err } db.SetMaxIdleConns(10) db.SetMaxOpenConns(100) db.SetConnMaxLifetime(time.Hour) // Create an ent.Driver from `db`. drv := entsql.OpenDB("mysql", db) return ent.NewClient(ent.Driver(drv)), nil } ``` ## Use Opencensus With MySQL ```go package main import ( "context" "database/sql" "database/sql/driver" "/ent" "contrib.go.opencensus.io/integrations/ocsql" "github.com/go-sql-driver/mysql" entsql "github.com/facebook/ent/dialect/sql" ) type connector struct { dsn string } func (c connector) Connect(context.Context) (driver.Conn, error) { return c.Driver().Open(c.dsn) } func (connector) Driver() driver.Driver { return ocsql.Wrap( mysql.MySQLDriver{}, ocsql.WithAllTraceOptions(), ocsql.WithRowsClose(false), ocsql.WithRowsNext(false), ocsql.WithDisableErrSkip(true), ) } // Open new connection and start stats recorder. func Open(dsn string) *ent.Client { db := sql.OpenDB(connector{dsn}) // Create an ent.Driver from `db`. drv := entsql.OpenDB("mysql", db) return ent.NewClient(ent.Driver(drv)) } ``` ## Use pgx with PostgreSQL ```go package main import ( "context" "database/sql" "log" "/ent" "github.com/facebook/ent/dialect" entsql "github.com/facebook/ent/dialect/sql" _ "github.com/jackc/pgx/v4/stdlib" ) // Open new connection func Open(databaseUrl string) *ent.Client { db, err := sql.Open("pgx", databaseUrl) if err != nil { log.Fatal(err) } // Create an ent.Driver from `db`. drv := entsql.OpenDB(dialect.Postgres, db) return ent.NewClient(ent.Driver(drv)) } func main() { client := Open("postgresql://user:password@127.0.0.1/database") // Your code. For example: ctx := context.Background() if err := client.Schema.Create(ctx); err != nil { log.Fatal(err) } users, err := client.User.Query().All(ctx) if err != nil { log.Fatal(err) } log.Println(users) } ``` ent-0.5.4/doc/md/templates.md000066400000000000000000000077631377533537200160270ustar00rootroot00000000000000--- id: templates title: External Templates --- `ent` accepts external [Go templates](https://golang.org/pkg/text/template) to execute using the `--template` flag. If the template name already defined by `ent`, it will override the existing one. Otherwise, it will write the execution output to a file with the same name as the template. For example: `stringer.tmpl` - This template example will be written in a file named: `ent/stringer.go`. ```gotemplate {{ define "stringer" }} {{/* Add the base header for the generated file */}} {{ $pkg := base $.Config.Package }} {{ template "header" $ }} {{/* Loop over all nodes and add implement the "GoStringer" interface */}} {{ range $n := $.Nodes }} {{ $receiver := $n.Receiver }} func ({{ $receiver }} *{{ $n.Name }}) GoString() string { if {{ $receiver }} == nil { return fmt.Sprintf("{{ $n.Name }}(nil)") } return {{ $receiver }}.String() } {{ end }} {{ end }} ``` `debug.tmpl` - This template example will be written in a file named: `ent/debug.go`. ```gotemplate {{ define "debug" }} {{/* A template that adds the functionality for running each client in debug mode */}} {{/* Add the base header for the generated file */}} {{ $pkg := base $.Config.Package }} {{ template "header" $ }} {{/* Loop over all nodes and add option the "Debug" method */}} {{ range $n := $.Nodes }} {{ $client := print $n.Name "Client" }} func (c *{{ $client }}) Debug() *{{ $client }} { if c.debug { return c } cfg := config{driver: dialect.Debug(c.driver, c.log), log: c.log, debug: true, hooks: c.hooks} return &{{ $client }}{config: cfg} } {{ end }} {{ end }} ``` In order to override an existing template, use its name. For example: ```gotemplate {{/* A template for adding additional fields to specific types. */}} {{ define "model/fields/additional" }} {{- /* Add static fields to the "Card" entity. */}} {{- if eq $.Name "Card" }} // StaticField defined by templates. StaticField string `json:"static_field,omitempty"` {{- end }} {{ end }} ``` ## Annotations Schema annotations allow attaching metadata to fields and edges and inject them to external templates. An annotation must be a Go type that is serializable to JSON raw value (e.g. struct, map or slice) and implement the [Annotation](https://pkg.go.dev/github.com/facebook/ent/schema/field?tab=doc#Annotation) interface. Here's an example of an annotation and its usage in schema and template: 1\. An annotation definition: ```go package entgql // Annotation annotates fields with metadata for templates. type Annotation struct { // OrderField is the ordering field as defined in graphql schema. OrderField string } // Name implements ent.Annotation interface. func (Annotation) Name() string { return "EntGQL" } ``` 2\. Annotation usage in ent/schema: ```go // User schema. type User struct { ent.Schema } // Fields of the user. func (User) Fields() []ent.Field { return []ent.Field{ field.Time("creation_date"). Annotations(entgql.Annotation{ OrderField: "CREATED_AT", }), } } ``` 3\. Annotation usage in external template: ```gotemplate {{ range $node := $.Nodes }} {{ range $f := $node.Fields }} {{/* Get the annotation by its name. See: Annotation.Name */}} {{ if $annotation := $f.Annotations.EntGQL }} {{/* Get the field from the annotation. */}} {{ $orderField := $annotation.OrderField }} {{ end }} {{ end }} {{ end }} ``` ## Examples - A custom template for implementing the `Node` API for GraphQL - [Github](https://github.com/facebook/ent/blob/master/entc/integration/template/ent/template/node.tmpl). - An example for executing external templates with custom functions. See [configuration](https://github.com/facebook/ent/blob/master/examples/entcpkg/ent/entc.go) and its [README](https://github.com/facebook/ent/blob/master/examples/entcpkg) file. ## Documentation Templates are executed on either a specific node-type or the entire schema graph. For API documentation, see the
GoDoc. ent-0.5.4/doc/md/testing.md000066400000000000000000000014601377533537200154720ustar00rootroot00000000000000--- id: testing title: Testing --- If you're using `ent.Client` in your unit-tests, you can use the generated `enttest` package for creating a client and auto-running the schema migration as follows: ```go package main import ( "testing" "/ent/enttest" _ "github.com/mattn/go-sqlite3" ) func TestXXX(t *testing.T) { client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") defer client.Close() // ... } ``` In order to pass functional options to `Open`, use `enttest.Option`: ```go func TestXXX(t *testing.T) { opts := []enttest.Option{ enttest.WithOptions(ent.Log(t.Log)), enttest.WithMigrateOptions(migrate.WithGlobalUniqueID(true)), } client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1", opts...) defer client.Close() // ... } ``` ent-0.5.4/doc/md/transactions.md000077500000000000000000000100501377533537200165230ustar00rootroot00000000000000--- id: transactions title: Transactions --- ## Starting A Transaction ```go // GenTx generates group of entities in a transaction. func GenTx(ctx context.Context, client *ent.Client) error { tx, err := client.Tx(ctx) if err != nil { return fmt.Errorf("starting a transaction: %v", err) } hub, err := tx.Group. Create(). SetName("Github"). Save(ctx) if err != nil { return rollback(tx, fmt.Errorf("failed creating the group: %v", err)) } // Create the admin of the group. dan, err := tx.User. Create(). SetAge(29). SetName("Dan"). AddManage(hub). Save(ctx) if err != nil { return rollback(tx, err) } // Create user "Ariel". a8m, err := tx.User. Create(). SetAge(30). SetName("Ariel"). AddGroups(hub). AddFriends(dan). Save(ctx) if err != nil { return rollback(tx, err) } fmt.Println(a8m) // Output: // User(id=2, age=30, name=Ariel) // Commit the transaction. return tx.Commit() } // rollback calls to tx.Rollback and wraps the given error // with the rollback error if occurred. func rollback(tx *ent.Tx, err error) error { if rerr := tx.Rollback(); rerr != nil { err = fmt.Errorf("%v: %v", err, rerr) } return err } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/traversal). ## Transactional Client Sometimes, you have an existing code that already works with `*ent.Client`, and you want to change it (or wrap it) to interact with transactions. For these use cases, you have a transactional client. An `*ent.Client` that you can get from an existing transaction. ```go // WrapGen wraps the existing "Gen" function in a transaction. func WrapGen(ctx context.Context, client *ent.Client) error { tx, err := client.Tx(ctx) if err != nil { return err } txClient := tx.Client() // Use the "Gen" below, but give it the transactional client; no code changes to "Gen". if err := Gen(ctx, txClient); err != nil { return rollback(tx, err) } return tx.Commit() } // Gen generates a group of entities. func Gen(ctx context.Context, client *ent.Client) error { // ... return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/traversal). ## Best Practices Reusable function that runs callbacks in a transaction: ```go func WithTx(ctx context.Context, client *ent.Client, fn func(tx *ent.Tx) error) error { tx, err := client.Tx(ctx) if err != nil { return err } defer func() { if v := recover(); v != nil { tx.Rollback() panic(v) } }() if err := fn(tx); err != nil { if rerr := tx.Rollback(); rerr != nil { err = errors.Wrapf(err, "rolling back transaction: %v", rerr) } return err } if err := tx.Commit(); err != nil { return errors.Wrapf(err, "committing transaction: %v", err) } return nil } ``` Its usage: ```go func Do(ctx context.Context, client *ent.Client) { // WithTx helper. if err := WithTx(ctx, client, func(tx *ent.Tx) error { return Gen(ctx, tx.Client()) }); err != nil { log.Fatal(err) } } ``` ## Hooks Same as [schema hooks](hooks.md#schema-hooks) and [runtime hooks](hooks.md#runtime-hooks), hooks can be registered on active transactions, and will be executed on `Tx.Commit` or `Tx.Rollback`: ```go func Do(ctx context.Context, client *ent.Client) error { tx, err := client.Tx(ctx) if err != nil { return err } // Add a hook on Tx.Commit. tx.OnCommit(func(next ent.Committer) ent.Committer { return ent.CommitFunc(func(ctx context.Context, tx *ent.Tx) error { // Code before the actual commit. err := next.Commit(ctx, tx) // Code after the transaction was committed. return err }) }) // Add a hook on Tx.Rollback. tx.OnRollback(func(next ent.Rollbacker) ent.Rollbacker { return ent.RollbackFunc(func(ctx context.Context, tx *ent.Tx) error { // Code before the actual rollback. err := next.Rollback(ctx, tx) // Code after the transaction was rolled back. return err }) }) // // // return err } ``` ent-0.5.4/doc/md/traversals.md000077500000000000000000000112151377533537200162050ustar00rootroot00000000000000--- id: traversals title: Graph Traversal --- For the purpose of the example, we'll generate the following graph: ![er-traversal-graph](https://entgo.io/assets/er_traversal_graph.png) The first step is to generate the 3 schemas: `Pet`, `User`, `Group`. ```console go run github.com/facebook/ent/cmd/ent init Pet User Group ``` Add the necessary fields and edges for the schemas: `ent/schema/pet.go` ```go // Pet holds the schema definition for the Pet entity. type Pet struct { ent.Schema } // Fields of the Pet. func (Pet) Fields() []ent.Field { return []ent.Field{ field.String("name"), } } // Edges of the Pet. func (Pet) Edges() []ent.Edge { return []ent.Edge{ edge.To("friends", Pet.Type), edge.From("owner", User.Type). Ref("pets"). Unique(), } } ``` `ent/schema/user.go` ```go // User holds the schema definition for the User entity. type User struct { ent.Schema } // Fields of the User. func (User) Fields() []ent.Field { return []ent.Field{ field.Int("age"), field.String("name"), } } // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("pets", Pet.Type), edge.To("friends", User.Type), edge.From("groups", Group.Type). Ref("users"), edge.From("manage", Group.Type). Ref("admin"), } } ``` `ent/schema/group.go` ```go // Group holds the schema definition for the Group entity. type Group struct { ent.Schema } // Fields of the Group. func (Group) Fields() []ent.Field { return []ent.Field{ field.String("name"), } } // Edges of the Group. func (Group) Edges() []ent.Edge { return []ent.Edge{ edge.To("users", User.Type), edge.To("admin", User.Type). Unique(), } } ``` Let's write the code for populating the vertices and the edges to the graph: ```go func Gen(ctx context.Context, client *ent.Client) error { hub, err := client.Group. Create(). SetName("Github"). Save(ctx) if err != nil { return fmt.Errorf("failed creating the group: %v", err) } // Create the admin of the group. // Unlike `Save`, `SaveX` panics if an error occurs. dan := client.User. Create(). SetAge(29). SetName("Dan"). AddManage(hub). SaveX(ctx) // Create "Ariel" and its pets. a8m := client.User. Create(). SetAge(30). SetName("Ariel"). AddGroups(hub). AddFriends(dan). SaveX(ctx) pedro := client.Pet. Create(). SetName("Pedro"). SetOwner(a8m). SaveX(ctx) xabi := client.Pet. Create(). SetName("Xabi"). SetOwner(a8m). SaveX(ctx) // Create "Alex" and its pets. alex := client.User. Create(). SetAge(37). SetName("Alex"). SaveX(ctx) coco := client.Pet. Create(). SetName("Coco"). SetOwner(alex). AddFriends(pedro). SaveX(ctx) fmt.Println("Pets created:", pedro, xabi, coco) // Output: // Pets created: Pet(id=1, name=Pedro) Pet(id=2, name=Xabi) Pet(id=3, name=Coco) return nil } ``` Let's go over a few traversals, and show the code for them: ![er-traversal-graph-gopher](https://entgo.io/assets/er_traversal_graph_gopher.png) The traversal above starts from a `Group` entity, continues to its `admin` (edge), continues to its `friends` (edge), gets their `pets` (edge), gets each pet's `friends` (edge), and requests their owners. ```go func Traverse(ctx context.Context, client *ent.Client) error { owner, err := client.Group. // GroupClient. Query(). // Query builder. Where(group.Name("Github")). // Filter only Github group (only 1). QueryAdmin(). // Getting Dan. QueryFriends(). // Getting Dan's friends: [Ariel]. QueryPets(). // Their pets: [Pedro, Xabi]. QueryFriends(). // Pedro's friends: [Coco], Xabi's friends: []. QueryOwner(). // Coco's owner: Alex. Only(ctx) // Expect only one entity to return in the query. if err != nil { return fmt.Errorf("failed querying the owner: %v", err) } fmt.Println(owner) // Output: // User(id=3, age=37, name=Alex) return nil } ``` What about the following traversal? ![er-traversal-graph-gopher-query](https://entgo.io/assets/er_traversal_graph_gopher_query.png) We want to get all pets (entities) that have an `owner` (`edge`) that is a `friend` (edge) of some group `admin` (edge). ```go func Traverse(ctx context.Context, client *ent.Client) error { pets, err := client.Pet. Query(). Where( pet.HasOwnerWith( user.HasFriendsWith( user.HasManage(), ), ), ). All(ctx) if err != nil { return fmt.Errorf("failed querying the pets: %v", err) } fmt.Println(pets) // Output: // [Pet(id=1, name=Pedro) Pet(id=2, name=Xabi)] return nil } ``` The full example exists in [GitHub](https://github.com/facebook/ent/tree/master/examples/traversal). ent-0.5.4/doc/website/000077500000000000000000000000001377533537200145345ustar00rootroot00000000000000ent-0.5.4/doc/website/README.md000077500000000000000000000077411377533537200160270ustar00rootroot00000000000000This website was created with [Docusaurus](https://docusaurus.io/). # What's In This Document * [Get Started in 5 Minutes](#get-started-in-5-minutes) * [Directory Structure](#directory-structure) * [Editing Content](#editing-content) * [Adding Content](#adding-content) * [Full Documentation](#full-documentation) # Get Started in 5 Minutes 1. Make sure all the dependencies for the website are installed: ```sh # Install dependencies $ yarn ``` 2. Run your dev server: ```sh # Start the site $ yarn start ``` ## Directory Structure Your project file structure should look something like this ``` my-docusaurus/ docs/ doc-1.md doc-2.md doc-3.md website/ blog/ 2016-3-11-oldest-post.md 2017-10-24-newest-post.md core/ node_modules/ pages/ static/ css/ img/ package.json sidebar.json siteConfig.js ``` # Editing Content ## Editing an existing docs page Edit docs by navigating to `docs/` and editing the corresponding document: `docs/doc-to-be-edited.md` ```markdown --- id: page-needs-edit title: This Doc Needs To Be Edited --- Edit me... ``` For more information about docs, click [here](https://docusaurus.io/docs/en/navigation) ## Editing an existing blog post Edit blog posts by navigating to `website/blog` and editing the corresponding post: `website/blog/post-to-be-edited.md` ```markdown --- id: post-needs-edit title: This Blog Post Needs To Be Edited --- Edit me... ``` For more information about blog posts, click [here](https://docusaurus.io/docs/en/adding-blog) # Adding Content ## Adding a new docs page to an existing sidebar 1. Create the doc as a new markdown file in `/docs`, example `docs/newly-created-doc.md`: ```md --- id: newly-created-doc title: This Doc Needs To Be Edited --- My new content here.. ``` 1. Refer to that doc's ID in an existing sidebar in `website/sidebar.json`: ```javascript // Add newly-created-doc to the Getting Started category of docs { "docs": { "Getting Started": [ "quick-start", "newly-created-doc" // new doc here ], ... }, ... } ``` For more information about adding new docs, click [here](https://docusaurus.io/docs/en/navigation) ## Adding a new blog post 1. Make sure there is a header link to your blog in `website/siteConfig.js`: `website/siteConfig.js` ```javascript headerLinks: [ ... { blog: true, label: 'Blog' }, ... ] ``` 2. Create the blog post with the format `YYYY-MM-DD-My-Blog-Post-Title.md` in `website/blog`: `website/blog/2018-05-21-New-Blog-Post.md` ```markdown --- author: Frank Li authorURL: https://twitter.com/foobarbaz authorFBID: 503283835 title: New Blog Post --- Lorem Ipsum... ``` For more information about blog posts, click [here](https://docusaurus.io/docs/en/adding-blog) ## Adding items to your site's top navigation bar 1. Add links to docs, custom pages or external links by editing the headerLinks field of `website/siteConfig.js`: `website/siteConfig.js` ```javascript { headerLinks: [ ... /* you can add docs */ { doc: 'my-examples', label: 'Examples' }, /* you can add custom pages */ { page: 'help', label: 'Help' }, /* you can add external links */ { href: 'https://github.com/facebook/Docusaurus', label: 'GitHub' }, ... ], ... } ``` For more information about the navigation bar, click [here](https://docusaurus.io/docs/en/navigation) ## Adding custom pages 1. Docusaurus uses React components to build pages. The components are saved as .js files in `website/pages/en`: 1. If you want your page to show up in your navigation header, you will need to update `website/siteConfig.js` to add to the `headerLinks` element: `website/siteConfig.js` ```javascript { headerLinks: [ ... { page: 'my-new-custom-page', label: 'My New Custom Page' }, ... ], ... } ``` For more information about custom pages, click [here](https://docusaurus.io/docs/en/custom-pages). # Full Documentation Full documentation can be found on the [website](https://docusaurus.io/). ent-0.5.4/doc/website/blog/000077500000000000000000000000001377533537200154575ustar00rootroot00000000000000ent-0.5.4/doc/website/blog/2019-10-03-introducing-ent.md000066400000000000000000000055301377533537200221440ustar00rootroot00000000000000--- title: Introducing ent author: Ariel Mashraki authorURL: https://github.com/a8m authorImageURL: https://avatars0.githubusercontent.com/u/7413593 authorTwitter: arielmashraki --- ## The state of Go in Facebook Connectivity Tel Aviv 20 months ago, I joined Facebook Connectivity (FBC) team in Tel Aviv after ~5 years of programming in Go and embedding it in a few companies. I joined a team that was working on a new project and we needed to choose a language for this mission. We compared a few languages and decided to go with Go. Since then, Go continued to spread across other FBC projects and became a big success with around 15 Go engineers in Tel Aviv alone. **New services are now written in Go**. ## The motivation for writing a new ORM in Go Most of my work in my 5 years before Facebook was on infra tooling and micro-services without too much data-model work. A service that was needed to do a little amount of work with an SQL database used one of the existing open-source solutions, but one that had worked with a complicated data model was written in a different language with a robust ORM. For example, Python with SQLAlchemy. At Facebook we like to think about our data-model in graph concepts. We've had a good experience with this model internally. The lack of a proper Graph-based ORM for Go, led us to write one here with the following principles: - **Schema As Code** - defining types, relations and constraints should be in Go code (not struct tags), and should be validated using a CLI tool. We have good experience with a similar tool internally at Facebook. - **Statically typed and explicit API** using codegen - API with `interface{}`s everywhere affects developers efficiency; especially project newbies. - **Queries, aggregations and graph traversals** should be simple - developers don’t want to deal with raw SQL queries nor SQL terms. - **Predicates should be statically typed**. No strings everywhere. - Full support for `context.Context` - This helps us to get full visibility in our traces and logs systems, and it’s important for other features like cancellation. - **Storage agnostic** - we tried to keep the storage layer dynamic using codegen templates, since the development initially started on Gremlin (AWS Neptune) and switched later to MySQL. ## Open-sourcing ent **ent** is an entity framework (ORM) for Go, built with the principles described above. **ent** makes it possible to define any data model or graph-structure in Go code easily; The schema configuration is verified by **entc** (the ent codegen) that generates an idiomatic and statically-typed API that keeps Go developers productive and happy. It supports MySQL, SQLite (mainly for testing) and Gremlin. PostgreSQL will be added soon. We’re open-sourcing **ent** today, and invite you to get started → [entgo.io/docs/getting-started](/docs/getting-started). ent-0.5.4/doc/website/core/000077500000000000000000000000001377533537200154645ustar00rootroot00000000000000ent-0.5.4/doc/website/core/Footer.js000077500000000000000000000054501377533537200172670ustar00rootroot00000000000000/** * Copyright (c) 2017-present, Facebook, Inc. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. * * @format */ const React = require('react'); class Footer extends React.Component { docUrl(doc, language) { const baseUrl = this.props.config.baseUrl; const docsUrl = this.props.config.docsUrl; const docsPart = `${docsUrl ? `${docsUrl}/` : ''}`; const langPart = `${language ? `${language}/` : ''}`; return `${baseUrl}${docsPart}${langPart}${doc}`; } pageUrl(doc, language) { const baseUrl = this.props.config.baseUrl; return baseUrl + (language ? `${language}/` : '') + doc; } render() { return (

); } } module.exports = Footer; ent-0.5.4/doc/website/package.json000066400000000000000000000005701377533537200170240ustar00rootroot00000000000000{ "scripts": { "examples": "docusaurus-examples", "start": "docusaurus-start", "build": "docusaurus-build", "publish-gh-pages": "docusaurus-publish", "write-translations": "docusaurus-write-translations", "version": "docusaurus-version", "rename-version": "docusaurus-rename-version" }, "devDependencies": { "docusaurus": "^1.14.0" } } ent-0.5.4/doc/website/pages/000077500000000000000000000000001377533537200156335ustar00rootroot00000000000000ent-0.5.4/doc/website/pages/en/000077500000000000000000000000001377533537200162355ustar00rootroot00000000000000ent-0.5.4/doc/website/pages/en/index.js000077500000000000000000000113211377533537200177030ustar00rootroot00000000000000/** * Copyright (c) 2017-present, Facebook, Inc. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. * * @format */ const React = require('react'); const CompLibrary = require('../../core/CompLibrary.js'); const MarkdownBlock = CompLibrary.MarkdownBlock; /* Used to read markdown */ const Container = CompLibrary.Container; const GridBlock = CompLibrary.GridBlock; const arrow = '\u2192'; const Block = props => ( ); const Features = () => (
); class HomeSplash extends React.Component { render() { const {siteConfig, language = ''} = this.props; const {baseUrl, docsUrl} = siteConfig; const docsPart = `${docsUrl ? `${docsUrl}/` : ''}`; const langPart = `${language ? `${language}/` : ''}`; const docUrl = doc => `${baseUrl}${docsPart}${langPart}${doc}`; const SplashContainer = props => (
{props.children}
); const Logo = props => (
Project Logo
); const ProjectTitle = () => (

{siteConfig.tagline}

Simple, yet powerful ORM for modeling and querying data.

); const PromoSection = props => (
{props.children}
); const Button = props => ( ); return ( {/**/} ); } } class Index extends React.Component { render() { const {config: siteConfig, language = ''} = this.props; const {baseUrl} = siteConfig; const Showcase = () => { if ((siteConfig.users || []).length === 0) { return null; } const showcase = siteConfig.users .filter(user => user.pinned) .map(user => ( {user.caption} )); const pageUrl = page => baseUrl + (language ? `${language}/` : '') + page; return (

Who is Using This?

This project is used by all these people

{showcase}
); }; return (
); } } module.exports = Index; ent-0.5.4/doc/website/sidebars.json000077500000000000000000000012001377533537200172170ustar00rootroot00000000000000{ "md": { "Getting Started": [ "getting-started" ], "Schema": [ "schema-def", "schema-fields", "schema-edges", "schema-indexes", "schema-mixin", "schema-annotations" ], "Code Generation": [ "code-gen", "crud", "traversals", "eager-load", "hooks", "privacy", "transactions", "predicates", "aggregate", "paging" ], "Migration": [ "migrate", "dialects" ], "Misc": [ "templates", "graphql", "sql-integration", "testing", "faq", "feature-flags" ] } } ent-0.5.4/doc/website/siteConfig.js000066400000000000000000000110271377533537200171650ustar00rootroot00000000000000/** * Copyright (c) 2017-present, Facebook, Inc. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. * * @format */ // See https://docusaurus.io/docs/site-config for all the possible // site configuration options. // List of projects/orgs using your project for the users page. const users = [ { caption: 'User1', // You will need to prepend the image path with your baseUrl // if it is not '/', like: '/test-site/img/image.jpg'. image: '/img/undraw_open_source.svg', infoLink: 'https://www.facebook.com', pinned: true, }, ]; const siteConfig = { title: 'ent', // Title for your website. tagline: 'An entity framework for Go', url: 'https://entgo.io', // Your website URL baseUrl: '/', // Base URL for your project */ // Used for publishing and more projectName: 'ent', organizationName: 'facebook', customDocsPath: 'md', // For no header links in the top nav bar -> headerLinks: [], headerLinks: [ {doc: 'getting-started', label: 'Docs'}, {href: 'https://pkg.go.dev/github.com/facebook/ent?tab=doc', label: 'GoDoc'}, {href: 'https://github.com/facebook/ent', label: 'Github'}, { blog: true, label: 'Blog' }, ], // If you have users set above, you add it here: users, /* path to images for header/footer */ headerIcon: 'img/logo.png', favicon: 'img/favicon.ico', /* Colors for website */ colors: { primaryColor: '#85daff', secondaryColor: '#4d8eaa', }, /* Custom fonts for website */ /* fonts: { myFont: [ "Times New Roman", "Serif" ], myOtherFont: [ "-apple-system", "system-ui" ] }, */ // This copyright info is used in /core/Footer.js and blog RSS/Atom feeds. copyright: `Copyright ${new Date().getFullYear()} Facebook Inc.`, highlight: { // Highlight.js theme to use for syntax highlighting in code blocks. theme: 'androidstudio', hljs: function(hljs) { hljs.registerLanguage('gotemplate', function(hljs) { var GO_KEYWORDS = { keyword: 'break default func interface select case map struct chan else goto package switch ' + 'const fallthrough if range type continue for import return var go defer ' + 'bool byte complex64 complex128 float32 float64 int8 int16 int32 int64 string uint8 ' + 'uint16 uint32 uint64 int uint uintptr rune with define block end', literal: 'true false iota nil', built_in: 'append cap close complex copy imag len make new panic print println real recover delete' + 'printf fail slice dict list' }; return { name: 'GoTemplate', aliases: ['gotmpl'], keywords: GO_KEYWORDS, contains: [ hljs.COMMENT('{{-* */\\*', '\\*/ *-*}}'), hljs.C_LINE_COMMENT_MODE, { className: 'string', variants: [ hljs.QUOTE_STRING_MODE, hljs.APOS_STRING_MODE, {begin: '`', end: '`'}, ] }, { begin: /:=/ }, ] }; }); } }, // Add custom scripts here that would be placed in