pax_global_header00006660000000000000000000000064143150074050014511gustar00rootroot0000000000000052 comment=5330f877593201a1538884944f1783b104c6fbc4 ent-0.11.3/000077500000000000000000000000001431500740500123615ustar00rootroot00000000000000ent-0.11.3/.all-contributorsrc000066400000000000000000000534101431500740500162150ustar00rootroot00000000000000{ "files": [ "doc/md/contributors.md" ], "imageSize": 100, "commit": false, "contributors": [ { "login": "a8m", "name": "Ariel Mashraki", "avatar_url": "https://avatars.githubusercontent.com/u/7413593?v=4", "profile": "https://github.com/a8m", "contributions": [ "maintenance", "doc", "code" ] }, { "login": "alexsn", "name": "Alex Snast", "avatar_url": "https://avatars.githubusercontent.com/u/987019?v=4", "profile": "https://github.com/alexsn", "contributions": [ "code" ] }, { "login": "rotemtam", "name": "Rotem Tamir", "avatar_url": "https://avatars.githubusercontent.com/u/1522681?v=4", "profile": "https://rotemtam.com/", "contributions": [ "maintenance", "doc", "code" ] }, { "login": "cliedeman", "name": "Ciaran Liedeman", "avatar_url": "https://avatars.githubusercontent.com/u/3578740?v=4", "profile": "https://github.com/cliedeman", "contributions": [ "code" ] }, { "login": "marwan-at-work", "name": "Marwan Sulaiman", "avatar_url": "https://avatars.githubusercontent.com/u/16294261?v=4", "profile": "https://www.marwan.io/", "contributions": [ "code" ] }, { "login": "napei", "name": "Nathaniel Peiffer", "avatar_url": "https://avatars.githubusercontent.com/u/8946502?v=4", "profile": "https://nathaniel.peiffer.com.au/", "contributions": [ "code" ] }, { "login": "tmc", "name": "Travis Cline", "avatar_url": "https://avatars.githubusercontent.com/u/3977?v=4", "profile": "https://github.com/tmc", "contributions": [ "code" ] }, { "login": "hantmac", "name": "Jeremy", "avatar_url": "https://avatars.githubusercontent.com/u/7600925?v=4", "profile": "https://cloudsjhan.github.io/", "contributions": [ "code" ] }, { "login": "aca", "name": "aca", "avatar_url": "https://avatars.githubusercontent.com/u/50316549?v=4", "profile": "https://github.com/aca", "contributions": [ "code" ] }, { "login": "BrentChesny", "name": "BrentChesny", "avatar_url": "https://avatars.githubusercontent.com/u/1449435?v=4", "profile": "https://github.com/BrentChesny", "contributions": [ "code", "doc" ] }, { "login": "giautm", "name": "Giau. Tran Minh", "avatar_url": "https://avatars.githubusercontent.com/u/12751435?v=4", "profile": "https://github.com/giautm", "contributions": [ "code", "review" ] }, { "login": "htdvisser", "name": "Hylke Visser", "avatar_url": "https://avatars.githubusercontent.com/u/181308?v=4", "profile": "https://htdvisser.dev/", "contributions": [ "code" ] }, { "login": "kerbelp", "name": "Pavel Kerbel", "avatar_url": "https://avatars.githubusercontent.com/u/3934990?v=4", "profile": "https://github.com/kerbelp", "contributions": [ "code" ] }, { "login": "day-dreams", "name": "zhangnan", "avatar_url": "https://avatars.githubusercontent.com/u/24593904?v=4", "profile": "https://github.com/day-dreams", "contributions": [ "code" ] }, { "login": "uta-mori", "name": "mori yuta", "avatar_url": "https://avatars.githubusercontent.com/u/59682979?v=4", "profile": "https://github.com/uta-mori", "contributions": [ "code", "translation", "review" ] }, { "login": "chris-rock", "name": "Christoph Hartmann", "avatar_url": "https://avatars.githubusercontent.com/u/1178413?v=4", "profile": "http://lollyrock.com/", "contributions": [ "code" ] }, { "login": "rubensayshi", "name": "Ruben de Vries", "avatar_url": "https://avatars.githubusercontent.com/u/649160?v=4", "profile": "https://github.com/rubensayshi", "contributions": [ "code" ] }, { "login": "ernado", "name": "Aleksandr Razumov", "avatar_url": "https://avatars.githubusercontent.com/u/866677?v=4", "profile": "https://keybase.io/ernado", "contributions": [ "code" ] }, { "login": "apbuteau", "name": "apbuteau", "avatar_url": "https://avatars.githubusercontent.com/u/6796073?v=4", "profile": "https://github.com/apbuteau", "contributions": [ "code" ] }, { "login": "ichord", "name": "Harold.Luo", "avatar_url": "https://avatars.githubusercontent.com/u/1324791?v=4", "profile": "https://github.com/ichord", "contributions": [ "code" ] }, { "login": "idoshveki", "name": "ido shveki", "avatar_url": "https://avatars.githubusercontent.com/u/11615669?v=4", "profile": "https://github.com/idoshveki", "contributions": [ "code" ] }, { "login": "masseelch", "name": "MasseElch", "avatar_url": "https://avatars.githubusercontent.com/u/12862103?v=4", "profile": "https://github.com/masseelch", "contributions": [ "code" ] }, { "login": "kidlj", "name": "Jian Li", "avatar_url": "https://avatars.githubusercontent.com/u/300616?v=4", "profile": "https://github.com/kidlj", "contributions": [ "code" ] }, { "login": "nolotz", "name": "Noah-Jerome Lotzer", "avatar_url": "https://avatars.githubusercontent.com/u/5778728?v=4", "profile": "https://noah.je/", "contributions": [ "code" ] }, { "login": "danf0rth", "name": "danforth", "avatar_url": "https://avatars.githubusercontent.com/u/14220891?v=4", "profile": "https://github.com/danf0rth", "contributions": [ "code" ] }, { "login": "maxiloEmmmm", "name": "maxilozoz", "avatar_url": "https://avatars.githubusercontent.com/u/16779121?v=4", "profile": "https://github.com/maxiloEmmmm", "contributions": [ "code" ] }, { "login": "zzwx", "name": "zzwx", "avatar_url": "https://avatars.githubusercontent.com/u/8169082?v=4", "profile": "https://gist.github.com/zzwx", "contributions": [ "code" ] }, { "login": "ix64", "name": "MengYX", "avatar_url": "https://avatars.githubusercontent.com/u/13902388?v=4", "profile": "https://github.com/ix64", "contributions": [ "translation" ] }, { "login": "mattn", "name": "mattn", "avatar_url": "https://avatars.githubusercontent.com/u/10111?v=4", "profile": "https://mattn.kaoriya.net/", "contributions": [ "translation" ] }, { "login": "Bladrak", "name": "Hugo Briand", "avatar_url": "https://avatars.githubusercontent.com/u/1321977?v=4", "profile": "https://github.com/Bladrak", "contributions": [ "code" ] }, { "login": "enmand", "name": "Dan Enman", "avatar_url": "https://avatars.githubusercontent.com/u/432487?v=4", "profile": "https://danielenman.com/", "contributions": [ "code" ] }, { "login": "UnAfraid", "name": "Rumen Nikiforov", "avatar_url": "https://avatars.githubusercontent.com/u/2185291?v=4", "profile": "http://www.l2junity.org/", "contributions": [ "code" ] }, { "login": "wenerme", "name": "陈杨文", "avatar_url": "https://avatars.githubusercontent.com/u/1777211?v=4", "profile": "https://wener.me", "contributions": [ "code" ] }, { "login": "joesonw", "name": "Qiaosen (Joeson) Huang", "avatar_url": "https://avatars.githubusercontent.com/u/1635441?v=4", "profile": "https://djwong.net", "contributions": [ "bug" ] }, { "login": "davebehr1", "name": "AlonDavidBehr", "avatar_url": "https://avatars.githubusercontent.com/u/16716239?v=4", "profile": "https://github.com/davebehr1", "contributions": [ "code", "review" ] }, { "login": "DuGlaser", "name": "DuGlaser", "avatar_url": "https://avatars.githubusercontent.com/u/50506482?v=4", "profile": "http://duglaser.dev", "contributions": [ "doc" ] }, { "login": "shanna", "name": "Shane Hanna", "avatar_url": "https://avatars.githubusercontent.com/u/28489?v=4", "profile": "https://github.com/shanna", "contributions": [ "doc" ] }, { "login": "mahmud2011", "name": "Mahmudul Haque", "avatar_url": "https://avatars.githubusercontent.com/u/5278142?v=4", "profile": "https://www.linkedin.com/in/mahmud2011", "contributions": [ "code" ] }, { "login": "sywesk", "name": "Benjamin Bourgeais", "avatar_url": "https://avatars.githubusercontent.com/u/862607?v=4", "profile": "http://blog.scaleprocess.net", "contributions": [ "code" ] }, { "login": "8ayac", "name": "8ayac(Yoshinori Hayashi)", "avatar_url": "https://avatars.githubusercontent.com/u/29266382?v=4", "profile": "https://about.8ay.ac/", "contributions": [ "doc" ] }, { "login": "y-yagi", "name": "y-yagi", "avatar_url": "https://avatars.githubusercontent.com/u/987638?v=4", "profile": "https://github.com/y-yagi", "contributions": [ "doc" ] }, { "login": "Sacro", "name": "Ben Woodward", "avatar_url": "https://avatars.githubusercontent.com/u/2659869?v=4", "profile": "https://github.com/Sacro", "contributions": [ "code" ] }, { "login": "wzyjerry", "name": "WzyJerry", "avatar_url": "https://avatars.githubusercontent.com/u/11435169?v=4", "profile": "https://github.com/wzyjerry", "contributions": [ "code" ] }, { "login": "tarrencev", "name": "Tarrence van As", "avatar_url": "https://avatars.githubusercontent.com/u/4740651?v=4", "profile": "https://github.com/tarrencev", "contributions": [ "doc", "code" ] }, { "login": "MONAKA0721", "name": "Yuya Sumie", "avatar_url": "https://avatars.githubusercontent.com/u/32859963?v=4", "profile": "https://mo7ka.com", "contributions": [ "doc" ] }, { "login": "akfaew", "name": "Michal Mazurek", "avatar_url": "https://avatars.githubusercontent.com/u/7853732?v=4", "profile": "http://jasminek.net", "contributions": [ "code" ] }, { "login": "nmemoto", "name": "Takafumi Umemoto", "avatar_url": "https://avatars.githubusercontent.com/u/1522332?v=4", "profile": "https://github.com/nmemoto", "contributions": [ "doc" ] }, { "login": "squarebat", "name": "Khadija Sidhpuri", "avatar_url": "https://avatars.githubusercontent.com/u/59063821?v=4", "profile": "http://www.linkedin.com/in/khadija-sidhpuri-87709316a", "contributions": [ "code" ] }, { "login": "neel229", "name": "Neel Modi", "avatar_url": "https://avatars.githubusercontent.com/u/53475167?v=4", "profile": "https://github.com/neel229", "contributions": [ "code" ] }, { "login": "shomodj", "name": "Boris Shomodjvarac", "avatar_url": "https://avatars.githubusercontent.com/u/304768?v=4", "profile": "https://ie.linkedin.com/in/boris-shomodjvarac-51970879", "contributions": [ "doc" ] }, { "login": "sadmansakib", "name": "Sadman Sakib", "avatar_url": "https://avatars.githubusercontent.com/u/17023844?v=4", "profile": "https://github.com/sadmansakib", "contributions": [ "doc" ] }, { "login": "dakimura", "name": "dakimura", "avatar_url": "https://avatars.githubusercontent.com/u/34202807?v=4", "profile": "https://github.com/dakimura", "contributions": [ "code" ] }, { "login": "RiskyFeryansyahP", "name": "Risky Feryansyah", "avatar_url": "https://avatars.githubusercontent.com/u/36788585?v=4", "profile": "https://github.com/RiskyFeryansyahP", "contributions": [ "code" ] }, { "login": "seiichi1101", "name": "seiichi ", "avatar_url": "https://avatars.githubusercontent.com/u/20941952?v=4", "profile": "https://github.com/seiichi1101", "contributions": [ "code" ] }, { "login": "odeke-em", "name": "Emmanuel T Odeke", "avatar_url": "https://avatars.githubusercontent.com/u/4898263?v=4", "profile": "https://orijtech.com/", "contributions": [ "code" ] }, { "login": "isoppp", "name": "Hiroki Isogai", "avatar_url": "https://avatars.githubusercontent.com/u/16318727?v=4", "profile": "https://isoppp.com", "contributions": [ "doc" ] }, { "login": "tsingsun", "name": "李清山", "avatar_url": "https://avatars.githubusercontent.com/u/5848549?v=4", "profile": "https://github.com/tsingsun", "contributions": [ "code" ] }, { "login": "s-takehana", "name": "s-takehana", "avatar_url": "https://avatars.githubusercontent.com/u/3423547?v=4", "profile": "https://github.com/s-takehana", "contributions": [ "doc" ] }, { "login": "EndlessIdea", "name": "Kuiba", "avatar_url": "https://avatars.githubusercontent.com/u/1527796?v=4", "profile": "https://github.com/EndlessIdea", "contributions": [ "code" ] }, { "login": "storyicon", "name": "storyicon", "avatar_url": "https://avatars.githubusercontent.com/u/29772821?v=4", "profile": "https://github.com/storyicon", "contributions": [ "code" ] }, { "login": "evanlurvey", "name": "Evan Lurvey", "avatar_url": "https://avatars.githubusercontent.com/u/54965655?v=4", "profile": "https://github.com/evanlurvey", "contributions": [ "code" ] }, { "login": "attackordie", "name": "Brian", "avatar_url": "https://avatars.githubusercontent.com/u/20145334?v=4", "profile": "https://github.com/attackordie", "contributions": [ "doc" ] }, { "login": "ThinkontrolSY", "name": "Shen Yang", "avatar_url": "https://avatars.githubusercontent.com/u/11331554?v=4", "profile": "http://www.thinkontrol.com", "contributions": [ "code" ] }, { "login": "sivchari", "name": "sivchari", "avatar_url": "https://avatars.githubusercontent.com/u/55221074?v=4", "profile": "https://twitter.com/sivchari", "contributions": [ "code" ] }, { "login": "mookjp", "name": "mook", "avatar_url": "https://avatars.githubusercontent.com/u/1519309?v=4", "profile": "https://blog.mookjp.io", "contributions": [ "code" ] }, { "login": "heliumbrain", "name": "heliumbrain", "avatar_url": "https://avatars.githubusercontent.com/u/1607668?v=4", "profile": "http://www.entiros.se", "contributions": [ "doc" ] }, { "login": "JeremyV2014", "name": "Jeremy Maxey-Vesperman", "avatar_url": "https://avatars.githubusercontent.com/u/9276415?v=4", "profile": "https://github.com/JeremyV2014", "contributions": [ "code", "doc" ] }, { "login": "tankbusta", "name": "Christopher Schmitt", "avatar_url": "https://avatars.githubusercontent.com/u/592749?v=4", "profile": "https://github.com/tankbusta", "contributions": [ "doc" ] }, { "login": "grevych", "name": "Gerardo Reyes", "avatar_url": "https://avatars.githubusercontent.com/u/3792003?v=4", "profile": "https://github.com/grevych", "contributions": [ "code" ] }, { "login": "naormatania", "name": "Naor Matania", "avatar_url": "https://avatars.githubusercontent.com/u/6978437?v=4", "profile": "https://github.com/naormatania", "contributions": [ "code" ] }, { "login": "idc77", "name": "idc77", "avatar_url": "https://avatars.githubusercontent.com/u/87644834?v=4", "profile": "https://github.com/idc77", "contributions": [ "doc" ] }, { "login": "HurSungYun", "name": "Sungyun Hur", "avatar_url": "https://avatars.githubusercontent.com/u/8033896?v=4", "profile": "http://ethanhur.me", "contributions": [ "doc" ] }, { "login": "peanut-cc", "name": "peanut-pg", "avatar_url": "https://avatars.githubusercontent.com/u/55480838?v=4", "profile": "https://github.com/peanut-cc", "contributions": [ "doc" ] }, { "login": "m3hm3t", "name": "Mehmet Yılmaz", "avatar_url": "https://avatars.githubusercontent.com/u/22320354?v=4", "profile": "https://github.com/m3hm3t", "contributions": [ "code" ] }, { "login": "Laconty", "name": "Roman Maklakov", "avatar_url": "https://avatars.githubusercontent.com/u/17760166?v=4", "profile": "https://github.com/Laconty", "contributions": [ "code" ] }, { "login": "genevieve", "name": "Genevieve", "avatar_url": "https://avatars.githubusercontent.com/u/12158641?v=4", "profile": "https://github.com/genevieve", "contributions": [ "code" ] }, { "login": "cjraa", "name": "Clarence", "avatar_url": "https://avatars.githubusercontent.com/u/62199269?v=4", "profile": "https://github.com/cjraa", "contributions": [ "code" ] }, { "login": "iamnande", "name": "Nicholas Anderson", "avatar_url": "https://avatars.githubusercontent.com/u/7806510?v=4", "profile": "https://www.linkedin.com/in/iamnande/", "contributions": [ "code" ] }, { "login": "hezhizhen", "name": "Zhizhen He", "avatar_url": "https://avatars.githubusercontent.com/u/7611700?v=4", "profile": "https://github.com/hezhizhen", "contributions": [ "code" ] }, { "login": "crossworth", "name": "Pedro Henrique", "avatar_url": "https://avatars.githubusercontent.com/u/1251151?v=4", "profile": "https://pedro.dev.br", "contributions": [ "code" ] }, { "login": "MrParano1d", "name": "MrParano1d", "avatar_url": "https://avatars.githubusercontent.com/u/7414374?v=4", "profile": "https://2jp.de", "contributions": [ "code" ] }, { "login": "tprebs", "name": "Thomas Prebble", "avatar_url": "https://avatars.githubusercontent.com/u/6523587?v=4", "profile": "https://github.com/tprebs", "contributions": [ "code" ] }, { "login": "imhuytq", "name": "Huy TQ", "avatar_url": "https://avatars.githubusercontent.com/u/5723282?v=4", "profile": "https://huytq.com", "contributions": [ "code" ] }, { "login": "maorlipchuk", "name": "maorlipchuk", "avatar_url": "https://avatars.githubusercontent.com/u/7034637?v=4", "profile": "https://github.com/maorlipchuk", "contributions": [ "code" ] }, { "login": "iwata", "name": "Motonori Iwata", "avatar_url": "https://avatars.githubusercontent.com/u/121048?v=4", "profile": "https://mobcov.hatenadiary.org/", "contributions": [ "doc" ] }, { "login": "CharlesGe129", "name": "Charles Ge", "avatar_url": "https://avatars.githubusercontent.com/u/20162173?v=4", "profile": "https://github.com/CharlesGe129", "contributions": [ "code" ] }, { "login": "thmeitz", "name": "Thomas Meitz", "avatar_url": "https://avatars.githubusercontent.com/u/92851940?v=4", "profile": "https://github.com/thmeitz", "contributions": [ "code", "doc" ] }, { "login": "booleangate", "name": "Justin Johnson", "avatar_url": "https://avatars.githubusercontent.com/u/181567?v=4", "profile": "http://justinjohnson.org", "contributions": [ "code" ] }, { "login": "hax10", "name": "hax10", "avatar_url": "https://avatars.githubusercontent.com/u/85743468?v=4", "profile": "https://github.com/hax10", "contributions": [ "code" ] }, { "login": "water-a", "name": "water-a", "avatar_url": "https://avatars.githubusercontent.com/u/38114545?v=4", "profile": "https://github.com/water-a", "contributions": [ "bug" ] }, { "login": "jhwz", "name": "jhwz", "avatar_url": "https://avatars.githubusercontent.com/u/52683873?v=4", "profile": "https://github.com/jhwz", "contributions": [ "doc" ] } ], "contributorsPerLine": 7, "projectName": "ent", "projectOwner": "ent", "repoType": "github", "repoHost": "https://github.com", "skipCi": true } ent-0.11.3/.bencher/000077500000000000000000000000001431500740500140455ustar00rootroot00000000000000ent-0.11.3/.bencher/config.yaml000066400000000000000000000000451431500740500161750ustar00rootroot00000000000000suppress_failure_on_regression: true ent-0.11.3/.github/000077500000000000000000000000001431500740500137215ustar00rootroot00000000000000ent-0.11.3/.github/ISSUE_TEMPLATE/000077500000000000000000000000001431500740500161045ustar00rootroot00000000000000ent-0.11.3/.github/ISSUE_TEMPLATE/1.bug.md000066400000000000000000000025701431500740500173460ustar00rootroot00000000000000--- name: Bug report 🐛 about: Create a bug report. labels: 'status: needs triage' --- - [ ] The issue is present in the latest release. - [ ] I have searched the [issues](https://github.com/ent/ent/issues) of this repository and believe that this is not a duplicate. ## Current Behavior 😯 ## Expected Behavior 🤔 ## Steps to Reproduce 🕹 Steps: 1. 2. 3. 4. ## Your Environment 🌎 | Tech | Version | | ----------- | ------- | | Go | 1.17.? | | Ent | 0.9.? | | Database | MySQL | | Driver | https://github.com/go-sql-driver/mysql | ent-0.11.3/.github/ISSUE_TEMPLATE/2.feature.md000066400000000000000000000014201431500740500202160ustar00rootroot00000000000000--- name: Feature request 🎉 about: Suggest a new idea for the project. labels: 'status: needs triage' --- - [ ] I have searched the [issues](https://github.com/ent/ent/issues) of this repository and believe that this is not a duplicate. ## Summary 💡 ## Motivation 🔦 ent-0.11.3/.github/ISSUE_TEMPLATE/3.support.md000066400000000000000000000001151431500740500203000ustar00rootroot00000000000000--- name: Question about: General support labels: 'status: needs triage' --- ent-0.11.3/.github/ISSUE_TEMPLATE/config.yml000066400000000000000000000003171431500740500200750ustar00rootroot00000000000000blank_issues_enabled: false # force the usage of a template contact_links: - name: Something Else ❔ url: https://gophers.slack.com/archives/C01FMSQDT53 about: Come chat to us in the gophers slackent-0.11.3/.github/dependabot.yml000066400000000000000000000002611431500740500165500ustar00rootroot00000000000000version: 2 updates: - package-ecosystem: github-actions directory: / schedule: interval: daily - package-ecosystem: gomod directory: / schedule: interval: daily ent-0.11.3/.github/workflows/000077500000000000000000000000001431500740500157565ustar00rootroot00000000000000ent-0.11.3/.github/workflows/cd.yml000066400000000000000000000025201431500740500170660ustar00rootroot00000000000000name: Continuous Deployment on: push: branches: - master paths: - 'doc/**' schedule: - cron: "0 9 * * 0-5" jobs: docs: name: docs runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 with: fetch-depth: 0 - uses: actions/setup-node@v3 with: node-version: 16.14 - name: Install Dependencies working-directory: doc/website run: yarn - name: Sync Translation working-directory: doc/website run: yarn crowdin:sync env: CROWDIN_TOKEN: ${{ secrets.CROWDIN_TOKEN }} - name: Build Docs working-directory: doc/website run: yarn build - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@v1.7.0 with: aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} aws-region: eu-central-1 - name: Deploy Docs working-directory: doc/website/build run: aws s3 sync . s3://entgoio --delete --exclude "images/*" - name: Invalidate Cache env: CDN_DISTRIBUTION_ID: ${{ secrets.CDN_DISTRIBUTION_ID }} run: aws cloudfront create-invalidation --distribution-id $CDN_DISTRIBUTION_ID --paths "/*" | jq -M "del(.Location)" ent-0.11.3/.github/workflows/ci.yml000066400000000000000000000257211431500740500171030ustar00rootroot00000000000000name: Continuous Integration on: push: paths-ignore: - 'doc/**' tags-ignore: - '*.*' pull_request: paths-ignore: - 'doc/**' jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 with: go-version: 1.19 - name: Run linters uses: golangci/golangci-lint-action@v3.2.0 with: version: v1.48.0 unit: runs-on: ubuntu-latest strategy: matrix: go: ['1.18', '1.19'] steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Run dialect tests run: go test -race ./... working-directory: dialect - name: Run schema tests run: go test -race ./... working-directory: schema - name: Run loader tests run: go test -race ./... working-directory: entc/load - name: Run codegen tests run: go test -race ./... working-directory: entc/gen - name: Run example tests working-directory: examples run: go test -race ./... generate: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 with: go-version: '1.19' - uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Run go generate run: go generate ./... - name: Check generated files run: | status=$(git status --porcelain) if [ -n "$status" ]; then echo "you need to run 'go generate ./...' and commit the changes" echo "$status" exit 1 fi integration: runs-on: ubuntu-latest services: mysql56: image: mysql:5.6.35 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 mysql57: image: mysql:5.7.26 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3307:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 mysql8: image: mysql:8 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3308:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria: image: mariadb env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria102: image: mariadb:10.2.32 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4307:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria103: image: mariadb:10.3.13 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4308:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 postgres10: image: postgres:10 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5430:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres11: image: postgres:11 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5431:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres12: image: postgres:12.3 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5432:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres13: image: postgres:13.1 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5433:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres14: image: postgres:14 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5434:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 gremlin-server: image: entgo/gremlin-server ports: - 8182:8182 options: >- --health-cmd "netstat -an | grep -q 8182" --health-interval 10s --health-timeout 5s --health-retries 5 steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 with: go-version: '1.18' - uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Run integration tests working-directory: entc/integration run: go test -race -count=2 ./... migration: runs-on: ubuntu-latest if: ${{ github.ref != 'refs/heads/master' }} services: mysql56: image: mysql:5.6.35 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 mysql57: image: mysql:5.7.26 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3307:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 mysql8: image: mysql:8 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3308:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria: image: mariadb env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria102: image: mariadb:10.2.32 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4307:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 maria103: image: mariadb:10.3.13 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 4308:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 postgres10: image: postgres:10 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5430:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres11: image: postgres:11 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5431:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres12: image: postgres:12.3 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5432:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres13: image: postgres:13.1 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5433:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 postgres14: image: postgres:14 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5434:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 gremlin-server: image: entgo/gremlin-server ports: - 8182:8182 options: >- --health-cmd "netstat -an | grep -q 8182" --health-interval 10s --health-timeout 5s --health-retries 5 steps: - uses: actions/checkout@v3 with: fetch-depth: 0 - uses: actions/setup-go@v3 with: go-version: '1.18' - uses: actions/cache@v3 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - name: Checkout origin/master run: git checkout origin/master - name: Run integration on origin/master working-directory: entc/integration run: go test -race -count=2 ./... - name: Checkout previous HEAD run: git checkout - - name: Run integration on HEAD working-directory: entc/integration run: go test -race -count=2 ./... ent-0.11.3/.golangci.yml000066400000000000000000000027011431500740500147450ustar00rootroot00000000000000run: go: '1.19' timeout: 5m linters-settings: errcheck: ignore: fmt:.*,Read|Write|Close|Exec,io:Copy dupl: threshold: 100 funlen: lines: 115 statements: 115 goheader: template: |- Copyright 2019-present Facebook Inc. All rights reserved. This source code is licensed under the Apache 2.0 license found in the LICENSE file in the root directory of this source tree. linters: disable-all: true enable: - bodyclose - deadcode - depguard - dogsled - dupl - errcheck - funlen - gocritic # - gofmt; Enable back when upgrading CI to Go 1.20. - goheader - gosec - gosimple - govet - ineffassign - misspell - staticcheck - structcheck - stylecheck - typecheck - unconvert - unused - varcheck - whitespace issues: exclude-rules: - path: _test\.go linters: - dupl - funlen - gosec - gocritic - linters: - unused source: ent.Schema - path: dialect/sql/schema linters: - dupl - gosec - text: "Expect WriteFile permissions to be 0600 or less" linters: - gosec - path: privacy/privacy.go linters: - stylecheck - path: entc/load/schema.go linters: - staticcheck - path: entc/gen/graph.go linters: - gocritic - path: \.go linters: - staticcheck text: SA1019 ent-0.11.3/CODE_OF_CONDUCT.md000066400000000000000000000064341431500740500151670ustar00rootroot00000000000000# Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ent-0.11.3/CONTRIBUTING.md000066400000000000000000000050361431500740500146160ustar00rootroot00000000000000# Contributing to ent We want to make contributing to this project as easy and transparent as possible. # Project structure - `dialect` - Contains SQL and Gremlin code used by the generated code. - `dialect/sql/schema` - Auto migration logic resides there. - `schema` - User schema API. - `schema/{field, edge, index, mixin}` - provides schema builders API. - `schema/field/gen` - Templates and codegen for numeric builders. - `entc` - Codegen of `ent`. - `entc/load` - `entc` loader API for loading user schemas into a Go objects at runtime. - `entc/gen` - The actual code generation logic resides in this package (and its `templates` package). - `integration` - Integration tests for `entc`. - `privacy` - Runtime code for [privacy layer](https://entgo.io/docs/privacy/). - `doc` - Documentation code for `entgo.io` (uses [Docusaurus](https://docusaurus.io)). - `doc/md` - Markdown files for documentation. - `doc/website` - Website code and assets. In order to test your documentation changes, run `npm start` from the `doc/website` directory, and open [localhost:3000](http://localhost:3000/). # Run integration tests If you touch any file in `entc`, run the following command in `entc`: ``` go generate ./... ``` Then, in `entc/integration` run `docker-compose` in order to spin-up all database containers: ``` docker-compose -f docker-compose.yaml up -d ``` Then, run `go test ./...` to run all integration tests. ## Pull Requests We actively welcome your pull requests. 1. Fork the repo and create your branch from `master`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. 5. Make sure your code lints. 6. If you haven't already, complete the Contributor License Agreement ("CLA"). ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Facebook's open source projects. Complete your CLA here: ## Issues We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. ## License By contributing to ent, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. ent-0.11.3/LICENSE000066400000000000000000000261351431500740500133750ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.ent-0.11.3/README.md000066400000000000000000000055611431500740500136470ustar00rootroot00000000000000## ent - An Entity Framework For Go [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/entgo_io.svg?style=social&label=Follow%20%40entgo_io)](https://twitter.com/entgo_io) [![Discord](https://img.shields.io/discord/885059418646003782?label=discord&logo=discord&style=flat-square&logoColor=white)](https://discord.gg/qZmPgTE6RX) [English](README.md) | [中文](README_zh.md) | [日本語](README_jp.md) Simple, yet powerful entity framework for Go, that makes it easy to build and maintain applications with large data-models. - **Schema As Code** - model any database schema as Go objects. - **Easily Traverse Any Graph** - run queries, aggregations and traverse any graph structure easily. - **Statically Typed And Explicit API** - 100% statically typed and explicit API using code generation. - **Multi Storage Driver** - supports MySQL, MariaDB, TiDB, PostgreSQL, CockroachDB, SQLite and Gremlin. - **Extendable** - simple to extend and customize using Go templates. ## Quick Installation ```console go install entgo.io/ent/cmd/ent@latest ``` For proper installation using [Go modules], visit [entgo.io website][entgo instal]. ## Docs and Support The documentation for developing and using ent is available at: https://entgo.io For discussion and support, [open an issue](https://github.com/ent/ent/issues/new/choose) or join our [channel](https://gophers.slack.com/archives/C01FMSQDT53) in the gophers Slack. ## Join the ent Community Building `ent` would not have been possible without the collective work of our entire community. We maintain a [contributors page](doc/md/contributors.md) which lists the contributors to this `ent`. In order to contribute to `ent`, see the [CONTRIBUTING](CONTRIBUTING.md) file for how to go get started. If your company or your product is using `ent`, please let us know by adding yourself to the [ent users page](https://github.com/ent/ent/wiki/ent-users). For updates, follow us on Twitter at https://twitter.com/entgo_io ## About the Project The `ent` project was inspired by Ent, an entity framework we use internally. It is developed and maintained by [a8m](https://github.com/a8m) and [alexsn](https://github.com/alexsn) from the [Facebook Connectivity][fbc] team. It is used by multiple teams and projects in production, and the roadmap for its v1 release is described [here](https://github.com/ent/ent/issues/46). Read more about the motivation of the project [here](https://entgo.io/blog/2019/10/03/introducing-ent). ## License ent is licensed under Apache 2.0 as found in the [LICENSE file](LICENSE). [entgo instal]: https://entgo.io/docs/code-gen/#version-compatibility-between-entc-and-ent [Go modules]: https://github.com/golang/go/wiki/Modules#quick-start [fbc]: https://connectivity.fb.com ent-0.11.3/README_jp.md000066400000000000000000000073511431500740500143370ustar00rootroot00000000000000## ent - Goのエンティティーフレームワーク [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/entgo_io.svg?style=social&label=Follow%20%40entgo_io)](https://twitter.com/entgo_io) [English](README.md) | [中文](README_zh.md) | [日本語](README_jp.md) シンプルながらもパワフルなGoのエンティティフレームワークであり、大規模なデータモデルを持つアプリケーションを容易に構築・保守できるようにします。 - **Schema As Code(コードとしてのスキーマ)** - あらゆるデータベーススキーマをGoオブジェクトとしてモデル化します。 - **任意のグラフを簡単にトラバースできます** - クエリや集約の実行、任意のグラフ構造の走査を容易に実行できます。 - **100%静的に型付けされた明示的なAPI** - コード生成により、100%静的に型付けされた曖昧さのないAPIを提供します。 - **マルチストレージドライバ** - MySQL、MariaDB、 TiDB、PostgreSQL、CockroachDB、SQLite、Gremlinをサポートしています。 - **拡張性** - Goテンプレートを使用して簡単に拡張、カスタマイズできます。 ## クイックインストール ```console go install entgo.io/ent/cmd/ent@latest ``` [Go modules]を使ったインストールについては、[entgo.ioのWebサイト](https://entgo.io/ja/docs/code-gen/#entc-%E3%81%A8-ent-%E3%81%AE%E3%83%90%E3%83%BC%E3%82%B8%E3%83%A7%E3%83%B3%E3%82%92%E4%B8%80%E8%87%B4%E3%81%95%E3%81%9B%E3%82%8B)をご覧ください。 ## ドキュメントとサポート entを開発・使用するためのドキュメントは、こちら: https://entgo.io 議論やサポートについては、[Issueを開く](https://github.com/ent/ent/issues/new/choose)か、gophers Slackの[チャンネル](https://gophers.slack.com/archives/C01FMSQDT53)に参加してください。 ## entコミュニティへの参加 `ent`の構築は、コミュニティ全体の協力なしには実現できませんでした。 私たちは、この`ent`の貢献者をリストアップした[contributorsページ](doc/md/contributors.md)を管理しています。 `ent`に貢献するときは、まず[CONTRIBUTING](CONTRIBUTING.md)を参照してください。 もし、あなたの会社や製品で`ent`を利用している場合は、[ent usersページ](https://github.com/ent/ent/wiki/ent-users)に追記する形で、そのことをぜひ教えて下さい。 最新情報については、Twitter()をフォローしてください。 ## プロジェクトについて `ent`プロジェクトは、私たちが社内で使用しているエンティティフレームワークであるEntからインスピレーションを得ています。 entは、[Facebook Connectivity][fbc]チームの[a8m](https://github.com/a8m)と[alexsn](https://github.com/alexsn)が開発・保守しています。 本番環境では複数のチームやプロジェクトで使用されており、v1リリースまでのロードマップは[こちら](https://github.com/ent/ent/issues/46)に記載されています。 このプロジェクトの動機については[こちら](https://entgo.io/blog/2019/10/03/introducing-ent)をご覧ください。 ## ライセンス entは、[LICENSEファイル](LICENSE)にもある通り、Apache 2.0でライセンスされています。 [entgo instal]: https://entgo.io/docs/code-gen/#version-compatibility-between-entc-and-ent [Go modules]: https://github.com/golang/go/wiki/Modules#quick-start [fbc]: https://connectivity.fb.com ent-0.11.3/README_zh.md000066400000000000000000000043621431500740500143460ustar00rootroot00000000000000## ent - 一个强大的Go语言实体框架 [English](README.md) | [中文](README_zh.md) | [日本語](README_jp.md) ent是一个简单而又功能强大的Go语言实体框架,ent易于构建和维护应用程序与大数据模型。 - **图就是代码** - 将任何数据库表建模为Go对象。 - **轻松地遍历任何图形** - 可以轻松地运行查询、聚合和遍历任何图形结构。 - **静态类型和显式API** - 使用代码生成静态类型和显式API,查询数据更加便捷。 - **多存储驱动程序** - 支持MySQL, PostgreSQL, SQLite 和 Gremlin。 - **可扩展** - 简单地扩展和使用Go模板自定义。 ## 快速安装 ```console go install entgo.io/ent/cmd/ent@latest ``` 请访问[entgo.io website][entgo instal]以使用[Go modules]进行正确安装。 ## 文档和支持 开发和使用ent的文档请参照: https://entgo.io 如要讨论问题和支持, [创建一个issue](https://github.com/ent/ent/issues/new/choose) 或者加入我们的Gopher Slack(Slack软件,类似于论坛)[讨论组](https://gophers.slack.com/archives/C01FMSQDT53) ## 加入 ent 社区 如果你想为`ent`做出贡献, [贡献代码](CONTRIBUTING.md) 中写了如何做出自己的贡献 如果你的公司或者产品在使用`ent`,请让我们知道你已经加入 [ent 用户](https://github.com/ent/ent/wiki/ent-users) ## 关于项目 `ent` 项目灵感来自于Ent,Ent是一个facebook内部使用的一个实体框架项目。 它由 [Facebook Connectivity][fbc] 团队通过 [a8m](https://github.com/a8m) 和 [alexsn](https://github.com/alexsn) 开发和维护 , 它被生产中的多个团队和项目使用。它的v1版本的路线图为 [版本的路线图](https://github.com/ent/ent/issues/46). 关于项目更多的信息 [ent介绍](https://entgo.io/blog/2019/10/03/introducing-ent)。 ## 声明 ent使用Apache 2.0协议授权,可以在[LICENSE文件](LICENSE)中找到。 [entgo instal]: https://entgo.io/docs/code-gen/#version-compatibility-between-entc-and-ent [Go modules]: https://github.com/golang/go/wiki/Modules#quick-start [fbc]: https://connectivity.fb.com ent-0.11.3/cmd/000077500000000000000000000000001431500740500131245ustar00rootroot00000000000000ent-0.11.3/cmd/ent/000077500000000000000000000000001431500740500137125ustar00rootroot00000000000000ent-0.11.3/cmd/ent/ent.go000066400000000000000000000007131431500740500150300ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package main import ( "log" "entgo.io/ent/cmd/internal/base" "github.com/spf13/cobra" ) func main() { log.SetFlags(0) cmd := &cobra.Command{Use: "ent"} cmd.AddCommand( base.InitCmd(), base.DescribeCmd(), base.GenerateCmd(), ) _ = cmd.Execute() } ent-0.11.3/cmd/ent/ent_test.go000066400000000000000000000020211431500740500160610ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package main import ( "bytes" "os" "os/exec" "testing" "github.com/stretchr/testify/require" ) func TestCmd(t *testing.T) { defer os.RemoveAll("ent") cmd := exec.Command("go", "run", "entgo.io/ent/cmd/ent", "init", "User") stderr := bytes.NewBuffer(nil) cmd.Stderr = stderr require.NoError(t, cmd.Run()) require.Zero(t, stderr.String()) cmd = exec.Command("go", "run", "entgo.io/ent/cmd/ent", "init", "User") require.Error(t, cmd.Run()) _, err := os.Stat("ent/generate.go") require.NoError(t, err) _, err = os.Stat("ent/schema/user.go") require.NoError(t, err) cmd = exec.Command("go", "run", "entgo.io/ent/cmd/ent", "generate", "./ent/schema") stderr = bytes.NewBuffer(nil) cmd.Stderr = stderr require.NoError(t, cmd.Run()) require.Zero(t, stderr.String()) _, err = os.Stat("ent/user.go") require.NoError(t, err) } ent-0.11.3/cmd/entc/000077500000000000000000000000001431500740500140555ustar00rootroot00000000000000ent-0.11.3/cmd/entc/entc.go000066400000000000000000000015211431500740500153340ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package main import ( "bytes" "log" "os" "path/filepath" "entgo.io/ent/cmd/internal/base" "entgo.io/ent/entc/gen" "github.com/spf13/cobra" ) func main() { log.SetFlags(0) cmd := &cobra.Command{Use: "entc"} cmd.AddCommand( base.InitCmd(), base.DescribeCmd(), base.GenerateCmd(migrate), ) _ = cmd.Execute() } func migrate(c *gen.Config) { var ( target = filepath.Join(c.Target, "generate.go") oldCmd = []byte("entgo.io/ent/cmd/entc") ) buf, err := os.ReadFile(target) if err != nil || !bytes.Contains(buf, oldCmd) { return } _ = os.WriteFile(target, bytes.ReplaceAll(buf, oldCmd, []byte("entgo.io/ent/cmd/ent")), 0644) } ent-0.11.3/cmd/entc/entc_test.go000066400000000000000000000016131431500740500163750ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package main import ( "bytes" "os" "os/exec" "testing" "github.com/stretchr/testify/require" ) func TestCmd(t *testing.T) { defer os.RemoveAll("ent") cmd := exec.Command("go", "run", "entgo.io/ent/cmd/entc", "init", "User") stderr := bytes.NewBuffer(nil) cmd.Stderr = stderr require.NoError(t, cmd.Run(), stderr.String()) _, err := os.Stat("ent/generate.go") require.NoError(t, err) _, err = os.Stat("ent/schema/user.go") require.NoError(t, err) cmd = exec.Command("go", "run", "entgo.io/ent/cmd/entc", "generate", "./ent/schema") stderr = bytes.NewBuffer(nil) cmd.Stderr = stderr require.NoError(t, cmd.Run(), stderr.String()) _, err = os.Stat("ent/user.go") require.NoError(t, err) } ent-0.11.3/cmd/internal/000077500000000000000000000000001431500740500147405ustar00rootroot00000000000000ent-0.11.3/cmd/internal/base/000077500000000000000000000000001431500740500156525ustar00rootroot00000000000000ent-0.11.3/cmd/internal/base/base.go000066400000000000000000000166571431500740500171320ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package base defines shared basic pieces of the ent command. package base import ( "bytes" "errors" "fmt" "log" "os" "path/filepath" "strings" "text/template" "unicode" "entgo.io/ent/cmd/internal/printer" "entgo.io/ent/entc" "entgo.io/ent/entc/gen" "entgo.io/ent/schema/field" "github.com/spf13/cobra" ) // IDType is a custom ID implementation for pflag. type IDType field.Type // Set implements the Set method of the flag.Value interface. func (t *IDType) Set(s string) error { switch s { case field.TypeInt.String(): *t = IDType(field.TypeInt) case field.TypeInt64.String(): *t = IDType(field.TypeInt64) case field.TypeUint.String(): *t = IDType(field.TypeUint) case field.TypeUint64.String(): *t = IDType(field.TypeUint64) case field.TypeString.String(): *t = IDType(field.TypeString) default: return fmt.Errorf("invalid type %q", s) } return nil } // Type returns the type representation of the id option for help command. func (IDType) Type() string { return fmt.Sprintf("%v", []field.Type{ field.TypeInt, field.TypeInt64, field.TypeUint, field.TypeUint64, field.TypeString, }) } // String returns the default value for the help command. func (IDType) String() string { return field.TypeInt.String() } // InitCmd returns the init command for ent/c packages. func InitCmd() *cobra.Command { var target, tmplPath string cmd := &cobra.Command{ Use: "init [flags] [schemas]", Short: "initialize an environment with zero or more schemas", Example: examples( "ent init Example", "ent init --target entv1/schema User Group", "ent init --template ./path/to/file.tmpl User", ), Args: func(_ *cobra.Command, names []string) error { for _, name := range names { if !unicode.IsUpper(rune(name[0])) { return errors.New("schema names must begin with uppercase") } } return nil }, Run: func(cmd *cobra.Command, names []string) { var ( err error tmpl *template.Template ) if tmplPath != "" { tmpl, err = template.ParseFiles(tmplPath) } else { tmpl, err = template.New("schema").Parse(defaultTemplate) } if err != nil { log.Fatalln(fmt.Errorf("ent/init: could not parse template %w", err)) } if err := initEnv(target, names, tmpl); err != nil { log.Fatalln(fmt.Errorf("ent/init: %w", err)) } }, } cmd.Flags().StringVar(&target, "target", defaultSchema, "target directory for schemas") cmd.Flags().StringVar(&tmplPath, "template", "", "template to use for new schemas") return cmd } // DescribeCmd returns the describe command for ent/c packages. func DescribeCmd() *cobra.Command { return &cobra.Command{ Use: "describe [flags] path", Short: "printer a description of the graph schema", Example: examples( "ent describe ./ent/schema", "ent describe github.com/a8m/x", ), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, path []string) { graph, err := entc.LoadGraph(path[0], &gen.Config{}) if err != nil { log.Fatalln(err) } printer.Fprint(os.Stdout, graph) }, } } // GenerateCmd returns the generate command for ent/c packages. func GenerateCmd(postRun ...func(*gen.Config)) *cobra.Command { var ( cfg gen.Config storage string features []string templates []string idtype = IDType(field.TypeInt) cmd = &cobra.Command{ Use: "generate [flags] path", Short: "generate go code for the schema directory", Example: examples( "ent generate ./ent/schema", "ent generate github.com/a8m/x", ), Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, path []string) { opts := []entc.Option{ entc.Storage(storage), entc.FeatureNames(features...), } for _, tmpl := range templates { typ := "dir" if parts := strings.SplitN(tmpl, "=", 2); len(parts) > 1 { typ, tmpl = parts[0], parts[1] } switch typ { case "dir": opts = append(opts, entc.TemplateDir(tmpl)) case "file": opts = append(opts, entc.TemplateFiles(tmpl)) case "glob": opts = append(opts, entc.TemplateGlob(tmpl)) default: log.Fatalln("unsupported template type", typ) } } // If the target directory is not inferred from // the schema path, resolve its package path. if cfg.Target != "" { pkgPath, err := PkgPath(DefaultConfig, cfg.Target) if err != nil { log.Fatalln(err) } cfg.Package = pkgPath } cfg.IDType = &field.TypeInfo{Type: field.Type(idtype)} if err := entc.Generate(path[0], &cfg, opts...); err != nil { log.Fatalln(err) } for _, fn := range postRun { fn(&cfg) } }, } ) cmd.Flags().Var(&idtype, "idtype", "type of the id field") cmd.Flags().StringVar(&storage, "storage", "sql", "storage driver to support in codegen") cmd.Flags().StringVar(&cfg.Header, "header", "", "override codegen header") cmd.Flags().StringVar(&cfg.Target, "target", "", "target directory for codegen") cmd.Flags().StringSliceVarP(&features, "feature", "", nil, "extend codegen with additional features") cmd.Flags().StringSliceVarP(&templates, "template", "", nil, "external templates to execute") return cmd } // initEnv initialize an environment for ent codegen. func initEnv(target string, names []string, tmpl *template.Template) error { if err := createDir(target); err != nil { return fmt.Errorf("create dir %s: %w", target, err) } for _, name := range names { if err := gen.ValidSchemaName(name); err != nil { return fmt.Errorf("init schema %s: %w", name, err) } if fileExists(target, name) { return fmt.Errorf("init schema %s: already exists", name) } b := bytes.NewBuffer(nil) if err := tmpl.Execute(b, name); err != nil { return fmt.Errorf("executing template %s: %w", name, err) } newFileTarget := filepath.Join(target, strings.ToLower(name+".go")) if err := os.WriteFile(newFileTarget, b.Bytes(), 0644); err != nil { return fmt.Errorf("writing file %s: %w", newFileTarget, err) } } return nil } func createDir(target string) error { _, err := os.Stat(target) if err == nil || !os.IsNotExist(err) { return err } if err := os.MkdirAll(target, os.ModePerm); err != nil { return fmt.Errorf("creating schema directory: %w", err) } if target != defaultSchema { return nil } if err := os.WriteFile("ent/generate.go", []byte(genFile), 0644); err != nil { return fmt.Errorf("creating generate.go file: %w", err) } return nil } func fileExists(target, name string) bool { var _, err = os.Stat(filepath.Join(target, strings.ToLower(name+".go"))) return err == nil } const ( // default schema package path. defaultSchema = "ent/schema" // ent/generate.go file used for "go generate" command. genFile = "package ent\n\n//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema\n" // schema template for the "init" command. defaultTemplate = `package schema import "entgo.io/ent" // {{ . }} holds the schema definition for the {{ . }} entity. type {{ . }} struct { ent.Schema } // Fields of the {{ . }}. func ({{ . }}) Fields() []ent.Field { return nil } // Edges of the {{ . }}. func ({{ . }}) Edges() []ent.Edge { return nil } ` ) // examples formats the given examples to the cli. func examples(ex ...string) string { for i := range ex { ex[i] = " " + ex[i] // indent each row with 2 spaces. } return strings.Join(ex, "\n") } ent-0.11.3/cmd/internal/base/packages.go000066400000000000000000000030501431500740500177550ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package base import ( "fmt" "os" "path" "path/filepath" "golang.org/x/tools/go/packages" ) // DefaultConfig for loading Go base. var DefaultConfig = &packages.Config{Mode: packages.NeedName} // PkgPath returns the Go package name for given target path. // Even if the existing path is not exist yet in the filesystem. // // If base.Config is nil, DefaultConfig will be used to load base. func PkgPath(config *packages.Config, target string) (string, error) { if config == nil { config = DefaultConfig } pathCheck, err := filepath.Abs(target) if err != nil { return "", err } var parts []string if _, err := os.Stat(pathCheck); os.IsNotExist(err) { parts = append(parts, filepath.Base(pathCheck)) pathCheck = filepath.Dir(pathCheck) } // Try maximum 2 directories above the given // target to find the root package or module. for i := 0; i < 2; i++ { pkgs, err := packages.Load(config, pathCheck) if err != nil { return "", fmt.Errorf("load package info: %w", err) } if len(pkgs) == 0 || len(pkgs[0].Errors) != 0 { parts = append(parts, filepath.Base(pathCheck)) pathCheck = filepath.Dir(pathCheck) continue } pkgPath := pkgs[0].PkgPath for j := len(parts) - 1; j >= 0; j-- { pkgPath = path.Join(pkgPath, parts[j]) } return pkgPath, nil } return "", fmt.Errorf("root package or module was not found for: %s", target) } ent-0.11.3/cmd/internal/base/packages_test.go000066400000000000000000000026161431500740500210230ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package base import ( "path/filepath" "testing" "github.com/stretchr/testify/require" "golang.org/x/tools/go/packages/packagestest" ) func TestPkgPath(t *testing.T) { packagestest.TestAll(t, testPkgPath) } func testPkgPath(t *testing.T, x packagestest.Exporter) { e := packagestest.Export(t, x, []packagestest.Module{ { Name: "golang.org/x", Files: map[string]any{ "x.go": "package x", "y/y.go": "package y", }, }, }) defer e.Cleanup() e.Config.Dir = filepath.Dir(e.File("golang.org/x", "x.go")) target := filepath.Join(e.Config.Dir, "ent") pkgPath, err := PkgPath(e.Config, target) require.NoError(t, err) require.Equal(t, "golang.org/x/ent", pkgPath) e.Config.Dir = filepath.Dir(e.File("golang.org/x", "y/y.go")) target = filepath.Join(e.Config.Dir, "ent") pkgPath, err = PkgPath(e.Config, target) require.NoError(t, err) require.Equal(t, "golang.org/x/y/ent", pkgPath) target = filepath.Join(e.Config.Dir, "z/ent") pkgPath, err = PkgPath(e.Config, target) require.NoError(t, err) require.Equal(t, "golang.org/x/y/z/ent", pkgPath) target = filepath.Join(e.Config.Dir, "z/e/n/t") pkgPath, err = PkgPath(e.Config, target) require.Error(t, err) require.Empty(t, pkgPath) } ent-0.11.3/cmd/internal/printer/000077500000000000000000000000001431500740500164235ustar00rootroot00000000000000ent-0.11.3/cmd/internal/printer/printer.go000066400000000000000000000041351431500740500204400ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package printer import ( "fmt" "io" "reflect" "strconv" "strings" "entgo.io/ent/entc/gen" "github.com/olekukonko/tablewriter" ) // A Config controls the output of Fprint. type Config struct { io.Writer } // Print prints a table description of the graph to the given writer. func (p Config) Print(g *gen.Graph) { for _, n := range g.Nodes { p.node(n) } } // Fprint executes "pretty-printer" on the given writer. func Fprint(w io.Writer, g *gen.Graph) { Config{Writer: w}.Print(g) } // node returns description of a type. The format of the description is: // // Type: // // // func (p Config) node(t *gen.Type) { var ( b strings.Builder id []*gen.Field table = tablewriter.NewWriter(&b) header = []string{"Field", "Type", "Unique", "Optional", "Nillable", "Default", "UpdateDefault", "Immutable", "StructTag", "Validators"} ) b.WriteString(t.Name + ":\n") table.SetAutoFormatHeaders(false) table.SetHeader(header) if t.ID != nil { id = append(id, t.ID) } for _, f := range append(id, t.Fields...) { v := reflect.ValueOf(*f) row := make([]string, len(header)) for i := range row { field := v.FieldByNameFunc(func(name string) bool { // The first field is mapped from "Name" to "Field". return name == "Name" && i == 0 || name == header[i] }) row[i] = fmt.Sprint(field.Interface()) } table.Append(row) } table.Render() table = tablewriter.NewWriter(&b) table.SetAutoFormatHeaders(false) table.SetHeader([]string{"Edge", "Type", "Inverse", "BackRef", "Relation", "Unique", "Optional"}) for _, e := range t.Edges { table.Append([]string{ e.Name, e.Type.Name, strconv.FormatBool(e.IsInverse()), e.Inverse, e.Rel.Type.String(), strconv.FormatBool(e.Unique), strconv.FormatBool(e.Optional), }) } if table.NumLines() > 0 { table.Render() } io.WriteString(p, strings.ReplaceAll(b.String(), "\n", "\n\t")+"\n") } ent-0.11.3/cmd/internal/printer/printer_test.go000066400000000000000000000207111431500740500214750ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package printer import ( "strings" "testing" "entgo.io/ent/entc/gen" "entgo.io/ent/schema/field" "github.com/stretchr/testify/assert" ) func TestPrinter_Print(t *testing.T) { tests := []struct { input *gen.Graph out string }{ { input: &gen.Graph{ Nodes: []*gen.Type{ { Name: "User", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Fields: []*gen.Field{ {Name: "name", Type: &field.TypeInfo{Type: field.TypeString}, Validators: 1}, {Name: "age", Type: &field.TypeInfo{Type: field.TypeInt}, Nillable: true}, {Name: "created_at", Type: &field.TypeInfo{Type: field.TypeTime}, Nillable: true, Immutable: true}, }, }, }, }, out: ` User: +------------+-----------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +------------+-----------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | | name | string | false | false | false | false | false | false | | 1 | | age | int | false | false | true | false | false | false | | 0 | | created_at | time.Time | false | false | true | false | false | true | | 0 | +------------+-----------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ `, }, { input: &gen.Graph{ Nodes: []*gen.Type{ { Name: "User", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Edges: []*gen.Edge{ {Name: "groups", Type: &gen.Type{Name: "Group"}, Rel: gen.Relation{Type: gen.M2M}, Optional: true}, {Name: "spouse", Type: &gen.Type{Name: "User"}, Unique: true, Rel: gen.Relation{Type: gen.O2O}}, }, }, }, }, out: ` User: +-------+------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | +-------+------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ +--------+-------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +--------+-------+---------+---------+----------+--------+----------+ | groups | Group | false | | M2M | false | true | | spouse | User | false | | O2O | true | false | +--------+-------+---------+---------+----------+--------+----------+ `, }, { input: &gen.Graph{ Nodes: []*gen.Type{ { Name: "User", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Fields: []*gen.Field{ {Name: "name", Type: &field.TypeInfo{Type: field.TypeString}, Validators: 1}, {Name: "age", Type: &field.TypeInfo{Type: field.TypeInt}, Nillable: true}, }, Edges: []*gen.Edge{ {Name: "groups", Type: &gen.Type{Name: "Group"}, Rel: gen.Relation{Type: gen.M2M}, Optional: true}, {Name: "spouse", Type: &gen.Type{Name: "User"}, Unique: true, Rel: gen.Relation{Type: gen.O2O}}, }, }, }, }, out: ` User: +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | | name | string | false | false | false | false | false | false | | 1 | | age | int | false | false | true | false | false | false | | 0 | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ +--------+-------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +--------+-------+---------+---------+----------+--------+----------+ | groups | Group | false | | M2M | false | true | | spouse | User | false | | O2O | true | false | +--------+-------+---------+---------+----------+--------+----------+ `, }, { input: &gen.Graph{ Nodes: []*gen.Type{ { Name: "User", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Fields: []*gen.Field{ {Name: "name", Type: &field.TypeInfo{Type: field.TypeString}, Validators: 1}, {Name: "age", Type: &field.TypeInfo{Type: field.TypeInt}, Nillable: true}, }, Edges: []*gen.Edge{ {Name: "groups", Type: &gen.Type{Name: "Group"}, Rel: gen.Relation{Type: gen.M2M}, Optional: true}, {Name: "spouse", Type: &gen.Type{Name: "User"}, Unique: true, Rel: gen.Relation{Type: gen.O2O}}, }, }, { Name: "Group", ID: &gen.Field{Name: "id", Type: &field.TypeInfo{Type: field.TypeInt}}, Fields: []*gen.Field{ {Name: "name", Type: &field.TypeInfo{Type: field.TypeString}}, }, Edges: []*gen.Edge{ {Name: "users", Type: &gen.Type{Name: "User"}, Rel: gen.Relation{Type: gen.M2M}, Optional: true}, }, }, }, }, out: ` User: +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | | name | string | false | false | false | false | false | false | | 1 | | age | int | false | false | true | false | false | false | | 0 | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ +--------+-------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +--------+-------+---------+---------+----------+--------+----------+ | groups | Group | false | | M2M | false | true | | spouse | User | false | | O2O | true | false | +--------+-------+---------+---------+----------+--------+----------+ Group: +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ | id | int | false | false | false | false | false | false | | 0 | | name | string | false | false | false | false | false | false | | 0 | +-------+--------+--------+----------+----------+---------+---------------+-----------+-----------+------------+ +-------+------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +-------+------+---------+---------+----------+--------+----------+ | users | User | false | | M2M | false | true | +-------+------+---------+---------+----------+--------+----------+ `, }, } for _, tt := range tests { b := &strings.Builder{} Fprint(b, tt.input) assert.Equal(t, tt.out, "\n"+b.String()) } } ent-0.11.3/dialect/000077500000000000000000000000001431500740500137665ustar00rootroot00000000000000ent-0.11.3/dialect/dialect.go000066400000000000000000000164251431500740500157320ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package dialect import ( "context" "database/sql" "database/sql/driver" "fmt" "log" "github.com/google/uuid" ) // Dialect names for external usage. const ( MySQL = "mysql" SQLite = "sqlite3" Postgres = "postgres" Gremlin = "gremlin" ) // ExecQuerier wraps the 2 database operations. type ExecQuerier interface { // Exec executes a query that does not return records. For example, in SQL, INSERT or UPDATE. // It scans the result into the pointer v. For SQL drivers, it is dialect/sql.Result. Exec(ctx context.Context, query string, args, v any) error // Query executes a query that returns rows, typically a SELECT in SQL. // It scans the result into the pointer v. For SQL drivers, it is *dialect/sql.Rows. Query(ctx context.Context, query string, args, v any) error } // Driver is the interface that wraps all necessary operations for ent clients. type Driver interface { ExecQuerier // Tx starts and returns a new transaction. // The provided context is used until the transaction is committed or rolled back. Tx(context.Context) (Tx, error) // Close closes the underlying connection. Close() error // Dialect returns the dialect name of the driver. Dialect() string } // Tx wraps the Exec and Query operations in transaction. type Tx interface { ExecQuerier driver.Tx } type nopTx struct { Driver } func (nopTx) Commit() error { return nil } func (nopTx) Rollback() error { return nil } // NopTx returns a Tx with a no-op Commit / Rollback methods wrapping // the provided Driver d. func NopTx(d Driver) Tx { return nopTx{d} } // DebugDriver is a driver that logs all driver operations. type DebugDriver struct { Driver // underlying driver. log func(context.Context, ...any) // log function. defaults to log.Println. } // Debug gets a driver and an optional logging function, and returns // a new debugged-driver that prints all outgoing operations. func Debug(d Driver, logger ...func(...any)) Driver { logf := log.Println if len(logger) == 1 { logf = logger[0] } drv := &DebugDriver{d, func(_ context.Context, v ...any) { logf(v...) }} return drv } // DebugWithContext gets a driver and a logging function, and returns // a new debugged-driver that prints all outgoing operations with context. func DebugWithContext(d Driver, logger func(context.Context, ...any)) Driver { drv := &DebugDriver{d, logger} return drv } // Exec logs its params and calls the underlying driver Exec method. func (d *DebugDriver) Exec(ctx context.Context, query string, args, v any) error { d.log(ctx, fmt.Sprintf("driver.Exec: query=%v args=%v", query, args)) return d.Driver.Exec(ctx, query, args, v) } // ExecContext logs its params and calls the underlying driver ExecContext method if it is supported. func (d *DebugDriver) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { drv, ok := d.Driver.(interface { ExecContext(context.Context, string, ...any) (sql.Result, error) }) if !ok { return nil, fmt.Errorf("Driver.ExecContext is not supported") } d.log(ctx, fmt.Sprintf("driver.ExecContext: query=%v args=%v", query, args)) return drv.ExecContext(ctx, query, args...) } // Query logs its params and calls the underlying driver Query method. func (d *DebugDriver) Query(ctx context.Context, query string, args, v any) error { d.log(ctx, fmt.Sprintf("driver.Query: query=%v args=%v", query, args)) return d.Driver.Query(ctx, query, args, v) } // QueryContext logs its params and calls the underlying driver QueryContext method if it is supported. func (d *DebugDriver) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { drv, ok := d.Driver.(interface { QueryContext(context.Context, string, ...any) (*sql.Rows, error) }) if !ok { return nil, fmt.Errorf("Driver.QueryContext is not supported") } d.log(ctx, fmt.Sprintf("driver.QueryContext: query=%v args=%v", query, args)) return drv.QueryContext(ctx, query, args...) } // Tx adds an log-id for the transaction and calls the underlying driver Tx command. func (d *DebugDriver) Tx(ctx context.Context) (Tx, error) { tx, err := d.Driver.Tx(ctx) if err != nil { return nil, err } id := uuid.New().String() d.log(ctx, fmt.Sprintf("driver.Tx(%s): started", id)) return &DebugTx{tx, id, d.log, ctx}, nil } // BeginTx adds an log-id for the transaction and calls the underlying driver BeginTx command if it is supported. func (d *DebugDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { drv, ok := d.Driver.(interface { BeginTx(context.Context, *sql.TxOptions) (Tx, error) }) if !ok { return nil, fmt.Errorf("Driver.BeginTx is not supported") } tx, err := drv.BeginTx(ctx, opts) if err != nil { return nil, err } id := uuid.New().String() d.log(ctx, fmt.Sprintf("driver.BeginTx(%s): started", id)) return &DebugTx{tx, id, d.log, ctx}, nil } // DebugTx is a transaction implementation that logs all transaction operations. type DebugTx struct { Tx // underlying transaction. id string // transaction logging id. log func(context.Context, ...any) // log function. defaults to fmt.Println. ctx context.Context // underlying transaction context. } // Exec logs its params and calls the underlying transaction Exec method. func (d *DebugTx) Exec(ctx context.Context, query string, args, v any) error { d.log(ctx, fmt.Sprintf("Tx(%s).Exec: query=%v args=%v", d.id, query, args)) return d.Tx.Exec(ctx, query, args, v) } // ExecContext logs its params and calls the underlying transaction ExecContext method if it is supported. func (d *DebugTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { drv, ok := d.Tx.(interface { ExecContext(context.Context, string, ...any) (sql.Result, error) }) if !ok { return nil, fmt.Errorf("Tx.ExecContext is not supported") } d.log(ctx, fmt.Sprintf("Tx(%s).ExecContext: query=%v args=%v", d.id, query, args)) return drv.ExecContext(ctx, query, args...) } // Query logs its params and calls the underlying transaction Query method. func (d *DebugTx) Query(ctx context.Context, query string, args, v any) error { d.log(ctx, fmt.Sprintf("Tx(%s).Query: query=%v args=%v", d.id, query, args)) return d.Tx.Query(ctx, query, args, v) } // QueryContext logs its params and calls the underlying transaction QueryContext method if it is supported. func (d *DebugTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { drv, ok := d.Tx.(interface { QueryContext(context.Context, string, ...any) (*sql.Rows, error) }) if !ok { return nil, fmt.Errorf("Tx.QueryContext is not supported") } d.log(ctx, fmt.Sprintf("Tx(%s).QueryContext: query=%v args=%v", d.id, query, args)) return drv.QueryContext(ctx, query, args...) } // Commit logs this step and calls the underlying transaction Commit method. func (d *DebugTx) Commit() error { d.log(d.ctx, fmt.Sprintf("Tx(%s): committed", d.id)) return d.Tx.Commit() } // Rollback logs this step and calls the underlying transaction Rollback method. func (d *DebugTx) Rollback() error { d.log(d.ctx, fmt.Sprintf("Tx(%s): rollbacked", d.id)) return d.Tx.Rollback() } ent-0.11.3/dialect/entsql/000077500000000000000000000000001431500740500152745ustar00rootroot00000000000000ent-0.11.3/dialect/entsql/annotation.go000066400000000000000000000266641431500740500200130ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package entsql import "entgo.io/ent/schema" // Annotation is a builtin schema annotation for attaching // SQL metadata to schema objects for both codegen and runtime. type Annotation struct { // The Table option allows overriding the default table // name that is generated by ent. For example: // // entsql.Annotation{ // Table: "Users", // } // Table string `json:"table,omitempty"` // Charset defines the character-set of the table. For example: // // entsql.Annotation{ // Charset: "utf8mb4", // } // Charset string `json:"charset,omitempty"` // Collation defines the collation of the table (a set of rules for comparing // characters in a character set). For example: // // entsql.Annotation{ // Collation: "utf8mb4_bin", // } // Collation string `json:"collation,omitempty"` // Default specifies the default value of a column. Note that using this option // will override the default behavior of the code-generation. For example: // // entsql.Annotation{ // Default: "CURRENT_TIMESTAMP", // } // // entsql.Annotation{ // Default: "uuid_generate_v4()", // } // Default string `json:"default,omitempty"` // Options defines the additional table options. For example: // // entsql.Annotation{ // Options: "ENGINE = INNODB", // } // Options string `json:"options,omitempty"` // Size defines the column size in the generated schema. For example: // // entsql.Annotation{ // Size: 128, // } // Size int64 `json:"size,omitempty"` // Incremental defines the auto-incremental behavior of a column. For example: // // incrementalEnabled := true // entsql.Annotation{ // Incremental: &incrementalEnabled, // } // // By default, this value is nil defaulting to whatever best fits each scenario. // Incremental *bool `json:"incremental,omitempty"` // OnDelete specifies a custom referential action for DELETE operations on parent // table that has matching rows in the child table. // // For example, in order to delete rows from the parent table and automatically delete // their matching rows in the child table, pass the following annotation: // // entsql.Annotation{ // OnDelete: entsql.Cascade, // } // OnDelete ReferenceOption `json:"on_delete,omitempty"` // Check allows injecting custom "DDL" for setting an unnamed "CHECK" clause in "CREATE TABLE". // // entsql.Annotation{ // Check: "age < 10", // } // Check string `json:"check,omitempty"` // Checks allows injecting custom "DDL" for setting named "CHECK" clauses in "CREATE TABLE". // // entsql.Annotation{ // Checks: map[string]string{ // "valid_discount": "price > discount_price", // }, // } // Checks map[string]string `json:"checks,omitempty"` } // Name describes the annotation name. func (Annotation) Name() string { return "EntSQL" } // Merge implements the schema.Merger interface. func (a Annotation) Merge(other schema.Annotation) schema.Annotation { var ant Annotation switch other := other.(type) { case Annotation: ant = other case *Annotation: if other != nil { ant = *other } default: return a } if t := ant.Table; t != "" { a.Table = t } if c := ant.Charset; c != "" { a.Charset = c } if c := ant.Collation; c != "" { a.Collation = c } if o := ant.Options; o != "" { a.Options = o } if s := ant.Size; s != 0 { a.Size = s } if i := ant.Incremental; i != nil { a.Incremental = i } if od := ant.OnDelete; od != "" { a.OnDelete = od } if c := ant.Check; c != "" { a.Check = c } if checks := ant.Checks; len(checks) > 0 { if a.Checks == nil { a.Checks = make(map[string]string) } for name, check := range checks { a.Checks[name] = check } } return a } var _ interface { schema.Annotation schema.Merger } = (*Annotation)(nil) // ReferenceOption for constraint actions. type ReferenceOption string // Reference options (actions) specified by ON UPDATE and ON DELETE // subclauses of the FOREIGN KEY clause. const ( NoAction ReferenceOption = "NO ACTION" Restrict ReferenceOption = "RESTRICT" Cascade ReferenceOption = "CASCADE" SetNull ReferenceOption = "SET NULL" SetDefault ReferenceOption = "SET DEFAULT" ) // IndexAnnotation is a builtin schema annotation for attaching // SQL metadata to schema indexes for both codegen and runtime. type IndexAnnotation struct { // Prefix defines a column prefix for a single string column index. // In MySQL, the following annotation maps to: // // index.Fields("column"). // Annotation(entsql.Prefix(100)) // // CREATE INDEX `table_column` ON `table`(`column`(100)) // Prefix uint // PrefixColumns defines column prefixes for a multi-column index. // In MySQL, the following annotation maps to: // // index.Fields("c1", "c2", "c3"). // Annotation( // entsql.PrefixColumn("c1", 100), // entsql.PrefixColumn("c2", 200), // ) // // CREATE INDEX `table_c1_c2_c3` ON `table`(`c1`(100), `c2`(200), `c3`) // PrefixColumns map[string]uint // Desc defines the DESC clause for a single column index. // In MySQL, the following annotation maps to: // // index.Fields("column"). // Annotation(entsql.Desc()) // // CREATE INDEX `table_column` ON `table`(`column` DESC) // Desc bool // DescColumns defines the DESC clause for columns in multi-column index. // In MySQL, the following annotation maps to: // // index.Fields("c1", "c2", "c3"). // Annotation( // entsql.DescColumns("c1", "c2"), // ) // // CREATE INDEX `table_c1_c2_c3` ON `table`(`c1` DESC, `c2` DESC, `c3`) // DescColumns map[string]bool // IncludeColumns defines the INCLUDE clause for the index. // Works only in Postgres and its definition is as follows: // // index.Fields("c1"). // Annotation( // entsql.IncludeColumns("c2"), // ) // // CREATE INDEX "table_column" ON "table"("c1") INCLUDE ("c2") // IncludeColumns []string // Type defines the type of the index. // In MySQL, the following annotation maps to: // // index.Fields("c1"). // Annotation( // entsql.IndexType("FULLTEXT"), // ) // // CREATE FULLTEXT INDEX `table_c1` ON `table`(`c1`) // Type string // Types is like the Type option but allows mapping an index-type per dialect. // // index.Fields("c1"). // Annotation( // entsql.IndexTypes(map[string]string{ // dialect.MySQL: "FULLTEXT", // dialect.Postgres: "GIN", // }), // ) // Types map[string]string // IndexWhere allows configuring partial indexes in SQLite and PostgreSQL. // Read more: https://postgresql.org/docs/current/indexes-partial.html. // // Note that the `WHERE` clause should be defined exactly like it is // stored in the database (i.e. normal form). Read more about this on // the Atlas website: https://atlasgo.io/concepts/dev-database#diffing. // // index.Fields("a"). // Annotations( // entsql.IndexWhere("b AND c > 0"), // ) // CREATE INDEX "table_a" ON "table"("a") WHERE (b AND c > 0) Where string } // Prefix returns a new index annotation with a single string column index. // In MySQL, the following annotation maps to: // // index.Fields("column"). // Annotation(entsql.Prefix(100)) // // CREATE INDEX `table_column` ON `table`(`column`(100)) func Prefix(prefix uint) *IndexAnnotation { return &IndexAnnotation{ Prefix: prefix, } } // PrefixColumn returns a new index annotation with column prefix for // multi-column indexes. In MySQL, the following annotation maps to: // // index.Fields("c1", "c2", "c3"). // Annotation( // entsql.PrefixColumn("c1", 100), // entsql.PrefixColumn("c2", 200), // ) // // CREATE INDEX `table_c1_c2_c3` ON `table`(`c1`(100), `c2`(200), `c3`) func PrefixColumn(name string, prefix uint) *IndexAnnotation { return &IndexAnnotation{ PrefixColumns: map[string]uint{ name: prefix, }, } } // Desc returns a new index annotation with the DESC clause for a // single column index. In MySQL, the following annotation maps to: // // index.Fields("column"). // Annotation(entsql.Desc()) // // CREATE INDEX `table_column` ON `table`(`column` DESC) func Desc() *IndexAnnotation { return &IndexAnnotation{ Desc: true, } } // DescColumns returns a new index annotation with the DESC clause attached to // the columns in the index. In MySQL, the following annotation maps to: // // index.Fields("c1", "c2", "c3"). // Annotation( // entsql.DescColumns("c1", "c2"), // ) // // CREATE INDEX `table_c1_c2_c3` ON `table`(`c1` DESC, `c2` DESC, `c3`) func DescColumns(names ...string) *IndexAnnotation { ant := &IndexAnnotation{ DescColumns: make(map[string]bool, len(names)), } for i := range names { ant.DescColumns[names[i]] = true } return ant } // IncludeColumns defines the INCLUDE clause for the index. // Works only in Postgres and its definition is as follows: // // index.Fields("c1"). // Annotation( // entsql.IncludeColumns("c2"), // ) // // CREATE INDEX "table_column" ON "table"("c1") INCLUDE ("c2") func IncludeColumns(names ...string) *IndexAnnotation { return &IndexAnnotation{IncludeColumns: names} } // IndexType defines the type of the index. // In MySQL, the following annotation maps to: // // index.Fields("c1"). // Annotation( // entsql.IndexType("FULLTEXT"), // ) // // CREATE FULLTEXT INDEX `table_c1` ON `table`(`c1`) func IndexType(t string) *IndexAnnotation { return &IndexAnnotation{Type: t} } // IndexTypes is like the Type option but allows mapping an index-type per dialect. // // index.Fields("c1"). // Annotations( // entsql.IndexTypes(map[string]string{ // dialect.MySQL: "FULLTEXT", // dialect.Postgres: "GIN", // }), // ) func IndexTypes(types map[string]string) *IndexAnnotation { return &IndexAnnotation{Types: types} } // IndexWhere allows configuring partial indexes in SQLite and PostgreSQL. // Read more: https://postgresql.org/docs/current/indexes-partial.html. // // Note that the `WHERE` clause should be defined exactly like it is // stored in the database (i.e. normal form). Read more about this on the // Atlas website: https://atlasgo.io/concepts/dev-database#diffing. // // index.Fields("a"). // Annotations( // entsql.IndexWhere("b AND c > 0"), // ) // CREATE INDEX "table_a" ON "table"("a") WHERE (b AND c > 0) func IndexWhere(pred string) *IndexAnnotation { return &IndexAnnotation{Where: pred} } // Name describes the annotation name. func (IndexAnnotation) Name() string { return "EntSQLIndexes" } // Merge implements the schema.Merger interface. func (a IndexAnnotation) Merge(other schema.Annotation) schema.Annotation { var ant IndexAnnotation switch other := other.(type) { case IndexAnnotation: ant = other case *IndexAnnotation: if other != nil { ant = *other } default: return a } if ant.Prefix != 0 { a.Prefix = ant.Prefix } if ant.PrefixColumns != nil { if a.PrefixColumns == nil { a.PrefixColumns = make(map[string]uint) } for column, prefix := range ant.PrefixColumns { a.PrefixColumns[column] = prefix } } if ant.Desc { a.Desc = ant.Desc } if ant.DescColumns != nil { if a.DescColumns == nil { a.DescColumns = make(map[string]bool) } for column, desc := range ant.DescColumns { a.DescColumns[column] = desc } } if ant.IncludeColumns != nil { a.IncludeColumns = append(a.IncludeColumns, ant.IncludeColumns...) } if ant.Type != "" { a.Type = ant.Type } if ant.Types != nil { a.Types = ant.Types } if ant.Where != "" { a.Where = ant.Where } return a } var _ interface { schema.Annotation schema.Merger } = (*IndexAnnotation)(nil) ent-0.11.3/dialect/gremlin/000077500000000000000000000000001431500740500154235ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/client.go000066400000000000000000000045231431500740500172340ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "fmt" "net/http" ) // RoundTripper is an interface representing the ability to execute a // single gremlin transaction, obtaining the Response for a given Request. type RoundTripper interface { RoundTrip(context.Context, *Request) (*Response, error) } // The RoundTripperFunc type is an adapter to allow the use of ordinary functions as Gremlin RoundTripper. type RoundTripperFunc func(context.Context, *Request) (*Response, error) // RoundTrip calls f(ctx, r). func (f RoundTripperFunc) RoundTrip(ctx context.Context, r *Request) (*Response, error) { return f(ctx, r) } // Interceptor provides a hook to intercept the execution of a Gremlin Request. type Interceptor func(RoundTripper) RoundTripper // A Client is a gremlin client. type Client struct { // Transport specifies the mechanism by which individual // Gremlin requests are made. Transport RoundTripper } // MaxResponseSize defines the maximum response size allowed. const MaxResponseSize = 2 << 20 // NewClient creates a gremlin client from config and options. func NewClient(cfg Config, opt ...Option) (*Client, error) { return cfg.Build(opt...) } // NewHTTPClient creates an http based gremlin client. func NewHTTPClient(url string, client *http.Client) (*Client, error) { transport, err := NewHTTPTransport(url, client) if err != nil { return nil, err } return &Client{transport}, nil } // Do sends a gremlin request and returns a gremlin response. func (c Client) Do(ctx context.Context, req *Request) (*Response, error) { rsp, err := c.Transport.RoundTrip(ctx, req) if err == nil { err = rsp.Err() } // If we got an error, and the context has been canceled, // the context's error is probably more useful. if err != nil && ctx.Err() != nil { err = ctx.Err() } return rsp, err } // Query issues an eval request via the Do function. func (c Client) Query(ctx context.Context, query string) (*Response, error) { return c.Do(ctx, NewEvalRequest(query)) } // Queryf formats a query string and invokes Query. func (c Client) Queryf(ctx context.Context, format string, args ...any) (*Response, error) { return c.Query(ctx, fmt.Sprintf(format, args...)) } ent-0.11.3/dialect/gremlin/client_test.go000066400000000000000000000044721431500740500202760ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "io" "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestNewClient(t *testing.T) { var cfg Config cfg.Endpoint.URL, _ = url.Parse("http://gremlin-server/gremlin") c, err := NewClient(cfg) assert.NotNil(t, c) assert.NoError(t, err) } type mockRoundTripper struct{ mock.Mock } func (m *mockRoundTripper) RoundTrip(ctx context.Context, req *Request) (*Response, error) { args := m.Called(ctx, req) return args.Get(0).(*Response), args.Error(1) } func TestClientRequest(t *testing.T) { ctx := context.Background() req, rsp := &Request{}, &Response{} var m mockRoundTripper m.On("RoundTrip", ctx, req). Run(func(mock.Arguments) { rsp.Status.Code = StatusSuccess }). Return(rsp, nil). Once() defer m.AssertExpectations(t) response, err := Client{&m}.Do(context.Background(), req) assert.NoError(t, err) assert.Equal(t, rsp, response) } func TestClientResponseError(t *testing.T) { rsp := &Response{} var m mockRoundTripper m.On("RoundTrip", mock.Anything, mock.Anything). Run(func(mock.Arguments) { rsp.Status.Code = StatusServerError }). Return(rsp, nil). Once() defer m.AssertExpectations(t) _, err := Client{&m}.Do(context.Background(), nil) assert.Error(t, err) } func TestClientCanceledContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var m mockRoundTripper m.On("RoundTrip", ctx, mock.Anything). Run(func(mock.Arguments) { cancel() }). Return(&Response{}, io.ErrUnexpectedEOF). Once() defer m.AssertExpectations(t) _, err := Client{&m}.Query(ctx, "g.E()") assert.EqualError(t, err, context.Canceled.Error()) } func TestClientQuery(t *testing.T) { rsp := &Response{} rsp.Status.Code = StatusNoContent var m mockRoundTripper m.On("RoundTrip", mock.Anything, mock.Anything). Run(func(args mock.Arguments) { req := args.Get(1).(*Request) assert.Equal(t, "g.V(1)", req.Arguments[ArgsGremlin]) }). Return(rsp, nil). Once() defer m.AssertExpectations(t) rsp, err := Client{&m}.Queryf(context.Background(), "g.V(%d)", 1) assert.NotNil(t, rsp) assert.NoError(t, err) } ent-0.11.3/dialect/gremlin/config.go000066400000000000000000000040361431500740500172220ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "fmt" "net/http" "net/url" ) type ( // Config offers a declarative way to construct a client. Config struct { Endpoint Endpoint `env:"ENDPOINT" long:"endpoint" default:"" description:"gremlin endpoint to connect to"` DisableExpansion bool `env:"DISABLE_EXPANSION" long:"disable-expansion" description:"disable bindings expansion"` } // An Option configured client. Option func(*options) options struct { interceptors []Interceptor httpClient *http.Client } // Endpoint wraps a url to add flag unmarshaling. Endpoint struct { *url.URL } ) // WithInterceptor adds interceptors to the client's transport. func WithInterceptor(interceptors ...Interceptor) Option { return func(opts *options) { opts.interceptors = append(opts.interceptors, interceptors...) } } // WithHTTPClient assigns underlying http client to be used by http transport. func WithHTTPClient(client *http.Client) Option { return func(opts *options) { opts.httpClient = client } } // Build constructs a client from Config. func (cfg Config) Build(opt ...Option) (c *Client, err error) { opts := cfg.buildOptions(opt) switch cfg.Endpoint.Scheme { case "http", "https": c, err = NewHTTPClient(cfg.Endpoint.String(), opts.httpClient) default: err = fmt.Errorf("unsupported endpoint scheme: %s", cfg.Endpoint.Scheme) } if err != nil { return nil, err } for i := len(opts.interceptors) - 1; i >= 0; i-- { c.Transport = opts.interceptors[i](c.Transport) } if !cfg.DisableExpansion { c.Transport = ExpandBindings(c.Transport) } return c, nil } func (Config) buildOptions(opts []Option) options { var o options for _, opt := range opts { opt(&o) } return o } // UnmarshalFlag implements flag.Unmarshaler interface. func (ep *Endpoint) UnmarshalFlag(value string) (err error) { ep.URL, err = url.Parse(value) return } ent-0.11.3/dialect/gremlin/config_test.go000066400000000000000000000067221431500740500202650ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "errors" "net/http" "net/url" "testing" "github.com/jessevdk/go-flags" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestConfigParsing(t *testing.T) { var cfg Config _, err := flags.ParseArgs(&cfg, []string{ "--disable-expansion", "--endpoint", "http://localhost:8182/gremlin", }) assert.NoError(t, err) assert.True(t, cfg.DisableExpansion) assert.Equal(t, "http", cfg.Endpoint.Scheme) assert.Equal(t, "http://localhost:8182/gremlin", cfg.Endpoint.String()) cfg = Config{} _, err = flags.ParseArgs(&cfg, nil) assert.NoError(t, err) assert.NotNil(t, cfg.Endpoint.URL) } func TestConfigBuild(t *testing.T) { tests := []struct { name string cfg Config opts []Option wantErr bool }{ { name: "HTTP", cfg: Config{ Endpoint: Endpoint{ URL: func() *url.URL { u, _ := url.Parse("http://gremlin-server/gremlin") return u }(), }, }, }, { name: "NoScheme", cfg: Config{ Endpoint: Endpoint{ URL: &url.URL{}, }, }, wantErr: true, }, { name: "BadScheme", cfg: Config{ Endpoint: Endpoint{ URL: &url.URL{ Scheme: "bad", }, }, }, wantErr: true, }, { name: "WithOptions", cfg: Config{ Endpoint: Endpoint{ URL: func() *url.URL { u, _ := url.Parse("http://gremlin-server/gremlin") return u }(), }, DisableExpansion: true, }, opts: []Option{WithHTTPClient(&http.Client{})}, }, { name: "NoExpansion", cfg: Config{ Endpoint: Endpoint{ URL: &url.URL{ Scheme: "bad", }, }, DisableExpansion: true, }, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { client, err := tc.cfg.Build(tc.opts...) if !tc.wantErr { assert.NotNil(t, client) assert.NoError(t, err) } else { assert.Error(t, err) } }) } } type testRoundTripper struct{ mock.Mock } func (rt *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { args := rt.Called(req) rsp, _ := args.Get(0).(*http.Response) return rsp, args.Error(1) } func TestBuildWithHTTPClient(t *testing.T) { var transport testRoundTripper transport.On("RoundTrip", mock.Anything). Return(nil, errors.New("noop")). Once() defer transport.AssertExpectations(t) u, err := url.Parse("http://gremlin-server:8182/gremlin") require.NoError(t, err) client, err := Config{Endpoint: Endpoint{u}}. Build(WithHTTPClient(&http.Client{Transport: &transport})) require.NoError(t, err) _, _ = client.Do(context.Background(), NewEvalRequest("g.V()")) } func TestExpandOrdering(t *testing.T) { var cfg Config cfg.Endpoint.URL, _ = url.Parse("http://gremlin-server/gremlin") interceptor := func(RoundTripper) RoundTripper { return RoundTripperFunc(func(ctx context.Context, req *Request) (*Response, error) { assert.Equal(t, `g.V().hasLabel("user")`, req.Arguments[ArgsGremlin]) assert.Nil(t, req.Arguments[ArgsBindings]) return nil, errors.New("noop") }) } c, err := cfg.Build(WithInterceptor(interceptor)) require.NoError(t, err) req := NewEvalRequest("g.V().hasLabel($1)", WithBindings(map[string]any{"$1": "user"})) _, _ = c.Do(context.Background(), req) } ent-0.11.3/dialect/gremlin/driver.go000066400000000000000000000032021431500740500172420ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "fmt" "entgo.io/ent/dialect" "entgo.io/ent/dialect/gremlin/graph/dsl" ) // Driver is a dialect.Driver implementation for TinkerPop gremlin. type Driver struct { *Client } // NewDriver returns a new dialect.Driver implementation for gremlin. func NewDriver(c *Client) *Driver { c.Transport = ExpandBindings(c.Transport) return &Driver{c} } // Dialect implements the dialect.Dialect method. func (Driver) Dialect() string { return dialect.Gremlin } // Exec implements the dialect.Exec method. func (c *Driver) Exec(ctx context.Context, query string, args, v any) error { vr, ok := v.(*Response) if !ok { return fmt.Errorf("dialect/gremlin: invalid type %T. expect *gremlin.Response", v) } bindings, ok := args.(dsl.Bindings) if !ok { return fmt.Errorf("dialect/gremlin: invalid type %T. expect map[string]any for bindings", args) } res, err := c.Do(ctx, NewEvalRequest(query, WithBindings(bindings))) if err != nil { return err } *vr = *res return nil } // Query implements the dialect.Query method. func (c *Driver) Query(ctx context.Context, query string, args, v any) error { return c.Exec(ctx, query, args, v) } // Close is a nop close call. It should close the connection in case of WS client. func (Driver) Close() error { return nil } // Tx returns a nop transaction. func (c *Driver) Tx(context.Context) (dialect.Tx, error) { return dialect.NopTx(c), nil } var _ dialect.Driver = (*Driver)(nil) ent-0.11.3/dialect/gremlin/encoding/000077500000000000000000000000001431500740500172115ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/encoding/graphson/000077500000000000000000000000001431500740500210325ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/encoding/graphson/bench_test.go000066400000000000000000000034701431500740500235030ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" jsoniter "github.com/json-iterator/go" ) type book struct { ID string `json:"id" graphson:"g:UUID"` Title string `json:"title"` Author string `json:"author"` Pages int `json:"num_pages"` Chapters []string `json:"chapters"` } func generateObject() *book { return &book{ ID: "21d5dcbf-1fd4-493e-9b74-d6c429f9e4a5", Title: "The Art of Computer Programming, Vol. 2", Author: "Donald E. Knuth", Pages: 784, Chapters: []string{"Random numbers", "Arithmetic"}, } } func BenchmarkMarshalObject(b *testing.B) { obj := generateObject() b.ReportAllocs() b.ResetTimer() for n := 0; n < b.N; n++ { _, err := Marshal(obj) if err != nil { b.Fatal(err) } } } func BenchmarkUnmarshalObject(b *testing.B) { b.ReportAllocs() out, err := Marshal(generateObject()) if err != nil { b.Fatal(err) } obj := &book{} b.ResetTimer() for n := 0; n < b.N; n++ { err = Unmarshal(out, obj) if err != nil { b.Fatal(err) } } } func BenchmarkMarshalInterface(b *testing.B) { b.ReportAllocs() data, err := jsoniter.Marshal(generateObject()) if err != nil { b.Fatal(err) } var obj any if err = jsoniter.Unmarshal(data, &obj); err != nil { b.Fatal(err) } b.ResetTimer() for n := 0; n < b.N; n++ { _, err = Marshal(obj) if err != nil { b.Fatal(err) } } } func BenchmarkUnmarshalInterface(b *testing.B) { b.ReportAllocs() data, err := Marshal(generateObject()) if err != nil { b.Fatal(err) } var obj any b.ResetTimer() for n := 0; n < b.N; n++ { err = Unmarshal(data, &obj) if err != nil { b.Fatal(err) } } } ent-0.11.3/dialect/gremlin/encoding/graphson/common_test.go000066400000000000000000000024601431500740500237120ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "unsafe" jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/mock" ) type mocker struct { mock.Mock } // Encode belongs to jsoniter.ValEncoder interface. func (m *mocker) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { m.Called(ptr, stream) } // IsEmpty belongs to jsoniter.ValEncoder interface. func (m *mocker) IsEmpty(ptr unsafe.Pointer) bool { args := m.Called(ptr) return args.Bool(0) } // Decode implements jsoniter.ValDecoder interface. func (m *mocker) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { m.Called(ptr, iter) } // CheckType implements typeChecker interface. func (m *mocker) CheckType(typ Type) error { args := m.Called(typ) return args.Error(0) } // MarshalGraphson implements Marshaler interface. func (m *mocker) MarshalGraphson() ([]byte, error) { args := m.Called() data, err := args.Get(0), args.Error(1) if data == nil { return nil, err } return data.([]byte), err } // UnmarshalGraphson implements Unmarshaler interface. func (m *mocker) UnmarshalGraphson(data []byte) error { args := m.Called(data) return args.Error(0) } ent-0.11.3/dialect/gremlin/encoding/graphson/decode.go000066400000000000000000000056031431500740500226100ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "io" "reflect" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) type decodeExtension struct { jsoniter.DummyExtension } // Unmarshal parses the graphson encoded data and stores the result // in the value pointed to by v. func Unmarshal(data []byte, v any) error { return config.Unmarshal(data, v) } // UnmarshalFromString parses the graphson encoded str and stores the result // in the value pointed to by v. func UnmarshalFromString(str string, v any) error { return config.UnmarshalFromString(str, v) } // Decoder defines a graphson decoder. type Decoder interface { Decode(any) error } // NewDecoder create a graphson decoder. func NewDecoder(r io.Reader) Decoder { return config.NewDecoder(r) } // Unmarshaler is the interface implemented by types // that can unmarshal a graphson description of themselves. type Unmarshaler interface { UnmarshalGraphson([]byte) error } // UpdateStructDescriptor decorates struct field encoders for graphson tags. func (ext decodeExtension) UpdateStructDescriptor(desc *jsoniter.StructDescriptor) { for _, binding := range desc.Fields { if tag, ok := binding.Field.Tag().Lookup("graphson"); ok && tag != "-" { if dec := ext.DecoratorOfStructField(binding.Decoder, tag); dec != nil { binding.Decoder = dec } } } } // CreateDecoder returns a value decoder for type. func (ext decodeExtension) CreateDecoder(typ reflect2.Type) jsoniter.ValDecoder { if dec := ext.DecoderOfRegistered(typ); dec != nil { return dec } if dec := ext.DecoderOfUnmarshaler(typ); dec != nil { return dec } if dec := ext.DecoderOfNative(typ); dec != nil { return dec } switch typ.Kind() { case reflect.Array: return ext.DecoderOfArray(typ) case reflect.Slice: return ext.DecoderOfSlice(typ) case reflect.Map: return ext.DecoderOfMap(typ) default: return nil } } // DecorateDecoder decorates an passed in value decoder for type. func (ext decodeExtension) DecorateDecoder(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { if dec := ext.DecoratorOfRegistered(dec); dec != nil { return dec } if dec := ext.DecoratorOfUnmarshaler(typ, dec); dec != nil { return dec } if dec := ext.DecoratorOfTyper(typ, dec); dec != nil { return dec } if dec := ext.DecoratorOfNative(typ, dec); dec != nil { return dec } switch typ.Kind() { case reflect.Ptr, reflect.Struct: return dec case reflect.Interface: return ext.DecoratorOfInterface(typ, dec) case reflect.Slice: return ext.DecoratorOfSlice(typ, dec) case reflect.Array: return ext.DecoratorOfArray(dec) case reflect.Map: return ext.DecoratorOfMap(dec) default: return ext.DecoderOfError("graphson: unsupported type: " + typ.String()) } } ent-0.11.3/dialect/gremlin/encoding/graphson/decode_test.go000066400000000000000000000003241431500740500236420ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson ent-0.11.3/dialect/gremlin/encoding/graphson/encode.go000066400000000000000000000050131431500740500226150ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "io" "reflect" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) type encodeExtension struct { jsoniter.DummyExtension } // Marshal returns the graphson encoding of v. func Marshal(v any) ([]byte, error) { return config.Marshal(v) } // MarshalToString returns the graphson encoding of v as string. func MarshalToString(v any) (string, error) { return config.MarshalToString(v) } // Encoder defines a graphson encoder. type Encoder interface { Encode(any) error } // NewEncoder create a graphson encoder. func NewEncoder(w io.Writer) Encoder { return config.NewEncoder(w) } // Marshaler is the interface implemented by types that // can marshal themselves as graphson. type Marshaler interface { MarshalGraphson() ([]byte, error) } // UpdateStructDescriptor decorates struct field encoders for graphson tags. func (ext encodeExtension) UpdateStructDescriptor(desc *jsoniter.StructDescriptor) { for _, binding := range desc.Fields { if tag, ok := binding.Field.Tag().Lookup("graphson"); ok && tag != "-" { if enc := ext.DecoratorOfStructField(binding.Encoder, tag); enc != nil { binding.Encoder = enc } } } } // CreateEncoder returns a value encoder for type. func (ext encodeExtension) CreateEncoder(typ reflect2.Type) jsoniter.ValEncoder { if enc := ext.EncoderOfRegistered(typ); enc != nil { return enc } if enc := ext.EncoderOfNative(typ); enc != nil { return enc } switch typ.Kind() { case reflect.Map: return ext.EncoderOfMap(typ) default: return nil } } // DecorateEncoder decorates an passed in value encoder for type. func (ext encodeExtension) DecorateEncoder(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { if enc := ext.DecoratorOfRegistered(enc); enc != nil { return enc } if enc := ext.DecoratorOfMarshaler(typ, enc); enc != nil { return enc } if enc := ext.DecoratorOfTyper(typ, enc); enc != nil { return enc } if enc := ext.DecoratorOfNative(typ, enc); enc != nil { return enc } switch typ.Kind() { case reflect.Ptr, reflect.Interface, reflect.Struct: return enc case reflect.Array: return ext.DecoratorOfArray(enc) case reflect.Slice: return ext.DecoratorOfSlice(typ, enc) case reflect.Map: return ext.DecoratorOfMap(enc) default: return ext.EncoderOfError("graphson: unsupported type: " + typ.String()) } } ent-0.11.3/dialect/gremlin/encoding/graphson/encode_test.go000066400000000000000000000005701431500740500236570ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/stretchr/testify/assert" ) func TestEncodeUnsupportedType(t *testing.T) { _, err := Marshal(func() {}) assert.Error(t, err) } ent-0.11.3/dialect/gremlin/encoding/graphson/error.go000066400000000000000000000022271431500740500225150ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "unsafe" jsoniter "github.com/json-iterator/go" ) // EncoderOfError returns a value encoder which always fails to encode. func (encodeExtension) EncoderOfError(format string, args ...any) jsoniter.ValEncoder { return decoratorOfError(format, args...) } // DecoderOfError returns a value decoder which always fails to decode. func (decodeExtension) DecoderOfError(format string, args ...any) jsoniter.ValDecoder { return decoratorOfError(format, args...) } func decoratorOfError(format string, args ...any) errorCodec { err := fmt.Errorf(format, args...) return errorCodec{err} } type errorCodec struct{ error } func (ec errorCodec) Encode(_ unsafe.Pointer, stream *jsoniter.Stream) { if stream.Error == nil { stream.Error = ec.error } } func (errorCodec) IsEmpty(unsafe.Pointer) bool { return false } func (ec errorCodec) Decode(_ unsafe.Pointer, iter *jsoniter.Iterator) { if iter.Error == nil { iter.Error = ec.error } } ent-0.11.3/dialect/gremlin/encoding/graphson/error_test.go000066400000000000000000000014031431500740500235470ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "bytes" "errors" "testing" "github.com/stretchr/testify/assert" ) func TestErrorCodec(t *testing.T) { codec := errorCodec{errors.New("codec error")} assert.False(t, codec.IsEmpty(nil)) var buf bytes.Buffer stream := config.BorrowStream(&buf) defer config.ReturnStream(stream) codec.Encode(nil, stream) assert.Empty(t, buf.Bytes()) assert.EqualError(t, stream.Error, codec.Error()) iter := config.BorrowIterator([]byte{}) defer config.ReturnIterator(iter) codec.Decode(nil, iter) assert.EqualError(t, iter.Error, codec.Error()) } ent-0.11.3/dialect/gremlin/encoding/graphson/extension.go000066400000000000000000000044371431500740500234050ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "reflect" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) var ( typeEncoders = map[string]jsoniter.ValEncoder{} typeDecoders = map[string]jsoniter.ValDecoder{} ) // RegisterTypeEncoder register type encoder for typ. func RegisterTypeEncoder(typ string, enc jsoniter.ValEncoder) { typeEncoders[typ] = enc } // RegisterTypeDecoder register type decoder for typ. func RegisterTypeDecoder(typ string, dec jsoniter.ValDecoder) { typeDecoders[typ] = dec } type registeredEncoder struct{ jsoniter.ValEncoder } // EncoderOfRegistered returns a value encoder of a registered type. func (encodeExtension) EncoderOfRegistered(typ reflect2.Type) jsoniter.ValEncoder { enc := typeEncoders[typ.String()] if enc != nil { return registeredEncoder{enc} } if typ.Kind() == reflect.Ptr { ptrType := typ.(reflect2.PtrType) enc := typeEncoders[ptrType.Elem().String()] if enc != nil { return registeredEncoder{ ValEncoder: &jsoniter.OptionalEncoder{ ValueEncoder: enc, }, } } } return nil } // DecoratorOfRegistered decorates a value encoder of a registered type. func (encodeExtension) DecoratorOfRegistered(enc jsoniter.ValEncoder) jsoniter.ValEncoder { if _, ok := enc.(registeredEncoder); ok { return enc } return nil } type registeredDecoder struct{ jsoniter.ValDecoder } // DecoderOfRegistered returns a value decoder of a registered type. func (decodeExtension) DecoderOfRegistered(typ reflect2.Type) jsoniter.ValDecoder { dec := typeDecoders[typ.String()] if dec != nil { return registeredDecoder{dec} } if typ.Kind() == reflect.Ptr { ptrType := typ.(reflect2.PtrType) dec := typeDecoders[ptrType.Elem().String()] if dec != nil { return registeredDecoder{ ValDecoder: &jsoniter.OptionalDecoder{ ValueType: ptrType.Elem(), ValueDecoder: dec, }, } } } return nil } // DecoratorOfRegistered decorates a value decoder of a registered type. func (decodeExtension) DecoratorOfRegistered(dec jsoniter.ValDecoder) jsoniter.ValDecoder { if _, ok := dec.(registeredDecoder); ok { return dec } return nil } ent-0.11.3/dialect/gremlin/encoding/graphson/init.go000066400000000000000000000006351431500740500223300ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( jsoniter "github.com/json-iterator/go" ) var config = jsoniter.Config{}.Froze() func init() { config.RegisterExtension(&encodeExtension{}) config.RegisterExtension(&decodeExtension{}) } ent-0.11.3/dialect/gremlin/encoding/graphson/interface.go000066400000000000000000000072011431500740500233210ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "bytes" "errors" "fmt" "io" "reflect" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) // DecoratorOfInterface decorates a value decoder of an interface type. func (decodeExtension) DecoratorOfInterface(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { if _, ok := typ.(*reflect2.UnsafeEFaceType); ok { return efaceDecoder{typ, dec} } return dec } type efaceDecoder struct { typ reflect2.Type jsoniter.ValDecoder } func (dec efaceDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { switch next := iter.WhatIsNext(); next { case jsoniter.StringValue, jsoniter.BoolValue, jsoniter.NilValue: dec.ValDecoder.Decode(ptr, iter) case jsoniter.ObjectValue: dec.decode(ptr, iter) default: iter.ReportError("decode empty interface", fmt.Sprintf("unexpected value type: %d", next)) } } func (dec efaceDecoder) decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { data := iter.SkipAndReturnBytes() if iter.Error != nil && iter.Error != io.EOF { return } rtype, err := dec.reflectBytes(data) if err != nil { iter.ReportError("decode empty interface", err.Error()) return } it := config.BorrowIterator(data) defer config.ReturnIterator(it) var val any if rtype != nil { val = rtype.New() it.ReadVal(val) val = rtype.Indirect(val) } else { if jsoniter.Get(data, TypeKey).LastError() == nil { vk := jsoniter.Get(data, ValueKey) if vk.LastError() == nil { val = vk.GetInterface() } } if val == nil { val = it.Read() } } if it.Error != nil && it.Error != io.EOF { iter.ReportError("decode empty interface", it.Error.Error()) return } // nolint: gas dec.typ.UnsafeSet(ptr, unsafe.Pointer(&val)) } func (dec efaceDecoder) reflectBytes(data []byte) (reflect2.Type, error) { typ := Type(jsoniter.Get(data, TypeKey).ToString()) rtype := dec.reflectType(typ) if rtype != nil { return rtype, nil } switch typ { case listType: return dec.reflectSlice(data) case mapType: return dec.reflectMap(data) default: return nil, nil } } func (efaceDecoder) reflectType(typ Type) reflect2.Type { switch typ { case doubleType: return reflect2.TypeOf(float64(0)) case floatType: return reflect2.TypeOf(float32(0)) case byteType: return reflect2.TypeOf(uint8(0)) case int16Type: return reflect2.TypeOf(int16(0)) case int32Type: return reflect2.TypeOf(int32(0)) case int64Type, bigIntegerType: return reflect2.TypeOf(int64(0)) case byteBufferType: return reflect2.TypeOf([]byte{}) default: return nil } } func (efaceDecoder) reflectSlice(data []byte) (reflect2.Type, error) { var elem any if err := Unmarshal(data, &[...]*any{&elem}); err != nil { return nil, fmt.Errorf("cannot read first list element: %w", err) } if elem == nil { return reflect2.TypeOf([]any{}), nil } sliceType := reflect.SliceOf(reflect.TypeOf(elem)) return reflect2.Type2(sliceType), nil } func (efaceDecoder) reflectMap(data []byte) (reflect2.Type, error) { var key, elem any if err := Unmarshal( bytes.Replace(data, []byte(mapType), []byte(listType), 1), &[...]*any{&key, &elem}, ); err != nil { return nil, fmt.Errorf("cannot unmarshal first map item: %w", err) } if key == nil { return reflect2.TypeOf(map[any]any{}), nil } else if elem == nil { return nil, errors.New("expect map element, but found only key") } mapType := reflect.MapOf(reflect.TypeOf(key), reflect.TypeOf(elem)) return reflect2.Type2(mapType), nil } ent-0.11.3/dialect/gremlin/encoding/graphson/interface_test.go000066400000000000000000000132021431500740500243560ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDecodeInterface(t *testing.T) { tests := []struct { name string in string want any wantErr bool }{ { name: "Boolean", in: "false", want: false, }, { name: "String", in: `"str"`, want: "str", }, { name: "Double", in: `{ "@type": "g:Double", "@value": 3.14 }`, want: float64(3.14), }, { name: "Float", in: `{ "@type": "g:Float", "@value": -22.567 }`, want: float32(-22.567), }, { name: "Int32", in: `{ "@type": "g:Int32", "@value": 9000 }`, want: int32(9000), }, { name: "Int64", in: `{ "@type": "g:Int64", "@value": 188786 }`, want: int64(188786), }, { name: "BigInteger", in: `{ "@type": "gx:BigInteger", "@value": 352353463712 }`, want: int64(352353463712), }, { name: "Byte", in: `{ "@type": "gx:Byte", "@value": 100 }`, want: uint8(100), }, { name: "Int16", in: `{ "@type": "gx:Int16", "@value": 2000 }`, want: int16(2000), }, { name: "UnknownType", in: `{ "@type": "g:T", "@value": "label" }`, want: "label", }, { name: "UntypedArray", in: "[]", wantErr: true, }, { name: "NoType", in: `{ "@typ": "g:Int32", "@value": 345 }`, wantErr: true, }, { name: "BadObject", in: `{ "@type": "g:Int32", "@value": 345 `, wantErr: true, }, { name: "BadList", in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int64", "@val": 123457990 } ] }`, wantErr: true, }, { name: "BadMap", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@val": 123457990 }, "First" ] }`, wantErr: true, }, { name: "KeyOnlyMap", in: `{ "@type": "g:Map", "@value": ["Key"] }`, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var got any err := UnmarshalFromString(tc.in, &got) if !tc.wantErr { require.NoError(t, err) assert.Equal(t, tc.want, got) } else { assert.Error(t, err) } }) } } func TestDecodeInterfaceSlice(t *testing.T) { tests := []struct { in string want any }{ { in: `{ "@type": "g:List", "@value": [] }`, want: []any{}, }, { in: `{ "@type": "g:List", "@value": ["x", "y", "z"] }`, want: []string{"x", "y", "z"}, }, { in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int64", "@value": 123457990 }, { "@type": "g:Int64", "@value": 23456111 }, { "@type": "g:Int64", "@value": -687450 } ] }`, want: []int64{123457990, 23456111, -687450}, }, { in: `{ "@type": "gx:ByteBuffer", "@value": "AQIDBAU=" }`, want: []byte{1, 2, 3, 4, 5}, }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.want), func(t *testing.T) { t.Parallel() var got any err := UnmarshalFromString(tc.in, &got) require.NoError(t, err) assert.Equal(t, tc.want, got) }) } } func TestDecodeInterfaceMap(t *testing.T) { tests := []struct { in string want any }{ { in: `{ "@type": "g:Map", "@value": [] }`, want: map[any]any{}, }, { in: `{ "@type": "g:Map", "@value": [ "Sep", { "@type": "g:Int32", "@value": 9 }, "Oct", { "@type": "g:Int32", "@value": 10 }, "Nov", { "@type": "g:Int32", "@value": 11 } ] }`, want: map[string]int32{ "Sep": int32(9), "Oct": int32(10), "Nov": int32(11), }, }, { in: `{ "@type": "g:Map", "@value": [ "One", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 1 } ] }, "Two", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 2 } ] }, "Three", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 3 } ] } ] }`, want: map[string][]int32{ "One": {1}, "Two": {2}, "Three": {3}, }, }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.want), func(t *testing.T) { t.Parallel() var got any err := UnmarshalFromString(tc.in, &got) require.NoError(t, err) assert.Equal(t, tc.want, got) }) } } func TestDecodeInterfaceObject(t *testing.T) { book := struct { ID string `json:"id" graphson:"g:UUID"` Title string `json:"title"` Author string `json:"author"` Pages int `json:"num_pages"` Chapters []string `json:"chapters"` }{ ID: "21d5dcbf-1fd4-493e-9b74-d6c429f9e4a5", Title: "The Art of Computer Programming, Vol. 2", Author: "Donald E. Knuth", Pages: 784, Chapters: []string{"Random numbers", "Arithmetic"}, } data, err := Marshal(book) require.NoError(t, err) var v any err = Unmarshal(data, &v) require.NoError(t, err) obj := v.(map[string]any) assert.Equal(t, book.ID, obj["id"]) assert.Equal(t, book.Title, obj["title"]) assert.Equal(t, book.Author, obj["author"]) assert.EqualValues(t, book.Pages, obj["num_pages"]) assert.ElementsMatch(t, book.Chapters, obj["chapters"]) } ent-0.11.3/dialect/gremlin/encoding/graphson/lazy.go000066400000000000000000000036331431500740500223450ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "sync" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) // LazyEncoderOf returns a lazy encoder for type. func (encodeExtension) LazyEncoderOf(typ reflect2.Type) jsoniter.ValEncoder { return &lazyEncoder{resolve: func() jsoniter.ValEncoder { return config.EncoderOf(typ) }} } // LazyDecoderOf returns a lazy unique decoder for type. func (decodeExtension) LazyDecoderOf(typ reflect2.Type) jsoniter.ValDecoder { return &lazyDecoder{resolve: func() jsoniter.ValDecoder { dec := config.DecoderOf(reflect2.PtrTo(typ)) if td, ok := dec.(typeDecoder); ok { td.typeChecker = &uniqueType{elemChecker: td.typeChecker} dec = td } return dec }} } type lazyEncoder struct { jsoniter.ValEncoder resolve func() jsoniter.ValEncoder once sync.Once } func (enc *lazyEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { enc.once.Do(func() { enc.ValEncoder = enc.resolve() }) enc.ValEncoder.Encode(ptr, stream) } func (enc *lazyEncoder) IsEmpty(ptr unsafe.Pointer) bool { enc.once.Do(func() { enc.ValEncoder = enc.resolve() }) return enc.ValEncoder.IsEmpty(ptr) } type lazyDecoder struct { jsoniter.ValDecoder resolve func() jsoniter.ValDecoder once sync.Once } func (dec *lazyDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { dec.once.Do(func() { dec.ValDecoder = dec.resolve() }) dec.ValDecoder.Decode(ptr, iter) } type uniqueType struct { typ Type once sync.Once elemChecker typeChecker } func (u *uniqueType) CheckType(other Type) error { u.once.Do(func() { u.typ = other }) if u.typ != other { return fmt.Errorf("expect type %s, but found %s", u.typ, other) } return u.elemChecker.CheckType(u.typ) } ent-0.11.3/dialect/gremlin/encoding/graphson/lazy_test.go000066400000000000000000000021521431500740500233770ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "sync/atomic" "testing" jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestLazyEncode(t *testing.T) { var m mocker m.On("IsEmpty", mock.Anything).Return(false).Once() m.On("Encode", mock.Anything, mock.Anything).Once() defer m.AssertExpectations(t) var cnt uint32 var enc jsoniter.ValEncoder = &lazyEncoder{resolve: func() jsoniter.ValEncoder { assert.Equal(t, uint32(1), atomic.AddUint32(&cnt, 1)) return &m }} enc.IsEmpty(nil) enc.Encode(nil, nil) } func TestLazyDecode(t *testing.T) { var m mocker m.On("Decode", mock.Anything, mock.Anything).Times(3) defer m.AssertExpectations(t) var cnt uint32 var dec jsoniter.ValDecoder = &lazyDecoder{resolve: func() jsoniter.ValDecoder { assert.Equal(t, uint32(1), atomic.AddUint32(&cnt, 1)) return &m }} for i := 0; i < 3; i++ { dec.Decode(nil, nil) } } ent-0.11.3/dialect/gremlin/encoding/graphson/map.go000066400000000000000000000054421431500740500221430ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) // EncoderOfMap returns a value encoder of a map type. func (ext encodeExtension) EncoderOfMap(typ reflect2.Type) jsoniter.ValEncoder { mapType := typ.(reflect2.MapType) return &mapEncoder{ mapType: mapType, keyEnc: ext.LazyEncoderOf(mapType.Key()), elemEnc: ext.LazyEncoderOf(mapType.Elem()), } } // DecoratorOfMap decorates a value encoder of a map type. func (encodeExtension) DecoratorOfMap(enc jsoniter.ValEncoder) jsoniter.ValEncoder { return typeEncoder{enc, mapType} } type mapEncoder struct { mapType reflect2.MapType keyEnc jsoniter.ValEncoder elemEnc jsoniter.ValEncoder } func (enc *mapEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { iter := enc.mapType.UnsafeIterate(ptr) if !iter.HasNext() { stream.WriteEmptyArray() return } stream.WriteArrayStart() for { key, elem := iter.UnsafeNext() enc.keyEnc.Encode(key, stream) stream.WriteMore() enc.elemEnc.Encode(elem, stream) if !iter.HasNext() { break } stream.WriteMore() } stream.WriteArrayEnd() } func (enc *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool { return !enc.mapType.UnsafeIterate(ptr).HasNext() } // DecoderOfMap returns a value decoder of a map type. func (ext decodeExtension) DecoderOfMap(typ reflect2.Type) jsoniter.ValDecoder { mapType := typ.(reflect2.MapType) keyType, elemType := mapType.Key(), mapType.Elem() return &mapDecoder{ mapType: mapType, keyType: keyType, elemType: elemType, keyDec: ext.LazyDecoderOf(keyType), elemDec: ext.LazyDecoderOf(elemType), } } // DecoratorOfMap decorates a value decoder of a map type. func (decodeExtension) DecoratorOfMap(dec jsoniter.ValDecoder) jsoniter.ValDecoder { return typeDecoder{dec, mapType} } type mapDecoder struct { mapType reflect2.MapType keyType reflect2.Type elemType reflect2.Type keyDec jsoniter.ValDecoder elemDec jsoniter.ValDecoder } func (dec *mapDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { mapType := dec.mapType if mapType.UnsafeIsNil(ptr) { mapType.UnsafeSet(ptr, mapType.UnsafeMakeMap(0)) } var key unsafe.Pointer if !iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { if key == nil { key = dec.keyType.UnsafeNew() dec.keyDec.Decode(key, iter) return iter.Error == nil } elem := dec.elemType.UnsafeNew() dec.elemDec.Decode(elem, iter) if iter.Error != nil { return false } mapType.UnsafeSetIndex(ptr, key, elem) key = nil return true }) { return } if key != nil { iter.ReportError("decode map", "odd number of map items") } } ent-0.11.3/dialect/gremlin/encoding/graphson/map_test.go000066400000000000000000000126341431500740500232030ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "strings" "testing" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncodeMap(t *testing.T) { tests := []struct { name string in any want string }{ { name: "simple", in: map[int32]string{ 3: "Mar", 1: "Jan", 2: "Feb", }, want: `[ { "@type": "g:Int32", "@value": 1 }, "Jan", { "@type": "g:Int32", "@value": 2 }, "Feb", { "@type": "g:Int32", "@value": 3 }, "Mar" ]`, }, { name: "mixed", in: map[string]any{ "byte": byte('a'), "string": "str", "slice": []int{1, 2, 3}, "map": map[string]int{}, }, want: `[ "byte", { "@type": "gx:Byte", "@value": 97 }, "string", "str", "slice", { "@type": "g:List", "@value": [ { "@type": "g:Int64", "@value": 1 }, { "@type": "g:Int64", "@value": 2 }, { "@type": "g:Int64", "@value": 3 } ] }, "map", { "@type": "g:Map", "@value": [] } ]`, }, { name: "struct-key", in: map[struct { K string `json:"key"` }]int32{ {"result"}: 42, }, want: `[ { "key": "result" }, { "@type": "g:Int32", "@value": 42 } ]`, }, { name: "nil", in: map[string]uint8(nil), want: "null", }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() data, err := Marshal(tc.in) require.NoError(t, err) assert.Equal(t, "g:Map", jsoniter.Get(data, "@type").ToString()) var want []any err = jsoniter.UnmarshalFromString(tc.want, &want) require.NoError(t, err) got, ok := jsoniter.Get(data, "@value").GetInterface().([]any) require.True(t, ok) assert.ElementsMatch(t, want, got) }) } } func TestDecodeMap(t *testing.T) { tests := []struct { name string in string want any }{ { name: "empty", in: `{ "@type": "g:Map", "@value": [] }`, want: map[int]int{}, }, { name: "simple", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int32", "@value": 6 }, "Jun", { "@type": "g:Int32", "@value": 7 }, "Jul", { "@type": "g:Int32", "@value": 8 }, "Aug" ] }`, want: map[int]string{ 6: "Jun", 7: "Jul", 8: "Aug", }, }, { name: "duplicate", in: `{ "@type": "g:Map", "@value": [ "Sep", { "@type": "g:Int32", "@value": 9 }, "Oct", { "@type": "g:Int32", "@value": 65 }, "Oct", { "@type": "g:Int32", "@value": 10 }, "Nov", null ] }`, want: map[string]*int{ "Sep": func() *int { v := 9; return &v }(), "Oct": func() *int { v := 10; return &v }(), "Nov": nil, }, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() typ := reflect2.TypeOf(tc.want).(reflect2.MapType) got := typ.MakeMap(0) err := UnmarshalFromString(tc.in, got) require.NoError(t, err) assert.Equal(t, tc.want, typ.Indirect(got)) }) } } func TestDecodeMapIntoNil(t *testing.T) { var got map[int64]int32 err := UnmarshalFromString(`{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": -9 }, { "@type": "g:Int64", "@value": 99 }, { "@type": "g:Int32", "@value": -99 }, { "@type": "g:Int64", "@value": 999 }, { "@type": "g:Int32", "@value": -999 } ] }`, &got) require.NoError(t, err) assert.Equal(t, map[int64]int32{9: -9, 99: -99, 999: -999}, got) } func TestDecodeBadMap(t *testing.T) { tests := []struct { name string in string }{ { name: "BadValue", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": "55" } ] }`, }, { name: "NoValue", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": 9 }, { "@type": "g:Int64", "@value": 42 } ] }`, }, { name: "AlterKeyType", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": 9 }, { "@type": "g:Int32", "@value": 42 }, { "@type": "g:Int32", "@value": 42 } ] }`, }, { name: "AlterValType", in: `{ "@type": "g:Map", "@value": [ { "@type": "g:Int64", "@value": 9 }, { "@type": "g:Int32", "@value": 9 }, { "@type": "g:Int64", "@value": 42 }, { "@type": "g:Int64", "@value": 42 } ] }`, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var v map[int]int err := NewDecoder(strings.NewReader(tc.in)).Decode(&v) assert.Error(t, err) }) } } ent-0.11.3/dialect/gremlin/encoding/graphson/marshaler.go000066400000000000000000000067531431500740500233520ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "io" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) // DecoratorOfMarshaler decorates a value encoder of a Marshaler interface. func (ext encodeExtension) DecoratorOfMarshaler(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { if typ == marshalerType { enc := marshalerEncoder{enc, typ} return directMarshalerEncoder{enc} } if typ.Implements(marshalerType) { return marshalerEncoder{enc, typ} } ptrType := reflect2.PtrTo(typ) if ptrType.Implements(marshalerType) { ptrEnc := ext.LazyEncoderOf(ptrType) enc := marshalerEncoder{ptrEnc, ptrType} return referenceEncoder{enc} } return nil } // DecoderOfUnmarshaler returns a value decoder of an Unmarshaler interface. func (decodeExtension) DecoderOfUnmarshaler(typ reflect2.Type) jsoniter.ValDecoder { ptrType := reflect2.PtrTo(typ) if ptrType.Implements(unmarshalerType) { return referenceDecoder{ unmarshalerDecoder{ptrType}, } } return nil } // DecoratorOfUnmarshaler decorates a value encoder of an Unmarshaler interface. func (decodeExtension) DecoratorOfUnmarshaler(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { if reflect2.PtrTo(typ).Implements(unmarshalerType) { return dec } return nil } var ( marshalerType = reflect2.TypeOfPtr((*Marshaler)(nil)).Elem() unmarshalerType = reflect2.TypeOfPtr((*Unmarshaler)(nil)).Elem() ) type marshalerEncoder struct { jsoniter.ValEncoder reflect2.Type } func (enc marshalerEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { marshaler := enc.Type.UnsafeIndirect(ptr).(Marshaler) enc.encode(marshaler, stream) } func (enc marshalerEncoder) encode(marshaler Marshaler, stream *jsoniter.Stream) { data, err := marshaler.MarshalGraphson() if err != nil { stream.Error = fmt.Errorf("graphson: error calling MarshalGraphson for type %s: %w", enc.Type, err) return } if !config.Valid(data) { stream.Error = fmt.Errorf("graphson: syntax error when marshaling type %s", enc.Type) return } _, stream.Error = stream.Write(data) } type directMarshalerEncoder struct { marshalerEncoder } func (enc directMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { marshaler := *(*Marshaler)(ptr) enc.encode(marshaler, stream) } type referenceEncoder struct { jsoniter.ValEncoder } func (enc referenceEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { // nolint: gas enc.ValEncoder.Encode(unsafe.Pointer(&ptr), stream) } func (enc referenceEncoder) IsEmpty(ptr unsafe.Pointer) bool { // nolint: gas return enc.ValEncoder.IsEmpty(unsafe.Pointer(&ptr)) } type unmarshalerDecoder struct { reflect2.Type } func (dec unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { bytes := iter.SkipAndReturnBytes() if iter.Error != nil && iter.Error != io.EOF { return } unmarshaler := dec.UnsafeIndirect(ptr).(Unmarshaler) if err := unmarshaler.UnmarshalGraphson(bytes); err != nil { iter.ReportError( "unmarshal graphson", fmt.Sprintf( "graphson: error calling UnmarshalGraphson for type %s: %s", dec.Type, err, ), ) } } type referenceDecoder struct { jsoniter.ValDecoder } func (dec referenceDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { // nolint: gas dec.ValDecoder.Decode(unsafe.Pointer(&ptr), iter) } ent-0.11.3/dialect/gremlin/encoding/graphson/marshaler_test.go000066400000000000000000000050661431500740500244050ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "errors" "fmt" "reflect" "testing" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestMarshalerEncode(t *testing.T) { want := []byte(`{"@type": "g:Int32", "@value": 42}`) m := &mocker{} call := m.On("MarshalGraphson").Return(want, nil) defer m.AssertExpectations(t) tests := []any{m, &m, func() *Marshaler { marshaler := Marshaler(m); return &marshaler }(), Marshaler(nil)} call.Times(len(tests) - 1) for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc), func(t *testing.T) { got, err := Marshal(tc) assert.NoError(t, err) if !reflect2.IsNil(tc) { assert.Equal(t, want, got) } else { assert.Equal(t, []byte("null"), got) } }) } } func TestMarshalerError(t *testing.T) { errStr := "marshaler error" m := &mocker{} m.On("MarshalGraphson").Return(nil, errors.New(errStr)).Once() defer m.AssertExpectations(t) _, err := Marshal(m) assert.Error(t, err) assert.Contains(t, err.Error(), errStr) } func TestBadMarshaler(t *testing.T) { m := &mocker{} m.On("MarshalGraphson").Return([]byte(`{"@type": "g:Int32", "@value":`), nil).Once() defer m.AssertExpectations(t) _, err := Marshal(m) assert.Error(t, err) } func TestUnmarshalerDecode(t *testing.T) { data := `{"@type": "g:UUID", "@value": "cb682578-9d92-4499-9ebc-5c6aa73c5397"}` var value string m := &mocker{} m.On("UnmarshalGraphson", mock.Anything). Run(func(args mock.Arguments) { data := args.Get(0).([]byte) value = jsoniter.Get(data, "@value").ToString() }). Return(nil). Once() defer m.AssertExpectations(t) err := UnmarshalFromString(data, m) require.NoError(t, err) assert.Equal(t, "cb682578-9d92-4499-9ebc-5c6aa73c5397", value) } func TestUnmarshalerError(t *testing.T) { errStr := "unmarshaler error" m := &mocker{} m.On("UnmarshalGraphson", mock.Anything).Return(errors.New(errStr)).Once() defer m.AssertExpectations(t) err := Unmarshal([]byte(`{}`), m) require.Error(t, err) assert.Contains(t, err.Error(), fmt.Sprintf("graphson: error calling UnmarshalGraphson for type %s: %s", reflect.TypeOf(m), errStr, ), ) } func TestUnmarshalBadInput(t *testing.T) { m := &mocker{} defer m.AssertExpectations(t) err := UnmarshalFromString(`{"@type"}`, m) assert.Error(t, err) } ent-0.11.3/dialect/gremlin/encoding/graphson/native.go000066400000000000000000000071201431500740500226470ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "io" "math" "reflect" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) // EncoderOfNative returns a value encoder of a native type. func (encodeExtension) EncoderOfNative(typ reflect2.Type) jsoniter.ValEncoder { switch typ.Kind() { case reflect.Float64: return float64Encoder{typ} default: return nil } } // DecoratorOfNative decorates a value encoder of a native type. func (encodeExtension) DecoratorOfNative(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { switch typ.Kind() { case reflect.Bool, reflect.String: return enc case reflect.Int64, reflect.Int, reflect.Uint32: return typeEncoder{enc, int64Type} case reflect.Int32, reflect.Int8, reflect.Uint16: return typeEncoder{enc, int32Type} case reflect.Int16: return typeEncoder{enc, int16Type} case reflect.Uint64, reflect.Uint: return typeEncoder{enc, bigIntegerType} case reflect.Uint8: return typeEncoder{enc, byteType} case reflect.Float32: return typeEncoder{enc, floatType} case reflect.Float64: return typeEncoder{enc, doubleType} default: return nil } } // DecoderOfNative returns a value decoder of a native type. func (decodeExtension) DecoderOfNative(typ reflect2.Type) jsoniter.ValDecoder { switch typ.Kind() { case reflect.Float64: return float64Decoder{typ} default: return nil } } // DecoratorOfNative decorates a value decoder of a native type. func (decodeExtension) DecoratorOfNative(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { switch typ.Kind() { case reflect.Bool: return dec case reflect.String: return typeDecoder{dec, typeCheckerFunc(func(Type) error { return nil })} case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return typeDecoder{dec, integerTypes} case reflect.Float32: return typeDecoder{dec, floatTypes} case reflect.Float64: return typeDecoder{dec, doubleTypes} default: return nil } } type float64Encoder struct { reflect2.Type } func (enc float64Encoder) IsEmpty(ptr unsafe.Pointer) bool { return enc.UnsafeIndirect(ptr).(float64) == 0 } func (enc float64Encoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { f := enc.UnsafeIndirect(ptr).(float64) switch { case math.IsNaN(f): stream.WriteString("NaN") case math.IsInf(f, 1): stream.WriteString("Infinity") case math.IsInf(f, -1): stream.WriteString("-Infinity") default: stream.WriteFloat64(f) } } type float64Decoder struct { reflect2.Type } var ( integerTypes = Types{byteType, int16Type, int32Type, int64Type, bigIntegerType} floatTypes = append(integerTypes, floatType, bigDecimal) doubleTypes = append(floatTypes, doubleType) ) func (dec float64Decoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { var val float64 switch next := iter.WhatIsNext(); next { case jsoniter.NumberValue: val = iter.ReadFloat64() case jsoniter.StringValue: switch str := iter.ReadString(); str { case "NaN": val = math.NaN() case "Infinity": val = math.Inf(1) case "-Infinity": val = math.Inf(-1) default: iter.ReportError("decode float64", "invalid value "+str) } default: iter.ReportError("decode float64", fmt.Sprintf("unexpected value type: %d", next)) } if iter.Error == nil || iter.Error == io.EOF { // nolint: gas dec.UnsafeSet(ptr, unsafe.Pointer(&val)) } } ent-0.11.3/dialect/gremlin/encoding/graphson/native_test.go000066400000000000000000000130111431500740500237020ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "math" "testing" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncodeNative(t *testing.T) { tests := []struct { in any want string wantErr bool }{ { in: true, want: "true", }, { in: "hello", want: `"hello"`, }, { in: int8(120), want: `{ "@type": "g:Int32", "@value": 120 }`, }, { in: int16(-16), want: `{ "@type": "gx:Int16", "@value": -16 }`, }, { in: int32(3232), want: `{ "@type": "g:Int32", "@value": 3232 }`, }, { in: int64(646464), want: `{ "@type": "g:Int64", "@value": 646464 }`, }, { in: int(127001), want: `{ "@type": "g:Int64", "@value": 127001 }`, }, { in: uint8(81), want: `{ "@type": "gx:Byte", "@value": 81 }`, }, { in: uint16(12345), want: `{ "@type": "g:Int32", "@value": 12345 }`, }, { in: uint32(123454321), want: `{ "@type": "g:Int64", "@value": 123454321 }`, }, { in: uint64(1234567890), want: `{ "@type": "gx:BigInteger", "@value": 1234567890 }`, }, { in: uint(9876543210), want: `{ "@type" :"gx:BigInteger", "@value": 9876543210 }`, }, { in: float32(math.Pi), want: `{ "@type": "g:Float", "@value": 3.1415927 }`, }, { in: float64(math.E), want: `{ "@type": "g:Double", "@value": 2.718281828459045 }`, }, { in: math.NaN(), want: `{ "@type": "g:Double", "@value": "NaN" }`, }, { in: math.Inf(1), want: `{ "@type": "g:Double", "@value": "Infinity" }`, }, { in: math.Inf(-1), want: `{ "@type": "g:Double", "@value": "-Infinity" }`, }, { in: func() *int { v := 7142; return &v }(), want: `{ "@type": "g:Int64", "@value": 7142 }`, }, { in: func() any { v := int16(6116); return &v }(), want: `{ "@type": "gx:Int16", "@value": 6116 }`, }, { in: nil, want: "null", }, { in: make(chan int), wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.in), func(t *testing.T) { t.Parallel() got, err := MarshalToString(tc.in) if !tc.wantErr { assert.NoError(t, err) assert.JSONEq(t, tc.want, got) } else { assert.Error(t, err) assert.Empty(t, got) } }) } } func TestDecodeNative(t *testing.T) { tests := []struct { in string want any }{ { in: `{"@type": "g:Float", "@value": 3.14}`, want: float32(3.14), }, { in: `{"@type": "g:Float", "@value": "Float"}`, }, { in: `{"@type": "g:Double", "@value": 2.71}`, want: float64(2.71), }, { in: `{"@type": "gx:BigDecimal", "@value": 3.142}`, want: float32(3.142), }, { in: `{"@type": "gx:BigDecimal", "@value": 55512.5176}`, want: float64(55512.5176), }, { in: `{"@type": "g:T", "@value": "world"}`, want: "world", }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.want), func(t *testing.T) { t.Parallel() if tc.want != nil { typ := reflect2.TypeOf(tc.want) got := typ.New() err := UnmarshalFromString(tc.in, got) require.NoError(t, err) assert.Equal(t, tc.want, typ.Indirect(got)) } else { var msg jsoniter.RawMessage err := UnmarshalFromString(tc.in, &msg) assert.Error(t, err) } }) } } func TestDecodeTypeMismatch(t *testing.T) { t.Run("FloatToInt", func(t *testing.T) { var v int err := UnmarshalFromString(`{"@type": "g:Float", "@value": 3.14}`, &v) assert.Error(t, err) }) t.Run("DoubleToFloat", func(t *testing.T) { var v float32 err := UnmarshalFromString(`{"@type": "g:Double", "@value": 5.51}`, &v) assert.Error(t, err) }) t.Run("BigDecimalToUint64", func(t *testing.T) { var v uint64 err := UnmarshalFromString(`{"@type": "gx:BigDecimal", "@value": 5645.51834}`, &v) assert.Error(t, err) }) } func TestDecodeNaNInfinity(t *testing.T) { tests := []struct { data []byte expect func(*testing.T, float64, error) }{ { data: []byte(`{"@type": "g:Double", "@value": "NaN"}`), expect: func(t *testing.T, f float64, err error) { assert.NoError(t, err) assert.True(t, math.IsNaN(f)) }, }, { data: []byte(`{"@type": "g:Double", "@value": "Infinity"}`), expect: func(t *testing.T, f float64, err error) { assert.NoError(t, err) assert.True(t, math.IsInf(f, 1)) }, }, { data: []byte(`{"@type": "g:Double", "@value": "-Infinity"}`), expect: func(t *testing.T, f float64, err error) { assert.NoError(t, err) assert.True(t, math.IsInf(f, -1)) }, }, { data: []byte(`{"@type": "g:Double", "@value": "Junk"}`), expect: func(t *testing.T, _ float64, err error) { assert.Error(t, err) }, }, { data: []byte(`{"@type": "g:Double", "@value": [42]}`), expect: func(t *testing.T, _ float64, err error) { assert.Error(t, err) }, }, } for _, tc := range tests { var f float64 err := Unmarshal(tc.data, &f) tc.expect(t, f, err) } } func TestDecodeTypeDefinition(t *testing.T) { type Status int const StatusOk Status = 42 var status Status err := UnmarshalFromString(`{"@type": "g:Int64", "@value": 42}`, &status) assert.NoError(t, err) assert.Equal(t, StatusOk, status) } ent-0.11.3/dialect/gremlin/encoding/graphson/raw.go000066400000000000000000000015611431500740500221550ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "errors" ) // RawMessage is a raw encoded graphson value. type RawMessage []byte // RawMessage must implement Marshaler/Unmarshaler interfaces. var ( _ Marshaler = (*RawMessage)(nil) _ Unmarshaler = (*RawMessage)(nil) ) // MarshalGraphson returns m as the graphson encoding of m. func (m RawMessage) MarshalGraphson() ([]byte, error) { if m == nil { return []byte("null"), nil } return m, nil } // UnmarshalGraphson sets *m to a copy of data. func (m *RawMessage) UnmarshalGraphson(data []byte) error { if m == nil { return errors.New("graphson.RawMessage: UnmarshalGraphson on nil pointer") } *m = append((*m)[0:0], data...) return nil } ent-0.11.3/dialect/gremlin/encoding/graphson/raw_test.go000066400000000000000000000014251431500740500232130ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRawMessageEncoding(t *testing.T) { var s struct{ M RawMessage } got, err := MarshalToString(s) require.NoError(t, err) assert.Equal(t, `{"M":null}`, got) s.M = []byte(`"155a"`) got, err = MarshalToString(s) require.NoError(t, err) assert.JSONEq(t, `{"M": "155a"}`, got) err = (*RawMessage)(nil).UnmarshalGraphson(s.M) assert.Error(t, err) s.M = nil err = UnmarshalFromString(got, &s) require.NoError(t, err) assert.Equal(t, `"155a"`, string(s.M)) } ent-0.11.3/dialect/gremlin/encoding/graphson/slice.go000066400000000000000000000074401431500740500224650ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "io" "reflect" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) // DecoratorOfSlice decorates a value encoder of a slice type. func (encodeExtension) DecoratorOfSlice(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { encoder := typeEncoder{ValEncoder: enc} sliceType := typ.(reflect2.SliceType) if sliceType.Elem().Kind() == reflect.Uint8 { encoder.Type = byteBufferType } else { encoder.Type = listType } return sliceEncoder{sliceType, encoder} } // DecoratorOfArray decorates a value encoder of an array type. func (encodeExtension) DecoratorOfArray(enc jsoniter.ValEncoder) jsoniter.ValEncoder { return typeEncoder{enc, listType} } // DecoderOfSlice returns a value decoder of a slice type. func (ext decodeExtension) DecoderOfSlice(typ reflect2.Type) jsoniter.ValDecoder { sliceType := typ.(reflect2.SliceType) elemType := sliceType.Elem() if elemType.Kind() == reflect.Uint8 { return nil } return sliceDecoder{ sliceType: sliceType, elemDec: ext.LazyDecoderOf(elemType), } } // DecoderOfArray returns a value decoder of an array type. func (ext decodeExtension) DecoderOfArray(typ reflect2.Type) jsoniter.ValDecoder { arrayType := typ.(reflect2.ArrayType) return arrayDecoder{ arrayType: arrayType, elemDec: ext.LazyDecoderOf(arrayType.Elem()), } } // DecoratorOfSlice decorates a value decoder of a slice type. func (ext decodeExtension) DecoratorOfSlice(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { if typ.(reflect2.SliceType).Elem().Kind() == reflect.Uint8 { return typeDecoder{dec, byteBufferType} } return typeDecoder{dec, listType} } // DecoratorOfArray decorates a value decoder of an array type. func (ext decodeExtension) DecoratorOfArray(dec jsoniter.ValDecoder) jsoniter.ValDecoder { return typeDecoder{dec, listType} } type sliceEncoder struct { sliceType reflect2.SliceType jsoniter.ValEncoder } func (enc sliceEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { if enc.sliceType.UnsafeIsNil(ptr) { stream.WriteNil() } else { enc.ValEncoder.Encode(ptr, stream) } } type sliceDecoder struct { sliceType reflect2.SliceType elemDec jsoniter.ValDecoder } func (dec sliceDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { dec.decode(ptr, iter) if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("decoding slice %s: %w", dec.sliceType, iter.Error) } } func (dec sliceDecoder) decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { sliceType := dec.sliceType if iter.ReadNil() { sliceType.UnsafeSetNil(ptr) return } sliceType.UnsafeSet(ptr, sliceType.UnsafeMakeSlice(0, 0)) var length int iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { idx := length length++ sliceType.UnsafeGrow(ptr, length) elem := sliceType.UnsafeGetIndex(ptr, idx) dec.elemDec.Decode(elem, iter) return iter.Error == nil }) } type arrayDecoder struct { arrayType reflect2.ArrayType elemDec jsoniter.ValDecoder } func (dec arrayDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { dec.decode(ptr, iter) if iter.Error != nil && iter.Error != io.EOF { iter.Error = fmt.Errorf("decoding array %s: %w", dec.arrayType, iter.Error) } } func (dec arrayDecoder) decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { var ( arrayType = dec.arrayType length int ) iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { if length < arrayType.Len() { idx := length length++ elem := arrayType.UnsafeGetIndex(ptr, idx) dec.elemDec.Decode(elem, iter) } else { iter.Skip() } return iter.Error == nil }) } ent-0.11.3/dialect/gremlin/encoding/graphson/slice_test.go000066400000000000000000000071011431500740500235160ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "bytes" "fmt" "strings" "testing" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncodeArray(t *testing.T) { t.Parallel() got, err := MarshalToString([...]string{"a", "b", "c"}) require.NoError(t, err) want := `{ "@type": "g:List", "@value": ["a", "b", "c"]}` assert.JSONEq(t, want, got) } func TestEncodeSlice(t *testing.T) { tests := []struct { in any want string }{ { in: []int32{5, 6, 7, 8}, want: `{ "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 5 }, { "@type": "g:Int32", "@value": 6 }, { "@type": "g:Int32", "@value": 7 }, { "@type": "g:Int32", "@value": 8 } ] }`, }, { in: []byte{1, 2, 3, 4, 5}, want: `{ "@type": "gx:ByteBuffer", "@value": "AQIDBAU=" }`, }, { in: [...]byte{4, 5}, want: `{ "@type": "g:List", "@value": [ { "@type": "gx:Byte", "@value": 4 }, { "@type": "gx:Byte", "@value": 5 } ] }`, }, { in: []uint64(nil), want: "null", }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.in), func(t *testing.T) { t.Parallel() var got bytes.Buffer err := NewEncoder(&got).Encode(tc.in) assert.NoError(t, err) assert.JSONEq(t, tc.want, got.String()) }) } } func TestDecodeSlice(t *testing.T) { tests := []struct { in string want any }{ { in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 3 }, { "@type": "g:Int32", "@value": -2 }, { "@type": "g:Int32", "@value": 1 } ] }`, want: []int32{3, -2, 1}, }, { in: `{ "@type": "g:List", "@value": ["a", "b", "c"] }`, want: []string{"a", "b", "c"}, }, { in: `{ "@type": "gx:ByteBuffer", "@value": "AQIDBAU=" }`, want: []byte{1, 2, 3, 4, 5}, }, { in: `{ "@type": "g:List", "@value": [ { "@type": "gx:Byte", "@value": 42 }, { "@type": "gx:Byte", "@value": 55 }, { "@type": "gx:Byte", "@value": 94 } ] }`, want: [...]byte{42, 55}, }, } for _, tc := range tests { tc := tc t.Run(fmt.Sprintf("%T", tc.want), func(t *testing.T) { t.Parallel() typ := reflect2.TypeOf(tc.want) got := typ.New() err := NewDecoder(strings.NewReader(tc.in)).Decode(got) require.NoError(t, err) assert.Equal(t, tc.want, typ.Indirect(got)) }) } } func TestDecodeBadSlice(t *testing.T) { tests := []struct { name string in string new func() any }{ { name: "TypeMismatch", in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 3 }, { "@type": "g:Int64", "@value": 2 } ] }`, new: func() any { return &[]int{} }, }, { name: "BadValue", in: `{ "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 3 }, { "@type": "g:Int32", "@value": "2" } ] }`, new: func() any { return &[2]int{} }, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() err := NewDecoder(strings.NewReader(tc.in)).Decode(tc.new()) assert.Error(t, err) }) } } ent-0.11.3/dialect/gremlin/encoding/graphson/struct.go000066400000000000000000000017401431500740500227070ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import jsoniter "github.com/json-iterator/go" // DecoratorOfStructField decorates a struct field value encoder. func (encodeExtension) DecoratorOfStructField(enc jsoniter.ValEncoder, tag string) jsoniter.ValEncoder { typ, _ := parseTag(tag) if typ == "" { return nil } encoder, ok := enc.(typeEncoder) if !ok { encoder = typeEncoder{ValEncoder: enc} } encoder.Type = Type(typ) return encoder } // DecoratorOfStructField decorates a struct field value decoder. func (decodeExtension) DecoratorOfStructField(dec jsoniter.ValDecoder, tag string) jsoniter.ValDecoder { typ, _ := parseTag(tag) if typ == "" { return nil } decoder, ok := dec.(typeDecoder) if !ok { decoder = typeDecoder{ValDecoder: dec} } decoder.typeChecker = Type(typ) return decoder } ent-0.11.3/dialect/gremlin/encoding/graphson/struct_test.go000066400000000000000000000074431431500740500237540ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/modern-go/reflect2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncodeStruct(t *testing.T) { tests := []struct { name string in any want string }{ { name: "Simple", in: struct { S string I int }{ S: "string", I: 1000, }, want: `{ "S":"string", "I": { "@type": "g:Int64", "@value": 1000 } }`, }, { name: "Tagged", in: struct { ID string `json:"requestId" graphson:"g:UUID"` Seq int `json:"seq" graphson:"g:Int32"` Op string `json:"op" graphson:","` Args map[string]string `json:"args"` }{ ID: "cb682578-9d92-4499-9ebc-5c6aa73c5397", Seq: 42, Op: "authentication", Args: map[string]string{ "sasl": "AHN0ZXBocGhlbgBwYXNzd29yZA==", }, }, want: `{ "requestId": { "@type": "g:UUID", "@value": "cb682578-9d92-4499-9ebc-5c6aa73c5397" }, "seq": { "@type": "g:Int32", "@value": 42 }, "op": "authentication", "args": { "@type": "g:Map", "@value": ["sasl", "AHN0ZXBocGhlbgBwYXNzd29yZA=="] } }`, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() data, err := MarshalToString(tc.in) require.NoError(t, err) assert.JSONEq(t, tc.want, data) }) } } func TestEncodeNestedStruct(t *testing.T) { type S struct { Parent *S `json:"parent,omitempty"` ID int `json:"id" graphson:"g:Int32"` } v := S{Parent: &S{ID: 1}, ID: 2} want := `{ "id": { "@type": "g:Int32", "@value": 2 }, "parent": { "id": { "@type": "g:Int32", "@value": 1 } } }` got, err := MarshalToString(&v) require.NoError(t, err) assert.JSONEq(t, want, got) } func TestDecodeStruct(t *testing.T) { tests := []struct { name string in string want any }{ { name: "Simple", in: `{ "S":"str", "I": { "@type": "g:Int32", "@value": 9999 } }`, want: struct { S string I int32 }{ S: "str", I: 9999, }, }, { name: "Tagged", in: `{ "requestId": { "@type": "g:UUID", "@value": "cb682578-9d92-4499-9ebc-5c6aa73c5397" }, "seq": { "@type": "g:Int32", "@value": 42 }, "op": "authentication", "args": { "@type": "g:Map", "@value": ["sasl", "AHN0ZXBocGhlbgBwYXNzd29yZA=="] } }`, want: struct { ID string `json:"requestId" graphson:"g:UUID"` Seq int `json:"seq" graphson:"g:Int32"` Op string `json:"op" graphson:","` Args map[string]string `json:"args"` }{ ID: "cb682578-9d92-4499-9ebc-5c6aa73c5397", Seq: 42, Op: "authentication", Args: map[string]string{ "sasl": "AHN0ZXBocGhlbgBwYXNzd29yZA==", }, }, }, { name: "Empty", in: `{}`, want: struct{}{}, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() typ := reflect2.TypeOf(tc.want) got := typ.New() err := UnmarshalFromString(tc.in, got) require.NoError(t, err) assert.Equal(t, tc.want, typ.Indirect(got)) }) } } func TestDecodeNestedStruct(t *testing.T) { type S struct { Parent *S `json:"parent,omitempty"` ID int `json:"id" graphson:"g:Int32"` } in := `{ "id": { "@type": "g:Int32", "@value": 37 }, "parent": { "id": { "@type": "g:Int32", "@value": 65 } } }` var got S err := UnmarshalFromString(in, &got) require.NoError(t, err) want := S{Parent: &S{ID: 65}, ID: 37} assert.Equal(t, want, got) } ent-0.11.3/dialect/gremlin/encoding/graphson/tags.go000066400000000000000000000016321431500740500223210ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "strings" ) type tagOptions string // parseTag splits a struct field's graphson tag into its type and // comma-separated options. func parseTag(tag string) (string, tagOptions) { if idx := strings.Index(tag, ","); idx != -1 { return tag[:idx], tagOptions(tag[idx+1:]) } return tag, "" } // Contains reports whether a comma-separated list of options // contains a particular substr flag. substr must be surrounded by a // string boundary or commas. func (opts tagOptions) Contains(opt string) bool { s := string(opts) for s != "" { var next string i := strings.Index(s, ",") if i >= 0 { s, next = s[:i], s[i+1:] } if s == opt { return true } s = next } return false } ent-0.11.3/dialect/gremlin/encoding/graphson/tags_test.go000066400000000000000000000022541431500740500233610ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/stretchr/testify/assert" ) func TestParseTag(t *testing.T) { tests := []struct { name string tag string typ string opts tagOptions }{ { name: "Empty", }, { name: "TypeOnly", tag: "g:Int32", typ: "g:Int32", }, { name: "OptsOnly", tag: ",opt1,opt2", opts: "opt1,opt2", }, { name: "TypeAndOpts", tag: "g:UUID,opt3,opt4", typ: "g:UUID", opts: "opt3,opt4", }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() typ, opts := parseTag(tc.tag) assert.Equal(t, tc.typ, typ) assert.Equal(t, tc.opts, opts) }) } } func TestTagOptionsContains(t *testing.T) { _, opts := parseTag(",opt1,opt2,opt3") assert.True(t, opts.Contains("opt1")) assert.True(t, opts.Contains("opt2")) assert.True(t, opts.Contains("opt3")) assert.False(t, opts.Contains("opt4")) assert.False(t, opts.Contains("opt11")) assert.False(t, opts.Contains("")) } ent-0.11.3/dialect/gremlin/encoding/graphson/time.go000066400000000000000000000016341431500740500223230ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "time" "unsafe" jsoniter "github.com/json-iterator/go" ) func init() { RegisterTypeEncoder("time.Time", typeEncoder{timeCodec{}, Timestamp}) RegisterTypeDecoder("time.Time", typeDecoder{timeCodec{}, Types{Timestamp, Date}}) } type timeCodec struct{} func (timeCodec) IsEmpty(ptr unsafe.Pointer) bool { ts := *((*time.Time)(ptr)) return ts.IsZero() } func (timeCodec) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { ts := *((*time.Time)(ptr)) stream.WriteInt64(ts.UnixNano() / time.Millisecond.Nanoseconds()) } func (timeCodec) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { ns := iter.ReadInt64() * time.Millisecond.Nanoseconds() *((*time.Time)(ptr)) = time.Unix(0, ns) } ent-0.11.3/dialect/gremlin/encoding/graphson/time_test.go000066400000000000000000000016001431500740500233530ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestTimeEncoding(t *testing.T) { const ms = 1481750076295 ts := time.Unix(0, ms*time.Millisecond.Nanoseconds()) for _, v := range []any{ts, &ts} { got, err := MarshalToString(v) require.NoError(t, err) assert.JSONEq(t, `{ "@type": "g:Timestamp", "@value": 1481750076295 }`, got) } strs := []string{ `{ "@type": "g:Timestamp", "@value": 1481750076295 }`, `{ "@type": "g:Date", "@value": 1481750076295 }`, } for _, str := range strs { var v time.Time err := UnmarshalFromString(str, &v) assert.NoError(t, err) assert.True(t, ts.Equal(v)) } } ent-0.11.3/dialect/gremlin/encoding/graphson/type.go000066400000000000000000000073541431500740500223530ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "fmt" "reflect" "strings" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" ) // A Type is a graphson type. type Type string // graphson typed value types. const ( // core doubleType Type = "g:Double" floatType Type = "g:Float" int32Type Type = "g:Int32" int64Type Type = "g:Int64" listType Type = "g:List" mapType Type = "g:Map" Timestamp Type = "g:Timestamp" Date Type = "g:Date" // extended bigIntegerType Type = "gx:BigInteger" bigDecimal Type = "gx:BigDecimal" byteType Type = "gx:Byte" byteBufferType Type = "gx:ByteBuffer" int16Type Type = "gx:Int16" ) // String implements fmt.Stringer interface. func (typ Type) String() string { return string(typ) } // CheckType implements typeChecker interface. func (typ Type) CheckType(other Type) error { if typ != other { return fmt.Errorf("expect type %s, but found %s", typ, other) } return nil } // Types is a slice of Type. type Types []Type // Contains reports whether a slice of types contains a particular type. func (types Types) Contains(typ Type) bool { for i := range types { if types[i] == typ { return true } } return false } // String implements fmt.Stringer interface. func (types Types) String() string { var builder strings.Builder builder.WriteByte('[') for i := range types { if i > 0 { builder.WriteByte(',') } builder.WriteString(types[i].String()) } builder.WriteByte(']') return builder.String() } // CheckType implements typeChecker interface. func (types Types) CheckType(typ Type) error { if !types.Contains(typ) { return fmt.Errorf("expect any of %s, but found %s", types, typ) } return nil } // Typer is the interface implemented by types that // define an underlying graphson type. type Typer interface { GraphsonType() Type } var typerType = reflect2.TypeOfPtr((*Typer)(nil)).Elem() // DecoratorOfTyper decorates a value encoder of a Typer interface. func (ext encodeExtension) DecoratorOfTyper(typ reflect2.Type, enc jsoniter.ValEncoder) jsoniter.ValEncoder { if typ.Kind() != reflect.Struct { return nil } if typ.Implements(typerType) { return typerEncoder{ typeEncoder: typeEncoder{ValEncoder: enc}, typerOf: func(ptr unsafe.Pointer) Typer { return typ.UnsafeIndirect(ptr).(Typer) }, } } ptrType := reflect2.PtrTo(typ) if ptrType.Implements(typerType) { return typerEncoder{ typeEncoder: typeEncoder{ValEncoder: enc}, typerOf: func(ptr unsafe.Pointer) Typer { // nolint: gas return ptrType.UnsafeIndirect(unsafe.Pointer(&ptr)).(Typer) }, } } return nil } // DecoratorOfTyper decorates a value decoder of a Typer interface. func (ext decodeExtension) DecoratorOfTyper(typ reflect2.Type, dec jsoniter.ValDecoder) jsoniter.ValDecoder { ptrType := reflect2.PtrTo(typ) if ptrType.Implements(typerType) { return typerDecoder{ typeDecoder: typeDecoder{ValDecoder: dec}, typerOf: func(ptr unsafe.Pointer) Typer { // nolint: gas return ptrType.UnsafeIndirect(unsafe.Pointer(&ptr)).(Typer) }, } } return nil } type typerEncoder struct { typeEncoder typerOf func(unsafe.Pointer) Typer } func (enc typerEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { enc.typeEncoder.Type = enc.typerOf(ptr).GraphsonType() enc.typeEncoder.Encode(ptr, stream) } type typerDecoder struct { typeDecoder typerOf func(unsafe.Pointer) Typer } func (dec typerDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { dec.typeDecoder.typeChecker = dec.typerOf(ptr).GraphsonType() dec.typeDecoder.Decode(ptr, iter) } ent-0.11.3/dialect/gremlin/encoding/graphson/type_test.go000066400000000000000000000041461431500740500234060ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) func TestTypeCheckType(t *testing.T) { assert.NoError(t, int32Type.CheckType(int32Type)) assert.Error(t, int32Type.CheckType(int64Type)) } func TestTypesCheckType(t *testing.T) { assert.NoError(t, Types{int16Type, int32Type, int64Type}.CheckType(int32Type)) assert.Error(t, Types{floatType, doubleType}.CheckType(bigIntegerType)) } func TestTypesString(t *testing.T) { assert.Equal(t, "[]", Types{}.String()) assert.Equal(t, "[gx:Byte]", Types{byteType}.String()) assert.Equal(t, "[gx:Int16,g:Int32,g:Int64]", Types{int16Type, int32Type, int64Type}.String()) } type vertex struct { ID int `json:"id"` Label string `json:"label"` } func (vertex) GraphsonType() Type { return Type("g:Vertex") } type mockVertex struct { mock.Mock `json:"-"` ID int `json:"id"` Label string `json:"label"` } func (m *mockVertex) GraphsonType() Type { return m.Called().Get(0).(Type) } func TestEncodeTyper(t *testing.T) { m := &mockVertex{ID: 42, Label: "person"} m.On("GraphsonType").Return(Type("g:Vertex")).Twice() defer m.AssertExpectations(t) v := vertex{ID: m.ID, Label: m.Label} var vv Typer = v want := `{ "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 42 }, "label": "person" } }` for _, tc := range []any{m, &m, v, vv, &vv} { got, err := MarshalToString(tc) assert.NoError(t, err) assert.JSONEq(t, want, got) } } func TestDecodeTyper(t *testing.T) { var m mockVertex m.On("GraphsonType").Return(Type("g:Vertex")).Once() defer m.AssertExpectations(t) in := `{ "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 55 }, "label": "user" } }` err := UnmarshalFromString(in, &m) assert.NoError(t, err) assert.Equal(t, 55, m.ID) assert.Equal(t, "user", m.Label) } ent-0.11.3/dialect/gremlin/encoding/graphson/util.go000066400000000000000000000051611431500740500223410ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "errors" "io" "unsafe" jsoniter "github.com/json-iterator/go" ) // graphson encoding type / value keys const ( TypeKey = "@type" ValueKey = "@value" ) // typeEncoder adds graphson type information to a value encoder. type typeEncoder struct { jsoniter.ValEncoder Type Type } // Encode belongs to jsoniter.ValEncoder interface. func (enc typeEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { stream.WriteObjectStart() stream.WriteObjectField(TypeKey) stream.WriteString(enc.Type.String()) stream.WriteMore() stream.WriteObjectField(ValueKey) enc.ValEncoder.Encode(ptr, stream) stream.WriteObjectEnd() } type ( // typeDecoder decorates a value decoder and adds graphson type verification. typeDecoder struct { jsoniter.ValDecoder typeChecker } // typeChecker defines an interface for graphson type verification. typeChecker interface { CheckType(Type) error } // typeCheckerFunc allows the use of functions as type checkers. typeCheckerFunc func(Type) error // typeValue defines a graphson type / value pair. typeValue struct { Type Type Value jsoniter.RawMessage } ) // Decode belongs to jsoniter.ValDecoder interface. func (dec typeDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { if iter.WhatIsNext() != jsoniter.ObjectValue { dec.ValDecoder.Decode(ptr, iter) return } data := iter.SkipAndReturnBytes() if iter.Error != nil && iter.Error != io.EOF { return } var tv typeValue if err := jsoniter.Unmarshal(data, &tv); err != nil { iter.ReportError("unmarshal type value", err.Error()) return } if err := dec.CheckType(tv.Type); err != nil { iter.ReportError("check type", err.Error()) return } it := config.BorrowIterator(tv.Value) defer config.ReturnIterator(it) dec.ValDecoder.Decode(ptr, it) if it.Error != nil && it.Error != io.EOF { iter.ReportError("decode value", it.Error.Error()) } } // UnmarshalJSON implements json.Unmarshaler interface. func (tv *typeValue) UnmarshalJSON(data []byte) error { var v struct { Type *Type `json:"@type"` Value jsoniter.RawMessage `json:"@value"` } if err := jsoniter.Unmarshal(data, &v); err != nil { return err } if v.Type == nil || v.Value == nil { return errors.New("missing type or value") } tv.Type = *v.Type tv.Value = v.Value return nil } // CheckType implements typeChecker interface. func (f typeCheckerFunc) CheckType(typ Type) error { return f(typ) } ent-0.11.3/dialect/gremlin/encoding/graphson/util_test.go000066400000000000000000000063351431500740500234040ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graphson import ( "bytes" "errors" "fmt" "testing" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) func TestTypeEncode(t *testing.T) { var got bytes.Buffer stream := config.BorrowStream(&got) defer config.ReturnStream(stream) typ, val := int32Type, 42 ptr := unsafe.Pointer(&val) var m mocker m.On("Encode", ptr, stream). Run(func(args mock.Arguments) { stream := args.Get(1).(*jsoniter.Stream) stream.WriteInt(val) }). Once() defer m.AssertExpectations(t) typeEncoder{&m, typ}.Encode(ptr, stream) require.NoError(t, stream.Flush()) want := fmt.Sprintf(`{"@type": "%s", "@value": %d}`, typ, val) assert.JSONEq(t, want, got.String()) } func TestTypeDecode(t *testing.T) { typ, val := int64Type, 84 ptr := unsafe.Pointer(&val) data := fmt.Sprintf(`{"@value": %d, "@type": "%s"}`, val, typ) iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) m := &mocker{} m.On("CheckType", typ). Return(nil). Once() m.On("Decode", ptr, mock.Anything). Run(func(args mock.Arguments) { iter := args.Get(1).(*jsoniter.Iterator) assert.Equal(t, val, iter.ReadInt()) }). Once() defer m.AssertExpectations(t) typeDecoder{m, m}.Decode(ptr, iter) assert.NoError(t, iter.Error) } func TestTypeDecodeBadType(t *testing.T) { typ, val := int64Type, 55 ptr := unsafe.Pointer(&val) m := &mocker{} m.On("CheckType", typ).Return(errors.New("bad type")).Once() defer m.AssertExpectations(t) data := fmt.Sprintf(`{"@type": "%s", "@value": %d}`, typ, val) iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) typeDecoder{m, m}.Decode(ptr, iter) require.Error(t, iter.Error) assert.Contains(t, iter.Error.Error(), "bad type") } func TestTypeDecodeDuplicateField(t *testing.T) { data := `{"@type": "gx:Byte", "@value": 33, "@type": "g:Int32"}` iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) var ptr unsafe.Pointer m := &mocker{} m.On("CheckType", mock.MatchedBy(func(typ Type) bool { return typ == int32Type })). Return(nil). Once() m.On("Decode", ptr, mock.Anything). Run(func(args mock.Arguments) { args.Get(1).(*jsoniter.Iterator).Skip() require.NoError(t, iter.Error) }). Once() defer m.AssertExpectations(t) typeDecoder{m, m}.Decode(ptr, iter) assert.NoError(t, iter.Error) } func TestTypeDecodeMissingField(t *testing.T) { data := `{"@type": "g:Int32"}` iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) m := &mocker{} defer m.AssertExpectations(t) typeDecoder{m, m}.Decode(nil, iter) require.Error(t, iter.Error) assert.Contains(t, iter.Error.Error(), "missing type or value") } func TestTypeDecodeSyntaxError(t *testing.T) { data := `{"@type": "gx:Int16", "@value", 65000}` iter := config.BorrowIterator([]byte(data)) defer config.ReturnIterator(iter) m := &mocker{} defer m.AssertExpectations(t) typeDecoder{m, m}.Decode(nil, iter) assert.Error(t, iter.Error) } ent-0.11.3/dialect/gremlin/encoding/mime.go000066400000000000000000000012261431500740500204700ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package encoding import ( "bytes" ) // Mime defines a gremlin mime type. type Mime []byte // GraphSON3Mime mime headers. var ( GraphSON3Mime = NewMime("application/vnd.gremlin-v3.0+json") ) // NewMime creates a wire format mime header. func NewMime(s string) Mime { var buf bytes.Buffer buf.WriteByte(byte(len(s))) buf.WriteString(s) return buf.Bytes() } // String implements fmt.Stringer interface. func (m Mime) String() string { return string(m[1:]) } ent-0.11.3/dialect/gremlin/encoding/mime_test.go000066400000000000000000000010661431500740500215310ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package encoding import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewMime(t *testing.T) { str := "application/vnd.gremlin-v2.0+json" mime := NewMime(str) require.Len(t, mime, len(str)+1) assert.EqualValues(t, len(str), mime[0]) assert.EqualValues(t, str, mime[1:]) assert.Equal(t, str, mime.String()) } ent-0.11.3/dialect/gremlin/example_test.go000066400000000000000000000016321431500740500204460ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "flag" "log" "os" "time" ) func ExampleClient_Query() { addr := flag.String("gremlin-server", os.Getenv("GREMLIN_SERVER"), "gremlin server address") flag.Parse() if *addr == "" { log.Fatal("missing gremlin server address") } client, err := NewHTTPClient(*addr, nil) if err != nil { log.Fatalf("creating client: %v", err) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) rsp, err := client.Query(ctx, "g.E()") if err != nil { log.Fatalf("executing query: %v", err) } edges, err := rsp.ReadEdges() if err != nil { log.Fatalf("unmashal edges") } defer cancel() for _, e := range edges { log.Println(e.String()) } // - Output: } ent-0.11.3/dialect/gremlin/expand.go000066400000000000000000000025011431500740500172270ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "fmt" "sort" "strings" jsoniter "github.com/json-iterator/go" ) // ExpandBindings expands the given RoundTripper and expands the request bindings into the Gremlin traversal. func ExpandBindings(rt RoundTripper) RoundTripper { return RoundTripperFunc(func(ctx context.Context, r *Request) (*Response, error) { bindings, ok := r.Arguments[ArgsBindings] if !ok { return rt.RoundTrip(ctx, r) } query, ok := r.Arguments[ArgsGremlin] if !ok { return rt.RoundTrip(ctx, r) } { query, bindings := query.(string), bindings.(map[string]any) keys := make(sort.StringSlice, 0, len(bindings)) for k := range bindings { keys = append(keys, k) } sort.Sort(sort.Reverse(keys)) kv := make([]string, 0, len(bindings)*2) for _, k := range keys { s, err := jsoniter.MarshalToString(bindings[k]) if err != nil { return nil, fmt.Errorf("marshal bindings value for key %s: %w", k, err) } kv = append(kv, k, s) } delete(r.Arguments, ArgsBindings) r.Arguments[ArgsGremlin] = strings.NewReplacer(kv...).Replace(query) } return rt.RoundTrip(ctx, r) }) } ent-0.11.3/dialect/gremlin/expand_test.go000066400000000000000000000037751431500740500203040ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "strconv" "testing" "github.com/stretchr/testify/assert" ) func TestExpandBindings(t *testing.T) { tests := []struct { req *Request wantErr bool wantQuery string }{ { req: NewEvalRequest("no bindings"), wantQuery: "no bindings", }, { req: NewEvalRequest("g.V($0)", WithBindings(map[string]any{"$0": 1})), wantQuery: "g.V(1)", }, { req: NewEvalRequest("g.V().has($1, $2)", WithBindings(map[string]any{"$1": "name", "$2": "a8m"})), wantQuery: "g.V().has(\"name\", \"a8m\")", }, { req: NewEvalRequest("g.V().limit(n)", WithBindings(map[string]any{"n": 10})), wantQuery: "g.V().limit(10)", }, { req: NewEvalRequest("g.V()", WithBindings(map[string]any{"$0": func() {}})), wantErr: true, }, { req: NewEvalRequest("g.V().has($0, $1)", WithBindings(map[string]any{"$0": "active", "$1": true})), wantQuery: "g.V().has(\"active\", true)", }, { req: NewEvalRequest("g.V().has($1, $11)", WithBindings(map[string]any{"$1": "active", "$11": true})), wantQuery: "g.V().has(\"active\", true)", }, } for i, tt := range tests { tt := tt t.Run(strconv.Itoa(i), func(t *testing.T) { rt := ExpandBindings(RoundTripperFunc(func(ctx context.Context, r *Request) (*Response, error) { assert.Equal(t, tt.wantQuery, r.Arguments[ArgsGremlin]) return nil, nil })) _, err := rt.RoundTrip(context.Background(), tt.req) assert.Equal(t, tt.wantErr, err != nil) }) } } func TestExpandBindingsNoQuery(t *testing.T) { rt := ExpandBindings(RoundTripperFunc(func(ctx context.Context, r *Request) (*Response, error) { return nil, nil })) _, err := rt.RoundTrip(context.Background(), &Request{Arguments: map[string]any{ ArgsBindings: map[string]any{}, }}) assert.NoError(t, err) } ent-0.11.3/dialect/gremlin/graph/000077500000000000000000000000001431500740500165245ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/graph/dsl/000077500000000000000000000000001431500740500173065ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/graph/dsl/__/000077500000000000000000000000001431500740500176635ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/graph/dsl/__/dsl.go000066400000000000000000000052541431500740500210020ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package __ import "entgo.io/ent/dialect/gremlin/graph/dsl" // As is the api for calling __.As(). func As(args ...any) *dsl.Traversal { return New().As(args...) } // Is is the api for calling __.Is(). func Is(args ...any) *dsl.Traversal { return New().Is(args...) } // Not is the api for calling __.Not(). func Not(args ...any) *dsl.Traversal { return New().Not(args...) } // Has is the api for calling __.Has(). func Has(args ...any) *dsl.Traversal { return New().Has(args...) } // HasNot is the api for calling __.HasNot(). func HasNot(args ...any) *dsl.Traversal { return New().HasNot(args...) } // Or is the api for calling __.Or(). func Or(args ...any) *dsl.Traversal { return New().Or(args...) } // And is the api for calling __.And(). func And(args ...any) *dsl.Traversal { return New().And(args...) } // In is the api for calling __.In(). func In(args ...any) *dsl.Traversal { return New().In(args...) } // Out is the api for calling __.Out(). func Out(args ...any) *dsl.Traversal { return New().Out(args...) } // OutE is the api for calling __.OutE(). func OutE(args ...any) *dsl.Traversal { return New().OutE(args...) } // InE is the api for calling __.InE(). func InE(args ...any) *dsl.Traversal { return New().InE(args...) } // InV is the api for calling __.InV(). func InV(args ...any) *dsl.Traversal { return New().InV(args...) } // V is the api for calling __.V(). func V(args ...any) *dsl.Traversal { return New().V(args...) } // OutV is the api for calling __.OutV(). func OutV(args ...any) *dsl.Traversal { return New().OutV(args...) } // Values is the api for calling __.Values(). func Values(args ...string) *dsl.Traversal { return New().Values(args...) } // Union is the api for calling __.Union(). func Union(args ...any) *dsl.Traversal { return New().Union(args...) } // Constant is the api for calling __.Constant(). func Constant(args ...any) *dsl.Traversal { return New().Constant(args...) } // Properties is the api for calling __.Properties(). func Properties(args ...any) *dsl.Traversal { return New().Properties(args...) } // OtherV is the api for calling __.OtherV(). func OtherV() *dsl.Traversal { return New().OtherV() } // Count is the api for calling __.Count(). func Count() *dsl.Traversal { return New().Count() } // Drop is the api for calling __.Drop(). func Drop() *dsl.Traversal { return New().Drop() } // Fold is the api for calling __.Fold(). func Fold() *dsl.Traversal { return New().Fold() } func New() *dsl.Traversal { return new(dsl.Traversal).Add(dsl.Token("__")) } ent-0.11.3/dialect/gremlin/graph/dsl/dsl.go000066400000000000000000000114101431500740500204140ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package dsl provide an API for writing gremlin dsl queries almost as-is // in Go without using strings in the code. // // Note that, the API is not type-safe and assume the provided query and // its arguments are valid. package dsl import ( "fmt" "strings" "time" ) // Node represents a DSL step in the traversal. type Node interface { // Code returns the code representation of the element and its bindings (if any). Code() (string, []any) } type ( // Token holds a simple token, like assignment. Token string // List represents a list of elements. List struct { Elements []any } // Func represents a function call. Func struct { Name string Args []any } // Block represents a block/group of nodes. Block struct { Nodes []any } // Var represents a variable assignment and usage. Var struct { Name string Elem any } ) // Code stringified the token. func (t Token) Code() (string, []any) { return string(t), nil } // Code returns the code representation of a list. func (l List) Code() (string, []any) { c, args := codeList(", ", l.Elements...) return fmt.Sprintf("[%s]", c), args } // Code returns the code representation of a function call. func (f Func) Code() (string, []any) { c, args := codeList(", ", f.Args...) return fmt.Sprintf("%s(%s)", f.Name, c), args } // Code returns the code representation of group/block of nodes. func (b Block) Code() (string, []any) { return codeList("; ", b.Nodes...) } // Code returns the code representation of variable declaration or its identifier. func (v Var) Code() (string, []any) { c, args := code(v.Elem) if v.Name == "" { return c, args } return fmt.Sprintf("%s = %s", v.Name, c), args } // predefined nodes. var ( G = Token("g") Dot = Token(".") ) // NewFunc returns a new function node. func NewFunc(name string, args ...any) *Func { return &Func{Name: name, Args: args} } // NewList returns a new list node. func NewList(args ...any) *List { return &List{Elements: args} } // Querier is the interface that wraps the Query method. type Querier interface { // Query returns the query-string (similar to the Gremlin byte-code) and its bindings. Query() (string, Bindings) } // Bindings are used to associate a variable with a value. type Bindings map[string]any // Add adds new value to the bindings map, formats it if needed, and returns its generated name. func (b Bindings) Add(v any) string { k := fmt.Sprintf("$%x", len(b)) switch v := v.(type) { case time.Time: b[k] = v.UnixNano() default: b[k] = v } return k } // Cardinality of vertex properties. type Cardinality string // Cardinality options. const ( Set Cardinality = "set" Single Cardinality = "single" ) // Code implements the Node interface. func (c Cardinality) Code() (string, []any) { return string(c), nil } // Keyword defines a Gremlin keyword. type Keyword string // Keyword options. const ( ID Keyword = "id" ) // Code implements the Node interface. func (k Keyword) Code() (string, []any) { return string(k), nil } // Order of vertex properties. type Order string // Order options. const ( Incr Order = "incr" Decr Order = "decr" Shuffle Order = "shuffle" ) // Code implements the Node interface. func (o Order) Code() (string, []any) { return string(o), nil } // Column references a particular type of column in a complex data structure such as a Map, a Map.Entry, or a Path. type Column string // Column options. const ( Keys Column = "keys" Values Column = "values" ) // Code implements the Node interface. func (o Column) Code() (string, []any) { return string(o), nil } // Scope used for steps that have a variable scope which alter the manner in which the step will behave in relation to how the traverses are processed. type Scope string // Scope options. const ( Local Scope = "local" Global Scope = "global" ) // Code implements the Node interface. func (s Scope) Code() (string, []any) { return string(s), nil } func codeList(sep string, vs ...any) (string, []any) { var ( br strings.Builder args []any ) for i, node := range vs { if i > 0 { br.WriteString(sep) } c, nargs := code(node) br.WriteString(c) args = append(args, nargs...) } return br.String(), args } func code(v any) (string, []any) { switch n := v.(type) { case Node: return n.Code() case *Traversal: var ( b strings.Builder args []any ) for i := range n.nodes { code, nargs := n.nodes[i].Code() b.WriteString(code) args = append(args, nargs...) } return b.String(), args default: return "%s", []any{v} } } func sface(args []string) (v []any) { for _, s := range args { v = append(v, s) } return } ent-0.11.3/dialect/gremlin/graph/dsl/dsl_test.go000066400000000000000000000200631431500740500214570ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package dsl_test import ( "strconv" "testing" "entgo.io/ent/dialect/gremlin/graph/dsl" "entgo.io/ent/dialect/gremlin/graph/dsl/__" "entgo.io/ent/dialect/gremlin/graph/dsl/g" "entgo.io/ent/dialect/gremlin/graph/dsl/p" "github.com/stretchr/testify/require" ) func TestTraverse(t *testing.T) { tests := []struct { input dsl.Querier wantQuery string wantBinds dsl.Bindings }{ { input: g.V(5), wantQuery: "g.V($0)", wantBinds: dsl.Bindings{"$0": 5}, }, { input: g.V(2).Both("knows"), wantQuery: "g.V($0).both($1)", wantBinds: dsl.Bindings{"$0": 2, "$1": "knows"}, }, { input: g.V(49).BothE("knows").OtherV().ValueMap(), wantQuery: "g.V($0).bothE($1).otherV().valueMap()", wantBinds: dsl.Bindings{"$0": 49, "$1": "knows"}, }, { input: g.AddV("person").Property("name", "a8m").Next(), wantQuery: "g.addV($0).property($1, $2).next()", wantBinds: dsl.Bindings{"$0": "person", "$1": "name", "$2": "a8m"}, }, { input: dsl.Each([]any{1, 2, 3}, func(it *dsl.Traversal) *dsl.Traversal { return g.V(it) }), wantQuery: "[$0, $1, $2].each { g.V(it) }", wantBinds: dsl.Bindings{"$0": 1, "$1": 2, "$2": 3}, }, { input: dsl.Each([]any{g.V(1).Next()}, func(it *dsl.Traversal) *dsl.Traversal { return it.ID() }), wantQuery: "[g.V($0).next()].each { it.id() }", wantBinds: dsl.Bindings{"$0": 1}, }, { input: g.AddV("person").AddE("knows").To(g.V(2)), wantQuery: "g.addV($0).addE($1).to(g.V($2))", wantBinds: dsl.Bindings{"$0": "person", "$1": "knows", "$2": 2}, }, { input: func() *dsl.Traversal { v1 := g.V(2).Next() v2 := g.AddV("person").Property("name", "a8m") e1 := g.V(v1).AddE("knows").To(v2) return dsl.Group(v1, v2, e1) }(), wantQuery: "t0 = g.V($0).next(); t1 = g.addV($1).property($2, $3); t2 = g.V(t0).addE($4).to(t1); t2", wantBinds: dsl.Bindings{"$0": 2, "$1": "person", "$2": "name", "$3": "a8m", "$4": "knows"}, }, { input: func() *dsl.Traversal { v1 := g.AddV("person") each := dsl.Each([]any{1, 2, 3}, func(it *dsl.Traversal) *dsl.Traversal { return g.V(v1).AddE("knows").To(g.V(it)).Next() }) return dsl.Group(v1, each) }(), wantQuery: "t0 = g.addV($0); t1 = [$1, $2, $3].each { g.V(t0).addE($4).to(g.V(it)).next() }; t1", wantBinds: dsl.Bindings{"$0": "person", "$1": 1, "$2": 2, "$3": 3, "$4": "knows"}, }, { input: g.V().HasLabel("person"). Choose(__.Values("age").Is(p.LTE(20))), wantQuery: "g.V().hasLabel($0).choose(__.values($1).is(lte($2)))", wantBinds: dsl.Bindings{"$0": "person", "$1": "age", "$2": 20}, }, { input: g.AddV("person").Property("name", "a8m").Properties(), wantQuery: "g.addV($0).property($1, $2).properties()", wantBinds: dsl.Bindings{"$0": "person", "$1": "name", "$2": "a8m"}, }, { input: func() *dsl.Traversal { v1 := g.AddV("person").Next() e1 := g.V(v1).AddE("knows").To(g.V(2).Next()) return dsl.Group(v1, e1, g.V(v1).ValueMap(true)) }(), wantQuery: "t0 = g.addV($0).next(); t1 = g.V(t0).addE($1).to(g.V($2).next()); t2 = g.V(t0).valueMap($3); t2", wantBinds: dsl.Bindings{"$0": "person", "$1": "knows", "$2": 2, "$3": true}, }, { input: func() *dsl.Traversal { vs := g.V().HasLabel("person").ToList() edge := g.V(vs).AddE("assoc").To(g.V(1)).Iterate() each := dsl.Each(vs, func(it *dsl.Traversal) *dsl.Traversal { return g.V(1).AddE("inverse").To(it).Next() }) return dsl.Group(vs, edge, each) }(), wantQuery: "t0 = g.V().hasLabel($0).toList(); t1 = g.V(t0).addE($1).to(g.V($2)).iterate(); t2 = t0.each { g.V($3).addE($4).to(it).next() }; t2", wantBinds: dsl.Bindings{"$0": "person", "$1": "assoc", "$2": 1, "$3": 1, "$4": "inverse"}, }, { input: g.V().Where(__.Or(__.Has("age", 29), __.Has("age", 30))), wantQuery: "g.V().where(__.or(__.has($0, $1), __.has($2, $3)))", wantBinds: dsl.Bindings{"$0": "age", "$1": 29, "$2": "age", "$3": 30}, }, { input: g.V().Has("name", p.Containing("le")).Has("name", p.StartingWith("A")), wantQuery: `g.V().has($0, containing($1)).has($2, startingWith($3))`, wantBinds: dsl.Bindings{"$0": "name", "$1": "le", "$2": "name", "$3": "A"}, }, { input: g.AddV().Property(dsl.Single, "age", 32).ValueMap(), wantQuery: "g.addV().property(single, $0, $1).valueMap()", wantBinds: dsl.Bindings{"$0": "age", "$1": 32}, }, { input: g.V().Count(), wantQuery: "g.V().count()", wantBinds: dsl.Bindings{}, }, { input: g.V().HasNot("age"), wantQuery: "g.V().hasNot($0)", wantBinds: dsl.Bindings{"$0": "age"}, }, { input: func() *dsl.Traversal { v := g.V().HasID(1) u := v.Clone().InE().Drop() return dsl.Join(v, u) }(), wantQuery: "g.V().hasId($0); g.V().hasId($1).inE().drop()", wantBinds: dsl.Bindings{"$0": 1, "$1": 1}, }, { input: func() *dsl.Traversal { v := g.V().HasID(1) u := v.Clone().InE().Drop() w := u.Clone() return dsl.Join(v, u, w) }(), wantQuery: "g.V().hasId($0); g.V().hasId($1).inE().drop(); g.V().hasId($2).inE().drop()", wantBinds: dsl.Bindings{"$0": 1, "$1": 1, "$2": 1}, }, { input: g.V().OutE("knows").Where(__.InV().Has("name", "a8m")).OutV(), wantQuery: "g.V().outE($0).where(__.inV().has($1, $2)).outV()", wantBinds: dsl.Bindings{"$0": "knows", "$1": "name", "$2": "a8m"}, }, { input: g.V().Has("name", p.Within("a8m", "alex")), wantQuery: "g.V().has($0, within($1, $2))", wantBinds: dsl.Bindings{"$0": "name", "$1": "a8m", "$2": "alex"}, }, { input: g.V().HasID(p.Within(1, 2)), wantQuery: "g.V().hasId(within($0, $1))", wantBinds: dsl.Bindings{"$0": 1, "$1": 2}, }, { input: g.V().HasID(p.Without(1, 2)), wantQuery: "g.V().hasId(without($0, $1))", wantBinds: dsl.Bindings{"$0": 1, "$1": 2}, }, { input: g.V().Order().By("name"), wantQuery: "g.V().order().by($0)", wantBinds: dsl.Bindings{"$0": "name"}, }, { input: g.V().Order().By("name", dsl.Incr), wantQuery: "g.V().order().by($0, incr)", wantBinds: dsl.Bindings{"$0": "name"}, }, { input: g.V().Order().By("name", dsl.Incr).Undo(), wantQuery: "g.V().order()", wantBinds: dsl.Bindings{}, }, { input: g.V().OutE("knows").Where(__.InV().Has("name", "a8m")).Undo(), wantQuery: "g.V().outE($0)", wantBinds: dsl.Bindings{"$0": "knows"}, }, { input: g.V().Has("name").Group().By("name").By("age").Select(dsl.Values), wantQuery: "g.V().has($0).group().by($1).by($2).select(values)", wantBinds: dsl.Bindings{"$0": "name", "$1": "name", "$2": "age"}, }, { input: g.V().Fold().Unfold(), wantQuery: "g.V().fold().unfold()", wantBinds: dsl.Bindings{}, }, { input: g.V().Has("person", "name", "a8m").Count().Coalesce( __.Is(p.NEQ(0)).Constant("unique constraint failed"), g.AddV("person").Property("name", "a8m").ValueMap(true), ), wantQuery: "g.V().has($0, $1, $2).count().coalesce(__.is(neq($3)).constant($4), g.addV($5).property($6, $7).valueMap($8))", wantBinds: dsl.Bindings{"$0": "person", "$1": "name", "$2": "a8m", "$3": 0, "$4": "unique constraint failed", "$5": "person", "$6": "name", "$7": "a8m", "$8": true}, }, { input: g.V().Has("age").Property("age", __.Union(__.Values("age"), __.Constant(10)).Sum()).ValueMap(), wantQuery: "g.V().has($0).property($1, __.union(__.values($2), __.constant($3)).sum()).valueMap()", wantBinds: dsl.Bindings{"$0": "age", "$1": "age", "$2": "age", "$3": 10}, }, { input: g.V().Has("age").SideEffect(__.Properties("name").Drop()).ValueMap(), wantQuery: "g.V().has($0).sideEffect(__.properties($1).drop()).valueMap()", wantBinds: dsl.Bindings{"$0": "age", "$1": "name"}, }, } for i, tt := range tests { tt := tt t.Run(strconv.Itoa(i), func(t *testing.T) { query, bindings := tt.input.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantBinds, bindings) }) } } ent-0.11.3/dialect/gremlin/graph/dsl/g/000077500000000000000000000000001431500740500175345ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/graph/dsl/g/g.go000066400000000000000000000013261431500740500203130ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package g import "entgo.io/ent/dialect/gremlin/graph/dsl" // V is the api for calling g.V(). func V(args ...any) *dsl.Traversal { return dsl.NewTraversal().V(args...) } // E is the api for calling g.E(). func E(args ...any) *dsl.Traversal { return dsl.NewTraversal().E(args...) } // AddV is the api for calling g.AddV(). func AddV(args ...any) *dsl.Traversal { return dsl.NewTraversal().AddV(args...) } // AddE is the api for calling g.AddE(). func AddE(args ...any) *dsl.Traversal { return dsl.NewTraversal().AddE(args...) } ent-0.11.3/dialect/gremlin/graph/dsl/p/000077500000000000000000000000001431500740500175455ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/graph/dsl/p/p.go000066400000000000000000000041771431500740500203440ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package p import ( "entgo.io/ent/dialect/gremlin/graph/dsl" ) // EQ is the equal predicate. func EQ(v any) *dsl.Traversal { return op("eq", v) } // NEQ is the not-equal predicate. func NEQ(v any) *dsl.Traversal { return op("neq", v) } // GT is the greater than predicate. func GT(v any) *dsl.Traversal { return op("gt", v) } // GTE is the greater than or equal predicate. func GTE(v any) *dsl.Traversal { return op("gte", v) } // LT is the less than predicate. func LT(v any) *dsl.Traversal { return op("lt", v) } // LTE is the less than or equal predicate. func LTE(v any) *dsl.Traversal { return op("lte", v) } // Between is the between/contains predicate. func Between(v, u any) *dsl.Traversal { return op("between", v, u) } // StartingWith is the prefix test predicate. func StartingWith(prefix string) *dsl.Traversal { return op("startingWith", prefix) } // EndingWith is the suffix test predicate. func EndingWith(suffix string) *dsl.Traversal { return op("endingWith", suffix) } // Containing is the sub string test predicate. func Containing(substr string) *dsl.Traversal { return op("containing", substr) } // NotStartingWith is the negation of StartingWith. func NotStartingWith(prefix string) *dsl.Traversal { return op("notStartingWith", prefix) } // NotEndingWith is the negation of EndingWith. func NotEndingWith(suffix string) *dsl.Traversal { return op("notEndingWith", suffix) } // NotContaining is the negation of Containing. func NotContaining(substr string) *dsl.Traversal { return op("notContaining", substr) } // Within Determines if a value is within the specified list of values. func Within(args ...any) *dsl.Traversal { return op("within", args...) } // Without determines if a value is not within the specified list of values. func Without(args ...any) *dsl.Traversal { return op("without", args...) } func op(name string, args ...any) *dsl.Traversal { t := &dsl.Traversal{} return t.Add(dsl.NewFunc(name, args...)) } ent-0.11.3/dialect/gremlin/graph/dsl/traversal.go000066400000000000000000000320471431500740500216460ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package dsl import ( "fmt" "strings" ) // Traversal mimics the TinkerPop graph traversal. type Traversal struct { // nodes holds the dsl nodes. first element is the reference name // of the TinkerGraph. defaults to "g". nodes []Node } // NewTraversal returns a new default traversal with "g" as a reference name to the Graph. func NewTraversal() *Traversal { return &Traversal{[]Node{G}} } // Group groups a list of traversals into one. all traversals are assigned into a temporary // variables named by their index. The last variable functions as a return value of the query. // Note that, this "temporary hack" is not perfect and may not work in some cases because of // the limitation of evaluation order. func Group(trs ...*Traversal) *Traversal { var ( b = Block{} names = make(map[*Traversal]Token) ) for i, tr := range trs { if _, ok := names[tr]; ok { continue } v := &Var{Name: fmt.Sprintf("t%d", i), Elem: &Traversal{nodes: tr.nodes}} b.Nodes = append(b.Nodes, v) names[tr] = Token(v.Name) } for _, tr := range trs { tr.nodes = []Node{names[tr]} } b.Nodes = append(b.Nodes, names[trs[len(trs)-1]]) return &Traversal{[]Node{b}} } // Join joins a list of traversals with a semicolon separator. func Join(trs ...*Traversal) *Traversal { b := Block{} for _, tr := range trs { b.Nodes = append(b.Nodes, &Traversal{nodes: tr.nodes}) } return &Traversal{[]Node{b}} } // V step is usually used to start a traversal but it may also be used mid-traversal. func (t *Traversal) V(args ...any) *Traversal { t.Add(Dot, NewFunc("V", args...)) return t } // OtherV maps the Edge to the incident vertex that was not just traversed from in the path history. func (t *Traversal) OtherV() *Traversal { t.Add(Dot, NewFunc("otherV")) return t } // E step is usually used to start a traversal but it may also be used mid-traversal. func (t *Traversal) E(args ...any) *Traversal { t.Add(Dot, NewFunc("E", args...)) return t } // AddV adds a vertex. func (t *Traversal) AddV(args ...any) *Traversal { t.Add(Dot, NewFunc("addV", args...)) return t } // AddE adds an edge. func (t *Traversal) AddE(args ...any) *Traversal { t.Add(Dot, NewFunc("addE", args...)) return t } // Next gets the next n-number of results from the traversal. func (t *Traversal) Next() *Traversal { return t.Add(Dot, NewFunc("next")) } // Drop removes elements and properties from the graph. func (t *Traversal) Drop() *Traversal { return t.Add(Dot, NewFunc("drop")) } // Property sets a Property value and related meta properties if supplied, // if supported by the Graph and if the Element is a VertexProperty. func (t *Traversal) Property(args ...any) *Traversal { return t.Add(Dot, NewFunc("property", args...)) } // Both maps the Vertex to its adjacent vertices given the edge labels. func (t *Traversal) Both(args ...any) *Traversal { return t.Add(Dot, NewFunc("both", args...)) } // BothE maps the Vertex to its incident edges given the edge labels. func (t *Traversal) BothE(args ...any) *Traversal { return t.Add(Dot, NewFunc("bothE", args...)) } // Has filters vertices, edges and vertex properties based on their properties. // See: http://tinkerpop.apache.org/docs/current/reference/#has-step. func (t *Traversal) Has(args ...any) *Traversal { return t.Add(Dot, NewFunc("has", args...)) } // HasNot filters vertices, edges and vertex properties based on the non-existence of properties. // See: http://tinkerpop.apache.org/docs/current/reference/#has-step. func (t *Traversal) HasNot(args ...any) *Traversal { return t.Add(Dot, NewFunc("hasNot", args...)) } // HasID filters vertices, edges and vertex properties based on their identifier. func (t *Traversal) HasID(args ...any) *Traversal { return t.Add(Dot, NewFunc("hasId", args...)) } // HasLabel filters vertices, edges and vertex properties based on their label. func (t *Traversal) HasLabel(args ...any) *Traversal { return t.Add(Dot, NewFunc("hasLabel", args...)) } // HasNext returns true if the iteration has more elements. func (t *Traversal) HasNext() *Traversal { return t.Add(Dot, NewFunc("hasNext")) } // Match maps the Traverser to a Map of bindings as specified by the provided match traversals. func (t *Traversal) Match(args ...any) *Traversal { return t.Add(Dot, NewFunc("match", args...)) } // Choose routes the current traverser to a particular traversal branch option which allows the creation of if-then-else like semantics within a traversal. func (t *Traversal) Choose(args ...any) *Traversal { return t.Add(Dot, NewFunc("choose", args...)) } // Select arbitrary values from the traversal. func (t *Traversal) Select(args ...any) *Traversal { return t.Add(Dot, NewFunc("select", args...)) } // Group organizes objects in the stream into a Map.Calls to group() are typically accompanied with by() modulators which help specify how the grouping should occur. func (t *Traversal) Group() *Traversal { return t.Add(Dot, NewFunc("group")) } // Values maps the Element to the values of the associated properties given the provide property keys. func (t *Traversal) Values(args ...string) *Traversal { return t.Add(Dot, NewFunc("values", sface(args)...)) } // ValueMap maps the Element to a Map of the property values key'd according to their Property.key(). func (t *Traversal) ValueMap(args ...any) *Traversal { return t.Add(Dot, NewFunc("valueMap", args...)) } // Properties maps the Element to its associated properties given the provide property keys. func (t *Traversal) Properties(args ...any) *Traversal { return t.Add(Dot, NewFunc("properties", args...)) } // Range filters the objects in the traversal by the number of them to pass through the stream. func (t *Traversal) Range(args ...any) *Traversal { return t.Add(Dot, NewFunc("range", args...)) } // Limit filters the objects in the traversal by the number of them to pass through the stream, where only the first n objects are allowed as defined by the limit argument. func (t *Traversal) Limit(args ...any) *Traversal { return t.Add(Dot, NewFunc("limit", args...)) } // ID maps the Element to its Element.id(). func (t *Traversal) ID() *Traversal { return t.Add(Dot, NewFunc("id")) } // Label maps the Element to its Element.label(). func (t *Traversal) Label() *Traversal { return t.Add(Dot, NewFunc("label")) } // From provides from()-modulation to respective steps. func (t *Traversal) From(args ...any) *Traversal { return t.Add(Dot, NewFunc("from", args...)) } // To used as a modifier to addE(String) this method specifies the traversal to use for selecting the incoming vertex of the newly added Edge. func (t *Traversal) To(args ...any) *Traversal { return t.Add(Dot, NewFunc("to", args...)) } // As provides a label to the step that can be accessed later in the traversal by other steps. func (t *Traversal) As(args ...any) *Traversal { return t.Add(Dot, NewFunc("as", args...)) } // Or ensures that at least one of the provided traversals yield a result. func (t *Traversal) Or(args ...any) *Traversal { return t.Add(Dot, NewFunc("or", args...)) } // And ensures that all of the provided traversals yield a result. func (t *Traversal) And(args ...any) *Traversal { return t.Add(Dot, NewFunc("and", args...)) } // Is filters the E object if it is not P.eq(V) to the provided value. func (t *Traversal) Is(args ...any) *Traversal { return t.Add(Dot, NewFunc("is", args...)) } // Not removes objects from the traversal stream when the traversal provided as an argument does not return any objects. func (t *Traversal) Not(args ...any) *Traversal { return t.Add(Dot, NewFunc("not", args...)) } // In maps the Vertex to its incoming adjacent vertices given the edge labels. func (t *Traversal) In(args ...any) *Traversal { return t.Add(Dot, NewFunc("in", args...)) } // Where filters the current object based on the object itself or the path history. func (t *Traversal) Where(args ...any) *Traversal { return t.Add(Dot, NewFunc("where", args...)) } // Out maps the Vertex to its outgoing adjacent vertices given the edge labels. func (t *Traversal) Out(args ...any) *Traversal { return t.Add(Dot, NewFunc("out", args...)) } // OutE maps the Vertex to its outgoing incident edges given the edge labels. func (t *Traversal) OutE(args ...any) *Traversal { return t.Add(Dot, NewFunc("outE", args...)) } // InE maps the Vertex to its incoming incident edges given the edge labels. func (t *Traversal) InE(args ...any) *Traversal { return t.Add(Dot, NewFunc("inE", args...)) } // OutV maps the Edge to its outgoing/tail incident Vertex. func (t *Traversal) OutV(args ...any) *Traversal { return t.Add(Dot, NewFunc("outV", args...)) } // InV maps the Edge to its incoming/head incident Vertex. func (t *Traversal) InV(args ...any) *Traversal { return t.Add(Dot, NewFunc("inV", args...)) } // ToList puts all the results into a Groovy list. func (t *Traversal) ToList() *Traversal { return t.Add(Dot, NewFunc("toList")) } // Iterate iterates the traversal presumably for the generation of side-effects. func (t *Traversal) Iterate() *Traversal { return t.Add(Dot, NewFunc("iterate")) } // Count maps the traversal stream to its reduction as a sum of the Traverser.bulk() values // (i.e. count the number of traversers up to this point). func (t *Traversal) Count(args ...any) *Traversal { return t.Add(Dot, NewFunc("count", args...)) } // Order all the objects in the traversal up to this point and then emit them one-by-one in their ordered sequence. func (t *Traversal) Order(args ...any) *Traversal { return t.Add(Dot, NewFunc("order", args...)) } // By can be applied to a number of different step to alter their behaviors. // This form is essentially an identity() modulation. func (t *Traversal) By(args ...any) *Traversal { return t.Add(Dot, NewFunc("by", args...)) } // Fold rolls up objects in the stream into an aggregate list.. func (t *Traversal) Fold() *Traversal { return t.Add(Dot, NewFunc("fold")) } // Unfold unrolls a Iterator, Iterable or Map into a linear form or simply emits the object if it is not one of those types. func (t *Traversal) Unfold() *Traversal { return t.Add(Dot, NewFunc("unfold")) } // Sum maps the traversal stream to its reduction as a sum of the Traverser.get() values multiplied by their Traverser.bulk(). func (t *Traversal) Sum(args ...any) *Traversal { return t.Add(Dot, NewFunc("sum", args...)) } // Mean determines the mean value in the stream. func (t *Traversal) Mean(args ...any) *Traversal { return t.Add(Dot, NewFunc("mean", args...)) } // Min determines the smallest value in the stream. func (t *Traversal) Min(args ...any) *Traversal { return t.Add(Dot, NewFunc("min", args...)) } // Max determines the greatest value in the stream. func (t *Traversal) Max(args ...any) *Traversal { return t.Add(Dot, NewFunc("max", args...)) } // Coalesce evaluates the provided traversals and returns the result of the first traversal to emit at least one object. func (t *Traversal) Coalesce(args ...any) *Traversal { return t.Add(Dot, NewFunc("coalesce", args...)) } // Dedup removes all duplicates in the traversal stream up to this point. func (t *Traversal) Dedup(args ...any) *Traversal { return t.Add(Dot, NewFunc("dedup", args...)) } // Constant maps any object to a fixed E value. func (t *Traversal) Constant(args ...any) *Traversal { return t.Add(Dot, NewFunc("constant", args...)) } // Union merges the results of an arbitrary number of traversals. func (t *Traversal) Union(args ...any) *Traversal { return t.Add(Dot, NewFunc("union", args...)) } // SideEffect allows the traverser to proceed unchanged, but yield some computational // sideEffect in the process. func (t *Traversal) SideEffect(args ...any) *Traversal { return t.Add(Dot, NewFunc("sideEffect", args...)) } // Each is a Groovy each-loop function. func Each(v any, cb func(it *Traversal) *Traversal) *Traversal { t := &Traversal{} switch v := v.(type) { case *Traversal: t.Add(&Var{Elem: v}) case []any: t.Add(NewList(v...)) default: t.Add(Token("undefined")) } t.Add(Dot, Token("each"), Token(" { ")) t.Add(cb(&Traversal{[]Node{Token("it")}}).nodes...) t.Add(Token(" }")) return t } // Add is the public API for adding new nodes to the traversal by its sub packages. func (t *Traversal) Add(n ...Node) *Traversal { t.nodes = append(t.nodes, n...) return t } // Query returns the query-representation and its binding of this traversal object. func (t *Traversal) Query() (string, Bindings) { var ( names []any query strings.Builder bindings = Bindings{} ) for _, n := range t.nodes { code, args := n.Code() query.WriteString(code) for _, arg := range args { names = append(names, bindings.Add(arg)) } } return fmt.Sprintf(query.String(), names...), bindings } // Clone creates a deep copy of an existing traversal. func (t *Traversal) Clone() *Traversal { if t == nil { return nil } return &Traversal{append(make([]Node, 0, len(t.nodes)), t.nodes...)} } // Undo reverts the last-step of the traversal. func (t *Traversal) Undo() *Traversal { if n := len(t.nodes); n > 2 { t.nodes = t.nodes[:n-2] } return t } ent-0.11.3/dialect/gremlin/graph/edge.go000066400000000000000000000042311431500740500177570ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "fmt" "entgo.io/ent/dialect/gremlin/encoding/graphson" ) type ( // An Edge between two vertices. Edge struct { Element OutV, InV Vertex } // graphson edge repr. edge struct { Element OutV any `json:"outV"` OutVLabel string `json:"outVLabel"` InV any `json:"inV"` InVLabel string `json:"inVLabel"` } ) // NewEdge create a new graph edge. func NewEdge(id any, label string, outV, inV Vertex) Edge { return Edge{ Element: NewElement(id, label), OutV: outV, InV: inV, } } // String implements fmt.Stringer interface. func (e Edge) String() string { return fmt.Sprintf("e[%v][%v-%s->%v]", e.ID, e.OutV.ID, e.Label, e.InV.ID) } // MarshalGraphson implements graphson.Marshaler interface. func (e Edge) MarshalGraphson() ([]byte, error) { return graphson.Marshal(edge{ Element: e.Element, OutV: e.OutV.ID, OutVLabel: e.OutV.Label, InV: e.InV.ID, InVLabel: e.InV.Label, }) } // UnmarshalGraphson implements graphson.Unmarshaler interface. func (e *Edge) UnmarshalGraphson(data []byte) error { var edge edge if err := graphson.Unmarshal(data, &edge); err != nil { return fmt.Errorf("unmarshaling edge: %w", err) } *e = NewEdge( edge.ID, edge.Label, NewVertex(edge.OutV, edge.OutVLabel), NewVertex(edge.InV, edge.InVLabel), ) return nil } // GraphsonType implements graphson.Typer interface. func (edge) GraphsonType() graphson.Type { return "g:Edge" } // Property denotes a key/value pair associated with an edge. type Property struct { Key string `json:"key"` Value any `json:"value"` } // NewProperty create a new graph edge property. func NewProperty(key string, value any) Property { return Property{key, value} } // GraphsonType implements graphson.Typer interface. func (Property) GraphsonType() graphson.Type { return "g:Property" } // String implements fmt.Stringer interface. func (p Property) String() string { return fmt.Sprintf("p[%s->%v]", p.Key, p.Value) } ent-0.11.3/dialect/gremlin/graph/edge_test.go000066400000000000000000000041041431500740500210150ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "fmt" "testing" "entgo.io/ent/dialect/gremlin/encoding/graphson" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEdgeString(t *testing.T) { e := NewEdge( 13, "develops", NewVertex(1, ""), NewVertex(10, ""), ) assert.Equal(t, "e[13][1-develops->10]", fmt.Sprint(e)) } func TestEdgeEncoding(t *testing.T) { t.Parallel() e := NewEdge(13, "develops", NewVertex(1, "person"), NewVertex(10, "software"), ) got, err := graphson.MarshalToString(e) require.NoError(t, err) want := `{ "@type" : "g:Edge", "@value" : { "id" : { "@type" : "g:Int64", "@value" : 13 }, "label" : "develops", "inVLabel" : "software", "outVLabel" : "person", "inV" : { "@type" : "g:Int64", "@value" : 10 }, "outV" : { "@type" : "g:Int64", "@value" : 1 } } }` assert.JSONEq(t, want, got) e = Edge{} err = graphson.UnmarshalFromString(got, &e) require.NoError(t, err) assert.Equal(t, NewElement(int64(13), "develops"), e.Element) assert.Equal(t, NewVertex(int64(1), "person"), e.OutV) assert.Equal(t, NewVertex(int64(10), "software"), e.InV) } func TestPropertyEncoding(t *testing.T) { t.Parallel() props := []Property{ NewProperty("from", int32(2017)), NewProperty("to", int32(2019)), } got, err := graphson.MarshalToString(props) require.NoError(t, err) want := `{ "@type" : "g:List", "@value" : [ { "@type" : "g:Property", "@value" : { "key" : "from", "value" : { "@type" : "g:Int32", "@value" : 2017 } } }, { "@type" : "g:Property", "@value" : { "key" : "to", "value" : { "@type" : "g:Int32", "@value" : 2019 } } } ] }` assert.JSONEq(t, want, got) } func TestPropertyString(t *testing.T) { p := NewProperty("since", 2019) assert.Equal(t, "p[since->2019]", fmt.Sprint(p)) } ent-0.11.3/dialect/gremlin/graph/element.go000066400000000000000000000007161431500740500205100ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph // Element defines a base struct for graph elements. type Element struct { ID any `json:"id"` Label string `json:"label"` } // NewElement create a new graph element. func NewElement(id any, label string) Element { return Element{id, label} } ent-0.11.3/dialect/gremlin/graph/valuemap.go000066400000000000000000000024371431500740500206730ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "errors" "fmt" "reflect" "github.com/mitchellh/mapstructure" ) // ValueMap models a .valueMap() gremlin response. type ValueMap []map[string]any // Decode decodes a value map into v. func (m ValueMap) Decode(v any) error { rv := reflect.ValueOf(v) if rv.Kind() != reflect.Ptr { return errors.New("cannot unmarshal into a non pointer") } if rv.IsNil() { return errors.New("cannot unmarshal into a nil pointer") } if rv.Elem().Kind() != reflect.Slice { v = &[]any{v} } return m.decode(v) } func (m ValueMap) decode(v any) error { cfg := mapstructure.DecoderConfig{ DecodeHook: func(f, t reflect.Kind, data any) (any, error) { if f == reflect.Slice && t != reflect.Slice { rv := reflect.ValueOf(data) if rv.Len() == 1 { data = rv.Index(0).Interface() } } return data, nil }, Result: v, TagName: "json", } dec, err := mapstructure.NewDecoder(&cfg) if err != nil { return fmt.Errorf("creating structure decoder: %w", err) } if err := dec.Decode(m); err != nil { return fmt.Errorf("decoding value map: %w", err) } return nil } ent-0.11.3/dialect/gremlin/graph/valuemap_test.go000066400000000000000000000032261431500740500217270ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestValueMapDecodeOne(t *testing.T) { vm := ValueMap{map[string]any{ "id": int64(1), "label": "person", "name": []any{"marko"}, "age": []any{int32(29)}, }} var ent struct { ID uint64 `json:"id"` Label string `json:"label"` Name string `json:"name"` Age uint8 `json:"age"` } err := vm.Decode(&ent) require.NoError(t, err) assert.Equal(t, uint64(1), ent.ID) assert.Equal(t, "person", ent.Label) assert.Equal(t, "marko", ent.Name) assert.Equal(t, uint8(29), ent.Age) } func TestValueMapDecodeMany(t *testing.T) { vm := ValueMap{ map[string]any{ "id": int64(1), "label": "person", "name": []any{"chico"}, }, map[string]any{ "id": int64(2), "label": "person", "name": []any{"dico"}, }, } ents := []struct { ID int `json:"id"` Label string `json:"label"` Name string `json:"name"` }{} err := vm.Decode(&ents) require.NoError(t, err) require.Len(t, ents, 2) assert.Equal(t, 1, ents[0].ID) assert.Equal(t, "person", ents[0].Label) assert.Equal(t, "chico", ents[0].Name) assert.Equal(t, 2, ents[1].ID) assert.Equal(t, "person", ents[1].Label) assert.Equal(t, "dico", ents[1].Name) } func TestValueMapDecodeBadInput(t *testing.T) { type s struct{ Name string } err := ValueMap{}.Decode(s{}) assert.Error(t, err) err = ValueMap{}.Decode((*s)(nil)) assert.Error(t, err) } ent-0.11.3/dialect/gremlin/graph/vertex.go000066400000000000000000000026311431500740500203720ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "fmt" "entgo.io/ent/dialect/gremlin/encoding/graphson" ) // Vertex represents a graph vertex. type Vertex struct { Element } // NewVertex create a new graph vertex. func NewVertex(id any, label string) Vertex { if label == "" { label = "vertex" } return Vertex{ Element: NewElement(id, label), } } // GraphsonType implements graphson.Typer interface. func (Vertex) GraphsonType() graphson.Type { return "g:Vertex" } // String implements fmt.Stringer interface. func (v Vertex) String() string { return fmt.Sprintf("v[%v]", v.ID) } // VertexProperty denotes a key/value pair associated with a vertex. type VertexProperty struct { ID any `json:"id"` Key string `json:"label"` Value any `json:"value"` } // NewVertexProperty create a new graph vertex property. func NewVertexProperty(id any, key string, value any) VertexProperty { return VertexProperty{ ID: id, Key: key, Value: value, } } // GraphsonType implements graphson.Typer interface. func (VertexProperty) GraphsonType() graphson.Type { return "g:VertexProperty" } // String implements fmt.Stringer interface. func (vp VertexProperty) String() string { return fmt.Sprintf("vp[%s->%v]", vp.Key, vp.Value) } ent-0.11.3/dialect/gremlin/graph/vertex_test.go000066400000000000000000000036051431500740500214330ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package graph import ( "fmt" "testing" "entgo.io/ent/dialect/gremlin/encoding/graphson" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestVertexCreation(t *testing.T) { v := NewVertex(45, "person") assert.Equal(t, 45, v.ID) assert.Equal(t, "person", v.Label) v = NewVertex(46, "") assert.Equal(t, "vertex", v.Label) } func TestVertexString(t *testing.T) { v := NewVertex(42, "") assert.Equal(t, "v[42]", fmt.Sprint(v)) } func TestVertexEncoding(t *testing.T) { t.Parallel() v := NewVertex(1, "user") got, err := graphson.MarshalToString(v) require.NoError(t, err) want := `{ "@type" : "g:Vertex", "@value" : { "id" : { "@type" : "g:Int64", "@value" : 1 }, "label" : "user" } }` assert.JSONEq(t, want, got) v = Vertex{} err = graphson.UnmarshalFromString(got, &v) require.NoError(t, err) assert.Equal(t, int64(1), v.ID) assert.Equal(t, "user", v.Label) } func TestVertexPropertyEncoding(t *testing.T) { t.Parallel() vp := NewVertexProperty("46ab60c2-918c-4cc4-a13b-350510e8908a", "name", "alex") got, err := graphson.MarshalToString(vp) require.NoError(t, err) want := `{ "@type" : "g:VertexProperty", "@value" : { "id" : "46ab60c2-918c-4cc4-a13b-350510e8908a", "label": "name", "value": "alex" } }` assert.JSONEq(t, want, got) vp = VertexProperty{} err = graphson.UnmarshalFromString(got, &vp) require.NoError(t, err) assert.Equal(t, "46ab60c2-918c-4cc4-a13b-350510e8908a", vp.ID) assert.Equal(t, "name", vp.Key) assert.Equal(t, "alex", vp.Value) } func TestVertexPropertyString(t *testing.T) { vp := NewVertexProperty(55, "country", "israel") assert.Equal(t, "vp[country->israel]", fmt.Sprint(vp)) } ent-0.11.3/dialect/gremlin/http.go000066400000000000000000000044361431500740500167400ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "errors" "fmt" "io" "net/http" "net/url" "entgo.io/ent/dialect/gremlin/encoding/graphson" jsoniter "github.com/json-iterator/go" ) type httpTransport struct { client *http.Client url string } // NewHTTPTransport returns a new http transport. func NewHTTPTransport(urlStr string, client *http.Client) (RoundTripper, error) { u, err := url.Parse(urlStr) if err != nil { return nil, fmt.Errorf("gremlin/http: parsing url: %w", err) } if client == nil { client = http.DefaultClient } return &httpTransport{client, u.String()}, nil } // RoundTrip implements RouterTripper interface. func (t *httpTransport) RoundTrip(ctx context.Context, req *Request) (*Response, error) { if req.Operation != OpsEval { return nil, fmt.Errorf("gremlin/http: unsupported operation: %q", req.Operation) } if _, ok := req.Arguments[ArgsGremlin]; !ok { return nil, errors.New("gremlin/http: missing query expression") } pr, pw := io.Pipe() defer pr.Close() go func() { err := jsoniter.NewEncoder(pw).Encode(req.Arguments) if err != nil { err = fmt.Errorf("gremlin/http: encoding request: %w", err) } _ = pw.CloseWithError(err) }() var br io.Reader { req, err := http.NewRequest(http.MethodPost, t.url, pr) if err != nil { return nil, fmt.Errorf("gremlin/http: creating http request: %w", err) } req.Header.Set("Content-Type", "application/json") rsp, err := t.client.Do(req.WithContext(ctx)) if err != nil { return nil, fmt.Errorf("gremlin/http: posting http request: %w", err) } defer rsp.Body.Close() if rsp.StatusCode < http.StatusOK || rsp.StatusCode > http.StatusPartialContent { body, _ := io.ReadAll(rsp.Body) return nil, fmt.Errorf("gremlin/http: status=%q, body=%q", rsp.Status, body) } if rsp.ContentLength > MaxResponseSize { return nil, errors.New("gremlin/http: context length exceeds limit") } br = rsp.Body } var rsp Response if err := graphson.NewDecoder(io.LimitReader(br, MaxResponseSize)).Decode(&rsp); err != nil { return nil, fmt.Errorf("gremlin/http: decoding response: %w", err) } return &rsp, nil } ent-0.11.3/dialect/gremlin/http_test.go000066400000000000000000000070331431500740500177730ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "context" "io" "net/http" "net/http/httptest" "testing" "entgo.io/ent/dialect/gremlin/encoding/graphson" jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestHTTPTransportRoundTripper(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, r.Header.Get("Content-Type"), "application/json") got, err := io.ReadAll(r.Body) require.NoError(t, err) assert.JSONEq(t, `{"gremlin": "g.V(1)", "language": "gremlin-groovy"}`, string(got)) _, err = io.WriteString(w, `{ "requestId": "f679127f-8701-425c-af55-049a44720db6", "result": { "data": { "@type": "g:List", "@value": [ { "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 1 }, "label": "person" } } ] }, "meta": { "@type": "g:Map", "@value": [] } }, "status": { "attributes": { "@type": "g:Map", "@value": [] }, "code": 200, "message": "" } }`) require.NoError(t, err) })) defer srv.Close() transport, err := NewHTTPTransport(srv.URL, nil) require.NoError(t, err) rsp, err := transport.RoundTrip(context.Background(), NewEvalRequest("g.V(1)")) require.NoError(t, err) assert.Equal(t, "f679127f-8701-425c-af55-049a44720db6", rsp.RequestID) assert.Equal(t, 200, rsp.Status.Code) assert.Empty(t, rsp.Status.Message) v := jsoniter.Get(rsp.Result.Data, graphson.ValueKey, 0, graphson.ValueKey) require.NoError(t, v.LastError()) assert.Equal(t, 1, v.Get("id", graphson.ValueKey).ToInt()) assert.Equal(t, "person", v.Get("label").ToString()) } func TestNewHTTPTransportBadURL(t *testing.T) { transport, err := NewHTTPTransport(":", nil) assert.Nil(t, transport) assert.Error(t, err) } func TestHTTPTransportBadRequest(t *testing.T) { transport, err := NewHTTPTransport("example.com", nil) require.NoError(t, err) req := NewEvalRequest("g.V()") req.Operation = "" rsp, err := transport.RoundTrip(context.Background(), req) assert.EqualError(t, err, `gremlin/http: unsupported operation: ""`) assert.Nil(t, rsp) req = NewEvalRequest("g.V()") delete(req.Arguments, ArgsGremlin) rsp, err = transport.RoundTrip(context.Background(), req) assert.EqualError(t, err, "gremlin/http: missing query expression") assert.Nil(t, rsp) } func TestHTTPTransportBadResponseStatus(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer srv.Close() transport, err := NewHTTPTransport(srv.URL, nil) require.NoError(t, err) _, err = transport.RoundTrip(context.Background(), NewEvalRequest("g.E().")) require.Error(t, err) assert.Contains(t, err.Error(), http.StatusText(http.StatusInternalServerError)) } func TestHTTPTransportBadResponseBody(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := io.WriteString(w, "{{{") require.NoError(t, err) })) defer srv.Close() transport, err := NewHTTPTransport(srv.URL, nil) require.NoError(t, err) _, err = transport.RoundTrip(context.Background(), NewEvalRequest("g.E().")) require.Error(t, err) assert.Contains(t, err.Error(), "decoding response") } ent-0.11.3/dialect/gremlin/internal/000077500000000000000000000000001431500740500172375ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/internal/ws/000077500000000000000000000000001431500740500176705ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/internal/ws/conn.go000066400000000000000000000172511431500740500211620ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ws import ( "bytes" "context" "errors" "fmt" "io" "net/http" "sync" "time" "entgo.io/ent/dialect/gremlin" "entgo.io/ent/dialect/gremlin/encoding" "entgo.io/ent/dialect/gremlin/encoding/graphson" "github.com/gorilla/websocket" "golang.org/x/sync/errgroup" ) const ( // Time allowed to write a message to the peer. writeWait = 5 * time.Second // Time allowed to read the next pong message from the peer. pongWait = 10 * time.Second // Send pings to peer with this period. Must be less than pongWait. pingPeriod = (pongWait * 9) / 10 ) type ( // A Dialer contains options for connecting to Gremlin server. Dialer struct { // Underlying websocket dialer. websocket.Dialer // Gremlin server basic auth credentials. user, pass string } // Conn performs operations on a gremlin server. Conn struct { // Underlying websocket connection. conn *websocket.Conn // Credentials for basic authentication. user, pass string // Goroutine tracking. ctx context.Context grp *errgroup.Group // Channel of outbound requests. send chan io.Reader // Map of in flight requests. inflight sync.Map } // inflight tracks request state. inflight struct { // partially received data frags []graphson.RawMessage // response channel result chan<- result } // represents an execution result. result struct { rsp *gremlin.Response err error } ) var ( // DefaultDialer is a dialer with all fields set to the default values. DefaultDialer = &Dialer{ Dialer: websocket.Dialer{ Proxy: http.ProxyFromEnvironment, HandshakeTimeout: 5 * time.Second, WriteBufferSize: 8192, ReadBufferSize: 8192, }, } // ErrConnClosed is returned by the Conn's Execute method when // the underlying gremlin server connection is closed. ErrConnClosed = errors.New("gremlin: server connection closed") // ErrDuplicateRequest is returned by the Conns Execute method on // request identifier key collision. ErrDuplicateRequest = errors.New("gremlin: duplicate request") ) // Dial creates a new connection by calling DialContext with a background context. func (d *Dialer) Dial(uri string) (*Conn, error) { return d.DialContext(context.Background(), uri) } // DialContext creates a new Gremlin connection. func (d *Dialer) DialContext(ctx context.Context, uri string) (*Conn, error) { c, rsp, err := d.Dialer.DialContext(ctx, uri, nil) if err != nil { return nil, fmt.Errorf("gremlin: dialing uri %s: %w", uri, err) } defer rsp.Body.Close() conn := &Conn{ conn: c, user: d.user, pass: d.pass, send: make(chan io.Reader), } conn.grp, conn.ctx = errgroup.WithContext(context.Background()) conn.grp.Go(conn.sender) conn.grp.Go(conn.receiver) return conn, nil } // Execute executes a request against a Gremlin server. func (c *Conn) Execute(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { // buffered result channel prevents receiver block on context cancellation result := make(chan result, 1) // request id must be unique across inflight request if _, loaded := c.inflight.LoadOrStore(req.RequestID, &inflight{result: result}); loaded { return nil, ErrDuplicateRequest } pr, pw := io.Pipe() defer pr.Close() // stream graphson encoding into request c.grp.Go(func() error { err := graphson.NewEncoder(pw).Encode(req) if err != nil { err = fmt.Errorf("encoding request: %w", err) } pw.CloseWithError(err) return err }) // local copy for single write send := c.send for { select { case <-c.ctx.Done(): c.inflight.Delete(req.RequestID) return nil, ErrConnClosed case <-ctx.Done(): c.inflight.Delete(req.RequestID) return nil, ctx.Err() case send <- pr: send = nil case result := <-result: return result.rsp, result.err } } } // Close connection with a Gremlin server. func (c *Conn) Close() error { c.grp.Go(func() error { return ErrConnClosed }) _ = c.grp.Wait() return nil } func (c *Conn) sender() error { pinger := time.NewTicker(pingPeriod) defer pinger.Stop() // closing connection terminates receiver defer c.conn.Close() for { select { case r := <-c.send: // ensure write completes within a window c.conn.SetWriteDeadline(time.Now().Add(writeWait)) // fetch next message writer w, err := c.conn.NextWriter(websocket.BinaryMessage) if err != nil { return fmt.Errorf("getting message writer: %w", err) } // write mime header if _, err := w.Write(encoding.GraphSON3Mime); err != nil { return fmt.Errorf("writing mime header: %w", err) } // write request body if _, err := io.Copy(w, r); err != nil { return fmt.Errorf("writing request: %w", err) } // finish message write if err := w.Close(); err != nil { return fmt.Errorf("closing message writer: %w", err) } case <-c.ctx.Done(): // connection closing return c.conn.WriteControl( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Time{}, ) case <-pinger.C: // periodic connection keepalive if err := c.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(writeWait)); err != nil { return fmt.Errorf("writing ping message: %w", err) } } } } func (c *Conn) receiver() error { // handle keepalive responses c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { return c.conn.SetReadDeadline(time.Now().Add(pongWait)) }) // complete all in flight requests on termination defer c.inflight.Range(func(id, ifr any) bool { ifr.(*inflight).result <- result{err: ErrConnClosed} c.inflight.Delete(id) return true }) for { // rely on sender connection close during termination _, r, err := c.conn.NextReader() if err != nil { return fmt.Errorf("writing ping message: %w", err) } // decode received response var rsp gremlin.Response if err := graphson.NewDecoder(r).Decode(&rsp); err != nil { return fmt.Errorf("reading response: %w", err) } ifr, ok := c.inflight.Load(rsp.RequestID) if !ok { // context cancellation aborts inflight requests continue } // handle incoming response if done := c.receive(ifr.(*inflight), &rsp); done { // stop tracking finished requests c.inflight.Delete(rsp.RequestID) } } } func (c *Conn) receive(ifr *inflight, rsp *gremlin.Response) bool { result := result{rsp: rsp} switch rsp.Status.Code { case gremlin.StatusSuccess: // quickly handle non fragmented responses if ifr.frags == nil { break } // handle fragment fallthrough case gremlin.StatusPartialContent: // append received fragment var frag []graphson.RawMessage if err := graphson.Unmarshal(rsp.Result.Data, &frag); err != nil { result.err = fmt.Errorf("decoding response fragment: %w", err) break } ifr.frags = append(ifr.frags, frag...) // partial response requires additional fragments if rsp.Status.Code == gremlin.StatusPartialContent { return false } // reassemble fragmented response if rsp.Result.Data, result.err = graphson.Marshal(ifr.frags); result.err != nil { result.err = fmt.Errorf("assembling fragmented response: %w", result.err) } case gremlin.StatusAuthenticate: // receiver should never block c.grp.Go(func() error { var buf bytes.Buffer if err := graphson.NewEncoder(&buf).Encode( gremlin.NewAuthRequest(rsp.RequestID, c.user, c.pass), ); err != nil { return fmt.Errorf("encoding auth request: %w", err) } select { case c.send <- &buf: case <-c.ctx.Done(): } return c.ctx.Err() }) return false } ifr.result <- result return true } ent-0.11.3/dialect/gremlin/internal/ws/conn_test.go000066400000000000000000000226141431500740500222200ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ws import ( "context" "net/http" "net/http/httptest" "strconv" "sync" "testing" "entgo.io/ent/dialect/gremlin" "entgo.io/ent/dialect/gremlin/encoding/graphson" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type conn struct{ *websocket.Conn } func (c conn) ReadRequest() (*gremlin.Request, error) { _, data, err := c.ReadMessage() if err != nil { return nil, err } var req gremlin.Request if err := graphson.Unmarshal(data[data[0]+1:], &req); err != nil { return nil, err } return &req, nil } func (c conn) WriteResponse(rsp *gremlin.Response) error { data, err := graphson.Marshal(rsp) if err != nil { return err } return c.WriteMessage(websocket.BinaryMessage, data) } func serve(handler func(conn)) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} c, _ := upgrader.Upgrade(w, r, nil) defer c.Close() handler(conn{c}) for { _, _, err := c.ReadMessage() if err != nil { break } } })) } func TestConnectClosure(t *testing.T) { var wg sync.WaitGroup wg.Add(1) defer wg.Wait() srv := serve(func(conn conn) { defer wg.Done() _, _, err := conn.ReadMessage() assert.True(t, websocket.IsCloseError(err, websocket.CloseNormalClosure)) }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) err = conn.Close() assert.NoError(t, err) _, err = conn.Execute(context.Background(), gremlin.NewEvalRequest("g.V()")) assert.EqualError(t, err, ErrConnClosed.Error()) } func TestSimpleQuery(t *testing.T) { srv := serve(func(conn conn) { typ, data, err := conn.ReadMessage() require.NoError(t, err) assert.Equal(t, websocket.BinaryMessage, typ) var req gremlin.Request err = graphson.Unmarshal(data[data[0]+1:], &req) require.NoError(t, err) assert.Equal(t, "g.V()", req.Arguments["gremlin"]) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusNoContent err = conn.WriteResponse(&rsp) require.NoError(t, err) }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer assert.Condition(t, func() bool { return assert.NoError(t, conn.Close()) }) rsp, err := conn.Execute(context.Background(), gremlin.NewEvalRequest("g.V()")) assert.NoError(t, err) require.NotNil(t, rsp) assert.Equal(t, gremlin.StatusNoContent, rsp.Status.Code) } func TestDuplicateRequest(t *testing.T) { // skip until flakiness will be fixed. t.SkipNow() srv := serve(func(conn conn) { req, err := conn.ReadRequest() require.NoError(t, err) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusNoContent err = conn.WriteResponse(&rsp) require.NoError(t, err) }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() var errors [2]error req := gremlin.NewEvalRequest("g.V()") var wg sync.WaitGroup wg.Add(len(errors)) for i := range errors { go func(i int) { _, errors[i] = conn.Execute(context.Background(), req) wg.Done() }(i) } wg.Wait() err = errors[0] if err == nil { err = errors[1] } assert.EqualError(t, err, ErrDuplicateRequest.Error()) } func TestConnectCancellation(t *testing.T) { srv := serve(func(conn) {}) defer srv.Close() ctx, cancel := context.WithCancel(context.Background()) cancel() conn, err := DefaultDialer.DialContext(ctx, "ws://"+srv.Listener.Addr().String()) assert.Error(t, err) assert.Nil(t, conn) } func TestQueryCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) srv := serve(func(conn conn) { if _, _, err := conn.ReadMessage(); err == nil { cancel() } }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() _, err = conn.Execute(ctx, gremlin.NewEvalRequest("g.E()")) assert.EqualError(t, err, context.Canceled.Error()) } func TestBadResponse(t *testing.T) { tests := []struct { name string mangle func(*gremlin.Response) *gremlin.Response }{ { name: "NoStatus", mangle: func(rsp *gremlin.Response) *gremlin.Response { return rsp }, }, { name: "Malformed", mangle: func(rsp *gremlin.Response) *gremlin.Response { rsp.Status.Code = gremlin.StatusMalformedRequest rsp.Status.Message = "bad request" return rsp }, }, { name: "Unknown", mangle: func(rsp *gremlin.Response) *gremlin.Response { rsp.Status.Code = 424242 return rsp }, }, } srv := serve(func(conn conn) { for { req, err := conn.ReadRequest() if err != nil { break } idx, err := strconv.ParseInt(req.Arguments["gremlin"].(string), 10, 0) require.NoError(t, err) err = conn.WriteResponse(tests[idx].mangle(&gremlin.Response{RequestID: req.RequestID})) require.NoError(t, err) } }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() var wg sync.WaitGroup wg.Add(len(tests)) ctx := context.Background() for i, tc := range tests { i, tc := i, tc t.Run(tc.name, func(t *testing.T) { defer wg.Done() rsp, err := conn.Execute(ctx, gremlin.NewEvalRequest(strconv.FormatInt(int64(i), 10))) assert.NoError(t, err) assert.True(t, rsp.IsErr()) }) } wg.Wait() } func TestServerHangup(t *testing.T) { // skip until flakiness will be fixed. t.SkipNow() srv := serve(func(conn conn) { _ = conn.Close() }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() _, err = conn.Execute(context.Background(), gremlin.NewEvalRequest("g.V()")) assert.EqualError(t, err, ErrConnClosed.Error()) assert.Error(t, conn.ctx.Err()) } func TestCanceledLongRequest(t *testing.T) { // skip until flakiness will be fixed. t.SkipNow() ctx, cancel := context.WithCancel(context.Background()) srv := serve(func(conn conn) { var responses [3]*gremlin.Response for i := 0; i < len(responses); i++ { req, err := conn.ReadRequest() require.NoError(t, err) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusSuccess rsp.Result.Data = graphson.RawMessage(`"ok"`) responses[i] = &rsp } cancel() responses[0], responses[2] = responses[2], responses[0] for i := 0; i < len(responses); i++ { err := conn.WriteResponse(responses[i]) require.NoError(t, err) } }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() var wg sync.WaitGroup wg.Add(3) defer wg.Wait() for i := 0; i < 3; i++ { go func(ctx context.Context, idx int) { defer wg.Done() rsp, err := conn.Execute(ctx, gremlin.NewEvalRequest("g.V()")) if idx > 0 { assert.NoError(t, err) assert.EqualValues(t, []byte(`"ok"`), rsp.Result.Data) } else { assert.EqualError(t, err, context.Canceled.Error()) } }(ctx, i) ctx = context.Background() } } func TestPartialResponse(t *testing.T) { type kv struct { Key string Value int } kvs := []kv{ {"one", 1}, {"two", 2}, {"three", 3}, } srv := serve(func(conn conn) { req, err := conn.ReadRequest() require.NoError(t, err) for i := range kvs { data, err := graphson.Marshal([]kv{kvs[i]}) require.NoError(t, err) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Result.Data = graphson.RawMessage(data) if i != len(kvs)-1 { rsp.Status.Code = gremlin.StatusPartialContent } else { rsp.Status.Code = gremlin.StatusSuccess } err = conn.WriteResponse(&rsp) require.NoError(t, err) } }) defer srv.Close() conn, err := DefaultDialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer conn.Close() rsp, err := conn.Execute(context.Background(), gremlin.NewEvalRequest("g.E()")) assert.NoError(t, err) var result []kv err = graphson.Unmarshal(rsp.Result.Data, &result) require.NoError(t, err) assert.Equal(t, kvs, result) } func TestAuthentication(t *testing.T) { user, pass := "username", "password" srv := serve(func(conn conn) { req, err := conn.ReadRequest() require.NoError(t, err) rsp := gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusAuthenticate err = conn.WriteResponse(&rsp) require.NoError(t, err) areq, err := conn.ReadRequest() require.NoError(t, err) var acreds gremlin.Credentials err = acreds.UnmarshalText([]byte(areq.Arguments["sasl"].(string))) assert.NoError(t, err) areq.Arguments["sasl"] = acreds assert.Equal(t, gremlin.NewAuthRequest(req.RequestID, user, pass), areq) rsp = gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusNoContent err = conn.WriteResponse(&rsp) require.NoError(t, err) }) defer srv.Close() dialer := *DefaultDialer dialer.user = user dialer.pass = pass client, err := dialer.Dial("ws://" + srv.Listener.Addr().String()) require.NoError(t, err) defer client.Close() _, err = client.Execute(context.Background(), gremlin.NewEvalRequest("g.E().drop()")) assert.NoError(t, err) } ent-0.11.3/dialect/gremlin/ocgremlin/000077500000000000000000000000001431500740500174025ustar00rootroot00000000000000ent-0.11.3/dialect/gremlin/ocgremlin/client.go000066400000000000000000000041301431500740500212050ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "entgo.io/ent/dialect/gremlin" "go.opencensus.io/trace" ) // Transport is an gremlin.RoundTripper that instruments all outgoing requests with // OpenCensus stats and tracing. type Transport struct { // Base is a wrapped gremlin.RoundTripper that does the actual requests. Base gremlin.RoundTripper // StartOptions are applied to the span started by this Transport around each // request. // // StartOptions.SpanKind will always be set to trace.SpanKindClient // for spans started by this transport. StartOptions trace.StartOptions // GetStartOptions allows to set start options per request. If set, // StartOptions is going to be ignored. GetStartOptions func(context.Context, *gremlin.Request) trace.StartOptions // NameFromRequest holds the function to use for generating the span name // from the information found in the outgoing Gremlin Request. By default the // name equals the URL Path. FormatSpanName func(context.Context, *gremlin.Request) string // WithQuery, if set to true, will enable recording of gremlin queries in spans. // Only allow this if it is safe to have queries recorded with respect to // security. WithQuery bool } // RoundTrip implements gremlin.RoundTripper, delegating to Base and recording stats and traces for the request. func (t *Transport) RoundTrip(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { spanNameFormatter := t.FormatSpanName if spanNameFormatter == nil { spanNameFormatter = func(context.Context, *gremlin.Request) string { return "gremlin:traversal" } } startOpts := t.StartOptions if t.GetStartOptions != nil { startOpts = t.GetStartOptions(ctx, req) } var rt gremlin.RoundTripper = &traceTransport{ base: t.Base, formatSpanName: spanNameFormatter, startOptions: startOpts, withQuery: t.WithQuery, } rt = statsTransport{rt} return rt.RoundTrip(ctx, req) } ent-0.11.3/dialect/gremlin/ocgremlin/client_test.go000066400000000000000000000032641431500740500222530ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "errors" "testing" "entgo.io/ent/dialect/gremlin" "github.com/stretchr/testify/mock" "go.opencensus.io/trace" ) type mockExporter struct { mock.Mock } func (e *mockExporter) ExportSpan(s *trace.SpanData) { e.Called(s) } func TestTransportOptions(t *testing.T) { tests := []struct { name string spanName string wantName string }{ { name: "Default formatter", wantName: "gremlin:traversal", }, { name: "Custom formatter", spanName: "tester", wantName: "tester", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { var exporter mockExporter exporter.On( "ExportSpan", mock.MatchedBy(func(s *trace.SpanData) bool { return s.Name == tt.wantName })). Once() defer exporter.AssertExpectations(t) trace.RegisterExporter(&exporter) defer trace.UnregisterExporter(&exporter) transport := &mockTransport{} transport.On("RoundTrip", mock.Anything, mock.Anything). Return(nil, errors.New("noop")). Once() defer transport.AssertExpectations(t) rt := &Transport{ Base: transport, GetStartOptions: func(context.Context, *gremlin.Request) trace.StartOptions { return trace.StartOptions{Sampler: trace.AlwaysSample()} }, } if tt.spanName != "" { rt.FormatSpanName = func(context.Context, *gremlin.Request) string { return tt.spanName } } _, _ = rt.RoundTrip(context.Background(), gremlin.NewEvalRequest("g.E()")) }) } } ent-0.11.3/dialect/gremlin/ocgremlin/stats.go000066400000000000000000000067061431500740500211000ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "strconv" "time" "entgo.io/ent/dialect/gremlin" "go.opencensus.io/stats" "go.opencensus.io/stats/view" "go.opencensus.io/tag" ) // The following measures are supported for use in custom views. var ( RequestCount = stats.Int64( "gremlin/request_count", "Number of Gremlin requests started", stats.UnitDimensionless, ) ResponseBytes = stats.Int64( "gremlin/response_bytes", "Total number of bytes in response data", stats.UnitBytes, ) RoundTripLatency = stats.Float64( "gremlin/roundtrip_latency", "End-to-end latency", stats.UnitMilliseconds, ) ) // The following tags are applied to stats recorded by this package. var ( // StatusCode is the numeric Gremlin response status code, // or "error" if a transport error occurred and no status code was read. StatusCode, _ = tag.NewKey("gremlin_status_code") ) // Default distributions used by views in this package. var ( DefaultSizeDistribution = view.Distribution(32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576) DefaultLatencyDistribution = view.Distribution(1, 2, 3, 4, 5, 6, 8, 10, 13, 16, 20, 25, 30, 40, 50, 65, 80, 100, 130, 160, 200, 250, 300, 400, 500, 650, 800, 1000, 2000, 5000, 10000, 20000, 50000, 100000) ) // Package ocgremlin provides some convenience views for measures. // You still need to register these views for data to actually be collected. var ( RequestCountView = &view.View{ Name: "gremlin/request_count", Measure: RequestCount, Aggregation: view.Count(), Description: "Count of Gremlin requests started", } ResponseCountView = &view.View{ Name: "gremlin/response_count", Measure: RoundTripLatency, Aggregation: view.Count(), Description: "Count of responses received, by response status", TagKeys: []tag.Key{StatusCode}, } ResponseBytesView = &view.View{ Name: "gremlin/response_bytes", Measure: ResponseBytes, Aggregation: DefaultSizeDistribution, Description: "Total number of bytes in response data", } RoundTripLatencyView = &view.View{ Name: "gremlin/roundtrip_latency", Measure: RoundTripLatency, Aggregation: DefaultLatencyDistribution, Description: "End-to-end latency, by response code", TagKeys: []tag.Key{StatusCode}, } ) // Views are the default views provided by this package. func Views() []*view.View { return []*view.View{ RequestCountView, ResponseCountView, ResponseBytesView, RoundTripLatencyView, } } // statsTransport is an gremlin.RoundTripper that collects stats for the outgoing requests. type statsTransport struct { base gremlin.RoundTripper } func (t statsTransport) RoundTrip(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { stats.Record(ctx, RequestCount.M(1)) start := time.Now() rsp, err := t.base.RoundTrip(ctx, req) latency := float64(time.Since(start)) / float64(time.Millisecond) var ( tags = make([]tag.Mutator, 1) ms = []stats.Measurement{RoundTripLatency.M(latency)} ) if err == nil { tags[0] = tag.Upsert(StatusCode, strconv.Itoa(rsp.Status.Code)) ms = append(ms, ResponseBytes.M(int64(len(rsp.Result.Data)))) } else { tags[0] = tag.Upsert(StatusCode, "error") } _ = stats.RecordWithTags(ctx, tags, ms...) return rsp, err } ent-0.11.3/dialect/gremlin/ocgremlin/stats_test.go000066400000000000000000000045171431500740500221350ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "strings" "testing" "entgo.io/ent/dialect/gremlin" "entgo.io/ent/dialect/gremlin/encoding/graphson" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opencensus.io/stats/view" ) func TestStatsCollection(t *testing.T) { err := view.Register( RequestCountView, ResponseCountView, ResponseBytesView, RoundTripLatencyView, ) require.NoError(t, err) req := gremlin.NewEvalRequest("g.E()") rsp := &gremlin.Response{RequestID: req.RequestID} rsp.Status.Code = gremlin.StatusSuccess rsp.Result.Data = graphson.RawMessage( `{"@type": "g:List", "@value": [{"@type": "g:Int32", "@value": 42}]}`, ) transport := &mockTransport{} transport.On("RoundTrip", mock.Anything, mock.Anything). Return(rsp, nil). Once() defer transport.AssertExpectations(t) rt := &statsTransport{transport} _, _ = rt.RoundTrip(context.Background(), req) tests := []struct { name string expect func(*testing.T, *view.Row) }{ { name: "gremlin/request_count", expect: func(t *testing.T, row *view.Row) { count, ok := row.Data.(*view.CountData) require.True(t, ok) assert.Equal(t, int64(1), count.Value) }, }, { name: "gremlin/response_count", expect: func(t *testing.T, row *view.Row) { count, ok := row.Data.(*view.CountData) require.True(t, ok) assert.Equal(t, int64(1), count.Value) }, }, { name: "gremlin/response_bytes", expect: func(t *testing.T, row *view.Row) { data, ok := row.Data.(*view.DistributionData) require.True(t, ok) assert.EqualValues(t, len(rsp.Result.Data), data.Sum()) }, }, { name: "gremlin/roundtrip_latency", expect: func(t *testing.T, row *view.Row) { data, ok := row.Data.(*view.DistributionData) require.True(t, ok) assert.NotZero(t, data.Sum()) }, }, } for _, tt := range tests { tt := tt t.Run(tt.name[strings.Index(tt.name, "/")+1:], func(t *testing.T) { v := view.Find(tt.name) assert.NotNil(t, v) rows, err := view.RetrieveData(tt.name) require.NoError(t, err) require.Len(t, rows, 1) tt.expect(t, rows[0]) }) } } ent-0.11.3/dialect/gremlin/ocgremlin/trace.go000066400000000000000000000073761431500740500210440ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "context" "fmt" "entgo.io/ent/dialect/gremlin" "go.opencensus.io/trace" ) // Attributes recorded on the span for the requests. const ( RequestIDAttribute = "gremlin.request_id" OperationAttribute = "gremlin.operation" QueryAttribute = "gremlin.query" BindingAttribute = "gremlin.binding" CodeAttribute = "gremlin.code" MessageAttribute = "gremlin.message" ) type traceTransport struct { base gremlin.RoundTripper startOptions trace.StartOptions formatSpanName func(context.Context, *gremlin.Request) string withQuery bool } func (t *traceTransport) RoundTrip(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { ctx, span := trace.StartSpan(ctx, t.formatSpanName(ctx, req), trace.WithSampler(t.startOptions.Sampler), trace.WithSpanKind(trace.SpanKindClient), ) defer span.End() span.AddAttributes(requestAttrs(req, t.withQuery)...) rsp, err := t.base.RoundTrip(ctx, req) if err != nil { span.SetStatus(trace.Status{Code: trace.StatusCodeUnknown, Message: err.Error()}) return rsp, err } span.AddAttributes(responseAttrs(rsp)...) span.SetStatus(TraceStatus(rsp.Status.Code)) return rsp, err } func requestAttrs(req *gremlin.Request, withQuery bool) []trace.Attribute { attrs := []trace.Attribute{ trace.StringAttribute(RequestIDAttribute, req.RequestID), trace.StringAttribute(OperationAttribute, req.Operation), } if withQuery { query, _ := req.Arguments[gremlin.ArgsGremlin].(string) attrs = append(attrs, trace.StringAttribute(QueryAttribute, query)) if bindings, ok := req.Arguments[gremlin.ArgsBindings].(map[string]any); ok { attrs = append(attrs, bindingsAttrs(bindings)...) } } return attrs } func bindingsAttrs(bindings map[string]any) []trace.Attribute { attrs := make([]trace.Attribute, 0, len(bindings)) for key, val := range bindings { key = BindingAttribute + "." + key attrs = append(attrs, bindingToAttr(key, val)) } return attrs } func bindingToAttr(key string, val any) trace.Attribute { switch v := val.(type) { case nil: return trace.StringAttribute(key, "") case int64: return trace.Int64Attribute(key, v) case float64: return trace.Float64Attribute(key, v) case string: return trace.StringAttribute(key, v) case bool: return trace.BoolAttribute(key, v) default: s := fmt.Sprintf("%v", v) if len(s) > 256 { s = s[:256] } return trace.StringAttribute(key, s) } } func responseAttrs(rsp *gremlin.Response) []trace.Attribute { attrs := []trace.Attribute{ trace.Int64Attribute(CodeAttribute, int64(rsp.Status.Code)), } if rsp.Status.Message != "" { attrs = append(attrs, trace.StringAttribute(MessageAttribute, rsp.Status.Message)) } return attrs } // TraceStatus is a utility to convert the gremlin status code to a trace.Status. func TraceStatus(status int) trace.Status { var code int32 switch status { case gremlin.StatusSuccess, gremlin.StatusNoContent, gremlin.StatusPartialContent: code = trace.StatusCodeOK case gremlin.StatusUnauthorized: code = trace.StatusCodePermissionDenied case gremlin.StatusAuthenticate: code = trace.StatusCodeUnauthenticated case gremlin.StatusMalformedRequest, gremlin.StatusInvalidRequestArguments, gremlin.StatusScriptEvaluationError: code = trace.StatusCodeInvalidArgument case gremlin.StatusServerError, gremlin.StatusServerSerializationError: code = trace.StatusCodeInternal case gremlin.StatusServerTimeout: code = trace.StatusCodeDeadlineExceeded default: code = trace.StatusCodeUnknown } return trace.Status{Code: code, Message: gremlin.StatusText(status)} } ent-0.11.3/dialect/gremlin/ocgremlin/trace_test.go000066400000000000000000000165601431500740500220760ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package ocgremlin import ( "bytes" "context" "errors" "fmt" "testing" "entgo.io/ent/dialect/gremlin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opencensus.io/trace" ) type mockTransport struct { mock.Mock } func (t *mockTransport) RoundTrip(ctx context.Context, req *gremlin.Request) (*gremlin.Response, error) { args := t.Called(ctx, req) rsp, _ := args.Get(0).(*gremlin.Response) return rsp, args.Error(1) } func TestTraceTransportRoundTrip(t *testing.T) { _, parent := trace.StartSpan(context.Background(), "parent") tests := []struct { name string parent *trace.Span }{ { name: "no parent", parent: nil, }, { name: "parent", parent: parent, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { transport := &mockTransport{} transport.On("RoundTrip", mock.Anything, mock.Anything). Run(func(args mock.Arguments) { span := trace.FromContext(args.Get(0).(context.Context)) require.NotNil(t, span) if tt.parent != nil { assert.Equal(t, tt.parent.SpanContext().TraceID, span.SpanContext().TraceID) } }). Return(nil, errors.New("noop")). Once() defer transport.AssertExpectations(t) ctx, req := context.Background(), gremlin.NewEvalRequest("g.V()") if tt.parent != nil { ctx = trace.NewContext(ctx, tt.parent) } rt := &Transport{Base: transport} _, _ = rt.RoundTrip(ctx, req) }) } } type testExporter struct { spans []*trace.SpanData } func (t *testExporter) ExportSpan(s *trace.SpanData) { t.spans = append(t.spans, s) } func TestEndToEnd(t *testing.T) { trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()}) var exporter testExporter trace.RegisterExporter(&exporter) defer trace.UnregisterExporter(&exporter) req := gremlin.NewEvalRequest("g.V()") rsp := &gremlin.Response{ RequestID: req.RequestID, } rsp.Status.Code = 200 rsp.Status.Message = "OK" var transport mockTransport transport.On("RoundTrip", mock.Anything, mock.Anything). Return(rsp, nil). Once() defer transport.AssertExpectations(t) rt := &Transport{Base: &transport, WithQuery: true} _, err := rt.RoundTrip(context.Background(), req) require.NoError(t, err) require.Len(t, exporter.spans, 1) attrs := exporter.spans[0].Attributes assert.Len(t, attrs, 5) assert.Equal(t, req.RequestID, attrs["gremlin.request_id"]) assert.Equal(t, req.Operation, attrs["gremlin.operation"]) assert.Equal(t, req.Arguments[gremlin.ArgsGremlin], attrs["gremlin.query"]) assert.Equal(t, int64(200), attrs["gremlin.code"]) assert.Equal(t, "OK", attrs["gremlin.message"]) } func TestRequestAttributes(t *testing.T) { tests := []struct { name string makeReq func() *gremlin.Request wantAttrs []trace.Attribute }{ { name: "Query without bindings", makeReq: func() *gremlin.Request { req := gremlin.NewEvalRequest("g.E().count()") req.RequestID = "a8b5c664-03ca-4175-a9e7-569b46f3551c" return req }, wantAttrs: []trace.Attribute{ trace.StringAttribute("gremlin.request_id", "a8b5c664-03ca-4175-a9e7-569b46f3551c"), trace.StringAttribute("gremlin.operation", "eval"), trace.StringAttribute("gremlin.query", "g.E().count()"), }, }, { name: "Query with bindings", makeReq: func() *gremlin.Request { bindings := map[string]any{ "$1": "user", "$2": int64(42), "$3": 3.14, "$4": bytes.Repeat([]byte{0xff}, 257), "$5": true, "$6": nil, } req := gremlin.NewEvalRequest( `g.V().hasLabel($1).has("age",$2).has("v",$3).limit($4).valueMap($5)`, gremlin.WithBindings(bindings), ) req.RequestID = "d3d986fa-bd22-41bd-b2f7-ef2f1f639260" return req }, wantAttrs: []trace.Attribute{ trace.StringAttribute("gremlin.request_id", "d3d986fa-bd22-41bd-b2f7-ef2f1f639260"), trace.StringAttribute("gremlin.operation", "eval"), trace.StringAttribute("gremlin.query", `g.V().hasLabel($1).has("age",$2).has("v",$3).limit($4).valueMap($5)`), trace.StringAttribute("gremlin.binding.$1", "user"), trace.Int64Attribute("gremlin.binding.$2", 42), trace.Float64Attribute("gremlin.binding.$3", 3.14), trace.StringAttribute("gremlin.binding.$4", func() string { str := fmt.Sprintf("%v", bytes.Repeat([]byte{0xff}, 256)) return str[:256] }()), trace.BoolAttribute("gremlin.binding.$5", true), trace.StringAttribute("gremlin.binding.$6", ""), }, }, { name: "Authentication", makeReq: func() *gremlin.Request { return gremlin.NewAuthRequest( "d239d950-59a1-41a7-a103-908f976ebd89", "user", "pass", ) }, wantAttrs: []trace.Attribute{ trace.StringAttribute("gremlin.request_id", "d239d950-59a1-41a7-a103-908f976ebd89"), trace.StringAttribute("gremlin.operation", "authentication"), trace.StringAttribute("gremlin.query", ""), }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { req := tt.makeReq() attrs := requestAttrs(req, true) for _, attr := range attrs { assert.Contains(t, tt.wantAttrs, attr) } assert.Len(t, attrs, len(tt.wantAttrs)) }) } } func TestResponseAttributes(t *testing.T) { tests := []struct { name string makeRsp func() *gremlin.Response wantAttrs []trace.Attribute }{ { name: "Success no message", makeRsp: func() *gremlin.Response { var rsp gremlin.Response rsp.Status.Code = 204 return &rsp }, wantAttrs: []trace.Attribute{ trace.Int64Attribute("gremlin.code", 204), }, }, { name: "Authenticate with message", makeRsp: func() *gremlin.Response { var rsp gremlin.Response rsp.Status.Code = 407 rsp.Status.Message = "login required" return &rsp }, wantAttrs: []trace.Attribute{ trace.Int64Attribute("gremlin.code", 407), trace.StringAttribute("gremlin.message", "login required"), }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { rsp := tt.makeRsp() attrs := responseAttrs(rsp) assert.Equal(t, tt.wantAttrs, attrs) }) } } func TestTraceStatus(t *testing.T) { tests := []struct { in int want trace.Status }{ {200, trace.Status{Code: trace.StatusCodeOK, Message: "Success"}}, {204, trace.Status{Code: trace.StatusCodeOK, Message: "No Content"}}, {206, trace.Status{Code: trace.StatusCodeOK, Message: "Partial Content"}}, {401, trace.Status{Code: trace.StatusCodePermissionDenied, Message: "Unauthorized"}}, {407, trace.Status{Code: trace.StatusCodeUnauthenticated, Message: "Authenticate"}}, {498, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: "Malformed Request"}}, {499, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: "Invalid Request Arguments"}}, {500, trace.Status{Code: trace.StatusCodeInternal, Message: "Server Error"}}, {597, trace.Status{Code: trace.StatusCodeInvalidArgument, Message: "Script Evaluation Error"}}, {598, trace.Status{Code: trace.StatusCodeDeadlineExceeded, Message: "Server Timeout"}}, {599, trace.Status{Code: trace.StatusCodeInternal, Message: "Server Serialization Error"}}, {600, trace.Status{Code: trace.StatusCodeUnknown, Message: ""}}, } for _, tt := range tests { assert.Equal(t, tt.want, TraceStatus(tt.in)) } } ent-0.11.3/dialect/gremlin/request.go000066400000000000000000000051441431500740500174460ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "bytes" "encoding/base64" "errors" "time" "github.com/google/uuid" ) type ( // A Request models a request message sent to the server. Request struct { RequestID string `json:"requestId" graphson:"g:UUID"` Operation string `json:"op"` Processor string `json:"processor"` Arguments map[string]any `json:"args"` } // RequestOption enables request customization. RequestOption func(*Request) // Credentials holds request plain auth credentials. Credentials struct{ Username, Password string } ) // NewEvalRequest returns a new evaluation request request. func NewEvalRequest(query string, opts ...RequestOption) *Request { r := &Request{ RequestID: uuid.New().String(), Operation: OpsEval, Arguments: map[string]any{ ArgsGremlin: query, ArgsLanguage: "gremlin-groovy", }, } for i := range opts { opts[i](r) } return r } // NewAuthRequest returns a new auth request. func NewAuthRequest(requestID, username, password string) *Request { return &Request{ RequestID: requestID, Operation: OpsAuthentication, Arguments: map[string]any{ ArgsSasl: Credentials{ Username: username, Password: password, }, ArgsSaslMechanism: "PLAIN", }, } } // WithBindings sets request bindings. func WithBindings(bindings map[string]any) RequestOption { return func(r *Request) { r.Arguments[ArgsBindings] = bindings } } // WithEvalTimeout sets script evaluation timeout. func WithEvalTimeout(timeout time.Duration) RequestOption { return func(r *Request) { r.Arguments[ArgsEvalTimeout] = int64(timeout / time.Millisecond) } } // MarshalText implements encoding.TextMarshaler interface. func (c Credentials) MarshalText() ([]byte, error) { var buf bytes.Buffer buf.WriteByte(0) buf.WriteString(c.Username) buf.WriteByte(0) buf.WriteString(c.Password) enc := base64.StdEncoding text := make([]byte, enc.EncodedLen(buf.Len())) enc.Encode(text, buf.Bytes()) return text, nil } // UnmarshalText implements encoding.TextUnmarshaler interface. func (c *Credentials) UnmarshalText(text []byte) error { enc := base64.StdEncoding data := make([]byte, enc.DecodedLen(len(text))) n, err := enc.Decode(data, text) if err != nil { return err } data = data[:n] parts := bytes.SplitN(data, []byte{0}, 3) if len(parts) != 3 { return errors.New("bad credentials data") } c.Username = string(parts[1]) c.Password = string(parts[2]) return nil } ent-0.11.3/dialect/gremlin/request_test.go000066400000000000000000000057761431500740500205200ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "encoding/json" "testing" "time" "entgo.io/ent/dialect/gremlin/encoding/graphson" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEvaluateRequestEncode(t *testing.T) { req := NewEvalRequest("g.V(x)", WithBindings(map[string]any{"x": 1}), WithEvalTimeout(time.Second), ) data, err := graphson.Marshal(req) require.NoError(t, err) var got map[string]any err = json.Unmarshal(data, &got) require.NoError(t, err) assert.Equal(t, map[string]any{ "@type": "g:UUID", "@value": req.RequestID, }, got["requestId"]) assert.Equal(t, req.Operation, got["op"]) assert.Equal(t, req.Processor, got["processor"]) args := got["args"].(map[string]any) assert.Equal(t, "g:Map", args["@type"]) assert.ElementsMatch(t, args["@value"], []any{ "gremlin", "g.V(x)", "language", "gremlin-groovy", "scriptEvaluationTimeout", map[string]any{ "@type": "g:Int64", "@value": float64(1000), }, "bindings", map[string]any{ "@type": "g:Map", "@value": []any{ "x", map[string]any{ "@type": "g:Int64", "@value": float64(1), }, }, }, }) } func TestEvaluateRequestWithoutBindingsEncode(t *testing.T) { req := NewEvalRequest("g.E()") got, err := graphson.MarshalToString(req) require.NoError(t, err) assert.NotContains(t, got, "bindings") } func TestAuthenticateRequestEncode(t *testing.T) { req := NewAuthRequest("41d2e28a-20a4-4ab0-b379-d810dede3786", "user", "pass") data, err := graphson.Marshal(req) require.NoError(t, err) var got map[string]any err = json.Unmarshal(data, &got) require.NoError(t, err) assert.Equal(t, map[string]any{ "@type": "g:UUID", "@value": req.RequestID, }, got["requestId"]) assert.Equal(t, req.Operation, got["op"]) assert.Equal(t, req.Processor, got["processor"]) args := got["args"].(map[string]any) assert.Equal(t, "g:Map", args["@type"]) assert.ElementsMatch(t, args["@value"], []any{ "sasl", "AHVzZXIAcGFzcw==", "saslMechanism", "PLAIN", }) } func TestCredentialsMarshaling(t *testing.T) { want := Credentials{ Username: "username", Password: "password", } text, err := want.MarshalText() assert.NoError(t, err) assert.Equal(t, "AHVzZXJuYW1lAHBhc3N3b3Jk", string(text)) var got Credentials err = got.UnmarshalText(text) assert.NoError(t, err) assert.Equal(t, want, got) } func TestCredentialsBadEncodingMarshaling(t *testing.T) { tests := []struct { name string text []byte }{ { name: "BadBase64", text: []byte{0x12}, }, { name: "Empty", text: []byte{}, }, { name: "BadPrefix", text: []byte("Kg=="), }, { name: "NoSeperator", text: []byte("AHVzZXI="), }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { var creds Credentials err := creds.UnmarshalText(tc.text) assert.Error(t, err) }) } } ent-0.11.3/dialect/gremlin/response.go000066400000000000000000000063141431500740500176140ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "errors" "fmt" "entgo.io/ent/dialect/gremlin/encoding/graphson" "entgo.io/ent/dialect/gremlin/graph" ) // A Response models a response message received from the server. type Response struct { RequestID string `json:"requestId" graphson:"g:UUID"` Status struct { Code int `json:"code"` Attributes map[string]any `json:"attributes"` Message string `json:"message"` } `json:"status"` Result struct { Data graphson.RawMessage `json:"data"` Meta map[string]any `json:"meta"` } `json:"result"` } // IsErr returns whether response indicates an error. func (rsp *Response) IsErr() bool { switch rsp.Status.Code { case StatusSuccess, StatusNoContent, StatusPartialContent: return false default: return true } } // Err returns an error representing response status. func (rsp *Response) Err() error { if rsp.IsErr() { return fmt.Errorf("gremlin: code=%d, message=%q", rsp.Status.Code, rsp.Status.Message) } return nil } // ReadVal reads gremlin response data into v. func (rsp *Response) ReadVal(v any) error { if err := rsp.Err(); err != nil { return err } if err := graphson.Unmarshal(rsp.Result.Data, v); err != nil { return fmt.Errorf("gremlin: unmarshal response data: type=%T: %w", v, err) } return nil } // ReadVertices returns response data as slice of vertices. func (rsp *Response) ReadVertices() ([]graph.Vertex, error) { var v []graph.Vertex err := rsp.ReadVal(&v) return v, err } // ReadVertexProperties returns response data as slice of vertex properties. func (rsp *Response) ReadVertexProperties() ([]graph.VertexProperty, error) { var vp []graph.VertexProperty err := rsp.ReadVal(&vp) return vp, err } // ReadEdges returns response data as slice of edges. func (rsp *Response) ReadEdges() ([]graph.Edge, error) { var e []graph.Edge err := rsp.ReadVal(&e) return e, err } // ReadProperties returns response data as slice of properties. func (rsp *Response) ReadProperties() ([]graph.Property, error) { var p []graph.Property err := rsp.ReadVal(&p) return p, err } // ReadValueMap returns response data as a value map. func (rsp *Response) ReadValueMap() (graph.ValueMap, error) { var m graph.ValueMap err := rsp.ReadVal(&m) return m, err } // ReadBool returns response data as a bool. func (rsp *Response) ReadBool() (bool, error) { var b [1]*bool if err := rsp.ReadVal(&b); err != nil { return false, err } if b[0] == nil { return false, errors.New("gremlin: no boolean value") } return *b[0], nil } // ReadInt returns response data as an int. func (rsp *Response) ReadInt() (int, error) { var v [1]*int if err := rsp.ReadVal(&v); err != nil { return 0, err } if v[0] == nil { return 0, errors.New("gremlin: no integer value") } return *v[0], nil } // ReadString returns response data as a string. func (rsp *Response) ReadString() (string, error) { var v [1]*string if err := rsp.ReadVal(&v); err != nil { return "", err } if v[0] == nil { return "", errors.New("gremlin: no string value") } return *v[0], nil } ent-0.11.3/dialect/gremlin/response_test.go000066400000000000000000000230541431500740500206530ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "reflect" "testing" "entgo.io/ent/dialect/gremlin/encoding/graphson" "entgo.io/ent/dialect/gremlin/graph" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDecodeResponse(t *testing.T) { in := `{ "requestId": "a65f2d39-1efa-45d2-a06a-c736476500fc", "result": { "data": { "@type": "g:List", "@value": [ { "@type": "g:Map", "@value": [ { "@type": "g:T", "@value": "id" }, { "@type": "g:Int64", "@value": 1 }, { "@type": "g:T", "@value": "label" }, "person", "name", { "@type": "g:List", "@value": [ "marko" ] }, "age", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 29 } ] } ] }, { "@type": "g:Map", "@value": [ { "@type": "g:T", "@value": "id" }, { "@type": "g:Int64", "@value": 6 }, { "@type": "g:T", "@value": "label" }, "person", "name", { "@type": "g:List", "@value": [ "peter" ] }, "age", { "@type": "g:List", "@value": [ { "@type": "g:Int32", "@value": 35 } ] } ] } ] }, "meta": { "@type": "g:Map", "@value": [] } }, "status": { "attributes": { "@type": "g:Map", "@value": [] }, "code": 200, "message": "" } }` var rsp Response err := graphson.UnmarshalFromString(in, &rsp) require.NoError(t, err) assert.Equal(t, "a65f2d39-1efa-45d2-a06a-c736476500fc", rsp.RequestID) assert.Equal(t, 200, rsp.Status.Code) assert.Empty(t, rsp.Status.Message) assert.Empty(t, rsp.Status.Attributes) assert.Empty(t, rsp.Result.Meta) var vm graph.ValueMap err = graphson.Unmarshal(rsp.Result.Data, &vm) require.NoError(t, err) require.Len(t, vm, 2) type person struct { ID int64 `json:"id"` Name string `json:"name"` Age int `json:"age"` } var people []person err = vm.Decode(&people) require.NoError(t, err) assert.Equal(t, []person{ {1, "marko", 29}, {6, "peter", 35}, }, people) } func TestDecodeResponseWithError(t *testing.T) { in := `{ "requestId": "41d2e28a-20a4-4ab0-b379-d810dede3786", "result": { "data": null, "meta": { "@type": "g:Map", "@value": [] } }, "status": { "attributes": { "@type": "g:Map", "@value": [] }, "code": 500, "message": "Database Down" } }` var rsp Response err := graphson.UnmarshalFromString(in, &rsp) require.NoError(t, err) err = rsp.Err() require.Error(t, err) assert.Contains(t, err.Error(), "Database Down") rsp = Response{} err = graphson.UnmarshalFromString(`{"status": null}`, &rsp) require.NoError(t, err) assert.Error(t, rsp.Err()) } func TestResponseReadVal(t *testing.T) { var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(`{"@type": "g:Int32", "@value": 15}`) var v int32 err := rsp.ReadVal(&v) assert.NoError(t, err) assert.Equal(t, int32(15), v) var s string err = rsp.ReadVal(&s) assert.Error(t, err) rsp.Status.Code = StatusServerError err = rsp.ReadVal(&v) assert.Error(t, err) } func TestResponseReadGraphElements(t *testing.T) { tests := []struct { method string data string want any }{ { method: "ReadVertices", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 1 }, "label": "person" } }, { "@type": "g:Vertex", "@value": { "id": { "@type": "g:Int64", "@value": 6 }, "label": "person" } } ] }`, want: []graph.Vertex{ graph.NewVertex(int64(1), "person"), graph.NewVertex(int64(6), "person"), }, }, { method: "ReadVertexProperties", data: `{ "@type": "g:List", "@value": [ { "@type": "g:VertexProperty", "@value": { "id": { "@type": "g:Int64", "@value": 0 }, "label": "name", "value": "marko" } }, { "@type": "g:VertexProperty", "@value": { "id": { "@type": "g:Int64", "@value": 2 }, "label": "age", "value": { "@type": "g:Int32", "@value": 29 } } } ] }`, want: []graph.VertexProperty{ graph.NewVertexProperty(int64(0), "name", "marko"), graph.NewVertexProperty(int64(2), "age", int32(29)), }, }, { method: "ReadEdges", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Edge", "@value": { "id": { "@type": "g:Int32", "@value": 12 }, "inV": { "@type": "g:Int64", "@value": 3 }, "inVLabel": "software", "label": "created", "outV": { "@type": "g:Int64", "@value": 6 }, "outVLabel": "person" } } ] }`, want: []graph.Edge{ graph.NewEdge(int32(12), "created", graph.NewVertex(int64(6), "person"), graph.NewVertex(int64(3), "software"), ), }, }, { method: "ReadProperties", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Property", "@value": { "key": "weight", "value": { "@type": "g:Double", "@value": 0.2 } } } ] }`, want: []graph.Property{ graph.NewProperty("weight", float64(0.2)), }, }, } for _, tc := range tests { tc := tc t.Run(tc.method, func(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(tc.data) vals := reflect.ValueOf(&rsp).MethodByName(tc.method).Call(nil) require.Len(t, vals, 2) require.True(t, vals[1].IsNil()) assert.Equal(t, tc.want, vals[0].Interface()) }) } } func TestResponseReadValueMap(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(`{ "@type": "g:List", "@value": [ { "@type": "g:Map", "@value": [ "name", { "@type": "g:List", "@value": [ "alex" ] } ] } ] }`) m, err := rsp.ReadValueMap() require.NoError(t, err) var name string err = m.Decode(&struct { Name *string `json:"name"` }{&name}) require.NoError(t, err) assert.Equal(t, "alex", name) } func TestResponseReadBool(t *testing.T) { tests := []struct { name string data string want bool wantErr bool }{ { name: "Simple", data: `{ "@type": "g:List", "@value": [ true ] }`, want: true, }, { name: "Multi", data: `{ "@type": "g:List", "@value": [ false, true ] }`, want: false, }, { name: "Empty", data: `{ "@type": "g:List", "@value": [] }`, wantErr: true, }, { name: "BadType", data: `{ "@type": "g:List", "@value": [ "user" ] }`, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(tc.data) got, err := rsp.ReadBool() if tc.wantErr { assert.Error(t, err) } else { require.NoError(t, err) assert.Equal(t, tc.want, got) } }) } } func TestResponseReadInt(t *testing.T) { tests := []struct { name string data string want int wantErr bool }{ { name: "Simple", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Int64", "@value": 42 } ] }`, want: 42, }, { name: "Multi", data: `{ "@type": "g:List", "@value": [ { "@type": "g:Int64", "@value": 55 }, { "@type": "g:Int64", "@value": 13 } ] }`, want: 55, }, { name: "Empty", data: `{ "@type": "g:List", "@value": [] }`, wantErr: true, }, { name: "BadType", data: `{ "@type": "g:List", "@value": [ true ] }`, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(tc.data) got, err := rsp.ReadInt() if tc.wantErr { assert.Error(t, err) } else { require.NoError(t, err) assert.Equal(t, tc.want, got) } }) } } func TestResponseReadString(t *testing.T) { tests := []struct { name string data string want string wantErr bool }{ { name: "Simple", data: `{ "@type": "g:List", "@value": ["foo"] }`, want: "foo", }, { name: "Empty", data: `{ "@type": "g:List", "@value": [] }`, wantErr: true, }, { name: "BadType", data: `{ "@type": "g:List", "@value": [ true ] }`, wantErr: true, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var rsp Response rsp.Status.Code = StatusSuccess rsp.Result.Data = []byte(tc.data) got, err := rsp.ReadString() if tc.wantErr { assert.Error(t, err) } else { require.NoError(t, err) assert.Equal(t, tc.want, got) } }) } } ent-0.11.3/dialect/gremlin/status.go000066400000000000000000000060571431500740500173050ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin const ( // StatusSuccess is returned on success. StatusSuccess = 200 // StatusNoContent means the server processed the request but there is no result to return. StatusNoContent = 204 // StatusPartialContent indicates the server successfully returned some content, but there // is more in the stream to arrive wait for a success code to signify the end. StatusPartialContent = 206 // StatusUnauthorized means the request attempted to access resources that // the requesting user did not have access to. StatusUnauthorized = 401 // StatusAuthenticate denotes a challenge from the server for the client to authenticate its request. StatusAuthenticate = 407 // StatusMalformedRequest means the request message was not properly formatted which means it could not be parsed at // all or the "op" code was not recognized such that Gremlin Server could properly route it for processing. // Check the message format and retry the request. StatusMalformedRequest = 498 // StatusInvalidRequestArguments means the request message was parsable, but the arguments supplied in the message // were in conflict or incomplete. Check the message format and retry the request. StatusInvalidRequestArguments = 499 // StatusServerError indicates a general server error occurred that prevented the request from being processed. StatusServerError = 500 // StatusScriptEvaluationError is returned when the script submitted for processing evaluated in the ScriptEngine // with errors and could not be processed. Check the script submitted for syntax errors or other problems // and then resubmit. StatusScriptEvaluationError = 597 // StatusServerTimeout means the server exceeded one of the timeout settings for the request and could therefore // only partially responded or did not respond at all. StatusServerTimeout = 598 // StatusServerSerializationError means the server was not capable of serializing an object that was returned from the // script supplied on the request. Either transform the object into something Gremlin Server can process within // the script or install mapper serialization classes to Gremlin Server. StatusServerSerializationError = 599 ) var statusText = map[int]string{ StatusSuccess: "Success", StatusNoContent: "No Content", StatusPartialContent: "Partial Content", StatusUnauthorized: "Unauthorized", StatusAuthenticate: "Authenticate", StatusMalformedRequest: "Malformed Request", StatusInvalidRequestArguments: "Invalid Request Arguments", StatusServerError: "Server Error", StatusScriptEvaluationError: "Script Evaluation Error", StatusServerTimeout: "Server Timeout", StatusServerSerializationError: "Server Serialization Error", } // StatusText returns status text of code. func StatusText(code int) string { return statusText[code] } ent-0.11.3/dialect/gremlin/status_test.go000066400000000000000000000006121431500740500203330ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin import ( "testing" "github.com/stretchr/testify/assert" ) func TestStatusText(t *testing.T) { assert.NotEmpty(t, StatusText(StatusSuccess)) assert.Empty(t, StatusText(4242)) } ent-0.11.3/dialect/gremlin/tokens.go000066400000000000000000000045451431500740500172650ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package gremlin // Gremlin server operations. const ( // OpsAuthentication used by the client to authenticate itself. OpsAuthentication = "authentication" // OpsBytecode used for a request that contains the Bytecode representation of a Traversal. OpsBytecode = "bytecode" // OpsEval used to evaluate a Gremlin script provided as a string. OpsEval = "eval" // OpsGather used to get a particular side-effect as produced by a previously executed Traversal. OpsGather = "gather" // OpsKeys used to get all the keys of all side-effects as produced by a previously executed Traversal. OpsKeys = "keys" // OpsClose used to get all the keys of all side-effects as produced by a previously executed Traversal. OpsClose = "close" ) // Gremlin server operation processors. const ( // ProcessorTraversal is the default operation processor. ProcessorTraversal = "traversal" ) const ( // ArgsBatchSize allows to defines the number of iterations each ResponseMessage should contain ArgsBatchSize = "batchSize" // ArgsBindings allows to provide a map of key/value pairs to apply // as variables in the context of the Gremlin script. ArgsBindings = "bindings" // ArgsAliases allows to define aliases that represent globally bound Graph and TraversalSource objects. ArgsAliases = "aliases" // ArgsGremlin corresponds to the Traversal to evaluate. ArgsGremlin = "gremlin" // ArgsSideEffect allows to specify the unique identifier for the request. ArgsSideEffect = "sideEffect" // ArgsSideEffectKey allows to specify the key for a specific side-effect. ArgsSideEffectKey = "sideEffectKey" // ArgsAggregateTo describes how side-effect data should be treated. ArgsAggregateTo = "aggregateTo" // ArgsLanguage allows to change the flavor of Gremlin used (e.g. gremlin-groovy). ArgsLanguage = "language" // ArgsEvalTimeout allows to override the server setting that determines // the maximum time to wait for a script to execute on the server. ArgsEvalTimeout = "scriptEvaluationTimeout" // ArgsSasl defines the response to the server authentication challenge. ArgsSasl = "sasl" // ArgsSaslMechanism defines the SASL mechanism (e.g. PLAIN). ArgsSaslMechanism = "saslMechanism" ) ent-0.11.3/dialect/sql/000077500000000000000000000000001431500740500145655ustar00rootroot00000000000000ent-0.11.3/dialect/sql/bench_test.go000066400000000000000000000020271431500740500172330ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sql import ( "testing" "entgo.io/ent/dialect" ) func BenchmarkInsertBuilder_Default(b *testing.B) { for _, d := range []string{dialect.SQLite, dialect.MySQL, dialect.Postgres} { b.Run(d, func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { Dialect(d).Insert("users").Default().Returning("id").Query() } }) } } func BenchmarkInsertBuilder_Small(b *testing.B) { for _, d := range []string{dialect.SQLite, dialect.MySQL, dialect.Postgres} { b.Run(d, func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { Dialect(d).Insert("users"). Columns("id", "age", "first_name", "last_name", "nickname", "spouse_id", "created_at", "updated_at"). Values(1, 30, "Ariel", "Mashraki", "a8m", 2, "2009-11-10 23:00:00", "2009-11-10 23:00:00"). Returning("id"). Query() } }) } } ent-0.11.3/dialect/sql/builder.go000066400000000000000000002643071431500740500165560ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package sql provides wrappers around the standard database/sql package // to allow the generated code to interact with a statically-typed API. // // Users that are interacting with this package should be aware that the // following builders don't check the given SQL syntax nor validate or escape // user-inputs. ~All validations are expected to be happened in the generated // ent package. package sql import ( "context" "database/sql/driver" "errors" "fmt" "strconv" "strings" "entgo.io/ent/dialect" ) // Querier wraps the basic Query method that is implemented // by the different builders in this file. type Querier interface { // Query returns the query representation of the element // and its arguments (if any). Query() (string, []any) } // querierErr allowed propagate Querier's inner error type querierErr interface { Err() error } // ColumnBuilder is a builder for column definition in table creation. type ColumnBuilder struct { Builder typ string // column type. name string // column name. attr string // extra attributes. modify bool // modify existing. fk *ForeignKeyBuilder // foreign-key constraint. check func(*Builder) // column checks. } // Column returns a new ColumnBuilder with the given name. // // sql.Column("group_id").Type("int").Attr("UNIQUE") func Column(name string) *ColumnBuilder { return &ColumnBuilder{name: name} } // Type sets the column type. func (c *ColumnBuilder) Type(t string) *ColumnBuilder { c.typ = t return c } // Attr sets an extra attribute for the column, like UNIQUE or AUTO_INCREMENT. func (c *ColumnBuilder) Attr(attr string) *ColumnBuilder { if c.attr != "" && attr != "" { c.attr += " " } c.attr += attr return c } // Constraint adds the CONSTRAINT clause to the ADD COLUMN statement in SQLite. func (c *ColumnBuilder) Constraint(fk *ForeignKeyBuilder) *ColumnBuilder { c.fk = fk return c } // Check adds a CHECK clause to the ADD COLUMN statement. func (c *ColumnBuilder) Check(check func(*Builder)) *ColumnBuilder { c.check = check return c } // Query returns query representation of a Column. func (c *ColumnBuilder) Query() (string, []any) { c.Ident(c.name) if c.typ != "" { if c.postgres() && c.modify { c.WriteString(" TYPE") } c.Pad().WriteString(c.typ) } if c.attr != "" { c.Pad().WriteString(c.attr) } if c.fk != nil { c.WriteString(" CONSTRAINT " + c.fk.symbol) c.Pad().Join(c.fk.ref) for _, action := range c.fk.actions { c.Pad().WriteString(action) } } if c.check != nil { c.WriteString(" CHECK ") c.Nested(c.check) } return c.String(), c.args } // TableBuilder is a query builder for `CREATE TABLE` statement. type TableBuilder struct { Builder name string // table name. exists bool // check existence. charset string // table charset. collation string // table collation. options string // table options. columns []Querier // table columns. primary []string // primary key. constraints []Querier // foreign keys and indices. checks []func(*Builder) // check constraints. } // CreateTable returns a query builder for the `CREATE TABLE` statement. // // CreateTable("users"). // Columns( // Column("id").Type("int").Attr("auto_increment"), // Column("name").Type("varchar(255)"), // ). // PrimaryKey("id") func CreateTable(name string) *TableBuilder { return &TableBuilder{name: name} } // IfNotExists appends the `IF NOT EXISTS` clause to the `CREATE TABLE` statement. func (t *TableBuilder) IfNotExists() *TableBuilder { t.exists = true return t } // Column appends the given column to the `CREATE TABLE` statement. func (t *TableBuilder) Column(c *ColumnBuilder) *TableBuilder { t.columns = append(t.columns, c) return t } // Columns appends the a list of columns to the builder. func (t *TableBuilder) Columns(columns ...*ColumnBuilder) *TableBuilder { t.columns = make([]Querier, 0, len(columns)) for i := range columns { t.columns = append(t.columns, columns[i]) } return t } // PrimaryKey adds a column to the primary-key constraint in the statement. func (t *TableBuilder) PrimaryKey(column ...string) *TableBuilder { t.primary = append(t.primary, column...) return t } // ForeignKeys adds a list of foreign-keys to the statement (without constraints). func (t *TableBuilder) ForeignKeys(fks ...*ForeignKeyBuilder) *TableBuilder { queries := make([]Querier, len(fks)) for i := range fks { // Erase the constraint symbol/name. fks[i].symbol = "" queries[i] = fks[i] } t.constraints = append(t.constraints, queries...) return t } // Constraints adds a list of foreign-key constraints to the statement. func (t *TableBuilder) Constraints(fks ...*ForeignKeyBuilder) *TableBuilder { queries := make([]Querier, len(fks)) for i := range fks { queries[i] = &Wrapper{"CONSTRAINT %s", fks[i]} } t.constraints = append(t.constraints, queries...) return t } // Checks adds CHECK clauses to the CREATE TABLE statement. func (t *TableBuilder) Checks(checks ...func(*Builder)) *TableBuilder { t.checks = append(t.checks, checks...) return t } // Charset appends the `CHARACTER SET` clause to the statement. MySQL only. func (t *TableBuilder) Charset(s string) *TableBuilder { t.charset = s return t } // Collate appends the `COLLATE` clause to the statement. MySQL only. func (t *TableBuilder) Collate(s string) *TableBuilder { t.collation = s return t } // Options appends additional options to to the statement (MySQL only). func (t *TableBuilder) Options(s string) *TableBuilder { t.options = s return t } // Query returns query representation of a `CREATE TABLE` statement. // // CREATE TABLE [IF NOT EXISTS] name // // (table definition) // [charset and collation] func (t *TableBuilder) Query() (string, []any) { t.WriteString("CREATE TABLE ") if t.exists { t.WriteString("IF NOT EXISTS ") } t.Ident(t.name) t.Nested(func(b *Builder) { b.JoinComma(t.columns...) if len(t.primary) > 0 { b.Comma().WriteString("PRIMARY KEY") b.Nested(func(b *Builder) { b.IdentComma(t.primary...) }) } if len(t.constraints) > 0 { b.Comma().JoinComma(t.constraints...) } for _, check := range t.checks { check(b.Comma()) } }) if t.charset != "" { t.WriteString(" CHARACTER SET " + t.charset) } if t.collation != "" { t.WriteString(" COLLATE " + t.collation) } if t.options != "" { t.WriteString(" " + t.options) } return t.String(), t.args } // DescribeBuilder is a query builder for `DESCRIBE` statement. type DescribeBuilder struct { Builder name string // table name. } // Describe returns a query builder for the `DESCRIBE` statement. // // Describe("users") func Describe(name string) *DescribeBuilder { return &DescribeBuilder{name: name} } // Query returns query representation of a `DESCRIBE` statement. func (t *DescribeBuilder) Query() (string, []any) { t.WriteString("DESCRIBE ") t.Ident(t.name) return t.String(), nil } // TableAlter is a query builder for `ALTER TABLE` statement. type TableAlter struct { Builder name string // table to alter. Queries []Querier // columns and foreign-keys to add. } // AlterTable returns a query builder for the `ALTER TABLE` statement. // // AlterTable("users"). // AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). // AddForeignKey(ForeignKey().Columns("group_id"). // Reference(Reference().Table("groups").Columns("id")).OnDelete("CASCADE")), // ) func AlterTable(name string) *TableAlter { return &TableAlter{name: name} } // AddColumn appends the `ADD COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) AddColumn(c *ColumnBuilder) *TableAlter { t.Queries = append(t.Queries, &Wrapper{"ADD COLUMN %s", c}) return t } // ModifyColumn appends the `MODIFY/ALTER COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) ModifyColumn(c *ColumnBuilder) *TableAlter { switch { case t.postgres(): c.modify = true t.Queries = append(t.Queries, &Wrapper{"ALTER COLUMN %s", c}) default: t.Queries = append(t.Queries, &Wrapper{"MODIFY COLUMN %s", c}) } return t } // RenameColumn appends the `RENAME COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) RenameColumn(old, new string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("RENAME COLUMN %s TO %s", t.Quote(old), t.Quote(new)))) return t } // ModifyColumns calls ModifyColumn with each of the given builders. func (t *TableAlter) ModifyColumns(cs ...*ColumnBuilder) *TableAlter { for _, c := range cs { t.ModifyColumn(c) } return t } // DropColumn appends the `DROP COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) DropColumn(c *ColumnBuilder) *TableAlter { t.Queries = append(t.Queries, &Wrapper{"DROP COLUMN %s", c}) return t } // ChangeColumn appends the `CHANGE COLUMN` clause to the given `ALTER TABLE` statement. func (t *TableAlter) ChangeColumn(name string, c *ColumnBuilder) *TableAlter { prefix := fmt.Sprintf("CHANGE COLUMN %s", t.Quote(name)) t.Queries = append(t.Queries, &Wrapper{prefix + " %s", c}) return t } // RenameIndex appends the `RENAME INDEX` clause to the given `ALTER TABLE` statement. func (t *TableAlter) RenameIndex(curr, new string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("RENAME INDEX %s TO %s", t.Quote(curr), t.Quote(new)))) return t } // DropIndex appends the `DROP INDEX` clause to the given `ALTER TABLE` statement. func (t *TableAlter) DropIndex(name string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("DROP INDEX %s", t.Quote(name)))) return t } // AddIndex appends the `ADD INDEX` clause to the given `ALTER TABLE` statement. func (t *TableAlter) AddIndex(idx *IndexBuilder) *TableAlter { b := &Builder{dialect: t.dialect} b.WriteString("ADD ") if idx.unique { b.WriteString("UNIQUE ") } b.WriteString("INDEX ") b.Ident(idx.name) b.Nested(func(b *Builder) { b.IdentComma(idx.columns...) }) t.Queries = append(t.Queries, b) return t } // AddForeignKey adds a foreign key constraint to the `ALTER TABLE` statement. func (t *TableAlter) AddForeignKey(fk *ForeignKeyBuilder) *TableAlter { t.Queries = append(t.Queries, &Wrapper{"ADD CONSTRAINT %s", fk}) return t } // DropConstraint appends the `DROP CONSTRAINT` clause to the given `ALTER TABLE` statement. func (t *TableAlter) DropConstraint(ident string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("DROP CONSTRAINT %s", t.Quote(ident)))) return t } // DropForeignKey appends the `DROP FOREIGN KEY` clause to the given `ALTER TABLE` statement. func (t *TableAlter) DropForeignKey(ident string) *TableAlter { t.Queries = append(t.Queries, Raw(fmt.Sprintf("DROP FOREIGN KEY %s", t.Quote(ident)))) return t } // Query returns query representation of the `ALTER TABLE` statement. // // ALTER TABLE name // [alter_specification] func (t *TableAlter) Query() (string, []any) { t.WriteString("ALTER TABLE ") t.Ident(t.name) t.Pad() t.JoinComma(t.Queries...) return t.String(), t.args } // IndexAlter is a query builder for `ALTER INDEX` statement. type IndexAlter struct { Builder name string // index to alter. Queries []Querier // alter options. } // AlterIndex returns a query builder for the `ALTER INDEX` statement. // // AlterIndex("old_key"). // Rename("new_key") func AlterIndex(name string) *IndexAlter { return &IndexAlter{name: name} } // Rename appends the `RENAME TO` clause to the `ALTER INDEX` statement. func (i *IndexAlter) Rename(name string) *IndexAlter { i.Queries = append(i.Queries, Raw(fmt.Sprintf("RENAME TO %s", i.Quote(name)))) return i } // Query returns query representation of the `ALTER INDEX` statement. // // ALTER INDEX name // [alter_specification] func (i *IndexAlter) Query() (string, []any) { i.WriteString("ALTER INDEX ") i.Ident(i.name) i.Pad() i.JoinComma(i.Queries...) return i.String(), i.args } // ForeignKeyBuilder is the builder for the foreign-key constraint clause. type ForeignKeyBuilder struct { Builder symbol string columns []string actions []string ref *ReferenceBuilder } // ForeignKey returns a builder for the foreign-key constraint clause in create/alter table statements. // // ForeignKey(). // Columns("group_id"). // Reference(Reference().Table("groups").Columns("id")). // OnDelete("CASCADE") func ForeignKey(symbol ...string) *ForeignKeyBuilder { fk := &ForeignKeyBuilder{} if len(symbol) != 0 { fk.symbol = symbol[0] } return fk } // Symbol sets the symbol of the foreign key. func (fk *ForeignKeyBuilder) Symbol(s string) *ForeignKeyBuilder { fk.symbol = s return fk } // Columns sets the columns of the foreign key in the source table. func (fk *ForeignKeyBuilder) Columns(s ...string) *ForeignKeyBuilder { fk.columns = append(fk.columns, s...) return fk } // Reference sets the reference clause. func (fk *ForeignKeyBuilder) Reference(r *ReferenceBuilder) *ForeignKeyBuilder { fk.ref = r return fk } // OnDelete sets the on delete action for this constraint. func (fk *ForeignKeyBuilder) OnDelete(action string) *ForeignKeyBuilder { fk.actions = append(fk.actions, "ON DELETE "+action) return fk } // OnUpdate sets the on delete action for this constraint. func (fk *ForeignKeyBuilder) OnUpdate(action string) *ForeignKeyBuilder { fk.actions = append(fk.actions, "ON UPDATE "+action) return fk } // Query returns query representation of a foreign key constraint. func (fk *ForeignKeyBuilder) Query() (string, []any) { if fk.symbol != "" { fk.Ident(fk.symbol).Pad() } fk.WriteString("FOREIGN KEY") fk.Nested(func(b *Builder) { b.IdentComma(fk.columns...) }) fk.Pad().Join(fk.ref) for _, action := range fk.actions { fk.Pad().WriteString(action) } return fk.String(), fk.args } // ReferenceBuilder is a builder for the reference clause in constraints. For example, in foreign key creation. type ReferenceBuilder struct { Builder table string // referenced table. columns []string // referenced columns. } // Reference create a reference builder for the reference_option clause. // // Reference().Table("groups").Columns("id") func Reference() *ReferenceBuilder { return &ReferenceBuilder{} } // Table sets the referenced table. func (r *ReferenceBuilder) Table(s string) *ReferenceBuilder { r.table = s return r } // Columns sets the columns of the referenced table. func (r *ReferenceBuilder) Columns(s ...string) *ReferenceBuilder { r.columns = append(r.columns, s...) return r } // Query returns query representation of a reference clause. func (r *ReferenceBuilder) Query() (string, []any) { r.WriteString("REFERENCES ") r.Ident(r.table) r.Nested(func(b *Builder) { b.IdentComma(r.columns...) }) return r.String(), r.args } // IndexBuilder is a builder for `CREATE INDEX` statement. type IndexBuilder struct { Builder name string unique bool exists bool table string method string columns []string } // CreateIndex creates a builder for the `CREATE INDEX` statement. // // CreateIndex("index_name"). // Unique(). // Table("users"). // Column("name") // // Or: // // CreateIndex("index_name"). // Unique(). // Table("users"). // Columns("name", "age") func CreateIndex(name string) *IndexBuilder { return &IndexBuilder{name: name} } // IfNotExists appends the `IF NOT EXISTS` clause to the `CREATE INDEX` statement. func (i *IndexBuilder) IfNotExists() *IndexBuilder { i.exists = true return i } // Unique sets the index to be a unique index. func (i *IndexBuilder) Unique() *IndexBuilder { i.unique = true return i } // Table defines the table for the index. func (i *IndexBuilder) Table(table string) *IndexBuilder { i.table = table return i } // Using sets the method to create the index with. func (i *IndexBuilder) Using(method string) *IndexBuilder { i.method = method return i } // Column appends a column to the column list for the index. func (i *IndexBuilder) Column(column string) *IndexBuilder { i.columns = append(i.columns, column) return i } // Columns appends the given columns to the column list for the index. func (i *IndexBuilder) Columns(columns ...string) *IndexBuilder { i.columns = append(i.columns, columns...) return i } // Query returns query representation of a reference clause. func (i *IndexBuilder) Query() (string, []any) { i.WriteString("CREATE ") if i.unique { i.WriteString("UNIQUE ") } i.WriteString("INDEX ") if i.exists { i.WriteString("IF NOT EXISTS ") } i.Ident(i.name) i.WriteString(" ON ") i.Ident(i.table) switch i.dialect { case dialect.Postgres: if i.method != "" { i.WriteString(" USING ").Ident(i.method) } i.Nested(func(b *Builder) { b.IdentComma(i.columns...) }) case dialect.MySQL: i.Nested(func(b *Builder) { b.IdentComma(i.columns...) }) if i.method != "" { i.WriteString(" USING " + i.method) } default: i.Nested(func(b *Builder) { b.IdentComma(i.columns...) }) } return i.String(), nil } // DropIndexBuilder is a builder for `DROP INDEX` statement. type DropIndexBuilder struct { Builder name string table string } // DropIndex creates a builder for the `DROP INDEX` statement. // // MySQL: // // DropIndex("index_name"). // Table("users"). // // SQLite/PostgreSQL: // // DropIndex("index_name") func DropIndex(name string) *DropIndexBuilder { return &DropIndexBuilder{name: name} } // Table defines the table for the index. func (d *DropIndexBuilder) Table(table string) *DropIndexBuilder { d.table = table return d } // Query returns query representation of a reference clause. // // DROP INDEX index_name [ON table_name] func (d *DropIndexBuilder) Query() (string, []any) { d.WriteString("DROP INDEX ") d.Ident(d.name) if d.table != "" { d.WriteString(" ON ") d.Ident(d.table) } return d.String(), nil } // InsertBuilder is a builder for `INSERT INTO` statement. type InsertBuilder struct { Builder table string schema string columns []string defaults bool returning []string values [][]any conflict *conflict } // Insert creates a builder for the `INSERT INTO` statement. // // Insert("users"). // Columns("name", "age"). // Values("a8m", 10). // Values("foo", 20) // // Note: Insert inserts all values in one batch. func Insert(table string) *InsertBuilder { return &InsertBuilder{table: table} } // Schema sets the database name for the insert table. func (i *InsertBuilder) Schema(name string) *InsertBuilder { i.schema = name return i } // Set is a syntactic sugar API for inserting only one row. func (i *InsertBuilder) Set(column string, v any) *InsertBuilder { i.columns = append(i.columns, column) if len(i.values) == 0 { i.values = append(i.values, []any{v}) } else { i.values[0] = append(i.values[0], v) } return i } // Columns appends columns to the INSERT statement. func (i *InsertBuilder) Columns(columns ...string) *InsertBuilder { i.columns = append(i.columns, columns...) return i } // Values append a value tuple for the insert statement. func (i *InsertBuilder) Values(values ...any) *InsertBuilder { i.values = append(i.values, values) return i } // Default sets the default values clause based on the dialect type. func (i *InsertBuilder) Default() *InsertBuilder { i.defaults = true return i } // Returning adds the `RETURNING` clause to the insert statement. PostgreSQL only. func (i *InsertBuilder) Returning(columns ...string) *InsertBuilder { i.returning = columns return i } type ( // conflict holds the configuration for the // `ON CONFLICT` / `ON DUPLICATE KEY` clause. conflict struct { target struct { constraint string columns []string where *Predicate } action struct { nothing bool where *Predicate update []func(*UpdateSet) } } // ConflictOption allows configuring the // conflict config using functional options. ConflictOption func(*conflict) ) // ConflictColumns sets the unique constraints that trigger the conflict // resolution on insert to perform an upsert operation. The columns must // have a unique constraint applied to trigger this behaviour. // // sql.Insert("users"). // Columns("id", "name"). // Values(1, "Mashraki"). // OnConflict( // sql.ConflictColumns("id"), // sql.ResolveWithNewValues(), // ) func ConflictColumns(names ...string) ConflictOption { return func(c *conflict) { c.target.columns = names } } // ConflictConstraint allows setting the constraint // name (i.e. `ON CONSTRAINT `) for PostgreSQL. // // sql.Insert("users"). // Columns("id", "name"). // Values(1, "Mashraki"). // OnConflict( // sql.ConflictConstraint("users_pkey"), // sql.ResolveWithNewValues(), // ) func ConflictConstraint(name string) ConflictOption { return func(c *conflict) { c.target.constraint = name } } // ConflictWhere allows inference of partial unique indexes. See, PostgreSQL // doc: https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT func ConflictWhere(p *Predicate) ConflictOption { return func(c *conflict) { c.target.where = p } } // UpdateWhere allows setting the an update condition. Only rows // for which this expression returns true will be updated. func UpdateWhere(p *Predicate) ConflictOption { return func(c *conflict) { c.action.where = p } } // DoNothing configures the conflict_action to `DO NOTHING`. // Supported by SQLite and PostgreSQL. // // sql.Insert("users"). // Columns("id", "name"). // Values(1, "Mashraki"). // OnConflict( // sql.ConflictColumns("id"), // sql.DoNothing() // ) func DoNothing() ConflictOption { return func(c *conflict) { c.action.nothing = true } } // ResolveWithIgnore sets each column to itself to force an update and return the ID, // otherwise does not change any data. This may still trigger update hooks in the database. // // sql.Insert("users"). // Columns("id"). // Values(1). // OnConflict( // sql.ConflictColumns("id"), // sql.ResolveWithIgnore() // ) // // // Output: // // MySQL: INSERT INTO `users` (`id`) VALUES(1) ON DUPLICATE KEY UPDATE `id` = `users`.`id` // // PostgreSQL: INSERT INTO "users" ("id") VALUES(1) ON CONFLICT ("id") DO UPDATE SET "id" = "users"."id func ResolveWithIgnore() ConflictOption { return func(c *conflict) { c.action.update = append(c.action.update, func(u *UpdateSet) { for _, c := range u.columns { u.SetIgnore(c) } }) } } // ResolveWithNewValues updates columns using the new values proposed // for insertion using the special EXCLUDED/VALUES table. // // sql.Insert("users"). // Columns("id", "name"). // Values(1, "Mashraki"). // OnConflict( // sql.ConflictColumns("id"), // sql.ResolveWithNewValues() // ) // // // Output: // // MySQL: INSERT INTO `users` (`id`, `name`) VALUES(1, 'Mashraki) ON DUPLICATE KEY UPDATE `id` = VALUES(`id`), `name` = VALUES(`name`), // // PostgreSQL: INSERT INTO "users" ("id") VALUES(1) ON CONFLICT ("id") DO UPDATE SET "id" = "excluded"."id, "name" = "excluded"."name" func ResolveWithNewValues() ConflictOption { return func(c *conflict) { c.action.update = append(c.action.update, func(u *UpdateSet) { for _, c := range u.columns { u.SetExcluded(c) } }) } } // ResolveWith allows setting a custom function to set the `UPDATE` clause. // // Insert("users"). // Columns("id", "name"). // Values(1, "Mashraki"). // OnConflict( // ConflictColumns("name"), // ResolveWith(func(u *UpdateSet) { // u.SetIgnore("id") // u.SetNull("created_at") // u.Set("name", Expr(u.Excluded().C("name"))) // }), // ) func ResolveWith(fn func(*UpdateSet)) ConflictOption { return func(c *conflict) { c.action.update = append(c.action.update, fn) } } // OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause // of the `INSERT` statement. For example: // // sql.Insert("users"). // Columns("id", "name"). // Values(1, "Mashraki"). // OnConflict( // sql.ConflictColumns("id"), // sql.ResolveWithNewValues() // ) func (i *InsertBuilder) OnConflict(opts ...ConflictOption) *InsertBuilder { if i.conflict == nil { i.conflict = &conflict{} } for _, opt := range opts { opt(i.conflict) } return i } // UpdateSet describes a set of changes of the `DO UPDATE` clause. type UpdateSet struct { columns []string update *UpdateBuilder } // Table returns the table the `UPSERT` statement is executed on. func (u *UpdateSet) Table() *SelectTable { return Dialect(u.update.dialect).Table(u.update.table) } // Columns returns all columns in the `INSERT` statement. func (u *UpdateSet) Columns() []string { return u.columns } // UpdateColumns returns all columns in the `UPDATE` statement. func (u *UpdateSet) UpdateColumns() []string { return append(u.update.nulls, u.update.columns...) } // Set sets a column to a given value. func (u *UpdateSet) Set(column string, v any) *UpdateSet { u.update.Set(column, v) return u } // Add adds a numeric value to the given column. func (u *UpdateSet) Add(column string, v any) *UpdateSet { u.update.Add(column, v) return u } // SetNull sets a column as null value. func (u *UpdateSet) SetNull(column string) *UpdateSet { u.update.SetNull(column) return u } // SetIgnore sets the column to itself. For example, "id" = "users"."id". func (u *UpdateSet) SetIgnore(name string) *UpdateSet { return u.Set(name, Expr(u.Table().C(name))) } // SetExcluded sets the column name to its EXCLUDED/VALUES value. // For example, "c" = "excluded"."c", or `c` = VALUES(`c`). func (u *UpdateSet) SetExcluded(name string) *UpdateSet { switch u.update.Dialect() { case dialect.MySQL: u.update.Set(name, ExprFunc(func(b *Builder) { b.WriteString("VALUES(").Ident(name).WriteByte(')') })) default: t := Dialect(u.update.dialect).Table("excluded") u.update.Set(name, Expr(t.C(name))) } return u } // Query returns query representation of an `INSERT INTO` statement. func (i *InsertBuilder) Query() (string, []any) { i.WriteString("INSERT INTO ") i.writeSchema(i.schema) i.Ident(i.table).Pad() if i.defaults && len(i.columns) == 0 { i.writeDefault() } else { i.WriteByte('(').IdentComma(i.columns...).WriteByte(')') i.WriteString(" VALUES ") for j, v := range i.values { if j > 0 { i.Comma() } i.WriteByte('(').Args(v...).WriteByte(')') } } if i.conflict != nil { i.writeConflict() } if len(i.returning) > 0 && !i.mysql() { i.WriteString(" RETURNING ") i.IdentComma(i.returning...) } return i.String(), i.args } func (i *InsertBuilder) writeDefault() { switch i.Dialect() { case dialect.MySQL: i.WriteString("VALUES ()") case dialect.SQLite, dialect.Postgres: i.WriteString("DEFAULT VALUES") } } func (i *InsertBuilder) writeConflict() { switch i.Dialect() { case dialect.MySQL: i.WriteString(" ON DUPLICATE KEY UPDATE ") if i.conflict.action.nothing { i.AddError(fmt.Errorf("invalid CONFLICT action ('DO NOTHING')")) } case dialect.SQLite, dialect.Postgres: i.WriteString(" ON CONFLICT") switch t := i.conflict.target; { case t.constraint != "" && len(t.columns) != 0: i.AddError(fmt.Errorf("duplicate CONFLICT clauses: %q, %q", t.constraint, t.columns)) case t.constraint != "": i.WriteString(" ON CONSTRAINT ").Ident(t.constraint) case len(t.columns) != 0: i.WriteString(" (").IdentComma(t.columns...).WriteByte(')') } if p := i.conflict.target.where; p != nil { i.WriteString(" WHERE ").Join(p) } if i.conflict.action.nothing { i.WriteString(" DO NOTHING") return } i.WriteString(" DO UPDATE SET ") } if len(i.conflict.action.update) == 0 { i.AddError(errors.New("missing action for 'DO UPDATE SET' clause")) } u := &UpdateSet{columns: i.columns, update: Dialect(i.dialect).Update(i.table)} u.update.Builder = i.Builder for _, f := range i.conflict.action.update { f(u) } u.update.writeSetter(&i.Builder) if p := i.conflict.action.where; p != nil { p.qualifier = i.table i.WriteString(" WHERE ").Join(p) } } // UpdateBuilder is a builder for `UPDATE` statement. type UpdateBuilder struct { Builder table string schema string where *Predicate nulls []string columns []string values []any order []any prefix Queries } // Update creates a builder for the `UPDATE` statement. // // Update("users").Set("name", "foo").Set("age", 10) func Update(table string) *UpdateBuilder { return &UpdateBuilder{table: table} } // Schema sets the database name for the updated table. func (u *UpdateBuilder) Schema(name string) *UpdateBuilder { u.schema = name return u } // Set sets a column to a given value. If `Set` was called before with // the same column name, it overrides the value of the previous call. func (u *UpdateBuilder) Set(column string, v any) *UpdateBuilder { for i := range u.columns { if column == u.columns[i] { u.values[i] = v return u } } u.columns = append(u.columns, column) u.values = append(u.values, v) return u } // Add adds a numeric value to the given column. Note that, calling Set(c) // after Add(c) will erase previous calls with c from the builder. func (u *UpdateBuilder) Add(column string, v any) *UpdateBuilder { u.columns = append(u.columns, column) u.values = append(u.values, ExprFunc(func(b *Builder) { b.WriteString("COALESCE") b.Nested(func(b *Builder) { b.Ident(Table(u.table).C(column)).Comma().WriteByte('0') }) b.WriteString(" + ") b.Arg(v) })) return u } // SetNull sets a column as null value. func (u *UpdateBuilder) SetNull(column string) *UpdateBuilder { u.nulls = append(u.nulls, column) return u } // Where adds a where predicate for update statement. func (u *UpdateBuilder) Where(p *Predicate) *UpdateBuilder { if u.where != nil { u.where = And(u.where, p) } else { u.where = p } return u } // FromSelect makes it possible to update entities that match the sub-query. func (u *UpdateBuilder) FromSelect(s *Selector) *UpdateBuilder { u.Where(s.where) if t := s.Table(); t != nil { u.table = t.name } return u } // Empty reports whether this builder does not contain update changes. func (u *UpdateBuilder) Empty() bool { return len(u.columns) == 0 && len(u.nulls) == 0 } // OrderBy appends the `ORDER BY` clause to the `UPDATE` statement. // Supported by SQLite and MySQL. func (u *UpdateBuilder) OrderBy(columns ...string) *UpdateBuilder { if u.postgres() { u.AddError(errors.New("ORDER BY is not supported by PostgreSQL")) return u } for i := range columns { u.order = append(u.order, columns[i]) } return u } // Prefix prefixes the UPDATE statement with list of statements. func (u *UpdateBuilder) Prefix(stmts ...Querier) *UpdateBuilder { u.prefix = append(u.prefix, stmts...) return u } // Query returns query representation of an `UPDATE` statement. func (u *UpdateBuilder) Query() (string, []any) { b := u.Builder.clone() if len(u.prefix) > 0 { b.join(u.prefix, " ") b.Pad() } b.WriteString("UPDATE ") b.writeSchema(u.schema) b.Ident(u.table).WriteString(" SET ") u.writeSetter(&b) if u.where != nil { b.WriteString(" WHERE ") b.Join(u.where) } joinOrder(u.order, &b) return b.String(), b.args } // writeSetter writes the "SET" clause for the UPDATE statement. func (u *UpdateBuilder) writeSetter(b *Builder) { for i, c := range u.nulls { if i > 0 { b.Comma() } b.Ident(c).WriteString(" = NULL") } if len(u.nulls) > 0 && len(u.columns) > 0 { b.Comma() } for i, c := range u.columns { if i > 0 { b.Comma() } b.Ident(c).WriteString(" = ") switch v := u.values[i].(type) { case Querier: b.Join(v) default: b.Arg(v) } } } // DeleteBuilder is a builder for `DELETE` statement. type DeleteBuilder struct { Builder table string schema string where *Predicate } // Delete creates a builder for the `DELETE` statement. // // Delete("users"). // Where( // Or( // EQ("name", "foo").And().EQ("age", 10), // EQ("name", "bar").And().EQ("age", 20), // And( // EQ("name", "qux"), // EQ("age", 1).Or().EQ("age", 2), // ), // ), // ) func Delete(table string) *DeleteBuilder { return &DeleteBuilder{table: table} } // Schema sets the database name for the table whose row will be deleted. func (d *DeleteBuilder) Schema(name string) *DeleteBuilder { d.schema = name return d } // Where appends a where predicate to the `DELETE` statement. func (d *DeleteBuilder) Where(p *Predicate) *DeleteBuilder { if d.where != nil { d.where = And(d.where, p) } else { d.where = p } return d } // FromSelect makes it possible to delete a sub query. func (d *DeleteBuilder) FromSelect(s *Selector) *DeleteBuilder { d.Where(s.where) if t := s.Table(); t != nil { d.table = t.name } return d } // Query returns query representation of a `DELETE` statement. func (d *DeleteBuilder) Query() (string, []any) { d.WriteString("DELETE FROM ") d.writeSchema(d.schema) d.Ident(d.table) if d.where != nil { d.WriteString(" WHERE ") d.Join(d.where) } return d.String(), d.args } // Predicate is a where predicate. type Predicate struct { Builder depth int fns []func(*Builder) } // P creates a new predicate. // // P().EQ("name", "a8m").And().EQ("age", 30) func P(fns ...func(*Builder)) *Predicate { return &Predicate{fns: fns} } // ExprP creates a new predicate from the given expression. // // ExprP("A = ? AND B > ?", args...) func ExprP(exr string, args ...any) *Predicate { return P(func(b *Builder) { b.Join(Expr(exr, args...)) }) } // Or combines all given predicates with OR between them. // // Or(EQ("name", "foo"), EQ("name", "bar")) func Or(preds ...*Predicate) *Predicate { p := P() return p.Append(func(b *Builder) { p.mayWrap(preds, b, "OR") }) } // False appends the FALSE keyword to the predicate. // // Delete().From("users").Where(False()) func False() *Predicate { return P().False() } // False appends FALSE to the predicate. func (p *Predicate) False() *Predicate { return p.Append(func(b *Builder) { b.WriteString("FALSE") }) } // Not wraps the given predicate with the not predicate. // // Not(Or(EQ("name", "foo"), EQ("name", "bar"))) func Not(pred *Predicate) *Predicate { return P().Not().Append(func(b *Builder) { b.Nested(func(b *Builder) { b.Join(pred) }) }) } // Not appends NOT to the predicate. func (p *Predicate) Not() *Predicate { return p.Append(func(b *Builder) { b.WriteString("NOT ") }) } // ColumnsOp returns a new predicate between 2 columns. func ColumnsOp(col1, col2 string, op Op) *Predicate { return P().ColumnsOp(col1, col2, op) } // ColumnsOp appends the given predicate between 2 columns. func (p *Predicate) ColumnsOp(col1, col2 string, op Op) *Predicate { return p.Append(func(b *Builder) { b.Ident(col1) b.WriteOp(op) b.Ident(col2) }) } // And combines all given predicates with AND between them. func And(preds ...*Predicate) *Predicate { p := P() return p.Append(func(b *Builder) { p.mayWrap(preds, b, "AND") }) } // IsTrue appends a predicate that checks if the column value is truthy. func IsTrue(col string) *Predicate { return P().IsTrue(col) } // IsTrue appends a predicate that checks if the column value is truthy. func (p *Predicate) IsTrue(col string) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) }) } // IsFalse appends a predicate that checks if the column value is falsey. func IsFalse(col string) *Predicate { return P().IsFalse(col) } // IsFalse appends a predicate that checks if the column value is falsey. func (p *Predicate) IsFalse(col string) *Predicate { return p.Append(func(b *Builder) { b.WriteString("NOT ").Ident(col) }) } // EQ returns a "=" predicate. func EQ(col string, value any) *Predicate { return P().EQ(col, value) } // EQ appends a "=" predicate. func (p *Predicate) EQ(col string, arg any) *Predicate { // A small optimization to avoid passing // arguments when it can be avoided. switch arg := arg.(type) { case bool: if arg { return IsTrue(col) } return IsFalse(col) default: return p.Append(func(b *Builder) { b.Ident(col) b.WriteOp(OpEQ) p.arg(b, arg) }) } } // ColumnsEQ appends a "=" predicate between 2 columns. func ColumnsEQ(col1, col2 string) *Predicate { return P().ColumnsEQ(col1, col2) } // ColumnsEQ appends a "=" predicate between 2 columns. func (p *Predicate) ColumnsEQ(col1, col2 string) *Predicate { return p.ColumnsOp(col1, col2, OpEQ) } // NEQ returns a "<>" predicate. func NEQ(col string, value any) *Predicate { return P().NEQ(col, value) } // NEQ appends a "<>" predicate. func (p *Predicate) NEQ(col string, arg any) *Predicate { // A small optimization to avoid passing // arguments when it can be avoided. switch arg := arg.(type) { case bool: if arg { return IsFalse(col) } return IsTrue(col) default: return p.Append(func(b *Builder) { b.Ident(col) b.WriteOp(OpNEQ) p.arg(b, arg) }) } } // ColumnsNEQ appends a "<>" predicate between 2 columns. func ColumnsNEQ(col1, col2 string) *Predicate { return P().ColumnsNEQ(col1, col2) } // ColumnsNEQ appends a "<>" predicate between 2 columns. func (p *Predicate) ColumnsNEQ(col1, col2 string) *Predicate { return p.ColumnsOp(col1, col2, OpNEQ) } // LT returns a "<" predicate. func LT(col string, value any) *Predicate { return P().LT(col, value) } // LT appends a "<" predicate. func (p *Predicate) LT(col string, arg any) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpLT) p.arg(b, arg) }) } // ColumnsLT appends a "<" predicate between 2 columns. func ColumnsLT(col1, col2 string) *Predicate { return P().ColumnsLT(col1, col2) } // ColumnsLT appends a "<" predicate between 2 columns. func (p *Predicate) ColumnsLT(col1, col2 string) *Predicate { return p.ColumnsOp(col1, col2, OpLT) } // LTE returns a "<=" predicate. func LTE(col string, value any) *Predicate { return P().LTE(col, value) } // LTE appends a "<=" predicate. func (p *Predicate) LTE(col string, arg any) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpLTE) p.arg(b, arg) }) } // ColumnsLTE appends a "<=" predicate between 2 columns. func ColumnsLTE(col1, col2 string) *Predicate { return P().ColumnsLTE(col1, col2) } // ColumnsLTE appends a "<=" predicate between 2 columns. func (p *Predicate) ColumnsLTE(col1, col2 string) *Predicate { return p.ColumnsOp(col1, col2, OpLTE) } // GT returns a ">" predicate. func GT(col string, value any) *Predicate { return P().GT(col, value) } // GT appends a ">" predicate. func (p *Predicate) GT(col string, arg any) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpGT) p.arg(b, arg) }) } // ColumnsGT appends a ">" predicate between 2 columns. func ColumnsGT(col1, col2 string) *Predicate { return P().ColumnsGT(col1, col2) } // ColumnsGT appends a ">" predicate between 2 columns. func (p *Predicate) ColumnsGT(col1, col2 string) *Predicate { return p.ColumnsOp(col1, col2, OpGT) } // GTE returns a ">=" predicate. func GTE(col string, value any) *Predicate { return P().GTE(col, value) } // GTE appends a ">=" predicate. func (p *Predicate) GTE(col string, arg any) *Predicate { return p.Append(func(b *Builder) { b.Ident(col) p.WriteOp(OpGTE) p.arg(b, arg) }) } // ColumnsGTE appends a ">=" predicate between 2 columns. func ColumnsGTE(col1, col2 string) *Predicate { return P().ColumnsGTE(col1, col2) } // ColumnsGTE appends a ">=" predicate between 2 columns. func (p *Predicate) ColumnsGTE(col1, col2 string) *Predicate { return p.ColumnsOp(col1, col2, OpGTE) } // NotNull returns the `IS NOT NULL` predicate. func NotNull(col string) *Predicate { return P().NotNull(col) } // NotNull appends the `IS NOT NULL` predicate. func (p *Predicate) NotNull(col string) *Predicate { return p.Append(func(b *Builder) { b.Ident(col).WriteString(" IS NOT NULL") }) } // IsNull returns the `IS NULL` predicate. func IsNull(col string) *Predicate { return P().IsNull(col) } // IsNull appends the `IS NULL` predicate. func (p *Predicate) IsNull(col string) *Predicate { return p.Append(func(b *Builder) { b.Ident(col).WriteString(" IS NULL") }) } // In returns the `IN` predicate. func In(col string, args ...any) *Predicate { return P().In(col, args...) } // In appends the `IN` predicate. func (p *Predicate) In(col string, args ...any) *Predicate { // If no arguments were provided, append the FALSE constant, since // we cannot apply "IN ()". This will make this predicate falsy. if len(args) == 0 { return p.False() } return p.Append(func(b *Builder) { b.Ident(col).WriteOp(OpIn) b.Nested(func(b *Builder) { if s, ok := args[0].(*Selector); ok { b.Join(s) } else { b.Args(args...) } }) }) } // InInts returns the `IN` predicate for ints. func InInts(col string, args ...int) *Predicate { return P().InInts(col, args...) } // InValues adds the `IN` predicate for slice of driver.Value. func InValues(col string, args ...driver.Value) *Predicate { return P().InValues(col, args...) } // InInts adds the `IN` predicate for ints. func (p *Predicate) InInts(col string, args ...int) *Predicate { iface := make([]any, len(args)) for i := range args { iface[i] = args[i] } return p.In(col, iface...) } // InValues adds the `IN` predicate for slice of driver.Value. func (p *Predicate) InValues(col string, args ...driver.Value) *Predicate { iface := make([]any, len(args)) for i := range args { iface[i] = args[i] } return p.In(col, iface...) } // NotIn returns the `Not IN` predicate. func NotIn(col string, args ...any) *Predicate { return P().NotIn(col, args...) } // NotIn appends the `Not IN` predicate. func (p *Predicate) NotIn(col string, args ...any) *Predicate { // If no arguments were provided, append the NOT FALSE constant, since // we cannot apply "NOT IN ()". This will make this predicate truthy. if len(args) == 0 { return Not(p.False()) } return p.Append(func(b *Builder) { b.Ident(col).WriteOp(OpNotIn) b.Nested(func(b *Builder) { if s, ok := args[0].(*Selector); ok { b.Join(s) } else { b.Args(args...) } }) }) } // Exists returns the `Exists` predicate. func Exists(query Querier) *Predicate { return P().Exists(query) } // Exists appends the `EXISTS` predicate with the given query. func (p *Predicate) Exists(query Querier) *Predicate { return p.Append(func(b *Builder) { b.WriteString("EXISTS ") b.Nested(func(b *Builder) { b.Join(query) }) }) } // NotExists returns the `NotExists` predicate. func NotExists(query Querier) *Predicate { return P().NotExists(query) } // NotExists appends the `NOT EXISTS` predicate with the given query. func (p *Predicate) NotExists(query Querier) *Predicate { return p.Append(func(b *Builder) { b.WriteString("NOT EXISTS ") b.Nested(func(b *Builder) { b.Join(query) }) }) } // Like returns the `LIKE` predicate. func Like(col, pattern string) *Predicate { return P().Like(col, pattern) } // Like appends the `LIKE` predicate. func (p *Predicate) Like(col, pattern string) *Predicate { return p.Append(func(b *Builder) { b.Ident(col).WriteOp(OpLike) b.Arg(pattern) }) } // escape escapes w with the default escape character ('/'), // to be used by the pattern matching functions below. // The second return value indicates if w was escaped or not. func escape(w string) (string, bool) { var n int for i := range w { if c := w[i]; c == '%' || c == '_' || c == '\\' { n++ } } // No characters to escape. if n == 0 { return w, false } var b strings.Builder b.Grow(len(w) + n) for i := range w { if c := w[i]; c == '%' || c == '_' || c == '\\' { b.WriteByte('\\') } b.WriteByte(w[i]) } return b.String(), true } func (p *Predicate) escapedLike(col, left, right, word string) *Predicate { return p.Append(func(b *Builder) { w, escaped := escape(word) b.Ident(col).WriteOp(OpLike) b.Arg(left + w + right) if p.dialect == dialect.SQLite && escaped { p.WriteString(" ESCAPE ").Arg("\\") } }) } // HasPrefix is a helper predicate that checks prefix using the LIKE predicate. func HasPrefix(col, prefix string) *Predicate { return P().HasPrefix(col, prefix) } // HasPrefix is a helper predicate that checks prefix using the LIKE predicate. func (p *Predicate) HasPrefix(col, prefix string) *Predicate { return p.escapedLike(col, "", "%", prefix) } // HasSuffix is a helper predicate that checks suffix using the LIKE predicate. func HasSuffix(col, suffix string) *Predicate { return P().HasSuffix(col, suffix) } // HasSuffix is a helper predicate that checks suffix using the LIKE predicate. func (p *Predicate) HasSuffix(col, suffix string) *Predicate { return p.escapedLike(col, "%", "", suffix) } // EqualFold is a helper predicate that applies the "=" predicate with case-folding. func EqualFold(col, sub string) *Predicate { return P().EqualFold(col, sub) } // EqualFold is a helper predicate that applies the "=" predicate with case-folding. func (p *Predicate) EqualFold(col, sub string) *Predicate { return p.Append(func(b *Builder) { f := &Func{} f.SetDialect(b.dialect) switch b.dialect { case dialect.MySQL: // We assume the CHARACTER SET is configured to utf8mb4, // because this how it is defined in dialect/sql/schema. b.Ident(col).WriteString(" COLLATE utf8mb4_general_ci = ") b.Arg(strings.ToLower(sub)) case dialect.Postgres: b.Ident(col).WriteString(" ILIKE ") w, _ := escape(sub) b.Arg(strings.ToLower(w)) default: // SQLite. f.Lower(col) b.WriteString(f.String()) b.WriteOp(OpEQ) b.Arg(strings.ToLower(sub)) } }) } // Contains is a helper predicate that checks substring using the LIKE predicate. func Contains(col, sub string) *Predicate { return P().Contains(col, sub) } // Contains is a helper predicate that checks substring using the LIKE predicate. func (p *Predicate) Contains(col, substr string) *Predicate { return p.escapedLike(col, "%", "%", substr) } // ContainsFold is a helper predicate that checks substring using the LIKE predicate. func ContainsFold(col, sub string) *Predicate { return P().ContainsFold(col, sub) } // ContainsFold is a helper predicate that applies the LIKE predicate with case-folding. func (p *Predicate) ContainsFold(col, substr string) *Predicate { return p.Append(func(b *Builder) { w, escaped := escape(substr) switch b.dialect { case dialect.MySQL: // We assume the CHARACTER SET is configured to utf8mb4, // because this how it is defined in dialect/sql/schema. b.Ident(col).WriteString(" COLLATE utf8mb4_general_ci LIKE ") b.Arg("%" + strings.ToLower(w) + "%") case dialect.Postgres: b.Ident(col).WriteString(" ILIKE ") b.Arg("%" + strings.ToLower(w) + "%") default: // SQLite. var f Func f.SetDialect(b.dialect) f.Lower(col) b.WriteString(f.String()).WriteString(" LIKE ") b.Arg("%" + strings.ToLower(w) + "%") if escaped { p.WriteString(" ESCAPE ").Arg("\\") } } }) } // CompositeGT returns a composite ">" predicate func CompositeGT(columns []string, args ...any) *Predicate { return P().CompositeGT(columns, args...) } // CompositeLT returns a composite "<" predicate func CompositeLT(columns []string, args ...any) *Predicate { return P().CompositeLT(columns, args...) } func (p *Predicate) compositeP(operator string, columns []string, args ...any) *Predicate { return p.Append(func(b *Builder) { b.Nested(func(nb *Builder) { nb.IdentComma(columns...) }) b.WriteString(operator) b.WriteString("(") b.Args(args...) b.WriteString(")") }) } // CompositeGT returns a composite ">" predicate. func (p *Predicate) CompositeGT(columns []string, args ...any) *Predicate { const operator = " > " return p.compositeP(operator, columns, args...) } // CompositeLT appends a composite "<" predicate. func (p *Predicate) CompositeLT(columns []string, args ...any) *Predicate { const operator = " < " return p.compositeP(operator, columns, args...) } // Append appends a new function to the predicate callbacks. // The callback list are executed on call to Query. func (p *Predicate) Append(f func(*Builder)) *Predicate { p.fns = append(p.fns, f) return p } // Query returns query representation of a predicate. func (p *Predicate) Query() (string, []any) { if p.Len() > 0 || len(p.args) > 0 { p.Reset() p.args = nil } for _, f := range p.fns { f(&p.Builder) } return p.String(), p.args } // arg calls Builder.Arg, but wraps `a` with parens in case of a Selector. func (*Predicate) arg(b *Builder, a any) { switch a.(type) { case *Selector: b.Nested(func(b *Builder) { b.Arg(a) }) default: b.Arg(a) } } // clone returns a shallow clone of p. func (p *Predicate) clone() *Predicate { if p == nil { return p } return &Predicate{fns: append([]func(*Builder){}, p.fns...)} } func (p *Predicate) mayWrap(preds []*Predicate, b *Builder, op string) { switch n := len(preds); { case n == 1: b.Join(preds[0]) return case n > 1 && p.depth != 0: b.WriteByte('(') defer b.WriteByte(')') } for i := range preds { preds[i].depth = p.depth + 1 if i > 0 { b.WriteByte(' ') b.WriteString(op) b.WriteByte(' ') } if len(preds[i].fns) > 1 { b.Nested(func(b *Builder) { b.Join(preds[i]) }) } else { b.Join(preds[i]) } } } // Func represents an SQL function. type Func struct { Builder fns []func(*Builder) } // Lower wraps the given column with the LOWER function. // // P().EQ(sql.Lower("name"), "a8m") func Lower(ident string) string { f := &Func{} f.Lower(ident) return f.String() } // Lower wraps the given ident with the LOWER function. func (f *Func) Lower(ident string) { f.byName("LOWER", ident) } // Count wraps the ident with the COUNT aggregation function. func Count(ident string) string { f := &Func{} f.Count(ident) return f.String() } // Count wraps the ident with the COUNT aggregation function. func (f *Func) Count(ident string) { f.byName("COUNT", ident) } // Max wraps the ident with the MAX aggregation function. func Max(ident string) string { f := &Func{} f.Max(ident) return f.String() } // Max wraps the ident with the MAX aggregation function. func (f *Func) Max(ident string) { f.byName("MAX", ident) } // Min wraps the ident with the MIN aggregation function. func Min(ident string) string { f := &Func{} f.Min(ident) return f.String() } // Min wraps the ident with the MIN aggregation function. func (f *Func) Min(ident string) { f.byName("MIN", ident) } // Sum wraps the ident with the SUM aggregation function. func Sum(ident string) string { f := &Func{} f.Sum(ident) return f.String() } // Sum wraps the ident with the SUM aggregation function. func (f *Func) Sum(ident string) { f.byName("SUM", ident) } // Avg wraps the ident with the AVG aggregation function. func Avg(ident string) string { f := &Func{} f.Avg(ident) return f.String() } // Avg wraps the ident with the AVG aggregation function. func (f *Func) Avg(ident string) { f.byName("AVG", ident) } // byName wraps an identifier with a function name. func (f *Func) byName(fn, ident string) { f.Append(func(b *Builder) { f.WriteString(fn) f.Nested(func(b *Builder) { b.Ident(ident) }) }) } // Append appends a new function to the function callbacks. // The callback list are executed on call to String. func (f *Func) Append(fn func(*Builder)) *Func { f.fns = append(f.fns, fn) return f } // String implements the fmt.Stringer. func (f *Func) String() string { for _, fn := range f.fns { fn(&f.Builder) } return f.Builder.String() } // As suffixed the given column with an alias (`a` AS `b`). func As(ident string, as string) string { b := &Builder{} b.fromIdent(ident) b.Ident(ident).Pad().WriteString("AS") b.Pad().Ident(as) return b.String() } // Distinct prefixed the given columns with the `DISTINCT` keyword (DISTINCT `id`). func Distinct(idents ...string) string { b := &Builder{} if len(idents) > 0 { b.fromIdent(idents[0]) } b.WriteString("DISTINCT") b.Pad().IdentComma(idents...) return b.String() } // TableView is a view that returns a table view. Can be a Table, Selector or a View (WITH statement). type TableView interface { view() } // queryView allows using Querier (expressions) in the FROM clause. type queryView struct{ Querier } func (*queryView) view() {} // SelectTable is a table selector. type SelectTable struct { Builder as string name string schema string quote bool } // Table returns a new table selector. // // t1 := Table("users").As("u") // return Select(t1.C("name")) func Table(name string) *SelectTable { return &SelectTable{quote: true, name: name} } // Schema sets the schema name of the table. func (s *SelectTable) Schema(name string) *SelectTable { s.schema = name return s } // As adds the AS clause to the table selector. func (s *SelectTable) As(alias string) *SelectTable { s.as = alias return s } // C returns a formatted string for the table column. func (s *SelectTable) C(column string) string { name := s.name if s.as != "" { name = s.as } b := &Builder{dialect: s.dialect} if s.as == "" { b.writeSchema(s.schema) } b.Ident(name).WriteByte('.').Ident(column) return b.String() } // Columns returns a list of formatted strings for the table columns. func (s *SelectTable) Columns(columns ...string) []string { names := make([]string, 0, len(columns)) for _, c := range columns { names = append(names, s.C(c)) } return names } // Unquote makes the table name to be formatted as raw string (unquoted). // It is useful whe you don't want to query tables under the current database. // For example: "INFORMATION_SCHEMA.TABLE_CONSTRAINTS" in MySQL. func (s *SelectTable) Unquote() *SelectTable { s.quote = false return s } // ref returns the table reference. func (s *SelectTable) ref() string { if !s.quote { return s.name } b := &Builder{dialect: s.dialect} b.writeSchema(s.schema) b.Ident(s.name) if s.as != "" { b.WriteString(" AS ") b.Ident(s.as) } return b.String() } // implement the table view. func (*SelectTable) view() {} // join table option. type join struct { on *Predicate kind string table TableView } // clone a joiner. func (j join) clone() join { if sel, ok := j.table.(*Selector); ok { j.table = sel.Clone() } j.on = j.on.clone() return j } // Selector is a builder for the `SELECT` statement. type Selector struct { Builder // ctx stores contextual data typically from // generated code such as alternate table schemas. ctx context.Context as string selection []any from []TableView joins []join where *Predicate or bool not bool order []any group []string having *Predicate limit *int offset *int distinct bool union []union prefix Queries lock *LockOptions } // WithContext sets the context into the *Selector. func (s *Selector) WithContext(ctx context.Context) *Selector { if ctx == nil { panic("nil context") } s.ctx = ctx return s } // Context returns the Selector context or Background // if nil. func (s *Selector) Context() context.Context { if s.ctx != nil { return s.ctx } return context.Background() } // Select returns a new selector for the `SELECT` statement. // // t1 := Table("users").As("u") // t2 := Select().From(Table("groups")).Where(EQ("user_id", 10)).As("g") // return Select(t1.C("id"), t2.C("name")). // From(t1). // Join(t2). // On(t1.C("id"), t2.C("user_id")) func Select(columns ...string) *Selector { return (&Selector{}).Select(columns...) } // SelectExpr is like Select, but supports passing arbitrary // expressions for SELECT clause. func SelectExpr(exprs ...Querier) *Selector { return (&Selector{}).SelectExpr(exprs...) } // Select changes the columns selection of the SELECT statement. // Empty selection means all columns *. func (s *Selector) Select(columns ...string) *Selector { s.selection = make([]any, len(columns)) for i := range columns { s.selection[i] = columns[i] } return s } // AppendSelect appends additional columns to the SELECT statement. func (s *Selector) AppendSelect(columns ...string) *Selector { for i := range columns { s.selection = append(s.selection, columns[i]) } return s } // SelectExpr changes the columns selection of the SELECT statement // with custom list of expressions. func (s *Selector) SelectExpr(exprs ...Querier) *Selector { s.selection = make([]any, len(exprs)) for i := range exprs { s.selection[i] = exprs[i] } return s } // AppendSelectExpr appends additional expressions to the SELECT statement. func (s *Selector) AppendSelectExpr(exprs ...Querier) *Selector { for i := range exprs { s.selection = append(s.selection, exprs[i]) } return s } // AppendSelectExprAs appends additional expressions to the SELECT statement with the given name. func (s *Selector) AppendSelectExprAs(expr Querier, as string) *Selector { s.selection = append(s.selection, ExprFunc(func(b *Builder) { b.WriteByte('(') b.Join(expr) b.WriteString(") AS ") b.Ident(as) })) return s } // SelectedColumns returns the selected columns in the Selector. func (s *Selector) SelectedColumns() []string { columns := make([]string, 0, len(s.selection)) for i := range s.selection { if c, ok := s.selection[i].(string); ok { columns = append(columns, c) } } return columns } // UnqualifiedColumns returns the an unqualified version of the // selected columns in the Selector. e.g. "t1"."c" => "c". func (s *Selector) UnqualifiedColumns() []string { columns := make([]string, 0, len(s.selection)) for i := range s.selection { c, ok := s.selection[i].(string) if !ok { continue } if s.isIdent(c) { parts := strings.FieldsFunc(c, func(r rune) bool { return r == '`' || r == '"' }) if n := len(parts); n > 0 && parts[n-1] != "" { c = parts[n-1] } } columns = append(columns, c) } return columns } // From sets the source of `FROM` clause. func (s *Selector) From(t TableView) *Selector { s.from = nil return s.AppendFrom(t) } // AppendFrom appends a new TableView to the `FROM` clause. func (s *Selector) AppendFrom(t TableView) *Selector { s.from = append(s.from, t) if st, ok := t.(state); ok { st.SetDialect(s.dialect) } return s } // FromExpr sets the expression of `FROM` clause. func (s *Selector) FromExpr(x Querier) *Selector { s.from = nil return s.AppendFromExpr(x) } // AppendFromExpr appends an expression (Queries) to the `FROM` clause. func (s *Selector) AppendFromExpr(x Querier) *Selector { s.from = append(s.from, &queryView{Querier: x}) if st, ok := x.(state); ok { st.SetDialect(s.dialect) } return s } // Distinct adds the DISTINCT keyword to the `SELECT` statement. func (s *Selector) Distinct() *Selector { s.distinct = true return s } // SetDistinct sets explicitly if the returned rows are distinct or indistinct. func (s *Selector) SetDistinct(v bool) *Selector { s.distinct = v return s } // Limit adds the `LIMIT` clause to the `SELECT` statement. func (s *Selector) Limit(limit int) *Selector { s.limit = &limit return s } // Offset adds the `OFFSET` clause to the `SELECT` statement. func (s *Selector) Offset(offset int) *Selector { s.offset = &offset return s } // Where sets or appends the given predicate to the statement. func (s *Selector) Where(p *Predicate) *Selector { if s.not { p = Not(p) s.not = false } switch { case s.where == nil: s.where = p case s.where != nil && s.or: s.where = Or(s.where, p) s.or = false default: s.where = And(s.where, p) } return s } // P returns the predicate of a selector. func (s *Selector) P() *Predicate { return s.where } // SetP sets explicitly the predicate function for the selector and clear its previous state. func (s *Selector) SetP(p *Predicate) *Selector { s.where = p s.or = false s.not = false return s } // FromSelect copies the predicate from a selector. func (s *Selector) FromSelect(s2 *Selector) *Selector { s.where = s2.where return s } // Not sets the next coming predicate with not. func (s *Selector) Not() *Selector { s.not = true return s } // Or sets the next coming predicate with OR operator (disjunction). func (s *Selector) Or() *Selector { s.or = true return s } // Table returns the selected table. func (s *Selector) Table() *SelectTable { if len(s.from) == 0 { return nil } return s.from[0].(*SelectTable) } // TableName returns the name of the selected table or alias of selector. func (s *Selector) TableName() string { switch view := s.from[0].(type) { case *SelectTable: return view.name case *Selector: return view.as default: panic(fmt.Sprintf("unhandled TableView type %T", s.from)) } } // Join appends a `JOIN` clause to the statement. func (s *Selector) Join(t TableView) *Selector { return s.join("JOIN", t) } // LeftJoin appends a `LEFT JOIN` clause to the statement. func (s *Selector) LeftJoin(t TableView) *Selector { return s.join("LEFT JOIN", t) } // RightJoin appends a `RIGHT JOIN` clause to the statement. func (s *Selector) RightJoin(t TableView) *Selector { return s.join("RIGHT JOIN", t) } // FullJoin appends a `FULL JOIN` clause to the statement. func (s *Selector) FullJoin(t TableView) *Selector { return s.join("FULL JOIN", t) } // join adds a join table to the selector with the given kind. func (s *Selector) join(kind string, t TableView) *Selector { s.joins = append(s.joins, join{ kind: kind, table: t, }) switch view := t.(type) { case *SelectTable: if view.as == "" { view.as = "t" + strconv.Itoa(len(s.joins)) } case *Selector: if view.as == "" { view.as = "t" + strconv.Itoa(len(s.joins)) } } if st, ok := t.(state); ok { st.SetDialect(s.dialect) } return s } // unionType describes an UNION type. type unionType string const ( unionAll unionType = "ALL" unionDistinct unionType = "DISTINCT" ) // union query option. type union struct { unionType TableView } // Union appends the UNION clause to the query. func (s *Selector) Union(t TableView) *Selector { s.union = append(s.union, union{ TableView: t, }) return s } // UnionAll appends the UNION ALL clause to the query. func (s *Selector) UnionAll(t TableView) *Selector { s.union = append(s.union, union{ unionType: unionAll, TableView: t, }) return s } // UnionDistinct appends the UNION DISTINCT clause to the query. func (s *Selector) UnionDistinct(t TableView) *Selector { s.union = append(s.union, union{ unionType: unionDistinct, TableView: t, }) return s } // Prefix prefixes the query with list of queries. func (s *Selector) Prefix(queries ...Querier) *Selector { s.prefix = append(s.prefix, queries...) return s } // C returns a formatted string for a selected column from this statement. func (s *Selector) C(column string) string { if s.as != "" { b := &Builder{dialect: s.dialect} b.Ident(s.as) b.WriteByte('.') b.Ident(column) return b.String() } return s.Table().C(column) } // Columns returns a list of formatted strings for a selected columns from this statement. func (s *Selector) Columns(columns ...string) []string { names := make([]string, 0, len(columns)) for _, c := range columns { names = append(names, s.C(c)) } return names } // OnP sets or appends the given predicate for the `ON` clause of the statement. func (s *Selector) OnP(p *Predicate) *Selector { if len(s.joins) > 0 { join := &s.joins[len(s.joins)-1] switch { case join.on == nil: join.on = p default: join.on = And(join.on, p) } } return s } // On sets the `ON` clause for the `JOIN` operation. func (s *Selector) On(c1, c2 string) *Selector { s.OnP(P(func(builder *Builder) { builder.Ident(c1).WriteOp(OpEQ).Ident(c2) })) return s } // As give this selection an alias. func (s *Selector) As(alias string) *Selector { s.as = alias return s } // Count sets the Select statement to be a `SELECT COUNT(*)`. func (s *Selector) Count(columns ...string) *Selector { column := "*" if len(columns) > 0 { b := &Builder{} b.IdentComma(columns...) column = b.String() } s.Select(Count(column)) return s } // LockAction tells the transaction what to do in case of // requesting a row that is locked by other transaction. type LockAction string const ( // NoWait means never wait and returns an error. NoWait LockAction = "NOWAIT" // SkipLocked means never wait and skip. SkipLocked LockAction = "SKIP LOCKED" ) // LockStrength defines the strength of the lock (see the list below). type LockStrength string // A list of all locking clauses. const ( LockShare LockStrength = "SHARE" LockUpdate LockStrength = "UPDATE" LockNoKeyUpdate LockStrength = "NO KEY UPDATE" LockKeyShare LockStrength = "KEY SHARE" ) type ( // LockOptions defines a SELECT statement // lock for protecting concurrent updates. LockOptions struct { // Strength of the lock. Strength LockStrength // Action of the lock. Action LockAction // Tables are an option tables. Tables []string // custom clause for locking. clause string } // LockOption allows configuring the LockConfig using functional options. LockOption func(*LockOptions) ) // WithLockAction sets the Action of the lock. func WithLockAction(action LockAction) LockOption { return func(c *LockOptions) { c.Action = action } } // WithLockTables sets the Tables of the lock. func WithLockTables(tables ...string) LockOption { return func(c *LockOptions) { c.Tables = tables } } // WithLockClause allows providing a custom clause for // locking the statement. For example, in MySQL <= 8.22: // // Select(). // From(Table("users")). // ForShare( // WithLockClause("LOCK IN SHARE MODE"), // ) func WithLockClause(clause string) LockOption { return func(c *LockOptions) { c.clause = clause } } // For sets the lock configuration for suffixing the `SELECT` // statement with the `FOR [SHARE | UPDATE] ...` clause. func (s *Selector) For(l LockStrength, opts ...LockOption) *Selector { if s.Dialect() == dialect.SQLite { s.AddError(errors.New("sql: SELECT .. FOR UPDATE/SHARE not supported in SQLite")) } s.lock = &LockOptions{Strength: l} for _, opt := range opts { opt(s.lock) } return s } // ForShare sets the lock configuration for suffixing the // `SELECT` statement with the `FOR SHARE` clause. func (s *Selector) ForShare(opts ...LockOption) *Selector { return s.For(LockShare, opts...) } // ForUpdate sets the lock configuration for suffixing the // `SELECT` statement with the `FOR UPDATE` clause. func (s *Selector) ForUpdate(opts ...LockOption) *Selector { return s.For(LockUpdate, opts...) } // Clone returns a duplicate of the selector, including all associated steps. It can be // used to prepare common SELECT statements and use them differently after the clone is made. func (s *Selector) Clone() *Selector { if s == nil { return nil } joins := make([]join, len(s.joins)) for i := range s.joins { joins[i] = s.joins[i].clone() } return &Selector{ Builder: s.Builder.clone(), ctx: s.ctx, as: s.as, or: s.or, not: s.not, from: s.from, limit: s.limit, offset: s.offset, distinct: s.distinct, where: s.where.clone(), having: s.having.clone(), joins: append([]join{}, joins...), group: append([]string{}, s.group...), order: append([]any{}, s.order...), selection: append([]any{}, s.selection...), } } // Asc adds the ASC suffix for the given column. func Asc(column string) string { b := &Builder{} b.Ident(column).WriteString(" ASC") return b.String() } // Desc adds the DESC suffix for the given column. func Desc(column string) string { b := &Builder{} b.Ident(column).WriteString(" DESC") return b.String() } // OrderBy appends the `ORDER BY` clause to the `SELECT` statement. func (s *Selector) OrderBy(columns ...string) *Selector { for i := range columns { s.order = append(s.order, columns[i]) } return s } // OrderColumns returns the ordered columns in the Selector. // Note, this function skips columns selected with expressions. func (s *Selector) OrderColumns() []string { columns := make([]string, 0, len(s.order)) for i := range s.order { if c, ok := s.order[i].(string); ok { columns = append(columns, c) } } return columns } // OrderExpr appends the `ORDER BY` clause to the `SELECT` // statement with custom list of expressions. func (s *Selector) OrderExpr(exprs ...Querier) *Selector { for i := range exprs { s.order = append(s.order, exprs[i]) } return s } // GroupBy appends the `GROUP BY` clause to the `SELECT` statement. func (s *Selector) GroupBy(columns ...string) *Selector { s.group = append(s.group, columns...) return s } // Having appends a predicate for the `HAVING` clause. func (s *Selector) Having(p *Predicate) *Selector { s.having = p return s } // Query returns query representation of a `SELECT` statement. func (s *Selector) Query() (string, []any) { b := s.Builder.clone() s.joinPrefix(&b) b.WriteString("SELECT ") if s.distinct { b.WriteString("DISTINCT ") } if len(s.selection) > 0 { s.joinSelect(&b) } else { b.WriteString("*") } if len(s.from) > 0 { b.WriteString(" FROM ") } for i, from := range s.from { if i > 0 { b.Comma() } switch t := from.(type) { case *SelectTable: t.SetDialect(s.dialect) b.WriteString(t.ref()) case *Selector: t.SetDialect(s.dialect) b.Nested(func(b *Builder) { b.Join(t) }) b.WriteString(" AS ") b.Ident(t.as) case *WithBuilder: t.SetDialect(s.dialect) b.Ident(t.Name()) case *queryView: b.Join(t.Querier) } } for _, join := range s.joins { b.WriteString(" " + join.kind + " ") switch view := join.table.(type) { case *SelectTable: view.SetDialect(s.dialect) b.WriteString(view.ref()) case *Selector: view.SetDialect(s.dialect) b.Nested(func(b *Builder) { b.Join(view) }) b.WriteString(" AS ") b.Ident(view.as) case *WithBuilder: view.SetDialect(s.dialect) b.Ident(view.Name()) } if join.on != nil { b.WriteString(" ON ") b.Join(join.on) } } if s.where != nil { b.WriteString(" WHERE ") b.Join(s.where) } if len(s.group) > 0 { b.WriteString(" GROUP BY ") b.IdentComma(s.group...) } if s.having != nil { b.WriteString(" HAVING ") b.Join(s.having) } if len(s.union) > 0 { s.joinUnion(&b) } joinOrder(s.order, &b) if s.limit != nil { b.WriteString(" LIMIT ") b.WriteString(strconv.Itoa(*s.limit)) } if s.offset != nil { b.WriteString(" OFFSET ") b.WriteString(strconv.Itoa(*s.offset)) } s.joinLock(&b) s.total = b.total s.AddError(b.Err()) return b.String(), b.args } func (s *Selector) joinPrefix(b *Builder) { if len(s.prefix) > 0 { b.join(s.prefix, " ") b.Pad() } } func (s *Selector) joinLock(b *Builder) { if s.lock == nil { return } b.Pad() if s.lock.clause != "" { b.WriteString(s.lock.clause) return } b.WriteString("FOR ").WriteString(string(s.lock.Strength)) if len(s.lock.Tables) > 0 { b.WriteString(" OF ").IdentComma(s.lock.Tables...) } if s.lock.Action != "" { b.Pad().WriteString(string(s.lock.Action)) } } func (s *Selector) joinUnion(b *Builder) { for _, union := range s.union { b.WriteString(" UNION ") if union.unionType != "" { b.WriteString(string(union.unionType) + " ") } switch view := union.TableView.(type) { case *SelectTable: view.SetDialect(s.dialect) b.WriteString(view.ref()) case *Selector: view.SetDialect(s.dialect) b.Join(view) if view.as != "" { b.WriteString(" AS ") b.Ident(view.as) } } } } func joinOrder(order []any, b *Builder) { if len(order) == 0 { return } b.WriteString(" ORDER BY ") for i := range order { if i > 0 { b.Comma() } switch r := order[i].(type) { case string: b.Ident(r) case Querier: b.Join(r) } } } func (s *Selector) joinSelect(b *Builder) { for i := range s.selection { if i > 0 { b.Comma() } switch s := s.selection[i].(type) { case string: b.Ident(s) case Querier: b.Join(s) } } } // implement the table view interface. func (*Selector) view() {} // WithBuilder is the builder for the `WITH` statement. type WithBuilder struct { Builder recursive bool ctes []struct { name string columns []string s *Selector } } // With returns a new builder for the `WITH` statement. // // n := Queries{ // With("users_view").As(Select().From(Table("users"))), // Select().From(Table("users_view")), // } // return n.Query() func With(name string, columns ...string) *WithBuilder { return &WithBuilder{ ctes: []struct { name string columns []string s *Selector }{ {name: name, columns: columns}, }, } } // WithRecursive returns a new builder for the `WITH RECURSIVE` statement. // // n := Queries{ // WithRecursive("users_view").As(Select().From(Table("users"))), // Select().From(Table("users_view")), // } // return n.Query() func WithRecursive(name string, columns ...string) *WithBuilder { w := With(name, columns...) w.recursive = true return w } // Name returns the name of the view. func (w *WithBuilder) Name() string { return w.ctes[0].name } // As sets the view sub query. func (w *WithBuilder) As(s *Selector) *WithBuilder { w.ctes[len(w.ctes)-1].s = s return w } // With appends another named CTE to the statement. func (w *WithBuilder) With(name string, columns ...string) *WithBuilder { w.ctes = append(w.ctes, With(name, columns...).ctes...) return w } // C returns a formatted string for the WITH column. func (w *WithBuilder) C(column string) string { b := &Builder{dialect: w.dialect} b.Ident(w.Name()).WriteByte('.').Ident(column) return b.String() } // Query returns query representation of a `WITH` clause. func (w *WithBuilder) Query() (string, []any) { w.WriteString("WITH ") if w.recursive { w.WriteString("RECURSIVE ") } for i, cte := range w.ctes { if i > 0 { w.Comma() } w.Ident(cte.name) if len(cte.columns) > 0 { w.WriteByte('(') w.IdentComma(cte.columns...) w.WriteByte(')') } w.WriteString(" AS ") w.Nested(func(b *Builder) { b.Join(cte.s) }) } return w.String(), w.args } // implement the table view interface. func (*WithBuilder) view() {} // WindowBuilder represents a builder for a window clause. // Note that window functions support is limited and used // only to query rows-limited edges in pagination. type WindowBuilder struct { Builder fn string // e.g. ROW_NUMBER(), RANK(). partition func(*Builder) order []any } // RowNumber returns a new window clause with the ROW_NUMBER() as a function. // Using this function will assign a each row a number, from 1 to N, in the // order defined by the ORDER BY clause in the window spec. func RowNumber() *WindowBuilder { return &WindowBuilder{fn: "ROW_NUMBER"} } // PartitionBy indicates to divide the query rows into groups by the given columns. // Note that, standard SQL spec allows partition only by columns, and in order to // use the "expression" version, use the PartitionByExpr. func (w *WindowBuilder) PartitionBy(columns ...string) *WindowBuilder { w.partition = func(b *Builder) { b.IdentComma(columns...) } return w } // PartitionExpr indicates to divide the query rows into groups by the given expression. func (w *WindowBuilder) PartitionExpr(x Querier) *WindowBuilder { w.partition = func(b *Builder) { b.Join(x) } return w } // OrderBy indicates how to sort rows in each partition. func (w *WindowBuilder) OrderBy(columns ...string) *WindowBuilder { for i := range columns { w.order = append(w.order, columns[i]) } return w } // OrderExpr appends the `ORDER BY` clause to the window // partition with custom list of expressions. func (w *WindowBuilder) OrderExpr(exprs ...Querier) *WindowBuilder { for i := range exprs { w.order = append(w.order, exprs[i]) } return w } // Query returns query representation of the window function. func (w *WindowBuilder) Query() (string, []any) { w.WriteString(w.fn) w.WriteString("() OVER ") w.Nested(func(b *Builder) { if w.partition != nil { b.WriteString("PARTITION BY ") w.partition(b) } joinOrder(w.order, b) }) return w.Builder.String(), w.args } // Wrapper wraps a given Querier with different format. // Used to prefix/suffix other queries. type Wrapper struct { format string wrapped Querier } // Query returns query representation of a wrapped Querier. func (w *Wrapper) Query() (string, []any) { query, args := w.wrapped.Query() return fmt.Sprintf(w.format, query), args } // SetDialect calls SetDialect on the wrapped query. func (w *Wrapper) SetDialect(name string) { if s, ok := w.wrapped.(state); ok { s.SetDialect(name) } } // Dialect calls Dialect on the wrapped query. func (w *Wrapper) Dialect() string { if s, ok := w.wrapped.(state); ok { return s.Dialect() } return "" } // Total returns the total number of arguments so far. func (w *Wrapper) Total() int { if s, ok := w.wrapped.(state); ok { return s.Total() } return 0 } // SetTotal sets the value of the total arguments. // Used to pass this information between sub queries/expressions. func (w *Wrapper) SetTotal(total int) { if s, ok := w.wrapped.(state); ok { s.SetTotal(total) } } // Raw returns a raw SQL query that is placed as-is in the query. func Raw(s string) Querier { return &raw{s} } type raw struct{ s string } func (r *raw) Query() (string, []any) { return r.s, nil } // Expr returns an SQL expression that implements the Querier interface. func Expr(exr string, args ...any) Querier { return &expr{s: exr, args: args} } type expr struct { s string args []any } func (e *expr) Query() (string, []any) { return e.s, e.args } // ExprFunc returns an expression function that implements the Querier interface. // // Update("users"). // Set("x", ExprFunc(func(b *Builder) { // // The sql.Builder config (argc and dialect) // // was set before the function was executed. // b.Ident("x").WriteOp(OpAdd).Arg(1) // })) func ExprFunc(fn func(*Builder)) Querier { return &exprFunc{fn: fn} } type exprFunc struct { Builder fn func(*Builder) } func (e *exprFunc) Query() (string, []any) { e.fn(&e.Builder) return e.Builder.Query() } // Queries are list of queries join with space between them. type Queries []Querier // Query returns query representation of Queriers. func (n Queries) Query() (string, []any) { b := &Builder{} for i := range n { if i > 0 { b.Pad() } query, args := n[i].Query() b.WriteString(query) b.args = append(b.args, args...) } return b.String(), b.args } // Builder is the base query builder for the sql dsl. type Builder struct { sb *strings.Builder // underlying builder. dialect string // configured dialect. args []any // query parameters. total int // total number of parameters in query tree. errs []error // errors that added during the query construction. qualifier string // qualifier to prefix identifiers (e.g. table name). } // Quote quotes the given identifier with the characters based // on the configured dialect. It defaults to "`". func (b *Builder) Quote(ident string) string { quote := "`" switch { case b.postgres(): // If it was quoted with the wrong // identifier character. if strings.Contains(ident, "`") { return strings.ReplaceAll(ident, "`", `"`) } quote = `"` // An identifier for unknown dialect. case b.dialect == "" && strings.ContainsAny(ident, "`\""): return ident } return quote + ident + quote } // Ident appends the given string as an identifier. func (b *Builder) Ident(s string) *Builder { switch { case len(s) == 0: case !strings.HasSuffix(s, "*") && !b.isIdent(s) && !isFunc(s) && !isModifier(s): if b.qualifier != "" { b.WriteString(b.Quote(b.qualifier)).WriteByte('.') } b.WriteString(b.Quote(s)) case (isFunc(s) || isModifier(s)) && b.postgres(): // Modifiers and aggregation functions that // were called without dialect information. b.WriteString(strings.ReplaceAll(s, "`", `"`)) default: b.WriteString(s) } return b } // IdentComma calls Ident on all arguments and adds a comma between them. func (b *Builder) IdentComma(s ...string) *Builder { for i := range s { if i > 0 { b.Comma() } b.Ident(s[i]) } return b } // String returns the accumulated string. func (b *Builder) String() string { if b.sb == nil { return "" } return b.sb.String() } // WriteByte wraps the Buffer.WriteByte to make it chainable with other methods. func (b *Builder) WriteByte(c byte) *Builder { if b.sb == nil { b.sb = &strings.Builder{} } b.sb.WriteByte(c) return b } // WriteString wraps the Buffer.WriteString to make it chainable with other methods. func (b *Builder) WriteString(s string) *Builder { if b.sb == nil { b.sb = &strings.Builder{} } b.sb.WriteString(s) return b } // Len returns the number of accumulated bytes. func (b *Builder) Len() int { if b.sb == nil { return 0 } return b.sb.Len() } // Reset resets the Builder to be empty. func (b *Builder) Reset() *Builder { if b.sb != nil { b.sb.Reset() } return b } // AddError appends an error to the builder errors. func (b *Builder) AddError(err error) *Builder { // allowed nil error make build process easier if err != nil { b.errs = append(b.errs, err) } return b } func (b *Builder) writeSchema(schema string) { if schema != "" && b.dialect != dialect.SQLite { b.Ident(schema).WriteByte('.') } } // Err returns a concatenated error of all errors encountered during // the query-building, or were added manually by calling AddError. func (b *Builder) Err() error { if len(b.errs) == 0 { return nil } br := strings.Builder{} for i := range b.errs { if i > 0 { br.WriteString("; ") } br.WriteString(b.errs[i].Error()) } return fmt.Errorf(br.String()) } // An Op represents an operator. type Op int const ( // Predicate operators. OpEQ Op = iota // = OpNEQ // <> OpGT // > OpGTE // >= OpLT // < OpLTE // <= OpIn // IN OpNotIn // NOT IN OpLike // LIKE OpIsNull // IS NULL OpNotNull // IS NOT NULL // Arithmetic operators. OpAdd // + OpSub // - OpMul // * OpDiv // / (Quotient) OpMod // % (Reminder) ) var ops = [...]string{ OpEQ: "=", OpNEQ: "<>", OpGT: ">", OpGTE: ">=", OpLT: "<", OpLTE: "<=", OpIn: "IN", OpNotIn: "NOT IN", OpLike: "LIKE", OpIsNull: "IS NULL", OpNotNull: "IS NOT NULL", OpAdd: "+", OpSub: "-", OpMul: "*", OpDiv: "/", OpMod: "%", } // WriteOp writes an operator to the builder. func (b *Builder) WriteOp(op Op) *Builder { switch { case op >= OpEQ && op <= OpLike || op >= OpAdd && op <= OpMod: b.Pad().WriteString(ops[op]).Pad() case op == OpIsNull || op == OpNotNull: b.Pad().WriteString(ops[op]) default: panic(fmt.Sprintf("invalid op %d", op)) } return b } type ( // StmtInfo holds an information regarding // the statement StmtInfo struct { // The Dialect of the SQL driver. Dialect string } // ParamFormatter wraps the FormatPram function. ParamFormatter interface { // The FormatParam function lets users to define // custom placeholder formatting for their types. // For example, formatting the default placeholder // from '?' to 'ST_GeomFromWKB(?)' for MySQL dialect. FormatParam(placeholder string, info *StmtInfo) string } ) // Arg appends an input argument to the builder. func (b *Builder) Arg(a any) *Builder { switch a := a.(type) { case nil: b.WriteString("NULL") return b case *raw: b.WriteString(a.s) return b case Querier: b.Join(a) return b } b.total++ b.args = append(b.args, a) // Default placeholder param (MySQL and SQLite). param := "?" if b.postgres() { // Postgres' arguments are referenced using the syntax $n. // $1 refers to the 1st argument, $2 to the 2nd, and so on. param = "$" + strconv.Itoa(b.total) } if f, ok := a.(ParamFormatter); ok { param = f.FormatParam(param, &StmtInfo{ Dialect: b.dialect, }) } b.WriteString(param) return b } // Args appends a list of arguments to the builder. func (b *Builder) Args(a ...any) *Builder { for i := range a { if i > 0 { b.Comma() } b.Arg(a[i]) } return b } // Comma adds a comma to the query. func (b *Builder) Comma() *Builder { return b.WriteString(", ") } // Pad adds a space to the query. func (b *Builder) Pad() *Builder { return b.WriteByte(' ') } // Join joins a list of Queries to the builder. func (b *Builder) Join(qs ...Querier) *Builder { return b.join(qs, "") } // JoinComma joins a list of Queries and adds comma between them. func (b *Builder) JoinComma(qs ...Querier) *Builder { return b.join(qs, ", ") } // join a list of Queries to the builder with a given separator. func (b *Builder) join(qs []Querier, sep string) *Builder { for i, q := range qs { if i > 0 { b.WriteString(sep) } st, ok := q.(state) if ok { st.SetDialect(b.dialect) st.SetTotal(b.total) } query, args := q.Query() b.WriteString(query) b.args = append(b.args, args...) b.total += len(args) if qe, ok := q.(querierErr); ok { if err := qe.Err(); err != nil { b.AddError(err) } } } return b } // Nested gets a callback, and wraps its result with parentheses. func (b *Builder) Nested(f func(*Builder)) *Builder { nb := &Builder{dialect: b.dialect, total: b.total, sb: &strings.Builder{}} nb.WriteByte('(') f(nb) nb.WriteByte(')') b.WriteString(nb.String()) b.args = append(b.args, nb.args...) b.total = nb.total return b } // SetDialect sets the builder dialect. It's used for garnering dialect specific queries. func (b *Builder) SetDialect(dialect string) { b.dialect = dialect } // Dialect returns the dialect of the builder. func (b Builder) Dialect() string { return b.dialect } // Total returns the total number of arguments so far. func (b Builder) Total() int { return b.total } // SetTotal sets the value of the total arguments. // Used to pass this information between sub queries/expressions. func (b *Builder) SetTotal(total int) { b.total = total } // Query implements the Querier interface. func (b Builder) Query() (string, []any) { return b.String(), b.args } // clone returns a shallow clone of a builder. func (b Builder) clone() Builder { c := Builder{dialect: b.dialect, total: b.total, sb: &strings.Builder{}} if len(b.args) > 0 { c.args = append(c.args, b.args...) } if b.sb != nil { c.sb.WriteString(b.sb.String()) } return c } // postgres reports if the builder dialect is PostgreSQL. func (b Builder) postgres() bool { return b.Dialect() == dialect.Postgres } // mysql reports if the builder dialect is MySQL. func (b Builder) mysql() bool { return b.Dialect() == dialect.MySQL } // fromIdent sets the builder dialect from the identifier format. func (b *Builder) fromIdent(ident string) { if strings.Contains(ident, `"`) { b.SetDialect(dialect.Postgres) } // otherwise, use the default. } // isIdent reports if the given string is a dialect identifier. func (b *Builder) isIdent(s string) bool { switch { case b.postgres(): return strings.Contains(s, `"`) default: return strings.Contains(s, "`") } } // state wraps the all methods for setting and getting // update state between all queries in the query tree. type state interface { Dialect() string SetDialect(string) Total() int SetTotal(int) } // DialectBuilder prefixes all root builders with the `Dialect` constructor. type DialectBuilder struct { dialect string } // Dialect creates a new DialectBuilder with the given dialect name. func Dialect(name string) *DialectBuilder { return &DialectBuilder{name} } // Describe creates a DescribeBuilder for the configured dialect. // // Dialect(dialect.Postgres). // Describe("users") func (d *DialectBuilder) Describe(name string) *DescribeBuilder { b := Describe(name) b.SetDialect(d.dialect) return b } // CreateTable creates a TableBuilder for the configured dialect. // // Dialect(dialect.Postgres). // CreateTable("users"). // Columns( // Column("id").Type("int").Attr("auto_increment"), // Column("name").Type("varchar(255)"), // ). // PrimaryKey("id") func (d *DialectBuilder) CreateTable(name string) *TableBuilder { b := CreateTable(name) b.SetDialect(d.dialect) return b } // AlterTable creates a TableAlter for the configured dialect. // // Dialect(dialect.Postgres). // AlterTable("users"). // AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). // AddForeignKey(ForeignKey().Columns("group_id"). // Reference(Reference().Table("groups").Columns("id")). // OnDelete("CASCADE"), // ) func (d *DialectBuilder) AlterTable(name string) *TableAlter { b := AlterTable(name) b.SetDialect(d.dialect) return b } // AlterIndex creates an IndexAlter for the configured dialect. // // Dialect(dialect.Postgres). // AlterIndex("old"). // Rename("new") func (d *DialectBuilder) AlterIndex(name string) *IndexAlter { b := AlterIndex(name) b.SetDialect(d.dialect) return b } // Column creates a ColumnBuilder for the configured dialect. // // Dialect(dialect.Postgres).. // Column("group_id").Type("int").Attr("UNIQUE") func (d *DialectBuilder) Column(name string) *ColumnBuilder { b := Column(name) b.SetDialect(d.dialect) return b } // Insert creates a InsertBuilder for the configured dialect. // // Dialect(dialect.Postgres). // Insert("users").Columns("age").Values(1) func (d *DialectBuilder) Insert(table string) *InsertBuilder { b := Insert(table) b.SetDialect(d.dialect) return b } // Update creates a UpdateBuilder for the configured dialect. // // Dialect(dialect.Postgres). // Update("users").Set("name", "foo") func (d *DialectBuilder) Update(table string) *UpdateBuilder { b := Update(table) b.SetDialect(d.dialect) return b } // Delete creates a DeleteBuilder for the configured dialect. // // Dialect(dialect.Postgres). // Delete().From("users") func (d *DialectBuilder) Delete(table string) *DeleteBuilder { b := Delete(table) b.SetDialect(d.dialect) return b } // Select creates a Selector for the configured dialect. // // Dialect(dialect.Postgres). // Select().From(Table("users")) func (d *DialectBuilder) Select(columns ...string) *Selector { b := Select(columns...) b.SetDialect(d.dialect) return b } // SelectExpr is like Select, but supports passing arbitrary // expressions for SELECT clause. // // Dialect(dialect.Postgres). // SelectExpr(expr...). // From(Table("users")) func (d *DialectBuilder) SelectExpr(exprs ...Querier) *Selector { b := SelectExpr(exprs...) b.SetDialect(d.dialect) return b } // Table creates a SelectTable for the configured dialect. // // Dialect(dialect.Postgres). // Table("users").As("u") func (d *DialectBuilder) Table(name string) *SelectTable { b := Table(name) b.SetDialect(d.dialect) return b } // With creates a WithBuilder for the configured dialect. // // Dialect(dialect.Postgres). // With("users_view"). // As(Select().From(Table("users"))) func (d *DialectBuilder) With(name string) *WithBuilder { b := With(name) b.SetDialect(d.dialect) return b } // CreateIndex creates a IndexBuilder for the configured dialect. // // Dialect(dialect.Postgres). // CreateIndex("unique_name"). // Unique(). // Table("users"). // Columns("first", "last") func (d *DialectBuilder) CreateIndex(name string) *IndexBuilder { b := CreateIndex(name) b.SetDialect(d.dialect) return b } // DropIndex creates a DropIndexBuilder for the configured dialect. // // Dialect(dialect.Postgres). // DropIndex("name") func (d *DialectBuilder) DropIndex(name string) *DropIndexBuilder { b := DropIndex(name) b.SetDialect(d.dialect) return b } func isFunc(s string) bool { return strings.Contains(s, "(") && strings.Contains(s, ")") } func isModifier(s string) bool { for _, m := range [...]string{"DISTINCT", "ALL", "WITH ROLLUP"} { if strings.HasPrefix(s, m) { return true } } return false } ent-0.11.3/dialect/sql/builder_test.go000066400000000000000000002240531431500740500176070ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sql import ( "context" "database/sql/driver" "fmt" "strconv" "strings" "testing" "entgo.io/ent/dialect" "github.com/stretchr/testify/require" ) func TestBuilder(t *testing.T) { tests := []struct { input Querier wantQuery string wantArgs []any }{ { input: Describe("users"), wantQuery: "DESCRIBE `users`", }, { input: CreateTable("users"). Columns( Column("id").Type("int").Attr("auto_increment"), Column("name").Type("varchar(255)"), ). PrimaryKey("id"), wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`))", }, { input: Dialect(dialect.Postgres).CreateTable("users"). Columns( Column("id").Type("serial").Attr("PRIMARY KEY"), Column("name").Type("varchar"), ), wantQuery: `CREATE TABLE "users"("id" serial PRIMARY KEY, "name" varchar)`, }, { input: CreateTable("users"). Columns( Column("id").Type("int").Attr("auto_increment"), Column("name").Type("varchar(255)"), ). PrimaryKey("id"). Charset("utf8mb4"), wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4", }, { input: CreateTable("users"). Columns( Column("id").Type("int").Attr("auto_increment"), Column("name").Type("varchar(255)"), ). PrimaryKey("id"). Charset("utf8mb4"). Collate("utf8mb4_general_ci"). Options("ENGINE=InnoDB"), wantQuery: "CREATE TABLE `users`(`id` int auto_increment, `name` varchar(255), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci ENGINE=InnoDB", }, { input: CreateTable("users"). IfNotExists(). Columns( Column("id").Type("int").Attr("auto_increment"), ). PrimaryKey("id", "name"), wantQuery: "CREATE TABLE IF NOT EXISTS `users`(`id` int auto_increment, PRIMARY KEY(`id`, `name`))", }, { input: CreateTable("users"). IfNotExists(). Columns( Column("id").Type("int").Attr("auto_increment"), Column("card_id").Type("int"), Column("doc").Type("longtext").Check(func(b *Builder) { b.WriteString("JSON_VALID(").Ident("doc").WriteByte(')') }), ). PrimaryKey("id", "name"). ForeignKeys(ForeignKey().Columns("card_id"). Reference(Reference().Table("cards").Columns("id")).OnDelete("SET NULL")). Checks(func(b *Builder) { b.WriteString("CONSTRAINT ").Ident("valid_card").WriteString(" CHECK (").Ident("card_id").WriteString(" > 0)") }), wantQuery: "CREATE TABLE IF NOT EXISTS `users`(`id` int auto_increment, `card_id` int, `doc` longtext CHECK (JSON_VALID(`doc`)), PRIMARY KEY(`id`, `name`), FOREIGN KEY(`card_id`) REFERENCES `cards`(`id`) ON DELETE SET NULL, CONSTRAINT `valid_card` CHECK (`card_id` > 0))", }, { input: Dialect(dialect.Postgres).CreateTable("users"). IfNotExists(). Columns( Column("id").Type("serial"), Column("card_id").Type("int"), ). PrimaryKey("id", "name"). ForeignKeys(ForeignKey().Columns("card_id"). Reference(Reference().Table("cards").Columns("id")).OnDelete("SET NULL")), wantQuery: `CREATE TABLE IF NOT EXISTS "users"("id" serial, "card_id" int, PRIMARY KEY("id", "name"), FOREIGN KEY("card_id") REFERENCES "cards"("id") ON DELETE SET NULL)`, }, { input: AlterTable("users"). AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). AddForeignKey(ForeignKey().Columns("group_id"). Reference(Reference().Table("groups").Columns("id")). OnDelete("CASCADE"), ), wantQuery: "ALTER TABLE `users` ADD COLUMN `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`) ON DELETE CASCADE", }, { input: Dialect(dialect.Postgres).AlterTable("users"). AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). AddForeignKey(ForeignKey("constraint").Columns("group_id"). Reference(Reference().Table("groups").Columns("id")). OnDelete("CASCADE"), ), wantQuery: `ALTER TABLE "users" ADD COLUMN "group_id" int UNIQUE, ADD CONSTRAINT "constraint" FOREIGN KEY("group_id") REFERENCES "groups"("id") ON DELETE CASCADE`, }, { input: AlterTable("users"). AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). AddForeignKey(ForeignKey().Columns("group_id"). Reference(Reference().Table("groups").Columns("id")), ), wantQuery: "ALTER TABLE `users` ADD COLUMN `group_id` int UNIQUE, ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`)", }, { input: Dialect(dialect.Postgres).AlterTable("users"). AddColumn(Column("group_id").Type("int").Attr("UNIQUE")). AddForeignKey(ForeignKey().Columns("group_id"). Reference(Reference().Table("groups").Columns("id")), ), wantQuery: `ALTER TABLE "users" ADD COLUMN "group_id" int UNIQUE, ADD CONSTRAINT FOREIGN KEY("group_id") REFERENCES "groups"("id")`, }, { input: AlterTable("users"). AddColumn(Column("age").Type("int")). AddColumn(Column("name").Type("varchar(255)")), wantQuery: "ALTER TABLE `users` ADD COLUMN `age` int, ADD COLUMN `name` varchar(255)", }, { input: AlterTable("users"). DropForeignKey("users_parent_id"), wantQuery: "ALTER TABLE `users` DROP FOREIGN KEY `users_parent_id`", }, { input: Dialect(dialect.Postgres).AlterTable("users"). AddColumn(Column("age").Type("int")). AddColumn(Column("name").Type("varchar(255)")). DropConstraint("users_nickname_key"), wantQuery: `ALTER TABLE "users" ADD COLUMN "age" int, ADD COLUMN "name" varchar(255), DROP CONSTRAINT "users_nickname_key"`, }, { input: AlterTable("users"). AddForeignKey(ForeignKey().Columns("group_id"). Reference(Reference().Table("groups").Columns("id")), ). AddForeignKey(ForeignKey().Columns("location_id"). Reference(Reference().Table("locations").Columns("id")), ), wantQuery: "ALTER TABLE `users` ADD CONSTRAINT FOREIGN KEY(`group_id`) REFERENCES `groups`(`id`), ADD CONSTRAINT FOREIGN KEY(`location_id`) REFERENCES `locations`(`id`)", }, { input: AlterTable("users"). ModifyColumn(Column("age").Type("int")), wantQuery: "ALTER TABLE `users` MODIFY COLUMN `age` int", }, { input: Dialect(dialect.Postgres).AlterTable("users"). ModifyColumn(Column("age").Type("int")), wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int`, }, { input: AlterTable("users"). ModifyColumn(Column("age").Type("int")). DropColumn(Column("name")), wantQuery: "ALTER TABLE `users` MODIFY COLUMN `age` int, DROP COLUMN `name`", }, { input: Dialect(dialect.Postgres).AlterTable("users"). ModifyColumn(Column("age").Type("int")). DropColumn(Column("name")), wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int, DROP COLUMN "name"`, }, { input: Dialect(dialect.Postgres).AlterTable("users"). ModifyColumn(Column("age").Type("int")). ModifyColumn(Column("age").Attr("SET NOT NULL")). ModifyColumn(Column("name").Attr("DROP NOT NULL")), wantQuery: `ALTER TABLE "users" ALTER COLUMN "age" TYPE int, ALTER COLUMN "age" SET NOT NULL, ALTER COLUMN "name" DROP NOT NULL`, }, { input: AlterTable("users"). ChangeColumn("old_age", Column("age").Type("int")), wantQuery: "ALTER TABLE `users` CHANGE COLUMN `old_age` `age` int", }, { input: Dialect(dialect.Postgres).AlterTable("users"). AddColumn(Column("boring").Type("varchar")). ModifyColumn(Column("age").Type("int")). DropColumn(Column("name")), wantQuery: `ALTER TABLE "users" ADD COLUMN "boring" varchar, ALTER COLUMN "age" TYPE int, DROP COLUMN "name"`, }, { input: AlterTable("users").RenameIndex("old", "new"), wantQuery: "ALTER TABLE `users` RENAME INDEX `old` TO `new`", }, { input: AlterTable("users"). DropIndex("old"). AddIndex(CreateIndex("new1").Columns("c1", "c2")). AddIndex(CreateIndex("new2").Columns("c1", "c2").Unique()), wantQuery: "ALTER TABLE `users` DROP INDEX `old`, ADD INDEX `new1`(`c1`, `c2`), ADD UNIQUE INDEX `new2`(`c1`, `c2`)", }, { input: Dialect(dialect.Postgres).AlterIndex("old"). Rename("new"), wantQuery: `ALTER INDEX "old" RENAME TO "new"`, }, { input: Insert("users").Columns("age").Values(1), wantQuery: "INSERT INTO `users` (`age`) VALUES (?)", wantArgs: []any{1}, }, { input: Insert("users").Columns("age").Values(1).Schema("mydb"), wantQuery: "INSERT INTO `mydb`.`users` (`age`) VALUES (?)", wantArgs: []any{1}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("age").Values(1), wantQuery: `INSERT INTO "users" ("age") VALUES ($1)`, wantArgs: []any{1}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("age").Values(1).Schema("mydb"), wantQuery: `INSERT INTO "mydb"."users" ("age") VALUES ($1)`, wantArgs: []any{1}, }, { input: Dialect(dialect.SQLite).Insert("users").Columns("age").Values(1).Schema("mydb"), wantQuery: "INSERT INTO `users` (`age`) VALUES (?)", wantArgs: []any{1}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("age").Values(1).Returning("id"), wantQuery: `INSERT INTO "users" ("age") VALUES ($1) RETURNING "id"`, wantArgs: []any{1}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("age").Values(1).Returning("id").Returning("name"), wantQuery: `INSERT INTO "users" ("age") VALUES ($1) RETURNING "name"`, wantArgs: []any{1}, }, { input: Insert("users").Columns("name", "age").Values("a8m", 10), wantQuery: "INSERT INTO `users` (`name`, `age`) VALUES (?, ?)", wantArgs: []any{"a8m", 10}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("name", "age").Values("a8m", 10), wantQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2)`, wantArgs: []any{"a8m", 10}, }, { input: Insert("users").Columns("name", "age").Values("a8m", 10).Values("foo", 20), wantQuery: "INSERT INTO `users` (`name`, `age`) VALUES (?, ?), (?, ?)", wantArgs: []any{"a8m", 10, "foo", 20}, }, { input: Dialect(dialect.Postgres).Insert("users").Columns("name", "age").Values("a8m", 10).Values("foo", 20), wantQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2), ($3, $4)`, wantArgs: []any{"a8m", 10, "foo", 20}, }, { input: Dialect(dialect.Postgres).Insert("users"). Columns("name", "age"). Values("a8m", 10). Values("foo", 20). Values("bar", 30), wantQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2), ($3, $4), ($5, $6)`, wantArgs: []any{"a8m", 10, "foo", 20, "bar", 30}, }, { input: Update("users").Set("name", "foo"), wantQuery: "UPDATE `users` SET `name` = ?", wantArgs: []any{"foo"}, }, { input: Update("users").Set("name", "foo").Schema("mydb"), wantQuery: "UPDATE `mydb`.`users` SET `name` = ?", wantArgs: []any{"foo"}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo"), wantQuery: `UPDATE "users" SET "name" = $1`, wantArgs: []any{"foo"}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Schema("mydb"), wantQuery: `UPDATE "mydb"."users" SET "name" = $1`, wantArgs: []any{"foo"}, }, { input: Dialect(dialect.SQLite).Update("users").Set("name", "foo").Schema("mydb"), wantQuery: "UPDATE `users` SET `name` = ?", wantArgs: []any{"foo"}, }, { input: Update("users").Set("name", "foo").Set("age", 10), wantQuery: "UPDATE `users` SET `name` = ?, `age` = ?", wantArgs: []any{"foo", 10}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Set("age", 10), wantQuery: `UPDATE "users" SET "name" = $1, "age" = $2`, wantArgs: []any{"foo", 10}, }, { input: Dialect(dialect.Postgres).Update("users"). Set("active", false). Where(P(func(b *Builder) { b.Ident("name").WriteString(" SIMILAR TO ").Arg("(b|c)%") })), wantQuery: `UPDATE "users" SET "active" = $1 WHERE "name" SIMILAR TO $2`, wantArgs: []any{false, "(b|c)%"}, }, { input: Update("users").Set("name", "foo").Where(EQ("name", "bar")), wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` = ?", wantArgs: []any{"foo", "bar"}, }, { input: Update("users").Set("name", "foo").Where(EQ("name", Expr("?", "bar"))), wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` = ?", wantArgs: []any{"foo", "bar"}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").Where(EQ("name", "bar")), wantQuery: `UPDATE "users" SET "name" = $1 WHERE "name" = $2`, wantArgs: []any{"foo", "bar"}, }, { input: func() Querier { p1, p2 := EQ("name", "bar"), Or(EQ("age", 10), EQ("age", 20)) return Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Where(p1). Where(p2). Where(p1). Where(p2) }(), wantQuery: `UPDATE "users" SET "name" = $1 WHERE (("name" = $2 AND ("age" = $3 OR "age" = $4)) AND "name" = $5) AND ("age" = $6 OR "age" = $7)`, wantArgs: []any{"foo", "bar", 10, 20, "bar", 10, 20}, }, { input: Update("users").Set("name", "foo").SetNull("spouse_id"), wantQuery: "UPDATE `users` SET `spouse_id` = NULL, `name` = ?", wantArgs: []any{"foo"}, }, { input: Dialect(dialect.Postgres).Update("users").Set("name", "foo").SetNull("spouse_id"), wantQuery: `UPDATE "users" SET "spouse_id" = NULL, "name" = $1`, wantArgs: []any{"foo"}, }, { input: Update("users").Set("name", "foo"). Where(EQ("name", "bar")). Where(EQ("age", 20)), wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` = ? AND `age` = ?", wantArgs: []any{"foo", "bar", 20}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Where(EQ("name", "bar")). Where(EQ("age", 20)), wantQuery: `UPDATE "users" SET "name" = $1 WHERE "name" = $2 AND "age" = $3`, wantArgs: []any{"foo", "bar", 20}, }, { input: Update("users"). Set("name", "foo"). Set("age", 10). Where(Or(EQ("name", "bar"), EQ("name", "baz"))), wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ? OR `name` = ?", wantArgs: []any{"foo", 10, "bar", "baz"}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Set("age", 10). Where(Or(EQ("name", "bar"), EQ("name", "baz"))), wantQuery: `UPDATE "users" SET "name" = $1, "age" = $2 WHERE "name" = $3 OR "name" = $4`, wantArgs: []any{"foo", 10, "bar", "baz"}, }, { input: Update("users"). Set("name", "foo"). Set("age", 10). Where(P().EQ("name", "foo")), wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ?", wantArgs: []any{"foo", 10, "foo"}, }, { input: Dialect(dialect.Postgres). Update("users"). Add("rank", 10). Where( Or( EQ("rank", Select("rank").From(Table("ranks")).Where(EQ("name", "foo"))), GT("score", Select("score").From(Table("scores")).Where(GT("count", 0))), ), ), wantQuery: `UPDATE "users" SET "rank" = COALESCE("users"."rank", 0) + $1 WHERE "rank" = (SELECT "rank" FROM "ranks" WHERE "name" = $2) OR "score" > (SELECT "score" FROM "scores" WHERE "count" > $3)`, wantArgs: []any{10, "foo", 0}, }, { input: Update("users"). Add("rank", 10). Where( Or( EQ("rank", Select("rank").From(Table("ranks")).Where(EQ("name", "foo"))), GT("score", Select("score").From(Table("scores")).Where(GT("count", 0))), ), ), wantQuery: "UPDATE `users` SET `rank` = COALESCE(`users`.`rank`, 0) + ? WHERE `rank` = (SELECT `rank` FROM `ranks` WHERE `name` = ?) OR `score` > (SELECT `score` FROM `scores` WHERE `count` > ?)", wantArgs: []any{10, "foo", 0}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Set("age", 10). Where(P().EQ("name", "foo")), wantQuery: `UPDATE "users" SET "name" = $1, "age" = $2 WHERE "name" = $3`, wantArgs: []any{"foo", 10, "foo"}, }, { input: Update("users"). Set("name", "foo"). Where(And(In("name", "bar", "baz"), NotIn("age", 1, 2))), wantQuery: "UPDATE `users` SET `name` = ? WHERE `name` IN (?, ?) AND `age` NOT IN (?, ?)", wantArgs: []any{"foo", "bar", "baz", 1, 2}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Where(And(In("name", "bar", "baz"), NotIn("age", 1, 2))), wantQuery: `UPDATE "users" SET "name" = $1 WHERE "name" IN ($2, $3) AND "age" NOT IN ($4, $5)`, wantArgs: []any{"foo", "bar", "baz", 1, 2}, }, { input: Update("users"). Set("name", "foo"). Where(And(HasPrefix("nickname", "a8m"), Contains("lastname", "mash"))), wantQuery: "UPDATE `users` SET `name` = ? WHERE `nickname` LIKE ? AND `lastname` LIKE ?", wantArgs: []any{"foo", "a8m%", "%mash%"}, }, { input: Dialect(dialect.Postgres). Update("users"). Set("name", "foo"). Where(And(HasPrefix("nickname", "a8m"), Contains("lastname", "mash"))), wantQuery: `UPDATE "users" SET "name" = $1 WHERE "nickname" LIKE $2 AND "lastname" LIKE $3`, wantArgs: []any{"foo", "a8m%", "%mash%"}, }, { input: Update("users"). Add("age", 1). Where(HasPrefix("nickname", "a8m")), wantQuery: "UPDATE `users` SET `age` = COALESCE(`users`.`age`, 0) + ? WHERE `nickname` LIKE ?", wantArgs: []any{1, "a8m%"}, }, { input: Update("users"). Set("age", 1). Add("age", 2). Where(HasPrefix("nickname", "a8m")), wantQuery: "UPDATE `users` SET `age` = ?, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `nickname` LIKE ?", wantArgs: []any{1, 2, "a8m%"}, }, { input: Update("users"). Add("age", 2). Set("age", 1). Where(HasPrefix("nickname", "a8m")), wantQuery: "UPDATE `users` SET `age` = ? WHERE `nickname` LIKE ?", wantArgs: []any{1, "a8m%"}, }, { input: Dialect(dialect.Postgres). Update("users"). Add("age", 1). Where(HasPrefix("nickname", "a8m")), wantQuery: `UPDATE "users" SET "age" = COALESCE("users"."age", 0) + $1 WHERE "nickname" LIKE $2`, wantArgs: []any{1, "a8m%"}, }, { input: Update("users"). Add("age", 1). Set("nickname", "a8m"). Add("version", 10). Set("name", "mashraki"), wantQuery: "UPDATE `users` SET `age` = COALESCE(`users`.`age`, 0) + ?, `nickname` = ?, `version` = COALESCE(`users`.`version`, 0) + ?, `name` = ?", wantArgs: []any{1, "a8m", 10, "mashraki"}, }, { input: Dialect(dialect.Postgres). Update("users"). Add("age", 1). Set("nickname", "a8m"). Add("version", 10). Set("name", "mashraki"), wantQuery: `UPDATE "users" SET "age" = COALESCE("users"."age", 0) + $1, "nickname" = $2, "version" = COALESCE("users"."version", 0) + $3, "name" = $4`, wantArgs: []any{1, "a8m", 10, "mashraki"}, }, { input: Dialect(dialect.Postgres). Update("users"). Add("age", 1). Set("nickname", "a8m"). Add("version", 10). Set("name", "mashraki"). Set("first", "ariel"). Add("score", 1e5). Where(Or(EQ("age", 1), EQ("age", 2))), wantQuery: `UPDATE "users" SET "age" = COALESCE("users"."age", 0) + $1, "nickname" = $2, "version" = COALESCE("users"."version", 0) + $3, "name" = $4, "first" = $5, "score" = COALESCE("users"."score", 0) + $6 WHERE "age" = $7 OR "age" = $8`, wantArgs: []any{1, "a8m", 10, "mashraki", "ariel", 1e5, 1, 2}, }, { input: Select(). From(Table("users")). Where(EQ("name", "Alex")), wantQuery: "SELECT * FROM `users` WHERE `name` = ?", wantArgs: []any{"Alex"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")), wantQuery: `SELECT * FROM "users"`, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(EQ("name", "Ariel")), wantQuery: `SELECT * FROM "users" WHERE "name" = $1`, wantArgs: []any{"Ariel"}, }, { input: Select(). From(Table("users")). Where(Or(EQ("name", "BAR"), EQ("name", "BAZ"))), wantQuery: "SELECT * FROM `users` WHERE `name` = ? OR `name` = ?", wantArgs: []any{"BAR", "BAZ"}, }, { input: func() Querier { t1, t2 := Table("users"), Table("pets") return Dialect(dialect.Postgres). Select(). From(t1). Where(GT(t1.C("age"), 30)). Where( And( Exists(Select().From(t2).Where(ColumnsEQ(t2.C("owner_id"), t1.C("id")))), NotExists(Select().From(t2).Where(ColumnsEQ(t2.C("owner_id"), t1.C("id")))), ), ) }(), wantQuery: `SELECT * FROM "users" WHERE "users"."age" > $1 AND (EXISTS (SELECT * FROM "pets" WHERE "pets"."owner_id" = "users"."id") AND NOT EXISTS (SELECT * FROM "pets" WHERE "pets"."owner_id" = "users"."id"))`, wantArgs: []any{30}, }, { input: Update("users"). Set("name", "foo"). Set("age", 10). Where(And(EQ("name", "foo"), EQ("age", 20))), wantQuery: "UPDATE `users` SET `name` = ?, `age` = ? WHERE `name` = ? AND `age` = ?", wantArgs: []any{"foo", 10, "foo", 20}, }, { input: Delete("users"). Where(NotNull("parent_id")), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NOT NULL", }, { input: Delete("users"). Where(NotNull("parent_id")). Schema("mydb"), wantQuery: "DELETE FROM `mydb`.`users` WHERE `parent_id` IS NOT NULL", }, { input: Dialect(dialect.SQLite). Delete("users"). Where(NotNull("parent_id")). Schema("mydb"), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NOT NULL", }, { input: Dialect(dialect.Postgres). Delete("users"). Where(IsNull("parent_id")), wantQuery: `DELETE FROM "users" WHERE "parent_id" IS NULL`, }, { input: Dialect(dialect.Postgres). Delete("users"). Where(IsNull("parent_id")). Schema("mydb"), wantQuery: `DELETE FROM "mydb"."users" WHERE "parent_id" IS NULL`, }, { input: Delete("users"). Where(And(IsNull("parent_id"), NotIn("name", "foo", "bar"))), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NULL AND `name` NOT IN (?, ?)", wantArgs: []any{"foo", "bar"}, }, { input: Dialect(dialect.Postgres). Delete("users"). Where(And(IsNull("parent_id"), NotIn("name", "foo", "bar"))), wantQuery: `DELETE FROM "users" WHERE "parent_id" IS NULL AND "name" NOT IN ($1, $2)`, wantArgs: []any{"foo", "bar"}, }, { input: Delete("users"). Where(And(IsNull("parent_id"), In("name"))), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NULL AND FALSE", }, { input: Delete("users"). Where(And(IsNull("parent_id"), NotIn("name"))), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NULL AND (NOT (FALSE))", }, { input: Delete("users"). Where(And(False(), False())), wantQuery: "DELETE FROM `users` WHERE FALSE AND FALSE", }, { input: Dialect(dialect.Postgres). Delete("users"). Where(And(False(), False())), wantQuery: `DELETE FROM "users" WHERE FALSE AND FALSE`, }, { input: Delete("users"). Where(Or(NotNull("parent_id"), EQ("parent_id", 10))), wantQuery: "DELETE FROM `users` WHERE `parent_id` IS NOT NULL OR `parent_id` = ?", wantArgs: []any{10}, }, { input: Dialect(dialect.Postgres). Delete("users"). Where(Or(NotNull("parent_id"), EQ("parent_id", 10))), wantQuery: `DELETE FROM "users" WHERE "parent_id" IS NOT NULL OR "parent_id" = $1`, wantArgs: []any{10}, }, { input: Delete("users"). Where( Or( And(EQ("name", "foo"), EQ("age", 10)), And(EQ("name", "bar"), EQ("age", 20)), And( EQ("name", "qux"), Or(EQ("age", 1), EQ("age", 2)), ), ), ), wantQuery: "DELETE FROM `users` WHERE (`name` = ? AND `age` = ?) OR (`name` = ? AND `age` = ?) OR (`name` = ? AND (`age` = ? OR `age` = ?))", wantArgs: []any{"foo", 10, "bar", 20, "qux", 1, 2}, }, { input: Dialect(dialect.Postgres). Delete("users"). Where( Or( And(EQ("name", "foo"), EQ("age", 10)), And(EQ("name", "bar"), EQ("age", 20)), And( EQ("name", "qux"), Or(EQ("age", 1), EQ("age", 2)), ), ), ), wantQuery: `DELETE FROM "users" WHERE ("name" = $1 AND "age" = $2) OR ("name" = $3 AND "age" = $4) OR ("name" = $5 AND ("age" = $6 OR "age" = $7))`, wantArgs: []any{"foo", 10, "bar", 20, "qux", 1, 2}, }, { input: Delete("users"). Where( Or( And(EQ("name", "foo"), EQ("age", 10)), And(EQ("name", "bar"), EQ("age", 20)), ), ). Where(EQ("role", "admin")), wantQuery: "DELETE FROM `users` WHERE ((`name` = ? AND `age` = ?) OR (`name` = ? AND `age` = ?)) AND `role` = ?", wantArgs: []any{"foo", 10, "bar", 20, "admin"}, }, { input: Dialect(dialect.Postgres). Delete("users"). Where( Or( And(EQ("name", "foo"), EQ("age", 10)), And(EQ("name", "bar"), EQ("age", 20)), ), ). Where(EQ("role", "admin")), wantQuery: `DELETE FROM "users" WHERE (("name" = $1 AND "age" = $2) OR ("name" = $3 AND "age" = $4)) AND "role" = $5`, wantArgs: []any{"foo", 10, "bar", 20, "admin"}, }, { input: Select().From(Table("users")), wantQuery: "SELECT * FROM `users`", }, { input: Dialect(dialect.Postgres).Select().From(Table("users")), wantQuery: `SELECT * FROM "users"`, }, { input: Select().From(Table("users").Unquote()), wantQuery: "SELECT * FROM users", }, { input: Dialect(dialect.Postgres).Select().From(Table("users").Unquote()), wantQuery: "SELECT * FROM users", }, { input: Select().From(Table("users").As("u")), wantQuery: "SELECT * FROM `users` AS `u`", }, { input: Dialect(dialect.Postgres).Select().From(Table("users").As("u")), wantQuery: `SELECT * FROM "users" AS "u"`, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Select(t1.C("id"), t2.C("name")).From(t1).Join(t2) }(), wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g`", }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Dialect(dialect.Postgres).Select(t1.C("id"), t2.C("name")).From(t1).Join(t2) }(), wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN "groups" AS "g"`, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")) }(), wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g` ON `u`.`id` = `g`.`user_id`", }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Dialect(dialect.Postgres). Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")) }(), wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN "groups" AS "g" ON "u"."id" = "g"."user_id"`, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")). Where(And(EQ(t1.C("name"), "bar"), NotNull(t2.C("name")))) }(), wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN `groups` AS `g` ON `u`.`id` = `g`.`user_id` WHERE `u`.`name` = ? AND `g`.`name` IS NOT NULL", wantArgs: []any{"bar"}, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("groups").As("g") return Dialect(dialect.Postgres). Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")). Where(And(EQ(t1.C("name"), "bar"), NotNull(t2.C("name")))) }(), wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN "groups" AS "g" ON "u"."id" = "g"."user_id" WHERE "u"."name" = $1 AND "g"."name" IS NOT NULL`, wantArgs: []any{"bar"}, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("user_groups").As("ug") return Select(t1.C("id"), As(Count("`*`"), "group_count")). From(t1). LeftJoin(t2). On(t1.C("id"), t2.C("user_id")). GroupBy(t1.C("id")) }(), wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`", }, { input: func() Querier { t1 := Table("users").As("u") t2 := Table("user_groups").As("ug") return Select(t1.C("id"), As(Count("`*`"), "group_count")). From(t1). LeftJoin(t2). OnP(P(func(b *Builder) { b.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("user_id")) })). GroupBy(t1.C("id")).Clone() }(), wantQuery: "SELECT `u`.`id`, COUNT(`*`) AS `group_count` FROM `users` AS `u` LEFT JOIN `user_groups` AS `ug` ON `u`.`id` = `ug`.`user_id` GROUP BY `u`.`id`", }, { input: func() Querier { t1 := Table("groups").As("g") t2 := Table("user_groups").As("ug") return Select(t1.C("id"), As(Count("`*`"), "user_count")). From(t1). RightJoin(t2). On(t1.C("id"), t2.C("group_id")). GroupBy(t1.C("id")) }(), wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` RIGHT JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`", }, { input: func() Querier { t1 := Table("groups").As("g") t2 := Table("user_groups").As("ug") return Select(t1.C("id"), As(Count("`*`"), "user_count")). From(t1). FullJoin(t2). On(t1.C("id"), t2.C("group_id")). GroupBy(t1.C("id")) }(), wantQuery: "SELECT `g`.`id`, COUNT(`*`) AS `user_count` FROM `groups` AS `g` FULL JOIN `user_groups` AS `ug` ON `g`.`id` = `ug`.`group_id` GROUP BY `g`.`id`", }, { input: func() Querier { t1 := Table("users").As("u") return Select(t1.Columns("name", "age")...).From(t1) }(), wantQuery: "SELECT `u`.`name`, `u`.`age` FROM `users` AS `u`", }, { input: func() Querier { t1 := Table("users").As("u") return Dialect(dialect.Postgres). Select(t1.Columns("name", "age")...).From(t1) }(), wantQuery: `SELECT "u"."name", "u"."age" FROM "users" AS "u"`, }, { input: func() Querier { t1 := Dialect(dialect.Postgres). Table("users").As("u") return Dialect(dialect.Postgres). Select(t1.Columns("name", "age")...).From(t1) }(), wantQuery: `SELECT "u"."name", "u"."age" FROM "users" AS "u"`, }, { input: func() Querier { t1 := Table("users").As("u") t2 := Select().From(Table("groups")).Where(EQ("user_id", 10)).As("g") return Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")) }(), wantQuery: "SELECT `u`.`id`, `g`.`name` FROM `users` AS `u` JOIN (SELECT * FROM `groups` WHERE `user_id` = ?) AS `g` ON `u`.`id` = `g`.`user_id`", wantArgs: []any{10}, }, { input: func() Querier { d := Dialect(dialect.Postgres) t1 := d.Table("users").As("u") t2 := d.Select().From(Table("groups")).Where(EQ("user_id", 10)).As("g") return d.Select(t1.C("id"), t2.C("name")). From(t1). Join(t2). On(t1.C("id"), t2.C("user_id")) }(), wantQuery: `SELECT "u"."id", "g"."name" FROM "users" AS "u" JOIN (SELECT * FROM "groups" WHERE "user_id" = $1) AS "g" ON "u"."id" = "g"."user_id"`, wantArgs: []any{10}, }, { input: func() Querier { t1 := Table("users") t2 := Table("groups") t3 := Table("user_groups") return Select(t1.C("*")).From(t1). Join(t3).On(t1.C("id"), t3.C("user_id")). Join(t2).On(t2.C("id"), t3.C("group_id")) }(), wantQuery: "SELECT `users`.* FROM `users` JOIN `user_groups` AS `t1` ON `users`.`id` = `t1`.`user_id` JOIN `groups` AS `t2` ON `t2`.`id` = `t1`.`group_id`", }, { input: func() Querier { d := Dialect(dialect.Postgres) t1 := d.Table("users") t2 := d.Table("groups") t3 := d.Table("user_groups") return d.Select(t1.C("*")).From(t1). Join(t3).On(t1.C("id"), t3.C("user_id")). Join(t2).On(t2.C("id"), t3.C("group_id")) }(), wantQuery: `SELECT "users".* FROM "users" JOIN "user_groups" AS "t1" ON "users"."id" = "t1"."user_id" JOIN "groups" AS "t2" ON "t2"."id" = "t1"."group_id"`, }, { input: func() Querier { selector := Select().Where(Or(EQ("name", "foo"), EQ("name", "bar"))) return Delete("users").FromSelect(selector) }(), wantQuery: "DELETE FROM `users` WHERE `name` = ? OR `name` = ?", wantArgs: []any{"foo", "bar"}, }, { input: func() Querier { d := Dialect(dialect.Postgres) selector := d.Select().Where(Or(EQ("name", "foo"), EQ("name", "bar"))) return d.Delete("users").FromSelect(selector) }(), wantQuery: `DELETE FROM "users" WHERE "name" = $1 OR "name" = $2`, wantArgs: []any{"foo", "bar"}, }, { input: func() Querier { selector := Select().From(Table("users")).As("t") return selector.Select(selector.C("name")) }(), wantQuery: "SELECT `t`.`name` FROM `users`", }, { input: func() Querier { selector := Dialect(dialect.Postgres). Select().From(Table("users")).As("t") return selector.Select(selector.C("name")) }(), wantQuery: `SELECT "t"."name" FROM "users"`, }, { input: func() Querier { selector := Select().From(Table("groups")).Where(EQ("name", "foo")) return Delete("users").FromSelect(selector) }(), wantQuery: "DELETE FROM `groups` WHERE `name` = ?", wantArgs: []any{"foo"}, }, { input: func() Querier { d := Dialect(dialect.Postgres) selector := d.Select().From(Table("groups")).Where(EQ("name", "foo")) return d.Delete("users").FromSelect(selector) }(), wantQuery: `DELETE FROM "groups" WHERE "name" = $1`, wantArgs: []any{"foo"}, }, { input: func() Querier { selector := Select() return Delete("users").FromSelect(selector) }(), wantQuery: "DELETE FROM `users`", }, { input: func() Querier { d := Dialect(dialect.Postgres) selector := d.Select() return d.Delete("users").FromSelect(selector) }(), wantQuery: `DELETE FROM "users"`, }, { input: Select(). From(Table("users")). Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))), wantQuery: "SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)", wantArgs: []any{"foo", "bar"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))), wantQuery: `SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)`, wantArgs: []any{"foo", "bar"}, }, { input: Select(). From(Table("users")). Where(Or(EqualFold("name", "BAR"), EqualFold("name", "BAZ"))), wantQuery: "SELECT * FROM `users` WHERE LOWER(`name`) = ? OR LOWER(`name`) = ?", wantArgs: []any{"bar", "baz"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(Or(EqualFold("name", "BAR"), EqualFold("name", "BAZ"))), wantQuery: `SELECT * FROM "users" WHERE "name" ILIKE $1 OR "name" ILIKE $2`, wantArgs: []any{"bar", "baz"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(Or(EqualFold("name", "BAR%"), EqualFold("name", "%BAZ"))), wantQuery: `SELECT * FROM "users" WHERE "name" ILIKE $1 OR "name" ILIKE $2`, wantArgs: []any{"bar\\%", "\\%baz"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(Or(EqualFold("name", "BAR\\"), EqualFold("name", "\\BAZ"))), wantQuery: `SELECT * FROM "users" WHERE "name" ILIKE $1 OR "name" ILIKE $2`, wantArgs: []any{"bar\\\\", "\\\\baz"}, }, { input: Dialect(dialect.MySQL). Select(). From(Table("users")). Where(Or(EqualFold("name", "BAR"), EqualFold("name", "BAZ"))), wantQuery: "SELECT * FROM `users` WHERE `name` COLLATE utf8mb4_general_ci = ? OR `name` COLLATE utf8mb4_general_ci = ?", wantArgs: []any{"bar", "baz"}, }, { input: Dialect(dialect.SQLite). Select(). From(Table("users")). Where(And(ContainsFold("name", "Ariel"), ContainsFold("nick", "Bar"))), wantQuery: "SELECT * FROM `users` WHERE LOWER(`name`) LIKE ? AND LOWER(`nick`) LIKE ?", wantArgs: []any{"%ariel%", "%bar%"}, }, { input: Dialect(dialect.Postgres). Select(). From(Table("users")). Where(And(ContainsFold("name", "Ariel"), ContainsFold("nick", "Bar"))), wantQuery: `SELECT * FROM "users" WHERE "name" ILIKE $1 AND "nick" ILIKE $2`, wantArgs: []any{"%ariel%", "%bar%"}, }, { input: Dialect(dialect.MySQL). Select(). From(Table("users")). Where(And(ContainsFold("name", "Ariel"), ContainsFold("nick", "Bar"))), wantQuery: "SELECT * FROM `users` WHERE `name` COLLATE utf8mb4_general_ci LIKE ? AND `nick` COLLATE utf8mb4_general_ci LIKE ?", wantArgs: []any{"%ariel%", "%bar%"}, }, { input: func() Querier { s1 := Select(). From(Table("users")). Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))) return Queries{With("users_view").As(s1), Select("name").From(Table("users_view"))} }(), wantQuery: "WITH `users_view` AS (SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)) SELECT `name` FROM `users_view`", wantArgs: []any{"foo", "bar"}, }, { input: func() Querier { d := Dialect(dialect.Postgres) s1 := d.Select(). From(Table("users")). Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))) return Queries{d.With("users_view").As(s1), d.Select("name").From(Table("users_view"))} }(), wantQuery: `WITH "users_view" AS (SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)) SELECT "name" FROM "users_view"`, wantArgs: []any{"foo", "bar"}, }, { input: func() Querier { s1 := Select().From(Table("users")).Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))).As("users_view") return Select("name").From(s1) }(), wantQuery: "SELECT `name` FROM (SELECT * FROM `users` WHERE NOT (`name` = ? AND `age` = ?)) AS `users_view`", wantArgs: []any{"foo", "bar"}, }, { input: func() Querier { d := Dialect(dialect.Postgres) s1 := d.Select().From(Table("users")).Where(Not(And(EQ("name", "foo"), EQ("age", "bar")))).As("users_view") return d.Select("name").From(s1) }(), wantQuery: `SELECT "name" FROM (SELECT * FROM "users" WHERE NOT ("name" = $1 AND "age" = $2)) AS "users_view"`, wantArgs: []any{"foo", "bar"}, }, { input: func() Querier { t1 := Table("users") return Select(). From(t1). Where(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro")))) }(), wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `owner_id` FROM `pets` WHERE `name` = ?)", wantArgs: []any{"pedro"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro")))) }(), wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $1)`, wantArgs: []any{"pedro"}, }, { input: func() Querier { t1 := Table("users") return Select(). From(t1). Where(Not(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro"))))) }(), wantQuery: "SELECT * FROM `users` WHERE NOT (`users`.`id` IN (SELECT `owner_id` FROM `pets` WHERE `name` = ?))", wantArgs: []any{"pedro"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(Not(In(t1.C("id"), Select("owner_id").From(Table("pets")).Where(EQ("name", "pedro"))))) }(), wantQuery: `SELECT * FROM "users" WHERE NOT ("users"."id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $1))`, wantArgs: []any{"pedro"}, }, { input: Select().Count().From(Table("users")), wantQuery: "SELECT COUNT(*) FROM `users`", }, { input: Dialect(dialect.Postgres). Select().Count().From(Table("users")), wantQuery: `SELECT COUNT(*) FROM "users"`, }, { input: Select().Count(Distinct("id")).From(Table("users")), wantQuery: "SELECT COUNT(DISTINCT `id`) FROM `users`", }, { input: Dialect(dialect.Postgres). Select().Count(Distinct("id")).From(Table("users")), wantQuery: `SELECT COUNT(DISTINCT "id") FROM "users"`, }, { input: func() Querier { t1 := Table("users") t2 := Select().From(Table("groups")) t3 := Select().Count().From(t1).Join(t1).On(t2.C("id"), t1.C("blocked_id")) return t3.Count(Distinct(t3.Columns("id", "name")...)) }(), wantQuery: "SELECT COUNT(DISTINCT `t1`.`id`, `t1`.`name`) FROM `users` AS `t1` JOIN `users` AS `t1` ON `groups`.`id` = `t1`.`blocked_id`", }, { input: func() Querier { d := Dialect(dialect.Postgres) t1 := d.Table("users") t2 := d.Select().From(Table("groups")) t3 := d.Select().Count().From(t1).Join(t1).On(t2.C("id"), t1.C("blocked_id")) return t3.Count(Distinct(t3.Columns("id", "name")...)) }(), wantQuery: `SELECT COUNT(DISTINCT "t1"."id", "t1"."name") FROM "users" AS "t1" JOIN "users" AS "t1" ON "groups"."id" = "t1"."blocked_id"`, }, { input: Select(Sum("age"), Min("age")).From(Table("users")), wantQuery: "SELECT SUM(`age`), MIN(`age`) FROM `users`", }, { input: Dialect(dialect.Postgres). Select(Sum("age"), Min("age")). From(Table("users")), wantQuery: `SELECT SUM("age"), MIN("age") FROM "users"`, }, { input: func() Querier { t1 := Table("users").As("u") return Select(As(Max(t1.C("age")), "max_age")).From(t1) }(), wantQuery: "SELECT MAX(`u`.`age`) AS `max_age` FROM `users` AS `u`", }, { input: func() Querier { t1 := Table("users").As("u") return Dialect(dialect.Postgres). Select(As(Max(t1.C("age")), "max_age")). From(t1) }(), wantQuery: `SELECT MAX("u"."age") AS "max_age" FROM "users" AS "u"`, }, { input: Select("name", Count("*")). From(Table("users")). GroupBy("name"), wantQuery: "SELECT `name`, COUNT(*) FROM `users` GROUP BY `name`", }, { input: Dialect(dialect.Postgres). Select("name", Count("*")). From(Table("users")). GroupBy("name"), wantQuery: `SELECT "name", COUNT(*) FROM "users" GROUP BY "name"`, }, { input: Select("name", Count("*")). From(Table("users")). GroupBy("name"). OrderBy("name"), wantQuery: "SELECT `name`, COUNT(*) FROM `users` GROUP BY `name` ORDER BY `name`", }, { input: Dialect(dialect.Postgres). Select("name", Count("*")). From(Table("users")). GroupBy("name"). OrderBy("name"), wantQuery: `SELECT "name", COUNT(*) FROM "users" GROUP BY "name" ORDER BY "name"`, }, { input: Select("name", "age", Count("*")). From(Table("users")). GroupBy("name", "age"). OrderBy(Desc("name"), "age"), wantQuery: "SELECT `name`, `age`, COUNT(*) FROM `users` GROUP BY `name`, `age` ORDER BY `name` DESC, `age`", }, { input: Dialect(dialect.Postgres). Select("name", "age", Count("*")). From(Table("users")). GroupBy("name", "age"). OrderBy(Desc("name"), "age"), wantQuery: `SELECT "name", "age", COUNT(*) FROM "users" GROUP BY "name", "age" ORDER BY "name" DESC, "age"`, }, { input: Select("*"). From(Table("users")). Limit(1), wantQuery: "SELECT * FROM `users` LIMIT 1", }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("users")). Limit(1), wantQuery: `SELECT * FROM "users" LIMIT 1`, }, { input: Select("age").Distinct().From(Table("users")), wantQuery: "SELECT DISTINCT `age` FROM `users`", }, { input: Dialect(dialect.Postgres). Select("age"). Distinct(). From(Table("users")), wantQuery: `SELECT DISTINCT "age" FROM "users"`, }, { input: Select("age", "name").From(Table("users")).Distinct().OrderBy("name"), wantQuery: "SELECT DISTINCT `age`, `name` FROM `users` ORDER BY `name`", }, { input: Dialect(dialect.Postgres). Select("age", "name"). From(Table("users")). Distinct(). OrderBy("name"), wantQuery: `SELECT DISTINCT "age", "name" FROM "users" ORDER BY "name"`, }, { input: Select("age").From(Table("users")).Where(EQ("name", "foo")).Or().Where(EQ("name", "bar")), wantQuery: "SELECT `age` FROM `users` WHERE `name` = ? OR `name` = ?", wantArgs: []any{"foo", "bar"}, }, { input: Dialect(dialect.Postgres). Select("age"). From(Table("users")). Where(EQ("name", "foo")).Or().Where(EQ("name", "bar")), wantQuery: `SELECT "age" FROM "users" WHERE "name" = $1 OR "name" = $2`, wantArgs: []any{"foo", "bar"}, }, { input: Queries{With("users_view").As(Select().From(Table("users"))), Select().From(Table("users_view"))}, wantQuery: "WITH `users_view` AS (SELECT * FROM `users`) SELECT * FROM `users_view`", }, { input: func() Querier { base := Select("*").From(Table("groups")) return Queries{With("groups").As(base.Clone().Where(EQ("name", "bar"))), base.Select("age")} }(), wantQuery: "WITH `groups` AS (SELECT * FROM `groups` WHERE `name` = ?) SELECT `age` FROM `groups`", wantArgs: []any{"bar"}, }, { input: SelectExpr(Raw("1")), wantQuery: "SELECT 1", }, { input: Select("*").From(SelectExpr(Raw("1")).As("s")), wantQuery: "SELECT * FROM (SELECT 1) AS `s`", }, { input: func() Querier { builder := Dialect(dialect.Postgres) t1 := builder.Table("groups") t2 := builder.Table("users") t3 := builder.Table("user_groups") t4 := builder.Select(t3.C("id")). From(t3). Join(t2). On(t3.C("id"), t2.C("id2")). Where(EQ(t2.C("id"), "baz")) return builder.Select(). From(t1). Join(t4). On(t1.C("id"), t4.C("id")).Limit(1) }(), wantQuery: `SELECT * FROM "groups" JOIN (SELECT "user_groups"."id" FROM "user_groups" JOIN "users" AS "t1" ON "user_groups"."id" = "t1"."id2" WHERE "t1"."id" = $1) AS "t1" ON "groups"."id" = "t1"."id" LIMIT 1`, wantArgs: []any{"baz"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(CompositeGT(t1.Columns("id", "name"), 1, "Ariel")) }(), wantQuery: `SELECT * FROM "users" WHERE ("users"."id", "users"."name") > ($1, $2)`, wantArgs: []any{1, "Ariel"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(And(EQ("name", "Ariel"), CompositeGT(t1.Columns("id", "name"), 1, "Ariel"))) }(), wantQuery: `SELECT * FROM "users" WHERE "name" = $1 AND ("users"."id", "users"."name") > ($2, $3)`, wantArgs: []any{"Ariel", 1, "Ariel"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(And(EQ("name", "Ariel"), Or(EQ("surname", "Doe"), CompositeGT(t1.Columns("id", "name"), 1, "Ariel")))) }(), wantQuery: `SELECT * FROM "users" WHERE "name" = $1 AND ("surname" = $2 OR ("users"."id", "users"."name") > ($3, $4))`, wantArgs: []any{"Ariel", "Doe", 1, "Ariel"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(Table("users")). Where(And(EQ("name", "Ariel"), CompositeLT(t1.Columns("id", "name"), 1, "Ariel"))) }(), wantQuery: `SELECT * FROM "users" WHERE "name" = $1 AND ("users"."id", "users"."name") < ($2, $3)`, wantArgs: []any{"Ariel", 1, "Ariel"}, }, { input: CreateIndex("name_index").Table("users").Column("name"), wantQuery: "CREATE INDEX `name_index` ON `users`(`name`)", }, { input: Dialect(dialect.Postgres). CreateIndex("name_index"). Table("users"). Column("name"), wantQuery: `CREATE INDEX "name_index" ON "users"("name")`, }, { input: Dialect(dialect.Postgres). CreateIndex("name_index"). IfNotExists(). Table("users"). Column("name"), wantQuery: `CREATE INDEX IF NOT EXISTS "name_index" ON "users"("name")`, }, { input: Dialect(dialect.Postgres). CreateIndex("name_index"). IfNotExists(). Table("users"). Using("gin"). Column("name"), wantQuery: `CREATE INDEX IF NOT EXISTS "name_index" ON "users" USING "gin"("name")`, }, { input: Dialect(dialect.MySQL). CreateIndex("name_index"). IfNotExists(). Table("users"). Using("HASH"). Column("name"), wantQuery: "CREATE INDEX IF NOT EXISTS `name_index` ON `users`(`name`) USING HASH", }, { input: CreateIndex("unique_name").Unique().Table("users").Columns("first", "last"), wantQuery: "CREATE UNIQUE INDEX `unique_name` ON `users`(`first`, `last`)", }, { input: Dialect(dialect.Postgres). CreateIndex("unique_name"). Unique(). Table("users"). Columns("first", "last"), wantQuery: `CREATE UNIQUE INDEX "unique_name" ON "users"("first", "last")`, }, { input: DropIndex("name_index"), wantQuery: "DROP INDEX `name_index`", }, { input: Dialect(dialect.Postgres). DropIndex("name_index"), wantQuery: `DROP INDEX "name_index"`, }, { input: DropIndex("name_index").Table("users"), wantQuery: "DROP INDEX `name_index` ON `users`", }, { input: Select(). From(Table("pragma_table_info('t1')").Unquote()). OrderBy("pk"), wantQuery: "SELECT * FROM pragma_table_info('t1') ORDER BY `pk`", }, { input: AlterTable("users"). AddColumn(Column("spouse").Type("integer"). Constraint(ForeignKey("user_spouse"). Reference(Reference().Table("users").Columns("id")). OnDelete("SET NULL"))), wantQuery: "ALTER TABLE `users` ADD COLUMN `spouse` integer CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE SET NULL", }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("users")). Where(Or( And(EQ("id", 1), InInts("group_id", 2, 3)), And(EQ("id", 2), InValues("group_id", 4, 5)), )). Where(And( Or(EQ("a", "a"), And(EQ("b", "b"), EQ("c", "c"))), Not(Or(IsNull("d"), NotNull("e"))), )). Or(). Where(And(NEQ("f", "f"), NEQ("g", "g"))), wantQuery: strings.NewReplacer("\n", "", "\t", "").Replace(` SELECT * FROM "users" WHERE ( (("id" = $1 AND "group_id" IN ($2, $3)) OR ("id" = $4 AND "group_id" IN ($5, $6))) AND (("a" = $7 OR ("b" = $8 AND "c" = $9)) AND (NOT ("d" IS NULL OR "e" IS NOT NULL))) ) OR ("f" <> $10 AND "g" <> $11)`), wantArgs: []any{1, 2, 3, 2, 4, 5, "a", "b", "c", "f", "g"}, }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("test")). Where(P(func(b *Builder) { b.WriteString("nlevel(").Ident("path").WriteByte(')').WriteOp(OpGT).Arg(1) })), wantQuery: `SELECT * FROM "test" WHERE nlevel("path") > $1`, wantArgs: []any{1}, }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("test")). Where(P(func(b *Builder) { b.WriteString("nlevel(").Ident("path").WriteByte(')').WriteOp(OpGT).Arg(1) })), wantQuery: `SELECT * FROM "test" WHERE nlevel("path") > $1`, wantArgs: []any{1}, }, { input: Select("id").From(Table("users")).Where(ExprP("DATE(last_login_at) >= ?", "2022-05-03")), wantQuery: "SELECT `id` FROM `users` WHERE DATE(last_login_at) >= ?", wantArgs: []any{"2022-05-03"}, }, { input: Select("id"). From(Table("users")). Where(P(func(b *Builder) { b.WriteString("DATE(").Ident("last_login_at").WriteString(") >= ").Arg("2022-05-03") })), wantQuery: "SELECT `id` FROM `users` WHERE DATE(`last_login_at`) >= ?", wantArgs: []any{"2022-05-03"}, }, { input: Select("id").From(Table("events")).Where(ExprP("DATE_ADD(date, INTERVAL duration MINUTE) BETWEEN ? AND ?", "2022-05-03", "2022-05-04")), wantQuery: "SELECT `id` FROM `events` WHERE DATE_ADD(date, INTERVAL duration MINUTE) BETWEEN ? AND ?", wantArgs: []any{"2022-05-03", "2022-05-04"}, }, { input: Select("id"). From(Table("events")). Where(P(func(b *Builder) { b.WriteString("DATE_ADD(date, INTERVAL duration MINUTE) BETWEEN ").Arg("2022-05-03").WriteString(" AND ").Arg("2022-05-04") })), wantQuery: "SELECT `id` FROM `events` WHERE DATE_ADD(date, INTERVAL duration MINUTE) BETWEEN ? AND ?", wantArgs: []any{"2022-05-03", "2022-05-04"}, }, { input: func() Querier { t1, t2 := Table("users").Schema("s1"), Table("pets").Schema("s2") return Select("*"). From(t1).Join(t2). OnP(P(func(b *Builder) { b.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("owner_id")) })). Where(EQ(t2.C("name"), "pedro")) }(), wantQuery: "SELECT * FROM `s1`.`users` JOIN `s2`.`pets` AS `t1` ON `s1`.`users`.`id` = `t1`.`owner_id` WHERE `t1`.`name` = ?", wantArgs: []any{"pedro"}, }, { input: func() Querier { t1, t2 := Table("users").Schema("s1"), Table("pets").Schema("s2") sel := Select("*"). From(t1).Join(t2). OnP(P(func(b *Builder) { b.Ident(t1.C("id")).WriteOp(OpEQ).Ident(t2.C("owner_id")) })). Where(EQ(t2.C("name"), "pedro")) sel.SetDialect(dialect.SQLite) return sel }(), wantQuery: "SELECT * FROM `users` JOIN `pets` AS `t1` ON `users`.`id` = `t1`.`owner_id` WHERE `t1`.`name` = ?", wantArgs: []any{"pedro"}, }, { input: Dialect(dialect.Postgres). Select("*"). From(Table("users")). Where(ExprP("name = $1", "pedro")). Where(P(func(b *Builder) { b.Join(Expr("name = $2", "pedro")) })). Where(EQ("name", "pedro")). Where( And( In( "id", Select("owner_id"). From(Table("pets")). Where(EQ("name", "luna")), ), EQ("active", true), ), ), wantQuery: `SELECT * FROM "users" WHERE ((name = $1 AND name = $2) AND "name" = $3) AND ("id" IN (SELECT "owner_id" FROM "pets" WHERE "name" = $4) AND "active")`, wantArgs: []any{"pedro", "pedro", "pedro", "luna"}, }, { input: func() Querier { t1 := Table("users") return Dialect(dialect.Postgres). Select(). From(t1). Where(ColumnsEQ(t1.C("id1"), t1.C("id2"))). Where(ColumnsNEQ(t1.C("id1"), t1.C("id2"))). Where(ColumnsGT(t1.C("id1"), t1.C("id2"))). Where(ColumnsGTE(t1.C("id1"), t1.C("id2"))). Where(ColumnsLT(t1.C("id1"), t1.C("id2"))). Where(ColumnsLTE(t1.C("id1"), t1.C("id2"))) }(), wantQuery: strings.ReplaceAll(` SELECT * FROM "users" WHERE (((("users"."id1" = "users"."id2" AND "users"."id1" <> "users"."id2") AND "users"."id1" > "users"."id2") AND "users"."id1" >= "users"."id2") AND "users"."id1" < "users"."id2") AND "users"."id1" <= "users"."id2"`, "\n", ""), }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { query, args := tt.input.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } func TestBuilder_Err(t *testing.T) { b := Select("i-") require.NoError(t, b.Err()) b.AddError(fmt.Errorf("invalid")) require.EqualError(t, b.Err(), "invalid") b.AddError(fmt.Errorf("unexpected")) require.EqualError(t, b.Err(), "invalid; unexpected") b.Where(P(func(builder *Builder) { builder.AddError(fmt.Errorf("inner")) })) _, _ = b.Query() require.EqualError(t, b.Err(), "invalid; unexpected; inner") } func TestSelector_OrderByExpr(t *testing.T) { query, args := Select("*"). From(Table("users")). Where(GT("age", 28)). OrderBy("name"). OrderExpr(Expr("CASE WHEN id=? THEN id WHEN id=? THEN name END DESC", 1, 2)). Query() require.Equal(t, "SELECT * FROM `users` WHERE `age` > ? ORDER BY `name`, CASE WHEN id=? THEN id WHEN id=? THEN name END DESC", query) require.Equal(t, []any{28, 1, 2}, args) } func TestSelector_SelectExpr(t *testing.T) { query, args := SelectExpr( Expr("?", "a"), ExprFunc(func(b *Builder) { b.Ident("first_name").WriteOp(OpAdd).Ident("last_name") }), ExprFunc(func(b *Builder) { b.WriteString("COALESCE(").Ident("age").Comma().Arg(0).WriteByte(')') }), Expr("?", "b"), ).From(Table("users")).Query() require.Equal(t, "SELECT ?, `first_name` + `last_name`, COALESCE(`age`, ?), ? FROM `users`", query) require.Equal(t, []any{"a", 0, "b"}, args) query, args = Dialect(dialect.Postgres). Select("name"). AppendSelectExpr( Expr("age + $1", 1), ExprFunc(func(b *Builder) { b.Nested(func(b *Builder) { b.WriteString("similarity(").Ident("name").Comma().Arg("A").WriteByte(')') b.WriteOp(OpAdd) b.WriteString("similarity(").Ident("desc").Comma().Arg("D").WriteByte(')') }) b.WriteString(" AS s") }), Expr("rank + $4", 10), ). From(Table("users")). Query() require.Equal(t, `SELECT "name", age + $1, (similarity("name", $2) + similarity("desc", $3)) AS s, rank + $4 FROM "users"`, query) require.Equal(t, []any{1, "A", "D", 10}, args) } func TestSelector_Union(t *testing.T) { query, args := Dialect(dialect.Postgres). Select("*"). From(Table("users")). Where(EQ("active", true)). Union( Select("*"). From(Table("old_users1")). Where( And( EQ("is_active", true), GT("age", 20), ), ), ). UnionAll( Select("*"). From(Table("old_users2")). Where( And( EQ("is_active", "true"), LT("age", 18), ), ), ). Query() require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" WHERE "is_active" AND "age" > $1 UNION ALL SELECT * FROM "old_users2" WHERE "is_active" = $2 AND "age" < $3`, query) require.Equal(t, []any{20, "true", 18}, args) t1, t2, t3 := Table("files"), Table("files"), Table("path") n := Queries{ WithRecursive("path", "id", "name", "parent_id"). As(Select(t1.Columns("id", "name", "parent_id")...). From(t1). Where( And( IsNull(t1.C("parent_id")), EQ(t1.C("deleted"), false), ), ). UnionAll( Select(t2.Columns("id", "name", "parent_id")...). From(t2). Join(t3). On(t2.C("parent_id"), t3.C("id")). Where( EQ(t2.C("deleted"), false), ), ), ), Select(t3.Columns("id", "name", "parent_id")...). From(t3), } query, args = n.Query() require.Equal(t, "WITH RECURSIVE `path`(`id`, `name`, `parent_id`) AS (SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` WHERE `files`.`parent_id` IS NULL AND NOT `files`.`deleted` UNION ALL SELECT `files`.`id`, `files`.`name`, `files`.`parent_id` FROM `files` JOIN `path` AS `t1` ON `files`.`parent_id` = `t1`.`id` WHERE NOT `files`.`deleted`) SELECT `t1`.`id`, `t1`.`name`, `t1`.`parent_id` FROM `path` AS `t1`", query) require.Nil(t, args) } func TestBuilderContext(t *testing.T) { type key string want := "myval" ctx := context.WithValue(context.Background(), key("mykey"), want) sel := Dialect(dialect.Postgres).Select().WithContext(ctx) if got := sel.Context().Value(key("mykey")).(string); got != want { t.Fatalf("expected selector context key to be %q but got %q", want, got) } if got := sel.Clone().Context().Value(key("mykey")).(string); got != want { t.Fatalf("expected cloned selector context key to be %q but got %q", want, got) } } type point struct { xy []float64 *testing.T } // FormatParam implements the sql.ParamFormatter interface. func (p point) FormatParam(placeholder string, info *StmtInfo) string { require.Equal(p.T, dialect.MySQL, info.Dialect) return "ST_GeomFromWKB(" + placeholder + ")" } // Value implements the driver.Valuer interface. func (p point) Value() (driver.Value, error) { return p.xy, nil } func TestParamFormatter(t *testing.T) { p := point{xy: []float64{1, 2}, T: t} query, args := Dialect(dialect.MySQL). Select(). From(Table("users")). Where(EQ("point", p)). Query() require.Equal(t, "SELECT * FROM `users` WHERE `point` = ST_GeomFromWKB(?)", query) require.Equal(t, p, args[0]) } func TestSelectWithLock(t *testing.T) { query, args := Dialect(dialect.MySQL). Select(). From(Table("users")). Where(EQ("id", 1)). ForUpdate(). Query() require.Equal(t, "SELECT * FROM `users` WHERE `id` = ? FOR UPDATE", query) require.Equal(t, 1, args[0]) query, args = Dialect(dialect.Postgres). Select(). From(Table("users")). Where(EQ("id", 1)). ForUpdate(WithLockAction(NoWait)). Query() require.Equal(t, `SELECT * FROM "users" WHERE "id" = $1 FOR UPDATE NOWAIT`, query) require.Equal(t, 1, args[0]) users, pets := Table("users"), Table("pets") query, args = Dialect(dialect.Postgres). Select(). From(pets). Join(users). On(pets.C("owner_id"), users.C("id")). Where(EQ("id", 20)). ForUpdate( WithLockAction(SkipLocked), WithLockTables("pets"), ). Query() require.Equal(t, `SELECT * FROM "pets" JOIN "users" AS "t1" ON "pets"."owner_id" = "t1"."id" WHERE "id" = $1 FOR UPDATE OF "pets" SKIP LOCKED`, query) require.Equal(t, 20, args[0]) query, args = Dialect(dialect.MySQL). Select(). From(Table("users")). Where(EQ("id", 20)). ForShare(WithLockClause("LOCK IN SHARE MODE")). Query() require.Equal(t, "SELECT * FROM `users` WHERE `id` = ? LOCK IN SHARE MODE", query) require.Equal(t, 20, args[0]) s := Dialect(dialect.SQLite). Select(). From(Table("users")). Where(EQ("id", 1)). ForUpdate() s.Query() require.EqualError(t, s.Err(), "sql: SELECT .. FOR UPDATE/SHARE not supported in SQLite") } func TestSelector_UnionOrderBy(t *testing.T) { table := Table("users") query, _ := Dialect(dialect.Postgres). Select("*"). From(table). Where(EQ("active", true)). Union(Select("*").From(Table("old_users1"))). OrderBy(table.C("whatever")). Query() require.Equal(t, `SELECT * FROM "users" WHERE "active" UNION SELECT * FROM "old_users1" ORDER BY "users"."whatever"`, query) } func TestUpdateBuilder_SetExpr(t *testing.T) { d := Dialect(dialect.Postgres) excluded := d.Table("excluded") query, args := d.Update("users"). Set("name", "Ariel"). Set("active", Expr("NOT(active)")). Set("age", Expr(excluded.C("age"))). Set("x", ExprFunc(func(b *Builder) { b.WriteString(excluded.C("x")).WriteString(" || ' (formerly ' || ").Ident("x").WriteString(" || ')'") })). Set("y", ExprFunc(func(b *Builder) { b.Arg("~").WriteOp(OpAdd).WriteString(excluded.C("y")).WriteOp(OpAdd).Arg("~") })). Query() require.Equal(t, `UPDATE "users" SET "name" = $1, "active" = NOT(active), "age" = "excluded"."age", "x" = "excluded"."x" || ' (formerly ' || "x" || ')', "y" = $2 + "excluded"."y" + $3`, query) require.Equal(t, []any{"Ariel", "~", "~"}, args) } func TestInsert_OnConflict(t *testing.T) { t.Run("Postgres", func(t *testing.T) { // And SQLite. query, args := Dialect(dialect.Postgres). Insert("users"). Columns("id", "email", "creation_time"). Values("1", "user@example.com", 1633279231). OnConflict( ConflictColumns("email"), ConflictWhere(EQ("name", "Ariel")), ResolveWithNewValues(), // Update all new values excepts id field. ResolveWith(func(u *UpdateSet) { u.SetIgnore("id") u.SetIgnore("creation_time") u.Add("version", 1) }), UpdateWhere(NEQ("updated_at", 0)), ). Query() require.Equal(t, `INSERT INTO "users" ("id", "email", "creation_time") VALUES ($1, $2, $3) ON CONFLICT ("email") WHERE "name" = $4 DO UPDATE SET "id" = "users"."id", "email" = "excluded"."email", "creation_time" = "users"."creation_time", "version" = COALESCE("users"."version", 0) + $5 WHERE "users"."updated_at" <> $6`, query) require.Equal(t, []any{"1", "user@example.com", 1633279231, "Ariel", 1, 0}, args) query, args = Dialect(dialect.Postgres). Insert("users"). Columns("id", "name"). Values("1", "Mashraki"). OnConflict( ConflictConstraint("users_pkey"), DoNothing(), ). Query() require.Equal(t, `INSERT INTO "users" ("id", "name") VALUES ($1, $2) ON CONFLICT ON CONSTRAINT "users_pkey" DO NOTHING`, query) require.Equal(t, []any{"1", "Mashraki"}, args) query, args = Dialect(dialect.Postgres). Insert("users"). Columns("id"). Values(1). OnConflict( DoNothing(), ). Query() require.Equal(t, `INSERT INTO "users" ("id") VALUES ($1) ON CONFLICT DO NOTHING`, query) require.Equal(t, []any{1}, args) query, args = Dialect(dialect.Postgres). Insert("users"). Columns("id"). Values(1). OnConflict( ConflictColumns("id"), ResolveWithIgnore(), ). Query() require.Equal(t, `INSERT INTO "users" ("id") VALUES ($1) ON CONFLICT ("id") DO UPDATE SET "id" = "users"."id"`, query) require.Equal(t, []any{1}, args) query, args = Dialect(dialect.Postgres). Insert("users"). Columns("id", "name"). Values(1, "Mashraki"). OnConflict( ConflictColumns("name"), ResolveWith(func(s *UpdateSet) { s.SetExcluded("name") s.SetNull("created_at") }), ). Query() require.Equal(t, `INSERT INTO "users" ("id", "name") VALUES ($1, $2) ON CONFLICT ("name") DO UPDATE SET "created_at" = NULL, "name" = "excluded"."name"`, query) require.Equal(t, []any{1, "Mashraki"}, args) }) t.Run("MySQL", func(t *testing.T) { query, args := Dialect(dialect.MySQL). Insert("users"). Columns("id", "email"). Values("1", "user@example.com"). OnConflict( ResolveWithNewValues(), ). Query() require.Equal(t, "INSERT INTO `users` (`id`, `email`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `id` = VALUES(`id`), `email` = VALUES(`email`)", query) require.Equal(t, []any{"1", "user@example.com"}, args) query, args = Dialect(dialect.MySQL). Insert("users"). Columns("id", "email"). Values("1", "user@example.com"). OnConflict( ResolveWithIgnore(), ). Query() require.Equal(t, "INSERT INTO `users` (`id`, `email`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `id` = `users`.`id`, `email` = `users`.`email`", query) require.Equal(t, []any{"1", "user@example.com"}, args) query, args = Dialect(dialect.MySQL). Insert("users"). Columns("id", "name"). Values("1", "Mashraki"). OnConflict( ResolveWith(func(s *UpdateSet) { s.SetExcluded("name") s.SetNull("created_at") s.Add("version", 1) }), ). Query() require.Equal(t, "INSERT INTO `users` (`id`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `created_at` = NULL, `name` = VALUES(`name`), `version` = COALESCE(`users`.`version`, 0) + ?", query) require.Equal(t, []any{"1", "Mashraki", 1}, args) query, args = Dialect(dialect.MySQL). Insert("users"). Columns("name", "rank"). Values("Mashraki", nil). OnConflict( ResolveWithNewValues(), ResolveWith(func(s *UpdateSet) { s.Set("id", Expr("LAST_INSERT_ID(`id`)")) }), ). Query() require.Equal(t, "INSERT INTO `users` (`name`, `rank`) VALUES (?, NULL) ON DUPLICATE KEY UPDATE `name` = VALUES(`name`), `rank` = VALUES(`rank`), `id` = LAST_INSERT_ID(`id`)", query) require.Equal(t, []any{"Mashraki"}, args) query, args = Dialect(dialect.MySQL). Insert("users"). Columns("name", "rank"). Values("Ariel", 10). Values("Mashraki", nil). OnConflict( ResolveWithNewValues(), ResolveWith(func(s *UpdateSet) { s.Set("id", Expr("LAST_INSERT_ID(`id`)")) }), ). Query() require.Equal(t, "INSERT INTO `users` (`name`, `rank`) VALUES (?, ?), (?, NULL) ON DUPLICATE KEY UPDATE `name` = VALUES(`name`), `rank` = VALUES(`rank`), `id` = LAST_INSERT_ID(`id`)", query) require.Equal(t, []any{"Ariel", 10, "Mashraki"}, args) }) } func TestEscapePatterns(t *testing.T) { q, args := Dialect(dialect.MySQL). Update("users"). SetNull("name"). Where( Or( HasPrefix("nickname", "%a8m%"), HasSuffix("nickname", "_alexsn_"), Contains("nickname", "\\pedro\\"), ContainsFold("nickname", "%AbcD%efg"), ), ). Query() require.Equal(t, "UPDATE `users` SET `name` = NULL WHERE `nickname` LIKE ? OR `nickname` LIKE ? OR `nickname` LIKE ? OR `nickname` COLLATE utf8mb4_general_ci LIKE ?", q) require.Equal(t, []any{"\\%a8m\\%%", "%\\_alexsn\\_", "%\\\\pedro\\\\%", "%\\%abcd\\%efg%"}, args) q, args = Dialect(dialect.SQLite). Update("users"). SetNull("name"). Where( Or( HasPrefix("nickname", "%a8m%"), HasSuffix("nickname", "_alexsn_"), Contains("nickname", "\\pedro\\"), ContainsFold("nickname", "%AbcD%efg"), ), ). Query() require.Equal(t, "UPDATE `users` SET `name` = NULL WHERE `nickname` LIKE ? ESCAPE ? OR `nickname` LIKE ? ESCAPE ? OR `nickname` LIKE ? ESCAPE ? OR LOWER(`nickname`) LIKE ? ESCAPE ?", q) require.Equal(t, []any{"\\%a8m\\%%", "\\", "%\\_alexsn\\_", "\\", "%\\\\pedro\\\\%", "\\", "%\\%abcd\\%efg%", "\\"}, args) } func TestReusePredicates(t *testing.T) { tests := []struct { p *Predicate wantQuery string wantArgs []any }{ { p: EQ("active", false), wantQuery: `SELECT * FROM "users" WHERE NOT "active"`, }, { p: Or( EQ("a", "a"), EQ("b", "b"), ), wantQuery: `SELECT * FROM "users" WHERE "a" = $1 OR "b" = $2`, wantArgs: []any{"a", "b"}, }, { p: Or( EQ("a", "a"), In("b"), ), wantQuery: `SELECT * FROM "users" WHERE "a" = $1 OR FALSE`, wantArgs: []any{"a"}, }, { p: And( EQ("active", true), HasPrefix("name", "foo"), HasSuffix("name", "bar"), Or( In("id", Select("oid").From(Table("audit"))), In("id", Select("oid").From(Table("history"))), ), ), wantQuery: `SELECT * FROM "users" WHERE "active" AND "name" LIKE $1 AND "name" LIKE $2 AND ("id" IN (SELECT "oid" FROM "audit") OR "id" IN (SELECT "oid" FROM "history"))`, wantArgs: []any{"foo%", "%bar"}, }, { p: func() *Predicate { t1 := Table("groups") pivot := Table("user_groups") matches := Select(pivot.C("user_id")). From(pivot). Join(t1). On(pivot.C("group_id"), t1.C("id")). Where(EQ(t1.C("name"), "ent")) return And( GT("balance", 0), In("id", matches), GT("balance", 100), ) }(), wantQuery: `SELECT * FROM "users" WHERE "balance" > $1 AND "id" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."id" WHERE "t1"."name" = $2) AND "balance" > $3`, wantArgs: []any{0, "ent", 100}, }, } for _, tt := range tests { query, args := Dialect(dialect.Postgres).Select().From(Table("users")).Where(tt.p).Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) query, args = Dialect(dialect.Postgres).Select().From(Table("users")).Where(tt.p).Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) } } func TestBoolPredicates(t *testing.T) { t1, t2 := Table("users"), Table("posts") query, args := Select(). From(t1). Join(t2). On(t1.C("id"), t2.C("author_id")). Where( And( EQ(t1.C("active"), true), NEQ(t2.C("deleted"), true), ), ). Query() require.Nil(t, args) require.Equal(t, "SELECT * FROM `users` JOIN `posts` AS `t1` ON `users`.`id` = `t1`.`author_id` WHERE `users`.`active` AND NOT `t1`.`deleted`", query) } func TestWindowFunction(t *testing.T) { posts := Table("posts") base := Select(posts.Columns("id", "content", "author_id")...). From(posts). Where(EQ("active", true)) with := With("active_posts"). As(base). With("selected_posts"). As( Select(). AppendSelect("*"). AppendSelectExprAs( RowNumber().PartitionBy("author_id").OrderBy("id").OrderExpr(Expr("f(`s`)")), "row_number", ). From(Table("active_posts")), ) query, args := Select("*").From(Table("selected_posts")).Where(LTE("row_number", 2)).Prefix(with).Query() require.Equal(t, "WITH `active_posts` AS (SELECT `posts`.`id`, `posts`.`content`, `posts`.`author_id` FROM `posts` WHERE `active`), `selected_posts` AS (SELECT *, (ROW_NUMBER() OVER (PARTITION BY `author_id` ORDER BY `id`, f(`s`))) AS `row_number` FROM `active_posts`) SELECT * FROM `selected_posts` WHERE `row_number` <= ?", query) require.Equal(t, []any{2}, args) } func TestSelector_UnqualifiedColumns(t *testing.T) { t1, t2 := Table("t1"), Table("t2") s := Select(t1.C("a"), t2.C("b")) require.Equal(t, []string{"`t1`.`a`", "`t2`.`b`"}, s.SelectedColumns()) require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns()) d := Dialect(dialect.Postgres) t1, t2 = d.Table("t1"), d.Table("t2") s = d.Select(t1.C("a"), t2.C("b")) require.Equal(t, []string{`"t1"."a"`, `"t2"."b"`}, s.SelectedColumns()) require.Equal(t, []string{"a", "b"}, s.UnqualifiedColumns()) } func TestUpdateBuilder_OrderBy(t *testing.T) { u := Dialect(dialect.MySQL).Update("users").Set("id", Expr("`id` + 1")).OrderBy("id") require.NoError(t, u.Err()) query, args := u.Query() require.Nil(t, args) require.Equal(t, "UPDATE `users` SET `id` = `id` + 1 ORDER BY `id`", query) u = Dialect(dialect.Postgres).Update("users").Set("id", Expr("id + 1")).OrderBy("id") require.Error(t, u.Err()) } func TestUpdateBuilder_WithPrefix(t *testing.T) { u := Dialect(dialect.MySQL). Update("users"). Prefix(ExprFunc(func(b *Builder) { b.WriteString("SET @i = ").Arg(1).WriteByte(';') })). Set("id", Expr("(@i:=@i+1)")). OrderBy("id") require.NoError(t, u.Err()) query, args := u.Query() require.Equal(t, []any{1}, args) require.Equal(t, "SET @i = ?; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query) u = Dialect(dialect.MySQL). Update("users"). Prefix(Expr("SET @i = 1;")). Set("id", Expr("(@i:=@i+1)")). OrderBy("id") require.NoError(t, u.Err()) query, args = u.Query() require.Empty(t, args) require.Equal(t, "SET @i = 1; UPDATE `users` SET `id` = (@i:=@i+1) ORDER BY `id`", query) } func TestMultipleFrom(t *testing.T) { query, args := Dialect(dialect.Postgres). Select("items.*", As("ts_rank_cd(search, search_query)", "rank")). From(Table("items")). AppendFrom(Table("to_tsquery('neutrino|(dark & matter)')").As("search_query")). Where(P(func(b *Builder) { b.WriteString("search @@ search_query") })). OrderBy(Desc("rank")). Query() require.Empty(t, args) require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery('neutrino|(dark & matter)') AS "search_query" WHERE search @@ search_query ORDER BY "rank" DESC`, query) query, args = Dialect(dialect.Postgres). Select("items.*", As("ts_rank_cd(search, search_query)", "rank")). From(Table("items")). AppendFromExpr(Expr("to_tsquery($1) AS search_query", "neutrino|(dark & matter)")). Where(P(func(b *Builder) { b.WriteString("search @@ search_query") })). Query() require.Equal(t, []any{"neutrino|(dark & matter)"}, args) require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery($1) AS search_query WHERE search @@ search_query`, query) query, args = Dialect(dialect.Postgres). Select("items.*", As("ts_rank_cd(search, search_query)", "rank")). From(Table("items")). Where(EQ("value", 10)). AppendFromExpr(ExprFunc(func(b *Builder) { b.WriteString("to_tsquery(").Arg("neutrino|(dark & matter)").WriteString(") AS search_query") })). Where(P(func(b *Builder) { b.WriteString("search @@ search_query") })). Query() require.Equal(t, []any{"neutrino|(dark & matter)", 10}, args) require.Equal(t, `SELECT items.*, ts_rank_cd(search, search_query) AS "rank" FROM "items", to_tsquery($1) AS search_query WHERE "value" = $2 AND search @@ search_query`, query) } ent-0.11.3/dialect/sql/driver.go000066400000000000000000000113551431500740500164140ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sql import ( "context" "database/sql" "database/sql/driver" "fmt" "strings" "entgo.io/ent/dialect" ) // Driver is a dialect.Driver implementation for SQL based databases. type Driver struct { Conn dialect string } // NewDriver creates a new Driver with the given Conn and dialect. func NewDriver(dialect string, c Conn) *Driver { return &Driver{dialect: dialect, Conn: c} } // Open wraps the database/sql.Open method and returns a dialect.Driver that implements the an ent/dialect.Driver interface. func Open(dialect, source string) (*Driver, error) { db, err := sql.Open(dialect, source) if err != nil { return nil, err } return NewDriver(dialect, Conn{db}), nil } // OpenDB wraps the given database/sql.DB method with a Driver. func OpenDB(dialect string, db *sql.DB) *Driver { return NewDriver(dialect, Conn{db}) } // DB returns the underlying *sql.DB instance. func (d Driver) DB() *sql.DB { return d.ExecQuerier.(*sql.DB) } // Dialect implements the dialect.Dialect method. func (d Driver) Dialect() string { // If the underlying driver is wrapped with a telemetry driver. for _, name := range []string{dialect.MySQL, dialect.SQLite, dialect.Postgres} { if strings.HasPrefix(d.dialect, name) { return name } } return d.dialect } // Tx starts and returns a transaction. func (d *Driver) Tx(ctx context.Context) (dialect.Tx, error) { return d.BeginTx(ctx, nil) } // BeginTx starts a transaction with options. func (d *Driver) BeginTx(ctx context.Context, opts *TxOptions) (dialect.Tx, error) { tx, err := d.DB().BeginTx(ctx, opts) if err != nil { return nil, err } return &Tx{ Conn: Conn{tx}, Tx: tx, }, nil } // Close closes the underlying connection. func (d *Driver) Close() error { return d.DB().Close() } // Tx implements dialect.Tx interface. type Tx struct { Conn driver.Tx } // ExecQuerier wraps the standard Exec and Query methods. type ExecQuerier interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } // Conn implements dialect.ExecQuerier given ExecQuerier. type Conn struct { ExecQuerier } // Exec implements the dialect.Exec method. func (c Conn) Exec(ctx context.Context, query string, args, v any) error { argv, ok := args.([]any) if !ok { return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", v) } switch v := v.(type) { case nil: if _, err := c.ExecContext(ctx, query, argv...); err != nil { return err } case *sql.Result: res, err := c.ExecContext(ctx, query, argv...) if err != nil { return err } *v = res default: return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Result", v) } return nil } // Query implements the dialect.Query method. func (c Conn) Query(ctx context.Context, query string, args, v any) error { vr, ok := v.(*Rows) if !ok { return fmt.Errorf("dialect/sql: invalid type %T. expect *sql.Rows", v) } argv, ok := args.([]any) if !ok { return fmt.Errorf("dialect/sql: invalid type %T. expect []any for args", args) } rows, err := c.QueryContext(ctx, query, argv...) if err != nil { return err } *vr = Rows{rows} return nil } var _ dialect.Driver = (*Driver)(nil) type ( // Rows wraps the sql.Rows to avoid locks copy. Rows struct{ ColumnScanner } // Result is an alias to sql.Result. Result = sql.Result // NullBool is an alias to sql.NullBool. NullBool = sql.NullBool // NullInt64 is an alias to sql.NullInt64. NullInt64 = sql.NullInt64 // NullString is an alias to sql.NullString. NullString = sql.NullString // NullFloat64 is an alias to sql.NullFloat64. NullFloat64 = sql.NullFloat64 // NullTime represents a time.Time that may be null. NullTime = sql.NullTime // TxOptions holds the transaction options to be used in DB.BeginTx. TxOptions = sql.TxOptions ) // NullScanner represents an sql.Scanner that may be null. // NullScanner implements the sql.Scanner interface so it can // be used as a scan destination, similar to the types above. type NullScanner struct { S sql.Scanner Valid bool // Valid is true if the Scan value is not NULL. } // Scan implements the Scanner interface. func (n *NullScanner) Scan(value any) error { n.Valid = value != nil if n.Valid { return n.S.Scan(value) } return nil } // ColumnScanner is the interface that wraps the standard // sql.Rows methods used for scanning database rows. type ColumnScanner interface { Close() error ColumnTypes() ([]*sql.ColumnType, error) Columns() ([]string, error) Err() error Next() bool NextResultSet() bool Scan(dest ...any) error } ent-0.11.3/dialect/sql/scan.go000066400000000000000000000161331431500740500160440ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sql import ( "database/sql" "database/sql/driver" "fmt" "reflect" "strings" ) // ScanOne scans one row to the given value. It fails if the rows holds more than 1 row. func ScanOne(rows ColumnScanner, v any) error { columns, err := rows.Columns() if err != nil { return fmt.Errorf("sql/scan: failed getting column names: %w", err) } if n := len(columns); n != 1 { return fmt.Errorf("sql/scan: unexpected number of columns: %d", n) } if !rows.Next() { if err := rows.Err(); err != nil { return err } return sql.ErrNoRows } if err := rows.Scan(v); err != nil { return err } if rows.Next() { return fmt.Errorf("sql/scan: expect exactly one row in result set") } return rows.Err() } // ScanInt64 scans and returns an int64 from the rows. func ScanInt64(rows ColumnScanner) (int64, error) { var n int64 if err := ScanOne(rows, &n); err != nil { return 0, err } return n, nil } // ScanInt scans and returns an int from the rows. func ScanInt(rows ColumnScanner) (int, error) { n, err := ScanInt64(rows) if err != nil { return 0, err } return int(n), nil } // ScanBool scans and returns a boolean from the rows. func ScanBool(rows ColumnScanner) (bool, error) { var b bool if err := ScanOne(rows, &b); err != nil { return false, err } return b, nil } // ScanString scans and returns a string from the rows. func ScanString(rows ColumnScanner) (string, error) { var s string if err := ScanOne(rows, &s); err != nil { return "", err } return s, nil } // ScanValue scans and returns a driver.Value from the rows. func ScanValue(rows ColumnScanner) (driver.Value, error) { var v driver.Value if err := ScanOne(rows, &v); err != nil { return "", err } return v, nil } // ScanSlice scans the given ColumnScanner (basically, sql.Row or sql.Rows) into the given slice. func ScanSlice(rows ColumnScanner, v any) error { columns, err := rows.Columns() if err != nil { return fmt.Errorf("sql/scan: failed getting column names: %w", err) } rv := reflect.ValueOf(v) switch { case rv.Kind() != reflect.Ptr: if t := reflect.TypeOf(v); t != nil { return fmt.Errorf("sql/scan: ScanSlice(non-pointer %s)", t) } fallthrough case rv.IsNil(): return fmt.Errorf("sql/scan: ScanSlice(nil)") } rv = reflect.Indirect(rv) if k := rv.Kind(); k != reflect.Slice { return fmt.Errorf("sql/scan: invalid type %s. expected slice as an argument", k) } scan, err := scanType(rv.Type().Elem(), columns) if err != nil { return err } if n, m := len(columns), len(scan.columns); n > m { return fmt.Errorf("sql/scan: columns do not match (%d > %d)", n, m) } for rows.Next() { values := scan.values() if err := rows.Scan(values...); err != nil { return fmt.Errorf("sql/scan: failed scanning rows: %w", err) } vv := reflect.Append(rv, scan.value(values...)) rv.Set(vv) } return rows.Err() } // rowScan is the configuration for scanning one sql.Row. type rowScan struct { // column types of a row. columns []reflect.Type // value functions that converts the row columns (result) to a reflect.Value. value func(v ...any) reflect.Value } // values returns a []any from the configured column types. func (r *rowScan) values() []any { values := make([]any, len(r.columns)) for i := range r.columns { values[i] = reflect.New(r.columns[i]).Interface() } return values } // scanType returns rowScan for the given reflect.Type. func scanType(typ reflect.Type, columns []string) (*rowScan, error) { switch k := typ.Kind(); { case assignable(typ): return &rowScan{ columns: []reflect.Type{typ}, value: func(v ...any) reflect.Value { return reflect.Indirect(reflect.ValueOf(v[0])) }, }, nil case k == reflect.Ptr: return scanPtr(typ, columns) case k == reflect.Struct: return scanStruct(typ, columns) default: return nil, fmt.Errorf("sql/scan: unsupported type ([]%s)", k) } } var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // assignable reports if the given type can be assigned directly by `Rows.Scan`. func assignable(typ reflect.Type) bool { switch k := typ.Kind(); { case typ.Implements(scannerType): case k == reflect.Interface && typ.NumMethod() == 0: case k == reflect.String || k >= reflect.Bool && k <= reflect.Float64: case (k == reflect.Slice || k == reflect.Array) && typ.Elem().Kind() == reflect.Uint8: default: return false } return true } // scanStruct returns the a configuration for scanning an sql.Row into a struct. func scanStruct(typ reflect.Type, columns []string) (*rowScan, error) { var ( scan = &rowScan{} idxs = make([][]int, 0, typ.NumField()) names = make(map[string][]int, typ.NumField()) ) for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) // Skip unexported fields. if f.PkgPath != "" { continue } // Support 1-level embedding to accepts types as `type T struct {ent.T; V int}`. if typ := f.Type; f.Anonymous && typ.Kind() == reflect.Struct { for j := 0; j < typ.NumField(); j++ { names[columnName(typ.Field(j))] = []int{i, j} } continue } names[columnName(f)] = []int{i} } for _, c := range columns { // Normalize columns if necessary, for example: COUNT(*) => count. name := strings.ToLower(strings.Split(c, "(")[0]) idx, ok := names[name] if !ok { return nil, fmt.Errorf("sql/scan: missing struct field for column: %s (%s)", c, name) } idxs = append(idxs, idx) rtype := typ.Field(idx[0]).Type if len(idx) > 1 { rtype = rtype.Field(idx[1]).Type } if !nillable(rtype) { // Create a pointer to the actual reflect // types to accept optional struct fields. rtype = reflect.PtrTo(rtype) } scan.columns = append(scan.columns, rtype) } scan.value = func(vs ...any) reflect.Value { st := reflect.New(typ).Elem() for i, v := range vs { rv := reflect.Indirect(reflect.ValueOf(v)) if rv.IsNil() { continue } idx := idxs[i] rvalue := st.Field(idx[0]) if len(idx) > 1 { rvalue = rvalue.Field(idx[1]) } if !nillable(rvalue.Type()) { rv = reflect.Indirect(rv) } rvalue.Set(rv) } return st } return scan, nil } // columnName returns the column name of a struct-field. func columnName(f reflect.StructField) string { name := strings.ToLower(f.Name) if tag, ok := f.Tag.Lookup("sql"); ok { name = tag } else if tag, ok := f.Tag.Lookup("json"); ok { name = strings.Split(tag, ",")[0] } return name } // nillable reports if the reflect-type can have nil value. func nillable(t reflect.Type) bool { switch t.Kind() { case reflect.Interface, reflect.Slice, reflect.Map, reflect.Ptr, reflect.UnsafePointer: return true } return false } // scanPtr wraps the underlying type with rowScan. func scanPtr(typ reflect.Type, columns []string) (*rowScan, error) { typ = typ.Elem() scan, err := scanType(typ, columns) if err != nil { return nil, err } wrap := scan.value scan.value = func(vs ...any) reflect.Value { v := wrap(vs...) pt := reflect.PtrTo(v.Type()) pv := reflect.New(pt.Elem()) pv.Elem().Set(v) return pv } return scan, nil } ent-0.11.3/dialect/sql/scan_test.go000066400000000000000000000164131431500740500171040ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sql import ( "database/sql" "database/sql/driver" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/google/uuid" "github.com/stretchr/testify/require" ) func TestScanSlice(t *testing.T) { mock := sqlmock.NewRows([]string{"name"}). AddRow("foo"). AddRow("bar") var v0 []string require.NoError(t, ScanSlice(toRows(mock), &v0)) require.Equal(t, []string{"foo", "bar"}, v0) mock = sqlmock.NewRows([]string{"age"}). AddRow(1). AddRow(2) var v1 []int require.NoError(t, ScanSlice(toRows(mock), &v1)) require.Equal(t, []int{1, 2}, v1) mock = sqlmock.NewRows([]string{"name", "COUNT(*)"}). AddRow("foo", 1). AddRow("bar", 2) var v2 []struct { Name string Count int } require.NoError(t, ScanSlice(toRows(mock), &v2)) require.Equal(t, "foo", v2[0].Name) require.Equal(t, "bar", v2[1].Name) require.Equal(t, 1, v2[0].Count) require.Equal(t, 2, v2[1].Count) mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}). AddRow("foo", 1). AddRow("bar", 2) var v3 []struct { Count int Name string `json:"nick_name"` } require.NoError(t, ScanSlice(toRows(mock), &v3)) require.Equal(t, "foo", v3[0].Name) require.Equal(t, "bar", v3[1].Name) require.Equal(t, 1, v3[0].Count) require.Equal(t, 2, v3[1].Count) mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}). AddRow("foo", 1). AddRow("bar", 2) var v4 []*struct { Count int Name string `json:"nick_name"` Ignored string `json:"string"` } require.NoError(t, ScanSlice(toRows(mock), &v4)) require.Equal(t, "foo", v4[0].Name) require.Equal(t, "bar", v4[1].Name) require.Equal(t, 1, v4[0].Count) require.Equal(t, 2, v4[1].Count) mock = sqlmock.NewRows([]string{"nick_name", "COUNT(*)"}). AddRow("foo", 1). AddRow("bar", 2) var v5 []*struct { Count int Name string `json:"name" sql:"nick_name"` } require.NoError(t, ScanSlice(toRows(mock), &v5)) require.Equal(t, "foo", v5[0].Name) require.Equal(t, "bar", v5[1].Name) require.Equal(t, 1, v5[0].Count) require.Equal(t, 2, v5[1].Count) mock = sqlmock.NewRows([]string{"age", "name"}). AddRow(1, nil). AddRow(nil, "a8m") var v6 []struct { Age NullInt64 Name NullString } require.NoError(t, ScanSlice(toRows(mock), &v6)) require.EqualValues(t, 1, v6[0].Age.Int64) require.False(t, v6[0].Name.Valid) require.False(t, v6[1].Age.Valid) require.Equal(t, "a8m", v6[1].Name.String) u1, u2 := uuid.New().String(), uuid.New().String() mock = sqlmock.NewRows([]string{"ids"}). AddRow([]byte(u1)). AddRow([]byte(u2)) var ids []uuid.UUID require.NoError(t, ScanSlice(toRows(mock), &ids)) require.Equal(t, u1, ids[0].String()) require.Equal(t, u2, ids[1].String()) mock = sqlmock.NewRows([]string{"pids"}). AddRow([]byte(u1)). AddRow([]byte(u2)) var pids []*uuid.UUID require.NoError(t, ScanSlice(toRows(mock), &pids)) require.Equal(t, u1, pids[0].String()) require.Equal(t, u2, pids[1].String()) mock = sqlmock.NewRows([]string{"id", "first", "last"}). AddRow(10, "Ariel", "Mashraki") err := ScanSlice(toRows(mock), nil) require.EqualError(t, err, "sql/scan: ScanSlice(nil)") type P struct { _ int ID int First string Last string } var p []P err = ScanSlice(toRows(mock), p) require.EqualError(t, err, "sql/scan: ScanSlice(non-pointer []sql.P)") require.NoError(t, ScanSlice(toRows(mock), &p)) require.Equal(t, 10, p[0].ID) require.Equal(t, "Ariel", p[0].First) require.Equal(t, "Mashraki", p[0].Last) var pp []struct{ _, id int } mock = sqlmock.NewRows([]string{"id"}). AddRow(10) err = ScanSlice(toRows(mock), &pp) require.EqualError(t, err, "sql/scan: missing struct field for column: id (id)") require.Empty(t, pp) } func TestScanNestedStruct(t *testing.T) { mock := sqlmock.NewRows([]string{"name", "age"}). AddRow("foo", 1). AddRow("bar", 2). AddRow("baz", nil) type T struct{ Name string } var v []struct { T Age int } require.NoError(t, ScanSlice(toRows(mock), &v)) require.Equal(t, "foo", v[0].Name) require.Equal(t, 1, v[0].Age) require.Equal(t, "bar", v[1].Name) require.Equal(t, 2, v[1].Age) require.Equal(t, "baz", v[2].Name) require.Equal(t, 0, v[2].Age) mock = sqlmock.NewRows([]string{"name", "age"}). AddRow("foo", 1). AddRow("bar", nil) type T1 struct{ Name **string } var v1 []struct { T1 Age *int } require.NoError(t, ScanSlice(toRows(mock), &v1)) require.Equal(t, "foo", **v1[0].Name) require.Equal(t, "bar", **v1[1].Name) require.Equal(t, 1, *v1[0].Age) require.Nil(t, v1[1].Age) } func TestScanSlicePtr(t *testing.T) { mock := sqlmock.NewRows([]string{"name"}). AddRow("foo"). AddRow("bar") var v0 []*string require.NoError(t, ScanSlice(toRows(mock), &v0)) require.Equal(t, "foo", *v0[0]) require.Equal(t, "bar", *v0[1]) mock = sqlmock.NewRows([]string{"age"}). AddRow(1). AddRow(2) var v1 []**int require.NoError(t, ScanSlice(toRows(mock), &v1)) require.Equal(t, 1, **v1[0]) require.Equal(t, 2, **v1[1]) mock = sqlmock.NewRows([]string{"age", "name"}). AddRow(1, "a8m"). AddRow(2, "nati") var v2 []*struct { Age *int Name **string } require.NoError(t, ScanSlice(toRows(mock), &v2)) require.Equal(t, 1, *v2[0].Age) require.Equal(t, "a8m", **v2[0].Name) require.Equal(t, 2, *v2[1].Age) require.Equal(t, "nati", **v2[1].Name) } func TestScanInt64(t *testing.T) { mock := sqlmock.NewRows([]string{"age"}). AddRow("10"). AddRow("20") n, err := ScanInt64(toRows(mock)) require.Error(t, err) require.Zero(t, n) mock = sqlmock.NewRows([]string{"age", "count"}). AddRow("10", "1") n, err = ScanInt64(toRows(mock)) require.Error(t, err) require.Zero(t, n) mock = sqlmock.NewRows([]string{"count"}). AddRow(10) n, err = ScanInt64(toRows(mock)) require.NoError(t, err) require.EqualValues(t, 10, n) mock = sqlmock.NewRows([]string{"count"}). AddRow("10") n, err = ScanInt64(toRows(mock)) require.NoError(t, err) require.EqualValues(t, 10, n) } func TestScanValue(t *testing.T) { mock := sqlmock.NewRows([]string{"count"}). AddRow(10) n, err := ScanValue(toRows(mock)) require.NoError(t, err) require.EqualValues(t, 10, n) } func TestScanOne(t *testing.T) { mock := sqlmock.NewRows([]string{"name"}). AddRow("10"). AddRow("20") err := ScanOne(toRows(mock), new(string)) require.Error(t, err, "multiple lines") mock = sqlmock.NewRows([]string{"name"}). AddRow("10") err = ScanOne(toRows(mock), "") require.Error(t, err, "not a pointer") mock = sqlmock.NewRows([]string{"name"}). AddRow("10") var s string err = ScanOne(toRows(mock), &s) require.NoError(t, err) require.Equal(t, "10", s) } func TestInterface(t *testing.T) { mock := sqlmock.NewRows([]string{"age"}). AddRow("10"). AddRow("20") var values []driver.Value err := ScanSlice(toRows(mock), &values) require.NoError(t, err) require.Equal(t, []driver.Value{"10", "20"}, values) mock = sqlmock.NewRows([]string{"age"}). AddRow(10). AddRow(20) values = values[:0:0] err = ScanSlice(toRows(mock), &values) require.NoError(t, err) require.Equal(t, []driver.Value{int64(10), int64(20)}, values) } func toRows(mrows *sqlmock.Rows) *sql.Rows { db, mock, _ := sqlmock.New() mock.ExpectQuery("").WillReturnRows(mrows) rows, _ := db.Query("") return rows } ent-0.11.3/dialect/sql/schema/000077500000000000000000000000001431500740500160255ustar00rootroot00000000000000ent-0.11.3/dialect/sql/schema/atlas.go000066400000000000000000000755161431500740500174760ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "crypto/md5" "database/sql" "errors" "fmt" "net/url" "sort" "strings" "ariga.io/atlas/sql/migrate" "ariga.io/atlas/sql/schema" "ariga.io/atlas/sql/sqlclient" "ariga.io/atlas/sql/sqltool" "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" ) // Atlas atlas migration engine. type Atlas struct { atDriver migrate.Driver sqlDialect sqlDialect legacy bool // if the legacy migration engine instead of Atlas should be used withFixture bool // deprecated: with fks rename fixture sum bool // deprecated: sum file generation will be required universalID bool // global unique ids dropColumns bool // drop deleted columns dropIndexes bool // drop deleted indexes withForeignKeys bool // with foreign keys mode Mode hooks []Hook // hooks to apply before creation diffHooks []DiffHook // diff hooks to run when diffing current and desired applyHook []ApplyHook // apply hooks to run when applying the plan skip ChangeKind // what changes to skip and not apply dir migrate.Dir // the migration directory to read from fmt migrate.Formatter // how to format the plan into migration files driver dialect.Driver // driver passed in when not using an atlas URL url *url.URL // url of database connection dialect string // Ent dialect to use when generating migration files types []string // pre-existing pk range allocation for global unique id } // Diff compares the state read from a database connection or migration directory with the state defined by the Ent // schema. Changes will be written to new migration files. func Diff(ctx context.Context, u, name string, tables []*Table, opts ...MigrateOption) (err error) { m, err := NewMigrateURL(u, opts...) if err != nil { return err } return m.NamedDiff(ctx, name, tables...) } // NewMigrate creates a new Atlas form the given dialect.Driver. func NewMigrate(drv dialect.Driver, opts ...MigrateOption) (*Atlas, error) { a := &Atlas{driver: drv, withForeignKeys: true, mode: ModeInspect, sum: true} for _, opt := range opts { opt(a) } a.dialect = a.driver.Dialect() if err := a.init(); err != nil { return nil, err } return a, nil } // NewMigrateURL create a new Atlas from the given url. func NewMigrateURL(u string, opts ...MigrateOption) (*Atlas, error) { parsed, err := url.Parse(u) if err != nil { return nil, err } a := &Atlas{url: parsed, withForeignKeys: true, mode: ModeInspect, sum: true} for _, opt := range opts { opt(a) } if a.dialect == "" { a.dialect = parsed.Scheme } if err := a.init(); err != nil { return nil, err } return a, nil } // Create creates all schema resources in the database. It works in an "append-only" // mode, which means, it only creates tables, appends columns to tables or modifies column types. // // Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not // resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but // changing varchar(120) to varchar(255) is valid. For more info, see the convert function below. func (a *Atlas) Create(ctx context.Context, tables ...*Table) (err error) { a.setupTables(tables) var creator Creator = CreateFunc(a.create) if a.legacy { m, err := a.legacyMigrate() if err != nil { return err } creator = CreateFunc(m.create) } for i := len(a.hooks) - 1; i >= 0; i-- { creator = a.hooks[i](creator) } return creator.Create(ctx, tables...) } // Diff compares the state read from the connected database with the state defined by Ent. // Changes will be written to migration files by the configured Planner. func (a *Atlas) Diff(ctx context.Context, tables ...*Table) error { return a.NamedDiff(ctx, "changes", tables...) } // NamedDiff compares the state read from the connected database with the state defined by Ent. // Changes will be written to migration files by the configured Planner. func (a *Atlas) NamedDiff(ctx context.Context, name string, tables ...*Table) error { if a.dir == nil { return errors.New("no migration directory given") } opts := []migrate.PlannerOption{migrate.WithFormatter(a.fmt)} if a.sum { // Validate the migration directory before proceeding. if err := migrate.Validate(a.dir); err != nil { return fmt.Errorf("validating migration directory: %w", err) } } else { opts = append(opts, migrate.DisableChecksum()) } a.setupTables(tables) // Set up connections. if a.driver != nil { var err error a.sqlDialect, err = a.entDialect(a.driver) if err != nil { return err } a.atDriver, err = a.sqlDialect.atOpen(a.sqlDialect) if err != nil { return err } } else { c, err := sqlclient.OpenURL(ctx, a.url) if err != nil { return err } defer c.Close() a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB)) if err != nil { return err } a.atDriver = c.Driver } defer func() { a.sqlDialect = nil a.atDriver = nil }() if err := a.sqlDialect.init(ctx); err != nil { return err } if a.universalID { tables = append(tables, NewTable(TypeTable). AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}). AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}), ) } switch a.mode { case ModeInspect: // Do nothing here, simply inspect later on. case ModeReplay: // We consider a database clean if there are no tables in the connected schema. s, err := a.atDriver.InspectSchema(ctx, "", nil) if err != nil { return err } if len(s.Tables) > 0 { return migrate.NotCleanError{Reason: fmt.Sprintf("found table %q", s.Tables[0].Name)} } // Clean up once done. defer func() { // We clean a database by dropping all tables inside the connected schema. s, err = a.atDriver.InspectSchema(ctx, "", nil) if err != nil { return } tbls := make([]schema.Change, len(s.Tables)) for i, t := range s.Tables { tbls[i] = &schema.DropTable{T: t} } if err2 := a.atDriver.ApplyChanges(ctx, tbls); err2 != nil { if err != nil { err = fmt.Errorf("%v: %w", err2, err) return } err = err2 return } }() // Replay the migration directory on the database. ex, err := migrate.NewExecutor(a.atDriver, a.dir, &migrate.NopRevisionReadWriter{}) if err != nil { return err } if err := ex.ExecuteN(ctx, 0); err != nil && !errors.Is(err, migrate.ErrNoPendingFiles) { return err } default: return fmt.Errorf("unknown migration mode: %q", a.mode) } plan, err := a.plan(ctx, a.sqlDialect, name, tables) if err != nil { return err } // Skip if the plan has no changes. if len(plan.Changes) == 0 { return nil } return migrate.NewPlanner(nil, a.dir, opts...).WritePlan(plan) } // VerifyTableRange ensures, that the defined autoincrement starting value is set for each table as defined by the // TypTable. This is necessary for MySQL versions < 8.0. In those versions the defined starting value for AUTOINCREMENT // columns was stored in memory, and when a server restarts happens and there are no rows yet in a table, the defined // starting value is lost, which will result in incorrect behavior when working with global unique ids. Calling this // method on service start ensures the information are correct and are set again, if they aren't. For MySQL versions > 8 // calling this method is only required once after the upgrade. func (a *Atlas) VerifyTableRange(ctx context.Context, tables []*Table) error { if a.driver != nil { var err error a.sqlDialect, err = a.entDialect(a.driver) if err != nil { return err } } else { c, err := sqlclient.OpenURL(ctx, a.url) if err != nil { return err } defer c.Close() a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB)) if err != nil { return err } } defer func() { a.sqlDialect = nil }() vr, ok := a.sqlDialect.(verifyRanger) if !ok { return nil } types, err := a.loadTypes(ctx, a.sqlDialect) if err != nil { // In most cases this means the table does not exist, which in turn // indicates the user does not use global unique ids. return err } for _, t := range tables { id := indexOf(types, t.Name) if id == -1 { continue } if err := vr.verifyRange(ctx, a.sqlDialect, t, int64(id<<32)); err != nil { return err } } return nil } type ( // Differ is the interface that wraps the Diff method. Differ interface { // Diff returns a list of changes that construct a migration plan. Diff(current, desired *schema.Schema) ([]schema.Change, error) } // The DiffFunc type is an adapter to allow the use of ordinary function as Differ. // If f is a function with the appropriate signature, DiffFunc(f) is a Differ that calls f. DiffFunc func(current, desired *schema.Schema) ([]schema.Change, error) // DiffHook defines the "diff middleware". A function that gets a Differ and returns a Differ. DiffHook func(Differ) Differ ) // Diff calls f(current, desired). func (f DiffFunc) Diff(current, desired *schema.Schema) ([]schema.Change, error) { return f(current, desired) } // WithDiffHook adds a list of DiffHook to the schema migration. // // schema.WithDiffHook(func(next schema.Differ) schema.Differ { // return schema.DiffFunc(func(current, desired *atlas.Schema) ([]atlas.Change, error) { // // Code before standard diff. // changes, err := next.Diff(current, desired) // if err != nil { // return nil, err // } // // After diff, you can filter // // changes or return new ones. // return changes, nil // }) // }) func WithDiffHook(hooks ...DiffHook) MigrateOption { return func(a *Atlas) { a.diffHooks = append(a.diffHooks, hooks...) } } // WithSkipChanges allows skipping/filtering list of changes // returned by the Differ before executing migration planning. // // SkipChanges(schema.DropTable|schema.DropColumn) func WithSkipChanges(skip ChangeKind) MigrateOption { return func(a *Atlas) { a.skip = skip } } // A ChangeKind denotes the kind of schema change. type ChangeKind uint // List of change types. const ( NoChange ChangeKind = 0 AddSchema ChangeKind = 1 << (iota - 1) ModifySchema DropSchema AddTable ModifyTable DropTable AddColumn ModifyColumn DropColumn AddIndex ModifyIndex DropIndex AddForeignKey ModifyForeignKey DropForeignKey AddCheck ModifyCheck DropCheck ) // Is reports whether c is match the given change kind. func (k ChangeKind) Is(c ChangeKind) bool { return k == c || k&c != 0 } // filterChanges is a DiffHook for filtering changes before plan. func filterChanges(skip ChangeKind) DiffHook { return func(next Differ) Differ { return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) { var f func([]schema.Change) []schema.Change f = func(changes []schema.Change) (keep []schema.Change) { var k ChangeKind for _, c := range changes { switch c := c.(type) { case *schema.AddSchema: k = AddSchema case *schema.ModifySchema: k = ModifySchema if !skip.Is(k) { c.Changes = f(c.Changes) } case *schema.DropSchema: k = DropSchema case *schema.AddTable: k = AddTable case *schema.ModifyTable: k = ModifyTable if !skip.Is(k) { c.Changes = f(c.Changes) } case *schema.DropTable: k = DropTable case *schema.AddColumn: k = AddColumn case *schema.ModifyColumn: k = ModifyColumn case *schema.DropColumn: k = DropColumn case *schema.AddIndex: k = AddIndex case *schema.ModifyIndex: k = ModifyIndex case *schema.DropIndex: k = DropIndex case *schema.AddForeignKey: k = AddIndex case *schema.ModifyForeignKey: k = ModifyForeignKey case *schema.DropForeignKey: k = DropForeignKey case *schema.AddCheck: k = AddCheck case *schema.ModifyCheck: k = ModifyCheck case *schema.DropCheck: k = DropCheck } if !skip.Is(k) { keep = append(keep, c) } } return } changes, err := next.Diff(current, desired) if err != nil { return nil, err } return f(changes), nil }) } } func withoutForeignKeys(next Differ) Differ { return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) { changes, err := next.Diff(current, desired) if err != nil { return nil, err } for _, c := range changes { switch c := c.(type) { case *schema.AddTable: c.T.ForeignKeys = nil case *schema.ModifyTable: c.T.ForeignKeys = nil filtered := make([]schema.Change, 0, len(c.Changes)) for _, change := range c.Changes { switch change.(type) { case *schema.AddForeignKey, *schema.DropForeignKey, *schema.ModifyForeignKey: continue default: filtered = append(filtered, change) } } c.Changes = filtered } } return changes, nil }) } type ( // Applier is the interface that wraps the Apply method. Applier interface { // Apply applies the given migrate.Plan on the database. Apply(context.Context, dialect.ExecQuerier, *migrate.Plan) error } // The ApplyFunc type is an adapter to allow the use of ordinary function as Applier. // If f is a function with the appropriate signature, ApplyFunc(f) is an Applier that calls f. ApplyFunc func(context.Context, dialect.ExecQuerier, *migrate.Plan) error // ApplyHook defines the "migration applying middleware". A function that gets an Applier and returns an Applier. ApplyHook func(Applier) Applier ) // Apply calls f(ctx, tables...). func (f ApplyFunc) Apply(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error { return f(ctx, conn, plan) } // WithApplyHook adds a list of ApplyHook to the schema migration. // // schema.WithApplyHook(func(next schema.Applier) schema.Applier { // return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error { // // Example to hook into the apply process, or implement // // a custom applier. // // // // for _, c := range plan.Changes { // // fmt.Printf("%s: %s", c.Comment, c.Cmd) // // } // // // return next.Apply(ctx, conn, plan) // }) // }) func WithApplyHook(hooks ...ApplyHook) MigrateOption { return func(a *Atlas) { a.applyHook = append(a.applyHook, hooks...) } } // WithAtlas is an opt-out option for v0.11 indicating the migration // should be executed using the deprecated legacy engine. // Note, in future versions, this option is going to be removed // and the Atlas (https://atlasgo.io) based migration engine should be used. // // Deprecated: The legacy engine will be removed. func WithAtlas(b bool) MigrateOption { return func(a *Atlas) { a.legacy = !b } } // WithDir sets the atlas migration directory to use to store migration files. func WithDir(dir migrate.Dir) MigrateOption { return func(a *Atlas) { a.dir = dir } } // WithFormatter sets atlas formatter to use to write changes to migration files. func WithFormatter(fmt migrate.Formatter) MigrateOption { return func(a *Atlas) { a.fmt = fmt } } // WithDialect configures the Ent dialect to use when migrating for an Atlas supported dialect flavor. // As an example, Ent can work with TiDB in MySQL dialect and Atlas can handle TiDB migrations. func WithDialect(d string) MigrateOption { return func(a *Atlas) { a.dialect = d } } // WithSumFile instructs atlas to generate a migration directory integrity sum file. // // Deprecated: generating the sum file is now opt-out. This method will be removed in future versions. func WithSumFile() MigrateOption { return func(a *Atlas) {} } // DisableChecksum instructs atlas to skip migration directory integrity sum file generation. // // Deprecated: generating the sum file will no longer be optional in future versions. func DisableChecksum() MigrateOption { return func(a *Atlas) { a.sum = false } } // WithMigrationMode instructs atlas how to compute the current state of the schema. This can be done by either // replaying (ModeReplay) the migration directory on the connected database, or by inspecting (ModeInspect) the // connection. Currently, ModeReplay is opt-in, and ModeInspect is the default. In future versions, ModeReplay will // become the default behavior. This option has no effect when using online migrations. func WithMigrationMode(mode Mode) MigrateOption { return func(a *Atlas) { a.mode = mode } } // Mode to compute the current state. type Mode uint const ( // ModeReplay computes the current state by replaying the migration directory on the connected database. ModeReplay = iota // ModeInspect computes the current state by inspecting the connected database. ModeInspect ) // StateReader returns an atlas migrate.StateReader returning the state as described by the Ent table slice. func (a *Atlas) StateReader(tables ...*Table) migrate.StateReaderFunc { return func(context.Context) (*schema.Realm, error) { ts, err := a.tables(tables) if err != nil { return nil, err } return &schema.Realm{Schemas: []*schema.Schema{{Tables: ts}}}, nil } } // atBuilder must be implemented by the different drivers in // order to convert a dialect/sql/schema to atlas/sql/schema. type atBuilder interface { atOpen(dialect.ExecQuerier) (migrate.Driver, error) atTable(*Table, *schema.Table) atTypeC(*Column, *schema.Column) error atUniqueC(*Table, *Column, *schema.Table, *schema.Column) atIncrementC(*schema.Table, *schema.Column) atIncrementT(*schema.Table, int64) atIndex(*Index, *schema.Table, *schema.Index) error atTypeRangeSQL(t ...string) string } // init initializes the configuration object based on the options passed in. func (a *Atlas) init() error { skip := DropIndex | DropColumn if a.skip != NoChange { skip = a.skip } if a.dropIndexes { skip &= ^DropIndex } if a.dropColumns { skip &= ^DropColumn } if skip != NoChange { a.diffHooks = append(a.diffHooks, filterChanges(skip)) } if !a.withForeignKeys { a.diffHooks = append(a.diffHooks, withoutForeignKeys) } if a.dir != nil && a.fmt == nil { switch a.dir.(type) { case *sqltool.GooseDir: a.fmt = sqltool.GooseFormatter case *sqltool.DBMateDir: a.fmt = sqltool.DBMateFormatter case *sqltool.FlywayDir: a.fmt = sqltool.FlywayFormatter case *sqltool.LiquibaseDir: a.fmt = sqltool.LiquibaseFormatter default: // migrate.LocalDir, sqltool.GolangMigrateDir and custom ones a.fmt = sqltool.GolangMigrateFormatter } } if a.mode == ModeReplay { // ModeReplay requires a migration directory. if a.dir == nil { return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires versioned migrations: WithDir()") } // ModeReplay requires sum file generation. if !a.sum { return errors.New("sql/schema: WithMigrationMode(ModeReplay) requires migration directory integrity file") } } return nil } // create is the Atlas engine based online migration. func (a *Atlas) create(ctx context.Context, tables ...*Table) (err error) { if a.universalID { tables = append(tables, NewTable(TypeTable). AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}). AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}), ) } if a.driver != nil { a.sqlDialect, err = a.entDialect(a.driver) if err != nil { return err } } else { c, err := sqlclient.OpenURL(ctx, a.url) if err != nil { return err } defer c.Close() a.sqlDialect, err = a.entDialect(entsql.OpenDB(a.dialect, c.DB)) if err != nil { return err } } defer func() { a.sqlDialect = nil }() if err := a.sqlDialect.init(ctx); err != nil { return err } // Open a transaction for backwards compatibility, // even if the migration is not transactional. tx, err := a.sqlDialect.Tx(ctx) if err != nil { return err } a.atDriver, err = a.sqlDialect.atOpen(tx) if err != nil { return err } defer func() { a.atDriver = nil }() if err := func() error { plan, err := a.plan(ctx, tx, "changes", tables) if err != nil { return err } // Apply plan (changes). var applier Applier = ApplyFunc(func(ctx context.Context, tx dialect.ExecQuerier, plan *migrate.Plan) error { for _, c := range plan.Changes { if err := tx.Exec(ctx, c.Cmd, c.Args, nil); err != nil { if c.Comment != "" { err = fmt.Errorf("%s: %w", c.Comment, err) } return err } } return nil }) for i := len(a.applyHook) - 1; i >= 0; i-- { applier = a.applyHook[i](applier) } return applier.Apply(ctx, tx, plan) }(); err != nil { err = fmt.Errorf("sql/schema: %w", err) if rerr := tx.Rollback(); rerr != nil { err = fmt.Errorf("%w: %v", err, rerr) } return err } return tx.Commit() } // plan creates the current state by inspecting the connected database, computing the current state of the Ent schema // and proceeds to diff the changes to create a migration plan. // before diffing. func (a *Atlas) plan(ctx context.Context, conn dialect.ExecQuerier, name string, tables []*Table) (*migrate.Plan, error) { current, err := a.atDriver.InspectSchema(ctx, "", &schema.InspectOptions{ Tables: func() (t []string) { for i := range tables { t = append(t, tables[i].Name) } return t }(), }) if err != nil { return nil, err } var types []string if a.universalID { types, err = a.loadTypes(ctx, conn) if err != nil && !errors.Is(err, errTypeTableNotFound) { return nil, err } a.types = types } desired, err := a.StateReader(tables...).ReadState(ctx) if err != nil { return nil, err } // Diff changes. changes, err := (&diffDriver{a.atDriver, a.diffHooks}).SchemaDiff(current, &schema.Schema{ Name: current.Name, Attrs: current.Attrs, Tables: desired.Schemas[0].Tables, }) if err != nil { return nil, err } // Plan changes. plan, err := a.atDriver.PlanChanges(ctx, name, changes) if err != nil { return nil, err } // Insert new types. newTypes := a.types[len(types):] if len(newTypes) > 0 { plan.Changes = append(plan.Changes, &migrate.Change{ Cmd: a.sqlDialect.atTypeRangeSQL(newTypes...), Comment: fmt.Sprintf("add pk ranges for %s tables", strings.Join(newTypes, ",")), }) } return plan, nil } var errTypeTableNotFound = errors.New("ent_type table not found") // loadTypes loads the currently saved range allocations from the TypeTable. func (a *Atlas) loadTypes(ctx context.Context, conn dialect.ExecQuerier) ([]string, error) { // Fetch pre-existing type allocations. exists, err := a.sqlDialect.tableExist(ctx, conn, TypeTable) if err != nil { return nil, err } if !exists { return nil, errTypeTableNotFound } rows := &entsql.Rows{} query, args := entsql.Dialect(a.dialect). Select("type").From(entsql.Table(TypeTable)).OrderBy(entsql.Asc("id")).Query() if err := conn.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("query types table: %w", err) } defer rows.Close() var types []string if err := entsql.ScanSlice(rows, &types); err != nil { return nil, err } return types, nil } type db struct{ dialect.ExecQuerier } func (d *db) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { rows := &entsql.Rows{} if err := d.ExecQuerier.Query(ctx, query, args, rows); err != nil { return nil, err } return rows.ColumnScanner.(*sql.Rows), nil } func (d *db) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { var r sql.Result if err := d.ExecQuerier.Exec(ctx, query, args, &r); err != nil { return nil, err } return r, nil } // tables converts an Ent table slice to an atlas table slice func (a *Atlas) tables(tables []*Table) ([]*schema.Table, error) { ts := make([]*schema.Table, len(tables)) for i, et := range tables { at := schema.NewTable(et.Name) a.sqlDialect.atTable(et, at) if a.universalID && et.Name != TypeTable && len(et.PrimaryKey) == 1 { r, err := a.pkRange(et) if err != nil { return nil, err } a.sqlDialect.atIncrementT(at, r) } if err := a.aColumns(et, at); err != nil { return nil, err } if err := a.aIndexes(et, at); err != nil { return nil, err } ts[i] = at } for i, t1 := range tables { t2 := ts[i] for _, fk1 := range t1.ForeignKeys { fk2 := schema.NewForeignKey(fk1.Symbol). SetTable(t2). SetOnUpdate(schema.ReferenceOption(fk1.OnUpdate)). SetOnDelete(schema.ReferenceOption(fk1.OnDelete)) for _, c1 := range fk1.Columns { c2, ok := t2.Column(c1.Name) if !ok { return nil, fmt.Errorf("unexpected fk %q column: %q", fk1.Symbol, c1.Name) } fk2.AddColumns(c2) } var refT *schema.Table for _, t2 := range ts { if t2.Name == fk1.RefTable.Name { refT = t2 break } } if refT == nil { return nil, fmt.Errorf("unexpected fk %q ref-table: %q", fk1.Symbol, fk1.RefTable.Name) } fk2.SetRefTable(refT) for _, c1 := range fk1.RefColumns { c2, ok := refT.Column(c1.Name) if !ok { return nil, fmt.Errorf("unexpected fk %q ref-column: %q", fk1.Symbol, c1.Name) } fk2.AddRefColumns(c2) } t2.AddForeignKeys(fk2) } } return ts, nil } func (a *Atlas) aColumns(et *Table, at *schema.Table) error { for _, c1 := range et.Columns { c2 := schema.NewColumn(c1.Name). SetNull(c1.Nullable) if c1.Collation != "" { c2.SetCollation(c1.Collation) } if err := a.sqlDialect.atTypeC(c1, c2); err != nil { return err } if c1.Default != nil && c1.supportDefault() { // Has default and the database supports adding this default. x := fmt.Sprint(c1.Default) if v, ok := c1.Default.(string); ok && c1.Type != field.TypeUUID && c1.Type != field.TypeTime { // Escape single quote by replacing each with 2. x = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) } c2.SetDefault(&schema.RawExpr{X: x}) } if c1.Unique && (len(et.PrimaryKey) != 1 || et.PrimaryKey[0] != c1) { a.sqlDialect.atUniqueC(et, c1, at, c2) } if c1.Increment { a.sqlDialect.atIncrementC(at, c2) } at.AddColumns(c2) } return nil } func (a *Atlas) aIndexes(et *Table, at *schema.Table) error { // Primary-key index. pk := make([]*schema.Column, 0, len(et.PrimaryKey)) for _, c1 := range et.PrimaryKey { c2, ok := at.Column(c1.Name) if !ok { return fmt.Errorf("unexpected primary-key column: %q", c1.Name) } pk = append(pk, c2) } at.SetPrimaryKey(schema.NewPrimaryKey(pk...)) // Rest of indexes. for _, idx1 := range et.Indexes { idx2 := schema.NewIndex(idx1.Name). SetUnique(idx1.Unique) if err := a.sqlDialect.atIndex(idx1, at, idx2); err != nil { return err } desc := descIndexes(idx1) for _, p := range idx2.Parts { p.Desc = desc[p.C.Name] } at.AddIndexes(idx2) } return nil } // setupTables ensures the table is configured properly, like table columns // are linked to their indexes, and PKs columns are defined. func (a *Atlas) setupTables(tables []*Table) { for _, t := range tables { if t.columns == nil { t.columns = make(map[string]*Column, len(t.Columns)) } for _, c := range t.Columns { t.columns[c.Name] = c } for _, idx := range t.Indexes { idx.Name = a.symbol(idx.Name) for _, c := range idx.Columns { c.indexes.append(idx) } } for _, pk := range t.PrimaryKey { c := t.columns[pk.Name] c.Key = PrimaryKey pk.Key = PrimaryKey } for _, fk := range t.ForeignKeys { fk.Symbol = a.symbol(fk.Symbol) for i := range fk.Columns { fk.Columns[i].foreign = fk } } } } // symbol makes sure the symbol length is not longer than the maxlength in the dialect. func (a *Atlas) symbol(name string) string { size := 64 if a.dialect == dialect.Postgres { size = 63 } if len(name) <= size { return name } return fmt.Sprintf("%s_%x", name[:size-33], md5.Sum([]byte(name))) } // entDialect returns the Ent dialect as configured by the dialect option. func (a *Atlas) entDialect(drv dialect.Driver) (sqlDialect, error) { switch a.dialect { case dialect.MySQL: return &MySQL{Driver: drv}, nil case dialect.SQLite: return &SQLite{Driver: drv, WithForeignKeys: a.withForeignKeys}, nil case dialect.Postgres: return &Postgres{Driver: drv}, nil default: return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect) } } func (a *Atlas) pkRange(et *Table) (int64, error) { idx := indexOf(a.types, et.Name) // If the table re-created, re-use its range from // the past. Otherwise, allocate a new id-range. if idx == -1 { if len(a.types) > MaxTypes { return 0, fmt.Errorf("max number of types exceeded: %d", MaxTypes) } idx = len(a.types) a.types = append(a.types, et.Name) } return int64(idx << 32), nil } func setAtChecks(et *Table, at *schema.Table) { if check := et.Annotation.Check; check != "" { at.AddChecks(&schema.Check{ Expr: check, }) } if checks := et.Annotation.Checks; len(et.Annotation.Checks) > 0 { names := make([]string, 0, len(checks)) for name := range checks { names = append(names, name) } sort.Strings(names) for _, name := range names { at.AddChecks(&schema.Check{ Name: name, Expr: checks[name], }) } } } // descIndexes returns a map holding the DESC mapping if exist. func descIndexes(idx *Index) map[string]bool { descs := make(map[string]bool) if idx.Annotation == nil { return descs } // If DESC (without a column) was defined on the // annotation, map it to the single column index. if idx.Annotation.Desc && len(idx.Columns) == 1 { descs[idx.Columns[0].Name] = idx.Annotation.Desc } for column, desc := range idx.Annotation.DescColumns { descs[column] = desc } return descs } // driver decorates the atlas migrate.Driver and adds "diff hooking" and functionality. type diffDriver struct { migrate.Driver hooks []DiffHook // hooks to apply } // RealmDiff creates the diff between two realms. Since Ent does not care about Realms, // not even schema changes, calling this method raises an error. func (r *diffDriver) RealmDiff(_, _ *schema.Realm) ([]schema.Change, error) { return nil, errors.New("sqlDialect does not support working with realms") } // SchemaDiff creates the diff between two schemas, but includes "diff hooks". func (r *diffDriver) SchemaDiff(from, to *schema.Schema) ([]schema.Change, error) { var d Differ = DiffFunc(r.Driver.SchemaDiff) for i := len(r.hooks) - 1; i >= 0; i-- { d = r.hooks[i](d) } return d.Diff(from, to) } // legacyMigrate returns a configured legacy migration engine (before Atlas) to keep backwards compatibility. // // Deprecated: Will be removed alongside legacy migration support. func (a *Atlas) legacyMigrate() (*Migrate, error) { m := &Migrate{ universalID: a.universalID, dropColumns: a.dropColumns, dropIndexes: a.dropIndexes, withFixture: a.withFixture, withForeignKeys: a.withForeignKeys, hooks: a.hooks, atlas: a, } switch a.dialect { case dialect.MySQL: m.sqlDialect = &MySQL{Driver: a.driver} case dialect.SQLite: m.sqlDialect = &SQLite{Driver: a.driver, WithForeignKeys: a.withForeignKeys} case dialect.Postgres: m.sqlDialect = &Postgres{Driver: a.driver} default: return nil, fmt.Errorf("sql/schema: unsupported dialect %q", a.dialect) } return m, nil } ent-0.11.3/dialect/sql/schema/inspect.go000066400000000000000000000046141431500740500200260ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" ) // InspectOption allows for managing schema configuration using functional options. type InspectOption func(inspect *Inspector) // WithSchema provides a schema (named-database) for reading the tables from. func WithSchema(schema string) InspectOption { return func(m *Inspector) { m.schema = schema } } // An Inspector provides methods for inspecting database tables. type Inspector struct { sqlDialect schema string } // NewInspect returns an inspector for the given SQL driver. func NewInspect(d dialect.Driver, opts ...InspectOption) (*Inspector, error) { i := &Inspector{} for _, opt := range opts { opt(i) } switch d.Dialect() { case dialect.MySQL: i.sqlDialect = &MySQL{Driver: d, schema: i.schema} case dialect.SQLite: i.sqlDialect = &SQLite{Driver: d} case dialect.Postgres: i.sqlDialect = &Postgres{Driver: d, schema: i.schema} default: return nil, fmt.Errorf("sql/schema: unsupported dialect %q", d.Dialect()) } return i, nil } // Tables returns the tables in the schema. func (i *Inspector) Tables(ctx context.Context) ([]*Table, error) { names, err := i.tables(ctx) if err != nil { return nil, err } tx := dialect.NopTx(i.sqlDialect) tables := make([]*Table, 0, len(names)) for _, name := range names { t, err := i.table(ctx, tx, name) if err != nil { return nil, err } tables = append(tables, t) } fki, ok := i.sqlDialect.(interface { foreignKeys(context.Context, dialect.Tx, []*Table) error }) if ok { if err := fki.foreignKeys(ctx, tx, tables); err != nil { return nil, err } } return tables, nil } func (i *Inspector) tables(ctx context.Context) ([]string, error) { t, ok := i.sqlDialect.(interface{ tables() sql.Querier }) if !ok { return nil, fmt.Errorf("sql/schema: %q driver does not support inspection", i.Dialect()) } query, args := t.tables().Query() var ( names []string rows = &sql.Rows{} ) if err := i.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("%q driver: reading table names %w", i.Dialect(), err) } defer rows.Close() if err := sql.ScanSlice(rows, &names); err != nil { return nil, err } return names, nil } ent-0.11.3/dialect/sql/schema/inspect_test.go000066400000000000000000000450361431500740500210700ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "math" "path" "testing" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestInspector_Tables(t *testing.T) { tests := []struct { name string options []InspectOption before map[string]func(mysqlMock) tables func(drv string) []*Table wantErr bool }{ { name: "default schema", before: map[string]func(mysqlMock){ dialect.MySQL: func(mock mysqlMock) { mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE())")). WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"})) }, dialect.SQLite: func(mock mysqlMock) { mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")). WithArgs("table"). WillReturnRows(sqlmock.NewRows([]string{"name"})) }, dialect.Postgres: func(mock mysqlMock) { mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA()`)). WillReturnRows(sqlmock.NewRows([]string{"name"})) }, }, tables: func(drv string) []*Table { return nil }, }, { name: "custom schema", options: []InspectOption{WithSchema("public")}, before: map[string]func(mysqlMock){ dialect.MySQL: func(mock mysqlMock) { mock.ExpectQuery(escape("SELECT `TABLE_NAME` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = ?")). WithArgs("public"). WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}). AddRow("users"). AddRow("pets"). AddRow("groups"). AddRow("user_groups")) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")). WithArgs("public", "users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin", nil, nil). AddRow("price", "decimal(6, 4)", "NO", "YES", "NULL", "", "", "", "6", "4"). AddRow("bank_id", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("public", "users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")). WithArgs("public", "pets"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("user_pets", "bigint(20)", "YES", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("public", "pets"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")). WithArgs("public", "groups"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("public", "groups"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?")). WithArgs("public", "user_groups"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("user_id", "bigint(20)", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("group_id", "bigint(20)", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("public", "user_groups"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"})) }, dialect.SQLite: func(mock mysqlMock) { mock.ExpectQuery(escape("SELECT `name` FROM `sqlite_schema` WHERE `type` = ?")). WithArgs("table"). WillReturnRows(sqlmock.NewRows([]string{"name"}). AddRow("users"). AddRow("pets"). AddRow("groups"). AddRow("user_groups")) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("id", "integer", 1, "NULL", 1). AddRow("name", "varchar(255)", 0, "NULL", 0). AddRow("text", "text", 0, "NULL", 0). AddRow("uuid", "uuid", 0, "NULL", 0). AddRow("price", "real", 1, "NULL", 0). AddRow("bank_id", "varchar(255)", 1, "NULL", 0)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('pets') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("id", "integer", 1, "NULL", 1). AddRow("name", "varchar(255)", 0, "NULL", 0). AddRow("user_pets", "integer", 0, "NULL", 0)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('pets')")). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('groups') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("id", "integer", 1, "NULL", 1). AddRow("name", "varchar(255)", 1, "NULL", 0)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('groups')")). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('user_groups') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("user_id", "integer", 1, "NULL", 0). AddRow("group_id", "integer", 1, "NULL", 0)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('user_groups')")). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) }, dialect.Postgres: func(mock mysqlMock) { mock.ExpectQuery(escape(`SELECT "table_name" FROM "information_schema"."tables" WHERE "table_schema" = $1`)). WithArgs("public"). WillReturnRows(sqlmock.NewRows([]string{"name"}). AddRow("users"). AddRow("pets"). AddRow("groups"). AddRow("user_groups")) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)). WithArgs("public", "users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil). AddRow("text", "text", "YES", "NULL", "text", nil, nil, nil). AddRow("uuid", "uuid", "YES", "NULL", "uuid", nil, nil, nil). AddRow("price", "numeric", "NO", "NULL", "numeric", "6", "4", nil). AddRow("bank_id", "character", "NO", "NULL", "bpchar", nil, nil, 20)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "users"))). WithArgs("public"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)). WithArgs("public", "pets"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil). AddRow("user_pets", "bigint", "YES", "NULL", "int8", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "pets"))). WithArgs("public"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("pets_pkey", "id", "t", "t", 0)) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)). WithArgs("public", "groups"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "NO", "NULL", "bpchar", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "groups"))). WithArgs("public"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("groups_pkey", "id", "t", "t", 0)) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = $1 AND "table_name" = $2`)). WithArgs("public", "user_groups"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("user_id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("group_id", "bigint", "NO", "NULL", "int8", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "$1", "user_groups"))). WithArgs("public"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"})) mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "users"))). WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"})) mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "pets"))). WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}). AddRow("public", "pet_users_pets", "pets", "user_pets", "public", "users", "id")) mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "groups"))). WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"})) mock.ExpectQuery(escape(fmt.Sprintf(fkQuery, "user_groups"))). WillReturnRows(sqlmock.NewRows([]string{"table_schema", "constraint_name", "table_name", "column_name", "foreign_table_schema", "foreign_table_name", "foreign_column_name"}). AddRow("public", "user_groups_group_id", "user_groups", "group_id", "public", "groups", "id"). AddRow("public", "user_groups_user_id", "user_groups", "user_id", "public", "users", "id")) }, }, tables: func(drv string) []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "name", Type: field.TypeString, Size: 255, Nullable: true}, {Name: "text", Type: field.TypeString, Size: math.MaxInt32, Nullable: true}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{ dialect.MySQL: "decimal(6,4)", dialect.Postgres: "numeric(6,4)", }}, {Name: "bank_id", Type: field.TypeString, SchemaType: map[string]string{ dialect.Postgres: "varchar(20)", }}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } c2 = []*Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "name", Type: field.TypeString, Size: 255, Nullable: true}, {Name: "user_pets", Type: field.TypeInt64, Nullable: true}, } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], } c3 = []*Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "name", Type: field.TypeString}, } t3 = &Table{ Name: "groups", Columns: c3, PrimaryKey: c3[0:1], } c4 = []*Column{ {Name: "user_id", Type: field.TypeInt64}, {Name: "group_id", Type: field.TypeInt64}, } t4 = &Table{ Name: "user_groups", Columns: c4, } ) // Only postgres currently supports foreign key inspection if drv == dialect.Postgres { t2.ForeignKeys = []*ForeignKey{ { Symbol: "pet_users_pets", Columns: []*Column{c2[2]}, RefTable: t1, RefColumns: []*Column{c1[0]}, }, } t4.ForeignKeys = []*ForeignKey{ { Symbol: "user_groups_group_id", Columns: []*Column{c4[1]}, RefTable: t3, RefColumns: []*Column{c3[0]}, }, { Symbol: "user_groups_user_id", Columns: []*Column{c4[0]}, RefTable: t1, RefColumns: []*Column{c1[0]}, }, } } return []*Table{t1, t2, t3, t4} }, }, } for _, tt := range tests { for drv := range tt.before { t.Run(path.Join(drv, tt.name), func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.before[drv](mysqlMock{mock}) inspect, err := NewInspect(sql.OpenDB(drv, db), tt.options...) require.NoError(t, err) tables, err := inspect.Tables(context.Background()) require.Equal(t, tt.wantErr, err != nil, err) tablesMatch(t, drv, tables, tt.tables(drv)) }) } } } func tablesMatch(t *testing.T, drv string, got, expected []*Table) { require.Equal(t, len(expected), len(got)) for i := range got { columnsMatch(t, drv, got[i].Columns, expected[i].Columns) columnsMatch(t, drv, got[i].PrimaryKey, expected[i].PrimaryKey) foreignKeysMatch(t, drv, got[i].ForeignKeys, expected[i].ForeignKeys) } } func columnsMatch(t *testing.T, drv string, got, expected []*Column) { require.Equal(t, len(expected), len(got)) for i := range got { c1, c2 := got[i], expected[i] require.Equal(t, c2.Name, c1.Name) require.Equal(t, c2.Nullable, c1.Nullable) require.True(t, c1.Type == c2.Type || c1.ConvertibleTo(c2), "mismatched types: %s - %s", c1.Type, c2.Type) if c2.SchemaType[drv] != "" { require.Equal(t, c2.SchemaType[drv], c1.SchemaType[drv]) } } } func foreignKeysMatch(t *testing.T, drv string, expected []*ForeignKey, got []*ForeignKey) { require.Equal(t, len(expected), len(got)) for i := range got { fk1, fk2 := got[i], expected[i] require.Equal(t, fk2.Symbol, fk1.Symbol) require.Equal(t, fk2.RefTable.Name, fk1.RefTable.Name) columnsMatch(t, drv, fk1.Columns, fk2.Columns) columnsMatch(t, drv, fk1.RefColumns, fk2.RefColumns) } } ent-0.11.3/dialect/sql/schema/migrate.go000066400000000000000000000472261431500740500200170ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "math" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" ) const ( // TypeTable defines the table name holding the type information. TypeTable = "ent_types" // MaxTypes defines the max number of types can be created when // defining universal ids. The left 16-bits are reserved. MaxTypes = math.MaxUint16 ) // MigrateOption allows configuring Atlas using functional arguments. type MigrateOption func(*Atlas) // WithGlobalUniqueID sets the universal ids options to the migration. // Defaults to false. func WithGlobalUniqueID(b bool) MigrateOption { return func(a *Atlas) { a.universalID = b } } // WithDropColumn sets the columns dropping option to the migration. // Defaults to false. func WithDropColumn(b bool) MigrateOption { return func(a *Atlas) { a.dropColumns = b } } // WithDropIndex sets the indexes dropping option to the migration. // Defaults to false. func WithDropIndex(b bool) MigrateOption { return func(a *Atlas) { a.dropIndexes = b } } // WithFixture sets the foreign-key renaming option to the migration when upgrading // sqlDialect from v0.1.0 (issue-#285). Defaults to false. // // Deprecated: This option is no longer needed with the Atlas based // migration engine, which now is the default. func WithFixture(b bool) MigrateOption { return func(a *Atlas) { a.withFixture = b } } // WithForeignKeys enables creating foreign-key in ddl. Defaults to true. func WithForeignKeys(b bool) MigrateOption { return func(a *Atlas) { a.withForeignKeys = b } } // WithHooks adds a list of hooks to the schema migration. func WithHooks(hooks ...Hook) MigrateOption { return func(a *Atlas) { a.hooks = append(a.hooks, hooks...) } } type ( // Creator is the interface that wraps the Create method. Creator interface { // Create creates the given tables in the database. See Migrate.Create for more details. Create(context.Context, ...*Table) error } // The CreateFunc type is an adapter to allow the use of ordinary function as Creator. // If f is a function with the appropriate signature, CreateFunc(f) is a Creator that calls f. CreateFunc func(context.Context, ...*Table) error // Hook defines the "create middleware". A function that gets a Creator and returns a Creator. // For example: // // hook := func(next schema.Creator) schema.Creator { // return schema.CreateFunc(func(ctx context.Context, tables ...*schema.Table) error { // fmt.Println("Tables:", tables) // return next.Create(ctx, tables...) // }) // } // Hook func(Creator) Creator ) // Create calls f(ctx, tables...). func (f CreateFunc) Create(ctx context.Context, tables ...*Table) error { return f(ctx, tables...) } // Migrate runs the migration logic for the SQL dialects. // // Deprecated: Use the new Atlas struct instead. type Migrate struct { sqlDialect atlas *Atlas // Atlas this Migrate is based on universalID bool // global unique ids dropColumns bool // drop deleted columns dropIndexes bool // drop deleted indexes withFixture bool // with fks rename fixture withForeignKeys bool // with foreign keys typeRanges []string // types order by their range hooks []Hook // hooks to apply before creation } // Create creates all schema resources in the database. It works in an "append-only" // mode, which means, it only creates tables, appends columns to tables or modifies column types. // // Column can be modified by turning into a NULL from NOT NULL, or having a type conversion not // resulting data altering. From example, changing varchar(255) to varchar(120) is invalid, but // changing varchar(120) to varchar(255) is valid. For more info, see the convert function below. // // Note that SQLite dialect does not support (this moment) the "append-only" mode describe above, // since it's used only for testing. func (m *Migrate) Create(ctx context.Context, tables ...*Table) error { m.setupTables(tables) var creator Creator = CreateFunc(m.create) for i := len(m.hooks) - 1; i >= 0; i-- { creator = m.hooks[i](creator) } return creator.Create(ctx, tables...) } func (m *Migrate) create(ctx context.Context, tables ...*Table) error { if err := m.init(ctx); err != nil { return err } tx, err := m.Tx(ctx) if err != nil { return err } if m.universalID { if err := m.types(ctx, tx); err != nil { return rollback(tx, err) } } if err := m.txCreate(ctx, tx, tables...); err != nil { return rollback(tx, err) } return tx.Commit() } func (m *Migrate) txCreate(ctx context.Context, tx dialect.Tx, tables ...*Table) error { for _, t := range tables { switch exist, err := m.tableExist(ctx, tx, t.Name); { case err != nil: return err case exist: curr, err := m.table(ctx, tx, t.Name) if err != nil { return err } if err := m.verify(ctx, tx, curr); err != nil { return err } if err := m.fixture(ctx, tx, curr, t); err != nil { return err } change, err := m.changeSet(curr, t) if err != nil { return fmt.Errorf("creating changeset for %q: %w", t.Name, err) } if err := m.apply(ctx, tx, t.Name, change); err != nil { return err } default: // !exist query, args := m.tBuilder(t).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create table %q: %w", t.Name, err) } // If global unique identifier is enabled, and it's not // a relation table, allocate a range for the table pk. if m.universalID && len(t.PrimaryKey) == 1 { if err := m.allocPKRange(ctx, tx, t); err != nil { return err } } // indexes. for _, idx := range t.Indexes { query, args := m.addIndex(idx, t.Name).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create index %q: %w", idx.Name, err) } } } } if !m.withForeignKeys { return nil } // Create foreign keys after tables were created/altered, // because circular foreign-key constraints are possible. for _, t := range tables { if len(t.ForeignKeys) == 0 { continue } fks := make([]*ForeignKey, 0, len(t.ForeignKeys)) for _, fk := range t.ForeignKeys { exist, err := m.fkExist(ctx, tx, fk.Symbol) if err != nil { return err } if !exist { fks = append(fks, fk) } } if len(fks) == 0 { continue } b := sql.Dialect(m.Dialect()).AlterTable(t.Name) for _, fk := range fks { b.AddForeignKey(fk.DSL()) } query, args := b.Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create foreign keys for %q: %w", t.Name, err) } } return nil } // apply changes on the given table. func (m *Migrate) apply(ctx context.Context, tx dialect.Tx, table string, change *changes) error { // Constraints should be dropped before dropping columns, because if a column // is a part of multi-column constraints (like, unique index), ALTER TABLE // might fail if the intermediate state violates the constraints. if m.dropIndexes { if pr, ok := m.sqlDialect.(preparer); ok { if err := pr.prepare(ctx, tx, change, table); err != nil { return err } } for _, idx := range change.index.drop { if err := m.dropIndex(ctx, tx, idx, table); err != nil { return fmt.Errorf("drop index of table %q: %w", table, err) } } } var drop []*Column if m.dropColumns { drop = change.column.drop } queries := m.alterColumns(table, change.column.add, change.column.modify, drop) // If there's actual action to execute on ALTER TABLE. for i := range queries { query, args := queries[i].Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("alter table %q: %w", table, err) } } for _, idx := range change.index.add { query, args := m.addIndex(idx, table).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create index %q: %w", table, err) } } return nil } // changes to apply on existing table. type changes struct { // column changes. column struct { add []*Column drop []*Column modify []*Column } // index changes. index struct { add Indexes drop Indexes } } // dropColumn returns the dropped column by name (if any). func (c *changes) dropColumn(name string) (*Column, bool) { for _, col := range c.column.drop { if col.Name == name { return col, true } } return nil, false } // changeSet returns a changes object to be applied on existing table. // It fails if one of the changes is invalid. func (m *Migrate) changeSet(curr, new *Table) (*changes, error) { change := &changes{} // pks. if len(curr.PrimaryKey) != len(new.PrimaryKey) { return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) } for i := range curr.PrimaryKey { if curr.PrimaryKey[i].Name != new.PrimaryKey[i].Name { return nil, fmt.Errorf("cannot change primary key for table: %q", curr.Name) } } // Add or modify columns. for _, c1 := range new.Columns { // Ignore primary keys. if c1.PrimaryKey() { continue } switch c2, ok := curr.column(c1.Name); { case !ok: change.column.add = append(change.column.add, c1) case !c2.Type.Valid(): return nil, fmt.Errorf("invalid type %q for column %q", c2.typ, c2.Name) // Modify a non-unique column to unique. case c1.Unique && !c2.Unique: // Make sure the table does not have unique index for this column // before adding it to the changeset, because there are 2 ways to // configure uniqueness on sqlDialect.Field (using the Unique modifier or // adding rule on the Indexes option). if idx, ok := curr.index(c1.Name); !ok || !idx.Unique { change.index.add.append(&Index{ Name: c1.Name, Unique: true, Columns: []*Column{c1}, columns: []string{c1.Name}, }) } // Modify a unique column to non-unique. case !c1.Unique && c2.Unique: // If the uniqueness was defined on the Indexes option, // or was moved from the Unique modifier to the Indexes. if idx, ok := new.index(c1.Name); ok && idx.Unique { continue } idx, ok := curr.index(c2.Name) if !ok { return nil, fmt.Errorf("missing index to drop for unique column %q", c2.Name) } change.index.drop.append(idx) // Extending column types. case m.needsConversion(c2, c1): if !c2.ConvertibleTo(c1) { return nil, fmt.Errorf("changing column type for %q is invalid (%s != %s)", c1.Name, m.cType(c1), m.cType(c2)) } fallthrough // Change nullability of a column. case c1.Nullable != c2.Nullable: change.column.modify = append(change.column.modify, c1) // Change default value. case c1.Default != nil && c2.Default == nil: change.column.modify = append(change.column.modify, c1) } } // Drop columns. for _, c1 := range curr.Columns { // If a column was dropped, multi-columns indexes that are associated with this column will // no longer behave the same. Therefore, these indexes should be dropped too. There's no need // to do it explicitly (here), because entc will remove them from the schema specification, // and they will be dropped in the block below. if _, ok := new.column(c1.Name); !ok { change.column.drop = append(change.column.drop, c1) } } // Add or modify indexes. for _, idx1 := range new.Indexes { switch idx2, ok := curr.index(idx1.Name); { case !ok: change.index.add.append(idx1) // Changing index cardinality require drop and create. case idx1.Unique != idx2.Unique: change.index.drop.append(idx2) change.index.add.append(idx1) default: im, ok := m.sqlDialect.(interface{ indexModified(old, new *Index) bool }) // If the dialect supports comparing indexes. if ok && im.indexModified(idx2, idx1) { change.index.drop.append(idx2) change.index.add.append(idx1) } } } // Drop indexes. for _, idx := range curr.Indexes { if _, isFK := new.fk(idx.Name); !isFK && !new.hasIndex(idx.Name, idx.realname) { change.index.drop.append(idx) } } return change, nil } // fixture is a special migration code for renaming foreign-key columns (issue-#285). func (m *Migrate) fixture(ctx context.Context, tx dialect.Tx, curr, new *Table) error { d, ok := m.sqlDialect.(fkRenamer) if !m.withFixture || !m.withForeignKeys || !ok { return nil } rename := make(map[string]*Index) for _, fk := range new.ForeignKeys { ok, err := m.fkExist(ctx, tx, fk.Symbol) if err != nil { return fmt.Errorf("checking foreign-key existence %q: %w", fk.Symbol, err) } if !ok { continue } column, err := m.fkColumn(ctx, tx, fk) if err != nil { return err } newcol := fk.Columns[0] if column == newcol.Name { continue } query, args := d.renameColumn(curr, &Column{Name: column}, newcol).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("rename column %q: %w", column, err) } prev, ok := curr.column(column) if !ok { continue } // Find all indexes that ~maybe need to be renamed. for _, idx := range prev.indexes { switch _, ok := new.index(idx.Name); { // Ignore indexes that exist in the schema, PKs. case ok || idx.primary: // Index that was created implicitly for a unique // column needs to be renamed to the column name. case d.isImplicitIndex(idx, prev): idx2 := &Index{Name: newcol.Name, Unique: true, Columns: []*Column{newcol}} query, args := d.renameIndex(curr, idx, idx2).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("rename index %q: %w", prev.Name, err) } idx.Name = idx2.Name default: rename[idx.Name] = idx } } // Update the name of the loaded column, so `changeSet` won't create it. prev.Name = newcol.Name } // Go over the indexes that need to be renamed // and find their ~identical in the new schema. for _, idx := range rename { Find: // Find its ~identical in the new schema, and rename it // if it doesn't exist. for _, idx2 := range new.Indexes { if _, ok := curr.index(idx2.Name); ok { continue } if idx.sameAs(idx2) { query, args := d.renameIndex(curr, idx, idx2).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("rename index %q: %w", idx.Name, err) } idx.Name = idx2.Name break Find } } } return nil } // verify that the auto-increment counter is correct for table with universal-id support. func (m *Migrate) verify(ctx context.Context, tx dialect.Tx, t *Table) error { vr, ok := m.sqlDialect.(verifyRanger) if !ok || !m.universalID { return nil } id := indexOf(m.typeRanges, t.Name) if id == -1 { return nil } return vr.verifyRange(ctx, tx, t, int64(id<<32)) } // types loads the type list from the type store. It will create the types table, if it does not exist yet. func (m *Migrate) types(ctx context.Context, tx dialect.ExecQuerier) error { exists, err := m.tableExist(ctx, tx, TypeTable) if err != nil { return err } if !exists { t := NewTable(TypeTable). AddPrimary(&Column{Name: "id", Type: field.TypeUint, Increment: true}). AddColumn(&Column{Name: "type", Type: field.TypeString, Unique: true}) query, args := m.tBuilder(t).Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("create types table: %w", err) } return nil } rows := &sql.Rows{} query, args := sql.Dialect(m.Dialect()). Select("type").From(sql.Table(TypeTable)).OrderBy(sql.Asc("id")).Query() if err := tx.Query(ctx, query, args, rows); err != nil { return fmt.Errorf("query types table: %w", err) } defer rows.Close() return sql.ScanSlice(rows, &m.typeRanges) } func (m *Migrate) allocPKRange(ctx context.Context, conn dialect.ExecQuerier, t *Table) error { r, err := m.pkRange(ctx, conn, t) if err != nil { return err } return m.setRange(ctx, conn, t, r) } func (m *Migrate) pkRange(ctx context.Context, conn dialect.ExecQuerier, t *Table) (int64, error) { id := indexOf(m.typeRanges, t.Name) // If the table re-created, re-use its range from // the past. Otherwise, allocate a new id-range. if id == -1 { if len(m.typeRanges) > MaxTypes { return 0, fmt.Errorf("max number of types exceeded: %d", MaxTypes) } query, args := sql.Dialect(m.Dialect()).Insert(TypeTable).Columns("type").Values(t.Name).Query() if err := conn.Exec(ctx, query, args, nil); err != nil { return 0, fmt.Errorf("insert into ent_types: %w", err) } id = len(m.typeRanges) m.typeRanges = append(m.typeRanges, t.Name) } return int64(id << 32), nil } // fkColumn returns the column name of a foreign-key. func (m *Migrate) fkColumn(ctx context.Context, tx dialect.Tx, fk *ForeignKey) (string, error) { t1 := sql.Table("INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS t1").Unquote().As("t1") t2 := sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS t2").Unquote().As("t2") query, args := sql.Dialect(m.Dialect()). Select("column_name"). From(t1). Join(t2). On(t1.C("constraint_name"), t2.C("constraint_name")). Where(sql.And( sql.EQ(t2.C("constraint_type"), sql.Raw("'FOREIGN KEY'")), m.sqlDialect.(fkRenamer).matchSchema(t2.C("table_schema")), m.sqlDialect.(fkRenamer).matchSchema(t1.C("table_schema")), sql.EQ(t2.C("constraint_name"), fk.Symbol), )). Query() rows := &sql.Rows{} if err := tx.Query(ctx, query, args, rows); err != nil { return "", fmt.Errorf("reading foreign-key %q column: %w", fk.Symbol, err) } defer rows.Close() column, err := sql.ScanString(rows) if err != nil { return "", fmt.Errorf("scanning foreign-key %q column: %w", fk.Symbol, err) } return column, nil } // setup ensures the table is configured properly, like table columns // are linked to their indexes, and PKs columns are defined. func (m *Migrate) setupTables(tables []*Table) { m.atlas.setupTables(tables) } // rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred. func rollback(tx dialect.Tx, err error) error { err = fmt.Errorf("sql/schema: %w", err) if rerr := tx.Rollback(); rerr != nil { err = fmt.Errorf("%w: %v", err, rerr) } return err } // exist checks if the given COUNT query returns a value >= 1. func exist(ctx context.Context, conn dialect.ExecQuerier, query string, args ...any) (bool, error) { rows := &sql.Rows{} if err := conn.Query(ctx, query, args, rows); err != nil { return false, fmt.Errorf("reading schema information %w", err) } defer rows.Close() n, err := sql.ScanInt(rows) if err != nil { return false, err } return n > 0, nil } func indexOf(a []string, s string) int { for i := range a { if a[i] == s { return i } } return -1 } type sqlDialect interface { atBuilder dialect.Driver init(context.Context) error table(context.Context, dialect.Tx, string) (*Table, error) tableExist(context.Context, dialect.ExecQuerier, string) (bool, error) fkExist(context.Context, dialect.Tx, string) (bool, error) setRange(context.Context, dialect.ExecQuerier, *Table, int64) error dropIndex(context.Context, dialect.Tx, *Index, string) error // table, column and index builder per dialect. cType(*Column) string tBuilder(*Table) *sql.TableBuilder addIndex(*Index, string) *sql.IndexBuilder alterColumns(table string, add, modify, drop []*Column) sql.Queries needsConversion(*Column, *Column) bool } type preparer interface { prepare(context.Context, dialect.Tx, *changes, string) error } // fkRenamer is used by the fixture migration (to solve #285), // and it's implemented by the different dialects for renaming FKs. type fkRenamer interface { matchSchema(...string) *sql.Predicate isImplicitIndex(*Index, *Column) bool renameIndex(*Table, *Index, *Index) sql.Querier renameColumn(*Table, *Column, *Column) sql.Querier } // verifyRanger wraps the method for verifying global-id range correctness. type verifyRanger interface { verifyRange(context.Context, dialect.ExecQuerier, *Table, int64) error } ent-0.11.3/dialect/sql/schema/migrate_test.go000066400000000000000000000336101431500740500210460ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "os" "path/filepath" "strings" "testing" "text/template" "time" "ariga.io/atlas/sql/migrate" "ariga.io/atlas/sql/schema" "ariga.io/atlas/sql/sqltool" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" _ "github.com/mattn/go-sqlite3" ) func TestMigrateHookOmitTable(t *testing.T) { db, mk, err := sqlmock.New() require.NoError(t, err) tables := []*Table{{Name: "users"}, {Name: "pets"}} mock := mysqlMock{mk} mock.start("5.7.23") mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() m, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator { return CreateFunc(func(ctx context.Context, tables ...*Table) error { return next.Create(ctx, tables[1]) }) }), WithAtlas(false)) require.NoError(t, err) err = m.Create(context.Background(), tables...) require.NoError(t, err) } func TestMigrateHookAddTable(t *testing.T) { db, mk, err := sqlmock.New() require.NoError(t, err) tables := []*Table{{Name: "users"}} mock := mysqlMock{mk} mock.start("5.7.23") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`() CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() m, err := NewMigrate(sql.OpenDB("mysql", db), WithHooks(func(next Creator) Creator { return CreateFunc(func(ctx context.Context, tables ...*Table) error { return next.Create(ctx, tables[0], &Table{Name: "pets"}) }) }), WithAtlas(false)) require.NoError(t, err) err = m.Create(context.Background(), tables...) require.NoError(t, err) } func TestMigrate_Formatter(t *testing.T) { db, _, err := sqlmock.New() require.NoError(t, err) // If no formatter is given it will be set according to the given migration directory implementation. for _, tt := range []struct { dir migrate.Dir fmt migrate.Formatter }{ {&migrate.LocalDir{}, sqltool.GolangMigrateFormatter}, {&sqltool.GolangMigrateDir{}, sqltool.GolangMigrateFormatter}, {&sqltool.GooseDir{}, sqltool.GooseFormatter}, {&sqltool.DBMateDir{}, sqltool.DBMateFormatter}, {&sqltool.FlywayDir{}, sqltool.FlywayFormatter}, {&sqltool.LiquibaseDir{}, sqltool.LiquibaseFormatter}, {struct{ migrate.Dir }{}, sqltool.GolangMigrateFormatter}, // default one if migration dir is unknown } { m, err := NewMigrate(sql.OpenDB("", db), WithDir(tt.dir)) require.NoError(t, err) require.Equal(t, tt.fmt, m.fmt) } // If a formatter is given, it is not overridden. m, err := NewMigrate(sql.OpenDB("", db), WithDir(&migrate.LocalDir{}), WithFormatter(migrate.DefaultFormatter)) require.NoError(t, err) require.Equal(t, migrate.DefaultFormatter, m.fmt) } func TestMigrate_DiffJoinTableAllocationBC(t *testing.T) { // Due to a bug in previous versions, if the universal ID option was enabled and the schema did contain an M2M // relation, the join table would have had an entry for the join table in the types table. This test ensures, // that the PK range allocated for the join table stays in place, since it's removal would break existing projects // due to shifted ranges. db, err := sql.Open(dialect.SQLite, "file:test?mode=memory&_fk=1") require.NoError(t, err) // Mock an existing database with an allocation for a join table. for _, stmt := range []string{ "CREATE TABLE `groups` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `name` text NOT NULL);", "CREATE INDEX `short` ON `groups` (`id`);", "CREATE INDEX `long____________________________1cb2e7e47a309191385af4ad320875b1` ON `groups` (`id`);", "CREATE TABLE `users` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `name` text NOT NULL);", "INSERT INTO sqlite_sequence (name, seq) VALUES (\"users\", 4294967296);", "CREATE TABLE `user_groups` (`user_id` integer NOT NULL, `group_id` integer NOT NULL, PRIMARY KEY (`user_id`, `group_id`), CONSTRAINT `user_groups_user_id` FOREIGN KEY (`user_id`) REFERENCES `users` (`id`) ON DELETE CASCADE, CONSTRAINT `user_groups_group_id` FOREIGN KEY (`group_id`) REFERENCES `groups` (`id`) ON DELETE CASCADE);", "INSERT INTO sqlite_sequence (name, seq) VALUES (\"user_groups\", 8589934592);", "CREATE TABLE `ent_types` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `type` text NOT NULL);", "CREATE UNIQUE INDEX `ent_types_type_key` ON `ent_types` (`type`);", "INSERT INTO `ent_types` (`type`) VALUES ('groups'), ('users'), ('user_groups');", "INSERT INTO `groups` (`name`) VALUES ('seniors'), ('juniors')", "INSERT INTO `users` (`name`) VALUES ('masseelch'), ('a8m'), ('rotemtam')", "INSERT INTO `user_groups` (`user_id`, `group_id`) VALUES (4294967297, 1), (4294967298, 1), (4294967299, 2)", } { _, err := db.ExecContext(context.Background(), stmt) require.NoError(t, err) } // Expect to have no changes when migration runs with fix. m, err := NewMigrate(db, WithGlobalUniqueID(true), WithDiffHook(func(next Differ) Differ { return DiffFunc(func(current, desired *schema.Schema) ([]schema.Change, error) { changes, err := next.Diff(current, desired) if err != nil { return nil, err } require.Len(t, changes, 0) return changes, nil }) })) require.NoError(t, err) require.NoError(t, m.Create(context.Background(), tables...)) // Expect to have no changes to the allocation when the join table is dropped. m, err = NewMigrate(db, WithGlobalUniqueID(true)) require.NoError(t, err) require.NoError(t, m.Create(context.Background(), groupsTable, usersTable)) rows, err := db.QueryContext(context.Background(), "SELECT `type` from `ent_types` ORDER BY `id` ASC") require.NoError(t, err) var types []string for rows.Next() { var typ string require.NoError(t, rows.Scan(&typ)) types = append(types, typ) } require.NoError(t, rows.Err()) require.Equal(t, []string{"groups", "users", "user_groups"}, types) } var ( groupsColumns = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, } groupsTable = &Table{ Name: "groups", Columns: groupsColumns, PrimaryKey: []*Column{groupsColumns[0]}, Indexes: []*Index{ { Name: "short", Columns: []*Column{groupsColumns[0]}}, { Name: "long_" + strings.Repeat("_", 60), Columns: []*Column{groupsColumns[0]}, }, }, } usersColumns = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, } usersTable = &Table{ Name: "users", Columns: usersColumns, PrimaryKey: []*Column{usersColumns[0]}, } userGroupsColumns = []*Column{ {Name: "user_id", Type: field.TypeInt}, {Name: "group_id", Type: field.TypeInt}, } userGroupsTable = &Table{ Name: "user_groups", Columns: userGroupsColumns, PrimaryKey: []*Column{userGroupsColumns[0], userGroupsColumns[1]}, ForeignKeys: []*ForeignKey{ { Symbol: "user_groups_user_id", Columns: []*Column{userGroupsColumns[0]}, RefColumns: []*Column{usersColumns[0]}, OnDelete: Cascade, }, { Symbol: "user_groups_group_id", Columns: []*Column{userGroupsColumns[1]}, RefColumns: []*Column{groupsColumns[0]}, OnDelete: Cascade, }, }, } tables = []*Table{ groupsTable, usersTable, userGroupsTable, } petColumns = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, } petsTable = &Table{ Name: "pets", Columns: petColumns, PrimaryKey: petColumns, } ) func init() { userGroupsTable.ForeignKeys[0].RefTable = usersTable userGroupsTable.ForeignKeys[1].RefTable = groupsTable } func TestMigrate_Diff(t *testing.T) { ctx := context.Background() db, err := sql.Open(dialect.SQLite, "file:test?mode=memory&_fk=1") require.NoError(t, err) p := t.TempDir() d, err := migrate.NewLocalDir(p) require.NoError(t, err) m, err := NewMigrate(db, WithDir(d)) require.NoError(t, err) require.NoError(t, m.Diff(ctx, &Table{Name: "users"})) v := time.Now().UTC().Format("20060102150405") requireFileEqual(t, filepath.Join(p, v+"_changes.up.sql"), "-- create \"users\" table\nCREATE TABLE `users` (, PRIMARY KEY ());\n") requireFileEqual(t, filepath.Join(p, v+"_changes.down.sql"), "-- reverse: create \"users\" table\nDROP TABLE `users`;\n") require.FileExists(t, filepath.Join(p, migrate.HashFileName)) // Test integrity file. p = t.TempDir() d, err = migrate.NewLocalDir(p) require.NoError(t, err) m, err = NewMigrate(db, WithDir(d)) require.NoError(t, err) require.NoError(t, m.Diff(ctx, &Table{Name: "users"})) requireFileEqual(t, filepath.Join(p, v+"_changes.up.sql"), "-- create \"users\" table\nCREATE TABLE `users` (, PRIMARY KEY ());\n") requireFileEqual(t, filepath.Join(p, v+"_changes.down.sql"), "-- reverse: create \"users\" table\nDROP TABLE `users`;\n") require.FileExists(t, filepath.Join(p, migrate.HashFileName)) require.NoError(t, d.WriteFile("tmp.sql", nil)) require.ErrorIs(t, m.Diff(ctx, &Table{Name: "users"}), migrate.ErrChecksumMismatch) p = t.TempDir() d, err = migrate.NewLocalDir(p) require.NoError(t, err) f, err := migrate.NewTemplateFormatter( template.Must(template.New("").Parse("{{ .Name }}.sql")), template.Must(template.New("").Parse( `{{ range .Changes }}{{ printf "%s;\n" .Cmd }}{{ end }}`, )), ) require.NoError(t, err) // Join tables (mapping between user and group) will not result in an entry to the types table. m, err = NewMigrate(db, WithFormatter(f), WithDir(d), WithGlobalUniqueID(true)) require.NoError(t, err) require.NoError(t, m.Diff(ctx, tables...)) changesSQL := strings.Join([]string{ "CREATE TABLE `groups` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `name` text NOT NULL);", "CREATE INDEX `short` ON `groups` (`id`);", "CREATE INDEX `long____________________________1cb2e7e47a309191385af4ad320875b1` ON `groups` (`id`);", "CREATE TABLE `users` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `name` text NOT NULL);", fmt.Sprintf("INSERT INTO sqlite_sequence (name, seq) VALUES (\"users\", %d);", 1<<32), "CREATE TABLE `user_groups` (`user_id` integer NOT NULL, `group_id` integer NOT NULL, PRIMARY KEY (`user_id`, `group_id`), CONSTRAINT `user_groups_user_id` FOREIGN KEY (`user_id`) REFERENCES `users` (`id`) ON DELETE CASCADE, CONSTRAINT `user_groups_group_id` FOREIGN KEY (`group_id`) REFERENCES `groups` (`id`) ON DELETE CASCADE);", "CREATE TABLE `ent_types` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT, `type` text NOT NULL);", "CREATE UNIQUE INDEX `ent_types_type_key` ON `ent_types` (`type`);", "INSERT INTO `ent_types` (`type`) VALUES ('groups'), ('users');", "", }, "\n") requireFileEqual(t, filepath.Join(p, "changes.sql"), changesSQL) // Adding another node will result in a new entry to the TypeTable (without actually creating it). _, err = db.ExecContext(ctx, changesSQL, nil, nil) require.NoError(t, err) require.NoError(t, m.NamedDiff(ctx, "changes_2", petsTable)) requireFileEqual(t, filepath.Join(p, "changes_2.sql"), strings.Join([]string{ "CREATE TABLE `pets` (`id` integer NOT NULL PRIMARY KEY AUTOINCREMENT);", fmt.Sprintf("INSERT INTO sqlite_sequence (name, seq) VALUES (\"pets\", %d);", 2<<32), "INSERT INTO `ent_types` (`type`) VALUES ('pets');", "", }, "\n")) // Checksum will be updated as well. require.NoError(t, migrate.Validate(d)) } func requireFileEqual(t *testing.T, name, contents string) { c, err := os.ReadFile(name) require.NoError(t, err) require.Equal(t, contents, string(c)) } func TestMigrateWithoutForeignKeys(t *testing.T) { tbl := &schema.Table{ Name: "tbl", Columns: []*schema.Column{ {Name: "id", Type: &schema.ColumnType{Type: &schema.IntegerType{T: "bigint"}}}, }, } fk := &schema.ForeignKey{ Symbol: "fk", Table: tbl, Columns: tbl.Columns[1:], RefTable: tbl, RefColumns: tbl.Columns[:1], OnUpdate: schema.NoAction, OnDelete: schema.Cascade, } tbl.ForeignKeys = append(tbl.ForeignKeys, fk) t.Run("AddTable", func(t *testing.T) { mdiff := DiffFunc(func(_, _ *schema.Schema) ([]schema.Change, error) { return []schema.Change{ &schema.AddTable{ T: tbl, }, }, nil }) df, err := withoutForeignKeys(mdiff).Diff(nil, nil) require.NoError(t, err) require.Len(t, df, 1) actual, ok := df[0].(*schema.AddTable) require.True(t, ok) require.Nil(t, actual.T.ForeignKeys) }) t.Run("ModifyTable", func(t *testing.T) { mdiff := DiffFunc(func(_, _ *schema.Schema) ([]schema.Change, error) { return []schema.Change{ &schema.ModifyTable{ T: tbl, Changes: []schema.Change{ &schema.AddIndex{ I: &schema.Index{ Name: "id_key", Parts: []*schema.IndexPart{ {C: tbl.Columns[0]}, }, }, }, &schema.DropForeignKey{ F: fk, }, &schema.AddForeignKey{ F: fk, }, &schema.ModifyForeignKey{ From: fk, To: fk, Change: schema.ChangeRefColumn, }, &schema.AddColumn{ C: &schema.Column{Name: "name", Type: &schema.ColumnType{Type: &schema.StringType{T: "varchar(255)"}}}, }, }, }, }, nil }) df, err := withoutForeignKeys(mdiff).Diff(nil, nil) require.NoError(t, err) require.Len(t, df, 1) actual, ok := df[0].(*schema.ModifyTable) require.True(t, ok) require.Len(t, actual.Changes, 2) addIndex, ok := actual.Changes[0].(*schema.AddIndex) require.True(t, ok) require.EqualValues(t, "id_key", addIndex.I.Name) addColumn, ok := actual.Changes[1].(*schema.AddColumn) require.True(t, ok) require.EqualValues(t, "name", addColumn.C.Name) }) } ent-0.11.3/dialect/sql/schema/mysql.go000066400000000000000000000705121431500740500175260ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "math" "strconv" "strings" "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "ariga.io/atlas/sql/migrate" "ariga.io/atlas/sql/mysql" "ariga.io/atlas/sql/schema" ) // MySQL is a MySQL migration driver. type MySQL struct { dialect.Driver schema string version string } // init loads the MySQL version from the database for later use in the migration process. func (d *MySQL) init(ctx context.Context) error { rows := &sql.Rows{} if err := d.Query(ctx, "SHOW VARIABLES LIKE 'version'", []any{}, rows); err != nil { return fmt.Errorf("mysql: querying mysql version %w", err) } defer rows.Close() if !rows.Next() { if err := rows.Err(); err != nil { return err } return fmt.Errorf("mysql: version variable was not found") } version := make([]string, 2) if err := rows.Scan(&version[0], &version[1]); err != nil { return fmt.Errorf("mysql: scanning mysql version: %w", err) } d.version = version[1] return nil } func (d *MySQL) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) { query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")). Where(sql.And( d.matchSchema(), sql.EQ("TABLE_NAME", name), )).Query() return exist(ctx, conn, query, args...) } func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) { query, args := sql.Select(sql.Count("*")).From(sql.Table("TABLE_CONSTRAINTS").Schema("INFORMATION_SCHEMA")). Where(sql.And( d.matchSchema(), sql.EQ("CONSTRAINT_TYPE", "FOREIGN KEY"), sql.EQ("CONSTRAINT_NAME", name), )).Query() return exist(ctx, tx, query, args...) } // table loads the current table description from the database. func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { rows := &sql.Rows{} query, args := sql.Select( "column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale", ). From(sql.Table("COLUMNS").Schema("INFORMATION_SCHEMA")). Where(sql.And( d.matchSchema(), sql.EQ("TABLE_NAME", name)), ).Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("mysql: reading table description %w", err) } // Call Close in cases of failures (Close is idempotent). defer rows.Close() t := NewTable(name) for rows.Next() { c := &Column{} if err := d.scanColumn(c, rows); err != nil { return nil, fmt.Errorf("mysql: %w", err) } t.AddColumn(c) } if err := rows.Err(); err != nil { return nil, err } if err := rows.Close(); err != nil { return nil, fmt.Errorf("mysql: closing rows %w", err) } indexes, err := d.indexes(ctx, tx, t) if err != nil { return nil, err } // Add and link indexes to table columns. for _, idx := range indexes { t.addIndex(idx) } if _, ok := d.mariadb(); ok { if err := d.normalizeJSON(ctx, tx, t); err != nil { return nil, err } } return t, nil } // table loads the table indexes from the database. func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, t *Table) ([]*Index, error) { rows := &sql.Rows{} query, args := sql.Select("index_name", "column_name", "sub_part", "non_unique", "seq_in_index"). From(sql.Table("STATISTICS").Schema("INFORMATION_SCHEMA")). Where(sql.And( d.matchSchema(), sql.EQ("TABLE_NAME", t.Name), )). OrderBy("index_name", "seq_in_index"). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("mysql: reading index description %w", err) } defer rows.Close() idx, err := d.scanIndexes(rows, t) if err != nil { return nil, fmt.Errorf("mysql: %w", err) } return idx, nil } func (d *MySQL) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error { return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", t.Name, value), []any{}, nil) } func (d *MySQL) verifyRange(ctx context.Context, tx dialect.ExecQuerier, t *Table, expected int64) error { if expected == 0 { return nil } rows := &sql.Rows{} query, args := sql.Select("AUTO_INCREMENT"). From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")). Where(sql.And( d.matchSchema(), sql.EQ("TABLE_NAME", t.Name), )). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return fmt.Errorf("mysql: query auto_increment %w", err) } // Call Close in cases of failures (Close is idempotent). defer rows.Close() actual := &sql.NullInt64{} if err := sql.ScanOne(rows, actual); err != nil { return fmt.Errorf("mysql: scan auto_increment %w", err) } if err := rows.Close(); err != nil { return err } // Table is empty and auto-increment is not configured. This can happen // because MySQL (< 8.0) stores the auto-increment counter in main memory // (not persistent), and the value is reset on restart (if table is empty). if actual.Int64 <= 1 { return d.setRange(ctx, tx, t, expected) } return nil } // tBuilder returns the MySQL DSL query for table creation. func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder { b := sql.CreateTable(t.Name).IfNotExists() for _, c := range t.Columns { b.Column(d.addColumn(c)) } for _, pk := range t.PrimaryKey { b.PrimaryKey(pk.Name) } // Charset and collation config on MySQL table. // These options can be overridden by the entsql annotation. b.Charset("utf8mb4").Collate("utf8mb4_bin") if t.Annotation != nil { if charset := t.Annotation.Charset; charset != "" { b.Charset(charset) } if collate := t.Annotation.Collation; collate != "" { b.Collate(collate) } if opts := t.Annotation.Options; opts != "" { b.Options(opts) } addChecks(b, t.Annotation) } return b } // cType returns the MySQL string type for the given column. func (d *MySQL) cType(c *Column) (t string) { if c.SchemaType != nil && c.SchemaType[dialect.MySQL] != "" { // MySQL returns the column type lower cased. return strings.ToLower(c.SchemaType[dialect.MySQL]) } switch c.Type { case field.TypeBool: t = "boolean" case field.TypeInt8: t = "tinyint" case field.TypeUint8: t = "tinyint unsigned" case field.TypeInt16: t = "smallint" case field.TypeUint16: t = "smallint unsigned" case field.TypeInt32: t = "int" case field.TypeUint32: t = "int unsigned" case field.TypeInt, field.TypeInt64: t = "bigint" case field.TypeUint, field.TypeUint64: t = "bigint unsigned" case field.TypeBytes: size := int64(math.MaxUint16) if c.Size > 0 { size = c.Size } switch { case size <= math.MaxUint8: t = "tinyblob" case size <= math.MaxUint16: t = "blob" case size < 1<<24: t = "mediumblob" case size <= math.MaxUint32: t = "longblob" } case field.TypeJSON: t = "json" if compareVersions(d.version, "5.7.8") == -1 { t = "longblob" } case field.TypeString: size := c.Size if size == 0 { size = d.defaultSize(c) } switch { case c.typ == "tinytext", c.typ == "text": t = c.typ case size <= math.MaxUint16: t = fmt.Sprintf("varchar(%d)", size) case size == 1<<24-1: t = "mediumtext" default: t = "longtext" } case field.TypeFloat32, field.TypeFloat64: t = c.scanTypeOr("double") case field.TypeTime: t = c.scanTypeOr("timestamp") // In MariaDB or in MySQL < v8.0.2, the TIMESTAMP column has both `DEFAULT CURRENT_TIMESTAMP` // and `ON UPDATE CURRENT_TIMESTAMP` if neither is specified explicitly. this behavior is // suppressed if the column is defined with a `DEFAULT` clause or with the `NULL` attribute. if _, maria := d.mariadb(); maria || compareVersions(d.version, "8.0.2") == -1 && c.Default == nil { c.Nullable = c.Attr == "" } case field.TypeEnum: values := make([]string, len(c.Enums)) for i, e := range c.Enums { values[i] = fmt.Sprintf("'%s'", e) } t = fmt.Sprintf("enum(%s)", strings.Join(values, ", ")) case field.TypeUUID: t = "char(36) binary" case field.TypeOther: t = c.typ default: panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name)) } return t } // addColumn returns the DSL query for adding the given column to a table. // The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable]. func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder { b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr) c.unique(b) if c.Increment { b.Attr("AUTO_INCREMENT") } c.nullable(b) c.defaultValue(b) if c.Collation != "" { b.Attr("COLLATE " + c.Collation) } if c.Type == field.TypeJSON { // Manually add a `CHECK` clause for older versions of MariaDB for validating the // JSON documents. This constraint is automatically included from version 10.4.3. if version, ok := d.mariadb(); ok && compareVersions(version, "10.4.3") == -1 { b.Check(func(b *sql.Builder) { b.WriteString("JSON_VALID(").Ident(c.Name).WriteByte(')') }) } } return b } // addIndex returns the querying for adding an index to MySQL. func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder { idx := sql.CreateIndex(i.Name).Table(table) if i.Unique { idx.Unique() } parts := indexParts(i) for _, c := range i.Columns { part, ok := parts[c.Name] if !ok || part == 0 { idx.Column(c.Name) } else { idx.Column(fmt.Sprintf("%s(%d)", idx.Builder.Quote(c.Name), part)) } } return idx } // dropIndex drops a MySQL index. func (d *MySQL) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error { query, args := idx.DropBuilder(table).Query() return tx.Exec(ctx, query, args, nil) } // prepare runs preparation work that needs to be done to apply the change-set. func (d *MySQL) prepare(ctx context.Context, tx dialect.Tx, change *changes, table string) error { for _, idx := range change.index.drop { switch n := len(idx.columns); { case n == 0: return fmt.Errorf("index %q has no columns", idx.Name) case n > 1: continue // not a foreign-key index. } var qr sql.Querier Switch: switch col, ok := change.dropColumn(idx.columns[0]); { // If both the index and the column need to be dropped, the foreign-key // constraint that is associated with them need to be dropped as well. case ok: names, err := d.fkNames(ctx, tx, table, col.Name) if err != nil { return err } if len(names) == 1 { qr = sql.AlterTable(table).DropForeignKey(names[0]) } // If the uniqueness was dropped from a foreign-key column, // create a "simple index" if no other index exist for it. case !ok && idx.Unique && len(idx.Columns) > 0: col := idx.Columns[0] for _, idx2 := range col.indexes { if idx2 != idx && len(idx2.columns) == 1 { break Switch } } names, err := d.fkNames(ctx, tx, table, col.Name) if err != nil { return err } if len(names) == 1 { qr = sql.CreateIndex(names[0]).Table(table).Columns(col.Name) } } if qr != nil { query, args := qr.Query() if err := tx.Exec(ctx, query, args, nil); err != nil { return err } } } return nil } // scanColumn scans the column information from MySQL column description. func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error { var ( nullable sql.NullString defaults sql.NullString numericPrecision sql.NullInt64 numericScale sql.NullInt64 ) if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}, &numericPrecision, &numericScale); err != nil { return fmt.Errorf("scanning column description: %w", err) } c.Unique = c.UniqueKey() if nullable.Valid { c.Nullable = nullable.String == "YES" } if c.typ == "" { return fmt.Errorf("missing type information for column %q", c.Name) } parts, size, unsigned, err := parseColumn(c.typ) if err != nil { return err } switch parts[0] { case "mediumint", "int": c.Type = field.TypeInt32 if unsigned { c.Type = field.TypeUint32 } case "smallint": c.Type = field.TypeInt16 if unsigned { c.Type = field.TypeUint16 } case "bigint": c.Type = field.TypeInt64 if unsigned { c.Type = field.TypeUint64 } case "tinyint": switch { case size == 1: c.Type = field.TypeBool case unsigned: c.Type = field.TypeUint8 default: c.Type = field.TypeInt8 } case "double", "float": c.Type = field.TypeFloat64 case "numeric", "decimal": c.Type = field.TypeFloat64 // If precision is specified then we should take that into account. if numericPrecision.Valid { schemaType := fmt.Sprintf("%s(%d,%d)", parts[0], numericPrecision.Int64, numericScale.Int64) c.SchemaType = map[string]string{dialect.MySQL: schemaType} } case "time", "timestamp", "date", "datetime": c.Type = field.TypeTime // The mapping from schema defaults to database // defaults is not supported for TypeTime fields. defaults = sql.NullString{} case "tinyblob": c.Size = math.MaxUint8 c.Type = field.TypeBytes case "blob": c.Size = math.MaxUint16 c.Type = field.TypeBytes case "mediumblob": c.Size = 1<<24 - 1 c.Type = field.TypeBytes case "longblob": c.Size = math.MaxUint32 c.Type = field.TypeBytes case "binary", "varbinary": c.Type = field.TypeBytes c.Size = size case "varchar": c.Type = field.TypeString c.Size = size case "text": c.Size = math.MaxUint16 c.Type = field.TypeString case "mediumtext": c.Size = 1<<24 - 1 c.Type = field.TypeString case "longtext": c.Size = math.MaxInt32 c.Type = field.TypeString case "json": c.Type = field.TypeJSON case "enum": c.Type = field.TypeEnum // Parse the enum values according to the MySQL format. // github.com/mysql/mysql-server/blob/8.0/sql/field.cc#Field_enum::sql_type values := strings.TrimSuffix(strings.TrimPrefix(c.typ, "enum("), ")") if values == "" { return fmt.Errorf("mysql: unexpected enum type: %q", c.typ) } parts := strings.Split(values, "','") for i := range parts { c.Enums = append(c.Enums, strings.Trim(parts[i], "'")) } case "char": c.Type = field.TypeOther // UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens). if size == 36 { c.Type = field.TypeUUID } case "point", "geometry", "linestring", "polygon": c.Type = field.TypeOther default: return fmt.Errorf("unknown column type %q for version %q", parts[0], d.version) } if defaults.Valid { return c.ScanDefault(defaults.String) } return nil } // scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows, // should return the following 5 columns: INDEX_NAME, COLUMN_NAME, SUB_PART, NON_UNIQUE, // SEQ_IN_INDEX. SEQ_IN_INDEX specifies the position of the column in the index columns. func (d *MySQL) scanIndexes(rows *sql.Rows, t *Table) (Indexes, error) { var ( i Indexes names = make(map[string]*Index) ) for rows.Next() { var ( name string column string nonuniq bool seqindex int subpart sql.NullInt64 ) if err := rows.Scan(&name, &column, &subpart, &nonuniq, &seqindex); err != nil { return nil, fmt.Errorf("scanning index description: %w", err) } // Skip primary keys. if name == "PRIMARY" { c, ok := t.column(column) if !ok { return nil, fmt.Errorf("missing primary-key column: %q", column) } t.PrimaryKey = append(t.PrimaryKey, c) continue } idx, ok := names[name] if !ok { idx = &Index{Name: name, Unique: !nonuniq, Annotation: &entsql.IndexAnnotation{}} i = append(i, idx) names[name] = idx } idx.columns = append(idx.columns, column) if subpart.Int64 > 0 { if idx.Annotation.PrefixColumns == nil { idx.Annotation.PrefixColumns = make(map[string]uint) } idx.Annotation.PrefixColumns[column] = uint(subpart.Int64) } } if err := rows.Err(); err != nil { return nil, err } return i, nil } // isImplicitIndex reports if the index was created implicitly for the unique column. func (d *MySQL) isImplicitIndex(idx *Index, col *Column) bool { // We execute `CHANGE COLUMN` on older versions of MySQL (<8.0), which // auto create the new index. The old one, will be dropped in `changeSet`. if compareVersions(d.version, "8.0.0") >= 0 { return idx.Name == col.Name && col.Unique } return false } // renameColumn returns the statement for renaming a column in // MySQL based on its version. func (d *MySQL) renameColumn(t *Table, old, new *Column) sql.Querier { q := sql.AlterTable(t.Name) if compareVersions(d.version, "8.0.0") >= 0 { return q.RenameColumn(old.Name, new.Name) } return q.ChangeColumn(old.Name, d.addColumn(new)) } // renameIndex returns the statement for renaming an index. func (d *MySQL) renameIndex(t *Table, old, new *Index) sql.Querier { q := sql.AlterTable(t.Name) if compareVersions(d.version, "5.7.0") >= 0 { return q.RenameIndex(old.Name, new.Name) } return q.DropIndex(old.Name).AddIndex(new.Builder(t.Name)) } // matchSchema returns the predicate for matching table schema. func (d *MySQL) matchSchema(columns ...string) *sql.Predicate { column := "TABLE_SCHEMA" if len(columns) > 0 { column = columns[0] } if d.schema != "" { return sql.EQ(column, d.schema) } return sql.EQ(column, sql.Raw("(SELECT DATABASE())")) } // tables returns the query for getting the in the schema. func (d *MySQL) tables() sql.Querier { return sql.Select("TABLE_NAME"). From(sql.Table("TABLES").Schema("INFORMATION_SCHEMA")). Where(d.matchSchema()) } // alterColumns returns the queries for applying the columns change-set. func (d *MySQL) alterColumns(table string, add, modify, drop []*Column) sql.Queries { b := sql.Dialect(dialect.MySQL).AlterTable(table) for _, c := range add { b.AddColumn(d.addColumn(c)) } for _, c := range modify { b.ModifyColumn(d.addColumn(c)) } for _, c := range drop { b.DropColumn(sql.Dialect(dialect.MySQL).Column(c.Name)) } if len(b.Queries) == 0 { return nil } return sql.Queries{b} } // normalizeJSON normalize MariaDB longtext columns to type JSON. func (d *MySQL) normalizeJSON(ctx context.Context, tx dialect.Tx, t *Table) error { columns := make(map[string]*Column) for _, c := range t.Columns { if c.typ == "longtext" { columns[c.Name] = c } } if len(columns) == 0 { return nil } rows := &sql.Rows{} query, args := sql.Select("CONSTRAINT_NAME"). From(sql.Table("CHECK_CONSTRAINTS").Schema("INFORMATION_SCHEMA")). Where(sql.And( d.matchSchema("CONSTRAINT_SCHEMA"), sql.EQ("TABLE_NAME", t.Name), sql.Like("CHECK_CLAUSE", "json_valid(%)"), )). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return fmt.Errorf("mysql: query table constraints %w", err) } // Call Close in cases of failures (Close is idempotent). defer rows.Close() names := make([]string, 0, len(columns)) if err := sql.ScanSlice(rows, &names); err != nil { return fmt.Errorf("mysql: scan table constraints: %w", err) } if err := rows.Err(); err != nil { return err } if err := rows.Close(); err != nil { return err } for _, name := range names { c, ok := columns[name] if ok { c.Type = field.TypeJSON } } return nil } // mariadb reports if the migration runs on MariaDB and returns the semver string. func (d *MySQL) mariadb() (string, bool) { idx := strings.Index(d.version, "MariaDB") if idx == -1 { return "", false } return d.version[:idx-1], true } // parseColumn returns column parts, size and signed-info from a MySQL type. func parseColumn(typ string) (parts []string, size int64, unsigned bool, err error) { switch parts = strings.FieldsFunc(typ, func(r rune) bool { return r == '(' || r == ')' || r == ' ' || r == ',' }); parts[0] { case "tinyint", "smallint", "mediumint", "int", "bigint": switch { case len(parts) == 2 && parts[1] == "unsigned": // int unsigned unsigned = true case len(parts) == 3: // int(10) unsigned unsigned = true fallthrough case len(parts) == 2: // int(10) size, err = strconv.ParseInt(parts[1], 10, 0) } case "varbinary", "varchar", "char", "binary": if len(parts) > 1 { size, err = strconv.ParseInt(parts[1], 10, 64) } } if err != nil { return parts, size, unsigned, fmt.Errorf("converting %s size to int: %w", parts[0], err) } return parts, size, unsigned, nil } // fkNames returns the foreign-key names of a column. func (d *MySQL) fkNames(ctx context.Context, tx dialect.Tx, table, column string) ([]string, error) { query, args := sql.Select("CONSTRAINT_NAME").From(sql.Table("KEY_COLUMN_USAGE").Schema("INFORMATION_SCHEMA")). Where(sql.And( sql.EQ("TABLE_NAME", table), sql.EQ("COLUMN_NAME", column), // NULL for unique and primary-key constraints. sql.NotNull("POSITION_IN_UNIQUE_CONSTRAINT"), d.matchSchema(), )). Query() var ( names []string rows = &sql.Rows{} ) if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("mysql: reading constraint names %w", err) } defer rows.Close() if err := sql.ScanSlice(rows, &names); err != nil { return nil, err } return names, nil } // defaultSize returns the default size for MySQL/MariaDB varchar type // based on column size, charset and table indexes, in order to avoid // index prefix key limit (767) for older versions of MySQL/MariaDB. func (d *MySQL) defaultSize(c *Column) int64 { size := DefaultStringLen version, checked := d.version, "5.7.0" if v, ok := d.mariadb(); ok { version, checked = v, "10.2.2" } switch { // Version is >= 5.7 for MySQL, or >= 10.2.2 for MariaDB. case compareVersions(version, checked) != -1: // Column is non-unique, or not part of any index (reaching // the error 1071). case !c.Unique && len(c.indexes) == 0 && !c.PrimaryKey(): default: size = 191 } return size } // needsConversion reports if column "old" needs to be converted // (by table altering) to column "new". func (d *MySQL) needsConversion(old, new *Column) bool { return d.cType(old) != d.cType(new) } // indexModified used by the migration differ to check if the index was modified. func (d *MySQL) indexModified(old, new *Index) bool { oldParts, newParts := indexParts(old), indexParts(new) if len(oldParts) != len(newParts) { return true } for column, oldPart := range oldParts { newPart, ok := newParts[column] if !ok || oldPart != newPart { return true } } return false } // indexParts returns a map holding the sub_part mapping if exist. func indexParts(idx *Index) map[string]uint { parts := make(map[string]uint) if idx.Annotation == nil { return parts } // If prefix (without a name) was defined on the // annotation, map it to the single column index. if idx.Annotation.Prefix > 0 && len(idx.Columns) == 1 { parts[idx.Columns[0].Name] = idx.Annotation.Prefix } for column, part := range idx.Annotation.PrefixColumns { parts[column] = part } return parts } // Atlas integration. func (d *MySQL) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) { return mysql.Open(&db{ExecQuerier: conn}) } func (d *MySQL) atTable(t1 *Table, t2 *schema.Table) { t2.SetCharset("utf8mb4").SetCollation("utf8mb4_bin") if t1.Annotation == nil { return } if charset := t1.Annotation.Charset; charset != "" { t2.SetCharset(charset) } if collate := t1.Annotation.Collation; collate != "" { t2.SetCollation(collate) } if opts := t1.Annotation.Options; opts != "" { t2.AddAttrs(&mysql.CreateOptions{ V: opts, }) } // Check if the connected database supports the CHECK clause. // For MySQL, is >= "8.0.16" and for MariaDB it is "10.2.1". v1, v2 := d.version, "8.0.16" if v, ok := d.mariadb(); ok { v1, v2 = v, "10.2.1" } if compareVersions(v1, v2) >= 0 { setAtChecks(t1, t2) } } func (d *MySQL) atTypeC(c1 *Column, c2 *schema.Column) error { if c1.SchemaType != nil && c1.SchemaType[dialect.MySQL] != "" { t, err := mysql.ParseType(strings.ToLower(c1.SchemaType[dialect.MySQL])) if err != nil { return err } c2.Type.Type = t return nil } var t schema.Type switch c1.Type { case field.TypeBool: t = &schema.BoolType{T: "boolean"} case field.TypeInt8: t = &schema.IntegerType{T: mysql.TypeTinyInt} case field.TypeUint8: t = &schema.IntegerType{T: mysql.TypeTinyInt, Unsigned: true} case field.TypeInt16: t = &schema.IntegerType{T: mysql.TypeSmallInt} case field.TypeUint16: t = &schema.IntegerType{T: mysql.TypeSmallInt, Unsigned: true} case field.TypeInt32: t = &schema.IntegerType{T: mysql.TypeInt} case field.TypeUint32: t = &schema.IntegerType{T: mysql.TypeInt, Unsigned: true} case field.TypeInt, field.TypeInt64: t = &schema.IntegerType{T: mysql.TypeBigInt} case field.TypeUint, field.TypeUint64: t = &schema.IntegerType{T: mysql.TypeBigInt, Unsigned: true} case field.TypeBytes: size := int64(math.MaxUint16) if c1.Size > 0 { size = c1.Size } switch { case size <= math.MaxUint8: t = &schema.BinaryType{T: mysql.TypeTinyBlob} case size <= math.MaxUint16: t = &schema.BinaryType{T: mysql.TypeBlob} case size < 1<<24: t = &schema.BinaryType{T: mysql.TypeMediumBlob} case size <= math.MaxUint32: t = &schema.BinaryType{T: mysql.TypeLongBlob} } case field.TypeJSON: t = &schema.JSONType{T: mysql.TypeJSON} if compareVersions(d.version, "5.7.8") == -1 { t = &schema.BinaryType{T: mysql.TypeLongBlob} } case field.TypeString: size := c1.Size if size == 0 { size = d.defaultSize(c1) } switch { case c1.typ == "tinytext", c1.typ == "text": t = &schema.StringType{T: c1.typ} case size <= math.MaxUint16: t = &schema.StringType{T: mysql.TypeVarchar, Size: int(size)} case size == 1<<24-1: t = &schema.StringType{T: mysql.TypeMediumText} default: t = &schema.StringType{T: mysql.TypeLongText} } case field.TypeFloat32, field.TypeFloat64: t = &schema.FloatType{T: c1.scanTypeOr(mysql.TypeDouble)} case field.TypeTime: t = &schema.TimeType{T: c1.scanTypeOr(mysql.TypeTimestamp)} // In MariaDB or in MySQL < v8.0.2, the TIMESTAMP column has both `DEFAULT CURRENT_TIMESTAMP` // and `ON UPDATE CURRENT_TIMESTAMP` if neither is specified explicitly. this behavior is // suppressed if the column is defined with a `DEFAULT` clause or with the `NULL` attribute. if _, maria := d.mariadb(); maria || compareVersions(d.version, "8.0.2") == -1 && c1.Default == nil { c2.SetNull(c1.Attr == "") } case field.TypeEnum: t = &schema.EnumType{T: mysql.TypeEnum, Values: c1.Enums} case field.TypeUUID: // "CHAR(X) BINARY" is treated as "CHAR(X) COLLATE latin1_bin", and in MySQL < 8, // and "COLLATE utf8mb4_bin" in MySQL >= 8. However we already set the table to t = &schema.StringType{T: mysql.TypeChar, Size: 36} c2.SetCollation("utf8mb4_bin") default: t, err := mysql.ParseType(strings.ToLower(c1.typ)) if err != nil { return err } c2.Type.Type = t } c2.Type.Type = t return nil } func (d *MySQL) atUniqueC(t1 *Table, c1 *Column, t2 *schema.Table, c2 *schema.Column) { // For UNIQUE columns, MySQL create an implicit index // named as the column with an extra index in case the // name is already taken (, , , ...). for _, idx := range t1.Indexes { // Index also defined explicitly, and will be add in atIndexes. if idx.Unique && d.atImplicitIndexName(idx, c1) { return } } t2.AddIndexes(schema.NewUniqueIndex(c1.Name).AddColumns(c2)) } func (d *MySQL) atIncrementC(_ *schema.Table, c *schema.Column) { c.AddAttrs(&mysql.AutoIncrement{}) } func (d *MySQL) atIncrementT(t *schema.Table, v int64) { t.AddAttrs(&mysql.AutoIncrement{V: v}) } func (d *MySQL) atImplicitIndexName(idx *Index, c1 *Column) bool { if idx.Name == c1.Name { return true } if !strings.HasPrefix(idx.Name, c1.Name+"_") { return false } i, err := strconv.ParseInt(strings.TrimLeft(idx.Name, c1.Name+"_"), 10, 64) return err == nil && i > 1 } func (d *MySQL) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error { prefix := indexParts(idx1) for _, c1 := range idx1.Columns { c2, ok := t2.Column(c1.Name) if !ok { return fmt.Errorf("unexpected index %q column: %q", idx1.Name, c1.Name) } part := &schema.IndexPart{C: c2} if v, ok := prefix[c1.Name]; ok { part.AddAttrs(&mysql.SubPart{Len: int(v)}) } idx2.AddParts(part) } if t, ok := indexType(idx1, dialect.MySQL); ok { idx2.AddAttrs(&mysql.IndexType{T: t}) } return nil } func indexType(idx *Index, d string) (string, bool) { ant := idx.Annotation if ant == nil { return "", false } if ant.Types != nil && ant.Types[d] != "" { return ant.Types[d], true } if ant.Type != "" { return ant.Type, true } return "", false } func (MySQL) atTypeRangeSQL(ts ...string) string { for i := range ts { ts[i] = fmt.Sprintf("('%s')", ts[i]) } return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", ")) } ent-0.11.3/dialect/sql/schema/mysql_test.go000066400000000000000000002213671431500740500205730ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "math" "regexp" "strings" "testing" "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestMySQL_Create(t *testing.T) { tests := []struct { name string tables []*Table options []MigrateOption before func(mysqlMock) wantErr bool }{ { name: "tx failed", before: func(mock mysqlMock) { mock.ExpectBegin(). WillReturnError(sqlmock.ErrCancelled) }, wantErr: true, }, { name: "no tables", before: func(mock mysqlMock) { mock.start("5.7.23") mock.ExpectCommit() }, }, { name: "create new table", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeInt}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, {Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "ts", Type: field.TypeTime}, {Name: "ts_default", Type: field.TypeTime, Default: "CURRENT_TIMESTAMP"}, {Name: "datetime", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Default: "CURRENT_TIMESTAMP"}, {Name: "decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2)"}}, {Name: "unsigned decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2) unsigned"}}, {Name: "float", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "float"}, Default: "0"}, }, Annotation: &entsql.Annotation{ Charset: "utf8", Collation: "utf8_general_ci", Options: "ENGINE = INNODB", Check: "price > 0", Checks: map[string]string{ "valid_age": "age > 0", "valid_name": "name <> ''", }, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.8") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `age` bigint NOT NULL, `doc` json NULL, `enums` enum('a', 'b') NOT NULL, `uuid` char(36) binary NULL, `ts` timestamp NULL, `ts_default` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, `datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, `decimal` decimal(6,2) NOT NULL, `unsigned decimal` decimal(6,2) unsigned NOT NULL, `float` float NOT NULL DEFAULT '0', PRIMARY KEY(`id`), CHECK (price > 0), CONSTRAINT `valid_age` CHECK (age > 0), CONSTRAINT `valid_name` CHECK (name <> '')) CHARACTER SET utf8 COLLATE utf8_general_ci ENGINE = INNODB")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with specific field collation", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "address", Type: field.TypeString, Nullable: true, Collation: "utf8_unicode_ci"}, {Name: "age", Type: field.TypeInt}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, {Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "datetime", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Default: "CURRENT_TIMESTAMP"}, {Name: "decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2)"}}, }, Annotation: &entsql.Annotation{ Charset: "utf8", Collation: "utf8_general_ci", Options: "ENGINE = INNODB", }, }, }, before: func(mock mysqlMock) { mock.start("5.7.33") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `address` varchar(255) NULL COLLATE utf8_unicode_ci, `age` bigint NOT NULL, `doc` json NULL, `enums` enum('a', 'b') NOT NULL, `uuid` char(36) binary NULL, `datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP, `decimal` decimal(6,2) NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8 COLLATE utf8_general_ci ENGINE = INNODB")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table 5.6", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, {Name: "name", Type: field.TypeString, Unique: true}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.6.35") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `age` bigint NOT NULL, `name` varchar(191) UNIQUE NOT NULL, `doc` longblob NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` bigint NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.fkExists("pets_owner", false) mock.ExpectExec(escape("ALTER TABLE `pets` ADD CONSTRAINT `pets_owner` FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key disabled", options: []MigrateOption{ WithForeignKeys(false), }, tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` timestamp NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `pets`(`id` bigint AUTO_INCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` bigint NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add columns to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, {Name: "mediumtext", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.MySQL: "mediumtext"}}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "date", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{dialect.MySQL: "date"}}, {Name: "age", Type: field.TypeInt}, {Name: "tiny", Type: field.TypeInt8}, {Name: "tiny_unsigned", Type: field.TypeUint8}, {Name: "small", Type: field.TypeInt16}, {Name: "small_unsigned", Type: field.TypeUint16}, {Name: "big", Type: field.TypeInt64}, {Name: "big_unsigned", Type: field.TypeUint64}, {Name: "decimal", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2)"}}, {Name: "unsigned_decimal", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.MySQL: "decimal(6,2) unsigned"}}, {Name: "ts", Type: field.TypeTime}, {Name: "timestamp", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "TIMESTAMP"}, Default: "CURRENT_TIMESTAMP"}, {Name: "float", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.MySQL: "float"}, Default: "0"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("8.0.19") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("text", "longtext", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("mediumtext", "mediumtext", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("uuid", "char(36)", "YES", "YES", "NULL", "", "", "utf8mb4_bin", nil, nil). AddRow("date", "date", "YES", "YES", "NULL", "", "", "", nil, nil). // 8.0.19: new int column type formats AddRow("tiny", "tinyint", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("tiny_unsigned", "tinyint unsigned", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("small", "smallint", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("small_unsigned", "smallint unsigned", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("big", "bigint", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("big_unsigned", "bigint unsigned", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("decimal", "decimal(6,2)", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("unsigned_decimal", "decimal(6,2) unsigned", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("timestamp", "timestamp", "NO", "NO", "CURRENT_TIMESTAMP", "DEFAULT_GENERATED on update CURRENT_TIMESTAMP", "", "", nil, nil). AddRow("float", "float", "NO", "NO", "0", "0", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` bigint NOT NULL, ADD COLUMN `ts` timestamp NOT NULL, MODIFY COLUMN `timestamp` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "enums", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "enums1", Type: field.TypeEnum, Enums: []string{"a", "b"}}, // add enum. {Name: "enums2", Type: field.TypeEnum, Enums: []string{"a"}}, // remove enum. {Name: "enums3", Type: field.TypeEnum, Enums: []string{"a", "b c"}}, // no changes. }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("enums1", "enum('a')", "YES", "NO", "NULL", "", "", "", nil, nil). AddRow("enums2", "enum('b', 'a')", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("enums3", "enum('a', 'b c')", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `enums1` enum('a', 'b') NOT NULL, MODIFY COLUMN `enums2` enum('a') NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "datetime and timestamp", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Nullable: true}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "datetime"}, Nullable: true}, {Name: "deleted_at", Type: field.TypeTime, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("created_at", "datetime", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("updated_at", "timestamp", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("deleted_at", "datetime", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `updated_at` datetime NULL, MODIFY COLUMN `deleted_at` timestamp NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add int column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeInt, Default: 10}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.6.0") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("doc", "longblob", "YES", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` bigint NOT NULL DEFAULT 10")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add blob columns", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "tiny", Type: field.TypeBytes, Size: 100}, {Name: "blob", Type: field.TypeBytes, Size: 1e3}, {Name: "medium", Type: field.TypeBytes, Size: 1e5}, {Name: "long", Type: field.TypeBytes, Size: 1e8}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `tiny` tinyblob NOT NULL, ADD COLUMN `blob` blob NOT NULL, ADD COLUMN `medium` mediumblob NOT NULL, ADD COLUMN `long` longblob NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add binary column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "binary", Type: field.TypeBytes, Size: 20, SchemaType: map[string]string{dialect.MySQL: "binary(20)"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("8.0.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `binary` binary(20) NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "accept varbinary columns", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "tiny", Type: field.TypeBytes, Size: 100}, {Name: "medium", Type: field.TypeBytes, Size: math.MaxUint32}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("tiny", "varbinary(255)", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("medium", "varbinary(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `medium` longblob NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add float column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeFloat64, Default: 10.1}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec("ALTER TABLE `users` ADD COLUMN `age` double NOT NULL DEFAULT 10.1"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add bool column with default value", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeBool, Default: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec("ALTER TABLE `users` ADD COLUMN `age` boolean NOT NULL DEFAULT true"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add string column with default value", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "nick", Type: field.TypeString, Default: "unknown"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `nick` varchar(255) NOT NULL DEFAULT 'unknown'")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add column with unsupported default value", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "nick", Type: field.TypeString, Size: 1 << 17, Default: "unknown"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `nick` longtext NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "drop columns", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropColumn(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` DROP COLUMN `name`")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "modify column to nullable", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, {Name: "name", Type: field.TypeString, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "NO", "YES", "NULL", "", "", "", nil, nil). AddRow("age", "bigint(20)", "NO", "NO", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `name` varchar(255) NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "apply uniqueness on column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt, Unique: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("age", "bigint(20)", "NO", "", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) // create the unique index. mock.ExpectExec(escape("CREATE UNIQUE INDEX `age` ON `users`(`age`)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "remove uniqueness from column without option", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("age", "bigint(20)", "NO", "UNI", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1"). AddRow("age", "age", nil, "0", "1")) mock.ExpectCommit() }, }, { name: "remove uniqueness from column with option", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropIndex(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("age", "bigint(20)", "NO", "UNI", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1"). AddRow("age", "age", nil, "0", "1")) // check if a foreign-key needs to be dropped. mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). WithArgs("users", "age"). WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"})) // drop the unique index. mock.ExpectExec(escape("DROP INDEX `age` ON `users`")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "increase index sub_part", tables: func() []*Table { t := &Table{ Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "text", Type: field.TypeString, Size: math.MaxInt32, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Indexes: []*Index{ {Name: "prefix_text", Annotation: &entsql.IndexAnnotation{Prefix: 100}}, }, } t.Indexes[0].Columns = t.Columns[1:] return []*Table{t} }(), options: []MigrateOption{WithDropIndex(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("text", "longtext", "YES", "NO", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1"). AddRow("prefix_text", "text", "50", "0", "1")) mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). WithArgs("users", "text"). WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"})) // modify index by dropping and creating it. mock.ExpectExec(escape("DROP INDEX `prefix_text` ON `users`")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("CREATE INDEX `prefix_text` ON `users`(`text`(100))")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "ignore foreign keys on index dropping", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "parent_id", Type: field.TypeInt, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, ForeignKeys: []*ForeignKey{ { Symbol: "parent_id", Columns: []*Column{ {Name: "parent_id", Type: field.TypeInt, Nullable: true}, }, }, }, }, }, options: []MigrateOption{WithDropIndex(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1"). AddRow("old_index", "old", nil, "0", "1"). AddRow("parent_id", "parent_id", nil, "0", "1")) // drop the unique index. mock.ExpectExec(escape("DROP INDEX `old_index` ON `users`")). WillReturnResult(sqlmock.NewResult(0, 1)) // foreign key already exist. mock.fkExists("parent_id", true) mock.ExpectCommit() }, }, { name: "drop foreign key with column and index", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropIndex(true), WithDropColumn(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1"). AddRow("parent_id", "parent_id", nil, "0", "1")) // check if a foreign-key needs to be dropped. mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). WithArgs("users", "parent_id"). WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"}).AddRow("users_parent_id")) mock.ExpectExec(escape("ALTER TABLE `users` DROP FOREIGN KEY `users_parent_id`")). WillReturnResult(sqlmock.NewResult(0, 1)) // drop the unique index. mock.ExpectExec(escape("DROP INDEX `parent_id` ON `users`")). WillReturnResult(sqlmock.NewResult(0, 1)) // drop the unique index. mock.ExpectExec(escape("ALTER TABLE `users` DROP COLUMN `parent_id`")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create a new simple-index for the foreign-key", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "parent_id", Type: field.TypeInt, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropIndex(true), WithDropColumn(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("parent_id", "bigint(20)", "YES", "NULL", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1"). AddRow("parent_id", "parent_id", nil, "0", "1")) // check if there's a foreign-key that is associated with this index. mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`KEY_COLUMN_USAGE` WHERE `TABLE_NAME` = ? AND `COLUMN_NAME` = ? AND `POSITION_IN_UNIQUE_CONSTRAINT` IS NOT NULL AND `TABLE_SCHEMA` = (SELECT DATABASE())")). WithArgs("users", "parent_id"). WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"}).AddRow("users_parent_id")) // create a new index, to replace the old one (that needs to be dropped). mock.ExpectExec(escape("CREATE INDEX `users_parent_id` ON `users`(`parent_id`)")). WillReturnResult(sqlmock.NewResult(0, 1)) // drop the unique index. mock.ExpectExec(escape("DROP INDEX `parent_id` ON `users`")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add edge to table", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "spouse_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "user_spouse" + strings.Repeat("_", 64), // super long fk. Columns: c1[2:], RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) t1.ForeignKeys[0].RefTable = t1 return []*Table{t1} }(), before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` bigint NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.fkExists("user_spouse_____________________390ed76f91d3c57cd3516e7690f621dc", false) mock.ExpectExec("ALTER TABLE `users` ADD CONSTRAINT `.{64}` FOREIGN KEY\\(`spouse_id`\\) REFERENCES `users`\\(`id`\\) ON DELETE CASCADE"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for all tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("ent_types", false) // create ent_types table. mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `ent_types`(`id` bigint unsigned AUTO_INCREMENT NOT NULL, `type` varchar(255) UNIQUE NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("users"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("groups"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for new tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) mock.tableExists("users", true) // users table has no changes. mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) // query groups table. mock.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("groups"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for restored tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range (without inserting to ent_types). mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("groups", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `groups`(`id` bigint AUTO_INCREMENT NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `groups` AUTO_INCREMENT = 4294967296")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id mismatch with ent_types", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). WillReturnRows(sqlmock.NewRows([]string{"type"}). AddRow("deleted"). AddRow("users")) mock.tableExists("users", true) // users table has no changes. mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) // query the auto-increment value. mock.ExpectQuery(escape("SELECT `AUTO_INCREMENT` FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"AUTO_INCREMENT"}). AddRow(1)) // restore the auto-increment counter. mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 4294967296")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "no modify numeric column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.MySQL: "decimal(6,4)"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("price", "decimal(6,4)", "NO", "YES", "NULL", "", "", "", "6", "4")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectCommit() }, }, { name: "modify numeric column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.MySQL: "decimal(6,4)"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("5.7.23") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("price", "decimal(6,4)", "NO", "YES", "NULL", "", "", "", "5", "4")) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectExec(escape("ALTER TABLE `users` MODIFY COLUMN `price` decimal(6,4) NOT NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, // MariaDB specific tests. { name: "mariadb/10.2.32/create table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "json", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("10.2.32-MariaDB") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL CHECK (JSON_VALID(`json`)), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "mariadb/10.3.13/create table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "json", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("10.3.13-MariaDB-1:10.3.13+maria~bionic") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL CHECK (JSON_VALID(`json`)), PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "mariadb/10.5.8/create table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "json", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("10.5.8-MariaDB") mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `json` json NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "mariadb/10.5.8/table exists", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "json", Type: field.TypeJSON, Nullable: true}, {Name: "longtext", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock mysqlMock) { mock.start("10.5.8-MariaDB-1:10.5.8+maria~focal") mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `column_name`, `column_type`, `is_nullable`, `column_key`, `column_default`, `extra`, `character_set_name`, `collation_name`, `numeric_precision`, `numeric_scale` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name", "numeric_precision", "numeric_scale"}). AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("json", "longtext", "YES", "YES", "NULL", "", "utf8mb4", "utf8mb4_bin", nil, nil). AddRow("longtext", "longtext", "YES", "YES", "NULL", "", "utf8mb4", "utf8mb4_bin", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). AddRow("PRIMARY", "id", nil, "0", "1")) mock.ExpectQuery(escape("SELECT `CONSTRAINT_NAME` FROM `INFORMATION_SCHEMA`.`CHECK_CONSTRAINTS` WHERE `CONSTRAINT_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? AND `CHECK_CLAUSE` LIKE ?")). WithArgs("users", "json_valid(%)"). WillReturnRows(sqlmock.NewRows([]string{"CONSTRAINT_NAME"}). AddRow("json")) mock.ExpectCommit() }, }, { name: "mariadb/10.1.37/create table", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, {Name: "name", Type: field.TypeString, Unique: true}, }, }, }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock mysqlMock) { mock.start("10.1.48-MariaDB-1~bionic") mock.tableExists("ent_types", false) // create ent_types table. mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `ent_types`(`id` bigint unsigned AUTO_INCREMENT NOT NULL, `type` varchar(191) UNIQUE NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE IF NOT EXISTS `users`(`id` bigint AUTO_INCREMENT NOT NULL, `age` bigint NOT NULL, `name` varchar(191) UNIQUE NOT NULL, PRIMARY KEY(`id`)) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin")). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("users"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `users` AUTO_INCREMENT = 0")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.before(mysqlMock{mock}) migrate, err := NewMigrate(sql.OpenDB("mysql", db), append(tt.options, WithAtlas(false))...) require.NoError(t, err) err = migrate.Create(context.Background(), tt.tables...) require.Equal(t, tt.wantErr, err != nil, err) }) } } type mysqlMock struct { sqlmock.Sqlmock } func (m mysqlMock) start(version string) { m.ExpectQuery(escape("SHOW VARIABLES LIKE 'version'")). WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", version)) m.ExpectBegin() } func (m mysqlMock) tableExists(table string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ?")). WithArgs(table). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } func (m mysqlMock) fkExists(fk string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape("SELECT COUNT(*) FROM `INFORMATION_SCHEMA`.`TABLE_CONSTRAINTS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `CONSTRAINT_TYPE` = ? AND `CONSTRAINT_NAME` = ?")). WithArgs("FOREIGN KEY", fk). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } func escape(query string) string { rows := strings.Split(query, "\n") for i := range rows { rows[i] = strings.TrimPrefix(rows[i], " ") } query = strings.Join(rows, " ") return strings.TrimSpace(regexp.QuoteMeta(query)) + "$" } ent-0.11.3/dialect/sql/schema/postgres.go000066400000000000000000000625101431500740500202260ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "strconv" "strings" "unicode" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "ariga.io/atlas/sql/migrate" "ariga.io/atlas/sql/postgres" "ariga.io/atlas/sql/schema" ) // Postgres is a postgres migration driver. type Postgres struct { dialect.Driver schema string version string } // init loads the Postgres version from the database for later use in the migration process. // It returns an error if the server version is lower than v10. func (d *Postgres) init(ctx context.Context) error { rows := &sql.Rows{} if err := d.Query(ctx, "SHOW server_version_num", []any{}, rows); err != nil { return fmt.Errorf("querying server version %w", err) } defer rows.Close() if !rows.Next() { if err := rows.Err(); err != nil { return err } return fmt.Errorf("server_version_num variable was not found") } var version string if err := rows.Scan(&version); err != nil { return fmt.Errorf("scanning version: %w", err) } if len(version) < 6 { return fmt.Errorf("malformed version: %s", version) } d.version = fmt.Sprintf("%s.%s.%s", version[:2], version[2:4], version[4:]) if compareVersions(d.version, "10.0.0") == -1 { return fmt.Errorf("unsupported postgres version: %s", d.version) } return nil } // tableExist checks if a table exists in the database and current schema. func (d *Postgres) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) { query, args := sql.Dialect(dialect.Postgres). Select(sql.Count("*")).From(sql.Table("tables").Schema("information_schema")). Where(sql.And( d.matchSchema(), sql.EQ("table_name", name), )).Query() return exist(ctx, conn, query, args...) } // tableExist checks if a foreign-key exists in the current schema. func (d *Postgres) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) { query, args := sql.Dialect(dialect.Postgres). Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")). Where(sql.And( d.matchSchema(), sql.EQ("constraint_type", "FOREIGN KEY"), sql.EQ("constraint_name", name), )).Query() return exist(ctx, tx, query, args...) } // setRange sets restart the identity column to the given offset. Used by the universal-id option. func (d *Postgres) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error { if value == 0 { value = 1 // RESTART value cannot be < 1. } pk := "id" if len(t.PrimaryKey) == 1 { pk = t.PrimaryKey[0].Name } return conn.Exec(ctx, fmt.Sprintf("ALTER TABLE %q ALTER COLUMN %q RESTART WITH %d", t.Name, pk, value), []any{}, nil) } // table loads the current table description from the database. func (d *Postgres) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { rows := &sql.Rows{} query, args := sql.Dialect(dialect.Postgres). Select( "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length", ). From(sql.Table("columns").Schema("information_schema")). Where(sql.And( d.matchSchema(), sql.EQ("table_name", name), )).Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("postgres: reading table description %w", err) } // Call `Close` in cases of failures (`Close` is idempotent). defer rows.Close() t := NewTable(name) for rows.Next() { c := &Column{} if err := d.scanColumn(c, rows); err != nil { return nil, err } t.AddColumn(c) } if err := rows.Err(); err != nil { return nil, err } if err := rows.Close(); err != nil { return nil, fmt.Errorf("closing rows %w", err) } idxs, err := d.indexes(ctx, tx, name) if err != nil { return nil, err } // Populate the index information to the table and its columns. // We do it manually, because PK and uniqueness information does // not exist when querying the information_schema.COLUMNS above. for _, idx := range idxs { switch { case idx.primary: for _, name := range idx.columns { c, ok := t.column(name) if !ok { return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) } c.Key = PrimaryKey t.PrimaryKey = append(t.PrimaryKey, c) } case idx.Unique && len(idx.columns) == 1: name := idx.columns[0] c, ok := t.column(name) if !ok { return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) } c.Key = UniqueKey c.Unique = true fallthrough default: t.addIndex(idx) } } return t, nil } // indexesQuery holds a query format for retrieving // table indexes of the current schema. const indexesQuery = ` SELECT i.relname AS index_name, a.attname AS column_name, idx.indisprimary AS primary, idx.indisunique AS unique, array_position(idx.indkey, a.attnum) as seq_in_index FROM pg_class t, pg_class i, pg_index idx, pg_attribute a, pg_namespace n WHERE t.oid = idx.indrelid AND i.oid = idx.indexrelid AND n.oid = t.relnamespace AND a.attrelid = t.oid AND a.attnum = ANY(idx.indkey) AND t.relkind = 'r' AND n.nspname = %s AND t.relname = '%s' ORDER BY index_name, seq_in_index; ` // indexesQuery returns the query (and its placeholders) for getting table indexes. func (d *Postgres) indexesQuery(table string) (string, []any) { if d.schema != "" { return fmt.Sprintf(indexesQuery, "$1", table), []any{d.schema} } return fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", table), nil } func (d *Postgres) indexes(ctx context.Context, tx dialect.Tx, table string) (Indexes, error) { rows := &sql.Rows{} query, args := d.indexesQuery(table) if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("querying indexes for table %s: %w", table, err) } defer rows.Close() var ( idxs Indexes names = make(map[string]*Index) ) for rows.Next() { var ( seqindex int name, column string unique, primary bool ) if err := rows.Scan(&name, &column, &primary, &unique, &seqindex); err != nil { return nil, fmt.Errorf("scanning index description: %w", err) } // If the index is prefixed with the table, it may was added by // `addIndex` and it should be trimmed. But, since entc prefixes // all indexes with schema-type, for uncountable types (like, media // or equipment) this isn't correct, and we fallback for the real-name. short := strings.TrimPrefix(name, table+"_") idx, ok := names[short] if !ok { idx = &Index{Name: short, Unique: unique, primary: primary, realname: name} idxs = append(idxs, idx) names[short] = idx } idx.columns = append(idx.columns, column) } if err := rows.Err(); err != nil { return nil, err } return idxs, nil } // maxCharSize defines the maximum size of limited character types in Postgres (10 MB). const maxCharSize = 10 << 20 // scanColumn scans the information a column from column description. func (d *Postgres) scanColumn(c *Column, rows *sql.Rows) error { var ( nullable sql.NullString defaults sql.NullString udt sql.NullString numericPrecision sql.NullInt64 numericScale sql.NullInt64 characterMaximumLen sql.NullInt64 ) if err := rows.Scan(&c.Name, &c.typ, &nullable, &defaults, &udt, &numericPrecision, &numericScale, &characterMaximumLen); err != nil { return fmt.Errorf("scanning column description: %w", err) } if nullable.Valid { c.Nullable = nullable.String == "YES" } switch c.typ { case "boolean": c.Type = field.TypeBool case "smallint": c.Type = field.TypeInt16 case "integer": c.Type = field.TypeInt32 case "bigint": c.Type = field.TypeInt64 case "real": c.Type = field.TypeFloat32 case "double precision": c.Type = field.TypeFloat64 case "numeric", "decimal": c.Type = field.TypeFloat64 // If precision is specified then we should take that into account. if numericPrecision.Valid { schemaType := fmt.Sprintf("%s(%d,%d)", c.typ, numericPrecision.Int64, numericScale.Int64) c.SchemaType = map[string]string{dialect.Postgres: schemaType} } case "text": c.Type = field.TypeString c.Size = maxCharSize + 1 case "character", "character varying": c.Type = field.TypeString // If character maximum length is specified then we should take that into account. if characterMaximumLen.Valid { schemaType := fmt.Sprintf("varchar(%d)", characterMaximumLen.Int64) c.SchemaType = map[string]string{dialect.Postgres: schemaType} } case "date", "time with time zone", "time without time zone", "timestamp with time zone", "timestamp without time zone": c.Type = field.TypeTime case "bytea": c.Type = field.TypeBytes case "jsonb": c.Type = field.TypeJSON case "uuid": c.Type = field.TypeUUID case "cidr", "inet", "macaddr", "macaddr8": c.Type = field.TypeOther case "point", "line", "lseg", "box", "path", "polygon", "circle": c.Type = field.TypeOther case "ARRAY": c.Type = field.TypeOther if !udt.Valid { return fmt.Errorf("missing array type for column %q", c.Name) } // Note that for ARRAY types, the 'udt_name' column holds the array type // prefixed with '_'. For example, for 'integer[]' the result is '_int', // and for 'text[N][M]' the result is also '_text'. That's because, the // database ignores any size or multi-dimensions constraints. c.SchemaType = map[string]string{dialect.Postgres: "ARRAY"} c.typ = udt.String case "USER-DEFINED", "tstzrange", "interval": c.Type = field.TypeOther if !udt.Valid { return fmt.Errorf("missing user defined type for column %q", c.Name) } c.SchemaType = map[string]string{dialect.Postgres: udt.String} } switch { case !defaults.Valid || c.Type == field.TypeTime || callExpr(defaults.String): return nil case strings.Contains(defaults.String, "::"): parts := strings.Split(defaults.String, "::") defaults.String = strings.Trim(parts[0], "'") fallthrough default: return c.ScanDefault(defaults.String) } } // tBuilder returns the TableBuilder for the given table. func (d *Postgres) tBuilder(t *Table) *sql.TableBuilder { b := sql.Dialect(dialect.Postgres). CreateTable(t.Name).IfNotExists() for _, c := range t.Columns { b.Column(d.addColumn(c)) } for _, pk := range t.PrimaryKey { b.PrimaryKey(pk.Name) } if t.Annotation != nil { addChecks(b, t.Annotation) } return b } // cType returns the PostgreSQL string type for this column. func (d *Postgres) cType(c *Column) (t string) { if c.SchemaType != nil && c.SchemaType[dialect.Postgres] != "" { return c.SchemaType[dialect.Postgres] } switch c.Type { case field.TypeBool: t = "boolean" case field.TypeUint8, field.TypeInt8, field.TypeInt16, field.TypeUint16: t = "smallint" case field.TypeInt32, field.TypeUint32: t = "int" case field.TypeInt, field.TypeUint, field.TypeInt64, field.TypeUint64: t = "bigint" case field.TypeFloat32: t = c.scanTypeOr("real") case field.TypeFloat64: t = c.scanTypeOr("double precision") case field.TypeBytes: t = "bytea" case field.TypeJSON: t = "jsonb" case field.TypeUUID: t = "uuid" case field.TypeString: t = "varchar" if c.Size > maxCharSize { t = "text" } case field.TypeTime: t = c.scanTypeOr("timestamp with time zone") case field.TypeEnum: // Currently, the support for enums is weak (application level only. // like SQLite). Dialect needs to create and maintain its enum type. t = "varchar" case field.TypeOther: t = c.typ default: panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name)) } return t } // addColumn returns the ColumnBuilder for adding the given column to a table. func (d *Postgres) addColumn(c *Column) *sql.ColumnBuilder { b := sql.Dialect(dialect.Postgres). Column(c.Name).Type(d.cType(c)).Attr(c.Attr) c.unique(b) if c.Increment { b.Attr("GENERATED BY DEFAULT AS IDENTITY") } c.nullable(b) d.writeDefault(b, c, "DEFAULT") if c.Collation != "" { b.Attr("COLLATE " + strconv.Quote(c.Collation)) } return b } // writeDefault writes the `DEFAULT` clause to column builder // if exists and supported by the driver. func (d *Postgres) writeDefault(b *sql.ColumnBuilder, c *Column, clause string) { if c.Default == nil || !c.supportDefault() { return } attr := fmt.Sprint(c.Default) switch v := c.Default.(type) { case bool: attr = strconv.FormatBool(v) case string: if t := c.Type; t != field.TypeUUID && t != field.TypeTime && !t.Numeric() { // Escape single quote by replacing each with 2. attr = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) } } b.Attr(clause + " " + attr) } // alterColumn returns list of ColumnBuilder for applying in order to alter a column. func (d *Postgres) alterColumn(c *Column) (ops []*sql.ColumnBuilder) { b := sql.Dialect(dialect.Postgres) ops = append(ops, b.Column(c.Name).Type(d.cType(c))) if c.Nullable { ops = append(ops, b.Column(c.Name).Attr("DROP NOT NULL")) } else { ops = append(ops, b.Column(c.Name).Attr("SET NOT NULL")) } if c.Default != nil && c.supportDefault() { ops = append(ops, d.writeSetDefault(b.Column(c.Name), c)) } return ops } func (d *Postgres) writeSetDefault(b *sql.ColumnBuilder, c *Column) *sql.ColumnBuilder { d.writeDefault(b, c, "SET DEFAULT") return b } // hasUniqueName reports if the index has a unique name in the schema. func hasUniqueName(i *Index) bool { // Trim the "_key" suffix if it was added by Postgres for implicit indexes. name := strings.TrimSuffix(i.Name, "_key") suffix := strings.Join(i.columnNames(), "_") if !strings.HasSuffix(name, suffix) { return true // Assume it has a custom storage-key. } // The codegen prefixes by default indexes with the type name. // For example, an index "users"("name"), will named as "user_name". return name != suffix } // addIndex returns the query for adding an index to PostgreSQL. func (d *Postgres) addIndex(i *Index, table string) *sql.IndexBuilder { name := i.Name if !hasUniqueName(i) { // Since index name should be unique in pg_class for schema, // we prefix it with the table name and remove on read. name = fmt.Sprintf("%s_%s", table, i.Name) } idx := sql.Dialect(dialect.Postgres). CreateIndex(name).IfNotExists().Table(table) if i.Unique { idx.Unique() } for _, c := range i.Columns { idx.Column(c.Name) } return idx } // dropIndex drops a Postgres index. func (d *Postgres) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error { name := idx.Name build := sql.Dialect(dialect.Postgres) if prefix := table + "_"; !strings.HasPrefix(name, prefix) && !hasUniqueName(idx) { name = prefix + name } query, args := sql.Dialect(dialect.Postgres). Select(sql.Count("*")).From(sql.Table("table_constraints").Schema("information_schema")). Where(sql.And( d.matchSchema(), sql.EQ("constraint_type", "UNIQUE"), sql.EQ("constraint_name", name), )). Query() exists, err := exist(ctx, tx, query, args...) if err != nil { return err } query, args = build.DropIndex(name).Query() if exists { query, args = build.AlterTable(table).DropConstraint(name).Query() } return tx.Exec(ctx, query, args, nil) } // isImplicitIndex reports if the index was created implicitly for the unique column. func (d *Postgres) isImplicitIndex(idx *Index, col *Column) bool { return strings.TrimSuffix(idx.Name, "_key") == col.Name && col.Unique } // renameColumn returns the statement for renaming a column. func (d *Postgres) renameColumn(t *Table, old, new *Column) sql.Querier { return sql.Dialect(dialect.Postgres). AlterTable(t.Name). RenameColumn(old.Name, new.Name) } // renameIndex returns the statement for renaming an index. func (d *Postgres) renameIndex(t *Table, old, new *Index) sql.Querier { if sfx := "_key"; strings.HasSuffix(old.Name, sfx) && !strings.HasSuffix(new.Name, sfx) { new.Name += sfx } if pfx := t.Name + "_"; strings.HasPrefix(old.realname, pfx) && !strings.HasPrefix(new.Name, pfx) { new.Name = pfx + new.Name } return sql.Dialect(dialect.Postgres).AlterIndex(old.realname).Rename(new.Name) } // matchSchema returns the predicate for matching table schema. func (d *Postgres) matchSchema(columns ...string) *sql.Predicate { column := "table_schema" if len(columns) > 0 { column = columns[0] } if d.schema != "" { return sql.EQ(column, d.schema) } return sql.EQ(column, sql.Raw("CURRENT_SCHEMA()")) } // tables returns the query for getting the in the schema. func (d *Postgres) tables() sql.Querier { return sql.Dialect(dialect.Postgres). Select("table_name"). From(sql.Table("tables").Schema("information_schema")). Where(d.matchSchema()) } // alterColumns returns the queries for applying the columns change-set. func (d *Postgres) alterColumns(table string, add, modify, drop []*Column) sql.Queries { b := sql.Dialect(dialect.Postgres).AlterTable(table) for _, c := range add { b.AddColumn(d.addColumn(c)) } for _, c := range modify { b.ModifyColumns(d.alterColumn(c)...) } for _, c := range drop { b.DropColumn(sql.Dialect(dialect.Postgres).Column(c.Name)) } if len(b.Queries) == 0 { return nil } return sql.Queries{b} } // needsConversion reports if column "old" needs to be converted // (by table altering) to column "new". func (d *Postgres) needsConversion(old, new *Column) bool { oldT, newT := d.cType(old), d.cType(new) return oldT != newT && (oldT != "ARRAY" || !arrayType(newT)) } // callExpr reports if the given string ~looks like a function call expression. func callExpr(s string) bool { if parts := strings.Split(s, "::"); !strings.HasSuffix(s, ")") && strings.HasSuffix(parts[0], ")") { s = parts[0] } i, j := strings.IndexByte(s, '('), strings.LastIndexByte(s, ')') if i == -1 || i > j || j != len(s)-1 { return false } for i, r := range s[:i] { if !isAlpha(r, i > 0) { return false } } return true } func isAlpha(r rune, digit bool) bool { return 'a' <= r && r <= 'z' || 'A' <= r && r <= 'Z' || r == '_' || digit && '0' <= r && r <= '9' } // arrayType reports if the given string is an array type (e.g. int[], text[2]). func arrayType(t string) bool { i, j := strings.LastIndexByte(t, '['), strings.LastIndexByte(t, ']') if i == -1 || j == -1 { return false } for _, r := range t[i+1 : j] { if !unicode.IsDigit(r) { return false } } return true } // foreignKeys populates the tables foreign keys using the information_schema tables func (d *Postgres) foreignKeys(ctx context.Context, tx dialect.Tx, tables []*Table) error { var tableLookup = make(map[string]*Table) for _, t := range tables { tableLookup[t.Name] = t } for _, t := range tables { rows := &sql.Rows{} query := fmt.Sprintf(fkQuery, t.Name) if err := tx.Query(ctx, query, []any{}, rows); err != nil { return fmt.Errorf("querying foreign keys for table %s: %w", t.Name, err) } defer rows.Close() var tableFksLookup = make(map[string]*ForeignKey) for rows.Next() { var tableSchema, constraintName, tableName, columnName, refTableSchema, refTableName, refColumnName string if err := rows.Scan(&tableSchema, &constraintName, &tableName, &columnName, &refTableSchema, &refTableName, &refColumnName); err != nil { return fmt.Errorf("scanning index description: %w", err) } refTable := tableLookup[refTableName] if refTable == nil { return fmt.Errorf("could not find table: %s", refTableName) } column, ok := t.column(columnName) if !ok { return fmt.Errorf("could not find column: %s on table: %s", columnName, tableName) } refColumn, ok := refTable.column(refColumnName) if !ok { return fmt.Errorf("could not find ref column: %s on ref table: %s", refTableName, refColumnName) } if fk, ok := tableFksLookup[constraintName]; ok { if _, ok := fk.column(columnName); !ok { fk.Columns = append(fk.Columns, column) } if _, ok := fk.refColumn(refColumnName); !ok { fk.RefColumns = append(fk.RefColumns, refColumn) } } else { newFk := &ForeignKey{ Symbol: constraintName, Columns: []*Column{column}, RefTable: refTable, RefColumns: []*Column{refColumn}, } tableFksLookup[constraintName] = newFk t.AddForeignKey(newFk) } } if err := rows.Err(); err != nil { return err } } return nil } // fkQuery holds a query format for retrieving // foreign keys of the current schema. const fkQuery = ` SELECT tc.table_schema, tc.constraint_name, tc.table_name, kcu.column_name, ccu.table_schema AS foreign_table_schema, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_name = '%s' order by constraint_name, kcu.ordinal_position; ` // Atlas integration. func (d *Postgres) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) { return postgres.Open(&db{ExecQuerier: conn}) } func (d *Postgres) atTable(t1 *Table, t2 *schema.Table) { if t1.Annotation != nil { setAtChecks(t1, t2) } } func (d *Postgres) atTypeC(c1 *Column, c2 *schema.Column) error { if c1.SchemaType != nil && c1.SchemaType[dialect.Postgres] != "" { t, err := postgres.ParseType(strings.ToLower(c1.SchemaType[dialect.Postgres])) if err != nil { return err } c2.Type.Type = t if s, ok := t.(*postgres.SerialType); c1.foreign != nil && ok { c2.Type.Type = s.IntegerType() } return nil } var t schema.Type switch c1.Type { case field.TypeBool: t = &schema.BoolType{T: postgres.TypeBoolean} case field.TypeUint8, field.TypeInt8, field.TypeInt16, field.TypeUint16: t = &schema.IntegerType{T: postgres.TypeSmallInt} case field.TypeInt32, field.TypeUint32: t = &schema.IntegerType{T: postgres.TypeInt} case field.TypeInt, field.TypeUint, field.TypeInt64, field.TypeUint64: t = &schema.IntegerType{T: postgres.TypeBigInt} case field.TypeFloat32: t = &schema.FloatType{T: c1.scanTypeOr(postgres.TypeReal)} case field.TypeFloat64: t = &schema.FloatType{T: c1.scanTypeOr(postgres.TypeDouble)} case field.TypeBytes: t = &schema.BinaryType{T: postgres.TypeBytea} case field.TypeUUID: t = &postgres.UUIDType{T: postgres.TypeUUID} case field.TypeJSON: t = &schema.JSONType{T: postgres.TypeJSONB} case field.TypeString: t = &schema.StringType{T: postgres.TypeVarChar} if c1.Size > maxCharSize { t = &schema.StringType{T: postgres.TypeText} } case field.TypeTime: t = &schema.TimeType{T: c1.scanTypeOr(postgres.TypeTimestampWTZ)} case field.TypeEnum: // Although atlas supports enum types, we keep backwards compatibility // with previous versions of ent and use varchar (see cType). t = &schema.StringType{T: postgres.TypeVarChar} case field.TypeOther: t = &schema.UnsupportedType{T: c1.typ} default: t, err := postgres.ParseType(strings.ToLower(c1.typ)) if err != nil { return err } c2.Type.Type = t } c2.Type.Type = t return nil } func (d *Postgres) atUniqueC(t1 *Table, c1 *Column, t2 *schema.Table, c2 *schema.Column) { // For UNIQUE columns, PostgreSQL creates an implicit index named // "__key". for _, idx := range t1.Indexes { // Index also defined explicitly, and will be added in atIndexes. if idx.Unique && d.atImplicitIndexName(idx, t1, c1) { return } } t2.AddIndexes(schema.NewUniqueIndex(fmt.Sprintf("%s_%s_key", t1.Name, c1.Name)).AddColumns(c2)) } func (d *Postgres) atImplicitIndexName(idx *Index, t1 *Table, c1 *Column) bool { p := fmt.Sprintf("%s_%s_key", t1.Name, c1.Name) if idx.Name == p { return true } i, err := strconv.ParseInt(strings.TrimPrefix(idx.Name, p), 10, 64) return err == nil && i > 0 } func (d *Postgres) atIncrementC(t *schema.Table, c *schema.Column) { if _, ok := c.Type.Type.(*postgres.SerialType); ok { return } id := &postgres.Identity{} for _, a := range t.Attrs { if a, ok := a.(*postgres.Identity); ok { id = a } } c.AddAttrs(id) } func (d *Postgres) atIncrementT(t *schema.Table, v int64) { t.AddAttrs(&postgres.Identity{Sequence: &postgres.Sequence{Start: v}}) } func (d *Postgres) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error { for _, c1 := range idx1.Columns { c2, ok := t2.Column(c1.Name) if !ok { return fmt.Errorf("unexpected index %q column: %q", idx1.Name, c1.Name) } idx2.AddParts(&schema.IndexPart{C: c2}) } if t, ok := indexType(idx1, dialect.Postgres); ok { idx2.AddAttrs(&postgres.IndexType{T: t}) } if ant, supportsInclude := idx1.Annotation, compareVersions(d.version, "11.0.0") >= 0; ant != nil && len(ant.IncludeColumns) > 0 && supportsInclude { columns := make([]*schema.Column, len(ant.IncludeColumns)) for i, ic := range ant.IncludeColumns { c, ok := t2.Column(ic) if !ok { return fmt.Errorf("include column %q was not found for index %q", ic, idx1.Name) } columns[i] = c } idx2.AddAttrs(&postgres.IndexInclude{Columns: columns}) } if idx1.Annotation != nil && idx1.Annotation.Where != "" { idx2.AddAttrs(&postgres.IndexPredicate{P: idx1.Annotation.Where}) } return nil } func (Postgres) atTypeRangeSQL(ts ...string) string { for i := range ts { ts[i] = fmt.Sprintf("('%s')", ts[i]) } return fmt.Sprintf(`INSERT INTO "%s" ("type") VALUES %s`, TypeTable, strings.Join(ts, ", ")) } ent-0.11.3/dialect/sql/schema/postgres_test.go000066400000000000000000001432611431500740500212700ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "math" "strings" "testing" "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestPostgres_Create(t *testing.T) { tests := []struct { name string tables []*Table options []MigrateOption before func(pgMock) wantErr bool }{ { name: "tx failed", before: func(mock pgMock) { mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled) }, wantErr: true, }, { name: "unsupported version", before: func(mock pgMock) { mock.start("90000") }, wantErr: true, }, { name: "no tables", before: func(mock pgMock) { mock.start("120000") mock.ExpectCommit() }, }, { name: "create new table", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeUUID, Default: "uuid_generate_v4()"}, {Name: "block_size", Type: field.TypeInt, Default: "current_setting('block_size')::bigint"}, {Name: "name", Type: field.TypeString, Nullable: true, Collation: "he_IL"}, {Name: "age", Type: field.TypeInt}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, {Name: "enums", Type: field.TypeEnum, Enums: []string{"a", "b"}, Default: "a"}, {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.Postgres: "numeric(5,2)"}}, {Name: "strings", Type: field.TypeOther, SchemaType: map[string]string{dialect.Postgres: "text[]"}, Nullable: true}, {Name: "fixed_string", Type: field.TypeString, SchemaType: map[string]string{dialect.Postgres: "varchar(100)"}}, }, Annotation: &entsql.Annotation{ Check: "price > 0", Checks: map[string]string{ "valid_age": "age > 0", "valid_name": "name <> ''", }, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "block_size" bigint NOT NULL DEFAULT current_setting('block_size')::bigint, "name" varchar NULL COLLATE "he_IL", "age" bigint NOT NULL, "doc" jsonb NULL, "enums" varchar NOT NULL DEFAULT 'a', "price" numeric(5,2) NOT NULL, "strings" text[] NULL, "fixed_string" varchar(100) NOT NULL, PRIMARY KEY("id"), CHECK (price > 0), CONSTRAINT "valid_age" CHECK (age > 0), CONSTRAINT "valid_name" CHECK (name <> ''))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, {Name: "inet", Type: field.TypeString, Unique: true, SchemaType: map[string]string{dialect.Postgres: "inet"}}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NULL, "created_at" timestamp with time zone NOT NULL, "inet" inet UNIQUE NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "pets"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NOT NULL, "owner_id" bigint NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.fkExists("pets_owner", false) mock.ExpectExec(escape(`ALTER TABLE "pets" ADD CONSTRAINT "pets_owner" FOREIGN KEY("owner_id") REFERENCES "users"("id") ON DELETE CASCADE`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "create new table with foreign key disabled", options: []MigrateOption{ WithForeignKeys(false), }, tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NULL, "created_at" timestamp with time zone NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "pets"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "name" varchar NOT NULL, "owner_id" bigint NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "scan table with default", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "block_size", Type: field.TypeInt, Default: "current_setting('block_size')::bigint"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "nextval('users_colname_seq'::regclass)", "int4", nil, nil, nil). AddRow("block_size", "bigint", "NO", "current_setting('block_size')::bigint", "int4", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "block_size" TYPE bigint, ALTER COLUMN "block_size" SET NOT NULL, ALTER COLUMN "block_size" SET DEFAULT current_setting('block_size')::bigint`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "scan table with custom type", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "custom", Type: field.TypeOther, SchemaType: map[string]string{dialect.Postgres: "customtype"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "nextval('users_colname_seq'::regclass)", "NULL", nil, nil, nil). AddRow("custom", "USER-DEFINED", "NO", "NULL", "customtype", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectCommit() }, }, { name: "add column to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, {Name: "age", Type: field.TypeInt}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.Postgres: "date"}, Default: "CURRENT_DATE"}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{dialect.MySQL: "date"}, Nullable: true}, {Name: "deleted_at", Type: field.TypeTime, Nullable: true}, {Name: "cidr", Type: field.TypeString, SchemaType: map[string]string{dialect.Postgres: "cidr"}}, {Name: "point", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "point"}}, {Name: "line", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "line"}}, {Name: "lseg", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "lseg"}}, {Name: "box", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "box"}}, {Name: "path", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "path"}}, {Name: "polygon", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "polygon"}}, {Name: "circle", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "circle"}}, {Name: "macaddr", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "macaddr"}}, {Name: "macaddr8", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{dialect.Postgres: "macaddr8"}}, {Name: "strings", Type: field.TypeOther, SchemaType: map[string]string{dialect.Postgres: "text[]"}, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character varying", "YES", "NULL", "varchar", nil, nil, nil). AddRow("uuid", "uuid", "YES", "NULL", "uuid", nil, nil, nil). AddRow("created_at", "date", "NO", "CURRENT_DATE", "date", nil, nil, nil). AddRow("updated_at", "timestamp with time zone", "YES", "NULL", "timestamptz", nil, nil, nil). AddRow("deleted_at", "date", "YES", "NULL", "date", nil, nil, nil). AddRow("text", "text", "YES", "NULL", "text", nil, nil, nil). AddRow("cidr", "cidr", "NO", "NULL", "cidr", nil, nil, nil). AddRow("inet", "inet", "YES", "NULL", "inet", nil, nil, nil). AddRow("point", "point", "YES", "NULL", "point", nil, nil, nil). AddRow("line", "line", "YES", "NULL", "line", nil, nil, nil). AddRow("lseg", "lseg", "YES", "NULL", "lseg", nil, nil, nil). AddRow("box", "box", "YES", "NULL", "box", nil, nil, nil). AddRow("path", "path", "YES", "NULL", "path", nil, nil, nil). AddRow("polygon", "polygon", "YES", "NULL", "polygon", nil, nil, nil). AddRow("circle", "circle", "YES", "NULL", "circle", nil, nil, nil). AddRow("macaddr", "macaddr", "YES", "NULL", "macaddr", nil, nil, nil). AddRow("macaddr8", "macaddr8", "YES", "NULL", "macaddr8", nil, nil, nil). AddRow("strings", "ARRAY", "YES", "NULL", "_text", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" bigint NOT NULL, ALTER COLUMN "created_at" TYPE date, ALTER COLUMN "created_at" SET NOT NULL, ALTER COLUMN "created_at" SET DEFAULT CURRENT_DATE, ALTER COLUMN "deleted_at" TYPE timestamp with time zone, ALTER COLUMN "deleted_at" DROP NOT NULL`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add int column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeInt, Default: 10}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil). AddRow("doc", "jsonb", "YES", "NULL", "jsonb", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" bigint NOT NULL DEFAULT 10`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add blob columns", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "blob", Type: field.TypeBytes, Size: 1e3}, {Name: "longblob", Type: field.TypeBytes, Size: 1e6}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil). AddRow("doc", "jsonb", "YES", "NULL", "jsonb", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "blob" bytea NOT NULL, ADD COLUMN "longblob" bytea NOT NULL`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add float column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeFloat64, Default: 10.1}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" double precision NOT NULL DEFAULT 10.1`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add bool column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeBool, Default: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "age" boolean NOT NULL DEFAULT true`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add string column with default value to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "nick", Type: field.TypeString, Default: "unknown"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "nick" varchar NOT NULL DEFAULT 'unknown'`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "drop column to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropColumn(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" DROP COLUMN "name"`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "modify column to nullable", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "NO", "NULL", "bpchar", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "name" TYPE varchar, ALTER COLUMN "name" DROP NOT NULL`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "modify column default value", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Default: "unknown"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "NO", "NULL", "bpchar", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "name" TYPE varchar, ALTER COLUMN "name" SET NOT NULL, ALTER COLUMN "name" SET DEFAULT 'unknown'`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "apply uniqueness on column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt, Unique: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("age", "bigint", "NO", "NULL", "int8", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`CREATE UNIQUE INDEX IF NOT EXISTS "users_age" ON "users"("age")`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "remove uniqueness from column without option", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("age", "bigint", "NO", "NULL", "int8", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0). AddRow("users_age_key", "age", "f", "t", 0)) mock.ExpectCommit() }, }, { name: "remove uniqueness from column with option", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "age", Type: field.TypeInt}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, options: []MigrateOption{WithDropIndex(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("age", "bigint", "NO", "NULL", "int8", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0). AddRow("users_age_key", "age", "f", "t", 0)) mock.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."table_constraints" WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)). WithArgs("UNIQUE", "users_age_key"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) mock.ExpectExec(escape(`ALTER TABLE "users" DROP CONSTRAINT "users_age_key"`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "add and remove indexes", tables: func() []*Table { c1 := []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, // Add implicit index. {Name: "age", Type: field.TypeInt, Unique: true}, {Name: "score", Type: field.TypeInt}, } c2 := []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "score", Type: field.TypeInt}, {Name: "email", Type: field.TypeString}, } return []*Table{ { Name: "users", Columns: c1, PrimaryKey: c1[0:1], Indexes: Indexes{ // Change non-unique index to unique. {Name: "user_score", Columns: c1[2:3], Unique: true}, }, }, { Name: "equipment", Columns: c2, PrimaryKey: c2[0:1], Indexes: Indexes{ {Name: "equipment_score", Columns: c2[1:2]}, // Index should not be changed. {Name: "equipment_email", Unique: true, Columns: c2[2:]}, }, }, } }(), options: []MigrateOption{WithDropIndex(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("age", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("score", "bigint", "NO", "NULL", "int8", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0). AddRow("user_score", "score", "f", "f", 0)) mock.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."table_constraints" WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)). WithArgs("UNIQUE", "user_score"). WillReturnRows(sqlmock.NewRows([]string{"count"}). AddRow(0)) mock.ExpectExec(escape(`DROP INDEX "user_score"`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape(`CREATE UNIQUE INDEX IF NOT EXISTS "users_age" ON "users"("age")`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape(`CREATE UNIQUE INDEX IF NOT EXISTS "user_score" ON "users"("score")`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("equipment", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("equipment"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("score", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("email", "character varying", "YES", "NULL", "varchar", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "equipment"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0). AddRow("equipment_score", "score", "f", "f", 0). AddRow("equipment_email", "email", "f", "t", 0)) mock.ExpectCommit() }, }, { name: "add edge to table", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "spouse_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "user_spouse" + strings.Repeat("_", 64), // super long fk. Columns: c1[2:], RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) t1.ForeignKeys[0].RefTable = t1 return []*Table{t1} }(), before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "YES", "NULL", "int8", nil, nil, nil). AddRow("name", "character", "YES", "NULL", "bpchar", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ADD COLUMN "spouse_id" bigint NULL`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.fkExists("user_spouse____________________390ed76f91d3c57cd3516e7690f621dc", false) mock.ExpectExec(`ALTER TABLE "users" ADD CONSTRAINT ".{63}" FOREIGN KEY\("spouse_id"\) REFERENCES "users"\("id"\) ON DELETE CASCADE`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for all tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("ent_types", false) // create ent_types table. mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "ent_types"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, "type" varchar UNIQUE NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range. mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). WithArgs("users"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`ALTER TABLE "users" ALTER COLUMN "id" RESTART WITH 1`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("groups", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "groups"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`ALTER TABLE "groups" ALTER COLUMN "id" RESTART WITH 4294967296`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for new tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(`SELECT "type" FROM "ent_types" ORDER BY "id" ASC`). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) // query users table. mock.tableExists("users", true) // users table has no changes. mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "YES", "NULL", "int8", nil, nil, nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) // query groups table. mock.tableExists("groups", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "groups"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`ALTER TABLE "groups" ALTER COLUMN "id" RESTART WITH 4294967296`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "universal id for restored tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock pgMock) { mock.start("120000") mock.tableExists("ent_types", true) // query ent_types table. mock.ExpectQuery(`SELECT "type" FROM "ent_types" ORDER BY "id" ASC`). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) // query and create users (restored table). mock.tableExists("users", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "users"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range (without inserting to ent_types). mock.ExpectExec(`ALTER TABLE "users" ALTER COLUMN "id" RESTART WITH 1`). WillReturnResult(sqlmock.NewResult(0, 1)) // query groups table. mock.tableExists("groups", false) mock.ExpectExec(escape(`CREATE TABLE IF NOT EXISTS "groups"("id" bigint GENERATED BY DEFAULT AS IDENTITY NOT NULL, PRIMARY KEY("id"))`)). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape(`INSERT INTO "ent_types" ("type") VALUES ($1)`)). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`ALTER TABLE "groups" ALTER COLUMN "id" RESTART WITH 4294967296`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "no modify numeric column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{dialect.Postgres: "numeric(6,4)"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("price", "numeric", "NO", "NULL", "numeric", "6", "4", nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectCommit() }, }, { name: "modify numeric column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "price", Type: field.TypeFloat64, Nullable: false, SchemaType: map[string]string{dialect.Postgres: "numeric(6,4)"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("price", "numeric", "NO", "NULL", "numeric", "5", "4", nil)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "price" TYPE numeric(6,4), ALTER COLUMN "price" SET NOT NULL`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, { name: "no modify fixed size varchar column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, SchemaType: map[string]string{dialect.Postgres: "varchar(20)"}}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character varying", "NO", "NULL", "varchar", nil, nil, 20)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectCommit() }, }, { name: "modify fixed size varchar column", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, SchemaType: map[string]string{dialect.Postgres: "varchar(20)"}, Default: "unknown"}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock pgMock) { mock.start("120000") mock.tableExists("users", true) mock.ExpectQuery(escape(`SELECT "column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length" FROM "information_schema"."columns" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type", "is_nullable", "column_default", "udt_name", "numeric_precision", "numeric_scale", "character_maximum_length"}). AddRow("id", "bigint", "NO", "NULL", "int8", nil, nil, nil). AddRow("name", "character varying", "NO", "NULL", "varchar", nil, nil, 10)) mock.ExpectQuery(escape(fmt.Sprintf(indexesQuery, "CURRENT_SCHEMA()", "users"))). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "primary", "unique", "seq_in_index"}). AddRow("users_pkey", "id", "t", "t", 0)) mock.ExpectExec(escape(`ALTER TABLE "users" ALTER COLUMN "name" TYPE varchar(20), ALTER COLUMN "name" SET NOT NULL, ALTER COLUMN "name" SET DEFAULT 'unknown'`)). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.before(pgMock{mock}) migrate, err := NewMigrate(sql.OpenDB("postgres", db), append(tt.options, WithAtlas(false))...) require.NoError(t, err) err = migrate.Create(context.Background(), tt.tables...) require.Equal(t, tt.wantErr, err != nil, err) }) } } type pgMock struct { sqlmock.Sqlmock } func (m pgMock) start(version string) { m.ExpectQuery(escape("SHOW server_version_num")). WillReturnRows(sqlmock.NewRows([]string{"server_version_num"}).AddRow(version)) m.ExpectBegin() } func (m pgMock) tableExists(table string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."tables" WHERE "table_schema" = CURRENT_SCHEMA() AND "table_name" = $1`)). WithArgs(table). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } func (m pgMock) fkExists(fk string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape(`SELECT COUNT(*) FROM "information_schema"."table_constraints" WHERE "table_schema" = CURRENT_SCHEMA() AND "constraint_type" = $1 AND "constraint_name" = $2`)). WithArgs("FOREIGN KEY", fk). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } ent-0.11.3/dialect/sql/schema/schema.go000066400000000000000000000455241431500740500176260ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package schema contains all schema migration logic for SQL dialects. package schema import ( "fmt" "sort" "strconv" "strings" "entgo.io/ent/dialect/entsql" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" ) const ( // DefaultStringLen describes the default length for string/varchar types. DefaultStringLen int64 = 255 // Null is the string representation of NULL in SQL. Null = "NULL" // PrimaryKey is the string representation of PKs in SQL. PrimaryKey = "PRI" // UniqueKey is the string representation of PKs in SQL. UniqueKey = "UNI" ) // Table schema definition for SQL dialects. type Table struct { Name string Columns []*Column columns map[string]*Column Indexes []*Index PrimaryKey []*Column ForeignKeys []*ForeignKey Annotation *entsql.Annotation } // NewTable returns a new table with the given name. func NewTable(name string) *Table { return &Table{ Name: name, columns: make(map[string]*Column), } } // AddPrimary adds a new primary key to the table. func (t *Table) AddPrimary(c *Column) *Table { c.Key = PrimaryKey t.AddColumn(c) t.PrimaryKey = append(t.PrimaryKey, c) return t } // AddForeignKey adds a foreign key to the table. func (t *Table) AddForeignKey(fk *ForeignKey) *Table { t.ForeignKeys = append(t.ForeignKeys, fk) return t } // AddColumn adds a new column to the table. func (t *Table) AddColumn(c *Column) *Table { t.columns[c.Name] = c t.Columns = append(t.Columns, c) return t } // HasColumn reports if the table contains a column with the given name. func (t *Table) HasColumn(name string) bool { _, ok := t.Column(name) return ok } // Column returns the column with the given name. If exists. func (t *Table) Column(name string) (*Column, bool) { if c, ok := t.columns[name]; ok { return c, true } // In case the column was added // directly to the Columns field. for _, c := range t.Columns { if c.Name == name { return c, true } } return nil, false } // SetAnnotation the entsql.Annotation on the table. func (t *Table) SetAnnotation(ant *entsql.Annotation) *Table { t.Annotation = ant return t } // AddIndex creates and adds a new index to the table from the given options. func (t *Table) AddIndex(name string, unique bool, columns []string) *Table { return t.addIndex(&Index{ Name: name, Unique: unique, columns: columns, Columns: make([]*Column, 0, len(columns)), }) } // AddIndex creates and adds a new index to the table from the given options. func (t *Table) addIndex(idx *Index) *Table { for _, name := range idx.columns { c, ok := t.columns[name] if ok { c.indexes.append(idx) idx.Columns = append(idx.Columns, c) } } t.Indexes = append(t.Indexes, idx) return t } // column returns a table column by its name. // faster than map lookup for most cases. func (t *Table) column(name string) (*Column, bool) { for _, c := range t.Columns { if c.Name == name { return c, true } } return nil, false } // Index returns a table index by its exact name. func (t *Table) Index(name string) (*Index, bool) { idx, ok := t.index(name) if ok && idx.Name == name { return idx, ok } return nil, false } // index returns a table index by its name. func (t *Table) index(name string) (*Index, bool) { for _, idx := range t.Indexes { if name == idx.Name || name == idx.realname { return idx, true } // Same as below, there are cases where the index name // is unknown (created automatically on column constraint). if len(idx.Columns) == 1 && idx.Columns[0].Name == name { return idx, true } } // If it is an "implicit index" (unique constraint on // table creation) and it wasn't loaded in table scanning. c, ok := t.column(name) if !ok { // Postgres naming convention for unique constraint (
__key). name = strings.TrimPrefix(name, t.Name+"_") name = strings.TrimSuffix(name, "_key") c, ok = t.column(name) } if ok && c.Unique { return &Index{Name: name, Unique: c.Unique, Columns: []*Column{c}, columns: []string{c.Name}}, true } return nil, false } // hasIndex reports if the table has at least one index that matches the given names. func (t *Table) hasIndex(names ...string) bool { for i := range names { if names[i] == "" { continue } if _, ok := t.index(names[i]); ok { return true } } return false } // fk returns a table foreign-key by its symbol. // faster than map lookup for most cases. func (t *Table) fk(symbol string) (*ForeignKey, bool) { for _, fk := range t.ForeignKeys { if fk.Symbol == symbol { return fk, true } } return nil, false } // CopyTables returns a deep-copy of the given tables. This utility function is // useful for copying the generated schema tables (i.e. migrate.Tables) before // running schema migration when there is a need for execute multiple migrations // concurrently. e.g. running parallel unit-tests using the generated enttest package. func CopyTables(tables []*Table) ([]*Table, error) { var ( copyT = make([]*Table, len(tables)) byName = make(map[string]*Table) ) for i, t := range tables { copyT[i] = &Table{ Name: t.Name, Columns: make([]*Column, len(t.Columns)), Indexes: make([]*Index, len(t.Indexes)), ForeignKeys: make([]*ForeignKey, len(t.ForeignKeys)), } for j, c := range t.Columns { cc := *c // SchemaType and Enums are read-only fields. cc.indexes = nil cc.foreign = nil copyT[i].Columns[j] = &cc } if at := t.Annotation; at != nil { cat := *at copyT[i].Annotation = &cat } byName[t.Name] = copyT[i] } for i, t := range tables { ct := copyT[i] for _, c := range t.PrimaryKey { cc, ok := ct.column(c.Name) if !ok { return nil, fmt.Errorf("sql/schema: missing primary key column %q", c.Name) } ct.PrimaryKey = append(ct.PrimaryKey, cc) } for j, idx := range t.Indexes { cidx := &Index{ Name: idx.Name, Unique: idx.Unique, Columns: make([]*Column, len(idx.Columns)), } if at := idx.Annotation; at != nil { cat := *at cidx.Annotation = &cat } for k, c := range idx.Columns { cc, ok := ct.column(c.Name) if !ok { return nil, fmt.Errorf("sql/schema: missing index column %q", c.Name) } cidx.Columns[k] = cc } ct.Indexes[j] = cidx } for j, fk := range t.ForeignKeys { cfk := &ForeignKey{ Symbol: fk.Symbol, OnUpdate: fk.OnUpdate, OnDelete: fk.OnDelete, Columns: make([]*Column, len(fk.Columns)), RefColumns: make([]*Column, len(fk.RefColumns)), } for k, c := range fk.Columns { cc, ok := ct.column(c.Name) if !ok { return nil, fmt.Errorf("sql/schema: missing foreign-key column %q", c.Name) } cfk.Columns[k] = cc } cref, ok := byName[fk.RefTable.Name] if !ok { return nil, fmt.Errorf("sql/schema: missing foreign-key ref-table %q", fk.RefTable.Name) } cfk.RefTable = cref for k, c := range fk.RefColumns { cc, ok := cref.column(c.Name) if !ok { return nil, fmt.Errorf("sql/schema: missing foreign-key ref-column %q", c.Name) } cfk.RefColumns[k] = cc } ct.ForeignKeys[j] = cfk } } return copyT, nil } // Column schema definition for SQL dialects. type Column struct { Name string // column name. Type field.Type // column type. SchemaType map[string]string // optional schema type per dialect. Attr string // extra attributes. Size int64 // max size parameter for string, blob, etc. Key string // key definition (PRI, UNI or MUL). Unique bool // column with unique constraint. Increment bool // auto increment attribute. Nullable bool // null or not null attribute. Default any // default value. Enums []string // enum values. Collation string // collation type (utf8mb4_unicode_ci, utf8mb4_general_ci) typ string // row column type (used for Rows.Scan). indexes Indexes // linked indexes. foreign *ForeignKey // linked foreign-key. } // UniqueKey returns boolean indicates if this column is a unique key. // Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects. func (c *Column) UniqueKey() bool { return c.Key == UniqueKey } // PrimaryKey returns boolean indicates if this column is on of the primary key columns. // Used by the migration tool when parsing the `DESCRIBE TABLE` output Go objects. func (c *Column) PrimaryKey() bool { return c.Key == PrimaryKey } // ConvertibleTo reports whether a column can be converted to the new column without altering its data. func (c *Column) ConvertibleTo(d *Column) bool { switch { case c.Type == d.Type: if c.Size != 0 && d.Size != 0 { // Types match and have a size constraint. return c.Size <= d.Size } return true case c.IntType() && d.IntType() || c.UintType() && d.UintType(): return c.Type <= d.Type case c.UintType() && d.IntType(): // uintX can not be converted to intY, when X > Y. return c.Type-field.TypeUint8 <= d.Type-field.TypeInt8 case c.Type == field.TypeString && d.Type == field.TypeEnum || c.Type == field.TypeEnum && d.Type == field.TypeString: return true case c.Type.Integer() && d.Type == field.TypeString: return true } return c.FloatType() && d.FloatType() } // IntType reports whether the column is an int type (int8 ... int64). func (c Column) IntType() bool { return c.Type >= field.TypeInt8 && c.Type <= field.TypeInt64 } // UintType reports of the given type is a uint type (int8 ... int64). func (c Column) UintType() bool { return c.Type >= field.TypeUint8 && c.Type <= field.TypeUint64 } // FloatType reports of the given type is a float type (float32, float64). func (c Column) FloatType() bool { return c.Type == field.TypeFloat32 || c.Type == field.TypeFloat64 } // ScanDefault scans the default value string to its interface type. func (c *Column) ScanDefault(value string) error { switch { case strings.ToUpper(value) == Null: // ignore. case c.IntType(): v := &sql.NullInt64{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning int value for column %q: %w", c.Name, err) } c.Default = v.Int64 case c.UintType(): v := &sql.NullInt64{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning uint value for column %q: %w", c.Name, err) } c.Default = uint64(v.Int64) case c.FloatType(): v := &sql.NullFloat64{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning float value for column %q: %w", c.Name, err) } c.Default = v.Float64 case c.Type == field.TypeBool: v := &sql.NullBool{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning bool value for column %q: %w", c.Name, err) } c.Default = v.Bool case c.Type == field.TypeString || c.Type == field.TypeEnum: v := &sql.NullString{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning string value for column %q: %w", c.Name, err) } c.Default = v.String case c.Type == field.TypeJSON: v := &sql.NullString{} if err := v.Scan(value); err != nil { return fmt.Errorf("scanning json value for column %q: %w", c.Name, err) } c.Default = v.String case c.Type == field.TypeBytes: c.Default = []byte(value) case c.Type == field.TypeUUID: // skip function if !strings.Contains(value, "()") { c.Default = value } default: return fmt.Errorf("unsupported default type: %v default to %q", c.Type, value) } return nil } // defaultValue adds the `DEFAULT` attribute to the column. // Note that, in SQLite if a NOT NULL constraint is specified, // then the column must have a default value which not NULL. func (c *Column) defaultValue(b *sql.ColumnBuilder) { if c.Default == nil || !c.supportDefault() { return } // Has default and the database supports adding this default. attr := fmt.Sprint(c.Default) switch v := c.Default.(type) { case bool: attr = strconv.FormatBool(v) case string: if t := c.Type; t != field.TypeUUID && t != field.TypeTime { // Escape single quote by replacing each with 2. attr = fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) } } b.Attr("DEFAULT " + attr) } // supportDefault reports if the column type supports default value. func (c Column) supportDefault() bool { switch t := c.Type; t { case field.TypeString, field.TypeEnum: return c.Size < 1<<16 // not a text. case field.TypeBool, field.TypeTime, field.TypeUUID: return true default: return t.Numeric() } } // unique adds the `UNIQUE` attribute if the column is a unique type. // it is exist in a different function to share the common declaration // between the two dialects. func (c *Column) unique(b *sql.ColumnBuilder) { if c.Unique { b.Attr("UNIQUE") } } // nullable adds the `NULL`/`NOT NULL` attribute to the column if it exists in // a different function to share the common declaration between the two dialects. func (c *Column) nullable(b *sql.ColumnBuilder) { attr := Null if !c.Nullable { attr = "NOT " + attr } b.Attr(attr) } // scanTypeOr returns the scanning type or the given value. func (c *Column) scanTypeOr(t string) string { if c.typ != "" { return strings.ToLower(c.typ) } return t } // ForeignKey definition for creation. type ForeignKey struct { Symbol string // foreign-key name. Generated if empty. Columns []*Column // table column RefTable *Table // referenced table. RefColumns []*Column // referenced columns. OnUpdate ReferenceOption // action on update. OnDelete ReferenceOption // action on delete. } func (fk ForeignKey) column(name string) (*Column, bool) { for _, c := range fk.Columns { if c.Name == name { return c, true } } return nil, false } func (fk ForeignKey) refColumn(name string) (*Column, bool) { for _, c := range fk.RefColumns { if c.Name == name { return c, true } } return nil, false } // DSL returns a default DSL query for a foreign-key. func (fk ForeignKey) DSL() *sql.ForeignKeyBuilder { cols := make([]string, len(fk.Columns)) refs := make([]string, len(fk.RefColumns)) for i, c := range fk.Columns { cols[i] = c.Name } for i, c := range fk.RefColumns { refs[i] = c.Name } dsl := sql.ForeignKey().Symbol(fk.Symbol). Columns(cols...). Reference(sql.Reference().Table(fk.RefTable.Name).Columns(refs...)) if action := string(fk.OnDelete); action != "" { dsl.OnDelete(action) } if action := string(fk.OnUpdate); action != "" { dsl.OnUpdate(action) } return dsl } // ReferenceOption for constraint actions. type ReferenceOption string // Reference options. const ( NoAction ReferenceOption = "NO ACTION" Restrict ReferenceOption = "RESTRICT" Cascade ReferenceOption = "CASCADE" SetNull ReferenceOption = "SET NULL" SetDefault ReferenceOption = "SET DEFAULT" ) // ConstName returns the constant name of a reference option. It's used by entc for printing the constant name in templates. func (r ReferenceOption) ConstName() string { return strings.ReplaceAll(strings.Title(strings.ToLower(string(r))), " ", "") } // Index definition for table index. type Index struct { Name string // index name. Unique bool // uniqueness. Columns []*Column // actual table columns. Annotation *entsql.IndexAnnotation // index annotation. columns []string // columns loaded from query scan. primary bool // primary key index. realname string // real name in the database (Postgres only). } // Builder returns the query builder for index creation. The DSL is identical in all dialects. func (i *Index) Builder(table string) *sql.IndexBuilder { idx := sql.CreateIndex(i.Name).Table(table) if i.Unique { idx.Unique() } for _, c := range i.Columns { idx.Column(c.Name) } return idx } // DropBuilder returns the query builder for the drop index. func (i *Index) DropBuilder(table string) *sql.DropIndexBuilder { idx := sql.DropIndex(i.Name).Table(table) return idx } // sameAs reports if the index has the same properties // as the given index (except the name). func (i *Index) sameAs(idx *Index) bool { if i.Unique != idx.Unique || len(i.Columns) != len(idx.Columns) { return false } for j, c := range i.Columns { if c.Name != idx.Columns[j].Name { return false } } return true } // columnNames returns the names of the columns of the index. func (i *Index) columnNames() []string { if len(i.columns) > 0 { return i.columns } columns := make([]string, 0, len(i.Columns)) for _, c := range i.Columns { columns = append(columns, c.Name) } return columns } // Indexes used for scanning all sql.Rows into a list of indexes, because // multiple sql rows can represent the same index (multi-columns indexes). type Indexes []*Index // append wraps the basic `append` function by filtering duplicates indexes. func (i *Indexes) append(idx1 *Index) { for _, idx2 := range *i { if idx2.Name == idx1.Name { return } } *i = append(*i, idx1) } // compareVersions returns an integer comparing the 2 versions. func compareVersions(v1, v2 string) int { pv1, ok1 := parseVersion(v1) pv2, ok2 := parseVersion(v2) if !ok1 && !ok2 { return 0 } if !ok1 { return -1 } if !ok2 { return 1 } if v := compare(pv1.major, pv2.major); v != 0 { return v } if v := compare(pv1.minor, pv2.minor); v != 0 { return v } return compare(pv1.patch, pv2.patch) } // version represents a parsed MySQL version. type version struct { major int minor int patch int } // parseVersion returns an integer comparing the 2 versions. func parseVersion(v string) (*version, bool) { parts := strings.Split(v, ".") if len(parts) == 0 { return nil, false } var ( err error ver = &version{} ) for i, e := range []*int{&ver.major, &ver.minor, &ver.patch} { if i == len(parts) { break } if *e, err = strconv.Atoi(strings.Split(parts[i], "-")[0]); err != nil { return nil, false } } return ver, true } func compare(v1, v2 int) int { if v1 == v2 { return 0 } if v1 < v2 { return -1 } return 1 } // addChecks appends the CHECK clauses from the entsql.Annotation. func addChecks(t *sql.TableBuilder, ant *entsql.Annotation) { if check := ant.Check; check != "" { t.Checks(func(b *sql.Builder) { b.WriteString("CHECK " + checkExpr(check)) }) } if checks := ant.Checks; len(ant.Checks) > 0 { names := make([]string, 0, len(checks)) for name := range checks { names = append(names, name) } sort.Strings(names) for _, name := range names { name := name t.Checks(func(b *sql.Builder) { b.WriteString("CONSTRAINT ").Ident(name).WriteString(" CHECK " + checkExpr(checks[name])) }) } } } // checkExpr formats the CHECK expression. func checkExpr(expr string) string { expr = strings.TrimSpace(expr) if !strings.HasPrefix(expr, "(") && !strings.HasSuffix(expr, ")") { expr = "(" + expr + ")" } return expr } ent-0.11.3/dialect/sql/schema/schema_test.go000066400000000000000000000130511431500740500206530ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "testing" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema/field" "github.com/stretchr/testify/require" ) func TestColumn_ConvertibleTo(t *testing.T) { c1 := &Column{Type: field.TypeString, Size: 10} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 10})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 255})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 9})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat32})) c1 = &Column{Type: field.TypeFloat32} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat32})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat64})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint})) c1 = &Column{Type: field.TypeFloat64} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat32})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeFloat64})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeString})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint})) c1 = &Column{Type: field.TypeUint} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeUint})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt64})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeUint64})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeInt8})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint8})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint16})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint32})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 1})) c1 = &Column{Type: field.TypeInt} require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeInt64})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeInt8})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeInt32})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint8})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint16})) require.False(t, c1.ConvertibleTo(&Column{Type: field.TypeUint32})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString})) require.True(t, c1.ConvertibleTo(&Column{Type: field.TypeString, Size: 1})) } func TestColumn_ScanDefault(t *testing.T) { c1 := &Column{Type: field.TypeString, Size: 10} require.NoError(t, c1.ScanDefault("Hello World")) require.Equal(t, "Hello World", c1.Default) require.NoError(t, c1.ScanDefault("1")) require.Equal(t, "1", c1.Default) c1 = &Column{Type: field.TypeInt64} require.NoError(t, c1.ScanDefault("128")) require.Equal(t, int64(128), c1.Default) require.NoError(t, c1.ScanDefault("1")) require.Equal(t, int64(1), c1.Default) require.Error(t, c1.ScanDefault("foo")) c1 = &Column{Type: field.TypeUint64} require.NoError(t, c1.ScanDefault("128")) require.Equal(t, uint64(128), c1.Default) require.NoError(t, c1.ScanDefault("1")) require.Equal(t, uint64(1), c1.Default) require.Error(t, c1.ScanDefault("foo")) c1 = &Column{Type: field.TypeFloat64} require.NoError(t, c1.ScanDefault("128.1")) require.Equal(t, 128.1, c1.Default) require.NoError(t, c1.ScanDefault("1")) require.Equal(t, float64(1), c1.Default) require.Error(t, c1.ScanDefault("foo")) c1 = &Column{Type: field.TypeBool} require.NoError(t, c1.ScanDefault("1")) require.Equal(t, true, c1.Default) require.NoError(t, c1.ScanDefault("true")) require.Equal(t, true, c1.Default) require.NoError(t, c1.ScanDefault("0")) require.Equal(t, false, c1.Default) require.NoError(t, c1.ScanDefault("false")) require.Equal(t, false, c1.Default) require.Error(t, c1.ScanDefault("foo")) c1 = &Column{Type: field.TypeUUID} require.NoError(t, c1.ScanDefault("gen_random_uuid()")) require.Equal(t, nil, c1.Default) require.NoError(t, c1.ScanDefault("00000000-0000-0000-0000-000000000000")) require.Equal(t, "00000000-0000-0000-0000-000000000000", c1.Default) } func TestCopyTables(t *testing.T) { users := &Table{ Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt}, {Name: "name", Type: field.TypeString}, {Name: "spouse_id", Type: field.TypeInt}, }, } users.PrimaryKey = users.Columns[:1] users.Indexes = append(users.Indexes, &Index{ Name: "name", Columns: users.Columns[1:2], }) users.AddForeignKey(&ForeignKey{ Columns: users.Columns[2:], RefTable: users, RefColumns: users.Columns[:1], OnUpdate: SetNull, }) users.SetAnnotation(&entsql.Annotation{Table: "Users"}) pets := &Table{ Name: "pets", Columns: []*Column{ {Name: "id", Type: field.TypeInt}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt}, }, } pets.Indexes = append(pets.Indexes, &Index{ Name: "name", Unique: true, Columns: pets.Columns[1:2], Annotation: entsql.Desc(), }) pets.AddForeignKey(&ForeignKey{ Columns: pets.Columns[2:], RefTable: users, RefColumns: users.Columns[:1], OnDelete: SetDefault, }) tables := []*Table{users, pets} copyT, err := CopyTables(tables) require.NoError(t, err) require.Equal(t, tables, copyT) } ent-0.11.3/dialect/sql/schema/sqlite.go000066400000000000000000000367131431500740500176670ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" stdsql "database/sql" "fmt" "strconv" "strings" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "ariga.io/atlas/sql/migrate" "ariga.io/atlas/sql/schema" "ariga.io/atlas/sql/sqlite" ) type ( // SQLite is an SQLite migration driver. SQLite struct { dialect.Driver WithForeignKeys bool } // SQLiteTx implements dialect.Tx. SQLiteTx struct { dialect.Tx commit func() error // Override Commit to toggle foreign keys back on after Commit. rollback func() error // Override Rollback to toggle foreign keys back on after Rollback. } ) // Tx implements opens a transaction. func (d *SQLite) Tx(ctx context.Context) (dialect.Tx, error) { db := &db{d} if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = off"); err != nil { return nil, fmt.Errorf("sqlite: set 'foreign_keys = off': %w", err) } t, err := d.Driver.Tx(ctx) if err != nil { return nil, err } tx := &tx{t} cm, err := sqlite.CommitFunc(ctx, db, tx, true) if err != nil { return nil, err } return &SQLiteTx{Tx: t, commit: cm, rollback: sqlite.RollbackFunc(ctx, db, tx, true)}, nil } // Commit ensures foreign keys are toggled back on after commit. func (tx *SQLiteTx) Commit() error { return tx.commit() } // Rollback ensures foreign keys are toggled back on after rollback. func (tx *SQLiteTx) Rollback() error { return tx.rollback() } // init makes sure that foreign_keys support is enabled. func (d *SQLite) init(ctx context.Context) error { on, err := exist(ctx, d, "PRAGMA foreign_keys") if err != nil { return fmt.Errorf("sqlite: check foreign_keys pragma: %w", err) } if !on { // foreign_keys pragma is off, either enable it by execute "PRAGMA foreign_keys=ON" // or add the following parameter in the connection string "_fk=1". return fmt.Errorf("sqlite: foreign_keys pragma is off: missing %q in the connection string", "_fk=1") } return nil } func (d *SQLite) tableExist(ctx context.Context, conn dialect.ExecQuerier, name string) (bool, error) { query, args := sql.Select().Count(). From(sql.Table("sqlite_master")). Where(sql.And( sql.EQ("type", "table"), sql.EQ("name", name), )). Query() return exist(ctx, conn, query, args...) } // setRange sets the start value of table PK. // SQLite tracks the AUTOINCREMENT in the "sqlite_sequence" table that is created and initialized automatically // whenever a table that contains an AUTOINCREMENT column is created. However, it populates to it a rows (for tables) // only after the first insertion. Therefore, we check. If a record (for the given table) already exists in the "sqlite_sequence" // table, we updated it. Otherwise, we insert a new value. func (d *SQLite) setRange(ctx context.Context, conn dialect.ExecQuerier, t *Table, value int64) error { query, args := sql.Select().Count(). From(sql.Table("sqlite_sequence")). Where(sql.EQ("name", t.Name)). Query() exists, err := exist(ctx, conn, query, args...) switch { case err != nil: return err case exists: query, args = sql.Update("sqlite_sequence").Set("seq", value).Where(sql.EQ("name", t.Name)).Query() default: // !exists query, args = sql.Insert("sqlite_sequence").Columns("name", "seq").Values(t.Name, value).Query() } return conn.Exec(ctx, query, args, nil) } func (d *SQLite) tBuilder(t *Table) *sql.TableBuilder { b := sql.CreateTable(t.Name) for _, c := range t.Columns { b.Column(d.addColumn(c)) } if t.Annotation != nil { addChecks(b, t.Annotation) } // Unlike in MySQL, we're not able to add foreign-key constraints to table // after it was created, and adding them to the `CREATE TABLE` statement is // not always valid (because circular foreign-keys situation is possible). // We stay consistent by not using constraints at all, and just defining the // foreign keys in the `CREATE TABLE` statement. if d.WithForeignKeys { for _, fk := range t.ForeignKeys { b.ForeignKeys(fk.DSL()) } } // If it's an ID based primary key with autoincrement, we add // the `PRIMARY KEY` clause to the column declaration. Otherwise, // we append it to the constraint clause. if len(t.PrimaryKey) == 1 && t.PrimaryKey[0].Increment { return b } for _, pk := range t.PrimaryKey { b.PrimaryKey(pk.Name) } return b } // cType returns the SQLite string type for the given column. func (*SQLite) cType(c *Column) (t string) { if c.SchemaType != nil && c.SchemaType[dialect.SQLite] != "" { return c.SchemaType[dialect.SQLite] } switch c.Type { case field.TypeBool: t = "bool" case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32, field.TypeUint32, field.TypeUint, field.TypeInt, field.TypeInt64, field.TypeUint64: t = "integer" case field.TypeBytes: t = "blob" case field.TypeString, field.TypeEnum: // SQLite does not impose any length restrictions on // the length of strings, BLOBs or numeric values. t = fmt.Sprintf("varchar(%d)", DefaultStringLen) case field.TypeFloat32, field.TypeFloat64: t = "real" case field.TypeTime: t = "datetime" case field.TypeJSON: t = "json" case field.TypeUUID: t = "uuid" case field.TypeOther: t = c.typ default: panic(fmt.Sprintf("unsupported type %q for column %q", c.Type, c.Name)) } return t } // addColumn returns the DSL query for adding the given column to a table. func (d *SQLite) addColumn(c *Column) *sql.ColumnBuilder { b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr) c.unique(b) if c.PrimaryKey() && c.Increment { b.Attr("PRIMARY KEY AUTOINCREMENT") } c.nullable(b) c.defaultValue(b) return b } // addIndex returns the query for adding an index to SQLite. func (d *SQLite) addIndex(i *Index, table string) *sql.IndexBuilder { return i.Builder(table).IfNotExists() } // dropIndex drops a SQLite index. func (d *SQLite) dropIndex(ctx context.Context, tx dialect.Tx, idx *Index, table string) error { query, args := idx.DropBuilder("").Query() return tx.Exec(ctx, query, args, nil) } // fkExist returns always true to disable foreign-keys creation after the table was created. func (d *SQLite) fkExist(context.Context, dialect.Tx, string) (bool, error) { return true, nil } // table returns always error to indicate that SQLite dialect doesn't support incremental migration. func (d *SQLite) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) { rows := &sql.Rows{} query, args := sql.Select("name", "type", "notnull", "dflt_value", "pk"). From(sql.Table(fmt.Sprintf("pragma_table_info('%s')", name)).Unquote()). OrderBy("pk"). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("sqlite: reading table description %w", err) } // Call Close in cases of failures (Close is idempotent). defer rows.Close() t := NewTable(name) for rows.Next() { c := &Column{} if err := d.scanColumn(c, rows); err != nil { return nil, fmt.Errorf("sqlite: %w", err) } if c.PrimaryKey() { t.PrimaryKey = append(t.PrimaryKey, c) } t.AddColumn(c) } if err := rows.Err(); err != nil { return nil, err } if err := rows.Close(); err != nil { return nil, fmt.Errorf("sqlite: closing rows %w", err) } indexes, err := d.indexes(ctx, tx, name) if err != nil { return nil, err } // Add and link indexes to table columns. for _, idx := range indexes { switch { case idx.primary: case idx.Unique && len(idx.columns) == 1: name := idx.columns[0] c, ok := t.column(name) if !ok { return nil, fmt.Errorf("index %q column %q was not found in table %q", idx.Name, name, t.Name) } c.Key = UniqueKey c.Unique = true fallthrough default: t.addIndex(idx) } } return t, nil } // table loads the table indexes from the database. func (d *SQLite) indexes(ctx context.Context, tx dialect.Tx, name string) (Indexes, error) { rows := &sql.Rows{} query, args := sql.Select("name", "unique", "origin"). From(sql.Table(fmt.Sprintf("pragma_index_list('%s')", name)).Unquote()). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("reading table indexes %w", err) } defer rows.Close() var idx Indexes for rows.Next() { i := &Index{} origin := sql.NullString{} if err := rows.Scan(&i.Name, &i.Unique, &origin); err != nil { return nil, fmt.Errorf("scanning index description %w", err) } i.primary = origin.String == "pk" idx = append(idx, i) } if err := rows.Err(); err != nil { return nil, err } if err := rows.Close(); err != nil { return nil, fmt.Errorf("closing rows %w", err) } for i := range idx { columns, err := d.indexColumns(ctx, tx, idx[i].Name) if err != nil { return nil, err } idx[i].columns = columns // Normalize implicit index names to ent naming convention. See: // https://github.com/sqlite/sqlite/blob/e937df8/src/build.c#L3583 if len(columns) == 1 && strings.HasPrefix(idx[i].Name, "sqlite_autoindex_"+name) { idx[i].Name = columns[0] } } return idx, nil } // indexColumns loads index columns from index info. func (d *SQLite) indexColumns(ctx context.Context, tx dialect.Tx, name string) ([]string, error) { rows := &sql.Rows{} query, args := sql.Select("name"). From(sql.Table(fmt.Sprintf("pragma_index_info('%s')", name)).Unquote()). OrderBy("seqno"). Query() if err := tx.Query(ctx, query, args, rows); err != nil { return nil, fmt.Errorf("reading table indexes %w", err) } defer rows.Close() var names []string if err := sql.ScanSlice(rows, &names); err != nil { return nil, err } return names, nil } // scanColumn scans the column information from SQLite column description. func (d *SQLite) scanColumn(c *Column, rows *sql.Rows) error { var ( pk sql.NullInt64 notnull sql.NullInt64 defaults sql.NullString ) if err := rows.Scan(&c.Name, &c.typ, ¬null, &defaults, &pk); err != nil { return fmt.Errorf("scanning column description: %w", err) } c.Nullable = notnull.Int64 == 0 if pk.Int64 > 0 { c.Key = PrimaryKey } if c.typ == "" { return fmt.Errorf("missing type information for column %q", c.Name) } parts, size, _, err := parseColumn(c.typ) if err != nil { return err } switch strings.ToLower(parts[0]) { case "bool", "boolean": c.Type = field.TypeBool case "blob": c.Type = field.TypeBytes case "integer": // All integer types have the same "type affinity". c.Type = field.TypeInt case "real", "float", "double": c.Type = field.TypeFloat64 case "datetime": c.Type = field.TypeTime case "json": c.Type = field.TypeJSON case "uuid": c.Type = field.TypeUUID case "varchar", "char", "text": c.Size = size c.Type = field.TypeString case "decimal", "numeric": c.Type = field.TypeOther } if defaults.Valid { return c.ScanDefault(defaults.String) } return nil } // alterColumns returns the queries for applying the columns change-set. func (d *SQLite) alterColumns(table string, add, _, _ []*Column) sql.Queries { queries := make(sql.Queries, 0, len(add)) for i := range add { c := d.addColumn(add[i]) if fk := add[i].foreign; fk != nil { c.Constraint(fk.DSL()) } queries = append(queries, sql.Dialect(dialect.SQLite).AlterTable(table).AddColumn(c)) } // Modifying and dropping columns is not supported and disabled until we // will support https://www.sqlite.org/lang_altertable.html#otheralter return queries } // tables returns the query for getting the in the schema. func (d *SQLite) tables() sql.Querier { return sql.Select("name"). From(sql.Table("sqlite_schema")). Where(sql.EQ("type", "table")) } // needsConversion reports if column "old" needs to be converted // (by table altering) to column "new". func (d *SQLite) needsConversion(old, new *Column) bool { c1, c2 := d.cType(old), d.cType(new) return c1 != c2 && old.typ != c2 } // Atlas integration. func (d *SQLite) atOpen(conn dialect.ExecQuerier) (migrate.Driver, error) { return sqlite.Open(&db{ExecQuerier: conn}) } func (d *SQLite) atTable(t1 *Table, t2 *schema.Table) { if t1.Annotation != nil { setAtChecks(t1, t2) } } func (d *SQLite) atTypeC(c1 *Column, c2 *schema.Column) error { if c1.SchemaType != nil && c1.SchemaType[dialect.SQLite] != "" { t, err := sqlite.ParseType(strings.ToLower(c1.SchemaType[dialect.SQLite])) if err != nil { return err } c2.Type.Type = t return nil } var t schema.Type switch c1.Type { case field.TypeBool: t = &schema.BoolType{T: "bool"} case field.TypeInt8, field.TypeUint8, field.TypeInt16, field.TypeUint16, field.TypeInt32, field.TypeUint32, field.TypeUint, field.TypeInt, field.TypeInt64, field.TypeUint64: t = &schema.IntegerType{T: sqlite.TypeInteger} case field.TypeBytes: t = &schema.BinaryType{T: sqlite.TypeBlob} case field.TypeString, field.TypeEnum: // SQLite does not impose any length restrictions on // the length of strings, BLOBs or numeric values. t = &schema.StringType{T: sqlite.TypeText} case field.TypeFloat32, field.TypeFloat64: t = &schema.FloatType{T: sqlite.TypeReal} case field.TypeTime: t = &schema.TimeType{T: "datetime"} case field.TypeJSON: t = &schema.JSONType{T: "json"} case field.TypeUUID: t = &sqlite.UUIDType{T: "uuid"} case field.TypeOther: t = &schema.UnsupportedType{T: c1.typ} default: t, err := sqlite.ParseType(strings.ToLower(c1.typ)) if err != nil { return err } c2.Type.Type = t } c2.Type.Type = t return nil } func (d *SQLite) atUniqueC(t1 *Table, c1 *Column, t2 *schema.Table, c2 *schema.Column) { // For UNIQUE columns, SQLite create an implicit index named // "sqlite_autoindex_
_". Ent uses the PostgreSQL approach // in its migration, and name these indexes as "
__key". for _, idx := range t1.Indexes { // Index also defined explicitly, and will be add in atIndexes. if idx.Unique && d.atImplicitIndexName(idx, t1, c1) { return } } t2.AddIndexes(schema.NewUniqueIndex(fmt.Sprintf("%s_%s_key", t2.Name, c1.Name)).AddColumns(c2)) } func (d *SQLite) atImplicitIndexName(idx *Index, t1 *Table, c1 *Column) bool { if idx.Name == c1.Name { return true } p := fmt.Sprintf("sqlite_autoindex_%s_", t1.Name) if !strings.HasPrefix(idx.Name, p) { return false } i, err := strconv.ParseInt(strings.TrimPrefix(idx.Name, p), 10, 64) return err == nil && i > 0 } func (d *SQLite) atIncrementC(_ *schema.Table, c *schema.Column) { c.AddAttrs(&sqlite.AutoIncrement{}) } func (d *SQLite) atIncrementT(t *schema.Table, v int64) { t.AddAttrs(&sqlite.AutoIncrement{Seq: v}) } func (d *SQLite) atIndex(idx1 *Index, t2 *schema.Table, idx2 *schema.Index) error { for _, c1 := range idx1.Columns { c2, ok := t2.Column(c1.Name) if !ok { return fmt.Errorf("unexpected index %q column: %q", idx1.Name, c1.Name) } idx2.AddParts(&schema.IndexPart{C: c2}) } if idx1.Annotation != nil && idx1.Annotation.Where != "" { idx2.AddAttrs(&sqlite.IndexPredicate{P: idx1.Annotation.Where}) } return nil } func (*SQLite) atTypeRangeSQL(ts ...string) string { for i := range ts { ts[i] = fmt.Sprintf("('%s')", ts[i]) } return fmt.Sprintf("INSERT INTO `%s` (`type`) VALUES %s", TypeTable, strings.Join(ts, ", ")) } type tx struct { dialect.Tx } func (tx *tx) QueryContext(ctx context.Context, query string, args ...any) (*stdsql.Rows, error) { rows := &sql.Rows{} if err := tx.Query(ctx, query, args, rows); err != nil { return nil, err } return rows.ColumnScanner.(*stdsql.Rows), nil } func (tx *tx) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { var r stdsql.Result if err := tx.Exec(ctx, query, args, &r); err != nil { return nil, err } return r, nil } ent-0.11.3/dialect/sql/schema/sqlite_test.go000066400000000000000000000424051431500740500207210ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "fmt" "math" "testing" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestSQLite_Create(t *testing.T) { tests := []struct { name string tables []*Table options []MigrateOption before func(sqliteMock) wantErr bool }{ { name: "tx failed", before: func(mock sqliteMock) { mock.ExpectBegin().WillReturnError(sqlmock.ErrCancelled) }, wantErr: true, }, { name: "fk disabled", before: func(mock sqliteMock) { mock.ExpectBegin() mock.ExpectQuery("PRAGMA foreign_keys"). WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(0)) mock.ExpectRollback() }, wantErr: true, }, { name: "no tables", before: func(mock sqliteMock) { mock.start() mock.commit() }, }, { name: "create new table", tables: []*Table{ { Name: "users", PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "age", Type: field.TypeInt}, {Name: "doc", Type: field.TypeJSON, Nullable: true}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "decimal", Type: field.TypeFloat32, SchemaType: map[string]string{dialect.SQLite: "decimal(6,2)"}}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `age` integer NOT NULL, `doc` json NULL, `uuid` uuid NULL, `decimal` decimal(6,2) NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.commit() }, }, { name: "create new table with foreign key", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock sqliteMock) { mock.start() mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` datetime NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL, FOREIGN KEY(`owner_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.commit() }, }, { name: "create new table with foreign key disabled", options: []MigrateOption{ WithForeignKeys(false), }, tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, } c2 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString}, {Name: "owner_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], } t2 = &Table{ Name: "pets", Columns: c2, PrimaryKey: c2[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "pets_owner", Columns: c2[2:], RefTable: t1, RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) return []*Table{t1, t2} }(), before: func(mock sqliteMock) { mock.start() mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NULL, `created_at` datetime NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("pets", false) mock.ExpectExec(escape("CREATE TABLE `pets`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `name` varchar(255) NOT NULL, `owner_id` integer NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.commit() }, }, { name: "add column to table", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "text", Type: field.TypeString, Nullable: true, Size: math.MaxInt32}, {Name: "uuid", Type: field.TypeUUID, Nullable: true}, {Name: "age", Type: field.TypeInt, Default: 0}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("name", "varchar(255)", 0, nil, 0). AddRow("text", "text", 0, "NULL", 0). AddRow("uuid", "uuid", 0, "Null", 0). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `age` integer NOT NULL DEFAULT 0")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.commit() }, }, { name: "datetime and timestamp", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "created_at", Type: field.TypeTime, Nullable: true}, {Name: "updated_at", Type: field.TypeTime, Nullable: true}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("created_at", "datetime", 0, nil, 0). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `updated_at` datetime NULL")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.commit() }, }, { name: "add blob columns", tables: []*Table{ { Name: "blobs", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "old_tiny", Type: field.TypeBytes, Size: 100}, {Name: "old_blob", Type: field.TypeBytes, Size: 1e3}, {Name: "old_medium", Type: field.TypeBytes, Size: 1e5}, {Name: "old_long", Type: field.TypeBytes, Size: 1e8}, {Name: "new_tiny", Type: field.TypeBytes, Size: 100}, {Name: "new_blob", Type: field.TypeBytes, Size: 1e3}, {Name: "new_medium", Type: field.TypeBytes, Size: 1e5}, {Name: "new_long", Type: field.TypeBytes, Size: 1e8}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("blobs", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('blobs') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("old_tiny", "blob", 1, nil, 0). AddRow("old_blob", "blob", 1, nil, 0). AddRow("old_medium", "blob", 1, nil, 0). AddRow("old_long", "blob", 1, nil, 0). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('blobs')")). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "unique"})) for _, c := range []string{"tiny", "blob", "medium", "long"} { mock.ExpectExec(escape(fmt.Sprintf("ALTER TABLE `blobs` ADD COLUMN `new_%s` blob NOT NULL", c))). WillReturnResult(sqlmock.NewResult(0, 1)) } mock.commit() }, }, { name: "add columns with default values", tables: []*Table{ { Name: "users", Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Default: "unknown"}, {Name: "active", Type: field.TypeBool, Default: false}, }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, }, }, }, before: func(mock sqliteMock) { mock.start() mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `name` varchar(255) NOT NULL DEFAULT 'unknown'")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `active` bool NOT NULL DEFAULT false")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.commit() }, }, { name: "add edge to table", tables: func() []*Table { var ( c1 = []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "spouse_id", Type: field.TypeInt, Nullable: true}, } t1 = &Table{ Name: "users", Columns: c1, PrimaryKey: c1[0:1], ForeignKeys: []*ForeignKey{ { Symbol: "user_spouse", Columns: c1[2:], RefColumns: c1[0:1], OnDelete: Cascade, }, }, } ) t1.ForeignKeys[0].RefTable = t1 return []*Table{t1} }(), before: func(mock sqliteMock) { mock.start() mock.tableExists("users", true) mock.ExpectQuery(escape("SELECT `name`, `type`, `notnull`, `dflt_value`, `pk` FROM pragma_table_info('users') ORDER BY `pk`")). WithArgs(). WillReturnRows(sqlmock.NewRows([]string{"name", "type", "notnull", "dflt_value", "pk"}). AddRow("name", "varchar(255)", 1, "NULL", 0). AddRow("id", "integer", 1, "NULL", 1)) mock.ExpectQuery(escape("SELECT `name`, `unique`, `origin` FROM pragma_index_list('users')")). WillReturnRows(sqlmock.NewRows([]string{"name", "unique", "origin"})) mock.ExpectExec(escape("ALTER TABLE `users` ADD COLUMN `spouse_id` integer NULL CONSTRAINT user_spouse REFERENCES `users`(`id`) ON DELETE CASCADE")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.commit() }, }, { name: "universal id for all tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock sqliteMock) { mock.start() // creating ent_types table. mock.tableExists("ent_types", false) mock.ExpectExec(escape("CREATE TABLE `ent_types`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `type` varchar(255) UNIQUE NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("users"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). WithArgs("users", 0). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("groups", false) mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). WithArgs("groups"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). WithArgs("groups", 1<<32). WillReturnResult(sqlmock.NewResult(0, 1)) mock.commit() }, }, { name: "universal id for restored tables", tables: []*Table{ NewTable("users").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), NewTable("groups").AddPrimary(&Column{Name: "id", Type: field.TypeInt, Increment: true}), }, options: []MigrateOption{WithGlobalUniqueID(true)}, before: func(mock sqliteMock) { mock.start() // query ent_types table. mock.tableExists("ent_types", true) mock.ExpectQuery(escape("SELECT `type` FROM `ent_types` ORDER BY `id` ASC")). WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow("users")) mock.tableExists("users", false) mock.ExpectExec(escape("CREATE TABLE `users`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) // set users id range (without inserting to ent_types). mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) mock.ExpectExec(escape("UPDATE `sqlite_sequence` SET `seq` = ? WHERE `name` = ?")). WithArgs(0, "users"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.tableExists("groups", false) mock.ExpectExec(escape("CREATE TABLE `groups`(`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL)")). WillReturnResult(sqlmock.NewResult(0, 1)) // set groups id range. mock.ExpectExec(escape("INSERT INTO `ent_types` (`type`) VALUES (?)")). WithArgs("groups"). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_sequence` WHERE `name` = ?")). WithArgs("groups"). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) mock.ExpectExec(escape("INSERT INTO `sqlite_sequence` (`name`, `seq`) VALUES (?, ?)")). WithArgs("groups", 1<<32). WillReturnResult(sqlmock.NewResult(0, 1)) mock.commit() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.before(sqliteMock{mock}) migrate, err := NewMigrate(sql.OpenDB("sqlite3", db), append(tt.options, WithAtlas(false))...) require.NoError(t, err) err = migrate.Create(context.Background(), tt.tables...) require.Equal(t, tt.wantErr, err != nil, err) }) } } type sqliteMock struct { sqlmock.Sqlmock } func (m sqliteMock) start() { m.ExpectQuery("PRAGMA foreign_keys"). WillReturnRows(sqlmock.NewRows([]string{"foreign_keys"}).AddRow(1)) m.ExpectExec("PRAGMA foreign_keys = off"). WillReturnResult(sqlmock.NewResult(0, 1)) m.ExpectBegin() m.ExpectQuery("PRAGMA foreign_key_check"). WillReturnRows(sqlmock.NewRows([]string{})) // empty } func (m sqliteMock) commit() { m.ExpectQuery("PRAGMA foreign_key_check"). WillReturnRows(sqlmock.NewRows([]string{})) // empty m.ExpectCommit() m.ExpectExec("PRAGMA foreign_keys = on"). WillReturnResult(sqlmock.NewResult(0, 1)) } func (m sqliteMock) tableExists(table string, exists bool) { count := 0 if exists { count = 1 } m.ExpectQuery(escape("SELECT COUNT(*) FROM `sqlite_master` WHERE `type` = ? AND `name` = ?")). WithArgs("table", table). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(count)) } ent-0.11.3/dialect/sql/schema/writer.go000066400000000000000000000022701431500740500176710ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "context" "io" "strings" "entgo.io/ent/dialect" ) // WriteDriver is a driver that writes all driver exec operations to its writer. type WriteDriver struct { dialect.Driver // underlying driver. io.Writer // target for exec statements. } // Exec writes its query and calls the underlying driver Exec method. func (w *WriteDriver) Exec(_ context.Context, query string, _, _ any) error { if !strings.HasSuffix(query, ";") { query += ";" } _, err := io.WriteString(w, query+"\n") return err } // Tx writes the transaction start. func (w *WriteDriver) Tx(context.Context) (dialect.Tx, error) { if _, err := io.WriteString(w, "BEGIN;\n"); err != nil { return nil, err } return w, nil } // Commit writes the transaction commit. func (w *WriteDriver) Commit() error { _, err := io.WriteString(w, "COMMIT;\n") return err } // Rollback writes the transaction rollback. func (w *WriteDriver) Rollback() error { _, err := io.WriteString(w, "ROLLBACK;\n") return err } ent-0.11.3/dialect/sql/schema/writer_test.go000066400000000000000000000030211431500740500207230ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package schema import ( "bytes" "context" "strings" "testing" "entgo.io/ent/dialect" "github.com/stretchr/testify/require" ) func TestWriteDriver(t *testing.T) { b := &bytes.Buffer{} w := WriteDriver{Driver: nopDriver{}, Writer: b} ctx := context.Background() tx, err := w.Tx(ctx) require.NoError(t, err) err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil) require.NoError(t, err) err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil) require.NoError(t, err) err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `age` int", nil, nil) require.NoError(t, err) err = tx.Exec(ctx, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", nil, nil) require.NoError(t, err) err = tx.Query(ctx, "SELECT `name` FROM `users`", nil, nil) require.NoError(t, err) require.NoError(t, tx.Commit()) lines := strings.Split(b.String(), "\n") require.Equal(t, "BEGIN;", lines[0]) require.Equal(t, "ALTER TABLE `users` ADD COLUMN `age` int;", lines[1]) require.Equal(t, "ALTER TABLE `users` ADD COLUMN `NAME` varchar(100);", lines[2]) require.Equal(t, "COMMIT;", lines[3]) require.Empty(t, lines[4], "file ends with blank line") } type nopDriver struct { dialect.Driver } func (nopDriver) Exec(context.Context, string, any, any) error { return nil } func (nopDriver) Query(context.Context, string, any, any) error { return nil } ent-0.11.3/dialect/sql/sqlgraph/000077500000000000000000000000001431500740500164065ustar00rootroot00000000000000ent-0.11.3/dialect/sql/sqlgraph/entql.go000066400000000000000000000212541431500740500200640ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqlgraph import ( "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/entql" ) type ( // A Schema holds a representation of ent/schema at runtime. Each Node // represents a single schema-type and its relations in the graph (storage). // // It is used for translating common graph traversal operations to the // underlying SQL storage. For example, an operation like `has_edge(E)`, // will be translated to an SQL lookup based on the relation type and the // FK configuration. Schema struct { Nodes []*Node } // A Node in the graph holds the SQL information for an ent/schema. Node struct { NodeSpec // Type holds the node type (schema name). Type string // Fields maps from field names to their spec. Fields map[string]*FieldSpec // Edges maps from edge names to their spec. Edges map[string]struct { To *Node Spec *EdgeSpec } } ) // AddE adds an edge to the graph. It fails, if one of the node // types is missing. // // g.AddE("pets", spec, "user", "pet") // g.AddE("friends", spec, "user", "user") // func (g *Schema) AddE(name string, spec *EdgeSpec, from, to string) error { var fromT, toT *Node for i := range g.Nodes { t := g.Nodes[i].Type if t == from { fromT = g.Nodes[i] } if t == to { toT = g.Nodes[i] } } if fromT == nil || toT == nil { return fmt.Errorf("from/to type was not found") } if fromT.Edges == nil { fromT.Edges = make(map[string]struct { To *Node Spec *EdgeSpec }) } fromT.Edges[name] = struct { To *Node Spec *EdgeSpec }{ To: toT, Spec: spec, } return nil } // MustAddE is like AddE but panics if the edge can be added to the graph. func (g *Schema) MustAddE(name string, spec *EdgeSpec, from, to string) { if err := g.AddE(name, spec, from, to); err != nil { panic(err) } } // EvalP evaluates the entql predicate on the given selector (query builder). func (g *Schema) EvalP(nodeType string, p entql.P, selector *sql.Selector) error { var node *Node for i := range g.Nodes { if g.Nodes[i].Type == nodeType { node = g.Nodes[i] break } } if node == nil { return fmt.Errorf("node %s was not found in the graph schema", nodeType) } pr, err := evalExpr(node, selector, p) if err != nil { return err } selector.Where(pr) return nil } // FuncSelector represents a selector function to be used as an entql foreign-function. const FuncSelector entql.Func = "func_selector" // wrappedFunc wraps the selector-function to an ent-expression. type wrappedFunc struct { entql.Expr Func func(*sql.Selector) } // WrapFunc wraps a selector-func with an entql call expression. func WrapFunc(s func(*sql.Selector)) *entql.CallExpr { return &entql.CallExpr{ Func: FuncSelector, Args: []entql.Expr{wrappedFunc{Func: s}}, } } var ( binary = [...]sql.Op{ entql.OpEQ: sql.OpEQ, entql.OpNEQ: sql.OpNEQ, entql.OpGT: sql.OpGT, entql.OpGTE: sql.OpGTE, entql.OpLT: sql.OpLT, entql.OpLTE: sql.OpLTE, entql.OpIn: sql.OpIn, entql.OpNotIn: sql.OpNotIn, } nary = [...]func(...*sql.Predicate) *sql.Predicate{ entql.OpAnd: sql.And, entql.OpOr: sql.Or, } strFunc = map[entql.Func]func(string, string) *sql.Predicate{ entql.FuncContains: sql.Contains, entql.FuncContainsFold: sql.ContainsFold, entql.FuncEqualFold: sql.EqualFold, entql.FuncHasPrefix: sql.HasPrefix, entql.FuncHasSuffix: sql.HasSuffix, } nullFunc = [...]func(string) *sql.Predicate{ entql.OpEQ: sql.IsNull, entql.OpNEQ: sql.NotNull, } ) // state represents the state of a predicate evaluation. // Note that, the evaluation output is a predicate to be // applied on the database. type state struct { sql.Builder context *Node selector *sql.Selector } // evalExpr evaluates the entql expression and returns a new SQL predicate to be applied on the database. func evalExpr(context *Node, selector *sql.Selector, expr entql.Expr) (p *sql.Predicate, err error) { ex := &state{ context: context, selector: selector, } defer catch(&err) p = ex.evalExpr(expr) return } // evalExpr evaluates any expression. func (e *state) evalExpr(expr entql.Expr) *sql.Predicate { switch expr := expr.(type) { case *entql.BinaryExpr: return e.evalBinary(expr) case *entql.UnaryExpr: return sql.Not(e.evalExpr(expr.X)) case *entql.NaryExpr: ps := make([]*sql.Predicate, len(expr.Xs)) for i, x := range expr.Xs { ps[i] = e.evalExpr(x) } return nary[expr.Op](ps...) case *entql.CallExpr: switch expr.Func { case entql.FuncHasPrefix, entql.FuncHasSuffix, entql.FuncContains, entql.FuncEqualFold, entql.FuncContainsFold: expect(len(expr.Args) == 2, "invalid number of arguments for %s", expr.Func) f, ok := expr.Args[0].(*entql.Field) expect(ok, "*entql.Field, got %T", expr.Args[0]) v, ok := expr.Args[1].(*entql.Value) expect(ok, "*entql.Value, got %T", expr.Args[1]) s, ok := v.V.(string) expect(ok, "string value, got %T", v.V) return strFunc[expr.Func](e.field(f), s) case entql.FuncHasEdge: expect(len(expr.Args) > 0, "invalid number of arguments for %s", expr.Func) edge, ok := expr.Args[0].(*entql.Edge) expect(ok, "*entql.Edge, got %T", expr.Args[0]) return e.evalEdge(edge.Name, expr.Args[1:]...) } } panic("invalid") } // evalBinary evaluates binary expressions. func (e *state) evalBinary(expr *entql.BinaryExpr) *sql.Predicate { switch expr.Op { case entql.OpOr: return sql.Or(e.evalExpr(expr.X), e.evalExpr(expr.Y)) case entql.OpAnd: return sql.And(e.evalExpr(expr.X), e.evalExpr(expr.Y)) case entql.OpEQ, entql.OpNEQ: if expr.Y == (*entql.Value)(nil) { f, ok := expr.X.(*entql.Field) expect(ok, "*entql.Field, got %T", expr.Y) return nullFunc[expr.Op](e.field(f)) } fallthrough default: field, ok := expr.X.(*entql.Field) expect(ok, "expr.X to be *entql.Field (got %T)", expr.X) _, ok = expr.Y.(*entql.Field) if !ok { _, ok = expr.Y.(*entql.Value) } expect(ok, "expr.Y to be *entql.Field or *entql.Value (got %T)", expr.X) switch x := expr.Y.(type) { case *entql.Field: return sql.ColumnsOp(e.field(field), e.field(x), binary[expr.Op]) case *entql.Value: c := e.field(field) return sql.P(func(b *sql.Builder) { b.Ident(c).WriteOp(binary[expr.Op]) args(b, x) }) default: panic("unreachable") } } } // evalEdge evaluates has-edge and has-edge-with calls. func (e *state) evalEdge(name string, exprs ...entql.Expr) *sql.Predicate { edge, ok := e.context.Edges[name] expect(ok, "edge %q was not found for node %q", name, e.context.Type) var toC string switch { case edge.To.ID != nil: toC = edge.To.ID.Column // Edge-owner points to its edge schema. case edge.To.CompositeID != nil && !edge.Spec.Inverse: toC = edge.To.CompositeID[0].Column // Edge-backref points to its edge schema. case edge.To.CompositeID != nil && edge.Spec.Inverse: toC = edge.To.CompositeID[1].Column default: panic(evalError{fmt.Sprintf("expect id definition for edge %q", name)}) } step := NewStep( From(e.context.Table, e.context.ID.Column), To(edge.To.Table, toC), Edge(edge.Spec.Rel, edge.Spec.Inverse, edge.Spec.Table, edge.Spec.Columns...), ) selector := e.selector.Clone().SetP(nil) selector.SetTotal(e.Total()) if len(exprs) == 0 { HasNeighbors(selector, step) return selector.P() } HasNeighborsWith(selector, step, func(s *sql.Selector) { for _, expr := range exprs { if cx, ok := expr.(*entql.CallExpr); ok && cx.Func == FuncSelector { expect(len(cx.Args) == 1, "invalid number of arguments for %s", FuncSelector) wrapped, ok := cx.Args[0].(wrappedFunc) expect(ok, "invalid argument for %s: %T", FuncSelector, cx.Args[0]) wrapped.Func(s) } else { p, err := evalExpr(edge.To, s, expr) expect(err == nil, "edge evaluation failed for %s->%s: %s", e.context.Type, name, err) s.Where(p) } } }) return selector.P() } func (e *state) field(f *entql.Field) string { _, ok := e.context.Fields[f.Name] expect(ok || e.context.ID.Column == f.Name, "field %q was not found for node %q", f.Name, e.context.Type) return e.selector.C(f.Name) } func args(b *sql.Builder, v *entql.Value) { vs, ok := v.V.([]any) if !ok { b.Arg(v.V) return } b.WriteByte('(').Args(vs...).WriteByte(')') } // expect panics if the condition is false. func expect(cond bool, msg string, args ...any) { if !cond { panic(evalError{fmt.Sprintf("expect "+msg, args...)}) } } type evalError struct { msg string } func (p evalError) Error() string { return fmt.Sprintf("sqlgraph: %s", p.msg) } func catch(err *error) { if e := recover(); e != nil { xerr, ok := e.(evalError) if !ok { panic(e) } *err = xerr } } ent-0.11.3/dialect/sql/sqlgraph/entql_test.go000066400000000000000000000151211431500740500211170ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqlgraph import ( "strconv" "testing" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/entql" "entgo.io/ent/schema/field" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGraph_AddE(t *testing.T) { g := &Schema{ Nodes: []*Node{{Type: "user"}, {Type: "pet"}}, } err := g.AddE("pets", &EdgeSpec{Rel: O2M}, "user", "pet") assert.NoError(t, err) err = g.AddE("owner", &EdgeSpec{Rel: O2M}, "pet", "user") assert.NoError(t, err) err = g.AddE("groups", &EdgeSpec{Rel: M2M}, "pet", "groups") assert.Error(t, err) } func TestGraph_EvalP(t *testing.T) { g := &Schema{ Nodes: []*Node{ { Type: "user", NodeSpec: NodeSpec{ Table: "users", ID: &FieldSpec{Column: "uid"}, }, Fields: map[string]*FieldSpec{ "name": {Column: "name", Type: field.TypeString}, "last": {Column: "last", Type: field.TypeString}, }, }, { Type: "pet", NodeSpec: NodeSpec{ Table: "pets", ID: &FieldSpec{Column: "pid"}, }, Fields: map[string]*FieldSpec{ "name": {Column: "name", Type: field.TypeString}, }, }, { Type: "group", NodeSpec: NodeSpec{ Table: "groups", ID: &FieldSpec{Column: "gid"}, }, Fields: map[string]*FieldSpec{ "name": {Column: "name", Type: field.TypeString}, }, }, }, } err := g.AddE("pets", &EdgeSpec{Rel: O2M, Table: "pets", Columns: []string{"owner_id"}}, "user", "pet") require.NoError(t, err) err = g.AddE("owner", &EdgeSpec{Rel: M2O, Inverse: true, Table: "pets", Columns: []string{"owner_id"}}, "pet", "user") require.NoError(t, err) err = g.AddE("groups", &EdgeSpec{Rel: M2M, Table: "user_groups", Columns: []string{"user_id", "group_id"}}, "user", "group") require.NoError(t, err) err = g.AddE("users", &EdgeSpec{Rel: M2M, Inverse: true, Table: "user_groups", Columns: []string{"user_id", "group_id"}}, "group", "user") require.NoError(t, err) tests := []struct { s *sql.Selector p entql.P wantQuery string wantArgs []any wantErr bool }{ { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.FieldHasPrefix("name", "a"), wantQuery: `SELECT * FROM "users" WHERE "users"."name" LIKE $1`, wantArgs: []any{"a%"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")). Where(sql.EQ("age", 1)), p: entql.FieldHasPrefix("name", "a"), wantQuery: `SELECT * FROM "users" WHERE "age" = $1 AND "users"."name" LIKE $2`, wantArgs: []any{1, "a%"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")). Where(sql.EQ("age", 1)), p: entql.FieldHasPrefix("name", "a"), wantQuery: `SELECT * FROM "users" WHERE "age" = $1 AND "users"."name" LIKE $2`, wantArgs: []any{1, "a%"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.EQ(entql.F("name"), entql.F("last")), wantQuery: `SELECT * FROM "users" WHERE "users"."name" = "users"."last"`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.EQ(entql.F("name"), entql.F("last")), wantQuery: `SELECT * FROM "users" WHERE "users"."name" = "users"."last"`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.And(entql.FieldNil("name"), entql.FieldNotNil("last")), wantQuery: `SELECT * FROM "users" WHERE "users"."name" IS NULL AND "users"."last" IS NOT NULL`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")). Where(sql.EQ("foo", "bar")), p: entql.Or(entql.FieldEQ("name", "foo"), entql.FieldEQ("name", "baz")), wantQuery: `SELECT * FROM "users" WHERE "foo" = $1 AND ("users"."name" = $2 OR "users"."name" = $3)`, wantArgs: []any{"bar", "foo", "baz"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.HasEdge("pets"), wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL)`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.HasEdge("groups"), wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups")`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")), p: entql.HasEdgeWith("pets", entql.Or(entql.FieldEQ("name", "pedro"), entql.FieldEQ("name", "xabi"))), wantQuery: `SELECT * FROM "users" WHERE "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $1 OR "pets"."name" = $2)`, wantArgs: []any{"pedro", "xabi"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)), p: entql.HasEdgeWith("groups", entql.Or(entql.FieldEQ("name", "GitHub"), entql.FieldEQ("name", "GitLab"))), wantQuery: `SELECT * FROM "users" WHERE "active" AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."gid" WHERE "t1"."name" = $1 OR "t1"."name" = $2)`, wantArgs: []any{"GitHub", "GitLab"}, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)), p: entql.And(entql.HasEdge("pets"), entql.HasEdge("groups"), entql.EQ(entql.F("name"), entql.F("uid"))), wantQuery: `SELECT * FROM "users" WHERE "active" AND ("users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."owner_id" IS NOT NULL) AND "users"."uid" IN (SELECT "user_groups"."user_id" FROM "user_groups") AND "users"."name" = "users"."uid")`, }, { s: sql.Dialect(dialect.Postgres).Select().From(sql.Table("users")).Where(sql.EQ("active", true)), p: entql.HasEdgeWith("pets", entql.FieldEQ("name", "pedro"), WrapFunc(func(s *sql.Selector) { s.Where(sql.EQ("owner_id", 10)) })), wantQuery: `SELECT * FROM "users" WHERE "active" AND "users"."uid" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "pets"."name" = $1 AND "owner_id" = $2)`, wantArgs: []any{"pedro", 10}, }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { err = g.EvalP("user", tt.p, tt.s) require.Equal(t, tt.wantErr, err != nil, err) query, args := tt.s.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } ent-0.11.3/dialect/sql/sqlgraph/errors.go000066400000000000000000000027401431500740500202540ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqlgraph import ( "errors" "strings" ) // IsConstraintError returns true if the error resulted from a database constraint violation. func IsConstraintError(err error) bool { var e *ConstraintError return errors.As(err, &e) || IsUniqueConstraintError(err) || IsForeignKeyConstraintError(err) } // IsUniqueConstraintError reports if the error resulted from a DB uniqueness constraint violation. // e.g. duplicate value in unique index. func IsUniqueConstraintError(err error) bool { for _, s := range []string{ "Error 1062", // MySQL "violates unique constraint", // Postgres "UNIQUE constraint failed", // SQLite } { if strings.Contains(err.Error(), s) { return true } } return false } // IsForeignKeyConstraintError reports if the error resulted from a database foreign-key constraint violation. // e.g. parent row does not exist. func IsForeignKeyConstraintError(err error) bool { for _, s := range []string{ "Error 1451", // MySQL (Cannot delete or update a parent row). "Error 1452", // MySQL (Cannot add or update a child row). "violates foreign key constraint", // Postgres "FOREIGN KEY constraint failed", // SQLite } { if strings.Contains(err.Error(), s) { return true } } return false } ent-0.11.3/dialect/sql/sqlgraph/graph.go000066400000000000000000001307261431500740500200470ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. // Package sqlgraph provides graph abstraction capabilities on top // of sql-based databases for ent codegen. package sqlgraph import ( "context" "database/sql/driver" "encoding/json" "fmt" "math" "sort" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" ) // Rel is an edge relation type. type Rel int // Relation types. const ( _ Rel = iota // Unknown. O2O // One to one / has one. O2M // One to many / has many. M2O // Many to one (inverse perspective for O2M). M2M // Many to many. ) // String returns the relation name. func (r Rel) String() (s string) { switch r { case O2O: s = "O2O" case O2M: s = "O2M" case M2O: s = "M2O" case M2M: s = "M2M" default: s = "Unknown" } return s } // A ConstraintError represents an error from mutation that violates a specific constraint. type ConstraintError struct { msg string } func (e ConstraintError) Error() string { return e.msg } // A Step provides a path-step information to the traversal functions. type Step struct { // From is the source of the step. From struct { // V can be either one vertex or set of vertices. // It can be a pre-processed step (sql.Query) or a simple Go type (integer or string). V any // Table holds the table name of V (from). Table string // Column to join with. Usually the "id" column. Column string } // Edge holds the edge information for getting the neighbors. Edge struct { // Rel of the edge. Rel Rel // Schema is an optional name of the database // where the table is defined. Schema string // Table name of where this edge columns reside. Table string // Columns of the edge. // In O2O and M2O, it holds the foreign-key column. Hence, len == 1. // In M2M, it holds the primary-key columns of the join table. Hence, len == 2. Columns []string // Inverse indicates if the edge is an inverse edge. Inverse bool } // To is the dest of the path (the neighbors). To struct { // Table holds the table name of the neighbors (to). Table string // Schema is an optional name of the database // where the table is defined. Schema string // Column to join with. Usually the "id" column. Column string } } // StepOption allows configuring Steps using functional options. type StepOption func(*Step) // From sets the source of the step. func From(table, column string, v ...any) StepOption { return func(s *Step) { s.From.Table = table s.From.Column = column if len(v) > 0 { s.From.V = v[0] } } } // To sets the destination of the step. func To(table, column string) StepOption { return func(s *Step) { s.To.Table = table s.To.Column = column } } // Edge sets the edge info for getting the neighbors. func Edge(rel Rel, inverse bool, table string, columns ...string) StepOption { return func(s *Step) { s.Edge.Rel = rel s.Edge.Table = table s.Edge.Columns = columns s.Edge.Inverse = inverse } } // NewStep gets list of options and returns a configured step. // // NewStep( // From("table", "pk", V), // To("table", "pk"), // Edge("name", O2M, "fk"), // ) func NewStep(opts ...StepOption) *Step { s := &Step{} for _, opt := range opts { opt(s) } return s } // Neighbors returns a Selector for evaluating the path-step // and getting the neighbors of one vertex. func Neighbors(dialect string, s *Step) (q *sql.Selector) { builder := sql.Dialect(dialect) switch r := s.Edge.Rel; { case r == M2M: pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] if s.Edge.Inverse { pk1, pk2 = pk2, pk1 } to := builder.Table(s.To.Table).Schema(s.To.Schema) join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) match := builder.Select(join.C(pk1)). From(join). Where(sql.EQ(join.C(pk2), s.From.V)) q = builder.Select(). From(to). Join(match). On(to.C(s.To.Column), match.C(pk1)) case r == M2O || (r == O2O && s.Edge.Inverse): t1 := builder.Table(s.To.Table).Schema(s.To.Schema) t2 := builder.Select(s.Edge.Columns[0]). From(builder.Table(s.Edge.Table).Schema(s.Edge.Schema)). Where(sql.EQ(s.From.Column, s.From.V)) q = builder.Select(). From(t1). Join(t2). On(t1.C(s.To.Column), t2.C(s.Edge.Columns[0])) case r == O2M || (r == O2O && !s.Edge.Inverse): q = builder.Select(). From(builder.Table(s.To.Table).Schema(s.To.Schema)). Where(sql.EQ(s.Edge.Columns[0], s.From.V)) } return q } // SetNeighbors returns a Selector for evaluating the path-step // and getting the neighbors of set of vertices. func SetNeighbors(dialect string, s *Step) (q *sql.Selector) { set := s.From.V.(*sql.Selector) builder := sql.Dialect(dialect) switch r := s.Edge.Rel; { case r == M2M: pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] if s.Edge.Inverse { pk1, pk2 = pk2, pk1 } to := builder.Table(s.To.Table).Schema(s.To.Schema) set.Select(set.C(s.From.Column)) join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) match := builder.Select(join.C(pk1)). From(join). Join(set). On(join.C(pk2), set.C(s.From.Column)) q = builder.Select(). From(to). Join(match). On(to.C(s.To.Column), match.C(pk1)) case r == M2O || (r == O2O && s.Edge.Inverse): t1 := builder.Table(s.To.Table).Schema(s.To.Schema) set.Select(set.C(s.Edge.Columns[0])) q = builder.Select(). From(t1). Join(set). On(t1.C(s.To.Column), set.C(s.Edge.Columns[0])) case r == O2M || (r == O2O && !s.Edge.Inverse): t1 := builder.Table(s.To.Table).Schema(s.To.Schema) set.Select(set.C(s.From.Column)) q = builder.Select(). From(t1). Join(set). On(t1.C(s.Edge.Columns[0]), set.C(s.From.Column)) } return q } // HasNeighbors applies on the given Selector a neighbors check. func HasNeighbors(q *sql.Selector, s *Step) { builder := sql.Dialect(q.Dialect()) switch r := s.Edge.Rel; { case r == M2M: pk1 := s.Edge.Columns[0] if s.Edge.Inverse { pk1 = s.Edge.Columns[1] } join := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) q.Where( sql.In( q.C(s.From.Column), builder.Select(join.C(pk1)).From(join), ), ) case r == M2O || (r == O2O && s.Edge.Inverse): q.Where(sql.NotNull(q.C(s.Edge.Columns[0]))) case r == O2M || (r == O2O && !s.Edge.Inverse): to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) q.Where( sql.In( q.C(s.From.Column), builder.Select(to.C(s.Edge.Columns[0])). From(to). Where(sql.NotNull(to.C(s.Edge.Columns[0]))), ), ) } } // HasNeighborsWith applies on the given Selector a neighbors check. // The given predicate applies its filtering on the selector. func HasNeighborsWith(q *sql.Selector, s *Step, pred func(*sql.Selector)) { builder := sql.Dialect(q.Dialect()) switch r := s.Edge.Rel; { case r == M2M: pk1, pk2 := s.Edge.Columns[1], s.Edge.Columns[0] if s.Edge.Inverse { pk1, pk2 = pk2, pk1 } to := builder.Table(s.To.Table).Schema(s.To.Schema) edge := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) join := builder.Select(edge.C(pk2)). From(edge). Join(to). On(edge.C(pk1), to.C(s.To.Column)) matches := builder.Select().From(to) matches.WithContext(q.Context()) pred(matches) join.FromSelect(matches) q.Where(sql.In(q.C(s.From.Column), join)) case r == M2O || (r == O2O && s.Edge.Inverse): to := builder.Table(s.To.Table).Schema(s.To.Schema) matches := builder.Select(to.C(s.To.Column)). From(to) matches.WithContext(q.Context()) pred(matches) q.Where(sql.In(q.C(s.Edge.Columns[0]), matches)) case r == O2M || (r == O2O && !s.Edge.Inverse): to := builder.Table(s.Edge.Table).Schema(s.Edge.Schema) matches := builder.Select(to.C(s.Edge.Columns[0])). From(to) matches.WithContext(q.Context()) pred(matches) q.Where(sql.In(q.C(s.From.Column), matches)) } } type ( // FieldSpec holds the information for updating a field // column in the database. FieldSpec struct { Column string Type field.Type Value driver.Value // value to be stored. } // EdgeTarget holds the information for the target nodes // of an edge. EdgeTarget struct { Nodes []driver.Value IDSpec *FieldSpec // Additional fields can be set on the // edge join table. Valid for M2M edges. Fields []*FieldSpec } // EdgeSpec holds the information for updating a field // column in the database. EdgeSpec struct { Rel Rel Inverse bool Table string Schema string Columns []string Bidi bool // bidirectional edge. Target *EdgeTarget // target nodes. } // EdgeSpecs used for perform common operations on list of edges. EdgeSpecs []*EdgeSpec // NodeSpec defines the information for querying and // decoding nodes in the graph. NodeSpec struct { Table string Schema string Columns []string ID *FieldSpec // primary key. CompositeID []*FieldSpec // composite id (edge schema). } ) type ( // CreateSpec holds the information for creating // a node in the graph. CreateSpec struct { Table string Schema string ID *FieldSpec Fields []*FieldSpec Edges []*EdgeSpec // The OnConflict option allows providing on-conflict // options to the INSERT statement. // // sqlgraph.CreateSpec{ // OnConflict: []sql.ConflictOption{ // sql.ResolveWithNewValues(), // }, // } // OnConflict []sql.ConflictOption } // BatchCreateSpec holds the information for creating // multiple nodes in the graph. BatchCreateSpec struct { Nodes []*CreateSpec // The OnConflict option allows providing on-conflict // options to the INSERT statement. // // sqlgraph.CreateSpec{ // OnConflict: []sql.ConflictOption{ // sql.ResolveWithNewValues(), // }, // } // OnConflict []sql.ConflictOption } ) // CreateNode applies the CreateSpec on the graph. The operation creates a new // record in the database, and connects it to other nodes specified in spec.Edges. func CreateNode(ctx context.Context, drv dialect.Driver, spec *CreateSpec) error { gr := graph{tx: drv, builder: sql.Dialect(drv.Dialect())} cr := &creator{CreateSpec: spec, graph: gr} return cr.node(ctx, drv) } // BatchCreate applies the BatchCreateSpec on the graph. func BatchCreate(ctx context.Context, drv dialect.Driver, spec *BatchCreateSpec) error { gr := graph{tx: drv, builder: sql.Dialect(drv.Dialect())} cr := &batchCreator{BatchCreateSpec: spec, graph: gr} return cr.nodes(ctx, drv) } type ( // EdgeMut defines edge mutations. EdgeMut struct { Add []*EdgeSpec Clear []*EdgeSpec } // FieldMut defines field mutations. FieldMut struct { Set []*FieldSpec // field = ? Add []*FieldSpec // field = field + ? Clear []*FieldSpec // field = NULL } // UpdateSpec holds the information for updating one // or more nodes in the graph. UpdateSpec struct { Node *NodeSpec Edges EdgeMut Fields FieldMut Predicate func(*sql.Selector) Modifiers []func(*sql.UpdateBuilder) ScanValues func(columns []string) ([]any, error) Assign func(columns []string, values []any) error } ) // UpdateNode applies the UpdateSpec on one node in the graph. func UpdateNode(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) error { tx, err := drv.Tx(ctx) if err != nil { return err } gr := graph{tx: tx, builder: sql.Dialect(drv.Dialect())} cr := &updater{UpdateSpec: spec, graph: gr} if err := cr.node(ctx, tx); err != nil { return rollback(tx, err) } return tx.Commit() } // UpdateNodes applies the UpdateSpec on a set of nodes in the graph. func UpdateNodes(ctx context.Context, drv dialect.Driver, spec *UpdateSpec) (int, error) { gr := graph{tx: drv, builder: sql.Dialect(drv.Dialect())} cr := &updater{UpdateSpec: spec, graph: gr} return cr.nodes(ctx, drv) } // NotFoundError returns when trying to update an // entity, and it was not found in the database. type NotFoundError struct { table string id driver.Value } func (e *NotFoundError) Error() string { return fmt.Sprintf("record with id %v not found in table %s", e.id, e.table) } // DeleteSpec holds the information for delete one // or more nodes in the graph. type DeleteSpec struct { Node *NodeSpec Predicate func(*sql.Selector) } // DeleteNodes applies the DeleteSpec on the graph. func DeleteNodes(ctx context.Context, drv dialect.Driver, spec *DeleteSpec) (int, error) { var ( res sql.Result builder = sql.Dialect(drv.Dialect()) ) selector := builder.Select(). From(builder.Table(spec.Node.Table).Schema(spec.Node.Schema)). WithContext(ctx) if pred := spec.Predicate; pred != nil { pred(selector) } query, args := builder.Delete(spec.Node.Table).Schema(spec.Node.Schema).FromSelect(selector).Query() if err := drv.Exec(ctx, query, args, &res); err != nil { return 0, err } affected, err := res.RowsAffected() if err != nil { return 0, err } return int(affected), nil } // QuerySpec holds the information for querying // nodes in the graph. type QuerySpec struct { Node *NodeSpec // Nodes info. From *sql.Selector // Optional query source (from path). Limit int Offset int Unique bool Order func(*sql.Selector) Predicate func(*sql.Selector) Modifiers []func(*sql.Selector) ScanValues func(columns []string) ([]any, error) Assign func(columns []string, values []any) error } // QueryNodes queries the nodes in the graph query and scans them to the given values. func QueryNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) error { builder := sql.Dialect(drv.Dialect()) qr := &query{graph: graph{builder: builder}, QuerySpec: spec} return qr.nodes(ctx, drv) } // CountNodes counts the nodes in the given graph query. func CountNodes(ctx context.Context, drv dialect.Driver, spec *QuerySpec) (int, error) { builder := sql.Dialect(drv.Dialect()) qr := &query{graph: graph{builder: builder}, QuerySpec: spec} return qr.count(ctx, drv) } // EdgeQuerySpec holds the information for querying // edges in the graph. type EdgeQuerySpec struct { Edge *EdgeSpec Predicate func(*sql.Selector) ScanValues func() [2]any Assign func(out, in any) error } // QueryEdges queries the edges in the graph and scans the result with the given dest function. func QueryEdges(ctx context.Context, drv dialect.Driver, spec *EdgeQuerySpec) error { if len(spec.Edge.Columns) != 2 { return fmt.Errorf("sqlgraph: edge query requires 2 columns (out, in)") } out, in := spec.Edge.Columns[0], spec.Edge.Columns[1] if spec.Edge.Inverse { out, in = in, out } selector := sql.Dialect(drv.Dialect()). Select(out, in). From(sql.Table(spec.Edge.Table).Schema(spec.Edge.Schema)) if p := spec.Predicate; p != nil { p(selector) } rows := &sql.Rows{} query, args := selector.Query() if err := drv.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() for rows.Next() { values := spec.ScanValues() if err := rows.Scan(values[0], values[1]); err != nil { return err } if err := spec.Assign(values[0], values[1]); err != nil { return err } } return rows.Err() } type query struct { graph *QuerySpec } func (q *query) nodes(ctx context.Context, drv dialect.Driver) error { rows := &sql.Rows{} selector, err := q.selector(ctx) if err != nil { return err } query, args := selector.Query() if err := drv.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() columns, err := rows.Columns() if err != nil { return err } for rows.Next() { values, err := q.ScanValues(columns) if err != nil { return err } if err := rows.Scan(values...); err != nil { return err } if err := q.Assign(columns, values); err != nil { return err } } return rows.Err() } func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) { rows := &sql.Rows{} selector, err := q.selector(ctx) if err != nil { return 0, err } // If no columns were selected in count, // the default selection is by node ids. columns := q.Node.Columns if len(columns) == 0 && q.Node.ID != nil { columns = append(columns, q.Node.ID.Column) } for i, c := range columns { columns[i] = selector.C(c) } if q.Unique { selector.SetDistinct(false) selector.Count(sql.Distinct(columns...)) } else { selector.Count(columns...) } query, args := selector.Query() if err := drv.Query(ctx, query, args, rows); err != nil { return 0, err } defer rows.Close() return sql.ScanInt(rows) } func (q *query) selector(ctx context.Context) (*sql.Selector, error) { selector := q.builder. Select(). From(q.builder.Table(q.Node.Table).Schema(q.Node.Schema)). WithContext(ctx) if q.From != nil { selector = q.From } selector.Select(selector.Columns(q.Node.Columns...)...) if pred := q.Predicate; pred != nil { pred(selector) } if order := q.Order; order != nil { order(selector) } if q.Offset != 0 { // Limit is mandatory for the offset clause. We start // with default value, and override it below if needed. selector.Offset(q.Offset).Limit(math.MaxInt32) } if q.Limit != 0 { selector.Limit(q.Limit) } if q.Unique { selector.Distinct() } for _, m := range q.Modifiers { m(selector) } if err := selector.Err(); err != nil { return nil, err } return selector, nil } type updater struct { graph *UpdateSpec } func (u *updater) node(ctx context.Context, tx dialect.ExecQuerier) error { var ( id driver.Value idp *sql.Predicate addEdges = EdgeSpecs(u.Edges.Add).GroupRel() clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel() ) switch { // In case it is not an edge schema, the id holds the PK // of the node used for linking it with the other nodes. case u.Node.ID != nil: id = u.Node.ID.Value idp = sql.EQ(u.Node.ID.Column, id) case len(u.Node.CompositeID) == 2: idp = sql.And( sql.EQ(u.Node.CompositeID[0].Column, u.Node.CompositeID[0].Value), sql.EQ(u.Node.CompositeID[1].Column, u.Node.CompositeID[1].Value), ) case len(u.Node.CompositeID) != 2: return fmt.Errorf("sql/sqlgraph: invalid composite id for update table %q", u.Node.Table) default: return fmt.Errorf("sql/sqlgraph: missing node id for update table %q", u.Node.Table) } update := u.builder.Update(u.Node.Table).Schema(u.Node.Schema).Where(idp) if pred := u.Predicate; pred != nil { selector := u.builder.Select().From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)) pred(selector) update.FromSelect(selector) } if err := u.setTableColumns(update, addEdges, clearEdges); err != nil { return err } for _, m := range u.Modifiers { m(update) } if err := update.Err(); err != nil { return err } if !update.Empty() { var res sql.Result query, args := update.Query() if err := tx.Exec(ctx, query, args, &res); err != nil { return err } affected, err := res.RowsAffected() if err != nil { return err } // In case there are zero affected rows by this statement, we need to distinguish // between the case of "record was not found" and "record was not changed". if affected == 0 && u.Predicate != nil { if err := u.ensureExists(ctx); err != nil { return err } } } if id != nil { // Not an edge schema. if err := u.setExternalEdges(ctx, []driver.Value{id}, addEdges, clearEdges); err != nil { return err } } // Ignore querying the database when there's nothing // to scan into it. if u.ScanValues == nil { return nil } selector := u.builder.Select(u.Node.Columns...). From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)). // Skip adding the custom predicates that were attached // to the updater as they may point to columns that were // changed by the UPDATE statement. Where(idp) rows := &sql.Rows{} query, args := selector.Query() if err := tx.Query(ctx, query, args, rows); err != nil { return err } return u.scan(rows) } func (u *updater) nodes(ctx context.Context, drv dialect.Driver) (int, error) { var ( addEdges = EdgeSpecs(u.Edges.Add).GroupRel() clearEdges = EdgeSpecs(u.Edges.Clear).GroupRel() multiple = hasExternalEdges(addEdges, clearEdges) update = u.builder.Update(u.Node.Table).Schema(u.Node.Schema) selector = u.builder.Select(). From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)). WithContext(ctx) ) switch { // In case it is not an edge schema, the id holds the PK of // the returned nodes are used for updating external tables. case u.Node.ID != nil: selector.Select(u.Node.ID.Column) case len(u.Node.CompositeID) == 2: // Other edge-schemas (M2M tables) cannot be updated by this operation. // Also, in case there is a need to update an external foreign-key, it must // be a single value and the user should use the "update by id" API instead. if multiple { return 0, fmt.Errorf("sql/sqlgraph: update edge schema table %q cannot update external tables", u.Node.Table) } case len(u.Node.CompositeID) != 2: return 0, fmt.Errorf("sql/sqlgraph: invalid composite id for update table %q", u.Node.Table) default: return 0, fmt.Errorf("sql/sqlgraph: missing node id for update table %q", u.Node.Table) } if err := u.setTableColumns(update, addEdges, clearEdges); err != nil { return 0, err } if pred := u.Predicate; pred != nil { pred(selector) } // In case of single statement update, avoid opening a transaction manually. if !multiple { update.FromSelect(selector) return u.updateTable(ctx, update) } tx, err := drv.Tx(ctx) if err != nil { return 0, err } u.tx = tx affected, err := func() (int, error) { var ( ids []driver.Value rows = &sql.Rows{} query, args = selector.Query() ) if err := u.tx.Query(ctx, query, args, rows); err != nil { return 0, fmt.Errorf("querying table %s: %w", u.Node.Table, err) } defer rows.Close() if err := sql.ScanSlice(rows, &ids); err != nil { return 0, fmt.Errorf("scan node ids: %w", err) } if err := rows.Close(); err != nil { return 0, err } if len(ids) == 0 { return 0, nil } update.Where(matchID(u.Node.ID.Column, ids)) // In case of multi statement update, that change can // affect more than 1 table, and therefore, we return // the list of ids as number of affected records. if _, err := u.updateTable(ctx, update); err != nil { return 0, err } if err := u.setExternalEdges(ctx, ids, addEdges, clearEdges); err != nil { return 0, err } return len(ids), nil }() if err != nil { return 0, rollback(tx, err) } return affected, tx.Commit() } func (u *updater) updateTable(ctx context.Context, stmt *sql.UpdateBuilder) (int, error) { for _, m := range u.Modifiers { m(stmt) } if err := stmt.Err(); err != nil { return 0, err } if stmt.Empty() { return 0, nil } var ( res sql.Result query, args = stmt.Query() ) if err := u.tx.Exec(ctx, query, args, &res); err != nil { return 0, err } affected, err := res.RowsAffected() if err != nil { return 0, err } return int(affected), nil } func (u *updater) setExternalEdges(ctx context.Context, ids []driver.Value, addEdges, clearEdges map[Rel][]*EdgeSpec) error { if err := u.graph.clearM2MEdges(ctx, ids, clearEdges[M2M]); err != nil { return err } if err := u.graph.addM2MEdges(ctx, ids, addEdges[M2M]); err != nil { return err } if err := u.graph.clearFKEdges(ctx, ids, append(clearEdges[O2M], clearEdges[O2O]...)); err != nil { return err } if err := u.graph.addFKEdges(ctx, ids, append(addEdges[O2M], addEdges[O2O]...)); err != nil { return err } return nil } // setTableColumns sets the table columns and foreign_keys used in insert. func (u *updater) setTableColumns(update *sql.UpdateBuilder, addEdges, clearEdges map[Rel][]*EdgeSpec) error { // Avoid multiple assignments to the same column. setEdges := make(map[string]bool) for _, e := range addEdges[M2O] { setEdges[e.Columns[0]] = true } for _, e := range addEdges[O2O] { if e.Inverse || e.Bidi { setEdges[e.Columns[0]] = true } } for _, fi := range u.Fields.Clear { update.SetNull(fi.Column) } for _, e := range clearEdges[M2O] { if col := e.Columns[0]; !setEdges[col] { update.SetNull(col) } } for _, e := range clearEdges[O2O] { col := e.Columns[0] if (e.Inverse || e.Bidi) && !setEdges[col] { update.SetNull(col) } } err := setTableColumns(u.Fields.Set, addEdges, func(column string, value driver.Value) { update.Set(column, value) }) if err != nil { return err } for _, fi := range u.Fields.Add { update.Add(fi.Column, fi.Value) } return nil } func (u *updater) scan(rows *sql.Rows) error { defer rows.Close() columns, err := rows.Columns() if err != nil { return err } if !rows.Next() { if err := rows.Err(); err != nil { return err } if len(u.Node.CompositeID) == 2 { return &NotFoundError{table: u.Node.Table, id: []driver.Value{u.Node.CompositeID[0].Value, u.Node.CompositeID[1].Value}} } return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value} } values, err := u.ScanValues(columns) if err != nil { return err } if err := rows.Scan(values...); err != nil { return fmt.Errorf("failed scanning rows: %w", err) } if err := u.Assign(columns, values); err != nil { return err } return nil } func (u *updater) ensureExists(ctx context.Context) error { exists := u.builder.Select().From(u.builder.Table(u.Node.Table).Schema(u.Node.Schema)).Where(sql.EQ(u.Node.ID.Column, u.Node.ID.Value)) u.Predicate(exists) query, args := u.builder.SelectExpr(sql.Exists(exists)).Query() rows := &sql.Rows{} if err := u.tx.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() found, err := sql.ScanBool(rows) if err != nil { return err } if !found { return &NotFoundError{table: u.Node.Table, id: u.Node.ID.Value} } return nil } type creator struct { graph *CreateSpec } func (c *creator) node(ctx context.Context, drv dialect.Driver) error { var ( edges = EdgeSpecs(c.Edges).GroupRel() insert = c.builder.Insert(c.Table).Schema(c.Schema).Default() ) if err := c.setTableColumns(insert, edges); err != nil { return err } tx, err := c.mayTx(ctx, drv, edges) if err != nil { return err } if err := func() error { // In case the spec does not contain an ID field, we assume // we interact with an edge-schema with composite primary key. if c.ID == nil { c.ensureConflict(insert) query, args := insert.Query() return c.tx.Exec(ctx, query, args, nil) } if err := c.insert(ctx, insert); err != nil { return err } if err := c.graph.addM2MEdges(ctx, []driver.Value{c.ID.Value}, edges[M2M]); err != nil { return err } return c.graph.addFKEdges(ctx, []driver.Value{c.ID.Value}, append(edges[O2M], edges[O2O]...)) }(); err != nil { return rollback(tx, err) } return tx.Commit() } // mayTx opens a new transaction if the create operation spans across multiple statements. func (c *creator) mayTx(ctx context.Context, drv dialect.Driver, edges map[Rel][]*EdgeSpec) (dialect.Tx, error) { if !hasExternalEdges(edges, nil) { return dialect.NopTx(drv), nil } tx, err := drv.Tx(ctx) if err != nil { return nil, err } c.tx = tx return tx, nil } // setTableColumns sets the table columns and foreign_keys used in insert. func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*EdgeSpec) error { err := setTableColumns(c.Fields, edges, func(column string, value driver.Value) { insert.Set(column, value) }) return err } // insert a node to its table and sets its ID if it was not provided by the user. func (c *creator) insert(ctx context.Context, insert *sql.InsertBuilder) error { c.ensureConflict(insert) // If the id field was provided by the user. if c.ID.Value != nil { insert.Set(c.ID.Column, c.ID.Value) // In case of "ON CONFLICT", the record may exist in the // database, and we need to get back the database id field. if len(c.CreateSpec.OnConflict) == 0 { query, args := insert.Query() return c.tx.Exec(ctx, query, args, nil) } } return c.insertLastID(ctx, insert.Returning(c.ID.Column)) } // ensureConflict ensures the ON CONFLICT is added to the insert statement. func (c *creator) ensureConflict(insert *sql.InsertBuilder) { if opts := c.CreateSpec.OnConflict; len(opts) > 0 { insert.OnConflict(opts...) c.ensureLastInsertID(insert) } } // ensureLastInsertID ensures the LAST_INSERT_ID was added to the // 'ON DUPLICATE ... UPDATE' clause in it was not provided. func (c *creator) ensureLastInsertID(insert *sql.InsertBuilder) { if c.ID == nil || !c.ID.Type.Numeric() || c.ID.Value != nil || insert.Dialect() != dialect.MySQL { return } insert.OnConflict(sql.ResolveWith(func(s *sql.UpdateSet) { for _, column := range s.UpdateColumns() { if column == c.ID.Column { return } } s.Set(c.ID.Column, sql.Expr(fmt.Sprintf("LAST_INSERT_ID(%s)", s.Table().C(c.ID.Column)))) })) } type batchCreator struct { graph *BatchCreateSpec } func (c *batchCreator) nodes(ctx context.Context, drv dialect.Driver) error { if len(c.Nodes) == 0 { return nil } columns := make(map[string]struct{}) values := make([]map[string]driver.Value, len(c.Nodes)) for i, node := range c.Nodes { if i > 0 && node.Table != c.Nodes[i-1].Table { return fmt.Errorf("more than 1 table for batch insert: %q != %q", node.Table, c.Nodes[i-1].Table) } values[i] = make(map[string]driver.Value) if node.ID != nil && node.ID.Value != nil { columns[node.ID.Column] = struct{}{} values[i][node.ID.Column] = node.ID.Value } edges := EdgeSpecs(node.Edges).GroupRel() err := setTableColumns(node.Fields, edges, func(column string, value driver.Value) { columns[column] = struct{}{} values[i][column] = value }) if err != nil { return err } } for column := range columns { for i := range values { if _, exists := values[i][column]; !exists { if c.Nodes[i].ID != nil && column == c.Nodes[i].ID.Column { // If the ID value was provided to one of the nodes, it should be // provided to all others because this affects the way we calculate // their values in MySQL and SQLite dialects. return fmt.Errorf("incosistent id values for batch insert") } // Assign NULL values for empty placeholders. values[i][column] = nil } } } sorted := keys(columns) insert := c.builder.Insert(c.Nodes[0].Table).Schema(c.Nodes[0].Schema).Default().Columns(sorted...) for i := range values { vs := make([]any, len(sorted)) for j, c := range sorted { vs[j] = values[i][c] } insert.Values(vs...) } tx, err := c.mayTx(ctx, drv) if err != nil { return err } c.tx = tx if err := func() error { // In case the spec does not contain an ID field, we assume // we interact with an edge-schema with composite primary key. if c.Nodes[0].ID == nil { c.ensureConflict(insert) query, args := insert.Query() return tx.Exec(ctx, query, args, nil) } if err := c.batchInsert(ctx, tx, insert); err != nil { return fmt.Errorf("insert nodes to table %q: %w", c.Nodes[0].Table, err) } if err := c.batchAddM2M(ctx, c.BatchCreateSpec); err != nil { return err } // FKs that exist in different tables can't be updated in batch (using the CASE // statement), because we rely on RowsAffected to check if the FK column is NULL. for _, node := range c.Nodes { edges := EdgeSpecs(node.Edges).GroupRel() if err := c.graph.addFKEdges(ctx, []driver.Value{node.ID.Value}, append(edges[O2M], edges[O2O]...)); err != nil { return err } } return nil }(); err != nil { return rollback(tx, err) } return tx.Commit() } // mayTx opens a new transaction if the create operation spans across multiple statements. func (c *batchCreator) mayTx(ctx context.Context, drv dialect.Driver) (dialect.Tx, error) { for _, node := range c.Nodes { for _, edge := range node.Edges { if isExternalEdge(edge) { return drv.Tx(ctx) } } } return dialect.NopTx(drv), nil } // batchInsert inserts a batch of nodes to their table and sets their ID if it was not provided by the user. func (c *batchCreator) batchInsert(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error { c.ensureConflict(insert) return c.insertLastIDs(ctx, tx, insert.Returning(c.Nodes[0].ID.Column)) } // ensureConflict ensures the ON CONFLICT is added to the insert statement. func (c *batchCreator) ensureConflict(insert *sql.InsertBuilder) { if opts := c.BatchCreateSpec.OnConflict; len(opts) > 0 { insert.OnConflict(opts...) } } // GroupRel groups edges by their relation type. func (es EdgeSpecs) GroupRel() map[Rel][]*EdgeSpec { edges := make(map[Rel][]*EdgeSpec) for _, edge := range es { edges[edge.Rel] = append(edges[edge.Rel], edge) } return edges } // GroupTable groups edges by their table name. func (es EdgeSpecs) GroupTable() map[string][]*EdgeSpec { edges := make(map[string][]*EdgeSpec) for _, edge := range es { edges[edge.Table] = append(edges[edge.Table], edge) } return edges } // FilterRel returns edges for the given relation type. func (es EdgeSpecs) FilterRel(r Rel) EdgeSpecs { edges := make([]*EdgeSpec, 0, len(es)) for _, edge := range es { if edge.Rel == r { edges = append(edges, edge) } } return edges } // The common operations shared between the different builders. // // M2M edges reside in join tables and require INSERT and DELETE // queries for adding or removing edges respectively. // // O2M and non-inverse O2O edges also reside in external tables, // but use UPDATE queries (fk = ?, fk = NULL). type graph struct { tx dialect.ExecQuerier builder *sql.DialectBuilder } func (g *graph) clearM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error { // Remove all M2M edges from the same type at once. // The EdgeSpec is the same for all members in a group. tables := edges.GroupTable() for _, table := range edgeKeys(tables) { edges := tables[table] preds := make([]*sql.Predicate, 0, len(edges)) for _, edge := range edges { fromC, toC := edge.Columns[0], edge.Columns[1] if edge.Inverse { fromC, toC = toC, fromC } // If there are no specific edges (to target-nodes) to remove, // clear all edges that go out (or come in) from the nodes. if len(edge.Target.Nodes) == 0 { preds = append(preds, matchID(fromC, ids)) if edge.Bidi { preds = append(preds, matchID(toC, ids)) } } else { pk1, pk2 := ids, edge.Target.Nodes preds = append(preds, matchIDs(fromC, pk1, toC, pk2)) if edge.Bidi { preds = append(preds, matchIDs(toC, pk1, fromC, pk2)) } } } deleter := g.builder.Delete(table).Where(sql.Or(preds...)) if edges[0].Schema != "" { // If the Schema field was provided to the EdgeSpec (by the // generated code), it should be the same for all EdgeSpecs. deleter.Schema(edges[0].Schema) } query, args := deleter.Query() if err := g.tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("remove m2m edge for table %s: %w", table, err) } } return nil } func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeSpecs) error { // Insert all M2M edges from the same type at once. // The EdgeSpec is the same for all members in a group. tables := edges.GroupTable() for _, table := range edgeKeys(tables) { var ( edges = tables[table] columns = edges[0].Columns values = make([]any, 0, len(edges[0].Target.Fields)) ) // Specs are generated equally for all edges from the same type. for _, f := range edges[0].Target.Fields { values = append(values, f.Value) columns = append(columns, f.Column) } insert := g.builder.Insert(table).Columns(columns...) if edges[0].Schema != "" { // If the Schema field was provided to the EdgeSpec (by the // generated code), it should be the same for all EdgeSpecs. insert.Schema(edges[0].Schema) } for _, edge := range edges { pk1, pk2 := ids, edge.Target.Nodes if edge.Inverse { pk1, pk2 = pk2, pk1 } for _, pair := range product(pk1, pk2) { insert.Values(append([]any{pair[0], pair[1]}, values...)...) if edge.Bidi { insert.Values(append([]any{pair[1], pair[0]}, values...)...) } } } query, args := insert.Query() if err := g.tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("add m2m edge for table %s: %w", table, err) } } return nil } func (g *graph) batchAddM2M(ctx context.Context, spec *BatchCreateSpec) error { tables := make(map[string]*sql.InsertBuilder) for _, node := range spec.Nodes { edges := EdgeSpecs(node.Edges).FilterRel(M2M) for t, edges := range edges.GroupTable() { insert, ok := tables[t] if !ok { insert = g.builder.Insert(t).Columns(edges[0].Columns...) if edges[0].Schema != "" { // If the Schema field was provided to the EdgeSpec (by the // generated code), it should be the same for all EdgeSpecs. insert.Schema(edges[0].Schema) } } tables[t] = insert if len(edges) != 1 { return fmt.Errorf("expect exactly 1 edge-spec per table, but got %d", len(edges)) } edge := edges[0] pk1, pk2 := []driver.Value{node.ID.Value}, edge.Target.Nodes if edge.Inverse { pk1, pk2 = pk2, pk1 } for _, pair := range product(pk1, pk2) { insert.Values(pair[0], pair[1]) if edge.Bidi { insert.Values(pair[1], pair[0]) } } } } for _, table := range insertKeys(tables) { query, args := tables[table].Query() if err := g.tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("add m2m edge for table %s: %w", table, err) } } return nil } func (g *graph) clearFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error { for _, edge := range edges { if edge.Rel == O2O && edge.Inverse { continue } // O2O relations can be cleared without // passing the target ids. pred := matchID(edge.Columns[0], ids) if nodes := edge.Target.Nodes; len(nodes) > 0 { pred = matchIDs(edge.Target.IDSpec.Column, edge.Target.Nodes, edge.Columns[0], ids) } query, args := g.builder.Update(edge.Table). SetNull(edge.Columns[0]). Where(pred). Query() if err := g.tx.Exec(ctx, query, args, nil); err != nil { return fmt.Errorf("add %s edge for table %s: %w", edge.Rel, edge.Table, err) } } return nil } func (g *graph) addFKEdges(ctx context.Context, ids []driver.Value, edges []*EdgeSpec) error { id := ids[0] if len(ids) > 1 && len(edges) != 0 { // O2M and non-inverse O2O edges are defined by a FK in the "other" // table. Therefore, ids[i+1] will override ids[i] which is invalid. return fmt.Errorf("unable to link FK edge to more than 1 node: %v", ids) } for _, edge := range edges { if edge.Rel == O2O && edge.Inverse { continue } p := sql.EQ(edge.Target.IDSpec.Column, edge.Target.Nodes[0]) // Use "IN" predicate instead of list of "OR" // in case of more than on nodes to connect. if len(edge.Target.Nodes) > 1 { p = sql.InValues(edge.Target.IDSpec.Column, edge.Target.Nodes...) } query, args := g.builder.Update(edge.Table). Schema(edge.Schema). Set(edge.Columns[0], id). Where(sql.And(p, sql.IsNull(edge.Columns[0]))). Query() var res sql.Result if err := g.tx.Exec(ctx, query, args, &res); err != nil { return fmt.Errorf("add %s edge for table %s: %w", edge.Rel, edge.Table, err) } affected, err := res.RowsAffected() if err != nil { return err } // Setting the FK value of the "other" table // without clearing it before, is not allowed. if ids := edge.Target.Nodes; int(affected) < len(ids) { return &ConstraintError{msg: fmt.Sprintf("one of %v is already connected to a different %s", ids, edge.Columns[0])} } } return nil } func hasExternalEdges(addEdges, clearEdges map[Rel][]*EdgeSpec) bool { // M2M edges reside in a join-table, and O2M edges reside // in the M2O table (the entity that holds the FK). if len(clearEdges[M2M]) > 0 || len(addEdges[M2M]) > 0 || len(clearEdges[O2M]) > 0 || len(addEdges[O2M]) > 0 { return true } for _, edges := range [][]*EdgeSpec{clearEdges[O2O], addEdges[O2O]} { for _, e := range edges { if !e.Inverse { return true } } } return false } // isExternalEdge reports if the given edge requires an UPDATE // or an INSERT to other table. func isExternalEdge(e *EdgeSpec) bool { return e.Rel == M2M || e.Rel == O2M || e.Rel == O2O && !e.Inverse } // setTableColumns is shared between updater and creator. func setTableColumns(fields []*FieldSpec, edges map[Rel][]*EdgeSpec, set func(string, driver.Value)) (err error) { for _, fi := range fields { value := fi.Value if fi.Type == field.TypeJSON { buf, err := json.Marshal(value) if err != nil { return fmt.Errorf("marshal value for column %s: %w", fi.Column, err) } // If the underlying driver does not support JSON types, // driver.DefaultParameterConverter will convert it to uint8. value = json.RawMessage(buf) } set(fi.Column, value) } for _, e := range edges[M2O] { set(e.Columns[0], e.Target.Nodes[0]) } for _, e := range edges[O2O] { if e.Inverse || e.Bidi { set(e.Columns[0], e.Target.Nodes[0]) } } return nil } // insertLastID invokes the insert query on the transaction and returns the LastInsertID. func (c *creator) insertLastID(ctx context.Context, insert *sql.InsertBuilder) error { query, args := insert.Query() if err := insert.Err(); err != nil { return err } // MySQL does not support the "RETURNING" clause. if insert.Dialect() != dialect.MySQL { rows := &sql.Rows{} if err := c.tx.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() switch _, ok := c.ID.Value.(field.ValueScanner); { case ok: // If the ID implements the sql.Scanner // interface it should be a pointer type. return sql.ScanOne(rows, c.ID.Value) case c.ID.Type.Numeric(): // Normalize the type to int64 to make it // looks like LastInsertId. id, err := sql.ScanInt64(rows) if err != nil { return err } c.ID.Value = id return nil default: return sql.ScanOne(rows, &c.ID.Value) } } // MySQL. var res sql.Result if err := c.tx.Exec(ctx, query, args, &res); err != nil { return err } // If the ID field is not numeric (e.g. string), // there is no way to scan the LAST_INSERT_ID. if c.ID.Type.Numeric() { id, err := res.LastInsertId() if err != nil { return err } c.ID.Value = id } return nil } // insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities. func (c *batchCreator) insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error { query, args := insert.Query() if err := insert.Err(); err != nil { return err } // MySQL does not support the "RETURNING" clause. if insert.Dialect() != dialect.MySQL { rows := &sql.Rows{} if err := tx.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() for i := 0; rows.Next(); i++ { node := c.Nodes[i] if node.ID.Type.Numeric() { // Normalize the type to int64 to make it looks // like LastInsertId. var id int64 if err := rows.Scan(&id); err != nil { return err } node.ID.Value = id } else if err := rows.Scan(&node.ID.Value); err != nil { return err } } return nil } // MySQL. var res sql.Result if err := tx.Exec(ctx, query, args, &res); err != nil { return err } // If the ID field is not numeric (e.g. string), // there is no way to scan the LAST_INSERT_ID. if len(c.Nodes) > 0 && c.Nodes[0].ID.Type.Numeric() { id, err := res.LastInsertId() if err != nil { return err } affected, err := res.RowsAffected() if err != nil { return err } // Assume the ID field is AUTO_INCREMENT // if its type is numeric. for i := 0; int64(i) < affected && i < len(c.Nodes); i++ { c.Nodes[i].ID.Value = id + int64(i) } } return nil } // rollback calls to tx.Rollback and wraps the given error with the rollback error if occurred. func rollback(tx dialect.Tx, err error) error { if rerr := tx.Rollback(); rerr != nil { err = fmt.Errorf("%w: %v", err, rerr) } return err } func edgeKeys(m map[string][]*EdgeSpec) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Strings(keys) return keys } func insertKeys(m map[string]*sql.InsertBuilder) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Strings(keys) return keys } func keys(m map[string]struct{}) []string { keys := make([]string, 0, len(m)) for k := range m { keys = append(keys, k) } sort.Strings(keys) return keys } func matchID(column string, pk []driver.Value) *sql.Predicate { if len(pk) > 1 { return sql.InValues(column, pk...) } return sql.EQ(column, pk[0]) } func matchIDs(column1 string, pk1 []driver.Value, column2 string, pk2 []driver.Value) *sql.Predicate { p := matchID(column1, pk1) if len(pk2) > 1 { // Use "IN" predicate instead of list of "OR" // in case of more than on nodes to connect. return sql.And(p, sql.InValues(column2, pk2...)) } return sql.And(p, sql.EQ(column2, pk2[0])) } // cartesian product of 2 id sets. func product(a, b []driver.Value) [][2]driver.Value { c := make([][2]driver.Value, 0, len(a)*len(b)) for i := range a { for j := range b { c = append(c, [2]driver.Value{a[i], b[j]}) } } return c } ent-0.11.3/dialect/sql/sqlgraph/graph_test.go000066400000000000000000002461221431500740500211040ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqlgraph import ( "context" "database/sql/driver" "errors" "fmt" "regexp" "strings" "testing" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) func TestNeighbors(t *testing.T) { tests := []struct { name string input *Step wantQuery string wantArgs []any }{ { name: "O2O/1type", // Since the relation is on the same sql.Table, // V used as a reference value. input: NewStep( From("users", "id", 1), To("users", "id"), Edge(O2O, false, "users", "spouse_id"), ), wantQuery: "SELECT * FROM `users` WHERE `spouse_id` = ?", wantArgs: []any{1}, }, { name: "O2O/1type/inverse", input: NewStep( From("nodes", "id", 1), To("nodes", "id"), Edge(O2O, true, "nodes", "prev_id"), ), wantQuery: "SELECT * FROM `nodes` JOIN (SELECT `prev_id` FROM `nodes` WHERE `id` = ?) AS `t1` ON `nodes`.`id` = `t1`.`prev_id`", wantArgs: []any{1}, }, { name: "O2M/1type", input: NewStep( From("users", "id", 1), To("users", "id"), Edge(O2M, false, "users", "parent_id"), ), wantQuery: "SELECT * FROM `users` WHERE `parent_id` = ?", wantArgs: []any{1}, }, { name: "O2O/2types", input: NewStep( From("users", "id", 2), To("card", "id"), Edge(O2O, false, "cards", "owner_id"), ), wantQuery: "SELECT * FROM `card` WHERE `owner_id` = ?", wantArgs: []any{2}, }, { name: "O2O/2types/inverse", input: NewStep( From("cards", "id", 2), To("users", "id"), Edge(O2O, true, "cards", "owner_id"), ), wantQuery: "SELECT * FROM `users` JOIN (SELECT `owner_id` FROM `cards` WHERE `id` = ?) AS `t1` ON `users`.`id` = `t1`.`owner_id`", wantArgs: []any{2}, }, { name: "O2M/2types", input: NewStep( From("users", "id", 1), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ), wantQuery: "SELECT * FROM `pets` WHERE `owner_id` = ?", wantArgs: []any{1}, }, { name: "M2O/2types/inverse", input: NewStep( From("pets", "id", 2), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), wantQuery: "SELECT * FROM `users` JOIN (SELECT `owner_id` FROM `pets` WHERE `id` = ?) AS `t1` ON `users`.`id` = `t1`.`owner_id`", wantArgs: []any{2}, }, { name: "M2O/1type/inverse", input: NewStep( From("users", "id", 2), To("users", "id"), Edge(M2O, true, "users", "parent_id"), ), wantQuery: "SELECT * FROM `users` JOIN (SELECT `parent_id` FROM `users` WHERE `id` = ?) AS `t1` ON `users`.`id` = `t1`.`parent_id`", wantArgs: []any{2}, }, { name: "M2M/2type", input: NewStep( From("groups", "id", 2), To("users", "id"), Edge(M2M, false, "user_groups", "group_id", "user_id"), ), wantQuery: "SELECT * FROM `users` JOIN (SELECT `user_groups`.`user_id` FROM `user_groups` WHERE `user_groups`.`group_id` = ?) AS `t1` ON `users`.`id` = `t1`.`user_id`", wantArgs: []any{2}, }, { name: "M2M/2type/inverse", input: NewStep( From("users", "id", 2), To("groups", "id"), Edge(M2M, true, "user_groups", "group_id", "user_id"), ), wantQuery: "SELECT * FROM `groups` JOIN (SELECT `user_groups`.`group_id` FROM `user_groups` WHERE `user_groups`.`user_id` = ?) AS `t1` ON `groups`.`id` = `t1`.`group_id`", wantArgs: []any{2}, }, { name: "schema/O2O/1type", // Since the relation is on the same sql.Table, // V used as a reference value. input: func() *Step { step := NewStep( From("users", "id", 1), To("users", "id"), Edge(O2O, false, "users", "spouse_id"), ) step.To.Schema = "mydb" return step }(), wantQuery: "SELECT * FROM `mydb`.`users` WHERE `spouse_id` = ?", wantArgs: []any{1}, }, { name: "schema/O2O/1type/inverse", input: func() *Step { step := NewStep( From("nodes", "id", 1), To("nodes", "id"), Edge(O2O, true, "nodes", "prev_id"), ) step.To.Schema = "mydb" step.Edge.Schema = "mydb" return step }(), wantQuery: "SELECT * FROM `mydb`.`nodes` JOIN (SELECT `prev_id` FROM `mydb`.`nodes` WHERE `id` = ?) AS `t1` ON `mydb`.`nodes`.`id` = `t1`.`prev_id`", wantArgs: []any{1}, }, { name: "schema/O2M/1type", input: func() *Step { step := NewStep( From("users", "id", 1), To("users", "id"), Edge(O2M, false, "users", "parent_id"), ) step.To.Schema = "mydb" return step }(), wantQuery: "SELECT * FROM `mydb`.`users` WHERE `parent_id` = ?", wantArgs: []any{1}, }, { name: "schema/O2O/2types", input: func() *Step { step := NewStep( From("users", "id", 2), To("card", "id"), Edge(O2O, false, "cards", "owner_id"), ) step.To.Schema = "mydb" return step }(), wantQuery: "SELECT * FROM `mydb`.`card` WHERE `owner_id` = ?", wantArgs: []any{2}, }, { name: "schema/O2O/2types/inverse", input: func() *Step { step := NewStep( From("cards", "id", 2), To("users", "id"), Edge(O2O, true, "cards", "owner_id"), ) step.To.Schema = "mydb" step.Edge.Schema = "mydb" return step }(), wantQuery: "SELECT * FROM `mydb`.`users` JOIN (SELECT `owner_id` FROM `mydb`.`cards` WHERE `id` = ?) AS `t1` ON `mydb`.`users`.`id` = `t1`.`owner_id`", wantArgs: []any{2}, }, { name: "schema/O2M/2types", input: func() *Step { step := NewStep( From("users", "id", 1), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ) step.To.Schema = "mydb" return step }(), wantQuery: "SELECT * FROM `mydb`.`pets` WHERE `owner_id` = ?", wantArgs: []any{1}, }, { name: "schema/M2O/2types/inverse", input: func() *Step { step := NewStep( From("pets", "id", 2), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ) step.To.Schema = "s1" step.Edge.Schema = "s2" return step }(), wantQuery: "SELECT * FROM `s1`.`users` JOIN (SELECT `owner_id` FROM `s2`.`pets` WHERE `id` = ?) AS `t1` ON `s1`.`users`.`id` = `t1`.`owner_id`", wantArgs: []any{2}, }, { name: "schema/M2O/1type/inverse", input: func() *Step { step := NewStep( From("users", "id", 2), To("users", "id"), Edge(M2O, true, "users", "parent_id"), ) step.To.Schema = "s1" step.Edge.Schema = "s1" return step }(), wantQuery: "SELECT * FROM `s1`.`users` JOIN (SELECT `parent_id` FROM `s1`.`users` WHERE `id` = ?) AS `t1` ON `s1`.`users`.`id` = `t1`.`parent_id`", wantArgs: []any{2}, }, { name: "schema/M2M/2type", input: func() *Step { step := NewStep( From("groups", "id", 2), To("users", "id"), Edge(M2M, false, "user_groups", "group_id", "user_id"), ) step.To.Schema = "s1" step.Edge.Schema = "s2" return step }(), wantQuery: "SELECT * FROM `s1`.`users` JOIN (SELECT `s2`.`user_groups`.`user_id` FROM `s2`.`user_groups` WHERE `s2`.`user_groups`.`group_id` = ?) AS `t1` ON `s1`.`users`.`id` = `t1`.`user_id`", wantArgs: []any{2}, }, { name: "schema/M2M/2type/inverse", input: func() *Step { step := NewStep( From("users", "id", 2), To("groups", "id"), Edge(M2M, true, "user_groups", "group_id", "user_id"), ) step.To.Schema = "s1" step.Edge.Schema = "s2" return step }(), wantQuery: "SELECT * FROM `s1`.`groups` JOIN (SELECT `s2`.`user_groups`.`group_id` FROM `s2`.`user_groups` WHERE `s2`.`user_groups`.`user_id` = ?) AS `t1` ON `s1`.`groups`.`id` = `t1`.`group_id`", wantArgs: []any{2}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { selector := Neighbors("", tt.input) query, args := selector.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } func TestSetNeighbors(t *testing.T) { tests := []struct { name string input *Step wantQuery string wantArgs []any }{ { name: "O2M/2types", input: NewStep( From("users", "id", sql.Select().From(sql.Table("users")).Where(sql.EQ("name", "a8m"))), To("pets", "id"), Edge(O2M, false, "users", "owner_id"), ), wantQuery: `SELECT * FROM "pets" JOIN (SELECT "users"."id" FROM "users" WHERE "name" = $1) AS "t1" ON "pets"."owner_id" = "t1"."id"`, wantArgs: []any{"a8m"}, }, { name: "M2O/2types", input: NewStep( From("pets", "id", sql.Select().From(sql.Table("pets")).Where(sql.EQ("name", "pedro"))), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), wantQuery: `SELECT * FROM "users" JOIN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $1) AS "t1" ON "users"."id" = "t1"."owner_id"`, wantArgs: []any{"pedro"}, }, { name: "M2M/2types", input: NewStep( From("users", "id", sql.Select().From(sql.Table("users")).Where(sql.EQ("name", "a8m"))), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ), wantQuery: ` SELECT * FROM "groups" JOIN (SELECT "user_groups"."group_id" FROM "user_groups" JOIN (SELECT "users"."id" FROM "users" WHERE "name" = $1) AS "t1" ON "user_groups"."user_id" = "t1"."id") AS "t1" ON "groups"."id" = "t1"."group_id"`, wantArgs: []any{"a8m"}, }, { name: "M2M/2types/inverse", input: NewStep( From("groups", "id", sql.Select().From(sql.Table("groups")).Where(sql.EQ("name", "GitHub"))), To("users", "id"), Edge(M2M, true, "user_groups", "user_id", "group_id"), ), wantQuery: ` SELECT * FROM "users" JOIN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN (SELECT "groups"."id" FROM "groups" WHERE "name" = $1) AS "t1" ON "user_groups"."group_id" = "t1"."id") AS "t1" ON "users"."id" = "t1"."user_id"`, wantArgs: []any{"GitHub"}, }, { name: "schema/O2M/2types", input: func() *Step { step := NewStep( From("users", "id", sql.Select().From(sql.Table("users").Schema("s2")).Where(sql.EQ("name", "a8m"))), To("pets", "id"), Edge(O2M, false, "users", "owner_id"), ) step.To.Schema = "s1" return step }(), wantQuery: `SELECT * FROM "s1"."pets" JOIN (SELECT "s2"."users"."id" FROM "s2"."users" WHERE "name" = $1) AS "t1" ON "s1"."pets"."owner_id" = "t1"."id"`, wantArgs: []any{"a8m"}, }, { name: "schema/M2O/2types", input: func() *Step { step := NewStep( From("pets", "id", sql.Select().From(sql.Table("pets").Schema("s2")).Where(sql.EQ("name", "pedro"))), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ) step.To.Schema = "s1" return step }(), wantQuery: `SELECT * FROM "s1"."users" JOIN (SELECT "s2"."pets"."owner_id" FROM "s2"."pets" WHERE "name" = $1) AS "t1" ON "s1"."users"."id" = "t1"."owner_id"`, wantArgs: []any{"pedro"}, }, { name: "schema/M2M/2types", input: func() *Step { step := NewStep( From("users", "id", sql.Select().From(sql.Table("users").Schema("s2")).Where(sql.EQ("name", "a8m"))), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ) step.To.Schema = "s1" step.Edge.Schema = "s3" return step }(), wantQuery: ` SELECT * FROM "s1"."groups" JOIN (SELECT "s3"."user_groups"."group_id" FROM "s3"."user_groups" JOIN (SELECT "s2"."users"."id" FROM "s2"."users" WHERE "name" = $1) AS "t1" ON "s3"."user_groups"."user_id" = "t1"."id") AS "t1" ON "s1"."groups"."id" = "t1"."group_id"`, wantArgs: []any{"a8m"}, }, { name: "schema/M2M/2types/inverse", input: func() *Step { step := NewStep( From("groups", "id", sql.Select().From(sql.Table("groups").Schema("s2")).Where(sql.EQ("name", "GitHub"))), To("users", "id"), Edge(M2M, true, "user_groups", "user_id", "group_id"), ) step.To.Schema = "s1" step.Edge.Schema = "s3" return step }(), wantQuery: ` SELECT * FROM "s1"."users" JOIN (SELECT "s3"."user_groups"."user_id" FROM "s3"."user_groups" JOIN (SELECT "s2"."groups"."id" FROM "s2"."groups" WHERE "name" = $1) AS "t1" ON "s3"."user_groups"."group_id" = "t1"."id") AS "t1" ON "s1"."users"."id" = "t1"."user_id"`, wantArgs: []any{"GitHub"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { selector := SetNeighbors("postgres", tt.input) query, args := selector.Query() tt.wantQuery = strings.Join(strings.Fields(tt.wantQuery), " ") require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } func TestHasNeighbors(t *testing.T) { tests := []struct { name string step *Step selector *sql.Selector wantQuery string }{ { name: "O2O/1type", // A nodes sql.Table; linked-list (next->prev). The "prev" // node holds association pointer. The neighbors query // here checks if a node "has-next". step: NewStep( From("nodes", "id"), To("nodes", "id"), Edge(O2O, false, "nodes", "prev_id"), ), selector: sql.Select("*").From(sql.Table("nodes")), wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`id` IN (SELECT `nodes`.`prev_id` FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL)", }, { name: "O2O/1type/inverse", // Same example as above, but the neighbors // query checks if a node "has-previous". step: NewStep( From("nodes", "id"), To("nodes", "id"), Edge(O2O, true, "nodes", "prev_id"), ), selector: sql.Select("*").From(sql.Table("nodes")), wantQuery: "SELECT * FROM `nodes` WHERE `nodes`.`prev_id` IS NOT NULL", }, { name: "O2M/2type2", step: NewStep( From("users", "id"), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ), selector: sql.Select("*").From(sql.Table("users")), wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `pets`.`owner_id` FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL)", }, { name: "M2O/2type2", step: NewStep( From("pets", "id"), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), selector: sql.Select("*").From(sql.Table("pets")), wantQuery: "SELECT * FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL", }, { name: "M2M/2types", step: NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ), selector: sql.Select("*").From(sql.Table("users")), wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `user_groups`.`user_id` FROM `user_groups`)", }, { name: "M2M/2types/inverse", step: NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, true, "group_users", "group_id", "user_id"), ), selector: sql.Select("*").From(sql.Table("users")), wantQuery: "SELECT * FROM `users` WHERE `users`.`id` IN (SELECT `group_users`.`user_id` FROM `group_users`)", }, { name: "schema/O2O/1type", step: func() *Step { step := NewStep( From("nodes", "id"), To("nodes", "id"), Edge(O2O, false, "nodes", "prev_id"), ) step.Edge.Schema = "s1" return step }(), selector: sql.Select("*").From(sql.Table("nodes").Schema("s1")), wantQuery: "SELECT * FROM `s1`.`nodes` WHERE `s1`.`nodes`.`id` IN (SELECT `s1`.`nodes`.`prev_id` FROM `s1`.`nodes` WHERE `s1`.`nodes`.`prev_id` IS NOT NULL)", }, { name: "schema/O2O/1type/inverse", // Same example as above, but the neighbors // query checks if a node "has-previous". step: NewStep( From("nodes", "id"), To("nodes", "id"), Edge(O2O, true, "nodes", "prev_id"), ), selector: sql.Select("*").From(sql.Table("nodes").Schema("s1")), wantQuery: "SELECT * FROM `s1`.`nodes` WHERE `s1`.`nodes`.`prev_id` IS NOT NULL", }, { name: "schema/O2M/2type2", step: func() *Step { step := NewStep( From("users", "id"), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ) step.Edge.Schema = "s2" return step }(), selector: sql.Select("*").From(sql.Table("users").Schema("s1")), wantQuery: "SELECT * FROM `s1`.`users` WHERE `s1`.`users`.`id` IN (SELECT `s2`.`pets`.`owner_id` FROM `s2`.`pets` WHERE `s2`.`pets`.`owner_id` IS NOT NULL)", }, { name: "schema/M2O/2type2", step: NewStep( From("pets", "id"), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), selector: sql.Select("*").From(sql.Table("pets").Schema("s1")), wantQuery: "SELECT * FROM `s1`.`pets` WHERE `s1`.`pets`.`owner_id` IS NOT NULL", }, { name: "schema/M2M/2types", step: func() *Step { step := NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ) step.Edge.Schema = "s2" return step }(), selector: sql.Select("*").From(sql.Table("users").Schema("s1")), wantQuery: "SELECT * FROM `s1`.`users` WHERE `s1`.`users`.`id` IN (SELECT `s2`.`user_groups`.`user_id` FROM `s2`.`user_groups`)", }, { name: "schema/M2M/2types/inverse", step: func() *Step { step := NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, true, "group_users", "group_id", "user_id"), ) step.Edge.Schema = "s2" return step }(), selector: sql.Select("*").From(sql.Table("users").Schema("s1")), wantQuery: "SELECT * FROM `s1`.`users` WHERE `s1`.`users`.`id` IN (SELECT `s2`.`group_users`.`user_id` FROM `s2`.`group_users`)", }, { name: "O2M/2type2/selector", step: NewStep( From("users", "id"), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ), selector: sql.Select("*").From(sql.Select("*").From(sql.Table("users")).As("users")).As("users"), wantQuery: "SELECT * FROM (SELECT * FROM `users`) AS `users` WHERE `users`.`id` IN (SELECT `pets`.`owner_id` FROM `pets` WHERE `pets`.`owner_id` IS NOT NULL)", }, { name: "M2O/2type2/selector", step: NewStep( From("pets", "id"), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), selector: sql.Select("*").From(sql.Select("*").From(sql.Table("pets")).As("pets")).As("pets"), wantQuery: "SELECT * FROM (SELECT * FROM `pets`) AS `pets` WHERE `pets`.`owner_id` IS NOT NULL", }, { name: "M2M/2types/selector", step: NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ), selector: sql.Select("*").From(sql.Select("*").From(sql.Table("users")).As("users")).As("users"), wantQuery: "SELECT * FROM (SELECT * FROM `users`) AS `users` WHERE `users`.`id` IN (SELECT `user_groups`.`user_id` FROM `user_groups`)", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { for _, s := range []*sql.Selector{tt.selector, tt.selector.Clone()} { HasNeighbors(s, tt.step) query, args := s.Query() require.Equal(t, tt.wantQuery, query) require.Empty(t, args) } }) } } func TestHasNeighborsWith(t *testing.T) { tests := []struct { name string step *Step selector *sql.Selector predicate func(*sql.Selector) wantQuery string wantArgs []any }{ { name: "O2O", step: NewStep( From("users", "id"), To("cards", "id"), Edge(O2O, false, "cards", "owner_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("users")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("expired", false)) }, wantQuery: `SELECT * FROM "users" WHERE "users"."id" IN (SELECT "cards"."owner_id" FROM "cards" WHERE NOT "expired")`, }, { name: "O2O/inverse", step: NewStep( From("cards", "id"), To("users", "id"), Edge(O2O, true, "cards", "owner_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("cards")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "a8m")) }, wantQuery: `SELECT * FROM "cards" WHERE "cards"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "name" = $1)`, wantArgs: []any{"a8m"}, }, { name: "O2M", step: NewStep( From("users", "id"), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ), selector: sql.Dialect("postgres").Select("*"). From(sql.Table("users")). Where(sql.EQ("last_name", "mashraki")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "pedro")) }, wantQuery: `SELECT * FROM "users" WHERE "last_name" = $1 AND "users"."id" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2)`, wantArgs: []any{"mashraki", "pedro"}, }, { name: "M2O", step: NewStep( From("pets", "id"), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), selector: sql.Dialect("postgres").Select("*"). From(sql.Table("pets")). Where(sql.EQ("name", "pedro")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("last_name", "mashraki")) }, wantQuery: `SELECT * FROM "pets" WHERE "name" = $1 AND "pets"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "last_name" = $2)`, wantArgs: []any{"pedro", "mashraki"}, }, { name: "M2M", step: NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("users")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "GitHub")) }, wantQuery: ` SELECT * FROM "users" WHERE "users"."id" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."id" WHERE "name" = $1)`, wantArgs: []any{"GitHub"}, }, { name: "M2M/inverse", step: NewStep( From("groups", "id"), To("users", "id"), Edge(M2M, true, "user_groups", "user_id", "group_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("groups")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "a8m")) }, wantQuery: ` SELECT * FROM "groups" WHERE "groups"."id" IN (SELECT "user_groups"."group_id" FROM "user_groups" JOIN "users" AS "t1" ON "user_groups"."user_id" = "t1"."id" WHERE "name" = $1)`, wantArgs: []any{"a8m"}, }, { name: "M2M/inverse", step: NewStep( From("groups", "id"), To("users", "id"), Edge(M2M, true, "user_groups", "user_id", "group_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Table("groups")), predicate: func(s *sql.Selector) { s.Where(sql.And(sql.NotNull("name"), sql.EQ("name", "a8m"))) }, wantQuery: ` SELECT * FROM "groups" WHERE "groups"."id" IN (SELECT "user_groups"."group_id" FROM "user_groups" JOIN "users" AS "t1" ON "user_groups"."user_id" = "t1"."id" WHERE "name" IS NOT NULL AND "name" = $1)`, wantArgs: []any{"a8m"}, }, { name: "schema/O2O", step: func() *Step { step := NewStep( From("users", "id"), To("cards", "id"), Edge(O2O, false, "cards", "owner_id"), ) step.Edge.Schema = "s2" return step }(), selector: sql.Dialect("postgres").Select("*").From(sql.Table("users").Schema("s1")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("expired", false)) }, wantQuery: `SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."cards"."owner_id" FROM "s2"."cards" WHERE NOT "expired")`, }, { name: "schema/O2M", step: func() *Step { step := NewStep( From("users", "id"), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ) step.Edge.Schema = "s2" return step }(), selector: sql.Dialect("postgres").Select("*"). From(sql.Table("users").Schema("s1")). Where(sql.EQ("last_name", "mashraki")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "pedro")) }, wantQuery: `SELECT * FROM "s1"."users" WHERE "last_name" = $1 AND "s1"."users"."id" IN (SELECT "s2"."pets"."owner_id" FROM "s2"."pets" WHERE "name" = $2)`, wantArgs: []any{"mashraki", "pedro"}, }, { name: "schema/M2M", step: func() *Step { step := NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ) step.To.Schema = "s3" step.Edge.Schema = "s2" return step }(), selector: sql.Dialect("postgres").Select("*").From(sql.Table("users").Schema("s1")), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "GitHub")) }, wantQuery: ` SELECT * FROM "s1"."users" WHERE "s1"."users"."id" IN (SELECT "s2"."user_groups"."user_id" FROM "s2"."user_groups" JOIN "s3"."groups" AS "t1" ON "s2"."user_groups"."group_id" = "t1"."id" WHERE "name" = $1)`, wantArgs: []any{"GitHub"}, }, { name: "O2M/selector", step: NewStep( From("users", "id"), To("pets", "id"), Edge(O2M, false, "pets", "owner_id"), ), selector: sql.Dialect("postgres").Select("*"). From(sql.Select("*").From(sql.Table("users")).As("users")). Where(sql.EQ("last_name", "mashraki")).As("users"), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "pedro")) }, wantQuery: `SELECT * FROM (SELECT * FROM "users") AS "users" WHERE "last_name" = $1 AND "users"."id" IN (SELECT "pets"."owner_id" FROM "pets" WHERE "name" = $2)`, wantArgs: []any{"mashraki", "pedro"}, }, { name: "M2O/selector", step: NewStep( From("pets", "id"), To("users", "id"), Edge(M2O, true, "pets", "owner_id"), ), selector: sql.Dialect("postgres").Select("*"). From(sql.Select("*").From(sql.Table("pets")).As("pets")). Where(sql.EQ("name", "pedro")).As("pets"), predicate: func(s *sql.Selector) { s.Where(sql.EQ("last_name", "mashraki")) }, wantQuery: `SELECT * FROM (SELECT * FROM "pets") AS "pets" WHERE "name" = $1 AND "pets"."owner_id" IN (SELECT "users"."id" FROM "users" WHERE "last_name" = $2)`, wantArgs: []any{"pedro", "mashraki"}, }, { name: "M2M/selector", step: NewStep( From("users", "id"), To("groups", "id"), Edge(M2M, false, "user_groups", "user_id", "group_id"), ), selector: sql.Dialect("postgres").Select("*").From(sql.Select("*").From(sql.Table("users")).As("users")).As("users"), predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "GitHub")) }, wantQuery: `SELECT * FROM (SELECT * FROM "users") AS "users" WHERE "users"."id" IN (SELECT "user_groups"."user_id" FROM "user_groups" JOIN "groups" AS "t1" ON "user_groups"."group_id" = "t1"."id" WHERE "name" = $1)`, wantArgs: []any{"GitHub"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { for _, s := range []*sql.Selector{tt.selector, tt.selector.Clone()} { HasNeighborsWith(s, tt.step, tt.predicate) query, args := s.Query() tt.wantQuery = strings.Join(strings.Fields(tt.wantQuery), " ") require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) } }) } } func TestHasNeighborsWithContext(t *testing.T) { type key string ctx := context.WithValue(context.Background(), key("mykey"), "myval") for _, rel := range [...]Rel{M2M, O2M, O2O} { t.Run(rel.String(), func(t *testing.T) { sel := sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). WithContext(ctx) step := NewStep( From("users", "id"), To("groups", "id"), Edge(rel, false, "user_groups", "user_id", "group_id"), ) var called bool pred := func(s *sql.Selector) { called = true got := s.Context().Value(key("mykey")).(string) require.Equal(t, "myval", got) } HasNeighborsWith(sel, step, pred) require.True(t, called, "expected predicate function to be called") }) } } func TestCreateNode(t *testing.T) { tests := []struct { name string spec *CreateSpec expect func(sqlmock.Sqlmock) wantErr bool }{ { name: "fields", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "a8m"}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`) VALUES (?, ?)")). WithArgs(30, "a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) }, }, { name: "modifiers", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "a8m"}, }, OnConflict: []sql.ConflictOption{ sql.ResolveWithNewValues(), }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `age` = VALUES(`age`), `name` = VALUES(`name`), `id` = LAST_INSERT_ID(`users`.`id`)")). WithArgs(30, "a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) }, }, { name: "fields/user-defined-id", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Value: 1}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "a8m"}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectExec(escape("INSERT INTO `users` (`age`, `name`, `id`) VALUES (?, ?, ?)")). WithArgs(30, "a8m", 1). WillReturnResult(sqlmock.NewResult(1, 1)) }, }, { name: "fields/json", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "json", Type: field.TypeJSON, Value: struct{}{}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectExec(escape("INSERT INTO `users` (`json`) VALUES (?)")). WithArgs([]byte("{}")). WillReturnResult(sqlmock.NewResult(1, 1)) }, }, { name: "edges/m2o", spec: &CreateSpec{ Table: "pets", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "pedro"}, }, Edges: []*EdgeSpec{ {Rel: M2O, Columns: []string{"owner_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectExec(escape("INSERT INTO `pets` (`name`, `owner_id`) VALUES (?, ?)")). WithArgs("pedro", 2). WillReturnResult(sqlmock.NewResult(1, 1)) }, }, { name: "edges/o2o/inverse", spec: &CreateSpec{ Table: "cards", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "number", Type: field.TypeString, Value: "0001"}, }, Edges: []*EdgeSpec{ {Rel: O2O, Columns: []string{"owner_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectExec(escape("INSERT INTO `cards` (`number`, `owner_id`) VALUES (?, ?)")). WithArgs("0001", 2). WillReturnResult(sqlmock.NewResult(1, 1)) }, }, { name: "edges/o2m", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "a8m"}, }, Edges: []*EdgeSpec{ {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/o2m", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "a8m"}, }, Edges: []*EdgeSpec{ {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2, 3, 4}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` IN (?, ?, ?) AND `owner_id` IS NULL")). WithArgs(1, 2, 3, 4). WillReturnResult(sqlmock.NewResult(1, 3)) m.ExpectCommit() }, }, { name: "edges/o2o", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "a8m"}, }, Edges: []*EdgeSpec{ {Rel: O2O, Table: "cards", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("UPDATE `cards` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/o2o/bidi", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "a8m"}, }, Edges: []*EdgeSpec{ {Rel: O2O, Bidi: true, Table: "users", Columns: []string{"spouse_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`, `spouse_id`) VALUES (?, ?)")). WithArgs("a8m", 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("UPDATE `users` SET `spouse_id` = ? WHERE `id` = ? AND `spouse_id` IS NULL")). WithArgs(1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m", spec: &CreateSpec{ Table: "groups", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "GitHub"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `groups` (`name`) VALUES (?)")). WithArgs("GitHub"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?)")). WithArgs(1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m/fields", spec: &CreateSpec{ Table: "groups", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "GitHub"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{{Column: "ts", Type: field.TypeInt, Value: 3}}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `groups` (`name`) VALUES (?)")). WithArgs("GitHub"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`, `ts`) VALUES (?, ?, ?)")). WithArgs(1, 2, 3). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m/inverse", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "mashraki"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("mashraki"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?)")). WithArgs(2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m/bidi", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "mashraki"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("mashraki"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")). WithArgs(1, 2, 2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m/bidi/fields", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "mashraki"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{{Column: "ts", Type: field.TypeInt, Value: 3}}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("mashraki"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`, `ts`) VALUES (?, ?, ?), (?, ?, ?)")). WithArgs(1, 2, 3, 2, 1, 3). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "edges/m2m/bidi/batch", spec: &CreateSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "mashraki"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{4}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{5}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")). WithArgs("mashraki"). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). WithArgs(4, 1, 5, 1). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). WithArgs(1, 2, 2, 1, 1, 3, 3, 1). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "schema", spec: &CreateSpec{ Table: "users", Schema: "mydb", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "a8m"}, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectExec(escape("INSERT INTO `mydb`.`users` (`age`, `name`) VALUES (?, ?)")). WithArgs(30, "a8m"). WillReturnResult(sqlmock.NewResult(1, 1)) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.expect(mock) err = CreateNode(context.Background(), sql.OpenDB(dialect.MySQL, db), tt.spec) require.Equal(t, tt.wantErr, err != nil, err) }) } } func TestBatchCreate(t *testing.T) { tests := []struct { name string spec *BatchCreateSpec expect func(sqlmock.Sqlmock) wantErr bool }{ { name: "empty", spec: &BatchCreateSpec{}, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectCommit() }, }, { name: "fields with modifiers", spec: &BatchCreateSpec{ Nodes: []*CreateSpec{ { Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 32}, {Column: "name", Type: field.TypeString, Value: "a8m"}, {Column: "active", Type: field.TypeBool, Value: false}, }, }, { Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "nati"}, {Column: "active", Type: field.TypeBool, Value: true}, }, }, }, OnConflict: []sql.ConflictOption{ sql.ResolveWithIgnore(), }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectExec(escape("INSERT INTO `users` (`active`, `age`, `name`) VALUES (?, ?, ?), (?, ?, ?) ON DUPLICATE KEY UPDATE `active` = `users`.`active`, `age` = `users`.`age`, `name` = `users`.`name`")). WithArgs(false, 32, "a8m", true, 30, "nati"). WillReturnResult(sqlmock.NewResult(10, 2)) }, }, { name: "no tx", spec: &BatchCreateSpec{ Nodes: []*CreateSpec{ { Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 32}, {Column: "name", Type: field.TypeString, Value: "a8m"}, {Column: "active", Type: field.TypeBool, Value: false}, }, Edges: []*EdgeSpec{ {Rel: M2O, Table: "company", Columns: []string{"workplace_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, {Rel: O2O, Inverse: true, Table: "users", Columns: []string{"best_friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, { Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "nati"}, }, Edges: []*EdgeSpec{ {Rel: M2O, Table: "company", Columns: []string{"workplace_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, {Rel: O2O, Inverse: true, Table: "users", Columns: []string{"best_friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{4}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, }, }, expect: func(m sqlmock.Sqlmock) { // Insert nodes with FKs. m.ExpectExec(escape("INSERT INTO `users` (`active`, `age`, `best_friend_id`, `name`, `workplace_id`) VALUES (?, ?, ?, ?, ?), (NULL, ?, ?, ?, ?)")). WithArgs(false, 32, 3, "a8m", 2, 30, 4, "nati", 2). WillReturnResult(sqlmock.NewResult(10, 2)) }, }, { name: "with tx", spec: &BatchCreateSpec{ Nodes: []*CreateSpec{ { Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "a8m"}, }, Edges: []*EdgeSpec{ {Rel: O2O, Table: "cards", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, { Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "name", Type: field.TypeString, Value: "nati"}, }, Edges: []*EdgeSpec{ {Rel: O2O, Table: "cards", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{4}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?), (?)")). WithArgs("a8m", "nati"). WillReturnResult(sqlmock.NewResult(10, 2)) m.ExpectExec(escape("UPDATE `cards` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(10 /* LAST_INSERT_ID() */, 3). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectExec(escape("UPDATE `cards` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(11 /* LAST_INSERT_ID() + 1 */, 4). WillReturnResult(sqlmock.NewResult(1, 1)) m.ExpectCommit() }, }, { name: "multiple", spec: &BatchCreateSpec{ Nodes: []*CreateSpec{ { Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 32}, {Column: "name", Type: field.TypeString, Value: "a8m"}, {Column: "active", Type: field.TypeBool, Value: false}, }, Edges: []*EdgeSpec{ {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Table: "user_products", Columns: []string{"user_id", "product_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id", Type: field.TypeInt}, Nodes: []driver.Value{2}}}, {Rel: M2O, Table: "company", Columns: []string{"workplace_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, { Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, Fields: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "nati"}, }, Edges: []*EdgeSpec{ {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Table: "user_products", Columns: []string{"user_id", "product_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}}}, {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id", Type: field.TypeInt}, Nodes: []driver.Value{2}}}, {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{3}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, }, }, expect: func(m sqlmock.Sqlmock) { m.ExpectBegin() // Insert nodes with FKs. m.ExpectExec(escape("INSERT INTO `users` (`active`, `age`, `name`, `workplace_id`) VALUES (?, ?, ?, ?), (NULL, ?, ?, NULL)")). WithArgs(false, 32, "a8m", 2, 30, "nati"). WillReturnResult(sqlmock.NewResult(10, 2)) // Insert M2M inverse-edges. m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). WithArgs(2, 10, 2, 11). WillReturnResult(sqlmock.NewResult(2, 2)) // Insert M2M bidirectional edges. m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). WithArgs(10, 2, 2, 10, 11, 2, 2, 11). WillReturnResult(sqlmock.NewResult(2, 2)) // Insert M2M edges. m.ExpectExec(escape("INSERT INTO `user_products` (`user_id`, `product_id`) VALUES (?, ?), (?, ?)")). WithArgs(10, 2, 11, 2). WillReturnResult(sqlmock.NewResult(2, 2)) // Update FKs exist in different tables. m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(10 /* id of the 1st new node */, 2 /* pet id */). WillReturnResult(sqlmock.NewResult(2, 2)) m.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(11 /* id of the 2nd new node */, 3 /* pet id */). WillReturnResult(sqlmock.NewResult(2, 2)) m.ExpectCommit() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.expect(mock) err = BatchCreate(context.Background(), sql.OpenDB("mysql", db), tt.spec) require.Equal(t, tt.wantErr, err != nil, err) }) } } type user struct { id int age int name string edges struct { fk1 int fk2 int } } func (*user) values(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch c := columns[i]; c { case "id", "age", "fk1", "fk2": values[i] = &sql.NullInt64{} case "name": values[i] = &sql.NullString{} default: return nil, fmt.Errorf("unexpected column %q", c) } } return values, nil } func (u *user) assign(columns []string, values []any) error { if len(columns) != len(values) { return fmt.Errorf("mismatch number of values") } for i, c := range columns { switch c { case "id": u.id = int(values[i].(*sql.NullInt64).Int64) case "age": u.age = int(values[i].(*sql.NullInt64).Int64) case "name": u.name = values[i].(*sql.NullString).String case "fk1": u.edges.fk1 = int(values[i].(*sql.NullInt64).Int64) case "fk2": u.edges.fk2 = int(values[i].(*sql.NullInt64).Int64) default: return fmt.Errorf("unknown column %q", c) } } return nil } func TestUpdateNode(t *testing.T) { tests := []struct { name string spec *UpdateSpec prepare func(sqlmock.Sqlmock) wantErr bool wantUser *user }{ { name: "fields/set", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Fields: FieldMut{ Set: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "Ariel"}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ? WHERE `id` = ?")). WithArgs(30, "Ariel", 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 30, "Ariel")) mock.ExpectCommit() }, wantUser: &user{name: "Ariel", age: 30, id: 1}, }, { name: "fields/set_modifier", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Modifiers: []func(*sql.UpdateBuilder){ func(u *sql.UpdateBuilder) { u.Set("name", sql.Expr(sql.Lower("name"))) }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `users` SET `name` = LOWER(`name`) WHERE `id` = ?")). WithArgs(1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 30, "Ariel")) mock.ExpectCommit() }, wantUser: &user{name: "Ariel", age: 30, id: 1}, }, { name: "fields/add_set_clear", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Predicate: func(s *sql.Selector) { s.Where(sql.EQ("deleted", false)) }, Fields: FieldMut{ Add: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 1}, }, Set: []*FieldSpec{ {Column: "deleted", Type: field.TypeBool, Value: true}, }, Clear: []*FieldSpec{ {Column: "name", Type: field.TypeString}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `deleted` = ?, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")). WithArgs(true, 1, 1). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 31, nil)) mock.ExpectCommit() }, wantUser: &user{age: 31, id: 1}, }, { name: "fields/ensure_exists", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Predicate: func(s *sql.Selector) { s.Where(sql.EQ("deleted", false)) }, Fields: FieldMut{ Add: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 1}, }, Set: []*FieldSpec{ {Column: "deleted", Type: field.TypeBool, Value: true}, }, Clear: []*FieldSpec{ {Column: "name", Type: field.TypeString}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `users` SET `name` = NULL, `deleted` = ?, `age` = COALESCE(`users`.`age`, 0) + ? WHERE `id` = ? AND NOT `deleted`")). WithArgs(true, 1, 1). WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectQuery(escape("SELECT EXISTS (SELECT * FROM `users` WHERE `id` = ? AND NOT `deleted`)")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"exists"}). AddRow(false)) mock.ExpectRollback() }, wantErr: true, wantUser: &user{}, }, { name: "edges/o2o_non_inverse and m2o", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: O2O, Columns: []string{"car_id"}, Inverse: true}, {Rel: M2O, Columns: []string{"workplace_id"}, Inverse: true}, }, Add: []*EdgeSpec{ {Rel: O2O, Columns: []string{"card_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, {Rel: M2O, Columns: []string{"parent_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `users` SET `workplace_id` = NULL, `car_id` = NULL, `parent_id` = ?, `card_id` = ? WHERE `id` = ?")). WithArgs(2, 2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 31, nil)) mock.ExpectCommit() }, wantUser: &user{age: 31, id: 1}, }, { name: "edges/o2o_bidi", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"partner_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}}, {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id", Type: field.TypeInt}, Nodes: []driver.Value{2}}}, }, Add: []*EdgeSpec{ {Rel: O2O, Table: "users", Bidi: true, Columns: []string{"spouse_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id", Type: field.TypeInt}, Nodes: []driver.Value{3}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Clear the "partner" from 1's column, and set "spouse 3". // "spouse 2" is implicitly removed when setting a different foreign-key. mock.ExpectExec(escape("UPDATE `users` SET `partner_id` = NULL, `spouse_id` = ? WHERE `id` = ?")). WithArgs(3, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Clear the "partner_id" column from previous 1's partner. mock.ExpectExec(escape("UPDATE `users` SET `partner_id` = NULL WHERE `partner_id` = ?")). WithArgs(1). WillReturnResult(sqlmock.NewResult(1, 1)) // Clear "spouse 1" from 3's column. mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = NULL WHERE `id` = ? AND `spouse_id` = ?")). WithArgs(2, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Set 3's column to point "spouse 1". mock.ExpectExec(escape("UPDATE `users` SET `spouse_id` = ? WHERE `id` = ? AND `spouse_id` IS NULL")). WithArgs(1, 3). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 31, nil)) mock.ExpectCommit() }, wantUser: &user{age: 31, id: 1}, }, { name: "edges/clear_add_m2m", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id", Type: field.TypeInt}, Nodes: []driver.Value{2}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id", Type: field.TypeInt}, Nodes: []driver.Value{3, 7}}}, // Clear all "following" edges (and their inverse). {Rel: M2M, Table: "user_following", Bidi: true, Columns: []string{"following_id", "follower_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}}, // Clear all "user_blocked" edges. {Rel: M2M, Table: "user_blocked", Columns: []string{"user_id", "blocked_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}}, // Clear all "comments" edges. {Rel: M2M, Inverse: true, Table: "comment_responders", Columns: []string{"comment_id", "responder_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id"}}}, }, Add: []*EdgeSpec{ {Rel: M2M, Table: "user_friends", Bidi: true, Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id", Type: field.TypeInt}, Nodes: []driver.Value{4}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id", Type: field.TypeInt}, Nodes: []driver.Value{5}}}, {Rel: M2M, Inverse: true, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{IDSpec: &FieldSpec{Column: "id", Type: field.TypeInt}, Nodes: []driver.Value{6, 8}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Clear comment responders. mock.ExpectExec(escape("DELETE FROM `comment_responders` WHERE `responder_id` = ?")). WithArgs(1). WillReturnResult(sqlmock.NewResult(1, 1)) // Remove user groups. mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `user_id` = ? AND `group_id` IN (?, ?)")). WithArgs(1, 3, 7). WillReturnResult(sqlmock.NewResult(1, 1)) // Clear all blocked users. mock.ExpectExec(escape("DELETE FROM `user_blocked` WHERE `user_id` = ?")). WithArgs(1). WillReturnResult(sqlmock.NewResult(1, 1)) // Clear all user following. mock.ExpectExec(escape("DELETE FROM `user_following` WHERE `following_id` = ? OR `follower_id` = ?")). WithArgs(1, 1). WillReturnResult(sqlmock.NewResult(1, 2)) // Clear user friends. mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` = ? AND `friend_id` = ?) OR (`friend_id` = ? AND `user_id` = ?)")). WithArgs(1, 2, 1, 2). WillReturnResult(sqlmock.NewResult(1, 1)) // Add new groups. mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?), (?, ?)")). WithArgs(5, 1, 6, 1, 8, 1). WillReturnResult(sqlmock.NewResult(1, 1)) // Add new friends. mock.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`) VALUES (?, ?), (?, ?)")). WithArgs(1, 4, 4, 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 31, nil)) mock.ExpectCommit() }, wantUser: &user{age: 31, id: 1}, }, { name: "schema/fields/set", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", Schema: "mydb", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Fields: FieldMut{ Set: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "Ariel"}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `mydb`.`users` SET `age` = ?, `name` = ? WHERE `id` = ?")). WithArgs(30, "Ariel", 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectQuery(escape("SELECT `id`, `name`, `age` FROM `mydb`.`users` WHERE `id` = ?")). WithArgs(1). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name"}). AddRow(1, 30, "Ariel")) mock.ExpectCommit() }, wantUser: &user{name: "Ariel", age: 30, id: 1}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.prepare(mock) usr := &user{} tt.spec.Assign = usr.assign tt.spec.ScanValues = usr.values err = UpdateNode(context.Background(), sql.OpenDB("", db), tt.spec) require.Equal(t, tt.wantErr, err != nil, err) require.Equal(t, tt.wantUser, usr) }) } } func TestExecUpdateNode(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectBegin() mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ? WHERE `id` = ?")). WithArgs(30, "Ariel", 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() err = UpdateNode(context.Background(), sql.OpenDB("", db), &UpdateSpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "name", "age"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt, Value: 1}, }, Fields: FieldMut{ Set: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "Ariel"}, }, }, }) require.NoError(t, err) } func TestUpdateNodes(t *testing.T) { tests := []struct { name string spec *UpdateSpec prepare func(sqlmock.Sqlmock) wantErr bool wantAffected int }{ { name: "without predicate", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Fields: FieldMut{ Set: []*FieldSpec{ {Column: "age", Type: field.TypeInt, Value: 30}, {Column: "name", Type: field.TypeString, Value: "Ariel"}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { // Apply field changes. mock.ExpectExec(escape("UPDATE `users` SET `age` = ?, `name` = ?")). WithArgs(30, "Ariel"). WillReturnResult(sqlmock.NewResult(0, 2)) }, wantAffected: 2, }, { name: "with predicate", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Fields: FieldMut{ Clear: []*FieldSpec{ {Column: "age", Type: field.TypeInt}, {Column: "name", Type: field.TypeString}, }, }, Predicate: func(s *sql.Selector) { s.Where(sql.EQ("name", "a8m")) }, }, prepare: func(mock sqlmock.Sqlmock) { // Clear fields. mock.ExpectExec(escape("UPDATE `users` SET `age` = NULL, `name` = NULL WHERE `name` = ?")). WithArgs("a8m"). WillReturnResult(sqlmock.NewResult(0, 1)) }, wantAffected: 1, }, { name: "with modifier", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Modifiers: []func(*sql.UpdateBuilder){ func(u *sql.UpdateBuilder) { u.Set("id", sql.Expr("id + 1")).OrderBy("id") }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectExec(escape("UPDATE `users` SET `id` = id + 1 ORDER BY `id`")). WillReturnResult(sqlmock.NewResult(0, 1)) }, wantAffected: 1, }, { name: "own_fks/m2o_o2o_inverse", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: O2O, Columns: []string{"car_id"}, Inverse: true}, {Rel: M2O, Columns: []string{"workplace_id"}, Inverse: true}, }, Add: []*EdgeSpec{ {Rel: O2O, Columns: []string{"card_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{3}}}, {Rel: M2O, Columns: []string{"parent_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{4}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { // Clear "car" and "workplace" foreign_keys and add "card" and a "parent". mock.ExpectExec(escape("UPDATE `users` SET `workplace_id` = NULL, `car_id` = NULL, `parent_id` = ?, `card_id` = ?")). WithArgs(4, 3). WillReturnResult(sqlmock.NewResult(0, 3)) }, wantAffected: 3, }, { name: "o2m", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Fields: FieldMut{ Add: []*FieldSpec{ {Column: "version", Type: field.TypeInt, Value: 1}, }, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: O2M, Table: "cards", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{20, 30}, IDSpec: &FieldSpec{Column: "id"}}}, }, Add: []*EdgeSpec{ {Rel: O2M, Table: "pets", Columns: []string{"owner_id"}, Target: &EdgeTarget{Nodes: []driver.Value{40}, IDSpec: &FieldSpec{Column: "id"}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Get all node ids first. mock.ExpectQuery(escape("SELECT `id` FROM `users`")). WillReturnRows(sqlmock.NewRows([]string{"id"}). AddRow(10)) mock.ExpectExec(escape("UPDATE `users` SET `version` = COALESCE(`users`.`version`, 0) + ? WHERE `id` = ?")). WithArgs(1, 10). WillReturnResult(sqlmock.NewResult(0, 1)) // Clear "owner_id" column in the "cards" table. mock.ExpectExec(escape("UPDATE `cards` SET `owner_id` = NULL WHERE `id` IN (?, ?) AND `owner_id` = ?")). WithArgs(20, 30, 10). WillReturnResult(sqlmock.NewResult(0, 2)) // Set "owner_id" column in the "pets" table. mock.ExpectExec(escape("UPDATE `pets` SET `owner_id` = ? WHERE `id` = ? AND `owner_id` IS NULL")). WithArgs(10, 40). WillReturnResult(sqlmock.NewResult(0, 2)) mock.ExpectCommit() }, wantAffected: 1, }, { name: "m2m_one", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2, 3}}}, {Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{5, 6}}}, {Rel: M2M, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{4}}}, }, Add: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{7, 8}}}, {Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{9}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Get all node ids first. mock.ExpectQuery(escape("SELECT `id` FROM `users`")). WillReturnRows(sqlmock.NewRows([]string{"id"}). AddRow(1)) // Clear user's groups. mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `user_id` = ? AND `group_id` IN (?, ?)")). WithArgs(1, 2, 3). WillReturnResult(sqlmock.NewResult(0, 2)) // Clear user's followers. mock.ExpectExec(escape("DELETE FROM `user_followers` WHERE (`user_id` = ? AND `follower_id` IN (?, ?)) OR (`follower_id` = ? AND `user_id` IN (?, ?))")). WithArgs(1, 5, 6, 1, 5, 6). WillReturnResult(sqlmock.NewResult(0, 2)) // Clear user's friends. mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` = ? AND `friend_id` = ?) OR (`friend_id` = ? AND `user_id` = ?)")). WithArgs(1, 4, 1, 4). WillReturnResult(sqlmock.NewResult(0, 2)) // Attach new groups to user. mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?)")). WithArgs(7, 1, 8, 1). WillReturnResult(sqlmock.NewResult(0, 2)) // Attach new friends to user. mock.ExpectExec(escape("INSERT INTO `user_followers` (`user_id`, `follower_id`) VALUES (?, ?), (?, ?)")). WithArgs(1, 9, 9, 1). WillReturnResult(sqlmock.NewResult(0, 2)) mock.ExpectCommit() }, wantAffected: 1, }, { name: "m2m_many", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Edges: EdgeMut{ Clear: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{2, 3}}}, {Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{5, 6}}}, {Rel: M2M, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{4}}}, }, Add: []*EdgeSpec{ {Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Inverse: true, Target: &EdgeTarget{Nodes: []driver.Value{7, 8}}}, {Rel: M2M, Table: "user_followers", Columns: []string{"user_id", "follower_id"}, Bidi: true, Target: &EdgeTarget{Nodes: []driver.Value{9}}}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectBegin() // Get all node ids first. mock.ExpectQuery(escape("SELECT `id` FROM `users`")). WillReturnRows(sqlmock.NewRows([]string{"id"}). AddRow(10). AddRow(20)) // Clear user's groups. mock.ExpectExec(escape("DELETE FROM `group_users` WHERE `user_id` IN (?, ?) AND `group_id` IN (?, ?)")). WithArgs(10, 20, 2, 3). WillReturnResult(sqlmock.NewResult(0, 2)) // Clear user's followers. mock.ExpectExec(escape("DELETE FROM `user_followers` WHERE (`user_id` IN (?, ?) AND `follower_id` IN (?, ?)) OR (`follower_id` IN (?, ?) AND `user_id` IN (?, ?))")). WithArgs(10, 20, 5, 6, 10, 20, 5, 6). WillReturnResult(sqlmock.NewResult(0, 2)) // Clear user's friends. mock.ExpectExec(escape("DELETE FROM `user_friends` WHERE (`user_id` IN (?, ?) AND `friend_id` = ?) OR (`friend_id` IN (?, ?) AND `user_id` = ?)")). WithArgs(10, 20, 4, 10, 20, 4). WillReturnResult(sqlmock.NewResult(0, 2)) // Attach new groups to user. mock.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). WithArgs(7, 10, 7, 20, 8, 10, 8, 20). WillReturnResult(sqlmock.NewResult(0, 4)) // Attach new friends to user. mock.ExpectExec(escape("INSERT INTO `user_followers` (`user_id`, `follower_id`) VALUES (?, ?), (?, ?), (?, ?), (?, ?)")). WithArgs(10, 9, 9, 10, 20, 9, 9, 20). WillReturnResult(sqlmock.NewResult(0, 4)) mock.ExpectCommit() }, wantAffected: 2, }, { name: "m2m_edge_schema", spec: &UpdateSpec{ Node: &NodeSpec{ Table: "users", CompositeID: []*FieldSpec{{Column: "user_id", Type: field.TypeInt}, {Column: "group_id", Type: field.TypeInt}}, }, Predicate: func(s *sql.Selector) { s.Where(sql.EQ("version", 1)) }, Fields: FieldMut{ Add: []*FieldSpec{ {Column: "version", Type: field.TypeInt, Value: 1}, }, }, }, prepare: func(mock sqlmock.Sqlmock) { mock.ExpectExec(escape("UPDATE `users` SET `version` = COALESCE(`users`.`version`, 0) + ? WHERE `version` = ?")). WithArgs(1, 1). WillReturnResult(sqlmock.NewResult(0, 4)) }, wantAffected: 4, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) tt.prepare(mock) affected, err := UpdateNodes(context.Background(), sql.OpenDB("", db), tt.spec) require.Equal(t, tt.wantErr, err != nil, err) require.Equal(t, tt.wantAffected, affected) }) } } func TestDeleteNodes(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectExec(escape("DELETE FROM `users`")). WillReturnResult(sqlmock.NewResult(0, 2)) affected, err := DeleteNodes(context.Background(), sql.OpenDB("", db), &DeleteSpec{ Node: &NodeSpec{ Table: "users", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, }) require.NoError(t, err) require.Equal(t, 2, affected) } func TestDeleteNodesSchema(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectExec(escape("DELETE FROM `mydb`.`users`")). WillReturnResult(sqlmock.NewResult(0, 2)) affected, err := DeleteNodes(context.Background(), sql.OpenDB("", db), &DeleteSpec{ Node: &NodeSpec{ Table: "users", Schema: "mydb", ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, }) require.NoError(t, err) require.Equal(t, 2, affected) } func TestQueryNodes(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectQuery(escape("SELECT DISTINCT `users`.`id`, `users`.`age`, `users`.`name`, `users`.`fk1`, `users`.`fk2` FROM `users` WHERE `age` < ? ORDER BY `id` LIMIT 3 OFFSET 4 FOR UPDATE NOWAIT")). WithArgs(40). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name", "fk1", "fk2"}). AddRow(1, 10, nil, nil, nil). AddRow(2, 20, "", 0, 0). AddRow(3, 30, "a8m", 1, 1)) mock.ExpectQuery(escape("SELECT COUNT(DISTINCT `users`.`id`) FROM `users` WHERE `age` < ? ORDER BY `id` LIMIT 3 OFFSET 4 FOR UPDATE NOWAIT")). WithArgs(40). WillReturnRows(sqlmock.NewRows([]string{"COUNT"}). AddRow(3)) mock.ExpectQuery(escape("SELECT COUNT(DISTINCT `users`.`name`) FROM `users` WHERE `age` < ? ORDER BY `id` LIMIT 3 OFFSET 4 FOR UPDATE NOWAIT")). WithArgs(40). WillReturnRows(sqlmock.NewRows([]string{"COUNT"}). AddRow(3)) var ( users []*user spec = &QuerySpec{ Node: &NodeSpec{ Table: "users", Columns: []string{"id", "age", "name", "fk1", "fk2"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Limit: 3, Offset: 4, Unique: true, Order: func(s *sql.Selector) { s.OrderBy("id") }, Predicate: func(s *sql.Selector) { s.Where(sql.LT("age", 40)) }, Modifiers: []func(*sql.Selector){ func(s *sql.Selector) { s.ForUpdate(sql.WithLockAction(sql.NoWait)) }, }, ScanValues: func(columns []string) ([]any, error) { u := &user{} users = append(users, u) return u.values(columns) }, Assign: func(columns []string, values []any) error { return users[len(users)-1].assign(columns, values) }, } ) // Query and scan. err = QueryNodes(context.Background(), sql.OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, &user{id: 1, age: 10, name: ""}, users[0]) require.Equal(t, &user{id: 2, age: 20, name: ""}, users[1]) require.Equal(t, &user{id: 3, age: 30, name: "a8m", edges: struct{ fk1, fk2 int }{1, 1}}, users[2]) // Count nodes. spec.Node.Columns = nil n, err := CountNodes(context.Background(), sql.OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, 3, n) // Count nodes. spec.Node.Columns = []string{"name"} n, err = CountNodes(context.Background(), sql.OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, 3, n) } func TestQueryNodesSchema(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectQuery(escape("SELECT DISTINCT `mydb`.`users`.`id`, `mydb`.`users`.`age`, `mydb`.`users`.`name`, `mydb`.`users`.`fk1`, `mydb`.`users`.`fk2` FROM `mydb`.`users` WHERE `age` < ? ORDER BY `id` LIMIT 3 OFFSET 4")). WithArgs(40). WillReturnRows(sqlmock.NewRows([]string{"id", "age", "name", "fk1", "fk2"}). AddRow(1, 10, nil, nil, nil). AddRow(2, 20, "", 0, 0). AddRow(3, 30, "a8m", 1, 1)) var ( users []*user spec = &QuerySpec{ Node: &NodeSpec{ Table: "users", Schema: "mydb", Columns: []string{"id", "age", "name", "fk1", "fk2"}, ID: &FieldSpec{Column: "id", Type: field.TypeInt}, }, Limit: 3, Offset: 4, Unique: true, Order: func(s *sql.Selector) { s.OrderBy("id") }, Predicate: func(s *sql.Selector) { s.Where(sql.LT("age", 40)) }, ScanValues: func(columns []string) ([]any, error) { u := &user{} users = append(users, u) return u.values(columns) }, Assign: func(columns []string, values []any) error { return users[len(users)-1].assign(columns, values) }, } ) // Query and scan. err = QueryNodes(context.Background(), sql.OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, &user{id: 1, age: 10, name: ""}, users[0]) require.Equal(t, &user{id: 2, age: 20, name: ""}, users[1]) require.Equal(t, &user{id: 3, age: 30, name: "a8m", edges: struct{ fk1, fk2 int }{1, 1}}, users[2]) } func TestQueryEdges(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectQuery(escape("SELECT `group_id`, `user_id` FROM `user_groups` WHERE `user_id` IN (?, ?, ?)")). WithArgs(1, 2, 3). WillReturnRows(sqlmock.NewRows([]string{"group_id", "user_id"}). AddRow(4, 5). AddRow(4, 6)) var ( edges [][]int64 spec = &EdgeQuerySpec{ Edge: &EdgeSpec{ Inverse: true, Table: "user_groups", Columns: []string{"user_id", "group_id"}, }, Predicate: func(s *sql.Selector) { s.Where(sql.InValues("user_id", 1, 2, 3)) }, ScanValues: func() [2]any { return [2]any{&sql.NullInt64{}, &sql.NullInt64{}} }, Assign: func(out, in any) error { o, i := out.(*sql.NullInt64), in.(*sql.NullInt64) edges = append(edges, []int64{o.Int64, i.Int64}) return nil }, } ) // Query and scan. err = QueryEdges(context.Background(), sql.OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, [][]int64{{4, 5}, {4, 6}}, edges) } func TestQueryEdgesSchema(t *testing.T) { db, mock, err := sqlmock.New() require.NoError(t, err) mock.ExpectQuery(escape("SELECT `group_id`, `user_id` FROM `mydb`.`user_groups` WHERE `user_id` IN (?, ?, ?)")). WithArgs(1, 2, 3). WillReturnRows(sqlmock.NewRows([]string{"group_id", "user_id"}). AddRow(4, 5). AddRow(4, 6)) var ( edges [][]int64 spec = &EdgeQuerySpec{ Edge: &EdgeSpec{ Inverse: true, Table: "user_groups", Schema: "mydb", Columns: []string{"user_id", "group_id"}, }, Predicate: func(s *sql.Selector) { s.Where(sql.InValues("user_id", 1, 2, 3)) }, ScanValues: func() [2]any { return [2]any{&sql.NullInt64{}, &sql.NullInt64{}} }, Assign: func(out, in any) error { o, i := out.(*sql.NullInt64), in.(*sql.NullInt64) edges = append(edges, []int64{o.Int64, i.Int64}) return nil }, } ) // Query and scan. err = QueryEdges(context.Background(), sql.OpenDB("", db), spec) require.NoError(t, err) require.Equal(t, [][]int64{{4, 5}, {4, 6}}, edges) } func TestIsConstraintError(t *testing.T) { tests := []struct { name string errMessage string expectedConstraint bool expectedFK bool expectedUnique bool }{ { name: "MySQL FK", errMessage: `insert node to table "pets": Error 1452: Cannot add or update a child row: a foreign key` + " constraint fails (`test`.`pets`, CONSTRAINT `pets_users_pets` FOREIGN KEY (`user_pets`) REFERENCES " + "`users` (`id`) ON DELETE SET NULL)", expectedConstraint: true, expectedFK: true, expectedUnique: false, }, { name: "SQLite FK", errMessage: `insert node to table "pets": FOREIGN KEY constraint failed`, expectedConstraint: true, expectedFK: true, expectedUnique: false, }, { name: "Postgres FK", errMessage: `insert node to table "pets": pq: insert or update on table "pets" violates foreign key constraint "pets_users_pets"`, expectedConstraint: true, expectedFK: true, expectedUnique: false, }, { name: "MySQL FK", errMessage: "Error 1451: Cannot delete or update a parent row: a foreign key constraint " + "fails (`test`.`groups`, CONSTRAINT `groups_group_infos_info` FOREIGN KEY (`group_info`) REFERENCES `group_infos` (`id`))", expectedConstraint: true, expectedFK: true, expectedUnique: false, }, { name: "SQLite FK", errMessage: `FOREIGN KEY constraint failed`, expectedConstraint: true, expectedFK: true, expectedUnique: false, }, { name: "Postgres FK", errMessage: `pq: update or delete on table "group_infos" violates foreign key constraint "groups_group_infos_info" on table "groups"`, expectedConstraint: true, expectedFK: true, expectedUnique: false, }, { name: "MySQL Unique", errMessage: `insert node to table "file_types": UNIQUE constraint failed: file_types.name ent: constraint failed: insert node to table "file_types": UNIQUE constraint failed: file_types.name`, expectedConstraint: true, expectedFK: false, expectedUnique: true, }, { name: "SQLite Unique", errMessage: `insert node to table "file_types": UNIQUE constraint failed: file_types.name ent: constraint failed: insert node to table "file_types": UNIQUE constraint failed: file_types.name`, expectedConstraint: true, expectedFK: false, expectedUnique: true, }, { name: "Postgres Unique", errMessage: `insert node to table "file_types": pq: duplicate key value violates unique constraint "file_types_name_key" ent: constraint failed: insert node to table "file_types": pq: duplicate key value violates unique constraint "file_types_name_key"`, expectedConstraint: true, expectedFK: false, expectedUnique: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := errors.New(tt.errMessage) require.EqualValues(t, tt.expectedConstraint, IsConstraintError(err)) require.EqualValues(t, tt.expectedFK, IsForeignKeyConstraintError(err)) require.EqualValues(t, tt.expectedUnique, IsUniqueConstraintError(err)) }) } } func escape(query string) string { rows := strings.Split(query, "\n") for i := range rows { rows[i] = strings.TrimPrefix(rows[i], " ") } query = strings.Join(rows, " ") return strings.TrimSpace(regexp.QuoteMeta(query)) + "$" } ent-0.11.3/dialect/sql/sqljson/000077500000000000000000000000001431500740500162565ustar00rootroot00000000000000ent-0.11.3/dialect/sql/sqljson/sqljson.go000066400000000000000000000377171431500740500203150ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqljson import ( "encoding/json" "fmt" "strings" "unicode" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" ) // HasKey return a predicate for checking that a JSON key // exists and not NULL. // // sqljson.HasKey("column", sql.DotPath("a.b[2].c")) func HasKey(column string, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { switch b.Dialect() { case dialect.SQLite: // JSON_TYPE returns NULL in case the path selects an element // that does not exist. See: https://sqlite.org/json1.html#jtype. path := identPath(column, opts...) path.mysqlFunc("JSON_TYPE", b) b.WriteOp(sql.OpNotNull) default: ValuePath(b, column, opts...) b.WriteOp(sql.OpNotNull) } }) } // ValueIsNull return a predicate for checking that a JSON value // (returned by the path) is a null literal (JSON "null"). // // In order to check if the column is NULL (database NULL), or if // the JSON key exists, use sql.IsNull or sqljson.HasKey. // // sqljson.ValueIsNull("a", sqljson.Path("b")) func ValueIsNull(column string, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { switch b.Dialect() { case dialect.MySQL: path := identPath(column, opts...) b.WriteString("JSON_CONTAINS").Nested(func(b *sql.Builder) { b.Ident(column).Comma() b.WriteString("'null'").Comma() path.mysqlPath(b) }) case dialect.Postgres: ValuePath(b, column, append(opts, Cast("jsonb"))...) b.WriteOp(sql.OpEQ).WriteString("'null'::jsonb") case dialect.SQLite: path := identPath(column, opts...) path.mysqlFunc("JSON_TYPE", b) b.WriteOp(sql.OpEQ).WriteString("'null'") } }) } // ValueEQ return a predicate for checking that a JSON value // (returned by the path) is equal to the given argument. // // sqljson.ValueEQ("a", 1, sqljson.Path("b")) func ValueEQ(column string, arg any, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpEQ).Arg(arg) }) } // ValueNEQ return a predicate for checking that a JSON value // (returned by the path) is not equal to the given argument. // // sqljson.ValueNEQ("a", 1, sqljson.Path("b")) func ValueNEQ(column string, arg any, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpNEQ).Arg(arg) }) } // ValueGT return a predicate for checking that a JSON value // (returned by the path) is greater than the given argument. // // sqljson.ValueGT("a", 1, sqljson.Path("b")) func ValueGT(column string, arg any, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpGT).Arg(arg) }) } // ValueGTE return a predicate for checking that a JSON value // (returned by the path) is greater than or equal to the given // argument. // // sqljson.ValueGTE("a", 1, sqljson.Path("b")) func ValueGTE(column string, arg any, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpGTE).Arg(arg) }) } // ValueLT return a predicate for checking that a JSON value // (returned by the path) is less than the given argument. // // sqljson.ValueLT("a", 1, sqljson.Path("b")) func ValueLT(column string, arg any, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpLT).Arg(arg) }) } // ValueLTE return a predicate for checking that a JSON value // (returned by the path) is less than or equal to the given // argument. // // sqljson.ValueLTE("a", 1, sqljson.Path("b")) func ValueLTE(column string, arg any, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts = normalizePG(b, arg, opts) ValuePath(b, column, opts...) b.WriteOp(sql.OpLTE).Arg(arg) }) } // ValueContains return a predicate for checking that a JSON // value (returned by the path) contains the given argument. // // sqljson.ValueContains("a", 1, sqljson.Path("b")) func ValueContains(column string, arg any, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { path := identPath(column, opts...) switch b.Dialect() { case dialect.MySQL: b.WriteString("JSON_CONTAINS").Nested(func(b *sql.Builder) { b.Ident(column).Comma() b.Arg(marshal(arg)).Comma() path.mysqlPath(b) }) b.WriteOp(sql.OpEQ).Arg(1) case dialect.SQLite: b.WriteString("EXISTS").Nested(func(b *sql.Builder) { b.WriteString("SELECT * FROM JSON_EACH").Nested(func(b *sql.Builder) { b.Ident(column).Comma() path.mysqlPath(b) }) b.WriteString(" WHERE ").Ident("value").WriteOp(sql.OpEQ).Arg(arg) }) case dialect.Postgres: opts = normalizePG(b, arg, opts) path.Cast = "jsonb" path.value(b) b.WriteString(" @> ").Arg(marshal(arg)) } }) } // StringHasPrefix return a predicate for checking that a JSON string value // (returned by the path) has the given substring as prefix func StringHasPrefix(column string, prefix string, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts = append([]Option{Unquote(true)}, opts...) ValuePath(b, column, opts...) b.Join(sql.HasPrefix("", prefix)) }) } // StringHasSuffix return a predicate for checking that a JSON string value // (returned by the path) has the given substring as suffix func StringHasSuffix(column string, suffix string, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts = append([]Option{Unquote(true)}, opts...) ValuePath(b, column, opts...) b.Join(sql.HasSuffix("", suffix)) }) } // StringContains return a predicate for checking that a JSON string value // (returned by the path) contains the given substring func StringContains(column string, sub string, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { opts = append([]Option{Unquote(true)}, opts...) ValuePath(b, column, opts...) b.Join(sql.Contains("", sub)) }) } // ValueIn return a predicate for checking that a JSON value // (returned by the path) is IN the given arguments. // // sqljson.ValueIn("a", []any{1, 2, 3}, sqljson.Path("b")) func ValueIn(column string, args []any, opts ...Option) *sql.Predicate { return valueInOp(column, args, opts, sql.OpIn) } // ValueNotIn return a predicate for checking that a JSON value // (returned by the path) is NOT IN the given arguments. // // sqljson.ValueNotIn("a", []any{1, 2, 3}, sqljson.Path("b")) func ValueNotIn(column string, args []any, opts ...Option) *sql.Predicate { if len(args) == 0 { return sql.NotIn(column) } return valueInOp(column, args, opts, sql.OpNotIn) } func valueInOp(column string, args []any, opts []Option, op sql.Op) *sql.Predicate { return sql.P(func(b *sql.Builder) { if allString(args) { opts = append(opts, Unquote(true)) } if len(args) > 0 { opts = normalizePG(b, args[0], opts) } ValuePath(b, column, opts...) b.WriteOp(op) b.Nested(func(b *sql.Builder) { if s, ok := args[0].(*sql.Selector); ok { b.Join(s) } else { b.Args(args...) } }) }) } // LenEQ return a predicate for checking that an array length // of a JSON (returned by the path) is equal to the given argument. // // sqljson.LenEQ("a", 1, sqljson.Path("b")) func LenEQ(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpEQ).Arg(size) }) } // LenNEQ return a predicate for checking that an array length // of a JSON (returned by the path) is not equal to the given argument. // // sqljson.LenEQ("a", 1, sqljson.Path("b")) func LenNEQ(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpNEQ).Arg(size) }) } // LenGT return a predicate for checking that an array length // of a JSON (returned by the path) is greater than the given // argument. // // sqljson.LenGT("a", 1, sqljson.Path("b")) func LenGT(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpGT).Arg(size) }) } // LenGTE return a predicate for checking that an array length // of a JSON (returned by the path) is greater than or equal to // the given argument. // // sqljson.LenGTE("a", 1, sqljson.Path("b")) func LenGTE(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpGTE).Arg(size) }) } // LenLT return a predicate for checking that an array length // of a JSON (returned by the path) is less than the given // argument. // // sqljson.LenLT("a", 1, sqljson.Path("b")) func LenLT(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpLT).Arg(size) }) } // LenLTE return a predicate for checking that an array length // of a JSON (returned by the path) is less than or equal to // the given argument. // // sqljson.LenLTE("a", 1, sqljson.Path("b")) func LenLTE(column string, size int, opts ...Option) *sql.Predicate { return sql.P(func(b *sql.Builder) { LenPath(b, column, opts...) b.WriteOp(sql.OpLTE).Arg(size) }) } // ValuePath writes to the given SQL builder the JSON path for // getting the value of a given JSON path. // // sqljson.ValuePath(b, Path("a", "b", "[1]", "c"), Cast("int")) func ValuePath(b *sql.Builder, column string, opts ...Option) { path := identPath(column, opts...) path.value(b) } // LenPath writes to the given SQL builder the JSON path for // getting the length of a given JSON path. // // sqljson.LenPath(b, Path("a", "b", "[1]", "c")) func LenPath(b *sql.Builder, column string, opts ...Option) { path := identPath(column, opts...) path.length(b) } // Option allows for calling database JSON paths with functional options. type Option func(*PathOptions) // Path sets the path to the JSON value of a column. // // ValuePath(b, "column", Path("a", "b", "[1]", "c")) func Path(path ...string) Option { return func(p *PathOptions) { p.Path = path } } // DotPath is similar to Path, but accepts string with dot format. // // ValuePath(b, "column", DotPath("a.b.c")) // ValuePath(b, "column", DotPath("a.b[2].c")) // // Note that DotPath is ignored if the input is invalid. func DotPath(dotpath string) Option { path, _ := ParsePath(dotpath) return func(p *PathOptions) { p.Path = path } } // Unquote indicates that the result value should be unquoted. // // ValuePath(b, "column", Path("a", "b", "[1]", "c"), Unquote(true)) func Unquote(unquote bool) Option { return func(p *PathOptions) { p.Unquote = unquote } } // Cast indicates that the result value should be casted to the given type. // // ValuePath(b, "column", Path("a", "b", "[1]", "c"), Cast("int")) func Cast(typ string) Option { return func(p *PathOptions) { p.Cast = typ } } // PathOptions holds the options for accessing a JSON value from an identifier. type PathOptions struct { Ident string Path []string Cast string Unquote bool } // identPath creates a PathOptions for the given identifier. func identPath(ident string, opts ...Option) *PathOptions { path := &PathOptions{Ident: ident} for i := range opts { opts[i](path) } return path } // value writes the path for getting the JSON value. func (p *PathOptions) value(b *sql.Builder) { switch { case len(p.Path) == 0: b.Ident(p.Ident) case b.Dialect() == dialect.Postgres: if p.Cast != "" { b.WriteByte('(') defer b.WriteString(")::" + p.Cast) } p.pgPath(b) default: if p.Unquote && b.Dialect() == dialect.MySQL { b.WriteString("JSON_UNQUOTE(") defer b.WriteByte(')') } p.mysqlFunc("JSON_EXTRACT", b) } } // value writes the path for getting the length of a JSON value. func (p *PathOptions) length(b *sql.Builder) { switch { case b.Dialect() == dialect.Postgres: b.WriteString("JSONB_ARRAY_LENGTH(") p.pgPath(b) b.WriteByte(')') case b.Dialect() == dialect.MySQL: p.mysqlFunc("JSON_LENGTH", b) default: p.mysqlFunc("JSON_ARRAY_LENGTH", b) } } // mysqlFunc writes the JSON path in MySQL format for the // given function. `JSON_EXTRACT("a", '$.b.c')`. func (p *PathOptions) mysqlFunc(fn string, b *sql.Builder) { b.WriteString(fn).WriteByte('(') b.Ident(p.Ident).Comma() p.mysqlPath(b) b.WriteByte(')') } // mysqlPath writes the JSON path in MySQL (or SQLite) format. func (p *PathOptions) mysqlPath(b *sql.Builder) { b.WriteString(`'$`) for _, p := range p.Path { switch _, isIndex := isJSONIdx(p); { case isIndex: b.WriteString(p) case p == "*" || isQuoted(p) || isIdentifier(p): b.WriteString("." + p) default: b.WriteString(`."` + p + `"`) } } b.WriteByte('\'') } // pgPath writes the JSON path in Postgres format `"a"->'b'->>'c'`. func (p *PathOptions) pgPath(b *sql.Builder) { b.Ident(p.Ident) for i, s := range p.Path { b.WriteString("->") if p.Unquote && i == len(p.Path)-1 { b.WriteString(">") } if idx, ok := isJSONIdx(s); ok { b.WriteString(idx) } else { b.WriteString("'" + s + "'") } } } // ParsePath parses the "dotpath" for the DotPath option. // // "a.b" => ["a", "b"] // "a[1][2]" => ["a", "[1]", "[2]"] // "a.\"b.c\" => ["a", "\"b.c\""] func ParsePath(dotpath string) ([]string, error) { var ( i, p int path []string ) for i < len(dotpath) { switch r := dotpath[i]; { case r == '"': if i == len(dotpath)-1 { return nil, fmt.Errorf("unexpected quote") } idx := strings.IndexRune(dotpath[i+1:], '"') if idx == -1 || idx == 0 { return nil, fmt.Errorf("unbalanced quote") } i += idx + 2 case r == '[': if p != i { path = append(path, dotpath[p:i]) } p = i if i == len(dotpath)-1 { return nil, fmt.Errorf("unexpected bracket") } idx := strings.IndexRune(dotpath[i:], ']') if idx == -1 || idx == 1 { return nil, fmt.Errorf("unbalanced bracket") } if !isNumber(dotpath[i+1 : i+idx]) { return nil, fmt.Errorf("invalid index %q", dotpath[i:i+idx+1]) } i += idx + 1 case r == '.' || r == ']': if p != i { path = append(path, dotpath[p:i]) } i++ p = i default: i++ } } if p != i { path = append(path, dotpath[p:i]) } return path, nil } // normalizePG adds cast option to the JSON path is the argument type is // not string, in order to avoid "missing type casts" error in Postgres. func normalizePG(b *sql.Builder, arg any, opts []Option) []Option { if b.Dialect() != dialect.Postgres { return opts } base := []Option{Unquote(true)} switch arg.(type) { case string: case bool: base = append(base, Cast("bool")) case float32, float64: base = append(base, Cast("float")) case int8, int16, int32, int64, int, uint8, uint16, uint32, uint64: base = append(base, Cast("int")) } return append(base, opts...) } func isIdentifier(name string) bool { if name == "" { return false } for i, c := range name { if !unicode.IsLetter(c) && c != '_' && (i == 0 || !unicode.IsDigit(c)) { return false } } return true } func isQuoted(s string) bool { if s == "" { return false } return s[0] == '"' && s[len(s)-1] == '"' } // isJSONIdx reports whether the string represents a JSON index. func isJSONIdx(s string) (string, bool) { if len(s) > 2 && s[0] == '[' && s[len(s)-1] == ']' && isNumber(s[1:len(s)-1]) { return s[1 : len(s)-1], true } return "", false } // isNumber reports whether the string is a number (category N). func isNumber(s string) bool { for _, r := range s { if !unicode.IsNumber(r) { return false } } return true } // allString reports if the slice contains only strings. func allString(v []any) bool { for i := range v { if _, ok := v[i].(string); !ok { return false } } return true } // marshal stringifies the given argument to a valid JSON document. func marshal(arg any) any { if buf, err := json.Marshal(arg); err == nil { arg = string(buf) } return arg } ent-0.11.3/dialect/sql/sqljson/sqljson_test.go000066400000000000000000000310661431500740500213430ustar00rootroot00000000000000// Copyright 2019-present Facebook Inc. All rights reserved. // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. package sqljson_test import ( "strconv" "testing" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqljson" "github.com/stretchr/testify/require" ) func TestWritePath(t *testing.T) { tests := []struct { input sql.Querier wantQuery string wantArgs []any }{ { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", 1, sqljson.Path("b", "c", "[1]", "d"), sqljson.Cast("int"))), wantQuery: `SELECT * FROM "users" WHERE ("a"->'b'->'c'->1->>'d')::int = $1`, wantArgs: []any{1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", "a", sqljson.DotPath("b.c[1].d"))), wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, '$.b.c[1].d') = ?", wantArgs: []any{"a"}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", "a", sqljson.DotPath("b.\"c[1]\".d[1][2].e"))), wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, '$.b.\"c[1]\".d[1][2].e') = ?", wantArgs: []any{"a"}, }, { input: sql.Select("*"). From(sql.Table("test")). Where(sqljson.HasKey("j", sqljson.DotPath("a.*.c"))), wantQuery: "SELECT * FROM `test` WHERE JSON_EXTRACT(`j`, '$.a.*.c') IS NOT NULL", }, { input: sql.Select("*"). From(sql.Table("test")). Where(sqljson.HasKey("j", sqljson.DotPath("a.*.c"))), wantQuery: "SELECT * FROM `test` WHERE JSON_EXTRACT(`j`, '$.a.*.c') IS NOT NULL", }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("test")). Where(sqljson.HasKey("j", sqljson.DotPath("attributes[1].body"))), wantQuery: "SELECT * FROM `test` WHERE JSON_TYPE(`j`, '$.attributes[1].body') IS NOT NULL", }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("test")). Where(sqljson.HasKey("j", sqljson.DotPath("a.*.c"))), wantQuery: "SELECT * FROM `test` WHERE JSON_TYPE(`j`, '$.a.*.c') IS NOT NULL", }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("test")). Where( sql.And( sql.GT("id", 100), sqljson.HasKey("j", sqljson.DotPath("a.*.c")), sql.EQ("active", true), ), ), wantQuery: "SELECT * FROM `test` WHERE `id` > ? AND JSON_TYPE(`j`, '$.a.*.c') IS NOT NULL AND `active`", wantArgs: []any{100}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("test")). Where(sql.And( sql.EQ("e", 10), sqljson.ValueEQ("a", 1, sqljson.DotPath("b.c")), )), wantQuery: `SELECT * FROM "test" WHERE "e" = $1 AND ("a"->'b'->>'c')::int = $2`, wantArgs: []any{10, 1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", "a", sqljson.Path("b", "c", "[1]", "d"), sqljson.Unquote(true))), wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, '$.b.c[1].d')) = ?", wantArgs: []any{"a"}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", "a", sqljson.Path("b", "c", "[1]", "d"), sqljson.Unquote(true))), wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' = $1`, wantArgs: []any{"a"}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueEQ("a", 1, sqljson.Path("b", "c", "[1]", "d"), sqljson.Cast("int"))), wantQuery: `SELECT * FROM "users" WHERE ("a"->'b'->'c'->1->>'d')::int = $1`, wantArgs: []any{1}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where( sql.Or( sqljson.ValueNEQ("a", 1, sqljson.Path("b")), sqljson.ValueGT("a", 1, sqljson.Path("c")), sqljson.ValueGTE("a", 1.1, sqljson.Path("d")), sqljson.ValueLT("a", 1, sqljson.Path("e")), sqljson.ValueLTE("a", 1, sqljson.Path("f")), ), ), wantQuery: `SELECT * FROM "users" WHERE ("a"->>'b')::int <> $1 OR ("a"->>'c')::int > $2 OR ("a"->>'d')::float >= $3 OR ("a"->>'e')::int < $4 OR ("a"->>'f')::int <= $5`, wantArgs: []any{1, 1, 1.1, 1, 1}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.LenEQ("a", 1)), wantQuery: `SELECT * FROM "users" WHERE JSONB_ARRAY_LENGTH("a") = $1`, wantArgs: []any{1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.LenEQ("a", 1)), wantQuery: "SELECT * FROM `users` WHERE JSON_LENGTH(`a`, '$') = ?", wantArgs: []any{1}, }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("users")). Where(sqljson.LenEQ("a", 1)), wantQuery: "SELECT * FROM `users` WHERE JSON_ARRAY_LENGTH(`a`, '$') = ?", wantArgs: []any{1}, }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("users")). Where( sql.Or( sqljson.LenGT("a", 1, sqljson.Path("b")), sqljson.LenGTE("a", 1, sqljson.Path("c")), sqljson.LenLT("a", 1, sqljson.Path("d")), sqljson.LenLTE("a", 1, sqljson.Path("e")), ), ), wantQuery: "SELECT * FROM `users` WHERE JSON_ARRAY_LENGTH(`a`, '$.b') > ? OR JSON_ARRAY_LENGTH(`a`, '$.c') >= ? OR JSON_ARRAY_LENGTH(`a`, '$.d') < ? OR JSON_ARRAY_LENGTH(`a`, '$.e') <= ?", wantArgs: []any{1, 1, 1, 1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", "foo")), wantQuery: "SELECT * FROM `users` WHERE JSON_CONTAINS(`tags`, ?, '$') = ?", wantArgs: []any{"\"foo\"", 1}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", 1, sqljson.Path("a"))), wantQuery: "SELECT * FROM `users` WHERE JSON_CONTAINS(`tags`, ?, '$.a') = ?", wantArgs: []any{"1", 1}, }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", "foo")), wantQuery: "SELECT * FROM `users` WHERE EXISTS(SELECT * FROM JSON_EACH(`tags`, '$') WHERE `value` = ?)", wantArgs: []any{"foo"}, }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", 1, sqljson.Path("a"))), wantQuery: "SELECT * FROM `users` WHERE EXISTS(SELECT * FROM JSON_EACH(`tags`, '$.a') WHERE `value` = ?)", wantArgs: []any{1}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", "foo")), wantQuery: "SELECT * FROM \"users\" WHERE \"tags\" @> $1", wantArgs: []any{"\"foo\""}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueContains("tags", 1, sqljson.Path("a"))), wantQuery: "SELECT * FROM \"users\" WHERE (\"tags\"->'a')::jsonb @> $1", wantArgs: []any{"1"}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.ValueIsNull("c", sqljson.Path("a"))), wantQuery: `SELECT * FROM "users" WHERE ("c"->'a')::jsonb = 'null'::jsonb`, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueIsNull("c", sqljson.Path("a"))), wantQuery: "SELECT * FROM `users` WHERE JSON_CONTAINS(`c`, 'null', '$.a')", }, { input: sql.Dialect(dialect.SQLite). Select("*"). From(sql.Table("users")). Where(sqljson.ValueIsNull("c", sqljson.Path("a"))), wantQuery: "SELECT * FROM `users` WHERE JSON_TYPE(`c`, '$.a') = 'null'", }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.StringContains("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' LIKE $1`, wantArgs: []any{"%substr%"}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where( sql.And( sqljson.StringContains("a", "c", sqljson.Path("a")), sqljson.StringContains("b", "d", sqljson.Path("b")), ), ), wantQuery: `SELECT * FROM "users" WHERE "a"->>'a' LIKE $1 AND "b"->>'b' LIKE $2`, wantArgs: []any{"%c%", "%d%"}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.StringContains("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, '$.b.c[1].d')) LIKE ?", wantArgs: []any{"%substr%"}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where( sql.And( sqljson.StringContains("a", "c", sqljson.Path("a")), sqljson.StringContains("b", "d", sqljson.Path("b")), ), ), wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, '$.a')) LIKE ? AND JSON_UNQUOTE(JSON_EXTRACT(`b`, '$.b')) LIKE ?", wantArgs: []any{"%c%", "%d%"}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.StringHasPrefix("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' LIKE $1`, wantArgs: []any{"substr%"}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.StringHasPrefix("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, '$.b.c[1].d')) LIKE ?", wantArgs: []any{"substr%"}, }, { input: sql.Dialect(dialect.Postgres). Select("*"). From(sql.Table("users")). Where(sqljson.StringHasSuffix("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), wantQuery: `SELECT * FROM "users" WHERE "a"->'b'->'c'->1->>'d' LIKE $1`, wantArgs: []any{"%substr"}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.StringHasSuffix("a", "substr", sqljson.Path("b", "c", "[1]", "d"))), wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, '$.b.c[1].d')) LIKE ?", wantArgs: []any{"%substr"}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueIn("a", []any{"a", "b"}, sqljson.Path("b"))), wantQuery: "SELECT * FROM `users` WHERE JSON_UNQUOTE(JSON_EXTRACT(`a`, '$.b')) IN (?, ?)", wantArgs: []any{"a", "b"}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueIn("a", []any{1, 2}, sqljson.Path("b"))), wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, '$.b') IN (?, ?)", wantArgs: []any{1, 2}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueIn("a", []any{1, "a"}, sqljson.Path("b"))), wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, '$.b') IN (?, ?)", wantArgs: []any{1, "a"}, }, { input: sql.Dialect(dialect.MySQL). Select("*"). From(sql.Table("users")). Where(sqljson.ValueIn("a", []any{1, 2}, sqljson.Path("foo-bar", "3000"))), wantQuery: "SELECT * FROM `users` WHERE JSON_EXTRACT(`a`, '$.\"foo-bar\".\"3000\"') IN (?, ?)", wantArgs: []any{1, 2}, }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { query, args := tt.input.Query() require.Equal(t, tt.wantQuery, query) require.Equal(t, tt.wantArgs, args) }) } } func TestParsePath(t *testing.T) { tests := []struct { input string wantPath []string wantErr bool }{ { input: "a.b.c", wantPath: []string{"a", "b", "c"}, }, { input: "a[1][2]", wantPath: []string{"a", "[1]", "[2]"}, }, { input: "a[1][2].b", wantPath: []string{"a", "[1]", "[2]", "b"}, }, { input: `a."b.c[0]"`, wantPath: []string{"a", `"b.c[0]"`}, }, { input: `a."b.c[0]".d`, wantPath: []string{"a", `"b.c[0]"`, "d"}, }, { input: `...`, }, { input: `.a.b.`, wantPath: []string{"a", "b"}, }, { input: `a."`, wantErr: true, }, { input: `a[`, wantErr: true, }, { input: `a[a]`, wantErr: true, }, } for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { path, err := sqljson.ParsePath(tt.input) require.Equal(t, tt.wantPath, path) require.Equal(t, tt.wantErr, err != nil) }) } } ent-0.11.3/doc/000077500000000000000000000000001431500740500131265ustar00rootroot00000000000000ent-0.11.3/doc/.gitignore000077500000000000000000000002551431500740500151230ustar00rootroot00000000000000.DS_Store node_modules lib/core/metadata.js lib/core/MetadataBlog.js website/translated_docs website/build/ website/node_modules website/i18n/* website/package-lock.json ent-0.11.3/doc/md/000077500000000000000000000000001431500740500135265ustar00rootroot00000000000000ent-0.11.3/doc/md/aggregate.md000077500000000000000000000046551431500740500160130ustar00rootroot00000000000000--- id: aggregate title: Aggregation --- ## Group By Group by `name` and `age` fields of all users, and sum their total age. ```go package main import ( "context" "/ent" "/ent/user" ) func Do(ctx context.Context, client *ent.Client) { var v []struct { Name string `json:"name"` Age int `json:"age"` Sum int `json:"sum"` Count int `json:"count"` } err := client.User.Query(). GroupBy(user.FieldName, user.FieldAge). Aggregate(ent.Count(), ent.Sum(user.FieldAge)). Scan(ctx, &v) } ``` Group by one field. ```go package main import ( "context" "/ent" "/ent/user" ) func Do(ctx context.Context, client *ent.Client) { names, err := client.User. Query(). GroupBy(user.FieldName). Strings(ctx) } ``` ## Group By Edge Custom aggregation functions can be useful if you want to write your own storage-specific logic. The following shows how to group by the `id` and the `name` of all users and calculate the average `age` of their pets. ```go package main import ( "context" "log" "/ent" "/ent/pet" "/ent/user" ) func Do(ctx context.Context, client *ent.Client) { var users []struct { ID int Name string Average float64 } err := client.User.Query(). GroupBy(user.FieldID, user.FieldName). Aggregate(func(s *sql.Selector) string { t := sql.Table(pet.Table) s.Join(t).On(s.C(user.FieldID), t.C(pet.OwnerColumn)) return sql.As(sql.Avg(t.C(pet.FieldAge)), "average") }). Scan(ctx, &users) } ``` ## Having + Group By [Custom SQL modifiers](https://entgo.io/docs/feature-flags/#custom-sql-modifiers) can be useful if you want to control all query parts. The following shows how to retrieve the oldest users for each role. ```go package main import ( "context" "log" "entgo.io/ent/dialect/sql" "/ent" "/ent/user" ) func Do(ctx context.Context, client *ent.Client) { var users []struct { Id Int Age Int Role string } err := client.User.Query(). Modify(func(s *sql.Selector) { s.GroupBy(user.Role) s.Having( sql.EQ( user.FieldAge, sql.Raw(sql.Max(user.FieldAge)), ), ) }). ScanX(ctx, &users) } ``` **Note:** The `sql.Raw` is crucial to have. It tells the predicate that `sql.Max` is not an argument. The above code essentially generates the following SQL query: ```sql SELECT * FROM user GROUP BY user.role HAVING user.age = MAX(user.age) ``` ent-0.11.3/doc/md/ci.mdx000066400000000000000000000213751431500740500146430ustar00rootroot00000000000000--- id: ci title: Continuous Integration --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; To ensure the quality of their software, teams often apply _Continuous Integration_ workflows, commonly known as CI. With CI, teams continuously run a suite of automated verifications against every change to the code-base. During CI, teams may run many kinds of verifications: * Compilation or build of the most recent version to make sure it isn't broken. * Linting to enforce any accepted code-style standards. * Unit tests that verify individual components work as expected and that changes to the codebase do not cause regressions in other areas. * Security scans to make sure no known vulnerabilities are introduced to the codebase. * And much more! From our discussions with the Ent community, we have learned that many teams using Ent already use CI and would like to enforce some Ent-specific verifications into their workflows. To support the community with this effort we have started this guide which documents common best practices to verify in CI and introduces [ent/contrib/ci](https://github.com/ent/contrib/ci) a GitHub Action we maintain that codifies them. ## Verify all generated files are checked in Ent heavily relies on code generation. In our experience, generated code should always be checked into source control. This is done for two reasons: * If generated code is checked into source control, it can be read along with the main application code. Having generated code present when the code is reviewed or when a repository is browsed is essential to get a complete picture of how things work. * Differences in development environments between team members can easily be spotted and remedied. This further reduces the chance of "it works on my machine" type issues since everyone is running the same code. If you're using GitHub for source control, it's easy to verify that all generated files are checked in with the `ent/contrib/ci` GitHub Action. Otherwise, we supply a simple bash script that you can integrate in your existing CI flow. Simply add a file named `.github/workflows/ent-ci.yaml` in your repository: ```yaml name: EntCI on: push: # Run whenever code is changed in the master. branches: - master # Run on PRs where something changed under the `ent/` directory. pull_request: paths: - 'ent/*' jobs: ent: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3.0.1 - uses: actions/setup-go@v3 with: go-version: 1.18 - uses: ent/contrib/ci@master ``` ```bash go generate ./... status=$(git status --porcelain) if [ -n "$status" ]; then echo "you need to run 'go generate ./...' and commit the changes" echo "$status" exit 1 fi ``` ## Lint migration files Changes to your project's Ent schema almost always result in a modification of your database. If you are using [Versioned Migrations](/docs/versioned-migrations) to manage changes to your database schema, you can run [migration linting](https://atlasgo.io/versioned/lint) as part of your continuous integration flow. This is done for multiple reasons: * Linting replays your migration directory on a [database container](https://atlasgo.io/concepts/dev-database) to make sure all SQL statements are valid and in the correct order. * [Migration directory integrity](/docs/versioned-migrations#atlas-migration-directory-integrity-file) is enforced - ensuring that history wasn't accidentally changed and that migrations that are planned in parallel are unified to a clean linear history. * Destructive changes are detected notifying you of any potential data loss that may be caused by your migrations way before they reach your production database. * Linting detects data-dependant changes that _may_ fail upon deployment and require more careful review from your side. If you're using GitHub, you can use the [Official Atlas Action](https://github.com/ariga/atlas-action) to run migration linting during CI. Add `.github/workflows/atlas-ci.yaml` to your repo with the following contents: ```yaml name: Atlas CI on: # Run whenever code is changed in the master branch, # change this to your root branch. push: branches: - master # Run on PRs where something changed under the `ent/migrate/migrations/` directory. pull_request: paths: - 'ent/migrate/migrations/*' jobs: lint: services: # Spin up a mysql:8.0.29 container to be used as the dev-database for analysis. mysql: image: mysql:8.0.29 env: MYSQL_ROOT_PASSWORD: pass MYSQL_DATABASE: test ports: - 3306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 runs-on: ubuntu-latest steps: - uses: actions/checkout@v3.0.1 with: fetch-depth: 0 # Mandatory unless "latest" is set below. - uses: ariga/atlas-action@v0 with: dir: ent/migrate/migrations dir-format: golang-migrate # Or: atlas, goose, dbmate dev-url: mysql://root:pass@localhost:3306/test ``` ```yaml name: Atlas CI on: # Run whenever code is changed in the master branch, # change this to your root branch. push: branches: - master # Run on PRs where something changed under the `ent/migrate/migrations/` directory. pull_request: paths: - 'ent/migrate/migrations/*' jobs: lint: services: # Spin up a maria:10.7 container to be used as the dev-database for analysis. maria: image: mariadb:10.7 env: MYSQL_DATABASE: test MYSQL_ROOT_PASSWORD: pass ports: - 3306:3306 options: >- --health-cmd "mysqladmin ping -ppass" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 runs-on: ubuntu-latest steps: - uses: actions/checkout@v3.0.1 with: fetch-depth: 0 # Mandatory unless "latest" is set below. - uses: ariga/atlas-action@v0 with: dir: ent/migrate/migrations dir-format: golang-migrate # Or: atlas, goose, dbmate dev-url: maria://root:pass@localhost:3306/test ``` ```yaml name: Atlas CI on: # Run whenever code is changed in the master branch, # change this to your root branch. push: branches: - master # Run on PRs where something changed under the `ent/migrate/migrations/` directory. pull_request: paths: - 'ent/migrate/migrations/*' jobs: lint: services: # Spin up a postgres:10 container to be used as the dev-database for analysis. postgres: image: postgres:10 env: POSTGRES_DB: test POSTGRES_PASSWORD: pass ports: - 5432:5432 options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 runs-on: ubuntu-latest steps: - uses: actions/checkout@v3.0.1 with: fetch-depth: 0 # Mandatory unless "latest" is set below. - uses: ariga/atlas-action@v0 with: dir: ent/migrate/migrations dir-format: golang-migrate # Or: atlas, goose, dbmate dev-url: postgres://postgres:pass@localhost:5432/test?sslmode=disable ``` ```yaml name: Atlas CI on: # Run whenever code is changed in the master branch, # change this to your root branch. push: branches: - master # Run on PRs where something changed under the `ent/migrate/migrations/` directory. pull_request: paths: - 'ent/migrate/migrations/*' jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3.0.1 with: fetch-depth: 0 # Mandatory unless "latest" is set below. - uses: ariga/atlas-action@v0 with: dir: ent/migrate/migrations dir-format: golang-migrate # Or: atlas, goose, dbmate dev-url: sqlite://./dev.db?_fk=1 ``` Notice that running `atlas migrate lint` requires a clean [dev-database](https://atlasgo.io/concepts/dev-database) which is provided by the `services` block in the example code above.ent-0.11.3/doc/md/code-gen.md000077500000000000000000000226751431500740500155500ustar00rootroot00000000000000--- id: code-gen title: Introduction --- ## Installation The project comes with a codegen tool called `ent`. In order to install `ent` run the following command: ```bash go get -d entgo.io/ent/cmd/ent ``` ## Initialize A New Schema In order to generate one or more schema templates, run `ent init` as follows: ```bash go run -mod=mod entgo.io/ent/cmd/ent init User Pet ``` `init` will create the 2 schemas (`user.go` and `pet.go`) under the `ent/schema` directory. If the `ent` directory does not exist, it will create it as well. The convention is to have an `ent` directory under the root directory of the project. ## Generate Assets After adding a few [fields](schema-fields.md) and [edges](schema-edges), you want to generate the assets for working with your entities. Run `ent generate` from the root directory of the project, or use `go generate`: ```bash go generate ./ent ``` The `generate` command generates the following assets for the schemas: - `Client` and `Tx` objects used for interacting with the graph. - CRUD builders for each schema type. See [CRUD](crud.md) for more info. - Entity object (Go struct) for each of the schema types. - Package containing constants and predicates used for interacting with the builders. - A `migrate` package for SQL dialects. See [Migration](migrate.md) for more info. - A `hook` package for adding mutation middlewares. See [Hooks](hooks.md) for more info. ## Version Compatibility Between `entc` And `ent` When working with `ent` CLI in a project, you want to make sure the version being used by the CLI is **identical** to the `ent` version used by your project. One of the options for achieving this is asking `go generate` to use the version mentioned in the `go.mod` file when running `ent`. If your project does not use [Go modules](https://github.com/golang/go/wiki/Modules#quick-start), setup one as follows: ```console go mod init ``` And then, re-run the following command in order to add `ent` to your `go.mod` file: ```console go get -d entgo.io/ent/cmd/ent ``` Add a `generate.go` file to your project under `/ent`: ```go package ent //go:generate go run -mod=mod entgo.io/ent/cmd/ent generate ./schema ``` Finally, you can run `go generate ./ent` from the root directory of your project in order to run `ent` code generation on your project schemas. ## Code Generation Options For more info about codegen options, run `ent generate -h`: ```console generate go code for the schema directory Usage: ent generate [flags] path Examples: ent generate ./ent/schema ent generate github.com/a8m/x Flags: --feature strings extend codegen with additional features --header string override codegen header -h, --help help for generate --idtype [int string] type of the id field (default int) --storage string storage driver to support in codegen (default "sql") --target string target directory for codegen --template strings external templates to execute ``` ## Storage Options `ent` can generate assets for both SQL and Gremlin dialect. The default dialect is SQL. ## External Templates `ent` accepts external Go templates to execute. If the template name already defined by `ent`, it will override the existing one. Otherwise, it will write the execution output to a file with the same name as the template. The flag format supports `file`, `dir` and `glob` as follows: ```console go run -mod=mod entgo.io/ent/cmd/ent generate --template --template glob="path/to/*.tmpl" ./ent/schema ``` More information and examples can be found in the [external templates doc](templates.md). ## Use `entc` as a Package Another option for running `ent` code generation is to create a file named `ent/entc.go` with the following content, and then the `ent/generate.go` file to execute it: ```go title="ent/entc.go" // +build ignore package main import ( "log" "entgo.io/ent/entc" "entgo.io/ent/entc/gen" "entgo.io/ent/schema/field" ) func main() { if err := entc.Generate("./schema", &gen.Config{}); err != nil { log.Fatal("running ent codegen:", err) } } ``` ```go title="ent/generate.go" package ent //go:generate go run -mod=mod entc.go ``` The full example exists in [GitHub](https://github.com/ent/ent/tree/master/examples/entcpkg). ## Schema Description In order to get a description of your graph schema, run: ```bash go run -mod=mod entgo.io/ent/cmd/ent describe ./ent/schema ``` An example for the output is as follows: ```console Pet: +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ | id | int | false | false | false | false | false | false | json:"id,omitempty" | 0 | | name | string | false | false | false | false | false | false | json:"name,omitempty" | 0 | +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ +-------+------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +-------+------+---------+---------+----------+--------+----------+ | owner | User | true | pets | M2O | true | true | +-------+------+---------+---------+----------+--------+----------+ User: +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ | Field | Type | Unique | Optional | Nillable | Default | UpdateDefault | Immutable | StructTag | Validators | +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ | id | int | false | false | false | false | false | false | json:"id,omitempty" | 0 | | age | int | false | false | false | false | false | false | json:"age,omitempty" | 0 | | name | string | false | false | false | false | false | false | json:"name,omitempty" | 0 | +-------+---------+--------+----------+----------+---------+---------------+-----------+-----------------------+------------+ +------+------+---------+---------+----------+--------+----------+ | Edge | Type | Inverse | BackRef | Relation | Unique | Optional | +------+------+---------+---------+----------+--------+----------+ | pets | Pet | false | | O2M | false | true | +------+------+---------+---------+----------+--------+----------+ ``` ## Code Generation Hooks The `entc` package provides an option to add a list of hooks (middlewares) to the code-generation phase. This option is ideal for adding custom validators for the schema, or for generating additional assets using the graph schema. ```go // +build ignore package main import ( "fmt" "log" "reflect" "entgo.io/ent/entc" "entgo.io/ent/entc/gen" ) func main() { err := entc.Generate("./schema", &gen.Config{ Hooks: []gen.Hook{ EnsureStructTag("json"), }, }) if err != nil { log.Fatalf("running ent codegen: %v", err) } } // EnsureStructTag ensures all fields in the graph have a specific tag name. func EnsureStructTag(name string) gen.Hook { return func(next gen.Generator) gen.Generator { return gen.GenerateFunc(func(g *gen.Graph) error { for _, node := range g.Nodes { for _, field := range node.Fields { tag := reflect.StructTag(field.StructTag) if _, ok := tag.Lookup(name); !ok { return fmt.Errorf("struct tag %q is missing for field %s.%s", name, node.Name, field.Name) } } } return next.Generate(g) }) } } ``` ## External Dependencies In order to extend the generated client and builders under the `ent` package, and inject them external dependencies as struct fields, use the `entc.Dependency` option in your [`ent/entc.go`](#use-entc-as-a-package) file: ```go title="ent/entc.go" {3-12} func main() { opts := []entc.Option{ entc.Dependency( entc.DependencyType(&http.Client{}), ), entc.Dependency( entc.DependencyName("Writer"), entc.DependencyTypeInfo(&field.TypeInfo{ Ident: "io.Writer", PkgPath: "io", }), ), } if err := entc.Generate("./schema", &gen.Config{}, opts...); err != nil { log.Fatalf("running ent codegen: %v", err) } } ``` Then, use it in your application: ```go title="example_test.go" {5-6,15-16} func Example_Deps() { client, err := ent.Open( "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1", ent.Writer(os.Stdout), ent.HTTPClient(http.DefaultClient), ) if err != nil { log.Fatalf("failed opening connection to sqlite: %v", err) } defer client.Close() // An example for using the injected dependencies in the generated builders. client.User.Use(func(next ent.Mutator) ent.Mutator { return hook.UserFunc(func(ctx context.Context, m *ent.UserMutation) (ent.Value, error) { _ = m.HTTPClient _ = m.Writer return next.Mutate(ctx, m) }) }) // ... } ``` The full example exists in [GitHub](https://github.com/ent/ent/tree/master/examples/entcpkg). ## Feature Flags The `entc` package provides a collection of code-generation features that be added or removed using flags. For more information, please see the [features-flags page](features.md). ent-0.11.3/doc/md/contributors.md000066400000000000000000000670531431500740500166200ustar00rootroot00000000000000--- id: contributors title: Contributors --- Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)):

Ariel Mashraki

🚧 📖 💻

Alex Snast

💻

Rotem Tamir

🚧 📖 💻

Ciaran Liedeman

💻

Marwan Sulaiman

💻

Nathaniel Peiffer

💻

Travis Cline

💻

Jeremy

💻

aca

💻

BrentChesny

💻 📖

Giau. Tran Minh

💻 👀

Hylke Visser

💻

Pavel Kerbel

💻

zhangnan

💻

mori yuta

💻 🌍 👀

Christoph Hartmann

💻

Ruben de Vries

💻

Aleksandr Razumov

💻

apbuteau

💻

Harold.Luo

💻

ido shveki

💻

MasseElch

💻

Jian Li

💻

Noah-Jerome Lotzer

💻

danforth

💻

maxilozoz

💻

zzwx

💻

MengYX

🌍

mattn

🌍

Hugo Briand

💻

Dan Enman

💻

Rumen Nikiforov

💻

陈杨文

💻

Qiaosen (Joeson) Huang

🐛

AlonDavidBehr

💻 👀

DuGlaser

📖

Shane Hanna

📖

Mahmudul Haque

💻

Benjamin Bourgeais

💻

8ayac(Yoshinori Hayashi)

📖

y-yagi

📖

Ben Woodward

💻

WzyJerry

💻

Tarrence van As

📖 💻

Yuya Sumie

📖

Michal Mazurek

💻

Takafumi Umemoto

📖

Khadija Sidhpuri

💻

Neel Modi

💻

Boris Shomodjvarac

📖

Sadman Sakib

📖

dakimura

💻

Risky Feryansyah

💻

seiichi

💻

Emmanuel T Odeke

💻

Hiroki Isogai

📖

李清山

💻

s-takehana

📖

Kuiba

💻

storyicon

💻

Evan Lurvey

💻

Brian

📖

Shen Yang

💻

sivchari

💻

mook

💻

heliumbrain

📖

Jeremy Maxey-Vesperman

💻 📖

Christopher Schmitt

📖

Gerardo Reyes

💻

Naor Matania

💻

idc77

📖

Sungyun Hur

📖

peanut-pg

📖

Mehmet Yılmaz

💻

Roman Maklakov

💻

Genevieve

💻

Clarence

💻

Nicholas Anderson

💻

Zhizhen He

💻

Pedro Henrique

💻

MrParano1d

💻

Thomas Prebble

💻

Huy TQ

💻

maorlipchuk

💻

Motonori Iwata

📖

Charles Ge

💻

Thomas Meitz

💻 📖

Justin Johnson

💻

hax10

💻

water-a

🐛

jhwz

📖
This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! ent-0.11.3/doc/md/crud.md000077500000000000000000000232441431500740500150150ustar00rootroot00000000000000--- id: crud title: CRUD API --- As mentioned in the [introduction](code-gen.md) section, running `ent` on the schemas, will generate the following assets: - `Client` and `Tx` objects used for interacting with the graph. - CRUD builders for each schema type. See [CRUD](crud.md) for more info. - Entity object (Go struct) for each of the schema type. - Package containing constants and predicates used for interacting with the builders. - A `migrate` package for SQL dialects. See [Migration](migrate.md) for more info. ## Create A New Client **MySQL** ```go package main import ( "log" "/ent" _ "github.com/go-sql-driver/mysql" ) func main() { client, err := ent.Open("mysql", ":@tcp(:)/?parseTime=True") if err != nil { log.Fatal(err) } defer client.Close() } ``` **PostgreSQL** ```go package main import ( "log" "/ent" _ "github.com/lib/pq" ) func main() { client, err := ent.Open("postgres","host= port= user= dbname= password=") if err != nil { log.Fatal(err) } defer client.Close() } ``` **SQLite** ```go package main import ( "log" "/ent" _ "github.com/mattn/go-sqlite3" ) func main() { client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") if err != nil { log.Fatal(err) } defer client.Close() } ``` **Gremlin (AWS Neptune)** ```go package main import ( "log" "/ent" ) func main() { client, err := ent.Open("gremlin", "http://localhost:8182") if err != nil { log.Fatal(err) } } ``` ## Create An Entity **Save** a user. ```go a8m, err := client.User. // UserClient. Create(). // User create builder. SetName("a8m"). // Set field value. SetNillableAge(age). // Avoid nil checks. AddGroups(g1, g2). // Add many edges. SetSpouse(nati). // Set unique edge. Save(ctx) // Create and return. ``` **SaveX** a pet; Unlike **Save**, **SaveX** panics if an error occurs. ```go pedro := client.Pet. // PetClient. Create(). // Pet create builder. SetName("pedro"). // Set field value. SetOwner(a8m). // Set owner (unique edge). SaveX(ctx) // Create and return. ``` ## Create Many **Save** a bulk of pets. ```go names := []string{"pedro", "xabi", "layla"} bulk := make([]*ent.PetCreate, len(names)) for i, name := range names { bulk[i] = client.Pet.Create().SetName(name).SetOwner(a8m) } pets, err := client.Pet.CreateBulk(bulk...).Save(ctx) ``` ## Update One Update an entity that was returned from the database. ```go a8m, err = a8m.Update(). // User update builder. RemoveGroup(g2). // Remove specific edge. ClearCard(). // Clear unique edge. SetAge(30). // Set field value Save(ctx) // Save and return. ``` ## Update By ID ```go pedro, err := client.Pet. // PetClient. UpdateOneID(id). // Pet update builder. SetName("pedro"). // Set field name. SetOwnerID(owner). // Set unique edge, using id. Save(ctx) // Save and return. ``` ## Update Many Filter using predicates. ```go n, err := client.User. // UserClient. Update(). // Pet update builder. Where( // user.Or( // (age >= 30 OR name = "bar") user.AgeGT(30), // user.Name("bar"), // AND ), // user.HasFollowers(), // UserHasFollowers() ). // SetName("foo"). // Set field name. Save(ctx) // exec and return. ``` Query edge-predicates. ```go n, err := client.User. // UserClient. Update(). // Pet update builder. Where( // user.HasFriendsWith( // UserHasFriendsWith ( user.Or( // age = 20 user.Age(20), // OR user.Age(30), // age = 30 ) // ) ), // ). // SetName("a8m"). // Set field name. Save(ctx) // exec and return. ``` ## Upsert One Ent supports [upsert](https://en.wikipedia.org/wiki/Merge_(SQL)) records using the [`sql/upsert`](features.md#upsert) feature-flag. ```go err := client.User. Create(). SetAge(30). SetName("Ariel"). OnConflict(). // Use the new values that were set on create. UpdateNewValues(). Exec(ctx) id, err := client.User. Create(). SetAge(30). SetName("Ariel"). OnConflict(). // Use the "age" that was set on create. UpdateAge(). // Set a different "name" in case of conflict. SetName("Mashraki"). ID(ctx) // Customize the UPDATE clause. err := client.User. Create(). SetAge(30). SetName("Ariel"). OnConflict(). UpdateNewValues(). // Override some of the fields with a custom update. Update(func(u *ent.UserUpsert) { u.SetAddress("localhost") u.AddCount(1) u.ClearPhone() }). Exec(ctx) ``` In PostgreSQL, the [conflict target](https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT) is required: ```go // Setting the column names using the fluent API. err := client.User. Create(). SetName("Ariel"). OnConflictColumns(user.FieldName). UpdateNewValues(). Exec(ctx) // Setting the column names using the SQL API. err := client.User. Create(). SetName("Ariel"). OnConflict( sql.ConflictColumns(user.FieldName), ). UpdateNewValues(). Exec(ctx) // Setting the constraint name using the SQL API. err := client.User. Create(). SetName("Ariel"). OnConflict( sql.ConflictConstraint(constraint), ). UpdateNewValues(). Exec(ctx) ``` In order to customize the executed statement, use the SQL API: ```go id, err := client.User. Create(). OnConflict( sql.ConflictColumns(...), sql.ConflictWhere(...), sql.UpdateWhere(...), ). Update(func(u *ent.UserUpsert) { u.SetAge(30) u.UpdateName() }). ID(ctx) // INSERT INTO "users" (...) VALUES (...) ON CONFLICT WHERE ... DO UPDATE SET ... WHERE ... ``` :::info Since the upsert API is implemented using the `ON CONFLICT` clause (and `ON DUPLICATE KEY` in MySQL), Ent executes only one statement to the database, and therefore, only create [hooks](hooks.md) are applied for such operations. ::: ## Upsert Many ```go err := client.User. // UserClient CreateBulk(builders...). // User bulk create. OnConflict(). // User bulk upsert. UpdateNewValues(). // Use the values that were set on create in case of conflict. Exec(ctx) // Execute the statement. ``` ## Query The Graph Get all users with followers. ```go users, err := client.User. // UserClient. Query(). // User query builder. Where(user.HasFollowers()). // filter only users with followers. All(ctx) // query and return. ``` Get all followers of a specific user; Start the traversal from a node in the graph. ```go users, err := a8m. QueryFollowers(). All(ctx) ``` Get all pets of the followers of a user. ```go users, err := a8m. QueryFollowers(). QueryPets(). All(ctx) ``` Count the number of posts without comments. ```go n, err := client.Post. Query(). Where( post.Not( post.HasComments(), ) ). Count(ctx) ``` More advance traversals can be found in the [next section](traversals.md). ## Field Selection Get all pet names. ```go names, err := client.Pet. Query(). Select(pet.FieldName). Strings(ctx) ``` Get all unique pet names. ```go names, err := client.Pet. Query(). Unique(true). Select(pet.FieldName). Strings(ctx) ``` Count the number of unique pet names. ```go n, err := client.Pet. Query(). Unique(true). Select(pet.FieldName). Count(ctx) ``` Select partial objects and partial associations.gs Get all pets and their owners, but select and fill only the `ID` and `Name` fields. ```go pets, err := client.Pet. Query(). Select(pet.FieldName). WithOwner(func (q *ent.UserQuery) { q.Select(user.FieldName) }). All(ctx) ``` Scan all pet names and ages to custom struct. ```go var v []struct { Age int `json:"age"` Name string `json:"name"` } err := client.Pet. Query(). Select(pet.FieldAge, pet.FieldName). Scan(ctx, &v) if err != nil { log.Fatal(err) } ``` Update an entity and return a partial of it. ```go pedro, err := client.Pet. UpdateOneID(id). SetAge(9). SetName("pedro"). // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. Select(pet.FieldName). Save(ctx) ``` ## Delete One Delete an entity. ```go err := client.User. DeleteOne(a8m). Exec(ctx) ``` Delete by ID. ```go err := client.User. DeleteOneID(id). Exec(ctx) ``` ## Delete Many Delete using predicates. ```go _, err := client.File. Delete(). Where(file.UpdatedAtLT(date)). Exec(ctx) ``` ## Mutation Each generated node type has its own type of mutation. For example, all [`User` builders](crud.md#create-an-entity), share the same generated `UserMutation` object. However, all builder types implement the generic `ent.Mutation` interface. For example, in order to write a generic code that apply a set of methods on both `ent.UserCreate` and `ent.UserUpdate`, use the `UserMutation` object: ```go func Do() { creator := client.User.Create() SetAgeName(creator.Mutation()) updater := client.User.UpdateOneID(id) SetAgeName(updater.Mutation()) } // SetAgeName sets the age and the name for any mutation. func SetAgeName(m *ent.UserMutation) { m.SetAge(32) m.SetName("Ariel") } ``` In some cases, you want to apply a set of methods on multiple types. For cases like this, either use the generic `ent.Mutation` interface, or create your own interface. ```go func Do() { creator1 := client.User.Create() SetName(creator1.Mutation(), "a8m") creator2 := client.Pet.Create() SetName(creator2.Mutation(), "pedro") } // SetNamer wraps the 2 methods for getting // and setting the "name" field in mutations. type SetNamer interface { SetName(string) Name() (string, bool) } func SetName(m SetNamer, name string) { if _, exist := m.Name(); !exist { m.SetName(name) } } ``` ent-0.11.3/doc/md/dialects.md000077500000000000000000000033031431500740500156420ustar00rootroot00000000000000--- id: dialects title: Supported Dialects --- ## MySQL MySQL supports all the features that are mentioned in the [Migration](migrate.md) section, and it's being tested constantly on the following 3 versions: `5.6.35`, `5.7.26` and `8`. ## MariaDB MariaDB supports all the features that are mentioned in the [Migration](migrate.md) section, and it's being tested constantly on the following 3 versions: `10.2`, `10.3` and latest version. ## PostgreSQL PostgreSQL supports all the features that are mentioned in the [Migration](migrate.md) section, and it's being tested constantly on the following 5 versions: `10`, `11`, `12`, `13` and `14`. ## CockroachDB **(preview)** CockroachDB support is in preview and requires the [Atlas migration engine](#atlas-integration). The integration with CRDB is currently tested on versions `v21.2.11`. ## SQLite SQLite supports all _"append-only"_ features mentioned in the [Migration](migrate.md) section. However, dropping or modifying resources, like [drop-index](migrate.md#drop-resources) are not supported by default by SQLite, and will be added in the future using a [temporary table](https://www.sqlite.org/lang_altertable.html#otheralter). ## Gremlin Gremlin does not support migration nor indexes, and **it's considered experimental**. ## TiDB **(preview)** TiDB support is in preview and requires the [Atlas migration engine](#atlas-integration). TiDB is MySQL compatible and thus any feature that works on MySQL _should_ work on TiDB as well. For a list of known compatibility issues, visit: https://docs.pingcap.com/tidb/stable/mysql-compatibility The integration with TiDB is currently tested on versions `5.4.0`, `6.0.0`. ent-0.11.3/doc/md/eager-load.mdx000066400000000000000000000104211431500740500162360ustar00rootroot00000000000000--- id: eager-load title: Eager Loading --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; ## Overview `ent` supports querying entities with their associations (through their edges). The associated entities are populated to the `Edges` field in the returned object. Let's give an example of what the API looks like for the following schema: ![er-group-users](https://entgo.io/images/assets/er_user_pets_groups.png) **Query all users with their pets:** ```go users, err := client.User. Query(). WithPets(). All(ctx) if err != nil { return err } // The returned users look as follows: // // [ // User { // ID: 1, // Name: "a8m", // Edges: { // Pets: [Pet(...), ...] // ... // } // }, // ... // ] // for _, u := range users { for _, p := range u.Edges.Pets { fmt.Printf("User(%v) -> Pet(%v)\n", u.ID, p.ID) // Output: // User(...) -> Pet(...) } } ``` Eager loading allows to query more than one association (including nested), and also filter, sort or limit their result. For example: ```go admins, err := client.User. Query(). Where(user.Admin(true)). // Populate the `pets` that associated with the `admins`. WithPets(). // Populate the first 5 `groups` that associated with the `admins`. WithGroups(func(q *ent.GroupQuery) { q.Limit(5) // Limit to 5. q.WithUsers() // Populate the `users` of each `groups`. }). All(ctx) if err != nil { return err } // The returned users look as follows: // // [ // User { // ID: 1, // Name: "admin1", // Edges: { // Pets: [Pet(...), ...] // Groups: [ // Group { // ID: 7, // Name: "GitHub", // Edges: { // Users: [User(...), ...] // ... // } // } // ] // } // }, // ... // ] // for _, admin := range admins { for _, p := range admin.Edges.Pets { fmt.Printf("Admin(%v) -> Pet(%v)\n", u.ID, p.ID) // Output: // Admin(...) -> Pet(...) } for _, g := range admin.Edges.Groups { for _, u := range g.Edges.Users { fmt.Printf("Admin(%v) -> Group(%v) -> User(%v)\n", u.ID, g.ID, u.ID) // Output: // Admin(...) -> Group(...) -> User(...) } } } ``` ## API Each query-builder has a list of methods in the form of `With(...func(Query))` for each of its edges. `` stands for the edge name (like, `WithGroups`) and `` for the edge type (like, `GroupQuery`). Note that only SQL dialects support this feature. ## Named Edges In some cases there is a need for preloading edges with custom names. For example, a GraphQL query that has two aliases referencing the same edge with different arguments. For this situation, Ent provides another API named `WithNamed` that can be enabled using the [`namedges`](features.md#named-edges) feature-flag and seamlessly integrated with [EntGQL Fields Collection](tutorial-todo-gql-field-collection.md). See the GraphQL tab to learn more about the motivation behind this API. ```go posts, err := client.Post.Query(). WithNamedComments("published", func(q *ent.CommentQuery) { q.Where(comment.StatusEQ(comment.StatusPublished)) }) WithNamedComments("draft", func(q *ent.CommentQuery) { q.Where(comment.StatusEQ(comment.StatusDraft)) }). Paginate(...) // Get the preloaded edges by their name: for _, p := range posts { published, err := p.Edges.NamedComments("published") if err != nil { return err } draft, err := p.Edges.NamedComments("draft") if err != nil { return err } } ``` An example of a GraphQL query that has two aliases referencing the same edge with different arguments. ```graphql query { posts { id title published: comments(where: { status: PUBLISHED }) { edges { node { text } } } draft: comments(where: { status: DRAFT }) { edges { node { text } } } } } ``` ## Implementation Since an Ent query can eager-load more than one edge, it is not possible to load all associations in a single `JOIN` operation. Therefore, Ent executes additional query to load each association. This expected to be optimized in future versions. ent-0.11.3/doc/md/extension.md000077500000000000000000000203651431500740500160750ustar00rootroot00000000000000--- id: extensions title: Extensions --- ### Introduction The Ent [Extension API](https://pkg.go.dev/entgo.io/ent/entc#Extension) facilitates the creation of code-generation extensions that bundle together [codegen hooks](code-gen.md#code-generation-hooks), [templates](templates.md) and [annotations](templates.md#annotations) to create reusable components that add new rich functionality to Ent's core. For example, Ent's [entgql plugin](https://pkg.go.dev/entgo.io/contrib/entgql#Extension) exposes an `Extension` that automatically generates GraphQL servers from an Ent schema. ### Defining a New Extension All extension's must implement the [Extension](https://pkg.go.dev/entgo.io/ent/entc#Extension) interface: ```go type Extension interface { // Hooks holds an optional list of Hooks to apply // on the graph before/after the code-generation. Hooks() []gen.Hook // Annotations injects global annotations to the gen.Config object that // can be accessed globally in all templates. Unlike schema annotations, // being serializable to JSON raw value is not mandatory. // // {{- with $.Config.Annotations.GQL }} // {{/* Annotation usage goes here. */}} // {{- end }} // Annotations() []Annotation // Templates specifies a list of alternative templates // to execute or to override the default. Templates() []*gen.Template // Options specifies a list of entc.Options to evaluate on // the gen.Config before executing the code generation. Options() []Option } ``` To simplify the development of new extensions, developers can embed [entc.DefaultExtension](https://pkg.go.dev/entgo.io/ent/entc#DefaultExtension) to create extensions without implementing all methods: ```go package hello // GreetExtension implements entc.Extension. type GreetExtension struct { entc.DefaultExtension } ``` ### Adding Templates Ent supports adding [external templates](templates.md) that will be rendered during code generation. To bundle such external templates on an extension, implement the `Templates` method: ```gotemplate title="templates/greet.tmpl" {{/* Tell Intellij/GoLand to enable the autocompletion based on the *gen.Graph type. */}} {{/* gotype: entgo.io/ent/entc/gen.Graph */}} {{ define "greet" }} {{/* Add the base header for the generated file */}} {{ $pkg := base $.Config.Package }} {{ template "header" $ }} {{/* Loop over all nodes and add the Greet method */}} {{ range $n := $.Nodes }} {{ $receiver := $n.Receiver }} func ({{ $receiver }} *{{ $n.Name }}) Greet() string { return "Hello, {{ $n.Name }}" } {{ end }} {{ end }} ``` ```go func (*GreetExtension) Templates() []*gen.Template { return []*gen.Template{ gen.MustParse(gen.NewTemplate("greet").ParseFiles("templates/greet.tmpl")), } } ``` ### Adding Global Annotations Annotations are a convenient way to supply users of our extension with an API to modify the behavior of code generation. To add annotations to our extension, implement the `Annotations` method. Let's say in our `GreetExtension` we want to provide users with the ability to configure the greeting word in the generated code: ```go // GreetingWord implements entc.Annotation. type GreetingWord string // Name of the annotation. Used by the codegen templates. func (GreetingWord) Name() string { return "GreetingWord" } ``` Then add it to the `GreetExtension` struct: ```go type GreetExtension struct { entc.DefaultExtension word GreetingWord } ``` Next, implement the `Annotations` method: ```go func (s *GreetExtension) Annotations() []entc.Annotation { return []entc.Annotation{ s.word, } } ``` Now, from within your templates you can access the `GreetingWord` annotation: ```gotemplate func ({{ $receiver }} *{{ $n.Name }}) Greet() string { return "{{ $.Annotations.GreetingWord }}, {{ $n.Name }}" } ``` ### Adding Hooks The entc package provides an option to add a list of [hooks](code-gen.md#code-generation-hooks) (middlewares) to the code-generation phase. This option is ideal for adding custom validators for the schema, or for generating additional assets using the graph schema. To bundle code generation hooks with your extension, implement the `Hooks` method: ```go func (s *GreetExtension) Hooks() []gen.Hook { return []gen.Hook{ DisallowTypeName("Shalom"), } } // DisallowTypeName ensures there is no ent.Schema with the given name in the graph. func DisallowTypeName(name string) gen.Hook { return func(next gen.Generator) gen.Generator { return gen.GenerateFunc(func(g *gen.Graph) error { for _, node := range g.Nodes { if node.Name == name { return fmt.Errorf("entc: validation failed, type named %q not allowed", name) } } return next.Generate(g) }) } } ``` ### Using an Extension in Code Generation To use an extension in our code-generation configuration, use `entc.Extensions`, a helper method that returns an `entc.Option` that applies our chosen extensions: ```go title="ent/entc.go" //+build ignore package main import ( "fmt" "log" "entgo.io/ent/entc" "entgo.io/ent/entc/gen" ) func main() { err := entc.Generate("./schema", &gen.Config{}, entc.Extensions(&GreetExtension{ word: GreetingWord("Shalom"), }), ) if err != nil { log.Fatal("running ent codegen:", err) } } ``` ### Popular Extensions - **[elk (discontinued)](https://github.com/masseelch/elk)** `elk` is an extension that generates RESTful API endpoints from Ent schemas. The extension generates HTTP CRUD handlers from the Ent schema, as well as an OpenAPI JSON file. By using it, you can easily build a RESTful HTTP server for your application. Please note, that `elk` has been discontinued in favor of `entoas`. An implementation generator is in the works. Read [this blog post](https://entgo.io/blog/2021/07/29/generate-a-fully-working-go-crud-http-api-with-ent) on how to work with `elk`, and [this blog post](https://entgo.io/blog/2021/09/10/openapi-generator) on how to generate an [OpenAPI Specification](https://swagger.io/resources/open-api/). - **[entoas](https://github.com/ent/contrib/tree/master/entoas)** `entoas` is an extension that originates from `elk` and was ported into its own extension and is now the official generator for and opinionated OpenAPI Specification document. You can use this to rapidly develop and document a RESTful HTTP server. There will be a new extension released soon providing a generated implementation integrating for the document provided by `entoas` using `ent`. - **[entgql](https://github.com/ent/contrib/tree/master/entgql)** This extension helps users build [GraphQL](https://graphql.org/) servers from Ent schemas. `entgql` integrates with [gqlgen](https://github.com/99designs/gqlgen), a popular, schema-first Go library for building GraphQL servers. The extension includes the generation of type-safe GraphQL filters, which enable users to effortlessly map GraphQL queries to Ent queries. Follow [this tutorial](https://entgo.io/docs/tutorial-todo-gql) to get started. - **[entproto](https://github.com/ent/contrib/tree/master/entproto)** `entproto` generates Protobuf message definitions and gRPC service definitions from Ent schemas. The project also includes `protoc-gen-entgrpc`, a `protoc` (Protobuf compiler) plugin that is used to generate a working implementation of the gRPC service definition generated by Entproto. In this manner, we can easily create a gRPC server that can serve requests to our service without writing any code (aside from defining the Ent schema)! To learn how to use and set up `entproto`, read [this tutorial](https://entgo.io/docs/grpc-intro). For more background you can read [this blog post](https://entgo.io/blog/2021/03/18/generating-a-grpc-server-with-ent), or [this blog post](https://entgo.io/blog/2021/06/28/gprc-ready-for-use/) discussing more `entproto` features. - **[entviz](https://github.com/hedwigz/entviz)** `entviz` is an extension that generates visual diagrams from Ent schemas. These diagrams visualize the schema in a web browser, and stay updated as we continue coding. `entviz` can be configured in such a way that every time we regenerate the schema, the diagram is automatically updated, making it easy to view the changes being made. Learn how to integrate `entviz` in your project in [this blog post](https://entgo.io/blog/2021/08/26/visualizing-your-data-graph-using-entviz). ent-0.11.3/doc/md/faq.md000066400000000000000000000545751431500740500146370ustar00rootroot00000000000000--- id: faq title: Frequently Asked Questions (FAQ) sidebar_label: FAQ --- ## Questions [How to create an entity from a struct `T`?](#how-to-create-an-entity-from-a-struct-t) [How to create a struct (or a mutation) level validator?](#how-to-create-a-mutation-level-validator) [How to write an audit-log extension?](#how-to-write-an-audit-log-extension) [How to write custom predicates?](#how-to-write-custom-predicates) [How to add custom predicates to the codegen assets?](#how-to-add-custom-predicates-to-the-codegen-assets) [How to define a network address field in PostgreSQL?](#how-to-define-a-network-address-field-in-postgresql) [How to customize time fields to type `DATETIME` in MySQL?](#how-to-customize-time-fields-to-type-datetime-in-mysql) [How to use a custom generator of IDs?](#how-to-use-a-custom-generator-of-ids) [How to use a custom XID globally unique ID?](#how-to-use-a-custom-xid-globally-unique-id) [How to define a spatial data type field in MySQL?](#how-to-define-a-spatial-data-type-field-in-mysql) [How to extend the generated models?](#how-to-extend-the-generated-models) [How to extend the generated builders?](#how-to-extend-the-generated-builders) [How to store Protobuf objects in a BLOB column?](#how-to-store-protobuf-objects-in-a-blob-column) [How to add `CHECK` constraints to table?](#how-to-add-check-constraints-to-table) [How to define a custom precision numeric field?](#how-to-define-a-custom-precision-numeric-field) [How to configure two or more `DB` to separate read and write?](#how-to-configure-two-or-more-db-to-separate-read-and-write) [How to change the character set and/or collation of a MySQL table?](#how-to-change-the-character-set-andor-collation-of-a-mysql-table) ## Answers #### How to create an entity from a struct `T`? The different builders don't support the option of setting the entity fields (or edges) from a given struct `T`. The reason is that there's no way to distinguish between zero/real values when updating the database (for example, `&ent.T{Age: 0, Name: ""}`). Setting these values, may set incorrect values in the database or update unnecessary columns. However, the [external template](templates.md) option lets you extend the default code-generation assets by adding custom logic. For example, in order to generate a method for each of the create-builders, that accepts a struct as an input and configure the builder, use the following template: ```gotemplate {{ range $n := $.Nodes }} {{ $builder := $n.CreateName }} {{ $receiver := receiver $builder }} func ({{ $receiver }} *{{ $builder }}) Set{{ $n.Name }}(input *{{ $n.Name }}) *{{ $builder }} { {{- range $f := $n.Fields }} {{- $setter := print "Set" $f.StructField }} {{ $receiver }}.{{ $setter }}(input.{{ $f.StructField }}) {{- end }} return {{ $receiver }} } {{ end }} ``` #### How to create a mutation level validator? In order to implement a mutation-level validator, you can either use [schema hooks](hooks.md#schema-hooks) for validating changes applied on one entity type, or use [transaction hooks](transactions.md#hooks) for validating mutations that being applied on multiple entity types (e.g. a GraphQL mutation). For example: ```go // A VersionHook is a dummy example for a hook that validates the "version" field // is incremented by 1 on each update. Note that this is just a dummy example, and // it doesn't promise consistency in the database. func VersionHook() ent.Hook { type OldSetVersion interface { SetVersion(int) Version() (int, bool) OldVersion(context.Context) (int, error) } return func(next ent.Mutator) ent.Mutator { return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { ver, ok := m.(OldSetVersion) if !ok { return next.Mutate(ctx, m) } oldV, err := ver.OldVersion(ctx) if err != nil { return nil, err } curV, exists := ver.Version() if !exists { return nil, fmt.Errorf("version field is required in update mutation") } if curV != oldV+1 { return nil, fmt.Errorf("version field must be incremented by 1") } // Add an SQL predicate that validates the "version" column is equal // to "oldV" (ensure it wasn't changed during the mutation by others). return next.Mutate(ctx, m) }) } } ``` #### How to write an audit-log extension? The preferred way for writing such an extension is to use [ent.Mixin](schema-mixin.md). Use the `Fields` option for setting the fields that are shared between all schemas that import the mixed-schema, and use the `Hooks` option for attaching a mutation-hook for all mutations that are being applied on these schemas. Here's an example, based on a discussion in the [repository issue-tracker](https://github.com/ent/ent/issues/830): ```go // AuditMixin implements the ent.Mixin for sharing // audit-log capabilities with package schemas. type AuditMixin struct{ mixin.Schema } // Fields of the AuditMixin. func (AuditMixin) Fields() []ent.Field { return []ent.Field{ field.Time("created_at"). Immutable(). Default(time.Now), field.Int("created_by"). Optional(), field.Time("updated_at"). Default(time.Now). UpdateDefault(time.Now), field.Int("updated_by"). Optional(), } } // Hooks of the AuditMixin. func (AuditMixin) Hooks() []ent.Hook { return []ent.Hook{ hooks.AuditHook, } } // A AuditHook is an example for audit-log hook. func AuditHook(next ent.Mutator) ent.Mutator { // AuditLogger wraps the methods that are shared between all mutations of // schemas that embed the AuditLog mixin. The variable "exists" is true, if // the field already exists in the mutation (e.g. was set by a different hook). type AuditLogger interface { SetCreatedAt(time.Time) CreatedAt() (value time.Time, exists bool) SetCreatedBy(int) CreatedBy() (id int, exists bool) SetUpdatedAt(time.Time) UpdatedAt() (value time.Time, exists bool) SetUpdatedBy(int) UpdatedBy() (id int, exists bool) } return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { ml, ok := m.(AuditLogger) if !ok { return nil, fmt.Errorf("unexpected audit-log call from mutation type %T", m) } usr, err := viewer.UserFromContext(ctx) if err != nil { return nil, err } switch op := m.Op(); { case op.Is(ent.OpCreate): ml.SetCreatedAt(time.Now()) if _, exists := ml.CreatedBy(); !exists { ml.SetCreatedBy(usr.ID) } case op.Is(ent.OpUpdateOne | ent.OpUpdate): ml.SetUpdatedAt(time.Now()) if _, exists := ml.UpdatedBy(); !exists { ml.SetUpdatedBy(usr.ID) } } return next.Mutate(ctx, m) }) } ``` #### How to write custom predicates? Users can provide custom predicates to apply on the query before it's executed. For example: ```go pets := client.Pet. Query(). Where(predicate.Pet(func(s *sql.Selector) { s.Where(sql.InInts(pet.OwnerColumn, 1, 2, 3)) })). AllX(ctx) users := client.User. Query(). Where(predicate.User(func(s *sql.Selector) { s.Where(sqljson.ValueContains(user.FieldTags, "tag")) })). AllX(ctx) ``` For more examples, go to the [predicates](predicates.md#custom-predicates) page, or search in the repository issue-tracker for more advance examples like [issue-842](https://github.com/ent/ent/issues/842#issuecomment-707896368). #### How to add custom predicates to the codegen assets? The [template](templates.md) option enables the capability for extending or overriding the default codegen assets. In order to generate a type-safe predicate for the [example above](#how-to-write-custom-predicates), use the template option for doing it as follows: ```gotemplate {{/* A template that adds the "Glob" predicate for all string fields. */}} {{ define "where/additional/strings" }} {{ range $f := $.Fields }} {{ if $f.IsString }} {{ $func := print $f.StructField "Glob" }} // {{ $func }} applies the Glob predicate on the {{ quote $f.Name }} field. func {{ $func }}(pattern string) predicate.{{ $.Name }} { return predicate.{{ $.Name }}(func(s *sql.Selector) { s.Where(sql.P(func(b *sql.Builder) { b.Ident(s.C({{ $f.Constant }})).WriteString(" glob" ).Arg(pattern) })) }) } {{ end }} {{ end }} {{ end }} ``` #### How to define a network address field in PostgreSQL? The [GoType](schema-fields.md#go-type) and the [SchemaType](schema-fields.md#database-type) options allow users to define database-specific fields. For example, in order to define a [`macaddr`](https://www.postgresql.org/docs/13/datatype-net-types.html#DATATYPE-MACADDR) field, use the following configuration: ```go func (T) Fields() []ent.Field { return []ent.Field{ field.String("mac"). GoType(&MAC{}). SchemaType(map[string]string{ dialect.Postgres: "macaddr", }). Validate(func(s string) error { _, err := net.ParseMAC(s) return err }), } } // MAC represents a physical hardware address. type MAC struct { net.HardwareAddr } // Scan implements the Scanner interface. func (m *MAC) Scan(value any) (err error) { switch v := value.(type) { case nil: case []byte: m.HardwareAddr, err = net.ParseMAC(string(v)) case string: m.HardwareAddr, err = net.ParseMAC(v) default: err = fmt.Errorf("unexpected type %T", v) } return } // Value implements the driver Valuer interface. func (m MAC) Value() (driver.Value, error) { return m.HardwareAddr.String(), nil } ``` Note that, if the database doesn't support the `macaddr` type (e.g. SQLite on testing), the field fallback to its native type (i.e. `string`). `inet` example: ```go func (T) Fields() []ent.Field { return []ent.Field{ field.String("ip"). GoType(&Inet{}). SchemaType(map[string]string{ dialect.Postgres: "inet", }). Validate(func(s string) error { if net.ParseIP(s) == nil { return fmt.Errorf("invalid value for ip %q", s) } return nil }), } } // Inet represents a single IP address type Inet struct { net.IP } // Scan implements the Scanner interface func (i *Inet) Scan(value any) (err error) { switch v := value.(type) { case nil: case []byte: if i.IP = net.ParseIP(string(v)); i.IP == nil { err = fmt.Errorf("invalid value for ip %q", v) } case string: if i.IP = net.ParseIP(v); i.IP == nil { err = fmt.Errorf("invalid value for ip %q", v) } default: err = fmt.Errorf("unexpected type %T", v) } return } // Value implements the driver Valuer interface func (i Inet) Value() (driver.Value, error) { return i.IP.String(), nil } ``` #### How to customize time fields to type `DATETIME` in MySQL? `Time` fields use the MySQL `TIMESTAMP` type in the schema creation by default, and this type has a range of '1970-01-01 00:00:01' UTC to '2038-01-19 03:14:07' UTC (see, [MySQL docs](https://dev.mysql.com/doc/refman/5.6/en/datetime.html)). In order to customize time fields for a wider range, use the MySQL `DATETIME` as follows: ```go field.Time("birth_date"). Optional(). SchemaType(map[string]string{ dialect.MySQL: "datetime", }), ``` #### How to use a custom generator of IDs? If you're using a custom ID generator instead of using auto-incrementing IDs in your database (e.g. Twitter's [Snowflake](https://github.com/twitter-archive/snowflake/tree/snowflake-2010)), you will need to write a custom ID field which automatically calls the generator on resource creation. To achieve this, you can either make use of `DefaultFunc` or of schema hooks - depending on your use case. If the generator does not return an error, `DefaultFunc` is more concise, whereas setting a hook on resource creation will allow you to capture errors as well. An example of how to use `DefaultFunc` can be seen in the section regarding [the ID field](schema-fields.md#id-field). Here is an example of how to use a custom generator with hooks, taking as an example [sonyflake](https://github.com/sony/sonyflake). ```go // BaseMixin to be shared will all different schemas. type BaseMixin struct { mixin.Schema } // Fields of the Mixin. func (BaseMixin) Fields() []ent.Field { return []ent.Field{ field.Uint64("id"), } } // Hooks of the Mixin. func (BaseMixin) Hooks() []ent.Hook { return []ent.Hook{ hook.On(IDHook(), ent.OpCreate), } } func IDHook() ent.Hook { sf := sonyflake.NewSonyflake(sonyflake.Settings{}) type IDSetter interface { SetID(uint64) } return func(next ent.Mutator) ent.Mutator { return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) { is, ok := m.(IDSetter) if !ok { return nil, fmt.Errorf("unexpected mutation %T", m) } id, err := sf.NextID() if err != nil { return nil, err } is.SetID(id) return next.Mutate(ctx, m) }) } } // User holds the schema definition for the User entity. type User struct { ent.Schema } // Mixin of the User. func (User) Mixin() []ent.Mixin { return []ent.Mixin{ // Embed the BaseMixin in the user schema. BaseMixin{}, } } ``` #### How to use a custom XID globally unique ID? Package [xid](https://github.com/rs/xid) is a globally unique ID generator library that uses the [Mongo Object ID](https://docs.mongodb.org/manual/reference/object-id/) algorithm to generate a 12 byte, 20 character ID with no configuration. The xid package comes with [database/sql](https://pkg.go.dev/database/sql) `sql.Scanner` and `driver.Valuer` interfaces required by Ent for serialization. To store an XID in any string field use the [GoType](schema-fields.md#go-type) schema configuration: ```go // Fields of type T. func (T) Fields() []ent.Field { return []ent.Field{ field.String("id"). GoType(xid.ID{}). DefaultFunc(xid.New), } } ``` Or as a reusable [Mixin](schema-mixin.md) across multiple schemas: ```go package schema import ( "entgo.io/ent" "entgo.io/ent/schema/field" "entgo.io/ent/schema/mixin" "github.com/rs/xid" ) // BaseMixin to be shared will all different schemas. type BaseMixin struct { mixin.Schema } // Fields of the User. func (BaseMixin) Fields() []ent.Field { return []ent.Field{ field.String("id"). GoType(xid.ID{}). DefaultFunc(xid.New), } } // User holds the schema definition for the User entity. type User struct { ent.Schema } // Mixin of the User. func (User) Mixin() []ent.Mixin { return []ent.Mixin{ // Embed the BaseMixin in the user schema. BaseMixin{}, } } ``` In order to use extended identifiers (XIDs) with gqlgen, follow the configuration mentioned in the [issue tracker](https://github.com/ent/ent/issues/1526#issuecomment-831034884). #### How to define a spatial data type field in MySQL? The [GoType](schema-fields.md#go-type) and the [SchemaType](schema-fields.md#database-type) options allow users to define database-specific fields. For example, in order to define a [`POINT`](https://dev.mysql.com/doc/refman/8.0/en/spatial-type-overview.html) field, use the following configuration: ```go // Fields of the Location. func (Location) Fields() []ent.Field { return []ent.Field{ field.String("name"), field.Other("coords", &Point{}). SchemaType(Point{}.SchemaType()), } } ``` ```go package schema import ( "database/sql/driver" "fmt" "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "github.com/paulmach/orb" "github.com/paulmach/orb/encoding/wkb" ) // A Point consists of (X,Y) or (Lat, Lon) coordinates // and it is stored in MySQL the POINT spatial data type. type Point [2]float64 // Scan implements the Scanner interface. func (p *Point) Scan(value any) error { bin, ok := value.([]byte) if !ok { return fmt.Errorf("invalid binary value for point") } var op orb.Point if err := wkb.Scanner(&op).Scan(bin[4:]); err != nil { return err } p[0], p[1] = op.X(), op.Y() return nil } // Value implements the driver Valuer interface. func (p Point) Value() (driver.Value, error) { op := orb.Point{p[0], p[1]} return wkb.Value(op).Value() } // FormatParam implements the sql.ParamFormatter interface to tell the SQL // builder that the placeholder for a Point parameter needs to be formatted. func (p Point) FormatParam(placeholder string, info *sql.StmtInfo) string { if info.Dialect == dialect.MySQL { return "ST_GeomFromWKB(" + placeholder + ")" } return placeholder } // SchemaType defines the schema-type of the Point object. func (Point) SchemaType() map[string]string { return map[string]string{ dialect.MySQL: "POINT", } } ``` A full example exists in the [example repository](https://github.com/a8m/entspatial). #### How to extend the generated models? Ent supports extending generated types (both global types and models) using custom templates. For example, in order to add additional struct fields or methods to the generated model, we can override the `model/fields/additional` template like in this [example](https://github.com/ent/ent/blob/dd4792f5b30bdd2db0d9a593a977a54cb3f0c1ce/examples/entcpkg/ent/template/static.tmpl). If your custom fields/methods require additional imports, you can add those imports using custom templates as well: ```gotemplate {{- define "import/additional/field_types" -}} "github.com/path/to/your/custom/type" {{- end -}} {{- define "import/additional/client_dependencies" -}} "github.com/path/to/your/custom/type" {{- end -}} ``` #### How to extend the generated builders? See the *[Injecting External Dependencies](code-gen.md#external-dependencies)* section, or follow the example on [GitHub](https://github.com/ent/ent/tree/master/examples/entcpkg). #### How to store Protobuf objects in a BLOB column? Assuming we have a Protobuf message defined: ```protobuf syntax = "proto3"; package pb; option go_package = "project/pb"; message Hi { string Greeting = 1; } ``` We add receiver methods to the generated protobuf struct such that it implements [ValueScanner](https://pkg.go.dev/entgo.io/ent/schema/field#ValueScanner) ```go func (x *Hi) Value() (driver.Value, error) { return proto.Marshal(x) } func (x *Hi) Scan(src any) error { if src == nil { return nil } if b, ok := src.([]byte); ok { if err := proto.Unmarshal(b, x); err != nil { return err } return nil } return fmt.Errorf("unexpected type %T", src) } ``` We add a new `field.Bytes` to our schema, setting the generated protobuf struct as its underlying `GoType`: ```go // Fields of the Message. func (Message) Fields() []ent.Field { return []ent.Field{ field.Bytes("hi"). GoType(&pb.Hi{}), } } ``` Test that it works: ```go package main import ( "context" "testing" "project/ent/enttest" "project/pb" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/require" ) func TestMain(t *testing.T) { client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") defer client.Close() msg := client.Message.Create(). SetHi(&pb.Hi{ Greeting: "hello", }). SaveX(context.TODO()) ret := client.Message.GetX(context.TODO(), msg.ID) require.Equal(t, "hello", ret.Hi.Greeting) } ``` #### How to add `CHECK` constraints to table? The [`entsql.Annotation`](schema-annotations.md) option allows adding custom `CHECK` constraints to the `CREATE TABLE` statement. In order to add `CHECK` constraints to your schema, use the following example: ```go func (User) Annotations() []schema.Annotation { return []schema.Annotation{ &entsql.Annotation{ // The `Check` option allows adding an // unnamed CHECK constraint to table DDL. Check: "website <> 'entgo.io'", // The `Checks` option allows adding multiple CHECK constraints // to table creation. The keys are used as the constraint names. Checks: map[string]string{ "valid_nickname": "nickname <> firstname", "valid_firstname": "length(first_name) > 1", }, }, } } ``` #### How to define a custom precision numeric field? Using [GoType](schema-fields.md#go-type) and [SchemaType](schema-fields.md#database-type) it is possible to define custom precision numeric fields. For example, defining a field that uses [big.Int](https://pkg.go.dev/math/big). ```go func (T) Fields() []ent.Field { return []ent.Field{ field.Int("precise"). GoType(new(BigInt)). SchemaType(map[string]string{ dialect.SQLite: "numeric(78, 0)", dialect.Postgres: "numeric(78, 0)", }), } } type BigInt struct { big.Int } func (b *BigInt) Scan(src any) error { var i sql.NullString if err := i.Scan(src); err != nil { return err } if !i.Valid { return nil } if _, ok := b.Int.SetString(i.String, 10); ok { return nil } return fmt.Errorf("could not scan type %T with value %v into BigInt", src, src) } func (b *BigInt) Value() (driver.Value, error) { return b.String(), nil } ``` #### How to configure two or more `DB` to separate read and write? You can wrap the `dialect.Driver` with your own driver and implement this logic. For example. You can extend it, add support for multiple read replicas and add some load-balancing magic. ```go func main() { // ... wd, err := sql.Open(dialect.MySQL, "root:pass@tcp()/?parseTime=True") if err != nil { log.Fatal(err) } rd, err := sql.Open(dialect.MySQL, "readonly:pass@tcp()/?parseTime=True") if err != nil { log.Fatal(err) } client := ent.NewClient(ent.Driver(&multiDriver{w: wd, r: rd})) defer client.Close() // Use the client here. } type multiDriver struct { r, w dialect.Driver } var _ dialect.Driver = (*multiDriver)(nil) func (d *multiDriver) Query(ctx context.Context, query string, args, v any) error { return d.r.Query(ctx, query, args, v) } func (d *multiDriver) Exec(ctx context.Context, query string, args, v any) error { return d.w.Exec(ctx, query, args, v) } func (d *multiDriver) Tx(ctx context.Context) (dialect.Tx, error) { return d.w.Tx(ctx) } func (d *multiDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) { return d.w.(interface { BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error) }).BeginTx(ctx, opts) } func (d *multiDriver) Close() error { rerr := d.r.Close() werr := d.w.Close() if rerr != nil { return rerr } if werr != nil { return werr } return nil } ``` #### How to change the character set and/or collation of a MySQL table? By default for MySQL the character set `utf8mb4` is used and the collation of `utf8mb4_bin`. However if you'd like to change the schema's character set and/or collation you need to use an annotation. Here's an example where we set the character set to `ascii` and the collation to `ascii_general_ci`. ```go // Annotations of the Entity. func (Entity) Annotations() []schema.Annotation { return []schema.Annotation{ entsql.Annotation{ Charset: "ascii", Collation: "ascii_general_ci", }, } } ``` ent-0.11.3/doc/md/features.md000066400000000000000000000177121431500740500156760ustar00rootroot00000000000000--- id: feature-flags title: Feature Flags sidebar_label: Feature Flags --- The framework provides a collection of code-generation features that be added or removed using flags. ## Usage Feature flags can be provided either by CLI flags or as arguments to the `gen` package. #### CLI ```console go run -mod=mod entgo.io/ent/cmd/ent generate --feature privacy,entql ./ent/schema ``` #### Go ```go // +build ignore package main import ( "log" "text/template" "entgo.io/ent/entc" "entgo.io/ent/entc/gen" ) func main() { err := entc.Generate("./schema", &gen.Config{ Features: []gen.Feature{ gen.FeaturePrivacy, gen.FeatureEntQL, }, Templates: []*gen.Template{ gen.MustParse(gen.NewTemplate("static"). Funcs(template.FuncMap{"title": strings.ToTitle}). ParseFiles("template/static.tmpl")), }, }) if err != nil { log.Fatalf("running ent codegen: %v", err) } } ``` ## List of Features ### Auto-Solve Merge Conflicts The `schema/snapshot` option tells `entc` (ent codegen) to store a snapshot of the latest schema in an internal package, and use it to automatically solve merge conflicts when user's schema can't be built. This option can be added to a project using the `--feature schema/snapshot` flag, but please see [ent/ent/issues/852](https://github.com/ent/ent/issues/852) to get more context about it. ### Privacy Layer The privacy layer allows configuring privacy policy for queries and mutations of entities in the database. This option can be added to a project using the `--feature privacy` flag, and you can learn more about in the [privacy](privacy.md) documentation. ### EntQL Filtering The `entql` option provides a generic and dynamic filtering capability at runtime for the different query builders. This option can be added to a project using the `--feature entql` flag, and you can learn more about in the [privacy](privacy.md#multi-tenancy) documentation. ### Named Edges The `namedges` option provides an API for preloading edges with custom names. This option can be added to a project using the `--feature namedges` flag, and you can learn more about in the [Eager Loading](eager-load.mdx) documentation. ### Schema Config The `sql/schemaconfig` option lets you pass alternate SQL database names to models. This is useful when your models don't all live under one database and are spread out across different schemas. This option can be added to a project using the `--feature sql/schemaconfig` flag. Once you generate the code, you can now use a new option as such: ```go c, err := ent.Open(dialect, conn, ent.AlternateSchema(ent.SchemaConfig{ User: "usersdb", Car: "carsdb", })) c.User.Query().All(ctx) // SELECT * FROM `usersdb`.`users` c.Car.Query().All(ctx) // SELECT * FROM `carsdb`.`cars` ``` ### Row-level Locks The `sql/lock` option lets configure row-level locking using the SQL `SELECT ... FOR {UPDATE | SHARE}` syntax. This option can be added to a project using the `--feature sql/lock` flag. ```go tx, err := client.Tx(ctx) if err != nil { log.Fatal(err) } tx.Pet.Query(). Where(pet.Name(name)). ForUpdate(). Only(ctx) tx.Pet.Query(). Where(pet.ID(id)). ForShare( sql.WithLockTables(pet.Table), sql.WithLockAction(sql.NoWait), ). Only(ctx) ``` ### Custom SQL Modifiers The `sql/modifier` option lets add custom SQL modifiers to the builders and mutate the statements before they are executed. This option can be added to a project using the `--feature sql/modifier` flag. #### Modify Example 1 ```go client.Pet. Query(). Modify(func(s *sql.Selector) { s.Select("SUM(LENGTH(name))") }). IntX(ctx) ``` The above code will produce the following SQL query: ```sql SELECT SUM(LENGTH(name)) FROM `pet` ``` #### Modify Example 2 ```go var p1 []struct { ent.Pet NameLength int `sql:"length"` } client.Pet.Query(). Order(ent.Asc(pet.FieldID)). Modify(func(s *sql.Selector) { s.AppendSelect("LENGTH(name)") }). ScanX(ctx, &p1) ``` The above code will produce the following SQL query: ```sql SELECT `pet`.*, LENGTH(name) FROM `pet` ORDER BY `pet`.`id` ASC ``` #### Modify Example 3 ```go var v []struct { Count int `json:"count"` Price int `json:"price"` CreatedAt time.Time `json:"created_at"` } client.User. Query(). Where( user.CreatedAtGT(x), user.CreatedAtLT(y), ). Modify(func(s *sql.Selector) { s.Select( sql.As(sql.Count("*"), "count"), sql.As(sql.Sum("price"), "price"), sql.As("DATE(created_at)", "created_at"), ). GroupBy("DATE(created_at)"). OrderBy(sql.Desc("DATE(created_at)")) }). ScanX(ctx, &v) ``` The above code will produce the following SQL query: ```sql SELECT COUNT(*) AS `count`, SUM(`price`) AS `price`, DATE(created_at) AS `created_at` FROM `users` WHERE `created_at` > x AND `created_at` < y GROUP BY DATE(created_at) ORDER BY DATE(created_at) DESC ``` #### Modify Example 4 ```go var gs []struct { ent.Group UsersCount int `sql:"users_count"` } client.Group.Query(). Order(ent.Asc(group.FieldID)). Modify(func(s *sql.Selector) { t := sql.Table(group.UsersTable) s.LeftJoin(t). On( s.C(group.FieldID), t.C(group.UsersPrimaryKey[1]), ). // Append the "users_count" column to the selected columns. AppendSelect( sql.As(sql.Count(t.C(group.UsersPrimaryKey[1])), "users_count"), ). GroupBy(s.C(group.FieldID)) }). ScanX(ctx, &gs) ``` The above code will produce the following SQL query: ```sql SELECT `groups`.*, COUNT(`t1`.`group_id`) AS `users_count` FROM `groups` LEFT JOIN `user_groups` AS `t1` ON `groups`.`id` = `t1`.`group_id` GROUP BY `groups`.`id` ORDER BY `groups`.`id` ASC ``` #### Modify Example 5 ```go client.User.Update(). Modify(func(s *sql.UpdateBuilder) { s.Set(user.FieldName, sql.Expr(fmt.Sprintf("UPPER(%s)", user.FieldName))) }). ExecX(ctx) ``` The above code will produce the following SQL query: ```sql UPDATE `users` SET `name` = UPPER(`name`) ``` #### Modify Example 6 ```go client.User.Update(). Modify(func(u *sql.UpdateBuilder) { u.Set(user.FieldID, sql.ExprFunc(func(b *sql.Builder) { b.Ident(user.FieldID).WriteOp(sql.OpAdd).Arg(1) })) u.OrderBy(sql.Desc(user.FieldID)) }). ExecX(ctx) ``` The above code will produce the following SQL query: ```sql UPDATE `users` SET `id` = `id` + 1 ORDER BY `id` DESC ``` ### SQL Raw API The `sql/execquery` option allows executing statements using the `ExecContext`/`QueryContext` methods of the underlying driver. For full documentation, see: [DB.ExecContext](https://pkg.go.dev/database/sql#DB.ExecContext), and [DB.QueryContext](https://pkg.go.dev/database/sql#DB.QueryContext). ```go // From ent.Client. if _, err := client.ExecContext(ctx, "TRUNCATE t1"); err != nil { return err } // From ent.Tx. tx, err := client.Tx(ctx) if err != nil { return err } if err := tx.User.Create().Exec(ctx); err != nil { return err } if _, err := tx.ExecContext("SAVEPOINT user_created"); err != nil { return err } // ... ``` :::warning Note Statements executed using `ExecContext`/`QueryContext` do not go through Ent, and may skip fundamental layers in your application such as hooks, privacy (authorization), and validators. ::: ### Upsert The `sql/upsert` option lets configure upsert and bulk-upsert logic using the SQL `ON CONFLICT` / `ON DUPLICATE KEY` syntax. For full documentation, go to the [Upsert API](crud.md#upsert-one). This option can be added to a project using the `--feature sql/upsert` flag. ```go // Use the new values that were set on create. id, err := client.User. Create(). SetAge(30). SetName("Ariel"). OnConflict(). UpdateNewValues(). ID(ctx) // In PostgreSQL, the conflict target is required. err := client.User. Create(). SetAge(30). SetName("Ariel"). OnConflictColumns(user.FieldName). UpdateNewValues(). Exec(ctx) // Bulk upsert is also supported. client.User. CreateBulk(builders...). OnConflict( sql.ConflictWhere(...), sql.UpdateWhere(...), ). UpdateNewValues(). Exec(ctx) // INSERT INTO "users" (...) VALUES ... ON CONFLICT WHERE ... DO UPDATE SET ... WHERE ... ``` ent-0.11.3/doc/md/generating-ent-schemas.md000066400000000000000000000113501431500740500204000ustar00rootroot00000000000000--- id: generating-ent-schemas title: Generating Schemas --- ## Introduction To facilitate the creation of tooling that generates `ent.Schema`s programmatically, `ent` supports the manipulation of the `schema/` directory using the `entgo.io/contrib/schemast` package. ## API ### Loading In order to manipulate an existing schema directory we must first load it into a `schemast.Context` object: ```go package main import ( "fmt" "log" "entgo.io/contrib/schemast" ) func main() { ctx, err := schemast.Load("./ent/schema") if err != nil { log.Fatalf("failed: %v", err) } if ctx.HasType("user") { fmt.Println("schema directory contains a schema named User!") } } ``` ### Printing To print back out our context to a target directory, use `schemast.Print`: ```go package main import ( "log" "entgo.io/contrib/schemast" ) func main() { ctx, err := schemast.Load("./ent/schema") if err != nil { log.Fatalf("failed: %v", err) } // A no-op since we did not manipulate the Context at all. if err := schemast.Print("./ent/schema"); err != nil { log.Fatalf("failed: %v", err) } } ``` ### Mutators To mutate the `ent/schema` directory, we can use `schemast.Mutate`, which takes a list of `schemast.Mutator`s to apply to the context: ```go package schemast // Mutator changes a Context. type Mutator interface { Mutate(ctx *Context) error } ``` Currently, only a single type of `schemast.Mutator` is implemented, `UpsertSchema`: ```go package schemast // UpsertSchema implements Mutator. UpsertSchema will add to the Context the type named // Name if not present and rewrite the type's Fields, Edges, Indexes and Annotations methods. type UpsertSchema struct { Name string Fields []ent.Field Edges []ent.Edge Indexes []ent.Index Annotations []schema.Annotation } ``` To use it: ```go package main import ( "log" "entgo.io/contrib/schemast" "entgo.io/ent" "entgo.io/ent/schema/field" ) func main() { ctx, err := schemast.Load("./ent/schema") if err != nil { log.Fatalf("failed: %v", err) } mutations := []schemast.Mutator{ &schemast.UpsertSchema{ Name: "User", Fields: []ent.Field{ field.String("name"), }, }, &schemast.UpsertSchema{ Name: "Team", Fields: []ent.Field{ field.String("name"), }, }, } err = schemast.Mutate(ctx, mutations...) if err := ctx.Print("./ent/schema"); err != nil { log.Fatalf("failed: %v", err) } } ``` After running this program, observe two new files exist in the schema directory: `user.go` and `team.go`: ```go // user.go package schema import ( "entgo.io/ent" "entgo.io/ent/schema" "entgo.io/ent/schema/field" ) type User struct { ent.Schema } func (User) Fields() []ent.Field { return []ent.Field{field.String("name")} } func (User) Edges() []ent.Edge { return nil } func (User) Annotations() []schema.Annotation { return nil } ``` ```go package schema import ( "entgo.io/ent" "entgo.io/ent/schema" "entgo.io/ent/schema/field" ) type Team struct { ent.Schema } func (Team) Fields() []ent.Field { return []ent.Field{field.String("name")} } func (Team) Edges() []ent.Edge { return nil } func (Team) Annotations() []schema.Annotation { return nil } ``` ### Working with Edges Edges are defined in `ent` this way: ```go edge.To("edge_name", OtherSchema.Type) ``` This syntax relies on the fact that the `OtherSchema` struct already exists when we define the edge so we can refer to its `Type` method. When we are generating schemas programmatically, obviously we need somehow to describe the edge to the code-generator before the type definitions exist. To do this you can do something like: ```go type placeholder struct { ent.Schema } func withType(e ent.Edge, typeName string) ent.Edge { e.Descriptor().Type = typeName return e } func newEdgeTo(edgeName, otherType string) ent.Edge { // we pass a placeholder type to the edge constructor: e := edge.To(edgeName, placeholder.Type) // then we override the other type's name directly on the edge descriptor: return withType(e, otherType) } ``` ## Examples The `protoc-gen-ent` ([doc](https://github.com/ent/contrib/tree/master/entproto/cmd/protoc-gen-ent)) is a protoc plugin that programmatically generates `ent.Schema`s from .proto files, it uses the `schemast` to manipulate the target `schema` directory. To see how, [read the source code](https://github.com/ent/contrib/blob/master/entproto/cmd/protoc-gen-ent/main.go#L34). ## Caveats `schemast` is still experimental, APIs are subject to change in the future. In addition, a small portion of the `ent.Field` definition API is unsupported at this point in time, to see a full list of unsupported features see the [source code](https://github.com/ent/contrib/blob/aed7a43a3e54550c1dd9a1a066ce1236b4bae56c/schemast/field.go#L158). ent-0.11.3/doc/md/getting-started.md000077500000000000000000000362361431500740500171720ustar00rootroot00000000000000--- id: getting-started title: Quick Introduction sidebar_label: Quick Introduction --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; **ent** is a simple, yet powerful entity framework for Go, that makes it easy to build and maintain applications with large data-models and sticks with the following principles: - Easily model database schema as a graph structure. - Define schema as a programmatic Go code. - Static typing based on code generation. - Database queries and graph traversals are easy to write. - Simple to extend and customize using Go templates. ![gopher-schema-as-code](https://entgo.io/images/assets/gopher-schema-as-code.png) ## Setup A Go Environment If your project directory is outside [GOPATH](https://github.com/golang/go/wiki/GOPATH) or you are not familiar with GOPATH, setup a [Go module](https://github.com/golang/go/wiki/Modules#quick-start) project as follows: ```console go mod init ``` ## Create Your First Schema Go to the root directory of your project, and run: ```console go run -mod=mod entgo.io/ent/cmd/ent init User ``` The command above will generate the schema for `User` under `/ent/schema/` directory: ```go title="/ent/schema/user.go" package schema import "entgo.io/ent" // User holds the schema definition for the User entity. type User struct { ent.Schema } // Fields of the User. func (User) Fields() []ent.Field { return nil } // Edges of the User. func (User) Edges() []ent.Edge { return nil } ``` Add 2 fields to the `User` schema: ```go title="/ent/schema/user.go" package schema import ( "entgo.io/ent" "entgo.io/ent/schema/field" ) // Fields of the User. func (User) Fields() []ent.Field { return []ent.Field{ field.Int("age"). Positive(), field.String("name"). Default("unknown"), } } ``` Run `go generate` from the root directory of the project as follows: ```go go generate ./ent ``` This produces the following files: ```console {12-20} ent ├── client.go ├── config.go ├── context.go ├── ent.go ├── generate.go ├── mutation.go ... truncated ├── schema │ └── user.go ├── tx.go ├── user │ ├── user.go │ └── where.go ├── user.go ├── user_create.go ├── user_delete.go ├── user_query.go └── user_update.go ``` ## Create Your First Entity To get started, create a new `ent.Client`. ```go title="/start/start.go" package main import ( "context" "log" "/ent" _ "github.com/mattn/go-sqlite3" ) func main() { client, err := ent.Open("sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") if err != nil { log.Fatalf("failed opening connection to sqlite: %v", err) } defer client.Close() // Run the auto migration tool. if err := client.Schema.Create(context.Background()); err != nil { log.Fatalf("failed creating schema resources: %v", err) } } ``` ```go title="/start/start.go" package main import ( "context" "log" "/ent" _ "github.com/lib/pq" ) func main() { client, err := ent.Open("postgres","host= port= user= dbname= password=") if err != nil { log.Fatalf("failed opening connection to postgres: %v", err) } defer client.Close() // Run the auto migration tool. if err := client.Schema.Create(context.Background()); err != nil { log.Fatalf("failed creating schema resources: %v", err) } } ``` ```go title="/start/start.go" package main import ( "context" "log" "/ent" _ "github.com/go-sql-driver/mysql" ) func main() { client, err := ent.Open("mysql", ":@tcp(:)/?parseTime=True") if err != nil { log.Fatalf("failed opening connection to mysql: %v", err) } defer client.Close() // Run the auto migration tool. if err := client.Schema.Create(context.Background()); err != nil { log.Fatalf("failed creating schema resources: %v", err) } } ``` Now, we're ready to create our user. Let's call this function `CreateUser` for the sake of example: ```go title="/start/start.go" func CreateUser(ctx context.Context, client *ent.Client) (*ent.User, error) { u, err := client.User. Create(). SetAge(30). SetName("a8m"). Save(ctx) if err != nil { return nil, fmt.Errorf("failed creating user: %w", err) } log.Println("user was created: ", u) return u, nil } ``` ## Query Your Entities `ent` generates a package for each entity schema that contains its predicates, default values, validators and additional information about storage elements (column names, primary keys, etc). ```go title="/start/start.go" package main import ( "log" "/ent" "/ent/user" ) func QueryUser(ctx context.Context, client *ent.Client) (*ent.User, error) { u, err := client.User. Query(). Where(user.Name("a8m")). // `Only` fails if no user found, // or more than 1 user returned. Only(ctx) if err != nil { return nil, fmt.Errorf("failed querying user: %w", err) } log.Println("user returned: ", u) return u, nil } ``` ## Add Your First Edge (Relation) In this part of the tutorial, we want to declare an edge (relation) to another entity in the schema. Let's create 2 additional entities named `Car` and `Group` with a few fields. We use `ent` CLI to generate the initial schemas: ```console go run -mod=mod entgo.io/ent/cmd/ent init Car Group ``` And then we add the rest of the fields manually: ```go title="/ent/schema/car.go" // Fields of the Car. func (Car) Fields() []ent.Field { return []ent.Field{ field.String("model"), field.Time("registered_at"), } } ``` ```go title="/ent/schema/group.go" // Fields of the Group. func (Group) Fields() []ent.Field { return []ent.Field{ field.String("name"). // Regexp validation for group name. Match(regexp.MustCompile("[a-zA-Z_]+$")), } } ``` Let's define our first relation. An edge from `User` to `Car` defining that a user can **have 1 or more** cars, but a car **has only one** owner (one-to-many relation). ![er-user-cars](https://entgo.io/images/assets/re_user_cars.png) Let's add the `"cars"` edge to the `User` schema, and run `go generate ./ent`: ```go title="/ent/schema/user.go" // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("cars", Car.Type), } } ``` We continue our example by creating 2 cars and adding them to a user. ```go title="/start/start.go" import ( "/ent" "/ent/car" "/ent/user" ) func CreateCars(ctx context.Context, client *ent.Client) (*ent.User, error) { // Create a new car with model "Tesla". tesla, err := client.Car. Create(). SetModel("Tesla"). SetRegisteredAt(time.Now()). Save(ctx) if err != nil { return nil, fmt.Errorf("failed creating car: %w", err) } log.Println("car was created: ", tesla) // Create a new car with model "Ford". ford, err := client.Car. Create(). SetModel("Ford"). SetRegisteredAt(time.Now()). Save(ctx) if err != nil { return nil, fmt.Errorf("failed creating car: %w", err) } log.Println("car was created: ", ford) // Create a new user, and add it the 2 cars. a8m, err := client.User. Create(). SetAge(30). SetName("a8m"). AddCars(tesla, ford). Save(ctx) if err != nil { return nil, fmt.Errorf("failed creating user: %w", err) } log.Println("user was created: ", a8m) return a8m, nil } ``` But what about querying the `cars` edge (relation)? Here's how we do it: ```go title="/start/start.go" import ( "log" "/ent" "/ent/car" ) func QueryCars(ctx context.Context, a8m *ent.User) error { cars, err := a8m.QueryCars().All(ctx) if err != nil { return fmt.Errorf("failed querying user cars: %w", err) } log.Println("returned cars:", cars) // What about filtering specific cars. ford, err := a8m.QueryCars(). Where(car.Model("Ford")). Only(ctx) if err != nil { return fmt.Errorf("failed querying user cars: %w", err) } log.Println(ford) return nil } ``` ## Add Your First Inverse Edge (BackRef) Assume we have a `Car` object and we want to get its owner; the user that this car belongs to. For this, we have another type of edge called "inverse edge" that is defined using the `edge.From` function. ![er-cars-owner](https://entgo.io/images/assets/re_cars_owner.png) The new edge created in the diagram above is translucent, to emphasize that we don't create another edge in the database. It's just a back-reference to the real edge (relation). Let's add an inverse edge named `owner` to the `Car` schema, reference it to the `cars` edge in the `User` schema, and run `go generate ./ent`. ```go title="/ent/schema/car.go" // Edges of the Car. func (Car) Edges() []ent.Edge { return []ent.Edge{ // Create an inverse-edge called "owner" of type `User` // and reference it to the "cars" edge (in User schema) // explicitly using the `Ref` method. edge.From("owner", User.Type). Ref("cars"). // setting the edge to unique, ensure // that a car can have only one owner. Unique(), } } ``` We'll continue the user/cars example above by querying the inverse edge. ```go title="/start/start.go" import ( "fmt" "log" "/ent" "/ent/user" ) func QueryCarUsers(ctx context.Context, a8m *ent.User) error { cars, err := a8m.QueryCars().All(ctx) if err != nil { return fmt.Errorf("failed querying user cars: %w", err) } // Query the inverse edge. for _, c := range cars { owner, err := c.QueryOwner().Only(ctx) if err != nil { return fmt.Errorf("failed querying car %q owner: %w", c.Model, err) } log.Printf("car %q owner: %q\n", c.Model, owner.Name) } return nil } ``` ## Create Your Second Edge We'll continue our example by creating a M2M (many-to-many) relationship between users and groups. ![er-group-users](https://entgo.io/images/assets/re_group_users.png) As you can see, each group entity can **have many** users, and a user can **be connected to many** groups; a simple "many-to-many" relationship. In the above illustration, the `Group` schema is the owner of the `users` edge (relation), and the `User` entity has a back-reference/inverse edge to this relationship named `groups`. Let's define this relationship in our schemas: ```go title="/ent/schema/group.go" // Edges of the Group. func (Group) Edges() []ent.Edge { return []ent.Edge{ edge.To("users", User.Type), } } ``` ```go title="/ent/schema/user.go" // Edges of the User. func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("cars", Car.Type), // Create an inverse-edge called "groups" of type `Group` // and reference it to the "users" edge (in Group schema) // explicitly using the `Ref` method. edge.From("groups", Group.Type). Ref("users"), } } ``` We run `ent` on the schema directory to re-generate the assets. ```console go generate ./ent ``` ## Run Your First Graph Traversal In order to run our first graph traversal, we need to generate some data (nodes and edges, or in other words, entities and relations). Let's create the following graph using the framework: ![re-graph](https://entgo.io/images/assets/re_graph_getting_started.png) ```go title="/start/start.go" func CreateGraph(ctx context.Context, client *ent.Client) error { // First, create the users. a8m, err := client.User. Create(). SetAge(30). SetName("Ariel"). Save(ctx) if err != nil { return err } neta, err := client.User. Create(). SetAge(28). SetName("Neta"). Save(ctx) if err != nil { return err } // Then, create the cars, and attach them to the users created above. err = client.Car. Create(). SetModel("Tesla"). SetRegisteredAt(time.Now()). // Attach this car to Ariel. SetOwner(a8m). Exec(ctx) if err != nil { return err } err = client.Car. Create(). SetModel("Mazda"). SetRegisteredAt(time.Now()). // Attach this car to Ariel. SetOwner(a8m). Exec(ctx) if err != nil { return err } err = client.Car. Create(). SetModel("Ford"). SetRegisteredAt(time.Now()). // Attach this graph to Neta. SetOwner(neta). Exec(ctx) if err != nil { return err } // Create the groups, and add their users in the creation. err = client.Group. Create(). SetName("GitLab"). AddUsers(neta, a8m). Exec(ctx) if err != nil { return err } err = client.Group. Create(). SetName("GitHub"). AddUsers(a8m). Exec(ctx) if err != nil { return err } log.Println("The graph was created successfully") return nil } ``` Now when we have a graph with data, we can run a few queries on it: 1. Get all user's cars within the group named "GitHub": ```go title="/start/start.go" import ( "log" "/ent" "/ent/group" ) func QueryGithub(ctx context.Context, client *ent.Client) error { cars, err := client.Group. Query(). Where(group.Name("GitHub")). // (Group(Name=GitHub),) QueryUsers(). // (User(Name=Ariel, Age=30),) QueryCars(). // (Car(Model=Tesla, RegisteredAt=