pax_global_header00006660000000000000000000000064147110066100014506gustar00rootroot0000000000000052 comment=2a853ab33ca377234373b5954b3348ee4b2991c4 golang-github-canonical-go-dqlite-2.0.0/000077500000000000000000000000001471100661000200445ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/.dir-locals.el000066400000000000000000000003771471100661000225040ustar00rootroot00000000000000;;; Directory Local Variables ;;; For more information see (info "(emacs) Directory Variables") ((go-mode . ((go-test-args . "-tags libsqlite3 -timeout 90s") (eval . (set (make-local-variable 'flycheck-go-build-tags) '("libsqlite3")))))) golang-github-canonical-go-dqlite-2.0.0/.github/000077500000000000000000000000001471100661000214045ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/.github/dependabot.yml000066400000000000000000000012001471100661000242250ustar00rootroot00000000000000# To get started with Dependabot version updates, you'll need to specify which # package ecosystems to update and where the package manifests are located. # Please see the documentation for all configuration options: # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file version: 2 updates: - package-ecosystem: "gomod" directory: "/" labels: [] schedule: interval: "weekly" target-branch: "master" - package-ecosystem: "github-actions" directory: "/" labels: [] schedule: interval: "weekly" target-branch: "master" golang-github-canonical-go-dqlite-2.0.0/.github/workflows/000077500000000000000000000000001471100661000234415ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/.github/workflows/build-and-test.yml000066400000000000000000000043401471100661000270010ustar00rootroot00000000000000name: CI tests on: - push - pull_request jobs: build-and-test: strategy: fail-fast: false matrix: go: - '1.13' - stable os: - ubuntu-20.04 - ubuntu-22.04 - ubuntu-24.04 runs-on: ${{ matrix.os }} steps: - name: Checkout code uses: actions/checkout@v4 - name: Install Go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Setup dependencies run: | sudo add-apt-repository ppa:dqlite/dev -y sudo apt update sudo apt install -y libsqlite3-dev libuv1-dev liblz4-dev libdqlite1.17-dev - name: Download deps run: | go get -t -tags libsqlite3 ./... - name: Test env: GO_DQLITE_MULTITHREAD: 1 run: | go test -v -tags libsqlite3 -race -coverprofile=coverage.out ./... go test -v -tags nosqlite3 ./... VERBOSE=1 ./test/dqlite-demo.sh VERBOSE=1 ./test/roles.sh VERBOSE=1 ./test/recover.sh - name: Coverage uses: coverallsapp/github-action@v2 with: file: coverage.out parallel: true - name: Benchmark env: GO_DQLITE_MULTITHREAD: 1 run: | go install -tags libsqlite3 github.com/canonical/go-dqlite/v2/cmd/dqlite-benchmark dqlite-benchmark --db 127.0.0.1:9001 --driver --cluster 127.0.0.1:9001,127.0.0.1:9002,127.0.0.1:9003 --workload kvreadwrite & masterpid=$! dqlite-benchmark --db 127.0.0.1:9002 --join 127.0.0.1:9001 & dqlite-benchmark --db 127.0.0.1:9003 --join 127.0.0.1:9001 & wait $masterpid echo "Write results:" head -n 5 /tmp/dqlite-benchmark/127.0.0.1:9001/results/0-exec-* echo "" echo "Read results:" head -n 5 /tmp/dqlite-benchmark/127.0.0.1:9001/results/0-query-* - uses: actions/upload-artifact@v3 with: name: dqlite-benchmark-${{ matrix.os }}-${{ matrix.go }} path: /tmp/dqlite-benchmark/127.0.0.1:9001/results/* finish: needs: build-and-test if: ${{ always() }} runs-on: ubuntu-latest steps: - name: Finish coverage uses: coverallsapp/github-action@v2 with: parallel-finished: true golang-github-canonical-go-dqlite-2.0.0/.github/workflows/cla-check.yml000066400000000000000000000002711471100661000257760ustar00rootroot00000000000000name: Canonical CLA on: - pull_request jobs: cla-check: runs-on: ubuntu-20.04 steps: - name: Check if CLA signed uses: canonical/has-signed-canonical-cla@v1 golang-github-canonical-go-dqlite-2.0.0/.github/workflows/daily-benchmark.yml000066400000000000000000000026001471100661000272140ustar00rootroot00000000000000name: Daily benchmark on: schedule: - cron: "0 12 * * *" jobs: benchmark: runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 - name: Install Go uses: actions/setup-go@v5 with: go-version: stable - name: Setup dependencies run: | sudo add-apt-repository ppa:dqlite/dev -y sudo apt update sudo apt install -y libsqlite3-dev libuv1-dev liblz4-dev libraft-dev libdqlite-dev - name: Build & Benchmark env: GO_DQLITE_MULTITHREAD: 1 run: | go get -t -tags libsqlite3 ./... go install -tags libsqlite3 github.com/canonical/go-dqlite/v2/cmd/dqlite-benchmark dqlite-benchmark --db 127.0.0.1:9001 --duration 3600 --driver --cluster 127.0.0.1:9001,127.0.0.1:9002,127.0.0.1:9003 --workload kvreadwrite & masterpid=$! dqlite-benchmark --db 127.0.0.1:9002 --join 127.0.0.1:9001 & dqlite-benchmark --db 127.0.0.1:9003 --join 127.0.0.1:9001 & wait $masterpid echo "Write results:" head -n 5 /tmp/dqlite-benchmark/127.0.0.1:9001/results/0-exec-* echo "" echo "Read results:" head -n 5 /tmp/dqlite-benchmark/127.0.0.1:9001/results/0-query-* - uses: actions/upload-artifact@v3 with: name: dqlite-daily-benchmark path: /tmp/dqlite-benchmark/127.0.0.1:9001/results/* golang-github-canonical-go-dqlite-2.0.0/.github/workflows/packages.yml000066400000000000000000000033301471100661000257410ustar00rootroot00000000000000name: Build PPA source packages on: push: branches: - v2 jobs: build: if: github.repository == 'canonical/go-dqlite' strategy: fail-fast: false matrix: target: - focal - jammy - noble - oracular runs-on: ubuntu-20.04 environment: name: ppa steps: - uses: actions/checkout@v4 with: fetch-depth: 0 fetch-tags: true - uses: actions/checkout@v4 with: repository: canonical/dqlite-ppa ref: go-dqlite-v2 path: dqlite-ppa - name: Setup dependencies run: | sudo apt-get update -qq sudo apt-get install -qq debhelper devscripts dh-golang gnupg - name: Setup GPG signing key env: PPA_SECRET_KEY: ${{ secrets.PPA_SECRET_KEY }} run: | echo "$PPA_SECRET_KEY" > private-key.asc gpg --import --batch private-key.asc - name: Delete GPG signing key file if: always() run: | rm -f private-key.asc - name: Build source package env: DEBFULLNAME: "Github Actions" DEBEMAIL: "dqlitebot@lists.canonical.com" TARGET: ${{ matrix.target }} run: | cp -R dqlite-ppa/debian . go mod vendor VERSION="$(git describe --tags | sed -e "s/^v//" -e "s/-/+git/")" dch --create \ --distribution ${TARGET} \ --package go-dqlite-v2 \ --newversion ${VERSION}~${TARGET}1 \ "Automatic build from Github" debuild -S -sa -d -k${{ vars.PPA_PUBLIC_KEY }} - name: Upload to Launchpad run: | cd .. shopt -s globstar dput -U -u ppa:dqlite/dev **/*.changes golang-github-canonical-go-dqlite-2.0.0/.github/workflows/static.yml000066400000000000000000000011401471100661000254470ustar00rootroot00000000000000name: Static checks on: - push - pull_request jobs: run: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version: stable - name: Install libdqlite run: | sudo add-apt-repository ppa:dqlite/dev -y sudo apt update sudo apt install -y libdqlite-dev - name: Go vet run: | go vet -tags libsqlite3 ./... - uses: dominikh/staticcheck-action@v1 with: version: '2024.1.1' build-tags: libsqlite3 install-go: false golang-github-canonical-go-dqlite-2.0.0/.gitignore000066400000000000000000000001641471100661000220350ustar00rootroot00000000000000.sqlite cmd/dqlite/dqlite cmd/dqlite-demo/dqlite-demo dqlite dqlite-demo profile.coverprofile overalls.coverprofile golang-github-canonical-go-dqlite-2.0.0/AUTHORS000066400000000000000000000003631471100661000211160ustar00rootroot00000000000000Unless mentioned otherwise in a specific file's header, all code in this project is released under the Apache 2.0 license. The list of authors and contributors can be retrieved from the git commit history and in some cases, the file headers. golang-github-canonical-go-dqlite-2.0.0/LICENSE000066400000000000000000000261351471100661000210600ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. golang-github-canonical-go-dqlite-2.0.0/README.md000066400000000000000000000116461471100661000213330ustar00rootroot00000000000000go-dqlite [![CI tests](https://github.com/canonical/go-dqlite/actions/workflows/build-and-test.yml/badge.svg)](https://github.com/canonical/go-dqlite/actions/workflows/build-and-test.yml) [![Coverage Status](https://coveralls.io/repos/github/canonical/go-dqlite/badge.svg?branch=master)](https://coveralls.io/github/canonical/go-dqlite?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/canonical/go-dqlite)](https://goreportcard.com/report/github.com/canonical/go-dqlite) [![GoDoc](https://godoc.org/github.com/canonical/go-dqlite?status.svg)](https://godoc.org/github.com/canonical/go-dqlite) ====== This repository provides the `go-dqlite` Go package, containing bindings for the [dqlite](https://github.com/canonical/dqlite) C library and a pure-Go client for the dqlite wire [protocol](https://github.com/canonical/dqlite/blob/master/doc/protocol.md). Usage ----- The best way to understand how to use the ```go-dqlite``` package is probably by looking at the source code of the [demo program](https://github.com/canonical/go-dqlite/blob/master/cmd/dqlite-demo/dqlite-demo.go) and use it as example. In general your application will use code such as: ```go dir := "/path/to/data/directory" address := "1.2.3.4:666" // Unique node address cluster := []string{...} // Optional list of existing nodes, when starting a new node app, err := app.New(dir, app.WithAddress(address), app.WithCluster(cluster)) if err != nil { // ... } db, err := app.Open(context.Background(), "my-database") if err != nil { // ... } // db is a *sql.DB object if _, err := db.Exec("CREATE TABLE my_table (n INT)"); err != nil // ... } ``` Build ----- In order to use the go-dqlite package in your application, you'll need to have the [dqlite](https://github.com/canonical/dqlite) C library installed on your system, along with its dependencies. By default, go-dqlite's `client` module supports storing a cache of the cluster's state in a SQLite database, locally on each cluster member. (This is not to be confused with any SQLite databases that are managed by dqlite.) In order to do this, it imports https://github.com/mattn/go-sqlite3, and so you can use the `libsqlite3` build tag to control whether go-sqlite3 links to a system libsqlite3 or builds its own. You can also disable support for SQLite node stores entirely with the `nosqlite3` build tag (unique to go-dqlite). If you pass this tag, your application will not link *directly* to libsqlite3 (but it will still link it *indirectly* via libdqlite, unless you've dropped the sqlite3.c amalgamation into the dqlite build). Documentation ------------- The documentation for this package can be found on [pkg.go.dev](https://pkg.go.dev/github.com/canonical/go-dqlite). Demo ---- To see dqlite in action, either install the Debian package from the PPA: ```bash sudo add-apt-repository -y ppa:dqlite/dev sudo apt install dqlite-tools libdqlite-dev ``` or build the dqlite C library and its dependencies from source, as described [here](https://github.com/canonical/dqlite#build), and then run: ``` go install -tags libsqlite3 ./cmd/dqlite-demo ``` from the top-level directory of this repository. This builds a demo dqlite application, which exposes a simple key/value store over an HTTP API. Once the `dqlite-demo` binary is installed (normally under `~/go/bin` or `/usr/bin/`), start three nodes of the demo application: ```bash dqlite-demo --api 127.0.0.1:8001 --db 127.0.0.1:9001 & dqlite-demo --api 127.0.0.1:8002 --db 127.0.0.1:9002 --join 127.0.0.1:9001 & dqlite-demo --api 127.0.0.1:8003 --db 127.0.0.1:9003 --join 127.0.0.1:9001 & ``` The `--api` flag tells the demo program where to expose its HTTP API. The `--db` flag tells the demo program to use the given address for internal database replication. The `--join` flag is optional and should be used only for additional nodes after the first one. It informs them about the existing cluster, so they can automatically join it. Now we can start using the cluster. Let's insert a key pair: ```bash curl -X PUT -d my-value http://127.0.0.1:8001/my-key ``` and then retrieve it from the database: ```bash curl http://127.0.0.1:8001/my-key ``` Currently the first node is the leader. If we stop it and then try to query the key again curl will fail, but we can simply change the endpoint to another node and things will work since an automatic failover has taken place: ```bash kill -TERM %1; curl http://127.0.0.1:8002/my-key ``` Shell ------ A basic SQLite-like dqlite shell is available in the `dqlite-tools` package or can be built with: ``` go install -tags libsqlite3 ./cmd/dqlite ``` ``` Usage: dqlite -s [command] [flags] ``` Example usage in the case of the `dqlite-demo` example listed above: ``` dqlite -s 127.0.0.1:9001 demo dqlite> SELECT * FROM model; my-key|my-value ``` The shell supports normal SQL queries plus the special `.cluster` and `.leader` commands to inspect the cluster members and the current leader. golang-github-canonical-go-dqlite-2.0.0/SECURITY.md000066400000000000000000000017051471100661000216400ustar00rootroot00000000000000# How to report a security issue with go-dqlite If you find a security issue with go-dqlite, the best way to report it is using GitHub's private vulnerability reporting. [Here][advisory] is the form to submit a report, and [here][docs] is the detailed documentation for the GitHub feature. Once you submit a report, the dqlite team will work with you to figure out whether there is a security issue. If so, we will develop a fix, get a CVE assigned, and coordinate the release of the fix. The [Ubuntu Security disclosure and embargo policy][policy] contains more information about what you can expect during this phase, and what we expect from you. [advisory]: https://github.com/canonical/go-dqlite/security/advisories/new [docs]: https://docs.github.com/en/code-security/security-advisories/guidance-on-reporting-and-writing-information-about-vulnerabilities/privately-reporting-a-security-vulnerability [policy]: https://ubuntu.com/security/disclosure-policy golang-github-canonical-go-dqlite-2.0.0/app/000077500000000000000000000000001471100661000206245ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/app/app.go000066400000000000000000000505751471100661000217470ustar00rootroot00000000000000package app import ( "context" "crypto/tls" "database/sql" "fmt" "net" "os" "path/filepath" "runtime" "sync" "sync/atomic" "time" "github.com/canonical/go-dqlite/v2" "github.com/canonical/go-dqlite/v2/client" "github.com/canonical/go-dqlite/v2/driver" "github.com/canonical/go-dqlite/v2/internal/protocol" "github.com/pkg/errors" "golang.org/x/sync/semaphore" ) // used to create a unique driver name, MUST be modified atomically // https://pkg.go.dev/sync/atomic#AddInt64 var driverIndex int64 // App is a high-level helper for initializing a typical dqlite-based Go // application. // // It takes care of starting a dqlite node and registering a dqlite Go SQL // driver. type App struct { id uint64 address string dir string node *dqlite.Node nodeBindAddress string listener net.Listener tls *tlsSetup dialFunc client.DialFunc store client.NodeStore lc *client.Connector driver *driver.Driver driverName string log client.LogFunc ctx context.Context stop context.CancelFunc // Signal App.run() to stop. proxyCh chan struct{} // Waits for App.proxy() to return. runCh chan struct{} // Waits for App.run() to return. readyCh chan struct{} // Waits for startup tasks voters int standbys int roles RolesConfig options *options } // New creates a new application node. func New(dir string, options ...Option) (app *App, err error) { o := defaultOptions() for _, option := range options { option(o) } var nodeBindAddress string if o.Conn != nil { listener, err := net.Listen("unix", o.UnixSocket) if err != nil { return nil, fmt.Errorf("failed to autobind unix socket: %w", err) } nodeBindAddress = listener.Addr().String() listener.Close() } // List of cleanup functions to run in case of errors. cleanups := []func(){} defer func() { if err == nil { return } for i := range cleanups { i = len(cleanups) - 1 - i // Reverse order cleanups[i]() } }() // Load our ID, or generate one if we are joining. info := client.NodeInfo{} infoFileExists, err := fileExists(dir, infoFile) if err != nil { return nil, err } if !infoFileExists { if o.Address == "" { if o.Address, err = defaultAddress(); err != nil { return nil, err } } if len(o.Cluster) == 0 { info.ID = dqlite.BootstrapID } else { info.ID = dqlite.GenerateID(o.Address) if err := fileWrite(dir, joinFile, []byte{}); err != nil { return nil, err } } info.Address = o.Address if err := fileMarshal(dir, infoFile, info); err != nil { return nil, err } cleanups = append(cleanups, func() { fileRemove(dir, infoFile) }) } else { if err := fileUnmarshal(dir, infoFile, &info); err != nil { return nil, err } if o.Address != "" && o.Address != info.Address { return nil, fmt.Errorf("address %q in info.yaml does not match %q", info.Address, o.Address) } } joinFileExists, err := fileExists(dir, joinFile) if err != nil { return nil, err } if info.ID == dqlite.BootstrapID && joinFileExists { return nil, fmt.Errorf("bootstrap node can't join a cluster") } // Open the nodes store. storeFileExists, err := fileExists(dir, storeFile) if err != nil { return nil, err } store, err := client.NewYamlNodeStore(filepath.Join(dir, storeFile)) if err != nil { return nil, fmt.Errorf("open cluster.yaml node store: %w", err) } // The info file and the store file should both exists or none of them // exist. if infoFileExists != storeFileExists { return nil, fmt.Errorf("inconsistent info.yaml and cluster.yaml") } if !storeFileExists { // If this is a brand new application node, populate the store // either with the node's address (for bootstrap nodes) or with // the given cluster addresses (for joining nodes). nodes := []client.NodeInfo{} if info.ID == dqlite.BootstrapID { nodes = append(nodes, client.NodeInfo{Address: info.Address}) } else { if len(o.Cluster) == 0 { return nil, fmt.Errorf("no cluster addresses provided") } for _, address := range o.Cluster { nodes = append(nodes, client.NodeInfo{Address: address}) } } if err := store.Set(context.Background(), nodes); err != nil { return nil, fmt.Errorf("initialize node store: %w", err) } cleanups = append(cleanups, func() { fileRemove(dir, storeFile) }) } // Start the local dqlite engine. ctx, stop := context.WithCancel(context.Background()) var nodeDial client.DialFunc if o.Conn != nil { nodeDial = extDialFuncWithProxy(ctx, o.Conn.dialFunc) } else if o.TLS != nil { nodeBindAddress = fmt.Sprintf("@dqlite-%d", info.ID) // Within a snap we need to choose a different name for the abstract unix domain // socket to get it past the AppArmor confinement. // See https://github.com/snapcore/snapd/blob/master/interfaces/apparmor/template.go#L357 snapInstanceName := os.Getenv("SNAP_INSTANCE_NAME") if len(snapInstanceName) > 0 { nodeBindAddress = fmt.Sprintf("@snap.%s.dqlite-%d", snapInstanceName, info.ID) } nodeDial = makeNodeDialFunc(ctx, o.TLS.Dial) } else { nodeBindAddress = info.Address nodeDial = client.DefaultDialFunc } node, err := dqlite.New( info.ID, info.Address, dir, dqlite.WithBindAddress(nodeBindAddress), dqlite.WithDialFunc(nodeDial), dqlite.WithFailureDomain(o.FailureDomain), dqlite.WithNetworkLatency(o.NetworkLatency), dqlite.WithSnapshotParams(o.SnapshotParams), dqlite.WithDiskMode(o.DiskMode), dqlite.WithAutoRecovery(o.AutoRecovery), ) if err != nil { stop() return nil, fmt.Errorf("create node: %w", err) } if err := node.Start(); err != nil { stop() return nil, fmt.Errorf("start node: %w", err) } cleanups = append(cleanups, func() { node.Close() }) // Register the local dqlite driver. driverDial := client.DefaultDialFunc if o.TLS != nil { driverDial = client.DialFuncWithTLS(driverDial, o.TLS.Dial) } else if o.Conn != nil { driverDial = o.Conn.dialFunc } driver, err := driver.New( store, driver.WithDialFunc(driverDial), driver.WithLogFunc(o.Log), driver.WithTracing(o.Tracing), driver.WithConcurrentLeaderConns(o.ConcurrentLeaderConns), ) if err != nil { stop() return nil, fmt.Errorf("create driver: %w", err) } driverName := fmt.Sprintf("dqlite-%d", atomic.AddInt64(&driverIndex, 1)) sql.Register(driverName, driver) if o.Voters < 3 || o.Voters%2 == 0 { stop() return nil, fmt.Errorf("invalid voters %d: must be an odd number greater than 1", o.Voters) } if runtime.GOOS != "linux" && nodeBindAddress[0] == '@' { // Do not use abstract socket on other platforms and left trim "@" nodeBindAddress = nodeBindAddress[1:] } lc := client.NewLeaderConnector( store, client.WithDialFunc(driverDial), client.WithLogFunc(o.Log), client.WithConcurrentLeaderConns(*o.ConcurrentLeaderConns), ) app = &App{ id: info.ID, address: info.Address, dir: dir, node: node, nodeBindAddress: nodeBindAddress, store: store, dialFunc: driverDial, lc: lc, driver: driver, driverName: driverName, log: o.Log, tls: o.TLS, ctx: ctx, stop: stop, runCh: make(chan struct{}), readyCh: make(chan struct{}), voters: o.Voters, standbys: o.StandBys, roles: RolesConfig{Voters: o.Voters, StandBys: o.StandBys}, options: o, } // Start the proxy if a TLS configuration was provided. if o.TLS != nil { listener, err := net.Listen("tcp", info.Address) if err != nil { return nil, fmt.Errorf("listen to %s: %w", info.Address, err) } proxyCh := make(chan struct{}) app.listener = listener app.proxyCh = proxyCh go app.proxy() cleanups = append(cleanups, func() { listener.Close(); <-proxyCh }) } else if o.Conn != nil { go func() { for remote := range o.Conn.acceptCh { // keep forward compatible _, isTcp := remote.(*net.TCPConn) _, isTLS := remote.(*tls.Conn) if isTcp || isTLS { // Write the status line and upgrade header by hand since w.WriteHeader() would fail after Hijack(). data := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: dqlite\r\n\r\n") n, err := remote.Write(data) if err != nil || n != len(data) { remote.Close() panic(fmt.Errorf("failed to write connection header: %w", err)) } } local, err := net.Dial("unix", nodeBindAddress) if err != nil { remote.Close() panic(fmt.Errorf("failed to connect to bind address %q: %w", nodeBindAddress, err)) } go proxy(app.ctx, remote, local, nil) } }() } go app.run(ctx, o, joinFileExists) return app, nil } // Handover transfers all responsibilities for this node (such has leadership // and voting rights) to another node, if one is available. // // This method should always be called before invoking Close(), in order to // gracefully shutdown a node. func (a *App) Handover(ctx context.Context) error { // Set a hard limit of one minute, in case the user-provided context // has no expiration. That avoids the call to stop responding forever // in case a majority of the cluster is down and no leader is available. // Watch out when removing or editing this context, the for loop at the // end of this function will possibly run "forever" without it. var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, time.Minute) defer cancel() cli, err := a.FindLeader(ctx) if err != nil { return fmt.Errorf("find leader: %w", err) } defer cli.Close() // Possibly transfer our role. nodes, err := cli.Cluster(ctx) if err != nil { return fmt.Errorf("cluster servers: %w", err) } changes := a.makeRolesChanges(nodes) role, candidates := changes.Handover(a.id) if role != -1 { for i, node := range candidates { if err := cli.Assign(ctx, node.ID, role); err != nil { a.warn("promote %s from %s to %s: %v", node.Address, node.Role, role, err) if i == len(candidates)-1 { // We could not promote any node return fmt.Errorf("could not promote any online node to %s", role) } continue } a.debug("promoted %s from %s to %s", node.Address, node.Role, role) break } } // Check if we are the current leader and transfer leadership if so. leader, err := cli.Leader(ctx) if err != nil { return fmt.Errorf("leader address: %w", err) } if leader != nil && leader.Address == a.address { nodes, err := cli.Cluster(ctx) if err != nil { return fmt.Errorf("cluster servers: %w", err) } changes := a.makeRolesChanges(nodes) voters := changes.list(client.Voter, true, nil) for i, voter := range voters { if voter.Address == a.address { continue } if err := cli.Transfer(ctx, voter.ID); err != nil { a.warn("transfer leadership to %s: %v", voter.Address, err) if i == len(voters)-1 { return fmt.Errorf("transfer leadership: %w", err) } } cli, err = a.FindLeader(ctx) if err != nil { return fmt.Errorf("find new leader: %w", err) } defer cli.Close() break } } // Demote ourselves if we have promoted someone else. if role != -1 { // Try a while before failing. The new leader has to possibly commit an entry // from its new term in order to commit the last configuration change, wait a bit // for that to happen and don't fail immediately for { err = cli.Assign(ctx, a.ID(), client.Spare) if err == nil { return nil } select { case <-ctx.Done(): return fmt.Errorf("demote ourselves context done: %w", err) default: // Wait a bit before trying again time.Sleep(time.Second) continue } } } return nil } // Close the application node, releasing all resources it created. func (a *App) Close() error { // Stop the run goroutine. a.stop() <-a.runCh if a.listener != nil { a.listener.Close() <-a.proxyCh } if err := a.node.Close(); err != nil { return err } return nil } // ID returns the dqlite ID of this application node. func (a *App) ID() uint64 { return a.id } // Address returns the dqlite address of this application node. func (a *App) Address() string { return a.address } // Driver returns the name used to register the dqlite driver. func (a *App) Driver() string { return a.driverName } // Ready can be used to wait for a node to complete some initial tasks that are // initiated at startup. For example a brand new node will attempt to join the // cluster, a restarted node will check if it should assume some particular // role, etc. // // If this method returns without error it means that those initial tasks have // succeeded and follow-up operations like Open() are more likely to succeeed // quickly. func (a *App) Ready(ctx context.Context) error { select { case <-a.readyCh: return nil case <-ctx.Done(): return ctx.Err() } } // Open the dqlite database with the given name func (a *App) Open(ctx context.Context, database string) (*sql.DB, error) { db, err := sql.Open(a.Driver(), database) if err != nil { return nil, err } for i := 0; i < 60; i++ { err = db.PingContext(ctx) if err == nil { break } cause := errors.Cause(err) if cause != driver.ErrNoAvailableLeader { return nil, err } time.Sleep(time.Second) } if err != nil { return nil, err } return db, nil } // Leader returns a client connected to the cluster leader. // // Prefer to use FindLeader instead unless you need to pass custom options. func (a *App) Leader(ctx context.Context, options ...client.Option) (*client.Client, error) { allOptions := a.clientOptions() allOptions = append(allOptions, options...) return client.FindLeader(ctx, a.store, allOptions...) } // FindLeader returns a client connected to the cluster leader. // // Compared to Leader, this method avoids opening extra connections int many // cases, but doesn't accept custom options. func (a *App) FindLeader(ctx context.Context) (*client.Client, error) { return a.lc.Connect(ctx) } // Client returns a client connected to the local node. func (a *App) Client(ctx context.Context) (*client.Client, error) { return client.New(ctx, a.nodeBindAddress) } // Proxy incoming TLS connections. func (a *App) proxy() { wg := sync.WaitGroup{} ctx, cancel := context.WithCancel(a.ctx) for { client, err := a.listener.Accept() if err != nil { cancel() wg.Wait() close(a.proxyCh) return } address := client.RemoteAddr() a.debug("new connection from %s", address) server, err := net.Dial("unix", a.nodeBindAddress) if err != nil { a.error("dial local node: %v", err) client.Close() continue } wg.Add(1) go func() { defer wg.Done() if err := proxy(ctx, client, server, a.tls.Listen); err != nil { a.error("proxy: %v", err) } }() } } // Run background tasks. The join flag is true if the node is a brand new one // and should join the cluster. func (a *App) run(ctx context.Context, options *options, join bool) { defer close(a.runCh) delay := time.Duration(0) ready := false for { select { case <-ctx.Done(): // If we didn't become ready yet, close the ready // channel, to unblock any call to Ready(). if !ready { close(a.readyCh) } return case <-time.After(delay): cli, err := a.FindLeader(ctx) if err != nil { continue } // Attempt to join the cluster if this is a brand new node. if join { info := client.NodeInfo{ID: a.id, Address: a.address, Role: client.Spare} if err := cli.Add(ctx, info); err != nil { a.warn("join cluster: %v", err) delay = time.Second cli.Close() continue } join = false if err := fileRemove(a.dir, joinFile); err != nil { a.error("remove join file: %v", err) } } // Refresh our node store. servers, err := cli.Cluster(ctx) if err != nil { cli.Close() continue } if len(servers) == 0 { a.warn("server list empty") cli.Close() continue } a.store.Set(ctx, servers) // If we are starting up, let's see if we should // promote ourselves. if !ready { if err := a.maybePromoteOurselves(ctx, cli, servers); err != nil { a.warn("%v", err) delay = time.Second cli.Close() continue } ready = true delay = options.RolesAdjustmentFrequency close(a.readyCh) cli.Close() continue } // If we are the leader, let's see if there's any // adjustment we should make to node roles. if err := a.maybeAdjustRoles(ctx, cli); err != nil { a.warn("adjust roles: %v", err) } leader, err := cli.Leader(ctx) if err != nil { a.error("fetch leader info: %v", err) cli.Close() continue } err = options.OnRolesAdjustment(*leader, servers) if err != nil { a.warn("roles adjustment hook: %v", err) } cli.Close() } } } // Possibly change our own role at startup. func (a *App) maybePromoteOurselves(ctx context.Context, cli *client.Client, nodes []client.NodeInfo) error { roles := a.makeRolesChanges(nodes) role := roles.Assume(a.id) if role == -1 { return nil } // Promote ourselves. if err := cli.Assign(ctx, a.id, role); err != nil { return fmt.Errorf("assign %s role to ourselves: %v", role, err) } // Possibly try to promote another node as well if we've reached the 3 // node threshold. If we don't succeed in doing that, errors are // ignored since the leader will eventually notice that don't have // enough voters and will retry. if role == client.Voter && roles.count(client.Voter, true) == 1 { for node := range roles.State { if node.ID == a.id || node.Role == client.Voter { continue } if err := cli.Assign(ctx, node.ID, client.Voter); err == nil { break } else { a.warn("promote %s from %s to voter: %v", node.Address, node.Role, err) } } } return nil } // Check if any adjustment needs to be made to existing roles. func (a *App) maybeAdjustRoles(ctx context.Context, cli *client.Client) error { again: info, err := cli.Leader(ctx) if err != nil { return err } if info.ID != a.id { return nil } nodes, err := cli.Cluster(ctx) if err != nil { return err } roles := a.makeRolesChanges(nodes) role, nodes := roles.Adjust(a.id) if role == -1 { return nil } for i, node := range nodes { if err := cli.Assign(ctx, node.ID, role); err != nil { a.warn("change %s from %s to %s: %v", node.Address, node.Role, role, err) if i == len(nodes)-1 { // We could not change any node return fmt.Errorf("could not assign role %s to any node", role) } continue } break } goto again } // Probe all given nodes for connectivity and metadata, then return a // RolesChanges object. func (a *App) makeRolesChanges(nodes []client.NodeInfo) RolesChanges { state := map[client.NodeInfo]*client.NodeMetadata{} for _, node := range nodes { state[node] = nil } var ( mtx sync.Mutex // Protects state map wg sync.WaitGroup // Wait for all probes to finish nProbes = runtime.NumCPU() sem = semaphore.NewWeighted(int64(nProbes)) // Limit number of parallel probes ) for _, node := range nodes { wg.Add(1) // sem.Acquire will not block forever because the goroutines // that release the semaphore will eventually timeout. if err := sem.Acquire(context.Background(), 1); err != nil { a.warn("failed to acquire semaphore: %v", err) wg.Done() continue } go func(node protocol.NodeInfo) { defer wg.Done() defer sem.Release(1) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cli, err := client.New(ctx, node.Address, a.clientOptions()...) if err == nil { metadata, err := cli.Describe(ctx) if err == nil { mtx.Lock() state[node] = metadata mtx.Unlock() } cli.Close() } }(node) } wg.Wait() return RolesChanges{Config: a.roles, State: state} } // Return the options to use for client.FindLeader() or client.New() func (a *App) clientOptions() []client.Option { return []client.Option{ client.WithDialFunc(a.dialFunc), client.WithLogFunc(a.log), client.WithConcurrentLeaderConns(*a.options.ConcurrentLeaderConns), } } func (a *App) debug(format string, args ...interface{}) { a.log(client.LogDebug, format, args...) } //lint:ignore U1000 not currently used but preserved for consistency func (a *App) info(format string, args ...interface{}) { a.log(client.LogInfo, format, args...) } func (a *App) warn(format string, args ...interface{}) { a.log(client.LogWarn, format, args...) } func (a *App) error(format string, args ...interface{}) { a.log(client.LogError, format, args...) } golang-github-canonical-go-dqlite-2.0.0/app/app_go1.18_test.go000066400000000000000000000076551471100661000240040ustar00rootroot00000000000000//go:build go1.18 // +build go1.18 package app_test // import ( // "context" // "crypto/tls" // "net" // "testing" // "github.com/canonical/go-dqlite/v2/app" // "github.com/canonical/go-dqlite/v2/client" // "github.com/quic-go/quic-go" // "github.com/stretchr/testify/assert" // "github.com/stretchr/testify/require" // ) // // quic.Stream doesn't implement net.Conn, so we need to wrap it. // type quicConn struct { // quic.Stream // } // func (c *quicConn) LocalAddr() net.Addr { // return nil // } // func (c *quicConn) RemoteAddr() net.Addr { // return nil // } // // TestExternalConnWithQUIC creates a 3-member cluster using external quic connection // // and ensures the cluster is successfully created, and that the connection is // // handled manually. // func TestExternalConnWithQUIC(t *testing.T) { // externalAddr1 := "127.0.0.1:9191" // externalAddr2 := "127.0.0.1:9292" // externalAddr3 := "127.0.0.1:9393" // acceptCh1 := make(chan net.Conn) // acceptCh2 := make(chan net.Conn) // acceptCh3 := make(chan net.Conn) // dialFunc := func(ctx context.Context, addr string) (net.Conn, error) { // conn, err := quic.DialAddrContext(ctx, addr, &tls.Config{InsecureSkipVerify: true, NextProtos: []string{"quic"}}, nil) // require.NoError(t, err) // stream, err := conn.OpenStreamSync(ctx) // require.NoError(t, err) // return &quicConn{ // Stream: stream, // }, nil // } // cert, pool := loadCert(t) // tlsconfig := app.SimpleListenTLSConfig(cert, pool) // tlsconfig.NextProtos = []string{"quic"} // tlsconfig.ClientAuth = tls.NoClientCert // serveQUIC := func(addr string, acceptCh chan net.Conn, cleanups chan func()) { // lis, err := quic.ListenAddr(addr, tlsconfig, nil) // require.NoError(t, err) // ctx, cancel := context.WithCancel(context.Background()) // go func() { // for { // select { // case <-ctx.Done(): // return // default: // conn, err := lis.Accept(context.Background()) // if err != nil { // return // } // stream, err := conn.AcceptStream(context.Background()) // if err != nil { // return // } // acceptCh <- &quicConn{ // Stream: stream, // } // } // } // }() // cleanup := func() { // cancel() // require.NoError(t, lis.Close()) // } // cleanups <- cleanup // } // liscleanups := make(chan func(), 3) // // Start up three listeners. // go serveQUIC(externalAddr1, acceptCh1, liscleanups) // go serveQUIC(externalAddr2, acceptCh2, liscleanups) // go serveQUIC(externalAddr3, acceptCh3, liscleanups) // defer func() { // for i := 0; i < 3; i++ { // cleanup := <-liscleanups // cleanup() // } // close(liscleanups) // }() // app1, cleanup := newAppWithNoTLS(t, app.WithAddress(externalAddr1), app.WithExternalConn(dialFunc, acceptCh1)) // defer cleanup() // app2, cleanup := newAppWithNoTLS(t, app.WithAddress(externalAddr2), app.WithExternalConn(dialFunc, acceptCh2), app.WithCluster([]string{externalAddr1})) // defer cleanup() // require.NoError(t, app2.Ready(context.Background())) // app3, cleanup := newAppWithNoTLS(t, app.WithAddress(externalAddr3), app.WithExternalConn(dialFunc, acceptCh3), app.WithCluster([]string{externalAddr1})) // defer cleanup() // require.NoError(t, app3.Ready(context.Background())) // // Get a client from the first node (likely the leader). // cli, err := app1.Leader(context.Background()) // require.NoError(t, err) // defer cli.Close() // // Ensure entries exist for each cluster member. // cluster, err := cli.Cluster(context.Background()) // require.NoError(t, err) // assert.Equal(t, externalAddr1, cluster[0].Address) // assert.Equal(t, externalAddr2, cluster[1].Address) // assert.Equal(t, externalAddr3, cluster[2].Address) // // Every cluster member should be a voter. // assert.Equal(t, client.Voter, cluster[0].Role) // assert.Equal(t, client.Voter, cluster[1].Role) // assert.Equal(t, client.Voter, cluster[2].Role) // } golang-github-canonical-go-dqlite-2.0.0/app/app_test.go000066400000000000000000001063201471100661000227740ustar00rootroot00000000000000package app_test import ( "bufio" "context" "crypto/tls" "crypto/x509" "database/sql" "encoding/binary" "fmt" "io/ioutil" "net" "net/http" "net/url" "os" "path/filepath" "strings" "testing" "time" "github.com/canonical/go-dqlite/v2" "github.com/canonical/go-dqlite/v2/app" "github.com/canonical/go-dqlite/v2/client" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Create a pristine bootstrap node with default value. func TestNew_PristineDefault(t *testing.T) { _, cleanup := newApp(t, app.WithAddress("127.0.0.1:9000")) defer cleanup() } // Create a pristine joining node. func TestNew_PristineJoiner(t *testing.T) { addr1 := "127.0.0.1:9001" addr2 := "127.0.0.1:9002" app1, cleanup := newApp(t, app.WithAddress(addr1)) defer cleanup() app2, cleanup := newApp(t, app.WithAddress(addr2), app.WithCluster([]string{addr1})) defer cleanup() require.NoError(t, app2.Ready(context.Background())) // The joining node to appear in the cluster list. cli, err := app1.Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, addr1, cluster[0].Address) assert.Equal(t, addr2, cluster[1].Address) // Initially the node joins as spare. assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Spare, cluster[1].Role) } // Restart a node that had previously joined the cluster successfully. func TestNew_JoinerRestart(t *testing.T) { addr1 := "127.0.0.1:9001" addr2 := "127.0.0.1:9002" app1, cleanup := newApp(t, app.WithAddress(addr1)) defer cleanup() require.NoError(t, app1.Ready(context.Background())) dir2, cleanup := newDir(t) defer cleanup() app2, cleanup := newAppWithDir(t, dir2, app.WithAddress(addr2), app.WithCluster([]string{addr1})) require.NoError(t, app2.Ready(context.Background())) cleanup() app2, cleanup = newAppWithDir(t, dir2, app.WithAddress(addr2)) defer cleanup() require.NoError(t, app2.Ready(context.Background())) } // The second joiner promotes itself and also the first joiner. func TestNew_SecondJoiner(t *testing.T) { addr1 := "127.0.0.1:9001" addr2 := "127.0.0.1:9002" addr3 := "127.0.0.1:9003" app1, cleanup := newApp(t, app.WithAddress(addr1)) defer cleanup() app2, cleanup := newApp(t, app.WithAddress(addr2), app.WithCluster([]string{addr1})) defer cleanup() require.NoError(t, app2.Ready(context.Background())) app3, cleanup := newApp(t, app.WithAddress(addr3), app.WithCluster([]string{addr1})) defer cleanup() require.NoError(t, app3.Ready(context.Background())) cli, err := app1.Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, addr1, cluster[0].Address) assert.Equal(t, addr2, cluster[1].Address) assert.Equal(t, addr3, cluster[2].Address) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) } // The third joiner gets the stand-by role. func TestNew_ThirdJoiner(t *testing.T) { apps := []*app.App{} for i := 0; i < 4; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{app.WithAddress(addr)} if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) defer cleanup() require.NoError(t, app.Ready(context.Background())) apps = append(apps, app) } cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) } // The fourth joiner gets the stand-by role. func TestNew_FourthJoiner(t *testing.T) { apps := []*app.App{} for i := 0; i < 5; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{app.WithAddress(addr)} if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) defer cleanup() require.NoError(t, app.Ready(context.Background())) apps = append(apps, app) } cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.StandBy, cluster[4].Role) } // The fifth joiner gets the stand-by role. func TestNew_FifthJoiner(t *testing.T) { apps := []*app.App{} for i := 0; i < 6; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{app.WithAddress(addr)} if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) defer cleanup() require.NoError(t, app.Ready(context.Background())) apps = append(apps, app) } cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.StandBy, cluster[4].Role) assert.Equal(t, client.StandBy, cluster[5].Role) } // The sixth joiner gets the spare role. func TestNew_SixthJoiner(t *testing.T) { apps := []*app.App{} for i := 0; i < 7; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{app.WithAddress(addr)} if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) defer cleanup() require.NoError(t, app.Ready(context.Background())) apps = append(apps, app) } cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.StandBy, cluster[4].Role) assert.Equal(t, client.StandBy, cluster[5].Role) assert.Equal(t, client.Spare, cluster[6].Role) } // Transfer voting rights to another online node. func TestHandover_Voter(t *testing.T) { n := 4 apps := make([]*app.App, n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{app.WithAddress(addr)} if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) defer cleanup() require.NoError(t, app.Ready(context.Background())) apps[i] = app } cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) require.NoError(t, apps[2].Handover(context.Background())) cluster, err = cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Spare, cluster[2].Role) assert.Equal(t, client.Voter, cluster[3].Role) } // In a two-node cluster only one of them is a voter. When Handover() is called // on the voter, the role and leadership are transfered. func TestHandover_TwoNodes(t *testing.T) { n := 2 apps := make([]*app.App, n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{app.WithAddress(addr)} if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) defer cleanup() require.NoError(t, app.Ready(context.Background())) apps[i] = app } err := apps[0].Handover(context.Background()) require.NoError(t, err) cli, err := apps[1].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Spare, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) } // Transfer voting rights to another online node. Failure domains are taken // into account. func TestHandover_VoterHonorFailureDomain(t *testing.T) { n := 6 apps := make([]*app.App, n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{ app.WithAddress(addr), app.WithFailureDomain(uint64(i % 3)), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) defer cleanup() require.NoError(t, app.Ready(context.Background())) apps[i] = app } cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() _, err = cli.Cluster(context.Background()) require.NoError(t, err) require.NoError(t, apps[2].Handover(context.Background())) cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Spare, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.StandBy, cluster[4].Role) assert.Equal(t, client.Voter, cluster[5].Role) } // Handover with a sinle node. func TestHandover_SingleNode(t *testing.T) { dir, cleanup := newDir(t) defer cleanup() app, err := app.New(dir, app.WithAddress("127.0.0.1:9001")) require.NoError(t, err) require.NoError(t, app.Ready(context.Background())) require.NoError(t, app.Handover(context.Background())) require.NoError(t, app.Close()) } // Exercise a sequential graceful shutdown of a 3-node cluster. func TestHandover_GracefulShutdown(t *testing.T) { n := 3 apps := make([]*app.App, n) for i := 0; i < n; i++ { dir, cleanup := newDir(t) defer cleanup() addr := fmt.Sprintf("127.0.0.1:900%d", i+1) log := func(l client.LogLevel, format string, a ...interface{}) { format = fmt.Sprintf("%s - %d: %s: %s", time.Now().Format("15:04:01.000"), i, l.String(), format) t.Logf(format, a...) } options := []app.Option{ app.WithAddress(addr), app.WithLogFunc(log), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, err := app.New(dir, options...) require.NoError(t, err) require.NoError(t, app.Ready(context.Background())) apps[i] = app } db, err := sql.Open(apps[0].Driver(), "test.db") require.NoError(t, err) _, err = db.Exec("CREATE TABLE test (n INT)") require.NoError(t, err) require.NoError(t, db.Close()) require.NoError(t, apps[0].Handover(context.Background())) require.NoError(t, apps[0].Close()) require.NoError(t, apps[1].Handover(context.Background())) require.NoError(t, apps[1].Close()) require.NoError(t, apps[2].Handover(context.Background())) require.NoError(t, apps[2].Close()) } // Transfer the stand-by role to another online node. func TestHandover_StandBy(t *testing.T) { n := 7 apps := make([]*app.App, n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{app.WithAddress(addr)} if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) defer cleanup() require.NoError(t, app.Ready(context.Background())) apps[i] = app } cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.StandBy, cluster[4].Role) assert.Equal(t, client.StandBy, cluster[5].Role) assert.Equal(t, client.Spare, cluster[6].Role) require.NoError(t, apps[4].Handover(context.Background())) cluster, err = cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.Spare, cluster[4].Role) assert.Equal(t, client.StandBy, cluster[5].Role) assert.Equal(t, client.StandBy, cluster[6].Role) } // Transfer leadership and voting rights to another node. func TestHandover_TransferLeadership(t *testing.T) { n := 4 apps := make([]*app.App, n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{app.WithAddress(addr)} if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) defer cleanup() require.NoError(t, app.Ready(context.Background())) apps[i] = app } cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() leader, err := cli.Leader(context.Background()) require.NoError(t, err) require.NotNil(t, leader) require.Equal(t, apps[0].ID(), leader.ID) require.NoError(t, apps[0].Handover(context.Background())) cli, err = apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() leader, err = cli.Leader(context.Background()) require.NoError(t, err) assert.NotEqual(t, apps[0].ID(), leader.ID) cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Spare, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.Voter, cluster[3].Role) } // If a voter goes offline, another node takes its place. func TestRolesAdjustment_ReplaceVoter(t *testing.T) { n := 4 apps := make([]*app.App, n) cleanups := make([]func(), n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{ app.WithAddress(addr), app.WithRolesAdjustmentFrequency(2 * time.Second), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) require.NoError(t, app.Ready(context.Background())) apps[i] = app cleanups[i] = cleanup } defer cleanups[0]() defer cleanups[1]() defer cleanups[3]() // A voter goes offline. cleanups[2]() time.Sleep(8 * time.Second) cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Spare, cluster[2].Role) assert.Equal(t, client.Voter, cluster[3].Role) } // If a voter goes offline, another node takes its place. If possible, pick a // voter from a failure domain which differs from the one of the two other // voters. func TestRolesAdjustment_ReplaceVoterHonorFailureDomain(t *testing.T) { n := 6 apps := make([]*app.App, n) cleanups := make([]func(), n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{ app.WithAddress(addr), app.WithRolesAdjustmentFrequency(4 * time.Second), app.WithFailureDomain(uint64(i % 3)), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) require.NoError(t, app.Ready(context.Background())) apps[i] = app cleanups[i] = cleanup } defer cleanups[0]() defer cleanups[1]() defer cleanups[3]() defer cleanups[4]() defer cleanups[5]() // A voter in failure domain 2 goes offline. cleanups[2]() time.Sleep(18 * time.Second) cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) // The replacement was picked in the same failure domain. assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Spare, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.StandBy, cluster[4].Role) assert.Equal(t, client.Voter, cluster[5].Role) } // If cluster is imbalanced (all voters in one failure domain), roles get re-shuffled. func TestRolesAdjustment_ImbalancedFailureDomain(t *testing.T) { n := 8 apps := make([]*app.App, n) cleanups := make([]func(), n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) // Half of the nodes will go to failure domain 0 and half on failure domain 1 fd := 0 if i > n/2 { fd = 1 } options := []app.Option{ app.WithAddress(addr), app.WithRolesAdjustmentFrequency(4 * time.Second), app.WithFailureDomain(uint64(fd)), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } // Nodes on failure domain 0 are started first so all voters are initially there. app, cleanup := newApp(t, options...) require.NoError(t, app.Ready(context.Background())) apps[i] = app cleanups[i] = cleanup } for i := 0; i < n; i++ { defer cleanups[i]() } for i := 0; i < n; i++ { cli, err := apps[i].Client(context.Background()) require.NoError(t, err) require.NoError(t, cli.Weight(context.Background(), uint64(n-i))) defer cli.Close() } time.Sleep(18 * time.Second) cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) domain := map[int]bool{ 0: false, 1: false, } for i := 0; i < n; i++ { // We know we have started half of the nodes in failure domain 0 and the other half on failure domain 1 fd := 0 if i > n/2 { fd = 1 } if cluster[i].Role == client.Voter { domain[fd] = true } } // All domain must have a voter for _, voters := range domain { assert.True(t, voters) } } // If a voter goes offline, another node takes its place. Preference will be // given to candidates with lower weights. func TestRolesAdjustment_ReplaceVoterHonorWeight(t *testing.T) { n := 6 apps := make([]*app.App, n) cleanups := make([]func(), n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{ app.WithAddress(addr), app.WithRolesAdjustmentFrequency(4 * time.Second), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) require.NoError(t, app.Ready(context.Background())) apps[i] = app cleanups[i] = cleanup } defer cleanups[0]() defer cleanups[1]() defer cleanups[3]() defer cleanups[4]() defer cleanups[5]() // A voter in failure domain 2 goes offline. cleanups[2]() cli, err := apps[3].Client(context.Background()) require.NoError(t, err) require.NoError(t, cli.Weight(context.Background(), uint64(15))) defer cli.Close() cli, err = apps[4].Client(context.Background()) require.NoError(t, err) require.NoError(t, cli.Weight(context.Background(), uint64(5))) defer cli.Close() cli, err = apps[5].Client(context.Background()) require.NoError(t, err) require.NoError(t, cli.Weight(context.Background(), uint64(10))) defer cli.Close() time.Sleep(18 * time.Second) cli, err = apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) // The stand-by with the lowest weight was picked. assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Spare, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.Voter, cluster[4].Role) assert.Equal(t, client.StandBy, cluster[5].Role) } // If a voter goes offline, but no another node can its place, then nothing // chagnes. func TestRolesAdjustment_CantReplaceVoter(t *testing.T) { n := 4 apps := make([]*app.App, n) cleanups := make([]func(), n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{ app.WithAddress(addr), app.WithRolesAdjustmentFrequency(4 * time.Second), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) require.NoError(t, app.Ready(context.Background())) apps[i] = app cleanups[i] = cleanup } defer cleanups[0]() defer cleanups[1]() // A voter and a spare go offline. cleanups[3]() cleanups[2]() time.Sleep(12 * time.Second) cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) } // If a stand-by goes offline, another node takes its place. func TestRolesAdjustment_ReplaceStandBy(t *testing.T) { n := 7 apps := make([]*app.App, n) cleanups := make([]func(), n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{ app.WithAddress(addr), app.WithRolesAdjustmentFrequency(5 * time.Second), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) require.NoError(t, app.Ready(context.Background())) apps[i] = app cleanups[i] = cleanup } defer cleanups[0]() defer cleanups[1]() defer cleanups[2]() defer cleanups[3]() defer cleanups[5]() defer cleanups[6]() // A stand-by goes offline. cleanups[4]() time.Sleep(20 * time.Second) cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.Spare, cluster[4].Role) assert.Equal(t, client.StandBy, cluster[5].Role) assert.Equal(t, client.StandBy, cluster[6].Role) } // If a stand-by goes offline, another node takes its place. If possible, pick // a stand-by from a failure domain which differs from the one of the two other // stand-bys. func TestRolesAdjustment_ReplaceStandByHonorFailureDomains(t *testing.T) { n := 9 apps := make([]*app.App, n) cleanups := make([]func(), n) for i := 0; i < n; i++ { addr := fmt.Sprintf("127.0.0.1:900%d", i+1) options := []app.Option{ app.WithAddress(addr), app.WithRolesAdjustmentFrequency(5 * time.Second), app.WithFailureDomain(uint64(i % 3)), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) } app, cleanup := newApp(t, options...) require.NoError(t, app.Ready(context.Background())) apps[i] = app cleanups[i] = cleanup } defer cleanups[0]() defer cleanups[1]() defer cleanups[2]() defer cleanups[3]() defer cleanups[5]() defer cleanups[6]() defer cleanups[7]() defer cleanups[8]() // A stand-by from failure domain 1 goes offline. cleanups[4]() time.Sleep(20 * time.Second) cli, err := apps[0].Leader(context.Background()) require.NoError(t, err) defer cli.Close() cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) // The replacement was picked in the same failure domain. assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) assert.Equal(t, client.StandBy, cluster[3].Role) assert.Equal(t, client.Spare, cluster[4].Role) assert.Equal(t, client.StandBy, cluster[5].Role) assert.Equal(t, client.Spare, cluster[6].Role) assert.Equal(t, client.StandBy, cluster[7].Role) assert.Equal(t, client.Spare, cluster[8].Role) } // Open a database on a fresh one-node cluster. func TestOpen(t *testing.T) { app, cleanup := newApp(t, app.WithAddress("127.0.0.1:9000")) defer cleanup() db, err := app.Open(context.Background(), "test") require.NoError(t, err) defer db.Close() _, err = db.ExecContext(context.Background(), "CREATE TABLE foo(n INT)") assert.NoError(t, err) } // Open a database with disk-mode on a fresh one-node cluster. func TestOpenDisk(t *testing.T) { app, cleanup := newApp(t, app.WithAddress("127.0.0.1:9000"), app.WithDiskMode(true)) defer cleanup() db, err := app.Open(context.Background(), "test") require.NoError(t, err) defer db.Close() _, err = db.ExecContext(context.Background(), "CREATE TABLE foo(n INT)") assert.NoError(t, err) } // Test some setup options func TestOptions(t *testing.T) { options := []app.Option{ app.WithAddress("127.0.0.1:9000"), app.WithNetworkLatency(20 * time.Millisecond), app.WithSnapshotParams(dqlite.SnapshotParams{Threshold: 1024, Trailing: 1024}), app.WithTracing(client.LogDebug), } app, cleanup := newApp(t, options...) defer cleanup() require.NotNil(t, app) } // Test client connections dropping uncleanly. func TestProxy_Error(t *testing.T) { cert, pool := loadCert(t) dial := client.DialFuncWithTLS(client.DefaultDialFunc, app.SimpleDialTLSConfig(cert, pool)) _, cleanup := newApp(t, app.WithAddress("127.0.0.1:9000")) defer cleanup() // Simulate a client which writes the protocol header, then a Leader // request and finally drops before reading the response. conn, err := dial(context.Background(), "127.0.0.1:9000") require.NoError(t, err) protocol := make([]byte, 8) binary.LittleEndian.PutUint64(protocol, uint64(1)) n, err := conn.Write(protocol) require.NoError(t, err) assert.Equal(t, n, 8) header := make([]byte, 8) binary.LittleEndian.PutUint32(header[0:], 1) header[4] = 0 header[5] = 0 binary.LittleEndian.PutUint16(header[6:], 0) n, err = conn.Write(header) require.NoError(t, err) assert.Equal(t, n, 8) body := make([]byte, 8) n, err = conn.Write(body) require.NoError(t, err) assert.Equal(t, n, 8) time.Sleep(100 * time.Millisecond) conn.Close() time.Sleep(250 * time.Millisecond) } // If the given context is cancelled before initial tasks are completed, an // error is returned. func TestReady_Cancel(t *testing.T) { app, cleanup := newApp(t, app.WithAddress("127.0.0.1:9002"), app.WithCluster([]string{"127.0.0.1:9001"})) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() err := app.Ready(ctx) assert.Equal(t, ctx.Err(), err) } func newApp(t *testing.T, options ...app.Option) (*app.App, func()) { t.Helper() dir, dirCleanup := newDir(t) app, appCleanup := newAppWithDir(t, dir, options...) cleanup := func() { appCleanup() dirCleanup() } return app, cleanup } // TestExternalConn creates a 3-member cluster using external http connections // and ensures the cluster is successfully created, and that the connection is // handled manually. func TestExternalConnWithTCP(t *testing.T) { externalAddr1 := "127.0.0.1:9191" externalAddr2 := "127.0.0.1:9292" externalAddr3 := "127.0.0.1:9393" acceptCh1 := make(chan net.Conn) acceptCh2 := make(chan net.Conn) acceptCh3 := make(chan net.Conn) hijackStatus := "101 Switching Protocols" dialFunc := func(ctx context.Context, addr string) (net.Conn, error) { conn, err := net.Dial("tcp", addr) require.NoError(t, err) request := &http.Request{} request.URL, err = url.Parse("http://" + addr) require.NoError(t, err) require.NoError(t, request.Write(conn)) resp, err := http.ReadResponse(bufio.NewReader(conn), request) require.NoError(t, err) require.Equal(t, hijackStatus, resp.Status) return conn, nil } newHandler := func(acceptCh chan net.Conn) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { hijacker, ok := w.(http.Hijacker) require.True(t, ok) conn, _, err := hijacker.Hijack() require.NoError(t, err) acceptCh <- conn } } // Start up three listeners. go http.ListenAndServe(externalAddr1, newHandler(acceptCh1)) go http.ListenAndServe(externalAddr2, newHandler(acceptCh2)) go http.ListenAndServe(externalAddr3, newHandler(acceptCh3)) app1, cleanup := newAppWithNoTLS(t, app.WithAddress(externalAddr1), app.WithExternalConn(dialFunc, acceptCh1)) defer cleanup() app2, cleanup := newAppWithNoTLS(t, app.WithAddress(externalAddr2), app.WithExternalConn(dialFunc, acceptCh2), app.WithCluster([]string{externalAddr1})) defer cleanup() require.NoError(t, app2.Ready(context.Background())) app3, cleanup := newAppWithNoTLS(t, app.WithAddress(externalAddr3), app.WithExternalConn(dialFunc, acceptCh3), app.WithCluster([]string{externalAddr1})) defer cleanup() require.NoError(t, app3.Ready(context.Background())) // Get a client from the first node (likely the leader). cli, err := app1.Leader(context.Background()) require.NoError(t, err) defer cli.Close() // Ensure entries exist for each cluster member. cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, externalAddr1, cluster[0].Address) assert.Equal(t, externalAddr2, cluster[1].Address) assert.Equal(t, externalAddr3, cluster[2].Address) // Every cluster member should be a voter. assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) } // TestExternalPipe creates a 3-member cluster using net.Pipe // and ensures the cluster is successfully created, and that the connection is // handled manually. func TestExternalConnWithPipe(t *testing.T) { externalAddr1 := "first" externalAddr2 := "second" externalAddr3 := "third" acceptCh1 := make(chan net.Conn) acceptCh2 := make(chan net.Conn) acceptCh3 := make(chan net.Conn) dialChannels := map[string]chan net.Conn{ externalAddr1: acceptCh1, externalAddr2: acceptCh2, externalAddr3: acceptCh3, } dialFunc := func(_ context.Context, addr string) (net.Conn, error) { client, server := net.Pipe() dialChannels[addr] <- server return client, nil } app1, cleanup := newAppWithNoTLS(t, app.WithAddress(externalAddr1), app.WithExternalConn(dialFunc, acceptCh1)) defer cleanup() app2, cleanup := newAppWithNoTLS(t, app.WithAddress(externalAddr2), app.WithExternalConn(dialFunc, acceptCh2), app.WithCluster([]string{externalAddr1})) defer cleanup() require.NoError(t, app2.Ready(context.Background())) app3, cleanup := newAppWithNoTLS(t, app.WithAddress(externalAddr3), app.WithExternalConn(dialFunc, acceptCh3), app.WithCluster([]string{externalAddr1})) defer cleanup() require.NoError(t, app3.Ready(context.Background())) // Get a client from the first node (likely the leader). cli, err := app1.Leader(context.Background()) require.NoError(t, err) defer cli.Close() // Ensure entries exist for each cluster member. cluster, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Equal(t, externalAddr1, cluster[0].Address) assert.Equal(t, externalAddr2, cluster[1].Address) assert.Equal(t, externalAddr3, cluster[2].Address) // Every cluster member should be a voter. assert.Equal(t, client.Voter, cluster[0].Role) assert.Equal(t, client.Voter, cluster[1].Role) assert.Equal(t, client.Voter, cluster[2].Role) } func TestParallelNewApp(t *testing.T) { t.Parallel() for i := 0; i < 100; i++ { i := i t.Run(fmt.Sprintf("run-%d", i), func(tt *testing.T) { tt.Parallel() // TODO: switch this to tt.TempDir when we switch to tmpDir := filepath.Join(os.TempDir(), strings.ReplaceAll(tt.Name(), "/", "-")) require.NoError(tt, os.MkdirAll(tmpDir, 0700)) dqApp, err := app.New(tmpDir, app.WithAddress(fmt.Sprintf("127.0.0.1:%d", 10200+i)), ) require.NoError(tt, err) defer func() { _ = dqApp.Close() _ = os.RemoveAll(tmpDir) }() }) } } func newAppWithDir(t *testing.T, dir string, options ...app.Option) (*app.App, func()) { t.Helper() appIndex++ index := appIndex log := func(l client.LogLevel, format string, a ...interface{}) { format = fmt.Sprintf("%s - %d: %s: %s", time.Now().Format("15:04:01.000"), index, l.String(), format) t.Logf(format, a...) } cert, pool := loadCert(t) options = append(options, app.WithLogFunc(log), app.WithTLS(app.SimpleTLSConfig(cert, pool))) app, err := app.New(dir, options...) require.NoError(t, err) cleanup := func() { require.NoError(t, app.Close()) } return app, cleanup } func newAppWithNoTLS(t *testing.T, options ...app.Option) (*app.App, func()) { t.Helper() dir, dirCleanup := newDir(t) appIndex++ index := appIndex log := func(l client.LogLevel, format string, a ...interface{}) { format = fmt.Sprintf("%s - %d: %s: %s", time.Now().Format("15:04:01.000"), index, l.String(), format) t.Logf(format, a...) } options = append(options, app.WithLogFunc(log)) app, err := app.New(dir, options...) require.NoError(t, err) cleanup := func() { require.NoError(t, app.Close()) dirCleanup() } return app, cleanup } // Loads the test TLS certificates. func loadCert(t *testing.T) (tls.Certificate, *x509.CertPool) { t.Helper() crt := filepath.Join("testdata", "cluster.crt") key := filepath.Join("testdata", "cluster.key") keypair, err := tls.LoadX509KeyPair(crt, key) require.NoError(t, err) data, err := ioutil.ReadFile(crt) require.NoError(t, err) pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(data) { t.Fatal("bad certificate") } return keypair, pool } var appIndex int // Return a new temporary directory. func newDir(t *testing.T) (string, func()) { t.Helper() dir, err := ioutil.TempDir("", "dqlite-app-test-") assert.NoError(t, err) cleanup := func() { os.RemoveAll(dir) } return dir, cleanup } func Test_TxRowsAffected(t *testing.T) { app, cleanup := newAppWithNoTLS(t, app.WithAddress("127.0.0.1:9001")) defer cleanup() err := app.Ready(context.Background()) require.NoError(t, err) db, err := app.Open(context.Background(), "test") require.NoError(t, err) defer db.Close() _, err = db.ExecContext(context.Background(), ` CREATE TABLE test ( id TEXT PRIMARY KEY, value INT );`) require.NoError(t, err) // Insert watermark err = tx(context.Background(), db, func(ctx context.Context, tx *sql.Tx) error { query := ` INSERT INTO test (id, value) VALUES ('id0', -1); ` result, err := tx.ExecContext(ctx, query) if err != nil { return err } _, err = result.RowsAffected() if err != nil { return err } return nil }) require.NoError(t, err) // Update watermark err = tx(context.Background(), db, func(ctx context.Context, tx *sql.Tx) error { query := ` UPDATE test SET value = 1 WHERE id = 'id0'; ` result, err := tx.ExecContext(ctx, query) if err != nil { return err } affected, err := result.RowsAffected() if err != nil { return err } if affected != 1 { return fmt.Errorf("expected 1 row affected, got %d", affected) } return nil }) require.NoError(t, err) } func tx(ctx context.Context, db *sql.DB, fn func(context.Context, *sql.Tx) error) error { tx, err := db.BeginTx(ctx, nil) if err != nil { return err } if err := fn(ctx, tx); err != nil { _ = tx.Rollback() return err } return tx.Commit() } golang-github-canonical-go-dqlite-2.0.0/app/dial.go000066400000000000000000000027451471100661000220740ustar00rootroot00000000000000package app import ( "context" "crypto/tls" "fmt" "net" "github.com/canonical/go-dqlite/v2/client" ) // Like client.DialFuncWithTLS but also starts the proxy, since the raft // connect function only supports Unix and TCP connections. func makeNodeDialFunc(appCtx context.Context, config *tls.Config) client.DialFunc { dial := func(ctx context.Context, addr string) (net.Conn, error) { clonedConfig := config.Clone() if len(clonedConfig.ServerName) == 0 { remoteIP, _, err := net.SplitHostPort(addr) if err != nil { return nil, err } clonedConfig.ServerName = remoteIP } dialer := &net.Dialer{} conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, err } goUnix, cUnix, err := socketpair() if err != nil { return nil, fmt.Errorf("create pair of Unix sockets: %w", err) } go proxy(appCtx, conn, goUnix, clonedConfig) return cUnix, nil } return dial } // extDialFuncWithProxy executes given DialFunc and then copies the data back // and forth between the remote connection and a local unix socket. func extDialFuncWithProxy(appCtx context.Context, dialFunc client.DialFunc) client.DialFunc { return func(ctx context.Context, addr string) (net.Conn, error) { goUnix, cUnix, err := socketpair() if err != nil { return nil, fmt.Errorf("create pair of Unix sockets: %w", err) } conn, err := dialFunc(ctx, addr) if err != nil { return nil, err } go proxy(appCtx, conn, goUnix, nil) return cUnix, nil } } golang-github-canonical-go-dqlite-2.0.0/app/example_test.go000066400000000000000000000052421471100661000236500ustar00rootroot00000000000000package app_test import ( "fmt" "io/ioutil" "os" "github.com/canonical/go-dqlite/v2/app" ) // To start the first node of a dqlite cluster for the first time, its network // address should be specified using the app.WithAddress() option. // // When the node is restarted a second time, the app.WithAddress() option might // be omitted, since the node address will be persisted in the info.yaml file. // // The very first node has always the same ID (dqlite.BootstrapID). func Example() { dir, err := ioutil.TempDir("", "dqlite-app-example-") if err != nil { return } defer os.RemoveAll(dir) node, err := app.New(dir, app.WithAddress("127.0.0.1:9001")) if err != nil { return } fmt.Printf("0x%x %s\n", node.ID(), node.Address()) if err := node.Close(); err != nil { return } node, err = app.New(dir) if err != nil { return } defer node.Close() fmt.Printf("0x%x %s\n", node.ID(), node.Address()) // Output: 0x2dc171858c3155be 127.0.0.1:9001 // 0x2dc171858c3155be 127.0.0.1:9001 } // After starting the very first node, a second node can be started by passing // the address of the first node using the app.WithCluster() option. // // In general additional nodes can be started by specifying one or more // addresses of existing nodes using the app.Cluster() option. // // When the node is restarted a second time, the app.WithCluster() option might // be omitted, since the node has already joined the cluster. // // Each additional node will be automatically assigned a unique ID. func ExampleWithCluster() { dir1, err := ioutil.TempDir("", "dqlite-app-example-") if err != nil { return } defer os.RemoveAll(dir1) dir2, err := ioutil.TempDir("", "dqlite-app-example-") if err != nil { return } defer os.RemoveAll(dir2) dir3, err := ioutil.TempDir("", "dqlite-app-example-") if err != nil { return } defer os.RemoveAll(dir3) node1, err := app.New(dir1, app.WithAddress("127.0.0.1:9001")) if err != nil { return } defer node1.Close() node2, err := app.New(dir2, app.WithAddress("127.0.0.1:9002"), app.WithCluster([]string{"127.0.0.1:9001"})) if err != nil { return } defer node2.Close() node3, err := app.New(dir3, app.WithAddress("127.0.0.1:9003"), app.WithCluster([]string{"127.0.0.1:9001"})) if err != nil { return } fmt.Println(node1.ID() != node2.ID(), node1.ID() != node3.ID(), node2.ID() != node3.ID()) // true true true // Restart the third node, the only argument we need to pass to // app.New() is its dir. id3 := node3.ID() if err := node3.Close(); err != nil { return } node3, err = app.New(dir3) if err != nil { return } defer node3.Close() fmt.Println(node3.ID() == id3, node3.Address()) // true 127.0.0.1:9003 } golang-github-canonical-go-dqlite-2.0.0/app/files.go000066400000000000000000000035361471100661000222640ustar00rootroot00000000000000package app import ( "fmt" "io/ioutil" "os" "path/filepath" "gopkg.in/yaml.v2" "github.com/google/renameio" ) const ( // Store the node ID and address. infoFile = "info.yaml" // The node store file. storeFile = "cluster.yaml" // This is a "flag" file to signal when a brand new node needs to join // the cluster. In case the node doesn't successfully make it to join // the cluster first time it's started, it will re-try the next time. joinFile = "join" ) // Return true if the given file exists in the given directory. func fileExists(dir, file string) (bool, error) { path := filepath.Join(dir, file) if _, err := os.Stat(path); err != nil { if !os.IsNotExist(err) { return false, fmt.Errorf("check if %s exists: %w", file, err) } return false, nil } return true, nil } // Write a file in the given directory. func fileWrite(dir, file string, data []byte) error { path := filepath.Join(dir, file) if err := renameio.WriteFile(path, data, 0600); err != nil { return fmt.Errorf("write %s: %w", file, err) } return nil } // Marshal the given object as YAML into the given file. func fileMarshal(dir, file string, object interface{}) error { data, err := yaml.Marshal(object) if err != nil { return fmt.Errorf("marshall %s: %w", file, err) } if err := fileWrite(dir, file, data); err != nil { return err } return nil } // Unmarshal the given YAML file into the given object. func fileUnmarshal(dir, file string, object interface{}) error { path := filepath.Join(dir, file) data, err := ioutil.ReadFile(path) if err != nil { return fmt.Errorf("read %s: %w", file, err) } if err := yaml.Unmarshal(data, object); err != nil { return fmt.Errorf("unmarshall %s: %w", file, err) } return nil } // Remove a file in the given directory. func fileRemove(dir, file string) error { return os.Remove(filepath.Join(dir, file)) } golang-github-canonical-go-dqlite-2.0.0/app/options.go000066400000000000000000000234451471100661000226560ustar00rootroot00000000000000package app import ( "crypto/tls" "fmt" "log" "net" "strings" "time" "github.com/canonical/go-dqlite/v2" "github.com/canonical/go-dqlite/v2/client" "github.com/canonical/go-dqlite/v2/internal/protocol" ) // Option can be used to tweak app parameters. type Option func(*options) // WithAddress sets the network address of the application node. // // Other application nodes must be able to connect to this application node // using the given address. // // If the application node is not the first one in the cluster, the address // must match the value that was passed to the App.Add() method upon // registration. // // If not given the first non-loopback IP address of any of the system network // interfaces will be used, with port 9000. // // The address must be stable across application restarts. func WithAddress(address string) Option { return func(options *options) { options.Address = address } } // WithCluster must be used when starting a newly added application node for // the first time. // // It should contain the addresses of one or more applications nodes which are // already part of the cluster. func WithCluster(cluster []string) Option { return func(options *options) { options.Cluster = cluster } } // WithExternalConn enables passing an external dial function that will be used // whenever dqlite needs to make an outside connection. // // Also takes a net.Conn channel that should be received when the external connection has been accepted. func WithExternalConn(dialFunc client.DialFunc, acceptCh chan net.Conn) Option { return func(options *options) { options.Conn = &connSetup{ dialFunc: dialFunc, acceptCh: acceptCh, } } } // WithTLS enables TLS encryption of network traffic. // // The "listen" parameter must hold the TLS configuration to use when accepting // incoming connections clients or application nodes. // // The "dial" parameter must hold the TLS configuration to use when // establishing outgoing connections to other application nodes. func WithTLS(listen *tls.Config, dial *tls.Config) Option { return func(options *options) { options.TLS = &tlsSetup{ Listen: listen, Dial: dial, } } } // WithUnixSocket allows setting a specific socket path for communication between go-dqlite and dqlite. // // The default is an empty string which means a random abstract unix socket. func WithUnixSocket(path string) Option { return func(options *options) { options.UnixSocket = path } } // WithVoters sets the number of nodes in the cluster that should have the // Voter role. // // When a new node is added to the cluster or it is started again after a // shutdown it will be assigned the Voter role in case the current number of // voters is below n. // // Similarly when a node with the Voter role is shutdown gracefully by calling // the Handover() method, it will try to transfer its Voter role to another // non-Voter node, if one is available. // // All App instances in a cluster must be created with the same WithVoters // setting. // // The given value must be an odd number greater than one. // // The default value is 3. func WithVoters(n int) Option { return func(options *options) { options.Voters = n } } // WithStandBys sets the number of nodes in the cluster that should have the // StandBy role. // // When a new node is added to the cluster or it is started again after a // shutdown it will be assigned the StandBy role in case there are already // enough online voters, but the current number of stand-bys is below n. // // Similarly when a node with the StandBy role is shutdown gracefully by // calling the Handover() method, it will try to transfer its StandBy role to // another non-StandBy node, if one is available. // // All App instances in a cluster must be created with the same WithStandBys // setting. // // The default value is 3. func WithStandBys(n int) Option { return func(options *options) { options.StandBys = n } } // WithRolesAdjustmentFrequency sets the frequency at which the current cluster // leader will check if the roles of the various nodes in the cluster matches // the desired setup and perform promotions/demotions to adjust the situation // if needed. // // The default is 30 seconds. func WithRolesAdjustmentFrequency(frequency time.Duration) Option { return func(options *options) { options.RolesAdjustmentFrequency = frequency } } // WithRolesAdjustmentHook will be run each time the roles are adjusted, as // controlled by WithRolesAdjustmentFrequency. Provides the current raft leader information // as well as the most up to date list of cluster members and their roles. func WithRolesAdjustmentHook(hook func(leader client.NodeInfo, cluster []client.NodeInfo) error) Option { return func(o *options) { o.OnRolesAdjustment = hook } } // WithLogFunc sets a custom log function. func WithLogFunc(log client.LogFunc) Option { return func(options *options) { options.Log = log } } // WithTracing will emit a log message at the given level every time a // statement gets executed. func WithTracing(level client.LogLevel) Option { return func(options *options) { options.Tracing = level } } // WithFailureDomain sets the node's failure domain. // // Failure domains are taken into account when deciding which nodes to promote // to Voter or StandBy when needed. func WithFailureDomain(code uint64) Option { return func(options *options) { options.FailureDomain = code } } // WithNetworkLatency sets the average one-way network latency. func WithNetworkLatency(latency time.Duration) Option { return func(options *options) { options.NetworkLatency = latency } } // WithConcurrentLeaderConns is the maximum number of concurrent connections // to other cluster members that will be attempted while searching for the dqlite leader. // It takes a pointer to an integer so that the value can be dynamically modified based on cluster health. // // The default is 10 connections to other cluster members. func WithConcurrentLeaderConns(maxConns *int64) Option { return func(o *options) { o.ConcurrentLeaderConns = maxConns } } // WithSnapshotParams sets the raft snapshot parameters. func WithSnapshotParams(params dqlite.SnapshotParams) Option { return func(options *options) { options.SnapshotParams = params } } // WithDiskMode enables or disables disk-mode. // WARNING: This is experimental API, use with caution // and prepare for data loss. // UNSTABLE: Behavior can change in future. // NOT RECOMMENDED for production use-cases, use at own risk. func WithDiskMode(disk bool) Option { return func(options *options) { options.DiskMode = disk } } // WithAutoRecovery enables or disables auto-recovery of persisted data // at startup for this node. // // When auto-recovery is enabled, raft snapshots and segment files may be // deleted at startup if they are determined to be corrupt. This helps // the startup process to succeed in more cases, but can lead to data loss. // // Auto-recovery is enabled by default. func WithAutoRecovery(recovery bool) Option { return func(options *options) { options.AutoRecovery = recovery } } type tlsSetup struct { Listen *tls.Config Dial *tls.Config } type connSetup struct { dialFunc client.DialFunc acceptCh chan net.Conn } type options struct { Address string Cluster []string Log client.LogFunc Tracing client.LogLevel TLS *tlsSetup Conn *connSetup Voters int StandBys int RolesAdjustmentFrequency time.Duration OnRolesAdjustment func(client.NodeInfo, []client.NodeInfo) error FailureDomain uint64 NetworkLatency time.Duration ConcurrentLeaderConns *int64 UnixSocket string SnapshotParams dqlite.SnapshotParams DiskMode bool AutoRecovery bool } // Create a options object with sane defaults. func defaultOptions() *options { maxConns := protocol.MaxConcurrentLeaderConns return &options{ Log: defaultLogFunc, Tracing: client.LogNone, Voters: 3, StandBys: 3, RolesAdjustmentFrequency: 30 * time.Second, OnRolesAdjustment: func(client.NodeInfo, []client.NodeInfo) error { return nil }, DiskMode: false, // Be explicit about not enabling disk-mode by default. AutoRecovery: true, ConcurrentLeaderConns: &maxConns, } } func isLoopback(iface *net.Interface) bool { return int(iface.Flags&net.FlagLoopback) > 0 } // see https://stackoverflow.com/a/48519490/3613657 // Valid IPv4 notations: // // "192.168.0.1": basic // "192.168.0.1:80": with port info // // Valid IPv6 notations: // // "::FFFF:C0A8:1": basic // "::FFFF:C0A8:0001": leading zeros // "0000:0000:0000:0000:0000:FFFF:C0A8:1": double colon expanded // "::FFFF:C0A8:1%1": with zone info // "::FFFF:192.168.0.1": IPv4 literal // "[::FFFF:C0A8:1]:80": with port info // "[::FFFF:C0A8:1%1]:80": with zone and port info func isIpV4(ip string) bool { return strings.Count(ip, ":") < 2 } func defaultAddress() (addr string, err error) { ifaces, err := net.Interfaces() if err != nil { return "", err } for _, iface := range ifaces { if isLoopback(&iface) { continue } addrs, err := iface.Addrs() if err != nil { continue } if len(addrs) == 0 { continue } addr, ok := addrs[0].(*net.IPNet) if !ok { continue } ipStr := addr.IP.String() if isIpV4(ipStr) { return addr.IP.String() + ":9000", nil } else { return "[" + addr.IP.String() + "]" + ":9000", nil } } return "", fmt.Errorf("no suitable net.Interface found: %v", err) } func defaultLogFunc(l client.LogLevel, format string, a ...interface{}) { // Log only error messages if l != client.LogError { return } msg := fmt.Sprintf("["+l.String()+"]"+" dqlite: "+format, a...) log.Printf("%s", msg) } golang-github-canonical-go-dqlite-2.0.0/app/proxy.go000066400000000000000000000130201471100661000223300ustar00rootroot00000000000000package app import ( "context" "crypto/tls" "fmt" "io" "net" "os" "reflect" "syscall" "time" "unsafe" "golang.org/x/sys/unix" ) // Copies data between a remote TCP network connection (possibly with TLS) and // a local unix socket. // // The function will return if one of the following events occurs: // // - the other end of the remote network socket closes the connection // - the other end of the local unix socket closes the connection // - the context is cancelled // - an error occurs when writing or reading data // // In case of errors, details are returned. func proxy(ctx context.Context, remote net.Conn, local net.Conn, config *tls.Config) error { tcp, err := tryExtractTCPConn(remote) if err == nil { if err := setKeepalive(tcp); err != nil { return err } } if config != nil { if config.ClientCAs != nil { remote = tls.Server(remote, config) } else { remote = tls.Client(remote, config) } } remoteToLocal := make(chan error) localToRemote := make(chan error) // Start copying data back and forth until either the client or the // server get closed or hit an error. go func() { _, err := io.Copy(local, remote) remoteToLocal <- err }() go func() { _, err := io.Copy(remote, local) localToRemote <- err }() errs := make([]error, 2) select { case <-ctx.Done(): // Force closing, ignore errors. remote.Close() local.Close() <-remoteToLocal <-localToRemote case err := <-remoteToLocal: if err != nil { errs[0] = fmt.Errorf("remote -> local: %v", err) } local.(*net.UnixConn).CloseRead() if err := <-localToRemote; err != nil { errs[1] = fmt.Errorf("local -> remote: %v", err) } remote.Close() local.Close() case err := <-localToRemote: if err != nil { errs[0] = fmt.Errorf("local -> remote: %v", err) } if tcp != nil { tcp.CloseRead() } if err := <-remoteToLocal; err != nil { errs[1] = fmt.Errorf("remote -> local: %v", err) } remote.Close() local.Close() } if errs[0] != nil || errs[1] != nil { return proxyError{first: errs[0], second: errs[1]} } return nil } // tryExtractTCPConn tries to extract the underlying net.TCPConn, potentially from a tls.Conn. func tryExtractTCPConn(conn net.Conn) (*net.TCPConn, error) { tcp, ok := conn.(*net.TCPConn) if ok { return tcp, nil } // Go doesn't currently expose the underlying TCP connection of a TLS connection, but we need it in order // to set timeout properties on the connection. We use some reflect/unsafe magic to extract the private // remote.conn field, which is indeed the underlying TCP connection. tlsConn, ok := conn.(*tls.Conn) if !ok { return nil, fmt.Errorf("connection is not a tls.Conn") } field := reflect.ValueOf(tlsConn).Elem().FieldByName("conn") field = reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() c := field.Interface() tcpConn, ok := c.(*net.TCPConn) if !ok { return nil, fmt.Errorf("connection is not a net.TCPConn") } return tcpConn, nil } // Set TCP_USER_TIMEOUT and TCP keepalive with 3 seconds idle time, 3 seconds // retry interval with at most 3 retries. // // See https://thenotexpert.com/golang-tcp-keepalive/. func setKeepalive(conn *net.TCPConn) error { err := conn.SetKeepAlive(true) if err != nil { return err } err = conn.SetKeepAlivePeriod(time.Second * 3) if err != nil { return err } raw, err := conn.SyscallConn() if err != nil { return err } raw.Control( func(ptr uintptr) { fd := int(ptr) // Number of probes. err = syscall.SetsockoptInt(fd, syscall.IPPROTO_TCP, _TCP_KEEPCNT, 3) if err != nil { return } // Wait time after an unsuccessful probe. err = syscall.SetsockoptInt(fd, syscall.IPPROTO_TCP, _TCP_KEEPINTVL, 3) if err != nil { return } // Set TCP_USER_TIMEOUT option to limit the maximum amount of time in ms that transmitted data may remain // unacknowledged before TCP will forcefully close the corresponding connection and return ETIMEDOUT to the // application. This combined with the TCP keepalive options on the socket will ensure that should the // remote side of the connection disappear abruptly that dqlite will detect this and close the socket quickly. // Decreasing the user timeouts allows applications to "fail fast" if so desired. Otherwise it may take // up to 20 minutes with the current system defaults in a normal WAN environment if there are packets in // the send queue that will prevent the keepalive timer from working as the retransmission timers kick in. // See https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/commit/?id=dca43c75e7e545694a9dd6288553f55c53e2a3a3 err = syscall.SetsockoptInt(fd, syscall.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int(30*time.Microsecond)) if err != nil { return } }) return err } // Returns a pair of connected unix sockets. func socketpair() (net.Conn, net.Conn, error) { fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0) if err != nil { return nil, nil, err } c1, err := fdToFileConn(fds[0]) if err != nil { return nil, nil, err } c2, err := fdToFileConn(fds[1]) if err != nil { c1.Close() return nil, nil, err } return c1, c2, err } func fdToFileConn(fd int) (net.Conn, error) { f := os.NewFile(uintptr(fd), "") defer f.Close() return net.FileConn(f) } type proxyError struct { first error second error } func (e proxyError) Error() string { msg := "" if e.first != nil { msg += "first: " + e.first.Error() } if e.second != nil { if e.first != nil { msg += " " } msg += "second: " + e.second.Error() } return msg } golang-github-canonical-go-dqlite-2.0.0/app/proxy_darwin.go000066400000000000000000000003071471100661000237000ustar00rootroot00000000000000// +build darwin package app // from netinet/tcp.h (OS X 10.9.4) const ( _TCP_KEEPINTVL = 0x101 /* interval between keepalives */ _TCP_KEEPCNT = 0x102 /* number of keepalives before close */ ) golang-github-canonical-go-dqlite-2.0.0/app/proxy_linux.go000066400000000000000000000003311471100661000235500ustar00rootroot00000000000000// +build linux package app import ( "syscall" ) const ( _TCP_KEEPINTVL = syscall.TCP_KEEPINTVL /* interval between keepalives */ _TCP_KEEPCNT = syscall.TCP_KEEPCNT /* number of keepalives before close */ ) golang-github-canonical-go-dqlite-2.0.0/app/roles.go000066400000000000000000000264551471100661000223130ustar00rootroot00000000000000package app import ( "sort" "github.com/canonical/go-dqlite/v2/client" ) const minVoters = 3 // RolesConfig can be used to tweak the algorithm implemented by RolesChanges. type RolesConfig struct { Voters int // Target number of voters, 3 by default. StandBys int // Target number of stand-bys, 3 by default. } // RolesChanges implements an algorithm to take decisions about which node // should have which role in a cluster. // // You normally don't need to use this data structure since it's already // transparently wired into the high-level App object. However this is exposed // for users who don't want to use the high-level App object but still want to // implement the same roles management algorithm. type RolesChanges struct { // Algorithm configuration. Config RolesConfig // Current state of the cluster. Each node in the cluster must be // present as a key in the map, and its value should be its associated // failure domain and weight metadata or nil if the node is currently // offline. State map[client.NodeInfo]*client.NodeMetadata } // Assume decides if a node should assume a different role than the one it // currently has. It should normally be run at node startup, where the // algorithm might decide that the node should assume the Voter or Stand-By // role in case there's a shortage of them. // // Return -1 in case no role change is needed. func (c *RolesChanges) Assume(id uint64) client.NodeRole { // If the cluster is still too small, do nothing. if c.size() < minVoters { return -1 } node := c.get(id) // If we are not in the cluster, it means we were removed, just do nothing. if node == nil { return -1 } // If we already have the Voter or StandBy role, there's nothing to do. if node.Role == client.Voter || node.Role == client.StandBy { return -1 } onlineVoters := c.list(client.Voter, true, nil) onlineStandbys := c.list(client.StandBy, true, nil) // If we have already the desired number of online voters and // stand-bys, there's nothing to do. if len(onlineVoters) >= c.Config.Voters && len(onlineStandbys) >= c.Config.StandBys { return -1 } // Figure if we need to become stand-by or voter. role := client.StandBy if len(onlineVoters) < c.Config.Voters { role = client.Voter } return role } // Handover decides if a node should transfer its current role to another // node. This is typically run when the node is shutting down and is hence going to be offline soon. // // Return the role that should be handed over and list of candidates that // should receive it, in order of preference. func (c *RolesChanges) Handover(id uint64) (client.NodeRole, []client.NodeInfo) { node := c.get(id) // If we are not in the cluster, it means we were removed, just do nothing. if node == nil { return -1, nil } // If we aren't a voter or a stand-by, there's nothing to do. if node.Role != client.Voter && node.Role != client.StandBy { return -1, nil } // Make a list of all online nodes with the same role and get their // failure domains. peers := c.list(node.Role, true, nil) for i := range peers { if peers[i].ID == node.ID { peers = append(peers[:i], peers[i+1:]...) break } } domains := c.failureDomains(peers) // Online spare nodes are always candidates. candidates := c.list(client.Spare, true, nil) // Stand-by nodes are candidates if we need to transfer voting // rights, and they are preferred over spares. if node.Role == client.Voter { candidates = append(c.list(client.StandBy, true, nil), candidates...) } if len(candidates) == 0 { // No online node available to be promoted. return -1, nil } c.sortCandidates(candidates, domains) return node.Role, candidates } // Adjust decides if there should be changes in the current roles. // // Return the role that should be assigned and a list of candidates that should // assume it, in order of preference. func (c *RolesChanges) Adjust(leader uint64) (client.NodeRole, []client.NodeInfo) { if c.size() == 1 { return -1, nil } // If the cluster is too small, make sure we have just one voter (us). if c.size() < minVoters { for node := range c.State { if node.ID == leader || node.Role != client.Voter { continue } return client.Spare, []client.NodeInfo{node} } return -1, nil } onlineVoters := c.list(client.Voter, true, nil) onlineStandbys := c.list(client.StandBy, true, nil) offlineVoters := c.list(client.Voter, false, nil) offlineStandbys := c.list(client.StandBy, false, nil) domainsWithVoters := c.failureDomains(onlineVoters) allDomains := c.allFailureDomains() // If we do not have voters on all failure domains and we have a domain with more than one voters // we may need to send voters to domains without voters. if len(domainsWithVoters) < len(allDomains) && len(domainsWithVoters) < len(onlineVoters) { // Find the domains we need to populate with voters domainsWithoutVoters := c.domainsSubtract(allDomains, domainsWithVoters) // Find nodes in the domains we need to populate candidates := c.list(client.StandBy, true, domainsWithoutVoters) candidates = append(candidates, c.list(client.Spare, true, domainsWithoutVoters)...) if len(candidates) > 0 { c.sortCandidates(candidates, domainsWithoutVoters) return client.Voter, candidates } } // If we have exactly the desired number of voters and stand-bys, and they are all // online, we're good. if len(offlineVoters) == 0 && len(onlineVoters) == c.Config.Voters && len(offlineStandbys) == 0 && len(onlineStandbys) == c.Config.StandBys { return -1, nil } // If we have less online voters than desired, let's try to promote // some other node. if n := len(onlineVoters); n < c.Config.Voters { candidates := c.list(client.StandBy, true, nil) candidates = append(candidates, c.list(client.Spare, true, nil)...) if len(candidates) == 0 { return -1, nil } domains := c.failureDomains(onlineVoters) c.sortCandidates(candidates, domains) return client.Voter, candidates } // If we have more online voters than desired, let's demote one of // them. if n := len(onlineVoters); n > c.Config.Voters { nodes := []client.NodeInfo{} for _, node := range onlineVoters { // Don't demote the leader. if node.ID == leader { continue } nodes = append(nodes, node) } return client.Spare, c.sortVoterCandidatesToDemote(nodes) } // If we have offline voters, let's demote one of them. if n := len(offlineVoters); n > 0 { return client.Spare, offlineVoters } // If we have less online stand-bys than desired, let's try to promote // some other node. if n := len(onlineStandbys); n < c.Config.StandBys { candidates := c.list(client.Spare, true, nil) if len(candidates) == 0 { return -1, nil } domains := c.failureDomains(onlineStandbys) c.sortCandidates(candidates, domains) return client.StandBy, candidates } // If we have more online stand-bys than desired, let's demote one of // them. if n := len(onlineStandbys); n > c.Config.StandBys { nodes := []client.NodeInfo{} for _, node := range onlineStandbys { // Don't demote the leader. if node.ID == leader { continue } nodes = append(nodes, node) } return client.Spare, nodes } // If we have offline stand-bys, let's demote one of them. if n := len(offlineStandbys); n > 0 { return client.Spare, offlineStandbys } return -1, nil } // Return the number of nodes il the cluster. func (c *RolesChanges) size() int { return len(c.State) } // Return information about the node with the given ID, or nil if no node // matches. func (c *RolesChanges) get(id uint64) *client.NodeInfo { for node := range c.State { if node.ID == id { return &node } } return nil } // Return the online or offline nodes with the given role (optionally) in specific domains. func (c *RolesChanges) list(role client.NodeRole, online bool, domains map[uint64]bool) []client.NodeInfo { nodes := []client.NodeInfo{} for node, metadata := range c.State { if node.Role == role && metadata != nil == online { if domains == nil || (domains != nil && domains[metadata.FailureDomain]) { nodes = append(nodes, node) } } } return nodes } // Return the number of online or offline nodes with the given role. func (c *RolesChanges) count(role client.NodeRole, online bool) int { return len(c.list(role, online, nil)) } // Return a map of the failure domains associated with the // given nodes. func (c *RolesChanges) failureDomains(nodes []client.NodeInfo) map[uint64]bool { domains := map[uint64]bool{} for _, node := range nodes { metadata := c.State[node] if metadata == nil { continue } domains[metadata.FailureDomain] = true } return domains } // Return a map of all failureDomains with online nodes. func (c *RolesChanges) allFailureDomains() map[uint64]bool { domains := map[uint64]bool{} for _, metadata := range c.State { if metadata == nil { continue } domains[metadata.FailureDomain] = true } return domains } // Return a map of domains that is the "from" minus the "subtract". func (c *RolesChanges) domainsSubtract(from map[uint64]bool, subtract map[uint64]bool) map[uint64]bool { domains := map[uint64]bool{} for fd, val := range from { _, common := subtract[fd] if !common { domains[fd] = val } } return domains } // Sort the given candidates according to their failure domain and // weight. Candidates belonging to a failure domain different from the given // domains take precedence. func (c *RolesChanges) sortCandidates(candidates []client.NodeInfo, domains map[uint64]bool) { less := func(i, j int) bool { metadata1 := c.metadata(candidates[i]) metadata2 := c.metadata(candidates[j]) // If i's failure domain is not in the given list, but j's is, // then i takes precedence. if !domains[metadata1.FailureDomain] && domains[metadata2.FailureDomain] { return true } // If j's failure domain is not in the given list, but i's is, // then j takes precedence. if !domains[metadata2.FailureDomain] && domains[metadata1.FailureDomain] { return false } return metadata1.Weight < metadata2.Weight } sort.Slice(candidates, less) } // Sort the given candidates according demotion priority. Return the sorted // We prefer to select a candidate from a domain with multiple candidates. // We prefer to select the candidate with highest weight. func (c *RolesChanges) sortVoterCandidatesToDemote(candidates []client.NodeInfo) []client.NodeInfo { domainsMap := make(map[uint64][]client.NodeInfo) for _, node := range candidates { id := c.metadata(node).FailureDomain domain, exists := domainsMap[id] if !exists { domain = []client.NodeInfo{node} } else { domain = append(domain, node) } domainsMap[id] = domain } domains := make([][]client.NodeInfo, 0, len(domainsMap)) for _, domain := range domainsMap { domains = append(domains, domain) } sort.Slice(domains, func(i, j int) bool { return len(domains[i]) > len(domains[j]) }) for _, domain := range domains { sort.Slice(domain, func(i, j int) bool { metadata1 := c.metadata(domain[i]) metadata2 := c.metadata(domain[j]) return metadata1.Weight > metadata2.Weight }) } sortedCandidates := make([]client.NodeInfo, 0, len(candidates)) for _, domain := range domains { sortedCandidates = append(sortedCandidates, domain...) } return sortedCandidates } // Return the metadata of the given node, if any. func (c *RolesChanges) metadata(node client.NodeInfo) *client.NodeMetadata { return c.State[node] } golang-github-canonical-go-dqlite-2.0.0/app/testdata/000077500000000000000000000000001471100661000224355ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/app/testdata/cluster.crt000066400000000000000000000011571471100661000246340ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIBnjCCAUSgAwIBAgIUddf2VYy/riyr+d2rByY0OT/N2HEwCgYIKoZIzj0EAwIw FjEUMBIGA1UEAwwLZHFsaXRlLXRlc3QwHhcNMjExMTE3MTIxMTU2WhcNNDkwNDA0 MTIxMTU2WjAWMRQwEgYDVQQDDAtkcWxpdGUtdGVzdDBZMBMGByqGSM49AgEGCCqG SM49AwEHA0IABHhD/t8WFSlqi04l2ce8l4ZktVjMMCwZ5edEwAjJl2QOvaW6qkP1 wFAaE9LOHTDQNEJv/BsA0XIHKXpG7fTHISajcDBuMB0GA1UdDgQWBBQ1qdnDo6Qm eJ51EH2/CS1AzxM2BTAfBgNVHSMEGDAWgBQ1qdnDo6QmeJ51EH2/CS1AzxM2BTAP BgNVHRMBAf8EBTADAQH/MBsGA1UdEQQUMBKHBH8AAAGCCmxvY2FsLnRlc3QwCgYI KoZIzj0EAwIDSAAwRQIhAJPVzO4jh61qKw0au/7UVU1TERavD3XPwzQhhq0ph9/h AiA1k0k8Iruvlty/5PA/CPKxeBH7smUyquVLYQW5Y5GbzQ== -----END CERTIFICATE----- golang-github-canonical-go-dqlite-2.0.0/app/testdata/cluster.key000066400000000000000000000004561471100661000246350ustar00rootroot00000000000000-----BEGIN EC PARAMETERS----- BggqhkjOPQMBBw== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- MHcCAQEEIBxSTUI5Xk1nsd/yfovKZ0cNdPGEcCTANDs0epC/Vo5foAoGCCqGSM49 AwEHoUQDQgAEeEP+3xYVKWqLTiXZx7yXhmS1WMwwLBnl50TACMmXZA69pbqqQ/XA UBoT0s4dMNA0Qm/8GwDRcgcpekbt9MchJg== -----END EC PRIVATE KEY----- golang-github-canonical-go-dqlite-2.0.0/app/tls.go000066400000000000000000000076421471100661000217660ustar00rootroot00000000000000package app import ( "crypto/tls" "crypto/x509" "fmt" ) // SimpleTLSConfig returns a pair of TLS configuration objects with sane // defaults, one to be used as server-side configuration when listening to // incoming connections and one to be used as client-side configuration when // establishing outgoing connections. // // The returned configs can be used as "listen" and "dial" parameters for the // WithTLS option. // // In order to generate a suitable TLS certificate you can use the openssl // command, for example: // // DNS=$(hostname) // IP=$(hostname -I | cut -f 1 -d ' ') // CN=example.com // openssl req -x509 -newkey rsa:4096 -sha256 -days 3650 \ // -nodes -keyout cluster.key -out cluster.crt -subj "/CN=$CN" \ // -addext "subjectAltName=DNS:$DNS,IP:$IP" // // then load the resulting key pair and pool with: // // cert, _ := tls.LoadX509KeyPair("cluster.crt", "cluster.key") // data, _ := ioutil.ReadFile("cluster.crt") // pool := x509.NewCertPool() // pool.AppendCertsFromPEM(data) // // and finally use the WithTLS option together with the SimpleTLSConfig helper: // // app, _ := app.New("/my/dir", app.WithTLS(app.SimpleTLSConfig(cert, pool))) // // See SimpleListenTLSConfig and SimpleDialTLSConfig for details. func SimpleTLSConfig(cert tls.Certificate, pool *x509.CertPool) (*tls.Config, *tls.Config) { listen := SimpleListenTLSConfig(cert, pool) dial := SimpleDialTLSConfig(cert, pool) return listen, dial } // SimpleListenTLSConfig returns a server-side TLS configuration with sane // defaults (e.g. TLS version, ciphers and mutual authentication). // // The cert parameter must be a public/private key pair, typically loaded from // disk using tls.LoadX509KeyPair(). // // The pool parameter can be used to specify a custom signing CA (e.g. for // self-signed certificates). // // When server and client both use the same certificate, the same key pair and // pool should be passed to SimpleDialTLSConfig() in order to generate the // client-side config. // // The returned config can be used as "listen" parameter for the WithTLS // option. // // A user can modify the returned config to suit their specifig needs. func SimpleListenTLSConfig(cert tls.Certificate, pool *x509.CertPool) *tls.Config { config := &tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{cert}, RootCAs: pool, ClientCAs: pool, ClientAuth: tls.RequireAndVerifyClientCert, } config.BuildNameToCertificate() return config } // SimpleDialTLSConfig returns a client-side TLS configuration with sane // defaults (e.g. TLS version, ciphers and mutual authentication). // // The cert parameter must be a public/private key pair, typically loaded from // disk using tls.LoadX509KeyPair(). // // The pool parameter can be used to specify a custom signing CA (e.g. for // self-signed certificates). // // When server and client both use the same certificate, the same key pair and // pool should be passed to SimpleListenTLSConfig() in order to generate the // server-side config. // // The returned config can be used as "client" parameter for the WithTLS App // option, or as "config" parameter for the client.DialFuncWithTLS() helper. // // TLS connections using the same `Config` will share a ClientSessionCache. // You can override this behaviour by setting your own ClientSessionCache or // nil. // // A user can modify the returned config to suit their specifig needs. func SimpleDialTLSConfig(cert tls.Certificate, pool *x509.CertPool) *tls.Config { config := &tls.Config{ MinVersion: tls.VersionTLS12, RootCAs: pool, Certificates: []tls.Certificate{cert}, ClientSessionCache: tls.NewLRUClientSessionCache(0), } x509cert, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { panic(fmt.Errorf("parse certificate: %v", err)) } if len(x509cert.DNSNames) == 0 { panic("certificate has no DNS extension") } config.ServerName = x509cert.DNSNames[0] return config } golang-github-canonical-go-dqlite-2.0.0/benchmark/000077500000000000000000000000001471100661000217765ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/benchmark/benchmark.go000066400000000000000000000100421471100661000242540ustar00rootroot00000000000000package benchmark import ( "context" "database/sql" "errors" "fmt" "os" "path" "time" "github.com/canonical/go-dqlite/v2/app" "github.com/canonical/go-dqlite/v2/client" ) const ( kvSchema = "CREATE TABLE IF NOT EXISTS model (key TEXT, value TEXT, UNIQUE(key))" ) type Benchmark struct { app *app.App db *sql.DB dir string options *options workers []*worker } func createWorkers(o *options) []*worker { workers := make([]*worker, o.nWorkers) for i := 0; i < o.nWorkers; i++ { switch o.workload { case kvWrite: workers[i] = newWorker(kvWriter, o) case kvReadWrite: workers[i] = newWorker(kvReaderWriter, o) } } return workers } func New(app *app.App, db *sql.DB, dir string, options ...Option) (bm *Benchmark, err error) { o := defaultOptions() for _, option := range options { option(o) } bm = &Benchmark{ app: app, db: db, dir: dir, options: o, workers: createWorkers(o), } return bm, nil } func (bm *Benchmark) runWorkload(ctx context.Context) { for _, worker := range bm.workers { go worker.run(ctx, bm.db) } } func (bm *Benchmark) kvSetup() error { _, err := bm.db.Exec(kvSchema) return err } func (bm *Benchmark) setup() error { switch bm.options.workload { default: return bm.kvSetup() } } func reportName(id int, work work) string { return fmt.Sprintf("%d-%s-%d", id, work, time.Now().Unix()) } // Returns a map of filename to filecontent func (bm *Benchmark) reportFiles() map[string]string { allReports := make(map[string]string) for i, worker := range bm.workers { reports := worker.report() for w, report := range reports { file := reportName(i, w) allReports[file] = report.String() } } return allReports } func (bm *Benchmark) reportResults() error { dir := path.Join(bm.dir, "results") if err := os.MkdirAll(dir, 0755); err != nil { return fmt.Errorf("failed to create %v: %v", dir, err) } reports := bm.reportFiles() for filename, content := range reports { f, err := os.Create(path.Join(dir, filename)) if err != nil { return fmt.Errorf("failed to create %v in %v: %v", filename, dir, err) } _, err = f.WriteString(content) if err != nil { return fmt.Errorf("failed to write %v in %v: %v", filename, dir, err) } f.Sync() } return nil } func (bm *Benchmark) nodeOnline(node *client.NodeInfo) bool { ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) defer cancel() cli, err := client.New(ctx, node.Address) if err != nil { return false } cli.Close() return true } func (bm *Benchmark) allNodesOnline(ctx context.Context, cancel context.CancelFunc) { for { if errors.Is(ctx.Err(), context.DeadlineExceeded) { return } cli, err := bm.app.Client(ctx) if err != nil { continue } nodes, err := cli.Cluster(ctx) if err != nil { continue } cli.Close() n := 0 for _, needed := range bm.options.cluster { for _, present := range nodes { if needed == present.Address && bm.nodeOnline(&present) { n += 1 } } } if len(bm.options.cluster) == n { cancel() return } } } func (bm *Benchmark) waitForCluster(ch <-chan os.Signal) error { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(bm.options.clusterTimeout)) defer cancel() go bm.allNodesOnline(ctx, cancel) select { case <-ctx.Done(): if !errors.Is(ctx.Err(), context.Canceled) { return fmt.Errorf("timed out waiting for cluster: %v", ctx.Err()) } return nil case <-ch: return fmt.Errorf("benchmark stopped, signal received while waiting for cluster") } } func (bm *Benchmark) Run(ch <-chan os.Signal) error { if err := bm.setup(); err != nil { return err } if err := bm.waitForCluster(ch); err != nil { return err } ctx, cancel := context.WithTimeout(context.Background(), bm.options.duration) defer cancel() bm.runWorkload(ctx) select { case <-ctx.Done(): break case <-ch: cancel() break } if err := bm.reportResults(); err != nil { return err } fmt.Printf("Benchmark done. Results available here:\n%s\n", path.Join(bm.dir, "results")) return nil } golang-github-canonical-go-dqlite-2.0.0/benchmark/benchmark_test.go000066400000000000000000000053561471100661000253270ustar00rootroot00000000000000package benchmark_test import ( "context" "database/sql" "io/ioutil" "os" "testing" "time" "github.com/canonical/go-dqlite/v2/app" "github.com/canonical/go-dqlite/v2/benchmark" "github.com/stretchr/testify/require" ) const ( addr1 = "127.0.0.1:9011" addr2 = "127.0.0.1:9012" addr3 = "127.0.0.1:9013" ) func bmSetup(t *testing.T, addr string, join []string) (string, *app.App, *sql.DB, func()) { t.Helper() dir, err := ioutil.TempDir("", "dqlite-app-test-") require.NoError(t, err) app, err := app.New(dir, app.WithAddress(addr), app.WithCluster(join)) require.NoError(t, err) readyCtx, cancel := context.WithTimeout(context.Background(), time.Duration(3)*time.Second) err = app.Ready(readyCtx) require.NoError(t, err) db, err := app.Open(context.Background(), "benchmark") require.NoError(t, err) cleanups := func() { os.RemoveAll(dir) cancel() } return dir, app, db, cleanups } func bmRun(t *testing.T, bm *benchmark.Benchmark, app *app.App, db *sql.DB) { defer db.Close() defer app.Close() ch := make(chan os.Signal) err := bm.Run(ch) require.NoError(t, err) } // Create a Benchmark with default values. func TestNew_Default(t *testing.T) { dir, app, db, cleanup := bmSetup(t, addr1, nil) defer cleanup() bm, err := benchmark.New( app, db, dir, benchmark.WithCluster([]string{addr1}), benchmark.WithDuration(1)) require.NoError(t, err) bmRun(t, bm, app, db) } // Create a Benchmark with a kvReadWriteWorkload. func TestNew_KvReadWrite(t *testing.T) { dir, app, db, cleanup := bmSetup(t, addr1, nil) defer cleanup() bm, err := benchmark.New( app, db, dir, benchmark.WithCluster([]string{addr1}), benchmark.WithDuration(1), benchmark.WithWorkload("KvReadWrite")) require.NoError(t, err) bmRun(t, bm, app, db) } // Create a clustered Benchmark. func TestNew_ClusteredKvReadWrite(t *testing.T) { dir, app, db, cleanup := bmSetup(t, addr1, nil) _, _, _, cleanup2 := bmSetup(t, addr2, []string{addr1}) _, _, _, cleanup3 := bmSetup(t, addr3, []string{addr1}) defer cleanup() defer cleanup2() defer cleanup3() bm, err := benchmark.New( app, db, dir, benchmark.WithCluster([]string{addr1, addr2, addr3}), benchmark.WithDuration(2)) require.NoError(t, err) bmRun(t, bm, app, db) } // Create a clustered Benchmark that times out waiting for the cluster to form. func TestNew_ClusteredTimeout(t *testing.T) { dir, app, db, cleanup := bmSetup(t, addr1, nil) defer cleanup() defer db.Close() defer app.Close() bm, err := benchmark.New( app, db, dir, benchmark.WithCluster([]string{addr1, addr2}), benchmark.WithClusterTimeout(2)) require.NoError(t, err) ch := make(chan os.Signal) err = bm.Run(ch) require.Errorf(t, err, "Timed out waiting for cluster: context deadline exceeded") } golang-github-canonical-go-dqlite-2.0.0/benchmark/options.go000066400000000000000000000041171471100661000240230ustar00rootroot00000000000000package benchmark import ( "strings" "time" ) type workload int32 const ( kvWrite workload = iota kvReadWrite workload = iota ) type Option func(*options) type options struct { cluster []string clusterTimeout time.Duration workload workload duration time.Duration nWorkers int kvKeySizeB int kvValueSizeB int } func parseWorkload(workload string) workload { switch strings.ToLower(workload) { case "kvwrite": return kvWrite case "kvreadwrite": return kvReadWrite default: return kvWrite } } // WithWorkload sets the workload of the benchmark. func WithWorkload(workload string) Option { return func(options *options) { options.workload = parseWorkload(workload) } } // WithDuration sets the duration of the benchmark. func WithDuration(seconds int) Option { return func(options *options) { options.duration = time.Duration(seconds) * time.Second } } // WithWorkers sets the number of workers of the benchmark. func WithWorkers(n int) Option { return func(options *options) { options.nWorkers = n } } // WithKvKeySize sets the size of the KV keys of the benchmark. func WithKvKeySize(bytes int) Option { return func(options *options) { options.kvKeySizeB = bytes } } // WithKvValueSize sets the size of the KV values of the benchmark. func WithKvValueSize(bytes int) Option { return func(options *options) { options.kvValueSizeB = bytes } } // WithCluster sets the cluster option of the benchmark. A benchmark will only // start once the whole cluster is online. func WithCluster(cluster []string) Option { return func(options *options) { options.cluster = cluster } } // WithClusterTimeout sets the timeout when waiting for the whole cluster to be // online func WithClusterTimeout(cTo int) Option { return func(options *options) { options.clusterTimeout = time.Duration(cTo) * time.Second } } func defaultOptions() *options { return &options{ cluster: nil, clusterTimeout: time.Minute, duration: time.Minute, kvKeySizeB: 32, kvValueSizeB: 1024, nWorkers: 1, workload: kvWrite, } } golang-github-canonical-go-dqlite-2.0.0/benchmark/tracker.go000066400000000000000000000053541471100661000237670ustar00rootroot00000000000000package benchmark import ( "fmt" "math" "strings" "sync" "time" ) func durToMs(d time.Duration) string { ms := int64(d / time.Millisecond) rest := int64(d % time.Millisecond) return fmt.Sprintf("%d.%06d", ms, rest) } type measurement struct { start time.Time duration time.Duration } func (m measurement) String() string { return fmt.Sprintf("%v %v", m.start.UnixNano(), durToMs(m.duration)) } type measurementErr struct { start time.Time err error } func (m measurementErr) String() string { return fmt.Sprintf("%v %v", m.start.UnixNano(), m.err) } type tracker struct { lock sync.RWMutex measurements map[work][]measurement errors map[work][]measurementErr } type report struct { n int nErr int totalDuration time.Duration avgDuration time.Duration maxDuration time.Duration minDuration time.Duration measurements []measurement errors []measurementErr } func (r report) String() string { var msb strings.Builder for _, m := range r.measurements { fmt.Fprintf(&msb, "%s\n", m) } var esb strings.Builder for _, e := range r.errors { fmt.Fprintf(&esb, "%s\n", e) } return fmt.Sprintf("n %d\n"+ "n_err %d\n"+ "avg [ms] %s\n"+ "max [ms] %s\n"+ "min [ms] %s\n"+ "measurements [timestamp in ns] [ms]\n%s\n"+ "errors\n%s\n", r.n, r.nErr, durToMs(r.avgDuration), durToMs(r.maxDuration), durToMs(r.minDuration), msb.String(), esb.String()) } func (t *tracker) measure(start time.Time, work work, err *error) { t.lock.Lock() defer t.lock.Unlock() duration := time.Since(start) if *err == nil { m := measurement{start, duration} t.measurements[work] = append(t.measurements[work], m) } else { e := measurementErr{start, *err} t.errors[work] = append(t.errors[work], e) } } func (t *tracker) report() map[work]report { t.lock.RLock() defer t.lock.RUnlock() reports := make(map[work]report) for w := range t.measurements { report := report{ n: len(t.measurements[w]), nErr: len(t.errors[w]), totalDuration: 0, avgDuration: 0, maxDuration: 0, minDuration: time.Duration(math.MaxInt64), measurements: t.measurements[w], errors: t.errors[w], } for _, m := range t.measurements[w] { report.totalDuration += m.duration if m.duration < report.minDuration { report.minDuration = m.duration } if m.duration > report.maxDuration { report.maxDuration = m.duration } } if report.n > 0 { report.avgDuration = report.totalDuration / time.Duration(report.n) } reports[w] = report } return reports } func newTracker() *tracker { return &tracker{ lock: sync.RWMutex{}, measurements: make(map[work][]measurement), errors: make(map[work][]measurementErr), } } golang-github-canonical-go-dqlite-2.0.0/benchmark/worker.go000066400000000000000000000065131471100661000236430ustar00rootroot00000000000000package benchmark import ( "context" "database/sql" "errors" "fmt" "math/rand" "strings" "time" ) type work int type workerType int func (w work) String() string { switch w { case exec: return "exec" case query: return "query" case none: return "none" default: return "unknown" } } const ( // The type of query to perform none work = iota exec work = iota // a `write` query work = iota // a `read` kvWriter workerType = iota kvReader workerType = iota kvReaderWriter workerType = iota kvReadSql = "SELECT value FROM model WHERE key = ?" kvWriteSql = "INSERT OR REPLACE INTO model(key, value) VALUES(?, ?)" ) // A worker performs the queries to the database and keeps around some state // in order to do that. `lastWork` and `lastArgs` refer to the previously // executed operation and can be used to determine the next work the worker // should perform. `kvKeys` tells the worker which keys it has inserted in the // database. type worker struct { workerType workerType lastWork work lastArgs []interface{} tracker *tracker kvKeySizeB int kvValueSizeB int kvKeys []string } // Thanks to https://stackoverflow.com/a/22892986 var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") func randSeq(n int) string { b := make([]rune, n) for i := range b { b[i] = letters[rand.Intn(len(letters))] } return string(b) } func (w *worker) randNewKey() string { return randSeq(w.kvKeySizeB) } func (w *worker) randExistingKey() (string, error) { n := len(w.kvKeys) if n == 0 { return "", errors.New("no keys") } return w.kvKeys[rand.Intn(n)], nil } // A mix of random bytes and easily compressable bytes. func (w *worker) randValue() string { return strings.Repeat(randSeq(1), w.kvValueSizeB/2) + randSeq(w.kvValueSizeB/2) } // Returns the type of work to execute and a sql statement with arguments func (w *worker) getWork() (work, string, []interface{}) { switch w.workerType { case kvWriter: k, v := w.randNewKey(), w.randValue() return exec, kvWriteSql, []interface{}{k, v} case kvReaderWriter: read := rand.Intn(2) == 0 if read && len(w.kvKeys) != 0 { k, _ := w.randExistingKey() return query, kvReadSql, []interface{}{k} } k, v := w.randNewKey(), w.randValue() return exec, kvWriteSql, []interface{}{k, v} default: return none, "", []interface{}{} } } // Retrieve a query and execute it against the database func (w *worker) doWork(ctx context.Context, db *sql.DB) { var err error var str string work, q, args := w.getWork() w.lastWork = work w.lastArgs = args switch work { case exec: w.kvKeys = append(w.kvKeys, fmt.Sprintf("%v", (args[0]))) defer w.tracker.measure(time.Now(), work, &err) _, err = db.ExecContext(ctx, q, args...) if err != nil { w.kvKeys = w.kvKeys[:len(w.kvKeys)-1] } case query: defer w.tracker.measure(time.Now(), work, &err) err = db.QueryRowContext(ctx, q, args...).Scan(&str) default: return } } func (w *worker) run(ctx context.Context, db *sql.DB) { for { if ctx.Err() != nil { return } w.doWork(ctx, db) } } func (w *worker) report() map[work]report { return w.tracker.report() } func newWorker(workerType workerType, o *options) *worker { return &worker{ workerType: workerType, kvKeySizeB: o.kvKeySizeB, kvValueSizeB: o.kvValueSizeB, tracker: newTracker(), } } golang-github-canonical-go-dqlite-2.0.0/client/000077500000000000000000000000001471100661000213225ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/client/client.go000066400000000000000000000227711471100661000231400ustar00rootroot00000000000000package client import ( "context" "github.com/canonical/go-dqlite/v2/internal/protocol" "github.com/pkg/errors" ) // DialFunc is a function that can be used to establish a network connection. type DialFunc = protocol.DialFunc // Client speaks the dqlite wire protocol. type Client struct { protocol *protocol.Protocol } // Option that can be used to tweak client parameters. type Option func(*options) type options struct { DialFunc DialFunc LogFunc LogFunc ConcurrentLeaderConns int64 PermitShared bool } // WithDialFunc sets a custom dial function for creating the client network // connection. func WithDialFunc(dial DialFunc) Option { return func(options *options) { options.DialFunc = dial } } // WithLogFunc sets a custom log function. // connection. func WithLogFunc(log LogFunc) Option { return func(options *options) { options.LogFunc = log } } // WithConcurrentLeaderConns is the maximum number of concurrent connections // to other cluster members that will be attempted while searching for the dqlite leader. // // The default is 10 connections to other cluster members. func WithConcurrentLeaderConns(maxConns int64) Option { return func(o *options) { o.ConcurrentLeaderConns = maxConns } } // New creates a new client connected to the dqlite node with the given // address. func New(ctx context.Context, address string, options ...Option) (*Client, error) { o := defaultOptions() for _, option := range options { option(o) } // Establish the connection. conn, err := o.DialFunc(ctx, address) if err != nil { return nil, errors.Wrap(err, "failed to establish network connection") } protocol, err := protocol.Handshake(ctx, conn, protocol.VersionOne, address) if err != nil { conn.Close() return nil, err } return &Client{protocol}, nil } // Leader returns information about the current leader, if any. func (c *Client) Leader(ctx context.Context) (*NodeInfo, error) { request := protocol.Message{} request.Init(16) response := protocol.Message{} response.Init(512) protocol.EncodeLeader(&request) if err := c.protocol.Call(ctx, &request, &response); err != nil { return nil, errors.Wrap(err, "failed to send Leader request") } id, address, err := protocol.DecodeNode(&response) if err != nil { return nil, errors.Wrap(err, "failed to parse Node response") } info := &NodeInfo{ID: id, Address: address} return info, nil } // Cluster returns information about all nodes in the cluster. func (c *Client) Cluster(ctx context.Context) ([]NodeInfo, error) { request := protocol.Message{} request.Init(16) response := protocol.Message{} response.Init(512) protocol.EncodeCluster(&request, protocol.ClusterFormatV1) if err := c.protocol.Call(ctx, &request, &response); err != nil { return nil, errors.Wrap(err, "failed to send Cluster request") } servers, err := protocol.DecodeNodes(&response) if err != nil { return nil, errors.Wrap(err, "failed to parse Node response") } return servers, nil } // File holds the content of a single database file. type File struct { Name string Data []byte } // Dump the content of the database with the given name. Two files will be // returned, the first is the main database file (which has the same name as // the database), the second is the WAL file (which has the same name as the // database plus the suffix "-wal"). func (c *Client) Dump(ctx context.Context, dbname string) ([]File, error) { request := protocol.Message{} request.Init(16) response := protocol.Message{} response.Init(512) protocol.EncodeDump(&request, dbname) if err := c.protocol.Call(ctx, &request, &response); err != nil { return nil, errors.Wrap(err, "failed to send dump request") } files, err := protocol.DecodeFiles(&response) if err != nil { return nil, errors.Wrap(err, "failed to parse files response") } defer files.Close() dump := make([]File, 0) for { name, data := files.Next() if name == "" { break } dump = append(dump, File{Name: name, Data: data}) } return dump, nil } // Add a node to a cluster. // // The new node will have the role specified in node.Role. Note that if the // desired role is Voter, the node being added must be online, since it will be // granted voting rights only once it catches up with the leader's log. func (c *Client) Add(ctx context.Context, node NodeInfo) error { request := protocol.Message{} response := protocol.Message{} request.Init(4096) response.Init(4096) protocol.EncodeAdd(&request, node.ID, node.Address) if err := c.protocol.Call(ctx, &request, &response); err != nil { return err } if err := protocol.DecodeEmpty(&response); err != nil { return err } // If the desired role is spare, there's nothing to do, since all newly // added nodes have the spare role. if node.Role == Spare { return nil } return c.Assign(ctx, node.ID, node.Role) } // Assign a role to a node. // // Possible roles are: // // - Voter: the node will replicate data and participate in quorum. // - StandBy: the node will replicate data but won't participate in quorum. // - Spare: the node won't replicate data and won't participate in quorum. // // If the target node does not exist or has already the desired role, an error // is returned. func (c *Client) Assign(ctx context.Context, id uint64, role NodeRole) error { request := protocol.Message{} response := protocol.Message{} request.Init(4096) response.Init(4096) protocol.EncodeAssign(&request, id, uint64(role)) if err := c.protocol.Call(ctx, &request, &response); err != nil { return err } if err := protocol.DecodeEmpty(&response); err != nil { return err } return nil } // Transfer leadership from the current leader to another node. // // This must be invoked one client connected to the current leader. func (c *Client) Transfer(ctx context.Context, id uint64) error { request := protocol.Message{} response := protocol.Message{} request.Init(4096) response.Init(4096) protocol.EncodeTransfer(&request, id) if err := c.protocol.Call(ctx, &request, &response); err != nil { return err } if err := protocol.DecodeEmpty(&response); err != nil { return err } return nil } // Remove a node from the cluster. func (c *Client) Remove(ctx context.Context, id uint64) error { request := protocol.Message{} request.Init(4096) response := protocol.Message{} response.Init(4096) protocol.EncodeRemove(&request, id) if err := c.protocol.Call(ctx, &request, &response); err != nil { return err } if err := protocol.DecodeEmpty(&response); err != nil { return err } return nil } // NodeMetadata user-defined node-level metadata. type NodeMetadata struct { FailureDomain uint64 Weight uint64 } // Describe returns metadata about the node we're connected with. func (c *Client) Describe(ctx context.Context) (*NodeMetadata, error) { request := protocol.Message{} request.Init(4096) response := protocol.Message{} response.Init(4096) protocol.EncodeDescribe(&request, protocol.RequestDescribeFormatV0) if err := c.protocol.Call(ctx, &request, &response); err != nil { return nil, err } domain, weight, err := protocol.DecodeMetadata(&response) if err != nil { return nil, err } metadata := &NodeMetadata{ FailureDomain: domain, Weight: weight, } return metadata, nil } // Weight updates the weight associated to the node we're connected with. func (c *Client) Weight(ctx context.Context, weight uint64) error { request := protocol.Message{} request.Init(4096) response := protocol.Message{} response.Init(4096) protocol.EncodeWeight(&request, weight) if err := c.protocol.Call(ctx, &request, &response); err != nil { return err } if err := protocol.DecodeEmpty(&response); err != nil { return err } return nil } // Close the client. func (c *Client) Close() error { return c.protocol.Close() } // Create a client options object with sane defaults. func defaultOptions() *options { return &options{ DialFunc: DefaultDialFunc, LogFunc: DefaultLogFunc, ConcurrentLeaderConns: protocol.MaxConcurrentLeaderConns, } } // Connector is a reusable configuration for creating new Clients. // // In some cases, Connector.Connect can take advantage of state stored in the // Connector to be more efficient than New or FindLeader, so prefer to use a // Connector whenever several Clients need to be created with the same // parameters. type Connector protocol.Connector // NewLeaderConnector creates a Connector that will yield Clients connected to // the cluster leader. func NewLeaderConnector(store NodeStore, options ...Option) *Connector { opts := defaultOptions() for _, o := range options { o(opts) } config := protocol.Config{ Dial: opts.DialFunc, ConcurrentLeaderConns: opts.ConcurrentLeaderConns, PermitShared: opts.PermitShared, } inner := protocol.NewLeaderConnector(store, config, opts.LogFunc) return (*Connector)(inner) } // NewDirectConnector creates a Connector that will yield Clients connected to // the node with the given ID and address. func NewDirectConnector(id uint64, address string, options ...Option) *Connector { opts := defaultOptions() for _, o := range options { o(opts) } config := protocol.Config{Dial: opts.DialFunc} inner := protocol.NewDirectConnector(id, address, config, opts.LogFunc) return (*Connector)(inner) } // Connect opens a Client based on the Connector's configuration. func (connector *Connector) Connect(ctx context.Context) (*Client, error) { protocol, err := (*protocol.Connector)(connector).Connect(ctx) if err != nil { return nil, err } return &Client{protocol}, nil } golang-github-canonical-go-dqlite-2.0.0/client/client_test.go000066400000000000000000000100221471100661000241610ustar00rootroot00000000000000package client_test import ( "context" "fmt" "io/ioutil" "os" "testing" "time" dqlite "github.com/canonical/go-dqlite/v2" "github.com/canonical/go-dqlite/v2/client" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestClient_Leader(t *testing.T) { node, cleanup := newNode(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() client, err := client.New(ctx, node.BindAddress()) require.NoError(t, err) defer client.Close() leader, err := client.Leader(context.Background()) require.NoError(t, err) assert.Equal(t, leader.ID, uint64(1)) assert.Equal(t, leader.Address, "@1001") } func TestClient_Cluster(t *testing.T) { node, cleanup := newNode(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() cli, err := client.New(ctx, node.BindAddress()) require.NoError(t, err) defer cli.Close() servers, err := cli.Cluster(context.Background()) require.NoError(t, err) assert.Len(t, servers, 1) assert.Equal(t, servers[0].ID, uint64(1)) assert.Equal(t, servers[0].Address, "@1001") assert.Equal(t, servers[0].Role, client.Voter) } func TestClient_Transfer(t *testing.T) { node1, cleanup := newNode(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cli, err := client.New(ctx, node1.BindAddress()) require.NoError(t, err) defer cli.Close() node2, cleanup := addNode(t, cli, 2) defer cleanup() err = cli.Assign(context.Background(), 2, client.Voter) require.NoError(t, err) err = cli.Transfer(context.Background(), 2) require.NoError(t, err) leader, err := cli.Leader(context.Background()) require.NoError(t, err) assert.Equal(t, leader.ID, uint64(2)) cli, err = client.New(ctx, node2.BindAddress()) require.NoError(t, err) defer cli.Close() leader, err = cli.Leader(context.Background()) require.NoError(t, err) assert.Equal(t, leader.ID, uint64(2)) } func TestClient_Describe(t *testing.T) { node, cleanup := newNode(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() cli, err := client.New(ctx, node.BindAddress()) require.NoError(t, err) defer cli.Close() metadata, err := cli.Describe(context.Background()) require.NoError(t, err) assert.Equal(t, uint64(0), metadata.FailureDomain) assert.Equal(t, uint64(0), metadata.Weight) require.NoError(t, cli.Weight(context.Background(), 123)) metadata, err = cli.Describe(context.Background()) require.NoError(t, err) assert.Equal(t, uint64(0), metadata.FailureDomain) assert.Equal(t, uint64(123), metadata.Weight) } func newNode(t *testing.T) (*dqlite.Node, func()) { t.Helper() dir, dirCleanup := newDir(t) id := uint64(1) address := fmt.Sprintf("@%d", id+1000) node, err := dqlite.New(uint64(1), address, dir, dqlite.WithBindAddress(address)) require.NoError(t, err) err = node.Start() require.NoError(t, err) cleanup := func() { require.NoError(t, node.Close()) dirCleanup() } return node, cleanup } func addNode(t *testing.T, cli *client.Client, id uint64) (*dqlite.Node, func()) { t.Helper() dir, dirCleanup := newDir(t) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() address := fmt.Sprintf("@%d", id+1000) node, err := dqlite.New(id, address, dir, dqlite.WithBindAddress(address)) require.NoError(t, err) err = node.Start() require.NoError(t, err) info := client.NodeInfo{ ID: id, Address: address, Role: client.Spare, } err = cli.Add(ctx, info) require.NoError(t, err) cleanup := func() { require.NoError(t, node.Close()) dirCleanup() } return node, cleanup } // Return a new temporary directory. func newDir(t *testing.T) (string, func()) { t.Helper() dir, err := ioutil.TempDir("", "dqlite-replication-test-") assert.NoError(t, err) cleanup := func() { _, err := os.Stat(dir) if err != nil { assert.True(t, os.IsNotExist(err)) } else { assert.NoError(t, os.RemoveAll(dir)) } } return dir, cleanup } golang-github-canonical-go-dqlite-2.0.0/client/constants.go000066400000000000000000000002731471100661000236670ustar00rootroot00000000000000package client import ( "github.com/canonical/go-dqlite/v2/internal/protocol" ) // Node roles const ( Voter = protocol.Voter StandBy = protocol.StandBy Spare = protocol.Spare ) golang-github-canonical-go-dqlite-2.0.0/client/database_store.go000066400000000000000000000103431471100661000246320ustar00rootroot00000000000000// +build !nosqlite3 package client import ( "context" "database/sql" "fmt" "strings" _ "github.com/mattn/go-sqlite3" // Go SQLite bindings "github.com/pkg/errors" ) // Option that can be used to tweak node store parameters. type NodeStoreOption func(*nodeStoreOptions) type nodeStoreOptions struct { Where string } // DatabaseNodeStore persists a list addresses of dqlite nodes in a SQL table. type DatabaseNodeStore struct { db *sql.DB // Database handle to use. schema string // Name of the schema holding the servers table. table string // Name of the servers table. column string // Column name in the servers table holding the server address. where string // Optional WHERE filter } // DefaultNodeStore creates a new NodeStore using the given filename. // // If the filename ends with ".yaml" then the YamlNodeStore implementation will // be used. Otherwise the SQLite-based one will be picked, with default names // for the schema, table and column parameters. // // It also creates the table if it doesn't exist yet. func DefaultNodeStore(filename string) (NodeStore, error) { if strings.HasSuffix(filename, ".yaml") { return NewYamlNodeStore(filename) } // Open the database. db, err := sql.Open("sqlite3", filename) if err != nil { return nil, errors.Wrap(err, "failed to open database") } // Since we're setting SQLite single-thread mode, we need to have one // connection at most. db.SetMaxOpenConns(1) // Create the servers table if it does not exist yet. _, err = db.Exec("CREATE TABLE IF NOT EXISTS servers (address TEXT, UNIQUE(address))") if err != nil { return nil, errors.Wrap(err, "failed to create servers table") } store := NewNodeStore(db, "main", "servers", "address") return store, nil } // NewNodeStore creates a new NodeStore. func NewNodeStore(db *sql.DB, schema, table, column string, options ...NodeStoreOption) *DatabaseNodeStore { o := &nodeStoreOptions{} for _, option := range options { option(o) } return &DatabaseNodeStore{ db: db, schema: schema, table: table, column: column, where: o.Where, } } // WithNodeStoreWhereClause configures the node store to append the given // hard-coded where clause to the SELECT query used to fetch nodes. Only the // clause itself must be given, without the "WHERE" prefix. func WithNodeStoreWhereClause(where string) NodeStoreOption { return func(options *nodeStoreOptions) { options.Where = where } } // Get the current servers. func (d *DatabaseNodeStore) Get(ctx context.Context) ([]NodeInfo, error) { tx, err := d.db.Begin() if err != nil { return nil, errors.Wrap(err, "failed to begin transaction") } defer tx.Rollback() query := fmt.Sprintf("SELECT %s FROM %s.%s", d.column, d.schema, d.table) if d.where != "" { query += " WHERE " + d.where } rows, err := tx.QueryContext(ctx, query) if err != nil { return nil, errors.Wrap(err, "failed to query servers table") } defer rows.Close() servers := make([]NodeInfo, 0) for rows.Next() { var address string err := rows.Scan(&address) if err != nil { return nil, errors.Wrap(err, "failed to fetch server address") } servers = append(servers, NodeInfo{ID: 1, Address: address}) } if err := rows.Err(); err != nil { return nil, errors.Wrap(err, "result set failure") } return servers, nil } // Set the servers addresses. func (d *DatabaseNodeStore) Set(ctx context.Context, servers []NodeInfo) error { tx, err := d.db.Begin() if err != nil { return errors.Wrap(err, "failed to begin transaction") } query := fmt.Sprintf("DELETE FROM %s.%s", d.schema, d.table) if _, err := tx.ExecContext(ctx, query); err != nil { tx.Rollback() return errors.Wrap(err, "failed to delete existing servers rows") } query = fmt.Sprintf("INSERT INTO %s.%s(%s) VALUES (?)", d.schema, d.table, d.column) stmt, err := tx.PrepareContext(ctx, query) if err != nil { tx.Rollback() return errors.Wrap(err, "failed to prepare insert statement") } defer stmt.Close() for _, server := range servers { if _, err := stmt.ExecContext(ctx, server.Address); err != nil { tx.Rollback() return errors.Wrapf(err, "failed to insert server %s", server.Address) } } if err := tx.Commit(); err != nil { return errors.Wrap(err, "failed to commit transaction") } return nil } golang-github-canonical-go-dqlite-2.0.0/client/dial.go000066400000000000000000000020311471100661000225560ustar00rootroot00000000000000package client import ( "context" "crypto/tls" "net" "github.com/canonical/go-dqlite/v2/internal/protocol" ) // DefaultDialFunc is the default dial function, which can handle plain TCP and // Unix socket endpoints. You can customize it with WithDialFunc() func DefaultDialFunc(ctx context.Context, address string) (net.Conn, error) { return protocol.Dial(ctx, address) } // DialFuncWithTLS returns a dial function that uses TLS encryption. // // The given dial function will be used to establish the network connection, // and the given TLS config will be used for encryption. func DialFuncWithTLS(dial DialFunc, config *tls.Config) DialFunc { return func(ctx context.Context, addr string) (net.Conn, error) { clonedConfig := config.Clone() if len(clonedConfig.ServerName) == 0 { remoteIP, _, err := net.SplitHostPort(addr) if err != nil { return nil, err } clonedConfig.ServerName = remoteIP } conn, err := dial(ctx, addr) if err != nil { return nil, err } return tls.Client(conn, clonedConfig), nil } } golang-github-canonical-go-dqlite-2.0.0/client/leader.go000066400000000000000000000010211471100661000230770ustar00rootroot00000000000000package client import ( "context" ) // FindLeader returns a Client connected to the current cluster leader. // // The function will iterate through to all nodes in the given store, and for // each of them check if it's the current leader. If no leader is found, the // function will keep retrying (with a capped exponential backoff) until the // given context is canceled. func FindLeader(ctx context.Context, store NodeStore, options ...Option) (*Client, error) { return NewLeaderConnector(store, options...).Connect(ctx) } golang-github-canonical-go-dqlite-2.0.0/client/leader_test.go000066400000000000000000000024061471100661000241460ustar00rootroot00000000000000package client_test import ( "context" "fmt" "testing" "time" dqlite "github.com/canonical/go-dqlite/v2" "github.com/canonical/go-dqlite/v2/client" "github.com/stretchr/testify/require" ) func TestMembership(t *testing.T) { infos, cleanup := setup(t) defer cleanup() store := client.NewInmemNodeStore() store.Set(context.Background(), infos[:1]) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() client, err := client.FindLeader(ctx, store) require.NoError(t, err) defer client.Close() err = client.Add(ctx, infos[1]) require.NoError(t, err) } func setup(t *testing.T) ([]client.NodeInfo, func()) { n := 3 nodes := make([]*dqlite.Node, n) infos := make([]client.NodeInfo, n) var cleanups []func() for i := range nodes { id := uint64(i + 1) address := fmt.Sprintf("@test-%d", id) dir, cleanup := newDir(t) cleanups = append(cleanups, cleanup) node, err := dqlite.New(id, address, dir, dqlite.WithBindAddress(address)) require.NoError(t, err) nodes[i] = node infos[i].ID = id infos[i].Address = address err = node.Start() require.NoError(t, err) cleanups = append(cleanups, func() { node.Close() }) } return infos, func() { for i := len(cleanups) - 1; i >= 0; i-- { cleanups[i]() } } } golang-github-canonical-go-dqlite-2.0.0/client/log.go000066400000000000000000000007721471100661000224400ustar00rootroot00000000000000package client import ( "github.com/canonical/go-dqlite/v2/logging" ) // LogFunc is a function that can be used for logging. type LogFunc = logging.Func // LogLevel defines the logging level. type LogLevel = logging.Level // Available logging levels. const ( LogNone = logging.None LogDebug = logging.Debug LogInfo = logging.Info LogWarn = logging.Warn LogError = logging.Error ) // DefaultLogFunc doesn't emit any message. func DefaultLogFunc(l LogLevel, format string, a ...interface{}) {} golang-github-canonical-go-dqlite-2.0.0/client/no_database_store.go000066400000000000000000000006361471100661000253320ustar00rootroot00000000000000// +build nosqlite3 package client import ( "strings" "github.com/pkg/errors" ) // DefaultNodeStore creates a new NodeStore using the given filename. // // The filename must end with ".yaml". func DefaultNodeStore(filename string) (NodeStore, error) { if strings.HasSuffix(filename, ".yaml") { return NewYamlNodeStore(filename) } return nil, errors.New("built without support for DatabaseNodeStore") } golang-github-canonical-go-dqlite-2.0.0/client/store.go000066400000000000000000000037021471100661000230070ustar00rootroot00000000000000package client import ( "context" "io/ioutil" "os" "sync" "github.com/google/renameio" "gopkg.in/yaml.v2" "github.com/canonical/go-dqlite/v2/internal/protocol" ) // NodeStore is used by a dqlite client to get an initial list of candidate // dqlite nodes that it can dial in order to find a leader dqlite node to use. type NodeStore = protocol.NodeStore // NodeRole identifies the role of a node. type NodeRole = protocol.NodeRole // NodeInfo holds information about a single server. type NodeInfo = protocol.NodeInfo // InmemNodeStore keeps the list of target dqlite nodes in memory. type InmemNodeStore = protocol.InmemNodeStore // NewInmemNodeStore creates NodeStore which stores its data in-memory. var NewInmemNodeStore = protocol.NewInmemNodeStore // Persists a list addresses of dqlite nodes in a YAML file. type YamlNodeStore struct { path string servers []NodeInfo mu sync.RWMutex } // NewYamlNodeStore creates a new YamlNodeStore backed by the given YAML file. func NewYamlNodeStore(path string) (*YamlNodeStore, error) { servers := []NodeInfo{} _, err := os.Stat(path) if err != nil { if !os.IsNotExist(err) { return nil, err } } else { data, err := ioutil.ReadFile(path) if err != nil { return nil, err } if err := yaml.Unmarshal(data, &servers); err != nil { return nil, err } } store := &YamlNodeStore{ path: path, servers: servers, } return store, nil } // Get the current servers. func (s *YamlNodeStore) Get(ctx context.Context) ([]NodeInfo, error) { s.mu.RLock() defer s.mu.RUnlock() ret := make([]NodeInfo, len(s.servers)) copy(ret, s.servers) return ret, nil } // Set the servers addresses. func (s *YamlNodeStore) Set(ctx context.Context, servers []NodeInfo) error { s.mu.Lock() defer s.mu.Unlock() data, err := yaml.Marshal(servers) if err != nil { return err } if err := renameio.WriteFile(s.path, data, 0600); err != nil { return err } s.servers = servers return nil } golang-github-canonical-go-dqlite-2.0.0/client/store_test.go000066400000000000000000000044341471100661000240510ustar00rootroot00000000000000// +build !nosqlite3 package client_test import ( "context" "database/sql" "testing" dqlite "github.com/canonical/go-dqlite/v2" "github.com/canonical/go-dqlite/v2/client" "github.com/canonical/go-dqlite/v2/driver" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Exercise setting and getting servers in a DatabaseNodeStore created with // DefaultNodeStore. func TestDefaultNodeStore(t *testing.T) { // Create a new default store. store, err := client.DefaultNodeStore(":memory:") require.NoError(t, err) // Set and get some targets. err = store.Set(context.Background(), []client.NodeInfo{ {Address: "1.2.3.4:666"}, {Address: "5.6.7.8:666"}}, ) require.NoError(t, err) servers, err := store.Get(context.Background()) require.NoError(t, err) assert.Equal(t, []client.NodeInfo{ {ID: uint64(1), Address: "1.2.3.4:666"}, {ID: uint64(1), Address: "5.6.7.8:666"}}, servers) // Set and get some new targets. err = store.Set(context.Background(), []client.NodeInfo{ {Address: "1.2.3.4:666"}, {Address: "9.9.9.9:666"}, }) require.NoError(t, err) servers, err = store.Get(context.Background()) require.NoError(t, err) assert.Equal(t, []client.NodeInfo{ {ID: uint64(1), Address: "1.2.3.4:666"}, {ID: uint64(1), Address: "9.9.9.9:666"}}, servers) // Setting duplicate targets returns an error and the change is not // persisted. err = store.Set(context.Background(), []client.NodeInfo{ {Address: "1.2.3.4:666"}, {Address: "1.2.3.4:666"}, }) assert.EqualError(t, err, "failed to insert server 1.2.3.4:666: UNIQUE constraint failed: servers.address") servers, err = store.Get(context.Background()) require.NoError(t, err) assert.Equal(t, []client.NodeInfo{ {ID: uint64(1), Address: "1.2.3.4:666"}, {ID: uint64(1), Address: "9.9.9.9:666"}}, servers) } func TestConfigMultiThread(t *testing.T) { cleanup := dummyDBSetup(t) defer cleanup() err := dqlite.ConfigMultiThread() assert.EqualError(t, err, "SQLite is already initialized") } func dummyDBSetup(t *testing.T) func() { store := client.NewInmemNodeStore() driver, err := driver.New(store) require.NoError(t, err) sql.Register("dummy", driver) db, err := sql.Open("dummy", "test.db") require.NoError(t, err) cleanup := func() { require.NoError(t, db.Close()) } return cleanup } golang-github-canonical-go-dqlite-2.0.0/cmd/000077500000000000000000000000001471100661000206075ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/cmd/dqlite-benchmark/000077500000000000000000000000001471100661000240215ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/cmd/dqlite-benchmark/dqlite-benchmark.go000066400000000000000000000124071471100661000275660ustar00rootroot00000000000000package main import ( "context" "fmt" "os" "os/signal" "path/filepath" "time" "github.com/canonical/go-dqlite/v2/app" "github.com/canonical/go-dqlite/v2/benchmark" "github.com/pkg/errors" "github.com/spf13/cobra" "golang.org/x/sys/unix" ) const ( defaultClusterTimeout = 120 defaultDir = "/tmp/dqlite-benchmark" defaultDiskMode = false defaultDriver = false defaultDurationS = 60 defaultKvKeySize = 32 defaultKvValueSize = 1024 defaultWorkers = 1 defaultWorkload = "kvwrite" docString = "For benchmarking dqlite.\n\n" + "Run a 1 node benchmark:\n" + "dqlite-benchmark -d 127.0.0.1:9001 --driver --cluster 127.0.0.1:9001\n\n" + "Run a multi-node benchmark, the first node will self-elect and become leader,\n" + "the driver flag results in the workload being run from the first, leader node.\n" + "dqlite-benchmark --db 127.0.0.1:9001 --driver --cluster 127.0.0.1:9001,127.0.0.1:9002,127.0.0.1:9003 &\n" + "dqlite-benchmark --db 127.0.0.1:9002 --join 127.0.0.1:9001 &\n" + "dqlite-benchmark --db 127.0.0.1:9003 --join 127.0.0.1:9001 &\n\n" + "Run a multi-node benchmark, the first node will self-elect and become leader,\n" + "the driver flag results in the workload being run from the third, non-leader node.\n" + "dqlite-benchmark --db 127.0.0.1:9001 &\n" + "dqlite-benchmark --db 127.0.0.1:9002 --join 127.0.0.1:9001 &\n" + "dqlite-benchmark --db 127.0.0.1:9003 --join 127.0.0.1:9001 --driver --cluster 127.0.0.1:9001,127.0.0.1:9002,127.0.0.1:9003 &\n\n" + "The results can be found on the `driver` node in " + defaultDir + "/results or in the directory provided to the tool.\n" + "Benchmark results are files named `n-q-timestamp` where `n` is the number of the worker,\n" + "`q` is the type of query that was tracked. All results in the file are in milliseconds.\n" ) func signalChannel() chan os.Signal { ch := make(chan os.Signal, 32) signal.Notify(ch, unix.SIGPWR) signal.Notify(ch, unix.SIGINT) signal.Notify(ch, unix.SIGQUIT) signal.Notify(ch, unix.SIGTERM) return ch } func main() { var cluster *[]string var clusterTimeout int var db string var dir string var driver bool var duration int var join *[]string var kvKeySize int var kvValueSize int var workers int var workload string var diskMode bool cmd := &cobra.Command{ Use: "dqlite-benchmark", Short: "For benchmarking dqlite", Long: docString, RunE: func(cmd *cobra.Command, args []string) error { dir := filepath.Join(dir, db) if err := os.MkdirAll(dir, 0755); err != nil { return errors.Wrapf(err, "can't create %s", dir) } app, err := app.New(dir, app.WithDiskMode(diskMode), app.WithAddress(db), app.WithCluster(*join)) if err != nil { return err } readyCtx, cancel := context.WithTimeout(context.Background(), time.Duration(clusterTimeout)*time.Second) defer cancel() if err := app.Ready(readyCtx); err != nil { return errors.Wrap(err, "App not ready in time") } ch := signalChannel() if !driver { fmt.Println("Benchmark client ready. Send signal to abort or when done.") <-ch return nil } if len(*cluster) == 0 { return fmt.Errorf("driver node, `--cluster` flag must be provided") } db, err := app.Open(context.Background(), "benchmark") if err != nil { return err } db.SetMaxOpenConns(500) db.SetMaxIdleConns(500) bm, err := benchmark.New( app, db, dir, benchmark.WithWorkload(workload), benchmark.WithDuration(duration), benchmark.WithWorkers(workers), benchmark.WithKvKeySize(kvKeySize), benchmark.WithKvValueSize(kvValueSize), benchmark.WithCluster(*cluster), benchmark.WithClusterTimeout(clusterTimeout), ) if err != nil { return err } if err := bm.Run(ch); err != nil { return err } db.Close() app.Close() return nil }, } flags := cmd.Flags() flags.StringVarP(&db, "db", "d", "", "Address used for internal database replication.") join = flags.StringSliceP("join", "j", nil, "Database addresses of existing nodes.") cluster = flags.StringSliceP("cluster", "c", nil, "Database addresses of all nodes taking part in the benchmark.\n"+ "The driver will wait for all nodes to be online before running the benchmark.") flags.IntVar(&clusterTimeout, "cluster-timeout", defaultClusterTimeout, "How long the benchmark should wait in seconds for the whole cluster to be online.") flags.StringVarP(&dir, "dir", "D", defaultDir, "Data directory.") flags.StringVarP(&workload, "workload", "w", defaultWorkload, "The workload to run: \"kvwrite\" or \"kvreadwrite\".") flags.BoolVar(&driver, "driver", defaultDriver, "Set this flag to run the benchmark from this instance. Must be set on 1 node.") flags.IntVar(&duration, "duration", defaultDurationS, "Run duration in seconds.") flags.IntVar(&workers, "workers", defaultWorkers, "Number of workers executing the workload.") flags.IntVar(&kvKeySize, "key-size", defaultKvKeySize, "Size of the KV keys in bytes.") flags.IntVar(&kvValueSize, "value-size", defaultKvValueSize, "Size of the KV values in bytes.") flags.BoolVar(&diskMode, "disk", defaultDiskMode, "Warning: Unstable, Experimental. Set this flag to enable dqlite's disk-mode.") cmd.MarkFlagRequired("db") if err := cmd.Execute(); err != nil { os.Exit(1) } } golang-github-canonical-go-dqlite-2.0.0/cmd/dqlite-demo/000077500000000000000000000000001471100661000230135ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/cmd/dqlite-demo/dqlite-demo.go000066400000000000000000000100311471100661000255410ustar00rootroot00000000000000package main import ( "context" "crypto/tls" "crypto/x509" "fmt" "io/ioutil" "log" "net" "net/http" "os" "os/signal" "path/filepath" "strings" "github.com/canonical/go-dqlite/v2/app" "github.com/canonical/go-dqlite/v2/client" "github.com/pkg/errors" "github.com/spf13/cobra" "golang.org/x/sys/unix" ) func main() { var api string var db string var join *[]string var dir string var verbose bool var diskMode bool var crt string var key string cmd := &cobra.Command{ Use: "dqlite-demo", Short: "Demo application using dqlite", Long: `This demo shows how to integrate a Go application with dqlite. Complete documentation is available at https://github.com/canonical/go-dqlite`, RunE: func(cmd *cobra.Command, args []string) error { dir := filepath.Join(dir, db) if err := os.MkdirAll(dir, 0755); err != nil { return errors.Wrapf(err, "can't create %s", dir) } logFunc := func(l client.LogLevel, format string, a ...interface{}) { if !verbose { return } log.Printf(fmt.Sprintf("%s: %s: %s\n", api, l.String(), format), a...) } options := []app.Option{ app.WithAddress(db), app.WithCluster(*join), app.WithLogFunc(logFunc), app.WithDiskMode(diskMode), } // Set TLS options if (crt != "" && key == "") || (key != "" && crt == "") { return fmt.Errorf("both TLS certificate and key must be given") } if crt != "" { cert, err := tls.LoadX509KeyPair(crt, key) if err != nil { return err } data, err := ioutil.ReadFile(crt) if err != nil { return err } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(data) { return fmt.Errorf("bad certificate") } options = append(options, app.WithTLS(app.SimpleTLSConfig(cert, pool))) } app, err := app.New(dir, options...) if err != nil { return err } if err := app.Ready(context.Background()); err != nil { return err } db, err := app.Open(context.Background(), "demo") if err != nil { return err } if _, err := db.Exec(schema); err != nil { return err } http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { key := strings.TrimLeft(r.URL.Path, "/") result := "" err = nil switch r.Method { case "GET": row := db.QueryRow(query, key) err = row.Scan(&result) case "PUT": result = "done" value, _ := ioutil.ReadAll(r.Body) _, err = db.Exec(update, key, string(value[:])) default: err = fmt.Errorf("unsupported method") } if err == nil { fmt.Fprint(w, result) } else { http.Error(w, err.Error(), 500) } }) listener, err := net.Listen("tcp", api) if err != nil { return err } go http.Serve(listener, nil) ch := make(chan os.Signal, 32) signal.Notify(ch, unix.SIGPWR) signal.Notify(ch, unix.SIGINT) signal.Notify(ch, unix.SIGQUIT) signal.Notify(ch, unix.SIGTERM) <-ch listener.Close() db.Close() app.Handover(context.Background()) app.Close() return nil }, } flags := cmd.Flags() flags.StringVarP(&api, "api", "a", "", "address used to expose the demo API") flags.StringVarP(&db, "db", "d", "", "address used for internal database replication") join = flags.StringSliceP("join", "j", nil, "database addresses of existing nodes") flags.StringVarP(&dir, "dir", "D", "/tmp/dqlite-demo", "data directory") flags.BoolVarP(&verbose, "verbose", "v", false, "verbose logging") flags.BoolVar(&diskMode, "disk", defaultDiskMode, "Warning: Unstable, Experimental. Set this flag to enable dqlite's disk-mode.") flags.StringVarP(&crt, "cert", "c", "", "public TLS cert") flags.StringVarP(&key, "key", "k", "", "private TLS key") cmd.MarkFlagRequired("api") cmd.MarkFlagRequired("db") if err := cmd.Execute(); err != nil { os.Exit(1) } } const ( schema = "CREATE TABLE IF NOT EXISTS model (key TEXT, value TEXT, UNIQUE(key))" query = "SELECT value FROM model WHERE key = ?" update = "INSERT OR REPLACE INTO model(key, value) VALUES(?, ?)" defaultDiskMode = false ) golang-github-canonical-go-dqlite-2.0.0/cmd/dqlite/000077500000000000000000000000001471100661000220715ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/cmd/dqlite/dqlite.go000066400000000000000000000067411471100661000237120ustar00rootroot00000000000000package main import ( "context" "crypto/tls" "crypto/x509" "fmt" "io" "io/ioutil" "os" "strings" "time" "github.com/peterh/liner" "github.com/spf13/cobra" "github.com/canonical/go-dqlite/v2/app" "github.com/canonical/go-dqlite/v2/client" "github.com/canonical/go-dqlite/v2/internal/shell" ) func main() { var crt string var key string var servers *[]string var format string var timeoutMsec uint cmd := &cobra.Command{ Use: "dqlite -s [command]", Short: "Standard dqlite shell", Args: cobra.RangeArgs(1, 2), RunE: func(cmd *cobra.Command, args []string) error { if len(*servers) == 0 { return fmt.Errorf("no servers provided") } var store client.NodeStore var err error first := (*servers)[0] if strings.HasPrefix(first, "file://") { if len(*servers) > 1 { return fmt.Errorf("can't mix server store and explicit list") } path := first[len("file://"):] if _, err := os.Stat(path); err != nil { return fmt.Errorf("open servers store: %w", err) } store, err = client.DefaultNodeStore(path) if err != nil { return fmt.Errorf("open servers store: %w", err) } } else { infos := make([]client.NodeInfo, len(*servers)) for i, address := range *servers { infos[i].Address = address } store = client.NewInmemNodeStore() store.Set(context.Background(), infos) } if (crt != "" && key == "") || (key != "" && crt == "") { return fmt.Errorf("both TLS certificate and key must be given") } dial := client.DefaultDialFunc if crt != "" { cert, err := tls.LoadX509KeyPair(crt, key) if err != nil { return err } data, err := ioutil.ReadFile(crt) if err != nil { return err } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(data) { return fmt.Errorf("bad certificate") } config := app.SimpleDialTLSConfig(cert, pool) dial = client.DialFuncWithTLS(dial, config) } sh, err := shell.New(args[0], store, shell.WithDialFunc(dial), shell.WithFormat(format)) if err != nil { return err } if len(args) > 1 { for _, input := range strings.Split(args[1], ";") { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMsec)*time.Millisecond) defer cancel() result, err := sh.Process(ctx, input) if err != nil { return err } else if result != "" { fmt.Println(result) } } return nil } line := liner.NewLiner() defer line.Close() for { input, err := line.Prompt("dqlite> ") if err != nil { if err == io.EOF { break } return err } ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMsec)*time.Millisecond) defer cancel() result, err := sh.Process(ctx, input) if err != nil { fmt.Println("Error: ", err) } else { line.AppendHistory(input) if result != "" { fmt.Println(result) } } } return nil }, } flags := cmd.Flags() servers = flags.StringSliceP("servers", "s", nil, "comma-separated list of db servers, or file://") flags.StringVarP(&crt, "cert", "c", "", "public TLS cert") flags.StringVarP(&key, "key", "k", "", "private TLS key") flags.StringVarP(&format, "format", "f", "tabular", "output format (tabular, json)") flags.UintVar(&timeoutMsec, "timeout", 2000, "timeout of each request (msec)") cmd.MarkFlagRequired("servers") if err := cmd.Execute(); err != nil { os.Exit(1) } } golang-github-canonical-go-dqlite-2.0.0/config.go000066400000000000000000000032011471100661000216340ustar00rootroot00000000000000// +build !nosqlite3 package dqlite import ( "fmt" "os" "github.com/canonical/go-dqlite/v2/internal/bindings" "github.com/canonical/go-dqlite/v2/internal/protocol" "github.com/pkg/errors" ) // ConfigMultiThread sets the threading mode of SQLite to Multi-thread. // // By default go-dqlite configures SQLite to Single-thread mode, because the // dqlite engine itself is single-threaded, and enabling Multi-thread or // Serialized modes would incur in a performance penality. // // If your Go process also uses SQLite directly (e.g. using the // github.com/mattn/go-sqlite3 bindings) you might need to switch to // Multi-thread mode in order to be thread-safe. // // IMPORTANT: It's possible to successfully change SQLite's threading mode only // if no SQLite APIs have been invoked yet (e.g. no database has been opened // yet). Therefore you'll typically want to call ConfigMultiThread() very early // in your process setup. Alternatively you can set the GO_DQLITE_MULTITHREAD // environment variable to 1 at process startup, in order to prevent go-dqlite // from setting Single-thread mode at all. func ConfigMultiThread() error { if err := bindings.ConfigMultiThread(); err != nil { if err, ok := err.(protocol.Error); ok && err.Code == 21 /* SQLITE_MISUSE */ { return fmt.Errorf("SQLite is already initialized") } return errors.Wrap(err, "unknown error") } return nil } func init() { // Don't enable single thread mode by default if GO_DQLITE_MULTITHREAD // is set. if os.Getenv("GO_DQLITE_MULTITHREAD") == "1" { return } err := bindings.ConfigSingleThread() if err != nil { panic(errors.Wrap(err, "set single thread mode")) } } golang-github-canonical-go-dqlite-2.0.0/docs/000077500000000000000000000000001471100661000207745ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/docs/restore-db.md000066400000000000000000000154371471100661000233760ustar00rootroot00000000000000Note this document is not complete and work in progress. # A. INFO **Always backup your database folders before performing any of the steps described below and make sure no dqlite nodes are running!** ## A.1 cluster.yaml File containing the node configuration of your installation. ### Contents ``` - Address: 127.0.0.1:9001 ID: 1 Role: 0 - Address: 127.0.0.1:9002 ID: 2 Role: 1 - Address: 127.0.0.1:9003 ID: 3 Role: 0 ``` *Address* : The `host:port` that will be used in database replication, we will refer to this as `DbAddress`. *ID* : Raft id of the node *Role* : - 0: Voter, takes part in quorum, replicates DB. - 1: Standby, doesn't take part in quorum, replicates DB. - 2: Backup, doesn't take part in quorum, doesn't replicate DB. ## A.2 info.yaml File containing the node specific information. ### Contents ``` Address: 127.0.0.1:9001 ID: 1 Role: 0 ``` ## A.3 Finding the node with the most up-to-date data. 1. For every known node, make a new directory, `NodeDirectory` and copy the data directory of the node into it. 2. Make a `cluster.yaml` conforming to the structure layed out above, with the desired node configuration, save it at `TargetClusterYamlPath`. e.g. you can perform: ``` cat < "cluster.yaml" - Address: 127.0.0.1:9001 ID: 1 Role: 0 - Address: 127.0.0.1:9002 ID: 2 Role: 1 - Address: 127.0.0.1:9003 ID: 3 Role: 0 EOF ``` 3. For every node, run `dqlite -s ".reconfigure "` The `DbAddress`, `DbName` aren't really important, just use something syntactically correct, we are more interested in the side effects of this command on the `NodeDirectory`. The command should return `OK`. 4. Look in the `NodeDirectory` of every node, there should be at least 1 new segment file e.g. `0000000057688811-0000000057688811` with the start index (the number before `-`) equal to the end index (the number after `-`), this will be the most recently created segment file. Remember this index. 5. The node with the highest index from the previous step has the most up-to-date data. If there is an ex aequo, pick one. note: A new command that doesn't rely on the side effects of the `.reconfigure` command will be added in the future. # B. Restoring Data ## B.1 Loading existing data and existing network/node configuration in `dqlite-demo` *Use this when you have access to the machines where the database lives and want to start the database with the unaltered data of every node.* 0. Stop all database nodes & backup all the database folders. 1. Make a base directory for your data e.g. `data`, we will refer to this as the `DataDirectory`. 2. For every node in `cluster.yaml`, create a directory with name equal to `DbAddress` under the `DataDirectory`, unique to the node, this `host:port` will be needed later on for the `--db` argument when you start the `dqlite-demo` application, e.g. for node 1 you now have a directory `data/127.0.0.1:9001`. We will refer to this as the `NodeDirectory`. 3. For every node in `cluster.yaml`, copy all the data for that node to its `NodeDirectory`. 4. For every node in `cluster.yaml`, make sure there exists an `info.yaml` in `NodeDirectory` that contains the information as found in `cluster.yaml`. 5. For every node in `cluster.yaml`, run: `dqlite-demo --dir --api --db `, where `ApiAddress` is a `host:port`, e.g. `dqlite-demo --dir data --api 127.0.0.1:8001 --db 127.0.0.1:9001`. Remark that it is important that `--dir` is a path to the newly created `DataDirectory`, otherwise the demo will create a new directory without the existing data. 6. You should have an operational cluster, access it through e.g. the `dqlite` cli tool. ## B.2 Restore existing data and new network/node configuration in `dqlite-demo`. *Use this when you don't have access to the machines where the database lives and want to start the database with data from a specific node or when you have access to the machines but the cluster has to be reconfigured or repaired.* 0. Stop all database nodes & backup all the database folders. 1. Create a `cluster.yaml` containing your desired node configuration. We will refer to this file by `TargetClusterYaml` and to its location by `TargetClusterYamlPath`. 2. Follow steps 1 and 2 of part `B.1`, where `cluster.yaml` should be interpreted as `TargetClusterYaml`. 3. Find the node with the most up-to-date data following the steps in `A.3`, but use the directories and `cluster.yaml` created in the previous steps. 4. For every non up-to-date node, remove the data files and metadata files from the `NodeDirectory`. 5. For every non up-to-date node, copy the data files of the node with the most up-to-date data to the `NodeDirectory`, don't copy the metadata1 & metadata2 files over. 6. For every node, copy `TargetClusterYaml` to `NodeDirectory`, overwriting `cluster.yaml` that's already there. 7. For every node, make sure there is an `info.yaml` in `NodeDirectory` that is in line with `cluster.yaml` and correct for that node. 8. For every node, run: `dqlite-demo --dir --api --db `. 9. You should have an operational cluster, access it through e.g. the `dqlite` cli tool. ## Terminology - ApiAddress: `host:port` where the `dqlite-demo` REST api is available. - DataDirectory: Base directory under which the NodeDirectories are saved. - data file: segment file, snapshot file or snapshot.meta file. - DbAddress: `host:port` used for database replication. - DbName: name of the sqlite database. - metadata file: file named `metadata1` or `metadata2`. - NodeDirectory: Directory where node specific data is saved, for `dqlite-demo` it should be named `DbAddress` and exist under `DataDirectory`. - segment file: file named like `0000000057685378-0000000057685875`, meaning `startindex-endindex`, these contain raft log entries. - snapshot file: file named like `snapshot-2818-57687002-3645852168`, meaning `snapshot-term-index-timestamp`. - snapshot.meta file: file named like `snapshot-2818-57687002-3645852168.meta`, contains metadata about the matching snapshot file. - TargetClusterYaml: `cluster.yaml` file containing the desired cluster configuration. - TargetClusterYamlPath: location of `TargetClusterYaml`. # C. Startup Errors ## C.1 raft_start(): io: closed segment 0000xxxx-0000xxxx is past last snapshot-x-xxxx-xxxxxx ### C.1.1 Method with data loss This situation can happen when you only have 1 node for example. 1. Backup your data folder and stop the database. 2. Remove the offending segment and try to start again. 3. Repeat step 2 if another segment is preventing you from starting. ### C.1.2 Method preventing data loss 1. Backup your data folders and stop the database. 2. TODO [Variation of the restoring data process] golang-github-canonical-go-dqlite-2.0.0/driver/000077500000000000000000000000001471100661000213375ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/driver/driver.go000066400000000000000000000642331471100661000231710ustar00rootroot00000000000000// Copyright 2017 Canonical Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package driver import ( "context" "database/sql/driver" "fmt" "io" "math" "net" "reflect" "syscall" "time" "github.com/pkg/errors" "github.com/canonical/go-dqlite/v2/client" "github.com/canonical/go-dqlite/v2/internal/protocol" "github.com/canonical/go-dqlite/v2/tracing" ) // Driver perform queries against a dqlite server. type Driver struct { log client.LogFunc // Log function to use store client.NodeStore // Holds addresses of dqlite servers context context.Context // Global cancellation context connectionTimeout time.Duration // Max time to wait for a new connection contextTimeout time.Duration // Default client context timeout. clientConfig protocol.Config // Configuration for dqlite client instances tracing client.LogLevel // Whether to trace statements concurrentLeaderConns *int64 // Maximum number of concurrent connections to other cluster members while probing for leadership. } // Error is returned in case of database errors. type Error = protocol.Error // Error codes. Values here mostly overlap with native SQLite codes. const ( ErrBusy = 5 ErrBusyRecovery = 5 | (1 << 8) ErrBusySnapshot = 5 | (2 << 8) errIoErr = 10 ErrIoErrNotLeader = errIoErr | (40 << 8) ErrIoErrLeadershipLost = errIoErr | (41 << 8) errNotFound = 12 // Legacy error codes before version-3.32.1+replication4. Kept here // for backward compatibility, but should eventually be dropped. errIoErrNotLeaderLegacy = errIoErr | 32<<8 errIoErrLeadershipLostLegacy = errIoErr | (33 << 8) ) // Option can be used to tweak driver parameters. type Option func(*options) // NodeStore is a convenience alias of client.NodeStore. type NodeStore = client.NodeStore // NodeInfo is a convenience alias of client.NodeInfo. type NodeInfo = client.NodeInfo // DefaultNodeStore is a convenience alias of client.DefaultNodeStore. var DefaultNodeStore = client.DefaultNodeStore // WithLogFunc sets a custom logging function. func WithLogFunc(log client.LogFunc) Option { return func(options *options) { options.Log = log } } // DialFunc is a function that can be used to establish a network connection // with a dqlite node. type DialFunc = protocol.DialFunc // WithDialFunc sets a custom dial function. func WithDialFunc(dial DialFunc) Option { return func(options *options) { options.Dial = protocol.DialFunc(dial) } } // WithConnectionTimeout sets the connection timeout. // // If not used, the default is 5 seconds. // // DEPRECATED: Connection cancellation is supported via the driver.Connector // interface, which is used internally by the stdlib sql package. func WithConnectionTimeout(timeout time.Duration) Option { return func(options *options) { options.ConnectionTimeout = timeout } } // WithConnectionBackoffFactor sets the exponential backoff factor for retrying // failed connection attempts. // // If not used, the default is 100 milliseconds. func WithConnectionBackoffFactor(factor time.Duration) Option { return func(options *options) { options.ConnectionBackoffFactor = factor } } // WithConnectionBackoffCap sets the maximum connection retry backoff value, // (regardless of the backoff factor) for retrying failed connection attempts. // // If not used, the default is 1 second. func WithConnectionBackoffCap(cap time.Duration) Option { return func(options *options) { options.ConnectionBackoffCap = cap } } // WithConcurrentLeaderConns is the maximum number of concurrent connections // to other cluster members that will be attempted while searching for the dqlite leader. // It takes a pointer to an integer so that the value can be dynamically modified based on cluster health. // // The default is 10 connections to other cluster members. func WithConcurrentLeaderConns(maxConns *int64) Option { return func(o *options) { o.ConcurrentLeaderConns = maxConns } } // WithAttemptTimeout sets the timeout for each individual connection attempt. // // The Connector.Connect() and Driver.Open() methods try to find the current // leader among the servers in the store that was passed to New(). Each time // they attempt to probe an individual server for leadership this timeout will // apply, so a server which accepts the connection but it's then unresponsive // won't block the line. // // If not used, the default is 15 seconds. func WithAttemptTimeout(timeout time.Duration) Option { return func(options *options) { options.AttemptTimeout = timeout } } // WithRetryLimit sets the maximum number of connection retries. // // If not used, the default is 0 (unlimited retries) func WithRetryLimit(limit uint) Option { return func(options *options) { options.RetryLimit = limit } } // WithContext sets a global cancellation context. // // DEPRECATED: This API is no a no-op. Users should explicitly pass a context // if they wish to cancel their requests. func WithContext(context context.Context) Option { return func(options *options) { options.Context = context } } // WithContextTimeout sets the default client context timeout for DB.Begin() // when no context deadline is provided. // // DEPRECATED: Users should use db APIs that support contexts if they wish to // cancel their requests. func WithContextTimeout(timeout time.Duration) Option { return func(options *options) { options.ContextTimeout = timeout } } // WithTracing will emit a log message at the given level every time a // statement gets executed. func WithTracing(level client.LogLevel) Option { return func(options *options) { options.Tracing = level } } // New creates a new dqlite driver, which also implements the // driver.Driver interface. func New(store client.NodeStore, options ...Option) (*Driver, error) { o := defaultOptions() for _, option := range options { option(o) } driver := &Driver{ log: o.Log, store: store, context: o.Context, connectionTimeout: o.ConnectionTimeout, contextTimeout: o.ContextTimeout, tracing: o.Tracing, concurrentLeaderConns: o.ConcurrentLeaderConns, clientConfig: protocol.Config{ Dial: o.Dial, AttemptTimeout: o.AttemptTimeout, BackoffFactor: o.ConnectionBackoffFactor, BackoffCap: o.ConnectionBackoffCap, RetryLimit: o.RetryLimit, }, } return driver, nil } // Hold configuration options for a dqlite driver. type options struct { Log client.LogFunc Dial protocol.DialFunc AttemptTimeout time.Duration ConnectionTimeout time.Duration ContextTimeout time.Duration ConnectionBackoffFactor time.Duration ConnectionBackoffCap time.Duration ConcurrentLeaderConns *int64 RetryLimit uint Context context.Context Tracing client.LogLevel } // Create a options object with sane defaults. func defaultOptions() *options { maxConns := protocol.MaxConcurrentLeaderConns return &options{ Log: client.DefaultLogFunc, Dial: client.DefaultDialFunc, Tracing: client.LogNone, ConcurrentLeaderConns: &maxConns, } } // A Connector represents a driver in a fixed configuration and can create any // number of equivalent Conns for use by multiple goroutines. type Connector struct { uri string driver *Driver protocol *protocol.Connector } // Connect returns a connection to the database. func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { if c.driver.context != nil { ctx = c.driver.context } if c.driver.connectionTimeout != 0 { var cancel func() ctx, cancel = context.WithTimeout(ctx, c.driver.connectionTimeout) defer cancel() } conn := &Conn{ log: c.driver.log, contextTimeout: c.driver.contextTimeout, tracing: c.driver.tracing, } proto, err := c.protocol.Connect(ctx) if err != nil { return nil, driverError(conn.log, errors.Wrap(err, "failed to create dqlite connection")) } conn.protocol = proto conn.request.Init(4096) conn.response.Init(4096) protocol.EncodeOpen(&conn.request, c.uri, 0, "volatile") if err := conn.protocol.Call(ctx, &conn.request, &conn.response); err != nil { conn.protocol.Close() return nil, driverError(conn.log, errors.Wrap(err, "failed to open database")) } conn.id, err = protocol.DecodeDb(&conn.response) if err != nil { conn.protocol.Close() return nil, driverError(conn.log, errors.Wrap(err, "failed to open database")) } return conn, nil } // Driver returns the underlying Driver of the Connector, func (c *Connector) Driver() driver.Driver { return c.driver } // OpenConnector creates a reusable Connector for a specific database. func (d *Driver) OpenConnector(name string) (driver.Connector, error) { config := d.clientConfig config.ConcurrentLeaderConns = *d.concurrentLeaderConns pc := protocol.NewLeaderConnector(d.store, config, d.log) connector := &Connector{ uri: name, driver: d, protocol: pc, } return connector, nil } // Open establishes a new connection to a SQLite database on the dqlite server. // // The given name must be a pure file name without any directory segment, // dqlite will connect to a database with that name in its data directory. // // Query parameters are always valid except for "mode=memory". // // If this node is not the leader, or the leader is unknown an ErrNotLeader // error is returned. func (d *Driver) Open(uri string) (driver.Conn, error) { connector, err := d.OpenConnector(uri) if err != nil { return nil, err } return connector.Connect(context.Background()) } // SetContextTimeout sets the default client timeout when no context deadline // is provided. // // DEPRECATED: This API is no a no-op. Users should explicitly pass a context // if they wish to cancel their requests, or use the WithContextTimeout option. func (d *Driver) SetContextTimeout(timeout time.Duration) {} // ErrNoAvailableLeader is returned as root cause of Open() if there's no // leader available in the cluster. var ErrNoAvailableLeader = protocol.ErrNoAvailableLeader // Conn implements the sql.Conn interface. type Conn struct { log client.LogFunc protocol *protocol.Protocol request protocol.Message response protocol.Message id uint32 // Database ID. contextTimeout time.Duration tracing client.LogLevel } // PrepareContext returns a prepared statement, bound to this connection. // context is for the preparation of the statement, it must not store the // context within the statement itself. func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { ctx, span := tracing.Start(ctx, "dqlite.driver.PrepareContext", query) defer span.End() stmt := &Stmt{ protocol: c.protocol, request: &c.request, response: &c.response, log: c.log, tracing: c.tracing, } protocol.EncodePrepare(&c.request, uint64(c.id), query) var start time.Time if c.tracing != client.LogNone { start = time.Now() } err := c.protocol.Call(ctx, &c.request, &c.response) if c.tracing != client.LogNone { c.log(c.tracing, "%.3fs request prepared: %q", time.Since(start).Seconds(), query) } if err != nil { return nil, driverError(c.log, err) } stmt.db, stmt.id, stmt.params, err = protocol.DecodeStmt(&c.response) if err != nil { return nil, driverError(c.log, err) } if c.tracing != client.LogNone { stmt.sql = query } return stmt, nil } // Prepare returns a prepared statement, bound to this connection. func (c *Conn) Prepare(query string) (driver.Stmt, error) { return c.PrepareContext(context.Background(), query) } // ExecContext is an optional interface that may be implemented by a Conn. func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { ctx, span := tracing.Start(ctx, "dqlite.driver.ExecContext", query) defer span.End() if int64(len(args)) > math.MaxUint32 { return nil, driverError(c.log, fmt.Errorf("too many parameters (%d)", len(args))) } else if len(args) > math.MaxUint8 { protocol.EncodeExecSQLV1(&c.request, uint64(c.id), query, args) } else { protocol.EncodeExecSQLV0(&c.request, uint64(c.id), query, args) } var start time.Time if c.tracing != client.LogNone { start = time.Now() } err := c.protocol.Call(ctx, &c.request, &c.response) if c.tracing != client.LogNone { c.log(c.tracing, "%.3fs request exec: %q", time.Since(start).Seconds(), query) } if err != nil { return nil, driverError(c.log, err) } var result protocol.Result result, err = protocol.DecodeResult(&c.response) if err != nil { return nil, driverError(c.log, err) } return &Result{result: result}, nil } // Query is an optional interface that may be implemented by a Conn. func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { return c.QueryContext(context.Background(), query, valuesToNamedValues(args)) } // QueryContext is an optional interface that may be implemented by a Conn. func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { ctx, span := tracing.Start(ctx, "dqlite.driver.QueryContext", query) defer span.End() if int64(len(args)) > math.MaxUint32 { return nil, driverError(c.log, fmt.Errorf("too many parameters (%d)", len(args))) } else if len(args) > math.MaxUint8 { protocol.EncodeQuerySQLV1(&c.request, uint64(c.id), query, args) } else { protocol.EncodeQuerySQLV0(&c.request, uint64(c.id), query, args) } var start time.Time if c.tracing != client.LogNone { start = time.Now() } err := c.protocol.Call(ctx, &c.request, &c.response) if c.tracing != client.LogNone { c.log(c.tracing, "%.3fs request query: %q", time.Since(start).Seconds(), query) } if err != nil { return nil, driverError(c.log, err) } var rows protocol.Rows rows, err = protocol.DecodeRows(&c.response) if err != nil { return nil, driverError(c.log, err) } return &Rows{ ctx: ctx, request: &c.request, response: &c.response, protocol: c.protocol, rows: rows, log: c.log, }, nil } // Exec is an optional interface that may be implemented by a Conn. func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { return c.ExecContext(context.Background(), query, valuesToNamedValues(args)) } // Close invalidates and potentially stops any current prepared statements and // transactions, marking this connection as no longer in use. // // Because the sql package maintains a free pool of connections and only calls // Close when there's a surplus of idle connections, it shouldn't be necessary // for drivers to do their own connection caching. func (c *Conn) Close() error { return c.protocol.Close() } // BeginTx starts and returns a new transaction. If the context is canceled by // the user the sql package will call Tx.Rollback before discarding and closing // the connection. // // This must check opts.Isolation to determine if there is a set isolation // level. If the driver does not support a non-default level and one is set or // if there is a non-default isolation level that is not supported, an error // must be returned. // // This must also check opts.ReadOnly to determine if the read-only value is // true to either set the read-only transaction property if supported or return // an error if it is not supported. func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if _, err := c.ExecContext(ctx, "BEGIN", nil); err != nil { return nil, err } tx := &Tx{ conn: c, log: c.log, } return tx, nil } // Begin starts and returns a new transaction. // // Deprecated: Drivers should implement ConnBeginTx instead (or additionally). func (c *Conn) Begin() (driver.Tx, error) { ctx := context.Background() if c.contextTimeout > 0 { var cancel func() ctx, cancel = context.WithTimeout(context.Background(), c.contextTimeout) defer cancel() } return c.BeginTx(ctx, driver.TxOptions{}) } // Tx is a transaction. type Tx struct { conn *Conn log client.LogFunc } // Commit the transaction. func (tx *Tx) Commit() error { ctx := context.Background() if _, err := tx.conn.ExecContext(ctx, "COMMIT", nil); err != nil { return driverError(tx.log, err) } return nil } // Rollback the transaction. func (tx *Tx) Rollback() error { ctx := context.Background() if _, err := tx.conn.ExecContext(ctx, "ROLLBACK", nil); err != nil { return driverError(tx.log, err) } return nil } // Stmt is a prepared statement. It is bound to a Conn and not // used by multiple goroutines concurrently. type Stmt struct { protocol *protocol.Protocol request *protocol.Message response *protocol.Message db uint32 id uint32 params uint64 log client.LogFunc sql string // Prepared SQL, only set when tracing tracing client.LogLevel } // Close closes the statement. func (s *Stmt) Close() error { protocol.EncodeFinalize(s.request, s.db, s.id) ctx := context.Background() if err := s.protocol.Call(ctx, s.request, s.response); err != nil { return driverError(s.log, err) } if err := protocol.DecodeEmpty(s.response); err != nil { return driverError(s.log, err) } return nil } // NumInput returns the number of placeholder parameters. func (s *Stmt) NumInput() int { return int(s.params) } // ExecContext executes a query that doesn't return rows, such // as an INSERT or UPDATE. // // ExecContext must honor the context timeout and return when it is canceled. func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { ctx, span := tracing.Start(ctx, "dqlite.driver.Stmt.ExecContext", s.sql) defer span.End() if int64(len(args)) > math.MaxUint32 { return nil, driverError(s.log, fmt.Errorf("too many parameters (%d)", len(args))) } else if len(args) > math.MaxUint8 { protocol.EncodeExecV1(s.request, s.db, s.id, args) } else { protocol.EncodeExecV0(s.request, s.db, s.id, args) } var start time.Time if s.tracing != client.LogNone { start = time.Now() } err := s.protocol.Call(ctx, s.request, s.response) if s.tracing != client.LogNone { s.log(s.tracing, "%.3fs request prepared: %q", time.Since(start).Seconds(), s.sql) } if err != nil { return nil, driverError(s.log, err) } var result protocol.Result result, err = protocol.DecodeResult(s.response) if err != nil { return nil, driverError(s.log, err) } return &Result{result: result}, nil } // Exec executes a query that doesn't return rows, such func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { return s.ExecContext(context.Background(), valuesToNamedValues(args)) } // QueryContext executes a query that may return rows, such as a // SELECT. // // QueryContext must honor the context timeout and return when it is canceled. func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { ctx, span := tracing.Start(ctx, "dqlite.driver.Stmt.QueryContext", s.sql) defer span.End() if int64(len(args)) > math.MaxUint32 { return nil, driverError(s.log, fmt.Errorf("too many parameters (%d)", len(args))) } else if len(args) > math.MaxUint8 { protocol.EncodeQueryV1(s.request, s.db, s.id, args) } else { protocol.EncodeQueryV0(s.request, s.db, s.id, args) } var start time.Time if s.tracing != client.LogNone { start = time.Now() } err := s.protocol.Call(ctx, s.request, s.response) if s.tracing != client.LogNone { s.log(s.tracing, "%.3fs request prepared: %q", time.Since(start).Seconds(), s.sql) } if err != nil { return nil, driverError(s.log, err) } var rows protocol.Rows rows, err = protocol.DecodeRows(s.response) if err != nil { return nil, driverError(s.log, err) } return &Rows{ ctx: ctx, request: s.request, response: s.response, protocol: s.protocol, rows: rows, log: s.log, }, nil } // Query executes a query that may return rows, such as a func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { return s.QueryContext(context.Background(), valuesToNamedValues(args)) } // Result is the result of a query execution. type Result struct { result protocol.Result } // LastInsertId returns the database's auto-generated ID // after, for example, an INSERT into a table with primary // key. func (r *Result) LastInsertId() (int64, error) { return int64(r.result.LastInsertID), nil } // RowsAffected returns the number of rows affected by the // query. func (r *Result) RowsAffected() (int64, error) { return int64(r.result.RowsAffected), nil } // Rows is an iterator over an executed query's results. type Rows struct { ctx context.Context protocol *protocol.Protocol request *protocol.Message response *protocol.Message rows protocol.Rows consumed bool types []string log client.LogFunc } // Columns returns the names of the columns. The number of // columns of the result is inferred from the length of the // slice. If a particular column name isn't known, an empty // string should be returned for that entry. func (r *Rows) Columns() []string { return r.rows.Columns } // Close closes the rows iterator. func (r *Rows) Close() error { err := r.rows.Close() // If we consumed the whole result set, there's nothing to do as // there's no pending response from the server. if r.consumed { return nil } // If there is was a single-response result set, we're done. if err == io.EOF { return nil } // Let's issue an interrupt request and wait until we get an empty // response, signalling that the query was interrupted. if err := r.protocol.Interrupt(r.ctx, r.request, r.response); err != nil { return driverError(r.log, err) } return nil } // Next is called to populate the next row of data into // the provided slice. The provided slice will be the same // size as the Columns() are wide. // // Next should return io.EOF when there are no more rows. func (r *Rows) Next(dest []driver.Value) error { err := r.rows.Next(dest) if err == protocol.ErrRowsPart { r.rows.Close() if err := r.protocol.More(r.ctx, r.response); err != nil { return driverError(r.log, err) } rows, err := protocol.DecodeRows(r.response) if err != nil { return driverError(r.log, err) } r.rows = rows return r.rows.Next(dest) } if err == io.EOF { r.consumed = true } return err } // ColumnTypeScanType implements RowsColumnTypeScanType. func (r *Rows) ColumnTypeScanType(i int) reflect.Type { // column := sql.NewColumn(r.rows, i) // typ, err := r.protocol.ColumnTypeScanType(context.Background(), column) // if err != nil { // return nil // } // return typ.DriverType() return nil } // ColumnTypeDatabaseTypeName implements RowsColumnTypeDatabaseTypeName. // warning: not thread safe func (r *Rows) ColumnTypeDatabaseTypeName(i int) string { if r.types == nil { var err error r.types, err = r.rows.ColumnTypes() // an error might not matter if we get our types if err != nil && i >= len(r.types) { // a panic here doesn't really help, // as an empty column type is not the end of the world // but we should still inform the user of the failure const msg = "row (%p) error returning column #%d type: %v\n" r.log(client.LogWarn, msg, r, i, err) return "" } } return r.types[i] } // Convert a driver.Value slice into a driver.NamedValue slice. func valuesToNamedValues(args []driver.Value) []driver.NamedValue { namedValues := make([]driver.NamedValue, len(args)) for i, value := range args { namedValues[i] = driver.NamedValue{ Ordinal: i + 1, Value: value, } } return namedValues } type unwrappable interface { Unwrap() error } // TODO driver.ErrBadConn should not be returned when there's a possibility that // the query has been executed. In our case there is a window in protocol.Call // between `send` and `recv` where the send has succeeded but the recv has // failed. In those cases we call driverError on the result of protocol.Call, // possibly returning ErrBadCon. // https://cs.opensource.google/go/go/+/refs/tags/go1.20.4:src/database/sql/driver/driver.go;drc=a32a592c8c14927c20ac42808e1fb2e55b2e9470;l=162 func driverError(log client.LogFunc, err error) error { switch err := errors.Cause(err).(type) { case syscall.Errno: log(client.LogDebug, "network connection lost: %v", err) return driver.ErrBadConn case *net.OpError: log(client.LogDebug, "network connection lost: %v", err) return driver.ErrBadConn case protocol.ErrRequest: switch err.Code { case errIoErrNotLeaderLegacy: fallthrough case errIoErrLeadershipLostLegacy: fallthrough case ErrIoErrNotLeader: fallthrough case ErrIoErrLeadershipLost: log(client.LogDebug, "leadership lost (%d - %s)", err.Code, err.Description) return driver.ErrBadConn case errNotFound: log(client.LogDebug, "not found - potentially after leadership loss (%d - %s)", err.Code, err.Description) return driver.ErrBadConn default: // FIXME: the server side sometimes return SQLITE_OK // even in case of errors. This issue is still being // investigated, but for now let's just mark this // connection as bad so the client will retry. if err.Code == 0 { log(client.LogWarn, "unexpected error code (%d - %s)", err.Code, err.Description) return driver.ErrBadConn } return Error{ Code: int(err.Code), Message: err.Description, } } default: // When using a TLS connection, the underlying error might get // wrapped by the stdlib itself with the new errors wrapping // conventions available since go 1.13. In that case we check // the underlying error with Unwrap() instead of Cause(). if root, ok := err.(unwrappable); ok { err = root.Unwrap() } switch err.(type) { case *net.OpError: log(client.LogDebug, "network connection lost: %v", err) return driver.ErrBadConn } } if errors.Is(err, io.EOF) { log(client.LogDebug, "EOF detected: %v", err) return driver.ErrBadConn } return err } golang-github-canonical-go-dqlite-2.0.0/driver/driver_test.go000066400000000000000000000467721471100661000242400ustar00rootroot00000000000000// Copyright 2017 Canonical Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package driver_test import ( "context" "database/sql/driver" "io" "io/ioutil" "os" "strings" "testing" "time" dqlite "github.com/canonical/go-dqlite/v2" "github.com/canonical/go-dqlite/v2/client" dqlitedriver "github.com/canonical/go-dqlite/v2/driver" "github.com/canonical/go-dqlite/v2/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestDriver_Open(t *testing.T) { driver, cleanup := newDriver(t) defer cleanup() conn, err := driver.Open("test.db") require.NoError(t, err) assert.NoError(t, conn.Close()) } func TestDriver_Prepare(t *testing.T) { driver, cleanup := newDriver(t) defer cleanup() conn, err := driver.Open("test.db") require.NoError(t, err) stmt, err := conn.Prepare("CREATE TABLE test (n INT)") require.NoError(t, err) assert.Equal(t, 0, stmt.NumInput()) assert.NoError(t, conn.Close()) } func TestConn_Exec(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) execer := conn.(driver.ExecerContext) _, err = execer.ExecContext(context.Background(), "CREATE TABLE test (n INT)", nil) require.NoError(t, err) result, err := execer.ExecContext(context.Background(), "INSERT INTO test(n) VALUES(1)", nil) require.NoError(t, err) lastInsertID, err := result.LastInsertId() require.NoError(t, err) assert.Equal(t, lastInsertID, int64(1)) rowsAffected, err := result.RowsAffected() require.NoError(t, err) assert.Equal(t, rowsAffected, int64(1)) assert.NoError(t, conn.Close()) } func TestConn_Query(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) execer := conn.(driver.ExecerContext) _, err = execer.ExecContext(context.Background(), "CREATE TABLE test (n INT)", nil) require.NoError(t, err) _, err = execer.ExecContext(context.Background(), "INSERT INTO test(n) VALUES(1)", nil) require.NoError(t, err) queryer := conn.(driver.QueryerContext) _, err = queryer.QueryContext(context.Background(), "SELECT n FROM test", nil) require.NoError(t, err) assert.NoError(t, conn.Close()) } func TestConn_QueryRow(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) execer := conn.(driver.ExecerContext) _, err = execer.ExecContext(context.Background(), "CREATE TABLE test (n INT)", nil) require.NoError(t, err) _, err = execer.ExecContext(context.Background(), "INSERT INTO test(n) VALUES(1)", nil) require.NoError(t, err) _, err = execer.ExecContext(context.Background(), "INSERT INTO test(n) VALUES(1)", nil) require.NoError(t, err) queryer := conn.(driver.QueryerContext) rows, err := queryer.QueryContext(context.Background(), "SELECT n FROM test", nil) require.NoError(t, err) values := make([]driver.Value, 1) require.NoError(t, rows.Next(values)) require.NoError(t, rows.Close()) assert.NoError(t, conn.Close()) } func TestConn_InterruptQuery(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) execer := conn.(driver.ExecerContext) _, err = execer.ExecContext(context.Background(), "CREATE TABLE test (n INT)", nil) require.NoError(t, err) // When querying all these rows, dqlite will fill up more than 1 // response buffer allowing us to interrupt an active query. for i := 0; i < 4098; i++ { _, err = execer.ExecContext(context.Background(), "INSERT INTO test(n) VALUES(1)", nil) require.NoError(t, err) } queryer := conn.(driver.QueryerContext) rows, err := queryer.QueryContext(context.Background(), "SELECT * FROM test", nil) require.NoError(t, err) // rows.Close() will trigger an Interrupt. require.NoError(t, rows.Close()) assert.NoError(t, conn.Close()) } func TestConn_QueryBlob(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) execer := conn.(driver.ExecerContext) _, err = execer.ExecContext(context.Background(), "CREATE TABLE test (data BLOB)", nil) require.NoError(t, err) values := []driver.NamedValue{ {Ordinal: 1, Value: []byte{'a', 'b', 'c'}}, } _, err = execer.ExecContext(context.Background(), "INSERT INTO test(data) VALUES(?)", values) require.NoError(t, err) queryer := conn.(driver.QueryerContext) rows, err := queryer.QueryContext(context.Background(), "SELECT data FROM test", nil) require.NoError(t, err) assert.Equal(t, rows.Columns(), []string{"data"}) rowValues := make([]driver.Value, 1) require.NoError(t, rows.Next(rowValues)) assert.Equal(t, []byte{'a', 'b', 'c'}, rowValues[0]) assert.NoError(t, conn.Close()) } func TestStmt_Exec(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) stmt, err := conn.Prepare("CREATE TABLE test (n INT)") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, stmt.Close()) values := []driver.NamedValue{ {Ordinal: 1, Value: int64(1)}, } stmt, err = conn.Prepare("INSERT INTO test(n) VALUES(?)") require.NoError(t, err) result, err := stmt.(driver.StmtExecContext).ExecContext(context.Background(), values) require.NoError(t, err) lastInsertID, err := result.LastInsertId() require.NoError(t, err) assert.Equal(t, lastInsertID, int64(1)) rowsAffected, err := result.RowsAffected() require.NoError(t, err) assert.Equal(t, rowsAffected, int64(1)) require.NoError(t, stmt.Close()) assert.NoError(t, conn.Close()) } func TestStmt_ExecManyParams(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) stmt, err := conn.Prepare("CREATE TABLE test (n INT)") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, stmt.Close()) stmt, err = conn.Prepare("INSERT INTO test(n) VALUES " + strings.Repeat("(?), ", 299) + " (?)") require.NoError(t, err) values := make([]driver.NamedValue, 300) for i := range values { values[i] = driver.NamedValue{Ordinal: i + 1, Value: int64(1)} } _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), values) require.NoError(t, err) require.NoError(t, stmt.Close()) assert.NoError(t, conn.Close()) } func TestStmt_Query(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) stmt, err := conn.Prepare("CREATE TABLE test (n INT)") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, stmt.Close()) stmt, err = conn.Prepare("INSERT INTO test(n) VALUES(-123)") require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, stmt.Close()) stmt, err = conn.Prepare("SELECT n FROM test") require.NoError(t, err) rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) require.NoError(t, err) assert.Equal(t, rows.Columns(), []string{"n"}) values := make([]driver.Value, 1) require.NoError(t, rows.Next(values)) assert.Equal(t, int64(-123), values[0]) require.Equal(t, io.EOF, rows.Next(values)) require.NoError(t, stmt.Close()) assert.NoError(t, conn.Close()) } func TestStmt_QueryManyParams(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) stmt, err := conn.Prepare("CREATE TABLE test (n INT)") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, stmt.Close()) stmt, err = conn.Prepare("SELECT n FROM test WHERE n IN (" + strings.Repeat("?, ", 299) + " ?)") require.NoError(t, err) values := make([]driver.NamedValue, 300) for i := range values { values[i] = driver.NamedValue{Ordinal: i + 1, Value: int64(1)} } _, err = stmt.(driver.StmtQueryContext).QueryContext(context.Background(), values) require.NoError(t, err) require.NoError(t, stmt.Close()) assert.NoError(t, conn.Close()) } func TestConn_QueryParams(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) execer := conn.(driver.ExecerContext) _, err = execer.ExecContext(context.Background(), "CREATE TABLE test (n INT, t TEXT)", nil) require.NoError(t, err) _, err = execer.ExecContext(context.Background(), ` INSERT INTO test (n,t) VALUES (1,'a'); INSERT INTO test (n,t) VALUES (2,'a'); INSERT INTO test (n,t) VALUES (2,'b'); INSERT INTO test (n,t) VALUES (3,'b'); `, nil) require.NoError(t, err) values := []driver.NamedValue{ {Ordinal: 1, Value: int64(1)}, {Ordinal: 2, Value: "a"}, } queryer := conn.(driver.QueryerContext) rows, err := queryer.QueryContext(context.Background(), "SELECT n, t FROM test WHERE n > ? AND t = ?", values) require.NoError(t, err) assert.Equal(t, rows.Columns()[0], "n") rowValues := make([]driver.Value, 2) require.NoError(t, rows.Next(rowValues)) assert.Equal(t, int64(2), rowValues[0]) assert.Equal(t, "a", rowValues[1]) require.Equal(t, io.EOF, rows.Next(rowValues)) assert.NoError(t, conn.Close()) } func TestConn_QueryManyParams(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) execer := conn.(driver.ExecerContext) _, err = execer.ExecContext(context.Background(), "CREATE TABLE test (n INT)", nil) require.NoError(t, err) values := make([]driver.NamedValue, 300) for i := range values { values[i] = driver.NamedValue{Ordinal: i + 1, Value: int64(1)} } queryer := conn.(driver.QueryerContext) _, err = queryer.QueryContext(context.Background(), "SELECT n FROM test WHERE n IN ("+strings.Repeat("?, ", 299)+" ?)", values) require.NoError(t, err) assert.NoError(t, conn.Close()) } func TestConn_ExecManyParams(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) execer := conn.(driver.ExecerContext) _, err = execer.ExecContext(context.Background(), "CREATE TABLE test (n INT)", nil) require.NoError(t, err) values := make([]driver.NamedValue, 300) for i := range values { values[i] = driver.NamedValue{Ordinal: i + 1, Value: int64(1)} } _, err = execer.ExecContext(context.Background(), "INSERT INTO test(n) VALUES "+strings.Repeat("(?), ", 299)+" (?)", values) require.NoError(t, err) assert.NoError(t, conn.Close()) } func Test_ColumnTypesEmpty(t *testing.T) { t.Skip("this currently fails if the result set is empty, is dqlite skipping the header if empty set?") drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) stmt, err := conn.Prepare("CREATE TABLE test (n INT)") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, stmt.Close()) stmt, err = conn.Prepare("SELECT n FROM test") require.NoError(t, err) rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, err) rowTypes, ok := rows.(driver.RowsColumnTypeDatabaseTypeName) require.True(t, ok) typeName := rowTypes.ColumnTypeDatabaseTypeName(0) assert.Equal(t, "INTEGER", typeName) require.NoError(t, stmt.Close()) assert.NoError(t, conn.Close()) } func Test_ColumnTypesExists(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) stmt, err := conn.Prepare("CREATE TABLE test (n INT)") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, stmt.Close()) stmt, err = conn.Prepare("INSERT INTO test(n) VALUES(-123)") require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) stmt, err = conn.Prepare("SELECT n FROM test") require.NoError(t, err) rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, err) rowTypes, ok := rows.(driver.RowsColumnTypeDatabaseTypeName) require.True(t, ok) typeName := rowTypes.ColumnTypeDatabaseTypeName(0) assert.Equal(t, "INTEGER", typeName) require.NoError(t, stmt.Close()) assert.NoError(t, conn.Close()) } // ensure column types data is available // even after the last row of the query func Test_ColumnTypesEnd(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) stmt, err := conn.Prepare("CREATE TABLE test (n INT)") require.NoError(t, err) _, err = conn.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, stmt.Close()) stmt, err = conn.Prepare("INSERT INTO test(n) VALUES(-123)") require.NoError(t, err) _, err = stmt.(driver.StmtExecContext).ExecContext(context.Background(), nil) require.NoError(t, err) stmt, err = conn.Prepare("SELECT n FROM test") require.NoError(t, err) rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) require.NoError(t, err) require.NoError(t, err) rowTypes, ok := rows.(driver.RowsColumnTypeDatabaseTypeName) require.True(t, ok) typeName := rowTypes.ColumnTypeDatabaseTypeName(0) assert.Equal(t, "INTEGER", typeName) values := make([]driver.Value, 1) require.NoError(t, rows.Next(values)) assert.Equal(t, int64(-123), values[0]) require.Equal(t, io.EOF, rows.Next(values)) // despite EOF we should have types cached typeName = rowTypes.ColumnTypeDatabaseTypeName(0) assert.Equal(t, "INTEGER", typeName) require.NoError(t, stmt.Close()) assert.NoError(t, conn.Close()) } func Test_ZeroColumns(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) queryer := conn.(driver.QueryerContext) rows, err := queryer.QueryContext(context.Background(), "CREATE TABLE foo (bar INTEGER)", []driver.NamedValue{}) require.NoError(t, err) values := []driver.Value{} require.Equal(t, io.EOF, rows.Next(values)) require.NoError(t, conn.Close()) } func Test_DescribeLastEntry(t *testing.T) { dir, dirCleanup := newDir(t) defer dirCleanup() _, cleanup := newNode(t, dir) store := newStore(t, bindAddress) log := logging.Test(t) drv, err := dqlitedriver.New(store, dqlitedriver.WithLogFunc(log)) require.NoError(t, err) conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ExecerContext).ExecContext(context.Background(), `CREATE TABLE test (n INT)`, nil) require.NoError(t, err) stmt, err := conn.Prepare(`INSERT INTO test(n) VALUES(?)`) require.NoError(t, err) for i := 0; i < 300; i++ { values := []driver.NamedValue{{Ordinal: 1, Value: int64(i)}} _, err := stmt.(driver.StmtExecContext).ExecContext(context.Background(), values) require.NoError(t, err) } require.NoError(t, stmt.Close()) assert.NoError(t, conn.Close()) cleanup() info, err := dqlite.ReadLastEntryInfo(dir) require.NoError(t, err) assert.Equal(t, info.Index, uint64(302)) assert.Equal(t, info.Term, uint64(1)) } func Test_Dump(t *testing.T) { drv, cleanup := newDriver(t) defer cleanup() conn, err := drv.Open("test.db") require.NoError(t, err) _, err = conn.(driver.ExecerContext).ExecContext(context.Background(), `CREATE TABLE foo (n INT)`, nil) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() client, err := client.New(ctx, bindAddress) require.NoError(t, err) defer client.Close() files, err := client.Dump(ctx, "test.db") require.NoError(t, err) require.Len(t, files, 2) assert.Equal(t, "test.db", files[0].Name) assert.Equal(t, 4096, len(files[0].Data)) assert.Equal(t, "test.db-wal", files[1].Name) assert.Equal(t, 8272, len(files[1].Data)) } const bindAddress = "@1" func newDriver(t *testing.T) (*dqlitedriver.Driver, func()) { t.Helper() dir, dirCleanup := newDir(t) _, nodeCleanup := newNode(t, dir) store := newStore(t, bindAddress) log := logging.Test(t) driver, err := dqlitedriver.New(store, dqlitedriver.WithLogFunc(log)) require.NoError(t, err) cleanup := func() { nodeCleanup() dirCleanup() } return driver, cleanup } // Create a new in-memory server store populated with the given addresses. func newStore(t *testing.T, address string) client.NodeStore { t.Helper() store := client.NewInmemNodeStore() server := client.NodeInfo{Address: address} require.NoError(t, store.Set(context.Background(), []client.NodeInfo{server})) return store } func newNode(t *testing.T, dir string) (*dqlite.Node, func()) { t.Helper() server, err := dqlite.New(uint64(1), bindAddress, dir, dqlite.WithBindAddress(bindAddress)) require.NoError(t, err) err = server.Start() require.NoError(t, err) cleanup := func() { require.NoError(t, server.Close()) } return server, cleanup } // Return a new temporary directory. func newDir(t *testing.T) (string, func()) { t.Helper() dir, err := ioutil.TempDir("", "dqlite-replication-test-") assert.NoError(t, err) cleanup := func() { _, err := os.Stat(dir) if err != nil { assert.True(t, os.IsNotExist(err)) } else { assert.NoError(t, os.RemoveAll(dir)) } } return dir, cleanup } golang-github-canonical-go-dqlite-2.0.0/driver/integration_test.go000066400000000000000000000253341471100661000252570ustar00rootroot00000000000000package driver_test import ( "context" "database/sql" "fmt" "os" "testing" "time" dqlite "github.com/canonical/go-dqlite/v2" "github.com/canonical/go-dqlite/v2/client" "github.com/canonical/go-dqlite/v2/driver" "github.com/canonical/go-dqlite/v2/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // https://sqlite.org/rescode.html#constraint_unique const SQLITE_CONSTRAINT_UNIQUE = 2067 func TestIntegration_DatabaseSQL(t *testing.T) { db, _, cleanup := newDB(t, 3) defer cleanup() tx, err := db.Begin() require.NoError(t, err) _, err = tx.Exec(` CREATE TABLE test (n INT, s TEXT); CREATE TABLE test2 (n INT, t DATETIME DEFAULT CURRENT_TIMESTAMP) `) require.NoError(t, err) stmt, err := tx.Prepare("INSERT INTO test(n, s) VALUES(?, ?)") require.NoError(t, err) _, err = stmt.Exec(int64(123), "hello") require.NoError(t, err) require.NoError(t, stmt.Close()) _, err = tx.Exec("INSERT INTO test2(n) VALUES(?)", int64(456)) require.NoError(t, err) require.NoError(t, tx.Commit()) tx, err = db.Begin() require.NoError(t, err) rows, err := tx.Query("SELECT n, s FROM test") require.NoError(t, err) for rows.Next() { var n int64 var s string require.NoError(t, rows.Scan(&n, &s)) assert.Equal(t, int64(123), n) assert.Equal(t, "hello", s) } require.NoError(t, rows.Err()) require.NoError(t, rows.Close()) rows, err = tx.Query("SELECT n, t FROM test2") require.NoError(t, err) for rows.Next() { var n int64 var s time.Time require.NoError(t, rows.Scan(&n, &s)) assert.Equal(t, int64(456), n) } require.NoError(t, rows.Err()) require.NoError(t, rows.Close()) require.NoError(t, tx.Rollback()) } func TestIntegration_ConstraintError(t *testing.T) { db, _, cleanup := newDB(t, 3) defer cleanup() _, err := db.Exec("CREATE TABLE test (n INT, UNIQUE (n))") require.NoError(t, err) _, err = db.Exec("INSERT INTO test (n) VALUES (1)") require.NoError(t, err) _, err = db.Exec("INSERT INTO test (n) VALUES (1)") if derr, ok := err.(driver.Error); ok { assert.Equal(t, SQLITE_CONSTRAINT_UNIQUE, derr.Code) assert.Equal(t, "UNIQUE constraint failed: test.n", derr.Message) } else { t.Fatalf("expected diver error, got %+v", err) } } func TestIntegration_ExecBindError(t *testing.T) { db, _, cleanup := newDB(t, 1) defer cleanup() defer db.Close() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() _, err := db.ExecContext(ctx, "CREATE TABLE test (n INT)") require.NoError(t, err) _, err = db.ExecContext(ctx, "INSERT INTO test(n) VALUES(1)", 1) assert.EqualError(t, err, "bind parameters") } func TestIntegration_QueryBindError(t *testing.T) { db, _, cleanup := newDB(t, 1) defer cleanup() defer db.Close() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() _, err := db.QueryContext(ctx, "SELECT 1", 1) assert.EqualError(t, err, "bind parameters") } func TestIntegration_LargeQuery(t *testing.T) { db, _, cleanup := newDB(t, 3) defer cleanup() tx, err := db.Begin() require.NoError(t, err) _, err = tx.Exec("CREATE TABLE test (n INT)") require.NoError(t, err) stmt, err := tx.Prepare("INSERT INTO test(n) VALUES(?)") require.NoError(t, err) for i := 0; i < 512; i++ { _, err = stmt.Exec(int64(i)) require.NoError(t, err) } require.NoError(t, stmt.Close()) require.NoError(t, tx.Commit()) tx, err = db.Begin() require.NoError(t, err) rows, err := tx.Query("SELECT n FROM test") require.NoError(t, err) columns, err := rows.Columns() require.NoError(t, err) assert.Equal(t, []string{"n"}, columns) count := 0 for i := 0; rows.Next(); i++ { var n int64 require.NoError(t, rows.Scan(&n)) assert.Equal(t, int64(i), n) count++ } require.NoError(t, rows.Err()) require.NoError(t, rows.Close()) assert.Equal(t, count, 512) require.NoError(t, tx.Rollback()) } // Build a 2-node cluster, kill one node and recover the other. func TestIntegration_Recover(t *testing.T) { db, helpers, cleanup := newDB(t, 2) defer cleanup() _, err := db.Exec("CREATE TABLE test (n INT)") require.NoError(t, err) helpers[0].Close() helpers[1].Close() helpers[0].Create() infos := []client.NodeInfo{{ID: 1, Address: "@1", Role: client.Voter}} err = dqlite.ReconfigureMembershipExt(helpers[0].Dir, infos) require.NoError(t, err) helpers[0].Start() // FIXME: this is necessary otherwise the INSERT below fails with "no // such table", because the replication hooks are not triggered and the // barrier is not applied. _, err = db.Exec("CREATE TABLE test2 (n INT)") require.NoError(t, err) _, err = db.Exec("INSERT INTO test(n) VALUES(1)") require.NoError(t, err) } // The db.Ping() method can be used to wait until there is a stable leader. func TestIntegration_PingOnlyWorksOnceLeaderElected(t *testing.T) { db, helpers, cleanup := newDB(t, 2) defer cleanup() helpers[0].Close() // Ping returns an error, since the cluster is not available. ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() assert.Error(t, db.PingContext(ctx)) helpers[0].Create() helpers[0].Start() // Ping now returns no error, since the cluster is available. assert.NoError(t, db.Ping()) // If leadership is lost after the first successful call, Ping() still // returns no error. helpers[0].Close() assert.NoError(t, db.Ping()) } func TestIntegration_HighAvailability(t *testing.T) { db, helpers, cleanup := newDB(t, 3) defer cleanup() _, err := db.Exec("CREATE TABLE test (n INT)") require.NoError(t, err) // Shutdown all three nodes. helpers[0].Close() helpers[1].Close() helpers[2].Close() // Restart two of them. helpers[1].Create() helpers[2].Create() helpers[1].Start() helpers[2].Start() // Give the cluster a chance to establish a quorom time.Sleep(2 * time.Second) _, err = db.Exec("INSERT INTO test(n) VALUES(1)") require.NoError(t, err) } func TestIntegration_LeadershipTransfer(t *testing.T) { db, helpers, cleanup := newDB(t, 3) defer cleanup() _, err := db.Exec("CREATE TABLE test (n INT)") require.NoError(t, err) cli := helpers[0].Client() require.NoError(t, cli.Transfer(context.Background(), 2)) _, err = db.Exec("INSERT INTO test(n) VALUES(1)") require.NoError(t, err) } func TestIntegration_LeadershipTransfer_Tx(t *testing.T) { db, helpers, cleanup := newDB(t, 3) defer cleanup() _, err := db.Exec("CREATE TABLE test (n INT)") require.NoError(t, err) cli := helpers[0].Client() require.NoError(t, cli.Transfer(context.Background(), 2)) tx, err := db.Begin() require.NoError(t, err) _, err = tx.Query("SELECT * FROM test") require.NoError(t, err) require.NoError(t, tx.Commit()) } func TestOptions(t *testing.T) { // make sure applying all options doesn't break anything store := client.NewInmemNodeStore() log := logging.Test(t) _, err := driver.New( store, driver.WithLogFunc(log), driver.WithContext(context.Background()), driver.WithConnectionTimeout(15*time.Second), driver.WithContextTimeout(2*time.Second), driver.WithConnectionBackoffFactor(50*time.Millisecond), driver.WithConnectionBackoffCap(1*time.Second), driver.WithAttemptTimeout(5*time.Second), driver.WithRetryLimit(0), ) require.NoError(t, err) } func newDB(t *testing.T, n int) (*sql.DB, []*nodeHelper, func()) { infos := make([]client.NodeInfo, n) for i := range infos { infos[i].ID = uint64(i + 1) infos[i].Address = fmt.Sprintf("@%d", infos[i].ID) infos[i].Role = client.Voter } return newDBWithInfos(t, infos) } func newDBWithInfos(t *testing.T, infos []client.NodeInfo) (*sql.DB, []*nodeHelper, func()) { helpers, helpersCleanup := newNodeHelpers(t, infos) store := client.NewInmemNodeStore() require.NoError(t, store.Set(context.Background(), infos)) log := logging.Test(t) driver, err := driver.New(store, driver.WithLogFunc(log)) require.NoError(t, err) driverName := fmt.Sprintf("dqlite-integration-test-%d", driversCount) sql.Register(driverName, driver) driversCount++ db, err := sql.Open(driverName, "test.db") require.NoError(t, err) cleanup := func() { require.NoError(t, db.Close()) helpersCleanup() } return db, helpers, cleanup } type nodeHelper struct { t *testing.T ID uint64 Address string Dir string Node *dqlite.Node } func newNodeHelper(t *testing.T, id uint64, address string) *nodeHelper { h := &nodeHelper{ t: t, ID: id, Address: address, } h.Dir, _ = newDir(t) h.Create() h.Start() return h } func (h *nodeHelper) Client() *client.Client { client, err := client.New(context.Background(), h.Node.BindAddress()) require.NoError(h.t, err) return client } func (h *nodeHelper) Create() { var err error require.Nil(h.t, h.Node) h.Node, err = dqlite.New(h.ID, h.Address, h.Dir, dqlite.WithBindAddress(h.Address)) require.NoError(h.t, err) } func (h *nodeHelper) Start() { require.NotNil(h.t, h.Node) require.NoError(h.t, h.Node.Start()) } func (h *nodeHelper) Close() { require.NotNil(h.t, h.Node) require.NoError(h.t, h.Node.Close()) h.Node = nil } func (h *nodeHelper) cleanup() { if h.Node != nil { h.Close() } require.NoError(h.t, os.RemoveAll(h.Dir)) } func newNodeHelpers(t *testing.T, infos []client.NodeInfo) ([]*nodeHelper, func()) { t.Helper() n := len(infos) helpers := make([]*nodeHelper, n) for i, info := range infos { helpers[i] = newNodeHelper(t, info.ID, info.Address) if i > 0 { client := helpers[0].Client() defer client.Close() require.NoError(t, client.Add(context.Background(), infos[i])) } } cleanup := func() { for _, helper := range helpers { helper.cleanup() } } return helpers, cleanup } var driversCount = 0 func TestIntegration_ColumnTypeName(t *testing.T) { db, _, cleanup := newDB(t, 1) defer cleanup() _, err := db.Exec("CREATE TABLE test (n INT, UNIQUE (n))") require.NoError(t, err) _, err = db.Exec("INSERT INTO test (n) VALUES (1)") require.NoError(t, err) rows, err := db.Query("SELECT n FROM test") require.NoError(t, err) defer rows.Close() types, err := rows.ColumnTypes() require.NoError(t, err) assert.Equal(t, "INTEGER", types[0].DatabaseTypeName()) require.True(t, rows.Next()) var n int64 err = rows.Scan(&n) require.NoError(t, err) assert.Equal(t, int64(1), n) } func TestIntegration_SqlNullTime(t *testing.T) { db, _, cleanup := newDB(t, 1) defer cleanup() _, err := db.Exec("CREATE TABLE test (tm DATETIME)") require.NoError(t, err) // Insert sql.NullTime into DB var t1 sql.NullTime res, err := db.Exec("INSERT INTO test (tm) VALUES (?)", t1) require.NoError(t, err) n, err := res.RowsAffected() require.NoError(t, err) assert.EqualValues(t, n, 1) // Retrieve inserted sql.NullTime from DB row := db.QueryRow("SELECT tm FROM test LIMIT 1") var t2 sql.NullTime err = row.Scan(&t2) require.NoError(t, err) assert.Equal(t, t1, t2) } golang-github-canonical-go-dqlite-2.0.0/go.mod000066400000000000000000000012151471100661000211510ustar00rootroot00000000000000module github.com/canonical/go-dqlite/v2 // This is to maintain the ppa package on focal go 1.13 require ( github.com/Rican7/retry v0.3.1 github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/renameio v1.0.1 github.com/kr/pretty v0.1.0 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect github.com/mattn/go-sqlite3 v1.14.7 github.com/peterh/liner v1.2.2 github.com/pkg/errors v0.9.1 github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.7.0 golang.org/x/sync v0.8.0 golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v2 v2.4.0 ) golang-github-canonical-go-dqlite-2.0.0/go.sum000066400000000000000000000100751471100661000212020ustar00rootroot00000000000000github.com/Rican7/retry v0.3.1 h1:scY4IbO8swckzoA/11HgBwaZRJEyY9vaNJshcdhp1Mc= github.com/Rican7/retry v0.3.1/go.mod h1:CxSDrhAyXmTMeEuRAnArMu1FHu48vtfjLREWqVl7Vw0= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/renameio v1.0.1 h1:Lh/jXZmvZxb0BBeSY5VKEfidcbcbenKjZFzM/q0fSeU= github.com/google/renameio v1.0.1/go.mod h1:t/HQoYBZSsWSNK35C6CO/TpPLDVWvxOHboWUAweKUpk= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mattn/go-runewidth v0.0.3/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4OSgU= github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/peterh/liner v1.2.2 h1:aJ4AOodmL+JxOZZEL2u9iJf8omNRpqHc/EbrK+3mAXw= github.com/peterh/liner v1.2.2/go.mod h1:xFwJyiKIXJZUKItq5dGHZSTBRAuG/CpeNpWLyiNRNwI= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1 h1:kwrAHlwJ0DUBZwQ238v+Uod/3eZ8B2K5rYsUHBQvzmI= golang.org/x/sys v0.0.0-20211117180635-dee7805ff2e1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= golang-github-canonical-go-dqlite-2.0.0/internal/000077500000000000000000000000001471100661000216605ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/internal/bindings/000077500000000000000000000000001471100661000234555ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/internal/bindings/build.go000066400000000000000000000002331471100661000251010ustar00rootroot00000000000000package bindings /* #cgo linux LDFLAGS: -ldqlite */ import "C" // required dqlite version var dqliteMajorVersion int = 1 var dqliteMinorVersion int = 17 golang-github-canonical-go-dqlite-2.0.0/internal/bindings/server.go000066400000000000000000000221571471100661000253210ustar00rootroot00000000000000package bindings /* #include #include #include #include #include #include #include #include #define RAFT_NOCONNECTION 16 #define EMIT_BUF_LEN 1024 typedef unsigned long long nanoseconds_t; typedef unsigned long long failure_domain_t; // Duplicate a file descriptor and prevent it from being cloned into child processes. static int dupCloexec(int oldfd) { int newfd = -1; newfd = dup(oldfd); if (newfd < 0) { return -1; } if (fcntl(newfd, F_SETFD, FD_CLOEXEC) < 0) { close(newfd); return -1; } return newfd; } // C to Go trampoline for custom connect function. int connectWithDial(uintptr_t handle, char *address, int *fd); // Wrapper to call the Go trampoline. static int connectTrampoline(void *data, const char *address, int *fd) { uintptr_t handle = (uintptr_t)(data); return connectWithDial(handle, (char*)address, fd); } // Configure a custom connect function. static int configConnectFunc(dqlite_node *t, uintptr_t handle) { return dqlite_node_set_connect_func(t, connectTrampoline, (void*)handle); } static dqlite_node_info_ext *makeInfos(int n) { return calloc(n, sizeof(dqlite_node_info_ext)); } static void setInfo(dqlite_node_info_ext *infos, unsigned i, dqlite_node_id id, const char *address, int role) { dqlite_node_info_ext *info = &infos[i]; info->size = sizeof(dqlite_node_info_ext); info->id = id; info->address = (uint64_t)(uintptr_t)address; info->dqlite_role = role; } __attribute__((weak)) int dqlite_node_set_auto_recovery(dqlite_node *t, bool on); static int setAutoRecovery(dqlite_node *t, bool on) { if (dqlite_node_set_auto_recovery == NULL) { return DQLITE_ERROR; } return dqlite_node_set_auto_recovery(t, on); } */ import "C" import ( "context" "fmt" "net" "os" "sync" "time" "unsafe" "github.com/canonical/go-dqlite/v2/internal/protocol" ) type Node struct { node *C.dqlite_node ctx context.Context cancel context.CancelFunc } type SnapshotParams struct { Threshold uint64 Trailing uint64 } // Initializes state. func init() { // FIXME: ignore SIGPIPE, see https://github.com/joyent/libuv/issues/1254 C.signal(C.SIGPIPE, C.SIG_IGN) } // NewNode creates a new Node instance. func NewNode(ctx context.Context, id uint64, address string, dir string) (*Node, error) { requiredVersion := dqliteMajorVersion*100 + dqliteMinorVersion // Remove the patch version, as patch versions should be compatible. runtimeVersion := int(C.dqlite_version_number()) / 100 if requiredVersion > runtimeVersion { return nil, fmt.Errorf("version mismatch: required version(%d.%d.x) current version(%d.%d.x)", dqliteMajorVersion, dqliteMinorVersion, runtimeVersion/100, runtimeVersion%100) } var server *C.dqlite_node cid := C.dqlite_node_id(id) caddress := C.CString(address) defer C.free(unsafe.Pointer(caddress)) cdir := C.CString(dir) defer C.free(unsafe.Pointer(cdir)) if rc := C.dqlite_node_create(cid, caddress, cdir, &server); rc != 0 { errmsg := C.GoString(C.dqlite_node_errmsg(server)) C.dqlite_node_destroy(server) return nil, fmt.Errorf("%s", errmsg) } node := &Node{node: (*C.dqlite_node)(unsafe.Pointer(server))} node.ctx, node.cancel = context.WithCancel(ctx) return node, nil } func (s *Node) SetDialFunc(dial protocol.DialFunc) error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) connectLock.Lock() defer connectLock.Unlock() connectIndex++ connectRegistry[connectIndex] = dial contextRegistry[connectIndex] = s.ctx if rc := C.configConnectFunc(server, connectIndex); rc != 0 { return fmt.Errorf("failed to set connect func") } return nil } func (s *Node) SetBindAddress(address string) error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) caddress := C.CString(address) defer C.free(unsafe.Pointer(caddress)) if rc := C.dqlite_node_set_bind_address(server, caddress); rc != 0 { return fmt.Errorf("failed to set bind address %q: %d", address, rc) } return nil } func (s *Node) SetNetworkLatency(nanoseconds uint64) error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) cnanoseconds := C.nanoseconds_t(nanoseconds) if rc := C.dqlite_node_set_network_latency(server, cnanoseconds); rc != 0 { return fmt.Errorf("failed to set network latency") } return nil } func (s *Node) SetSnapshotParams(params SnapshotParams) error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) cthreshold := C.unsigned(params.Threshold) ctrailing := C.unsigned(params.Trailing) if rc := C.dqlite_node_set_snapshot_params(server, cthreshold, ctrailing); rc != 0 { return fmt.Errorf("failed to set snapshot params") } return nil } func (s *Node) SetFailureDomain(code uint64) error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) ccode := C.failure_domain_t(code) if rc := C.dqlite_node_set_failure_domain(server, ccode); rc != 0 { return fmt.Errorf("set failure domain: %d", rc) } return nil } func (s *Node) EnableDiskMode() error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) if rc := C.dqlite_node_enable_disk_mode(server); rc != 0 { return fmt.Errorf("failed to set disk mode") } return nil } func (s *Node) SetAutoRecovery(on bool) error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) if rc := C.setAutoRecovery(server, C.bool(on)); rc != 0 { return fmt.Errorf("failed to set auto-recovery behavior") } return nil } func (s *Node) GetBindAddress() string { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) return C.GoString(C.dqlite_node_get_bind_address(server)) } func (s *Node) Start() error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) if rc := C.dqlite_node_start(server); rc != 0 { errmsg := C.GoString(C.dqlite_node_errmsg(server)) return fmt.Errorf("%s", errmsg) } return nil } func (s *Node) Stop() error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) if rc := C.dqlite_node_stop(server); rc != 0 { return fmt.Errorf("task stopped with error code %d", rc) } return nil } // Close the server releasing all used resources. func (s *Node) Close() { defer s.cancel() server := (*C.dqlite_node)(unsafe.Pointer(s.node)) C.dqlite_node_destroy(server) } // Remark that Recover doesn't take the node role into account func (s *Node) Recover(cluster []protocol.NodeInfo) error { for i := range cluster { cluster[i].Role = protocol.Voter } return s.RecoverExt(cluster) } // RecoverExt has a similar purpose as `Recover` but takes the node role into account func (s *Node) RecoverExt(cluster []protocol.NodeInfo) error { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) n := C.int(len(cluster)) infos := C.makeInfos(n) defer C.free(unsafe.Pointer(infos)) for i, info := range cluster { cid := C.dqlite_node_id(info.ID) caddress := C.CString(info.Address) crole := C.int(info.Role) defer C.free(unsafe.Pointer(caddress)) C.setInfo(infos, C.unsigned(i), cid, caddress, crole) } if rc := C.dqlite_node_recover_ext(server, infos, n); rc != 0 { errmsg := C.GoString(C.dqlite_node_errmsg(server)) return fmt.Errorf("recover failed with error code %d, error details: %s", rc, errmsg) } return nil } func (s *Node) DescribeLastEntry() (uint64, uint64, error) { server := (*C.dqlite_node)(unsafe.Pointer(s.node)) index := C.uint64_t(0) term := C.uint64_t(0) if rc := C.dqlite_node_describe_last_entry(server, &index, &term); rc != 0 { return 0, 0, fmt.Errorf("dqlite_node_describe_last_entry failed with error code %d", rc) } return uint64(index), uint64(term), nil } // GenerateID generates a unique ID for a server. func GenerateID(address string) uint64 { caddress := C.CString(address) defer C.free(unsafe.Pointer(caddress)) id := C.dqlite_generate_node_id(caddress) return uint64(id) } // Extract the underlying socket from a connection. func connToSocket(conn net.Conn) (C.int, error) { file, err := conn.(fileConn).File() if err != nil { return C.int(-1), err } fd1 := C.int(file.Fd()) // Duplicate the file descriptor, in order to prevent Go's finalizer to // close it. fd2 := C.dupCloexec(fd1) if fd2 < 0 { return C.int(-1), fmt.Errorf("failed to dup socket fd") } conn.Close() return fd2, nil } // Interface that net.Conn must implement in order to extract the underlying // file descriptor. type fileConn interface { File() (*os.File, error) } //export connectWithDial func connectWithDial(handle C.uintptr_t, address *C.char, fd *C.int) C.int { connectLock.Lock() defer connectLock.Unlock() dial := connectRegistry[handle] ctx := contextRegistry[handle] // TODO: make timeout customizable. dialCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() conn, err := dial(dialCtx, C.GoString(address)) if err != nil { return C.RAFT_NOCONNECTION } socket, err := connToSocket(conn) if err != nil { return C.RAFT_NOCONNECTION } *fd = socket return C.int(0) } // Use handles to avoid passing Go pointers to C. var contextRegistry = make(map[C.uintptr_t]context.Context) var connectRegistry = make(map[C.uintptr_t]protocol.DialFunc) var connectIndex C.uintptr_t = 100 var connectLock = sync.Mutex{} // ErrNodeStopped is returned by Node.Handle() is the server was stopped. var ErrNodeStopped = fmt.Errorf("server was stopped") // To compare bool values. var cfalse C.bool golang-github-canonical-go-dqlite-2.0.0/internal/bindings/server_test.go000066400000000000000000000120531471100661000263520ustar00rootroot00000000000000package bindings_test import ( "context" "encoding/binary" "io/ioutil" "net" "os" "strings" "testing" "time" "github.com/canonical/go-dqlite/v2/internal/bindings" "github.com/canonical/go-dqlite/v2/internal/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNode_Create(t *testing.T) { _, cleanup := newNode(t) defer cleanup() } func TestNode_Start(t *testing.T) { dir, cleanup := newDir(t) defer cleanup() server, err := bindings.NewNode(context.Background(), 1, "1", dir) require.NoError(t, err) defer server.Close() err = server.SetBindAddress("@") require.NoError(t, err) err = server.Start() require.NoError(t, err) conn, err := net.Dial("unix", server.GetBindAddress()) require.NoError(t, err) conn.Close() assert.True(t, strings.HasPrefix(server.GetBindAddress(), "@")) err = server.Stop() require.NoError(t, err) } func TestNode_Restart(t *testing.T) { dir, cleanup := newDir(t) defer cleanup() server, err := bindings.NewNode(context.Background(), 1, "1", dir) require.NoError(t, err) require.NoError(t, server.SetBindAddress("@abc")) require.NoError(t, server.Start()) require.NoError(t, server.Stop()) server.Close() server, err = bindings.NewNode(context.Background(), 1, "1", dir) require.NoError(t, err) require.NoError(t, server.SetBindAddress("@abc")) require.NoError(t, server.Start()) require.NoError(t, server.Stop()) server.Close() } func TestNode_Start_Inet(t *testing.T) { dir, cleanup := newDir(t) defer cleanup() server, err := bindings.NewNode(context.Background(), 1, "1", dir) require.NoError(t, err) defer server.Close() err = server.SetBindAddress("127.0.0.1:9000") require.NoError(t, err) err = server.Start() require.NoError(t, err) conn, err := net.Dial("tcp", server.GetBindAddress()) require.NoError(t, err) conn.Close() err = server.Stop() require.NoError(t, err) } func TestNode_Leader(t *testing.T) { _, cleanup := newNode(t) defer cleanup() conn := newClient(t) // Make a Leader request buf := makeClientRequest(t, conn, protocol.RequestLeader) assert.Equal(t, uint8(1), buf[0]) require.NoError(t, conn.Close()) } func TestNode_Autorecovery(t *testing.T) { dir, cleanup := newDir(t) defer cleanup() server, err := bindings.NewNode(context.Background(), 1, "1", dir) require.NoError(t, err) defer server.Close() err = server.SetAutoRecovery(false) require.NoError(t, err) } // func TestNode_Heartbeat(t *testing.T) { // server, cleanup := newNode(t) // defer cleanup() // listener, cleanup := newListener(t) // defer cleanup() // cleanup = runNode(t, server, listener) // defer cleanup() // conn := newClient(t, listener) // // Make a Heartbeat request // makeClientRequest(t, conn, bindings.RequestHeartbeat) // require.NoError(t, conn.Close()) // } // func TestNode_ConcurrentHandleAndClose(t *testing.T) { // server, cleanup := newNode(t) // defer cleanup() // listener, cleanup := newListener(t) // defer cleanup() // acceptCh := make(chan error) // go func() { // conn, err := listener.Accept() // if err != nil { // acceptCh <- err // } // server.Handle(conn) // acceptCh <- nil // }() // conn, err := net.Dial("unix", listener.Addr().String()) // require.NoError(t, err) // require.NoError(t, conn.Close()) // assert.NoError(t, <-acceptCh) // } // Create a new Node object for tests. func newNode(t *testing.T) (*bindings.Node, func()) { t.Helper() dir, dirCleanup := newDir(t) server, err := bindings.NewNode(context.Background(), 1, "1", dir) require.NoError(t, err) err = server.SetBindAddress("@test") require.NoError(t, err) require.NoError(t, server.Start()) cleanup := func() { require.NoError(t, server.Stop()) server.Close() dirCleanup() } return server, cleanup } // Create a new client network connection, performing the handshake. func newClient(t *testing.T) net.Conn { t.Helper() conn, err := net.Dial("unix", "@test") require.NoError(t, err) // Handshake err = binary.Write(conn, binary.LittleEndian, protocol.VersionLegacy) require.NoError(t, err) return conn } // Perform a client request. func makeClientRequest(t *testing.T, conn net.Conn, kind byte) []byte { t.Helper() // Number of words err := binary.Write(conn, binary.LittleEndian, uint32(1)) require.NoError(t, err) // Type, flags, extra. n, err := conn.Write([]byte{kind, 0, 0, 0}) require.NoError(t, err) require.Equal(t, 4, n) n, err = conn.Write([]byte{0, 0, 0, 0, 0, 0, 0, 0}) // Unused single-word request payload require.NoError(t, err) require.Equal(t, 8, n) // Read the response conn.SetDeadline(time.Now().Add(250 * time.Millisecond)) buf := make([]byte, 64) _, err = conn.Read(buf) require.NoError(t, err) return buf } // Return a new temporary directory. func newDir(t *testing.T) (string, func()) { t.Helper() dir, err := ioutil.TempDir("", "dqlite-replication-test-") assert.NoError(t, err) cleanup := func() { _, err := os.Stat(dir) if err != nil { assert.True(t, os.IsNotExist(err)) } else { assert.NoError(t, os.RemoveAll(dir)) } } return dir, cleanup } golang-github-canonical-go-dqlite-2.0.0/internal/bindings/sqlite3.go000066400000000000000000000013361471100661000253730ustar00rootroot00000000000000// +build !nosqlite3 package bindings import ( "github.com/canonical/go-dqlite/v2/internal/protocol" ) /* #cgo linux LDFLAGS: -lsqlite3 #include static int sqlite3ConfigSingleThread() { return sqlite3_config(SQLITE_CONFIG_SINGLETHREAD); } static int sqlite3ConfigMultiThread() { return sqlite3_config(SQLITE_CONFIG_MULTITHREAD); } */ import "C" func ConfigSingleThread() error { if rc := C.sqlite3ConfigSingleThread(); rc != 0 { return protocol.Error{Message: C.GoString(C.sqlite3_errstr(rc)), Code: int(rc)} } return nil } func ConfigMultiThread() error { if rc := C.sqlite3ConfigMultiThread(); rc != 0 { return protocol.Error{Message: C.GoString(C.sqlite3_errstr(rc)), Code: int(rc)} } return nil } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/000077500000000000000000000000001471100661000235215ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/internal/protocol/buffer.go000066400000000000000000000002671471100661000253260ustar00rootroot00000000000000package protocol // Buffer for reading responses or writing requests. type buffer struct { Bytes []byte Offset int } func (b *buffer) Advance(amount int) { b.Offset += amount } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/config.go000066400000000000000000000031341471100661000253160ustar00rootroot00000000000000package protocol import ( "time" "github.com/Rican7/retry/backoff" "github.com/Rican7/retry/strategy" ) // Config holds various configuration parameters for a dqlite client. type Config struct { Dial DialFunc // Network dialer. DialTimeout time.Duration // Timeout for establishing a network connection . AttemptTimeout time.Duration // Timeout for each individual attempt to probe a server's leadership. BackoffFactor time.Duration // Exponential backoff factor for retries. BackoffCap time.Duration // Maximum connection retry backoff value, RetryLimit uint // Maximum number of retries, or 0 for unlimited. ConcurrentLeaderConns int64 // Maximum number of concurrent connections to other cluster members while probing for leadership. PermitShared bool } // RetryStrategies returns a configuration for the retry package based on a Config. func (config Config) RetryStrategies() (strategies []strategy.Strategy) { limit, factor, cap := config.RetryLimit, config.BackoffFactor, config.BackoffCap // Fix for change in behavior: https://github.com/Rican7/retry/pull/12 if limit++; limit > 1 { strategies = append(strategies, strategy.Limit(limit)) } backoffFunc := backoff.BinaryExponential(factor) strategies = append(strategies, func(attempt uint) bool { if attempt > 0 { duration := backoffFunc(attempt) // Duration might be negative in case of integer overflow. if !(0 < duration && duration <= cap) { duration = cap } time.Sleep(duration) } return true }, ) return } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/connector.go000066400000000000000000000302701471100661000260440ustar00rootroot00000000000000package protocol import ( "context" "encoding/binary" "fmt" "io" "net" "sort" "sync" "time" "github.com/Rican7/retry" "github.com/canonical/go-dqlite/v2/logging" "github.com/pkg/errors" "golang.org/x/sync/semaphore" ) // MaxConcurrentLeaderConns is the default maximum number of concurrent requests to other cluster members to probe for leadership. const MaxConcurrentLeaderConns int64 = 10 // DialFunc is a function that can be used to establish a network connection. type DialFunc func(context.Context, string) (net.Conn, error) // LeaderTracker remembers the address of the cluster leader, and possibly // holds a reusable connection to it. type LeaderTracker struct { mu sync.RWMutex lastKnownLeaderAddr string proto *Protocol } func (lt *LeaderTracker) GetLeaderAddr() string { lt.mu.RLock() defer lt.mu.RUnlock() return lt.lastKnownLeaderAddr } func (lt *LeaderTracker) SetLeaderAddr(address string) { lt.mu.Lock() defer lt.mu.Unlock() lt.lastKnownLeaderAddr = address } func (lt *LeaderTracker) UnsetLeaderAddr() { lt.mu.Lock() defer lt.mu.Unlock() lt.lastKnownLeaderAddr = "" } func (lt *LeaderTracker) TakeSharedProtocol() (proto *Protocol) { lt.mu.Lock() defer lt.mu.Unlock() if proto, lt.proto = lt.proto, nil; proto != nil { proto.lt = lt } return } func (lt *LeaderTracker) DonateSharedProtocol(proto *Protocol) (accepted bool) { lt.mu.Lock() defer lt.mu.Unlock() if accepted = lt.proto == nil; accepted { lt.proto = proto } return } type Connector struct { clientID uint64 // Conn ID to use when registering against the server. store NodeStore nodeID uint64 nodeAddress string lt *LeaderTracker config Config // Connection parameters. log logging.Func // Logging function. } // NewConnector returns a Connector that will connect to the current cluster // leader. func NewLeaderConnector(store NodeStore, config Config, log logging.Func) *Connector { if config.Dial == nil { config.Dial = Dial } if config.DialTimeout == 0 { config.DialTimeout = 5 * time.Second } if config.AttemptTimeout == 0 { config.AttemptTimeout = 15 * time.Second } if config.BackoffFactor == 0 { config.BackoffFactor = 100 * time.Millisecond } if config.BackoffCap == 0 { config.BackoffCap = time.Second } if config.ConcurrentLeaderConns == 0 { config.ConcurrentLeaderConns = MaxConcurrentLeaderConns } return &Connector{ store: store, lt: &LeaderTracker{}, config: config, log: log, } } // NewDirectConnector returns a Connector that will connect to the node with // the given ID and address. func NewDirectConnector(id uint64, address string, config Config, log logging.Func) *Connector { if config.Dial == nil { config.Dial = Dial } if config.DialTimeout == 0 { config.DialTimeout = 5 * time.Second } if config.AttemptTimeout == 0 { config.AttemptTimeout = 15 * time.Second } if config.BackoffFactor == 0 { config.BackoffFactor = 100 * time.Millisecond } if config.BackoffCap == 0 { config.BackoffCap = time.Second } if config.ConcurrentLeaderConns == 0 { config.ConcurrentLeaderConns = MaxConcurrentLeaderConns } return &Connector{ nodeID: id, nodeAddress: address, lt: &LeaderTracker{}, config: config, log: log, } } // Connect opens a new Protocol based on the Connector's configuration. func (c *Connector) Connect(ctx context.Context) (*Protocol, error) { if c.nodeID != 0 { ctx, cancel := context.WithTimeout(ctx, c.config.AttemptTimeout) defer cancel() conn, err := c.config.Dial(ctx, c.nodeAddress) if err != nil { return nil, errors.Wrap(err, "dial") } version := VersionOne protocol, err := Handshake(ctx, conn, version, c.nodeAddress) if err == errBadProtocol { c.log(logging.Warn, "unsupported protocol %d, attempt with legacy", version) version = VersionLegacy protocol, err = Handshake(ctx, conn, version, c.nodeAddress) } if err != nil { conn.Close() return nil, errors.Wrap(err, "handshake") } return protocol, nil } if c.config.PermitShared { if sharedProto := c.lt.TakeSharedProtocol(); sharedProto != nil { if leaderAddr, err := askLeader(ctx, sharedProto); err == nil && sharedProto.addr == leaderAddr { c.log(logging.Debug, "reusing shared connection to %s", sharedProto.addr) c.lt.SetLeaderAddr(leaderAddr) return sharedProto, nil } c.log(logging.Debug, "discarding shared connection to %s", sharedProto.addr) sharedProto.Bad() sharedProto.Close() } } var protocol *Protocol err := retry.Retry(func(attempt uint) error { log := func(l logging.Level, format string, a ...interface{}) { format = fmt.Sprintf("attempt %d: ", attempt) + format c.log(l, format, a...) } if attempt > 1 { select { case <-ctx.Done(): // Stop retrying return nil default: } } var err error protocol, err = c.connectAttemptAll(ctx, log) return err }, c.config.RetryStrategies()...) if err != nil || ctx.Err() != nil { return nil, ErrNoAvailableLeader } // At this point we should have a connected protocol object, since the // retry loop didn't hit any error and the given context hasn't // expired. if protocol == nil { panic("no protocol object") } c.lt.SetLeaderAddr(protocol.addr) if c.config.PermitShared { protocol.lt = c.lt } return protocol, nil } // connectAttemptAll tries to establish a new connection to the cluster leader. // // First, if the address of the last known leader has been recorded, try // to connect to that server and confirm its leadership. This is a fast path // for stable clusters that avoids opening lots of connections. If that fails, // fall back to probing all servers in parallel, checking whether each // is the leader itself or knows who the leader is. func (c *Connector) connectAttemptAll(ctx context.Context, log logging.Func) (*Protocol, error) { if addr := c.lt.GetLeaderAddr(); addr != "" { // TODO In the event of failure, we could still use the second // return value to guide the next stage of the search. if proto, _, _ := c.connectAttemptOne(ctx, ctx, addr, log); proto != nil { log(logging.Debug, "server %s: connected on fast path", addr) return proto, nil } c.lt.UnsetLeaderAddr() } servers, err := c.store.Get(ctx) if err != nil { return nil, errors.Wrap(err, "get servers") } // Probe voters before standbys before spares. Only voters can potentially // be the leader, and standbys are more likely to know who the leader is // than spares since they participate more in the cluster. sort.Slice(servers, func(i, j int) bool { return servers[i].Role < servers[j].Role }) // The new context will be cancelled when we successfully connect // to the leader. The original context will be used only for net.Dial. // Motivation: threading the cancellation through to net.Dial results // in lots of warnings being logged on remote nodes when our probing // goroutines disconnect during a TLS handshake. origCtx := ctx ctx, cancel := context.WithCancel(ctx) defer cancel() leaderCh := make(chan *Protocol) sem := semaphore.NewWeighted(c.config.ConcurrentLeaderConns) wg := &sync.WaitGroup{} wg.Add(len(servers)) go func() { wg.Wait() close(leaderCh) }() for _, server := range servers { go func(server NodeInfo) { defer wg.Done() if err := sem.Acquire(ctx, 1); err != nil { log(logging.Warn, "server %s: %v", server.Address, err) return } defer sem.Release(1) protocol, leader, err := c.connectAttemptOne(origCtx, ctx, server.Address, log) if err != nil { log(logging.Warn, "server %s: %v", server.Address, err) return } else if protocol != nil { leaderCh <- protocol return } else if leader == "" { log(logging.Warn, "server %s: no known leader", server.Address) return } // Try the server that the original server thinks is the leader. log(logging.Debug, "server %s: connect to reported leader %s", server.Address, leader) protocol, _, err = c.connectAttemptOne(origCtx, ctx, leader, log) if err != nil { log(logging.Warn, "server %s: %v", leader, err) return } else if protocol == nil { log(logging.Warn, "server %s: reported leader server is not the leader", leader) return } leaderCh <- protocol }(server) } leader, ok := <-leaderCh cancel() if !ok { return nil, ErrNoAvailableLeader } log(logging.Debug, "server %s: connected on fallback path", leader.addr) for extra := range leaderCh { extra.Close() } return leader, nil } // Perform the initial handshake using the given protocol version. func Handshake(ctx context.Context, conn net.Conn, version uint64, addr string) (*Protocol, error) { // Latest protocol version. protocol := make([]byte, 8) binary.LittleEndian.PutUint64(protocol, version) // Honor the ctx deadline, if present. if deadline, ok := ctx.Deadline(); ok { conn.SetDeadline(deadline) defer conn.SetDeadline(time.Time{}) } // Perform the protocol handshake. n, err := conn.Write(protocol) if err != nil { return nil, errors.Wrap(err, "write handshake") } if n != 8 { return nil, errors.Wrap(io.ErrShortWrite, "short handshake write") } return &Protocol{conn: conn, version: version, addr: addr}, nil } // Connect to the given dqlite server and check if it's the leader. // // dialCtx is used for net.Dial; ctx is used for all other requests. // // Return values: // // - Any failure is hit: -> nil, "", err // - Target not leader and no leader known: -> nil, "", nil // - Target not leader and leader known: -> nil, leader, nil // - Target is the leader: -> server, "", nil func (c *Connector) connectAttemptOne( dialCtx context.Context, ctx context.Context, address string, origLog logging.Func, ) (*Protocol, string, error) { log := func(l logging.Level, format string, a ...interface{}) { format = fmt.Sprintf("server %s: ", address) + format origLog(l, format, a...) } if ctx.Err() != nil { return nil, "", ctx.Err() } ctx, cancel := context.WithTimeout(ctx, c.config.AttemptTimeout) defer cancel() dialCtx, cancel = context.WithTimeout(dialCtx, c.config.DialTimeout) defer cancel() // Establish the connection. conn, err := c.config.Dial(dialCtx, address) if err != nil { return nil, "", errors.Wrap(err, "dial") } version := VersionOne protocol, err := Handshake(ctx, conn, version, address) if err == errBadProtocol { log(logging.Warn, "unsupported protocol %d, attempt with legacy", version) version = VersionLegacy protocol, err = Handshake(ctx, conn, version, address) } if err != nil { conn.Close() return nil, "", err } leader, err := askLeader(ctx, protocol) if err != nil { protocol.Close() return nil, "", err } switch leader { case "": // Currently this server does not know about any leader. protocol.Close() return nil, "", nil case address: // This server is the leader, register ourselves and return. request := Message{} request.Init(16) response := Message{} response.Init(512) EncodeClient(&request, c.clientID) if err := protocol.Call(ctx, &request, &response); err != nil { protocol.Close() return nil, "", err } _, err := DecodeWelcome(&response) if err != nil { protocol.Close() return nil, "", err } // TODO: enable heartbeat // protocol.heartbeatTimeout = time.Duration(heartbeatTimeout) * time.Millisecond // go protocol.heartbeat() return protocol, "", nil default: // This server claims to know who the current leader is. protocol.Close() return nil, leader, nil } } // TODO move client logic including Leader method to Protocol, // and get rid of this. func askLeader(ctx context.Context, protocol *Protocol) (string, error) { request := Message{} request.Init(16) response := Message{} response.Init(512) EncodeLeader(&request) if err := protocol.Call(ctx, &request, &response); err != nil { cause := errors.Cause(err) // Best-effort detection of a pre-1.0 dqlite node: when sent // version 1 it should close the connection immediately. if err, ok := cause.(*net.OpError); ok && !err.Timeout() || cause == io.EOF { return "", errBadProtocol } return "", err } _, leader, err := DecodeNodeCompat(protocol, &response) if err != nil { return "", err } return leader, nil } var errBadProtocol = fmt.Errorf("bad protocol") golang-github-canonical-go-dqlite-2.0.0/internal/protocol/connector_test.go000066400000000000000000000324561471100661000271130ustar00rootroot00000000000000package protocol_test import ( "context" "fmt" "io/ioutil" "net" "os" "testing" "time" "github.com/canonical/go-dqlite/v2/internal/bindings" "github.com/canonical/go-dqlite/v2/internal/protocol" "github.com/canonical/go-dqlite/v2/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Successful connection. func TestConnector_Success(t *testing.T) { address, cleanup := newNode(t, 0) defer cleanup() store := newStore(t, []string{address}) log, check := newLogFunc(t) connector := protocol.NewLeaderConnector(store, protocol.Config{}, log) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() client, err := connector.Connect(ctx) require.NoError(t, err) assert.NoError(t, client.Close()) check([]string{ "DEBUG: attempt 1: server @test-0: connected on fallback path", }) } // Check the interaction of Connector.Connect with a leader tracker. // // The leader tracker potentially stores two pieces of data, an address and a shared connection. // This gives us four states: INIT (have neither address nor connection), HAVE_ADDR, HAVE_CONN, and HAVE_BOTH. // Transitions between these states are triggered by Connector.Connect and Protocol.Close. // This test methodically triggers all the possible transitions and checks that they have // the intended externally-observable effects. func TestConnector_LeaderTracker(t *testing.T) { // options is a configuration for calling Connector.Connect // in order to trigger a specific state transition. type options struct { injectFailure bool returnProto bool expectedLog []string } injectFailure := func(o *options) { o.injectFailure = true o.expectedLog = append(o.expectedLog, "WARN: attempt 1: server @test-0: context deadline exceeded") } returnProto := func(o *options) { o.returnProto = true } expectDiscard := func(o *options) { o.expectedLog = append(o.expectedLog, "DEBUG: discarding shared connection to @test-0") } expectFallback := func(o *options) { o.expectedLog = append(o.expectedLog, "DEBUG: attempt 1: server @test-0: connected on fallback path") } expectFast := func(o *options) { o.expectedLog = append(o.expectedLog, "DEBUG: attempt 1: server @test-0: connected on fast path") } expectShared := func(o *options) { o.expectedLog = append(o.expectedLog, "DEBUG: reusing shared connection to @test-0") } address, cleanup := newNode(t, 0) defer cleanup() store := newStore(t, []string{address}) log, checkLog := newLogFunc(t) connector := protocol.NewLeaderConnector(store, protocol.Config{RetryLimit: 1, PermitShared: true}, log) check := func(opts ...func(*options)) *protocol.Protocol { o := &options{} for _, opt := range opts { opt(o) } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() if o.injectFailure { ctx, cancel = context.WithDeadline(ctx, time.Unix(1, 0)) defer cancel() } proto, err := connector.Connect(ctx) if o.injectFailure { require.Equal(t, protocol.ErrNoAvailableLeader, err) } else { require.NoError(t, err) } checkLog(o.expectedLog) if o.returnProto { return proto } else if err == nil { assert.NoError(t, proto.Close()) } return nil } // INIT -> INIT check(injectFailure) // INIT -> HAVE_ADDR proto := check(expectFallback, returnProto) proto.Bad() assert.NoError(t, proto.Close()) // HAVE_ADDR -> HAVE_ADDR proto = check(expectFast, returnProto) // We need an extra protocol to trigger INIT->HAVE_CONN later. // Grab one here where it doesn't cause a state transition. protoForLater := check(expectFast, returnProto) // HAVE_ADDR -> HAVE_BOTH assert.NoError(t, proto.Close()) // HAVE_BOTH -> HAVE_ADDR -> HAVE_BOTH check(expectShared) // HAVE_BOTH -> HAVE_ADDR check(expectDiscard, injectFailure) // HAVE_ADDR -> INIT check(injectFailure) // INIT -> HAVE_CONN assert.NoError(t, protoForLater.Close()) // HAVE_CONN -> HAVE_CONN check(expectShared) // HAVE_CONN -> INIT check(expectDiscard, injectFailure) } // The network connection can't be established within the specified number of // attempts. func TestConnector_LimitRetries(t *testing.T) { store := newStore(t, []string{"@test-123"}) config := protocol.Config{ RetryLimit: 2, } log, check := newLogFunc(t) connector := protocol.NewLeaderConnector(store, config, log) _, err := connector.Connect(context.Background()) assert.Equal(t, protocol.ErrNoAvailableLeader, err) check([]string{ "WARN: attempt 1: server @test-123: dial: dial unix @test-123: connect: connection refused", "WARN: attempt 2: server @test-123: dial: dial unix @test-123: connect: connection refused", "WARN: attempt 3: server @test-123: dial: dial unix @test-123: connect: connection refused", }) } // The network connection can't be established because of a connection timeout. func TestConnector_DialTimeout(t *testing.T) { store := newStore(t, []string{"8.8.8.8:9000"}) log, check := newLogFunc(t) config := protocol.Config{ DialTimeout: 50 * time.Millisecond, RetryLimit: 1, } connector := protocol.NewLeaderConnector(store, config, log) _, err := connector.Connect(context.Background()) assert.Equal(t, protocol.ErrNoAvailableLeader, err) check([]string{ "WARN: attempt 1: server 8.8.8.8:9000: dial: dial tcp 8.8.8.8:9000: i/o timeout", "WARN: attempt 2: server 8.8.8.8:9000: dial: dial tcp 8.8.8.8:9000: i/o timeout", }) } // Connection failed because the server store is empty. func TestConnector_EmptyNodeStore(t *testing.T) { store := newStore(t, []string{}) log, check := newLogFunc(t) connector := protocol.NewLeaderConnector(store, protocol.Config{}, log) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) defer cancel() _, err := connector.Connect(ctx) assert.Equal(t, protocol.ErrNoAvailableLeader, err) check([]string{}) } // Connection failed because the context was canceled. func TestConnector_ContextCanceled(t *testing.T) { store := newStore(t, []string{"1.2.3.4:666"}) log, check := newLogFunc(t) connector := protocol.NewLeaderConnector(store, protocol.Config{}, log) ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) defer cancel() _, err := connector.Connect(ctx) assert.Equal(t, protocol.ErrNoAvailableLeader, err) check([]string{ "WARN: attempt 1: server 1.2.3.4:666: dial: dial tcp 1.2.3.4:666: i/o timeout", }) } // Simulate a server which accepts the connection but doesn't reply within the // attempt timeout. func TestConnector_AttemptTimeout(t *testing.T) { listener, err := net.Listen("unix", "@1234") require.NoError(t, err) store := newStore(t, []string{listener.Addr().String()}) config := protocol.Config{ AttemptTimeout: 100 * time.Millisecond, RetryLimit: 1, } connector := protocol.NewLeaderConnector(store, config, logging.Test(t)) var conn net.Conn go func() { conn, err = listener.Accept() require.NoError(t, err) require.NotNil(t, conn) }() defer func() { if conn != nil { _ = conn.Close() } }() _, err = connector.Connect(context.Background()) assert.Equal(t, protocol.ErrNoAvailableLeader, err) } // If an election is in progress, the connector will retry until a leader gets // elected. // func TestConnector_Connect_ElectionInProgress(t *testing.T) { // address1, cleanup := newNode(t, 1) // defer cleanup() // address2, cleanup := newNode(t, 2) // defer cleanup() // address3, cleanup := newNode(t, 3) // defer cleanup() // store := newStore(t, []string{address1, address2, address3}) // connector := newConnector(t, store) // go func() { // // Simulate server 1 winning the election after 10ms // time.Sleep(10 * time.Millisecond) // }() // ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) // defer cancel() // client, err := connector.Connect(ctx) // require.NoError(t, err) // assert.NoError(t, client.Close()) // } // If a server reports that it knows about the leader, the hint will be taken // and an attempt will be made to connect to it. // func TestConnector_Connect_NodeKnowsAboutLeader(t *testing.T) { // defer bindings.AssertNoMemoryLeaks(t) // methods1 := &testClusterMethods{} // methods2 := &testClusterMethods{} // methods3 := &testClusterMethods{} // address1, cleanup := newNode(t, 1, methods1) // defer cleanup() // address2, cleanup := newNode(t, 2, methods2) // defer cleanup() // address3, cleanup := newNode(t, 3, methods3) // defer cleanup() // // Node 1 will be contacted first, which will report that server 2 is // // the leader. // store := newStore(t, []string{address1, address2, address3}) // methods1.leader = address2 // methods2.leader = address2 // methods3.leader = address2 // connector := newConnector(t, store) // ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) // defer cancel() // client, err := connector.Connect(ctx) // require.NoError(t, err) // assert.NoError(t, client.Close()) // } // If a server reports that it knows about the leader, the hint will be taken // and an attempt will be made to connect to it. If that leader has died, the // next target will be tried. // func TestConnector_Connect_NodeKnowsAboutDeadLeader(t *testing.T) { // defer bindings.AssertNoMemoryLeaks(t) // methods1 := &testClusterMethods{} // methods2 := &testClusterMethods{} // methods3 := &testClusterMethods{} // address1, cleanup := newNode(t, 1, methods1) // defer cleanup() // address2, cleanup := newNode(t, 2, methods2) // // Simulate server 2 crashing. // cleanup() // address3, cleanup := newNode(t, 3, methods3) // defer cleanup() // // Node 1 will be contacted first, which will report that server 2 is // // the leader. However server 2 has crashed, and after a bit server 1 // // gets elected. // store := newStore(t, []string{address1, address2, address3}) // methods1.leader = address2 // methods3.leader = address2 // go func() { // // Simulate server 1 becoming the new leader after server 2 // // crashed. // time.Sleep(10 * time.Millisecond) // methods1.leader = address1 // methods3.leader = address1 // }() // connector := newConnector(t, store) // ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) // defer cancel() // client, err := connector.Connect(ctx) // require.NoError(t, err) // assert.NoError(t, client.Close()) // } // If a server reports that it knows about the leader, the hint will be taken // and an attempt will be made to connect to it. If that leader is not actually // the leader the next target will be tried. // func TestConnector_Connect_NodeKnowsAboutStaleLeader(t *testing.T) { // defer bindings.AssertNoMemoryLeaks(t) // methods1 := &testClusterMethods{} // methods2 := &testClusterMethods{} // methods3 := &testClusterMethods{} // address1, cleanup := newNode(t, 1, methods1) // defer cleanup() // address2, cleanup := newNode(t, 2, methods2) // defer cleanup() // address3, cleanup := newNode(t, 3, methods3) // defer cleanup() // // Node 1 will be contacted first, which will report that server 2 is // // the leader. However server 2 thinks that 3 is the leader, and server // // 3 is actually the leader. // store := newStore(t, []string{address1, address2, address3}) // methods1.leader = address2 // methods2.leader = address3 // methods3.leader = address3 // connector := newConnector(t, store) // ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) // defer cancel() // client, err := connector.Connect(ctx) // require.NoError(t, err) // assert.NoError(t, client.Close()) // } // Return a log function that emits messages using the test logger as well as // collecting them into a slice. The second function returned can be used to // assert that the collected messages match the given ones. func newLogFunc(t *testing.T) (logging.Func, func([]string)) { messages := []string{} log := func(l logging.Level, format string, a ...interface{}) { message := l.String() + ": " + fmt.Sprintf(format, a...) messages = append(messages, message) t.Log(message) } check := func(expected []string) { assert.Equal(t, expected, messages) messages = messages[:0] } return log, check } // Create a new in-memory server store populated with the given addresses. func newStore(t *testing.T, addresses []string) protocol.NodeStore { t.Helper() servers := make([]protocol.NodeInfo, len(addresses)) for i, address := range addresses { servers[i].ID = uint64(i) servers[i].Address = address } store := protocol.NewInmemNodeStore() require.NoError(t, store.Set(context.Background(), servers)) return store } func newNode(t *testing.T, index int) (string, func()) { t.Helper() id := uint64(index + 1) dir, dirCleanup := newDir(t) address := fmt.Sprintf("@test-%d", index) server, err := bindings.NewNode(context.Background(), id, address, dir) require.NoError(t, err) err = server.SetBindAddress(address) require.NoError(t, err) require.NoError(t, server.Start()) cleanup := func() { require.NoError(t, server.Stop()) server.Close() dirCleanup() } return address, cleanup } // Return a new temporary directory. func newDir(t *testing.T) (string, func()) { t.Helper() dir, err := ioutil.TempDir("", "dqlite-connector-test-") assert.NoError(t, err) cleanup := func() { _, err := os.Stat(dir) if err != nil { assert.True(t, os.IsNotExist(err)) } else { assert.NoError(t, os.RemoveAll(dir)) } } return dir, cleanup } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/constants.go000066400000000000000000000053771471100661000261000ustar00rootroot00000000000000package protocol // VersionOne is version 1 of the server protocol. const VersionOne = uint64(1) // VersionLegacy is the pre 1.0 dqlite server protocol version. const VersionLegacy = uint64(0x86104dd760433fe5) // Cluster response formats const ( ClusterFormatV0 = 0 ClusterFormatV1 = 1 ) // Node roles const ( Voter = NodeRole(0) StandBy = NodeRole(1) Spare = NodeRole(2) ) // SQLite datatype codes const ( Integer = 1 Float = 2 Text = 3 Blob = 4 Null = 5 ) // Special data types for time values. const ( UnixTime = 9 ISO8601 = 10 Boolean = 11 ) // Request types. const ( RequestLeader = 0 RequestClient = 1 RequestHeartbeat = 2 RequestOpen = 3 RequestPrepare = 4 RequestExec = 5 RequestQuery = 6 RequestFinalize = 7 RequestExecSQL = 8 RequestQuerySQL = 9 RequestInterrupt = 10 RequestAdd = 12 RequestAssign = 13 RequestRemove = 14 RequestDump = 15 RequestCluster = 16 RequestTransfer = 17 RequestDescribe = 18 RequestWeight = 19 ) // Formats const ( RequestDescribeFormatV0 = 0 ) // Response types. const ( ResponseFailure = 0 ResponseNode = 1 ResponseNodeLegacy = 1 ResponseWelcome = 2 ResponseNodes = 3 ResponseDb = 4 ResponseStmt = 5 ResponseResult = 6 ResponseRows = 7 ResponseEmpty = 8 ResponseFiles = 9 ResponseMetadata = 10 ) // Human-readable description of a request type. func requestDesc(code uint8) string { switch code { // Requests case RequestLeader: return "leader" case RequestClient: return "client" case RequestHeartbeat: return "heartbeat" case RequestOpen: return "open" case RequestPrepare: return "prepare" case RequestExec: return "exec" case RequestQuery: return "query" case RequestFinalize: return "finalize" case RequestExecSQL: return "exec-sql" case RequestQuerySQL: return "query-sql" case RequestInterrupt: return "interrupt" case RequestAdd: return "add" case RequestAssign: return "assign" case RequestRemove: return "remove" case RequestDump: return "dump" case RequestCluster: return "cluster" case RequestTransfer: return "transfer" case RequestDescribe: return "describe" } return "unknown" } // Human-readable description of a response type. func responseDesc(code uint8) string { switch code { case ResponseFailure: return "failure" case ResponseNode: return "node" case ResponseWelcome: return "welcome" case ResponseNodes: return "nodes" case ResponseDb: return "db" case ResponseStmt: return "stmt" case ResponseResult: return "result" case ResponseRows: return "rows" case ResponseEmpty: return "empty" case ResponseFiles: return "files" case ResponseMetadata: return "metadata" } return "unknown" } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/dial.go000066400000000000000000000005241471100661000247620ustar00rootroot00000000000000package protocol import ( "context" "net" "strings" ) // Dial function handling plain TCP and Unix socket endpoints. func Dial(ctx context.Context, address string) (net.Conn, error) { family := "tcp" if strings.HasPrefix(address, "@") { family = "unix" } dialer := net.Dialer{} return dialer.DialContext(ctx, family, address) } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/errors.go000066400000000000000000000013701471100661000253650ustar00rootroot00000000000000package protocol import ( "fmt" ) // Client errors. var ( ErrNoAvailableLeader = fmt.Errorf("no available dqlite leader server found") errNegativeRead = fmt.Errorf("reader returned negative count from Read") ) // ErrRequest is returned in case of request failure. type ErrRequest struct { Code uint64 Description string } func (e ErrRequest) Error() string { return fmt.Sprintf("%s (%d)", e.Description, e.Code) } // ErrRowsPart is returned when the first batch of a multi-response result // batch is done. var ErrRowsPart = fmt.Errorf("not all rows were returned in this response") // Error holds information about a SQLite error. type Error struct { Code int Message string } func (e Error) Error() string { return e.Message } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/message.go000066400000000000000000000346221471100661000255030ustar00rootroot00000000000000package protocol import ( "bytes" "database/sql/driver" "encoding/binary" "fmt" "io" "math" "strings" "time" ) // NamedValues is a type alias of a slice of driver.NamedValue. It's used by // schema.sh to generate encoding logic for statement parameters. type NamedValues = []driver.NamedValue type NamedValues32 = []driver.NamedValue // Nodes is a type alias of a slice of NodeInfo. It's used by schema.sh to // generate decoding logic for the heartbeat response. type Nodes []NodeInfo // Message holds data about a single request or response. type Message struct { words uint32 mtype uint8 schema uint8 extra uint16 header []byte // Statically allocated header buffer body buffer // Message body data. } // Init initializes the message using the given initial size for the data // buffer, which is re-used across requests or responses encoded or decoded // using this message object. func (m *Message) Init(initialBufferSize int) { if (initialBufferSize % messageWordSize) != 0 { panic("initial buffer size is not aligned to word boundary") } m.header = make([]byte, messageHeaderSize) m.body.Bytes = make([]byte, initialBufferSize) m.reset() } // Reset the state of the message so it can be used to encode or decode again. func (m *Message) reset() { m.words = 0 m.mtype = 0 m.schema = 0 m.extra = 0 for i := 0; i < messageHeaderSize; i++ { m.header[i] = 0 } m.body.Offset = 0 } // Append a byte slice to the message. func (m *Message) putBlob(v []byte) { size := len(v) m.putUint64(uint64(size)) pad := 0 if (size % messageWordSize) != 0 { // Account for padding pad = messageWordSize - (size % messageWordSize) size += pad } b := m.bufferForPut(size) defer b.Advance(size) // Copy the bytes into the buffer. offset := b.Offset copy(b.Bytes[offset:], v) offset += len(v) // Add padding for i := 0; i < pad; i++ { b.Bytes[offset] = 0 offset++ } } // Append a string to the message. func (m *Message) putString(v string) { size := len(v) + 1 pad := 0 if (size % messageWordSize) != 0 { // Account for padding pad = messageWordSize - (size % messageWordSize) size += pad } b := m.bufferForPut(size) defer b.Advance(size) // Copy the string bytes into the buffer. offset := b.Offset copy(b.Bytes[offset:], v) offset += len(v) // Add a nul byte b.Bytes[offset] = 0 offset++ // Add padding for i := 0; i < pad; i++ { b.Bytes[offset] = 0 offset++ } } // Append a byte to the message. func (m *Message) putUint8(v uint8) { b := m.bufferForPut(1) defer b.Advance(1) b.Bytes[b.Offset] = v } // Append a 2-byte word to the message. func (m *Message) putUint16(v uint16) { b := m.bufferForPut(2) defer b.Advance(2) binary.LittleEndian.PutUint16(b.Bytes[b.Offset:], v) } // Append a 4-byte word to the message. func (m *Message) putUint32(v uint32) { b := m.bufferForPut(4) defer b.Advance(4) binary.LittleEndian.PutUint32(b.Bytes[b.Offset:], v) } // Append an 8-byte word to the message. func (m *Message) putUint64(v uint64) { b := m.bufferForPut(8) defer b.Advance(8) binary.LittleEndian.PutUint64(b.Bytes[b.Offset:], v) } // Append a signed 8-byte word to the message. func (m *Message) putInt64(v int64) { b := m.bufferForPut(8) defer b.Advance(8) binary.LittleEndian.PutUint64(b.Bytes[b.Offset:], uint64(v)) } // Append a floating point number to the message. func (m *Message) putFloat64(v float64) { b := m.bufferForPut(8) defer b.Advance(8) binary.LittleEndian.PutUint64(b.Bytes[b.Offset:], math.Float64bits(v)) } func (m *Message) putNamedValuesInner(values NamedValues) { for i := range values { if values[i].Ordinal != i+1 { panic("unexpected ordinal") } switch values[i].Value.(type) { case int64: m.putUint8(Integer) case float64: m.putUint8(Float) case bool: m.putUint8(Boolean) case []byte: m.putUint8(Blob) case string: m.putUint8(Text) case nil: m.putUint8(Null) case time.Time: m.putUint8(ISO8601) default: panic("unsupported value type") } } b := m.bufferForPut(1) if trailing := b.Offset % messageWordSize; trailing != 0 { // Skip padding bytes b.Advance(messageWordSize - trailing) } for i := range values { switch v := values[i].Value.(type) { case int64: m.putInt64(v) case float64: m.putFloat64(v) case bool: if v { m.putUint64(1) } else { m.putUint64(0) } case []byte: m.putBlob(v) case string: m.putString(v) case nil: m.putInt64(0) case time.Time: timestamp := v.Format(iso8601Formats[0]) m.putString(timestamp) default: panic("unsupported value type") } } } // Encode the given driver values as binding parameters. func (m *Message) putNamedValues(values NamedValues) { l := len(values) if l == 0 { return } else if l > math.MaxUint8 { // safeguard, should have been checked beforehand. panic("too many parameters") } n := uint8(l) m.putUint8(n) m.putNamedValuesInner(values) } // Encode the given driver values as binding parameters, with a 32-bit // parameter count (new format). func (m *Message) putNamedValues32(values NamedValues) { l := len(values) if l == 0 { return } else if int64(l) > math.MaxUint32 { // safeguard, should have been checked beforehand. panic("too many parameters") } n := uint32(l) m.putUint32(n) m.putNamedValuesInner(values) } // Finalize the message by setting the message type and the number // of words in the body (calculated from the body size). func (m *Message) putHeader(mtype, schema uint8) { if m.body.Offset <= 0 { panic("static offset is not positive") } if (m.body.Offset % messageWordSize) != 0 { panic("static body is not aligned") } m.mtype = mtype m.schema = schema m.extra = 0 m.words = uint32(m.body.Offset) / messageWordSize m.finalize() } func (m *Message) finalize() { if m.words == 0 { panic("empty message body") } binary.LittleEndian.PutUint32(m.header[0:], m.words) m.header[4] = m.mtype m.header[5] = m.schema binary.LittleEndian.PutUint16(m.header[6:], m.extra) } func (m *Message) bufferForPut(size int) *buffer { for (m.body.Offset + size) > len(m.body.Bytes) { // Grow message buffer. bytes := make([]byte, len(m.body.Bytes)*2) copy(bytes, m.body.Bytes) m.body.Bytes = bytes } return &m.body } // Return the message type and its schema version. func (m *Message) getHeader() (uint8, uint8) { return m.mtype, m.schema } // Read a string from the message body. func (m *Message) getString() string { b := m.bufferForGet() index := bytes.IndexByte(b.Bytes[b.Offset:], 0) if index == -1 { panic("no string found") } s := string(b.Bytes[b.Offset : b.Offset+index]) index++ if trailing := index % messageWordSize; trailing != 0 { // Account for padding, moving index to the next word boundary. index += messageWordSize - trailing } b.Advance(index) return s } func (m *Message) getBlob() []byte { size := m.getUint64() data := make([]byte, size) b := m.bufferForGet() defer b.Advance(int(alignUp(size, messageWordSize))) copy(data, b.Bytes[b.Offset:]) return data } // Read a byte from the message body. func (m *Message) getUint8() uint8 { b := m.bufferForGet() defer b.Advance(1) return b.Bytes[b.Offset] } // Read a 4-byte word from the message body. func (m *Message) getUint32() uint32 { b := m.bufferForGet() defer b.Advance(4) return binary.LittleEndian.Uint32(b.Bytes[b.Offset:]) } // Read reads an 8-byte word from the message body. func (m *Message) getUint64() uint64 { b := m.bufferForGet() defer b.Advance(8) return binary.LittleEndian.Uint64(b.Bytes[b.Offset:]) } // Read a signed 8-byte word from the message body. func (m *Message) getInt64() int64 { b := m.bufferForGet() defer b.Advance(8) return int64(binary.LittleEndian.Uint64(b.Bytes[b.Offset:])) } // Read a floating point number from the message body. func (m *Message) getFloat64() float64 { b := m.bufferForGet() defer b.Advance(8) return math.Float64frombits(binary.LittleEndian.Uint64(b.Bytes[b.Offset:])) } // Decode a list of server objects from the message body. func (m *Message) getNodes() Nodes { n := m.getUint64() servers := make(Nodes, n) for i := 0; i < int(n); i++ { servers[i].ID = m.getUint64() servers[i].Address = m.getString() servers[i].Role = NodeRole(m.getUint64()) } return servers } // Decode a statement result object from the message body. func (m *Message) getResult() Result { return Result{ LastInsertID: m.getUint64(), RowsAffected: m.getUint64(), } } // Decode a query result set object from the message body. func (m *Message) getRows() Rows { // Read the column count and column names. columns := make([]string, m.getUint64()) for i := range columns { columns[i] = m.getString() } rows := Rows{ Columns: columns, message: m, } return rows } func (m *Message) getFiles() Files { files := Files{ n: m.getUint64(), message: m, } return files } func (m *Message) hasBeenConsumed() bool { size := int(m.words * messageWordSize) return m.body.Offset == size } func (m *Message) lastByte() byte { size := int(m.words * messageWordSize) return m.body.Bytes[size-1] } func (m *Message) bufferForGet() *buffer { size := int(m.words * messageWordSize) // The static body has been exahusted, use the dynamic one. if m.body.Offset == size { err := fmt.Errorf("short message: type=%d words=%d off=%d", m.mtype, m.words, m.body.Offset) panic(err) } return &m.body } // Result holds the result of a statement. type Result struct { LastInsertID uint64 RowsAffected uint64 } // Rows holds a result set encoded in a message body. type Rows struct { Columns []string message *Message types []uint8 } // columnTypes returns the row's column types // if save is true, it will restore the buffer offset func (r *Rows) columnTypes(save bool) ([]uint8, error) { // use cached values if possible if not advancing the buffer offset if save && r.types != nil { return r.types, nil } // column types should never change between rows // use cached copy to allow getting types when no more rows if r.types == nil { r.types = make([]uint8, len(r.Columns)) } // If there are zero columns, no rows can be encoded or decoded, // so we signal EOF immediately. if len(r.types) == 0 { return r.types, io.EOF } // Each column needs a 4 byte slot to store the column type. The row // header must be padded to reach word boundary. headerBits := len(r.types) * 4 padBits := 0 if trailingBits := (headerBits % messageWordBits); trailingBits != 0 { padBits = (messageWordBits - trailingBits) } headerSize := (headerBits + padBits) / messageWordBits * messageWordSize for i := 0; i < headerSize; i++ { slot := r.message.getUint8() if slot == 0xee { // More rows are available. if save { r.message.bufferForGet().Advance(-(i + 1)) } return r.types, ErrRowsPart } if slot == 0xff { // Rows EOF marker if save { r.message.bufferForGet().Advance(-(i + 1)) } return r.types, io.EOF } index := i * 2 if index >= len(r.types) { continue // This is padding. } r.types[index] = slot & 0x0f index++ if index >= len(r.types) { continue // This is padding byte. } r.types[index] = slot >> 4 } if save { r.message.bufferForGet().Advance(-headerSize) } return r.types, nil } // Next returns the next row in the result set. func (r *Rows) Next(dest []driver.Value) error { types, err := r.columnTypes(false) if err != nil { return err } for i := range types { switch types[i] { case Integer: dest[i] = r.message.getInt64() case Float: dest[i] = r.message.getFloat64() case Blob: dest[i] = r.message.getBlob() case Text: dest[i] = r.message.getString() case Null: r.message.getUint64() dest[i] = nil case UnixTime: timestamp := time.Unix(r.message.getInt64(), 0) dest[i] = timestamp case ISO8601: value := r.message.getString() if value == "" { dest[i] = nil break } var t time.Time var timeVal time.Time var err error value = strings.TrimSuffix(value, "Z") for _, format := range iso8601Formats { if timeVal, err = time.ParseInLocation(format, value, time.UTC); err == nil { t = timeVal break } } if err != nil { return err } dest[i] = t case Boolean: dest[i] = r.message.getInt64() != 0 default: panic("unknown data type") } } return nil } // Close the result set and reset the underlying message. func (r *Rows) Close() error { // If we didn't go through all rows, let's look at the last byte. var err error if !r.message.hasBeenConsumed() { slot := r.message.lastByte() if slot == 0xee { // More rows are available. err = ErrRowsPart } else if slot == 0xff { // Rows EOF marker err = io.EOF } else { err = fmt.Errorf("unexpected end of message") } } r.message.reset() return err } // Files holds a set of files encoded in a message body. type Files struct { n uint64 message *Message } func (f *Files) Next() (string, []byte) { if f.n == 0 { return "", nil } f.n-- name := f.message.getString() length := f.message.getUint64() data := make([]byte, length) for i := 0; i < int(length); i++ { data[i] = f.message.getUint8() } return name, data } func (f *Files) Close() { f.message.reset() } const ( messageWordSize = 8 messageWordBits = messageWordSize * 8 messageHeaderSize = messageWordSize messageMaxConsecutiveEmptyReads = 100 ) var iso8601Formats = []string{ // By default, store timestamps with whatever timezone they come with. // When parsed, they will be returned with the same timezone. "2006-01-02 15:04:05.999999999-07:00", "2006-01-02T15:04:05.999999999-07:00", "2006-01-02 15:04:05.999999999", "2006-01-02T15:04:05.999999999", "2006-01-02 15:04:05", "2006-01-02T15:04:05", "2006-01-02 15:04", "2006-01-02T15:04", "2006-01-02", } // ColumnTypes returns the column types for the the result set. func (r *Rows) ColumnTypes() ([]string, error) { types, err := r.columnTypes(true) kinds := make([]string, len(types)) for i, t := range types { switch t { case Integer: kinds[i] = "INTEGER" case Float: kinds[i] = "FLOAT" case Blob: kinds[i] = "BLOB" case Text: kinds[i] = "TEXT" case Null: kinds[i] = "NULL" case UnixTime: kinds[i] = "TIME" case ISO8601: kinds[i] = "TIME" case Boolean: kinds[i] = "BOOL" default: return nil, fmt.Errorf("unknown data type: %d", t) } } return kinds, err } // alignUp rounds n up to a multiple of a. a must be a power of 2. func alignUp(n, a uint64) uint64 { return (n + a - 1) &^ (a - 1) } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/message_export_test.go000066400000000000000000000002241471100661000301320ustar00rootroot00000000000000package protocol func (m *Message) Body() ([]byte, int) { return m.body.Bytes, m.body.Offset } func (m *Message) Rewind() { m.body.Offset = 0 } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/message_internal_test.go000066400000000000000000000164201471100661000304320ustar00rootroot00000000000000package protocol import ( "fmt" "testing" "time" "unsafe" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMessage_StaticBytesAlignment(t *testing.T) { message := Message{} message.Init(4096) pointer := uintptr(unsafe.Pointer(&message.body.Bytes[0])) assert.Equal(t, uintptr(0), pointer%messageWordSize) } func TestMessage_putBlob(t *testing.T) { cases := []struct { Blob []byte Offset int }{ {[]byte{1, 2, 3, 4, 5}, 16}, {[]byte{1, 2, 3, 4, 5, 6, 7, 8}, 16}, {[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 24}, } message := Message{} message.Init(64) for _, c := range cases { t.Run(fmt.Sprintf("%d", c.Offset), func(t *testing.T) { message.putBlob(c.Blob) bytes, offset := message.Body() assert.Equal(t, bytes[8:len(c.Blob)+8], c.Blob) assert.Equal(t, offset, c.Offset) message.reset() }) } } func TestMessage_putString(t *testing.T) { cases := []struct { String string Offset int }{ {"hello", 8}, {"hello!!", 8}, {"hello world", 16}, } message := Message{} message.Init(16) for _, c := range cases { t.Run(c.String, func(t *testing.T) { message.putString(c.String) bytes, offset := message.Body() assert.Equal(t, string(bytes[:len(c.String)]), c.String) assert.Equal(t, offset, c.Offset) message.reset() }) } } func TestMessage_putUint8(t *testing.T) { message := Message{} message.Init(8) v := uint8(12) message.putUint8(v) bytes, offset := message.Body() assert.Equal(t, bytes[0], byte(v)) assert.Equal(t, offset, 1) } func TestMessage_putUint16(t *testing.T) { message := Message{} message.Init(8) v := uint16(666) message.putUint16(v) bytes, offset := message.Body() assert.Equal(t, bytes[0], byte((v & 0x00ff))) assert.Equal(t, bytes[1], byte((v&0xff00)>>8)) assert.Equal(t, offset, 2) } func TestMessage_putUint32(t *testing.T) { message := Message{} message.Init(8) v := uint32(130000) message.putUint32(v) bytes, offset := message.Body() assert.Equal(t, bytes[0], byte((v & 0x000000ff))) assert.Equal(t, bytes[1], byte((v&0x0000ff00)>>8)) assert.Equal(t, bytes[2], byte((v&0x00ff0000)>>16)) assert.Equal(t, bytes[3], byte((v&0xff000000)>>24)) assert.Equal(t, offset, 4) } func TestMessage_putUint64(t *testing.T) { message := Message{} message.Init(8) v := uint64(5000000000) message.putUint64(v) bytes, offset := message.Body() assert.Equal(t, bytes[0], byte((v & 0x00000000000000ff))) assert.Equal(t, bytes[1], byte((v&0x000000000000ff00)>>8)) assert.Equal(t, bytes[2], byte((v&0x0000000000ff0000)>>16)) assert.Equal(t, bytes[3], byte((v&0x00000000ff000000)>>24)) assert.Equal(t, bytes[4], byte((v&0x000000ff00000000)>>32)) assert.Equal(t, bytes[5], byte((v&0x0000ff0000000000)>>40)) assert.Equal(t, bytes[6], byte((v&0x00ff000000000000)>>48)) assert.Equal(t, bytes[7], byte((v&0xff00000000000000)>>56)) assert.Equal(t, offset, 8) } func TestMessage_putNamedValues(t *testing.T) { message := Message{} message.Init(256) timestamp, err := time.ParseInLocation("2006-01-02", "2018-08-01", time.UTC) require.NoError(t, err) values := NamedValues{ {Ordinal: 1, Value: int64(123)}, {Ordinal: 2, Value: float64(3.1415)}, {Ordinal: 3, Value: true}, {Ordinal: 4, Value: []byte{1, 2, 3, 4, 5, 6}}, {Ordinal: 5, Value: "hello"}, {Ordinal: 6, Value: nil}, {Ordinal: 7, Value: timestamp}, } message.putNamedValues(values) bytes, offset := message.Body() assert.Equal(t, 96, offset) assert.Equal(t, bytes[0], byte(7)) assert.Equal(t, bytes[1], byte(Integer)) assert.Equal(t, bytes[2], byte(Float)) assert.Equal(t, bytes[3], byte(Boolean)) assert.Equal(t, bytes[4], byte(Blob)) assert.Equal(t, bytes[5], byte(Text)) assert.Equal(t, bytes[6], byte(Null)) assert.Equal(t, bytes[7], byte(ISO8601)) } func TestMessage_putNamedValues32(t *testing.T) { message := Message{} message.Init(256) timestamp, err := time.ParseInLocation("2006-01-02", "2018-08-01", time.UTC) require.NoError(t, err) values := NamedValues{ {Ordinal: 1, Value: int64(123)}, {Ordinal: 2, Value: float64(3.1415)}, {Ordinal: 3, Value: true}, {Ordinal: 4, Value: []byte{1, 2, 3, 4, 5, 6}}, {Ordinal: 5, Value: "hello"}, {Ordinal: 6, Value: nil}, {Ordinal: 7, Value: timestamp}, } message.putNamedValues32(values) bytes, offset := message.Body() assert.Equal(t, 104, offset) assert.Equal(t, bytes[0], byte(7)) assert.Equal(t, bytes[1], byte(0)) assert.Equal(t, bytes[2], byte(0)) assert.Equal(t, bytes[3], byte(0)) assert.Equal(t, bytes[4], byte(Integer)) assert.Equal(t, bytes[5], byte(Float)) assert.Equal(t, bytes[6], byte(Boolean)) assert.Equal(t, bytes[7], byte(Blob)) assert.Equal(t, bytes[8], byte(Text)) assert.Equal(t, bytes[9], byte(Null)) assert.Equal(t, bytes[10], byte(ISO8601)) } func TestMessage_putHeader(t *testing.T) { message := Message{} message.Init(64) message.putString("hello") message.putHeader(RequestExec, 1) } func BenchmarkMessage_putString(b *testing.B) { message := Message{} message.Init(4096) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { message.reset() message.putString("hello") } } func BenchmarkMessage_putUint64(b *testing.B) { message := Message{} message.Init(4096) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { message.reset() message.putUint64(270) } } func TestMessage_getString(t *testing.T) { cases := []struct { String string Offset int }{ {"hello", 8}, {"hello!!", 8}, {"hello!!!", 16}, {"hello world", 16}, } for _, c := range cases { t.Run(c.String, func(t *testing.T) { message := Message{} message.Init(16) message.putString(c.String) message.putHeader(0, 0) message.Rewind() s := message.getString() _, offset := message.Body() assert.Equal(t, s, c.String) assert.Equal(t, offset, c.Offset) }) } } func TestMessage_getBlob(t *testing.T) { cases := []struct { Blob []byte Offset int }{ {[]byte{1, 2, 3, 4, 5}, 16}, {[]byte{1, 2, 3, 4, 5, 6, 7, 8}, 16}, {[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 24}, } for _, c := range cases { t.Run(fmt.Sprintf("%d", c.Offset), func(t *testing.T) { message := Message{} message.Init(64) message.putBlob(c.Blob) message.putHeader(0, 0) message.Rewind() bytes := message.getBlob() _, offset := message.Body() assert.Equal(t, bytes, c.Blob) assert.Equal(t, offset, c.Offset) }) } } func BenchmarkMessage_getBlob(b *testing.B) { makeBlob := func(size int) []byte { blob := make([]byte, size) for i := range blob { blob[i] = byte(i) } return blob } for _, size := range []int{16, 64, 256, 1024, 4096, 8096} { b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { message := Message{} message.Init(size + 16) message.putBlob(makeBlob(size)) message.putHeader(0, 0) for i := 0; i < b.N; i++ { message.Rewind() _ = message.getBlob() } }) } } // The overflowing string ends exactly at word boundary. func TestMessage_getString_Overflow_WordBoundary(t *testing.T) { message := Message{} message.Init(8) message.putBlob([]byte{ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 0, 0, 0, 0, 0, 0, 0, }) message.putHeader(0, 0) message.Rewind() message.getUint64() s := message.getString() assert.Equal(t, "abcdefghilmnopqr", s) assert.Equal(t, 32, message.body.Offset) } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/protocol.go000066400000000000000000000130131471100661000257070ustar00rootroot00000000000000package protocol import ( "context" "encoding/binary" "io" "net" "sync" "time" "github.com/pkg/errors" ) // Protocol sends and receive the dqlite message on the wire. type Protocol struct { version uint64 // Protocol version conn net.Conn // Underlying network connection. mu sync.Mutex // Serialize requests netErr error // A network error occurred addr string lt *LeaderTracker } // Call invokes a dqlite RPC, sending a request message and receiving a // response message. func (p *Protocol) Call(ctx context.Context, request, response *Message) (err error) { // We need to take a lock since the dqlite server currently does not // support concurrent requests. p.mu.Lock() defer p.mu.Unlock() if err = p.netErr; err != nil { return } defer func() { if err == nil { return } p.Bad() switch errors.Cause(err).(type) { case *net.OpError: p.netErr = err } }() var budget time.Duration // Honor the ctx deadline, if present. if deadline, ok := ctx.Deadline(); ok { p.conn.SetDeadline(deadline) budget = time.Until(deadline) defer p.conn.SetDeadline(time.Time{}) } desc := requestDesc(request.mtype) if err = p.send(request); err != nil { return errors.Wrapf(err, "call %s (budget %s): send", desc, budget) } if err = p.recv(response); err != nil { return errors.Wrapf(err, "call %s (budget %s): receive", desc, budget) } return } // More is used when a request maps to multiple responses. func (p *Protocol) More(ctx context.Context, response *Message) (err error) { if err = p.recv(response); err != nil { p.Bad() } return } // Interrupt sends an interrupt request and awaits for the server's empty // response. func (p *Protocol) Interrupt(ctx context.Context, request *Message, response *Message) (err error) { // We need to take a lock since the dqlite server currently does not // support concurrent requests. p.mu.Lock() defer p.mu.Unlock() // Honor the ctx deadline, if present. if deadline, ok := ctx.Deadline(); ok { p.conn.SetDeadline(deadline) defer p.conn.SetDeadline(time.Time{}) } EncodeInterrupt(request, 0) defer func() { if err != nil { p.Bad() } }() if err = p.send(request); err != nil { return errors.Wrap(err, "failed to send interrupt request") } for { if err = p.recv(response); err != nil { return errors.Wrap(err, "failed to receive response") } mtype, _ := response.getHeader() if mtype == ResponseEmpty { break } } return nil } // Bad prevents a protocol from being reused when it is released. // // There is no need to call Bad after a method of Protocol returns an error. // Only call Bad when the protocol is deemed unsuitable for reuse for some // higher-level reason. func (p *Protocol) Bad() { p.lt = nil } // Close releases a protocol. // // If the protocol was associated with a LeaderTracker, it will be made // available for reuse by that tracker. Otherwise, the underlying connection // will be closed. func (p *Protocol) Close() error { if tr := p.lt; tr == nil || !tr.DonateSharedProtocol(p) { return p.conn.Close() } return nil } func (p *Protocol) send(req *Message) error { if err := p.sendHeader(req); err != nil { return errors.Wrap(err, "header") } if err := p.sendBody(req); err != nil { return errors.Wrap(err, "body") } return nil } func (p *Protocol) sendHeader(req *Message) error { n, err := p.conn.Write(req.header[:]) if err != nil { return err } if n != messageHeaderSize { return io.ErrShortWrite } return nil } func (p *Protocol) sendBody(req *Message) error { buf := req.body.Bytes[:req.body.Offset] n, err := p.conn.Write(buf) if err != nil { return err } if n != len(buf) { return io.ErrShortWrite } return nil } func (p *Protocol) recv(res *Message) error { res.reset() if err := p.recvHeader(res); err != nil { return errors.Wrap(err, "header") } if err := p.recvBody(res); err != nil { return errors.Wrap(err, "body") } return nil } func (p *Protocol) recvHeader(res *Message) error { if err := p.recvPeek(res.header); err != nil { return err } res.words = binary.LittleEndian.Uint32(res.header[0:]) res.mtype = res.header[4] res.schema = res.header[5] res.extra = binary.LittleEndian.Uint16(res.header[6:]) return nil } func (p *Protocol) recvBody(res *Message) error { n := int(res.words) * messageWordSize for n > len(res.body.Bytes) { // Grow message buffer. bytes := make([]byte, len(res.body.Bytes)*2) res.body.Bytes = bytes } buf := res.body.Bytes[:n] if err := p.recvPeek(buf); err != nil { return err } return nil } // Read until buf is full. func (p *Protocol) recvPeek(buf []byte) error { for offset := 0; offset < len(buf); { n, err := p.recvFill(buf[offset:]) if err != nil { return err } offset += n } return nil } // Try to fill buf, but perform at most one read. func (p *Protocol) recvFill(buf []byte) (int, error) { // Read new data: try a limited number of times. // // This technique is copied from bufio.Reader. for i := messageMaxConsecutiveEmptyReads; i > 0; i-- { n, err := p.conn.Read(buf) if n < 0 { panic(errNegativeRead) } if err != nil { return -1, err } if n > 0 { return n, nil } } return -1, io.ErrNoProgress } // DecodeNodeCompat handles also pre-1.0 legacy server messages. func DecodeNodeCompat(protocol *Protocol, response *Message) (uint64, string, error) { if protocol.version == VersionLegacy { address, err := DecodeNodeLegacy(response) if err != nil { return 0, "", err } return 0, address, nil } return DecodeNode(response) } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/protocol_test.go000066400000000000000000000107251471100661000267550ustar00rootroot00000000000000package protocol_test import ( "context" "testing" "time" "github.com/canonical/go-dqlite/v2/internal/protocol" "github.com/canonical/go-dqlite/v2/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // func TestProtocol_Heartbeat(t *testing.T) { // c, cleanup := newProtocol(t) // defer cleanup() // request, response := newMessagePair(512, 512) // protocol.EncodeHeartbeat(&request, uint64(time.Now().Unix())) // makeCall(t, c, &request, &response) // servers, err := protocol.DecodeNodes(&response) // require.NoError(t, err) // assert.Len(t, servers, 2) // assert.Equal(t, client.Nodes{ // {ID: uint64(1), Address: "1.2.3.4:666"}, // {ID: uint64(2), Address: "5.6.7.8:666"}}, // servers) // } // Test sending a request that needs to be written into the dynamic buffer. func TestProtocol_RequestWithDynamicBuffer(t *testing.T) { p, cleanup := newProtocol(t) defer cleanup() request, response := newMessagePair(64, 64) protocol.EncodeOpen(&request, "test.db", 0, "test-0") makeCall(t, p, &request, &response) id, err := protocol.DecodeDb(&response) require.NoError(t, err) sql := ` CREATE TABLE foo (n INT); CREATE TABLE bar (n INT); CREATE TABLE egg (n INT); CREATE TABLE baz (n INT); ` protocol.EncodeExecSQLV0(&request, uint64(id), sql, nil) makeCall(t, p, &request, &response) } func TestProtocol_Prepare(t *testing.T) { c, cleanup := newProtocol(t) defer cleanup() request, response := newMessagePair(64, 64) protocol.EncodeOpen(&request, "test.db", 0, "test-0") makeCall(t, c, &request, &response) db, err := protocol.DecodeDb(&response) require.NoError(t, err) protocol.EncodePrepare(&request, uint64(db), "CREATE TABLE test (n INT)") makeCall(t, c, &request, &response) _, stmt, params, err := protocol.DecodeStmt(&response) require.NoError(t, err) assert.Equal(t, uint32(0), stmt) assert.Equal(t, uint64(0), params) } /* func TestProtocol_Exec(t *testing.T) { client, cleanup := newProtocol(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) defer cancel() db, err := client.Open(ctx, "test.db", "volatile") require.NoError(t, err) stmt, err := client.Prepare(ctx, db.ID, "CREATE TABLE test (n INT)") require.NoError(t, err) _, err = client.Exec(ctx, db.ID, stmt.ID) require.NoError(t, err) } func TestProtocol_Query(t *testing.T) { client, cleanup := newProtocol(t) defer cleanup() ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) defer cancel() db, err := client.Open(ctx, "test.db", "volatile") require.NoError(t, err) start := time.Now() stmt, err := client.Prepare(ctx, db.ID, "CREATE TABLE test (n INT)") require.NoError(t, err) _, err = client.Exec(ctx, db.ID, stmt.ID) require.NoError(t, err) _, err = client.Finalize(ctx, db.ID, stmt.ID) require.NoError(t, err) stmt, err = client.Prepare(ctx, db.ID, "INSERT INTO test VALUES(1)") require.NoError(t, err) _, err = client.Exec(ctx, db.ID, stmt.ID) require.NoError(t, err) _, err = client.Finalize(ctx, db.ID, stmt.ID) require.NoError(t, err) stmt, err = client.Prepare(ctx, db.ID, "SELECT n FROM test") require.NoError(t, err) _, err = client.Query(ctx, db.ID, stmt.ID) require.NoError(t, err) _, err = client.Finalize(ctx, db.ID, stmt.ID) require.NoError(t, err) fmt.Printf("time %s\n", time.Since(start)) } */ func newProtocol(t *testing.T) (*protocol.Protocol, func()) { t.Helper() address, serverCleanup := newNode(t, 0) store := newStore(t, []string{address}) config := protocol.Config{ AttemptTimeout: 100 * time.Millisecond, BackoffFactor: time.Millisecond, } connector := protocol.NewLeaderConnector(store, config, logging.Test(t)) ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) defer cancel() proto, err := connector.Connect(ctx) require.NoError(t, err) cleanup := func() { proto.Close() serverCleanup() } return proto, cleanup } // Perform a client call. func makeCall(t *testing.T, p *protocol.Protocol, request, response *protocol.Message) { ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) defer cancel() err := p.Call(ctx, request, response) require.NoError(t, err) } // Return a new message pair to be used as request and response. func newMessagePair(size1, size2 int) (protocol.Message, protocol.Message) { message1 := protocol.Message{} message1.Init(size1) message2 := protocol.Message{} message2.Init(size2) return message1, message2 } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/request.go000066400000000000000000000117411471100661000255440ustar00rootroot00000000000000package protocol // DO NOT EDIT // // This file was generated by ./schema.sh // EncodeLeader encodes a Leader request. func EncodeLeader(request *Message) { request.reset() request.putUint64(0) request.putHeader(RequestLeader, 0) } // EncodeClient encodes a Client request. func EncodeClient(request *Message, id uint64) { request.reset() request.putUint64(id) request.putHeader(RequestClient, 0) } // EncodeHeartbeat encodes a Heartbeat request. func EncodeHeartbeat(request *Message, timestamp uint64) { request.reset() request.putUint64(timestamp) request.putHeader(RequestHeartbeat, 0) } // EncodeOpen encodes a Open request. func EncodeOpen(request *Message, name string, flags uint64, vfs string) { request.reset() request.putString(name) request.putUint64(flags) request.putString(vfs) request.putHeader(RequestOpen, 0) } // EncodePrepare encodes a Prepare request. func EncodePrepare(request *Message, db uint64, sql string) { request.reset() request.putUint64(db) request.putString(sql) request.putHeader(RequestPrepare, 0) } // EncodeExecV0 encodes a Exec request. func EncodeExecV0(request *Message, db uint32, stmt uint32, values NamedValues) { request.reset() request.putUint32(db) request.putUint32(stmt) request.putNamedValues(values) request.putHeader(RequestExec, 0) } // EncodeExecV1 encodes a Exec request. func EncodeExecV1(request *Message, db uint32, stmt uint32, values NamedValues32) { request.reset() request.putUint32(db) request.putUint32(stmt) request.putNamedValues32(values) request.putHeader(RequestExec, 1) } // EncodeQueryV0 encodes a Query request. func EncodeQueryV0(request *Message, db uint32, stmt uint32, values NamedValues) { request.reset() request.putUint32(db) request.putUint32(stmt) request.putNamedValues(values) request.putHeader(RequestQuery, 0) } // EncodeQueryV1 encodes a Query request. func EncodeQueryV1(request *Message, db uint32, stmt uint32, values NamedValues32) { request.reset() request.putUint32(db) request.putUint32(stmt) request.putNamedValues32(values) request.putHeader(RequestQuery, 1) } // EncodeFinalize encodes a Finalize request. func EncodeFinalize(request *Message, db uint32, stmt uint32) { request.reset() request.putUint32(db) request.putUint32(stmt) request.putHeader(RequestFinalize, 0) } // EncodeExecSQLV0 encodes a ExecSQL request. func EncodeExecSQLV0(request *Message, db uint64, sql string, values NamedValues) { request.reset() request.putUint64(db) request.putString(sql) request.putNamedValues(values) request.putHeader(RequestExecSQL, 0) } // EncodeExecSQLV1 encodes a ExecSQL request. func EncodeExecSQLV1(request *Message, db uint64, sql string, values NamedValues32) { request.reset() request.putUint64(db) request.putString(sql) request.putNamedValues32(values) request.putHeader(RequestExecSQL, 1) } // EncodeQuerySQLV0 encodes a QuerySQL request. func EncodeQuerySQLV0(request *Message, db uint64, sql string, values NamedValues) { request.reset() request.putUint64(db) request.putString(sql) request.putNamedValues(values) request.putHeader(RequestQuerySQL, 0) } // EncodeQuerySQLV1 encodes a QuerySQL request. func EncodeQuerySQLV1(request *Message, db uint64, sql string, values NamedValues32) { request.reset() request.putUint64(db) request.putString(sql) request.putNamedValues32(values) request.putHeader(RequestQuerySQL, 1) } // EncodeInterrupt encodes a Interrupt request. func EncodeInterrupt(request *Message, db uint64) { request.reset() request.putUint64(db) request.putHeader(RequestInterrupt, 0) } // EncodeAdd encodes a Add request. func EncodeAdd(request *Message, id uint64, address string) { request.reset() request.putUint64(id) request.putString(address) request.putHeader(RequestAdd, 0) } // EncodeAssign encodes a Assign request. func EncodeAssign(request *Message, id uint64, role uint64) { request.reset() request.putUint64(id) request.putUint64(role) request.putHeader(RequestAssign, 0) } // EncodeRemove encodes a Remove request. func EncodeRemove(request *Message, id uint64) { request.reset() request.putUint64(id) request.putHeader(RequestRemove, 0) } // EncodeDump encodes a Dump request. func EncodeDump(request *Message, name string) { request.reset() request.putString(name) request.putHeader(RequestDump, 0) } // EncodeCluster encodes a Cluster request. func EncodeCluster(request *Message, format uint64) { request.reset() request.putUint64(format) request.putHeader(RequestCluster, 0) } // EncodeTransfer encodes a Transfer request. func EncodeTransfer(request *Message, id uint64) { request.reset() request.putUint64(id) request.putHeader(RequestTransfer, 0) } // EncodeDescribe encodes a Describe request. func EncodeDescribe(request *Message, format uint64) { request.reset() request.putUint64(format) request.putHeader(RequestDescribe, 0) } // EncodeWeight encodes a Weight request. func EncodeWeight(request *Message, weight uint64) { request.reset() request.putUint64(weight) request.putHeader(RequestWeight, 0) } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/response.go000066400000000000000000000133401471100661000257070ustar00rootroot00000000000000package protocol // DO NOT EDIT // // This file was generated by ./schema.sh import "fmt" // DecodeFailure decodes a Failure response. func DecodeFailure(response *Message) (code uint64, message string, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseFailure { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseFailure), mtype) return } code = response.getUint64() message = response.getString() return } // DecodeWelcome decodes a Welcome response. func DecodeWelcome(response *Message) (heartbeatTimeout uint64, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseWelcome { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseWelcome), mtype) return } heartbeatTimeout = response.getUint64() return } // DecodeNodeLegacy decodes a NodeLegacy response. func DecodeNodeLegacy(response *Message) (address string, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseNodeLegacy { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseNodeLegacy), mtype) return } address = response.getString() return } // DecodeNode decodes a Node response. func DecodeNode(response *Message) (id uint64, address string, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseNode { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseNode), mtype) return } id = response.getUint64() address = response.getString() return } // DecodeNodes decodes a Nodes response. func DecodeNodes(response *Message) (servers Nodes, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseNodes { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseNodes), mtype) return } servers = response.getNodes() return } // DecodeDb decodes a Db response. func DecodeDb(response *Message) (id uint32, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseDb { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseDb), mtype) return } id = response.getUint32() response.getUint32() return } // DecodeStmt decodes a Stmt response. func DecodeStmt(response *Message) (db uint32, id uint32, params uint64, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseStmt { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseStmt), mtype) return } db = response.getUint32() id = response.getUint32() params = response.getUint64() return } // DecodeEmpty decodes a Empty response. func DecodeEmpty(response *Message) (err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseEmpty { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseEmpty), mtype) return } response.getUint64() return } // DecodeResult decodes a Result response. func DecodeResult(response *Message) (result Result, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseResult { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseResult), mtype) return } result = response.getResult() return } // DecodeRows decodes a Rows response. func DecodeRows(response *Message) (rows Rows, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseRows { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseRows), mtype) return } rows = response.getRows() return } // DecodeFiles decodes a Files response. func DecodeFiles(response *Message) (files Files, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseFiles { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseFiles), mtype) return } files = response.getFiles() return } // DecodeMetadata decodes a Metadata response. func DecodeMetadata(response *Message) (failureDomain uint64, weight uint64, err error) { mtype, _ := response.getHeader() if mtype == ResponseFailure { e := ErrRequest{} e.Code = response.getUint64() e.Description = response.getString() err = e return } if mtype != ResponseMetadata { err = fmt.Errorf("decode %s: unexpected type %d", responseDesc(ResponseMetadata), mtype) return } failureDomain = response.getUint64() weight = response.getUint64() return } golang-github-canonical-go-dqlite-2.0.0/internal/protocol/schema.go000066400000000000000000000050121471100661000253060ustar00rootroot00000000000000package protocol //go:generate ./schema.sh --request init //go:generate ./schema.sh --request Leader unused:uint64 //go:generate ./schema.sh --request Client id:uint64 //go:generate ./schema.sh --request Heartbeat timestamp:uint64 //go:generate ./schema.sh --request Open name:string flags:uint64 vfs:string //go:generate ./schema.sh --request Prepare db:uint64 sql:string //go:generate ./schema.sh --request Exec:0 db:uint32 stmt:uint32 values:NamedValues //go:generate ./schema.sh --request Exec:1 db:uint32 stmt:uint32 values:NamedValues32 //go:generate ./schema.sh --request Query:0 db:uint32 stmt:uint32 values:NamedValues //go:generate ./schema.sh --request Query:1 db:uint32 stmt:uint32 values:NamedValues32 //go:generate ./schema.sh --request Finalize db:uint32 stmt:uint32 //go:generate ./schema.sh --request ExecSQL:0 db:uint64 sql:string values:NamedValues //go:generate ./schema.sh --request ExecSQL:1 db:uint64 sql:string values:NamedValues32 //go:generate ./schema.sh --request QuerySQL:0 db:uint64 sql:string values:NamedValues //go:generate ./schema.sh --request QuerySQL:1 db:uint64 sql:string values:NamedValues32 //go:generate ./schema.sh --request Interrupt db:uint64 //go:generate ./schema.sh --request Add id:uint64 address:string //go:generate ./schema.sh --request Assign id:uint64 role:uint64 //go:generate ./schema.sh --request Remove id:uint64 //go:generate ./schema.sh --request Dump name:string //go:generate ./schema.sh --request Cluster format:uint64 //go:generate ./schema.sh --request Transfer id:uint64 //go:generate ./schema.sh --request Describe format:uint64 //go:generate ./schema.sh --request Weight weight:uint64 //go:generate ./schema.sh --response init //go:generate ./schema.sh --response Failure code:uint64 message:string //go:generate ./schema.sh --response Welcome heartbeatTimeout:uint64 //go:generate ./schema.sh --response NodeLegacy address:string //go:generate ./schema.sh --response Node id:uint64 address:string //go:generate ./schema.sh --response Nodes servers:Nodes //go:generate ./schema.sh --response Db id:uint32 unused:uint32 //go:generate ./schema.sh --response Stmt db:uint32 id:uint32 params:uint64 //go:generate ./schema.sh --response Empty unused:uint64 //go:generate ./schema.sh --response Result result:Result //go:generate ./schema.sh --response Rows rows:Rows //go:generate ./schema.sh --response Files files:Files //go:generate ./schema.sh --response Metadata failureDomain:uint64 weight:uint64 golang-github-canonical-go-dqlite-2.0.0/internal/protocol/schema.sh000077500000000000000000000044731471100661000253300ustar00rootroot00000000000000#!/bin/bash request_init() { cat > request.go < response.go <> request.go <> request.go <> request.go <> response.go <> response.go <> response.go < Remove a node from the cluster .describe
Show the details of a node .weight
Set the weight of a node .dump
[] Dump the database .reconfigure Reconfigure the cluster `[1:] } func (s *Shell) processCluster(ctx context.Context, line string) (string, error) { cli, err := client.FindLeader(ctx, s.store, client.WithDialFunc(s.dial)) if err != nil { return "", err } cluster, err := cli.Cluster(ctx) if err != nil { return "", err } result := "" switch s.format { case formatTabular: for i, server := range cluster { if i > 0 { result += "\n" } result += fmt.Sprintf("%x|%s|%s", server.ID, server.Address, server.Role) } case formatJson: data, err := json.Marshal(cluster) if err != nil { return "", err } var indented bytes.Buffer json.Indent(&indented, data, "", "\t") result = indented.String() } return result, nil } func (s *Shell) processLeader(ctx context.Context, line string) (string, error) { cli, err := client.FindLeader(ctx, s.store, client.WithDialFunc(s.dial)) if err != nil { return "", err } leader, err := cli.Leader(ctx) if err != nil { return "", err } if leader == nil { return "", nil } return leader.Address, nil } func (s *Shell) processRemove(ctx context.Context, line string) (string, error) { parts := strings.Split(line, " ") if len(parts) != 2 { return "", fmt.Errorf("bad command format, should be: .remove
") } address := parts[1] cli, err := client.FindLeader(ctx, s.store, client.WithDialFunc(s.dial)) if err != nil { return "", err } cluster, err := cli.Cluster(ctx) if err != nil { return "", err } for _, node := range cluster { if node.Address != address { continue } if err := cli.Remove(ctx, node.ID); err != nil { return "", fmt.Errorf("remove node %q: %w", address, err) } return "", nil } return "", fmt.Errorf("no node has address %q", address) } func (s *Shell) processDescribe(ctx context.Context, line string) (string, error) { parts := strings.Split(line, " ") if len(parts) != 2 { return "", fmt.Errorf("bad command format, should be: .describe
") } address := parts[1] cli, err := client.New(ctx, address, client.WithDialFunc(s.dial)) if err != nil { return "", err } metadata, err := cli.Describe(ctx) if err != nil { return "", err } result := "" switch s.format { case formatTabular: result += fmt.Sprintf("%s|%d|%d", address, metadata.FailureDomain, metadata.Weight) case formatJson: data, err := json.Marshal(metadata) if err != nil { return "", err } var indented bytes.Buffer json.Indent(&indented, data, "", "\t") result = indented.String() } return result, nil } func (s *Shell) processDump(ctx context.Context, line string) (string, error) { parts := strings.Split(line, " ") if len(parts) < 2 || len(parts) > 3 { return "NOK", fmt.Errorf("bad command format, should be: .dump
[]") } address := parts[1] cli, err := client.New(ctx, address, client.WithDialFunc(s.dial)) if err != nil { return "NOK", fmt.Errorf("dial failed") } database := "db.bin" if len(parts) == 3 { database = parts[2] } files, err := cli.Dump(ctx, database) if err != nil { return "NOK", fmt.Errorf("dump failed") } dir, err := os.Getwd() if err != nil { return "NOK", fmt.Errorf("os.Getwd() failed") } for _, file := range files { path := filepath.Join(dir, file.Name) err := ioutil.WriteFile(path, file.Data, 0600) if err != nil { return "NOK", fmt.Errorf("WriteFile failed on path %s", path) } } return "OK", nil } func (s *Shell) processReconfigure(ctx context.Context, line string) (string, error) { parts := strings.Split(line, " ") if len(parts) != 3 { //lint:ignore ST1005 intentional long prosy error message return "NOK", fmt.Errorf("bad command format, should be: .reconfigure \n" + "Args:\n" + "\tdir - Directory of node with up to date data\n" + "\tclusteryaml - Path to a .yaml file containing the desired cluster configuration\n\n" + "Help:\n" + "\tUse this command when trying to preserve the data from your cluster while changing the\n" + "\tconfiguration of the cluster because e.g. your cluster is broken due to unreachablee nodes.\n" + "\t0. BACKUP ALL YOUR NODE DATA DIRECTORIES BEFORE PROCEEDING!\n" + "\t1. Stop all dqlite nodes.\n" + "\t2. Identify the dir of the node with the most up to date raft term and log, this will be the argument.\n" + "\t3. Create a .yaml file with the same format as cluster.yaml (or use/adapt an existing cluster.yaml) with the\n " + "\t desired cluster configuration. This will be the argument.\n" + "\t Don't forget to make sure the ID's in the file line up with the ID's in the info.yaml files.\n" + "\t4. Run the .reconfigure command, it should return \"OK\".\n" + "\t5. Copy the snapshot-xxx-xxx-xxx, snapshot-xxx-xxx-xxx.meta, segment files (00000xxxxx-000000xxxxx), desired cluster.yaml\n" + "\t from over to the directories of the other nodes identified in , deleting any leftover snapshot-xxx-xxx-xxx, snapshot-xxx-xxx-xxx.meta,\n" + "\t segment (00000xxxxx-000000xxxxx, open-xxx) and metadata{1,2} files that it contains.\n" + "\t Make sure an info.yaml is also present that is in line with cluster.yaml.\n" + "\t6. Start all the dqlite nodes.\n" + "\t7. If, for some reason, this fails or gives undesired results, try again with data from another node (you should still have this from step 0).\n") } dir := parts[1] clusteryamlpath := parts[2] store, err := client.NewYamlNodeStore(clusteryamlpath) if err != nil { return "NOK", fmt.Errorf("failed to create YamlNodeStore from file at %s :%v", clusteryamlpath, err) } servers, err := store.Get(ctx) if err != nil { return "NOK", fmt.Errorf("failed to retrieve NodeInfo list :%v", err) } err = dqlite.ReconfigureMembershipExt(dir, servers) if err != nil { return "NOK", fmt.Errorf("failed to reconfigure membership :%v", err) } return "OK", nil } func (s *Shell) processWeight(ctx context.Context, line string) (string, error) { parts := strings.Split(line, " ") if len(parts) != 3 { return "", fmt.Errorf("bad command format, should be: .weight
") } address := parts[1] weight, err := strconv.Atoi(parts[2]) if err != nil || weight < 0 { return "", fmt.Errorf("bad weight %q", parts[2]) } cli, err := client.New(ctx, address, client.WithDialFunc(s.dial)) if err != nil { return "", err } if err := cli.Weight(ctx, uint64(weight)); err != nil { return "", err } return "", nil } func (s *Shell) processQuery(ctx context.Context, line string) (string, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { return "", fmt.Errorf("begin transaction: %w", err) } rows, err := tx.Query(line) if err != nil { err = fmt.Errorf("query: %w", err) if rbErr := tx.Rollback(); rbErr != nil { return "", fmt.Errorf("unable to rollback: %v", err) } return "", err } defer rows.Close() columns, err := rows.Columns() if err != nil { err = fmt.Errorf("columns: %w", err) if rbErr := tx.Rollback(); rbErr != nil { return "", fmt.Errorf("unable to rollback: %v", err) } return "", err } n := len(columns) var sb strings.Builder writer := tabwriter.NewWriter(&sb, 0, 8, 1, '\t', 0) for _, col := range columns { fmt.Fprintf(writer, "%s\t", col) } fmt.Fprintln(writer) for rows.Next() { row := make([]interface{}, n) rowPointers := make([]interface{}, n) for i := range row { rowPointers[i] = &row[i] } if err := rows.Scan(rowPointers...); err != nil { err = fmt.Errorf("scan: %w", err) if rbErr := tx.Rollback(); rbErr != nil { return "", fmt.Errorf("unable to rollback: %v", err) } return "", err } for _, column := range row { fmt.Fprintf(writer, "%v\t", column) } fmt.Fprintln(writer) } if err := rows.Err(); err != nil { err = fmt.Errorf("rows: %w", err) if rbErr := tx.Rollback(); rbErr != nil { return "", fmt.Errorf("unable to rollback: %v", err) } return "", err } if err := tx.Commit(); err != nil { return "", fmt.Errorf("commit: %w", err) } if err := writer.Flush(); err != nil { return "", fmt.Errorf("flush: %w", err) } return strings.TrimRight(sb.String(), "\n"), nil } golang-github-canonical-go-dqlite-2.0.0/logging/000077500000000000000000000000001471100661000214725ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/logging/func.go000066400000000000000000000012041471100661000227510ustar00rootroot00000000000000package logging import ( "fmt" "testing" ) // Func is a function that can be used for logging. type Func func(Level, string, ...interface{}) // Test returns a logging function that forwards messages to the test logger. func Test(t *testing.T) Func { return func(l Level, format string, a ...interface{}) { format = fmt.Sprintf("%s: %s", l.String(), format) t.Logf(format, a...) } } // Stdout returns a logging function that prints log messages on standard // output. func Stdout() Func { return func(l Level, format string, a ...interface{}) { format = fmt.Sprintf("%s: %s\n", l.String(), format) fmt.Printf(format, a...) } } golang-github-canonical-go-dqlite-2.0.0/logging/func_test.go000066400000000000000000000002601471100661000240110ustar00rootroot00000000000000package logging_test import ( "testing" "github.com/canonical/go-dqlite/v2/logging" ) func Test_TestFunc(t *testing.T) { f := logging.Test(t) f(logging.Info, "hello") } golang-github-canonical-go-dqlite-2.0.0/logging/level.go000066400000000000000000000005351471100661000231330ustar00rootroot00000000000000package logging // Level defines the logging level. type Level int // Available logging levels. const ( None Level = iota Debug Info Warn Error ) func (l Level) String() string { switch l { case Debug: return "DEBUG" case Info: return "INFO" case Warn: return "WARN" case Error: return "ERROR" default: return "UNKNOWN" } } golang-github-canonical-go-dqlite-2.0.0/logging/level_test.go000066400000000000000000000006731471100661000241750ustar00rootroot00000000000000package logging_test import ( "testing" "github.com/canonical/go-dqlite/v2/logging" "github.com/stretchr/testify/assert" ) func TestLevel_String(t *testing.T) { assert.Equal(t, "DEBUG", logging.Debug.String()) assert.Equal(t, "INFO", logging.Info.String()) assert.Equal(t, "WARN", logging.Warn.String()) assert.Equal(t, "ERROR", logging.Error.String()) unknown := logging.Level(666) assert.Equal(t, "UNKNOWN", unknown.String()) } golang-github-canonical-go-dqlite-2.0.0/node.go000066400000000000000000000217161471100661000213270ustar00rootroot00000000000000package dqlite import ( "context" "time" "github.com/canonical/go-dqlite/v2/client" "github.com/canonical/go-dqlite/v2/internal/bindings" "github.com/pkg/errors" ) // Node runs a dqlite node. type Node struct { server *bindings.Node // Low-level C implementation acceptCh chan error // Receives connection handling errors id uint64 address string bindAddress string cancel context.CancelFunc } // NodeInfo is a convenience alias for client.NodeInfo. type NodeInfo = client.NodeInfo // SnapshotParams exposes bindings.SnapshotParams. Used for setting dqlite's // snapshot parameters. // SnapshotParams.Threshold controls after how many raft log entries a snapshot is // taken. The higher this number, the lower the frequency of the snapshots. // SnapshotParams.Trailing controls how many raft log entries are retained after // taking a snapshot. type SnapshotParams = bindings.SnapshotParams // Option can be used to tweak node parameters. type Option func(*options) // WithDialFunc sets a custom dial function for the server. func WithDialFunc(dial client.DialFunc) Option { return func(options *options) { options.DialFunc = dial } } // WithBindAddress sets a custom bind address for the server. func WithBindAddress(address string) Option { return func(options *options) { options.BindAddress = address } } // WithNetworkLatency sets the average one-way network latency. func WithNetworkLatency(latency time.Duration) Option { return func(options *options) { options.NetworkLatency = uint64(latency.Nanoseconds()) } } // WithFailureDomain sets the code of the failure domain the node belongs to. func WithFailureDomain(code uint64) Option { return func(options *options) { options.FailureDomain = code } } // WithSnapshotParams sets the snapshot parameters of the node. func WithSnapshotParams(params SnapshotParams) Option { return func(options *options) { options.SnapshotParams = params } } // WithDiskMode enables dqlite disk-mode on the node. // WARNING: This is experimental API, use with caution // and prepare for data loss. // UNSTABLE: Behavior can change in future. // NOT RECOMMENDED for production use-cases, use at own risk. func WithDiskMode(disk bool) Option { return func(options *options) { options.DiskMode = disk } } // WithAutoRecovery enables or disables auto-recovery of persisted data // at startup for this node. // // When auto-recovery is enabled, raft snapshots and segment files may be // deleted at startup if they are determined to be corrupt. This helps // the startup process to succeed in more cases, but can lead to data loss. // // Auto-recovery is enabled by default. func WithAutoRecovery(recovery bool) Option { return func(options *options) { options.AutoRecovery = recovery } } // New creates a new Node instance. func New(id uint64, address string, dir string, options ...Option) (*Node, error) { o := defaultOptions() for _, option := range options { option(o) } ctx, cancel := context.WithCancel(context.Background()) server, err := bindings.NewNode(ctx, id, address, dir) if err != nil { cancel() return nil, err } if o.DialFunc != nil { if err := server.SetDialFunc(o.DialFunc); err != nil { cancel() return nil, err } } if o.BindAddress != "" { if err := server.SetBindAddress(o.BindAddress); err != nil { cancel() return nil, err } } if o.NetworkLatency != 0 { if err := server.SetNetworkLatency(o.NetworkLatency); err != nil { cancel() return nil, err } } if o.FailureDomain != 0 { if err := server.SetFailureDomain(o.FailureDomain); err != nil { cancel() return nil, err } } if o.SnapshotParams.Threshold != 0 || o.SnapshotParams.Trailing != 0 { if err := server.SetSnapshotParams(o.SnapshotParams); err != nil { cancel() return nil, err } } if o.DiskMode { if err := server.EnableDiskMode(); err != nil { cancel() return nil, err } } if !o.AutoRecovery { if err := server.SetAutoRecovery(false); err != nil { cancel() return nil, err } } s := &Node{ server: server, acceptCh: make(chan error, 1), id: id, address: address, bindAddress: o.BindAddress, cancel: cancel, } return s, nil } // BindAddress returns the network address the node is listening to. func (s *Node) BindAddress() string { return s.server.GetBindAddress() } // Start serving requests. func (s *Node) Start() error { return s.server.Start() } // Recover a node by forcing a new cluster configuration. // // Deprecated: use ReconfigureMembershipExt instead, which does not require // instantiating a new Node object. func (s *Node) Recover(cluster []NodeInfo) error { return s.server.Recover(cluster) } // Hold configuration options for a dqlite server. type options struct { Log client.LogFunc DialFunc client.DialFunc BindAddress string NetworkLatency uint64 FailureDomain uint64 SnapshotParams bindings.SnapshotParams DiskMode bool AutoRecovery bool } // Close the server, releasing all resources it created. func (s *Node) Close() error { s.cancel() // Send a stop signal to the dqlite event loop. if err := s.server.Stop(); err != nil { return errors.Wrap(err, "server failed to stop") } s.server.Close() return nil } // BootstrapID is a magic ID that should be used for the fist node in a // cluster. Alternatively ID 1 can be used as well. const BootstrapID = 0x2dc171858c3155be // GenerateID generates a unique ID for a new node, based on a hash of its // address and the current time. func GenerateID(address string) uint64 { return bindings.GenerateID(address) } // ReconfigureMembership forces a new cluster configuration. // // Deprecated: this function ignores the provided node roles and makes every // node in the new configuration a voter. Use ReconfigureMembershipExt, which // respects the provided roles. func ReconfigureMembership(dir string, cluster []NodeInfo) error { server, err := bindings.NewNode(context.Background(), 1, "1", dir) if err != nil { return err } defer server.Close() return server.Recover(cluster) } // ReconfigureMembershipExt forces a new cluster configuration. // // This function is useful to revive a cluster that can't achieve quorum in its // old configuration because some nodes can't be brought online. Forcing a new // configuration is unsafe, and you should follow these steps to avoid data // loss and inconsistency: // // 1. Make sure no dqlite node in the cluster is running. // 2. Identify all dqlite nodes that have survived and that you want to be part // of the recovered cluster. Call this the "new member list". // 3. Call ReadLastEntryInfo on each node in the member list, and find which // node has the most recent entry according to LastEntryInfo.Before. Call this // the "template node". // 4. Invoke ReconfigureMembershipExt exactly one time, on the template node. // The arguments are the data directory of the template node and the new // member list. // 5. Copy the data directory of the template node to all other nodes in the // new member list, replacing their previous data directories. // 6. Restart all nodes in the new member list. func ReconfigureMembershipExt(dir string, cluster []NodeInfo) error { server, err := bindings.NewNode(context.Background(), 1, "1", dir) if err != nil { return err } defer server.Close() return server.RecoverExt(cluster) } // LastEntryInfo holds information about the last entry in the persistent raft // log of a node. // // The zero value is not a valid entry description, and can be used as a // sentinel. type LastEntryInfo struct { Term, Index uint64 } // Before tells whether the entry described by the receiver is strictly less // recent than another entry. // // Entry A is less recent than entry B when A has a lower term number, or // when A and B have the same term number and A has a lower index. func (lhs LastEntryInfo) Before(rhs LastEntryInfo) bool { return lhs.Term < rhs.Term || (lhs.Term == rhs.Term && lhs.Index < rhs.Index) } // ReadLastEntryInfo reads information about the last entry in the raft // persistent log from a node's data directory. // // This is intended to be used during the cluster recovery process, see // ReconfigureMembershipExt. The node must not be running. // // This is a non-destructive operation, but is not read-only, since it has the // side effect of renaming raft open segment files to closed segment files. func ReadLastEntryInfo(dir string) (LastEntryInfo, error) { node, err := bindings.NewNode(context.Background(), 1, "1", dir) if err != nil { return LastEntryInfo{}, err } defer node.Close() if err = node.SetAutoRecovery(false); err != nil { return LastEntryInfo{}, err } index, term, err := node.DescribeLastEntry() if err != nil { return LastEntryInfo{}, err } return LastEntryInfo{term, index}, nil } // Create a options object with sane defaults. func defaultOptions() *options { return &options{ DialFunc: client.DefaultDialFunc, DiskMode: false, // Be explicit about not enabling disk-mode by default. AutoRecovery: true, } } golang-github-canonical-go-dqlite-2.0.0/node_test.go000066400000000000000000000011071471100661000223560ustar00rootroot00000000000000package dqlite_test import ( "fmt" "sort" dqlite "github.com/canonical/go-dqlite/v2" ) type infoSorter []dqlite.LastEntryInfo func (is infoSorter) Len() int { return len(is) } func (is infoSorter) Less(i, j int) bool { return is[i].Before(is[j]) } func (is infoSorter) Swap(i, j int) { is[i], is[j] = is[j], is[i] } func ExampleLastEntryInfo() { infos := []dqlite.LastEntryInfo{ {Term: 1, Index: 2}, {Term: 2, Index: 2}, {Term: 1, Index: 1}, {Term: 2, Index: 1}, } sort.Sort(infoSorter(infos)) fmt.Println(infos) // Output: // [{1 1} {1 2} {2 1} {2 2}] } golang-github-canonical-go-dqlite-2.0.0/test/000077500000000000000000000000001471100661000210235ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/test/dqlite-demo-util.sh000077500000000000000000000031221471100661000245370ustar00rootroot00000000000000# dqlite-demo test utilities GO=${GO:-go} ASAN=${ASAN:-} VERBOSE=${VERBOSE:-0} DISK=${DISK:-0} $GO build -tags libsqlite3 $ASAN ./cmd/dqlite-demo/ DIR=$(mktemp -d) start_node() { n="${1}" pidfile="${DIR}/pid.${n}" join="${2}" verbose="" disk="" if [ "$VERBOSE" -eq 1 ]; then verbose="--verbose" fi if [ "$DISK" -eq 1 ]; then disk="--disk" fi ./dqlite-demo --dir "$DIR" --api=127.0.0.1:800"${n}" --db=127.0.0.1:900"${n}" "$join" $verbose $disk & echo "${!}" > "${pidfile}" i=0 while ! nc -z 127.0.0.1 800"${n}" 2>/dev/null; do i=$(expr $i + 1) sleep 0.2 if [ "$i" -eq 25 ]; then echo "Error: node ${n} not yet up after 5 seconds" exit 1 fi done } kill_node() { n=$1 pidfile="${DIR}/pid.${n}" if ! [ -e "$pidfile" ]; then return fi pid=$(cat "${pidfile}") kill -TERM "$pid" wait "$pid" rm "${pidfile}" } set_up_node() { n=$1 join="" if [ "$n" -ne 1 ]; then join=--join=127.0.0.1:9001 fi echo "=> Set up dqlite-demo node $n" start_node "${n}" "${join}" } tear_down_node() { n=$1 echo "=> Tear down dqlite-demo node $n" kill_node "$n" } set_up() { echo "=> Set up dqlite-demo cluster" set_up_node 1 set_up_node 2 set_up_node 3 } tear_down() { err=$? trap '' HUP INT TERM echo "=> Tear down dqlite-demo cluster" tear_down_node 3 tear_down_node 2 tear_down_node 1 rm -rf "$DIR" exit $err } sig_handler() { trap '' EXIT tear_down } golang-github-canonical-go-dqlite-2.0.0/test/dqlite-demo.sh000077500000000000000000000012571471100661000235730ustar00rootroot00000000000000#!/bin/sh -eu # # Test the dqlite-demo application. BASEDIR=$(dirname "$0") . "$BASEDIR"/dqlite-demo-util.sh trap tear_down EXIT trap sig_handler HUP INT TERM set_up echo "=> Start test" echo "=> Put key to node 1" if [ "$(curl -s -X PUT -d my-key http://127.0.0.1:8001/my-value)" != "done" ]; then echo "Error: put key to node 1" fi echo "=> Get key from node 1" if [ "$(curl -s http://127.0.0.1:8001/my-value)" != "my-key" ]; then echo "Error: get key from node 1" fi echo "=> Kill node 1" kill_node 1 echo "=> Get key from node 2" if [ "$(curl -s http://127.0.0.1:8002/my-value)" != "my-key" ]; then echo "Error: get key from node 2" fi echo "=> Test successful" golang-github-canonical-go-dqlite-2.0.0/test/recover.sh000077500000000000000000000044501471100661000230320ustar00rootroot00000000000000#!/bin/sh -eu # # Test the dqlite cluster recovery. ASAN=${ASAN:-} BASEDIR=$(dirname "$0") . "$BASEDIR"/dqlite-demo-util.sh $GO build -tags libsqlite3 $ASAN ./cmd/dqlite/ trap tear_down EXIT trap sig_handler HUP INT TERM set_up echo "=> Start test" echo "=> Put key to node 1" if [ "$(curl -s -X PUT -d my-value http://127.0.0.1:8001/my-key)" != "done" ]; then echo "Error: put key to node 1" fi echo "=> Get key from node 1" if [ "$(curl -s http://127.0.0.1:8001/my-key)" != "my-value" ]; then echo "Error: get key from node 1" fi echo "=> Stopping the cluster" tear_down_node 3 tear_down_node 2 tear_down_node 1 echo "=> Running recovery on node 1" node1_dir=$DIR/127.0.0.1:9001 node2_dir=$DIR/127.0.0.1:9002 node1_id=$(grep ID "$node1_dir"/info.yaml | cut -d" " -f2) node2_id=$(grep ID "$node2_dir"/info.yaml | cut -d" " -f2) target_yaml=${DIR}/cluster.yaml cat < "$target_yaml" - Address: 127.0.0.1:9001 ID: ${node1_id} Role: 0 - Address: 127.0.0.1:9002 ID: ${node2_id} Role: 1 EOF if ! ./dqlite -s 127.0.0.1:9001 test ".reconfigure ${node1_dir} ${target_yaml}"; then echo "Error: Reconfigure failed" exit 1 fi echo "=> Starting nodes 1 & 2" start_node 1 "" start_node 2 "" echo "=> Confirming new config" if [ "$(./dqlite -s 127.0.0.1:9001 test .leader)" != 127.0.0.1:9001 ]; then echo "Error: Expected node 1 to be leader" exit 1 fi if [ "$(./dqlite -s 127.0.0.1:9001 test .cluster | wc -l)" != 2 ]; then echo "Error: Expected 2 servers in the cluster" exit 1 fi if ! ./dqlite -s 127.0.0.1:9001 test .cluster | grep -q "127.0.0.1:9001|voter"; then echo "Error: server 1 not voter" exit 1 fi if ! ./dqlite -s 127.0.0.1:9001 test .cluster | grep -q "127.0.0.1:9002|stand-by"; then echo "Error: server 2 not stand-by" exit 1 fi echo "=> Get original key from node 1" if [ "$(curl -s http://127.0.0.1:8001/my-key)" != "my-value" ]; then echo "Error: get key from node 1" exit 1 fi echo "=> Put new key to node 1" if [ "$(curl -s -X PUT -d my-value-new http://127.0.0.1:8001/my-key-new)" != "done" ]; then echo "Error: put new key to node 1" exit 1 fi echo "=> Get new key from node 1" if [ "$(curl -s http://127.0.0.1:8001/my-key-new)" != "my-value-new" ]; then echo "Error: get new key from node 1" exit 1 fi echo "=> Test successful" golang-github-canonical-go-dqlite-2.0.0/test/roles.sh000077500000000000000000000132171471100661000225120ustar00rootroot00000000000000#!/bin/bash -eu # # Test dynamic roles management. GO=${GO:-go} ASAN=${ASAN:-} VERBOSE=${VERBOSE:-0} DIR=$(mktemp -d) BINARY=$DIR/main CLUSTER=127.0.0.1:9001,127.0.0.1:9002,127.0.0.1:9003,127.0.0.1:9004,127.0.0.1:9005,127.0.0.1:9006 N=7 DISK=${DISK:-0} $GO build -tags libsqlite3 $ASAN ./cmd/dqlite/ set_up_binary() { cat > "$DIR"/main.go < 1 { join = append(join, "127.0.0.1:9001") } addr := fmt.Sprintf("127.0.0.1:900%d", index) if err := os.MkdirAll(dir, 0755); err != nil { panic(err) } app, err := app.New( dir, app.WithAddress(addr), app.WithCluster(join), app.WithLogFunc(logFunc), app.WithRolesAdjustmentFrequency(3 * time.Second), app.WithDiskMode($DISK != 0), ) if err != nil { panic(err) } ctx, _ := context.WithTimeout(context.Background(), 30 * time.Second) if err := app.Ready(ctx); err != nil { panic(err) } <-ch ctx, cancel := context.WithTimeout(context.Background(), 2 * time.Second) defer cancel() app.Handover(ctx) app.Close() } EOF $GO build -o "$BINARY" -tags libsqlite3 $ASAN "$DIR"/main.go } start_node() { n="${1}" pidfile="${DIR}/pid.${n}" $BINARY "$n" & echo "${!}" > "${pidfile}" } kill_node() { n=$1 signal=$2 pidfile="${DIR}/pid.${n}" if ! [ -e "$pidfile" ]; then return fi pid=$(cat "${pidfile}") kill -"${signal}" "$pid" wait "$pid" || true rm "${pidfile}" } # Wait for the cluster to have 3 voters, 2 stand-bys and 1 spare wait_stable() { i=0 while true; do i=$(expr $i + 1) voters=$(./dqlite -s "$CLUSTER" test .cluster | grep voter | wc -l) standbys=$(./dqlite -s "$CLUSTER" test .cluster | grep stand-by | wc -l) spares=$(./dqlite -s "$CLUSTER" test .cluster | grep spare | wc -l) if [ "$voters" -eq 3 ] && [ "$standbys" -eq 3 ] && [ "$spares" -eq 1 ] ; then break fi if [ "$i" -eq 40 ]; then echo "Error: node roles not yet stable after 10 seconds" ./dqlite -s "$CLUSTER" test .cluster exit 1 fi sleep 0.25 done } # Wait for the given node to have the given role wait_role() { index=$1 role=$2 i=0 while true; do i=$(expr $i + 1) current=$(./dqlite -s "$CLUSTER" test .cluster | grep "127.0.0.1:900${index}" | cut -f 3 -d "|") if [ "$current" = "$role" ]; then break fi if [ "$i" -eq 40 ]; then echo "Error: node $index has role $current instead of $role" ./dqlite -s "$CLUSTER" test .cluster exit 1 fi sleep 0.25 done } set_up_node() { n=$1 echo "=> Set up test node $n" start_node "${n}" } set_up() { echo "=> Set up test cluster" set_up_binary for i in $(seq $N); do set_up_node "$i" done } tear_down_node() { n=$1 echo "=> Tear down test node $n" kill_node "$n" TERM } tear_down() { err=$? trap '' HUP INT TERM echo "=> Tear down test cluster" for i in $(seq $N -1 1); do tear_down_node "$i" done rm -rf "$DIR" exit $err } sig_handler() { trap '' EXIT tear_down } trap tear_down EXIT trap sig_handler HUP INT TERM set_up echo "=> Wait for roles to get stable" wait_stable # Stop one node at a time gracefully, then check that the cluster is stable. for i in $(seq 10); do index=$((1 + RANDOM % $N)) echo "=> Stop node $index" kill_node $index TERM echo "=> Wait for roles to get stable" wait_role $index spare wait_stable echo "=> Restart node $index" start_node $index sleep 2 done # Kill one node at a time ungracefully, then check that the cluster is stable. for i in $(seq 1); do index=$((1 + RANDOM % $N)) echo "=> Kill node $index" kill_node $index KILL echo "=> Wait for roles to get stable" wait_role $index spare wait_stable echo "=> Restart node $index" start_node $index sleep 2 done # Stop two nodes at a time gracefully, then check that the cluster is stable. for i in $(seq 10); do index1=$((1 + RANDOM % $N)) index2=$((1 + (index1 + $((RANDOM % ($N - 1)))) % $N)) echo "=> Stop nodes $index1 and $index2" kill_node $index1 TERM kill_node $index2 TERM sleep 2 echo "=> Restart nodes $index1 and $index2" start_node $index1 start_node $index2 echo "=> Wait for roles to get stable" wait_stable sleep 1 done # Kill two nodes at a time ungracefully, then check that the cluster is stable. for i in $(seq 10); do index1=$((1 + RANDOM % $N)) index2=$((1 + (index1 + $((RANDOM % ($N - 1)))) % $N)) echo "=> Stop nodes $index1 and $index2" kill_node $index1 KILL kill_node $index2 KILL sleep 5 echo "=> Restart nodes $index1 and $index2" start_node $index1 start_node $index2 echo "=> Wait for roles to get stable" wait_stable sleep 1 done echo "=> Test successful" golang-github-canonical-go-dqlite-2.0.0/tracing/000077500000000000000000000000001471100661000214735ustar00rootroot00000000000000golang-github-canonical-go-dqlite-2.0.0/tracing/tracing.go000066400000000000000000000043471471100661000234610ustar00rootroot00000000000000package tracing import "context" type contextKey string const ( traceContextKey contextKey = "trace" ) // WithTracer returns a context with the tracer embedded in the context // under the context key. func WithTracer(ctx context.Context, tracer Tracer) context.Context { return context.WithValue(ctx, traceContextKey, tracer) } // Start returns a new context with the given trace. // A valid span is always returned, even if the context does not contain a // tracer. In that case, the span is a noop span. func Start(ctx context.Context, name, query string) (context.Context, Span) { value := ctx.Value(traceContextKey) if value == nil { return ctx, noopSpan{} } tracer, ok := value.(Tracer) if !ok { return ctx, noopSpan{} } return tracer.Start(ctx, name, query) } // Tracer is the interface that all tracers must implement. type Tracer interface { // Start creates a span and a context.Context containing the newly-created // span. // // If the context.Context provided in `ctx` contains a Span then the // newly-created Span will be a child of that span, otherwise it will be a // root span. // // Any Span that is created MUST also be ended. This is the responsibility // of the user. Implementations of this API may leak memory or other // resources if Spans are not ended. Start(context.Context, string, string) (context.Context, Span) } // Span is the individual component of a trace. It represents a single named // and timed operation of a workflow that is traced. A Tracer is used to // create a Span and it is then up to the operation the Span represents to // properly end the Span when the operation itself ends. type Span interface { // End completes the Span. The Span is considered complete and ready to be // delivered through the rest of the telemetry pipeline after this method // is called. Therefore, updates to the Span are not allowed after this // method has been called. End() } // noopSpan is a span that does nothing. type noopSpan struct{} // End completes the Span. The Span is considered complete and ready to be // delivered through the rest of the telemetry pipeline after this method // is called. Therefore, updates to the Span are not allowed after this // method has been called. func (noopSpan) End() {}