pax_global_header00006660000000000000000000000064146750404430014521gustar00rootroot0000000000000052 comment=953e0df999cabb3f5eef714df9921c00e9f632c2 cassandra-gocql-driver-1.7.0/000077500000000000000000000000001467504044300160615ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/.asf.yaml000066400000000000000000000023331467504044300175750ustar00rootroot00000000000000# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. notifications: commits: commits@cassandra.apache.org issues: commits@cassandra.apache.org pullrequests: pr@cassandra.apache.org jira_options: link worklog github: description: "GoCQL Driver for Apache Cassandra®" homepage: https://cassandra.apache.org/ enabled_merge_buttons: squash: false merge: false rebase: true features: wiki: false issues: true projects: false autolink_jira: - CASSANDRA cassandra-gocql-driver-1.7.0/.github/000077500000000000000000000000001467504044300174215ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/.github/issue_template.md000066400000000000000000000011641467504044300227700ustar00rootroot00000000000000Please answer these questions before submitting your issue. Thanks! ### What version of Cassandra are you using? ### What version of Gocql are you using? ### What version of Go are you using? ### What did you do? ### What did you expect to see? ### What did you see instead? --- If you are having connectivity related issues please share the following additional information ### Describe your Cassandra cluster please provide the following information - output of `nodetool status` - output of `SELECT peer, rpc_address FROM system.peers` - rebuild your application with the `gocql_debug` tag and post the output cassandra-gocql-driver-1.7.0/.github/workflows/000077500000000000000000000000001467504044300214565ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/.github/workflows/main.yml000066400000000000000000000156741467504044300231420ustar00rootroot00000000000000name: Build on: push: branches: - master pull_request: types: [ opened, synchronize, reopened ] env: CCM_VERSION: "6e71061146f7ae67b84ccd2b1d90d7319b640e4c" jobs: build: name: Unit tests runs-on: ubuntu-latest strategy: matrix: go: [ '1.22', '1.23' ] steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - run: go vet - name: Run unit tests run: go test -v -tags unit -race integration-cassandra: timeout-minutes: 15 needs: - build name: Integration Tests runs-on: ubuntu-latest strategy: fail-fast: false matrix: go: [ '1.22', '1.23' ] cassandra_version: [ '4.0.13', '4.1.6' ] auth: [ "false" ] compressor: [ "snappy" ] tags: [ "cassandra", "integration", "ccm" ] steps: - uses: actions/checkout@v2 - uses: actions/setup-go@v2 with: go-version: ${{ matrix.go }} - uses: actions/cache@v2 id: gomod-cache with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('go.mod') }} restore-keys: | ${{ runner.os }}-go- - name: Install CCM run: pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" - name: Start cassandra nodes run: | VERSION=${{ matrix.cassandra_version }} keypath="$(pwd)/testdata/pki" conf=( "client_encryption_options.enabled: true" "client_encryption_options.keystore: $keypath/.keystore" "client_encryption_options.keystore_password: cassandra" "client_encryption_options.require_client_auth: true" "client_encryption_options.truststore: $keypath/.truststore" "client_encryption_options.truststore_password: cassandra" "concurrent_reads: 2" "concurrent_writes: 2" "write_request_timeout_in_ms: 5000" "read_request_timeout_in_ms: 5000" ) if [[ $VERSION == 3.*.* ]]; then conf+=( "rpc_server_type: sync" "rpc_min_threads: 2" "rpc_max_threads: 2" "enable_user_defined_functions: true" "enable_materialized_views: true" ) elif [[ $VERSION == 4.0.* ]]; then conf+=( "enable_user_defined_functions: true" "enable_materialized_views: true" ) else conf+=( "user_defined_functions_enabled: true" "materialized_views_enabled: true" ) fi ccm remove test || true ccm create test -v $VERSION -n 3 -d --vnodes --jvm_arg="-Xmx256m -XX:NewSize=100m" ccm updateconf "${conf[@]}" export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" ccm start --wait-for-binary-proto --verbose ccm status ccm node1 nodetool status args="-gocql.timeout=60s -runssl -proto=4 -rf=3 -clusterSize=3 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." echo "args=$args" >> $GITHUB_ENV echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV - name: Integration tests run: | export JVM_EXTRA_OPTS="${{env.JVM_EXTRA_OPTS}}" go test -v -tags "${{ matrix.tags }} gocql_debug" -timeout=5m -race ${{ env.args }} - name: 'Save ccm logs' if: 'failure()' uses: actions/upload-artifact@v3 with: name: ccm-cluster path: /home/runner/.ccm/test retention-days: 5 integration-auth-cassandra: timeout-minutes: 15 needs: - build name: Integration Tests with auth runs-on: ubuntu-latest strategy: fail-fast: false matrix: go: [ '1.22', '1.23' ] cassandra_version: [ '4.0.13' ] compressor: [ "snappy" ] tags: [ "integration" ] steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Install CCM run: pip install "git+https://github.com/riptano/ccm.git@${CCM_VERSION}" - name: Start cassandra nodes run: | VERSION=${{ matrix.cassandra_version }} keypath="$(pwd)/testdata/pki" conf=( "client_encryption_options.enabled: true" "client_encryption_options.keystore: $keypath/.keystore" "client_encryption_options.keystore_password: cassandra" "client_encryption_options.require_client_auth: true" "client_encryption_options.truststore: $keypath/.truststore" "client_encryption_options.truststore_password: cassandra" "concurrent_reads: 2" "concurrent_writes: 2" "write_request_timeout_in_ms: 5000" "read_request_timeout_in_ms: 5000" "authenticator: PasswordAuthenticator" "authorizer: CassandraAuthorizer" "enable_user_defined_functions: true" ) if [[ $VERSION == 3.*.* ]]; then conf+=( "rpc_server_type: sync" "rpc_min_threads: 2" "rpc_max_threads: 2" "enable_user_defined_functions: true" "enable_materialized_views: true" ) elif [[ $VERSION == 4.0.* ]]; then conf+=( "enable_user_defined_functions: true" "enable_materialized_views: true" ) else conf+=( "user_defined_functions_enabled: true" "materialized_views_enabled: true" ) fi ccm remove test || true ccm create test -v $VERSION -n 1 -d --vnodes --jvm_arg="-Xmx256m -XX:NewSize=100m" ccm updateconf "${conf[@]}" rm -rf $HOME/.ccm/test/node1/data/system_auth export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" ccm start --wait-for-binary-proto --verbose ccm status ccm node1 nodetool status args="-gocql.timeout=60s -runssl -proto=4 -rf=3 -clusterSize=1 -autowait=2000ms -compressor=${{ matrix.compressor }} -gocql.cversion=$VERSION -cluster=$(ccm liveset) ./..." echo "args=$args" >> $GITHUB_ENV echo "JVM_EXTRA_OPTS=$JVM_EXTRA_OPTS" >> $GITHUB_ENV sleep 30s - name: Integration tests run: | export JVM_EXTRA_OPTS="${{env.JVM_EXTRA_OPTS}}" go test -v -run=TestAuthentication -tags "${{ matrix.tags }} gocql_debug" -timeout=15s -runauth ${{ env.args }} cassandra-gocql-driver-1.7.0/.gitignore000066400000000000000000000000621467504044300200470ustar00rootroot00000000000000gocql-fuzz fuzz-corpus fuzz-work gocql.test .idea cassandra-gocql-driver-1.7.0/CHANGELOG.md000066400000000000000000000147761467504044300177110ustar00rootroot00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] ### Added ### Changed ### Fixed ## [1.7.0] - 2024-09-23 This release is the first after the donation of gocql to the Apache Software Foundation (ASF) ### Changed - Update DRIVER_NAME parameter in STARTUP messages to a different value intended to clearly identify this driver as an ASF driver. This should clearly distinguish this release (and future gocql-cassandra-driver releases) from prior versions. (#1824) - Supported Go versions updated to 1.23 and 1.22 to conform to gocql's sunset model. (#1825) ## [1.6.0] - 2023-08-28 ### Added - Added the InstaclustrPasswordAuthenticator to the list of default approved authenticators. (#1711) - Added the `com.scylladb.auth.SaslauthdAuthenticator` and `com.scylladb.auth.TransitionalAuthenticator` to the list of default approved authenticators. (#1712) - Added transferring Keyspace and Table names to the Query from the prepared response and updating information about that every time this information is received. (#1714) ### Changed - Tracer created with NewTraceWriter now includes the thread information from trace events in the output. (#1716) - Increased default timeouts so that they are higher than Cassandra default timeouts. This should help prevent issues where a default configuration overloads a server using default timeouts during retries. (#1701, #1719) ## [1.5.2] - 2023-06-12 Same as 1.5.0. GitHub does not like gpg signed text in the tag message (even with prefixed armor), so pushing a new tag. ## [1.5.1] - 2023-06-12 Same as 1.5.0. GitHub does not like gpg signed text in the tag message, so pushing a new tag. ## [1.5.0] - 2023-06-12 ### Added - gocql now advertises the driver name and version in the STARTUP message to the server. The values are taken from the Go module's path and version (or from the replacement module, if used). (#1702) That allows the server to track which fork of the driver is being used. - Query.Values() to retrieve the values bound to the Query. This makes writing wrappers around Query easier. (#1700) ### Fixed - Potential panic on deserialization (#1695) - Unmarshalling of dates outside of `[1677-09-22, 2262-04-11]` range. (#1692) ## [1.4.0] - 2023-04-26 ### Added ### Changed - gocql now refreshes the entire ring when it receives a topology change event and when control connection is re-connected. This simplifies code managing ring state. (#1680) - Supported versions of Cassandra that we test against are now 4.0.x and 4.1.x. (#1685) - Default HostDialer now uses already-resolved connect address instead of hostname when establishing TCP connections (#1683). ### Fixed - Deadlock in Session.Close(). (#1688) - Race between Query.Release() and speculative executions (#1684) - Missed ring update during control connection reconnection (#1680) ## [1.3.2] - 2023-03-27 ### Changed - Supported versions of Go that we test against are now Go 1.19 and Go 1.20. ### Fixed - Node event handling now processes topology events before status events. This fixes some cases where new nodes were missed. (#1682) - Learning a new IP address for an existing node (identified by host ID) now triggers replacement of that host. This fixes some Kubernetes reconnection failures. (#1682) - Refresh ring when processing a node UP event for an unknown host. This fixes some cases where new nodes were missed. (#1669) ## [1.3.1] - 2022-12-13 ### Fixed - Panic in RackAwareRoundRobinPolicy caused by wrong alignment on 32-bit platforms. (#1666) ## [1.3.0] - 2022-11-29 ### Added - Added a RackAwareRoundRobinPolicy that attempts to keep client->server traffic in the same rack when possible. ### Changed - Supported versions of Go that we test against are now Go 1.18 and Go 1.19. ## [1.2.1] - 2022-09-02 ### Changed - GetCustomPayload now returns nil instead of panicking in case of query error. (#1385) ### Fixed - Nil pointer dereference in events.go when handling node removal. (#1652) - Reading peers from DataStax Enterprise clusters. This was a regression in 1.2.0. (#1646) - Unmarshaling maps did not pre-allocate the map. (#1642) ## [1.2.0] - 2022-07-07 This release improves support for connecting through proxies and some improvements when using Cassandra 4.0 or later. ### Added - HostDialer interface now allows customizing connection including TLS setup per host. (#1629) ### Changed - The driver now uses `host_id` instead of connect address to identify nodes. (#1632) - gocql reads `system.peers_v2` instead of `system.peers` when connected to Cassandra 4.0 or later and populates `HostInfo.Port` using the native port. (#1635) ### Fixed - Data race in `HostInfo.HostnameAndPort()`. (#1631) - Handling of nils when marshaling/unmarshaling lists and maps. (#1630) - Silent data corruption in case a map was serialized into UDT and some fields in the UDT were not present in the map. The driver now correctly writes nulls instead of shifting fields. (#1626, #1639) ## [1.1.0] - 2022-04-29 ### Added - Changelog. - StreamObserver and StreamObserverContext interfaces to allow observing CQL streams. - ClusterConfig.WriteTimeout option now allows to specify a write-timeout different from read-timeout. - TypeInfo.NewWithError method. ### Changed - Supported versions of Go that we test against are now Go 1.17 and Go 1.18. - The driver now returns an error if SetWriteDeadline fails. If you need to run gocql on a platform that does not support SetWriteDeadline, set WriteTimeout to zero to disable the timeout. - Creating streams on a connection that is closing now fails early. - HostFilter now also applies to control connections. - TokenAwareHostPolicy now panics immediately during initialization instead of at random point later if you reuse the TokenAwareHostPolicy between multiple sessions. Reusing TokenAwareHostPolicy between sessions was never supported. ### Fixed - The driver no longer resets the network connection if a write fails with non-network-related error. - Blocked network write to a network could block other goroutines, this is now fixed. - Fixed panic in unmarshalUDT when trying to unmarshal a user-defined-type to a non-pointer Go type. - Fixed panic when trying to unmarshal unknown/custom CQL type. ## Deprecated - TypeInfo.New, please use TypeInfo.NewWithError instead. ## [1.0.0] - 2022-03-04 ### Changed - Started tagging versions with semantic version tags cassandra-gocql-driver-1.7.0/CONTRIBUTING.md000066400000000000000000000114171467504044300203160ustar00rootroot00000000000000# Contributing to the Apache Cassandra GoCQL Driver **TL;DR** - this manifesto sets out the bare minimum requirements for submitting a patch to gocql. This guide outlines the process of landing patches in gocql and the general approach to maintaining the code base. ## Background The goal of the gocql project is to provide a stable and robust CQL driver for Go. This is a community driven project that is coordinated by a small team of developers in and around the Apache Cassandra project. For security, governance and administration issues please refer to the Cassandra Project Management Committee. ## Minimum Requirement Checklist The following is a check list of requirements that need to be satisfied in order for us to merge your patch: * You should raise a pull request to apache/cassandra-gocql-driver on Github * The pull request has a title that clearly summarizes the purpose of the patch * The motivation behind the patch is clearly defined in the pull request summary * You agree that your contribution is donated to the Apache Software Foundation (appropriate copyright is on all new files) * The patch will merge cleanly * The test coverage does not fall * The merge commit passes the regression test suite on GitHub Actions * `go fmt` has been applied to the submitted code * Notable changes (i.e. new features or changed behavior, bugfixes) are appropriately documented in CHANGELOG.md, functional changes also in godoc * A correctly formatted commit message, see below If there are any requirements that can't be reasonably satisfied, please state this either on the pull request or as part of discussion on the mailing list. Where appropriate, the core team may apply discretion and make an exception to these requirements. ## Commit Message The Apache Cassandra project has a commit message precendence like ``` patch by ; reviewed by for CASSANDRA-##### ``` The 'patch by …; reviewed by' line is important. It permits our review-than-commit procedure, allowing commits from non-git-branch patches. It is also parsed to build the project contribulyse statistics found [here](https://nightlies.apache.org/cassandra/devbranch/misc/contribulyze/html/). Background: https://cassandra.apache.org/_/development/how_to_commit.html#tips ## Beyond The Checklist In addition to stating the hard requirements, there are a bunch of things that we consider when assessing changes to the library. These soft requirements are helpful pointers of how to get a patch landed quicker and with less fuss. ### General QA Approach The Cassandra project needs to consider the ongoing maintainability of the library at all times. Patches that look like they will introduce maintenance issues for the team will not be accepted. Your patch will get merged quicker if you have decent test cases that provide test coverage for the new behavior you wish to introduce. Unit tests are good, integration tests are even better. An example of a unit test is `marshal_test.go` - this tests the serialization code in isolation. `cassandra_test.go` is an integration test suite that is executed against every version of Cassandra that gocql supports as part of the CI process on Travis. That said, the point of writing tests is to provide a safety net to catch regressions, so there is no need to go overboard with tests. Remember that the more tests you write, the more code we will have to maintain. So there's a balance to strike there. ### Sign Off Procedure Generally speaking, a pull request can get merged by any one of the project's committers. If your change is minor, chances are that one team member will just go ahead and merge it there and then. As stated earlier, suitable test coverage will increase the likelihood that a single reviewer will assess and merge your change. If your change has no test coverage, or looks like it may have wider implications for the health and stability of the library, the reviewer may elect to refer the change to another team member to achieve consensus before proceeding. Therefore, the tighter and cleaner your patch is, the quicker it will go through the review process. ### Supported Features gocql is a low level wire driver for Cassandra CQL. By and large, we would like to keep the functional scope of the library as narrow as possible. We think that gocql should be tight and focused, and we will be naturally skeptical of things that could just as easily be implemented in a higher layer. Inevitably you will come across something that could be implemented in a higher layer, save for a minor change to the core API. In this instance, please strike up a conversation in the Cassandra community. Chances are we will understand what you are trying to achieve and will try to accommodate this in a maintainable way. cassandra-gocql-driver-1.7.0/LICENSE000066400000000000000000000261371467504044300170770ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. cassandra-gocql-driver-1.7.0/NOTICE000066400000000000000000000172041467504044300167710ustar00rootroot00000000000000Apache Cassandra GoCQL Driver Copyright 2024 The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). This product originates, before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40, from software from the Gocql Authors, with copyright and license as follows: Copyright (c) 2016, The Gocql authors All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Where The Gocql Authors for copyright purposes are below. Those marked with asterisk have agreed to donate (copyright assign) their contributions to the Apache Software Foundation, signing CLAs when appropriate. Christoph Hack Jonathan Rudenberg * Thorsten von Eicken * Matt Robenolt Phillip Couto * Niklas Korz Nimi Wariboko Jr Ghais Issa * Sasha Klizhentas Konstantin Cherkasov Ben Hood <0x6e6562@gmail.com> Pete Hopkins Chris Bannister * Maxim Bublis Alex Zorin Kasper Middelboe Petersen Harpreet Sawhney Charlie Andrews * Stanislavs Koikovs Dan Forest Miguel Serrano * Stefan Radomski Josh Wright Jacob Rhoden Ben Frye Fred McCann * Dan Simmons * Muir Manders Sankar P * Julien Da Silva Dan Kennedy * Nick Dhupia Yasuharu Goto * Jeremy Schlatter * Matthias Kadenbach Dean Elbaz Mike Berman Dmitriy Fedorenko * Zach Marcantel James Maloney Ashwin Purohit * Dan Kinder * Oliver Beattie * Justin Corpron * Miles Delahunty Zach Badgett Maciek Sakrejda * Jeff Mitchell Baptiste Fontaine * Matt Heath * Jamie Cuthill Adrian Casajus * John Weldon * Adrien Bustany * Andrey Smirnov * Adam Weiner * Daniel Cannon Johnny Bergström Adriano Orioli * Claudiu Raveica * Artem Chernyshev * Ference Fu LOVOO nikandfor * Anthony Woods * Alexander Inozemtsev * Rob McColl ; * Viktor Tönköl * Ian Lozinski Michael Highstead * Sarah Brown * Caleb Doxsey * Frederic Hemery * Pekka Enberg * Mark M Bartosz Burclaf * Marcus King * Andrew de Andrade Robert Nix Nathan Youngman * Charles Law ; * Nathan Davies * Bo Blanton Vincent Rischmann * Jesse Claven * Derrick Wippler Leigh McCulloch Ron Kuris Raphael Gavache * Yasser Abdolmaleki Krishnanand Thommandra Blake Atkinson Dharmendra Parsaila Nayef Ghattas * Michał Matczuk * Ben Krebsbach * Vivian Mathews * Sascha Steinbiss * Seth Rosenblum * Javier Zunzunegui Luke Hines * Zhixin Wen * Chang Liu Ingo Oeser * Luke Hines Jacob Greenleaf Alex Lourie ; * Marco Cadetg * Karl Matthias * Thomas Meson * Martin Sucha ; * Pavel Buchinchik Rintaro Okamura * Yura Sokolov ; Jorge Bay * Dmitriy Kozlov * Alexey Romanovsky Jaume Marhuenda Beltran Piotr Dulikowski Árni Dagur * Tushar Das * Maxim Vladimirskiy * Bogdan-Ciprian Rusu Yuto Doi * Krishna Vadali Jens-W. Schicke-Uffmann * Ondrej Polakovič * Sergei Karetnikov * Stefan Miklosovic * Adam Burk * Valerii Ponomarov * Neal Turett * Doug Schaapveld * Steven Seidman Wojciech Przytuła * João Reis * Lauro Ramos Venancio Dmitry Kropachev Oliver Boyle * Jackson Fleming * Sylwia Szunejko * cassandra-gocql-driver-1.7.0/README.md000066400000000000000000000221171467504044300173430ustar00rootroot00000000000000Apache Cassandra GoCQL Driver ===== [!Join the chat at https://the-asf.slack.com/archives/C05LPRVNZV1](https://the-asf.slack.com/archives/C05LPRVNZV1) ![go build](https://github.com/apache/cassandra-gocql-driver/actions/workflows/main.yml/badge.svg) [![GoDoc](https://godoc.org/github.com/gocql/gocql?status.svg)](https://godoc.org/github.com/gocql/gocql) Package gocql implements a fast and robust Cassandra client for the Go programming language. Project Website: https://cassandra.apache.org
API documentation: https://godoc.org/github.com/gocql/gocql
Discussions: https://cassandra.apache.org/_/community.html#discussions Supported Versions ------------------ The following matrix shows the versions of Go and Cassandra that are tested with the integration test suite as part of the CI build: | Go/Cassandra | 4.0.x | 4.1.x | |--------------|-------|-------| | 1.22 | yes | yes | | 1.23 | yes | yes | Gocql has been tested in production against many versions of Cassandra. Due to limits in our CI setup we only test against the latest 2 GA releases. Sunsetting Model ---------------- In general, the Cassandra community will focus on supporting the current and previous versions of Go. gocql may still work with older versions of Go, but official support for these versions will have been sunset. Installation ------------ go get github.com/gocql/gocql Features -------- * Modern Cassandra client using the native transport * Automatic type conversions between Cassandra and Go * Support for all common types including sets, lists and maps * Custom types can implement a `Marshaler` and `Unmarshaler` interface * Strict type conversions without any loss of precision * Built-In support for UUIDs (version 1 and 4) * Support for logged, unlogged and counter batches * Cluster management * Automatic reconnect on connection failures with exponential falloff * Round robin distribution of queries to different hosts * Round robin distribution of queries to different connections on a host * Each connection can execute up to n concurrent queries (whereby n is the limit set by the protocol version the client chooses to use) * Optional automatic discovery of nodes * Policy based connection pool with token aware and round-robin policy implementations * Support for password authentication * Iteration over paged results with configurable page size * Support for TLS/SSL * Optional frame compression (using snappy) * Automatic query preparation * Support for query tracing * Support for Cassandra 2.1+ [binary protocol version 3](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v3.spec) * Support for up to 32768 streams * Support for tuple types * Support for client side timestamps by default * Support for UDTs via a custom marshaller or struct tags * Support for Cassandra 3.0+ [binary protocol version 4](https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec) * An API to access the schema metadata of a given keyspace Performance ----------- While the driver strives to be highly performant, there are cases where it is difficult to test and verify. The driver is built with maintainability and code readability in mind first and then performance and features, as such every now and then performance may degrade, if this occurs please report and issue and it will be looked at and remedied. The only time the driver copies data from its read buffer is when it Unmarshal's data into supplied types. Some tips for getting more performance from the driver: * Use the TokenAware policy * Use many goroutines when doing inserts, the driver is asynchronous but provides a synchronous API, it can execute many queries concurrently * Tune query page size * Reading data from the network to unmarshal will incur a large amount of allocations, this can adversely affect the garbage collector, tune `GOGC` * Close iterators after use to recycle byte buffers Important Default Keyspace Changes ---------------------------------- gocql no longer supports executing "use " statements to simplify the library. The user still has the ability to define the default keyspace for connections but now the keyspace can only be defined before a session is created. Queries can still access keyspaces by indicating the keyspace in the query: ```sql SELECT * FROM example2.table; ``` Example of correct usage: ```go cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") cluster.Keyspace = "example" ... session, err := cluster.CreateSession() ``` Example of incorrect usage: ```go cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") cluster.Keyspace = "example" ... session, err := cluster.CreateSession() if err = session.Query("use example2").Exec(); err != nil { log.Fatal(err) } ``` This will result in an err being returned from the session.Query line as the user is trying to execute a "use" statement. Example ------- See [package documentation](https://pkg.go.dev/github.com/gocql/gocql#pkg-examples). Data Binding ------------ There are various ways to bind application level data structures to CQL statements: * You can write the data binding by hand, as outlined in the Tweet example. This provides you with the greatest flexibility, but it does mean that you need to keep your application code in sync with your Cassandra schema. * You can dynamically marshal an entire query result into an `[]map[string]interface{}` using the `SliceMap()` API. This returns a slice of row maps keyed by CQL column names. This method requires no special interaction with the gocql API, but it does require your application to be able to deal with a key value view of your data. * As a refinement on the `SliceMap()` API you can also call `MapScan()` which returns `map[string]interface{}` instances in a row by row fashion. * The `Bind()` API provides a client app with a low level mechanism to introspect query meta data and extract appropriate field values from application level data structures. * The [gocqlx](https://github.com/scylladb/gocqlx) package is an idiomatic extension to gocql that provides usability features. With gocqlx you can bind the query parameters from maps and structs, use named query parameters (:identifier) and scan the query results into structs and slices. It comes with a fluent and flexible CQL query builder that supports full CQL spec, including BATCH statements and custom functions. * Building on top of the gocql driver, [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement. * Another external project that layers on top of gocql is [cqlc](http://relops.com/cqlc) which generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax. * [gocassa](https://github.com/hailocab/gocassa) is an external project that layers on top of gocql to provide convenient query building and data binding. * [gocqltable](https://github.com/kristoiv/gocqltable) provides an ORM-style convenience layer to make CRUD operations with gocql easier. Ecosystem --------- The following community maintained tools are known to integrate with gocql: * [gocqlx](https://github.com/scylladb/gocqlx) is a gocql extension that automates data binding, adds named queries support, provides flexible query builders and plays well with gocql. * [journey](https://github.com/db-journey/journey) is a migration tool with Cassandra support. * [negronicql](https://github.com/mikebthun/negronicql) is gocql middleware for Negroni. * [cqlr](https://github.com/relops/cqlr) adds the ability to auto-bind a CQL iterator to a struct or to bind a struct to an INSERT statement. * [cqlc](http://relops.com/cqlc) generates gocql compliant code from your Cassandra schema so that you can write type safe CQL statements in Go with a natural query syntax. * [gocassa](https://github.com/hailocab/gocassa) provides query building, adds data binding, and provides easy-to-use "recipe" tables for common query use-cases. * [gocqltable](https://github.com/kristoiv/gocqltable) is a wrapper around gocql that aims to simplify common operations. * [gockle](https://github.com/willfaught/gockle) provides simple, mockable interfaces that wrap gocql types * [scylladb](https://github.com/scylladb/scylla) is a fast Apache Cassandra-compatible NoSQL database * [go-cql-driver](https://github.com/MichaelS11/go-cql-driver) is an CQL driver conforming to the built-in database/sql interface. It is good for simple use cases where the database/sql interface is wanted. The CQL driver is a wrapper around this project. Other Projects -------------- * [gocqldriver](https://github.com/tux21b/gocqldriver) is the predecessor of gocql based on Go's `database/sql` package. This project isn't maintained anymore, because Cassandra wasn't a good fit for the traditional `database/sql` API. Use this package instead. SEO --- For some reason, when you Google `golang cassandra`, this project doesn't feature very highly in the result list. But if you Google `go cassandra`, then we're a bit higher up the list. So this is note to try to convince Google that golang is an alias for Go. cassandra-gocql-driver-1.7.0/address_translators.go000066400000000000000000000037201467504044300224730ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import "net" // AddressTranslator provides a way to translate node addresses (and ports) that are // discovered or received as a node event. This can be useful in an ec2 environment, // for instance, to translate public IPs to private IPs. type AddressTranslator interface { // Translate will translate the provided address and/or port to another // address and/or port. If no translation is possible, Translate will return the // address and port provided to it. Translate(addr net.IP, port int) (net.IP, int) } type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int) func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) { return fn(addr, port) } // IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op. func IdentityTranslator() AddressTranslator { return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) { return addr, port }) } cassandra-gocql-driver-1.7.0/address_translators_test.go000066400000000000000000000036301467504044300235320ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "net" "testing" ) func TestIdentityAddressTranslator_NilAddrAndZeroPort(t *testing.T) { var tr AddressTranslator = IdentityTranslator() hostIP := net.ParseIP("") if hostIP != nil { t.Errorf("expected host ip to be (nil) but was (%+v) instead", hostIP) } addr, port := tr.Translate(hostIP, 0) if addr != nil { t.Errorf("expected translated host to be (nil) but was (%+v) instead", addr) } assertEqual(t, "translated port", 0, port) } func TestIdentityAddressTranslator_HostProvided(t *testing.T) { var tr AddressTranslator = IdentityTranslator() hostIP := net.ParseIP("10.1.2.3") if hostIP == nil { t.Error("expected host ip not to be (nil)") } addr, port := tr.Translate(hostIP, 9042) if !hostIP.Equal(addr) { t.Errorf("expected translated addr to be (%+v) but was (%+v) instead", hostIP, addr) } assertEqual(t, "translated port", 9042, port) } cassandra-gocql-driver-1.7.0/batch_test.go000066400000000000000000000047751467504044300205450ustar00rootroot00000000000000//go:build all || cassandra // +build all cassandra /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "testing" "time" ) func TestBatch_Errors(t *testing.T) { if *flagProto == 1 { } session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion2 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.batch_errors (id int primary key, val inet)`); err != nil { t.Fatal(err) } b := session.NewBatch(LoggedBatch) b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil) if err := session.ExecuteBatch(b); err == nil { t.Fatal("expected to get error for invalid query in batch") } } func TestBatch_WithTimestamp(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("Batch timestamps are only available on protocol >= 3") } if err := createTable(session, `CREATE TABLE gocql_test.batch_ts (id int primary key, val text)`); err != nil { t.Fatal(err) } micros := time.Now().UnixNano()/1e3 - 1000 b := session.NewBatch(LoggedBatch) b.WithTimestamp(micros) b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val") if err := session.ExecuteBatch(b); err != nil { t.Fatal(err) } var storedTs int64 if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil { t.Fatal(err) } if storedTs != micros { t.Errorf("got ts %d, expected %d", storedTs, micros) } } cassandra-gocql-driver-1.7.0/cass1batch_test.go000066400000000000000000000052161467504044300214670ustar00rootroot00000000000000//go:build all || cassandra // +build all cassandra /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "strings" "testing" ) func TestProto1BatchInsert(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.large (id int primary key)"); err != nil { t.Fatal(err) } begin := "BEGIN BATCH" end := "APPLY BATCH" query := "INSERT INTO large (id) VALUES (?)" fullQuery := strings.Join([]string{begin, query, end}, "\n") args := []interface{}{5} if err := session.Query(fullQuery, args...).Consistency(Quorum).Exec(); err != nil { t.Fatal(err) } } func TestShouldPrepareFunction(t *testing.T) { var shouldPrepareTests = []struct { Stmt string Result bool }{ {` BEGIN BATCH INSERT INTO users (userID, password) VALUES ('smith', 'secret') APPLY BATCH ; `, true}, {`INSERT INTO users (userID, password, name) VALUES ('user2', 'ch@ngem3b', 'second user')`, true}, {`BEGIN COUNTER BATCH UPDATE stats SET views = views + 1 WHERE pageid = 1 APPLY BATCH`, true}, {`delete name from users where userID = 'smith';`, true}, {` UPDATE users SET password = 'secret' WHERE userID = 'smith' `, true}, {`CREATE TABLE users ( user_name varchar PRIMARY KEY, password varchar, gender varchar, session_token varchar, state varchar, birth_year bigint );`, false}, } for _, test := range shouldPrepareTests { q := &Query{stmt: test.Stmt, routingInfo: &queryRoutingInfo{}} if got := q.shouldPrepare(); got != test.Result { t.Fatalf("%q: got %v, expected %v\n", test.Stmt, got, test.Result) } } } cassandra-gocql-driver-1.7.0/cassandra_test.go000066400000000000000000003165241467504044300214210ustar00rootroot00000000000000//go:build all || cassandra // +build all cassandra /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "context" "errors" "fmt" "io" "math" "math/big" "net" "reflect" "strconv" "strings" "sync" "testing" "time" "unicode" inf "gopkg.in/inf.v0" ) func TestEmptyHosts(t *testing.T) { cluster := createCluster() cluster.Hosts = nil if session, err := cluster.CreateSession(); err == nil { session.Close() t.Error("expected err, got nil") } } func TestInvalidPeerEntry(t *testing.T) { t.Skip("dont mutate system tables, rewrite this to test what we mean to test") session := createSession(t) // rack, release_version, schema_version, tokens are all null query := session.Query("INSERT into system.peers (peer, data_center, host_id, rpc_address) VALUES (?, ?, ?, ?)", "169.254.235.45", "datacenter1", "35c0ec48-5109-40fd-9281-9e9d4add2f1e", "169.254.235.45", ) if err := query.Exec(); err != nil { t.Fatal(err) } session.Close() cluster := createCluster() cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy()) session = createSessionFromCluster(cluster, t) defer func() { session.Query("DELETE from system.peers where peer = ?", "169.254.235.45").Exec() session.Close() }() // check we can perform a query iter := session.Query("select peer from system.peers").Iter() var peer string for iter.Scan(&peer) { } if err := iter.Close(); err != nil { t.Fatal(err) } } // TestUseStatementError checks to make sure the correct error is returned when the user tries to execute a use statement. func TestUseStatementError(t *testing.T) { session := createSession(t) defer session.Close() if err := session.Query("USE gocql_test").Exec(); err != nil { if err != ErrUseStmt { t.Fatalf("expected ErrUseStmt, got " + err.Error()) } } else { t.Fatal("expected err, got nil.") } } // TestInvalidKeyspace checks that an invalid keyspace will return promptly and without a flood of connections func TestInvalidKeyspace(t *testing.T) { cluster := createCluster() cluster.Keyspace = "invalidKeyspace" session, err := cluster.CreateSession() if err != nil { if err != ErrNoConnectionsStarted { t.Fatalf("Expected ErrNoConnections but got %v", err) } } else { session.Close() //Clean up the session t.Fatal("expected err, got nil.") } } func TestTracing(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.trace (id int primary key)`); err != nil { t.Fatal("create:", err) } buf := &bytes.Buffer{} trace := &traceWriter{session: session, w: buf} if err := session.Query(`INSERT INTO trace (id) VALUES (?)`, 42).Trace(trace).Exec(); err != nil { t.Fatal("insert:", err) } else if buf.Len() == 0 { t.Fatal("insert: failed to obtain any tracing") } trace.mu.Lock() buf.Reset() trace.mu.Unlock() var value int if err := session.Query(`SELECT id FROM trace WHERE id = ?`, 42).Trace(trace).Scan(&value); err != nil { t.Fatal("select:", err) } else if value != 42 { t.Fatalf("value: expected %d, got %d", 42, value) } else if buf.Len() == 0 { t.Fatal("select: failed to obtain any tracing") } // also works from session tracer session.SetTrace(trace) trace.mu.Lock() buf.Reset() trace.mu.Unlock() if err := session.Query(`SELECT id FROM trace WHERE id = ?`, 42).Scan(&value); err != nil { t.Fatal("select:", err) } if buf.Len() == 0 { t.Fatal("select: failed to obtain any tracing") } } func TestObserve(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.observe (id int primary key)`); err != nil { t.Fatal("create:", err) } var ( observedErr error observedKeyspace string observedStmt string ) const keyspace = "gocql_test" resetObserved := func() { observedErr = errors.New("placeholder only") // used to distinguish err=nil cases observedKeyspace = "" observedStmt = "" } observer := funcQueryObserver(func(ctx context.Context, o ObservedQuery) { observedKeyspace = o.Keyspace observedStmt = o.Statement observedErr = o.Err }) // select before inserted, will error but the reporting is err=nil as the query is valid resetObserved() var value int if err := session.Query(`SELECT id FROM observe WHERE id = ?`, 43).Observer(observer).Scan(&value); err == nil { t.Fatal("select: expected error") } else if observedErr != nil { t.Fatalf("select: observed error expected nil, got %q", observedErr) } else if observedKeyspace != keyspace { t.Fatal("select: unexpected observed keyspace", observedKeyspace) } else if observedStmt != `SELECT id FROM observe WHERE id = ?` { t.Fatal("select: unexpected observed stmt", observedStmt) } resetObserved() if err := session.Query(`INSERT INTO observe (id) VALUES (?)`, 42).Observer(observer).Exec(); err != nil { t.Fatal("insert:", err) } else if observedErr != nil { t.Fatal("insert:", observedErr) } else if observedKeyspace != keyspace { t.Fatal("insert: unexpected observed keyspace", observedKeyspace) } else if observedStmt != `INSERT INTO observe (id) VALUES (?)` { t.Fatal("insert: unexpected observed stmt", observedStmt) } resetObserved() value = 0 if err := session.Query(`SELECT id FROM observe WHERE id = ?`, 42).Observer(observer).Scan(&value); err != nil { t.Fatal("select:", err) } else if value != 42 { t.Fatalf("value: expected %d, got %d", 42, value) } else if observedErr != nil { t.Fatal("select:", observedErr) } else if observedKeyspace != keyspace { t.Fatal("select: unexpected observed keyspace", observedKeyspace) } else if observedStmt != `SELECT id FROM observe WHERE id = ?` { t.Fatal("select: unexpected observed stmt", observedStmt) } // also works from session observer resetObserved() oSession := createSession(t, func(config *ClusterConfig) { config.QueryObserver = observer }) if err := oSession.Query(`SELECT id FROM observe WHERE id = ?`, 42).Scan(&value); err != nil { t.Fatal("select:", err) } else if observedErr != nil { t.Fatal("select:", err) } else if observedKeyspace != keyspace { t.Fatal("select: unexpected observed keyspace", observedKeyspace) } else if observedStmt != `SELECT id FROM observe WHERE id = ?` { t.Fatal("select: unexpected observed stmt", observedStmt) } // reports errors when the query is poorly formed resetObserved() value = 0 if err := session.Query(`SELECT id FROM unknown_table WHERE id = ?`, 42).Observer(observer).Scan(&value); err == nil { t.Fatal("select: expecting error") } else if observedErr == nil { t.Fatal("select: expecting observed error") } else if observedKeyspace != keyspace { t.Fatal("select: unexpected observed keyspace", observedKeyspace) } else if observedStmt != `SELECT id FROM unknown_table WHERE id = ?` { t.Fatal("select: unexpected observed stmt", observedStmt) } } func TestObserve_Pagination(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.observe2 (id int, PRIMARY KEY (id))`); err != nil { t.Fatal("create:", err) } var observedRows int resetObserved := func() { observedRows = -1 } observer := funcQueryObserver(func(ctx context.Context, o ObservedQuery) { observedRows = o.Rows }) // insert 100 entries, relevant for pagination for i := 0; i < 50; i++ { if err := session.Query(`INSERT INTO observe2 (id) VALUES (?)`, i).Exec(); err != nil { t.Fatal("insert:", err) } } resetObserved() // read the 100 entries in paginated entries of size 10. Expecting 5 observations, each with 10 rows scanner := session.Query(`SELECT id FROM observe2 LIMIT 100`). Observer(observer). PageSize(10). Iter().Scanner() for i := 0; i < 50; i++ { if !scanner.Next() { t.Fatalf("next: should still be true: %d: %v", i, scanner.Err()) } if i%10 == 0 { if observedRows != 10 { t.Fatalf("next: expecting a paginated query with 10 entries, got: %d (%d)", observedRows, i) } } else if observedRows != -1 { t.Fatalf("next: not expecting paginated query (-1 entries), got: %d", observedRows) } resetObserved() } if scanner.Next() { t.Fatal("next: no more entries where expected") } } func TestPaging(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("Paging not supported. Please use Cassandra >= 2.0") } if err := createTable(session, "CREATE TABLE gocql_test.paging (id int primary key)"); err != nil { t.Fatal("create table:", err) } for i := 0; i < 100; i++ { if err := session.Query("INSERT INTO paging (id) VALUES (?)", i).Exec(); err != nil { t.Fatal("insert:", err) } } iter := session.Query("SELECT id FROM paging").PageSize(10).Iter() var id int count := 0 for iter.Scan(&id) { count++ } if err := iter.Close(); err != nil { t.Fatal("close:", err) } if count != 100 { t.Fatalf("expected %d, got %d", 100, count) } } func TestPagingWithBind(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("Paging not supported. Please use Cassandra >= 2.0") } if err := createTable(session, "CREATE TABLE gocql_test.paging_bind (id int, val int, primary key(id,val))"); err != nil { t.Fatal("create table:", err) } for i := 0; i < 100; i++ { if err := session.Query("INSERT INTO paging_bind (id,val) VALUES (?,?)", 1, i).Exec(); err != nil { t.Fatal("insert:", err) } } q := session.Query("SELECT val FROM paging_bind WHERE id = ? AND val < ?", 1, 50).PageSize(10) iter := q.Iter() var id int count := 0 for iter.Scan(&id) { count++ } if err := iter.Close(); err != nil { t.Fatal("close:", err) } if count != 50 { t.Fatalf("expected %d, got %d", 50, count) } iter = q.Bind(1, 20).Iter() count = 0 for iter.Scan(&id) { count++ } if count != 20 { t.Fatalf("expected %d, got %d", 20, count) } if err := iter.Close(); err != nil { t.Fatal("close:", err) } } func TestCAS(t *testing.T) { cluster := createCluster() cluster.SerialConsistency = LocalSerial session := createSessionFromCluster(cluster, t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("lightweight transactions not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.cas_table ( title varchar, revid timeuuid, last_modified timestamp, PRIMARY KEY (title, revid) )`); err != nil { t.Fatal("create:", err) } title, revid, modified := "baz", TimeUUID(), time.Now() var titleCAS string var revidCAS UUID var modifiedCAS time.Time if applied, err := session.Query(`INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS`, title, revid, modified).ScanCAS(&titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) } else if !applied { t.Fatal("insert should have been applied") } if applied, err := session.Query(`INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS`, title, revid, modified).ScanCAS(&titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) } else if applied { t.Fatal("insert should not have been applied") } else if title != titleCAS || revid != revidCAS { t.Fatalf("expected %s/%v/%v but got %s/%v/%v", title, revid, modified, titleCAS, revidCAS, modifiedCAS) } tenSecondsLater := modified.Add(10 * time.Second) if applied, err := session.Query(`DELETE FROM cas_table WHERE title = ? and revid = ? IF last_modified = ?`, title, revid, tenSecondsLater).ScanCAS(&modifiedCAS); err != nil { t.Fatal("delete:", err) } else if applied { t.Fatal("delete should have not been applied") } if modifiedCAS.Unix() != tenSecondsLater.Add(-10*time.Second).Unix() { t.Fatalf("Was expecting modified CAS to be %v; but was one second later", modifiedCAS.UTC()) } if _, err := session.Query(`DELETE FROM cas_table WHERE title = ? and revid = ? IF last_modified = ?`, title, revid, tenSecondsLater).ScanCAS(); !strings.HasPrefix(err.Error(), "gocql: not enough columns to scan into") { t.Fatalf("delete: was expecting count mismatch error but got: %q", err.Error()) } if applied, err := session.Query(`DELETE FROM cas_table WHERE title = ? and revid = ? IF last_modified = ?`, title, revid, modified).ScanCAS(&modifiedCAS); err != nil { t.Fatal("delete:", err) } else if !applied { t.Fatal("delete should have been applied") } if err := session.Query(`TRUNCATE cas_table`).Exec(); err != nil { t.Fatal("truncate:", err) } successBatch := session.NewBatch(LoggedBatch) successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified) if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) } else if !applied { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } successBatch = session.NewBatch(LoggedBatch) successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified) casMap := make(map[string]interface{}) if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil { t.Fatal("insert:", err) } else if !applied { t.Fatal("insert should have been applied") } failBatch := session.NewBatch(LoggedBatch) failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified) if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) } else if applied { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } insertBatch := session.NewBatch(LoggedBatch) insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") if err := session.ExecuteBatch(insertBatch); err != nil { t.Fatal("insert:", err) } failBatch = session.NewBatch(LoggedBatch) failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());") failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());") if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) } else if applied { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } else { if scan := iter.Scan(&applied, &titleCAS, &revidCAS, &modifiedCAS); scan && applied { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } else if !scan { t.Fatal("should have scanned another row") } if err := iter.Close(); err != nil { t.Fatal("scan:", err) } } } func TestDurationType(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < 5 { t.Skip("Duration type is not supported. Please use protocol version >= 4 and cassandra version >= 3.11") } if err := createTable(session, `CREATE TABLE gocql_test.duration_table ( k int primary key, v duration )`); err != nil { t.Fatal("create:", err) } durations := []Duration{ Duration{ Months: 250, Days: 500, Nanoseconds: 300010001, }, Duration{ Months: -250, Days: -500, Nanoseconds: -300010001, }, Duration{ Months: 0, Days: 128, Nanoseconds: 127, }, Duration{ Months: 0x7FFFFFFF, Days: 0x7FFFFFFF, Nanoseconds: 0x7FFFFFFFFFFFFFFF, }, } for _, durationSend := range durations { if err := session.Query(`INSERT INTO gocql_test.duration_table (k, v) VALUES (1, ?)`, durationSend).Exec(); err != nil { t.Fatal(err) } var id int var duration Duration if err := session.Query(`SELECT k, v FROM gocql_test.duration_table`).Scan(&id, &duration); err != nil { t.Fatal(err) } if duration.Months != durationSend.Months || duration.Days != durationSend.Days || duration.Nanoseconds != durationSend.Nanoseconds { t.Fatalf("Unexpeted value returned, expected=%v, received=%v", durationSend, duration) } } } func TestMapScanCAS(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("lightweight transactions not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.cas_table2 ( title varchar, revid timeuuid, last_modified timestamp, deleted boolean, PRIMARY KEY (title, revid) )`); err != nil { t.Fatal("create:", err) } title, revid, modified, deleted := "baz", TimeUUID(), time.Now(), false mapCAS := map[string]interface{}{} if applied, err := session.Query(`INSERT INTO cas_table2 (title, revid, last_modified, deleted) VALUES (?, ?, ?, ?) IF NOT EXISTS`, title, revid, modified, deleted).MapScanCAS(mapCAS); err != nil { t.Fatal("insert:", err) } else if !applied { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", title, revid, modified) } mapCAS = map[string]interface{}{} if applied, err := session.Query(`INSERT INTO cas_table2 (title, revid, last_modified, deleted) VALUES (?, ?, ?, ?) IF NOT EXISTS`, title, revid, modified, deleted).MapScanCAS(mapCAS); err != nil { t.Fatal("insert:", err) } else if applied { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", title, revid, modified) } else if title != mapCAS["title"] || revid != mapCAS["revid"] || deleted != mapCAS["deleted"] { t.Fatalf("expected %s/%v/%v/%v but got %s/%v/%v%v", title, revid, modified, false, mapCAS["title"], mapCAS["revid"], mapCAS["last_modified"], mapCAS["deleted"]) } } func TestBatch(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.batch_table (id int primary key)`); err != nil { t.Fatal("create table:", err) } batch := session.NewBatch(LoggedBatch) for i := 0; i < 100; i++ { batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i) } if err := session.ExecuteBatch(batch); err != nil { t.Fatal("execute batch:", err) } count := 0 if err := session.Query(`SELECT COUNT(*) FROM batch_table`).Scan(&count); err != nil { t.Fatal("select count:", err) } else if count != 100 { t.Fatalf("count: expected %d, got %d\n", 100, count) } } func TestUnpreparedBatch(t *testing.T) { t.Skip("FLAKE skipping") session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.batch_unprepared (id int primary key, c counter)`); err != nil { t.Fatal("create table:", err) } var batch *Batch if session.cfg.ProtoVersion == 2 { batch = session.NewBatch(CounterBatch) } else { batch = session.NewBatch(UnloggedBatch) } for i := 0; i < 100; i++ { batch.Query(`UPDATE batch_unprepared SET c = c + 1 WHERE id = 1`) } if err := session.ExecuteBatch(batch); err != nil { t.Fatal("execute batch:", err) } count := 0 if err := session.Query(`SELECT COUNT(*) FROM batch_unprepared`).Scan(&count); err != nil { t.Fatal("select count:", err) } else if count != 1 { t.Fatalf("count: expected %d, got %d\n", 100, count) } if err := session.Query(`SELECT c FROM batch_unprepared`).Scan(&count); err != nil { t.Fatal("select count:", err) } else if count != 100 { t.Fatalf("count: expected %d, got %d\n", 100, count) } } // TestBatchLimit tests gocql to make sure batch operations larger than the maximum // statement limit are not submitted to a cassandra node. func TestBatchLimit(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.batch_table2 (id int primary key)`); err != nil { t.Fatal("create table:", err) } batch := session.NewBatch(LoggedBatch) for i := 0; i < 65537; i++ { batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i) } if err := session.ExecuteBatch(batch); err != ErrTooManyStmts { t.Fatal("gocql attempted to execute a batch larger than the support limit of statements.") } } func TestWhereIn(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.where_in_table (id int, cluster int, primary key (id,cluster))`); err != nil { t.Fatal("create table:", err) } if err := session.Query("INSERT INTO where_in_table (id, cluster) VALUES (?,?)", 100, 200).Exec(); err != nil { t.Fatal("insert:", err) } iter := session.Query("SELECT * FROM where_in_table WHERE id = ? AND cluster IN (?)", 100, 200).Iter() var id, cluster int count := 0 for iter.Scan(&id, &cluster) { count++ } if id != 100 || cluster != 200 { t.Fatalf("Was expecting id and cluster to be (100,200) but were (%d,%d)", id, cluster) } } // TestTooManyQueryArgs tests to make sure the library correctly handles the application level bug // whereby too many query arguments are passed to a query func TestTooManyQueryArgs(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.too_many_query_args (id int primary key, value int)`); err != nil { t.Fatal("create table:", err) } _, err := session.Query(`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2).Iter().SliceMap() if err == nil { t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error") } batch := session.NewBatch(UnloggedBatch) batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3) err = session.ExecuteBatch(batch) if err == nil { t.Fatal("'`INSERT INTO too_many_query_args (id, value) VALUES (?, ?)`, 1, 2, 3' should return an error") } // TODO: should indicate via an error code that it is an invalid arg? } // TestNotEnoughQueryArgs tests to make sure the library correctly handles the application level bug // whereby not enough query arguments are passed to a query func TestNotEnoughQueryArgs(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.not_enough_query_args (id int, cluster int, value int, primary key (id, cluster))`); err != nil { t.Fatal("create table:", err) } _, err := session.Query(`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1).Iter().SliceMap() if err == nil { t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error") } batch := session.NewBatch(UnloggedBatch) batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2) err = session.ExecuteBatch(batch) if err == nil { t.Fatal("'`INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)`, 1, 2' should return an error") } } // TestCreateSessionTimeout tests to make sure the CreateSession function timeouts out correctly // and prevents an infinite loop of connection retries. func TestCreateSessionTimeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { select { case <-time.After(2 * time.Second): t.Error("no startup timeout") case <-ctx.Done(): } }() cluster := createCluster() cluster.Hosts = []string{"127.0.0.1:1"} session, err := cluster.CreateSession() if err == nil { session.Close() t.Fatal("expected ErrNoConnectionsStarted, but no error was returned.") } } func TestReconnection(t *testing.T) { cluster := createCluster() cluster.ReconnectInterval = 1 * time.Second session := createSessionFromCluster(cluster, t) defer session.Close() h := session.ring.allHosts()[0] session.handleNodeDown(h.ConnectAddress(), h.Port()) if h.State() != NodeDown { t.Fatal("Host should be NodeDown but not.") } time.Sleep(cluster.ReconnectInterval + h.Version().nodeUpDelay() + 1*time.Second) if h.State() != NodeUp { t.Fatal("Host should be NodeUp but not. Failed to reconnect.") } } type FullName struct { FirstName string LastName string } func (n FullName) MarshalCQL(info TypeInfo) ([]byte, error) { return []byte(n.FirstName + " " + n.LastName), nil } func (n *FullName) UnmarshalCQL(info TypeInfo, data []byte) error { t := strings.SplitN(string(data), " ", 2) n.FirstName, n.LastName = t[0], t[1] return nil } func TestMapScanWithRefMap(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.scan_map_ref_table ( testtext text PRIMARY KEY, testfullname text, testint int, )`); err != nil { t.Fatal("create table:", err) } m := make(map[string]interface{}) m["testtext"] = "testtext" m["testfullname"] = FullName{"John", "Doe"} m["testint"] = 100 if err := session.Query(`INSERT INTO scan_map_ref_table (testtext, testfullname, testint) values (?,?,?)`, m["testtext"], m["testfullname"], m["testint"]).Exec(); err != nil { t.Fatal("insert:", err) } var testText string var testFullName FullName ret := map[string]interface{}{ "testtext": &testText, "testfullname": &testFullName, // testint is not set here. } iter := session.Query(`SELECT * FROM scan_map_ref_table`).Iter() if ok := iter.MapScan(ret); !ok { t.Fatal("select:", iter.Close()) } else { if ret["testtext"] != "testtext" { t.Fatal("returned testtext did not match") } f := ret["testfullname"].(FullName) if f.FirstName != "John" || f.LastName != "Doe" { t.Fatal("returned testfullname did not match") } if ret["testint"] != 100 { t.Fatal("returned testinit did not match") } } if testText != "testtext" { t.Fatal("returned testtext did not match") } if testFullName.FirstName != "John" || testFullName.LastName != "Doe" { t.Fatal("returned testfullname did not match") } // using MapScan to read a nil int value intp := new(int64) ret = map[string]interface{}{ "testint": &intp, } if err := session.Query("INSERT INTO scan_map_ref_table(testtext, testint) VALUES(?, ?)", "null-int", nil).Exec(); err != nil { t.Fatal(err) } err := session.Query(`SELECT testint FROM scan_map_ref_table WHERE testtext = ?`, "null-int").MapScan(ret) if err != nil { t.Fatal(err) } else if v := ret["testint"].(*int64); v != nil { t.Fatalf("testint should be nil got %+#v", v) } } func TestMapScan(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.scan_map_table ( fullname text PRIMARY KEY, age int, address inet, )`); err != nil { t.Fatal("create table:", err) } if err := session.Query(`INSERT INTO scan_map_table (fullname, age, address) values (?,?,?)`, "Grace Hopper", 31, net.ParseIP("10.0.0.1")).Exec(); err != nil { t.Fatal("insert:", err) } if err := session.Query(`INSERT INTO scan_map_table (fullname, age, address) values (?,?,?)`, "Ada Lovelace", 30, net.ParseIP("10.0.0.2")).Exec(); err != nil { t.Fatal("insert:", err) } iter := session.Query(`SELECT * FROM scan_map_table`).Iter() // First iteration row := make(map[string]interface{}) if !iter.MapScan(row) { t.Fatal("select:", iter.Close()) } assertEqual(t, "fullname", "Ada Lovelace", row["fullname"]) assertEqual(t, "age", 30, row["age"]) assertEqual(t, "address", "10.0.0.2", row["address"]) // Second iteration using a new map row = make(map[string]interface{}) if !iter.MapScan(row) { t.Fatal("select:", iter.Close()) } assertEqual(t, "fullname", "Grace Hopper", row["fullname"]) assertEqual(t, "age", 31, row["age"]) assertEqual(t, "address", "10.0.0.1", row["address"]) } func TestSliceMap(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.slice_map_table ( testuuid timeuuid PRIMARY KEY, testtimestamp timestamp, testvarchar varchar, testbigint bigint, testblob blob, testbool boolean, testfloat float, testdouble double, testint int, testdecimal decimal, testlist list, testset set, testmap map, testvarint varint, testinet inet )`); err != nil { t.Fatal("create table:", err) } m := make(map[string]interface{}) bigInt := new(big.Int) if _, ok := bigInt.SetString("830169365738487321165427203929228", 10); !ok { t.Fatal("Failed setting bigint by string") } m["testuuid"] = TimeUUID() m["testvarchar"] = "Test VarChar" m["testbigint"] = time.Now().Unix() m["testtimestamp"] = time.Now().Truncate(time.Millisecond).UTC() m["testblob"] = []byte("test blob") m["testbool"] = true m["testfloat"] = float32(4.564) m["testdouble"] = float64(4.815162342) m["testint"] = 2343 m["testdecimal"] = inf.NewDec(100, 0) m["testlist"] = []string{"quux", "foo", "bar", "baz", "quux"} m["testset"] = []int{1, 2, 3, 4, 5, 6, 7, 8, 9} m["testmap"] = map[string]string{"field1": "val1", "field2": "val2", "field3": "val3"} m["testvarint"] = bigInt m["testinet"] = "213.212.2.19" sliceMap := []map[string]interface{}{m} if err := session.Query(`INSERT INTO slice_map_table (testuuid, testtimestamp, testvarchar, testbigint, testblob, testbool, testfloat, testdouble, testint, testdecimal, testlist, testset, testmap, testvarint, testinet) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, m["testuuid"], m["testtimestamp"], m["testvarchar"], m["testbigint"], m["testblob"], m["testbool"], m["testfloat"], m["testdouble"], m["testint"], m["testdecimal"], m["testlist"], m["testset"], m["testmap"], m["testvarint"], m["testinet"]).Exec(); err != nil { t.Fatal("insert:", err) } if returned, retErr := session.Query(`SELECT * FROM slice_map_table`).Iter().SliceMap(); retErr != nil { t.Fatal("select:", retErr) } else { matchSliceMap(t, sliceMap, returned[0]) } // Test for Iter.MapScan() { testMap := make(map[string]interface{}) if !session.Query(`SELECT * FROM slice_map_table`).Iter().MapScan(testMap) { t.Fatal("MapScan failed to work with one row") } matchSliceMap(t, sliceMap, testMap) } // Test for Query.MapScan() { testMap := make(map[string]interface{}) if session.Query(`SELECT * FROM slice_map_table`).MapScan(testMap) != nil { t.Fatal("MapScan failed to work with one row") } matchSliceMap(t, sliceMap, testMap) } } func matchSliceMap(t *testing.T, sliceMap []map[string]interface{}, testMap map[string]interface{}) { if sliceMap[0]["testuuid"] != testMap["testuuid"] { t.Fatal("returned testuuid did not match") } if sliceMap[0]["testtimestamp"] != testMap["testtimestamp"] { t.Fatal("returned testtimestamp did not match") } if sliceMap[0]["testvarchar"] != testMap["testvarchar"] { t.Fatal("returned testvarchar did not match") } if sliceMap[0]["testbigint"] != testMap["testbigint"] { t.Fatal("returned testbigint did not match") } if !reflect.DeepEqual(sliceMap[0]["testblob"], testMap["testblob"]) { t.Fatal("returned testblob did not match") } if sliceMap[0]["testbool"] != testMap["testbool"] { t.Fatal("returned testbool did not match") } if sliceMap[0]["testfloat"] != testMap["testfloat"] { t.Fatal("returned testfloat did not match") } if sliceMap[0]["testdouble"] != testMap["testdouble"] { t.Fatal("returned testdouble did not match") } if sliceMap[0]["testinet"] != testMap["testinet"] { t.Fatal("returned testinet did not match") } expectedDecimal := sliceMap[0]["testdecimal"].(*inf.Dec) returnedDecimal := testMap["testdecimal"].(*inf.Dec) if expectedDecimal.Cmp(returnedDecimal) != 0 { t.Fatal("returned testdecimal did not match") } if !reflect.DeepEqual(sliceMap[0]["testlist"], testMap["testlist"]) { t.Fatal("returned testlist did not match") } if !reflect.DeepEqual(sliceMap[0]["testset"], testMap["testset"]) { t.Fatal("returned testset did not match") } if !reflect.DeepEqual(sliceMap[0]["testmap"], testMap["testmap"]) { t.Fatal("returned testmap did not match") } if sliceMap[0]["testint"] != testMap["testint"] { t.Fatal("returned testint did not match") } } func TestSmallInt(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion4 { t.Skip("smallint is only supported in cassandra 2.2+") } if err := createTable(session, `CREATE TABLE gocql_test.smallint_table ( testsmallint smallint PRIMARY KEY, )`); err != nil { t.Fatal("create table:", err) } m := make(map[string]interface{}) m["testsmallint"] = int16(2) sliceMap := []map[string]interface{}{m} if err := session.Query(`INSERT INTO smallint_table (testsmallint) VALUES (?)`, m["testsmallint"]).Exec(); err != nil { t.Fatal("insert:", err) } if returned, retErr := session.Query(`SELECT * FROM smallint_table`).Iter().SliceMap(); retErr != nil { t.Fatal("select:", retErr) } else { if sliceMap[0]["testsmallint"] != returned[0]["testsmallint"] { t.Fatal("returned testsmallint did not match") } } } func TestScanWithNilArguments(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.scan_with_nil_arguments ( foo varchar, bar int, PRIMARY KEY (foo, bar) )`); err != nil { t.Fatal("create:", err) } for i := 1; i <= 20; i++ { if err := session.Query("INSERT INTO scan_with_nil_arguments (foo, bar) VALUES (?, ?)", "squares", i*i).Exec(); err != nil { t.Fatal("insert:", err) } } iter := session.Query("SELECT * FROM scan_with_nil_arguments WHERE foo = ?", "squares").Iter() var n int count := 0 for iter.Scan(nil, &n) { count += n } if err := iter.Close(); err != nil { t.Fatal("close:", err) } if count != 2870 { t.Fatalf("expected %d, got %d", 2870, count) } } func TestScanCASWithNilArguments(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("lightweight transactions not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.scan_cas_with_nil_arguments ( foo varchar, bar varchar, PRIMARY KEY (foo, bar) )`); err != nil { t.Fatal("create:", err) } foo := "baz" var cas string if applied, err := session.Query(`INSERT INTO scan_cas_with_nil_arguments (foo, bar) VALUES (?, ?) IF NOT EXISTS`, foo, foo).ScanCAS(nil, nil); err != nil { t.Fatal("insert:", err) } else if !applied { t.Fatal("insert should have been applied") } if applied, err := session.Query(`INSERT INTO scan_cas_with_nil_arguments (foo, bar) VALUES (?, ?) IF NOT EXISTS`, foo, foo).ScanCAS(&cas, nil); err != nil { t.Fatal("insert:", err) } else if applied { t.Fatal("insert should not have been applied") } else if foo != cas { t.Fatalf("expected %v but got %v", foo, cas) } if applied, err := session.Query(`INSERT INTO scan_cas_with_nil_arguments (foo, bar) VALUES (?, ?) IF NOT EXISTS`, foo, foo).ScanCAS(nil, &cas); err != nil { t.Fatal("insert:", err) } else if applied { t.Fatal("insert should not have been applied") } else if foo != cas { t.Fatalf("expected %v but got %v", foo, cas) } } func TestRebindQueryInfo(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.rebind_query (id int, value text, PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } if err := session.Query("INSERT INTO rebind_query (id, value) VALUES (?, ?)", 23, "quux").Exec(); err != nil { t.Fatalf("insert into rebind_query failed, err '%v'", err) } if err := session.Query("INSERT INTO rebind_query (id, value) VALUES (?, ?)", 24, "w00t").Exec(); err != nil { t.Fatalf("insert into rebind_query failed, err '%v'", err) } q := session.Query("SELECT value FROM rebind_query WHERE ID = ?") q.Bind(23) iter := q.Iter() var value string for iter.Scan(&value) { } if value != "quux" { t.Fatalf("expected %v but got %v", "quux", value) } q.Bind(24) iter = q.Iter() for iter.Scan(&value) { } if value != "w00t" { t.Fatalf("expected %v but got %v", "w00t", value) } } // TestStaticQueryInfo makes sure that the application can manually bind query parameters using the simplest possible static binding strategy func TestStaticQueryInfo(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.static_query_info (id int, value text, PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } if err := session.Query("INSERT INTO static_query_info (id, value) VALUES (?, ?)", 113, "foo").Exec(); err != nil { t.Fatalf("insert into static_query_info failed, err '%v'", err) } autobinder := func(q *QueryInfo) ([]interface{}, error) { values := make([]interface{}, 1) values[0] = 113 return values, nil } qry := session.Bind("SELECT id, value FROM static_query_info WHERE id = ?", autobinder) if err := qry.Exec(); err != nil { t.Fatalf("expose query info failed, error '%v'", err) } iter := qry.Iter() var id int var value string iter.Scan(&id, &value) if err := iter.Close(); err != nil { t.Fatalf("query with exposed info failed, err '%v'", err) } if value != "foo" { t.Fatalf("Expected value %s, but got %s", "foo", value) } } type ClusteredKeyValue struct { Id int Cluster int Value string } func (kv *ClusteredKeyValue) Bind(q *QueryInfo) ([]interface{}, error) { values := make([]interface{}, len(q.Args)) for i, info := range q.Args { fieldName := upcaseInitial(info.Name) value := reflect.ValueOf(kv) field := reflect.Indirect(value).FieldByName(fieldName) values[i] = field.Addr().Interface() } return values, nil } func upcaseInitial(str string) string { for i, v := range str { return string(unicode.ToUpper(v)) + str[i+1:] } return "" } // TestBoundQueryInfo makes sure that the application can manually bind query parameters using the query meta data supplied at runtime func TestBoundQueryInfo(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.clustered_query_info (id int, cluster int, value text, PRIMARY KEY (id, cluster))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } write := &ClusteredKeyValue{Id: 200, Cluster: 300, Value: "baz"} insert := session.Bind("INSERT INTO clustered_query_info (id, cluster, value) VALUES (?, ?,?)", write.Bind) if err := insert.Exec(); err != nil { t.Fatalf("insert into clustered_query_info failed, err '%v'", err) } read := &ClusteredKeyValue{Id: 200, Cluster: 300} qry := session.Bind("SELECT id, cluster, value FROM clustered_query_info WHERE id = ? and cluster = ?", read.Bind) iter := qry.Iter() var id, cluster int var value string iter.Scan(&id, &cluster, &value) if err := iter.Close(); err != nil { t.Fatalf("query with clustered_query_info info failed, err '%v'", err) } if value != "baz" { t.Fatalf("Expected value %s, but got %s", "baz", value) } } // TestBatchQueryInfo makes sure that the application can manually bind query parameters when executing in a batch func TestBatchQueryInfo(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } if err := createTable(session, "CREATE TABLE gocql_test.batch_query_info (id int, cluster int, value text, PRIMARY KEY (id, cluster))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } write := func(q *QueryInfo) ([]interface{}, error) { values := make([]interface{}, 3) values[0] = 4000 values[1] = 5000 values[2] = "bar" return values, nil } batch := session.NewBatch(LoggedBatch) batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write) if err := session.ExecuteBatch(batch); err != nil { t.Fatalf("batch insert into batch_query_info failed, err '%v'", err) } read := func(q *QueryInfo) ([]interface{}, error) { values := make([]interface{}, 2) values[0] = 4000 values[1] = 5000 return values, nil } qry := session.Bind("SELECT id, cluster, value FROM batch_query_info WHERE id = ? and cluster = ?", read) iter := qry.Iter() var id, cluster int var value string iter.Scan(&id, &cluster, &value) if err := iter.Close(); err != nil { t.Fatalf("query with batch_query_info info failed, err '%v'", err) } if value != "bar" { t.Fatalf("Expected value %s, but got %s", "bar", value) } } func getRandomConn(t *testing.T, session *Session) *Conn { conn := session.getConn() if conn == nil { t.Fatal("unable to get a connection") } return conn } func injectInvalidPreparedStatement(t *testing.T, session *Session, table string) (string, *Conn) { if err := createTable(session, `CREATE TABLE gocql_test.`+table+` ( foo varchar, bar int, PRIMARY KEY (foo, bar) )`); err != nil { t.Fatal("create:", err) } stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)" conn := getRandomConn(t, session) flight := new(inflightPrepare) key := session.stmtsLRU.keyFor(conn.host.HostID(), "", stmt) session.stmtsLRU.add(key, flight) flight.preparedStatment = &preparedStatment{ id: []byte{'f', 'o', 'o', 'b', 'a', 'r'}, request: preparedMetadata{ resultMetadata: resultMetadata{ colCount: 1, actualColCount: 1, columns: []ColumnInfo{ { Keyspace: "gocql_test", Table: table, Name: "foo", TypeInfo: NativeType{ typ: TypeVarchar, }, }, }, }, }, } return stmt, conn } func TestPrepare_MissingSchemaPrepare(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() s := createSession(t) conn := getRandomConn(t, s) defer s.Close() insertQry := s.Query("INSERT INTO invalidschemaprep (val) VALUES (?)", 5) if err := conn.executeQuery(ctx, insertQry).err; err == nil { t.Fatal("expected error, but got nil.") } if err := createTable(s, "CREATE TABLE gocql_test.invalidschemaprep (val int, PRIMARY KEY (val))"); err != nil { t.Fatal("create table:", err) } if err := conn.executeQuery(ctx, insertQry).err; err != nil { t.Fatal(err) // unconfigured columnfamily } } func TestPrepare_ReprepareStatement(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() session := createSession(t) defer session.Close() stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement") query := session.Query(stmt, "bar") if err := conn.executeQuery(ctx, query).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) } } func TestPrepare_ReprepareBatch(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch") batch := session.NewBatch(UnloggedBatch) batch.Query(stmt, "bar") if err := conn.executeBatch(ctx, batch).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) } } func TestQueryInfo(t *testing.T) { session := createSession(t) defer session.Close() conn := getRandomConn(t, session) info, err := conn.prepareStatement(context.Background(), "SELECT release_version, host_id FROM system.local WHERE key = ?", nil) if err != nil { t.Fatalf("Failed to execute query for preparing statement: %v", err) } if x := len(info.request.columns); x != 1 { t.Fatalf("Was not expecting meta data for %d query arguments, but got %d\n", 1, x) } if session.cfg.ProtoVersion > 1 { if x := len(info.response.columns); x != 2 { t.Fatalf("Was not expecting meta data for %d result columns, but got %d\n", 2, x) } } } // TestPreparedCacheEviction will make sure that the cache size is maintained func TestPrepare_PreparedCacheEviction(t *testing.T) { const maxPrepared = 4 clusterHosts := getClusterHosts() host := clusterHosts[0] cluster := createCluster() cluster.MaxPreparedStmts = maxPrepared cluster.Events.DisableSchemaEvents = true cluster.Hosts = []string{host} cluster.HostFilter = WhiteListHostFilter(host) session := createSessionFromCluster(cluster, t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.prepcachetest (id int,mod int,PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } // clear the cache session.stmtsLRU.clear() //Fill the table for i := 0; i < 2; i++ { if err := session.Query("INSERT INTO prepcachetest (id,mod) VALUES (?, ?)", i, 10000%(i+1)).Exec(); err != nil { t.Fatalf("insert into prepcachetest failed, err '%v'", err) } } //Populate the prepared statement cache with select statements var id, mod int for i := 0; i < 2; i++ { err := session.Query("SELECT id,mod FROM prepcachetest WHERE id = "+strconv.FormatInt(int64(i), 10)).Scan(&id, &mod) if err != nil { t.Fatalf("select from prepcachetest failed, error '%v'", err) } } //generate an update statement to test they are prepared err := session.Query("UPDATE prepcachetest SET mod = ? WHERE id = ?", 1, 11).Exec() if err != nil { t.Fatalf("update prepcachetest failed, error '%v'", err) } //generate a delete statement to test they are prepared err = session.Query("DELETE FROM prepcachetest WHERE id = ?", 1).Exec() if err != nil { t.Fatalf("delete from prepcachetest failed, error '%v'", err) } //generate an insert statement to test they are prepared err = session.Query("INSERT INTO prepcachetest (id,mod) VALUES (?, ?)", 3, 11).Exec() if err != nil { t.Fatalf("insert into prepcachetest failed, error '%v'", err) } session.stmtsLRU.mu.Lock() defer session.stmtsLRU.mu.Unlock() //Make sure the cache size is maintained if session.stmtsLRU.lru.Len() != session.stmtsLRU.lru.MaxEntries { t.Fatalf("expected cache size of %v, got %v", session.stmtsLRU.lru.MaxEntries, session.stmtsLRU.lru.Len()) } // Walk through all the configured hosts and test cache retention and eviction for _, host := range session.ring.hosts { _, ok := session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host.HostID(), session.cfg.Keyspace, "SELECT id,mod FROM prepcachetest WHERE id = 0")) if ok { t.Errorf("expected first select to be purged but was in cache for host=%q", host) } _, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host.HostID(), session.cfg.Keyspace, "SELECT id,mod FROM prepcachetest WHERE id = 1")) if !ok { t.Errorf("exepected second select to be in cache for host=%q", host) } _, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host.HostID(), session.cfg.Keyspace, "INSERT INTO prepcachetest (id,mod) VALUES (?, ?)")) if !ok { t.Errorf("expected insert to be in cache for host=%q", host) } _, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host.HostID(), session.cfg.Keyspace, "UPDATE prepcachetest SET mod = ? WHERE id = ?")) if !ok { t.Errorf("expected update to be in cached for host=%q", host) } _, ok = session.stmtsLRU.lru.Get(session.stmtsLRU.keyFor(host.HostID(), session.cfg.Keyspace, "DELETE FROM prepcachetest WHERE id = ?")) if !ok { t.Errorf("expected delete to be cached for host=%q", host) } } } func TestPrepare_PreparedCacheKey(t *testing.T) { session := createSession(t) defer session.Close() // create a second keyspace cluster2 := createCluster() createKeyspace(t, cluster2, "gocql_test2") cluster2.Keyspace = "gocql_test2" session2, err := cluster2.CreateSession() if err != nil { t.Fatal("create session:", err) } defer session2.Close() // both keyspaces have a table named "test_stmt_cache_key" if err := createTable(session, "CREATE TABLE gocql_test.test_stmt_cache_key (id varchar primary key, field varchar)"); err != nil { t.Fatal("create table:", err) } if err := createTable(session2, "CREATE TABLE gocql_test2.test_stmt_cache_key (id varchar primary key, field varchar)"); err != nil { t.Fatal("create table:", err) } // both tables have a single row with the same partition key but different column value if err = session.Query(`INSERT INTO test_stmt_cache_key (id, field) VALUES (?, ?)`, "key", "one").Exec(); err != nil { t.Fatal("insert:", err) } if err = session2.Query(`INSERT INTO test_stmt_cache_key (id, field) VALUES (?, ?)`, "key", "two").Exec(); err != nil { t.Fatal("insert:", err) } // should be able to see different values in each keyspace var value string if err = session.Query("SELECT field FROM test_stmt_cache_key WHERE id = ?", "key").Scan(&value); err != nil { t.Fatal("select:", err) } if value != "one" { t.Errorf("Expected one, got %s", value) } if err = session2.Query("SELECT field FROM test_stmt_cache_key WHERE id = ?", "key").Scan(&value); err != nil { t.Fatal("select:", err) } if value != "two" { t.Errorf("Expected two, got %s", value) } } // TestMarshalFloat64Ptr tests to see that a pointer to a float64 is marshalled correctly. func TestMarshalFloat64Ptr(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.float_test (id double, test double, primary key (id))"); err != nil { t.Fatal("create table:", err) } testNum := float64(7500) if err := session.Query(`INSERT INTO float_test (id,test) VALUES (?,?)`, float64(7500.00), &testNum).Exec(); err != nil { t.Fatal("insert float64:", err) } } // TestMarshalInet tests to see that a pointer to a float64 is marshalled correctly. func TestMarshalInet(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.inet_test (ip inet, name text, primary key (ip))"); err != nil { t.Fatal("create table:", err) } stringIp := "123.34.45.56" if err := session.Query(`INSERT INTO inet_test (ip,name) VALUES (?,?)`, stringIp, "Test IP 1").Exec(); err != nil { t.Fatal("insert string inet:", err) } var stringResult string if err := session.Query("SELECT ip FROM inet_test").Scan(&stringResult); err != nil { t.Fatalf("select for string from inet_test 1 failed: %v", err) } if stringResult != stringIp { t.Errorf("Expected %s, was %s", stringIp, stringResult) } var ipResult net.IP if err := session.Query("SELECT ip FROM inet_test").Scan(&ipResult); err != nil { t.Fatalf("select for net.IP from inet_test 1 failed: %v", err) } if ipResult.String() != stringIp { t.Errorf("Expected %s, was %s", stringIp, ipResult.String()) } if err := session.Query(`DELETE FROM inet_test WHERE ip = ?`, stringIp).Exec(); err != nil { t.Fatal("delete inet table:", err) } netIp := net.ParseIP("222.43.54.65") if err := session.Query(`INSERT INTO inet_test (ip,name) VALUES (?,?)`, netIp, "Test IP 2").Exec(); err != nil { t.Fatal("insert netIp inet:", err) } if err := session.Query("SELECT ip FROM inet_test").Scan(&stringResult); err != nil { t.Fatalf("select for string from inet_test 2 failed: %v", err) } if stringResult != netIp.String() { t.Errorf("Expected %s, was %s", netIp.String(), stringResult) } if err := session.Query("SELECT ip FROM inet_test").Scan(&ipResult); err != nil { t.Fatalf("select for net.IP from inet_test 2 failed: %v", err) } if ipResult.String() != netIp.String() { t.Errorf("Expected %s, was %s", netIp.String(), ipResult.String()) } } func TestVarint(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.varint_test (id varchar, test varint, test2 varint, primary key (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } if err := session.Query(`INSERT INTO varint_test (id, test) VALUES (?, ?)`, "id", 0).Exec(); err != nil { t.Fatalf("insert varint: %v", err) } var result int if err := session.Query("SELECT test FROM varint_test").Scan(&result); err != nil { t.Fatalf("select from varint_test failed: %v", err) } if result != 0 { t.Errorf("Expected 0, was %d", result) } if err := session.Query(`INSERT INTO varint_test (id, test) VALUES (?, ?)`, "id", -1).Exec(); err != nil { t.Fatalf("insert varint: %v", err) } if err := session.Query("SELECT test FROM varint_test").Scan(&result); err != nil { t.Fatalf("select from varint_test failed: %v", err) } if result != -1 { t.Errorf("Expected -1, was %d", result) } if err := session.Query(`INSERT INTO varint_test (id, test) VALUES (?, ?)`, "id", nil).Exec(); err != nil { t.Fatalf("insert varint: %v", err) } if err := session.Query("SELECT test FROM varint_test").Scan(&result); err != nil { t.Fatalf("select from varint_test failed: %v", err) } if result != 0 { t.Errorf("Expected 0, was %d", result) } var nullableResult *int if err := session.Query("SELECT test FROM varint_test").Scan(&nullableResult); err != nil { t.Fatalf("select from varint_test failed: %v", err) } if nullableResult != nil { t.Errorf("Expected nil, was %d", nullableResult) } if err := session.Query(`INSERT INTO varint_test (id, test) VALUES (?, ?)`, "id", int64(math.MaxInt32)+1).Exec(); err != nil { t.Fatalf("insert varint: %v", err) } var result64 int64 if err := session.Query("SELECT test FROM varint_test").Scan(&result64); err != nil { t.Fatalf("select from varint_test failed: %v", err) } if result64 != int64(math.MaxInt32)+1 { t.Errorf("Expected %d, was %d", int64(math.MaxInt32)+1, result64) } biggie := new(big.Int) biggie.SetString("36893488147419103232", 10) // > 2**64 if err := session.Query(`INSERT INTO varint_test (id, test) VALUES (?, ?)`, "id", biggie).Exec(); err != nil { t.Fatalf("insert varint: %v", err) } resultBig := new(big.Int) if err := session.Query("SELECT test FROM varint_test").Scan(resultBig); err != nil { t.Fatalf("select from varint_test failed: %v", err) } if resultBig.String() != biggie.String() { t.Errorf("Expected %s, was %s", biggie.String(), resultBig.String()) } err := session.Query("SELECT test FROM varint_test").Scan(&result64) if err == nil || strings.Index(err.Error(), "out of range") == -1 { t.Errorf("expected out of range error since value is too big for int64") } // value not set in cassandra, leave bind variable empty resultBig = new(big.Int) if err := session.Query("SELECT test2 FROM varint_test").Scan(resultBig); err != nil { t.Fatalf("select from varint_test failed: %v", err) } if resultBig.Int64() != 0 { t.Errorf("Expected %s, was %s", biggie.String(), resultBig.String()) } // can use double pointer to explicitly detect value is not set in cassandra if err := session.Query("SELECT test2 FROM varint_test").Scan(&resultBig); err != nil { t.Fatalf("select from varint_test failed: %v", err) } if resultBig != nil { t.Errorf("Expected %v, was %v", nil, *resultBig) } } // TestQueryStats confirms that the stats are returning valid data. Accuracy may be questionable. func TestQueryStats(t *testing.T) { session := createSession(t) defer session.Close() qry := session.Query("SELECT * FROM system.peers") if err := qry.Exec(); err != nil { t.Fatalf("query failed. %v", err) } else { if qry.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } if qry.Latency() <= 0 { t.Fatalf("expected latency to be greater than 0, but got %v instead.", qry.Latency()) } } } // TestIterHosts confirms that host is added to Iter when the query succeeds. func TestIterHost(t *testing.T) { session := createSession(t) defer session.Close() iter := session.Query("SELECT * FROM system.peers").Iter() // check if Host method works if iter.Host() == nil { t.Error("No host in iter") } } // TestBatchStats confirms that the stats are returning valid data. Accuracy may be questionable. func TestBatchStats(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } if err := createTable(session, "CREATE TABLE gocql_test.batchStats (id int, PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } b := session.NewBatch(LoggedBatch) b.Query("INSERT INTO batchStats (id) VALUES (?)", 1) b.Query("INSERT INTO batchStats (id) VALUES (?)", 2) if err := session.ExecuteBatch(b); err != nil { t.Fatalf("query failed. %v", err) } else { if b.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } if b.Latency() <= 0 { t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency()) } } } type funcBatchObserver func(context.Context, ObservedBatch) func (f funcBatchObserver) ObserveBatch(ctx context.Context, o ObservedBatch) { f(ctx, o) } func TestBatchObserve(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion == 1 { t.Skip("atomic batches not supported. Please use Cassandra >= 2.0") } if err := createTable(session, `CREATE TABLE gocql_test.batch_observe_table (id int, other int, PRIMARY KEY (id))`); err != nil { t.Fatal("create table:", err) } type observation struct { observedErr error observedKeyspace string observedStmts []string observedValues [][]interface{} } var observedBatch *observation batch := session.NewBatch(LoggedBatch) batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) { if observedBatch != nil { t.Fatal("batch observe called more than once") } observedBatch = &observation{ observedKeyspace: o.Keyspace, observedStmts: o.Statements, observedErr: o.Err, observedValues: o.Values, } })) for i := 0; i < 100; i++ { // hard coding 'i' into one of the values for better testing of observation batch.Query(fmt.Sprintf(`INSERT INTO batch_observe_table (id,other) VALUES (?,%d)`, i), i) } if err := session.ExecuteBatch(batch); err != nil { t.Fatal("execute batch:", err) } if observedBatch == nil { t.Fatal("batch observation has not been called") } if len(observedBatch.observedStmts) != 100 { t.Fatal("expecting 100 observed statements, got", len(observedBatch.observedStmts)) } if observedBatch.observedErr != nil { t.Fatal("not expecting to observe an error", observedBatch.observedErr) } if observedBatch.observedKeyspace != "gocql_test" { t.Fatalf("expecting keyspace 'gocql_test', got %q", observedBatch.observedKeyspace) } for i, stmt := range observedBatch.observedStmts { if stmt != fmt.Sprintf(`INSERT INTO batch_observe_table (id,other) VALUES (?,%d)`, i) { t.Fatal("unexpected query", stmt) } assertDeepEqual(t, "observed value", []interface{}{i}, observedBatch.observedValues[i]) } } // TestNilInQuery tests to see that a nil value passed to a query is handled by Cassandra // TODO validate the nil value by reading back the nil. Need to fix Unmarshalling. func TestNilInQuery(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.testNilInsert (id int, count int, PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } if err := session.Query("INSERT INTO testNilInsert (id,count) VALUES (?,?)", 1, nil).Exec(); err != nil { t.Fatalf("failed to insert with err: %v", err) } var id int if err := session.Query("SELECT id FROM testNilInsert").Scan(&id); err != nil { t.Fatalf("failed to select with err: %v", err) } else if id != 1 { t.Fatalf("expected id to be 1, got %v", id) } } // Don't initialize time.Time bind variable if cassandra timestamp column is empty func TestEmptyTimestamp(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.test_empty_timestamp (id int, time timestamp, num int, PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } if err := session.Query("INSERT INTO test_empty_timestamp (id, num) VALUES (?,?)", 1, 561).Exec(); err != nil { t.Fatalf("failed to insert with err: %v", err) } var timeVal time.Time if err := session.Query("SELECT time FROM test_empty_timestamp where id = ?", 1).Scan(&timeVal); err != nil { t.Fatalf("failed to select with err: %v", err) } if !timeVal.IsZero() { t.Errorf("time.Time bind variable should still be empty (was %s)", timeVal) } } // Integration test of just querying for data from the system.schema_keyspace table where the keyspace DOES exist. func TestGetKeyspaceMetadata(t *testing.T) { session := createSession(t) defer session.Close() keyspaceMetadata, err := getKeyspaceMetadata(session, "gocql_test") if err != nil { t.Fatalf("failed to query the keyspace metadata with err: %v", err) } if keyspaceMetadata == nil { t.Fatal("failed to query the keyspace metadata, nil returned") } if keyspaceMetadata.Name != "gocql_test" { t.Errorf("Expected keyspace name to be 'gocql' but was '%s'", keyspaceMetadata.Name) } if keyspaceMetadata.StrategyClass != "org.apache.cassandra.locator.SimpleStrategy" { t.Errorf("Expected replication strategy class to be 'org.apache.cassandra.locator.SimpleStrategy' but was '%s'", keyspaceMetadata.StrategyClass) } if keyspaceMetadata.StrategyOptions == nil { t.Error("Expected replication strategy options map but was nil") } rfStr, ok := keyspaceMetadata.StrategyOptions["replication_factor"] if !ok { t.Fatalf("Expected strategy option 'replication_factor' but was not found in %v", keyspaceMetadata.StrategyOptions) } rfInt, err := strconv.Atoi(rfStr.(string)) if err != nil { t.Fatalf("Error converting string to int with err: %v", err) } if rfInt != *flagRF { t.Errorf("Expected replication factor to be %d but was %d", *flagRF, rfInt) } } // Integration test of just querying for data from the system.schema_keyspace table where the keyspace DOES NOT exist. func TestGetKeyspaceMetadataFails(t *testing.T) { session := createSession(t) defer session.Close() _, err := getKeyspaceMetadata(session, "gocql_keyspace_does_not_exist") if err != ErrKeyspaceDoesNotExist || err == nil { t.Fatalf("Expected error of type ErrKeySpaceDoesNotExist. Instead, error was %v", err) } } // Integration test of just querying for data from the system.schema_columnfamilies table func TestGetTableMetadata(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.test_table_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } tables, err := getTableMetadata(session, "gocql_test") if err != nil { t.Fatalf("failed to query the table metadata with err: %v", err) } if tables == nil { t.Fatal("failed to query the table metadata, nil returned") } var testTable *TableMetadata // verify all tables have minimum expected data for i := range tables { table := &tables[i] if table.Name == "" { t.Errorf("Expected table name to be set, but it was empty: index=%d metadata=%+v", i, table) } if table.Keyspace != "gocql_test" { t.Errorf("Expected keyspace for '%s' table metadata to be 'gocql_test' but was '%s'", table.Name, table.Keyspace) } if session.cfg.ProtoVersion < 4 { // TODO(zariel): there has to be a better way to detect what metadata version // we are in, and a better way to structure the code so that it is abstracted away // from us here if table.KeyValidator == "" { t.Errorf("Expected key validator to be set for table %s", table.Name) } if table.Comparator == "" { t.Errorf("Expected comparator to be set for table %s", table.Name) } if table.DefaultValidator == "" { t.Errorf("Expected default validator to be set for table %s", table.Name) } } // these fields are not set until the metadata is compiled if table.PartitionKey != nil { t.Errorf("Did not expect partition key for table %s", table.Name) } if table.ClusteringColumns != nil { t.Errorf("Did not expect clustering columns for table %s", table.Name) } if table.Columns != nil { t.Errorf("Did not expect columns for table %s", table.Name) } // for the next part of the test after this loop, find the metadata for the test table if table.Name == "test_table_metadata" { testTable = table } } // verify actual values on the test tables if testTable == nil { t.Fatal("Expected table metadata for name 'test_table_metadata'") } if session.cfg.ProtoVersion == protoVersion1 { if testTable.KeyValidator != "org.apache.cassandra.db.marshal.Int32Type" { t.Errorf("Expected test_table_metadata key validator to be 'org.apache.cassandra.db.marshal.Int32Type' but was '%s'", testTable.KeyValidator) } if testTable.Comparator != "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)" { t.Errorf("Expected test_table_metadata key validator to be 'org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.UTF8Type)' but was '%s'", testTable.Comparator) } if testTable.DefaultValidator != "org.apache.cassandra.db.marshal.BytesType" { t.Errorf("Expected test_table_metadata key validator to be 'org.apache.cassandra.db.marshal.BytesType' but was '%s'", testTable.DefaultValidator) } expectedKeyAliases := []string{"first_id"} if !reflect.DeepEqual(testTable.KeyAliases, expectedKeyAliases) { t.Errorf("Expected key aliases %v but was %v", expectedKeyAliases, testTable.KeyAliases) } expectedColumnAliases := []string{"second_id"} if !reflect.DeepEqual(testTable.ColumnAliases, expectedColumnAliases) { t.Errorf("Expected key aliases %v but was %v", expectedColumnAliases, testTable.ColumnAliases) } } if testTable.ValueAlias != "" { t.Errorf("Expected value alias '' but was '%s'", testTable.ValueAlias) } } // Integration test of just querying for data from the system.schema_columns table func TestGetColumnMetadata(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.test_column_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } if err := session.Query("CREATE INDEX index_column_metadata ON test_column_metadata ( third_id )").Exec(); err != nil { t.Fatalf("failed to create index with err: %v", err) } columns, err := getColumnMetadata(session, "gocql_test") if err != nil { t.Fatalf("failed to query column metadata with err: %v", err) } if columns == nil { t.Fatal("failed to query column metadata, nil returned") } testColumns := map[string]*ColumnMetadata{} // verify actual values on the test columns for i := range columns { column := &columns[i] if column.Name == "" { t.Errorf("Expected column name to be set, but it was empty: index=%d metadata=%+v", i, column) } if column.Table == "" { t.Errorf("Expected column %s table name to be set, but it was empty", column.Name) } if column.Keyspace != "gocql_test" { t.Errorf("Expected column %s keyspace name to be 'gocql_test', but it was '%s'", column.Name, column.Keyspace) } if column.Kind == ColumnUnkownKind { t.Errorf("Expected column %s kind to be set, but it was empty", column.Name) } if session.cfg.ProtoVersion == 1 && column.Kind != ColumnRegular { t.Errorf("Expected column %s kind to be set to 'regular' for proto V1 but it was '%s'", column.Name, column.Kind) } if column.Validator == "" { t.Errorf("Expected column %s validator to be set, but it was empty", column.Name) } // find the test table columns for the next step after this loop if column.Table == "test_column_metadata" { testColumns[column.Name] = column } } if session.cfg.ProtoVersion == 1 { // V1 proto only returns "regular columns" if len(testColumns) != 1 { t.Errorf("Expected 1 test columns but there were %d", len(testColumns)) } thirdID, found := testColumns["third_id"] if !found { t.Fatalf("Expected to find column 'third_id' metadata but there was only %v", testColumns) } if thirdID.Kind != ColumnRegular { t.Errorf("Expected %s column kind to be '%s' but it was '%s'", thirdID.Name, ColumnRegular, thirdID.Kind) } if thirdID.Index.Name != "index_column_metadata" { t.Errorf("Expected %s column index name to be 'index_column_metadata' but it was '%s'", thirdID.Name, thirdID.Index.Name) } } else { if len(testColumns) != 3 { t.Errorf("Expected 3 test columns but there were %d", len(testColumns)) } firstID, found := testColumns["first_id"] if !found { t.Fatalf("Expected to find column 'first_id' metadata but there was only %v", testColumns) } secondID, found := testColumns["second_id"] if !found { t.Fatalf("Expected to find column 'second_id' metadata but there was only %v", testColumns) } thirdID, found := testColumns["third_id"] if !found { t.Fatalf("Expected to find column 'third_id' metadata but there was only %v", testColumns) } if firstID.Kind != ColumnPartitionKey { t.Errorf("Expected %s column kind to be '%s' but it was '%s'", firstID.Name, ColumnPartitionKey, firstID.Kind) } if secondID.Kind != ColumnClusteringKey { t.Errorf("Expected %s column kind to be '%s' but it was '%s'", secondID.Name, ColumnClusteringKey, secondID.Kind) } if thirdID.Kind != ColumnRegular { t.Errorf("Expected %s column kind to be '%s' but it was '%s'", thirdID.Name, ColumnRegular, thirdID.Kind) } if !session.useSystemSchema && thirdID.Index.Name != "index_column_metadata" { // TODO(zariel): update metadata to scan index from system_schema t.Errorf("Expected %s column index name to be 'index_column_metadata' but it was '%s'", thirdID.Name, thirdID.Index.Name) } } } func TestViewMetadata(t *testing.T) { session := createSession(t) defer session.Close() createViews(t, session) views, err := getViewsMetadata(session, "gocql_test") if err != nil { t.Fatalf("failed to query view metadata with err: %v", err) } if views == nil { t.Fatal("failed to query view metadata, nil returned") } if len(views) != 1 { t.Fatal("expected one view") } textType := TypeText if flagCassVersion.Before(3, 0, 0) { textType = TypeVarchar } expectedView := ViewMetadata{ Keyspace: "gocql_test", Name: "basicview", FieldNames: []string{"birthday", "nationality", "weight", "height"}, FieldTypes: []TypeInfo{ NativeType{typ: TypeTimestamp}, NativeType{typ: textType}, NativeType{typ: textType}, NativeType{typ: textType}, }, } if !reflect.DeepEqual(views[0], expectedView) { t.Fatalf("view is %+v, but expected %+v", views[0], expectedView) } } func TestMaterializedViewMetadata(t *testing.T) { if flagCassVersion.Before(3, 0, 0) { return } session := createSession(t) defer session.Close() createMaterializedViews(t, session) materializedViews, err := getMaterializedViewsMetadata(session, "gocql_test") if err != nil { t.Fatalf("failed to query view metadata with err: %v", err) } if materializedViews == nil { t.Fatal("failed to query view metadata, nil returned") } if len(materializedViews) != 2 { t.Fatal("expected two views") } expectedChunkLengthInKB := "16" expectedDCLocalReadRepairChance := float64(0) expectedSpeculativeRetry := "99p" if flagCassVersion.Before(4, 0, 0) { expectedChunkLengthInKB = "64" expectedDCLocalReadRepairChance = 0.1 expectedSpeculativeRetry = "99PERCENTILE" } expectedView1 := MaterializedViewMetadata{ Keyspace: "gocql_test", Name: "view_view", baseTableName: "view_table", BloomFilterFpChance: 0.01, Caching: map[string]string{"keys": "ALL", "rows_per_partition": "NONE"}, Comment: "", Compaction: map[string]string{"class": "org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy", "max_threshold": "32", "min_threshold": "4"}, Compression: map[string]string{"chunk_length_in_kb": expectedChunkLengthInKB, "class": "org.apache.cassandra.io.compress.LZ4Compressor"}, CrcCheckChance: 1, DcLocalReadRepairChance: expectedDCLocalReadRepairChance, DefaultTimeToLive: 0, Extensions: map[string]string{}, GcGraceSeconds: 864000, IncludeAllColumns: false, MaxIndexInterval: 2048, MemtableFlushPeriodInMs: 0, MinIndexInterval: 128, ReadRepairChance: 0, SpeculativeRetry: expectedSpeculativeRetry, } expectedView2 := MaterializedViewMetadata{ Keyspace: "gocql_test", Name: "view_view2", baseTableName: "view_table2", BloomFilterFpChance: 0.01, Caching: map[string]string{"keys": "ALL", "rows_per_partition": "NONE"}, Comment: "", Compaction: map[string]string{"class": "org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy", "max_threshold": "32", "min_threshold": "4"}, Compression: map[string]string{"chunk_length_in_kb": expectedChunkLengthInKB, "class": "org.apache.cassandra.io.compress.LZ4Compressor"}, CrcCheckChance: 1, DcLocalReadRepairChance: expectedDCLocalReadRepairChance, DefaultTimeToLive: 0, Extensions: map[string]string{}, GcGraceSeconds: 864000, IncludeAllColumns: false, MaxIndexInterval: 2048, MemtableFlushPeriodInMs: 0, MinIndexInterval: 128, ReadRepairChance: 0, SpeculativeRetry: expectedSpeculativeRetry, } expectedView1.BaseTableId = materializedViews[0].BaseTableId expectedView1.Id = materializedViews[0].Id if !reflect.DeepEqual(materializedViews[0], expectedView1) { t.Fatalf("materialized view is %+v, but expected %+v", materializedViews[0], expectedView1) } expectedView2.BaseTableId = materializedViews[1].BaseTableId expectedView2.Id = materializedViews[1].Id if !reflect.DeepEqual(materializedViews[1], expectedView2) { t.Fatalf("materialized view is %+v, but expected %+v", materializedViews[1], expectedView2) } } func TestAggregateMetadata(t *testing.T) { session := createSession(t) defer session.Close() createAggregate(t, session) aggregates, err := getAggregatesMetadata(session, "gocql_test") if err != nil { t.Fatalf("failed to query aggregate metadata with err: %v", err) } if aggregates == nil { t.Fatal("failed to query aggregate metadata, nil returned") } if len(aggregates) != 2 { t.Fatal("expected two aggregates") } expectedAggregrate := AggregateMetadata{ Keyspace: "gocql_test", Name: "average", ArgumentTypes: []TypeInfo{NativeType{typ: TypeInt}}, InitCond: "(0, 0)", ReturnType: NativeType{typ: TypeDouble}, StateType: TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeBigInt}, }, }, stateFunc: "avgstate", finalFunc: "avgfinal", } // In this case cassandra is returning a blob if flagCassVersion.Before(3, 0, 0) { expectedAggregrate.InitCond = string([]byte{0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0}) } if !reflect.DeepEqual(aggregates[0], expectedAggregrate) { t.Fatalf("aggregate 'average' is %+v, but expected %+v", aggregates[0], expectedAggregrate) } expectedAggregrate.Name = "average2" if !reflect.DeepEqual(aggregates[1], expectedAggregrate) { t.Fatalf("aggregate 'average2' is %+v, but expected %+v", aggregates[1], expectedAggregrate) } } func TestFunctionMetadata(t *testing.T) { session := createSession(t) defer session.Close() createFunctions(t, session) functions, err := getFunctionsMetadata(session, "gocql_test") if err != nil { t.Fatalf("failed to query function metadata with err: %v", err) } if functions == nil { t.Fatal("failed to query function metadata, nil returned") } if len(functions) != 2 { t.Fatal("expected two functions") } avgState := functions[1] avgFinal := functions[0] avgStateBody := "if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;" expectedAvgState := FunctionMetadata{ Keyspace: "gocql_test", Name: "avgstate", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeBigInt}, }, }, NativeType{typ: TypeInt}, }, ArgumentNames: []string{"state", "val"}, ReturnType: TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeBigInt}, }, }, CalledOnNullInput: true, Language: "java", Body: avgStateBody, } if !reflect.DeepEqual(avgState, expectedAvgState) { t.Fatalf("function is %+v, but expected %+v", avgState, expectedAvgState) } finalStateBody := "double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);" expectedAvgFinal := FunctionMetadata{ Keyspace: "gocql_test", Name: "avgfinal", ArgumentTypes: []TypeInfo{ TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeBigInt}, }, }, }, ArgumentNames: []string{"state"}, ReturnType: NativeType{typ: TypeDouble}, CalledOnNullInput: true, Language: "java", Body: finalStateBody, } if !reflect.DeepEqual(avgFinal, expectedAvgFinal) { t.Fatalf("function is %+v, but expected %+v", avgFinal, expectedAvgFinal) } } // Integration test of querying and composition the keyspace metadata func TestKeyspaceMetadata(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.test_metadata (first_id int, second_id int, third_id int, PRIMARY KEY (first_id, second_id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } createAggregate(t, session) createViews(t, session) createMaterializedViews(t, session) if err := session.Query("CREATE INDEX index_metadata ON test_metadata ( third_id )").Exec(); err != nil { t.Fatalf("failed to create index with err: %v", err) } keyspaceMetadata, err := session.KeyspaceMetadata("gocql_test") if err != nil { t.Fatalf("failed to query keyspace metadata with err: %v", err) } if keyspaceMetadata == nil { t.Fatal("expected the keyspace metadata to not be nil, but it was nil") } if keyspaceMetadata.Name != session.cfg.Keyspace { t.Fatalf("Expected the keyspace name to be %s but was %s", session.cfg.Keyspace, keyspaceMetadata.Name) } if len(keyspaceMetadata.Tables) == 0 { t.Errorf("Expected tables but there were none") } tableMetadata, found := keyspaceMetadata.Tables["test_metadata"] if !found { t.Fatalf("failed to find the test_metadata table metadata") } if len(tableMetadata.PartitionKey) != 1 { t.Errorf("expected partition key length of 1, but was %d", len(tableMetadata.PartitionKey)) } for i, column := range tableMetadata.PartitionKey { if column == nil { t.Errorf("partition key column metadata at index %d was nil", i) } } if tableMetadata.PartitionKey[0].Name != "first_id" { t.Errorf("Expected the first partition key column to be 'first_id' but was '%s'", tableMetadata.PartitionKey[0].Name) } if len(tableMetadata.ClusteringColumns) != 1 { t.Fatalf("expected clustering columns length of 1, but was %d", len(tableMetadata.ClusteringColumns)) } for i, column := range tableMetadata.ClusteringColumns { if column == nil { t.Fatalf("clustering column metadata at index %d was nil", i) } } if tableMetadata.ClusteringColumns[0].Name != "second_id" { t.Errorf("Expected the first clustering column to be 'second_id' but was '%s'", tableMetadata.ClusteringColumns[0].Name) } thirdColumn, found := tableMetadata.Columns["third_id"] if !found { t.Fatalf("Expected a column definition for 'third_id'") } if !session.useSystemSchema && thirdColumn.Index.Name != "index_metadata" { // TODO(zariel): scan index info from system_schema t.Errorf("Expected column index named 'index_metadata' but was '%s'", thirdColumn.Index.Name) } aggregate, found := keyspaceMetadata.Aggregates["average"] if !found { t.Fatal("failed to find the aggregate 'average' in metadata") } if aggregate.FinalFunc.Name != "avgfinal" { t.Fatalf("expected final function %s, but got %s", "avgFinal", aggregate.FinalFunc.Name) } if aggregate.StateFunc.Name != "avgstate" { t.Fatalf("expected state function %s, but got %s", "avgstate", aggregate.StateFunc.Name) } aggregate, found = keyspaceMetadata.Aggregates["average2"] if !found { t.Fatal("failed to find the aggregate 'average2' in metadata") } if aggregate.FinalFunc.Name != "avgfinal" { t.Fatalf("expected final function %s, but got %s", "avgFinal", aggregate.FinalFunc.Name) } if aggregate.StateFunc.Name != "avgstate" { t.Fatalf("expected state function %s, but got %s", "avgstate", aggregate.StateFunc.Name) } _, found = keyspaceMetadata.Views["basicview"] if !found { t.Fatal("failed to find the view in metadata") } _, found = keyspaceMetadata.UserTypes["basicview"] if !found { t.Fatal("failed to find the types in metadata") } textType := TypeText if flagCassVersion.Before(3, 0, 0) { textType = TypeVarchar } expectedType := UserTypeMetadata{ Keyspace: "gocql_test", Name: "basicview", FieldNames: []string{"birthday", "nationality", "weight", "height"}, FieldTypes: []TypeInfo{ NativeType{typ: TypeTimestamp}, NativeType{typ: textType}, NativeType{typ: textType}, NativeType{typ: textType}, }, } if !reflect.DeepEqual(*keyspaceMetadata.UserTypes["basicview"], expectedType) { t.Fatalf("type is %+v, but expected %+v", keyspaceMetadata.UserTypes["basicview"], expectedType) } if flagCassVersion.Major >= 3 { materializedView, found := keyspaceMetadata.MaterializedViews["view_view"] if !found { t.Fatal("failed to find materialized view view_view in metadata") } if materializedView.BaseTable.Name != "view_table" { t.Fatalf("expected name: %s, materialized view base table name: %s", "view_table", materializedView.BaseTable.Name) } materializedView, found = keyspaceMetadata.MaterializedViews["view_view2"] if !found { t.Fatal("failed to find materialized view view_view2 in metadata") } if materializedView.BaseTable.Name != "view_table2" { t.Fatalf("expected name: %s, materialized view base table name: %s", "view_table2", materializedView.BaseTable.Name) } } } // Integration test of the routing key calculation func TestRoutingKey(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.test_single_routing_key (first_id int, second_id int, PRIMARY KEY (first_id, second_id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } if err := createTable(session, "CREATE TABLE gocql_test.test_composite_routing_key (first_id int, second_id int, PRIMARY KEY ((first_id, second_id)))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } routingKeyInfo, err := session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } if routingKeyInfo == nil { t.Fatal("Expected routing key info, but was nil") } if len(routingKeyInfo.indexes) != 1 { t.Fatalf("Expected routing key indexes length to be 1 but was %d", len(routingKeyInfo.indexes)) } if routingKeyInfo.indexes[0] != 1 { t.Errorf("Expected routing key index[0] to be 1 but was %d", routingKeyInfo.indexes[0]) } if len(routingKeyInfo.types) != 1 { t.Fatalf("Expected routing key types length to be 1 but was %d", len(routingKeyInfo.types)) } if routingKeyInfo.types[0] == nil { t.Fatal("Expected routing key types[0] to be non-nil") } if routingKeyInfo.types[0].Type() != TypeInt { t.Fatalf("Expected routing key types[0].Type to be %v but was %v", TypeInt, routingKeyInfo.types[0].Type()) } // verify the cache is working routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } if len(routingKeyInfo.indexes) != 1 { t.Fatalf("Expected routing key indexes length to be 1 but was %d", len(routingKeyInfo.indexes)) } if routingKeyInfo.indexes[0] != 1 { t.Errorf("Expected routing key index[0] to be 1 but was %d", routingKeyInfo.indexes[0]) } if len(routingKeyInfo.types) != 1 { t.Fatalf("Expected routing key types length to be 1 but was %d", len(routingKeyInfo.types)) } if routingKeyInfo.types[0] == nil { t.Fatal("Expected routing key types[0] to be non-nil") } if routingKeyInfo.types[0].Type() != TypeInt { t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[0].Type()) } cacheSize := session.routingKeyInfoCache.lru.Len() if cacheSize != 1 { t.Errorf("Expected cache size to be 1 but was %d", cacheSize) } query := session.Query("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", 1, 2) routingKey, err := query.GetRoutingKey() if err != nil { t.Fatalf("Failed to get routing key due to error: %v", err) } expectedRoutingKey := []byte{0, 0, 0, 2} if !reflect.DeepEqual(expectedRoutingKey, routingKey) { t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey) } routingKeyInfo, err = session.routingKeyInfo(context.Background(), "SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?") if err != nil { t.Fatalf("failed to get routing key info due to error: %v", err) } if routingKeyInfo == nil { t.Fatal("Expected routing key info, but was nil") } if len(routingKeyInfo.indexes) != 2 { t.Fatalf("Expected routing key indexes length to be 2 but was %d", len(routingKeyInfo.indexes)) } if routingKeyInfo.indexes[0] != 1 { t.Errorf("Expected routing key index[0] to be 1 but was %d", routingKeyInfo.indexes[0]) } if routingKeyInfo.indexes[1] != 0 { t.Errorf("Expected routing key index[1] to be 0 but was %d", routingKeyInfo.indexes[1]) } if len(routingKeyInfo.types) != 2 { t.Fatalf("Expected routing key types length to be 1 but was %d", len(routingKeyInfo.types)) } if routingKeyInfo.types[0] == nil { t.Fatal("Expected routing key types[0] to be non-nil") } if routingKeyInfo.types[0].Type() != TypeInt { t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[0].Type()) } if routingKeyInfo.types[1] == nil { t.Fatal("Expected routing key types[1] to be non-nil") } if routingKeyInfo.types[1].Type() != TypeInt { t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[1].Type()) } query = session.Query("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", 1, 2) routingKey, err = query.GetRoutingKey() if err != nil { t.Fatalf("Failed to get routing key due to error: %v", err) } expectedRoutingKey = []byte{0, 4, 0, 0, 0, 2, 0, 0, 4, 0, 0, 0, 1, 0} if !reflect.DeepEqual(expectedRoutingKey, routingKey) { t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey) } // verify the cache is working cacheSize = session.routingKeyInfoCache.lru.Len() if cacheSize != 2 { t.Errorf("Expected cache size to be 2 but was %d", cacheSize) } } // Integration test of the token-aware policy-based connection pool func TestTokenAwareConnPool(t *testing.T) { cluster := createCluster() cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy()) // force metadata query to page cluster.PageSize = 1 session := createSessionFromCluster(cluster, t) defer session.Close() expectedPoolSize := cluster.NumConns * len(session.ring.allHosts()) // wait for pool to fill for i := 0; i < 10; i++ { if session.pool.Size() == expectedPoolSize { break } time.Sleep(100 * time.Millisecond) } if expectedPoolSize != session.pool.Size() { t.Errorf("Expected pool size %d but was %d", expectedPoolSize, session.pool.Size()) } // add another cf so there are two pages when fetching table metadata from our keyspace if err := createTable(session, "CREATE TABLE gocql_test.test_token_aware_other_cf (id int, data text, PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create test_token_aware table with err: %v", err) } if err := createTable(session, "CREATE TABLE gocql_test.test_token_aware (id int, data text, PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create test_token_aware table with err: %v", err) } query := session.Query("INSERT INTO test_token_aware (id, data) VALUES (?,?)", 42, "8 * 6 =") if err := query.Exec(); err != nil { t.Fatalf("failed to insert with err: %v", err) } query = session.Query("SELECT data FROM test_token_aware where id = ?", 42).Consistency(One) var data string if err := query.Scan(&data); err != nil { t.Error(err) } // TODO add verification that the query went to the correct host } func TestNegativeStream(t *testing.T) { session := createSession(t) defer session.Close() conn := getRandomConn(t, session) const stream = -50 writer := frameWriterFunc(func(f *framer, streamID int) error { f.writeHeader(0, opOptions, stream) return f.finish() }) frame, err := conn.exec(context.Background(), writer, nil) if err == nil { t.Fatalf("expected to get an error on stream %d", stream) } else if frame != nil { t.Fatalf("expected to get nil frame got %+v", frame) } } func TestManualQueryPaging(t *testing.T) { const rowsToInsert = 5 session := createSession(t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.testManualPaging (id int, count int, PRIMARY KEY (id))"); err != nil { t.Fatal(err) } for i := 0; i < rowsToInsert; i++ { err := session.Query("INSERT INTO testManualPaging(id, count) VALUES(?, ?)", i, i*i).Exec() if err != nil { t.Fatal(err) } } // disable auto paging, 1 page per iteration query := session.Query("SELECT id, count FROM testManualPaging").PageState(nil).PageSize(2) var id, count, fetched int iter := query.Iter() // NOTE: this isnt very indicative of how it should be used, the idea is that // the page state is returned to some client who will send it back to manually // page through the results. for { for iter.Scan(&id, &count) { if count != (id * id) { t.Fatalf("got wrong value from iteration: got %d expected %d", count, id*id) } fetched++ } if len(iter.PageState()) > 0 { // more pages iter = query.PageState(iter.PageState()).Iter() } else { break } } if err := iter.Close(); err != nil { t.Fatal(err) } if fetched != rowsToInsert { t.Fatalf("expected to fetch %d rows got %d", rowsToInsert, fetched) } } func TestLexicalUUIDType(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.test_lexical_uuid ( key varchar, column1 'org.apache.cassandra.db.marshal.LexicalUUIDType', value int, PRIMARY KEY (key, column1) )`); err != nil { t.Fatal("create:", err) } key := TimeUUID().String() column1 := TimeUUID() err := session.Query("INSERT INTO test_lexical_uuid(key, column1, value) VALUES(?, ?, ?)", key, column1, 55).Exec() if err != nil { t.Fatal(err) } var gotUUID UUID if err := session.Query("SELECT column1 from test_lexical_uuid where key = ? AND column1 = ?", key, column1).Scan(&gotUUID); err != nil { t.Fatal(err) } if gotUUID != column1 { t.Errorf("got %s, expected %s", gotUUID, column1) } } // Issue 475 func TestSessionBindRoutingKey(t *testing.T) { cluster := createCluster() cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy()) session := createSessionFromCluster(cluster, t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.test_bind_routing_key ( key varchar, value int, PRIMARY KEY (key) )`); err != nil { t.Fatal(err) } const ( key = "routing-key" value = 5 ) fn := func(info *QueryInfo) ([]interface{}, error) { return []interface{}{key, value}, nil } q := session.Bind("INSERT INTO test_bind_routing_key(key, value) VALUES(?, ?)", fn) if err := q.Exec(); err != nil { t.Fatal(err) } } func TestJSONSupport(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < 4 { t.Skip("skipping JSON support on proto < 4") } if err := createTable(session, `CREATE TABLE gocql_test.test_json ( id text PRIMARY KEY, age int, state text )`); err != nil { t.Fatal(err) } err := session.Query("INSERT INTO test_json JSON ?", `{"id": "user123", "age": 42, "state": "TX"}`).Exec() if err != nil { t.Fatal(err) } var ( id string age int state string ) err = session.Query("SELECT id, age, state FROM test_json WHERE id = ?", "user123").Scan(&id, &age, &state) if err != nil { t.Fatal(err) } if id != "user123" { t.Errorf("got id %q expected %q", id, "user123") } if age != 42 { t.Errorf("got age %d expected %d", age, 42) } if state != "TX" { t.Errorf("got state %q expected %q", state, "TX") } } func TestDiscoverViaProxy(t *testing.T) { // This (complicated) test tests that when the driver is given an initial host // that is infact a proxy it discovers the rest of the ring behind the proxy // and does not store the proxies address as a host in its connection pool. // See https://github.com/apache/cassandra-gocql-driver/issues/481 clusterHosts := getClusterHosts() proxy, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("unable to create proxy listener: %v", err) } ctx, cancel := context.WithCancel(context.Background()) defer cancel() var ( mu sync.Mutex proxyConns []net.Conn closed bool ) go func() { cassandraAddr := JoinHostPort(clusterHosts[0], 9042) cassandra := func() (net.Conn, error) { return net.Dial("tcp", cassandraAddr) } proxyFn := func(errs chan error, from, to net.Conn) { _, err := io.Copy(to, from) if err != nil { errs <- err } } // handle dials cassandra and then proxies requests and reponsess. It waits // for both the read and write side of the TCP connection to close before // returning. handle := func(conn net.Conn) error { cass, err := cassandra() if err != nil { return err } defer cass.Close() errs := make(chan error, 2) go proxyFn(errs, conn, cass) go proxyFn(errs, cass, conn) select { case <-ctx.Done(): return ctx.Err() case err := <-errs: return err } } for { // proxy just accepts connections and then proxies them to cassandra, // it runs until it is closed. conn, err := proxy.Accept() if err != nil { mu.Lock() if !closed { t.Error(err) } mu.Unlock() return } mu.Lock() proxyConns = append(proxyConns, conn) mu.Unlock() go func(conn net.Conn) { defer conn.Close() if err := handle(conn); err != nil { mu.Lock() if !closed { t.Error(err) } mu.Unlock() } }(conn) } }() proxyAddr := proxy.Addr().String() cluster := createCluster() cluster.NumConns = 1 // initial host is the proxy address cluster.Hosts = []string{proxyAddr} session := createSessionFromCluster(cluster, t) defer session.Close() // we shouldnt need this but to be safe time.Sleep(1 * time.Second) session.pool.mu.RLock() for _, host := range clusterHosts { found := false for _, hi := range session.pool.hostConnPools { if hi.host.ConnectAddress().String() == host { found = true break } } if !found { t.Errorf("missing host in pool after discovery: %q", host) } } session.pool.mu.RUnlock() mu.Lock() closed = true if err := proxy.Close(); err != nil { t.Log(err) } for _, conn := range proxyConns { if err := conn.Close(); err != nil { t.Log(err) } } mu.Unlock() } func TestUnmarshallNestedTypes(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("can not have frozen types in cassandra < 2.1.3") } if err := createTable(session, `CREATE TABLE gocql_test.test_557 ( id text PRIMARY KEY, val list > > )`); err != nil { t.Fatal(err) } m := []map[string]string{ {"key1": "val1"}, {"key2": "val2"}, } const id = "key" err := session.Query("INSERT INTO test_557(id, val) VALUES(?, ?)", id, m).Exec() if err != nil { t.Fatal(err) } var data []map[string]string if err := session.Query("SELECT val FROM test_557 WHERE id = ?", id).Scan(&data); err != nil { t.Fatal(err) } if !reflect.DeepEqual(data, m) { t.Fatalf("%+#v != %+#v", data, m) } } func TestSchemaReset(t *testing.T) { if flagCassVersion.Major == 0 || flagCassVersion.Before(2, 1, 3) { t.Skipf("skipping TestSchemaReset due to CASSANDRA-7910 in Cassandra <2.1.3 version=%v", flagCassVersion) } cluster := createCluster() cluster.NumConns = 1 session := createSessionFromCluster(cluster, t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.test_schema_reset ( id text PRIMARY KEY)`); err != nil { t.Fatal(err) } const key = "test" err := session.Query("INSERT INTO test_schema_reset(id) VALUES(?)", key).Exec() if err != nil { t.Fatal(err) } var id string err = session.Query("SELECT * FROM test_schema_reset WHERE id=?", key).Scan(&id) if err != nil { t.Fatal(err) } else if id != key { t.Fatalf("expected to get id=%q got=%q", key, id) } if err := createTable(session, `ALTER TABLE gocql_test.test_schema_reset ADD val text`); err != nil { t.Fatal(err) } const expVal = "test-val" err = session.Query("INSERT INTO test_schema_reset(id, val) VALUES(?, ?)", key, expVal).Exec() if err != nil { t.Fatal(err) } var val string err = session.Query("SELECT * FROM test_schema_reset WHERE id=?", key).Scan(&id, &val) if err != nil { t.Fatal(err) } if id != key { t.Errorf("expected to get id=%q got=%q", key, id) } if val != expVal { t.Errorf("expected to get val=%q got=%q", expVal, val) } } func TestCreateSession_DontSwallowError(t *testing.T) { t.Skip("This test is bad, and the resultant error from cassandra changes between versions") cluster := createCluster() cluster.ProtoVersion = 0x100 session, err := cluster.CreateSession() if err == nil { session.Close() t.Fatal("expected to get an error for unsupported protocol") } if flagCassVersion.Major < 3 { // TODO: we should get a distinct error type here which include the underlying // cassandra error about the protocol version, for now check this here. if !strings.Contains(err.Error(), "Invalid or unsupported protocol version") { t.Fatalf(`expcted to get error "unsupported protocol version" got: %q`, err) } } else { if !strings.Contains(err.Error(), "unsupported response version") { t.Fatalf(`expcted to get error "unsupported response version" got: %q`, err) } } } func TestControl_DiscoverProtocol(t *testing.T) { cluster := createCluster() cluster.ProtoVersion = 0 session, err := cluster.CreateSession() if err != nil { t.Fatal(err) } defer session.Close() if session.cfg.ProtoVersion == 0 { t.Fatal("did not discovery protocol") } } // TestUnsetCol verify unset column will not replace an existing column func TestUnsetCol(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < 4 { t.Skip("Unset Values are not supported in protocol < 4") } if err := createTable(session, "CREATE TABLE gocql_test.testUnsetInsert (id int, my_int int, my_text text, PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } if err := session.Query("INSERT INTO testUnSetInsert (id,my_int,my_text) VALUES (?,?,?)", 1, 2, "3").Exec(); err != nil { t.Fatalf("failed to insert with err: %v", err) } if err := session.Query("INSERT INTO testUnSetInsert (id,my_int,my_text) VALUES (?,?,?)", 1, UnsetValue, UnsetValue).Exec(); err != nil { t.Fatalf("failed to insert with err: %v", err) } var id, mInt int var mText string if err := session.Query("SELECT id, my_int ,my_text FROM testUnsetInsert").Scan(&id, &mInt, &mText); err != nil { t.Fatalf("failed to select with err: %v", err) } else if id != 1 || mInt != 2 || mText != "3" { t.Fatalf("Expected results: 1, 2, \"3\", got %v, %v, %v", id, mInt, mText) } } // TestUnsetColBatch verify unset column will not replace a column in batch func TestUnsetColBatch(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < 4 { t.Skip("Unset Values are not supported in protocol < 4") } if err := createTable(session, "CREATE TABLE gocql_test.batchUnsetInsert (id int, my_int int, my_text text, PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } b := session.NewBatch(LoggedBatch) b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue) b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "") b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue) if err := session.ExecuteBatch(b); err != nil { t.Fatalf("query failed. %v", err) } else { if b.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } if b.Latency() <= 0 { t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency()) } } var id, mInt, count int var mText string if err := session.Query("SELECT count(*) FROM gocql_test.batchUnsetInsert;").Scan(&count); err != nil { t.Fatalf("Failed to select with err: %v", err) } else if count != 2 { t.Fatalf("Expected Batch Insert count 2, got %v", count) } if err := session.Query("SELECT id, my_int ,my_text FROM gocql_test.batchUnsetInsert where id=1;").Scan(&id, &mInt, &mText); err != nil { t.Fatalf("failed to select with err: %v", err) } else if id != mInt { t.Fatalf("expected id, my_int to be 1, got %v and %v", id, mInt) } } func TestQuery_NamedValues(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < 3 { t.Skip("named Values are not supported in protocol < 3") } if err := createTable(session, "CREATE TABLE gocql_test.named_query(id int, value text, PRIMARY KEY (id))"); err != nil { t.Fatal(err) } err := session.Query("INSERT INTO gocql_test.named_query(id, value) VALUES(:id, :value)", NamedValue("id", 1), NamedValue("value", "i am a value")).Exec() if err != nil { t.Fatal(err) } var value string if err := session.Query("SELECT VALUE from gocql_test.named_query WHERE id = :id", NamedValue("id", 1)).Scan(&value); err != nil { t.Fatal(err) } } cassandra-gocql-driver-1.7.0/cluster.go000066400000000000000000000326501467504044300200770ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "context" "errors" "net" "time" ) // PoolConfig configures the connection pool used by the driver, it defaults to // using a round-robin host selection policy and a round-robin connection selection // policy for each host. type PoolConfig struct { // HostSelectionPolicy sets the policy for selecting which host to use for a // given query (default: RoundRobinHostPolicy()) // It is not supported to use a single HostSelectionPolicy in multiple sessions // (even if you close the old session before using in a new session). HostSelectionPolicy HostSelectionPolicy } func (p PoolConfig) buildPool(session *Session) *policyConnPool { return newPolicyConnPool(session) } // ClusterConfig is a struct to configure the default cluster implementation // of gocql. It has a variety of attributes that can be used to modify the // behavior to fit the most common use cases. Applications that require a // different setup must implement their own cluster. type ClusterConfig struct { // addresses for the initial connections. It is recommended to use the value set in // the Cassandra config for broadcast_address or listen_address, an IP address not // a domain name. This is because events from Cassandra will use the configured IP // address, which is used to index connected hosts. If the domain name specified // resolves to more than 1 IP address then the driver may connect multiple times to // the same host, and will not mark the node being down or up from events. Hosts []string // CQL version (default: 3.0.0) CQLVersion string // ProtoVersion sets the version of the native protocol to use, this will // enable features in the driver for specific protocol versions, generally this // should be set to a known version (2,3,4) for the cluster being connected to. // // If it is 0 or unset (the default) then the driver will attempt to discover the // highest supported protocol for the cluster. In clusters with nodes of different // versions the protocol selected is not defined (ie, it can be any of the supported in the cluster) ProtoVersion int // Timeout limits the time spent on the client side while executing a query. // Specifically, query or batch execution will return an error if the client does not receive a response // from the server within the Timeout period. // Timeout is also used to configure the read timeout on the underlying network connection. // Client Timeout should always be higher than the request timeouts configured on the server, // so that retries don't overload the server. // Timeout has a default value of 11 seconds, which is higher than default server timeout for most query types. // Timeout is not applied to requests during initial connection setup, see ConnectTimeout. Timeout time.Duration // ConnectTimeout limits the time spent during connection setup. // During initial connection setup, internal queries, AUTH requests will return an error if the client // does not receive a response within the ConnectTimeout period. // ConnectTimeout is applied to the connection setup queries independently. // ConnectTimeout also limits the duration of dialing a new TCP connection // in case there is no Dialer nor HostDialer configured. // ConnectTimeout has a default value of 11 seconds. ConnectTimeout time.Duration // WriteTimeout limits the time the driver waits to write a request to a network connection. // WriteTimeout should be lower than or equal to Timeout. // WriteTimeout defaults to the value of Timeout. WriteTimeout time.Duration // Port used when dialing. // Default: 9042 Port int // Initial keyspace. Optional. Keyspace string // Number of connections per host. // Default: 2 NumConns int // Default consistency level. // Default: Quorum Consistency Consistency // Compression algorithm. // Default: nil Compressor Compressor // Default: nil Authenticator Authenticator // An Authenticator factory. Can be used to create alternative authenticators. // Default: nil AuthProvider func(h *HostInfo) (Authenticator, error) // Default retry policy to use for queries. // Default: no retries. RetryPolicy RetryPolicy // ConvictionPolicy decides whether to mark host as down based on the error and host info. // Default: SimpleConvictionPolicy ConvictionPolicy ConvictionPolicy // Default reconnection policy to use for reconnecting before trying to mark host as down. ReconnectionPolicy ReconnectionPolicy // The keepalive period to use, enabled if > 0 (default: 0) // SocketKeepalive is used to set up the default dialer and is ignored if Dialer or HostDialer is provided. SocketKeepalive time.Duration // Maximum cache size for prepared statements globally for gocql. // Default: 1000 MaxPreparedStmts int // Maximum cache size for query info about statements for each session. // Default: 1000 MaxRoutingKeyInfo int // Default page size to use for created sessions. // Default: 5000 PageSize int // Consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL. // Default: unset SerialConsistency SerialConsistency // SslOpts configures TLS use when HostDialer is not set. // SslOpts is ignored if HostDialer is set. SslOpts *SslOptions // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. // Default: true, only enabled for protocol 3 and above. DefaultTimestamp bool // PoolConfig configures the underlying connection pool, allowing the // configuration of host selection and connection selection policies. PoolConfig PoolConfig // If not zero, gocql attempt to reconnect known DOWN nodes in every ReconnectInterval. ReconnectInterval time.Duration // The maximum amount of time to wait for schema agreement in a cluster after // receiving a schema change frame. (default: 60s) MaxWaitSchemaAgreement time.Duration // HostFilter will filter all incoming events for host, any which don't pass // the filter will be ignored. If set will take precedence over any options set // via Discovery HostFilter HostFilter // AddressTranslator will translate addresses found on peer discovery and/or // node change events. AddressTranslator AddressTranslator // If IgnorePeerAddr is true and the address in system.peers does not match // the supplied host by either initial hosts or discovered via events then the // host will be replaced with the supplied address. // // For example if an event comes in with host=10.0.0.1 but when looking up that // address in system.local or system.peers returns 127.0.0.1, the peer will be // set to 10.0.0.1 which is what will be used to connect to. IgnorePeerAddr bool // If DisableInitialHostLookup then the driver will not attempt to get host info // from the system.peers table, this will mean that the driver will connect to // hosts supplied and will not attempt to lookup the hosts information, this will // mean that data_centre, rack and token information will not be available and as // such host filtering and token aware query routing will not be available. DisableInitialHostLookup bool // Configure events the driver will register for Events struct { // disable registering for status events (node up/down) DisableNodeStatusEvents bool // disable registering for topology events (node added/removed/moved) DisableTopologyEvents bool // disable registering for schema events (keyspace/table/function removed/created/updated) DisableSchemaEvents bool } // DisableSkipMetadata will override the internal result metadata cache so that the driver does not // send skip_metadata for queries, this means that the result will always contain // the metadata to parse the rows and will not reuse the metadata from the prepared // statement. // // See https://issues.apache.org/jira/browse/CASSANDRA-10786 DisableSkipMetadata bool // QueryObserver will set the provided query observer on all queries created from this session. // Use it to collect metrics / stats from queries by providing an implementation of QueryObserver. QueryObserver QueryObserver // BatchObserver will set the provided batch observer on all queries created from this session. // Use it to collect metrics / stats from batch queries by providing an implementation of BatchObserver. BatchObserver BatchObserver // ConnectObserver will set the provided connect observer on all queries // created from this session. ConnectObserver ConnectObserver // FrameHeaderObserver will set the provided frame header observer on all frames' headers created from this session. // Use it to collect metrics / stats from frames by providing an implementation of FrameHeaderObserver. FrameHeaderObserver FrameHeaderObserver // StreamObserver will be notified of stream state changes. // This can be used to track in-flight protocol requests and responses. StreamObserver StreamObserver // Default idempotence for queries DefaultIdempotence bool // The time to wait for frames before flushing the frames connection to Cassandra. // Can help reduce syscall overhead by making less calls to write. Set to 0 to // disable. // // (default: 200 microseconds) WriteCoalesceWaitTime time.Duration // Dialer will be used to establish all connections created for this Cluster. // If not provided, a default dialer configured with ConnectTimeout will be used. // Dialer is ignored if HostDialer is provided. Dialer Dialer // HostDialer will be used to establish all connections for this Cluster. // If not provided, Dialer will be used instead. HostDialer HostDialer // Logger for this ClusterConfig. // If not specified, defaults to the global gocql.Logger. Logger StdLogger // internal config for testing disableControlConn bool } type Dialer interface { DialContext(ctx context.Context, network, addr string) (net.Conn, error) } // NewCluster generates a new config for the default cluster implementation. // // The supplied hosts are used to initially connect to the cluster then the rest of // the ring will be automatically discovered. It is recommended to use the value set in // the Cassandra config for broadcast_address or listen_address, an IP address not // a domain name. This is because events from Cassandra will use the configured IP // address, which is used to index connected hosts. If the domain name specified // resolves to more than 1 IP address then the driver may connect multiple times to // the same host, and will not mark the node being down or up from events. func NewCluster(hosts ...string) *ClusterConfig { cfg := &ClusterConfig{ Hosts: hosts, CQLVersion: "3.0.0", Timeout: 11 * time.Second, ConnectTimeout: 11 * time.Second, Port: 9042, NumConns: 2, Consistency: Quorum, MaxPreparedStmts: defaultMaxPreparedStmts, MaxRoutingKeyInfo: 1000, PageSize: 5000, DefaultTimestamp: true, MaxWaitSchemaAgreement: 60 * time.Second, ReconnectInterval: 60 * time.Second, ConvictionPolicy: &SimpleConvictionPolicy{}, ReconnectionPolicy: &ConstantReconnectionPolicy{MaxRetries: 3, Interval: 1 * time.Second}, WriteCoalesceWaitTime: 200 * time.Microsecond, } return cfg } func (cfg *ClusterConfig) logger() StdLogger { if cfg.Logger == nil { return Logger } return cfg.Logger } // CreateSession initializes the cluster based on this config and returns a // session object that can be used to interact with the database. func (cfg *ClusterConfig) CreateSession() (*Session, error) { return NewSession(*cfg) } // translateAddressPort is a helper method that will use the given AddressTranslator // if defined, to translate the given address and port into a possibly new address // and port, If no AddressTranslator or if an error occurs, the given address and // port will be returned. func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) { if cfg.AddressTranslator == nil || len(addr) == 0 { return addr, port } newAddr, newPort := cfg.AddressTranslator.Translate(addr, port) if gocqlDebug { cfg.logger().Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort) } return newAddr, newPort } func (cfg *ClusterConfig) filterHost(host *HostInfo) bool { return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host)) } var ( ErrNoHosts = errors.New("no hosts provided") ErrNoConnectionsStarted = errors.New("no connections were made when creating the session") ErrHostQueryFailed = errors.New("unable to populate Hosts") ) cassandra-gocql-driver-1.7.0/cluster_test.go000066400000000000000000000072451467504044300211400ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "net" "reflect" "testing" "time" ) func TestNewCluster_Defaults(t *testing.T) { cfg := NewCluster() assertEqual(t, "cluster config cql version", "3.0.0", cfg.CQLVersion) assertEqual(t, "cluster config timeout", 11*time.Second, cfg.Timeout) assertEqual(t, "cluster config port", 9042, cfg.Port) assertEqual(t, "cluster config num-conns", 2, cfg.NumConns) assertEqual(t, "cluster config consistency", Quorum, cfg.Consistency) assertEqual(t, "cluster config max prepared statements", defaultMaxPreparedStmts, cfg.MaxPreparedStmts) assertEqual(t, "cluster config max routing key info", 1000, cfg.MaxRoutingKeyInfo) assertEqual(t, "cluster config page-size", 5000, cfg.PageSize) assertEqual(t, "cluster config default timestamp", true, cfg.DefaultTimestamp) assertEqual(t, "cluster config max wait schema agreement", 60*time.Second, cfg.MaxWaitSchemaAgreement) assertEqual(t, "cluster config reconnect interval", 60*time.Second, cfg.ReconnectInterval) assertTrue(t, "cluster config conviction policy", reflect.DeepEqual(&SimpleConvictionPolicy{}, cfg.ConvictionPolicy)) assertTrue(t, "cluster config reconnection policy", reflect.DeepEqual(&ConstantReconnectionPolicy{MaxRetries: 3, Interval: 1 * time.Second}, cfg.ReconnectionPolicy)) } func TestNewCluster_WithHosts(t *testing.T) { cfg := NewCluster("addr1", "addr2") assertEqual(t, "cluster config hosts length", 2, len(cfg.Hosts)) assertEqual(t, "cluster config host 0", "addr1", cfg.Hosts[0]) assertEqual(t, "cluster config host 1", "addr2", cfg.Hosts[1]) } func TestClusterConfig_translateAddressAndPort_NilTranslator(t *testing.T) { cfg := NewCluster() assertNil(t, "cluster config address translator", cfg.AddressTranslator) newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 1234) assertTrue(t, "same address as provided", net.ParseIP("10.0.0.1").Equal(newAddr)) assertEqual(t, "translated host and port", 1234, newPort) } func TestClusterConfig_translateAddressAndPort_EmptyAddr(t *testing.T) { cfg := NewCluster() cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432) newAddr, newPort := cfg.translateAddressPort(net.IP([]byte{}), 0) assertTrue(t, "translated address is still empty", len(newAddr) == 0) assertEqual(t, "translated port", 0, newPort) } func TestClusterConfig_translateAddressAndPort_Success(t *testing.T) { cfg := NewCluster() cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432) newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 2345) assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr)) assertEqual(t, "translated port", 5432, newPort) } cassandra-gocql-driver-1.7.0/common_test.go000066400000000000000000000222521467504044300207420ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "flag" "fmt" "log" "net" "reflect" "strings" "sync" "testing" "time" ) var ( flagCluster = flag.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") flagProto = flag.Int("proto", 0, "protcol version") flagCQL = flag.String("cql", "3.0.0", "CQL version") flagRF = flag.Int("rf", 1, "replication factor for test keyspace") clusterSize = flag.Int("clusterSize", 1, "the expected size of the cluster") flagRetry = flag.Int("retries", 5, "number of times to retry queries") flagAutoWait = flag.Duration("autowait", 1000*time.Millisecond, "time to wait for autodiscovery to fill the hosts poll") flagRunSslTest = flag.Bool("runssl", false, "Set to true to run ssl test") flagRunAuthTest = flag.Bool("runauth", false, "Set to true to run authentication test") flagCompressTest = flag.String("compressor", "", "compressor to use") flagTimeout = flag.Duration("gocql.timeout", 5*time.Second, "sets the connection `timeout` for all operations") flagCassVersion cassVersion ) func init() { flag.Var(&flagCassVersion, "gocql.cversion", "the cassandra version being tested against") log.SetFlags(log.Lshortfile | log.LstdFlags) } func getClusterHosts() []string { return strings.Split(*flagCluster, ",") } func addSslOptions(cluster *ClusterConfig) *ClusterConfig { if *flagRunSslTest { cluster.SslOpts = &SslOptions{ CertPath: "testdata/pki/gocql.crt", KeyPath: "testdata/pki/gocql.key", CaPath: "testdata/pki/ca.crt", EnableHostVerification: false, } } return cluster } var initOnce sync.Once func createTable(s *Session, table string) error { // lets just be really sure if err := s.control.awaitSchemaAgreement(); err != nil { log.Printf("error waiting for schema agreement pre create table=%q err=%v\n", table, err) return err } if err := s.Query(table).RetryPolicy(&SimpleRetryPolicy{}).Exec(); err != nil { log.Printf("error creating table table=%q err=%v\n", table, err) return err } if err := s.control.awaitSchemaAgreement(); err != nil { log.Printf("error waiting for schema agreement post create table=%q err=%v\n", table, err) return err } return nil } func createCluster(opts ...func(*ClusterConfig)) *ClusterConfig { clusterHosts := getClusterHosts() cluster := NewCluster(clusterHosts...) cluster.ProtoVersion = *flagProto cluster.CQLVersion = *flagCQL cluster.Timeout = *flagTimeout cluster.Consistency = Quorum cluster.MaxWaitSchemaAgreement = 2 * time.Minute // travis might be slow if *flagRetry > 0 { cluster.RetryPolicy = &SimpleRetryPolicy{NumRetries: *flagRetry} } switch *flagCompressTest { case "snappy": cluster.Compressor = &SnappyCompressor{} case "": default: panic("invalid compressor: " + *flagCompressTest) } cluster = addSslOptions(cluster) for _, opt := range opts { opt(cluster) } return cluster } func createKeyspace(tb testing.TB, cluster *ClusterConfig, keyspace string) { // TODO: tb.Helper() c := *cluster c.Keyspace = "system" c.Timeout = 30 * time.Second session, err := c.CreateSession() if err != nil { panic(err) } defer session.Close() err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace) if err != nil { panic(fmt.Sprintf("unable to drop keyspace: %v", err)) } err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : %d }`, keyspace, *flagRF)) if err != nil { panic(fmt.Sprintf("unable to create keyspace: %v", err)) } } func createSessionFromCluster(cluster *ClusterConfig, tb testing.TB) *Session { // Drop and re-create the keyspace once. Different tests should use their own // individual tables, but can assume that the table does not exist before. initOnce.Do(func() { createKeyspace(tb, cluster, "gocql_test") }) cluster.Keyspace = "gocql_test" session, err := cluster.CreateSession() if err != nil { tb.Fatal("createSession:", err) } if err := session.control.awaitSchemaAgreement(); err != nil { tb.Fatal(err) } return session } func createSession(tb testing.TB, opts ...func(config *ClusterConfig)) *Session { cluster := createCluster(opts...) return createSessionFromCluster(cluster, tb) } func createViews(t *testing.T, session *Session) { if err := session.Query(` CREATE TYPE IF NOT EXISTS gocql_test.basicView ( birthday timestamp, nationality text, weight text, height text); `).Exec(); err != nil { t.Fatalf("failed to create view with err: %v", err) } } func createMaterializedViews(t *testing.T, session *Session) { if flagCassVersion.Before(3, 0, 0) { return } if err := session.Query(`CREATE TABLE IF NOT EXISTS gocql_test.view_table ( userid text, year int, month int, PRIMARY KEY (userid));`).Exec(); err != nil { t.Fatalf("failed to create materialized view with err: %v", err) } if err := session.Query(`CREATE TABLE IF NOT EXISTS gocql_test.view_table2 ( userid text, year int, month int, PRIMARY KEY (userid));`).Exec(); err != nil { t.Fatalf("failed to create materialized view with err: %v", err) } if err := session.Query(`CREATE MATERIALIZED VIEW IF NOT EXISTS gocql_test.view_view AS SELECT year, month, userid FROM gocql_test.view_table WHERE year IS NOT NULL AND month IS NOT NULL AND userid IS NOT NULL PRIMARY KEY (userid, year);`).Exec(); err != nil { t.Fatalf("failed to create materialized view with err: %v", err) } if err := session.Query(`CREATE MATERIALIZED VIEW IF NOT EXISTS gocql_test.view_view2 AS SELECT year, month, userid FROM gocql_test.view_table2 WHERE year IS NOT NULL AND month IS NOT NULL AND userid IS NOT NULL PRIMARY KEY (userid, year);`).Exec(); err != nil { t.Fatalf("failed to create materialized view with err: %v", err) } } func createFunctions(t *testing.T, session *Session) { if err := session.Query(` CREATE OR REPLACE FUNCTION gocql_test.avgState ( state tuple, val int ) CALLED ON NULL INPUT RETURNS tuple LANGUAGE java AS $$if (val !=null) {state.setInt(0, state.getInt(0)+1); state.setLong(1, state.getLong(1)+val.intValue());}return state;$$; `).Exec(); err != nil { t.Fatalf("failed to create function with err: %v", err) } if err := session.Query(` CREATE OR REPLACE FUNCTION gocql_test.avgFinal ( state tuple ) CALLED ON NULL INPUT RETURNS double LANGUAGE java AS $$double r = 0; if (state.getInt(0) == 0) return null; r = state.getLong(1); r/= state.getInt(0); return Double.valueOf(r);$$ `).Exec(); err != nil { t.Fatalf("failed to create function with err: %v", err) } } func createAggregate(t *testing.T, session *Session) { createFunctions(t, session) if err := session.Query(` CREATE OR REPLACE AGGREGATE gocql_test.average(int) SFUNC avgState STYPE tuple FINALFUNC avgFinal INITCOND (0,0); `).Exec(); err != nil { t.Fatalf("failed to create aggregate with err: %v", err) } if err := session.Query(` CREATE OR REPLACE AGGREGATE gocql_test.average2(int) SFUNC avgState STYPE tuple FINALFUNC avgFinal INITCOND (0,0); `).Exec(); err != nil { t.Fatalf("failed to create aggregate with err: %v", err) } } func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator { return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) { return newAddr, newPort }) } func assertTrue(t *testing.T, description string, value bool) { t.Helper() if !value { t.Fatalf("expected %s to be true", description) } } func assertEqual(t *testing.T, description string, expected, actual interface{}) { t.Helper() if expected != actual { t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) } } func assertDeepEqual(t *testing.T, description string, expected, actual interface{}) { t.Helper() if !reflect.DeepEqual(expected, actual) { t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual) } } func assertNil(t *testing.T, description string, actual interface{}) { t.Helper() if actual != nil { t.Fatalf("expected %s to be (nil) but was (%+v) instead", description, actual) } } cassandra-gocql-driver-1.7.0/compressor.go000066400000000000000000000032751467504044300206130ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "github.com/golang/snappy" ) type Compressor interface { Name() string Encode(data []byte) ([]byte, error) Decode(data []byte) ([]byte, error) } // SnappyCompressor implements the Compressor interface and can be used to // compress incoming and outgoing frames. The snappy compression algorithm // aims for very high speeds and reasonable compression. type SnappyCompressor struct{} func (s SnappyCompressor) Name() string { return "snappy" } func (s SnappyCompressor) Encode(data []byte) ([]byte, error) { return snappy.Encode(nil, data), nil } func (s SnappyCompressor) Decode(data []byte) ([]byte, error) { return snappy.Decode(nil, data) } cassandra-gocql-driver-1.7.0/compressor_test.go000066400000000000000000000041111467504044300216400ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "testing" "github.com/golang/snappy" ) func TestSnappyCompressor(t *testing.T) { c := SnappyCompressor{} if c.Name() != "snappy" { t.Fatalf("expected name to be 'snappy', got %v", c.Name()) } str := "My Test String" //Test Encoding expected := snappy.Encode(nil, []byte(str)) if res, err := c.Encode([]byte(str)); err != nil { t.Fatalf("failed to encode '%v' with error %v", str, err) } else if bytes.Compare(expected, res) != 0 { t.Fatal("failed to match the expected encoded value with the result encoded value.") } val, err := c.Encode([]byte(str)) if err != nil { t.Fatalf("failed to encode '%v' with error '%v'", str, err) } //Test Decoding if expected, err := snappy.Decode(nil, val); err != nil { t.Fatalf("failed to decode '%v' with error %v", val, err) } else if res, err := c.Decode(val); err != nil { t.Fatalf("failed to decode '%v' with error %v", val, err) } else if bytes.Compare(expected, res) != 0 { t.Fatal("failed to match the expected decoded value with the result decoded value.") } } cassandra-gocql-driver-1.7.0/conn.go000066400000000000000000001322351467504044300173530ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bufio" "context" "crypto/tls" "errors" "fmt" "io" "io/ioutil" "net" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/gocql/gocql/internal/lru" "github.com/gocql/gocql/internal/streams" ) var ( defaultApprovedAuthenticators = []string{ "org.apache.cassandra.auth.PasswordAuthenticator", "com.instaclustr.cassandra.auth.SharedSecretAuthenticator", "com.datastax.bdp.cassandra.auth.DseAuthenticator", "io.aiven.cassandra.auth.AivenAuthenticator", "com.ericsson.bss.cassandra.ecaudit.auth.AuditPasswordAuthenticator", "com.amazon.helenus.auth.HelenusAuthenticator", "com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", "com.scylladb.auth.SaslauthdAuthenticator", "com.scylladb.auth.TransitionalAuthenticator", "com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", } ) // approve the authenticator with the list of allowed authenticators or default list if approvedAuthenticators is empty. func approve(authenticator string, approvedAuthenticators []string) bool { if len(approvedAuthenticators) == 0 { approvedAuthenticators = defaultApprovedAuthenticators } for _, s := range approvedAuthenticators { if authenticator == s { return true } } return false } // JoinHostPort is a utility to return an address string that can be used // by `gocql.Conn` to form a connection with a host. func JoinHostPort(addr string, port int) string { addr = strings.TrimSpace(addr) if _, _, err := net.SplitHostPort(addr); err != nil { addr = net.JoinHostPort(addr, strconv.Itoa(port)) } return addr } type Authenticator interface { Challenge(req []byte) (resp []byte, auth Authenticator, err error) Success(data []byte) error } type PasswordAuthenticator struct { Username string Password string AllowedAuthenticators []string } func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) { if !approve(string(req), p.AllowedAuthenticators) { return nil, nil, fmt.Errorf("unexpected authenticator %q", req) } resp := make([]byte, 2+len(p.Username)+len(p.Password)) resp[0] = 0 copy(resp[1:], p.Username) resp[len(p.Username)+1] = 0 copy(resp[2+len(p.Username):], p.Password) return resp, nil, nil } func (p PasswordAuthenticator) Success(data []byte) error { return nil } // SslOptions configures TLS use. // // Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification // to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config. // SslOptions and Config.InsecureSkipVerify interact as follows: // // Config.InsecureSkipVerify | EnableHostVerification | Result // Config is nil | false | do not verify host // Config is nil | true | verify host // false | false | verify host // true | false | do not verify host // false | true | verify host // true | true | verify host type SslOptions struct { *tls.Config // CertPath and KeyPath are optional depending on server // config, but both fields must be omitted to avoid using a // client certificate CertPath string KeyPath string CaPath string //optional depending on server config // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this // on. // This option is basically the inverse of tls.Config.InsecureSkipVerify. // See InsecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info. // // See SslOptions documentation to see how EnableHostVerification interacts with the provided tls.Config. EnableHostVerification bool } type ConnConfig struct { ProtoVersion int CQLVersion string Timeout time.Duration WriteTimeout time.Duration ConnectTimeout time.Duration Dialer Dialer HostDialer HostDialer Compressor Compressor Authenticator Authenticator AuthProvider func(h *HostInfo) (Authenticator, error) Keepalive time.Duration Logger StdLogger tlsConfig *tls.Config disableCoalesce bool } func (c *ConnConfig) logger() StdLogger { if c.Logger == nil { return Logger } return c.Logger } type ConnErrorHandler interface { HandleError(conn *Conn, err error, closed bool) } type connErrorHandlerFn func(conn *Conn, err error, closed bool) func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) { fn(conn, err, closed) } // If not zero, how many timeouts we will allow to occur before the connection is closed // and restarted. This is to prevent a single query timeout from killing a connection // which may be serving more queries just fine. // Default is 0, should not be changed concurrently with queries. // // Deprecated. var TimeoutLimit int64 = 0 // Conn is a single connection to a Cassandra node. It can be used to execute // queries, but users are usually advised to use a more reliable, higher // level API. type Conn struct { conn net.Conn r *bufio.Reader w contextWriter timeout time.Duration writeTimeout time.Duration cfg *ConnConfig frameObserver FrameHeaderObserver streamObserver StreamObserver headerBuf [maxFrameHeaderSize]byte streams *streams.IDGenerator mu sync.Mutex // calls stores a map from stream ID to callReq. // This map is protected by mu. // calls should not be used when closed is true, calls is set to nil when closed=true. calls map[int]*callReq errorHandler ConnErrorHandler compressor Compressor auth Authenticator addr string version uint8 currentKeyspace string host *HostInfo isSchemaV2 bool session *Session // true if connection close process for the connection started. // closed is protected by mu. closed bool ctx context.Context cancel context.CancelFunc timeouts int64 logger StdLogger } // connect establishes a connection to a Cassandra node using session's connection config. func (s *Session) connect(ctx context.Context, host *HostInfo, errorHandler ConnErrorHandler) (*Conn, error) { return s.dial(ctx, host, s.connCfg, errorHandler) } // dial establishes a connection to a Cassandra node and notifies the session's connectObserver. func (s *Session) dial(ctx context.Context, host *HostInfo, connConfig *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) { var obs ObservedConnect if s.connectObserver != nil { obs.Host = host obs.Start = time.Now() } conn, err := s.dialWithoutObserver(ctx, host, connConfig, errorHandler) if s.connectObserver != nil { obs.End = time.Now() obs.Err = err s.connectObserver.ObserveConnect(obs) } return conn, err } // dialWithoutObserver establishes connection to a Cassandra node. // // dialWithoutObserver does not notify the connection observer, so you most probably want to call dial() instead. func (s *Session) dialWithoutObserver(ctx context.Context, host *HostInfo, cfg *ConnConfig, errorHandler ConnErrorHandler) (*Conn, error) { dialedHost, err := cfg.HostDialer.DialHost(ctx, host) if err != nil { return nil, err } writeTimeout := cfg.Timeout if cfg.WriteTimeout > 0 { writeTimeout = cfg.WriteTimeout } ctx, cancel := context.WithCancel(ctx) c := &Conn{ conn: dialedHost.Conn, r: bufio.NewReader(dialedHost.Conn), cfg: cfg, calls: make(map[int]*callReq), version: uint8(cfg.ProtoVersion), addr: dialedHost.Conn.RemoteAddr().String(), errorHandler: errorHandler, compressor: cfg.Compressor, session: s, streams: streams.New(cfg.ProtoVersion), host: host, isSchemaV2: true, // Try using "system.peers_v2" until proven otherwise frameObserver: s.frameObserver, w: &deadlineContextWriter{ w: dialedHost.Conn, timeout: writeTimeout, semaphore: make(chan struct{}, 1), quit: make(chan struct{}), }, ctx: ctx, cancel: cancel, logger: cfg.logger(), streamObserver: s.streamObserver, writeTimeout: writeTimeout, } if err := c.init(ctx, dialedHost); err != nil { cancel() c.Close() return nil, err } return c, nil } func (c *Conn) init(ctx context.Context, dialedHost *DialedHost) error { if c.session.cfg.AuthProvider != nil { var err error c.auth, err = c.cfg.AuthProvider(c.host) if err != nil { return err } } else { c.auth = c.cfg.Authenticator } startup := &startupCoordinator{ frameTicker: make(chan struct{}), conn: c, } c.timeout = c.cfg.ConnectTimeout if err := startup.setupConn(ctx); err != nil { return err } c.timeout = c.cfg.Timeout // dont coalesce startup frames if c.session.cfg.WriteCoalesceWaitTime > 0 && !c.cfg.disableCoalesce && !dialedHost.DisableCoalesce { c.w = newWriteCoalescer(c.conn, c.writeTimeout, c.session.cfg.WriteCoalesceWaitTime, ctx.Done()) } go c.serve(ctx) go c.heartBeat(ctx) return nil } func (c *Conn) Write(p []byte) (n int, err error) { return c.w.writeContext(context.Background(), p) } func (c *Conn) Read(p []byte) (n int, err error) { const maxAttempts = 5 for i := 0; i < maxAttempts; i++ { var nn int if c.timeout > 0 { c.conn.SetReadDeadline(time.Now().Add(c.timeout)) } nn, err = io.ReadFull(c.r, p[n:]) n += nn if err == nil { break } if verr, ok := err.(net.Error); !ok || !verr.Temporary() { break } } return } type startupCoordinator struct { conn *Conn frameTicker chan struct{} } func (s *startupCoordinator) setupConn(ctx context.Context) error { var cancel context.CancelFunc if s.conn.timeout > 0 { ctx, cancel = context.WithTimeout(ctx, s.conn.timeout) } else { ctx, cancel = context.WithCancel(ctx) } defer cancel() startupErr := make(chan error) go func() { for range s.frameTicker { err := s.conn.recv(ctx) if err != nil { select { case startupErr <- err: case <-ctx.Done(): } return } } }() go func() { defer close(s.frameTicker) err := s.options(ctx) select { case startupErr <- err: case <-ctx.Done(): } }() select { case err := <-startupErr: if err != nil { return err } case <-ctx.Done(): return errors.New("gocql: no response to connection startup within timeout") } return nil } func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder) (frame, error) { select { case s.frameTicker <- struct{}{}: case <-ctx.Done(): return nil, ctx.Err() } framer, err := s.conn.exec(ctx, frame, nil) if err != nil { return nil, err } return framer.parseFrame() } func (s *startupCoordinator) options(ctx context.Context) error { frame, err := s.write(ctx, &writeOptionsFrame{}) if err != nil { return err } supported, ok := frame.(*supportedFrame) if !ok { return NewErrProtocol("Unknown type of response to startup frame: %T", frame) } return s.startup(ctx, supported.supported) } func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string) error { m := map[string]string{ "CQL_VERSION": s.conn.cfg.CQLVersion, "DRIVER_NAME": driverName, "DRIVER_VERSION": driverVersion, } if s.conn.compressor != nil { comp := supported["COMPRESSION"] name := s.conn.compressor.Name() for _, compressor := range comp { if compressor == name { m["COMPRESSION"] = compressor break } } if _, ok := m["COMPRESSION"]; !ok { s.conn.compressor = nil } } frame, err := s.write(ctx, &writeStartupFrame{opts: m}) if err != nil { return err } switch v := frame.(type) { case error: return v case *readyFrame: return nil case *authenticateFrame: return s.authenticateHandshake(ctx, v) default: return NewErrProtocol("Unknown type of response to startup frame: %s", v) } } func (s *startupCoordinator) authenticateHandshake(ctx context.Context, authFrame *authenticateFrame) error { if s.conn.auth == nil { return fmt.Errorf("authentication required (using %q)", authFrame.class) } resp, challenger, err := s.conn.auth.Challenge([]byte(authFrame.class)) if err != nil { return err } req := &writeAuthResponseFrame{data: resp} for { frame, err := s.write(ctx, req) if err != nil { return err } switch v := frame.(type) { case error: return v case *authSuccessFrame: if challenger != nil { return challenger.Success(v.data) } return nil case *authChallengeFrame: resp, challenger, err = challenger.Challenge(v.data) if err != nil { return err } req = &writeAuthResponseFrame{ data: resp, } default: return fmt.Errorf("unknown frame response during authentication: %v", v) } } } func (c *Conn) closeWithError(err error) { if c == nil { return } c.mu.Lock() if c.closed { c.mu.Unlock() return } c.closed = true var callsToClose map[int]*callReq // We should attempt to deliver the error back to the caller if it // exists. However, don't block c.mu while we are delivering the // error to outstanding calls. if err != nil { callsToClose = c.calls // It is safe to change c.calls to nil. Nobody should use it after c.closed is set to true. c.calls = nil } c.mu.Unlock() for _, req := range callsToClose { // we need to send the error to all waiting queries. select { case req.resp <- callResp{err: err}: case <-req.timeout: } if req.streamObserverContext != nil { req.streamObserverEndOnce.Do(func() { req.streamObserverContext.StreamAbandoned(ObservedStream{ Host: c.host, }) }) } } // if error was nil then unblock the quit channel c.cancel() cerr := c.close() if err != nil { c.errorHandler.HandleError(c, err, true) } else if cerr != nil { // TODO(zariel): is it a good idea to do this? c.errorHandler.HandleError(c, cerr, true) } } func (c *Conn) close() error { return c.conn.Close() } func (c *Conn) Close() { c.closeWithError(nil) } // Serve starts the stream multiplexer for this connection, which is required // to execute any queries. This method runs as long as the connection is // open and is therefore usually called in a separate goroutine. func (c *Conn) serve(ctx context.Context) { var err error for err == nil { err = c.recv(ctx) } c.closeWithError(err) } func (c *Conn) discardFrame(head frameHeader) error { _, err := io.CopyN(ioutil.Discard, c, int64(head.length)) if err != nil { return err } return nil } type protocolError struct { frame frame } func (p *protocolError) Error() string { if err, ok := p.frame.(error); ok { return err.Error() } return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame) } func (c *Conn) heartBeat(ctx context.Context) { sleepTime := 1 * time.Second timer := time.NewTimer(sleepTime) defer timer.Stop() var failures int for { if failures > 5 { c.closeWithError(fmt.Errorf("gocql: heartbeat failed")) return } timer.Reset(sleepTime) select { case <-ctx.Done(): return case <-timer.C: } framer, err := c.exec(context.Background(), &writeOptionsFrame{}, nil) if err != nil { failures++ continue } resp, err := framer.parseFrame() if err != nil { // invalid frame failures++ continue } switch resp.(type) { case *supportedFrame: // Everything ok sleepTime = 5 * time.Second failures = 0 case error: // TODO: should we do something here? default: panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp)) } } } func (c *Conn) recv(ctx context.Context) error { // not safe for concurrent reads // read a full header, ignore timeouts, as this is being ran in a loop // TODO: TCP level deadlines? or just query level deadlines? if c.timeout > 0 { c.conn.SetReadDeadline(time.Time{}) } headStartTime := time.Now() // were just reading headers over and over and copy bodies head, err := readHeader(c.r, c.headerBuf[:]) headEndTime := time.Now() if err != nil { return err } if c.frameObserver != nil { c.frameObserver.ObserveFrameHeader(context.Background(), ObservedFrameHeader{ Version: protoVersion(head.version), Flags: head.flags, Stream: int16(head.stream), Opcode: frameOp(head.op), Length: int32(head.length), Start: headStartTime, End: headEndTime, Host: c.host, }) } if head.stream > c.streams.NumStreams { return fmt.Errorf("gocql: frame header stream is beyond call expected bounds: %d", head.stream) } else if head.stream == -1 { // TODO: handle cassandra event frames, we shouldnt get any currently framer := newFramer(c.compressor, c.version) if err := framer.readFrame(c, &head); err != nil { return err } go c.session.handleEvent(framer) return nil } else if head.stream <= 0 { // reserved stream that we dont use, probably due to a protocol error // or a bug in Cassandra, this should be an error, parse it and return. framer := newFramer(c.compressor, c.version) if err := framer.readFrame(c, &head); err != nil { return err } frame, err := framer.parseFrame() if err != nil { return err } return &protocolError{ frame: frame, } } c.mu.Lock() if c.closed { c.mu.Unlock() return ErrConnectionClosed } call, ok := c.calls[head.stream] delete(c.calls, head.stream) c.mu.Unlock() if call == nil || !ok { c.logger.Printf("gocql: received response for stream which has no handler: header=%v\n", head) return c.discardFrame(head) } else if head.stream != call.streamID { panic(fmt.Sprintf("call has incorrect streamID: got %d expected %d", call.streamID, head.stream)) } framer := newFramer(c.compressor, c.version) err = framer.readFrame(c, &head) if err != nil { // only net errors should cause the connection to be closed. Though // cassandra returning corrupt frames will be returned here as well. if _, ok := err.(net.Error); ok { return err } } // we either, return a response to the caller, the caller timedout, or the // connection has closed. Either way we should never block indefinatly here select { case call.resp <- callResp{framer: framer, err: err}: case <-call.timeout: c.releaseStream(call) case <-ctx.Done(): } return nil } func (c *Conn) releaseStream(call *callReq) { if call.timer != nil { call.timer.Stop() } c.streams.Clear(call.streamID) if call.streamObserverContext != nil { call.streamObserverEndOnce.Do(func() { call.streamObserverContext.StreamFinished(ObservedStream{ Host: c.host, }) }) } } func (c *Conn) handleTimeout() { if TimeoutLimit > 0 && atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit { c.closeWithError(ErrTooManyTimeouts) } } type callReq struct { // resp will receive the frame that was sent as a response to this stream. resp chan callResp timeout chan struct{} // indicates to recv() that a call has timed out streamID int // current stream in use timer *time.Timer // streamObserverContext is notified about events regarding this stream streamObserverContext StreamObserverContext // streamObserverEndOnce ensures that either StreamAbandoned or StreamFinished is called, // but not both. streamObserverEndOnce sync.Once } type callResp struct { // framer is the response frame. // May be nil if err is not nil. framer *framer // err is error encountered, if any. err error } // contextWriter is like io.Writer, but takes context as well. type contextWriter interface { // writeContext writes p to the connection. // // If ctx is canceled before we start writing p (e.g. during waiting while another write is currently in progress), // p is not written and ctx.Err() is returned. Context is ignored after we start writing p (i.e. we don't interrupt // blocked writes that are in progress) so that we always either write the full frame or not write it at all. // // It returns the number of bytes written from p (0 <= n <= len(p)) and any error that caused the write to stop // early. writeContext must return a non-nil error if it returns n < len(p). writeContext must not modify the // data in p, even temporarily. writeContext(ctx context.Context, p []byte) (n int, err error) } type deadlineWriter interface { SetWriteDeadline(time.Time) error io.Writer } type deadlineContextWriter struct { w deadlineWriter timeout time.Duration // semaphore protects critical section for SetWriteDeadline/Write. // It is a channel with capacity 1. semaphore chan struct{} // quit closed once the connection is closed. quit chan struct{} } // writeContext implements contextWriter. func (c *deadlineContextWriter) writeContext(ctx context.Context, p []byte) (int, error) { select { case <-ctx.Done(): return 0, ctx.Err() case <-c.quit: return 0, ErrConnectionClosed case c.semaphore <- struct{}{}: // acquired } defer func() { // release <-c.semaphore }() if c.timeout > 0 { err := c.w.SetWriteDeadline(time.Now().Add(c.timeout)) if err != nil { return 0, err } } return c.w.Write(p) } func newWriteCoalescer(conn deadlineWriter, writeTimeout, coalesceDuration time.Duration, quit <-chan struct{}) *writeCoalescer { wc := &writeCoalescer{ writeCh: make(chan writeRequest), c: conn, quit: quit, timeout: writeTimeout, } go wc.writeFlusher(coalesceDuration) return wc } type writeCoalescer struct { c deadlineWriter mu sync.Mutex quit <-chan struct{} writeCh chan writeRequest timeout time.Duration testEnqueuedHook func() testFlushedHook func() } type writeRequest struct { // resultChan is a channel (with buffer size 1) where to send results of the write. resultChan chan<- writeResult // data to write. data []byte } type writeResult struct { n int err error } // writeContext implements contextWriter. func (w *writeCoalescer) writeContext(ctx context.Context, p []byte) (int, error) { resultChan := make(chan writeResult, 1) wr := writeRequest{ resultChan: resultChan, data: p, } select { case <-ctx.Done(): return 0, ctx.Err() case <-w.quit: return 0, io.EOF // TODO: better error here? case w.writeCh <- wr: // enqueued for writing } if w.testEnqueuedHook != nil { w.testEnqueuedHook() } result := <-resultChan return result.n, result.err } func (w *writeCoalescer) writeFlusher(interval time.Duration) { timer := time.NewTimer(interval) defer timer.Stop() if !timer.Stop() { <-timer.C } w.writeFlusherImpl(timer.C, func() { timer.Reset(interval) }) } func (w *writeCoalescer) writeFlusherImpl(timerC <-chan time.Time, resetTimer func()) { running := false var buffers net.Buffers var resultChans []chan<- writeResult for { select { case req := <-w.writeCh: buffers = append(buffers, req.data) resultChans = append(resultChans, req.resultChan) if !running { // Start timer on first write. resetTimer() running = true } case <-w.quit: result := writeResult{ n: 0, err: io.EOF, // TODO: better error here? } // Unblock whoever was waiting. for _, resultChan := range resultChans { // resultChan has capacity 1, so it does not block. resultChan <- result } return case <-timerC: running = false w.flush(resultChans, buffers) buffers = nil resultChans = nil if w.testFlushedHook != nil { w.testFlushedHook() } } } } func (w *writeCoalescer) flush(resultChans []chan<- writeResult, buffers net.Buffers) { // Flush everything we have so far. if w.timeout > 0 { err := w.c.SetWriteDeadline(time.Now().Add(w.timeout)) if err != nil { for i := range resultChans { resultChans[i] <- writeResult{ n: 0, err: err, } } return } } // Copy buffers because WriteTo modifies buffers in-place. buffers2 := make(net.Buffers, len(buffers)) copy(buffers2, buffers) n, err := buffers2.WriteTo(w.c) // Writes of bytes before n succeeded, writes of bytes starting from n failed with err. // Use n as remaining byte counter. for i := range buffers { if int64(len(buffers[i])) <= n { // this buffer was fully written. resultChans[i] <- writeResult{ n: len(buffers[i]), err: nil, } n -= int64(len(buffers[i])) } else { // this buffer was not (fully) written. resultChans[i] <- writeResult{ n: int(n), err: err, } n = 0 } } } // addCall attempts to add a call to c.calls. // It fails with error if the connection already started closing or if a call for the given stream // already exists. func (c *Conn) addCall(call *callReq) error { c.mu.Lock() defer c.mu.Unlock() if c.closed { return ErrConnectionClosed } existingCall := c.calls[call.streamID] if existingCall != nil { return fmt.Errorf("attempting to use stream already in use: %d -> %d", call.streamID, existingCall.streamID) } c.calls[call.streamID] = call return nil } func (c *Conn) exec(ctx context.Context, req frameBuilder, tracer Tracer) (*framer, error) { if ctxErr := ctx.Err(); ctxErr != nil { return nil, ctxErr } // TODO: move tracer onto conn stream, ok := c.streams.GetStream() if !ok { return nil, ErrNoStreams } // resp is basically a waiting semaphore protecting the framer framer := newFramer(c.compressor, c.version) call := &callReq{ timeout: make(chan struct{}), streamID: stream, resp: make(chan callResp), } if c.streamObserver != nil { call.streamObserverContext = c.streamObserver.StreamContext(ctx) } if err := c.addCall(call); err != nil { return nil, err } // After this point, we need to either read from call.resp or close(call.timeout) // since closeWithError can try to write a connection close error to call.resp. // If we don't close(call.timeout) or read from call.resp, closeWithError can deadlock. if tracer != nil { framer.trace() } if call.streamObserverContext != nil { call.streamObserverContext.StreamStarted(ObservedStream{ Host: c.host, }) } err := req.buildFrame(framer, stream) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout. close(call.timeout) // We failed to serialize the frame into a buffer. // This should not affect the connection as we didn't write anything. We just free the current call. c.mu.Lock() if !c.closed { delete(c.calls, call.streamID) } c.mu.Unlock() // We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil // check above could fail. c.releaseStream(call) return nil, err } n, err := c.w.writeContext(ctx, framer.buf) if err != nil { // closeWithError will block waiting for this stream to either receive a response // or for us to timeout, close the timeout chan here. Im not entirely sure // but we should not get a response after an error on the write side. close(call.timeout) if (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)) && n == 0 { // We have not started to write this frame. // Release the stream as no response can come from the server on the stream. c.mu.Lock() if !c.closed { delete(c.calls, call.streamID) } c.mu.Unlock() // We need to release the stream after we remove the call from c.calls, otherwise the existingCall != nil // check above could fail. c.releaseStream(call) } else { // I think this is the correct thing to do, im not entirely sure. It is not // ideal as readers might still get some data, but they probably wont. // Here we need to be careful as the stream is not available and if all // writes just timeout or fail then the pool might use this connection to // send a frame on, with all the streams used up and not returned. c.closeWithError(err) } return nil, err } var timeoutCh <-chan time.Time if c.timeout > 0 { if call.timer == nil { call.timer = time.NewTimer(0) <-call.timer.C } else { if !call.timer.Stop() { select { case <-call.timer.C: default: } } } call.timer.Reset(c.timeout) timeoutCh = call.timer.C } var ctxDone <-chan struct{} if ctx != nil { ctxDone = ctx.Done() } select { case resp := <-call.resp: close(call.timeout) if resp.err != nil { if !c.Closed() { // if the connection is closed then we cant release the stream, // this is because the request is still outstanding and we have // been handed another error from another stream which caused the // connection to close. c.releaseStream(call) } return nil, resp.err } // dont release the stream if detect a timeout as another request can reuse // that stream and get a response for the old request, which we have no // easy way of detecting. // // Ensure that the stream is not released if there are potentially outstanding // requests on the stream to prevent nil pointer dereferences in recv(). defer c.releaseStream(call) if v := resp.framer.header.version.version(); v != c.version { return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version) } return resp.framer, nil case <-timeoutCh: close(call.timeout) c.handleTimeout() return nil, ErrTimeoutNoResponse case <-ctxDone: close(call.timeout) return nil, ctx.Err() case <-c.ctx.Done(): close(call.timeout) return nil, ErrConnectionClosed } } // ObservedStream observes a single request/response stream. type ObservedStream struct { // Host of the connection used to send the stream. Host *HostInfo } // StreamObserver is notified about request/response pairs. // Streams are created for executing queries/batches or // internal requests to the database and might live longer than // execution of the query - the stream is still tracked until // response arrives so that stream IDs are not reused. type StreamObserver interface { // StreamContext is called before creating a new stream. // ctx is context passed to Session.Query / Session.Batch, // but might also be an internal context (for example // for internal requests that use control connection). // StreamContext might return nil if it is not interested // in the details of this stream. // StreamContext is called before the stream is created // and the returned StreamObserverContext might be discarded // without any methods called on the StreamObserverContext if // creation of the stream fails. // Note that if you don't need to track per-stream data, // you can always return the same StreamObserverContext. StreamContext(ctx context.Context) StreamObserverContext } // StreamObserverContext is notified about state of a stream. // A stream is started every time a request is written to the server // and is finished when a response is received. // It is abandoned when the underlying network connection is closed // before receiving a response. type StreamObserverContext interface { // StreamStarted is called when the stream is started. // This happens just before a request is written to the wire. StreamStarted(observedStream ObservedStream) // StreamAbandoned is called when we stop waiting for response. // This happens when the underlying network connection is closed. // StreamFinished won't be called if StreamAbandoned is. StreamAbandoned(observedStream ObservedStream) // StreamFinished is called when we receive a response for the stream. StreamFinished(observedStream ObservedStream) } type preparedStatment struct { id []byte request preparedMetadata response resultMetadata } type inflightPrepare struct { done chan struct{} err error preparedStatment *preparedStatment } func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { flight := &inflightPrepare{ done: make(chan struct{}), } lru.Add(stmtCacheKey, flight) return flight }) if !ok { go func() { defer close(flight.done) prep := &writePrepareFrame{ statement: stmt, } if c.version > protoVersion4 { prep.keyspace = c.currentKeyspace } // we won the race to do the load, if our context is canceled we shouldnt // stop the load as other callers are waiting for it but this caller should get // their context cancelled error. framer, err := c.exec(c.ctx, prep, tracer) if err != nil { flight.err = err c.session.stmtsLRU.remove(stmtCacheKey) return } frame, err := framer.parseFrame() if err != nil { flight.err = err c.session.stmtsLRU.remove(stmtCacheKey) return } // TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated // everytime we need to parse a frame. if len(framer.traceID) > 0 && tracer != nil { tracer.Trace(framer.traceID) } switch x := frame.(type) { case *resultPreparedFrame: flight.preparedStatment = &preparedStatment{ // defensively copy as we will recycle the underlying buffer after we // return. id: copyBytes(x.preparedID), // the type info's should _not_ have a reference to the framers read buffer, // therefore we can just copy them directly. request: x.reqMeta, response: x.respMeta, } case error: flight.err = x default: flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x) } if flight.err != nil { c.session.stmtsLRU.remove(stmtCacheKey) } }() } select { case <-ctx.Done(): return nil, ctx.Err() case <-flight.done: return flight.preparedStatment, flight.err } } func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { if named, ok := value.(*namedValue); ok { dst.name = named.name value = named.value } if _, ok := value.(unsetColumn); !ok { val, err := Marshal(typ, value) if err != nil { return err } dst.value = val } else { dst.isUnset = true } return nil } func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params := queryParams{ consistency: qry.cons, } // frame checks that it is not 0 params.serialConsistency = qry.serialCons params.defaultTimestamp = qry.defaultTimestamp params.defaultTimestampValue = qry.defaultTimestampValue if len(qry.pageState) > 0 { params.pagingState = qry.pageState } if qry.pageSize > 0 { params.pageSize = qry.pageSize } if c.version > protoVersion4 { params.keyspace = c.currentKeyspace } var ( frame frameBuilder info *preparedStatment ) if !qry.skipPrepare && qry.shouldPrepare() { // Prepare all DML queries. Other queries can not be prepared. var err error info, err = c.prepareStatement(ctx, qry.stmt, qry.trace) if err != nil { return &Iter{err: err} } values := qry.values if qry.binding != nil { values, err = qry.binding(&QueryInfo{ Id: info.id, Args: info.request.columns, Rval: info.response.columns, PKeyColumns: info.request.pkeyColumns, }) if err != nil { return &Iter{err: err} } } if len(values) != info.request.actualColCount { return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))} } params.values = make([]queryValues, len(values)) for i := 0; i < len(values); i++ { v := ¶ms.values[i] value := values[i] typ := info.request.columns[i].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { return &Iter{err: err} } } params.skipMeta = !(c.session.cfg.DisableSkipMetadata || qry.disableSkipMetadata) frame = &writeExecuteFrame{ preparedID: info.id, params: params, customPayload: qry.customPayload, } // Set "keyspace" and "table" property in the query if it is present in preparedMetadata qry.routingInfo.mu.Lock() qry.routingInfo.keyspace = info.request.keyspace qry.routingInfo.table = info.request.table qry.routingInfo.mu.Unlock() } else { frame = &writeQueryFrame{ statement: qry.stmt, params: params, customPayload: qry.customPayload, } } framer, err := c.exec(ctx, frame, qry.trace) if err != nil { return &Iter{err: err} } resp, err := framer.parseFrame() if err != nil { return &Iter{err: err} } if len(framer.traceID) > 0 && qry.trace != nil { qry.trace.Trace(framer.traceID) } switch x := resp.(type) { case *resultVoidFrame: return &Iter{framer: framer} case *resultRowsFrame: iter := &Iter{ meta: x.meta, framer: framer, numRows: x.numRows, } if params.skipMeta { if info != nil { iter.meta = info.response iter.meta.pagingState = copyBytes(x.meta.pagingState) } else { return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} } } else { iter.meta = x.meta } if x.meta.morePages() && !qry.disableAutoPage { newQry := new(Query) *newQry = *qry newQry.pageState = copyBytes(x.meta.pagingState) newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} iter.next = &nextIter{ qry: newQry, pos: int((1 - qry.prefetch) * float64(x.numRows)), } if iter.next.pos < 1 { iter.next.pos = 1 } } return iter case *resultKeyspaceFrame: return &Iter{framer: framer} case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType: iter := &Iter{framer: framer} if err := c.awaitSchemaAgreement(ctx); err != nil { // TODO: should have this behind a flag c.logger.Println(err) } // dont return an error from this, might be a good idea to give a warning // though. The impact of this returning an error would be that the cluster // is not consistent with regards to its schema. return iter case *RequestErrUnprepared: stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) return c.executeQuery(ctx, qry) case error: return &Iter{err: x, framer: framer} default: return &Iter{ err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x), framer: framer, } } } func (c *Conn) Pick(qry *Query) *Conn { if c.Closed() { return nil } return c } func (c *Conn) Closed() bool { c.mu.Lock() defer c.mu.Unlock() return c.closed } func (c *Conn) Address() string { return c.addr } func (c *Conn) AvailableStreams() int { return c.streams.Available() } func (c *Conn) UseKeyspace(keyspace string) error { q := &writeQueryFrame{statement: `USE "` + keyspace + `"`} q.params.consistency = c.session.cons framer, err := c.exec(c.ctx, q, nil) if err != nil { return err } resp, err := framer.parseFrame() if err != nil { return err } switch x := resp.(type) { case *resultKeyspaceFrame: case error: return x default: return NewErrProtocol("unknown frame in response to USE: %v", x) } c.currentKeyspace = keyspace return nil } func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { if c.version == protoVersion1 { return &Iter{err: ErrUnsupported} } n := len(batch.Entries) req := &writeBatchFrame{ typ: batch.Type, statements: make([]batchStatment, n), consistency: batch.Cons, serialConsistency: batch.serialCons, defaultTimestamp: batch.defaultTimestamp, defaultTimestampValue: batch.defaultTimestampValue, customPayload: batch.CustomPayload, } stmts := make(map[string]string, len(batch.Entries)) for i := 0; i < n; i++ { entry := &batch.Entries[i] b := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace) if err != nil { return &Iter{err: err} } var values []interface{} if entry.binding == nil { values = entry.Args } else { values, err = entry.binding(&QueryInfo{ Id: info.id, Args: info.request.columns, Rval: info.response.columns, PKeyColumns: info.request.pkeyColumns, }) if err != nil { return &Iter{err: err} } } if len(values) != info.request.actualColCount { return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))} } b.preparedID = info.id stmts[string(info.id)] = entry.Stmt b.values = make([]queryValues, info.request.actualColCount) for j := 0; j < info.request.actualColCount; j++ { v := &b.values[j] value := values[j] typ := info.request.columns[j].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { return &Iter{err: err} } } } else { b.statement = entry.Stmt } } framer, err := c.exec(batch.Context(), req, batch.trace) if err != nil { return &Iter{err: err} } resp, err := framer.parseFrame() if err != nil { return &Iter{err: err, framer: framer} } if len(framer.traceID) > 0 && batch.trace != nil { batch.trace.Trace(framer.traceID) } switch x := resp.(type) { case *resultVoidFrame: return &Iter{} case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { key := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) c.session.stmtsLRU.evictPreparedID(key, x.StatementId) } return c.executeBatch(ctx, batch) case *resultRowsFrame: iter := &Iter{ meta: x.meta, framer: framer, numRows: x.numRows, } return iter case error: return &Iter{err: x, framer: framer} default: return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer} } } func (c *Conn) query(ctx context.Context, statement string, values ...interface{}) (iter *Iter) { q := c.session.Query(statement, values...).Consistency(One).Trace(nil) q.skipPrepare = true q.disableSkipMetadata = true // we want to keep the query on this connection q.conn = c return c.executeQuery(ctx, q) } func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter { const ( peerSchema = "SELECT * FROM system.peers" peerV2Schemas = "SELECT * FROM system.peers_v2" ) c.mu.Lock() isSchemaV2 := c.isSchemaV2 c.mu.Unlock() if version.AtLeast(4, 0, 0) && isSchemaV2 { // Try "system.peers_v2" and fallback to "system.peers" if it's not found iter := c.query(ctx, peerV2Schemas) err := iter.checkErrAndNotFound() if err != nil { if errFrame, ok := err.(errorFrame); ok && errFrame.code == ErrCodeInvalid { // system.peers_v2 not found, try system.peers c.mu.Lock() c.isSchemaV2 = false c.mu.Unlock() return c.query(ctx, peerSchema) } else { return iter } } return iter } else { return c.query(ctx, peerSchema) } } func (c *Conn) querySystemLocal(ctx context.Context) *Iter { return c.query(ctx, "SELECT * FROM system.local WHERE key='local'") } func (c *Conn) awaitSchemaAgreement(ctx context.Context) (err error) { const localSchemas = "SELECT schema_version FROM system.local WHERE key='local'" var versions map[string]struct{} var schemaVersion string endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement) for time.Now().Before(endDeadline) { iter := c.querySystemPeers(ctx, c.host.version) versions = make(map[string]struct{}) rows, err := iter.SliceMap() if err != nil { goto cont } for _, row := range rows { host, err := c.session.hostInfoFromMap(row, &HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port}) if err != nil { goto cont } if !isValidPeer(host) || host.schemaVersion == "" { c.logger.Printf("invalid peer or peer with empty schema_version: peer=%q", host) continue } versions[host.schemaVersion] = struct{}{} } if err = iter.Close(); err != nil { goto cont } iter = c.query(ctx, localSchemas) for iter.Scan(&schemaVersion) { versions[schemaVersion] = struct{}{} schemaVersion = "" } if err = iter.Close(); err != nil { goto cont } if len(versions) <= 1 { return nil } cont: select { case <-ctx.Done(): return ctx.Err() case <-time.After(200 * time.Millisecond): } } if err != nil { return err } schemas := make([]string, 0, len(versions)) for schema := range versions { schemas = append(schemas, schema) } // not exported return fmt.Errorf("gocql: cluster schema versions not consistent: %+v", schemas) } var ( ErrQueryArgLength = errors.New("gocql: query argument length mismatch") ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period") ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection") ErrConnectionClosed = errors.New("gocql: connection closed waiting for response") ErrNoStreams = errors.New("gocql: no streams available on connection") ) cassandra-gocql-driver-1.7.0/conn_test.go000066400000000000000000001027271467504044300204150ustar00rootroot00000000000000//go:build all || unit // +build all unit /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bufio" "bytes" "context" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "io/ioutil" "math/rand" "net" "os" "strings" "sync" "sync/atomic" "testing" "time" "github.com/gocql/gocql/internal/streams" ) const ( defaultProto = protoVersion2 ) func TestApprove(t *testing.T) { tests := map[bool]bool{ approve("org.apache.cassandra.auth.PasswordAuthenticator", []string{}): true, approve("com.instaclustr.cassandra.auth.SharedSecretAuthenticator", []string{}): true, approve("com.datastax.bdp.cassandra.auth.DseAuthenticator", []string{}): true, approve("io.aiven.cassandra.auth.AivenAuthenticator", []string{}): true, approve("com.amazon.helenus.auth.HelenusAuthenticator", []string{}): true, approve("com.ericsson.bss.cassandra.ecaudit.auth.AuditAuthenticator", []string{}): true, approve("com.scylladb.auth.SaslauthdAuthenticator", []string{}): true, approve("com.scylladb.auth.TransitionalAuthenticator", []string{}): true, approve("com.instaclustr.cassandra.auth.InstaclustrPasswordAuthenticator", []string{}): true, approve("com.apache.cassandra.auth.FakeAuthenticator", []string{}): false, approve("com.apache.cassandra.auth.FakeAuthenticator", nil): false, approve("com.apache.cassandra.auth.FakeAuthenticator", []string{"com.apache.cassandra.auth.FakeAuthenticator"}): true, } for k, v := range tests { if k != v { t.Fatalf("expected '%v', got '%v'", k, v) } } } func TestJoinHostPort(t *testing.T) { tests := map[string]string{ "127.0.0.1:0": JoinHostPort("127.0.0.1", 0), "127.0.0.1:1": JoinHostPort("127.0.0.1:1", 9142), "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:0": JoinHostPort("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 0), "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1": JoinHostPort("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:1", 9142), } for k, v := range tests { if k != v { t.Fatalf("expected '%v', got '%v'", k, v) } } } func testCluster(proto protoVersion, addresses ...string) *ClusterConfig { cluster := NewCluster(addresses...) cluster.ProtoVersion = int(proto) cluster.disableControlConn = true return cluster } func TestSimple(t *testing.T) { srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() cluster := testCluster(defaultProto, srv.Address) db, err := cluster.CreateSession() if err != nil { t.Fatalf("0x%x: NewCluster: %v", defaultProto, err) } if err := db.Query("void").Exec(); err != nil { t.Fatalf("0x%x: %v", defaultProto, err) } } func TestSSLSimple(t *testing.T) { srv := NewSSLTestServer(t, defaultProto, context.Background()) defer srv.Stop() db, err := createTestSslCluster(srv.Address, defaultProto, true).CreateSession() if err != nil { t.Fatalf("0x%x: NewCluster: %v", defaultProto, err) } if err := db.Query("void").Exec(); err != nil { t.Fatalf("0x%x: %v", defaultProto, err) } } func TestSSLSimpleNoClientCert(t *testing.T) { srv := NewSSLTestServer(t, defaultProto, context.Background()) defer srv.Stop() db, err := createTestSslCluster(srv.Address, defaultProto, false).CreateSession() if err != nil { t.Fatalf("0x%x: NewCluster: %v", defaultProto, err) } if err := db.Query("void").Exec(); err != nil { t.Fatalf("0x%x: %v", defaultProto, err) } } func createTestSslCluster(addr string, proto protoVersion, useClientCert bool) *ClusterConfig { cluster := testCluster(proto, addr) sslOpts := &SslOptions{ CaPath: "testdata/pki/ca.crt", EnableHostVerification: false, } if useClientCert { sslOpts.CertPath = "testdata/pki/gocql.crt" sslOpts.KeyPath = "testdata/pki/gocql.key" } cluster.SslOpts = sslOpts return cluster } func TestClosed(t *testing.T) { t.Skip("Skipping the execution of TestClosed for now to try to concentrate on more important test failures on Travis") srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() session, err := newTestSession(defaultProto, srv.Address) if err != nil { t.Fatalf("0x%x: NewCluster: %v", defaultProto, err) } session.Close() if err := session.Query("void").Exec(); err != ErrSessionClosed { t.Fatalf("0x%x: expected %#v, got %#v", defaultProto, ErrSessionClosed, err) } } func newTestSession(proto protoVersion, addresses ...string) (*Session, error) { return testCluster(proto, addresses...).CreateSession() } func TestDNSLookupConnected(t *testing.T) { log := &testLogger{} // Override the defaul DNS resolver and restore at the end failDNS = true defer func() { failDNS = false }() srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() cluster := NewCluster("cassandra1.invalid", srv.Address, "cassandra2.invalid") cluster.Logger = log cluster.ProtoVersion = int(defaultProto) cluster.disableControlConn = true // CreateSession() should attempt to resolve the DNS name "cassandraX.invalid" // and fail, but continue to connect via srv.Address _, err := cluster.CreateSession() if err != nil { t.Fatal("CreateSession() should have connected") } if !strings.Contains(log.String(), "gocql: dns error") { t.Fatalf("Expected to receive dns error log message - got '%s' instead", log.String()) } } func TestDNSLookupError(t *testing.T) { log := &testLogger{} // Override the defaul DNS resolver and restore at the end failDNS = true defer func() { failDNS = false }() cluster := NewCluster("cassandra1.invalid", "cassandra2.invalid") cluster.Logger = log cluster.ProtoVersion = int(defaultProto) cluster.disableControlConn = true // CreateSession() should attempt to resolve each DNS name "cassandraX.invalid" // and fail since it could not resolve any dns entries _, err := cluster.CreateSession() if err == nil { t.Fatal("CreateSession() should have returned an error") } if !strings.Contains(log.String(), "gocql: dns error") { t.Fatalf("Expected to receive dns error log message - got '%s' instead", log.String()) } if err.Error() != "gocql: unable to create session: failed to resolve any of the provided hostnames" { t.Fatalf("Expected CreateSession() to fail with message - got '%s' instead", err.Error()) } } func TestStartupTimeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) log := &testLogger{} srv := NewTestServer(t, defaultProto, ctx) defer srv.Stop() // Tell the server to never respond to Startup frame atomic.StoreInt32(&srv.TimeoutOnStartup, 1) startTime := time.Now() cluster := NewCluster(srv.Address) cluster.Logger = log cluster.ProtoVersion = int(defaultProto) cluster.disableControlConn = true // Set very long query connection timeout // so we know CreateSession() is using the ConnectTimeout cluster.Timeout = time.Second * 5 cluster.ConnectTimeout = 600 * time.Millisecond // Create session should timeout during connect attempt _, err := cluster.CreateSession() if err == nil { t.Fatal("CreateSession() should have returned a timeout error") } elapsed := time.Since(startTime) if elapsed > time.Second*5 { t.Fatal("ConnectTimeout is not respected") } if !errors.Is(err, ErrNoConnectionsStarted) { t.Fatalf("Expected to receive no connections error - got '%s'", err) } if !strings.Contains(log.String(), "no response to connection startup within timeout") && !strings.Contains(log.String(), "no response received from cassandra within timeout period") { t.Fatalf("Expected to receive timeout log message - got '%s'", log.String()) } cancel() } func TestTimeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) srv := NewTestServer(t, defaultProto, ctx) defer srv.Stop() db, err := newTestSession(defaultProto, srv.Address) if err != nil { t.Fatalf("NewCluster: %v", err) } defer db.Close() var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() select { case <-time.After(5 * time.Second): t.Errorf("no timeout") case <-ctx.Done(): } }() if err := db.Query("kill").WithContext(ctx).Exec(); err == nil { t.Fatal("expected error got nil") } cancel() wg.Wait() } func TestCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() srv := NewTestServer(t, defaultProto, ctx) defer srv.Stop() cluster := testCluster(defaultProto, srv.Address) cluster.Timeout = 1 * time.Second db, err := cluster.CreateSession() if err != nil { t.Fatalf("NewCluster: %v", err) } defer db.Close() qry := db.Query("timeout").WithContext(ctx) // Make sure we finish the query without leftovers var wg sync.WaitGroup wg.Add(1) go func() { if err := qry.Exec(); err != context.Canceled { t.Fatalf("expected to get context cancel error: '%v', got '%v'", context.Canceled, err) } wg.Done() }() // The query will timeout after about 1 seconds, so cancel it after a short pause time.AfterFunc(20*time.Millisecond, cancel) wg.Wait() } type testQueryObserver struct { metrics map[string]*hostMetrics verbose bool logger StdLogger } func (o *testQueryObserver) ObserveQuery(ctx context.Context, q ObservedQuery) { host := q.Host.ConnectAddress().String() o.metrics[host] = q.Metrics if o.verbose { o.logger.Printf("Observed query %q. Returned %v rows, took %v on host %q with %v attempts and total latency %v. Error: %q\n", q.Statement, q.Rows, q.End.Sub(q.Start), host, q.Metrics.Attempts, q.Metrics.TotalLatency, q.Err) } } func (o *testQueryObserver) GetMetrics(host *HostInfo) *hostMetrics { return o.metrics[host.ConnectAddress().String()] } // TestQueryRetry will test to make sure that gocql will execute // the exact amount of retry queries designated by the user. func TestQueryRetry(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() srv := NewTestServer(t, defaultProto, ctx) defer srv.Stop() db, err := newTestSession(defaultProto, srv.Address) if err != nil { t.Fatalf("NewCluster: %v", err) } defer db.Close() go func() { select { case <-ctx.Done(): return case <-time.After(5 * time.Second): t.Errorf("no timeout") } }() rt := &SimpleRetryPolicy{NumRetries: 1} qry := db.Query("kill").RetryPolicy(rt) if err := qry.Exec(); err == nil { t.Fatalf("expected error") } requests := atomic.LoadInt64(&srv.nKillReq) attempts := qry.Attempts() if requests != int64(attempts) { t.Fatalf("expected requests %v to match query attempts %v", requests, attempts) } // the query will only be attempted once, but is being retried if requests != int64(rt.NumRetries) { t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, requests-1) } } func TestQueryMultinodeWithMetrics(t *testing.T) { log := &testLogger{} defer func() { os.Stdout.WriteString(log.String()) }() // Build a 3 node cluster to test host metric mapping var nodes []*TestServer var addresses = []string{ "127.0.0.1", "127.0.0.2", "127.0.0.3", } // Can do with 1 context for all servers ctx := context.Background() for _, ip := range addresses { srv := NewTestServerWithAddress(ip+":0", t, defaultProto, ctx) defer srv.Stop() nodes = append(nodes, srv) } db, err := newTestSession(defaultProto, nodes[0].Address, nodes[1].Address, nodes[2].Address) if err != nil { t.Fatalf("NewCluster: %v", err) } defer db.Close() // 1 retry per host rt := &SimpleRetryPolicy{NumRetries: 3} observer := &testQueryObserver{metrics: make(map[string]*hostMetrics), verbose: false, logger: log} qry := db.Query("kill").RetryPolicy(rt).Observer(observer) if err := qry.Exec(); err == nil { t.Fatalf("expected error") } for i, ip := range addresses { host := &HostInfo{connectAddress: net.ParseIP(ip)} queryMetric := qry.metrics.hostMetrics(host) observedMetrics := observer.GetMetrics(host) requests := int(atomic.LoadInt64(&nodes[i].nKillReq)) hostAttempts := queryMetric.Attempts if requests != hostAttempts { t.Fatalf("expected requests %v to match query attempts %v", requests, hostAttempts) } if hostAttempts != observedMetrics.Attempts { t.Fatalf("expected observed attempts %v to match query attempts %v on host %v", observedMetrics.Attempts, hostAttempts, ip) } hostLatency := queryMetric.TotalLatency observedLatency := observedMetrics.TotalLatency if hostLatency != observedLatency { t.Fatalf("expected observed latency %v to match query latency %v on host %v", observedLatency, hostLatency, ip) } } // the query will only be attempted once, but is being retried attempts := qry.Attempts() if attempts != rt.NumRetries { t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, attempts) } } type testRetryPolicy struct { NumRetries int } func (t *testRetryPolicy) Attempt(qry RetryableQuery) bool { return qry.Attempts() <= t.NumRetries } func (t *testRetryPolicy) GetRetryType(err error) RetryType { return Retry } func TestSpeculativeExecution(t *testing.T) { log := &testLogger{} defer func() { os.Stdout.WriteString(log.String()) }() // Build a 3 node cluster var nodes []*TestServer var addresses = []string{ "127.0.0.1", "127.0.0.2", "127.0.0.3", } // Can do with 1 context for all servers ctx := context.Background() for _, ip := range addresses { srv := NewTestServerWithAddress(ip+":0", t, defaultProto, ctx) defer srv.Stop() nodes = append(nodes, srv) } db, err := newTestSession(defaultProto, nodes[0].Address, nodes[1].Address, nodes[2].Address) if err != nil { t.Fatalf("NewCluster: %v", err) } defer db.Close() // Create a test retry policy, 6 retries will cover 2 executions rt := &testRetryPolicy{NumRetries: 8} // test Speculative policy with 1 additional execution sp := &SimpleSpeculativeExecution{NumAttempts: 1, TimeoutDelay: 200 * time.Millisecond} // Build the query qry := db.Query("speculative").RetryPolicy(rt).SetSpeculativeExecutionPolicy(sp).Idempotent(true) // Execute the query and close, check that it doesn't error out if err := qry.Exec(); err != nil { t.Errorf("The query failed with '%v'!\n", err) } requests1 := atomic.LoadInt64(&nodes[0].nKillReq) requests2 := atomic.LoadInt64(&nodes[1].nKillReq) requests3 := atomic.LoadInt64(&nodes[2].nKillReq) // Spec Attempts == 1, so expecting to see only 1 regular + 1 speculative = 2 nodes attempted if requests1 != 0 && requests2 != 0 && requests3 != 0 { t.Error("error: all 3 nodes were attempted, should have been only 2") } // Only the 4th request will generate results, so if requests1 != 4 && requests2 != 4 && requests3 != 4 { t.Error("error: none of 3 nodes was attempted 4 times!") } // "speculative" query will succeed on one arbitrary node after 4 attempts, so // expecting to see 4 (on successful node) + not more than 2 (as cancelled on another node) == 6 if requests1+requests2+requests3 > 6 { t.Errorf("error: expected to see 6 attempts, got %v\n", requests1+requests2+requests3) } } // This tests that the policy connection pool handles SSL correctly func TestPolicyConnPoolSSL(t *testing.T) { srv := NewSSLTestServer(t, defaultProto, context.Background()) defer srv.Stop() cluster := createTestSslCluster(srv.Address, defaultProto, true) cluster.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() db, err := cluster.CreateSession() if err != nil { t.Fatalf("failed to create new session: %v", err) } if err := db.Query("void").Exec(); err != nil { t.Fatalf("query failed due to error: %v", err) } db.Close() // wait for the pool to drain time.Sleep(100 * time.Millisecond) size := db.pool.Size() if size != 0 { t.Fatalf("connection pool did not drain, still contains %d connections", size) } } func TestQueryTimeout(t *testing.T) { srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() cluster := testCluster(defaultProto, srv.Address) // Set the timeout arbitrarily low so that the query hits the timeout in a // timely manner. cluster.Timeout = 1 * time.Millisecond db, err := cluster.CreateSession() if err != nil { t.Fatalf("NewCluster: %v", err) } defer db.Close() ch := make(chan error, 1) go func() { err := db.Query("timeout").Exec() if err != nil { ch <- err return } t.Errorf("err was nil, expected to get a timeout after %v", db.cfg.Timeout) }() select { case err := <-ch: if err != ErrTimeoutNoResponse { t.Fatalf("expected to get %v for timeout got %v", ErrTimeoutNoResponse, err) } case <-time.After(40*time.Millisecond + db.cfg.Timeout): // ensure that the query goroutines have been scheduled t.Fatalf("query did not timeout after %v", db.cfg.Timeout) } } func BenchmarkSingleConn(b *testing.B) { srv := NewTestServer(b, 3, context.Background()) defer srv.Stop() cluster := testCluster(3, srv.Address) // Set the timeout arbitrarily low so that the query hits the timeout in a // timely manner. cluster.Timeout = 500 * time.Millisecond cluster.NumConns = 1 db, err := cluster.CreateSession() if err != nil { b.Fatalf("NewCluster: %v", err) } defer db.Close() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { err := db.Query("void").Exec() if err != nil { b.Error(err) return } } }) } func TestQueryTimeoutReuseStream(t *testing.T) { t.Skip("no longer tests anything") // TODO(zariel): move this to conn test, we really just want to check what // happens when a conn is srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() cluster := testCluster(defaultProto, srv.Address) // Set the timeout arbitrarily low so that the query hits the timeout in a // timely manner. cluster.Timeout = 1 * time.Millisecond cluster.NumConns = 1 db, err := cluster.CreateSession() if err != nil { t.Fatalf("NewCluster: %v", err) } defer db.Close() db.Query("slow").Exec() err = db.Query("void").Exec() if err != nil { t.Fatal(err) } } func TestQueryTimeoutClose(t *testing.T) { srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() cluster := testCluster(defaultProto, srv.Address) // Set the timeout arbitrarily low so that the query hits the timeout in a // timely manner. cluster.Timeout = 1000 * time.Millisecond cluster.NumConns = 1 db, err := cluster.CreateSession() if err != nil { t.Fatalf("NewCluster: %v", err) } ch := make(chan error) go func() { err := db.Query("timeout").Exec() ch <- err }() // ensure that the above goroutine gets sheduled time.Sleep(50 * time.Millisecond) db.Close() select { case err = <-ch: case <-time.After(1 * time.Second): t.Fatal("timedout waiting to get a response once cluster is closed") } if err != ErrConnectionClosed { t.Fatalf("expected to get %v got %v", ErrConnectionClosed, err) } } func TestStream0(t *testing.T) { // TODO: replace this with type check const expErr = "gocql: received unexpected frame on stream 0" var buf bytes.Buffer f := newFramer(nil, protoVersion4) f.writeHeader(0, opResult, 0) f.writeInt(resultKindVoid) f.buf[0] |= 0x80 if err := f.finish(); err != nil { t.Fatal(err) } if err := f.writeTo(&buf); err != nil { t.Fatal(err) } conn := &Conn{ r: bufio.NewReader(&buf), streams: streams.New(protoVersion4), logger: &defaultLogger{}, } err := conn.recv(context.Background()) if err == nil { t.Fatal("expected to get an error on stream 0") } else if !strings.HasPrefix(err.Error(), expErr) { t.Fatalf("expected to get error prefix %q got %q", expErr, err.Error()) } } func TestContext_Timeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() srv := NewTestServer(t, defaultProto, ctx) defer srv.Stop() cluster := testCluster(defaultProto, srv.Address) cluster.Timeout = 5 * time.Second db, err := cluster.CreateSession() if err != nil { t.Fatal(err) } defer db.Close() ctx, cancel = context.WithCancel(ctx) cancel() err = db.Query("timeout").WithContext(ctx).Exec() if err != context.Canceled { t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err) } } func TestContext_CanceledBeforeExec(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() var reqCount uint64 srv := newTestServerOpts{ addr: "127.0.0.1:0", protocol: defaultProto, recvHook: func(f *framer) { if f.header.op == opStartup || f.header.op == opOptions { // ignore statup and heartbeat messages return } atomic.AddUint64(&reqCount, 1) }, }.newServer(t, ctx) defer srv.Stop() cluster := testCluster(defaultProto, srv.Address) cluster.Timeout = 5 * time.Second db, err := cluster.CreateSession() if err != nil { t.Fatal(err) } defer db.Close() startupRequestCount := atomic.LoadUint64(&reqCount) ctx, cancel = context.WithCancel(ctx) cancel() err = db.Query("timeout").WithContext(ctx).Exec() if err != context.Canceled { t.Fatalf("expected to get context cancel error: %v got %v", context.Canceled, err) } // Queries are executed by separate goroutine and we don't have a synchronization point that would allow us to // check if a request was sent or not. // Fall back to waiting a little bit. time.Sleep(100 * time.Millisecond) queryRequestCount := atomic.LoadUint64(&reqCount) - startupRequestCount if queryRequestCount != 0 { t.Fatalf("expected that no request is sent to server, sent %d requests", queryRequestCount) } } // tcpConnPair returns a matching set of a TCP client side and server side connection. func tcpConnPair() (s, c net.Conn, err error) { l, err := net.Listen("tcp", "localhost:0") if err != nil { // maybe ipv6 works, if ipv4 fails? l, err = net.Listen("tcp6", "[::1]:0") if err != nil { return nil, nil, err } } defer l.Close() // we only try to accept one connection, so will stop listening. addr := l.Addr() done := make(chan struct{}) var errDial error go func(done chan<- struct{}) { c, errDial = net.Dial(addr.Network(), addr.String()) close(done) }(done) s, err = l.Accept() <-done if err == nil { err = errDial } if err != nil { if s != nil { s.Close() } if c != nil { c.Close() } } return s, c, err } func TestWriteCoalescing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() server, client, err := tcpConnPair() if err != nil { t.Fatal(err) } done := make(chan struct{}, 1) var ( buf bytes.Buffer bufMutex sync.Mutex ) go func() { defer close(done) defer server.Close() var err error b := make([]byte, 256) var n int for { if n, err = server.Read(b); err != nil { break } bufMutex.Lock() buf.Write(b[:n]) bufMutex.Unlock() } if err != io.EOF { t.Errorf("unexpected read error: %v", err) } }() enqueued := make(chan struct{}) resetTimer := make(chan struct{}) w := &writeCoalescer{ writeCh: make(chan writeRequest), c: client, quit: ctx.Done(), timeout: 500 * time.Millisecond, testEnqueuedHook: func() { enqueued <- struct{}{} }, testFlushedHook: func() { client.Close() }, } timerC := make(chan time.Time, 1) go func() { w.writeFlusherImpl(timerC, func() { resetTimer <- struct{}{} }) }() go func() { if _, err := w.writeContext(context.Background(), []byte("one")); err != nil { t.Error(err) } }() go func() { if _, err := w.writeContext(context.Background(), []byte("two")); err != nil { t.Error(err) } }() <-enqueued <-resetTimer <-enqueued // flush timerC <- time.Now() <-done if got := buf.String(); got != "onetwo" && got != "twoone" { t.Fatalf("expected to get %q got %q", "onetwo or twoone", got) } } func TestWriteCoalescing_WriteAfterClose(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() var buf bytes.Buffer defer cancel() server, client, err := tcpConnPair() if err != nil { t.Fatal(err) } done := make(chan struct{}, 1) go func() { io.Copy(&buf, server) server.Close() close(done) }() w := newWriteCoalescer(client, 0, 5*time.Millisecond, ctx.Done()) // ensure 1 write works if _, err := w.writeContext(context.Background(), []byte("one")); err != nil { t.Fatal(err) } client.Close() <-done if v := buf.String(); v != "one" { t.Fatalf("expected buffer to be %q got %q", "one", v) } // now close and do a write, we should error cancel() client.Close() // close client conn too, since server won't see the answer anyway. if _, err := w.writeContext(context.Background(), []byte("two")); err == nil { t.Fatal("expected to get error for write after closing") } else if err != io.EOF { t.Fatalf("expected to get EOF got %v", err) } } type recordingFrameHeaderObserver struct { t *testing.T mu sync.Mutex frames []ObservedFrameHeader } func (r *recordingFrameHeaderObserver) ObserveFrameHeader(ctx context.Context, frm ObservedFrameHeader) { r.mu.Lock() r.frames = append(r.frames, frm) r.mu.Unlock() } func (r *recordingFrameHeaderObserver) getFrames() []ObservedFrameHeader { r.mu.Lock() defer r.mu.Unlock() return r.frames } func TestFrameHeaderObserver(t *testing.T) { srv := NewTestServer(t, defaultProto, context.Background()) defer srv.Stop() cluster := testCluster(defaultProto, srv.Address) cluster.NumConns = 1 observer := &recordingFrameHeaderObserver{t: t} cluster.FrameHeaderObserver = observer db, err := cluster.CreateSession() if err != nil { t.Fatal(err) } if err := db.Query("void").Exec(); err != nil { t.Fatal(err) } frames := observer.getFrames() expFrames := []frameOp{opSupported, opReady, opResult} if len(frames) != len(expFrames) { t.Fatalf("Expected to receive %d frames, instead received %d", len(expFrames), len(frames)) } for i, op := range expFrames { if op != frames[i].Opcode { t.Fatalf("expected frame %d to be %v got %v", i, op, frames[i]) } } voidResultFrame := frames[2] if voidResultFrame.Length != int32(4) { t.Fatalf("Expected to receive frame with body length 4, instead received body length %d", voidResultFrame.Length) } } func NewTestServerWithAddress(addr string, t testing.TB, protocol uint8, ctx context.Context) *TestServer { return newTestServerOpts{ addr: addr, protocol: protocol, }.newServer(t, ctx) } type newTestServerOpts struct { addr string protocol uint8 recvHook func(*framer) } func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestServer { laddr, err := net.ResolveTCPAddr("tcp", nts.addr) if err != nil { t.Fatal(err) } listen, err := net.ListenTCP("tcp", laddr) if err != nil { t.Fatal(err) } headerSize := 8 if nts.protocol > protoVersion2 { headerSize = 9 } ctx, cancel := context.WithCancel(ctx) srv := &TestServer{ Address: listen.Addr().String(), listen: listen, t: t, protocol: nts.protocol, headerSize: headerSize, ctx: ctx, cancel: cancel, onRecv: nts.recvHook, } go srv.closeWatch() go srv.serve() return srv } func NewTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer { return NewTestServerWithAddress("127.0.0.1:0", t, protocol, ctx) } func NewSSLTestServer(t testing.TB, protocol uint8, ctx context.Context) *TestServer { pem, err := ioutil.ReadFile("testdata/pki/ca.crt") certPool := x509.NewCertPool() if !certPool.AppendCertsFromPEM(pem) { t.Fatalf("Failed parsing or appending certs") } mycert, err := tls.LoadX509KeyPair("testdata/pki/cassandra.crt", "testdata/pki/cassandra.key") if err != nil { t.Fatalf("could not load cert") } config := &tls.Config{ Certificates: []tls.Certificate{mycert}, RootCAs: certPool, } listen, err := tls.Listen("tcp", "127.0.0.1:0", config) if err != nil { t.Fatal(err) } headerSize := 8 if protocol > protoVersion2 { headerSize = 9 } ctx, cancel := context.WithCancel(ctx) srv := &TestServer{ Address: listen.Addr().String(), listen: listen, t: t, protocol: protocol, headerSize: headerSize, ctx: ctx, cancel: cancel, } go srv.closeWatch() go srv.serve() return srv } type TestServer struct { Address string TimeoutOnStartup int32 t testing.TB listen net.Listener nKillReq int64 protocol byte headerSize int ctx context.Context cancel context.CancelFunc mu sync.Mutex closed bool // onRecv is a hook point for tests, called in receive loop. onRecv func(*framer) } func (srv *TestServer) closeWatch() { <-srv.ctx.Done() srv.mu.Lock() defer srv.mu.Unlock() srv.closeLocked() } func (srv *TestServer) serve() { defer srv.listen.Close() for !srv.isClosed() { conn, err := srv.listen.Accept() if err != nil { break } go func(conn net.Conn) { defer conn.Close() for !srv.isClosed() { framer, err := srv.readFrame(conn) if err != nil { if err == io.EOF { return } srv.errorLocked(err) return } if srv.onRecv != nil { srv.onRecv(framer) } go srv.process(conn, framer) } }(conn) } } func (srv *TestServer) isClosed() bool { srv.mu.Lock() defer srv.mu.Unlock() return srv.closed } func (srv *TestServer) closeLocked() { if srv.closed { return } srv.closed = true srv.listen.Close() srv.cancel() } func (srv *TestServer) Stop() { srv.mu.Lock() defer srv.mu.Unlock() srv.closeLocked() } func (srv *TestServer) errorLocked(err interface{}) { srv.mu.Lock() defer srv.mu.Unlock() if srv.closed { return } srv.t.Error(err) } func (srv *TestServer) process(conn net.Conn, reqFrame *framer) { head := reqFrame.header if head == nil { srv.errorLocked("process frame with a nil header") return } respFrame := newFramer(nil, reqFrame.proto) switch head.op { case opStartup: if atomic.LoadInt32(&srv.TimeoutOnStartup) > 0 { // Do not respond to startup command // wait until we get a cancel signal select { case <-srv.ctx.Done(): return } } respFrame.writeHeader(0, opReady, head.stream) case opOptions: respFrame.writeHeader(0, opSupported, head.stream) respFrame.writeShort(0) case opQuery: query := reqFrame.readLongString() first := query if n := strings.Index(query, " "); n > 0 { first = first[:n] } switch strings.ToLower(first) { case "kill": atomic.AddInt64(&srv.nKillReq, 1) respFrame.writeHeader(0, opError, head.stream) respFrame.writeInt(0x1001) respFrame.writeString("query killed") case "use": respFrame.writeInt(resultKindKeyspace) respFrame.writeString(strings.TrimSpace(query[3:])) case "void": respFrame.writeHeader(0, opResult, head.stream) respFrame.writeInt(resultKindVoid) case "timeout": <-srv.ctx.Done() return case "slow": go func() { respFrame.writeHeader(0, opResult, head.stream) respFrame.writeInt(resultKindVoid) respFrame.buf[0] = srv.protocol | 0x80 select { case <-srv.ctx.Done(): return case <-time.After(50 * time.Millisecond): respFrame.finish() respFrame.writeTo(conn) } }() return case "speculative": atomic.AddInt64(&srv.nKillReq, 1) if atomic.LoadInt64(&srv.nKillReq) > 3 { respFrame.writeHeader(0, opResult, head.stream) respFrame.writeInt(resultKindVoid) respFrame.writeString("speculative query success on the node " + srv.Address) } else { respFrame.writeHeader(0, opError, head.stream) respFrame.writeInt(0x1001) respFrame.writeString("speculative error") rand.Seed(time.Now().UnixNano()) <-time.After(time.Millisecond * 120) } default: respFrame.writeHeader(0, opResult, head.stream) respFrame.writeInt(resultKindVoid) } case opError: respFrame.writeHeader(0, opError, head.stream) respFrame.buf = append(respFrame.buf, reqFrame.buf...) default: respFrame.writeHeader(0, opError, head.stream) respFrame.writeInt(0) respFrame.writeString("not supported") } respFrame.buf[0] = srv.protocol | 0x80 if err := respFrame.finish(); err != nil { srv.errorLocked(err) } if err := respFrame.writeTo(conn); err != nil { srv.errorLocked(err) } } func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) { buf := make([]byte, srv.headerSize) head, err := readHeader(conn, buf) if err != nil { return nil, err } framer := newFramer(nil, srv.protocol) err = framer.readFrame(conn, &head) if err != nil { return nil, err } // should be a request frame if head.version.response() { return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version) } else if head.version.version() != srv.protocol { return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version()) } return framer, nil } cassandra-gocql-driver-1.7.0/connectionpool.go000066400000000000000000000355061467504044300214520ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "crypto/tls" "crypto/x509" "errors" "fmt" "io/ioutil" "math/rand" "net" "sync" "sync/atomic" "time" ) // interface to implement to receive the host information type SetHosts interface { SetHosts(hosts []*HostInfo) } // interface to implement to receive the partitioner value type SetPartitioner interface { SetPartitioner(partitioner string) } func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) { // Config.InsecureSkipVerify | EnableHostVerification | Result // Config is nil | true | verify host // Config is nil | false | do not verify host // false | false | verify host // true | false | do not verify host // false | true | verify host // true | true | verify host var tlsConfig *tls.Config if sslOpts.Config == nil { tlsConfig = &tls.Config{ InsecureSkipVerify: !sslOpts.EnableHostVerification, } } else { // use clone to avoid race. tlsConfig = sslOpts.Config.Clone() } if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification { tlsConfig.InsecureSkipVerify = false } // ca cert is optional if sslOpts.CaPath != "" { if tlsConfig.RootCAs == nil { tlsConfig.RootCAs = x509.NewCertPool() } pem, err := ioutil.ReadFile(sslOpts.CaPath) if err != nil { return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err) } if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) { return nil, errors.New("connectionpool: failed parsing or CA certs") } } if sslOpts.CertPath != "" || sslOpts.KeyPath != "" { mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath) if err != nil { return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err) } tlsConfig.Certificates = append(tlsConfig.Certificates, mycert) } return tlsConfig, nil } type policyConnPool struct { session *Session port int numConns int keyspace string mu sync.RWMutex hostConnPools map[string]*hostConnPool } func connConfig(cfg *ClusterConfig) (*ConnConfig, error) { var ( err error hostDialer HostDialer ) hostDialer = cfg.HostDialer if hostDialer == nil { var tlsConfig *tls.Config // TODO(zariel): move tls config setup into session init. if cfg.SslOpts != nil { tlsConfig, err = setupTLSConfig(cfg.SslOpts) if err != nil { return nil, err } } dialer := cfg.Dialer if dialer == nil { d := &net.Dialer{ Timeout: cfg.ConnectTimeout, } if cfg.SocketKeepalive > 0 { d.KeepAlive = cfg.SocketKeepalive } dialer = d } hostDialer = &defaultHostDialer{ dialer: dialer, tlsConfig: tlsConfig, } } return &ConnConfig{ ProtoVersion: cfg.ProtoVersion, CQLVersion: cfg.CQLVersion, Timeout: cfg.Timeout, WriteTimeout: cfg.WriteTimeout, ConnectTimeout: cfg.ConnectTimeout, Dialer: cfg.Dialer, HostDialer: hostDialer, Compressor: cfg.Compressor, Authenticator: cfg.Authenticator, AuthProvider: cfg.AuthProvider, Keepalive: cfg.SocketKeepalive, Logger: cfg.logger(), }, nil } func newPolicyConnPool(session *Session) *policyConnPool { // create the pool pool := &policyConnPool{ session: session, port: session.cfg.Port, numConns: session.cfg.NumConns, keyspace: session.cfg.Keyspace, hostConnPools: map[string]*hostConnPool{}, } return pool } func (p *policyConnPool) SetHosts(hosts []*HostInfo) { p.mu.Lock() defer p.mu.Unlock() toRemove := make(map[string]struct{}) for hostID := range p.hostConnPools { toRemove[hostID] = struct{}{} } pools := make(chan *hostConnPool) createCount := 0 for _, host := range hosts { if !host.IsUp() { // don't create a connection pool for a down host continue } hostID := host.HostID() if _, exists := p.hostConnPools[hostID]; exists { // still have this host, so don't remove it delete(toRemove, hostID) continue } createCount++ go func(host *HostInfo) { // create a connection pool for the host pools <- newHostConnPool( p.session, host, p.port, p.numConns, p.keyspace, ) }(host) } // add created pools for createCount > 0 { pool := <-pools createCount-- if pool.Size() > 0 { // add pool only if there a connections available p.hostConnPools[pool.host.HostID()] = pool } } for addr := range toRemove { pool := p.hostConnPools[addr] delete(p.hostConnPools, addr) go pool.Close() } } func (p *policyConnPool) Size() int { p.mu.RLock() count := 0 for _, pool := range p.hostConnPools { count += pool.Size() } p.mu.RUnlock() return count } func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) { hostID := host.HostID() p.mu.RLock() pool, ok = p.hostConnPools[hostID] p.mu.RUnlock() return } func (p *policyConnPool) Close() { p.mu.Lock() defer p.mu.Unlock() // close the pools for addr, pool := range p.hostConnPools { delete(p.hostConnPools, addr) pool.Close() } } func (p *policyConnPool) addHost(host *HostInfo) { hostID := host.HostID() p.mu.Lock() pool, ok := p.hostConnPools[hostID] if !ok { pool = newHostConnPool( p.session, host, host.Port(), // TODO: if port == 0 use pool.port? p.numConns, p.keyspace, ) p.hostConnPools[hostID] = pool } p.mu.Unlock() pool.fill() } func (p *policyConnPool) removeHost(hostID string) { p.mu.Lock() pool, ok := p.hostConnPools[hostID] if !ok { p.mu.Unlock() return } delete(p.hostConnPools, hostID) p.mu.Unlock() go pool.Close() } // hostConnPool is a connection pool for a single host. // Connection selection is based on a provided ConnSelectionPolicy type hostConnPool struct { session *Session host *HostInfo port int size int keyspace string // protection for conns, closed, filling mu sync.RWMutex conns []*Conn closed bool filling bool pos uint32 logger StdLogger } func (h *hostConnPool) String() string { h.mu.RLock() defer h.mu.RUnlock() return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]", h.filling, h.closed, len(h.conns), h.size, h.host) } func newHostConnPool(session *Session, host *HostInfo, port, size int, keyspace string) *hostConnPool { pool := &hostConnPool{ session: session, host: host, port: port, size: size, keyspace: keyspace, conns: make([]*Conn, 0, size), filling: false, closed: false, logger: session.logger, } // the pool is not filled or connected return pool } // Pick a connection from this connection pool for the given query. func (pool *hostConnPool) Pick() *Conn { pool.mu.RLock() defer pool.mu.RUnlock() if pool.closed { return nil } size := len(pool.conns) if size < pool.size { // try to fill the pool go pool.fill() if size == 0 { return nil } } pos := int(atomic.AddUint32(&pool.pos, 1) - 1) var ( leastBusyConn *Conn streamsAvailable int ) // find the conn which has the most available streams, this is racy for i := 0; i < size; i++ { conn := pool.conns[(pos+i)%size] if streams := conn.AvailableStreams(); streams > streamsAvailable { leastBusyConn = conn streamsAvailable = streams } } return leastBusyConn } // Size returns the number of connections currently active in the pool func (pool *hostConnPool) Size() int { pool.mu.RLock() defer pool.mu.RUnlock() return len(pool.conns) } // Close the connection pool func (pool *hostConnPool) Close() { pool.mu.Lock() if pool.closed { pool.mu.Unlock() return } pool.closed = true // ensure we dont try to reacquire the lock in handleError // TODO: improve this as the following can happen // 1) we have locked pool.mu write lock // 2) conn.Close calls conn.closeWithError(nil) // 3) conn.closeWithError calls conn.Close() which returns an error // 4) conn.closeWithError calls pool.HandleError with the error from conn.Close // 5) pool.HandleError tries to lock pool.mu // deadlock // empty the pool conns := pool.conns pool.conns = nil pool.mu.Unlock() // close the connections for _, conn := range conns { conn.Close() } } // Fill the connection pool func (pool *hostConnPool) fill() { pool.mu.RLock() // avoid filling a closed pool, or concurrent filling if pool.closed || pool.filling { pool.mu.RUnlock() return } // determine the filling work to be done startCount := len(pool.conns) fillCount := pool.size - startCount // avoid filling a full (or overfull) pool if fillCount <= 0 { pool.mu.RUnlock() return } // switch from read to write lock pool.mu.RUnlock() pool.mu.Lock() // double check everything since the lock was released startCount = len(pool.conns) fillCount = pool.size - startCount if pool.closed || pool.filling || fillCount <= 0 { // looks like another goroutine already beat this // goroutine to the filling pool.mu.Unlock() return } // ok fill the pool pool.filling = true // allow others to access the pool while filling pool.mu.Unlock() // only this goroutine should make calls to fill/empty the pool at this // point until after this routine or its subordinates calls // fillingStopped // fill only the first connection synchronously if startCount == 0 { err := pool.connect() pool.logConnectErr(err) if err != nil { // probably unreachable host pool.fillingStopped(err) return } // notify the session that this node is connected go pool.session.handleNodeConnected(pool.host) // filled one fillCount-- } // fill the rest of the pool asynchronously go func() { err := pool.connectMany(fillCount) // mark the end of filling pool.fillingStopped(err) if err == nil && startCount > 0 { // notify the session that this node is connected again go pool.session.handleNodeConnected(pool.host) } }() } func (pool *hostConnPool) logConnectErr(err error) { if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") { // connection refused // these are typical during a node outage so avoid log spam. if gocqlDebug { pool.logger.Printf("gocql: unable to dial %q: %v\n", pool.host, err) } } else if err != nil { // unexpected error pool.logger.Printf("error: failed to connect to %q due to error: %v", pool.host, err) } } // transition back to a not-filling state. func (pool *hostConnPool) fillingStopped(err error) { if err != nil { if gocqlDebug { pool.logger.Printf("gocql: filling stopped %q: %v\n", pool.host.ConnectAddress(), err) } // wait for some time to avoid back-to-back filling // this provides some time between failed attempts // to fill the pool for the host to recover time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond) } pool.mu.Lock() pool.filling = false count := len(pool.conns) host := pool.host port := pool.port pool.mu.Unlock() // if we errored and the size is now zero, make sure the host is marked as down // see https://github.com/apache/cassandra-gocql-driver/issues/1614 if gocqlDebug { pool.logger.Printf("gocql: conns of pool after stopped %q: %v\n", host.ConnectAddress(), count) } if err != nil && count == 0 { if pool.session.cfg.ConvictionPolicy.AddFailure(err, host) { pool.session.handleNodeDown(host.ConnectAddress(), port) } } } // connectMany creates new connections concurrent. func (pool *hostConnPool) connectMany(count int) error { if count == 0 { return nil } var ( wg sync.WaitGroup mu sync.Mutex connectErr error ) wg.Add(count) for i := 0; i < count; i++ { go func() { defer wg.Done() err := pool.connect() pool.logConnectErr(err) if err != nil { mu.Lock() connectErr = err mu.Unlock() } }() } // wait for all connections are done wg.Wait() return connectErr } // create a new connection to the host and add it to the pool func (pool *hostConnPool) connect() (err error) { // TODO: provide a more robust connection retry mechanism, we should also // be able to detect hosts that come up by trying to connect to downed ones. // try to connect var conn *Conn reconnectionPolicy := pool.session.cfg.ReconnectionPolicy for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ { conn, err = pool.session.connect(pool.session.ctx, pool.host, pool) if err == nil { break } if opErr, isOpErr := err.(*net.OpError); isOpErr { // if the error is not a temporary error (ex: network unreachable) don't // retry if !opErr.Temporary() { break } } if gocqlDebug { pool.logger.Printf("gocql: connection failed %q: %v, reconnecting with %T\n", pool.host.ConnectAddress(), err, reconnectionPolicy) } time.Sleep(reconnectionPolicy.GetInterval(i)) } if err != nil { return err } if pool.keyspace != "" { // set the keyspace if err = conn.UseKeyspace(pool.keyspace); err != nil { conn.Close() return err } } // add the Conn to the pool pool.mu.Lock() defer pool.mu.Unlock() if pool.closed { conn.Close() return nil } pool.conns = append(pool.conns, conn) return nil } // handle any error from a Conn func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) { if !closed { // still an open connection, so continue using it return } // TODO: track the number of errors per host and detect when a host is dead, // then also have something which can detect when a host comes back. pool.mu.Lock() defer pool.mu.Unlock() if pool.closed { // pool closed return } if gocqlDebug { pool.logger.Printf("gocql: pool connection error %q: %v\n", conn.addr, err) } // find the connection index for i, candidate := range pool.conns { if candidate == conn { // remove the connection, not preserving order pool.conns[i], pool.conns = pool.conns[len(pool.conns)-1], pool.conns[:len(pool.conns)-1] // lost a connection, so fill the pool go pool.fill() break } } } cassandra-gocql-driver-1.7.0/connectionpool_test.go000066400000000000000000000057211467504044300225050ustar00rootroot00000000000000//go:build all || unit // +build all unit /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "crypto/tls" "testing" ) func TestSetupTLSConfig(t *testing.T) { tests := []struct { name string opts *SslOptions expectedInsecureSkipVerify bool }{ { name: "Config nil, EnableHostVerification false", opts: &SslOptions{ EnableHostVerification: false, }, expectedInsecureSkipVerify: true, }, { name: "Config nil, EnableHostVerification true", opts: &SslOptions{ EnableHostVerification: true, }, expectedInsecureSkipVerify: false, }, { name: "Config.InsecureSkipVerify false, EnableHostVerification false", opts: &SslOptions{ EnableHostVerification: false, Config: &tls.Config{ InsecureSkipVerify: false, }, }, expectedInsecureSkipVerify: false, }, { name: "Config.InsecureSkipVerify true, EnableHostVerification false", opts: &SslOptions{ EnableHostVerification: false, Config: &tls.Config{ InsecureSkipVerify: true, }, }, expectedInsecureSkipVerify: true, }, { name: "Config.InsecureSkipVerify false, EnableHostVerification true", opts: &SslOptions{ EnableHostVerification: true, Config: &tls.Config{ InsecureSkipVerify: false, }, }, expectedInsecureSkipVerify: false, }, { name: "Config.InsecureSkipVerify true, EnableHostVerification true", opts: &SslOptions{ EnableHostVerification: true, Config: &tls.Config{ InsecureSkipVerify: true, }, }, expectedInsecureSkipVerify: false, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { tlsConfig, err := setupTLSConfig(test.opts) if err != nil { t.Fatalf("unexpected error %q", err.Error()) } if tlsConfig.InsecureSkipVerify != test.expectedInsecureSkipVerify { t.Fatalf("got %v, but expected %v", tlsConfig.InsecureSkipVerify, test.expectedInsecureSkipVerify) } }) } } cassandra-gocql-driver-1.7.0/control.go000066400000000000000000000315301467504044300200720ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "context" crand "crypto/rand" "errors" "fmt" "math/rand" "net" "os" "regexp" "strconv" "sync" "sync/atomic" "time" ) var ( randr *rand.Rand mutRandr sync.Mutex ) func init() { b := make([]byte, 4) if _, err := crand.Read(b); err != nil { panic(fmt.Sprintf("unable to seed random number generator: %v", err)) } randr = rand.New(rand.NewSource(int64(readInt(b)))) } const ( controlConnStarting = 0 controlConnStarted = 1 controlConnClosing = -1 ) // Ensure that the atomic variable is aligned to a 64bit boundary // so that atomic operations can be applied on 32bit architectures. type controlConn struct { state int32 reconnecting int32 session *Session conn atomic.Value retry RetryPolicy quit chan struct{} } func createControlConn(session *Session) *controlConn { control := &controlConn{ session: session, quit: make(chan struct{}), retry: &SimpleRetryPolicy{NumRetries: 3}, } control.conn.Store((*connHost)(nil)) return control } func (c *controlConn) heartBeat() { if !atomic.CompareAndSwapInt32(&c.state, controlConnStarting, controlConnStarted) { return } sleepTime := 1 * time.Second timer := time.NewTimer(sleepTime) defer timer.Stop() for { timer.Reset(sleepTime) select { case <-c.quit: return case <-timer.C: } resp, err := c.writeFrame(&writeOptionsFrame{}) if err != nil { goto reconn } switch resp.(type) { case *supportedFrame: // Everything ok sleepTime = 5 * time.Second continue case error: goto reconn default: panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp)) } reconn: // try to connect a bit faster sleepTime = 1 * time.Second c.reconnect() continue } } var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true" func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) { var port int host, portStr, err := net.SplitHostPort(addr) if err != nil { host = addr port = defaultPort } else { port, err = strconv.Atoi(portStr) if err != nil { return nil, err } } var hosts []*HostInfo // Check if host is a literal IP address if ip := net.ParseIP(host); ip != nil { hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port}) return hosts, nil } // Look up host in DNS ips, err := LookupIP(host) if err != nil { return nil, err } else if len(ips) == 0 { return nil, fmt.Errorf("no IP's returned from DNS lookup for %q", addr) } // Filter to v4 addresses if any present if hostLookupPreferV4 { var preferredIPs []net.IP for _, v := range ips { if v4 := v.To4(); v4 != nil { preferredIPs = append(preferredIPs, v4) } } if len(preferredIPs) != 0 { ips = preferredIPs } } for _, ip := range ips { hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port}) } return hosts, nil } func shuffleHosts(hosts []*HostInfo) []*HostInfo { shuffled := make([]*HostInfo, len(hosts)) copy(shuffled, hosts) mutRandr.Lock() randr.Shuffle(len(hosts), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] }) mutRandr.Unlock() return shuffled } // this is going to be version dependant and a nightmare to maintain :( var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`) func parseProtocolFromError(err error) int { // I really wish this had the actual info in the error frame... matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1) if len(matches) != 1 || len(matches[0]) != 2 { if verr, ok := err.(*protocolError); ok { return int(verr.frame.Header().version.version()) } return 0 } max, err := strconv.Atoi(matches[0][1]) if err != nil { return 0 } return max } func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) { hosts = shuffleHosts(hosts) connCfg := *c.session.connCfg connCfg.ProtoVersion = 4 // TODO: define maxProtocol handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) { // we should never get here, but if we do it means we connected to a // host successfully which means our attempted protocol version worked if !closed { c.Close() } }) var err error for _, host := range hosts { var conn *Conn conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler) if conn != nil { conn.Close() } if err == nil { return connCfg.ProtoVersion, nil } if proto := parseProtocolFromError(err); proto > 0 { return proto, nil } } return 0, err } func (c *controlConn) connect(hosts []*HostInfo) error { if len(hosts) == 0 { return errors.New("control: no endpoints specified") } // shuffle endpoints so not all drivers will connect to the same initial // node. hosts = shuffleHosts(hosts) cfg := *c.session.connCfg cfg.disableCoalesce = true var conn *Conn var err error for _, host := range hosts { conn, err = c.session.dial(c.session.ctx, host, &cfg, c) if err != nil { c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) continue } err = c.setupConn(conn) if err == nil { break } c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) conn.Close() conn = nil } if conn == nil { return fmt.Errorf("unable to connect to initial hosts: %v", err) } // we could fetch the initial ring here and update initial host data. So that // when we return from here we have a ring topology ready to go. go c.heartBeat() return nil } type connHost struct { conn *Conn host *HostInfo } func (c *controlConn) setupConn(conn *Conn) error { // we need up-to-date host info for the filterHost call below iter := conn.querySystemLocal(context.TODO()) host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress, conn.conn.RemoteAddr().(*net.TCPAddr).Port) if err != nil { return err } host = c.session.ring.addOrUpdate(host) if c.session.cfg.filterHost(host) { return fmt.Errorf("host was filtered: %v", host.ConnectAddress()) } if err := c.registerEvents(conn); err != nil { return fmt.Errorf("register events: %v", err) } ch := &connHost{ conn: conn, host: host, } c.conn.Store(ch) if c.session.initialized() { // We connected to control conn, so add the connect the host in pool as well. // Notify session we can start trying to connect to the node. // We can't start the fill before the session is initialized, otherwise the fill would interfere // with the fill called by Session.init. Session.init needs to wait for its fill to finish and that // would return immediately if we started the fill here. // TODO(martin-sucha): Trigger pool refill for all hosts, like in reconnectDownedHosts? go c.session.startPoolFill(host) } return nil } func (c *controlConn) registerEvents(conn *Conn) error { var events []string if !c.session.cfg.Events.DisableTopologyEvents { events = append(events, "TOPOLOGY_CHANGE") } if !c.session.cfg.Events.DisableNodeStatusEvents { events = append(events, "STATUS_CHANGE") } if !c.session.cfg.Events.DisableSchemaEvents { events = append(events, "SCHEMA_CHANGE") } if len(events) == 0 { return nil } framer, err := conn.exec(context.Background(), &writeRegisterFrame{ events: events, }, nil) if err != nil { return err } frame, err := framer.parseFrame() if err != nil { return err } else if _, ok := frame.(*readyFrame); !ok { return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame) } return nil } func (c *controlConn) reconnect() { if atomic.LoadInt32(&c.state) == controlConnClosing { return } if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) { return } defer atomic.StoreInt32(&c.reconnecting, 0) conn, err := c.attemptReconnect() if conn == nil { c.session.logger.Printf("gocql: unable to reconnect control connection: %v\n", err) return } err = c.session.refreshRing() if err != nil { c.session.logger.Printf("gocql: unable to refresh ring: %v\n", err) } } func (c *controlConn) attemptReconnect() (*Conn, error) { hosts := c.session.ring.allHosts() hosts = shuffleHosts(hosts) // keep the old behavior of connecting to the old host first by moving it to // the front of the slice ch := c.getConn() if ch != nil { for i := range hosts { if hosts[i].Equal(ch.host) { hosts[0], hosts[i] = hosts[i], hosts[0] break } } ch.conn.Close() } conn, err := c.attemptReconnectToAnyOfHosts(hosts) if conn != nil { return conn, err } c.session.logger.Printf("gocql: unable to connect to any ring node: %v\n", err) c.session.logger.Printf("gocql: control falling back to initial contact points.\n") // Fallback to initial contact points, as it may be the case that all known initialHosts // changed their IPs while keeping the same hostname(s). initialHosts, resolvErr := addrsToHosts(c.session.cfg.Hosts, c.session.cfg.Port, c.session.logger) if resolvErr != nil { return nil, fmt.Errorf("resolve contact points' hostnames: %v", resolvErr) } return c.attemptReconnectToAnyOfHosts(initialHosts) } func (c *controlConn) attemptReconnectToAnyOfHosts(hosts []*HostInfo) (*Conn, error) { var conn *Conn var err error for _, host := range hosts { conn, err = c.session.connect(c.session.ctx, host, c) if err != nil { c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) continue } err = c.setupConn(conn) if err == nil { break } c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err) conn.Close() conn = nil } return conn, err } func (c *controlConn) HandleError(conn *Conn, err error, closed bool) { if !closed { return } oldConn := c.getConn() // If connection has long gone, and not been attempted for awhile, // it's possible to have oldConn as nil here (#1297). if oldConn != nil && oldConn.conn != conn { return } c.reconnect() } func (c *controlConn) getConn() *connHost { return c.conn.Load().(*connHost) } func (c *controlConn) writeFrame(w frameBuilder) (frame, error) { ch := c.getConn() if ch == nil { return nil, errNoControl } framer, err := ch.conn.exec(context.Background(), w, nil) if err != nil { return nil, err } return framer.parseFrame() } func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter { const maxConnectAttempts = 5 connectAttempts := 0 for i := 0; i < maxConnectAttempts; i++ { ch := c.getConn() if ch == nil { if connectAttempts > maxConnectAttempts { break } connectAttempts++ c.reconnect() continue } return fn(ch) } return &Iter{err: errNoControl} } func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { return c.withConnHost(func(ch *connHost) *Iter { return fn(ch.conn) }) } // query will return nil if the connection is closed or nil func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) { q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil) for { iter = c.withConn(func(conn *Conn) *Iter { // we want to keep the query on the control connection q.conn = conn return conn.executeQuery(context.TODO(), q) }) if gocqlDebug && iter.err != nil { c.session.logger.Printf("control: error executing %q: %v\n", statement, iter.err) } q.AddAttempts(1, c.getConn().host) if iter.err == nil || !c.retry.Attempt(q) { break } } return } func (c *controlConn) awaitSchemaAgreement() error { return c.withConn(func(conn *Conn) *Iter { return &Iter{err: conn.awaitSchemaAgreement(context.TODO())} }).err } func (c *controlConn) close() { if atomic.CompareAndSwapInt32(&c.state, controlConnStarted, controlConnClosing) { c.quit <- struct{}{} } ch := c.getConn() if ch != nil { ch.conn.Close() } } var errNoControl = errors.New("gocql: no control connection available") cassandra-gocql-driver-1.7.0/control_ccm_test.go000066400000000000000000000116061467504044300217550ustar00rootroot00000000000000//go:build ccm // +build ccm /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "fmt" "sync" "testing" "time" "github.com/gocql/gocql/internal/ccm" ) type TestHostFilter struct { mu sync.Mutex allowedHosts map[string]ccm.Host } func (f *TestHostFilter) Accept(h *HostInfo) bool { f.mu.Lock() defer f.mu.Unlock() _, ok := f.allowedHosts[h.ConnectAddress().String()] return ok } func (f *TestHostFilter) SetAllowedHosts(hosts map[string]ccm.Host) { f.mu.Lock() defer f.mu.Unlock() f.allowedHosts = hosts } func TestControlConn_ReconnectRefreshesRing(t *testing.T) { if err := ccm.AllUp(); err != nil { t.Fatal(err) } allCcmHosts, err := ccm.Status() if err != nil { t.Fatal(err) } if len(allCcmHosts) < 2 { t.Skip("this test requires at least 2 nodes") } allAllowedHosts := map[string]ccm.Host{} var firstNode *ccm.Host for _, node := range allCcmHosts { if firstNode == nil { firstNode = &node } allAllowedHosts[node.Addr] = node } allowedHosts := map[string]ccm.Host{ firstNode.Addr: *firstNode, } testFilter := &TestHostFilter{allowedHosts: allowedHosts} session := createSession(t, func(config *ClusterConfig) { config.Hosts = []string{firstNode.Addr} config.Events.DisableTopologyEvents = true config.Events.DisableNodeStatusEvents = true config.HostFilter = testFilter }) defer session.Close() if session.control == nil || session.control.conn.Load() == nil { t.Fatal("control conn is nil") } controlConnection := session.control.getConn() ccHost := controlConnection.host var ccHostName string for _, node := range allCcmHosts { if node.Addr == ccHost.ConnectAddress().String() { ccHostName = node.Name break } } if ccHostName == "" { t.Fatal("could not find name of control host") } if err := ccm.NodeDown(ccHostName); err != nil { t.Fatal() } defer func() { ccmStatus, err := ccm.Status() if err != nil { t.Logf("could not bring nodes back up after test: %v", err) return } for _, node := range ccmStatus { if node.State == ccm.NodeStateDown { err = ccm.NodeUp(node.Name) if err != nil { t.Logf("could not bring node %v back up after test: %v", node.Name, err) } } } }() assertNodeDown := func() error { hosts := session.ring.currentHosts() if len(hosts) != 1 { return fmt.Errorf("expected 1 host in ring but there were %v", len(hosts)) } for _, host := range hosts { if host.IsUp() { return fmt.Errorf("expected host to be DOWN but %v isn't", host.String()) } } session.pool.mu.RLock() poolsLen := len(session.pool.hostConnPools) session.pool.mu.RUnlock() if poolsLen != 0 { return fmt.Errorf("expected 0 connection pool but there were %v", poolsLen) } return nil } maxAttempts := 5 delayPerAttempt := 1 * time.Second assertErr := assertNodeDown() for i := 0; i < maxAttempts && assertErr != nil; i++ { time.Sleep(delayPerAttempt) assertErr = assertNodeDown() } if assertErr != nil { t.Fatal(err) } testFilter.SetAllowedHosts(allAllowedHosts) if err = ccm.NodeUp(ccHostName); err != nil { t.Fatal(err) } assertNodeUp := func() error { hosts := session.ring.currentHosts() if len(hosts) != len(allCcmHosts) { return fmt.Errorf("expected %v hosts in ring but there were %v", len(allCcmHosts), len(hosts)) } for _, host := range hosts { if !host.IsUp() { return fmt.Errorf("expected all hosts to be UP but %v isn't", host.String()) } } session.pool.mu.RLock() poolsLen := len(session.pool.hostConnPools) session.pool.mu.RUnlock() if poolsLen != len(allCcmHosts) { return fmt.Errorf("expected %v connection pool but there were %v", len(allCcmHosts), poolsLen) } return nil } maxAttempts = 30 delayPerAttempt = 1 * time.Second assertErr = assertNodeUp() for i := 0; i < maxAttempts && assertErr != nil; i++ { time.Sleep(delayPerAttempt) assertErr = assertNodeUp() } if assertErr != nil { t.Fatal(err) } } cassandra-gocql-driver-1.7.0/control_test.go000066400000000000000000000045361467504044300211370ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "net" "testing" ) func TestHostInfo_Lookup(t *testing.T) { hostLookupPreferV4 = true defer func() { hostLookupPreferV4 = false }() tests := [...]struct { addr string ip net.IP }{ {"127.0.0.1", net.IPv4(127, 0, 0, 1)}, {"localhost", net.IPv4(127, 0, 0, 1)}, // TODO: this may be host dependant } for i, test := range tests { hosts, err := hostInfo(test.addr, 1) if err != nil { t.Errorf("%d: %v", i, err) continue } host := hosts[0] if !host.ConnectAddress().Equal(test.ip) { t.Errorf("expected ip %v got %v for addr %q", test.ip, host.ConnectAddress(), test.addr) } } } func TestParseProtocol(t *testing.T) { tests := [...]struct { err error proto int }{ { err: &protocolError{ frame: errorFrame{ code: 0x10, message: "Invalid or unsupported protocol version (5); the lowest supported version is 3 and the greatest is 4", }, }, proto: 4, }, { err: &protocolError{ frame: errorFrame{ frameHeader: frameHeader{ version: 0x83, }, code: 0x10, message: "Invalid or unsupported protocol version: 5", }, }, proto: 3, }, } for i, test := range tests { if proto := parseProtocolFromError(test.err); proto != test.proto { t.Errorf("%d: exepcted proto %d got %d", i, test.proto, proto) } } } cassandra-gocql-driver-1.7.0/cqltypes.go000066400000000000000000000021661467504044300202610ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql type Duration struct { Months int32 Days int32 Nanoseconds int64 } cassandra-gocql-driver-1.7.0/debug_off.go000066400000000000000000000021551467504044300203330ustar00rootroot00000000000000//go:build !gocql_debug // +build !gocql_debug /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql const gocqlDebug = false cassandra-gocql-driver-1.7.0/debug_on.go000066400000000000000000000021521467504044300201720ustar00rootroot00000000000000//go:build gocql_debug // +build gocql_debug /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql const gocqlDebug = true cassandra-gocql-driver-1.7.0/dial.go000066400000000000000000000077131467504044300173310ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "context" "crypto/tls" "fmt" "net" "strings" ) // HostDialer allows customizing connection to cluster nodes. type HostDialer interface { // DialHost establishes a connection to the host. // The returned connection must be directly usable for CQL protocol, // specifically DialHost is responsible also for setting up the TLS session if needed. // DialHost should disable write coalescing if the returned net.Conn does not support writev. // As of Go 1.18, only plain TCP connections support writev, TLS sessions should disable coalescing. // You can use WrapTLS helper function if you don't need to override the TLS setup. DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) } // DialedHost contains information about established connection to a host. type DialedHost struct { // Conn used to communicate with the server. Conn net.Conn // DisableCoalesce disables write coalescing for the Conn. // If true, the effect is the same as if WriteCoalesceWaitTime was configured to 0. DisableCoalesce bool } // defaultHostDialer dials host in a default way. type defaultHostDialer struct { dialer Dialer tlsConfig *tls.Config } func (hd *defaultHostDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) { ip := host.ConnectAddress() port := host.Port() if !validIpAddr(ip) { return nil, fmt.Errorf("host missing connect ip address: %v", ip) } else if port == 0 { return nil, fmt.Errorf("host missing port: %v", port) } connAddr := host.ConnectAddressAndPort() conn, err := hd.dialer.DialContext(ctx, "tcp", connAddr) if err != nil { return nil, err } addr := host.HostnameAndPort() return WrapTLS(ctx, conn, addr, hd.tlsConfig) } func tlsConfigForAddr(tlsConfig *tls.Config, addr string) *tls.Config { // the TLS config is safe to be reused by connections but it must not // be modified after being used. if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" { colonPos := strings.LastIndex(addr, ":") if colonPos == -1 { colonPos = len(addr) } hostname := addr[:colonPos] // clone config to avoid modifying the shared one. tlsConfig = tlsConfig.Clone() tlsConfig.ServerName = hostname } return tlsConfig } // WrapTLS optionally wraps a net.Conn connected to addr with the given tlsConfig. // If the tlsConfig is nil, conn is not wrapped into a TLS session, so is insecure. // If the tlsConfig does not have server name set, it is updated based on the default gocql rules. func WrapTLS(ctx context.Context, conn net.Conn, addr string, tlsConfig *tls.Config) (*DialedHost, error) { if tlsConfig != nil { tlsConfig := tlsConfigForAddr(tlsConfig, addr) tconn := tls.Client(conn, tlsConfig) if err := tconn.HandshakeContext(ctx); err != nil { conn.Close() return nil, err } conn = tconn } return &DialedHost{ Conn: conn, DisableCoalesce: tlsConfig != nil, // write coalescing can't use writev when the connection is wrapped. }, nil } cassandra-gocql-driver-1.7.0/doc.go000066400000000000000000000430151467504044300171600ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ // Package gocql implements a fast and robust Cassandra driver for the // Go programming language. // // # Connecting to the cluster // // Pass a list of initial node IP addresses to NewCluster to create a new cluster configuration: // // cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") // // Port can be specified as part of the address, the above is equivalent to: // // cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042") // // It is recommended to use the value set in the Cassandra config for broadcast_address or listen_address, // an IP address not a domain name. This is because events from Cassandra will use the configured IP // address, which is used to index connected hosts. If the domain name specified resolves to more than 1 IP address // then the driver may connect multiple times to the same host, and will not mark the node being down or up from events. // // Then you can customize more options (see ClusterConfig): // // cluster.Keyspace = "example" // cluster.Consistency = gocql.Quorum // cluster.ProtoVersion = 4 // // The driver tries to automatically detect the protocol version to use if not set, but you might want to set the // protocol version explicitly, as it's not defined which version will be used in certain situations (for example // during upgrade of the cluster when some of the nodes support different set of protocol versions than other nodes). // // The driver advertises the module name and version in the STARTUP message, so servers are able to detect the version. // If you use replace directive in go.mod, the driver will send information about the replacement module instead. // // When ready, create a session from the configuration. Don't forget to Close the session once you are done with it: // // session, err := cluster.CreateSession() // if err != nil { // return err // } // defer session.Close() // // # Authentication // // CQL protocol uses a SASL-based authentication mechanism and so consists of an exchange of server challenges and // client response pairs. The details of the exchanged messages depend on the authenticator used. // // To use authentication, set ClusterConfig.Authenticator or ClusterConfig.AuthProvider. // // PasswordAuthenticator is provided to use for username/password authentication: // // cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") // cluster.Authenticator = gocql.PasswordAuthenticator{ // Username: "user", // Password: "password" // } // session, err := cluster.CreateSession() // if err != nil { // return err // } // defer session.Close() // // # Transport layer security // // It is possible to secure traffic between the client and server with TLS. // // To use TLS, set the ClusterConfig.SslOpts field. SslOptions embeds *tls.Config so you can set that directly. // There are also helpers to load keys/certificates from files. // // Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification // to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config. // SslOptions and Config.InsecureSkipVerify interact as follows: // // Config.InsecureSkipVerify | EnableHostVerification | Result // Config is nil | false | do not verify host // Config is nil | true | verify host // false | false | verify host // true | false | do not verify host // false | true | verify host // true | true | verify host // // For example: // // cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") // cluster.SslOpts = &gocql.SslOptions{ // EnableHostVerification: true, // } // session, err := cluster.CreateSession() // if err != nil { // return err // } // defer session.Close() // // # Data-center awareness and query routing // // To route queries to local DC first, use DCAwareRoundRobinPolicy. For example, if the datacenter you // want to primarily connect is called dc1 (as configured in the database): // // cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") // cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy("dc1") // // The driver can route queries to nodes that hold data replicas based on partition key (preferring local DC). // // cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3") // cluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy("dc1")) // // Note that TokenAwareHostPolicy can take options such as gocql.ShuffleReplicas and gocql.NonLocalReplicasFallback. // // We recommend running with a token aware host policy in production for maximum performance. // // The driver can only use token-aware routing for queries where all partition key columns are query parameters. // For example, instead of // // session.Query("select value from mytable where pk1 = 'abc' AND pk2 = ?", "def") // // use // // session.Query("select value from mytable where pk1 = ? AND pk2 = ?", "abc", "def") // // # Rack-level awareness // // The DCAwareRoundRobinPolicy can be replaced with RackAwareRoundRobinPolicy, which takes two parameters, datacenter and rack. // // Instead of dividing hosts with two tiers (local datacenter and remote datacenters) it divides hosts into three // (the local rack, the rest of the local datacenter, and everything else). // // RackAwareRoundRobinPolicy can be combined with TokenAwareHostPolicy in the same way as DCAwareRoundRobinPolicy. // // # Executing queries // // Create queries with Session.Query. Query values must not be reused between different executions and must not be // modified after starting execution of the query. // // To execute a query without reading results, use Query.Exec: // // err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, // "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec() // // Single row can be read by calling Query.Scan: // // err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`, // "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text) // // Multiple rows can be read using Iter.Scanner: // // scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, // "me").WithContext(ctx).Iter().Scanner() // for scanner.Next() { // var ( // id gocql.UUID // text string // ) // err = scanner.Scan(&id, &text) // if err != nil { // log.Fatal(err) // } // fmt.Println("Tweet:", id, text) // } // // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards. // if err := scanner.Err(); err != nil { // log.Fatal(err) // } // // See Example for complete example. // // # Prepared statements // // The driver automatically prepares DML queries (SELECT/INSERT/UPDATE/DELETE/BATCH statements) and maintains a cache // of prepared statements. // CQL protocol does not support preparing other query types. // // When using CQL protocol >= 4, it is possible to use gocql.UnsetValue as the bound value of a column. // This will cause the database to ignore writing the column. // The main advantage is the ability to keep the same prepared statement even when you don't // want to update some fields, where before you needed to make another prepared statement. // // # Executing multiple queries concurrently // // Session is safe to use from multiple goroutines, so to execute multiple concurrent queries, just execute them // from several worker goroutines. Gocql provides synchronously-looking API (as recommended for Go APIs) and the queries // are executed asynchronously at the protocol level. // // results := make(chan error, 2) // go func() { // results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, // "me", gocql.TimeUUID(), "hello world 1").Exec() // }() // go func() { // results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, // "me", gocql.TimeUUID(), "hello world 2").Exec() // }() // // # Nulls // // Null values are are unmarshalled as zero value of the type. If you need to distinguish for example between text // column being null and empty string, you can unmarshal into *string variable instead of string. // // var text *string // err := scanner.Scan(&text) // if err != nil { // // handle error // } // if text != nil { // // not null // } // else { // // null // } // // See Example_nulls for full example. // // # Reusing slices // // The driver reuses backing memory of slices when unmarshalling. This is an optimization so that a buffer does not // need to be allocated for every processed row. However, you need to be careful when storing the slices to other // memory structures. // // scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner() // var myInts []int // for scanner.Next() { // // This scan reuses backing store of myInts for each row. // err = scanner.Scan(&myInts) // if err != nil { // log.Fatal(err) // } // } // // When you want to save the data for later use, pass a new slice every time. A common pattern is to declare the // slice variable within the scanner loop: // // scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner() // for scanner.Next() { // var myInts []int // // This scan always gets pointer to fresh myInts slice, so does not reuse memory. // err = scanner.Scan(&myInts) // if err != nil { // log.Fatal(err) // } // } // // # Paging // // The driver supports paging of results with automatic prefetch, see ClusterConfig.PageSize, Session.SetPrefetch, // Query.PageSize, and Query.Prefetch. // // It is also possible to control the paging manually with Query.PageState (this disables automatic prefetch). // Manual paging is useful if you want to store the page state externally, for example in a URL to allow users // browse pages in a result. You might want to sign/encrypt the paging state when exposing it externally since // it contains data from primary keys. // // Paging state is specific to the CQL protocol version and the exact query used. It is meant as opaque state that // should not be modified. If you send paging state from different query or protocol version, then the behaviour // is not defined (you might get unexpected results or an error from the server). For example, do not send paging state // returned by node using protocol version 3 to a node using protocol version 4. Also, when using protocol version 4, // paging state between Cassandra 2.2 and 3.0 is incompatible (https://issues.apache.org/jira/browse/CASSANDRA-10880). // // The driver does not check whether the paging state is from the same protocol version/statement. // You might want to validate yourself as this could be a problem if you store paging state externally. // For example, if you store paging state in a URL, the URLs might become broken when you upgrade your cluster. // // Call Query.PageState(nil) to fetch just the first page of the query results. Pass the page state returned by // Iter.PageState to Query.PageState of a subsequent query to get the next page. If the length of slice returned // by Iter.PageState is zero, there are no more pages available (or an error occurred). // // Using too low values of PageSize will negatively affect performance, a value below 100 is probably too low. // While Cassandra returns exactly PageSize items (except for last page) in a page currently, the protocol authors // explicitly reserved the right to return smaller or larger amount of items in a page for performance reasons, so don't // rely on the page having the exact count of items. // // See Example_paging for an example of manual paging. // // # Dynamic list of columns // // There are certain situations when you don't know the list of columns in advance, mainly when the query is supplied // by the user. Iter.Columns, Iter.RowData, Iter.MapScan and Iter.SliceMap can be used to handle this case. // // See Example_dynamicColumns. // // # Batches // // The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql. // Use Session.NewBatch to create a new batch and then fill-in details of individual queries. // Then execute the batch with Session.ExecuteBatch. // // Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have // overhead to ensure this property. // Unlogged batches don't have the overhead of logged batches, but don't guarantee atomicity. // Updates of counters are handled specially by Cassandra so batches of counter updates have to use CounterBatch type. // A counter batch can only contain statements to update counters. // // For unlogged batches it is recommended to send only single-partition batches (i.e. all statements in the batch should // involve only a single partition). // Multi-partition batch needs to be split by the coordinator node and re-sent to // correct nodes. // With single-partition batches you can send the batch directly to the node for the partition without incurring the // additional network hop. // // It is also possible to pass entire BEGIN BATCH .. APPLY BATCH statement to Query.Exec. // There are differences how those are executed. // BEGIN BATCH statement passed to Query.Exec is prepared as a whole in a single statement. // Session.ExecuteBatch prepares individual statements in the batch. // If you have variable-length batches using the same statement, using Session.ExecuteBatch is more efficient. // // See Example_batch for an example. // // # Lightweight transactions // // Query.ScanCAS or Query.MapScanCAS can be used to execute a single-statement lightweight transaction (an // INSERT/UPDATE .. IF statement) and reading its result. See example for Query.MapScanCAS. // // Multiple-statement lightweight transactions can be executed as a logged batch that contains at least one conditional // statement. All the conditions must return true for the batch to be applied. You can use Session.ExecuteBatchCAS and // Session.MapExecuteBatchCAS when executing the batch to learn about the result of the LWT. See example for // Session.MapExecuteBatchCAS. // // # Retries and speculative execution // // Queries can be marked as idempotent. Marking the query as idempotent tells the driver that the query can be executed // multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative // execution. // // Idempotent queries are retried in case of errors based on the configured RetryPolicy. // // Queries can be retried even before they fail by setting a SpeculativeExecutionPolicy. The policy can // cause the driver to retry on a different node if the query is taking longer than a specified delay even before the // driver receives an error or timeout from the server. When a query is speculatively executed, the original execution // is still executing. The two parallel executions of the query race to return a result, the first received result will // be returned. // // # User-defined types // // UDTs can be mapped (un)marshaled from/to map[string]interface{} a Go struct (or a type implementing // UDTUnmarshaler, UDTMarshaler, Unmarshaler or Marshaler interfaces). // // For structs, cql tag can be used to specify the CQL field name to be mapped to a struct field: // // type MyUDT struct { // FieldA int32 `cql:"a"` // FieldB string `cql:"b"` // } // // See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler. // // # Metrics and tracing // // It is possible to provide observer implementations that could be used to gather metrics: // // - QueryObserver for monitoring individual queries. // - BatchObserver for monitoring batch queries. // - ConnectObserver for monitoring new connections from the driver to the database. // - FrameHeaderObserver for monitoring individual protocol frames. // // CQL protocol also supports tracing of queries. When enabled, the database will write information about // internal events that happened during execution of the query. You can use Query.Trace to request tracing and receive // the session ID that the database used to store the trace information in system_traces.sessions and // system_traces.events tables. NewTraceWriter returns an implementation of Tracer that writes the events to a writer. // Gathering trace information might be essential for debugging and optimizing queries, but writing traces has overhead, // so this feature should not be used on production systems with very high load unless you know what you are doing. package gocql // import "github.com/gocql/gocql" cassandra-gocql-driver-1.7.0/errors.go000066400000000000000000000161321467504044300177270ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import "fmt" // See CQL Binary Protocol v5, section 8 for more details. // https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec const ( // ErrCodeServer indicates unexpected error on server-side. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1246-L1247 ErrCodeServer = 0x0000 // ErrCodeProtocol indicates a protocol violation by some client message. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1248-L1250 ErrCodeProtocol = 0x000A // ErrCodeCredentials indicates missing required authentication. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1251-L1254 ErrCodeCredentials = 0x0100 // ErrCodeUnavailable indicates unavailable error. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1255-L1265 ErrCodeUnavailable = 0x1000 // ErrCodeOverloaded returned in case of request on overloaded node coordinator. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1266-L1267 ErrCodeOverloaded = 0x1001 // ErrCodeBootstrapping returned from the coordinator node in bootstrapping phase. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1268-L1269 ErrCodeBootstrapping = 0x1002 // ErrCodeTruncate indicates truncation exception. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1270 ErrCodeTruncate = 0x1003 // ErrCodeWriteTimeout returned in case of timeout during the request write. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1271-L1304 ErrCodeWriteTimeout = 0x1100 // ErrCodeReadTimeout returned in case of timeout during the request read. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1305-L1321 ErrCodeReadTimeout = 0x1200 // ErrCodeReadFailure indicates request read error which is not covered by ErrCodeReadTimeout. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1322-L1340 ErrCodeReadFailure = 0x1300 // ErrCodeFunctionFailure indicates an error in user-defined function. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1341-L1347 ErrCodeFunctionFailure = 0x1400 // ErrCodeWriteFailure indicates request write error which is not covered by ErrCodeWriteTimeout. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1348-L1385 ErrCodeWriteFailure = 0x1500 // ErrCodeCDCWriteFailure is defined, but not yet documented in CQLv5 protocol. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1386 ErrCodeCDCWriteFailure = 0x1600 // ErrCodeCASWriteUnknown indicates only partially completed CAS operation. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 ErrCodeCASWriteUnknown = 0x1700 // ErrCodeSyntax indicates the syntax error in the query. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1399 ErrCodeSyntax = 0x2000 // ErrCodeUnauthorized indicates access rights violation by user on performed operation. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1400-L1401 ErrCodeUnauthorized = 0x2100 // ErrCodeInvalid indicates invalid query error which is not covered by ErrCodeSyntax. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1402 ErrCodeInvalid = 0x2200 // ErrCodeConfig indicates the configuration error. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1403 ErrCodeConfig = 0x2300 // ErrCodeAlreadyExists is returned for the requests creating the existing keyspace/table. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1404-L1413 ErrCodeAlreadyExists = 0x2400 // ErrCodeUnprepared returned from the host for prepared statement which is unknown. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1414-L1417 ErrCodeUnprepared = 0x2500 ) type RequestError interface { Code() int Message() string Error() string } type errorFrame struct { frameHeader code int message string } func (e errorFrame) Code() int { return e.code } func (e errorFrame) Message() string { return e.message } func (e errorFrame) Error() string { return e.Message() } func (e errorFrame) String() string { return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message) } type RequestErrUnavailable struct { errorFrame Consistency Consistency Required int Alive int } func (e *RequestErrUnavailable) String() string { return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive) } type ErrorMap map[string]uint16 type RequestErrWriteTimeout struct { errorFrame Consistency Consistency Received int BlockFor int WriteType string } type RequestErrWriteFailure struct { errorFrame Consistency Consistency Received int BlockFor int NumFailures int WriteType string ErrorMap ErrorMap } type RequestErrCDCWriteFailure struct { errorFrame } type RequestErrReadTimeout struct { errorFrame Consistency Consistency Received int BlockFor int DataPresent byte } type RequestErrAlreadyExists struct { errorFrame Keyspace string Table string } type RequestErrUnprepared struct { errorFrame StatementId []byte } type RequestErrReadFailure struct { errorFrame Consistency Consistency Received int BlockFor int NumFailures int DataPresent bool ErrorMap ErrorMap } type RequestErrFunctionFailure struct { errorFrame Keyspace string Function string ArgTypes []string } // RequestErrCASWriteUnknown is distinct error for ErrCodeCasWriteUnknown. // // See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397 type RequestErrCASWriteUnknown struct { errorFrame Consistency Consistency Received int BlockFor int } cassandra-gocql-driver-1.7.0/errors_test.go000066400000000000000000000034421467504044300207660ustar00rootroot00000000000000//go:build all || cassandra // +build all cassandra /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "testing" ) func TestErrorsParse(t *testing.T) { session := createSession(t) defer session.Close() if err := createTable(session, `CREATE TABLE gocql_test.errors_parse (id int primary key)`); err != nil { t.Fatal("create:", err) } if err := createTable(session, `CREATE TABLE gocql_test.errors_parse (id int primary key)`); err == nil { t.Fatal("Should have gotten already exists error from cassandra server.") } else { switch e := err.(type) { case *RequestErrAlreadyExists: if e.Table != "errors_parse" { t.Fatalf("expected error table to be 'errors_parse' but was %q", e.Table) } default: t.Fatalf("expected to get RequestErrAlreadyExists instead got %T", e) } } } cassandra-gocql-driver-1.7.0/events.go000066400000000000000000000152611467504044300177210ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "net" "sync" "time" ) type eventDebouncer struct { name string timer *time.Timer mu sync.Mutex events []frame callback func([]frame) quit chan struct{} logger StdLogger } func newEventDebouncer(name string, eventHandler func([]frame), logger StdLogger) *eventDebouncer { e := &eventDebouncer{ name: name, quit: make(chan struct{}), timer: time.NewTimer(eventDebounceTime), callback: eventHandler, logger: logger, } e.timer.Stop() go e.flusher() return e } func (e *eventDebouncer) stop() { e.quit <- struct{}{} // sync with flusher close(e.quit) } func (e *eventDebouncer) flusher() { for { select { case <-e.timer.C: e.mu.Lock() e.flush() e.mu.Unlock() case <-e.quit: return } } } const ( eventBufferSize = 1000 eventDebounceTime = 1 * time.Second ) // flush must be called with mu locked func (e *eventDebouncer) flush() { if len(e.events) == 0 { return } // if the flush interval is faster than the callback then we will end up calling // the callback multiple times, probably a bad idea. In this case we could drop // frames? go e.callback(e.events) e.events = make([]frame, 0, eventBufferSize) } func (e *eventDebouncer) debounce(frame frame) { e.mu.Lock() e.timer.Reset(eventDebounceTime) // TODO: probably need a warning to track if this threshold is too low if len(e.events) < eventBufferSize { e.events = append(e.events, frame) } else { e.logger.Printf("%s: buffer full, dropping event frame: %s", e.name, frame) } e.mu.Unlock() } func (s *Session) handleEvent(framer *framer) { frame, err := framer.parseFrame() if err != nil { s.logger.Printf("gocql: unable to parse event frame: %v\n", err) return } if gocqlDebug { s.logger.Printf("gocql: handling frame: %v\n", frame) } switch f := frame.(type) { case *schemaChangeKeyspace, *schemaChangeFunction, *schemaChangeTable, *schemaChangeAggregate, *schemaChangeType: s.schemaEvents.debounce(frame) case *topologyChangeEventFrame, *statusChangeEventFrame: s.nodeEvents.debounce(frame) default: s.logger.Printf("gocql: invalid event frame (%T): %v\n", f, f) } } func (s *Session) handleSchemaEvent(frames []frame) { // TODO: debounce events for _, frame := range frames { switch f := frame.(type) { case *schemaChangeKeyspace: s.schemaDescriber.clearSchema(f.keyspace) s.handleKeyspaceChange(f.keyspace, f.change) case *schemaChangeTable: s.schemaDescriber.clearSchema(f.keyspace) case *schemaChangeAggregate: s.schemaDescriber.clearSchema(f.keyspace) case *schemaChangeFunction: s.schemaDescriber.clearSchema(f.keyspace) case *schemaChangeType: s.schemaDescriber.clearSchema(f.keyspace) } } } func (s *Session) handleKeyspaceChange(keyspace, change string) { s.control.awaitSchemaAgreement() s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change}) } // handleNodeEvent handles inbound status and topology change events. // // Status events are debounced by host IP; only the latest event is processed. // // Topology events are debounced by performing a single full topology refresh // whenever any topology event comes in. // // Processing topology change events before status change events ensures // that a NEW_NODE event is not dropped in favor of a newer UP event (which // would itself be dropped/ignored, as the node is not yet known). func (s *Session) handleNodeEvent(frames []frame) { type nodeEvent struct { change string host net.IP port int } topologyEventReceived := false // status change events sEvents := make(map[string]*nodeEvent) for _, frame := range frames { switch f := frame.(type) { case *topologyChangeEventFrame: topologyEventReceived = true case *statusChangeEventFrame: event, ok := sEvents[f.host.String()] if !ok { event = &nodeEvent{change: f.change, host: f.host, port: f.port} sEvents[f.host.String()] = event } event.change = f.change } } if topologyEventReceived && !s.cfg.Events.DisableTopologyEvents { s.debounceRingRefresh() } for _, f := range sEvents { if gocqlDebug { s.logger.Printf("gocql: dispatching status change event: %+v\n", f) } // ignore events we received if they were disabled // see https://github.com/apache/cassandra-gocql-driver/issues/1591 switch f.change { case "UP": if !s.cfg.Events.DisableNodeStatusEvents { s.handleNodeUp(f.host, f.port) } case "DOWN": if !s.cfg.Events.DisableNodeStatusEvents { s.handleNodeDown(f.host, f.port) } } } } func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) { if gocqlDebug { s.logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort) } host, ok := s.ring.getHostByIP(eventIp.String()) if !ok { s.debounceRingRefresh() return } if s.cfg.filterHost(host) { return } if d := host.Version().nodeUpDelay(); d > 0 { time.Sleep(d) } s.startPoolFill(host) } func (s *Session) startPoolFill(host *HostInfo) { // we let the pool call handleNodeConnected to change the host state s.pool.addHost(host) s.policy.AddHost(host) } func (s *Session) handleNodeConnected(host *HostInfo) { if gocqlDebug { s.logger.Printf("gocql: Session.handleNodeConnected: %s:%d\n", host.ConnectAddress(), host.Port()) } host.setState(NodeUp) if !s.cfg.filterHost(host) { s.policy.HostUp(host) } } func (s *Session) handleNodeDown(ip net.IP, port int) { if gocqlDebug { s.logger.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port) } host, ok := s.ring.getHostByIP(ip.String()) if ok { host.setState(NodeDown) if s.cfg.filterHost(host) { return } s.policy.HostDown(host) hostID := host.HostID() s.pool.removeHost(hostID) } } cassandra-gocql-driver-1.7.0/events_ccm_test.go000066400000000000000000000165141467504044300216040ustar00rootroot00000000000000//go:build (ccm && ignore) || ignore // +build ccm,ignore ignore /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "log" "testing" "time" "github.com/gocql/gocql/internal/ccm" ) func TestEventDiscovery(t *testing.T) { t.Skip("FLAKE skipping") if err := ccm.AllUp(); err != nil { t.Fatal(err) } session := createSession(t) defer session.Close() status, err := ccm.Status() if err != nil { t.Fatal(err) } t.Logf("status=%+v\n", status) session.pool.mu.RLock() poolHosts := session.pool.hostConnPools // TODO: replace with session.ring t.Logf("poolhosts=%+v\n", poolHosts) // check we discovered all the nodes in the ring for _, host := range status { if _, ok := poolHosts[host.Addr]; !ok { t.Errorf("did not discover %q", host.Addr) } } session.pool.mu.RUnlock() if t.Failed() { t.FailNow() } } func TestEventNodeDownControl(t *testing.T) { t.Skip("FLAKE skipping") const targetNode = "node1" if err := ccm.AllUp(); err != nil { t.Fatal(err) } status, err := ccm.Status() if err != nil { t.Fatal(err) } cluster := createCluster() cluster.Hosts = []string{status[targetNode].Addr} session := createSessionFromCluster(cluster, t) defer session.Close() t.Log("marking " + targetNode + " as down") if err := ccm.NodeDown(targetNode); err != nil { t.Fatal(err) } t.Logf("status=%+v\n", status) t.Logf("marking node %q down: %v\n", targetNode, status[targetNode]) time.Sleep(5 * time.Second) session.pool.mu.RLock() poolHosts := session.pool.hostConnPools node := status[targetNode] t.Logf("poolhosts=%+v\n", poolHosts) if _, ok := poolHosts[node.Addr]; ok { session.pool.mu.RUnlock() t.Fatal("node not removed after remove event") } session.pool.mu.RUnlock() host := session.ring.getHost(node.Addr) if host == nil { t.Fatal("node not in metadata ring") } else if host.IsUp() { t.Fatalf("not not marked as down after event in metadata: %v", host) } } func TestEventNodeDown(t *testing.T) { t.Skip("FLAKE skipping") const targetNode = "node3" if err := ccm.AllUp(); err != nil { t.Fatal(err) } session := createSession(t) defer session.Close() if err := ccm.NodeDown(targetNode); err != nil { t.Fatal(err) } status, err := ccm.Status() if err != nil { t.Fatal(err) } t.Logf("status=%+v\n", status) t.Logf("marking node %q down: %v\n", targetNode, status[targetNode]) time.Sleep(5 * time.Second) session.pool.mu.RLock() defer session.pool.mu.RUnlock() poolHosts := session.pool.hostConnPools node := status[targetNode] t.Logf("poolhosts=%+v\n", poolHosts) if _, ok := poolHosts[node.Addr]; ok { t.Fatal("node not removed after remove event") } host := session.ring.getHost(node.Addr) if host == nil { t.Fatal("node not in metadata ring") } else if host.IsUp() { t.Fatalf("not not marked as down after event in metadata: %v", host) } } func TestEventNodeUp(t *testing.T) { t.Skip("FLAKE skipping") if err := ccm.AllUp(); err != nil { t.Fatal(err) } status, err := ccm.Status() if err != nil { t.Fatal(err) } log.Printf("status=%+v\n", status) session := createSession(t) defer session.Close() const targetNode = "node2" node := status[targetNode] _, ok := session.pool.getPool(node.Addr) if !ok { session.pool.mu.RLock() t.Errorf("target pool not in connection pool: addr=%q pools=%v", status[targetNode].Addr, session.pool.hostConnPools) session.pool.mu.RUnlock() t.FailNow() } if err := ccm.NodeDown(targetNode); err != nil { t.Fatal(err) } time.Sleep(5 * time.Second) _, ok = session.pool.getPool(node.Addr) if ok { t.Fatal("node not removed after remove event") } if err := ccm.NodeUp(targetNode); err != nil { t.Fatal(err) } // cassandra < 2.2 needs 10 seconds to start up the binary service time.Sleep(15 * time.Second) _, ok = session.pool.getPool(node.Addr) if !ok { t.Fatal("node not added after node added event") } host := session.ring.getHost(node.Addr) if host == nil { t.Fatal("node not in metadata ring") } else if !host.IsUp() { t.Fatalf("not not marked as UP after event in metadata: addr=%q host=%p: %v", node.Addr, host, host) } } func TestEventFilter(t *testing.T) { t.Skip("FLAKE skipping") if err := ccm.AllUp(); err != nil { t.Fatal(err) } status, err := ccm.Status() if err != nil { t.Fatal(err) } log.Printf("status=%+v\n", status) cluster := createCluster() cluster.HostFilter = WhiteListHostFilter(status["node1"].Addr) session := createSessionFromCluster(cluster, t) defer session.Close() if _, ok := session.pool.getPool(status["node1"].Addr); !ok { t.Errorf("should have %v in pool but dont", "node1") } for _, host := range [...]string{"node2", "node3"} { _, ok := session.pool.getPool(status[host].Addr) if ok { t.Errorf("should not have %v in pool", host) } } if t.Failed() { t.FailNow() } if err := ccm.NodeDown("node2"); err != nil { t.Fatal(err) } time.Sleep(5 * time.Second) if err := ccm.NodeUp("node2"); err != nil { t.Fatal(err) } time.Sleep(15 * time.Second) for _, host := range [...]string{"node2", "node3"} { _, ok := session.pool.getPool(status[host].Addr) if ok { t.Errorf("should not have %v in pool", host) } } if t.Failed() { t.FailNow() } } func TestEventDownQueryable(t *testing.T) { t.Skip("FLAKE skipping") if err := ccm.AllUp(); err != nil { t.Fatal(err) } status, err := ccm.Status() if err != nil { t.Fatal(err) } log.Printf("status=%+v\n", status) const targetNode = "node1" addr := status[targetNode].Addr cluster := createCluster() cluster.Hosts = []string{addr} cluster.HostFilter = WhiteListHostFilter(addr) session := createSessionFromCluster(cluster, t) defer session.Close() if pool, ok := session.pool.getPool(addr); !ok { t.Fatalf("should have %v in pool but dont", addr) } else if !pool.host.IsUp() { t.Fatalf("host is not up %v", pool.host) } if err := ccm.NodeDown(targetNode); err != nil { t.Fatal(err) } time.Sleep(5 * time.Second) if err := ccm.NodeUp(targetNode); err != nil { t.Fatal(err) } time.Sleep(15 * time.Second) if pool, ok := session.pool.getPool(addr); !ok { t.Fatalf("should have %v in pool but dont", addr) } else if !pool.host.IsUp() { t.Fatalf("host is not up %v", pool.host) } var rows int if err := session.Query("SELECT COUNT(*) FROM system.local").Scan(&rows); err != nil { t.Fatal(err) } else if rows != 1 { t.Fatalf("expected to get 1 row got %d", rows) } } cassandra-gocql-driver-1.7.0/events_test.go000066400000000000000000000031751467504044300207610ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "net" "sync" "testing" ) func TestEventDebounce(t *testing.T) { const eventCount = 150 wg := &sync.WaitGroup{} wg.Add(1) eventsSeen := 0 debouncer := newEventDebouncer("testDebouncer", func(events []frame) { defer wg.Done() eventsSeen += len(events) }, &defaultLogger{}) defer debouncer.stop() for i := 0; i < eventCount; i++ { debouncer.debounce(&statusChangeEventFrame{ change: "UP", host: net.IPv4(127, 0, 0, 1), port: 9042, }) } wg.Wait() if eventCount != eventsSeen { t.Fatalf("expected to see %d events but got %d", eventCount, eventsSeen) } } cassandra-gocql-driver-1.7.0/example_batch_test.go000066400000000000000000000050041467504044300222420ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "fmt" "log" gocql "github.com/gocql/gocql" ) // Example_batch demonstrates how to execute a batch of statements. func Example_batch() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create table example.batches(pk int, ck int, description text, PRIMARY KEY(pk, ck)); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() ctx := context.Background() b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx) b.Entries = append(b.Entries, gocql.BatchEntry{ Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", Args: []interface{}{1, 2, "1.2"}, Idempotent: true, }) b.Entries = append(b.Entries, gocql.BatchEntry{ Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", Args: []interface{}{1, 3, "1.3"}, Idempotent: true, }) err = session.ExecuteBatch(b) if err != nil { log.Fatal(err) } scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner() for scanner.Next() { var pk, ck int32 var description string err = scanner.Scan(&pk, &ck, &description) if err != nil { log.Fatal(err) } fmt.Println(pk, ck, description) } // 1 2 1.2 // 1 3 1.3 } cassandra-gocql-driver-1.7.0/example_dynamic_columns_test.go000066400000000000000000000073431467504044300243550ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "fmt" "log" "os" "reflect" "text/tabwriter" gocql "github.com/gocql/gocql" ) // Example_dynamicColumns demonstrates how to handle dynamic column list. func Example_dynamicColumns() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create table example.table1(pk text, ck int, value1 text, value2 int, PRIMARY KEY(pk, ck)); insert into example.table1 (pk, ck, value1, value2) values ('a', 1, 'b', 2); insert into example.table1 (pk, ck, value1, value2) values ('c', 3, 'd', 4); insert into example.table1 (pk, ck, value1, value2) values ('c', 5, null, null); create table example.table2(pk int, value1 timestamp, PRIMARY KEY(pk)); insert into example.table2 (pk, value1) values (1, '2020-01-02 03:04:05'); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() printQuery := func(ctx context.Context, session *gocql.Session, stmt string, values ...interface{}) error { iter := session.Query(stmt, values...).WithContext(ctx).Iter() fmt.Println(stmt) w := tabwriter.NewWriter(os.Stdout, 0, 0, 1, ' ', 0) for i, columnInfo := range iter.Columns() { if i > 0 { fmt.Fprint(w, "\t| ") } fmt.Fprintf(w, "%s (%s)", columnInfo.Name, columnInfo.TypeInfo) } for { rd, err := iter.RowData() if err != nil { return err } if !iter.Scan(rd.Values...) { break } fmt.Fprint(w, "\n") for i, val := range rd.Values { if i > 0 { fmt.Fprint(w, "\t| ") } fmt.Fprint(w, reflect.Indirect(reflect.ValueOf(val)).Interface()) } } fmt.Fprint(w, "\n") w.Flush() fmt.Println() return iter.Close() } ctx := context.Background() err = printQuery(ctx, session, "SELECT * FROM table1") if err != nil { log.Fatal(err) } err = printQuery(ctx, session, "SELECT value2, pk, ck FROM table1") if err != nil { log.Fatal(err) } err = printQuery(ctx, session, "SELECT * FROM table2") if err != nil { log.Fatal(err) } // SELECT * FROM table1 // pk (varchar) | ck (int) | value1 (varchar) | value2 (int) // a | 1 | b | 2 // c | 3 | d | 4 // c | 5 | | 0 // // SELECT value2, pk, ck FROM table1 // value2 (int) | pk (varchar) | ck (int) // 2 | a | 1 // 4 | c | 3 // 0 | c | 5 // // SELECT * FROM table2 // pk (int) | value1 (timestamp) // 1 | 2020-01-02 03:04:05 +0000 UTC } cassandra-gocql-driver-1.7.0/example_lwt_batch_test.go000066400000000000000000000070131467504044300231320ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "fmt" "log" gocql "github.com/gocql/gocql" ) // ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch lightweight transaction. func ExampleSession_MapExecuteBatchCAS() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create table example.my_lwt_batch_table(pk text, ck text, version int, value text, PRIMARY KEY(pk, ck)); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() ctx := context.Background() err = session.Query("INSERT INTO example.my_lwt_batch_table (pk, ck, version, value) VALUES (?, ?, ?, ?)", "pk1", "ck1", 1, "a").WithContext(ctx).Exec() if err != nil { log.Fatal(err) } err = session.Query("INSERT INTO example.my_lwt_batch_table (pk, ck, version, value) VALUES (?, ?, ?, ?)", "pk1", "ck2", 1, "A").WithContext(ctx).Exec() if err != nil { log.Fatal(err) } executeBatch := func(ck2Version int) { b := session.NewBatch(gocql.LoggedBatch) b.Entries = append(b.Entries, gocql.BatchEntry{ Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?", Args: []interface{}{"b", "pk1", "ck1", 1}, }) b.Entries = append(b.Entries, gocql.BatchEntry{ Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?", Args: []interface{}{"B", "pk1", "ck2", ck2Version}, }) m := make(map[string]interface{}) applied, iter, err := session.MapExecuteBatchCAS(b.WithContext(ctx), m) if err != nil { log.Fatal(err) } fmt.Println(applied, m) m = make(map[string]interface{}) for iter.MapScan(m) { fmt.Println(m) m = make(map[string]interface{}) } if err := iter.Close(); err != nil { log.Fatal(err) } } printState := func() { scanner := session.Query("SELECT ck, value FROM example.my_lwt_batch_table WHERE pk = ?", "pk1"). WithContext(ctx).Iter().Scanner() for scanner.Next() { var ck, value string err = scanner.Scan(&ck, &value) if err != nil { log.Fatal(err) } fmt.Println(ck, value) } if err := scanner.Err(); err != nil { log.Fatal(err) } } executeBatch(0) printState() executeBatch(1) printState() // false map[ck:ck1 pk:pk1 version:1] // map[[applied]:false ck:ck2 pk:pk1 version:1] // ck1 a // ck2 A // true map[] // ck1 b // ck2 B } cassandra-gocql-driver-1.7.0/example_lwt_test.go000066400000000000000000000054531467504044300217770ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "fmt" "log" gocql "github.com/gocql/gocql" ) // ExampleQuery_MapScanCAS demonstrates how to execute a single-statement lightweight transaction. func ExampleQuery_MapScanCAS() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create table example.my_lwt_table(pk int, version int, value text, PRIMARY KEY(pk)); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() ctx := context.Background() err = session.Query("INSERT INTO example.my_lwt_table (pk, version, value) VALUES (?, ?, ?)", 1, 1, "a").WithContext(ctx).Exec() if err != nil { log.Fatal(err) } m := make(map[string]interface{}) applied, err := session.Query("UPDATE example.my_lwt_table SET value = ? WHERE pk = ? IF version = ?", "b", 1, 0).WithContext(ctx).MapScanCAS(m) if err != nil { log.Fatal(err) } fmt.Println(applied, m) var value string err = session.Query("SELECT value FROM example.my_lwt_table WHERE pk = ?", 1).WithContext(ctx). Scan(&value) if err != nil { log.Fatal(err) } fmt.Println(value) m = make(map[string]interface{}) applied, err = session.Query("UPDATE example.my_lwt_table SET value = ? WHERE pk = ? IF version = ?", "b", 1, 1).WithContext(ctx).MapScanCAS(m) if err != nil { log.Fatal(err) } fmt.Println(applied, m) var value2 string err = session.Query("SELECT value FROM example.my_lwt_table WHERE pk = ?", 1).WithContext(ctx). Scan(&value2) if err != nil { log.Fatal(err) } fmt.Println(value2) // false map[version:1] // a // true map[] // b } cassandra-gocql-driver-1.7.0/example_marshaler_test.go000066400000000000000000000066731467504044300231540ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "fmt" "log" "strconv" "strings" gocql "github.com/gocql/gocql" ) // MyMarshaler implements Marshaler and Unmarshaler. // It represents a version number stored as string. type MyMarshaler struct { major, minor, patch int } func (m MyMarshaler) MarshalCQL(info gocql.TypeInfo) ([]byte, error) { return gocql.Marshal(info, fmt.Sprintf("%d.%d.%d", m.major, m.minor, m.patch)) } func (m *MyMarshaler) UnmarshalCQL(info gocql.TypeInfo, data []byte) error { var s string err := gocql.Unmarshal(info, data, &s) if err != nil { return err } parts := strings.SplitN(s, ".", 3) if len(parts) != 3 { return fmt.Errorf("parse version %q: %d parts instead of 3", s, len(parts)) } major, err := strconv.Atoi(parts[0]) if err != nil { return fmt.Errorf("parse version %q major number: %v", s, err) } minor, err := strconv.Atoi(parts[1]) if err != nil { return fmt.Errorf("parse version %q minor number: %v", s, err) } patch, err := strconv.Atoi(parts[2]) if err != nil { return fmt.Errorf("parse version %q patch number: %v", s, err) } m.major = major m.minor = minor m.patch = patch return nil } // Example_marshalerUnmarshaler demonstrates how to implement a Marshaler and Unmarshaler. func Example_marshalerUnmarshaler() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create table example.my_marshaler_table(pk int, value text, PRIMARY KEY(pk)); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() ctx := context.Background() value := MyMarshaler{ major: 1, minor: 2, patch: 3, } err = session.Query("INSERT INTO example.my_marshaler_table (pk, value) VALUES (?, ?)", 1, value).WithContext(ctx).Exec() if err != nil { log.Fatal(err) } var stringValue string err = session.Query("SELECT value FROM example.my_marshaler_table WHERE pk = 1").WithContext(ctx). Scan(&stringValue) if err != nil { log.Fatal(err) } fmt.Println(stringValue) var unmarshaledValue MyMarshaler err = session.Query("SELECT value FROM example.my_marshaler_table WHERE pk = 1").WithContext(ctx). Scan(&unmarshaledValue) if err != nil { log.Fatal(err) } fmt.Println(unmarshaledValue) // 1.2.3 // {1 2 3} } cassandra-gocql-driver-1.7.0/example_nulls_test.go000066400000000000000000000046721467504044300223300ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "fmt" "log" gocql "github.com/gocql/gocql" ) // Example_nulls demonstrates how to distinguish between null and zero value when needed. // // Null values are unmarshalled as zero value of the type. If you need to distinguish for example between text // column being null and empty string, you can unmarshal into *string field. func Example_nulls() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create table example.stringvals(id int, value text, PRIMARY KEY(id)); insert into example.stringvals (id, value) values (1, null); insert into example.stringvals (id, value) values (2, ''); insert into example.stringvals (id, value) values (3, 'hello'); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() scanner := session.Query(`SELECT id, value FROM stringvals`).Iter().Scanner() for scanner.Next() { var ( id int32 val *string ) err := scanner.Scan(&id, &val) if err != nil { log.Fatal(err) } if val != nil { fmt.Printf("Row %d is %q\n", id, *val) } else { fmt.Printf("Row %d is null\n", id) } } err = scanner.Err() if err != nil { log.Fatal(err) } // Row 1 is null // Row 2 is "" // Row 3 is "hello" } cassandra-gocql-driver-1.7.0/example_paging_test.go000066400000000000000000000060221467504044300224270ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "fmt" "log" gocql "github.com/gocql/gocql" ) // Example_paging demonstrates how to manually fetch pages and use page state. // // See also package documentation about paging. func Example_paging() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create table example.itoa(id int, description text, PRIMARY KEY(id)); insert into example.itoa (id, description) values (1, 'one'); insert into example.itoa (id, description) values (2, 'two'); insert into example.itoa (id, description) values (3, 'three'); insert into example.itoa (id, description) values (4, 'four'); insert into example.itoa (id, description) values (5, 'five'); insert into example.itoa (id, description) values (6, 'six'); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() var pageState []byte for { // We use PageSize(2) for the sake of example, use larger values in production (default is 5000) for performance // reasons. iter := session.Query(`SELECT id, description FROM itoa`).PageSize(2).PageState(pageState).Iter() nextPageState := iter.PageState() scanner := iter.Scanner() for scanner.Next() { var ( id int description string ) err = scanner.Scan(&id, &description) if err != nil { log.Fatal(err) } fmt.Println(id, description) } err = scanner.Err() if err != nil { log.Fatal(err) } fmt.Printf("next page state: %+v\n", nextPageState) if len(nextPageState) == 0 { break } pageState = nextPageState } // 5 five // 1 one // next page state: [4 0 0 0 1 0 240 127 255 255 253 0] // 2 two // 4 four // next page state: [4 0 0 0 4 0 240 127 255 255 251 0] // 6 six // 3 three // next page state: [4 0 0 0 3 0 240 127 255 255 249 0] // next page state: [] } cassandra-gocql-driver-1.7.0/example_set_test.go000066400000000000000000000053001467504044300217530ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "fmt" "log" "sort" gocql "github.com/gocql/gocql" ) // Example_set demonstrates how to use sets. func Example_set() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create table example.sets(id int, value set, PRIMARY KEY(id)); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() err = session.Query(`UPDATE sets SET value=? WHERE id=1`, []string{"alpha", "beta", "gamma"}).Exec() if err != nil { log.Fatal(err) } err = session.Query(`UPDATE sets SET value=value+? WHERE id=1`, "epsilon").Exec() if err != nil { // This does not work because the ? expects a set, not a single item. fmt.Printf("expected error: %v\n", err) } err = session.Query(`UPDATE sets SET value=value+? WHERE id=1`, []string{"delta"}).Exec() if err != nil { log.Fatal(err) } // map[x]struct{} is supported too. toRemove := map[string]struct{}{ "alpha": {}, "gamma": {}, } err = session.Query(`UPDATE sets SET value=value-? WHERE id=1`, toRemove).Exec() if err != nil { log.Fatal(err) } scanner := session.Query(`SELECT id, value FROM sets`).Iter().Scanner() for scanner.Next() { var ( id int32 val []string ) err := scanner.Scan(&id, &val) if err != nil { log.Fatal(err) } sort.Strings(val) fmt.Printf("Row %d is %v\n", id, val) } err = scanner.Err() if err != nil { log.Fatal(err) } // expected error: can not marshal string into set(varchar) // Row 1 is [beta delta] } cassandra-gocql-driver-1.7.0/example_test.go000066400000000000000000000056441467504044300211130ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "fmt" "log" gocql "github.com/gocql/gocql" ) func Example() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create table example.tweet(timeline text, id UUID, text text, PRIMARY KEY(id)); create index on example.tweet(timeline); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.Consistency = gocql.Quorum // connect to the cluster session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() ctx := context.Background() // insert a tweet if err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`, "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec(); err != nil { log.Fatal(err) } var id gocql.UUID var text string /* Search for a specific set of records whose 'timeline' column matches * the value 'me'. The secondary index that we created earlier will be * used for optimizing the search */ if err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`, "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text); err != nil { log.Fatal(err) } fmt.Println("Tweet:", id, text) fmt.Println() // list all tweets scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`, "me").WithContext(ctx).Iter().Scanner() for scanner.Next() { err = scanner.Scan(&id, &text) if err != nil { log.Fatal(err) } fmt.Println("Tweet:", id, text) } // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards. if err := scanner.Err(); err != nil { log.Fatal(err) } // Tweet: cad53821-3731-11eb-971c-708bcdaada84 hello world // // Tweet: cad53821-3731-11eb-971c-708bcdaada84 hello world // Tweet: d577ab85-3731-11eb-81eb-708bcdaada84 hello world } cassandra-gocql-driver-1.7.0/example_udt_map_test.go000066400000000000000000000046231467504044300226200ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "fmt" "log" gocql "github.com/gocql/gocql" ) // Example_userDefinedTypesMap demonstrates how to work with user-defined types as maps. // See also Example_userDefinedTypesStruct and examples for UDTMarshaler and UDTUnmarshaler if you want to map to structs. func Example_userDefinedTypesMap() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create type example.my_udt (field_a text, field_b int); create table example.my_udt_table(pk int, value frozen, PRIMARY KEY(pk)); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() ctx := context.Background() value := map[string]interface{}{ "field_a": "a value", "field_b": 42, } err = session.Query("INSERT INTO example.my_udt_table (pk, value) VALUES (?, ?)", 1, value).WithContext(ctx).Exec() if err != nil { log.Fatal(err) } var readValue map[string]interface{} err = session.Query("SELECT value FROM example.my_udt_table WHERE pk = 1").WithContext(ctx).Scan(&readValue) if err != nil { log.Fatal(err) } fmt.Println(readValue["field_a"]) fmt.Println(readValue["field_b"]) // a value // 42 } cassandra-gocql-driver-1.7.0/example_udt_marshaler_test.go000066400000000000000000000051571467504044300240240ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "log" gocql "github.com/gocql/gocql" ) // MyUDTMarshaler implements UDTMarshaler. type MyUDTMarshaler struct { fieldA string fieldB int32 } // MarshalUDT marshals the selected field to bytes. func (m MyUDTMarshaler) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { switch name { case "field_a": return gocql.Marshal(info, m.fieldA) case "field_b": return gocql.Marshal(info, m.fieldB) default: // If you want to be strict and return error un unknown field, you can do so here instead. // Returning nil, nil will set the value of unknown fields to null, which might be handy if you want // to be forward-compatible when a new field is added to the UDT. return nil, nil } } // ExampleUDTMarshaler demonstrates how to implement a UDTMarshaler. func ExampleUDTMarshaler() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create type example.my_udt (field_a text, field_b int); create table example.my_udt_table(pk int, value frozen, PRIMARY KEY(pk)); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() ctx := context.Background() value := MyUDTMarshaler{ fieldA: "a value", fieldB: 42, } err = session.Query("INSERT INTO example.my_udt_table (pk, value) VALUES (?, ?)", 1, value).WithContext(ctx).Exec() if err != nil { log.Fatal(err) } } cassandra-gocql-driver-1.7.0/example_udt_struct_test.go000066400000000000000000000046541467504044300233730ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "fmt" "log" gocql "github.com/gocql/gocql" ) type MyUDT struct { FieldA string `cql:"field_a"` FieldB int32 `cql:"field_b"` } // Example_userDefinedTypesStruct demonstrates how to work with user-defined types as structs. // See also examples for UDTMarshaler and UDTUnmarshaler if you need more control/better performance. func Example_userDefinedTypesStruct() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create type example.my_udt (field_a text, field_b int); create table example.my_udt_table(pk int, value frozen, PRIMARY KEY(pk)); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() ctx := context.Background() value := MyUDT{ FieldA: "a value", FieldB: 42, } err = session.Query("INSERT INTO example.my_udt_table (pk, value) VALUES (?, ?)", 1, value).WithContext(ctx).Exec() if err != nil { log.Fatal(err) } var readValue MyUDT err = session.Query("SELECT value FROM example.my_udt_table WHERE pk = 1").WithContext(ctx).Scan(&readValue) if err != nil { log.Fatal(err) } fmt.Println(readValue.FieldA) fmt.Println(readValue.FieldB) // a value // 42 } cassandra-gocql-driver-1.7.0/example_udt_unmarshaler_test.go000066400000000000000000000054201467504044300243600ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql_test import ( "context" "fmt" "log" gocql "github.com/gocql/gocql" ) // MyUDTUnmarshaler implements UDTUnmarshaler. type MyUDTUnmarshaler struct { fieldA string fieldB int32 } // UnmarshalUDT unmarshals the field identified by name into MyUDTUnmarshaler. func (m *MyUDTUnmarshaler) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error { switch name { case "field_a": return gocql.Unmarshal(info, data, &m.fieldA) case "field_b": return gocql.Unmarshal(info, data, &m.fieldB) default: // If you want to be strict and return error un unknown field, you can do so here instead. // Returning nil will ignore unknown fields, which might be handy if you want // to be forward-compatible when a new field is added to the UDT. return nil } } // ExampleUDTUnmarshaler demonstrates how to implement a UDTUnmarshaler. func ExampleUDTUnmarshaler() { /* The example assumes the following CQL was used to setup the keyspace: create keyspace example with replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; create type example.my_udt (field_a text, field_b int); create table example.my_udt_table(pk int, value frozen, PRIMARY KEY(pk)); insert into example.my_udt_table (pk, value) values (1, {field_a: 'a value', field_b: 42}); */ cluster := gocql.NewCluster("localhost:9042") cluster.Keyspace = "example" cluster.ProtoVersion = 4 session, err := cluster.CreateSession() if err != nil { log.Fatal(err) } defer session.Close() ctx := context.Background() var value MyUDTUnmarshaler err = session.Query("SELECT value FROM example.my_udt_table WHERE pk = 1").WithContext(ctx).Scan(&value) if err != nil { log.Fatal(err) } fmt.Println(value.fieldA) fmt.Println(value.fieldB) // a value // 42 } cassandra-gocql-driver-1.7.0/filters.go000066400000000000000000000051341467504044300200630ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import "fmt" // HostFilter interface is used when a host is discovered via server sent events. type HostFilter interface { // Called when a new host is discovered, returning true will cause the host // to be added to the pools. Accept(host *HostInfo) bool } // HostFilterFunc converts a func(host HostInfo) bool into a HostFilter type HostFilterFunc func(host *HostInfo) bool func (fn HostFilterFunc) Accept(host *HostInfo) bool { return fn(host) } // AcceptAllFilter will accept all hosts func AcceptAllFilter() HostFilter { return HostFilterFunc(func(host *HostInfo) bool { return true }) } func DenyAllFilter() HostFilter { return HostFilterFunc(func(host *HostInfo) bool { return false }) } // DataCentreHostFilter filters all hosts such that they are in the same data centre // as the supplied data centre. func DataCentreHostFilter(dataCentre string) HostFilter { return HostFilterFunc(func(host *HostInfo) bool { return host.DataCenter() == dataCentre }) } // WhiteListHostFilter filters incoming hosts by checking that their address is // in the initial hosts whitelist. func WhiteListHostFilter(hosts ...string) HostFilter { hostInfos, err := addrsToHosts(hosts, 9042, nopLogger{}) if err != nil { // dont want to panic here, but rather not break the API panic(fmt.Errorf("unable to lookup host info from address: %v", err)) } m := make(map[string]bool, len(hostInfos)) for _, host := range hostInfos { m[host.ConnectAddress().String()] = true } return HostFilterFunc(func(host *HostInfo) bool { return m[host.ConnectAddress().String()] }) } cassandra-gocql-driver-1.7.0/filters_test.go000066400000000000000000000057601467504044300211270ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "net" "testing" ) func TestFilter_WhiteList(t *testing.T) { f := WhiteListHostFilter("127.0.0.1", "127.0.0.2") tests := [...]struct { addr net.IP accept bool }{ {net.ParseIP("127.0.0.1"), true}, {net.ParseIP("127.0.0.2"), true}, {net.ParseIP("127.0.0.3"), false}, } for i, test := range tests { if f.Accept(&HostInfo{connectAddress: test.addr}) { if !test.accept { t.Errorf("%d: should not have been accepted but was", i) } } else if test.accept { t.Errorf("%d: should have been accepted but wasn't", i) } } } func TestFilter_AllowAll(t *testing.T) { f := AcceptAllFilter() tests := [...]struct { addr net.IP accept bool }{ {net.ParseIP("127.0.0.1"), true}, {net.ParseIP("127.0.0.2"), true}, {net.ParseIP("127.0.0.3"), true}, } for i, test := range tests { if f.Accept(&HostInfo{connectAddress: test.addr}) { if !test.accept { t.Errorf("%d: should not have been accepted but was", i) } } else if test.accept { t.Errorf("%d: should have been accepted but wasn't", i) } } } func TestFilter_DenyAll(t *testing.T) { f := DenyAllFilter() tests := [...]struct { addr net.IP accept bool }{ {net.ParseIP("127.0.0.1"), false}, {net.ParseIP("127.0.0.2"), false}, {net.ParseIP("127.0.0.3"), false}, } for i, test := range tests { if f.Accept(&HostInfo{connectAddress: test.addr}) { if !test.accept { t.Errorf("%d: should not have been accepted but was", i) } } else if test.accept { t.Errorf("%d: should have been accepted but wasn't", i) } } } func TestFilter_DataCentre(t *testing.T) { f := DataCentreHostFilter("dc1") tests := [...]struct { dc string accept bool }{ {"dc1", true}, {"dc2", false}, } for i, test := range tests { if f.Accept(&HostInfo{dataCenter: test.dc}) { if !test.accept { t.Errorf("%d: should not have been accepted but was", i) } } else if test.accept { t.Errorf("%d: should have been accepted but wasn't", i) } } } cassandra-gocql-driver-1.7.0/frame.go000066400000000000000000001315401467504044300175060ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "context" "errors" "fmt" "io" "io/ioutil" "net" "runtime" "strings" "time" ) type unsetColumn struct{} // UnsetValue represents a value used in a query binding that will be ignored by Cassandra. // // By setting a field to the unset value Cassandra will ignore the write completely. // The main advantage is the ability to keep the same prepared statement even when you don't // want to update some fields, where before you needed to make another prepared statement. // // UnsetValue is only available when using the version 4 of the protocol. var UnsetValue = unsetColumn{} type namedValue struct { name string value interface{} } // NamedValue produce a value which will bind to the named parameter in a query func NamedValue(name string, value interface{}) interface{} { return &namedValue{ name: name, value: value, } } const ( protoDirectionMask = 0x80 protoVersionMask = 0x7F protoVersion1 = 0x01 protoVersion2 = 0x02 protoVersion3 = 0x03 protoVersion4 = 0x04 protoVersion5 = 0x05 maxFrameSize = 256 * 1024 * 1024 ) type protoVersion byte func (p protoVersion) request() bool { return p&protoDirectionMask == 0x00 } func (p protoVersion) response() bool { return p&protoDirectionMask == 0x80 } func (p protoVersion) version() byte { return byte(p) & protoVersionMask } func (p protoVersion) String() string { dir := "REQ" if p.response() { dir = "RESP" } return fmt.Sprintf("[version=%d direction=%s]", p.version(), dir) } type frameOp byte const ( // header ops opError frameOp = 0x00 opStartup frameOp = 0x01 opReady frameOp = 0x02 opAuthenticate frameOp = 0x03 opOptions frameOp = 0x05 opSupported frameOp = 0x06 opQuery frameOp = 0x07 opResult frameOp = 0x08 opPrepare frameOp = 0x09 opExecute frameOp = 0x0A opRegister frameOp = 0x0B opEvent frameOp = 0x0C opBatch frameOp = 0x0D opAuthChallenge frameOp = 0x0E opAuthResponse frameOp = 0x0F opAuthSuccess frameOp = 0x10 ) func (f frameOp) String() string { switch f { case opError: return "ERROR" case opStartup: return "STARTUP" case opReady: return "READY" case opAuthenticate: return "AUTHENTICATE" case opOptions: return "OPTIONS" case opSupported: return "SUPPORTED" case opQuery: return "QUERY" case opResult: return "RESULT" case opPrepare: return "PREPARE" case opExecute: return "EXECUTE" case opRegister: return "REGISTER" case opEvent: return "EVENT" case opBatch: return "BATCH" case opAuthChallenge: return "AUTH_CHALLENGE" case opAuthResponse: return "AUTH_RESPONSE" case opAuthSuccess: return "AUTH_SUCCESS" default: return fmt.Sprintf("UNKNOWN_OP_%d", f) } } const ( // result kind resultKindVoid = 1 resultKindRows = 2 resultKindKeyspace = 3 resultKindPrepared = 4 resultKindSchemaChanged = 5 // rows flags flagGlobalTableSpec int = 0x01 flagHasMorePages int = 0x02 flagNoMetaData int = 0x04 // query flags flagValues byte = 0x01 flagSkipMetaData byte = 0x02 flagPageSize byte = 0x04 flagWithPagingState byte = 0x08 flagWithSerialConsistency byte = 0x10 flagDefaultTimestamp byte = 0x20 flagWithNameValues byte = 0x40 flagWithKeyspace byte = 0x80 // prepare flags flagWithPreparedKeyspace uint32 = 0x01 // header flags flagCompress byte = 0x01 flagTracing byte = 0x02 flagCustomPayload byte = 0x04 flagWarning byte = 0x08 flagBetaProtocol byte = 0x10 ) type Consistency uint16 const ( Any Consistency = 0x00 One Consistency = 0x01 Two Consistency = 0x02 Three Consistency = 0x03 Quorum Consistency = 0x04 All Consistency = 0x05 LocalQuorum Consistency = 0x06 EachQuorum Consistency = 0x07 LocalOne Consistency = 0x0A ) func (c Consistency) String() string { switch c { case Any: return "ANY" case One: return "ONE" case Two: return "TWO" case Three: return "THREE" case Quorum: return "QUORUM" case All: return "ALL" case LocalQuorum: return "LOCAL_QUORUM" case EachQuorum: return "EACH_QUORUM" case LocalOne: return "LOCAL_ONE" default: return fmt.Sprintf("UNKNOWN_CONS_0x%x", uint16(c)) } } func (c Consistency) MarshalText() (text []byte, err error) { return []byte(c.String()), nil } func (c *Consistency) UnmarshalText(text []byte) error { switch string(text) { case "ANY": *c = Any case "ONE": *c = One case "TWO": *c = Two case "THREE": *c = Three case "QUORUM": *c = Quorum case "ALL": *c = All case "LOCAL_QUORUM": *c = LocalQuorum case "EACH_QUORUM": *c = EachQuorum case "LOCAL_ONE": *c = LocalOne default: return fmt.Errorf("invalid consistency %q", string(text)) } return nil } func ParseConsistency(s string) Consistency { var c Consistency if err := c.UnmarshalText([]byte(strings.ToUpper(s))); err != nil { panic(err) } return c } // ParseConsistencyWrapper wraps gocql.ParseConsistency to provide an err // return instead of a panic func ParseConsistencyWrapper(s string) (consistency Consistency, err error) { err = consistency.UnmarshalText([]byte(strings.ToUpper(s))) return } // MustParseConsistency is the same as ParseConsistency except it returns // an error (never). It is kept here since breaking changes are not good. // DEPRECATED: use ParseConsistency if you want a panic on parse error. func MustParseConsistency(s string) (Consistency, error) { c, err := ParseConsistencyWrapper(s) if err != nil { panic(err) } return c, nil } type SerialConsistency uint16 const ( Serial SerialConsistency = 0x08 LocalSerial SerialConsistency = 0x09 ) func (s SerialConsistency) String() string { switch s { case Serial: return "SERIAL" case LocalSerial: return "LOCAL_SERIAL" default: return fmt.Sprintf("UNKNOWN_SERIAL_CONS_0x%x", uint16(s)) } } func (s SerialConsistency) MarshalText() (text []byte, err error) { return []byte(s.String()), nil } func (s *SerialConsistency) UnmarshalText(text []byte) error { switch string(text) { case "SERIAL": *s = Serial case "LOCAL_SERIAL": *s = LocalSerial default: return fmt.Errorf("invalid consistency %q", string(text)) } return nil } const ( apacheCassandraTypePrefix = "org.apache.cassandra.db.marshal." ) var ( ErrFrameTooBig = errors.New("frame length is bigger than the maximum allowed") ) const maxFrameHeaderSize = 9 func readInt(p []byte) int32 { return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) } type frameHeader struct { version protoVersion flags byte stream int op frameOp length int warnings []string } func (f frameHeader) String() string { return fmt.Sprintf("[header version=%s flags=0x%x stream=%d op=%s length=%d]", f.version, f.flags, f.stream, f.op, f.length) } func (f frameHeader) Header() frameHeader { return f } const defaultBufSize = 128 type ObservedFrameHeader struct { Version protoVersion Flags byte Stream int16 Opcode frameOp Length int32 // StartHeader is the time we started reading the frame header off the network connection. Start time.Time // EndHeader is the time we finished reading the frame header off the network connection. End time.Time // Host is Host of the connection the frame header was read from. Host *HostInfo } func (f ObservedFrameHeader) String() string { return fmt.Sprintf("[observed header version=%s flags=0x%x stream=%d op=%s length=%d]", f.Version, f.Flags, f.Stream, f.Opcode, f.Length) } // FrameHeaderObserver is the interface implemented by frame observers / stat collectors. // // Experimental, this interface and use may change type FrameHeaderObserver interface { // ObserveFrameHeader gets called on every received frame header. ObserveFrameHeader(context.Context, ObservedFrameHeader) } // a framer is responsible for reading, writing and parsing frames on a single stream type framer struct { proto byte // flags are for outgoing flags, enabling compression and tracing etc flags byte compres Compressor headSize int // if this frame was read then the header will be here header *frameHeader // if tracing flag is set this is not nil traceID []byte // holds a ref to the whole byte slice for buf so that it can be reset to // 0 after a read. readBuffer []byte buf []byte customPayload map[string][]byte } func newFramer(compressor Compressor, version byte) *framer { buf := make([]byte, defaultBufSize) f := &framer{ buf: buf[:0], readBuffer: buf, } var flags byte if compressor != nil { flags |= flagCompress } if version == protoVersion5 { flags |= flagBetaProtocol } version &= protoVersionMask headSize := 8 if version > protoVersion2 { headSize = 9 } f.compres = compressor f.proto = version f.flags = flags f.headSize = headSize f.header = nil f.traceID = nil return f } type frame interface { Header() frameHeader } func readHeader(r io.Reader, p []byte) (head frameHeader, err error) { _, err = io.ReadFull(r, p[:1]) if err != nil { return frameHeader{}, err } version := p[0] & protoVersionMask if version < protoVersion1 || version > protoVersion5 { return frameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version) } headSize := 9 if version < protoVersion3 { headSize = 8 } _, err = io.ReadFull(r, p[1:headSize]) if err != nil { return frameHeader{}, err } p = p[:headSize] head.version = protoVersion(p[0]) head.flags = p[1] if version > protoVersion2 { if len(p) != 9 { return frameHeader{}, fmt.Errorf("not enough bytes to read header require 9 got: %d", len(p)) } head.stream = int(int16(p[2])<<8 | int16(p[3])) head.op = frameOp(p[4]) head.length = int(readInt(p[5:])) } else { if len(p) != 8 { return frameHeader{}, fmt.Errorf("not enough bytes to read header require 8 got: %d", len(p)) } head.stream = int(int8(p[2])) head.op = frameOp(p[3]) head.length = int(readInt(p[4:])) } return head, nil } // explicitly enables tracing for the framers outgoing requests func (f *framer) trace() { f.flags |= flagTracing } // explicitly enables the custom payload flag func (f *framer) payload() { f.flags |= flagCustomPayload } // reads a frame form the wire into the framers buffer func (f *framer) readFrame(r io.Reader, head *frameHeader) error { if head.length < 0 { return fmt.Errorf("frame body length can not be less than 0: %d", head.length) } else if head.length > maxFrameSize { // need to free up the connection to be used again _, err := io.CopyN(ioutil.Discard, r, int64(head.length)) if err != nil { return fmt.Errorf("error whilst trying to discard frame with invalid length: %v", err) } return ErrFrameTooBig } if cap(f.readBuffer) >= head.length { f.buf = f.readBuffer[:head.length] } else { f.readBuffer = make([]byte, head.length) f.buf = f.readBuffer } // assume the underlying reader takes care of timeouts and retries n, err := io.ReadFull(r, f.buf) if err != nil { return fmt.Errorf("unable to read frame body: read %d/%d bytes: %v", n, head.length, err) } if head.flags&flagCompress == flagCompress { if f.compres == nil { return NewErrProtocol("no compressor available with compressed frame body") } f.buf, err = f.compres.Decode(f.buf) if err != nil { return err } } f.header = head return nil } func (f *framer) parseFrame() (frame frame, err error) { defer func() { if r := recover(); r != nil { if _, ok := r.(runtime.Error); ok { panic(r) } err = r.(error) } }() if f.header.version.request() { return nil, NewErrProtocol("got a request frame from server: %v", f.header.version) } if f.header.flags&flagTracing == flagTracing { f.readTrace() } if f.header.flags&flagWarning == flagWarning { f.header.warnings = f.readStringList() } if f.header.flags&flagCustomPayload == flagCustomPayload { f.customPayload = f.readBytesMap() } // assumes that the frame body has been read into rbuf switch f.header.op { case opError: frame = f.parseErrorFrame() case opReady: frame = f.parseReadyFrame() case opResult: frame, err = f.parseResultFrame() case opSupported: frame = f.parseSupportedFrame() case opAuthenticate: frame = f.parseAuthenticateFrame() case opAuthChallenge: frame = f.parseAuthChallengeFrame() case opAuthSuccess: frame = f.parseAuthSuccessFrame() case opEvent: frame = f.parseEventFrame() default: return nil, NewErrProtocol("unknown op in frame header: %s", f.header.op) } return } func (f *framer) parseErrorFrame() frame { code := f.readInt() msg := f.readString() errD := errorFrame{ frameHeader: *f.header, code: code, message: msg, } switch code { case ErrCodeUnavailable: cl := f.readConsistency() required := f.readInt() alive := f.readInt() return &RequestErrUnavailable{ errorFrame: errD, Consistency: cl, Required: required, Alive: alive, } case ErrCodeWriteTimeout: cl := f.readConsistency() received := f.readInt() blockfor := f.readInt() writeType := f.readString() return &RequestErrWriteTimeout{ errorFrame: errD, Consistency: cl, Received: received, BlockFor: blockfor, WriteType: writeType, } case ErrCodeReadTimeout: cl := f.readConsistency() received := f.readInt() blockfor := f.readInt() dataPresent := f.readByte() return &RequestErrReadTimeout{ errorFrame: errD, Consistency: cl, Received: received, BlockFor: blockfor, DataPresent: dataPresent, } case ErrCodeAlreadyExists: ks := f.readString() table := f.readString() return &RequestErrAlreadyExists{ errorFrame: errD, Keyspace: ks, Table: table, } case ErrCodeUnprepared: stmtId := f.readShortBytes() return &RequestErrUnprepared{ errorFrame: errD, StatementId: copyBytes(stmtId), // defensively copy } case ErrCodeReadFailure: res := &RequestErrReadFailure{ errorFrame: errD, } res.Consistency = f.readConsistency() res.Received = f.readInt() res.BlockFor = f.readInt() if f.proto > protoVersion4 { res.ErrorMap = f.readErrorMap() res.NumFailures = len(res.ErrorMap) } else { res.NumFailures = f.readInt() } res.DataPresent = f.readByte() != 0 return res case ErrCodeWriteFailure: res := &RequestErrWriteFailure{ errorFrame: errD, } res.Consistency = f.readConsistency() res.Received = f.readInt() res.BlockFor = f.readInt() if f.proto > protoVersion4 { res.ErrorMap = f.readErrorMap() res.NumFailures = len(res.ErrorMap) } else { res.NumFailures = f.readInt() } res.WriteType = f.readString() return res case ErrCodeFunctionFailure: res := &RequestErrFunctionFailure{ errorFrame: errD, } res.Keyspace = f.readString() res.Function = f.readString() res.ArgTypes = f.readStringList() return res case ErrCodeCDCWriteFailure: res := &RequestErrCDCWriteFailure{ errorFrame: errD, } return res case ErrCodeCASWriteUnknown: res := &RequestErrCASWriteUnknown{ errorFrame: errD, } res.Consistency = f.readConsistency() res.Received = f.readInt() res.BlockFor = f.readInt() return res case ErrCodeInvalid, ErrCodeBootstrapping, ErrCodeConfig, ErrCodeCredentials, ErrCodeOverloaded, ErrCodeProtocol, ErrCodeServer, ErrCodeSyntax, ErrCodeTruncate, ErrCodeUnauthorized: // TODO(zariel): we should have some distinct types for these errors return errD default: panic(fmt.Errorf("unknown error code: 0x%x", errD.code)) } } func (f *framer) readErrorMap() (errMap ErrorMap) { errMap = make(ErrorMap) numErrs := f.readInt() for i := 0; i < numErrs; i++ { ip := f.readInetAdressOnly().String() errMap[ip] = f.readShort() } return } func (f *framer) writeHeader(flags byte, op frameOp, stream int) { f.buf = f.buf[:0] f.buf = append(f.buf, f.proto, flags, ) if f.proto > protoVersion2 { f.buf = append(f.buf, byte(stream>>8), byte(stream), ) } else { f.buf = append(f.buf, byte(stream), ) } // pad out length f.buf = append(f.buf, byte(op), 0, 0, 0, 0, ) } func (f *framer) setLength(length int) { p := 4 if f.proto > protoVersion2 { p = 5 } f.buf[p+0] = byte(length >> 24) f.buf[p+1] = byte(length >> 16) f.buf[p+2] = byte(length >> 8) f.buf[p+3] = byte(length) } func (f *framer) finish() error { if len(f.buf) > maxFrameSize { // huge app frame, lets remove it so it doesn't bloat the heap f.buf = make([]byte, defaultBufSize) return ErrFrameTooBig } if f.buf[1]&flagCompress == flagCompress { if f.compres == nil { panic("compress flag set with no compressor") } // TODO: only compress frames which are big enough compressed, err := f.compres.Encode(f.buf[f.headSize:]) if err != nil { return err } f.buf = append(f.buf[:f.headSize], compressed...) } length := len(f.buf) - f.headSize f.setLength(length) return nil } func (f *framer) writeTo(w io.Writer) error { _, err := w.Write(f.buf) return err } func (f *framer) readTrace() { f.traceID = f.readUUID().Bytes() } type readyFrame struct { frameHeader } func (f *framer) parseReadyFrame() frame { return &readyFrame{ frameHeader: *f.header, } } type supportedFrame struct { frameHeader supported map[string][]string } // TODO: if we move the body buffer onto the frameHeader then we only need a single // framer, and can move the methods onto the header. func (f *framer) parseSupportedFrame() frame { return &supportedFrame{ frameHeader: *f.header, supported: f.readStringMultiMap(), } } type writeStartupFrame struct { opts map[string]string } func (w writeStartupFrame) String() string { return fmt.Sprintf("[startup opts=%+v]", w.opts) } func (w *writeStartupFrame) buildFrame(f *framer, streamID int) error { f.writeHeader(f.flags&^flagCompress, opStartup, streamID) f.writeStringMap(w.opts) return f.finish() } type writePrepareFrame struct { statement string keyspace string customPayload map[string][]byte } func (w *writePrepareFrame) buildFrame(f *framer, streamID int) error { if len(w.customPayload) > 0 { f.payload() } f.writeHeader(f.flags, opPrepare, streamID) f.writeCustomPayload(&w.customPayload) f.writeLongString(w.statement) var flags uint32 = 0 if w.keyspace != "" { if f.proto > protoVersion4 { flags |= flagWithPreparedKeyspace } else { panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) } } if f.proto > protoVersion4 { f.writeUint(flags) } if w.keyspace != "" { f.writeString(w.keyspace) } return f.finish() } func (f *framer) readTypeInfo() TypeInfo { // TODO: factor this out so the same code paths can be used to parse custom // types and other types, as much of the logic will be duplicated. id := f.readShort() simple := NativeType{ proto: f.proto, typ: Type(id), } if simple.typ == TypeCustom { simple.custom = f.readString() if cassType := getApacheCassandraType(simple.custom); cassType != TypeCustom { simple.typ = cassType } } switch simple.typ { case TypeTuple: n := f.readShort() tuple := TupleTypeInfo{ NativeType: simple, Elems: make([]TypeInfo, n), } for i := 0; i < int(n); i++ { tuple.Elems[i] = f.readTypeInfo() } return tuple case TypeUDT: udt := UDTTypeInfo{ NativeType: simple, } udt.KeySpace = f.readString() udt.Name = f.readString() n := f.readShort() udt.Elements = make([]UDTField, n) for i := 0; i < int(n); i++ { field := &udt.Elements[i] field.Name = f.readString() field.Type = f.readTypeInfo() } return udt case TypeMap, TypeList, TypeSet: collection := CollectionType{ NativeType: simple, } if simple.typ == TypeMap { collection.Key = f.readTypeInfo() } collection.Elem = f.readTypeInfo() return collection } return simple } type preparedMetadata struct { resultMetadata // proto v4+ pkeyColumns []int keyspace string table string } func (r preparedMetadata) String() string { return fmt.Sprintf("[prepared flags=0x%x pkey=%v paging_state=% X columns=%v col_count=%d actual_col_count=%d]", r.flags, r.pkeyColumns, r.pagingState, r.columns, r.colCount, r.actualColCount) } func (f *framer) parsePreparedMetadata() preparedMetadata { // TODO: deduplicate this from parseMetadata meta := preparedMetadata{} meta.flags = f.readInt() meta.colCount = f.readInt() if meta.colCount < 0 { panic(fmt.Errorf("received negative column count: %d", meta.colCount)) } meta.actualColCount = meta.colCount if f.proto >= protoVersion4 { pkeyCount := f.readInt() pkeys := make([]int, pkeyCount) for i := 0; i < pkeyCount; i++ { pkeys[i] = int(f.readShort()) } meta.pkeyColumns = pkeys } if meta.flags&flagHasMorePages == flagHasMorePages { meta.pagingState = copyBytes(f.readBytes()) } if meta.flags&flagNoMetaData == flagNoMetaData { return meta } globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec if globalSpec { meta.keyspace = f.readString() meta.table = f.readString() } var cols []ColumnInfo if meta.colCount < 1000 { // preallocate columninfo to avoid excess copying cols = make([]ColumnInfo, meta.colCount) for i := 0; i < meta.colCount; i++ { f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) } } else { // use append, huge number of columns usually indicates a corrupt frame or // just a huge row. for i := 0; i < meta.colCount; i++ { var col ColumnInfo f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) cols = append(cols, col) } } meta.columns = cols return meta } type resultMetadata struct { flags int // only if flagPageState pagingState []byte columns []ColumnInfo colCount int // this is a count of the total number of columns which can be scanned, // it is at minimum len(columns) but may be larger, for instance when a column // is a UDT or tuple. actualColCount int } func (r *resultMetadata) morePages() bool { return r.flags&flagHasMorePages == flagHasMorePages } func (r resultMetadata) String() string { return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns) } func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string) { if !globalSpec { col.Keyspace = f.readString() col.Table = f.readString() } else { col.Keyspace = keyspace col.Table = table } col.Name = f.readString() col.TypeInfo = f.readTypeInfo() switch v := col.TypeInfo.(type) { // maybe also UDT case TupleTypeInfo: // -1 because we already included the tuple column meta.actualColCount += len(v.Elems) - 1 } } func (f *framer) parseResultMetadata() resultMetadata { var meta resultMetadata meta.flags = f.readInt() meta.colCount = f.readInt() if meta.colCount < 0 { panic(fmt.Errorf("received negative column count: %d", meta.colCount)) } meta.actualColCount = meta.colCount if meta.flags&flagHasMorePages == flagHasMorePages { meta.pagingState = copyBytes(f.readBytes()) } if meta.flags&flagNoMetaData == flagNoMetaData { return meta } var keyspace, table string globalSpec := meta.flags&flagGlobalTableSpec == flagGlobalTableSpec if globalSpec { keyspace = f.readString() table = f.readString() } var cols []ColumnInfo if meta.colCount < 1000 { // preallocate columninfo to avoid excess copying cols = make([]ColumnInfo, meta.colCount) for i := 0; i < meta.colCount; i++ { f.readCol(&cols[i], &meta, globalSpec, keyspace, table) } } else { // use append, huge number of columns usually indicates a corrupt frame or // just a huge row. for i := 0; i < meta.colCount; i++ { var col ColumnInfo f.readCol(&col, &meta, globalSpec, keyspace, table) cols = append(cols, col) } } meta.columns = cols return meta } type resultVoidFrame struct { frameHeader } func (f *resultVoidFrame) String() string { return "[result_void]" } func (f *framer) parseResultFrame() (frame, error) { kind := f.readInt() switch kind { case resultKindVoid: return &resultVoidFrame{frameHeader: *f.header}, nil case resultKindRows: return f.parseResultRows(), nil case resultKindKeyspace: return f.parseResultSetKeyspace(), nil case resultKindPrepared: return f.parseResultPrepared(), nil case resultKindSchemaChanged: return f.parseResultSchemaChange(), nil } return nil, NewErrProtocol("unknown result kind: %x", kind) } type resultRowsFrame struct { frameHeader meta resultMetadata // dont parse the rows here as we only need to do it once numRows int } func (f *resultRowsFrame) String() string { return fmt.Sprintf("[result_rows meta=%v]", f.meta) } func (f *framer) parseResultRows() frame { result := &resultRowsFrame{} result.meta = f.parseResultMetadata() result.numRows = f.readInt() if result.numRows < 0 { panic(fmt.Errorf("invalid row_count in result frame: %d", result.numRows)) } return result } type resultKeyspaceFrame struct { frameHeader keyspace string } func (r *resultKeyspaceFrame) String() string { return fmt.Sprintf("[result_keyspace keyspace=%s]", r.keyspace) } func (f *framer) parseResultSetKeyspace() frame { return &resultKeyspaceFrame{ frameHeader: *f.header, keyspace: f.readString(), } } type resultPreparedFrame struct { frameHeader preparedID []byte reqMeta preparedMetadata respMeta resultMetadata } func (f *framer) parseResultPrepared() frame { frame := &resultPreparedFrame{ frameHeader: *f.header, preparedID: f.readShortBytes(), reqMeta: f.parsePreparedMetadata(), } if f.proto < protoVersion2 { return frame } frame.respMeta = f.parseResultMetadata() return frame } type schemaChangeKeyspace struct { frameHeader change string keyspace string } func (f schemaChangeKeyspace) String() string { return fmt.Sprintf("[event schema_change_keyspace change=%q keyspace=%q]", f.change, f.keyspace) } type schemaChangeTable struct { frameHeader change string keyspace string object string } func (f schemaChangeTable) String() string { return fmt.Sprintf("[event schema_change change=%q keyspace=%q object=%q]", f.change, f.keyspace, f.object) } type schemaChangeType struct { frameHeader change string keyspace string object string } type schemaChangeFunction struct { frameHeader change string keyspace string name string args []string } type schemaChangeAggregate struct { frameHeader change string keyspace string name string args []string } func (f *framer) parseResultSchemaChange() frame { if f.proto <= protoVersion2 { change := f.readString() keyspace := f.readString() table := f.readString() if table != "" { return &schemaChangeTable{ frameHeader: *f.header, change: change, keyspace: keyspace, object: table, } } else { return &schemaChangeKeyspace{ frameHeader: *f.header, change: change, keyspace: keyspace, } } } else { change := f.readString() target := f.readString() // TODO: could just use a separate type for each target switch target { case "KEYSPACE": frame := &schemaChangeKeyspace{ frameHeader: *f.header, change: change, } frame.keyspace = f.readString() return frame case "TABLE": frame := &schemaChangeTable{ frameHeader: *f.header, change: change, } frame.keyspace = f.readString() frame.object = f.readString() return frame case "TYPE": frame := &schemaChangeType{ frameHeader: *f.header, change: change, } frame.keyspace = f.readString() frame.object = f.readString() return frame case "FUNCTION": frame := &schemaChangeFunction{ frameHeader: *f.header, change: change, } frame.keyspace = f.readString() frame.name = f.readString() frame.args = f.readStringList() return frame case "AGGREGATE": frame := &schemaChangeAggregate{ frameHeader: *f.header, change: change, } frame.keyspace = f.readString() frame.name = f.readString() frame.args = f.readStringList() return frame default: panic(fmt.Errorf("gocql: unknown SCHEMA_CHANGE target: %q change: %q", target, change)) } } } type authenticateFrame struct { frameHeader class string } func (a *authenticateFrame) String() string { return fmt.Sprintf("[authenticate class=%q]", a.class) } func (f *framer) parseAuthenticateFrame() frame { return &authenticateFrame{ frameHeader: *f.header, class: f.readString(), } } type authSuccessFrame struct { frameHeader data []byte } func (a *authSuccessFrame) String() string { return fmt.Sprintf("[auth_success data=%q]", a.data) } func (f *framer) parseAuthSuccessFrame() frame { return &authSuccessFrame{ frameHeader: *f.header, data: f.readBytes(), } } type authChallengeFrame struct { frameHeader data []byte } func (a *authChallengeFrame) String() string { return fmt.Sprintf("[auth_challenge data=%q]", a.data) } func (f *framer) parseAuthChallengeFrame() frame { return &authChallengeFrame{ frameHeader: *f.header, data: f.readBytes(), } } type statusChangeEventFrame struct { frameHeader change string host net.IP port int } func (t statusChangeEventFrame) String() string { return fmt.Sprintf("[status_change change=%s host=%v port=%v]", t.change, t.host, t.port) } // essentially the same as statusChange type topologyChangeEventFrame struct { frameHeader change string host net.IP port int } func (t topologyChangeEventFrame) String() string { return fmt.Sprintf("[topology_change change=%s host=%v port=%v]", t.change, t.host, t.port) } func (f *framer) parseEventFrame() frame { eventType := f.readString() switch eventType { case "TOPOLOGY_CHANGE": frame := &topologyChangeEventFrame{frameHeader: *f.header} frame.change = f.readString() frame.host, frame.port = f.readInet() return frame case "STATUS_CHANGE": frame := &statusChangeEventFrame{frameHeader: *f.header} frame.change = f.readString() frame.host, frame.port = f.readInet() return frame case "SCHEMA_CHANGE": // this should work for all versions return f.parseResultSchemaChange() default: panic(fmt.Errorf("gocql: unknown event type: %q", eventType)) } } type writeAuthResponseFrame struct { data []byte } func (a *writeAuthResponseFrame) String() string { return fmt.Sprintf("[auth_response data=%q]", a.data) } func (a *writeAuthResponseFrame) buildFrame(framer *framer, streamID int) error { return framer.writeAuthResponseFrame(streamID, a.data) } func (f *framer) writeAuthResponseFrame(streamID int, data []byte) error { f.writeHeader(f.flags, opAuthResponse, streamID) f.writeBytes(data) return f.finish() } type queryValues struct { value []byte // optional name, will set With names for values flag name string isUnset bool } type queryParams struct { consistency Consistency // v2+ skipMeta bool values []queryValues pageSize int pagingState []byte serialConsistency SerialConsistency // v3+ defaultTimestamp bool defaultTimestampValue int64 // v5+ keyspace string } func (q queryParams) String() string { return fmt.Sprintf("[query_params consistency=%v skip_meta=%v page_size=%d paging_state=%q serial_consistency=%v default_timestamp=%v values=%v keyspace=%s]", q.consistency, q.skipMeta, q.pageSize, q.pagingState, q.serialConsistency, q.defaultTimestamp, q.values, q.keyspace) } func (f *framer) writeQueryParams(opts *queryParams) { f.writeConsistency(opts.consistency) if f.proto == protoVersion1 { return } var flags byte if len(opts.values) > 0 { flags |= flagValues } if opts.skipMeta { flags |= flagSkipMetaData } if opts.pageSize > 0 { flags |= flagPageSize } if len(opts.pagingState) > 0 { flags |= flagWithPagingState } if opts.serialConsistency > 0 { flags |= flagWithSerialConsistency } names := false // protoV3 specific things if f.proto > protoVersion2 { if opts.defaultTimestamp { flags |= flagDefaultTimestamp } if len(opts.values) > 0 && opts.values[0].name != "" { flags |= flagWithNameValues names = true } } if opts.keyspace != "" { if f.proto > protoVersion4 { flags |= flagWithKeyspace } else { panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) } } if f.proto > protoVersion4 { f.writeUint(uint32(flags)) } else { f.writeByte(flags) } if n := len(opts.values); n > 0 { f.writeShort(uint16(n)) for i := 0; i < n; i++ { if names { f.writeString(opts.values[i].name) } if opts.values[i].isUnset { f.writeUnset() } else { f.writeBytes(opts.values[i].value) } } } if opts.pageSize > 0 { f.writeInt(int32(opts.pageSize)) } if len(opts.pagingState) > 0 { f.writeBytes(opts.pagingState) } if opts.serialConsistency > 0 { f.writeConsistency(Consistency(opts.serialConsistency)) } if f.proto > protoVersion2 && opts.defaultTimestamp { // timestamp in microseconds var ts int64 if opts.defaultTimestampValue != 0 { ts = opts.defaultTimestampValue } else { ts = time.Now().UnixNano() / 1000 } f.writeLong(ts) } if opts.keyspace != "" { f.writeString(opts.keyspace) } } type writeQueryFrame struct { statement string params queryParams // v4+ customPayload map[string][]byte } func (w *writeQueryFrame) String() string { return fmt.Sprintf("[query statement=%q params=%v]", w.statement, w.params) } func (w *writeQueryFrame) buildFrame(framer *framer, streamID int) error { return framer.writeQueryFrame(streamID, w.statement, &w.params, w.customPayload) } func (f *framer) writeQueryFrame(streamID int, statement string, params *queryParams, customPayload map[string][]byte) error { if len(customPayload) > 0 { f.payload() } f.writeHeader(f.flags, opQuery, streamID) f.writeCustomPayload(&customPayload) f.writeLongString(statement) f.writeQueryParams(params) return f.finish() } type frameBuilder interface { buildFrame(framer *framer, streamID int) error } type frameWriterFunc func(framer *framer, streamID int) error func (f frameWriterFunc) buildFrame(framer *framer, streamID int) error { return f(framer, streamID) } type writeExecuteFrame struct { preparedID []byte params queryParams // v4+ customPayload map[string][]byte } func (e *writeExecuteFrame) String() string { return fmt.Sprintf("[execute id=% X params=%v]", e.preparedID, &e.params) } func (e *writeExecuteFrame) buildFrame(fr *framer, streamID int) error { return fr.writeExecuteFrame(streamID, e.preparedID, &e.params, &e.customPayload) } func (f *framer) writeExecuteFrame(streamID int, preparedID []byte, params *queryParams, customPayload *map[string][]byte) error { if len(*customPayload) > 0 { f.payload() } f.writeHeader(f.flags, opExecute, streamID) f.writeCustomPayload(customPayload) f.writeShortBytes(preparedID) if f.proto > protoVersion1 { f.writeQueryParams(params) } else { n := len(params.values) f.writeShort(uint16(n)) for i := 0; i < n; i++ { if params.values[i].isUnset { f.writeUnset() } else { f.writeBytes(params.values[i].value) } } f.writeConsistency(params.consistency) } return f.finish() } // TODO: can we replace BatchStatemt with batchStatement? As they prety much // duplicate each other type batchStatment struct { preparedID []byte statement string values []queryValues } type writeBatchFrame struct { typ BatchType statements []batchStatment consistency Consistency // v3+ serialConsistency SerialConsistency defaultTimestamp bool defaultTimestampValue int64 //v4+ customPayload map[string][]byte } func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error { return framer.writeBatchFrame(streamID, w, w.customPayload) } func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload map[string][]byte) error { if len(customPayload) > 0 { f.payload() } f.writeHeader(f.flags, opBatch, streamID) f.writeCustomPayload(&customPayload) f.writeByte(byte(w.typ)) n := len(w.statements) f.writeShort(uint16(n)) var flags byte for i := 0; i < n; i++ { b := &w.statements[i] if len(b.preparedID) == 0 { f.writeByte(0) f.writeLongString(b.statement) } else { f.writeByte(1) f.writeShortBytes(b.preparedID) } f.writeShort(uint16(len(b.values))) for j := range b.values { col := b.values[j] if f.proto > protoVersion2 && col.name != "" { // TODO: move this check into the caller and set a flag on writeBatchFrame // to indicate using named values if f.proto <= protoVersion5 { return fmt.Errorf("gocql: named query values are not supported in batches, please see https://issues.apache.org/jira/browse/CASSANDRA-10246") } flags |= flagWithNameValues f.writeString(col.name) } if col.isUnset { f.writeUnset() } else { f.writeBytes(col.value) } } } f.writeConsistency(w.consistency) if f.proto > protoVersion2 { if w.serialConsistency > 0 { flags |= flagWithSerialConsistency } if w.defaultTimestamp { flags |= flagDefaultTimestamp } if f.proto > protoVersion4 { f.writeUint(uint32(flags)) } else { f.writeByte(flags) } if w.serialConsistency > 0 { f.writeConsistency(Consistency(w.serialConsistency)) } if w.defaultTimestamp { var ts int64 if w.defaultTimestampValue != 0 { ts = w.defaultTimestampValue } else { ts = time.Now().UnixNano() / 1000 } f.writeLong(ts) } } return f.finish() } type writeOptionsFrame struct{} func (w *writeOptionsFrame) buildFrame(framer *framer, streamID int) error { return framer.writeOptionsFrame(streamID, w) } func (f *framer) writeOptionsFrame(stream int, _ *writeOptionsFrame) error { f.writeHeader(f.flags&^flagCompress, opOptions, stream) return f.finish() } type writeRegisterFrame struct { events []string } func (w *writeRegisterFrame) buildFrame(framer *framer, streamID int) error { return framer.writeRegisterFrame(streamID, w) } func (f *framer) writeRegisterFrame(streamID int, w *writeRegisterFrame) error { f.writeHeader(f.flags, opRegister, streamID) f.writeStringList(w.events) return f.finish() } func (f *framer) readByte() byte { if len(f.buf) < 1 { panic(fmt.Errorf("not enough bytes in buffer to read byte require 1 got: %d", len(f.buf))) } b := f.buf[0] f.buf = f.buf[1:] return b } func (f *framer) readInt() (n int) { if len(f.buf) < 4 { panic(fmt.Errorf("not enough bytes in buffer to read int require 4 got: %d", len(f.buf))) } n = int(int32(f.buf[0])<<24 | int32(f.buf[1])<<16 | int32(f.buf[2])<<8 | int32(f.buf[3])) f.buf = f.buf[4:] return } func (f *framer) readShort() (n uint16) { if len(f.buf) < 2 { panic(fmt.Errorf("not enough bytes in buffer to read short require 2 got: %d", len(f.buf))) } n = uint16(f.buf[0])<<8 | uint16(f.buf[1]) f.buf = f.buf[2:] return } func (f *framer) readString() (s string) { size := f.readShort() if len(f.buf) < int(size) { panic(fmt.Errorf("not enough bytes in buffer to read string require %d got: %d", size, len(f.buf))) } s = string(f.buf[:size]) f.buf = f.buf[size:] return } func (f *framer) readLongString() (s string) { size := f.readInt() if len(f.buf) < size { panic(fmt.Errorf("not enough bytes in buffer to read long string require %d got: %d", size, len(f.buf))) } s = string(f.buf[:size]) f.buf = f.buf[size:] return } func (f *framer) readUUID() *UUID { if len(f.buf) < 16 { panic(fmt.Errorf("not enough bytes in buffer to read uuid require %d got: %d", 16, len(f.buf))) } // TODO: how to handle this error, if it is a uuid, then sureley, problems? u, _ := UUIDFromBytes(f.buf[:16]) f.buf = f.buf[16:] return &u } func (f *framer) readStringList() []string { size := f.readShort() l := make([]string, size) for i := 0; i < int(size); i++ { l[i] = f.readString() } return l } func (f *framer) readBytesInternal() ([]byte, error) { size := f.readInt() if size < 0 { return nil, nil } if len(f.buf) < size { return nil, fmt.Errorf("not enough bytes in buffer to read bytes require %d got: %d", size, len(f.buf)) } l := f.buf[:size] f.buf = f.buf[size:] return l, nil } func (f *framer) readBytes() []byte { l, err := f.readBytesInternal() if err != nil { panic(err) } return l } func (f *framer) readShortBytes() []byte { size := f.readShort() if len(f.buf) < int(size) { panic(fmt.Errorf("not enough bytes in buffer to read short bytes: require %d got %d", size, len(f.buf))) } l := f.buf[:size] f.buf = f.buf[size:] return l } func (f *framer) readInetAdressOnly() net.IP { if len(f.buf) < 1 { panic(fmt.Errorf("not enough bytes in buffer to read inet size require %d got: %d", 1, len(f.buf))) } size := f.buf[0] f.buf = f.buf[1:] if !(size == 4 || size == 16) { panic(fmt.Errorf("invalid IP size: %d", size)) } if len(f.buf) < 1 { panic(fmt.Errorf("not enough bytes in buffer to read inet require %d got: %d", size, len(f.buf))) } ip := make([]byte, size) copy(ip, f.buf[:size]) f.buf = f.buf[size:] return net.IP(ip) } func (f *framer) readInet() (net.IP, int) { return f.readInetAdressOnly(), f.readInt() } func (f *framer) readConsistency() Consistency { return Consistency(f.readShort()) } func (f *framer) readBytesMap() map[string][]byte { size := f.readShort() m := make(map[string][]byte, size) for i := 0; i < int(size); i++ { k := f.readString() v := f.readBytes() m[k] = v } return m } func (f *framer) readStringMultiMap() map[string][]string { size := f.readShort() m := make(map[string][]string, size) for i := 0; i < int(size); i++ { k := f.readString() v := f.readStringList() m[k] = v } return m } func (f *framer) writeByte(b byte) { f.buf = append(f.buf, b) } func appendBytes(p []byte, d []byte) []byte { if d == nil { return appendInt(p, -1) } p = appendInt(p, int32(len(d))) p = append(p, d...) return p } func appendShort(p []byte, n uint16) []byte { return append(p, byte(n>>8), byte(n), ) } func appendInt(p []byte, n int32) []byte { return append(p, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) } func appendUint(p []byte, n uint32) []byte { return append(p, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) } func appendLong(p []byte, n int64) []byte { return append(p, byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), byte(n>>24), byte(n>>16), byte(n>>8), byte(n), ) } func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { if len(*customPayload) > 0 { if f.proto < protoVersion4 { panic("Custom payload is not supported with version V3 or less") } f.writeBytesMap(*customPayload) } } // these are protocol level binary types func (f *framer) writeInt(n int32) { f.buf = appendInt(f.buf, n) } func (f *framer) writeUint(n uint32) { f.buf = appendUint(f.buf, n) } func (f *framer) writeShort(n uint16) { f.buf = appendShort(f.buf, n) } func (f *framer) writeLong(n int64) { f.buf = appendLong(f.buf, n) } func (f *framer) writeString(s string) { f.writeShort(uint16(len(s))) f.buf = append(f.buf, s...) } func (f *framer) writeLongString(s string) { f.writeInt(int32(len(s))) f.buf = append(f.buf, s...) } func (f *framer) writeStringList(l []string) { f.writeShort(uint16(len(l))) for _, s := range l { f.writeString(s) } } func (f *framer) writeUnset() { // Protocol version 4 specifies that bind variables do not require having a // value when executing a statement. Bind variables without a value are // called 'unset'. The 'unset' bind variable is serialized as the int // value '-2' without following bytes. f.writeInt(-2) } func (f *framer) writeBytes(p []byte) { // TODO: handle null case correctly, // [bytes] A [int] n, followed by n bytes if n >= 0. If n < 0, // no byte should follow and the value represented is `null`. if p == nil { f.writeInt(-1) } else { f.writeInt(int32(len(p))) f.buf = append(f.buf, p...) } } func (f *framer) writeShortBytes(p []byte) { f.writeShort(uint16(len(p))) f.buf = append(f.buf, p...) } func (f *framer) writeConsistency(cons Consistency) { f.writeShort(uint16(cons)) } func (f *framer) writeStringMap(m map[string]string) { f.writeShort(uint16(len(m))) for k, v := range m { f.writeString(k) f.writeString(v) } } func (f *framer) writeBytesMap(m map[string][]byte) { f.writeShort(uint16(len(m))) for k, v := range m { f.writeString(k) f.writeBytes(v) } } cassandra-gocql-driver-1.7.0/frame_test.go000066400000000000000000000072261467504044300205500ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "os" "testing" ) func TestFuzzBugs(t *testing.T) { // these inputs are found using go-fuzz (https://github.com/dvyukov/go-fuzz) // and should cause a panic unless fixed. tests := [][]byte{ []byte("00000\xa0000"), []byte("\x8000\x0e\x00\x00\x00\x000"), []byte("\x8000\x00\x00\x00\x00\t0000000000"), []byte("\xa0\xff\x01\xae\xefqE\xf2\x1a"), []byte("\x8200\b\x00\x00\x00c\x00\x00\x00\x02000\x01\x00\x00\x00\x03" + "\x00\n0000000000\x00\x14000000" + "00000000000000\x00\x020000" + "\x00\a000000000\x00\x050000000" + "\xff0000000000000000000" + "0000000"), []byte("\x82\xe600\x00\x00\x00\x000"), []byte("\x8200\b\x00\x00\x00\b0\x00\x00\x00\x040000"), []byte("\x8200\x00\x00\x00\x00\x100\x00\x00\x12\x00\x00\x0000000" + "00000"), []byte("\x83000\b\x00\x00\x00\x14\x00\x00\x00\x020000000" + "000000000"), []byte("\x83000\b\x00\x00\x000\x00\x00\x00\x04\x00\x1000000" + "00000000000000e00000" + "000\x800000000000000000" + "0000000000000"), } for i, test := range tests { t.Logf("test %d input: %q", i, test) r := bytes.NewReader(test) head, err := readHeader(r, make([]byte, 9)) if err != nil { continue } framer := newFramer(nil, byte(head.version)) err = framer.readFrame(r, &head) if err != nil { continue } frame, err := framer.parseFrame() if err != nil { continue } t.Errorf("(%d) expected to fail for input % X", i, test) t.Errorf("(%d) frame=%+#v", i, frame) } } func TestFrameWriteTooLong(t *testing.T) { if os.Getenv("TRAVIS") == "true" { t.Skip("skipping test in travis due to memory pressure with the race detecor") } framer := newFramer(nil, 2) framer.writeHeader(0, opStartup, 1) framer.writeBytes(make([]byte, maxFrameSize+1)) err := framer.finish() if err != ErrFrameTooBig { t.Fatalf("expected to get %v got %v", ErrFrameTooBig, err) } } func TestFrameReadTooLong(t *testing.T) { if os.Getenv("TRAVIS") == "true" { t.Skip("skipping test in travis due to memory pressure with the race detecor") } r := &bytes.Buffer{} r.Write(make([]byte, maxFrameSize+1)) // write a new header right after this frame to verify that we can read it r.Write([]byte{0x02, 0x00, 0x00, byte(opReady), 0x00, 0x00, 0x00, 0x00}) framer := newFramer(nil, 2) head := frameHeader{ version: 2, op: opReady, length: r.Len() - 8, } err := framer.readFrame(r, &head) if err != ErrFrameTooBig { t.Fatalf("expected to get %v got %v", ErrFrameTooBig, err) } head, err = readHeader(r, make([]byte, 8)) if err != nil { t.Fatal(err) } if head.op != opReady { t.Fatalf("expected to get header %v got %v", opReady, head.op) } } cassandra-gocql-driver-1.7.0/framer_bench_test.go000066400000000000000000000033671467504044300220730ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "compress/gzip" "io/ioutil" "os" "testing" ) func readGzipData(path string) ([]byte, error) { f, err := os.Open(path) if err != nil { return nil, err } defer f.Close() r, err := gzip.NewReader(f) if err != nil { return nil, err } defer r.Close() return ioutil.ReadAll(r) } func BenchmarkParseRowsFrame(b *testing.B) { data, err := readGzipData("testdata/frames/bench_parse_result.gz") if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { framer := &framer{ header: &frameHeader{ version: protoVersion4 | 0x80, op: opResult, length: len(data), }, buf: data, } _, err = framer.parseFrame() if err != nil { b.Fatal(err) } } } cassandra-gocql-driver-1.7.0/fuzz.go000066400000000000000000000027331467504044300174130ustar00rootroot00000000000000//go:build gofuzz // +build gofuzz /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import "bytes" func Fuzz(data []byte) int { var bw bytes.Buffer r := bytes.NewReader(data) head, err := readHeader(r, make([]byte, 9)) if err != nil { return 0 } framer := newFramer(r, &bw, nil, byte(head.version)) err = framer.readFrame(&head) if err != nil { return 0 } frame, err := framer.parseFrame() if err != nil { return 0 } if frame != nil { return 1 } return 2 } cassandra-gocql-driver-1.7.0/go.mod000066400000000000000000000023101467504044300171630ustar00rootroot00000000000000// // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // module github.com/gocql/gocql require ( github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 // indirect github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect github.com/golang/snappy v0.0.3 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed github.com/kr/pretty v0.1.0 // indirect github.com/stretchr/testify v1.3.0 // indirect gopkg.in/inf.v0 v0.9.1 ) go 1.13 cassandra-gocql-driver-1.7.0/go.sum000066400000000000000000000037371467504044300172260ustar00rootroot00000000000000github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= cassandra-gocql-driver-1.7.0/helpers.go000066400000000000000000000273171467504044300200640ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "fmt" "math/big" "net" "reflect" "strings" "time" "gopkg.in/inf.v0" ) type RowData struct { Columns []string Values []interface{} } func goType(t TypeInfo) (reflect.Type, error) { switch t.Type() { case TypeVarchar, TypeAscii, TypeInet, TypeText: return reflect.TypeOf(*new(string)), nil case TypeBigInt, TypeCounter: return reflect.TypeOf(*new(int64)), nil case TypeTime: return reflect.TypeOf(*new(time.Duration)), nil case TypeTimestamp: return reflect.TypeOf(*new(time.Time)), nil case TypeBlob: return reflect.TypeOf(*new([]byte)), nil case TypeBoolean: return reflect.TypeOf(*new(bool)), nil case TypeFloat: return reflect.TypeOf(*new(float32)), nil case TypeDouble: return reflect.TypeOf(*new(float64)), nil case TypeInt: return reflect.TypeOf(*new(int)), nil case TypeSmallInt: return reflect.TypeOf(*new(int16)), nil case TypeTinyInt: return reflect.TypeOf(*new(int8)), nil case TypeDecimal: return reflect.TypeOf(*new(*inf.Dec)), nil case TypeUUID, TypeTimeUUID: return reflect.TypeOf(*new(UUID)), nil case TypeList, TypeSet: elemType, err := goType(t.(CollectionType).Elem) if err != nil { return nil, err } return reflect.SliceOf(elemType), nil case TypeMap: keyType, err := goType(t.(CollectionType).Key) if err != nil { return nil, err } valueType, err := goType(t.(CollectionType).Elem) if err != nil { return nil, err } return reflect.MapOf(keyType, valueType), nil case TypeVarint: return reflect.TypeOf(*new(*big.Int)), nil case TypeTuple: // what can we do here? all there is to do is to make a list of interface{} tuple := t.(TupleTypeInfo) return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil case TypeUDT: return reflect.TypeOf(make(map[string]interface{})), nil case TypeDate: return reflect.TypeOf(*new(time.Time)), nil case TypeDuration: return reflect.TypeOf(*new(Duration)), nil default: return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t) } } func dereference(i interface{}) interface{} { return reflect.Indirect(reflect.ValueOf(i)).Interface() } func getCassandraBaseType(name string) Type { switch name { case "ascii": return TypeAscii case "bigint": return TypeBigInt case "blob": return TypeBlob case "boolean": return TypeBoolean case "counter": return TypeCounter case "date": return TypeDate case "decimal": return TypeDecimal case "double": return TypeDouble case "duration": return TypeDuration case "float": return TypeFloat case "int": return TypeInt case "smallint": return TypeSmallInt case "tinyint": return TypeTinyInt case "time": return TypeTime case "timestamp": return TypeTimestamp case "uuid": return TypeUUID case "varchar": return TypeVarchar case "text": return TypeText case "varint": return TypeVarint case "timeuuid": return TypeTimeUUID case "inet": return TypeInet case "MapType": return TypeMap case "ListType": return TypeList case "SetType": return TypeSet case "TupleType": return TypeTuple default: return TypeCustom } } func getCassandraType(name string, logger StdLogger) TypeInfo { if strings.HasPrefix(name, "frozen<") { return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger) } else if strings.HasPrefix(name, "set<") { return CollectionType{ NativeType: NativeType{typ: TypeSet}, Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger), } } else if strings.HasPrefix(name, "list<") { return CollectionType{ NativeType: NativeType{typ: TypeList}, Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), } } else if strings.HasPrefix(name, "map<") { names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) if len(names) != 2 { logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) return NativeType{ typ: TypeCustom, } } return CollectionType{ NativeType: NativeType{typ: TypeMap}, Key: getCassandraType(names[0], logger), Elem: getCassandraType(names[1], logger), } } else if strings.HasPrefix(name, "tuple<") { names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) types := make([]TypeInfo, len(names)) for i, name := range names { types[i] = getCassandraType(name, logger) } return TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: types, } } else { return NativeType{ typ: getCassandraBaseType(name), } } } func splitCompositeTypes(name string) []string { if !strings.Contains(name, "<") { return strings.Split(name, ", ") } var parts []string lessCount := 0 segment := "" for _, char := range name { if char == ',' && lessCount == 0 { if segment != "" { parts = append(parts, strings.TrimSpace(segment)) } segment = "" continue } segment += string(char) if char == '<' { lessCount++ } else if char == '>' { lessCount-- } } if segment != "" { parts = append(parts, strings.TrimSpace(segment)) } return parts } func apacheToCassandraType(t string) string { t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) t = strings.Replace(t, "(", "<", -1) t = strings.Replace(t, ")", ">", -1) types := strings.FieldsFunc(t, func(r rune) bool { return r == '<' || r == '>' || r == ',' }) for _, typ := range types { t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1) } // This is done so it exactly matches what Cassandra returns return strings.Replace(t, ",", ", ", -1) } func getApacheCassandraType(class string) Type { switch strings.TrimPrefix(class, apacheCassandraTypePrefix) { case "AsciiType": return TypeAscii case "LongType": return TypeBigInt case "BytesType": return TypeBlob case "BooleanType": return TypeBoolean case "CounterColumnType": return TypeCounter case "DecimalType": return TypeDecimal case "DoubleType": return TypeDouble case "FloatType": return TypeFloat case "Int32Type": return TypeInt case "ShortType": return TypeSmallInt case "ByteType": return TypeTinyInt case "TimeType": return TypeTime case "DateType", "TimestampType": return TypeTimestamp case "UUIDType", "LexicalUUIDType": return TypeUUID case "UTF8Type": return TypeVarchar case "IntegerType": return TypeVarint case "TimeUUIDType": return TypeTimeUUID case "InetAddressType": return TypeInet case "MapType": return TypeMap case "ListType": return TypeList case "SetType": return TypeSet case "TupleType": return TypeTuple case "DurationType": return TypeDuration default: return TypeCustom } } func (r *RowData) rowMap(m map[string]interface{}) { for i, column := range r.Columns { val := dereference(r.Values[i]) if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice { valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap()) reflect.Copy(valCopy, valVal) m[column] = valCopy.Interface() } else { m[column] = val } } } // TupeColumnName will return the column name of a tuple value in a column named // c at index n. It should be used if a specific element within a tuple is needed // to be extracted from a map returned from SliceMap or MapScan. func TupleColumnName(c string, n int) string { return fmt.Sprintf("%s[%d]", c, n) } func (iter *Iter) RowData() (RowData, error) { if iter.err != nil { return RowData{}, iter.err } columns := make([]string, 0, len(iter.Columns())) values := make([]interface{}, 0, len(iter.Columns())) for _, column := range iter.Columns() { if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { val, err := column.TypeInfo.NewWithError() if err != nil { return RowData{}, err } columns = append(columns, column.Name) values = append(values, val) } else { for i, elem := range c.Elems { columns = append(columns, TupleColumnName(column.Name, i)) val, err := elem.NewWithError() if err != nil { return RowData{}, err } values = append(values, val) } } } rowData := RowData{ Columns: columns, Values: values, } return rowData, nil } // TODO(zariel): is it worth exporting this? func (iter *Iter) rowMap() (map[string]interface{}, error) { if iter.err != nil { return nil, iter.err } rowData, _ := iter.RowData() iter.Scan(rowData.Values...) m := make(map[string]interface{}, len(rowData.Columns)) rowData.rowMap(m) return m, nil } // SliceMap is a helper function to make the API easier to use // returns the data from the query in the form of []map[string]interface{} func (iter *Iter) SliceMap() ([]map[string]interface{}, error) { if iter.err != nil { return nil, iter.err } // Not checking for the error because we just did rowData, _ := iter.RowData() dataToReturn := make([]map[string]interface{}, 0) for iter.Scan(rowData.Values...) { m := make(map[string]interface{}, len(rowData.Columns)) rowData.rowMap(m) dataToReturn = append(dataToReturn, m) } if iter.err != nil { return nil, iter.err } return dataToReturn, nil } // MapScan takes a map[string]interface{} and populates it with a row // that is returned from cassandra. // // Each call to MapScan() must be called with a new map object. // During the call to MapScan() any pointers in the existing map // are replaced with non pointer types before the call returns // // iter := session.Query(`SELECT * FROM mytable`).Iter() // for { // // New map each iteration // row := make(map[string]interface{}) // if !iter.MapScan(row) { // break // } // // Do things with row // if fullname, ok := row["fullname"]; ok { // fmt.Printf("Full Name: %s\n", fullname) // } // } // // You can also pass pointers in the map before each call // // var fullName FullName // Implements gocql.Unmarshaler and gocql.Marshaler interfaces // var address net.IP // var age int // iter := session.Query(`SELECT * FROM scan_map_table`).Iter() // for { // // New map each iteration // row := map[string]interface{}{ // "fullname": &fullName, // "age": &age, // "address": &address, // } // if !iter.MapScan(row) { // break // } // fmt.Printf("First: %s Age: %d Address: %q\n", fullName.FirstName, age, address) // } func (iter *Iter) MapScan(m map[string]interface{}) bool { if iter.err != nil { return false } // Not checking for the error because we just did rowData, _ := iter.RowData() for i, col := range rowData.Columns { if dest, ok := m[col]; ok { rowData.Values[i] = dest } } if iter.Scan(rowData.Values...) { rowData.rowMap(m) return true } return false } func copyBytes(p []byte) []byte { b := make([]byte, len(p)) copy(b, p) return b } var failDNS = false func LookupIP(host string) ([]net.IP, error) { if failDNS { return nil, &net.DNSError{} } return net.LookupIP(host) } cassandra-gocql-driver-1.7.0/helpers_test.go000066400000000000000000000141631467504044300211160ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "reflect" "testing" ) func TestGetCassandraType_Set(t *testing.T) { typ := getCassandraType("set", &defaultLogger{}) set, ok := typ.(CollectionType) if !ok { t.Fatalf("expected CollectionType got %T", typ) } else if set.typ != TypeSet { t.Fatalf("expected type %v got %v", TypeSet, set.typ) } inner, ok := set.Elem.(NativeType) if !ok { t.Fatalf("expected to get NativeType got %T", set.Elem) } else if inner.typ != TypeText { t.Fatalf("expected to get %v got %v for set value", TypeText, set.typ) } } func TestGetCassandraType(t *testing.T) { tests := []struct { input string exp TypeInfo }{ { "set", CollectionType{ NativeType: NativeType{typ: TypeSet}, Elem: NativeType{typ: TypeText}, }, }, { "map", CollectionType{ NativeType: NativeType{typ: TypeMap}, Key: NativeType{typ: TypeText}, Elem: NativeType{typ: TypeVarchar}, }, }, { "list", CollectionType{ NativeType: NativeType{typ: TypeList}, Elem: NativeType{typ: TypeInt}, }, }, { "tuple", TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeInt}, NativeType{typ: TypeText}, }, }, }, { "frozen>>>>>", CollectionType{ NativeType: NativeType{typ: TypeMap}, Key: NativeType{typ: TypeText}, Elem: CollectionType{ NativeType: NativeType{typ: TypeList}, Elem: TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeInt}, }, }, }, }, }, { "frozen>>>>>, frozen>>>>>, frozen>>>>>>>", TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeText}, CollectionType{ NativeType: NativeType{typ: TypeList}, Elem: TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeInt}, }, }, }, }, }, TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeText}, CollectionType{ NativeType: NativeType{typ: TypeList}, Elem: TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeInt}, }, }, }, }, }, CollectionType{ NativeType: NativeType{typ: TypeMap}, Key: NativeType{typ: TypeText}, Elem: CollectionType{ NativeType: NativeType{typ: TypeList}, Elem: TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeInt}, }, }, }, }, }, }, }, { "frozen>, int, frozen>>>", TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeInt}, }, }, NativeType{typ: TypeInt}, TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeInt}, }, }, }, }, }, { "frozen>, int>>", CollectionType{ NativeType: NativeType{typ: TypeMap}, Key: TupleTypeInfo{ NativeType: NativeType{typ: TypeTuple}, Elems: []TypeInfo{ NativeType{typ: TypeInt}, NativeType{typ: TypeInt}, }, }, Elem: NativeType{typ: TypeInt}, }, }, { "set", CollectionType{ NativeType: NativeType{typ: TypeSet}, Elem: NativeType{typ: TypeSmallInt}, }, }, { "list", CollectionType{ NativeType: NativeType{typ: TypeList}, Elem: NativeType{typ: TypeTinyInt}, }, }, {"smallint", NativeType{typ: TypeSmallInt}}, {"tinyint", NativeType{typ: TypeTinyInt}}, {"duration", NativeType{typ: TypeDuration}}, {"date", NativeType{typ: TypeDate}}, { "list", CollectionType{ NativeType: NativeType{typ: TypeList}, Elem: NativeType{typ: TypeDate}, }, }, { "set", CollectionType{ NativeType: NativeType{typ: TypeSet}, Elem: NativeType{typ: TypeDuration}, }, }, } for _, test := range tests { t.Run(test.input, func(t *testing.T) { got := getCassandraType(test.input, &defaultLogger{}) // TODO(zariel): define an equal method on the types? if !reflect.DeepEqual(got, test.exp) { t.Fatalf("expected %v got %v", test.exp, got) } }) } } cassandra-gocql-driver-1.7.0/host_source.go000066400000000000000000000525301467504044300207520ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "context" "errors" "fmt" "net" "strconv" "strings" "sync" "time" ) var ErrCannotFindHost = errors.New("cannot find host") var ErrHostAlreadyExists = errors.New("host already exists") type nodeState int32 func (n nodeState) String() string { if n == NodeUp { return "UP" } else if n == NodeDown { return "DOWN" } return fmt.Sprintf("UNKNOWN_%d", n) } const ( NodeUp nodeState = iota NodeDown ) type cassVersion struct { Major, Minor, Patch int } func (c *cassVersion) Set(v string) error { if v == "" { return nil } return c.UnmarshalCQL(nil, []byte(v)) } func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error { return c.unmarshal(data) } func (c *cassVersion) unmarshal(data []byte) error { version := strings.TrimSuffix(string(data), "-SNAPSHOT") version = strings.TrimPrefix(version, "v") v := strings.Split(version, ".") if len(v) < 2 { return fmt.Errorf("invalid version string: %s", data) } var err error c.Major, err = strconv.Atoi(v[0]) if err != nil { return fmt.Errorf("invalid major version %v: %v", v[0], err) } c.Minor, err = strconv.Atoi(v[1]) if err != nil { return fmt.Errorf("invalid minor version %v: %v", v[1], err) } if len(v) > 2 { c.Patch, err = strconv.Atoi(v[2]) if err != nil { return fmt.Errorf("invalid patch version %v: %v", v[2], err) } } return nil } func (c cassVersion) Before(major, minor, patch int) bool { // We're comparing us (cassVersion) with the provided version (major, minor, patch) // We return true if our version is lower (comes before) than the provided one. if c.Major < major { return true } else if c.Major == major { if c.Minor < minor { return true } else if c.Minor == minor && c.Patch < patch { return true } } return false } func (c cassVersion) AtLeast(major, minor, patch int) bool { return !c.Before(major, minor, patch) } func (c cassVersion) String() string { return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch) } func (c cassVersion) nodeUpDelay() time.Duration { if c.Major >= 2 && c.Minor >= 2 { // CASSANDRA-8236 return 0 } return 10 * time.Second } type HostInfo struct { // TODO(zariel): reduce locking maybe, not all values will change, but to ensure // that we are thread safe use a mutex to access all fields. mu sync.RWMutex hostname string peer net.IP broadcastAddress net.IP listenAddress net.IP rpcAddress net.IP preferredIP net.IP connectAddress net.IP port int dataCenter string rack string hostId string workload string graph bool dseVersion string partitioner string clusterName string version cassVersion state nodeState schemaVersion string tokens []string } func (h *HostInfo) Equal(host *HostInfo) bool { if h == host { // prevent rlock reentry return true } return h.ConnectAddress().Equal(host.ConnectAddress()) } func (h *HostInfo) Peer() net.IP { h.mu.RLock() defer h.mu.RUnlock() return h.peer } func (h *HostInfo) invalidConnectAddr() bool { h.mu.RLock() defer h.mu.RUnlock() addr, _ := h.connectAddressLocked() return !validIpAddr(addr) } func validIpAddr(addr net.IP) bool { return addr != nil && !addr.IsUnspecified() } func (h *HostInfo) connectAddressLocked() (net.IP, string) { if validIpAddr(h.connectAddress) { return h.connectAddress, "connect_address" } else if validIpAddr(h.rpcAddress) { return h.rpcAddress, "rpc_adress" } else if validIpAddr(h.preferredIP) { // where does perferred_ip get set? return h.preferredIP, "preferred_ip" } else if validIpAddr(h.broadcastAddress) { return h.broadcastAddress, "broadcast_address" } else if validIpAddr(h.peer) { return h.peer, "peer" } return net.IPv4zero, "invalid" } // nodeToNodeAddress returns address broadcasted between node to nodes. // It's either `broadcast_address` if host info is read from system.local or `peer` if read from system.peers. // This IP address is also part of CQL Event emitted on topology/status changes, // but does not uniquely identify the node in case multiple nodes use the same IP address. func (h *HostInfo) nodeToNodeAddress() net.IP { h.mu.RLock() defer h.mu.RUnlock() if validIpAddr(h.broadcastAddress) { return h.broadcastAddress } else if validIpAddr(h.peer) { return h.peer } return net.IPv4zero } // Returns the address that should be used to connect to the host. // If you wish to override this, use an AddressTranslator or // use a HostFilter to SetConnectAddress() func (h *HostInfo) ConnectAddress() net.IP { h.mu.RLock() defer h.mu.RUnlock() if addr, _ := h.connectAddressLocked(); validIpAddr(addr) { return addr } panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h)) } func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo { // TODO(zariel): should this not be exported? h.mu.Lock() defer h.mu.Unlock() h.connectAddress = address return h } func (h *HostInfo) BroadcastAddress() net.IP { h.mu.RLock() defer h.mu.RUnlock() return h.broadcastAddress } func (h *HostInfo) ListenAddress() net.IP { h.mu.RLock() defer h.mu.RUnlock() return h.listenAddress } func (h *HostInfo) RPCAddress() net.IP { h.mu.RLock() defer h.mu.RUnlock() return h.rpcAddress } func (h *HostInfo) PreferredIP() net.IP { h.mu.RLock() defer h.mu.RUnlock() return h.preferredIP } func (h *HostInfo) DataCenter() string { h.mu.RLock() dc := h.dataCenter h.mu.RUnlock() return dc } func (h *HostInfo) Rack() string { h.mu.RLock() rack := h.rack h.mu.RUnlock() return rack } func (h *HostInfo) HostID() string { h.mu.RLock() defer h.mu.RUnlock() return h.hostId } func (h *HostInfo) SetHostID(hostID string) { h.mu.Lock() defer h.mu.Unlock() h.hostId = hostID } func (h *HostInfo) WorkLoad() string { h.mu.RLock() defer h.mu.RUnlock() return h.workload } func (h *HostInfo) Graph() bool { h.mu.RLock() defer h.mu.RUnlock() return h.graph } func (h *HostInfo) DSEVersion() string { h.mu.RLock() defer h.mu.RUnlock() return h.dseVersion } func (h *HostInfo) Partitioner() string { h.mu.RLock() defer h.mu.RUnlock() return h.partitioner } func (h *HostInfo) ClusterName() string { h.mu.RLock() defer h.mu.RUnlock() return h.clusterName } func (h *HostInfo) Version() cassVersion { h.mu.RLock() defer h.mu.RUnlock() return h.version } func (h *HostInfo) State() nodeState { h.mu.RLock() defer h.mu.RUnlock() return h.state } func (h *HostInfo) setState(state nodeState) *HostInfo { h.mu.Lock() defer h.mu.Unlock() h.state = state return h } func (h *HostInfo) Tokens() []string { h.mu.RLock() defer h.mu.RUnlock() return h.tokens } func (h *HostInfo) Port() int { h.mu.RLock() defer h.mu.RUnlock() return h.port } func (h *HostInfo) update(from *HostInfo) { if h == from { return } h.mu.Lock() defer h.mu.Unlock() from.mu.RLock() defer from.mu.RUnlock() // autogenerated do not update if h.peer == nil { h.peer = from.peer } if h.broadcastAddress == nil { h.broadcastAddress = from.broadcastAddress } if h.listenAddress == nil { h.listenAddress = from.listenAddress } if h.rpcAddress == nil { h.rpcAddress = from.rpcAddress } if h.preferredIP == nil { h.preferredIP = from.preferredIP } if h.connectAddress == nil { h.connectAddress = from.connectAddress } if h.port == 0 { h.port = from.port } if h.dataCenter == "" { h.dataCenter = from.dataCenter } if h.rack == "" { h.rack = from.rack } if h.hostId == "" { h.hostId = from.hostId } if h.workload == "" { h.workload = from.workload } if h.dseVersion == "" { h.dseVersion = from.dseVersion } if h.partitioner == "" { h.partitioner = from.partitioner } if h.clusterName == "" { h.clusterName = from.clusterName } if h.version == (cassVersion{}) { h.version = from.version } if h.tokens == nil { h.tokens = from.tokens } } func (h *HostInfo) IsUp() bool { return h != nil && h.State() == NodeUp } func (h *HostInfo) HostnameAndPort() string { h.mu.Lock() defer h.mu.Unlock() if h.hostname == "" { addr, _ := h.connectAddressLocked() h.hostname = addr.String() } return net.JoinHostPort(h.hostname, strconv.Itoa(h.port)) } func (h *HostInfo) ConnectAddressAndPort() string { h.mu.Lock() defer h.mu.Unlock() addr, _ := h.connectAddressLocked() return net.JoinHostPort(addr.String(), strconv.Itoa(h.port)) } func (h *HostInfo) String() string { h.mu.RLock() defer h.mu.RUnlock() connectAddr, source := h.connectAddressLocked() return fmt.Sprintf("[HostInfo hostname=%q connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+ "preferred_ip=%q connect_addr=%q connect_addr_source=%q "+ "port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]", h.hostname, h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP, connectAddr, source, h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens)) } // Polls system.peers at a specific interval to find new hosts type ringDescriber struct { session *Session mu sync.Mutex prevHosts []*HostInfo prevPartitioner string } // Returns true if we are using system_schema.keyspaces instead of system.schema_keyspaces func checkSystemSchema(control *controlConn) (bool, error) { iter := control.query("SELECT * FROM system_schema.keyspaces") if err := iter.err; err != nil { if errf, ok := err.(*errorFrame); ok { if errf.code == ErrCodeSyntax { return false, nil } } return false, err } return true, nil } // Given a map that represents a row from either system.local or system.peers // return as much information as we can in *HostInfo func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) (*HostInfo, error) { const assertErrorMsg = "Assertion failed for %s" var ok bool // Default to our connected port if the cluster doesn't have port information for key, value := range row { switch key { case "data_center": host.dataCenter, ok = value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "data_center") } case "rack": host.rack, ok = value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "rack") } case "host_id": hostId, ok := value.(UUID) if !ok { return nil, fmt.Errorf(assertErrorMsg, "host_id") } host.hostId = hostId.String() case "release_version": version, ok := value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "release_version") } host.version.Set(version) case "peer": ip, ok := value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "peer") } host.peer = net.ParseIP(ip) case "cluster_name": host.clusterName, ok = value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "cluster_name") } case "partitioner": host.partitioner, ok = value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "partitioner") } case "broadcast_address": ip, ok := value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "broadcast_address") } host.broadcastAddress = net.ParseIP(ip) case "preferred_ip": ip, ok := value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "preferred_ip") } host.preferredIP = net.ParseIP(ip) case "rpc_address": ip, ok := value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "rpc_address") } host.rpcAddress = net.ParseIP(ip) case "native_address": ip, ok := value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "native_address") } host.rpcAddress = net.ParseIP(ip) case "listen_address": ip, ok := value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "listen_address") } host.listenAddress = net.ParseIP(ip) case "native_port": native_port, ok := value.(int) if !ok { return nil, fmt.Errorf(assertErrorMsg, "native_port") } host.port = native_port case "workload": host.workload, ok = value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "workload") } case "graph": host.graph, ok = value.(bool) if !ok { return nil, fmt.Errorf(assertErrorMsg, "graph") } case "tokens": host.tokens, ok = value.([]string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "tokens") } case "dse_version": host.dseVersion, ok = value.(string) if !ok { return nil, fmt.Errorf(assertErrorMsg, "dse_version") } case "schema_version": schemaVersion, ok := value.(UUID) if !ok { return nil, fmt.Errorf(assertErrorMsg, "schema_version") } host.schemaVersion = schemaVersion.String() } // TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete // Not sure what the port field will be called until the JIRA issue is complete } ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port) host.connectAddress = ip host.port = port return host, nil } func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPort int) (*HostInfo, error) { rows, err := iter.SliceMap() if err != nil { // TODO(zariel): make typed error return nil, err } if len(rows) == 0 { return nil, errors.New("query returned 0 rows") } host, err := s.hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}) if err != nil { return nil, err } return host, nil } // Ask the control node for the local host info func (r *ringDescriber) getLocalHostInfo() (*HostInfo, error) { if r.session.control == nil { return nil, errNoControl } iter := r.session.control.withConnHost(func(ch *connHost) *Iter { return ch.conn.querySystemLocal(context.TODO()) }) if iter == nil { return nil, errNoControl } host, err := r.session.hostInfoFromIter(iter, nil, r.session.cfg.Port) if err != nil { return nil, fmt.Errorf("could not retrieve local host info: %w", err) } return host, nil } // Ask the control node for host info on all it's known peers func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo) ([]*HostInfo, error) { if r.session.control == nil { return nil, errNoControl } var peers []*HostInfo iter := r.session.control.withConnHost(func(ch *connHost) *Iter { return ch.conn.querySystemPeers(context.TODO(), localHost.version) }) if iter == nil { return nil, errNoControl } rows, err := iter.SliceMap() if err != nil { // TODO(zariel): make typed error return nil, fmt.Errorf("unable to fetch peer host info: %s", err) } for _, row := range rows { // extract all available info about the peer host, err := r.session.hostInfoFromMap(row, &HostInfo{port: r.session.cfg.Port}) if err != nil { return nil, err } else if !isValidPeer(host) { // If it's not a valid peer r.session.logger.Printf("Found invalid peer '%s' "+ "Likely due to a gossip or snitch issue, this host will be ignored", host) continue } peers = append(peers, host) } return peers, nil } // Return true if the host is a valid peer func isValidPeer(host *HostInfo) bool { return !(len(host.RPCAddress()) == 0 || host.hostId == "" || host.dataCenter == "" || host.rack == "" || len(host.tokens) == 0) } // GetHosts returns a list of hosts found via queries to system.local and system.peers func (r *ringDescriber) GetHosts() ([]*HostInfo, string, error) { r.mu.Lock() defer r.mu.Unlock() localHost, err := r.getLocalHostInfo() if err != nil { return r.prevHosts, r.prevPartitioner, err } peerHosts, err := r.getClusterPeerInfo(localHost) if err != nil { return r.prevHosts, r.prevPartitioner, err } hosts := append([]*HostInfo{localHost}, peerHosts...) var partitioner string if len(hosts) > 0 { partitioner = hosts[0].Partitioner() } return hosts, partitioner, nil } // debounceRingRefresh submits a ring refresh request to the ring refresh debouncer. func (s *Session) debounceRingRefresh() { s.ringRefresher.debounce() } // refreshRing executes a ring refresh immediately and cancels pending debounce ring refresh requests. func (s *Session) refreshRing() error { err, ok := <-s.ringRefresher.refreshNow() if !ok { return errors.New("could not refresh ring because stop was requested") } return err } func refreshRing(r *ringDescriber) error { hosts, partitioner, err := r.GetHosts() if err != nil { return err } prevHosts := r.session.ring.currentHosts() for _, h := range hosts { if r.session.cfg.filterHost(h) { continue } if host, ok := r.session.ring.addHostIfMissing(h); !ok { r.session.startPoolFill(h) } else { // host (by hostID) already exists; determine if IP has changed newHostID := h.HostID() existing, ok := prevHosts[newHostID] if !ok { return fmt.Errorf("get existing host=%s from prevHosts: %w", h, ErrCannotFindHost) } if h.connectAddress.Equal(existing.connectAddress) && h.nodeToNodeAddress().Equal(existing.nodeToNodeAddress()) { // no host IP change host.update(h) } else { // host IP has changed // remove old HostInfo (w/old IP) r.session.removeHost(existing) if _, alreadyExists := r.session.ring.addHostIfMissing(h); alreadyExists { return fmt.Errorf("add new host=%s after removal: %w", h, ErrHostAlreadyExists) } // add new HostInfo (same hostID, new IP) r.session.startPoolFill(h) } } delete(prevHosts, h.HostID()) } for _, host := range prevHosts { r.session.removeHost(host) } r.session.metadata.setPartitioner(partitioner) r.session.policy.SetPartitioner(partitioner) return nil } const ( ringRefreshDebounceTime = 1 * time.Second ) // debounces requests to call a refresh function (currently used for ring refresh). It also supports triggering a refresh immediately. type refreshDebouncer struct { mu sync.Mutex stopped bool broadcaster *errorBroadcaster interval time.Duration timer *time.Timer refreshNowCh chan struct{} quit chan struct{} refreshFn func() error } func newRefreshDebouncer(interval time.Duration, refreshFn func() error) *refreshDebouncer { d := &refreshDebouncer{ stopped: false, broadcaster: nil, refreshNowCh: make(chan struct{}, 1), quit: make(chan struct{}), interval: interval, timer: time.NewTimer(interval), refreshFn: refreshFn, } d.timer.Stop() go d.flusher() return d } // debounces a request to call the refresh function func (d *refreshDebouncer) debounce() { d.mu.Lock() defer d.mu.Unlock() if d.stopped { return } d.timer.Reset(d.interval) } // requests an immediate refresh which will cancel pending refresh requests func (d *refreshDebouncer) refreshNow() <-chan error { d.mu.Lock() defer d.mu.Unlock() if d.broadcaster == nil { d.broadcaster = newErrorBroadcaster() select { case d.refreshNowCh <- struct{}{}: default: // already a refresh pending } } return d.broadcaster.newListener() } func (d *refreshDebouncer) flusher() { for { select { case <-d.refreshNowCh: case <-d.timer.C: case <-d.quit: } d.mu.Lock() if d.stopped { if d.broadcaster != nil { d.broadcaster.stop() d.broadcaster = nil } d.timer.Stop() d.mu.Unlock() return } // make sure both request channels are cleared before we refresh select { case <-d.refreshNowCh: default: } d.timer.Stop() select { case <-d.timer.C: default: } curBroadcaster := d.broadcaster d.broadcaster = nil d.mu.Unlock() err := d.refreshFn() if curBroadcaster != nil { curBroadcaster.broadcast(err) } } } func (d *refreshDebouncer) stop() { d.mu.Lock() if d.stopped { d.mu.Unlock() return } d.stopped = true d.mu.Unlock() d.quit <- struct{}{} // sync with flusher close(d.quit) } // broadcasts an error to multiple channels (listeners) type errorBroadcaster struct { listeners []chan<- error mu sync.Mutex } func newErrorBroadcaster() *errorBroadcaster { return &errorBroadcaster{ listeners: nil, mu: sync.Mutex{}, } } func (b *errorBroadcaster) newListener() <-chan error { ch := make(chan error, 1) b.mu.Lock() defer b.mu.Unlock() b.listeners = append(b.listeners, ch) return ch } func (b *errorBroadcaster) broadcast(err error) { b.mu.Lock() defer b.mu.Unlock() curListeners := b.listeners if len(curListeners) > 0 { b.listeners = nil } else { return } for _, listener := range curListeners { listener <- err close(listener) } } func (b *errorBroadcaster) stop() { b.mu.Lock() defer b.mu.Unlock() if len(b.listeners) == 0 { return } for _, listener := range b.listeners { close(listener) } b.listeners = nil } cassandra-gocql-driver-1.7.0/host_source_gen.go000066400000000000000000000035011467504044300215750ustar00rootroot00000000000000//go:build genhostinfo // +build genhostinfo /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package main import ( "fmt" "reflect" "sync" gocql "github.com/gocql/gocql" ) func gen(clause, field string) { fmt.Printf("if h.%s == %s {\n", field, clause) fmt.Printf("\th.%s = from.%s\n", field, field) fmt.Println("}") } func main() { t := reflect.ValueOf(&gocql.HostInfo{}).Elem().Type() mu := reflect.TypeOf(sync.RWMutex{}) for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Type == mu { continue } switch f.Type.Kind() { case reflect.Slice: gen("nil", f.Name) case reflect.String: gen(`""`, f.Name) case reflect.Int: gen("0", f.Name) case reflect.Struct: gen("("+f.Type.Name()+"{})", f.Name) case reflect.Bool, reflect.Int32: continue default: panic(fmt.Sprintf("unknown field: %s", f)) } } } cassandra-gocql-driver-1.7.0/host_source_test.go000066400000000000000000000242701467504044300220110ustar00rootroot00000000000000//go:build all || unit // +build all unit /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "errors" "net" "sync" "sync/atomic" "testing" "time" ) func TestUnmarshalCassVersion(t *testing.T) { tests := [...]struct { data string version cassVersion }{ {"3.2", cassVersion{3, 2, 0}}, {"2.10.1-SNAPSHOT", cassVersion{2, 10, 1}}, {"1.2.3", cassVersion{1, 2, 3}}, } for i, test := range tests { v := &cassVersion{} if err := v.UnmarshalCQL(nil, []byte(test.data)); err != nil { t.Errorf("%d: %v", i, err) } else if *v != test.version { t.Errorf("%d: expected %#+v got %#+v", i, test.version, *v) } } } func TestCassVersionBefore(t *testing.T) { tests := [...]struct { version cassVersion major, minor, patch int }{ {cassVersion{1, 0, 0}, 0, 0, 0}, {cassVersion{0, 1, 0}, 0, 0, 0}, {cassVersion{0, 0, 1}, 0, 0, 0}, {cassVersion{1, 0, 0}, 0, 1, 0}, {cassVersion{0, 1, 0}, 0, 0, 1}, {cassVersion{4, 1, 0}, 3, 1, 2}, } for i, test := range tests { if test.version.Before(test.major, test.minor, test.patch) { t.Errorf("%d: expected v%d.%d.%d to be before %v", i, test.major, test.minor, test.patch, test.version) } } } func TestIsValidPeer(t *testing.T) { host := &HostInfo{ rpcAddress: net.ParseIP("0.0.0.0"), rack: "myRack", hostId: "0", dataCenter: "datacenter", tokens: []string{"0", "1"}, } if !isValidPeer(host) { t.Errorf("expected %+v to be a valid peer", host) } host.rack = "" if isValidPeer(host) { t.Errorf("expected %+v to NOT be a valid peer", host) } } func TestHostInfo_ConnectAddress(t *testing.T) { var localhost = net.IPv4(127, 0, 0, 1) tests := []struct { name string connectAddr net.IP rpcAddr net.IP broadcastAddr net.IP peer net.IP }{ {name: "rpc_address", rpcAddr: localhost}, {name: "connect_address", connectAddr: localhost}, {name: "broadcast_address", broadcastAddr: localhost}, {name: "peer", peer: localhost}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { host := &HostInfo{ connectAddress: test.connectAddr, rpcAddress: test.rpcAddr, broadcastAddress: test.broadcastAddr, peer: test.peer, } if addr := host.ConnectAddress(); !addr.Equal(localhost) { t.Fatalf("expected ConnectAddress to be %s got %s", localhost, addr) } }) } } // This test sends debounce requests and waits until the refresh function is called (which should happen when the timer elapses). func TestRefreshDebouncer_MultipleEvents(t *testing.T) { const numberOfEvents = 10 channel := make(chan int, numberOfEvents) // should never use more than 1 but allow for more to possibly detect bugs fn := func() error { channel <- 0 return nil } beforeEvents := time.Now() wg := sync.WaitGroup{} d := newRefreshDebouncer(2*time.Second, fn) defer d.stop() for i := 0; i < numberOfEvents; i++ { wg.Add(1) go func() { defer wg.Done() d.debounce() }() } wg.Wait() timeoutCh := time.After(2500 * time.Millisecond) // extra time to avoid flakiness select { case <-channel: case <-timeoutCh: t.Fatalf("timeout elapsed without flush function being called") } afterFunctionCall := time.Now() // use 1.5 seconds instead of 2 seconds to avoid timer precision issues if afterFunctionCall.Sub(beforeEvents) < 1500*time.Millisecond { t.Fatalf("function was called after %v ms instead of ~2 seconds", afterFunctionCall.Sub(beforeEvents).Milliseconds()) } // wait another 2 seconds and check if function was called again time.Sleep(2500 * time.Millisecond) if len(channel) > 0 { t.Fatalf("function was called more than once") } } // This test: // // 1 - Sends debounce requests when test starts // 2 - Calls refreshNow() before the timer elapsed (which stops the timer) about 1.5 seconds after test starts // // The end result should be 1 refresh function call when refreshNow() is called. func TestRefreshDebouncer_RefreshNow(t *testing.T) { const numberOfEvents = 10 channel := make(chan int, numberOfEvents) // should never use more than 1 but allow for more to possibly detect bugs fn := func() error { channel <- 0 return nil } beforeEvents := time.Now() eventsWg := sync.WaitGroup{} d := newRefreshDebouncer(2*time.Second, fn) defer d.stop() for i := 0; i < numberOfEvents; i++ { eventsWg.Add(1) go func() { defer eventsWg.Done() d.debounce() }() } refreshNowWg := sync.WaitGroup{} refreshNowWg.Add(1) go func() { defer refreshNowWg.Done() time.Sleep(1500 * time.Millisecond) d.refreshNow() }() eventsWg.Wait() select { case <-channel: t.Fatalf("function was called before the expected time") default: } refreshNowWg.Wait() timeoutCh := time.After(200 * time.Millisecond) // allow for 200ms of delay to prevent flakiness select { case <-channel: case <-timeoutCh: t.Fatalf("timeout elapsed without flush function being called") } afterFunctionCall := time.Now() // use 1 second instead of 1.5s to avoid timer precision issues if afterFunctionCall.Sub(beforeEvents) < 1000*time.Millisecond { t.Fatalf("function was called after %v ms instead of ~1.5 seconds", afterFunctionCall.Sub(beforeEvents).Milliseconds()) } // wait some time and check if function was called again time.Sleep(2500 * time.Millisecond) if len(channel) > 0 { t.Fatalf("function was called more than once") } } // This test: // // 1 - Sends debounce requests when test starts // 2 - Calls refreshNow() before the timer elapsed (which stops the timer) about 1 second after test starts // 3 - Sends more debounce requests (which resets the timer with a 3-second interval) about 2 seconds after test starts // // The end result should be 2 refresh function calls: // // 1 - When refreshNow() is called (1 second after the test starts) // 2 - When the timer elapses after the second "wave" of debounce requests (5 seconds after the test starts) func TestRefreshDebouncer_EventsAfterRefreshNow(t *testing.T) { const numberOfEvents = 10 channel := make(chan int, numberOfEvents) // should never use more than 2 but allow for more to possibly detect bugs fn := func() error { channel <- 0 return nil } beforeEvents := time.Now() wg := sync.WaitGroup{} d := newRefreshDebouncer(3*time.Second, fn) defer d.stop() for i := 0; i < numberOfEvents; i++ { wg.Add(1) go func() { defer wg.Done() d.debounce() time.Sleep(2000 * time.Millisecond) d.debounce() }() } go func() { time.Sleep(1 * time.Second) d.refreshNow() }() wg.Wait() timeoutCh := time.After(1500 * time.Millisecond) // extra 500ms to prevent flakiness select { case <-channel: case <-timeoutCh: t.Fatalf("timeout elapsed without flush function being called after refreshNow()") } afterFunctionCall := time.Now() // use 500ms instead of 1s to avoid timer precision issues if afterFunctionCall.Sub(beforeEvents) < 500*time.Millisecond { t.Fatalf("function was called after %v ms instead of ~1 second", afterFunctionCall.Sub(beforeEvents).Milliseconds()) } timeoutCh = time.After(4 * time.Second) // extra 1s to prevent flakiness select { case <-channel: case <-timeoutCh: t.Fatalf("timeout elapsed without flush function being called after debounce requests") } afterSecondFunctionCall := time.Now() // use 2.5s instead of 3s to avoid timer precision issues if afterSecondFunctionCall.Sub(afterFunctionCall) < 2500*time.Millisecond { t.Fatalf("function was called after %v ms instead of ~3 seconds", afterSecondFunctionCall.Sub(afterFunctionCall).Milliseconds()) } if len(channel) > 0 { t.Fatalf("function was called more than twice") } } func TestErrorBroadcaster_MultipleListeners(t *testing.T) { b := newErrorBroadcaster() defer b.stop() const numberOfListeners = 10 var listeners []<-chan error for i := 0; i < numberOfListeners; i++ { listeners = append(listeners, b.newListener()) } err := errors.New("expected error") wg := sync.WaitGroup{} result := atomic.Value{} for _, listener := range listeners { currentListener := listener wg.Add(1) go func() { defer wg.Done() receivedErr, ok := <-currentListener if !ok { result.Store(errors.New("listener was closed")) } else if receivedErr != err { result.Store(errors.New("expected received error to be the same as the one that was broadcasted")) } }() } wg.Add(1) go func() { defer wg.Done() b.broadcast(err) b.stop() }() wg.Wait() if loadedVal := result.Load(); loadedVal != nil { t.Errorf(loadedVal.(error).Error()) } } func TestErrorBroadcaster_StopWithoutBroadcast(t *testing.T) { var b = newErrorBroadcaster() defer b.stop() const numberOfListeners = 10 var listeners []<-chan error for i := 0; i < numberOfListeners; i++ { listeners = append(listeners, b.newListener()) } wg := sync.WaitGroup{} result := atomic.Value{} for _, listener := range listeners { currentListener := listener wg.Add(1) go func() { defer wg.Done() // broadcaster stopped, expect listener to be closed _, ok := <-currentListener if ok { result.Store(errors.New("expected listener to be closed")) } }() } wg.Add(1) go func() { defer wg.Done() // call stop without broadcasting anything to current listeners b.stop() }() wg.Wait() if loadedVal := result.Load(); loadedVal != nil { t.Errorf(loadedVal.(error).Error()) } } cassandra-gocql-driver-1.7.0/install_test_deps.sh000077500000000000000000000020421467504044300221360ustar00rootroot00000000000000#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # set -x # This is not supposed to be an error-prone script; just a convenience. # Install CCM pip install -i https://pypi.org/simple --user cql PyYAML six psutil git clone https://github.com/pcmanus/ccm.git pushd ccm ./setup.py install --user popd cassandra-gocql-driver-1.7.0/integration.sh000077500000000000000000000065041467504044300207500ustar00rootroot00000000000000#!/bin/bash # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # set -eux function run_tests() { local clusterSize=3 local version=$1 local auth=$2 local compressor=$3 if [ "$auth" = true ]; then clusterSize=1 fi local keypath="$(pwd)/testdata/pki" local conf=( "client_encryption_options.enabled: true" "client_encryption_options.keystore: $keypath/.keystore" "client_encryption_options.keystore_password: cassandra" "client_encryption_options.require_client_auth: true" "client_encryption_options.truststore: $keypath/.truststore" "client_encryption_options.truststore_password: cassandra" "concurrent_reads: 2" "concurrent_writes: 2" "rpc_server_type: sync" "rpc_min_threads: 2" "rpc_max_threads: 2" "write_request_timeout_in_ms: 5000" "read_request_timeout_in_ms: 5000" ) ccm remove test || true ccm create test -v $version -n $clusterSize -d --vnodes --jvm_arg="-Xmx256m -XX:NewSize=100m" ccm updateconf "${conf[@]}" if [ "$auth" = true ] then ccm updateconf 'authenticator: PasswordAuthenticator' 'authorizer: CassandraAuthorizer' rm -rf $HOME/.ccm/test/node1/data/system_auth fi local proto=2 if [[ $version == 1.2.* ]]; then proto=1 elif [[ $version == 2.0.* ]]; then proto=2 elif [[ $version == 2.1.* ]]; then proto=3 elif [[ $version == 2.2.* || $version == 3.0.* ]]; then proto=4 ccm updateconf 'enable_user_defined_functions: true' export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" elif [[ $version == 3.*.* ]]; then proto=5 ccm updateconf 'enable_user_defined_functions: true' export JVM_EXTRA_OPTS=" -Dcassandra.test.fail_writes_ks=test -Dcassandra.custom_query_handler_class=org.apache.cassandra.cql3.CustomPayloadMirroringQueryHandler" fi sleep 1s ccm list ccm start --wait-for-binary-proto ccm status ccm node1 nodetool status local args="-gocql.timeout=60s -runssl -proto=$proto -rf=3 -clusterSize=$clusterSize -autowait=2000ms -compressor=$compressor -gocql.cversion=$version -cluster=$(ccm liveset) ./..." go test -v -tags unit -race if [ "$auth" = true ] then sleep 30s go test -run=TestAuthentication -tags "integration gocql_debug" -timeout=15s -runauth $args else sleep 1s go test -tags "cassandra gocql_debug" -timeout=5m -race $args ccm clear ccm start --wait-for-binary-proto sleep 1s go test -tags "integration gocql_debug" -timeout=5m -race $args ccm clear ccm start --wait-for-binary-proto sleep 1s go test -tags "ccm gocql_debug" -timeout=5m -race $args fi ccm remove } run_tests $1 $2 $3 cassandra-gocql-driver-1.7.0/integration_test.go000066400000000000000000000211771467504044300220020ustar00rootroot00000000000000//go:build all || integration // +build all integration /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql // This file groups integration tests where Cassandra has to be set up with some special integration variables import ( "context" "reflect" "testing" "time" ) // TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections func TestAuthentication(t *testing.T) { if *flagProto < 2 { t.Skip("Authentication is not supported with protocol < 2") } if !*flagRunAuthTest { t.Skip("Authentication is not configured in the target cluster") } cluster := createCluster() cluster.Authenticator = PasswordAuthenticator{ Username: "cassandra", Password: "cassandra", } session, err := cluster.CreateSession() if err != nil { t.Fatalf("Authentication error: %s", err) } session.Close() } func TestGetHosts(t *testing.T) { clusterHosts := getClusterHosts() cluster := createCluster() session := createSessionFromCluster(cluster, t) hosts, partitioner, err := session.hostSource.GetHosts() assertTrue(t, "err == nil", err == nil) assertEqual(t, "len(hosts)", len(clusterHosts), len(hosts)) assertTrue(t, "len(partitioner) != 0", len(partitioner) != 0) } // TestRingDiscovery makes sure that you can autodiscover other cluster members // when you seed a cluster config with just one node func TestRingDiscovery(t *testing.T) { clusterHosts := getClusterHosts() cluster := createCluster() cluster.Hosts = clusterHosts[:1] session := createSessionFromCluster(cluster, t) defer session.Close() if *clusterSize > 1 { // wait for autodiscovery to update the pool with the list of known hosts time.Sleep(*flagAutoWait) } session.pool.mu.RLock() defer session.pool.mu.RUnlock() size := len(session.pool.hostConnPools) if *clusterSize != size { for p, pool := range session.pool.hostConnPools { t.Logf("p=%q host=%v ips=%s", p, pool.host, pool.host.ConnectAddress().String()) } t.Errorf("Expected a cluster size of %d, but actual size was %d", *clusterSize, size) } } // TestHostFilterDiscovery ensures that host filtering works even when we discover hosts func TestHostFilterDiscovery(t *testing.T) { clusterHosts := getClusterHosts() if len(clusterHosts) < 2 { t.Skip("skipping because we don't have 2 or more hosts") } cluster := createCluster() rr := RoundRobinHostPolicy().(*roundRobinHostPolicy) cluster.PoolConfig.HostSelectionPolicy = rr // we'll filter out the second host filtered := clusterHosts[1] cluster.Hosts = clusterHosts[:1] cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool { if host.ConnectAddress().String() == filtered { return false } return true }) session := createSessionFromCluster(cluster, t) defer session.Close() assertEqual(t, "len(clusterHosts)-1 != len(rr.hosts.get())", len(clusterHosts)-1, len(rr.hosts.get())) } // TestHostFilterInitial ensures that host filtering works for the initial // connection including the control connection func TestHostFilterInitial(t *testing.T) { clusterHosts := getClusterHosts() if len(clusterHosts) < 2 { t.Skip("skipping because we don't have 2 or more hosts") } cluster := createCluster() rr := RoundRobinHostPolicy().(*roundRobinHostPolicy) cluster.PoolConfig.HostSelectionPolicy = rr // we'll filter out the second host filtered := clusterHosts[1] cluster.HostFilter = HostFilterFunc(func(host *HostInfo) bool { if host.ConnectAddress().String() == filtered { return false } return true }) session := createSessionFromCluster(cluster, t) defer session.Close() assertEqual(t, "len(clusterHosts)-1 != len(rr.hosts.get())", len(clusterHosts)-1, len(rr.hosts.get())) } func TestWriteFailure(t *testing.T) { cluster := createCluster() createKeyspace(t, cluster, "test") cluster.Keyspace = "test" session, err := cluster.CreateSession() if err != nil { t.Fatal("create session:", err) } defer session.Close() if err := createTable(session, "CREATE TABLE test.test (id int,value int,PRIMARY KEY (id))"); err != nil { t.Fatalf("failed to create table with error '%v'", err) } if err := session.Query(`INSERT INTO test.test (id, value) VALUES (1, 1)`).Exec(); err != nil { errWrite, ok := err.(*RequestErrWriteFailure) if ok { if session.cfg.ProtoVersion >= 5 { // ErrorMap should be filled with some hosts that should've errored if len(errWrite.ErrorMap) == 0 { t.Fatal("errWrite.ErrorMap should have some failed hosts but it didn't have any") } } else { // Map doesn't get filled for V4 if len(errWrite.ErrorMap) != 0 { t.Fatal("errWrite.ErrorMap should have length 0, it's: ", len(errWrite.ErrorMap)) } } } else { t.Fatal("error should be RequestErrWriteFailure, it's: ", errWrite) } } else { t.Fatal("a write fail error should have happened when querying test keyspace") } if err = session.Query("DROP KEYSPACE test").Exec(); err != nil { t.Fatal(err) } } func TestCustomPayloadMessages(t *testing.T) { cluster := createCluster() session := createSessionFromCluster(cluster, t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.testCustomPayloadMessages (id int, value int, PRIMARY KEY (id))"); err != nil { t.Fatal(err) } // QueryMessage var customPayload = map[string][]byte{"a": []byte{10, 20}, "b": []byte{20, 30}} query := session.Query("SELECT id FROM testCustomPayloadMessages where id = ?", 42).Consistency(One).CustomPayload(customPayload) iter := query.Iter() rCustomPayload := iter.GetCustomPayload() if !reflect.DeepEqual(customPayload, rCustomPayload) { t.Fatal("The received custom payload should match the sent") } iter.Close() // Insert query query = session.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)").Consistency(One).CustomPayload(customPayload) iter = query.Iter() rCustomPayload = iter.GetCustomPayload() if !reflect.DeepEqual(customPayload, rCustomPayload) { t.Fatal("The received custom payload should match the sent") } iter.Close() // Batch Message b := session.NewBatch(LoggedBatch) b.CustomPayload = customPayload b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)") if err := session.ExecuteBatch(b); err != nil { t.Fatalf("query failed. %v", err) } } func TestCustomPayloadValues(t *testing.T) { cluster := createCluster() session := createSessionFromCluster(cluster, t) defer session.Close() if err := createTable(session, "CREATE TABLE gocql_test.testCustomPayloadValues (id int, value int, PRIMARY KEY (id))"); err != nil { t.Fatal(err) } values := []map[string][]byte{map[string][]byte{"a": []byte{10, 20}, "b": []byte{20, 30}}, nil, map[string][]byte{"a": []byte{10, 20}, "b": nil}} for _, customPayload := range values { query := session.Query("SELECT id FROM testCustomPayloadValues where id = ?", 42).Consistency(One).CustomPayload(customPayload) iter := query.Iter() rCustomPayload := iter.GetCustomPayload() if !reflect.DeepEqual(customPayload, rCustomPayload) { t.Fatal("The received custom payload should match the sent") } } } func TestSessionAwaitSchemaAgreement(t *testing.T) { session := createSession(t) defer session.Close() if err := session.AwaitSchemaAgreement(context.Background()); err != nil { t.Fatalf("expected session.AwaitSchemaAgreement to not return an error but got '%v'", err) } } func TestUDF(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < 4 { t.Skip("skipping UDF support on proto < 4") } const query = `CREATE OR REPLACE FUNCTION uniq(state set, val text) CALLED ON NULL INPUT RETURNS set LANGUAGE java AS 'state.add(val); return state;'` err := session.Query(query).Exec() if err != nil { t.Fatal(err) } } cassandra-gocql-driver-1.7.0/internal/000077500000000000000000000000001467504044300176755ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/internal/ccm/000077500000000000000000000000001467504044300204375ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/internal/ccm/ccm.go000066400000000000000000000105711467504044300215340ustar00rootroot00000000000000//go:build ccm // +build ccm /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package ccm import ( "bufio" "bytes" "errors" "fmt" "os/exec" "runtime" "strings" ) func execCmd(args ...string) (*bytes.Buffer, error) { execName := "ccm" if runtime.GOOS == "windows" { args = append([]string{"/c", execName}, args...) execName = "cmd.exe" } cmd := exec.Command(execName, args...) stdout := &bytes.Buffer{} cmd.Stdout = stdout cmd.Stderr = &bytes.Buffer{} if err := cmd.Run(); err != nil { return nil, errors.New(cmd.Stderr.(*bytes.Buffer).String()) } return stdout, nil } func AllUp() error { status, err := Status() if err != nil { return err } for _, host := range status { if !host.State.IsUp() { if err := NodeUp(host.Name); err != nil { return err } } } return nil } func NodeUp(node string) error { args := []string{node, "start", "--wait-for-binary-proto"} if runtime.GOOS == "windows" { args = append(args, "--quiet-windows") } _, err := execCmd(args...) return err } func NodeDown(node string) error { _, err := execCmd(node, "stop") return err } type Host struct { State NodeState Addr string Name string } type NodeState int func (n NodeState) String() string { if n == NodeStateUp { return "UP" } else if n == NodeStateDown { return "DOWN" } else { return fmt.Sprintf("UNKNOWN_STATE_%d", n) } } func (n NodeState) IsUp() bool { return n == NodeStateUp } const ( NodeStateUp NodeState = iota NodeStateDown ) func Status() (map[string]Host, error) { // TODO: parse into struct to manipulate out, err := execCmd("status", "-v") if err != nil { return nil, err } const ( stateCluster = iota stateCommas stateNode stateOption ) nodes := make(map[string]Host) // didnt really want to write a full state machine parser state := stateCluster sc := bufio.NewScanner(out) var host Host for sc.Scan() { switch state { case stateCluster: text := sc.Text() if !strings.HasPrefix(text, "Cluster:") { return nil, fmt.Errorf("expected 'Cluster:' got %q", text) } state = stateCommas case stateCommas: text := sc.Text() if !strings.HasPrefix(text, "-") { return nil, fmt.Errorf("expected commas got %q", text) } state = stateNode case stateNode: // assume nodes start with node text := sc.Text() if !strings.HasPrefix(text, "node") { return nil, fmt.Errorf("expected 'node' got %q", text) } line := strings.Split(text, ":") host.Name = line[0] nodeState := strings.TrimSpace(line[1]) switch nodeState { case "UP": host.State = NodeStateUp case "DOWN": host.State = NodeStateDown default: return nil, fmt.Errorf("unknown node state from ccm: %q", nodeState) } state = stateOption case stateOption: text := sc.Text() if text == "" { state = stateNode nodes[host.Name] = host host = Host{} continue } line := strings.Split(strings.TrimSpace(text), "=") k, v := line[0], line[1] if k == "binary" { // could check errors // ('127.0.0.1', 9042) v = v[2:] // ('' if i := strings.IndexByte(v, '\''); i < 0 { return nil, fmt.Errorf("invalid binary v=%q", v) } else { host.Addr = v[:i] // dont need port } } default: return nil, fmt.Errorf("unexpected state: %q", state) } } if err := sc.Err(); err != nil { return nil, fmt.Errorf("unable to parse ccm status: %v", err) } return nodes, nil } cassandra-gocql-driver-1.7.0/internal/ccm/ccm_test.go000066400000000000000000000034551467504044300225760ustar00rootroot00000000000000//go:build ccm // +build ccm /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package ccm import ( "testing" ) func TestCCM(t *testing.T) { if err := AllUp(); err != nil { t.Fatal(err) } status, err := Status() if err != nil { t.Fatal(err) } if host, ok := status["node1"]; !ok { t.Fatal("node1 not in status list") } else if !host.State.IsUp() { t.Fatal("node1 is not up") } NodeDown("node1") status, err = Status() if err != nil { t.Fatal(err) } if host, ok := status["node1"]; !ok { t.Fatal("node1 not in status list") } else if host.State.IsUp() { t.Fatal("node1 is not down") } NodeUp("node1") status, err = Status() if err != nil { t.Fatal(err) } if host, ok := status["node1"]; !ok { t.Fatal("node1 not in status list") } else if !host.State.IsUp() { t.Fatal("node1 is not up") } } cassandra-gocql-driver-1.7.0/internal/lru/000077500000000000000000000000001467504044300204775ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/internal/lru/lru.go000066400000000000000000000077271467504044300216450ustar00rootroot00000000000000/* Copyright 2015 To gocql authors Copyright 2013 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ // Package lru implements an LRU cache. /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package lru import "container/list" // Cache is an LRU cache. It is not safe for concurrent access. // // This cache has been forked from github.com/golang/groupcache/lru, but // specialized with string keys to avoid the allocations caused by wrapping them // in interface{}. type Cache struct { // MaxEntries is the maximum number of cache entries before // an item is evicted. Zero means no limit. MaxEntries int // OnEvicted optionally specifies a callback function to be // executed when an entry is purged from the cache. OnEvicted func(key string, value interface{}) ll *list.List cache map[string]*list.Element } type entry struct { key string value interface{} } // New creates a new Cache. // If maxEntries is zero, the cache has no limit and it's assumed // that eviction is done by the caller. func New(maxEntries int) *Cache { return &Cache{ MaxEntries: maxEntries, ll: list.New(), cache: make(map[string]*list.Element), } } // Add adds a value to the cache. func (c *Cache) Add(key string, value interface{}) { if c.cache == nil { c.cache = make(map[string]*list.Element) c.ll = list.New() } if ee, ok := c.cache[key]; ok { c.ll.MoveToFront(ee) ee.Value.(*entry).value = value return } ele := c.ll.PushFront(&entry{key, value}) c.cache[key] = ele if c.MaxEntries != 0 && c.ll.Len() > c.MaxEntries { c.RemoveOldest() } } // Get looks up a key's value from the cache. func (c *Cache) Get(key string) (value interface{}, ok bool) { if c.cache == nil { return } if ele, hit := c.cache[key]; hit { c.ll.MoveToFront(ele) return ele.Value.(*entry).value, true } return } // Remove removes the provided key from the cache. func (c *Cache) Remove(key string) bool { if c.cache == nil { return false } if ele, hit := c.cache[key]; hit { c.removeElement(ele) return true } return false } // RemoveOldest removes the oldest item from the cache. func (c *Cache) RemoveOldest() { if c.cache == nil { return } ele := c.ll.Back() if ele != nil { c.removeElement(ele) } } func (c *Cache) removeElement(e *list.Element) { c.ll.Remove(e) kv := e.Value.(*entry) delete(c.cache, kv.key) if c.OnEvicted != nil { c.OnEvicted(kv.key, kv.value) } } // Len returns the number of items in the cache. func (c *Cache) Len() int { if c.cache == nil { return 0 } return c.ll.Len() } cassandra-gocql-driver-1.7.0/internal/lru/lru_test.go000066400000000000000000000051651467504044300226760ustar00rootroot00000000000000/* Copyright 2015 To gocql authors Copyright 2013 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package lru import ( "testing" ) var getTests = []struct { name string keyToAdd string keyToGet string expectedOk bool }{ {"string_hit", "mystring", "mystring", true}, {"string_miss", "mystring", "nonsense", false}, {"simple_struct_hit", "two", "two", true}, {"simeple_struct_miss", "two", "noway", false}, } func TestGet(t *testing.T) { for _, tt := range getTests { lru := New(0) lru.Add(tt.keyToAdd, 1234) val, ok := lru.Get(tt.keyToGet) if ok != tt.expectedOk { t.Fatalf("%s: cache hit = %v; want %v", tt.name, ok, !ok) } else if ok && val != 1234 { t.Fatalf("%s expected get to return 1234 but got %v", tt.name, val) } } } func TestRemove(t *testing.T) { lru := New(0) lru.Add("mystring", 1234) if val, ok := lru.Get("mystring"); !ok { t.Fatal("TestRemove returned no match") } else if val != 1234 { t.Fatalf("TestRemove failed. Expected %d, got %v", 1234, val) } lru.Remove("mystring") if _, ok := lru.Get("mystring"); ok { t.Fatal("TestRemove returned a removed entry") } } cassandra-gocql-driver-1.7.0/internal/murmur/000077500000000000000000000000001467504044300212245ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/internal/murmur/murmur.go000066400000000000000000000063471467504044300231140ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package murmur const ( c1 int64 = -8663945395140668459 // 0x87c37b91114253d5 c2 int64 = 5545529020109919103 // 0x4cf5ad432745937f fmix1 int64 = -49064778989728563 // 0xff51afd7ed558ccd fmix2 int64 = -4265267296055464877 // 0xc4ceb9fe1a85ec53 ) func fmix(n int64) int64 { // cast to unsigned for logical right bitshift (to match C* MM3 implementation) n ^= int64(uint64(n) >> 33) n *= fmix1 n ^= int64(uint64(n) >> 33) n *= fmix2 n ^= int64(uint64(n) >> 33) return n } func block(p byte) int64 { return int64(int8(p)) } func rotl(x int64, r uint8) int64 { // cast to unsigned for logical right bitshift (to match C* MM3 implementation) return (x << r) | (int64)((uint64(x) >> (64 - r))) } func Murmur3H1(data []byte) int64 { length := len(data) var h1, h2, k1, k2 int64 // body nBlocks := length / 16 for i := 0; i < nBlocks; i++ { k1, k2 = getBlock(data, i) k1 *= c1 k1 = rotl(k1, 31) k1 *= c2 h1 ^= k1 h1 = rotl(h1, 27) h1 += h2 h1 = h1*5 + 0x52dce729 k2 *= c2 k2 = rotl(k2, 33) k2 *= c1 h2 ^= k2 h2 = rotl(h2, 31) h2 += h1 h2 = h2*5 + 0x38495ab5 } // tail tail := data[nBlocks*16:] k1 = 0 k2 = 0 switch length & 15 { case 15: k2 ^= block(tail[14]) << 48 fallthrough case 14: k2 ^= block(tail[13]) << 40 fallthrough case 13: k2 ^= block(tail[12]) << 32 fallthrough case 12: k2 ^= block(tail[11]) << 24 fallthrough case 11: k2 ^= block(tail[10]) << 16 fallthrough case 10: k2 ^= block(tail[9]) << 8 fallthrough case 9: k2 ^= block(tail[8]) k2 *= c2 k2 = rotl(k2, 33) k2 *= c1 h2 ^= k2 fallthrough case 8: k1 ^= block(tail[7]) << 56 fallthrough case 7: k1 ^= block(tail[6]) << 48 fallthrough case 6: k1 ^= block(tail[5]) << 40 fallthrough case 5: k1 ^= block(tail[4]) << 32 fallthrough case 4: k1 ^= block(tail[3]) << 24 fallthrough case 3: k1 ^= block(tail[2]) << 16 fallthrough case 2: k1 ^= block(tail[1]) << 8 fallthrough case 1: k1 ^= block(tail[0]) k1 *= c1 k1 = rotl(k1, 31) k1 *= c2 h1 ^= k1 } h1 ^= int64(length) h2 ^= int64(length) h1 += h2 h2 += h1 h1 = fmix(h1) h2 = fmix(h2) h1 += h2 // the following is extraneous since h2 is discarded // h2 += h1 return h1 } cassandra-gocql-driver-1.7.0/internal/murmur/murmur_appengine.go000066400000000000000000000024541467504044300251350ustar00rootroot00000000000000//go:build appengine || s390x // +build appengine s390x /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package murmur import "encoding/binary" func getBlock(data []byte, n int) (int64, int64) { k1 := int64(binary.LittleEndian.Uint64(data[n*16:])) k2 := int64(binary.LittleEndian.Uint64(data[(n*16)+8:])) return k1, k2 } cassandra-gocql-driver-1.7.0/internal/murmur/murmur_test.go000066400000000000000000000115761467504044300241530ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package murmur import ( "encoding/hex" "fmt" "strconv" "testing" ) func TestRotl(t *testing.T) { tests := []struct { in, rotate, exp int64 }{ {123456789, 33, 1060485742448345088}, {-123456789, 33, -1060485733858410497}, {-12345678987654, 33, 1756681988166642059}, {7210216203459776512, 31, -4287945813905642825}, {2453826951392495049, 27, -2013042863942636044}, {270400184080946339, 33, -3553153987756601583}, {2060965185473694757, 31, 6290866853133484661}, {3075794793055692309, 33, -3158909918919076318}, {-6486402271863858009, 31, 405973038345868736}, } for _, test := range tests { t.Run(fmt.Sprintf("%d >> %d", test.in, test.rotate), func(t *testing.T) { if v := rotl(test.in, uint8(test.rotate)); v != test.exp { t.Fatalf("expected %d got %d", test.exp, v) } }) } } func TestFmix(t *testing.T) { tests := []struct { in, exp int64 }{ {123456789, -8107560010088384378}, {-123456789, -5252787026298255965}, {-12345678987654, -1122383578793231303}, {-1241537367799374202, 3388197556095096266}, {-7566534940689533355, 4729783097411765989}, } for _, test := range tests { t.Run(strconv.Itoa(int(test.in)), func(t *testing.T) { if v := fmix(test.in); v != test.exp { t.Fatalf("expected %d got %d", test.exp, v) } }) } } func TestMurmur3H1_CassandraSign(t *testing.T) { key, err := hex.DecodeString("00104327529fb645dd00b883ec39ae448bb800000400066a6b00") if err != nil { t.Fatal(err) } h := Murmur3H1(key) const exp int64 = -9223371632693506265 if h != exp { t.Fatalf("expected %d got %d", exp, h) } } // Test the implementation of murmur3 func TestMurmur3H1(t *testing.T) { // these examples are based on adding a index number to a sample string in // a loop. The expected values were generated by the java datastax murmur3 // implementation. The number of examples here of increasing lengths ensure // test coverage of all tail-length branches in the murmur3 algorithm seriesExpected := [...]uint64{ 0x0000000000000000, // "" 0x2ac9debed546a380, // "0" 0x649e4eaa7fc1708e, // "01" 0xce68f60d7c353bdb, // "012" 0x0f95757ce7f38254, // "0123" 0x0f04e459497f3fc1, // "01234" 0x88c0a92586be0a27, // "012345" 0x13eb9fb82606f7a6, // "0123456" 0x8236039b7387354d, // "01234567" 0x4c1e87519fe738ba, // "012345678" 0x3f9652ac3effeb24, // "0123456789" 0x3f33760ded9006c6, // "01234567890" 0xaed70a6631854cb1, // "012345678901" 0x8a299a8f8e0e2da7, // "0123456789012" 0x624b675c779249a6, // "01234567890123" 0xa4b203bb1d90b9a3, // "012345678901234" 0xa3293ad698ecb99a, // "0123456789012345" 0xbc740023dbd50048, // "01234567890123456" 0x3fe5ab9837d25cdd, // "012345678901234567" 0x2d0338c1ca87d132, // "0123456789012345678" } sample := "" for i, expected := range seriesExpected { assertMurmur3H1(t, []byte(sample), expected) sample = sample + strconv.Itoa(i%10) } // Here are some test examples from other driver implementations assertMurmur3H1(t, []byte("hello"), 0xcbd8a7b341bd9b02) assertMurmur3H1(t, []byte("hello, world"), 0x342fac623a5ebc8e) assertMurmur3H1(t, []byte("19 Jan 2038 at 3:14:07 AM"), 0xb89e5988b737affc) assertMurmur3H1(t, []byte("The quick brown fox jumps over the lazy dog."), 0xcd99481f9ee902c9) } // helper function for testing the murmur3 implementation func assertMurmur3H1(t *testing.T, data []byte, expected uint64) { actual := Murmur3H1(data) if actual != int64(expected) { t.Errorf("Expected h1 = %x for data = %x, but was %x", int64(expected), data, actual) } } // Benchmark of the performance of the murmur3 implementation func BenchmarkMurmur3H1(b *testing.B) { data := make([]byte, 1024) for i := 0; i < 1024; i++ { data[i] = byte(i) } b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { h1 := Murmur3H1(data) if h1 != int64(7627370222079200297) { b.Fatalf("expected %d got %d", int64(7627370222079200297), h1) } } }) } cassandra-gocql-driver-1.7.0/internal/murmur/murmur_unsafe.go000066400000000000000000000024201467504044300244410ustar00rootroot00000000000000//go:build !appengine && !s390x // +build !appengine,!s390x /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package murmur import ( "unsafe" ) func getBlock(data []byte, n int) (int64, int64) { block := (*[2]int64)(unsafe.Pointer(&data[n*16])) k1 := block[0] k2 := block[1] return k1, k2 } cassandra-gocql-driver-1.7.0/internal/streams/000077500000000000000000000000001467504044300213535ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/internal/streams/streams.go000066400000000000000000000103221467504044300233560ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package streams import ( "math" "strconv" "sync/atomic" ) const bucketBits = 64 // IDGenerator tracks and allocates streams which are in use. type IDGenerator struct { NumStreams int inuseStreams int32 numBuckets uint32 // streams is a bitset where each bit represents a stream, a 1 implies in use streams []uint64 offset uint32 } func New(protocol int) *IDGenerator { maxStreams := 128 if protocol > 2 { maxStreams = 32768 } buckets := maxStreams / 64 // reserve stream 0 streams := make([]uint64, buckets) streams[0] = 1 << 63 return &IDGenerator{ NumStreams: maxStreams, streams: streams, numBuckets: uint32(buckets), offset: uint32(buckets) - 1, } } func streamFromBucket(bucket, streamInBucket int) int { return (bucket * bucketBits) + streamInBucket } func (s *IDGenerator) GetStream() (int, bool) { // based closely on the java-driver stream ID generator // avoid false sharing subsequent requests. offset := atomic.LoadUint32(&s.offset) for !atomic.CompareAndSwapUint32(&s.offset, offset, (offset+1)%s.numBuckets) { offset = atomic.LoadUint32(&s.offset) } offset = (offset + 1) % s.numBuckets for i := uint32(0); i < s.numBuckets; i++ { pos := int((i + offset) % s.numBuckets) bucket := atomic.LoadUint64(&s.streams[pos]) if bucket == math.MaxUint64 { // all streams in use continue } for j := 0; j < bucketBits; j++ { mask := uint64(1 << streamOffset(j)) for bucket&mask == 0 { if atomic.CompareAndSwapUint64(&s.streams[pos], bucket, bucket|mask) { atomic.AddInt32(&s.inuseStreams, 1) return streamFromBucket(int(pos), j), true } bucket = atomic.LoadUint64(&s.streams[pos]) } } } return 0, false } func bitfmt(b uint64) string { return strconv.FormatUint(b, 16) } // returns the bucket offset of a given stream func bucketOffset(i int) int { return i / bucketBits } func streamOffset(stream int) uint64 { return bucketBits - uint64(stream%bucketBits) - 1 } func isSet(bits uint64, stream int) bool { return bits>>streamOffset(stream)&1 == 1 } func (s *IDGenerator) isSet(stream int) bool { bits := atomic.LoadUint64(&s.streams[bucketOffset(stream)]) return isSet(bits, stream) } func (s *IDGenerator) String() string { size := s.numBuckets * (bucketBits + 1) buf := make([]byte, 0, size) for i := 0; i < int(s.numBuckets); i++ { bits := atomic.LoadUint64(&s.streams[i]) buf = append(buf, bitfmt(bits)...) buf = append(buf, ' ') } return string(buf[: size-1 : size-1]) } func (s *IDGenerator) Clear(stream int) (inuse bool) { offset := bucketOffset(stream) bucket := atomic.LoadUint64(&s.streams[offset]) mask := uint64(1) << streamOffset(stream) if bucket&mask != mask { // already cleared return false } for !atomic.CompareAndSwapUint64(&s.streams[offset], bucket, bucket & ^mask) { bucket = atomic.LoadUint64(&s.streams[offset]) if bucket&mask != mask { // already cleared return false } } // TODO: make this account for 0 stream being reserved if atomic.AddInt32(&s.inuseStreams, -1) < 0 { // TODO(zariel): remove this panic("negative streams inuse") } return true } func (s *IDGenerator) Available() int { return s.NumStreams - int(atomic.LoadInt32(&s.inuseStreams)) - 1 } cassandra-gocql-driver-1.7.0/internal/streams/streams_test.go000066400000000000000000000116411467504044300244220ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package streams import ( "math" "strconv" "sync/atomic" "testing" ) func TestUsesAllStreams(t *testing.T) { streams := New(1) got := make(map[int]struct{}) for i := 1; i < streams.NumStreams; i++ { stream, ok := streams.GetStream() if !ok { t.Fatalf("unable to get stream %d", i) } if _, ok = got[stream]; ok { t.Fatalf("got an already allocated stream: %d", stream) } got[stream] = struct{}{} if !streams.isSet(stream) { bucket := atomic.LoadUint64(&streams.streams[bucketOffset(stream)]) t.Logf("bucket=%d: %s\n", bucket, strconv.FormatUint(bucket, 2)) t.Fatalf("stream not set: %d", stream) } } for i := 1; i < streams.NumStreams; i++ { if _, ok := got[i]; !ok { t.Errorf("did not use stream %d", i) } } if _, ok := got[0]; ok { t.Fatal("expected to not use stream 0") } for i, bucket := range streams.streams { if bucket != math.MaxUint64 { t.Errorf("did not use all streams in offset=%d bucket=%s", i, bitfmt(bucket)) } } } func TestFullStreams(t *testing.T) { streams := New(1) for i := range streams.streams { streams.streams[i] = math.MaxUint64 } stream, ok := streams.GetStream() if ok { t.Fatalf("should not get stream when all in use: stream=%d", stream) } } func TestClearStreams(t *testing.T) { streams := New(1) for i := range streams.streams { streams.streams[i] = math.MaxUint64 } streams.inuseStreams = int32(streams.NumStreams) for i := 0; i < streams.NumStreams; i++ { streams.Clear(i) } for i, bucket := range streams.streams { if bucket != 0 { t.Errorf("did not clear streams in offset=%d bucket=%s", i, bitfmt(bucket)) } } } func TestDoubleClear(t *testing.T) { streams := New(1) stream, ok := streams.GetStream() if !ok { t.Fatal("did not get stream") } if !streams.Clear(stream) { t.Fatalf("stream not indicated as in use: %d", stream) } if streams.Clear(stream) { t.Fatalf("stream not as in use after clear: %d", stream) } } func BenchmarkConcurrentUse(b *testing.B) { streams := New(2) b.RunParallel(func(pb *testing.PB) { for pb.Next() { stream, ok := streams.GetStream() if !ok { b.Error("unable to get stream") return } if !streams.Clear(stream) { b.Errorf("stream was already cleared: %d", stream) return } } }) } func TestStreamOffset(t *testing.T) { tests := [...]struct { n int off uint64 }{ {0, 63}, {1, 62}, {2, 61}, {3, 60}, {63, 0}, {64, 63}, {128, 63}, } for _, test := range tests { if off := streamOffset(test.n); off != test.off { t.Errorf("n=%d expected %d got %d", test.n, off, test.off) } } } func TestIsSet(t *testing.T) { tests := [...]struct { stream int bucket uint64 set bool }{ {0, 0, false}, {0, 1 << 63, true}, {1, 0, false}, {1, 1 << 62, true}, {63, 1, true}, {64, 1 << 63, true}, {0, 0x8000000000000000, true}, } for i, test := range tests { if set := isSet(test.bucket, test.stream); set != test.set { t.Errorf("[%d] stream=%d expected %v got %v", i, test.stream, test.set, set) } } for i := 0; i < bucketBits; i++ { if !isSet(math.MaxUint64, i) { var shift uint64 = math.MaxUint64 >> streamOffset(i) t.Errorf("expected isSet for all i=%d got=%d", i, shift) } } } func TestBucketOfset(t *testing.T) { tests := [...]struct { n int bucket int }{ {0, 0}, {1, 0}, {63, 0}, {64, 1}, } for _, test := range tests { if bucket := bucketOffset(test.n); bucket != test.bucket { t.Errorf("n=%d expected %v got %v", test.n, test.bucket, bucket) } } } func TestStreamFromBucket(t *testing.T) { tests := [...]struct { bucket int pos int stream int }{ {0, 0, 0}, {0, 1, 1}, {0, 2, 2}, {0, 63, 63}, {1, 0, 64}, {1, 1, 65}, } for _, test := range tests { if stream := streamFromBucket(test.bucket, test.pos); stream != test.stream { t.Errorf("bucket=%d pos=%d expected %v got %v", test.bucket, test.pos, test.stream, stream) } } } cassandra-gocql-driver-1.7.0/keyspace_table_test.go000066400000000000000000000060161467504044300224250ustar00rootroot00000000000000//go:build all || integration // +build all integration /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "context" "fmt" "testing" ) // Keyspace_table checks if Query.Keyspace() is updated based on prepared statement func TestKeyspaceTable(t *testing.T) { cluster := createCluster() fallback := RoundRobinHostPolicy() cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(fallback) session, err := cluster.CreateSession() if err != nil { t.Fatal("createSession:", err) } cluster.Keyspace = "wrong_keyspace" keyspace := "test1" table := "table1" err = createTable(session, `DROP KEYSPACE IF EXISTS `+keyspace) if err != nil { t.Fatal("unable to drop keyspace:", err) } err = createTable(session, fmt.Sprintf(`CREATE KEYSPACE %s WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }`, keyspace)) if err != nil { t.Fatal("unable to create keyspace:", err) } if err := session.control.awaitSchemaAgreement(); err != nil { t.Fatal(err) } err = createTable(session, fmt.Sprintf(`CREATE TABLE %s.%s (pk int, ck int, v int, PRIMARY KEY (pk, ck)); `, keyspace, table)) if err != nil { t.Fatal("unable to create table:", err) } if err := session.control.awaitSchemaAgreement(); err != nil { t.Fatal(err) } ctx := context.Background() // insert a row if err := session.Query(`INSERT INTO test1.table1(pk, ck, v) VALUES (?, ?, ?)`, 1, 2, 3).WithContext(ctx).Consistency(One).Exec(); err != nil { t.Fatal(err) } var pk int /* Search for a specific set of records whose 'pk' column matches * the value of inserted row. */ qry := session.Query(`SELECT pk FROM test1.table1 WHERE pk = ? LIMIT 1`, 1).WithContext(ctx).Consistency(One) if err := qry.Scan(&pk); err != nil { t.Fatal(err) } // cluster.Keyspace was set to "wrong_keyspace", but during prepering statement // Keyspace in Query should be changed to "test" and Table should be changed to table1 assertEqual(t, "qry.Keyspace()", "test1", qry.Keyspace()) assertEqual(t, "qry.Table()", "table1", qry.Table()) } cassandra-gocql-driver-1.7.0/logger.go000066400000000000000000000043061467504044300176720ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "fmt" "log" ) type StdLogger interface { Print(v ...interface{}) Printf(format string, v ...interface{}) Println(v ...interface{}) } type nopLogger struct{} func (n nopLogger) Print(_ ...interface{}) {} func (n nopLogger) Printf(_ string, _ ...interface{}) {} func (n nopLogger) Println(_ ...interface{}) {} type testLogger struct { capture bytes.Buffer } func (l *testLogger) Print(v ...interface{}) { fmt.Fprint(&l.capture, v...) } func (l *testLogger) Printf(format string, v ...interface{}) { fmt.Fprintf(&l.capture, format, v...) } func (l *testLogger) Println(v ...interface{}) { fmt.Fprintln(&l.capture, v...) } func (l *testLogger) String() string { return l.capture.String() } type defaultLogger struct{} func (l *defaultLogger) Print(v ...interface{}) { log.Print(v...) } func (l *defaultLogger) Printf(format string, v ...interface{}) { log.Printf(format, v...) } func (l *defaultLogger) Println(v ...interface{}) { log.Println(v...) } // Logger for logging messages. // Deprecated: Use ClusterConfig.Logger instead. var Logger StdLogger = &defaultLogger{} cassandra-gocql-driver-1.7.0/lz4/000077500000000000000000000000001467504044300165725ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/lz4/go.mod000066400000000000000000000016431467504044300177040ustar00rootroot00000000000000// // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // module github.com/gocql/gocql/lz4 go 1.16 require ( github.com/pierrec/lz4/v4 v4.1.8 github.com/stretchr/testify v1.7.0 ) cassandra-gocql-driver-1.7.0/lz4/go.sum000066400000000000000000000022511467504044300177250ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4= github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= cassandra-gocql-driver-1.7.0/lz4/lz4.go000066400000000000000000000053471467504044300176430ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package lz4 import ( "encoding/binary" "fmt" "github.com/pierrec/lz4/v4" ) // LZ4Compressor implements the gocql.Compressor interface and can be used to // compress incoming and outgoing frames. According to the Cassandra docs the // LZ4 protocol should be preferred over snappy. (For details refer to // https://cassandra.apache.org/doc/latest/operating/compression.html) // // Implementation note: Cassandra prefixes each compressed block with 4 bytes // of the uncompressed block length, written in big endian order. But the LZ4 // compression library github.com/pierrec/lz4/v4 does not expect the length // field, so it needs to be added to compressed blocks sent to Cassandra, and // removed from ones received from Cassandra before decompression. type LZ4Compressor struct{} func (s LZ4Compressor) Name() string { return "lz4" } func (s LZ4Compressor) Encode(data []byte) ([]byte, error) { buf := make([]byte, lz4.CompressBlockBound(len(data)+4)) var compressor lz4.Compressor n, err := compressor.CompressBlock(data, buf[4:]) // According to lz4.CompressBlock doc, it doesn't fail as long as the dst // buffer length is at least lz4.CompressBlockBound(len(data))) bytes, but // we check for error anyway just to be thorough. if err != nil { return nil, err } binary.BigEndian.PutUint32(buf, uint32(len(data))) return buf[:n+4], nil } func (s LZ4Compressor) Decode(data []byte) ([]byte, error) { if len(data) < 4 { return nil, fmt.Errorf("cassandra lz4 block size should be >4, got=%d", len(data)) } uncompressedLength := binary.BigEndian.Uint32(data) if uncompressedLength == 0 { return nil, nil } buf := make([]byte, uncompressedLength) n, err := lz4.UncompressBlock(data[4:], buf) return buf[:n], err } cassandra-gocql-driver-1.7.0/lz4/lz4_test.go000066400000000000000000000034521467504044300206750ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package lz4 import ( "testing" "github.com/stretchr/testify/require" ) func TestLZ4Compressor(t *testing.T) { var c LZ4Compressor require.Equal(t, "lz4", c.Name()) _, err := c.Decode([]byte{0, 1, 2}) require.EqualError(t, err, "cassandra lz4 block size should be >4, got=3") _, err = c.Decode([]byte{0, 1, 2, 4, 5}) require.EqualError(t, err, "lz4: invalid source or destination buffer too short") // If uncompressed size is zero then nothing is decoded even if present. decoded, err := c.Decode([]byte{0, 0, 0, 0, 5, 7, 8}) require.NoError(t, err) require.Nil(t, decoded) original := []byte("My Test String") encoded, err := c.Encode(original) require.NoError(t, err) decoded, err = c.Decode(encoded) require.NoError(t, err) require.Equal(t, original, decoded) } cassandra-gocql-driver-1.7.0/marshal.go000066400000000000000000002123651467504044300200500ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "encoding/binary" "errors" "fmt" "math" "math/big" "math/bits" "net" "reflect" "strconv" "strings" "time" "gopkg.in/inf.v0" ) var ( bigOne = big.NewInt(1) emptyValue reflect.Value ) var ( ErrorUDTUnavailable = errors.New("UDT are not available on protocols less than 3, please update config") ) // Marshaler is the interface implemented by objects that can marshal // themselves into values understood by Cassandra. type Marshaler interface { MarshalCQL(info TypeInfo) ([]byte, error) } // Unmarshaler is the interface implemented by objects that can unmarshal // a Cassandra specific description of themselves. type Unmarshaler interface { UnmarshalCQL(info TypeInfo, data []byte) error } // Marshal returns the CQL encoding of the value for the Cassandra // internal type described by the info parameter. // // nil is serialized as CQL null. // If value implements Marshaler, its MarshalCQL method is called to marshal the data. // If value is a pointer, the pointed-to value is marshaled. // // Supported conversions are as follows, other type combinations may be added in the future: // // CQL type | Go type (value) | Note // varchar, ascii, blob, text | string, []byte | // boolean | bool | // tinyint, smallint, int | integer types | // tinyint, smallint, int | string | formatted as base 10 number // bigint, counter | integer types | // bigint, counter | big.Int | // bigint, counter | string | formatted as base 10 number // float | float32 | // double | float64 | // decimal | inf.Dec | // time | int64 | nanoseconds since start of day // time | time.Duration | duration since start of day // timestamp | int64 | milliseconds since Unix epoch // timestamp | time.Time | // list, set | slice, array | // list, set | map[X]struct{} | // map | map[X]Y | // uuid, timeuuid | gocql.UUID | // uuid, timeuuid | [16]byte | raw UUID bytes // uuid, timeuuid | []byte | raw UUID bytes, length must be 16 bytes // uuid, timeuuid | string | hex representation, see ParseUUID // varint | integer types | // varint | big.Int | // varint | string | value of number in decimal notation // inet | net.IP | // inet | string | IPv4 or IPv6 address string // tuple | slice, array | // tuple | struct | fields are marshaled in order of declaration // user-defined type | gocql.UDTMarshaler | MarshalUDT is called // user-defined type | map[string]interface{} | // user-defined type | struct | struct fields' cql tags are used for column names // date | int64 | milliseconds since Unix epoch to start of day (in UTC) // date | time.Time | start of day (in UTC) // date | string | parsed using "2006-01-02" format // duration | int64 | duration in nanoseconds // duration | time.Duration | // duration | gocql.Duration | // duration | string | parsed with time.ParseDuration func Marshal(info TypeInfo, value interface{}) ([]byte, error) { if info.Version() < protoVersion1 { panic("protocol version not set") } if valueRef := reflect.ValueOf(value); valueRef.Kind() == reflect.Ptr { if valueRef.IsNil() { return nil, nil } else if v, ok := value.(Marshaler); ok { return v.MarshalCQL(info) } else { return Marshal(info, valueRef.Elem().Interface()) } } if v, ok := value.(Marshaler); ok { return v.MarshalCQL(info) } switch info.Type() { case TypeVarchar, TypeAscii, TypeBlob, TypeText: return marshalVarchar(info, value) case TypeBoolean: return marshalBool(info, value) case TypeTinyInt: return marshalTinyInt(info, value) case TypeSmallInt: return marshalSmallInt(info, value) case TypeInt: return marshalInt(info, value) case TypeBigInt, TypeCounter: return marshalBigInt(info, value) case TypeFloat: return marshalFloat(info, value) case TypeDouble: return marshalDouble(info, value) case TypeDecimal: return marshalDecimal(info, value) case TypeTime: return marshalTime(info, value) case TypeTimestamp: return marshalTimestamp(info, value) case TypeList, TypeSet: return marshalList(info, value) case TypeMap: return marshalMap(info, value) case TypeUUID, TypeTimeUUID: return marshalUUID(info, value) case TypeVarint: return marshalVarint(info, value) case TypeInet: return marshalInet(info, value) case TypeTuple: return marshalTuple(info, value) case TypeUDT: return marshalUDT(info, value) case TypeDate: return marshalDate(info, value) case TypeDuration: return marshalDuration(info, value) } // detect protocol 2 UDT if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { return nil, ErrorUDTUnavailable } // TODO(tux21b): add the remaining types return nil, fmt.Errorf("can not marshal %T into %s", value, info) } // Unmarshal parses the CQL encoded data based on the info parameter that // describes the Cassandra internal data type and stores the result in the // value pointed by value. // // If value implements Unmarshaler, it's UnmarshalCQL method is called to // unmarshal the data. // If value is a pointer to pointer, it is set to nil if the CQL value is // null. Otherwise, nulls are unmarshalled as zero value. // // Supported conversions are as follows, other type combinations may be added in the future: // // CQL type | Go type (value) | Note // varchar, ascii, blob, text | *string | // varchar, ascii, blob, text | *[]byte | non-nil buffer is reused // bool | *bool | // tinyint, smallint, int, bigint, counter | *integer types | // tinyint, smallint, int, bigint, counter | *big.Int | // tinyint, smallint, int, bigint, counter | *string | formatted as base 10 number // float | *float32 | // double | *float64 | // decimal | *inf.Dec | // time | *int64 | nanoseconds since start of day // time | *time.Duration | // timestamp | *int64 | milliseconds since Unix epoch // timestamp | *time.Time | // list, set | *slice, *array | // map | *map[X]Y | // uuid, timeuuid | *string | see UUID.String // uuid, timeuuid | *[]byte | raw UUID bytes // uuid, timeuuid | *gocql.UUID | // timeuuid | *time.Time | timestamp of the UUID // inet | *net.IP | // inet | *string | IPv4 or IPv6 address string // tuple | *slice, *array | // tuple | *struct | struct fields are set in order of declaration // user-defined types | gocql.UDTUnmarshaler | UnmarshalUDT is called // user-defined types | *map[string]interface{} | // user-defined types | *struct | cql tag is used to determine field name // date | *time.Time | time of beginning of the day (in UTC) // date | *string | formatted with 2006-01-02 format // duration | *gocql.Duration | func Unmarshal(info TypeInfo, data []byte, value interface{}) error { if v, ok := value.(Unmarshaler); ok { return v.UnmarshalCQL(info, data) } if isNullableValue(value) { return unmarshalNullable(info, data, value) } switch info.Type() { case TypeVarchar, TypeAscii, TypeBlob, TypeText: return unmarshalVarchar(info, data, value) case TypeBoolean: return unmarshalBool(info, data, value) case TypeInt: return unmarshalInt(info, data, value) case TypeBigInt, TypeCounter: return unmarshalBigInt(info, data, value) case TypeVarint: return unmarshalVarint(info, data, value) case TypeSmallInt: return unmarshalSmallInt(info, data, value) case TypeTinyInt: return unmarshalTinyInt(info, data, value) case TypeFloat: return unmarshalFloat(info, data, value) case TypeDouble: return unmarshalDouble(info, data, value) case TypeDecimal: return unmarshalDecimal(info, data, value) case TypeTime: return unmarshalTime(info, data, value) case TypeTimestamp: return unmarshalTimestamp(info, data, value) case TypeList, TypeSet: return unmarshalList(info, data, value) case TypeMap: return unmarshalMap(info, data, value) case TypeTimeUUID: return unmarshalTimeUUID(info, data, value) case TypeUUID: return unmarshalUUID(info, data, value) case TypeInet: return unmarshalInet(info, data, value) case TypeTuple: return unmarshalTuple(info, data, value) case TypeUDT: return unmarshalUDT(info, data, value) case TypeDate: return unmarshalDate(info, data, value) case TypeDuration: return unmarshalDuration(info, data, value) } // detect protocol 2 UDT if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { return ErrorUDTUnavailable } // TODO(tux21b): add the remaining types return fmt.Errorf("can not unmarshal %s into %T", info, value) } func isNullableValue(value interface{}) bool { v := reflect.ValueOf(value) return v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Ptr } func isNullData(info TypeInfo, data []byte) bool { return data == nil } func unmarshalNullable(info TypeInfo, data []byte, value interface{}) error { valueRef := reflect.ValueOf(value) if isNullData(info, data) { nilValue := reflect.Zero(valueRef.Type().Elem()) valueRef.Elem().Set(nilValue) return nil } newValue := reflect.New(valueRef.Type().Elem().Elem()) valueRef.Elem().Set(newValue) return Unmarshal(info, data, newValue.Interface()) } func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case string: return []byte(v), nil case []byte: return v, nil } if value == nil { return nil, nil } rv := reflect.ValueOf(value) t := rv.Type() k := t.Kind() switch { case k == reflect.String: return []byte(rv.String()), nil case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: return rv.Bytes(), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *string: *v = string(data) return nil case *[]byte: if data != nil { *v = append((*v)[:0], data...) } else { *v = nil } return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() t := rv.Type() k := t.Kind() switch { case k == reflect.String: rv.SetString(string(data)) return nil case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: var dataCopy []byte if data != nil { dataCopy = make([]byte, len(data)) copy(dataCopy, data) } rv.SetBytes(dataCopy) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case int16: return encShort(v), nil case uint16: return encShort(int16(v)), nil case int8: return encShort(int16(v)), nil case uint8: return encShort(int16(v)), nil case int: if v > math.MaxInt16 || v < math.MinInt16 { return nil, marshalErrorf("marshal smallint: value %d out of range", v) } return encShort(int16(v)), nil case int32: if v > math.MaxInt16 || v < math.MinInt16 { return nil, marshalErrorf("marshal smallint: value %d out of range", v) } return encShort(int16(v)), nil case int64: if v > math.MaxInt16 || v < math.MinInt16 { return nil, marshalErrorf("marshal smallint: value %d out of range", v) } return encShort(int16(v)), nil case uint: if v > math.MaxUint16 { return nil, marshalErrorf("marshal smallint: value %d out of range", v) } return encShort(int16(v)), nil case uint32: if v > math.MaxUint16 { return nil, marshalErrorf("marshal smallint: value %d out of range", v) } return encShort(int16(v)), nil case uint64: if v > math.MaxUint16 { return nil, marshalErrorf("marshal smallint: value %d out of range", v) } return encShort(int16(v)), nil case string: n, err := strconv.ParseInt(v, 10, 16) if err != nil { return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) } return encShort(int16(n)), nil } if value == nil { return nil, nil } switch rv := reflect.ValueOf(value); rv.Type().Kind() { case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: v := rv.Int() if v > math.MaxInt16 || v < math.MinInt16 { return nil, marshalErrorf("marshal smallint: value %d out of range", v) } return encShort(int16(v)), nil case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: v := rv.Uint() if v > math.MaxUint16 { return nil, marshalErrorf("marshal smallint: value %d out of range", v) } return encShort(int16(v)), nil case reflect.Ptr: if rv.IsNil() { return nil, nil } } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case int8: return []byte{byte(v)}, nil case uint8: return []byte{byte(v)}, nil case int16: if v > math.MaxInt8 || v < math.MinInt8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case uint16: if v > math.MaxUint8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case int: if v > math.MaxInt8 || v < math.MinInt8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case int32: if v > math.MaxInt8 || v < math.MinInt8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case int64: if v > math.MaxInt8 || v < math.MinInt8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case uint: if v > math.MaxUint8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case uint32: if v > math.MaxUint8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case uint64: if v > math.MaxUint8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case string: n, err := strconv.ParseInt(v, 10, 8) if err != nil { return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) } return []byte{byte(n)}, nil } if value == nil { return nil, nil } switch rv := reflect.ValueOf(value); rv.Type().Kind() { case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: v := rv.Int() if v > math.MaxInt8 || v < math.MinInt8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: v := rv.Uint() if v > math.MaxUint8 { return nil, marshalErrorf("marshal tinyint: value %d out of range", v) } return []byte{byte(v)}, nil case reflect.Ptr: if rv.IsNil() { return nil, nil } } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func marshalInt(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case int: if v > math.MaxInt32 || v < math.MinInt32 { return nil, marshalErrorf("marshal int: value %d out of range", v) } return encInt(int32(v)), nil case uint: if v > math.MaxUint32 { return nil, marshalErrorf("marshal int: value %d out of range", v) } return encInt(int32(v)), nil case int64: if v > math.MaxInt32 || v < math.MinInt32 { return nil, marshalErrorf("marshal int: value %d out of range", v) } return encInt(int32(v)), nil case uint64: if v > math.MaxUint32 { return nil, marshalErrorf("marshal int: value %d out of range", v) } return encInt(int32(v)), nil case int32: return encInt(v), nil case uint32: return encInt(int32(v)), nil case int16: return encInt(int32(v)), nil case uint16: return encInt(int32(v)), nil case int8: return encInt(int32(v)), nil case uint8: return encInt(int32(v)), nil case string: i, err := strconv.ParseInt(v, 10, 32) if err != nil { return nil, marshalErrorf("can not marshal string to int: %s", err) } return encInt(int32(i)), nil } if value == nil { return nil, nil } switch rv := reflect.ValueOf(value); rv.Type().Kind() { case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: v := rv.Int() if v > math.MaxInt32 || v < math.MinInt32 { return nil, marshalErrorf("marshal int: value %d out of range", v) } return encInt(int32(v)), nil case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: v := rv.Uint() if v > math.MaxInt32 { return nil, marshalErrorf("marshal int: value %d out of range", v) } return encInt(int32(v)), nil case reflect.Ptr: if rv.IsNil() { return nil, nil } } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func encInt(x int32) []byte { return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} } func decInt(x []byte) int32 { if len(x) != 4 { return 0 } return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) } func encShort(x int16) []byte { p := make([]byte, 2) p[0] = byte(x >> 8) p[1] = byte(x) return p } func decShort(p []byte) int16 { if len(p) != 2 { return 0 } return int16(p[0])<<8 | int16(p[1]) } func decTiny(p []byte) int8 { if len(p) != 1 { return 0 } return int8(p[0]) } func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case int: return encBigInt(int64(v)), nil case uint: if uint64(v) > math.MaxInt64 { return nil, marshalErrorf("marshal bigint: value %d out of range", v) } return encBigInt(int64(v)), nil case int64: return encBigInt(v), nil case uint64: return encBigInt(int64(v)), nil case int32: return encBigInt(int64(v)), nil case uint32: return encBigInt(int64(v)), nil case int16: return encBigInt(int64(v)), nil case uint16: return encBigInt(int64(v)), nil case int8: return encBigInt(int64(v)), nil case uint8: return encBigInt(int64(v)), nil case big.Int: return encBigInt2C(&v), nil case string: i, err := strconv.ParseInt(value.(string), 10, 64) if err != nil { return nil, marshalErrorf("can not marshal string to bigint: %s", err) } return encBigInt(i), nil } if value == nil { return nil, nil } rv := reflect.ValueOf(value) switch rv.Type().Kind() { case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: v := rv.Int() return encBigInt(v), nil case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: v := rv.Uint() if v > math.MaxInt64 { return nil, marshalErrorf("marshal bigint: value %d out of range", v) } return encBigInt(int64(v)), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func encBigInt(x int64) []byte { return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32), byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} } func bytesToInt64(data []byte) (ret int64) { for i := range data { ret |= int64(data[i]) << (8 * uint(len(data)-i-1)) } return ret } func bytesToUint64(data []byte) (ret uint64) { for i := range data { ret |= uint64(data[i]) << (8 * uint(len(data)-i-1)) } return ret } func unmarshalBigInt(info TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, decBigInt(data), data, value) } func unmarshalInt(info TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, int64(decInt(data)), data, value) } func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, int64(decShort(data)), data, value) } func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, int64(decTiny(data)), data, value) } func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case *big.Int: return unmarshalIntlike(info, 0, data, value) case *uint64: if len(data) == 9 && data[0] == 0 { *v = bytesToUint64(data[1:]) return nil } } if len(data) > 8 { return unmarshalErrorf("unmarshal int: varint value %v out of range for %T (use big.Int)", data, value) } int64Val := bytesToInt64(data) if len(data) > 0 && len(data) < 8 && data[0]&0x80 > 0 { int64Val -= (1 << uint(len(data)*8)) } return unmarshalIntlike(info, int64Val, data, value) } func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { var ( retBytes []byte err error ) switch v := value.(type) { case unsetColumn: return nil, nil case uint64: if v > uint64(math.MaxInt64) { retBytes = make([]byte, 9) binary.BigEndian.PutUint64(retBytes[1:], v) } else { retBytes = make([]byte, 8) binary.BigEndian.PutUint64(retBytes, v) } default: retBytes, err = marshalBigInt(info, value) } if err == nil { // trim down to most significant byte i := 0 for ; i < len(retBytes)-1; i++ { b0 := retBytes[i] if b0 != 0 && b0 != 0xFF { break } b1 := retBytes[i+1] if b0 == 0 && b1 != 0 { if b1&0x80 == 0 { i++ } break } if b0 == 0xFF && b1 != 0xFF { if b1&0x80 > 0 { i++ } break } } retBytes = retBytes[i:] } return retBytes, err } func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interface{}) error { switch v := value.(type) { case *int: if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) } *v = int(int64Val) return nil case *uint: unitVal := uint64(int64Val) switch info.Type() { case TypeInt: *v = uint(unitVal) & 0xFFFFFFFF case TypeSmallInt: *v = uint(unitVal) & 0xFFFF case TypeTinyInt: *v = uint(unitVal) & 0xFF default: if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v) } *v = uint(unitVal) } return nil case *int64: *v = int64Val return nil case *uint64: switch info.Type() { case TypeInt: *v = uint64(int64Val) & 0xFFFFFFFF case TypeSmallInt: *v = uint64(int64Val) & 0xFFFF case TypeTinyInt: *v = uint64(int64Val) & 0xFF default: *v = uint64(int64Val) } return nil case *int32: if int64Val < math.MinInt32 || int64Val > math.MaxInt32 { return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) } *v = int32(int64Val) return nil case *uint32: switch info.Type() { case TypeInt: *v = uint32(int64Val) & 0xFFFFFFFF case TypeSmallInt: *v = uint32(int64Val) & 0xFFFF case TypeTinyInt: *v = uint32(int64Val) & 0xFF default: if int64Val < 0 || int64Val > math.MaxUint32 { return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) } *v = uint32(int64Val) & 0xFFFFFFFF } return nil case *int16: if int64Val < math.MinInt16 || int64Val > math.MaxInt16 { return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) } *v = int16(int64Val) return nil case *uint16: switch info.Type() { case TypeSmallInt: *v = uint16(int64Val) & 0xFFFF case TypeTinyInt: *v = uint16(int64Val) & 0xFF default: if int64Val < 0 || int64Val > math.MaxUint16 { return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) } *v = uint16(int64Val) & 0xFFFF } return nil case *int8: if int64Val < math.MinInt8 || int64Val > math.MaxInt8 { return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) } *v = int8(int64Val) return nil case *uint8: if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) } *v = uint8(int64Val) & 0xFF return nil case *big.Int: decBigInt2C(data, v) return nil case *string: *v = strconv.FormatInt(int64Val, 10) return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() switch rv.Type().Kind() { case reflect.Int: if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) } rv.SetInt(int64Val) return nil case reflect.Int64: rv.SetInt(int64Val) return nil case reflect.Int32: if int64Val < math.MinInt32 || int64Val > math.MaxInt32 { return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) } rv.SetInt(int64Val) return nil case reflect.Int16: if int64Val < math.MinInt16 || int64Val > math.MaxInt16 { return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) } rv.SetInt(int64Val) return nil case reflect.Int8: if int64Val < math.MinInt8 || int64Val > math.MaxInt8 { return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) } rv.SetInt(int64Val) return nil case reflect.Uint: unitVal := uint64(int64Val) switch info.Type() { case TypeInt: rv.SetUint(unitVal & 0xFFFFFFFF) case TypeSmallInt: rv.SetUint(unitVal & 0xFFFF) case TypeTinyInt: rv.SetUint(unitVal & 0xFF) default: if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { return unmarshalErrorf("unmarshal int: value %d out of range for %s", unitVal, rv.Type()) } rv.SetUint(unitVal) } return nil case reflect.Uint64: unitVal := uint64(int64Val) switch info.Type() { case TypeInt: rv.SetUint(unitVal & 0xFFFFFFFF) case TypeSmallInt: rv.SetUint(unitVal & 0xFFFF) case TypeTinyInt: rv.SetUint(unitVal & 0xFF) default: rv.SetUint(unitVal) } return nil case reflect.Uint32: unitVal := uint64(int64Val) switch info.Type() { case TypeInt: rv.SetUint(unitVal & 0xFFFFFFFF) case TypeSmallInt: rv.SetUint(unitVal & 0xFFFF) case TypeTinyInt: rv.SetUint(unitVal & 0xFF) default: if int64Val < 0 || int64Val > math.MaxUint32 { return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) } rv.SetUint(unitVal & 0xFFFFFFFF) } return nil case reflect.Uint16: unitVal := uint64(int64Val) switch info.Type() { case TypeSmallInt: rv.SetUint(unitVal & 0xFFFF) case TypeTinyInt: rv.SetUint(unitVal & 0xFF) default: if int64Val < 0 || int64Val > math.MaxUint16 { return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) } rv.SetUint(unitVal & 0xFFFF) } return nil case reflect.Uint8: if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) } rv.SetUint(uint64(int64Val) & 0xff) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } func decBigInt(data []byte) int64 { if len(data) != 8 { return 0 } return int64(data[0])<<56 | int64(data[1])<<48 | int64(data[2])<<40 | int64(data[3])<<32 | int64(data[4])<<24 | int64(data[5])<<16 | int64(data[6])<<8 | int64(data[7]) } func marshalBool(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case bool: return encBool(v), nil } if value == nil { return nil, nil } rv := reflect.ValueOf(value) switch rv.Type().Kind() { case reflect.Bool: return encBool(rv.Bool()), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func encBool(v bool) []byte { if v { return []byte{1} } return []byte{0} } func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *bool: *v = decBool(data) return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() switch rv.Type().Kind() { case reflect.Bool: rv.SetBool(decBool(data)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } func decBool(v []byte) bool { if len(v) == 0 { return false } return v[0] != 0 } func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case float32: return encInt(int32(math.Float32bits(v))), nil } if value == nil { return nil, nil } rv := reflect.ValueOf(value) switch rv.Type().Kind() { case reflect.Float32: return encInt(int32(math.Float32bits(float32(rv.Float())))), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *float32: *v = math.Float32frombits(uint32(decInt(data))) return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() switch rv.Type().Kind() { case reflect.Float32: rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data))))) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case float64: return encBigInt(int64(math.Float64bits(v))), nil } if value == nil { return nil, nil } rv := reflect.ValueOf(value) switch rv.Type().Kind() { case reflect.Float64: return encBigInt(int64(math.Float64bits(rv.Float()))), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *float64: *v = math.Float64frombits(uint64(decBigInt(data))) return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() switch rv.Type().Kind() { case reflect.Float64: rv.SetFloat(math.Float64frombits(uint64(decBigInt(data)))) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { if value == nil { return nil, nil } switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case inf.Dec: unscaled := encBigInt2C(v.UnscaledBig()) if unscaled == nil { return nil, marshalErrorf("can not marshal %T into %s", value, info) } buf := make([]byte, 4+len(unscaled)) copy(buf[0:4], encInt(int32(v.Scale()))) copy(buf[4:], unscaled) return buf, nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *inf.Dec: if len(data) < 4 { return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) } scale := decInt(data[0:4]) unscaled := decBigInt2C(data[4:], nil) *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } // decBigInt2C sets the value of n to the big-endian two's complement // value stored in the given data. If data[0]&80 != 0, the number // is negative. If data is empty, the result will be 0. func decBigInt2C(data []byte, n *big.Int) *big.Int { if n == nil { n = new(big.Int) } n.SetBytes(data) if len(data) > 0 && data[0]&0x80 > 0 { n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8)) } return n } // encBigInt2C returns the big-endian two's complement // form of n. func encBigInt2C(n *big.Int) []byte { switch n.Sign() { case 0: return []byte{0} case 1: b := n.Bytes() if b[0]&0x80 > 0 { b = append([]byte{0}, b...) } return b case -1: length := uint(n.BitLen()/8+1) * 8 b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes() // When the most significant bit is on a byte // boundary, we can get some extra significant // bits, so strip them off when that happens. if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { b = b[1:] } return b } return nil } func marshalTime(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case int64: return encBigInt(v), nil case time.Duration: return encBigInt(v.Nanoseconds()), nil } if value == nil { return nil, nil } rv := reflect.ValueOf(value) switch rv.Type().Kind() { case reflect.Int64: return encBigInt(rv.Int()), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case int64: return encBigInt(v), nil case time.Time: if v.IsZero() { return []byte{}, nil } x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) return encBigInt(x), nil } if value == nil { return nil, nil } rv := reflect.ValueOf(value) switch rv.Type().Kind() { case reflect.Int64: return encBigInt(rv.Int()), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *int64: *v = decBigInt(data) return nil case *time.Duration: *v = time.Duration(decBigInt(data)) return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() switch rv.Type().Kind() { case reflect.Int64: rv.SetInt(decBigInt(data)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *int64: *v = decBigInt(data) return nil case *time.Time: if len(data) == 0 { *v = time.Time{} return nil } x := decBigInt(data) sec := x / 1000 nsec := (x - sec*1000) * 1000000 *v = time.Unix(sec, nsec).In(time.UTC) return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() switch rv.Type().Kind() { case reflect.Int64: rv.SetInt(decBigInt(data)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } const millisecondsInADay int64 = 24 * 60 * 60 * 1000 func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { var timestamp int64 switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case int64: timestamp = v x := timestamp/millisecondsInADay + int64(1<<31) return encInt(int32(x)), nil case time.Time: if v.IsZero() { return []byte{}, nil } timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) x := timestamp/millisecondsInADay + int64(1<<31) return encInt(int32(x)), nil case *time.Time: if v.IsZero() { return []byte{}, nil } timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) x := timestamp/millisecondsInADay + int64(1<<31) return encInt(int32(x)), nil case string: if v == "" { return []byte{}, nil } t, err := time.Parse("2006-01-02", v) if err != nil { return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) } timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) x := timestamp/millisecondsInADay + int64(1<<31) return encInt(int32(x)), nil } if value == nil { return nil, nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *time.Time: if len(data) == 0 { *v = time.Time{} return nil } var origin uint32 = 1 << 31 var current uint32 = binary.BigEndian.Uint32(data) timestamp := (int64(current) - int64(origin)) * millisecondsInADay *v = time.UnixMilli(timestamp).In(time.UTC) return nil case *string: if len(data) == 0 { *v = "" return nil } var origin uint32 = 1 << 31 var current uint32 = binary.BigEndian.Uint32(data) timestamp := (int64(current) - int64(origin)) * millisecondsInADay *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, nil case int64: return encVints(0, 0, v), nil case time.Duration: return encVints(0, 0, v.Nanoseconds()), nil case string: d, err := time.ParseDuration(v) if err != nil { return nil, err } return encVints(0, 0, d.Nanoseconds()), nil case Duration: return encVints(v.Months, v.Days, v.Nanoseconds), nil } if value == nil { return nil, nil } rv := reflect.ValueOf(value) switch rv.Type().Kind() { case reflect.Int64: return encBigInt(rv.Int()), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *Duration: if len(data) == 0 { *v = Duration{ Months: 0, Days: 0, Nanoseconds: 0, } return nil } months, days, nanos, err := decVints(data) if err != nil { return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) } *v = Duration{ Months: months, Days: days, Nanoseconds: nanos, } return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } func decVints(data []byte) (int32, int32, int64, error) { month, i, err := decVint(data, 0) if err != nil { return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error()) } days, i, err := decVint(data, i) if err != nil { return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error()) } nanos, _, err := decVint(data, i) if err != nil { return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error()) } return int32(month), int32(days), nanos, err } func decVint(data []byte, start int) (int64, int, error) { if len(data) <= start { return 0, 0, errors.New("unexpected eof") } firstByte := data[start] if firstByte&0x80 == 0 { return decIntZigZag(uint64(firstByte)), start + 1, nil } numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 ret := uint64(firstByte & (0xff >> uint(numBytes))) if len(data) < start+numBytes+1 { return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) } for i := start; i < start+numBytes; i++ { ret <<= 8 ret |= uint64(data[i+1] & 0xff) } return decIntZigZag(ret), start + numBytes + 1, nil } func decIntZigZag(n uint64) int64 { return int64((n >> 1) ^ -(n & 1)) } func encIntZigZag(n int64) uint64 { return uint64((n >> 63) ^ (n << 1)) } func encVints(months int32, seconds int32, nanos int64) []byte { buf := append(encVint(int64(months)), encVint(int64(seconds))...) return append(buf, encVint(nanos)...) } func encVint(v int64) []byte { vEnc := encIntZigZag(v) lead0 := bits.LeadingZeros64(vEnc) numBytes := (639 - lead0*9) >> 6 // It can be 1 or 0 is v ==0 if numBytes <= 1 { return []byte{byte(vEnc)} } extraBytes := numBytes - 1 var buf = make([]byte, numBytes) for i := extraBytes; i >= 0; i-- { buf[i] = byte(vEnc) vEnc >>= 8 } buf[0] |= byte(^(0xff >> uint(extraBytes))) return buf } func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error { if info.proto > protoVersion2 { if n > math.MaxInt32 { return marshalErrorf("marshal: collection too large") } buf.WriteByte(byte(n >> 24)) buf.WriteByte(byte(n >> 16)) buf.WriteByte(byte(n >> 8)) buf.WriteByte(byte(n)) } else { if n > math.MaxUint16 { return marshalErrorf("marshal: collection too large") } buf.WriteByte(byte(n >> 8)) buf.WriteByte(byte(n)) } return nil } func marshalList(info TypeInfo, value interface{}) ([]byte, error) { listInfo, ok := info.(CollectionType) if !ok { return nil, marshalErrorf("marshal: can not marshal non collection type into list") } if value == nil { return nil, nil } else if _, ok := value.(unsetColumn); ok { return nil, nil } rv := reflect.ValueOf(value) t := rv.Type() k := t.Kind() if k == reflect.Slice && rv.IsNil() { return nil, nil } switch k { case reflect.Slice, reflect.Array: buf := &bytes.Buffer{} n := rv.Len() if err := writeCollectionSize(listInfo, n, buf); err != nil { return nil, err } for i := 0; i < n; i++ { item, err := Marshal(listInfo.Elem, rv.Index(i).Interface()) if err != nil { return nil, err } itemLen := len(item) // Set the value to null for supported protocols if item == nil && listInfo.proto > protoVersion2 { itemLen = -1 } if err := writeCollectionSize(listInfo, itemLen, buf); err != nil { return nil, err } buf.Write(item) } return buf.Bytes(), nil case reflect.Map: elem := t.Elem() if elem.Kind() == reflect.Struct && elem.NumField() == 0 { rkeys := rv.MapKeys() keys := make([]interface{}, len(rkeys)) for i := 0; i < len(keys); i++ { keys[i] = rkeys[i].Interface() } return marshalList(listInfo, keys) } } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) { if info.proto > protoVersion2 { if len(data) < 4 { return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") } size = int(int32(data[0])<<24 | int32(data[1])<<16 | int32(data[2])<<8 | int32(data[3])) read = 4 } else { if len(data) < 2 { return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") } size = int(data[0])<<8 | int(data[1]) read = 2 } return } func unmarshalList(info TypeInfo, data []byte, value interface{}) error { listInfo, ok := info.(CollectionType) if !ok { return unmarshalErrorf("unmarshal: can not unmarshal none collection type into list") } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() t := rv.Type() k := t.Kind() switch k { case reflect.Slice, reflect.Array: if data == nil { if k == reflect.Array { return unmarshalErrorf("unmarshal list: can not store nil in array value") } if rv.IsNil() { return nil } rv.Set(reflect.Zero(t)) return nil } n, p, err := readCollectionSize(listInfo, data) if err != nil { return err } data = data[p:] if k == reflect.Array { if rv.Len() != n { return unmarshalErrorf("unmarshal list: array with wrong size") } } else { rv.Set(reflect.MakeSlice(t, n, n)) } for i := 0; i < n; i++ { m, p, err := readCollectionSize(listInfo, data) if err != nil { return err } data = data[p:] // In case m < 0, the value is null, and unmarshalData should be nil. var unmarshalData []byte if m >= 0 { if len(data) < m { return unmarshalErrorf("unmarshal list: unexpected eof") } unmarshalData = data[:m] data = data[m:] } if err := Unmarshal(listInfo.Elem, unmarshalData, rv.Index(i).Addr().Interface()); err != nil { return err } } return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { mapInfo, ok := info.(CollectionType) if !ok { return nil, marshalErrorf("marshal: can not marshal none collection type into map") } if value == nil { return nil, nil } else if _, ok := value.(unsetColumn); ok { return nil, nil } rv := reflect.ValueOf(value) t := rv.Type() if t.Kind() != reflect.Map { return nil, marshalErrorf("can not marshal %T into %s", value, info) } if rv.IsNil() { return nil, nil } buf := &bytes.Buffer{} n := rv.Len() if err := writeCollectionSize(mapInfo, n, buf); err != nil { return nil, err } keys := rv.MapKeys() for _, key := range keys { item, err := Marshal(mapInfo.Key, key.Interface()) if err != nil { return nil, err } itemLen := len(item) // Set the key to null for supported protocols if item == nil && mapInfo.proto > protoVersion2 { itemLen = -1 } if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { return nil, err } buf.Write(item) item, err = Marshal(mapInfo.Elem, rv.MapIndex(key).Interface()) if err != nil { return nil, err } itemLen = len(item) // Set the value to null for supported protocols if item == nil && mapInfo.proto > protoVersion2 { itemLen = -1 } if err := writeCollectionSize(mapInfo, itemLen, buf); err != nil { return nil, err } buf.Write(item) } return buf.Bytes(), nil } func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { mapInfo, ok := info.(CollectionType) if !ok { return unmarshalErrorf("unmarshal: can not unmarshal none collection type into map") } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() t := rv.Type() if t.Kind() != reflect.Map { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } if data == nil { rv.Set(reflect.Zero(t)) return nil } n, p, err := readCollectionSize(mapInfo, data) if err != nil { return err } if n < 0 { return unmarshalErrorf("negative map size %d", n) } rv.Set(reflect.MakeMapWithSize(t, n)) data = data[p:] for i := 0; i < n; i++ { m, p, err := readCollectionSize(mapInfo, data) if err != nil { return err } data = data[p:] key := reflect.New(t.Key()) // In case m < 0, the key is null, and unmarshalData should be nil. var unmarshalData []byte if m >= 0 { if len(data) < m { return unmarshalErrorf("unmarshal map: unexpected eof") } unmarshalData = data[:m] data = data[m:] } if err := Unmarshal(mapInfo.Key, unmarshalData, key.Interface()); err != nil { return err } m, p, err = readCollectionSize(mapInfo, data) if err != nil { return err } data = data[p:] val := reflect.New(t.Elem()) // In case m < 0, the value is null, and unmarshalData should be nil. unmarshalData = nil if m >= 0 { if len(data) < m { return unmarshalErrorf("unmarshal map: unexpected eof") } unmarshalData = data[:m] data = data[m:] } if err := Unmarshal(mapInfo.Elem, unmarshalData, val.Interface()); err != nil { return err } rv.SetMapIndex(key.Elem(), val.Elem()) } return nil } func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { switch val := value.(type) { case unsetColumn: return nil, nil case UUID: return val.Bytes(), nil case [16]byte: return val[:], nil case []byte: if len(val) != 16 { return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info) } return val, nil case string: b, err := ParseUUID(val) if err != nil { return nil, err } return b[:], nil } if value == nil { return nil, nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { if len(data) == 0 { switch v := value.(type) { case *string: *v = "" case *[]byte: *v = nil case *UUID: *v = UUID{} default: return unmarshalErrorf("can not unmarshal X %s into %T", info, value) } return nil } if len(data) != 16 { return unmarshalErrorf("unable to parse UUID: UUIDs must be exactly 16 bytes long") } switch v := value.(type) { case *[16]byte: copy((*v)[:], data) return nil case *UUID: copy((*v)[:], data) return nil } u, err := UUIDFromBytes(data) if err != nil { return unmarshalErrorf("unable to parse UUID: %s", err) } switch v := value.(type) { case *string: *v = u.String() return nil case *[]byte: *v = u[:] return nil } return unmarshalErrorf("can not unmarshal X %s into %T", info, value) } func unmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *time.Time: id, err := UUIDFromBytes(data) if err != nil { return err } else if id.Version() != 1 { return unmarshalErrorf("invalid timeuuid") } *v = id.Time() return nil default: return unmarshalUUID(info, data, value) } } func marshalInet(info TypeInfo, value interface{}) ([]byte, error) { // we return either the 4 or 16 byte representation of an // ip address here otherwise the db value will be prefixed // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 switch val := value.(type) { case unsetColumn: return nil, nil case net.IP: t := val.To4() if t == nil { return val.To16(), nil } return t, nil case string: b := net.ParseIP(val) if b != nil { t := b.To4() if t == nil { return b.To16(), nil } return t, nil } return nil, marshalErrorf("cannot marshal. invalid ip string %s", val) } if value == nil { return nil, nil } return nil, marshalErrorf("cannot marshal %T into %s", value, info) } func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *net.IP: if x := len(data); !(x == 4 || x == 16) { return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x) } buf := copyBytes(data) ip := net.IP(buf) if v4 := ip.To4(); v4 != nil { *v = v4 return nil } *v = ip return nil case *string: if len(data) == 0 { *v = "" return nil } ip := net.IP(data) if v4 := ip.To4(); v4 != nil { *v = v4.String() return nil } *v = ip.String() return nil } return unmarshalErrorf("cannot unmarshal %s into %T", info, value) } func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { tuple := info.(TupleTypeInfo) switch v := value.(type) { case unsetColumn: return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples") case []interface{}: if len(v) != len(tuple.Elems) { return nil, unmarshalErrorf("cannont marshal tuple: wrong number of elements") } var buf []byte for i, elem := range v { if elem == nil { buf = appendInt(buf, int32(-1)) continue } data, err := Marshal(tuple.Elems[i], elem) if err != nil { return nil, err } n := len(data) buf = appendInt(buf, int32(n)) buf = append(buf, data...) } return buf, nil } rv := reflect.ValueOf(value) t := rv.Type() k := t.Kind() switch k { case reflect.Struct: if v := t.NumField(); v != len(tuple.Elems) { return nil, marshalErrorf("can not marshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) } var buf []byte for i, elem := range tuple.Elems { field := rv.Field(i) if field.Kind() == reflect.Ptr && field.IsNil() { buf = appendInt(buf, int32(-1)) continue } data, err := Marshal(elem, field.Interface()) if err != nil { return nil, err } n := len(data) buf = appendInt(buf, int32(n)) buf = append(buf, data...) } return buf, nil case reflect.Slice, reflect.Array: size := rv.Len() if size != len(tuple.Elems) { return nil, marshalErrorf("can not marshal tuple into %v of length %d need %d elements", k, size, len(tuple.Elems)) } var buf []byte for i, elem := range tuple.Elems { item := rv.Index(i) if item.Kind() == reflect.Ptr && item.IsNil() { buf = appendInt(buf, int32(-1)) continue } data, err := Marshal(elem, item.Interface()) if err != nil { return nil, err } n := len(data) buf = appendInt(buf, int32(n)) buf = append(buf, data...) } return buf, nil } return nil, marshalErrorf("cannot marshal %T into %s", value, tuple) } func readBytes(p []byte) ([]byte, []byte) { // TODO: really should use a framer size := readInt(p) p = p[4:] if size < 0 { return nil, p } return p[:size], p[size:] } // currently only support unmarshal into a list of values, this makes it possible // to support tuples without changing the query API. In the future this can be extend // to allow unmarshalling into custom tuple types. func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { if v, ok := value.(Unmarshaler); ok { return v.UnmarshalCQL(info, data) } tuple := info.(TupleTypeInfo) switch v := value.(type) { case []interface{}: for i, elem := range tuple.Elems { // each element inside data is a [bytes] var p []byte if len(data) >= 4 { p, data = readBytes(data) } err := Unmarshal(elem, p, v[i]) if err != nil { return err } } return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() t := rv.Type() k := t.Kind() switch k { case reflect.Struct: if v := t.NumField(); v != len(tuple.Elems) { return unmarshalErrorf("can not unmarshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) } for i, elem := range tuple.Elems { var p []byte if len(data) >= 4 { p, data = readBytes(data) } v, err := elem.NewWithError() if err != nil { return err } if err := Unmarshal(elem, p, v); err != nil { return err } switch rv.Field(i).Kind() { case reflect.Ptr: if p != nil { rv.Field(i).Set(reflect.ValueOf(v)) } else { rv.Field(i).Set(reflect.Zero(reflect.TypeOf(v))) } default: rv.Field(i).Set(reflect.ValueOf(v).Elem()) } } return nil case reflect.Slice, reflect.Array: if k == reflect.Array { size := rv.Len() if size != len(tuple.Elems) { return unmarshalErrorf("can not unmarshal tuple into array of length %d need %d elements", size, len(tuple.Elems)) } } else { rv.Set(reflect.MakeSlice(t, len(tuple.Elems), len(tuple.Elems))) } for i, elem := range tuple.Elems { var p []byte if len(data) >= 4 { p, data = readBytes(data) } v, err := elem.NewWithError() if err != nil { return err } if err := Unmarshal(elem, p, v); err != nil { return err } switch rv.Index(i).Kind() { case reflect.Ptr: if p != nil { rv.Index(i).Set(reflect.ValueOf(v)) } else { rv.Index(i).Set(reflect.Zero(reflect.TypeOf(v))) } default: rv.Index(i).Set(reflect.ValueOf(v).Elem()) } } return nil } return unmarshalErrorf("cannot unmarshal %s into %T", info, value) } // UDTMarshaler is an interface which should be implemented by users wishing to // handle encoding UDT types to sent to Cassandra. Note: due to current implentations // methods defined for this interface must be value receivers not pointer receivers. type UDTMarshaler interface { // MarshalUDT will be called for each field in the the UDT returned by Cassandra, // the implementor should marshal the type to return by for example calling // Marshal. MarshalUDT(name string, info TypeInfo) ([]byte, error) } // UDTUnmarshaler should be implemented by users wanting to implement custom // UDT unmarshaling. type UDTUnmarshaler interface { // UnmarshalUDT will be called for each field in the UDT return by Cassandra, // the implementor should unmarshal the data into the value of their chosing, // for example by calling Unmarshal. UnmarshalUDT(name string, info TypeInfo, data []byte) error } func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { udt := info.(UDTTypeInfo) switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) case unsetColumn: return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types") case UDTMarshaler: var buf []byte for _, e := range udt.Elements { data, err := v.MarshalUDT(e.Name, e.Type) if err != nil { return nil, err } buf = appendBytes(buf, data) } return buf, nil case map[string]interface{}: var buf []byte for _, e := range udt.Elements { val, ok := v[e.Name] var data []byte if ok { var err error data, err = Marshal(e.Type, val) if err != nil { return nil, err } } buf = appendBytes(buf, data) } return buf, nil } k := reflect.ValueOf(value) if k.Kind() == reflect.Ptr { if k.IsNil() { return nil, marshalErrorf("cannot marshal %T into %s", value, info) } k = k.Elem() } if k.Kind() != reflect.Struct || !k.IsValid() { return nil, marshalErrorf("cannot marshal %T into %s", value, info) } fields := make(map[string]reflect.Value) t := reflect.TypeOf(value) for i := 0; i < t.NumField(); i++ { sf := t.Field(i) if tag := sf.Tag.Get("cql"); tag != "" { fields[tag] = k.Field(i) } } var buf []byte for _, e := range udt.Elements { f, ok := fields[e.Name] if !ok { f = k.FieldByName(e.Name) } var data []byte if f.IsValid() && f.CanInterface() { var err error data, err = Marshal(e.Type, f.Interface()) if err != nil { return nil, err } } buf = appendBytes(buf, data) } return buf, nil } func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case UDTUnmarshaler: udt := info.(UDTTypeInfo) for id, e := range udt.Elements { if len(data) == 0 { return nil } if len(data) < 4 { return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) } var p []byte p, data = readBytes(data) if err := v.UnmarshalUDT(e.Name, e.Type, p); err != nil { return err } } return nil case *map[string]interface{}: udt := info.(UDTTypeInfo) rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } rv = rv.Elem() t := rv.Type() if t.Kind() != reflect.Map { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } else if data == nil { rv.Set(reflect.Zero(t)) return nil } rv.Set(reflect.MakeMap(t)) m := *v for id, e := range udt.Elements { if len(data) == 0 { return nil } if len(data) < 4 { return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) } valType, err := goType(e.Type) if err != nil { return unmarshalErrorf("can not unmarshal %s: %v", info, err) } val := reflect.New(valType) var p []byte p, data = readBytes(data) if err := Unmarshal(e.Type, p, val.Interface()); err != nil { return err } m[e.Name] = val.Elem().Interface() } return nil } rv := reflect.ValueOf(value) if rv.Kind() != reflect.Ptr { return unmarshalErrorf("can not unmarshal into non-pointer %T", value) } k := rv.Elem() if k.Kind() != reflect.Struct || !k.IsValid() { return unmarshalErrorf("cannot unmarshal %s into %T", info, value) } if len(data) == 0 { if k.CanSet() { k.Set(reflect.Zero(k.Type())) } return nil } t := k.Type() fields := make(map[string]reflect.Value, t.NumField()) for i := 0; i < t.NumField(); i++ { sf := t.Field(i) if tag := sf.Tag.Get("cql"); tag != "" { fields[tag] = k.Field(i) } } udt := info.(UDTTypeInfo) for id, e := range udt.Elements { if len(data) == 0 { return nil } if len(data) < 4 { // UDT def does not match the column value return unmarshalErrorf("can not unmarshal %s: field [%d]%s: unexpected eof", info, id, e.Name) } var p []byte p, data = readBytes(data) f, ok := fields[e.Name] if !ok { f = k.FieldByName(e.Name) if f == emptyValue { // skip fields which exist in the UDT but not in // the struct passed in continue } } if !f.IsValid() || !f.CanAddr() { return unmarshalErrorf("cannot unmarshal %s into %T: field %v is not valid", info, value, e.Name) } fk := f.Addr().Interface() if err := Unmarshal(e.Type, p, fk); err != nil { return err } } return nil } // TypeInfo describes a Cassandra specific data type. type TypeInfo interface { Type() Type Version() byte Custom() string // New creates a pointer to an empty version of whatever type // is referenced by the TypeInfo receiver. // // If there is no corresponding Go type for the CQL type, New panics. // // Deprecated: Use NewWithError instead. New() interface{} // NewWithError creates a pointer to an empty version of whatever type // is referenced by the TypeInfo receiver. // // If there is no corresponding Go type for the CQL type, NewWithError returns an error. NewWithError() (interface{}, error) } type NativeType struct { proto byte typ Type custom string // only used for TypeCustom } func NewNativeType(proto byte, typ Type, custom string) NativeType { return NativeType{proto, typ, custom} } func (t NativeType) NewWithError() (interface{}, error) { typ, err := goType(t) if err != nil { return nil, err } return reflect.New(typ).Interface(), nil } func (t NativeType) New() interface{} { val, err := t.NewWithError() if err != nil { panic(err.Error()) } return val } func (s NativeType) Type() Type { return s.typ } func (s NativeType) Version() byte { return s.proto } func (s NativeType) Custom() string { return s.custom } func (s NativeType) String() string { switch s.typ { case TypeCustom: return fmt.Sprintf("%s(%s)", s.typ, s.custom) default: return s.typ.String() } } type CollectionType struct { NativeType Key TypeInfo // only used for TypeMap Elem TypeInfo // only used for TypeMap, TypeList and TypeSet } func (t CollectionType) NewWithError() (interface{}, error) { typ, err := goType(t) if err != nil { return nil, err } return reflect.New(typ).Interface(), nil } func (t CollectionType) New() interface{} { val, err := t.NewWithError() if err != nil { panic(err.Error()) } return val } func (c CollectionType) String() string { switch c.typ { case TypeMap: return fmt.Sprintf("%s(%s, %s)", c.typ, c.Key, c.Elem) case TypeList, TypeSet: return fmt.Sprintf("%s(%s)", c.typ, c.Elem) case TypeCustom: return fmt.Sprintf("%s(%s)", c.typ, c.custom) default: return c.typ.String() } } type TupleTypeInfo struct { NativeType Elems []TypeInfo } func (t TupleTypeInfo) String() string { var buf bytes.Buffer buf.WriteString(fmt.Sprintf("%s(", t.typ)) for _, elem := range t.Elems { buf.WriteString(fmt.Sprintf("%s, ", elem)) } buf.Truncate(buf.Len() - 2) buf.WriteByte(')') return buf.String() } func (t TupleTypeInfo) NewWithError() (interface{}, error) { typ, err := goType(t) if err != nil { return nil, err } return reflect.New(typ).Interface(), nil } func (t TupleTypeInfo) New() interface{} { val, err := t.NewWithError() if err != nil { panic(err.Error()) } return val } type UDTField struct { Name string Type TypeInfo } type UDTTypeInfo struct { NativeType KeySpace string Name string Elements []UDTField } func (u UDTTypeInfo) NewWithError() (interface{}, error) { typ, err := goType(u) if err != nil { return nil, err } return reflect.New(typ).Interface(), nil } func (u UDTTypeInfo) New() interface{} { val, err := u.NewWithError() if err != nil { panic(err.Error()) } return val } func (u UDTTypeInfo) String() string { buf := &bytes.Buffer{} fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name) first := true for _, e := range u.Elements { if !first { fmt.Fprint(buf, ",") } else { first = false } fmt.Fprintf(buf, "%s=%v", e.Name, e.Type) } fmt.Fprint(buf, "}") return buf.String() } // String returns a human readable name for the Cassandra datatype // described by t. // Type is the identifier of a Cassandra internal datatype. type Type int const ( TypeCustom Type = 0x0000 TypeAscii Type = 0x0001 TypeBigInt Type = 0x0002 TypeBlob Type = 0x0003 TypeBoolean Type = 0x0004 TypeCounter Type = 0x0005 TypeDecimal Type = 0x0006 TypeDouble Type = 0x0007 TypeFloat Type = 0x0008 TypeInt Type = 0x0009 TypeText Type = 0x000A TypeTimestamp Type = 0x000B TypeUUID Type = 0x000C TypeVarchar Type = 0x000D TypeVarint Type = 0x000E TypeTimeUUID Type = 0x000F TypeInet Type = 0x0010 TypeDate Type = 0x0011 TypeTime Type = 0x0012 TypeSmallInt Type = 0x0013 TypeTinyInt Type = 0x0014 TypeDuration Type = 0x0015 TypeList Type = 0x0020 TypeMap Type = 0x0021 TypeSet Type = 0x0022 TypeUDT Type = 0x0030 TypeTuple Type = 0x0031 ) // String returns the name of the identifier. func (t Type) String() string { switch t { case TypeCustom: return "custom" case TypeAscii: return "ascii" case TypeBigInt: return "bigint" case TypeBlob: return "blob" case TypeBoolean: return "boolean" case TypeCounter: return "counter" case TypeDecimal: return "decimal" case TypeDouble: return "double" case TypeFloat: return "float" case TypeInt: return "int" case TypeText: return "text" case TypeTimestamp: return "timestamp" case TypeUUID: return "uuid" case TypeVarchar: return "varchar" case TypeTimeUUID: return "timeuuid" case TypeInet: return "inet" case TypeDate: return "date" case TypeDuration: return "duration" case TypeTime: return "time" case TypeSmallInt: return "smallint" case TypeTinyInt: return "tinyint" case TypeList: return "list" case TypeMap: return "map" case TypeSet: return "set" case TypeVarint: return "varint" case TypeTuple: return "tuple" default: return fmt.Sprintf("unknown_type_%d", t) } } type MarshalError string func (m MarshalError) Error() string { return string(m) } func marshalErrorf(format string, args ...interface{}) MarshalError { return MarshalError(fmt.Sprintf(format, args...)) } type UnmarshalError string func (m UnmarshalError) Error() string { return string(m) } func unmarshalErrorf(format string, args ...interface{}) UnmarshalError { return UnmarshalError(fmt.Sprintf(format, args...)) } cassandra-gocql-driver-1.7.0/marshal_test.go000066400000000000000000001642561467504044300211140ustar00rootroot00000000000000//go:build all || unit // +build all unit /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "encoding/binary" "math" "math/big" "net" "reflect" "strings" "testing" "time" "gopkg.in/inf.v0" ) type AliasInt int type AliasUint uint type AliasUint8 uint8 type AliasUint16 uint16 type AliasUint32 uint32 type AliasUint64 uint64 var marshalTests = []struct { Info TypeInfo Data []byte Value interface{} MarshalError error UnmarshalError error }{ { NativeType{proto: 2, typ: TypeVarchar}, []byte("hello world"), []byte("hello world"), nil, nil, }, { NativeType{proto: 2, typ: TypeVarchar}, []byte("hello world"), "hello world", nil, nil, }, { NativeType{proto: 2, typ: TypeVarchar}, []byte(nil), []byte(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeVarchar}, []byte("hello world"), MyString("hello world"), nil, nil, }, { NativeType{proto: 2, typ: TypeVarchar}, []byte("HELLO WORLD"), CustomString("hello world"), nil, nil, }, { NativeType{proto: 2, typ: TypeBlob}, []byte("hello\x00"), []byte("hello\x00"), nil, nil, }, { NativeType{proto: 2, typ: TypeBlob}, []byte(nil), []byte(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeTimeUUID}, []byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, func() UUID { x, _ := UUIDFromBytes([]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}) return x }(), nil, nil, }, { NativeType{proto: 2, typ: TypeTimeUUID}, []byte{0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, []byte{0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, MarshalError("can not marshal []byte 6 bytes long into timeuuid, must be exactly 16 bytes long"), UnmarshalError("unable to parse UUID: UUIDs must be exactly 16 bytes long"), }, { NativeType{proto: 2, typ: TypeTimeUUID}, []byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, [16]byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x00\x00\x00\x00"), 0, nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x01\x02\x03\x04"), int(16909060), nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x01\x02\x03\x04"), AliasInt(16909060), nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x80\x00\x00\x00"), int32(math.MinInt32), nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x7f\xff\xff\xff"), int32(math.MaxInt32), nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x00\x00\x00\x00"), "0", nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x01\x02\x03\x04"), "16909060", nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x80\x00\x00\x00"), "-2147483648", // math.MinInt32 nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x7f\xff\xff\xff"), "2147483647", // math.MaxInt32 nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x00\x00"), 0, nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x01\x02\x03\x04\x05\x06\x07\x08"), 72623859790382856, nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x80\x00\x00\x00\x00\x00\x00\x00"), int64(math.MinInt64), nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x7f\xff\xff\xff\xff\xff\xff\xff"), int64(math.MaxInt64), nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x00\x00"), "0", nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x01\x02\x03\x04\x05\x06\x07\x08"), "72623859790382856", nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x80\x00\x00\x00\x00\x00\x00\x00"), "-9223372036854775808", // math.MinInt64 nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x7f\xff\xff\xff\xff\xff\xff\xff"), "9223372036854775807", // math.MaxInt64 nil, nil, }, { NativeType{proto: 2, typ: TypeBoolean}, []byte("\x00"), false, nil, nil, }, { NativeType{proto: 2, typ: TypeBoolean}, []byte("\x01"), true, nil, nil, }, { NativeType{proto: 2, typ: TypeFloat}, []byte("\x40\x49\x0f\xdb"), float32(3.14159265), nil, nil, }, { NativeType{proto: 2, typ: TypeDouble}, []byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"), float64(3.14159265), nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x00\x00"), inf.NewDec(0, 0), nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x00\x64"), inf.NewDec(100, 0), nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x02\x19"), decimalize("0.25"), nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x13\xD5\a;\x20\x14\xA2\x91"), decimalize("-0.0012095473475870063"), // From the iconara/cql-rb test suite nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x13*\xF8\xC4\xDF\xEB]o"), decimalize("0.0012095473475870063"), // From the iconara/cql-rb test suite nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x12\xF2\xD8\x02\xB6R\x7F\x99\xEE\x98#\x99\xA9V"), decimalize("-1042342234234.123423435647768234"), // From the iconara/cql-rb test suite nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\r\nJ\x04\"^\x91\x04\x8a\xb1\x18\xfe"), decimalize("1243878957943.1234124191998"), // From the datastax/python-driver test suite nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x06\xe5\xde]\x98Y"), decimalize("-112233.441191"), // From the datastax/python-driver test suite nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x14\x00\xfa\xce"), decimalize("0.00000000000000064206"), // From the datastax/python-driver test suite nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\x00\x00\x00\x14\xff\x052"), decimalize("-0.00000000000000064206"), // From the datastax/python-driver test suite nil, nil, }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\xff\xff\xff\x9c\x00\xfa\xce"), inf.NewDec(64206, -100), // From the datastax/python-driver test suite nil, nil, }, { NativeType{proto: 4, typ: TypeTime}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), time.Duration(int64(1376387523000)), nil, nil, }, { NativeType{proto: 4, typ: TypeTime}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), int64(1376387523000), nil, nil, }, { NativeType{proto: 2, typ: TypeTimestamp}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), nil, nil, }, { NativeType{proto: 2, typ: TypeTimestamp}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), int64(1376387523000), nil, nil, }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91\x06"), Duration{Months: 1233, Days: 123213, Nanoseconds: 2312323}, nil, nil, }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa1\xc3\xc2\x99\xe0F\x91\x05"), Duration{Months: -1233, Days: -123213, Nanoseconds: -2312323}, nil, nil, }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x02\x04\x80\xe6"), Duration{Months: 1, Days: 2, Nanoseconds: 115}, nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeList}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"), []int{1, 2}, nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeList}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"), [2]int{1, 2}, nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeSet}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"), []int{1, 2}, nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeSet}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte{0, 0}, // encoding of a list should always include the size of the collection []int{}, nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeMap}, Key: NativeType{proto: 2, typ: TypeVarchar}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"), map[string]int{"foo": 1}, nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeMap}, Key: NativeType{proto: 2, typ: TypeVarchar}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte{0, 0}, map[string]int{}, nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeList}, Elem: NativeType{proto: 2, typ: TypeVarchar}, }, bytes.Join([][]byte{ []byte("\x00\x01\xFF\xFF"), bytes.Repeat([]byte("X"), math.MaxUint16)}, []byte("")), []string{strings.Repeat("X", math.MaxUint16)}, nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeMap}, Key: NativeType{proto: 2, typ: TypeVarchar}, Elem: NativeType{proto: 2, typ: TypeVarchar}, }, bytes.Join([][]byte{ []byte("\x00\x01\xFF\xFF"), bytes.Repeat([]byte("X"), math.MaxUint16), []byte("\xFF\xFF"), bytes.Repeat([]byte("Y"), math.MaxUint16)}, []byte("")), map[string]string{ strings.Repeat("X", math.MaxUint16): strings.Repeat("Y", math.MaxUint16), }, nil, nil, }, { NativeType{proto: 2, typ: TypeVarint}, []byte("\x00"), 0, nil, nil, }, { NativeType{proto: 2, typ: TypeVarint}, []byte("\x37\xE2\x3C\xEC"), int32(937573612), nil, nil, }, { NativeType{proto: 2, typ: TypeVarint}, []byte("\x37\xE2\x3C\xEC"), big.NewInt(937573612), nil, nil, }, { NativeType{proto: 2, typ: TypeVarint}, []byte("\x03\x9EV \x15\f\x03\x9DK\x18\xCDI\\$?\a["), bigintize("1231312312331283012830129382342342412123"), // From the iconara/cql-rb test suite nil, nil, }, { NativeType{proto: 2, typ: TypeVarint}, []byte("\xC9v\x8D:\x86"), big.NewInt(-234234234234), // From the iconara/cql-rb test suite nil, nil, }, { NativeType{proto: 2, typ: TypeVarint}, []byte("f\x1e\xfd\xf2\xe3\xb1\x9f|\x04_\x15"), bigintize("123456789123456789123456789"), // From the datastax/python-driver test suite nil, nil, }, { NativeType{proto: 2, typ: TypeVarint}, []byte(nil), nil, nil, UnmarshalError("can not unmarshal into non-pointer "), }, { NativeType{proto: 2, typ: TypeInet}, []byte("\x7F\x00\x00\x01"), net.ParseIP("127.0.0.1").To4(), nil, nil, }, { NativeType{proto: 2, typ: TypeInet}, []byte("\xFF\xFF\xFF\xFF"), net.ParseIP("255.255.255.255").To4(), nil, nil, }, { NativeType{proto: 2, typ: TypeInet}, []byte("\x7F\x00\x00\x01"), "127.0.0.1", nil, nil, }, { NativeType{proto: 2, typ: TypeInet}, []byte("\xFF\xFF\xFF\xFF"), "255.255.255.255", nil, nil, }, { NativeType{proto: 2, typ: TypeInet}, []byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"), "21da:d3:0:2f3b:2aa:ff:fe28:9c5a", nil, nil, }, { NativeType{proto: 2, typ: TypeInet}, []byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"), "fe80::202:b3ff:fe1e:8329", nil, nil, }, { NativeType{proto: 2, typ: TypeInet}, []byte("\x21\xDA\x00\xd3\x00\x00\x2f\x3b\x02\xaa\x00\xff\xfe\x28\x9c\x5a"), net.ParseIP("21da:d3:0:2f3b:2aa:ff:fe28:9c5a"), nil, nil, }, { NativeType{proto: 2, typ: TypeInet}, []byte("\xfe\x80\x00\x00\x00\x00\x00\x00\x02\x02\xb3\xff\xfe\x1e\x83\x29"), net.ParseIP("fe80::202:b3ff:fe1e:8329"), nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte(nil), nil, nil, UnmarshalError("can not unmarshal into non-pointer "), }, { NativeType{proto: 2, typ: TypeVarchar}, []byte("nullable string"), func() *string { value := "nullable string" return &value }(), nil, nil, }, { NativeType{proto: 2, typ: TypeVarchar}, []byte(nil), (*string)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x7f\xff\xff\xff"), func() *int { var value int = math.MaxInt32 return &value }(), nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte(nil), (*int)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeTimeUUID}, []byte{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, &UUID{0x3d, 0xcd, 0x98, 0x0, 0xf3, 0xd9, 0x11, 0xbf, 0x86, 0xd4, 0xb8, 0xe8, 0x56, 0x2c, 0xc, 0xd0}, nil, nil, }, { NativeType{proto: 2, typ: TypeTimeUUID}, []byte(nil), (*UUID)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeTimestamp}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), func() *time.Time { t := time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC) return &t }(), nil, nil, }, { NativeType{proto: 2, typ: TypeTimestamp}, []byte(nil), (*time.Time)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeBoolean}, []byte("\x00"), func() *bool { b := false return &b }(), nil, nil, }, { NativeType{proto: 2, typ: TypeBoolean}, []byte("\x01"), func() *bool { b := true return &b }(), nil, nil, }, { NativeType{proto: 2, typ: TypeBoolean}, []byte(nil), (*bool)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeFloat}, []byte("\x40\x49\x0f\xdb"), func() *float32 { f := float32(3.14159265) return &f }(), nil, nil, }, { NativeType{proto: 2, typ: TypeFloat}, []byte(nil), (*float32)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeDouble}, []byte("\x40\x09\x21\xfb\x53\xc8\xd4\xf1"), func() *float64 { d := float64(3.14159265) return &d }(), nil, nil, }, { NativeType{proto: 2, typ: TypeDouble}, []byte(nil), (*float64)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeInet}, []byte("\x7F\x00\x00\x01"), func() *net.IP { ip := net.ParseIP("127.0.0.1").To4() return &ip }(), nil, nil, }, { NativeType{proto: 2, typ: TypeInet}, []byte(nil), (*net.IP)(nil), nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeList}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte("\x00\x02\x00\x04\x00\x00\x00\x01\x00\x04\x00\x00\x00\x02"), func() *[]int { l := []int{1, 2} return &l }(), nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 3, typ: TypeList}, Elem: NativeType{proto: 3, typ: TypeInt}, }, []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02"), func() *[]int { l := []int{1, 2} return &l }(), nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeList}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte(nil), (*[]int)(nil), nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeMap}, Key: NativeType{proto: 2, typ: TypeVarchar}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte("\x00\x01\x00\x03foo\x00\x04\x00\x00\x00\x01"), func() *map[string]int { m := map[string]int{"foo": 1} return &m }(), nil, nil, }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeMap}, Key: NativeType{proto: 2, typ: TypeVarchar}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte(nil), (*map[string]int)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeVarchar}, []byte("HELLO WORLD"), func() *CustomString { customString := CustomString("hello world") return &customString }(), nil, nil, }, { NativeType{proto: 2, typ: TypeVarchar}, []byte(nil), (*CustomString)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\x7f\xff"), 32767, // math.MaxInt16 nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\x7f\xff"), "32767", // math.MaxInt16 nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\x00\x01"), int16(1), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), int16(-1), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\x00\xff"), uint8(255), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), uint16(65535), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), uint32(65535), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), uint64(65535), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\x00\xff"), AliasUint8(255), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), AliasUint16(65535), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), AliasUint32(65535), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), AliasUint64(65535), nil, nil, }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), AliasUint(65535), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\x7f"), 127, // math.MaxInt8 nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\x7f"), "127", // math.MaxInt8 nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\x01"), int16(1), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), int16(-1), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), uint8(255), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), uint64(255), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), uint32(255), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), uint16(255), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), uint(255), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), AliasUint8(255), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), AliasUint64(255), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), AliasUint32(255), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), AliasUint16(255), nil, nil, }, { NativeType{proto: 2, typ: TypeTinyInt}, []byte("\xff"), AliasUint(255), nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x00\xff"), uint8(math.MaxUint8), nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\xff\xff"), uint64(math.MaxUint16), nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\xff\xff\xff\xff"), uint64(math.MaxUint32), nil, nil, }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint64(math.MaxUint64), nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\xff\xff\xff\xff"), uint32(math.MaxUint32), nil, nil, }, { NativeType{proto: 2, typ: TypeInt}, []byte("\xff\xff\xff\xff"), uint64(math.MaxUint32), nil, nil, }, { NativeType{proto: 2, typ: TypeBlob}, []byte(nil), ([]byte)(nil), nil, nil, }, { NativeType{proto: 2, typ: TypeVarchar}, []byte{}, func() interface{} { var s string return &s }(), nil, nil, }, { NativeType{proto: 2, typ: TypeTime}, encBigInt(1000), time.Duration(1000), nil, nil, }, } var unmarshalTests = []struct { Info TypeInfo Data []byte Value interface{} UnmarshalError error }{ { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), uint8(0), UnmarshalError("unmarshal int: value -1 out of range for uint8"), }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\x01\x00"), uint8(0), UnmarshalError("unmarshal int: value 256 out of range for uint8"), }, { NativeType{proto: 2, typ: TypeInt}, []byte("\xff\xff\xff\xff"), uint8(0), UnmarshalError("unmarshal int: value -1 out of range for uint8"), }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x00\x00\x01\x00"), uint8(0), UnmarshalError("unmarshal int: value 256 out of range for uint8"), }, { NativeType{proto: 2, typ: TypeInt}, []byte("\xff\xff\xff\xff"), uint16(0), UnmarshalError("unmarshal int: value -1 out of range for uint16"), }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x00\x01\x00\x00"), uint16(0), UnmarshalError("unmarshal int: value 65536 out of range for uint16"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint8(0), UnmarshalError("unmarshal int: value -1 out of range for uint8"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), uint8(0), UnmarshalError("unmarshal int: value 256 out of range for uint8"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint8(0), UnmarshalError("unmarshal int: value -1 out of range for uint8"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), uint8(0), UnmarshalError("unmarshal int: value 256 out of range for uint8"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint16(0), UnmarshalError("unmarshal int: value -1 out of range for uint16"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x01\x00\x00"), uint16(0), UnmarshalError("unmarshal int: value 65536 out of range for uint16"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), uint32(0), UnmarshalError("unmarshal int: value -1 out of range for uint32"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x01\x00\x00\x00\x00"), uint32(0), UnmarshalError("unmarshal int: value 4294967296 out of range for uint32"), }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\xff\xff"), AliasUint8(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), }, { NativeType{proto: 2, typ: TypeSmallInt}, []byte("\x01\x00"), AliasUint8(0), UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), }, { NativeType{proto: 2, typ: TypeInt}, []byte("\xff\xff\xff\xff"), AliasUint8(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x00\x00\x01\x00"), AliasUint8(0), UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), }, { NativeType{proto: 2, typ: TypeInt}, []byte("\xff\xff\xff\xff"), AliasUint16(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint16"), }, { NativeType{proto: 2, typ: TypeInt}, []byte("\x00\x01\x00\x00"), AliasUint16(0), UnmarshalError("unmarshal int: value 65536 out of range for gocql.AliasUint16"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), AliasUint8(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), AliasUint8(0), UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), AliasUint8(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint8"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x00\x01\x00"), AliasUint8(0), UnmarshalError("unmarshal int: value 256 out of range for gocql.AliasUint8"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), AliasUint16(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint16"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x00\x00\x01\x00\x00"), AliasUint16(0), UnmarshalError("unmarshal int: value 65536 out of range for gocql.AliasUint16"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\xff\xff\xff\xff\xff\xff\xff\xff"), AliasUint32(0), UnmarshalError("unmarshal int: value -1 out of range for gocql.AliasUint32"), }, { NativeType{proto: 2, typ: TypeBigInt}, []byte("\x00\x00\x00\x01\x00\x00\x00\x00"), AliasUint32(0), UnmarshalError("unmarshal int: value 4294967296 out of range for gocql.AliasUint32"), }, { CollectionType{ NativeType: NativeType{proto: 3, typ: TypeList}, Elem: NativeType{proto: 3, typ: TypeInt}, }, []byte("\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00"), // truncated data func() *[]int { l := []int{1, 2} return &l }(), UnmarshalError("unmarshal list: unexpected eof"), }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeMap}, Key: NativeType{proto: 2, typ: TypeVarchar}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte("\x00\x01\x00\x03fo"), map[string]int{"foo": 1}, UnmarshalError("unmarshal map: unexpected eof"), }, { CollectionType{ NativeType: NativeType{proto: 2, typ: TypeMap}, Key: NativeType{proto: 2, typ: TypeVarchar}, Elem: NativeType{proto: 2, typ: TypeInt}, }, []byte("\x00\x01\x00\x03foo\x00\x04\x00\x00"), map[string]int{"foo": 1}, UnmarshalError("unmarshal map: unexpected eof"), }, { NativeType{proto: 2, typ: TypeDecimal}, []byte("\xff\xff\xff"), inf.NewDec(0, 0), // From the datastax/python-driver test suite UnmarshalError("inf.Dec needs at least 4 bytes, while value has only 3"), }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2\xc3\xc2\x9a\xe0F\x91"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract nanoseconds: data expect to have 9 bytes, but it has only 8"), }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2\xc3\xc2\x9a"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract nanoseconds: unexpected eof"), }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2\xc3\xc2"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract days: data expect to have 5 bytes, but it has only 4"), }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89\xa2"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract days: unexpected eof"), }, { NativeType{proto: 5, typ: TypeDuration}, []byte("\x89"), Duration{}, UnmarshalError("failed to unmarshal duration into *gocql.Duration: failed to extract month: data expect to have 2 bytes, but it has only 1"), }, } func decimalize(s string) *inf.Dec { i, _ := new(inf.Dec).SetString(s) return i } func bigintize(s string) *big.Int { i, _ := new(big.Int).SetString(s, 10) return i } func TestMarshal_Encode(t *testing.T) { for i, test := range marshalTests { if test.MarshalError == nil { data, err := Marshal(test.Info, test.Value) if err != nil { t.Errorf("marshalTest[%d]: %v", i, err) continue } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %q, got %q (%#v)", i, test.Data, data, test.Value) } } else { if _, err := Marshal(test.Info, test.Value); err != test.MarshalError { t.Errorf("unmarshalTest[%d] (%v=>%t): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.MarshalError) } } } } func TestMarshal_Decode(t *testing.T) { for i, test := range marshalTests { if test.UnmarshalError == nil { v := reflect.New(reflect.TypeOf(test.Value)) err := Unmarshal(test.Info, test.Data, v.Interface()) if err != nil { t.Errorf("unmarshalTest[%d] (%v=>%T): %v", i, test.Info, test.Value, err) continue } if !reflect.DeepEqual(v.Elem().Interface(), test.Value) { t.Errorf("unmarshalTest[%d] (%v=>%T): expected %#v, got %#v.", i, test.Info, test.Value, test.Value, v.Elem().Interface()) } } else { if err := Unmarshal(test.Info, test.Data, test.Value); err != test.UnmarshalError { t.Errorf("unmarshalTest[%d] (%v=>%T): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError) } } } for i, test := range unmarshalTests { v := reflect.New(reflect.TypeOf(test.Value)) if test.UnmarshalError == nil { err := Unmarshal(test.Info, test.Data, v.Interface()) if err != nil { t.Errorf("unmarshalTest[%d] (%v=>%T): %v", i, test.Info, test.Value, err) continue } if !reflect.DeepEqual(v.Elem().Interface(), test.Value) { t.Errorf("unmarshalTest[%d] (%v=>%T): expected %#v, got %#v.", i, test.Info, test.Value, test.Value, v.Elem().Interface()) } } else { if err := Unmarshal(test.Info, test.Data, v.Interface()); err != test.UnmarshalError { t.Errorf("unmarshalTest[%d] (%v=>%T): %#v returned error %#v, want %#v.", i, test.Info, test.Value, test.Value, err, test.UnmarshalError) } } } } func TestMarshalVarint(t *testing.T) { varintTests := []struct { Value interface{} Marshaled []byte Unmarshaled *big.Int }{ { Value: int8(0), Marshaled: []byte("\x00"), Unmarshaled: big.NewInt(0), }, { Value: uint8(255), Marshaled: []byte("\x00\xFF"), Unmarshaled: big.NewInt(255), }, { Value: int8(-1), Marshaled: []byte("\xFF"), Unmarshaled: big.NewInt(-1), }, { Value: big.NewInt(math.MaxInt32), Marshaled: []byte("\x7F\xFF\xFF\xFF"), Unmarshaled: big.NewInt(math.MaxInt32), }, { Value: big.NewInt(int64(math.MaxInt32) + 1), Marshaled: []byte("\x00\x80\x00\x00\x00"), Unmarshaled: big.NewInt(int64(math.MaxInt32) + 1), }, { Value: big.NewInt(math.MinInt32), Marshaled: []byte("\x80\x00\x00\x00"), Unmarshaled: big.NewInt(math.MinInt32), }, { Value: big.NewInt(int64(math.MinInt32) - 1), Marshaled: []byte("\xFF\x7F\xFF\xFF\xFF"), Unmarshaled: big.NewInt(int64(math.MinInt32) - 1), }, { Value: math.MinInt64, Marshaled: []byte("\x80\x00\x00\x00\x00\x00\x00\x00"), Unmarshaled: big.NewInt(math.MinInt64), }, { Value: uint64(math.MaxInt64) + 1, Marshaled: []byte("\x00\x80\x00\x00\x00\x00\x00\x00\x00"), Unmarshaled: bigintize("9223372036854775808"), }, { Value: bigintize("2361183241434822606848"), // 2**71 Marshaled: []byte("\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00"), Unmarshaled: bigintize("2361183241434822606848"), }, { Value: bigintize("-9223372036854775809"), // -2**63 - 1 Marshaled: []byte("\xFF\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF"), Unmarshaled: bigintize("-9223372036854775809"), }, } for i, test := range varintTests { data, err := Marshal(NativeType{proto: 2, typ: TypeVarint}, test.Value) if err != nil { t.Errorf("error marshaling varint: %v (test #%d)", err, i) } if !bytes.Equal(test.Marshaled, data) { t.Errorf("marshaled varint mismatch: expected %v, got %v (test #%d)", test.Marshaled, data, i) } binder := new(big.Int) err = Unmarshal(NativeType{proto: 2, typ: TypeVarint}, test.Marshaled, binder) if err != nil { t.Errorf("error unmarshaling varint: %v (test #%d)", err, i) } if test.Unmarshaled.Cmp(binder) != 0 { t.Errorf("unmarshaled varint mismatch: expected %v, got %v (test #%d)", test.Unmarshaled, binder, i) } } varintUint64Tests := []struct { Value interface{} Marshaled []byte Unmarshaled uint64 }{ { Value: int8(0), Marshaled: []byte("\x00"), Unmarshaled: 0, }, { Value: uint8(255), Marshaled: []byte("\x00\xFF"), Unmarshaled: 255, }, { Value: big.NewInt(math.MaxInt32), Marshaled: []byte("\x7F\xFF\xFF\xFF"), Unmarshaled: uint64(math.MaxInt32), }, { Value: big.NewInt(int64(math.MaxInt32) + 1), Marshaled: []byte("\x00\x80\x00\x00\x00"), Unmarshaled: uint64(int64(math.MaxInt32) + 1), }, { Value: uint64(math.MaxInt64) + 1, Marshaled: []byte("\x00\x80\x00\x00\x00\x00\x00\x00\x00"), Unmarshaled: 9223372036854775808, }, { Value: uint64(math.MaxUint64), Marshaled: []byte("\x00\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"), Unmarshaled: uint64(math.MaxUint64), }, } for i, test := range varintUint64Tests { data, err := Marshal(NativeType{proto: 2, typ: TypeVarint}, test.Value) if err != nil { t.Errorf("error marshaling varint: %v (test #%d)", err, i) } if !bytes.Equal(test.Marshaled, data) { t.Errorf("marshaled varint mismatch: expected %v, got %v (test #%d)", test.Marshaled, data, i) } var binder uint64 err = Unmarshal(NativeType{proto: 2, typ: TypeVarint}, test.Marshaled, &binder) if err != nil { t.Errorf("error unmarshaling varint to uint64: %v (test #%d)", err, i) } if test.Unmarshaled != binder { t.Errorf("unmarshaled varint mismatch: expected %v, got %v (test #%d)", test.Unmarshaled, binder, i) } } } func equalStringPointerSlice(leftList, rightList []*string) bool { if len(leftList) != len(rightList) { return false } for index := range leftList { if !reflect.DeepEqual(rightList[index], leftList[index]) { return false } } return true } func TestMarshalList(t *testing.T) { typeInfoV2 := CollectionType{ NativeType: NativeType{proto: 2, typ: TypeList}, Elem: NativeType{proto: 2, typ: TypeVarchar}, } typeInfoV3 := CollectionType{ NativeType: NativeType{proto: 3, typ: TypeList}, Elem: NativeType{proto: 3, typ: TypeVarchar}, } type tc struct { typeInfo CollectionType input []*string expected []*string } valueA := "valueA" valueB := "valueB" valueEmpty := "" testCases := []tc{ { typeInfo: typeInfoV2, input: []*string{&valueA}, expected: []*string{&valueA}, }, { typeInfo: typeInfoV2, input: []*string{&valueA, &valueB}, expected: []*string{&valueA, &valueB}, }, { typeInfo: typeInfoV2, input: []*string{&valueA, &valueEmpty, &valueB}, expected: []*string{&valueA, &valueEmpty, &valueB}, }, { typeInfo: typeInfoV2, input: []*string{&valueEmpty}, expected: []*string{&valueEmpty}, }, { // nil values are marshalled to empty values for protocol < 3 typeInfo: typeInfoV2, input: []*string{nil}, expected: []*string{&valueEmpty}, }, { typeInfo: typeInfoV2, input: []*string{&valueA, nil, &valueB}, expected: []*string{&valueA, &valueEmpty, &valueB}, }, { typeInfo: typeInfoV3, input: []*string{&valueEmpty}, expected: []*string{&valueEmpty}, }, { typeInfo: typeInfoV3, input: []*string{nil}, expected: []*string{nil}, }, { typeInfo: typeInfoV3, input: []*string{&valueA, nil, &valueB}, expected: []*string{&valueA, nil, &valueB}, }, } listDatas := [][]byte{} for _, c := range testCases { listData, marshalErr := Marshal(c.typeInfo, c.input) if nil != marshalErr { t.Errorf("Error marshal %+v of type %+v: %s", c.input, c.typeInfo, marshalErr) } listDatas = append(listDatas, listData) } outputLists := [][]*string{} var outputList []*string for i, listData := range listDatas { if unmarshalErr := Unmarshal(testCases[i].typeInfo, listData, &outputList); nil != unmarshalErr { t.Error(unmarshalErr) } resultList := []interface{}{} for i := range outputList { if outputList[i] != nil { resultList = append(resultList, *outputList[i]) } else { resultList = append(resultList, nil) } } outputLists = append(outputLists, outputList) } for index, c := range testCases { outputList := outputLists[index] if !equalStringPointerSlice(c.expected, outputList) { t.Errorf("Lists %+v not equal to lists %+v, but should", c.expected, outputList) } } } type CustomString string func (c CustomString) MarshalCQL(info TypeInfo) ([]byte, error) { return []byte(strings.ToUpper(string(c))), nil } func (c *CustomString) UnmarshalCQL(info TypeInfo, data []byte) error { *c = CustomString(strings.ToLower(string(data))) return nil } type MyString string var typeLookupTest = []struct { TypeName string ExpectedType Type }{ {"AsciiType", TypeAscii}, {"LongType", TypeBigInt}, {"BytesType", TypeBlob}, {"BooleanType", TypeBoolean}, {"CounterColumnType", TypeCounter}, {"DecimalType", TypeDecimal}, {"DoubleType", TypeDouble}, {"FloatType", TypeFloat}, {"Int32Type", TypeInt}, {"DateType", TypeTimestamp}, {"TimestampType", TypeTimestamp}, {"UUIDType", TypeUUID}, {"UTF8Type", TypeVarchar}, {"IntegerType", TypeVarint}, {"TimeUUIDType", TypeTimeUUID}, {"InetAddressType", TypeInet}, {"MapType", TypeMap}, {"ListType", TypeList}, {"SetType", TypeSet}, {"unknown", TypeCustom}, {"ShortType", TypeSmallInt}, {"ByteType", TypeTinyInt}, } func testType(t *testing.T, cassType string, expectedType Type) { if computedType := getApacheCassandraType(apacheCassandraTypePrefix + cassType); computedType != expectedType { t.Errorf("Cassandra custom type lookup for %s failed. Expected %s, got %s.", cassType, expectedType.String(), computedType.String()) } } func TestLookupCassType(t *testing.T) { for _, lookupTest := range typeLookupTest { testType(t, lookupTest.TypeName, lookupTest.ExpectedType) } } type MyPointerMarshaler struct{} func (m *MyPointerMarshaler) MarshalCQL(_ TypeInfo) ([]byte, error) { return []byte{42}, nil } func TestMarshalPointer(t *testing.T) { m := &MyPointerMarshaler{} typ := NativeType{proto: 2, typ: TypeInt} data, err := Marshal(typ, m) if err != nil { t.Errorf("Pointer marshaling failed. Error: %s", err) } if len(data) != 1 || data[0] != 42 { t.Errorf("Pointer marshaling failed. Expected %+v, got %+v", []byte{42}, data) } } func TestMarshalTime(t *testing.T) { durationS := "1h10m10s" duration, _ := time.ParseDuration(durationS) expectedData := encBigInt(duration.Nanoseconds()) var marshalTimeTests = []struct { Info TypeInfo Data []byte Value interface{} }{ { NativeType{proto: 4, typ: TypeTime}, expectedData, duration.Nanoseconds(), }, { NativeType{proto: 4, typ: TypeTime}, expectedData, duration, }, { NativeType{proto: 4, typ: TypeTime}, expectedData, &duration, }, } for i, test := range marshalTimeTests { t.Log(i, test) data, err := Marshal(test.Info, test.Value) if err != nil { t.Errorf("marshalTest[%d]: %v", i, err) continue } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, test.Data, decInt(test.Data), data, decInt(data), test.Value) } } } func TestMarshalTimestamp(t *testing.T) { var marshalTimestampTests = []struct { Info TypeInfo Data []byte Value interface{} }{ { NativeType{proto: 2, typ: TypeTimestamp}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), }, { NativeType{proto: 2, typ: TypeTimestamp}, []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), int64(1376387523000), }, { // 9223372036854 is the maximum time representable in ms since the epoch // with int64 if using UnixNano to convert NativeType{proto: 2, typ: TypeTimestamp}, []byte("\x00\x00\x08\x63\x7b\xd0\x5a\xf6"), time.Date(2262, time.April, 11, 23, 47, 16, 854775807, time.UTC), }, { // One nanosecond after causes overflow when using UnixNano // Instead it should resolve to the same time in ms NativeType{proto: 2, typ: TypeTimestamp}, []byte("\x00\x00\x08\x63\x7b\xd0\x5a\xf6"), time.Date(2262, time.April, 11, 23, 47, 16, 854775808, time.UTC), }, { // -9223372036855 is the minimum time representable in ms since the epoch // with int64 if using UnixNano to convert NativeType{proto: 2, typ: TypeTimestamp}, []byte("\xff\xff\xf7\x9c\x84\x2f\xa5\x09"), time.Date(1677, time.September, 21, 00, 12, 43, 145224192, time.UTC), }, { // One nanosecond earlier causes overflow when using UnixNano // it should resolve to the same time in ms NativeType{proto: 2, typ: TypeTimestamp}, []byte("\xff\xff\xf7\x9c\x84\x2f\xa5\x09"), time.Date(1677, time.September, 21, 00, 12, 43, 145224191, time.UTC), }, { // Store the zero time as a blank slice NativeType{proto: 2, typ: TypeTimestamp}, []byte{}, time.Time{}, }, } for i, test := range marshalTimestampTests { t.Log(i, test) data, err := Marshal(test.Info, test.Value) if err != nil { t.Errorf("marshalTest[%d]: %v", i, err) continue } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, test.Data, decBigInt(test.Data), data, decBigInt(data), test.Value) } } } func TestMarshalTuple(t *testing.T) { info := TupleTypeInfo{ NativeType: NativeType{proto: 3, typ: TypeTuple}, Elems: []TypeInfo{ NativeType{proto: 3, typ: TypeVarchar}, NativeType{proto: 3, typ: TypeVarchar}, }, } stringToPtr := func(s string) *string { return &s } checkString := func(t *testing.T, exp string, got string) { if got != exp { t.Errorf("expected string to be %v, got %v", exp, got) } } type tupleStruct struct { A string B *string } var ( s1 *string s2 *string ) testCases := []struct { name string expected []byte value interface{} checkValue interface{} check func(*testing.T, interface{}) }{ { name: "interface-slice:two-strings", expected: []byte("\x00\x00\x00\x03foo\x00\x00\x00\x03bar"), value: []interface{}{"foo", "bar"}, checkValue: []interface{}{&s1, &s2}, check: func(t *testing.T, v interface{}) { checkString(t, "foo", *s1) checkString(t, "bar", *s2) }, }, { name: "interface-slice:one-string-one-nil-string", expected: []byte("\x00\x00\x00\x03foo\xff\xff\xff\xff"), value: []interface{}{"foo", nil}, checkValue: []interface{}{&s1, &s2}, check: func(t *testing.T, v interface{}) { checkString(t, "foo", *s1) if s2 != nil { t.Errorf("expected string to be nil, got %v", *s2) } }, }, { name: "struct:two-strings", expected: []byte("\x00\x00\x00\x03foo\x00\x00\x00\x03bar"), value: tupleStruct{ A: "foo", B: stringToPtr("bar"), }, checkValue: &tupleStruct{}, check: func(t *testing.T, v interface{}) { got := v.(*tupleStruct) if got.A != "foo" { t.Errorf("expected A string to be %v, got %v", "foo", got.A) } if got.B == nil { t.Errorf("expected B string to be %v, got nil", "bar") } if *got.B != "bar" { t.Errorf("expected B string to be %v, got %v", "bar", got.B) } }, }, { name: "struct:one-string-one-nil-string", expected: []byte("\x00\x00\x00\x03foo\xff\xff\xff\xff"), value: tupleStruct{A: "foo", B: nil}, checkValue: &tupleStruct{}, check: func(t *testing.T, v interface{}) { got := v.(*tupleStruct) if got.A != "foo" { t.Errorf("expected A string to be %v, got %v", "foo", got.A) } if got.B != nil { t.Errorf("expected B string to be nil, got %v", *got.B) } }, }, { name: "arrayslice:two-strings", expected: []byte("\x00\x00\x00\x03foo\x00\x00\x00\x03bar"), value: [2]*string{ stringToPtr("foo"), stringToPtr("bar"), }, checkValue: &[2]*string{}, check: func(t *testing.T, v interface{}) { got := v.(*[2]*string) checkString(t, "foo", *(got[0])) checkString(t, "bar", *(got[1])) }, }, { name: "arrayslice:one-string-one-nil-string", expected: []byte("\x00\x00\x00\x03foo\xff\xff\xff\xff"), value: [2]*string{ stringToPtr("foo"), nil, }, checkValue: &[2]*string{}, check: func(t *testing.T, v interface{}) { got := v.(*[2]*string) checkString(t, "foo", *(got[0])) if got[1] != nil { t.Errorf("expected string to be nil, got %v", *got[1]) } }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { data, err := Marshal(info, tc.value) if err != nil { t.Errorf("marshalTest: %v", err) return } if !bytes.Equal(data, tc.expected) { t.Errorf("marshalTest: expected %x (%v), got %x (%v)", tc.expected, decBigInt(tc.expected), data, decBigInt(data)) return } err = Unmarshal(info, data, tc.checkValue) if err != nil { t.Errorf("unmarshalTest: %v", err) return } tc.check(t, tc.checkValue) }) } } func TestUnmarshalTuple(t *testing.T) { info := TupleTypeInfo{ NativeType: NativeType{proto: 3, typ: TypeTuple}, Elems: []TypeInfo{ NativeType{proto: 3, typ: TypeVarchar}, NativeType{proto: 3, typ: TypeVarchar}, }, } // As per the CQL spec, a tuple is a sequence of "bytes" values. // Here we encode a null value (length -1) and the "foo" string (length 3) data := []byte("\xff\xff\xff\xff\x00\x00\x00\x03foo") t.Run("struct-ptr", func(t *testing.T) { var tmp struct { A *string B *string } err := Unmarshal(info, data, &tmp) if err != nil { t.Errorf("unmarshalTest: %v", err) return } if tmp.A != nil || *tmp.B != "foo" { t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", *tmp.A, *tmp.B) } }) t.Run("struct-nonptr", func(t *testing.T) { var tmp struct { A string B string } err := Unmarshal(info, data, &tmp) if err != nil { t.Errorf("unmarshalTest: %v", err) return } if tmp.A != "" || tmp.B != "foo" { t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", tmp.A, tmp.B) } }) t.Run("array", func(t *testing.T) { var tmp [2]*string err := Unmarshal(info, data, &tmp) if err != nil { t.Errorf("unmarshalTest: %v", err) return } if tmp[0] != nil || *tmp[1] != "foo" { t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", *tmp[0], *tmp[1]) } }) t.Run("array-nonptr", func(t *testing.T) { var tmp [2]string err := Unmarshal(info, data, &tmp) if err != nil { t.Errorf("unmarshalTest: %v", err) return } if tmp[0] != "" || tmp[1] != "foo" { t.Errorf("unmarshalTest: expected [nil, foo], got [%v, %v]", tmp[0], tmp[1]) } }) } func TestMarshalUDTMap(t *testing.T) { typeInfo := UDTTypeInfo{NativeType{proto: 3, typ: TypeUDT}, "", "xyz", []UDTField{ {Name: "x", Type: NativeType{proto: 3, typ: TypeInt}}, {Name: "y", Type: NativeType{proto: 3, typ: TypeInt}}, {Name: "z", Type: NativeType{proto: 3, typ: TypeInt}}, }} t.Run("partially bound", func(t *testing.T) { value := map[string]interface{}{ "y": 2, "z": 3, } expected := []byte("\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x03") data, err := Marshal(typeInfo, value) if err != nil { t.Errorf("got error %#v", err) } if !bytes.Equal(data, expected) { t.Errorf("got value %x", data) } }) t.Run("partially bound from the beginning", func(t *testing.T) { value := map[string]interface{}{ "x": 1, "y": 2, } expected := []byte("\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02\xff\xff\xff\xff") data, err := Marshal(typeInfo, value) if err != nil { t.Errorf("got error %#v", err) } if !bytes.Equal(data, expected) { t.Errorf("got value %x", data) } }) t.Run("fully bound", func(t *testing.T) { value := map[string]interface{}{ "x": 1, "y": 2, "z": 3, } expected := []byte("\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x03") data, err := Marshal(typeInfo, value) if err != nil { t.Errorf("got error %#v", err) } if !bytes.Equal(data, expected) { t.Errorf("got value %x", data) } }) } func TestMarshalUDTStruct(t *testing.T) { typeInfo := UDTTypeInfo{NativeType{proto: 3, typ: TypeUDT}, "", "xyz", []UDTField{ {Name: "x", Type: NativeType{proto: 3, typ: TypeInt}}, {Name: "y", Type: NativeType{proto: 3, typ: TypeInt}}, {Name: "z", Type: NativeType{proto: 3, typ: TypeInt}}, }} type xyzStruct struct { X int32 `cql:"x"` Y int32 `cql:"y"` Z int32 `cql:"z"` } type xyStruct struct { X int32 `cql:"x"` Y int32 `cql:"y"` } type yzStruct struct { Y int32 `cql:"y"` Z int32 `cql:"z"` } t.Run("partially bound", func(t *testing.T) { value := yzStruct{ Y: 2, Z: 3, } expected := []byte("\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x03") data, err := Marshal(typeInfo, value) if err != nil { t.Errorf("got error %#v", err) } if !bytes.Equal(data, expected) { t.Errorf("got value %x", data) } }) t.Run("partially bound from the beginning", func(t *testing.T) { value := xyStruct{ X: 1, Y: 2, } expected := []byte("\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02\xff\xff\xff\xff") data, err := Marshal(typeInfo, value) if err != nil { t.Errorf("got error %#v", err) } if !bytes.Equal(data, expected) { t.Errorf("got value %x", data) } }) t.Run("fully bound", func(t *testing.T) { value := xyzStruct{ X: 1, Y: 2, Z: 3, } expected := []byte("\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\x02\x00\x00\x00\x04\x00\x00\x00\x03") data, err := Marshal(typeInfo, value) if err != nil { t.Errorf("got error %#v", err) } if !bytes.Equal(data, expected) { t.Errorf("got value %x", data) } }) } func TestMarshalNil(t *testing.T) { types := []Type{ TypeAscii, TypeBlob, TypeBoolean, TypeBigInt, TypeCounter, TypeDecimal, TypeDouble, TypeFloat, TypeInt, TypeTimestamp, TypeUUID, TypeVarchar, TypeVarint, TypeTimeUUID, TypeInet, } for _, typ := range types { data, err := Marshal(NativeType{proto: 3, typ: typ}, nil) if err != nil { t.Errorf("unable to marshal nil %v: %v\n", typ, err) } else if data != nil { t.Errorf("expected to get nil byte for nil %v got % X", typ, data) } } } func TestUnmarshalInetCopyBytes(t *testing.T) { data := []byte{127, 0, 0, 1} var ip net.IP if err := unmarshalInet(NativeType{proto: 2, typ: TypeInet}, data, &ip); err != nil { t.Fatal(err) } copy(data, []byte{0xFF, 0xFF, 0xFF, 0xFF}) ip2 := net.IP(data) if !ip.Equal(net.IPv4(127, 0, 0, 1)) { t.Fatalf("IP memory shared with data: ip=%v ip2=%v", ip, ip2) } } func TestUnmarshalDate(t *testing.T) { data := []uint8{0x80, 0x0, 0x43, 0x31} var date time.Time if err := unmarshalDate(NativeType{proto: 2, typ: TypeDate}, data, &date); err != nil { t.Fatal(err) } expectedDate := "2017-02-04" formattedDate := date.Format("2006-01-02") if expectedDate != formattedDate { t.Errorf("marshalTest: expected %v, got %v", expectedDate, formattedDate) return } var stringDate string if err2 := unmarshalDate(NativeType{proto: 2, typ: TypeDate}, data, &stringDate); err2 != nil { t.Fatal(err2) } if expectedDate != stringDate { t.Errorf("marshalTest: expected %v, got %v", expectedDate, formattedDate) return } } func TestMarshalDate(t *testing.T) { now := time.Now().UTC() timestamp := now.UnixNano() / int64(time.Millisecond) expectedData := encInt(int32(timestamp/86400000 + int64(1<<31))) var marshalDateTests = []struct { Info TypeInfo Data []byte Value interface{} }{ { NativeType{proto: 4, typ: TypeDate}, expectedData, timestamp, }, { NativeType{proto: 4, typ: TypeDate}, expectedData, now, }, { NativeType{proto: 4, typ: TypeDate}, expectedData, &now, }, { NativeType{proto: 4, typ: TypeDate}, expectedData, now.Format("2006-01-02"), }, } for i, test := range marshalDateTests { t.Log(i, test) data, err := Marshal(test.Info, test.Value) if err != nil { t.Errorf("marshalTest[%d]: %v", i, err) continue } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, test.Data, decInt(test.Data), data, decInt(data), test.Value) } } } func TestLargeDate(t *testing.T) { farFuture := time.Date(999999, time.December, 31, 0, 0, 0, 0, time.UTC) expectedFutureData := encInt(int32(farFuture.UnixMilli()/86400000 + int64(1<<31))) farPast := time.Date(-999999, time.January, 1, 0, 0, 0, 0, time.UTC) expectedPastData := encInt(int32(farPast.UnixMilli()/86400000 + int64(1<<31))) var marshalDateTests = []struct { Data []byte Value interface{} ExpectedDate string }{ { expectedFutureData, farFuture, "999999-12-31", }, { expectedPastData, farPast, "-999999-01-01", }, } nativeType := NativeType{proto: 4, typ: TypeDate} for i, test := range marshalDateTests { t.Log(i, test) data, err := Marshal(nativeType, test.Value) if err != nil { t.Errorf("largeDateTest[%d]: %v", i, err) continue } if !bytes.Equal(data, test.Data) { t.Errorf("largeDateTest[%d]: expected %x (%v), got %x (%v) for time %s", i, test.Data, decInt(test.Data), data, decInt(data), test.Value) } var date time.Time if err := Unmarshal(nativeType, data, &date); err != nil { t.Fatal(err) } formattedDate := date.Format("2006-01-02") if test.ExpectedDate != formattedDate { t.Fatalf("largeDateTest: expected %v, got %v", test.ExpectedDate, formattedDate) } } } func BenchmarkUnmarshalVarchar(b *testing.B) { b.ReportAllocs() src := make([]byte, 1024) dst := make([]byte, len(src)) b.ResetTimer() for i := 0; i < b.N; i++ { if err := unmarshalVarchar(NativeType{}, src, &dst); err != nil { b.Fatal(err) } } } func TestMarshalDuration(t *testing.T) { durationS := "1h10m10s" duration, _ := time.ParseDuration(durationS) expectedData := append([]byte{0, 0}, encVint(duration.Nanoseconds())...) var marshalDurationTests = []struct { Info TypeInfo Data []byte Value interface{} }{ { NativeType{proto: 5, typ: TypeDuration}, expectedData, duration.Nanoseconds(), }, { NativeType{proto: 5, typ: TypeDuration}, expectedData, duration, }, { NativeType{proto: 5, typ: TypeDuration}, expectedData, durationS, }, { NativeType{proto: 5, typ: TypeDuration}, expectedData, &duration, }, } for i, test := range marshalDurationTests { t.Log(i, test) data, err := Marshal(test.Info, test.Value) if err != nil { t.Errorf("marshalTest[%d]: %v", i, err) continue } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, test.Data, decInt(test.Data), data, decInt(data), test.Value) } } } func TestReadCollectionSize(t *testing.T) { listV2 := CollectionType{ NativeType: NativeType{proto: 2, typ: TypeList}, Elem: NativeType{proto: 2, typ: TypeVarchar}, } listV3 := CollectionType{ NativeType: NativeType{proto: 3, typ: TypeList}, Elem: NativeType{proto: 3, typ: TypeVarchar}, } tests := []struct { name string info CollectionType data []byte isError bool expectedSize int }{ { name: "short read 0 proto 2", info: listV2, data: []byte{}, isError: true, }, { name: "short read 1 proto 2", info: listV2, data: []byte{0x01}, isError: true, }, { name: "good read proto 2", info: listV2, data: []byte{0x01, 0x38}, expectedSize: 0x0138, }, { name: "short read 0 proto 3", info: listV3, data: []byte{}, isError: true, }, { name: "short read 1 proto 3", info: listV3, data: []byte{0x01}, isError: true, }, { name: "short read 2 proto 3", info: listV3, data: []byte{0x01, 0x38}, isError: true, }, { name: "short read 3 proto 3", info: listV3, data: []byte{0x01, 0x38, 0x42}, isError: true, }, { name: "good read proto 3", info: listV3, data: []byte{0x01, 0x38, 0x42, 0x22}, expectedSize: 0x01384222, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { size, _, err := readCollectionSize(test.info, test.data) if test.isError { if err == nil { t.Fatal("Expected error, but it was nil") } } else { if err != nil { t.Fatalf("Expected no error, got %v", err) } if size != test.expectedSize { t.Fatalf("Expected size of %d, but got %d", test.expectedSize, size) } } }) } } func BenchmarkUnmarshalUUID(b *testing.B) { b.ReportAllocs() src := make([]byte, 16) dst := UUID{} var ti TypeInfo = NativeType{} b.ResetTimer() for i := 0; i < b.N; i++ { if err := unmarshalUUID(ti, src, &dst); err != nil { b.Fatal(err) } } } func TestUnmarshalUDT(t *testing.T) { info := UDTTypeInfo{ NativeType: NativeType{proto: 4, typ: TypeUDT}, Name: "myudt", KeySpace: "myks", Elements: []UDTField{ { Name: "first", Type: NativeType{proto: 4, typ: TypeAscii}, }, { Name: "second", Type: NativeType{proto: 4, typ: TypeSmallInt}, }, }, } data := bytesWithLength( // UDT bytesWithLength([]byte("Hello")), // first bytesWithLength([]byte("\x00\x2a")), // second ) value := map[string]interface{}{} expectedErr := UnmarshalError("can not unmarshal into non-pointer map[string]interface {}") if err := Unmarshal(info, data, value); err != expectedErr { t.Errorf("(%v=>%T): %#v returned error %#v, want %#v.", info, value, value, err, expectedErr) } } // bytesWithLength concatenates all data slices and prepends the total length as uint32. // The length does not count the size of the uint32 used for writing the size. func bytesWithLength(data ...[]byte) []byte { totalLen := 0 for i := range data { totalLen += len(data[i]) } if totalLen > math.MaxUint32 { panic("total length overflows") } ret := make([]byte, totalLen+4) binary.BigEndian.PutUint32(ret[:4], uint32(totalLen)) buf := ret[4:] for i := range data { n := copy(buf, data[i]) buf = buf[n:] } return ret } cassandra-gocql-driver-1.7.0/metadata.go000066400000000000000000001110621467504044300201710ustar00rootroot00000000000000// Copyright (c) 2015 The gocql Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "encoding/hex" "encoding/json" "fmt" "strconv" "strings" "sync" ) // schema metadata for a keyspace type KeyspaceMetadata struct { Name string DurableWrites bool StrategyClass string StrategyOptions map[string]interface{} Tables map[string]*TableMetadata Functions map[string]*FunctionMetadata Aggregates map[string]*AggregateMetadata // Deprecated: use the MaterializedViews field for views and UserTypes field for udts instead. Views map[string]*ViewMetadata MaterializedViews map[string]*MaterializedViewMetadata UserTypes map[string]*UserTypeMetadata } // schema metadata for a table (a.k.a. column family) type TableMetadata struct { Keyspace string Name string KeyValidator string Comparator string DefaultValidator string KeyAliases []string ColumnAliases []string ValueAlias string PartitionKey []*ColumnMetadata ClusteringColumns []*ColumnMetadata Columns map[string]*ColumnMetadata OrderedColumns []string } // schema metadata for a column type ColumnMetadata struct { Keyspace string Table string Name string ComponentIndex int Kind ColumnKind Validator string Type TypeInfo ClusteringOrder string Order ColumnOrder Index ColumnIndexMetadata } // FunctionMetadata holds metadata for function constructs type FunctionMetadata struct { Keyspace string Name string ArgumentTypes []TypeInfo ArgumentNames []string Body string CalledOnNullInput bool Language string ReturnType TypeInfo } // AggregateMetadata holds metadata for aggregate constructs type AggregateMetadata struct { Keyspace string Name string ArgumentTypes []TypeInfo FinalFunc FunctionMetadata InitCond string ReturnType TypeInfo StateFunc FunctionMetadata StateType TypeInfo stateFunc string finalFunc string } // ViewMetadata holds the metadata for views. // Deprecated: this is kept for backwards compatibility issues. Use MaterializedViewMetadata. type ViewMetadata struct { Keyspace string Name string FieldNames []string FieldTypes []TypeInfo } // MaterializedViewMetadata holds the metadata for materialized views. type MaterializedViewMetadata struct { Keyspace string Name string BaseTableId UUID BaseTable *TableMetadata BloomFilterFpChance float64 Caching map[string]string Comment string Compaction map[string]string Compression map[string]string CrcCheckChance float64 DcLocalReadRepairChance float64 DefaultTimeToLive int Extensions map[string]string GcGraceSeconds int Id UUID IncludeAllColumns bool MaxIndexInterval int MemtableFlushPeriodInMs int MinIndexInterval int ReadRepairChance float64 SpeculativeRetry string baseTableName string } type UserTypeMetadata struct { Keyspace string Name string FieldNames []string FieldTypes []TypeInfo } // the ordering of the column with regard to its comparator type ColumnOrder bool const ( ASC ColumnOrder = false DESC ColumnOrder = true ) type ColumnIndexMetadata struct { Name string Type string Options map[string]interface{} } type ColumnKind int const ( ColumnUnkownKind ColumnKind = iota ColumnPartitionKey ColumnClusteringKey ColumnRegular ColumnCompact ColumnStatic ) func (c ColumnKind) String() string { switch c { case ColumnPartitionKey: return "partition_key" case ColumnClusteringKey: return "clustering_key" case ColumnRegular: return "regular" case ColumnCompact: return "compact" case ColumnStatic: return "static" default: return fmt.Sprintf("unknown_column_%d", c) } } func (c *ColumnKind) UnmarshalCQL(typ TypeInfo, p []byte) error { if typ.Type() != TypeVarchar { return unmarshalErrorf("unable to marshall %s into ColumnKind, expected Varchar", typ) } kind, err := columnKindFromSchema(string(p)) if err != nil { return err } *c = kind return nil } func columnKindFromSchema(kind string) (ColumnKind, error) { switch kind { case "partition_key": return ColumnPartitionKey, nil case "clustering_key", "clustering": return ColumnClusteringKey, nil case "regular": return ColumnRegular, nil case "compact_value": return ColumnCompact, nil case "static": return ColumnStatic, nil default: return -1, fmt.Errorf("unknown column kind: %q", kind) } } // default alias values const ( DEFAULT_KEY_ALIAS = "key" DEFAULT_COLUMN_ALIAS = "column" DEFAULT_VALUE_ALIAS = "value" ) // queries the cluster for schema information for a specific keyspace type schemaDescriber struct { session *Session mu sync.Mutex cache map[string]*KeyspaceMetadata } // creates a session bound schema describer which will query and cache // keyspace metadata func newSchemaDescriber(session *Session) *schemaDescriber { return &schemaDescriber{ session: session, cache: map[string]*KeyspaceMetadata{}, } } // returns the cached KeyspaceMetadata held by the describer for the named // keyspace. func (s *schemaDescriber) getSchema(keyspaceName string) (*KeyspaceMetadata, error) { s.mu.Lock() defer s.mu.Unlock() metadata, found := s.cache[keyspaceName] if !found { // refresh the cache for this keyspace err := s.refreshSchema(keyspaceName) if err != nil { return nil, err } metadata = s.cache[keyspaceName] } return metadata, nil } // clears the already cached keyspace metadata func (s *schemaDescriber) clearSchema(keyspaceName string) { s.mu.Lock() defer s.mu.Unlock() delete(s.cache, keyspaceName) } // forcibly updates the current KeyspaceMetadata held by the schema describer // for a given named keyspace. func (s *schemaDescriber) refreshSchema(keyspaceName string) error { var err error // query the system keyspace for schema data // TODO retrieve concurrently keyspace, err := getKeyspaceMetadata(s.session, keyspaceName) if err != nil { return err } tables, err := getTableMetadata(s.session, keyspaceName) if err != nil { return err } columns, err := getColumnMetadata(s.session, keyspaceName) if err != nil { return err } functions, err := getFunctionsMetadata(s.session, keyspaceName) if err != nil { return err } aggregates, err := getAggregatesMetadata(s.session, keyspaceName) if err != nil { return err } views, err := getViewsMetadata(s.session, keyspaceName) if err != nil { return err } materializedViews, err := getMaterializedViewsMetadata(s.session, keyspaceName) if err != nil { return err } // organize the schema data compileMetadata(s.session.cfg.ProtoVersion, keyspace, tables, columns, functions, aggregates, views, materializedViews, s.session.logger) // update the cache s.cache[keyspaceName] = keyspace return nil } // "compiles" derived information about keyspace, table, and column metadata // for a keyspace from the basic queried metadata objects returned by // getKeyspaceMetadata, getTableMetadata, and getColumnMetadata respectively; // Links the metadata objects together and derives the column composition of // the partition key and clustering key for a table. func compileMetadata( protoVersion int, keyspace *KeyspaceMetadata, tables []TableMetadata, columns []ColumnMetadata, functions []FunctionMetadata, aggregates []AggregateMetadata, views []ViewMetadata, materializedViews []MaterializedViewMetadata, logger StdLogger, ) { keyspace.Tables = make(map[string]*TableMetadata) for i := range tables { tables[i].Columns = make(map[string]*ColumnMetadata) keyspace.Tables[tables[i].Name] = &tables[i] } keyspace.Functions = make(map[string]*FunctionMetadata, len(functions)) for i := range functions { keyspace.Functions[functions[i].Name] = &functions[i] } keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates)) for i, _ := range aggregates { aggregates[i].FinalFunc = *keyspace.Functions[aggregates[i].finalFunc] aggregates[i].StateFunc = *keyspace.Functions[aggregates[i].stateFunc] keyspace.Aggregates[aggregates[i].Name] = &aggregates[i] } keyspace.Views = make(map[string]*ViewMetadata, len(views)) for i := range views { keyspace.Views[views[i].Name] = &views[i] } // Views currently holds the types and hasn't been deleted for backward compatibility issues. // That's why it's ok to copy Views into Types in this case. For the real Views use MaterializedViews. types := make([]UserTypeMetadata, len(views)) for i := range views { types[i].Keyspace = views[i].Keyspace types[i].Name = views[i].Name types[i].FieldNames = views[i].FieldNames types[i].FieldTypes = views[i].FieldTypes } keyspace.UserTypes = make(map[string]*UserTypeMetadata, len(views)) for i := range types { keyspace.UserTypes[types[i].Name] = &types[i] } keyspace.MaterializedViews = make(map[string]*MaterializedViewMetadata, len(materializedViews)) for i, _ := range materializedViews { materializedViews[i].BaseTable = keyspace.Tables[materializedViews[i].baseTableName] keyspace.MaterializedViews[materializedViews[i].Name] = &materializedViews[i] } // add columns from the schema data for i := range columns { col := &columns[i] // decode the validator for TypeInfo and order if col.ClusteringOrder != "" { // Cassandra 3.x+ col.Type = getCassandraType(col.Validator, logger) col.Order = ASC if col.ClusteringOrder == "desc" { col.Order = DESC } } else { validatorParsed := parseType(col.Validator, logger) col.Type = validatorParsed.types[0] col.Order = ASC if validatorParsed.reversed[0] { col.Order = DESC } } table, ok := keyspace.Tables[col.Table] if !ok { // if the schema is being updated we will race between seeing // the metadata be complete. Potentially we should check for // schema versions before and after reading the metadata and // if they dont match try again. continue } table.Columns[col.Name] = col table.OrderedColumns = append(table.OrderedColumns, col.Name) } if protoVersion == protoVersion1 { compileV1Metadata(tables, logger) } else { compileV2Metadata(tables, logger) } } // Compiles derived information from TableMetadata which have had // ColumnMetadata added already. V1 protocol does not return as much // column metadata as V2+ (because V1 doesn't support the "type" column in the // system.schema_columns table) so determining PartitionKey and ClusterColumns // is more complex. func compileV1Metadata(tables []TableMetadata, logger StdLogger) { for i := range tables { table := &tables[i] // decode the key validator keyValidatorParsed := parseType(table.KeyValidator, logger) // decode the comparator comparatorParsed := parseType(table.Comparator, logger) // the partition key length is the same as the number of types in the // key validator table.PartitionKey = make([]*ColumnMetadata, len(keyValidatorParsed.types)) // V1 protocol only returns "regular" columns from // system.schema_columns (there is no type field for columns) // so the alias information is used to // create the partition key and clustering columns // construct the partition key from the alias for i := range table.PartitionKey { var alias string if len(table.KeyAliases) > i { alias = table.KeyAliases[i] } else if i == 0 { alias = DEFAULT_KEY_ALIAS } else { alias = DEFAULT_KEY_ALIAS + strconv.Itoa(i+1) } column := &ColumnMetadata{ Keyspace: table.Keyspace, Table: table.Name, Name: alias, Type: keyValidatorParsed.types[i], Kind: ColumnPartitionKey, ComponentIndex: i, } table.PartitionKey[i] = column table.Columns[alias] = column } // determine the number of clustering columns size := len(comparatorParsed.types) if comparatorParsed.isComposite { if len(comparatorParsed.collections) != 0 || (len(table.ColumnAliases) == size-1 && comparatorParsed.types[size-1].Type() == TypeVarchar) { size = size - 1 } } else { if !(len(table.ColumnAliases) != 0 || len(table.Columns) == 0) { size = 0 } } table.ClusteringColumns = make([]*ColumnMetadata, size) for i := range table.ClusteringColumns { var alias string if len(table.ColumnAliases) > i { alias = table.ColumnAliases[i] } else if i == 0 { alias = DEFAULT_COLUMN_ALIAS } else { alias = DEFAULT_COLUMN_ALIAS + strconv.Itoa(i+1) } order := ASC if comparatorParsed.reversed[i] { order = DESC } column := &ColumnMetadata{ Keyspace: table.Keyspace, Table: table.Name, Name: alias, Type: comparatorParsed.types[i], Order: order, Kind: ColumnClusteringKey, ComponentIndex: i, } table.ClusteringColumns[i] = column table.Columns[alias] = column } if size != len(comparatorParsed.types)-1 { alias := DEFAULT_VALUE_ALIAS if len(table.ValueAlias) > 0 { alias = table.ValueAlias } // decode the default validator defaultValidatorParsed := parseType(table.DefaultValidator, logger) column := &ColumnMetadata{ Keyspace: table.Keyspace, Table: table.Name, Name: alias, Type: defaultValidatorParsed.types[0], Kind: ColumnRegular, } table.Columns[alias] = column } } } // The simpler compile case for V2+ protocol func compileV2Metadata(tables []TableMetadata, logger StdLogger) { for i := range tables { table := &tables[i] clusteringColumnCount := componentColumnCountOfType(table.Columns, ColumnClusteringKey) table.ClusteringColumns = make([]*ColumnMetadata, clusteringColumnCount) if table.KeyValidator != "" { keyValidatorParsed := parseType(table.KeyValidator, logger) table.PartitionKey = make([]*ColumnMetadata, len(keyValidatorParsed.types)) } else { // Cassandra 3.x+ partitionKeyCount := componentColumnCountOfType(table.Columns, ColumnPartitionKey) table.PartitionKey = make([]*ColumnMetadata, partitionKeyCount) } for _, columnName := range table.OrderedColumns { column := table.Columns[columnName] if column.Kind == ColumnPartitionKey { table.PartitionKey[column.ComponentIndex] = column } else if column.Kind == ColumnClusteringKey { table.ClusteringColumns[column.ComponentIndex] = column } } } } // returns the count of coluns with the given "kind" value. func componentColumnCountOfType(columns map[string]*ColumnMetadata, kind ColumnKind) int { maxComponentIndex := -1 for _, column := range columns { if column.Kind == kind && column.ComponentIndex > maxComponentIndex { maxComponentIndex = column.ComponentIndex } } return maxComponentIndex + 1 } // query only for the keyspace metadata for the specified keyspace from system.schema_keyspace func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetadata, error) { keyspace := &KeyspaceMetadata{Name: keyspaceName} if session.useSystemSchema { // Cassandra 3.x+ const stmt = ` SELECT durable_writes, replication FROM system_schema.keyspaces WHERE keyspace_name = ?` var replication map[string]string iter := session.control.query(stmt, keyspaceName) if iter.NumRows() == 0 { return nil, ErrKeyspaceDoesNotExist } iter.Scan(&keyspace.DurableWrites, &replication) err := iter.Close() if err != nil { return nil, fmt.Errorf("error querying keyspace schema: %v", err) } keyspace.StrategyClass = replication["class"] delete(replication, "class") keyspace.StrategyOptions = make(map[string]interface{}, len(replication)) for k, v := range replication { keyspace.StrategyOptions[k] = v } } else { const stmt = ` SELECT durable_writes, strategy_class, strategy_options FROM system.schema_keyspaces WHERE keyspace_name = ?` var strategyOptionsJSON []byte iter := session.control.query(stmt, keyspaceName) if iter.NumRows() == 0 { return nil, ErrKeyspaceDoesNotExist } iter.Scan(&keyspace.DurableWrites, &keyspace.StrategyClass, &strategyOptionsJSON) err := iter.Close() if err != nil { return nil, fmt.Errorf("error querying keyspace schema: %v", err) } err = json.Unmarshal(strategyOptionsJSON, &keyspace.StrategyOptions) if err != nil { return nil, fmt.Errorf( "invalid JSON value '%s' as strategy_options for in keyspace '%s': %v", strategyOptionsJSON, keyspace.Name, err, ) } } return keyspace, nil } // query for only the table metadata in the specified keyspace from system.schema_columnfamilies func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, error) { var ( iter *Iter scan func(iter *Iter, table *TableMetadata) bool stmt string keyAliasesJSON []byte columnAliasesJSON []byte ) if session.useSystemSchema { // Cassandra 3.x+ stmt = ` SELECT table_name FROM system_schema.tables WHERE keyspace_name = ?` switchIter := func() *Iter { iter.Close() stmt = ` SELECT view_name FROM system_schema.views WHERE keyspace_name = ?` iter = session.control.query(stmt, keyspaceName) return iter } scan = func(iter *Iter, table *TableMetadata) bool { r := iter.Scan( &table.Name, ) if !r { iter = switchIter() if iter != nil { switchIter = func() *Iter { return nil } r = iter.Scan(&table.Name) } } return r } } else if session.cfg.ProtoVersion == protoVersion1 { // we have key aliases stmt = ` SELECT columnfamily_name, key_validator, comparator, default_validator, key_aliases, column_aliases, value_alias FROM system.schema_columnfamilies WHERE keyspace_name = ?` scan = func(iter *Iter, table *TableMetadata) bool { return iter.Scan( &table.Name, &table.KeyValidator, &table.Comparator, &table.DefaultValidator, &keyAliasesJSON, &columnAliasesJSON, &table.ValueAlias, ) } } else { stmt = ` SELECT columnfamily_name, key_validator, comparator, default_validator FROM system.schema_columnfamilies WHERE keyspace_name = ?` scan = func(iter *Iter, table *TableMetadata) bool { return iter.Scan( &table.Name, &table.KeyValidator, &table.Comparator, &table.DefaultValidator, ) } } iter = session.control.query(stmt, keyspaceName) tables := []TableMetadata{} table := TableMetadata{Keyspace: keyspaceName} for scan(iter, &table) { var err error // decode the key aliases if keyAliasesJSON != nil { table.KeyAliases = []string{} err = json.Unmarshal(keyAliasesJSON, &table.KeyAliases) if err != nil { iter.Close() return nil, fmt.Errorf( "invalid JSON value '%s' as key_aliases for in table '%s': %v", keyAliasesJSON, table.Name, err, ) } } // decode the column aliases if columnAliasesJSON != nil { table.ColumnAliases = []string{} err = json.Unmarshal(columnAliasesJSON, &table.ColumnAliases) if err != nil { iter.Close() return nil, fmt.Errorf( "invalid JSON value '%s' as column_aliases for in table '%s': %v", columnAliasesJSON, table.Name, err, ) } } tables = append(tables, table) table = TableMetadata{Keyspace: keyspaceName} } err := iter.Close() if err != nil && err != ErrNotFound { return nil, fmt.Errorf("error querying table schema: %v", err) } return tables, nil } func (s *Session) scanColumnMetadataV1(keyspace string) ([]ColumnMetadata, error) { // V1 does not support the type column, and all returned rows are // of kind "regular". const stmt = ` SELECT columnfamily_name, column_name, component_index, validator, index_name, index_type, index_options FROM system.schema_columns WHERE keyspace_name = ?` var columns []ColumnMetadata rows := s.control.query(stmt, keyspace).Scanner() for rows.Next() { var ( column = ColumnMetadata{Keyspace: keyspace} indexOptionsJSON []byte ) // all columns returned by V1 are regular column.Kind = ColumnRegular err := rows.Scan(&column.Table, &column.Name, &column.ComponentIndex, &column.Validator, &column.Index.Name, &column.Index.Type, &indexOptionsJSON) if err != nil { return nil, err } if len(indexOptionsJSON) > 0 { err := json.Unmarshal(indexOptionsJSON, &column.Index.Options) if err != nil { return nil, fmt.Errorf( "invalid JSON value '%s' as index_options for column '%s' in table '%s': %v", indexOptionsJSON, column.Name, column.Table, err) } } columns = append(columns, column) } if err := rows.Err(); err != nil { return nil, err } return columns, nil } func (s *Session) scanColumnMetadataV2(keyspace string) ([]ColumnMetadata, error) { // V2+ supports the type column const stmt = ` SELECT columnfamily_name, column_name, component_index, validator, index_name, index_type, index_options, type FROM system.schema_columns WHERE keyspace_name = ?` var columns []ColumnMetadata rows := s.control.query(stmt, keyspace).Scanner() for rows.Next() { var ( column = ColumnMetadata{Keyspace: keyspace} indexOptionsJSON []byte ) err := rows.Scan(&column.Table, &column.Name, &column.ComponentIndex, &column.Validator, &column.Index.Name, &column.Index.Type, &indexOptionsJSON, &column.Kind, ) if err != nil { return nil, err } if len(indexOptionsJSON) > 0 { err := json.Unmarshal(indexOptionsJSON, &column.Index.Options) if err != nil { return nil, fmt.Errorf( "invalid JSON value '%s' as index_options for column '%s' in table '%s': %v", indexOptionsJSON, column.Name, column.Table, err) } } columns = append(columns, column) } if err := rows.Err(); err != nil { return nil, err } return columns, nil } func (s *Session) scanColumnMetadataSystem(keyspace string) ([]ColumnMetadata, error) { const stmt = ` SELECT table_name, column_name, clustering_order, type, kind, position FROM system_schema.columns WHERE keyspace_name = ?` var columns []ColumnMetadata rows := s.control.query(stmt, keyspace).Scanner() for rows.Next() { column := ColumnMetadata{Keyspace: keyspace} err := rows.Scan(&column.Table, &column.Name, &column.ClusteringOrder, &column.Validator, &column.Kind, &column.ComponentIndex, ) if err != nil { return nil, err } columns = append(columns, column) } if err := rows.Err(); err != nil { return nil, err } // TODO(zariel): get column index info from system_schema.indexes return columns, nil } // query for only the column metadata in the specified keyspace from system.schema_columns func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, error) { var ( columns []ColumnMetadata err error ) // Deal with differences in protocol versions if session.cfg.ProtoVersion == 1 { columns, err = session.scanColumnMetadataV1(keyspaceName) } else if session.useSystemSchema { // Cassandra 3.x+ columns, err = session.scanColumnMetadataSystem(keyspaceName) } else { columns, err = session.scanColumnMetadataV2(keyspaceName) } if err != nil && err != ErrNotFound { return nil, fmt.Errorf("error querying column schema: %v", err) } return columns, nil } func getTypeInfo(t string, logger StdLogger) TypeInfo { if strings.HasPrefix(t, apacheCassandraTypePrefix) { t = apacheToCassandraType(t) } return getCassandraType(t, logger) } func getViewsMetadata(session *Session, keyspaceName string) ([]ViewMetadata, error) { if session.cfg.ProtoVersion == protoVersion1 { return nil, nil } var tableName string if session.useSystemSchema { tableName = "system_schema.types" } else { tableName = "system.schema_usertypes" } stmt := fmt.Sprintf(` SELECT type_name, field_names, field_types FROM %s WHERE keyspace_name = ?`, tableName) var views []ViewMetadata rows := session.control.query(stmt, keyspaceName).Scanner() for rows.Next() { view := ViewMetadata{Keyspace: keyspaceName} var argumentTypes []string err := rows.Scan(&view.Name, &view.FieldNames, &argumentTypes, ) if err != nil { return nil, err } view.FieldTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { view.FieldTypes[i] = getTypeInfo(argumentType, session.logger) } views = append(views, view) } if err := rows.Err(); err != nil { return nil, err } return views, nil } func getMaterializedViewsMetadata(session *Session, keyspaceName string) ([]MaterializedViewMetadata, error) { if !session.useSystemSchema { return nil, nil } var tableName = "system_schema.views" stmt := fmt.Sprintf(` SELECT view_name, base_table_id, base_table_name, bloom_filter_fp_chance, caching, comment, compaction, compression, crc_check_chance, dclocal_read_repair_chance, default_time_to_live, extensions, gc_grace_seconds, id, include_all_columns, max_index_interval, memtable_flush_period_in_ms, min_index_interval, read_repair_chance, speculative_retry FROM %s WHERE keyspace_name = ?`, tableName) var materializedViews []MaterializedViewMetadata rows := session.control.query(stmt, keyspaceName).Scanner() for rows.Next() { materializedView := MaterializedViewMetadata{Keyspace: keyspaceName} err := rows.Scan(&materializedView.Name, &materializedView.BaseTableId, &materializedView.baseTableName, &materializedView.BloomFilterFpChance, &materializedView.Caching, &materializedView.Comment, &materializedView.Compaction, &materializedView.Compression, &materializedView.CrcCheckChance, &materializedView.DcLocalReadRepairChance, &materializedView.DefaultTimeToLive, &materializedView.Extensions, &materializedView.GcGraceSeconds, &materializedView.Id, &materializedView.IncludeAllColumns, &materializedView.MaxIndexInterval, &materializedView.MemtableFlushPeriodInMs, &materializedView.MinIndexInterval, &materializedView.ReadRepairChance, &materializedView.SpeculativeRetry, ) if err != nil { return nil, err } materializedViews = append(materializedViews, materializedView) } if err := rows.Err(); err != nil { return nil, err } return materializedViews, nil } func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMetadata, error) { if session.cfg.ProtoVersion == protoVersion1 || !session.hasAggregatesAndFunctions { return nil, nil } var tableName string if session.useSystemSchema { tableName = "system_schema.functions" } else { tableName = "system.schema_functions" } stmt := fmt.Sprintf(` SELECT function_name, argument_types, argument_names, body, called_on_null_input, language, return_type FROM %s WHERE keyspace_name = ?`, tableName) var functions []FunctionMetadata rows := session.control.query(stmt, keyspaceName).Scanner() for rows.Next() { function := FunctionMetadata{Keyspace: keyspaceName} var argumentTypes []string var returnType string err := rows.Scan(&function.Name, &argumentTypes, &function.ArgumentNames, &function.Body, &function.CalledOnNullInput, &function.Language, &returnType, ) if err != nil { return nil, err } function.ReturnType = getTypeInfo(returnType, session.logger) function.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { function.ArgumentTypes[i] = getTypeInfo(argumentType, session.logger) } functions = append(functions, function) } if err := rows.Err(); err != nil { return nil, err } return functions, nil } func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMetadata, error) { if session.cfg.ProtoVersion == protoVersion1 || !session.hasAggregatesAndFunctions { return nil, nil } var tableName string if session.useSystemSchema { tableName = "system_schema.aggregates" } else { tableName = "system.schema_aggregates" } stmt := fmt.Sprintf(` SELECT aggregate_name, argument_types, final_func, initcond, return_type, state_func, state_type FROM %s WHERE keyspace_name = ?`, tableName) var aggregates []AggregateMetadata rows := session.control.query(stmt, keyspaceName).Scanner() for rows.Next() { aggregate := AggregateMetadata{Keyspace: keyspaceName} var argumentTypes []string var returnType string var stateType string err := rows.Scan(&aggregate.Name, &argumentTypes, &aggregate.finalFunc, &aggregate.InitCond, &returnType, &aggregate.stateFunc, &stateType, ) if err != nil { return nil, err } aggregate.ReturnType = getTypeInfo(returnType, session.logger) aggregate.StateType = getTypeInfo(stateType, session.logger) aggregate.ArgumentTypes = make([]TypeInfo, len(argumentTypes)) for i, argumentType := range argumentTypes { aggregate.ArgumentTypes[i] = getTypeInfo(argumentType, session.logger) } aggregates = append(aggregates, aggregate) } if err := rows.Err(); err != nil { return nil, err } return aggregates, nil } // type definition parser state type typeParser struct { input string index int logger StdLogger } // the type definition parser result type typeParserResult struct { isComposite bool types []TypeInfo reversed []bool collections map[string]TypeInfo } // Parse the type definition used for validator and comparator schema data func parseType(def string, logger StdLogger) typeParserResult { parser := &typeParser{input: def, logger: logger} return parser.parse() } const ( REVERSED_TYPE = "org.apache.cassandra.db.marshal.ReversedType" COMPOSITE_TYPE = "org.apache.cassandra.db.marshal.CompositeType" COLLECTION_TYPE = "org.apache.cassandra.db.marshal.ColumnToCollectionType" LIST_TYPE = "org.apache.cassandra.db.marshal.ListType" SET_TYPE = "org.apache.cassandra.db.marshal.SetType" MAP_TYPE = "org.apache.cassandra.db.marshal.MapType" ) // represents a class specification in the type def AST type typeParserClassNode struct { name string params []typeParserParamNode // this is the segment of the input string that defined this node input string } // represents a class parameter in the type def AST type typeParserParamNode struct { name *string class typeParserClassNode } func (t *typeParser) parse() typeParserResult { // parse the AST ast, ok := t.parseClassNode() if !ok { // treat this is a custom type return typeParserResult{ isComposite: false, types: []TypeInfo{ NativeType{ typ: TypeCustom, custom: t.input, }, }, reversed: []bool{false}, collections: nil, } } // interpret the AST if strings.HasPrefix(ast.name, COMPOSITE_TYPE) { count := len(ast.params) // look for a collections param last := ast.params[count-1] collections := map[string]TypeInfo{} if strings.HasPrefix(last.class.name, COLLECTION_TYPE) { count-- for _, param := range last.class.params { // decode the name var name string decoded, err := hex.DecodeString(*param.name) if err != nil { t.logger.Printf( "Error parsing type '%s', contains collection name '%s' with an invalid format: %v", t.input, *param.name, err, ) // just use the provided name name = *param.name } else { name = string(decoded) } collections[name] = param.class.asTypeInfo() } } types := make([]TypeInfo, count) reversed := make([]bool, count) for i, param := range ast.params[:count] { class := param.class reversed[i] = strings.HasPrefix(class.name, REVERSED_TYPE) if reversed[i] { class = class.params[0].class } types[i] = class.asTypeInfo() } return typeParserResult{ isComposite: true, types: types, reversed: reversed, collections: collections, } } else { // not composite, so one type class := *ast reversed := strings.HasPrefix(class.name, REVERSED_TYPE) if reversed { class = class.params[0].class } typeInfo := class.asTypeInfo() return typeParserResult{ isComposite: false, types: []TypeInfo{typeInfo}, reversed: []bool{reversed}, } } } func (class *typeParserClassNode) asTypeInfo() TypeInfo { if strings.HasPrefix(class.name, LIST_TYPE) { elem := class.params[0].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ typ: TypeList, }, Elem: elem, } } if strings.HasPrefix(class.name, SET_TYPE) { elem := class.params[0].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ typ: TypeSet, }, Elem: elem, } } if strings.HasPrefix(class.name, MAP_TYPE) { key := class.params[0].class.asTypeInfo() elem := class.params[1].class.asTypeInfo() return CollectionType{ NativeType: NativeType{ typ: TypeMap, }, Key: key, Elem: elem, } } // must be a simple type or custom type info := NativeType{typ: getApacheCassandraType(class.name)} if info.typ == TypeCustom { // add the entire class definition info.custom = class.input } return info } // CLASS := ID [ PARAMS ] func (t *typeParser) parseClassNode() (node *typeParserClassNode, ok bool) { t.skipWhitespace() startIndex := t.index name, ok := t.nextIdentifier() if !ok { return nil, false } params, ok := t.parseParamNodes() if !ok { return nil, false } endIndex := t.index node = &typeParserClassNode{ name: name, params: params, input: t.input[startIndex:endIndex], } return node, true } // PARAMS := "(" PARAM { "," PARAM } ")" // PARAM := [ PARAM_NAME ":" ] CLASS // PARAM_NAME := ID func (t *typeParser) parseParamNodes() (params []typeParserParamNode, ok bool) { t.skipWhitespace() // the params are optional if t.index == len(t.input) || t.input[t.index] != '(' { return nil, true } params = []typeParserParamNode{} // consume the '(' t.index++ t.skipWhitespace() for t.input[t.index] != ')' { // look for a named param, but if no colon, then we want to backup backupIndex := t.index // name will be a hex encoded version of a utf-8 string name, ok := t.nextIdentifier() if !ok { return nil, false } hasName := true // TODO handle '=>' used for DynamicCompositeType t.skipWhitespace() if t.input[t.index] == ':' { // there is a name for this parameter // consume the ':' t.index++ t.skipWhitespace() } else { // no name, backup hasName = false t.index = backupIndex } // parse the next full parameter classNode, ok := t.parseClassNode() if !ok { return nil, false } if hasName { params = append( params, typeParserParamNode{name: &name, class: *classNode}, ) } else { params = append( params, typeParserParamNode{class: *classNode}, ) } t.skipWhitespace() if t.input[t.index] == ',' { // consume the comma t.index++ t.skipWhitespace() } } // consume the ')' t.index++ return params, true } func (t *typeParser) skipWhitespace() { for t.index < len(t.input) && isWhitespaceChar(t.input[t.index]) { t.index++ } } func isWhitespaceChar(c byte) bool { return c == ' ' || c == '\n' || c == '\t' } // ID := LETTER { LETTER } // LETTER := "0"..."9" | "a"..."z" | "A"..."Z" | "-" | "+" | "." | "_" | "&" func (t *typeParser) nextIdentifier() (id string, found bool) { startIndex := t.index for t.index < len(t.input) && isIdentifierChar(t.input[t.index]) { t.index++ } if startIndex == t.index { return "", false } return t.input[startIndex:t.index], true } func isIdentifierChar(c byte) bool { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '-' || c == '+' || c == '.' || c == '_' || c == '&' } cassandra-gocql-driver-1.7.0/metadata_test.go000066400000000000000000000741521467504044300212400ustar00rootroot00000000000000// Copyright (c) 2015 The gocql Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "strconv" "testing" ) // Tests V1 and V2 metadata "compilation" from example data which might be returned // from metadata schema queries (see getKeyspaceMetadata, getTableMetadata, and getColumnMetadata) func TestCompileMetadata(t *testing.T) { // V1 tests - these are all based on real examples from the integration test ccm cluster log := &defaultLogger{} keyspace := &KeyspaceMetadata{ Name: "V1Keyspace", } tables := []TableMetadata{ { // This table, found in the system keyspace, has no key aliases or column aliases Keyspace: "V1Keyspace", Name: "Schema", KeyValidator: "org.apache.cassandra.db.marshal.BytesType", Comparator: "org.apache.cassandra.db.marshal.UTF8Type", DefaultValidator: "org.apache.cassandra.db.marshal.BytesType", KeyAliases: []string{}, ColumnAliases: []string{}, ValueAlias: "", }, { // This table, found in the system keyspace, has key aliases, column aliases, and a value alias. Keyspace: "V1Keyspace", Name: "hints", KeyValidator: "org.apache.cassandra.db.marshal.UUIDType", Comparator: "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.TimeUUIDType,org.apache.cassandra.db.marshal.Int32Type)", DefaultValidator: "org.apache.cassandra.db.marshal.BytesType", KeyAliases: []string{"target_id"}, ColumnAliases: []string{"hint_id", "message_version"}, ValueAlias: "mutation", }, { // This table, found in the system keyspace, has a comparator with collections, but no column aliases Keyspace: "V1Keyspace", Name: "peers", KeyValidator: "org.apache.cassandra.db.marshal.InetAddressType", Comparator: "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.ColumnToCollectionType(746f6b656e73:org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)))", DefaultValidator: "org.apache.cassandra.db.marshal.BytesType", KeyAliases: []string{"peer"}, ColumnAliases: []string{}, ValueAlias: "", }, { // This table, found in the system keyspace, has a column alias, but not a composite comparator Keyspace: "V1Keyspace", Name: "IndexInfo", KeyValidator: "org.apache.cassandra.db.marshal.UTF8Type", Comparator: "org.apache.cassandra.db.marshal.ReversedType(org.apache.cassandra.db.marshal.UTF8Type)", DefaultValidator: "org.apache.cassandra.db.marshal.BytesType", KeyAliases: []string{"table_name"}, ColumnAliases: []string{"index_name"}, ValueAlias: "", }, { // This table, found in the gocql_test keyspace following an integration test run, has a composite comparator with collections as well as a column alias Keyspace: "V1Keyspace", Name: "wiki_page", KeyValidator: "org.apache.cassandra.db.marshal.UTF8Type", Comparator: "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.TimeUUIDType,org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.ColumnToCollectionType(74616773:org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type),6174746163686d656e7473:org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.BytesType)))", DefaultValidator: "org.apache.cassandra.db.marshal.BytesType", KeyAliases: []string{"title"}, ColumnAliases: []string{"revid"}, ValueAlias: "", }, { // This is a made up example with multiple unnamed aliases Keyspace: "V1Keyspace", Name: "no_names", KeyValidator: "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UUIDType,org.apache.cassandra.db.marshal.UUIDType)", Comparator: "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.Int32Type)", DefaultValidator: "org.apache.cassandra.db.marshal.BytesType", KeyAliases: []string{}, ColumnAliases: []string{}, ValueAlias: "", }, } columns := []ColumnMetadata{ // Here are the regular columns from the peers table for testing regular columns {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "data_center", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type"}, {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "host_id", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType"}, {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "rack", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type"}, {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "release_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type"}, {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "rpc_address", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.InetAddressType"}, {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "schema_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType"}, {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "tokens", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)"}, } compileMetadata(1, keyspace, tables, columns, nil, nil, nil, nil, log) assertKeyspaceMetadata( t, keyspace, &KeyspaceMetadata{ Name: "V1Keyspace", Tables: map[string]*TableMetadata{ "Schema": { PartitionKey: []*ColumnMetadata{ { Name: "key", Type: NativeType{typ: TypeBlob}, }, }, ClusteringColumns: []*ColumnMetadata{}, Columns: map[string]*ColumnMetadata{ "key": { Name: "key", Type: NativeType{typ: TypeBlob}, Kind: ColumnPartitionKey, }, }, }, "hints": { PartitionKey: []*ColumnMetadata{ { Name: "target_id", Type: NativeType{typ: TypeUUID}, }, }, ClusteringColumns: []*ColumnMetadata{ { Name: "hint_id", Type: NativeType{typ: TypeTimeUUID}, Order: ASC, }, { Name: "message_version", Type: NativeType{typ: TypeInt}, Order: ASC, }, }, Columns: map[string]*ColumnMetadata{ "target_id": { Name: "target_id", Type: NativeType{typ: TypeUUID}, Kind: ColumnPartitionKey, }, "hint_id": { Name: "hint_id", Type: NativeType{typ: TypeTimeUUID}, Order: ASC, Kind: ColumnClusteringKey, }, "message_version": { Name: "message_version", Type: NativeType{typ: TypeInt}, Order: ASC, Kind: ColumnClusteringKey, }, "mutation": { Name: "mutation", Type: NativeType{typ: TypeBlob}, Kind: ColumnRegular, }, }, }, "peers": { PartitionKey: []*ColumnMetadata{ { Name: "peer", Type: NativeType{typ: TypeInet}, }, }, ClusteringColumns: []*ColumnMetadata{}, Columns: map[string]*ColumnMetadata{ "peer": { Name: "peer", Type: NativeType{typ: TypeInet}, Kind: ColumnPartitionKey, }, "data_center": {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "data_center", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", Type: NativeType{typ: TypeVarchar}}, "host_id": {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "host_id", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType", Type: NativeType{typ: TypeUUID}}, "rack": {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "rack", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", Type: NativeType{typ: TypeVarchar}}, "release_version": {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "release_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", Type: NativeType{typ: TypeVarchar}}, "rpc_address": {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "rpc_address", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.InetAddressType", Type: NativeType{typ: TypeInet}}, "schema_version": {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "schema_version", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UUIDType", Type: NativeType{typ: TypeUUID}}, "tokens": {Keyspace: "V1Keyspace", Table: "peers", Kind: ColumnRegular, Name: "tokens", ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.UTF8Type)", Type: CollectionType{NativeType: NativeType{typ: TypeSet}}}, }, }, "IndexInfo": { PartitionKey: []*ColumnMetadata{ { Name: "table_name", Type: NativeType{typ: TypeVarchar}, }, }, ClusteringColumns: []*ColumnMetadata{ { Name: "index_name", Type: NativeType{typ: TypeVarchar}, Order: DESC, }, }, Columns: map[string]*ColumnMetadata{ "table_name": { Name: "table_name", Type: NativeType{typ: TypeVarchar}, Kind: ColumnPartitionKey, }, "index_name": { Name: "index_name", Type: NativeType{typ: TypeVarchar}, Order: DESC, Kind: ColumnClusteringKey, }, "value": { Name: "value", Type: NativeType{typ: TypeBlob}, Kind: ColumnRegular, }, }, }, "wiki_page": { PartitionKey: []*ColumnMetadata{ { Name: "title", Type: NativeType{typ: TypeVarchar}, }, }, ClusteringColumns: []*ColumnMetadata{ { Name: "revid", Type: NativeType{typ: TypeTimeUUID}, Order: ASC, }, }, Columns: map[string]*ColumnMetadata{ "title": { Name: "title", Type: NativeType{typ: TypeVarchar}, Kind: ColumnPartitionKey, }, "revid": { Name: "revid", Type: NativeType{typ: TypeTimeUUID}, Kind: ColumnClusteringKey, }, }, }, "no_names": { PartitionKey: []*ColumnMetadata{ { Name: "key", Type: NativeType{typ: TypeUUID}, }, { Name: "key2", Type: NativeType{typ: TypeUUID}, }, }, ClusteringColumns: []*ColumnMetadata{ { Name: "column", Type: NativeType{typ: TypeInt}, Order: ASC, }, { Name: "column2", Type: NativeType{typ: TypeInt}, Order: ASC, }, { Name: "column3", Type: NativeType{typ: TypeInt}, Order: ASC, }, }, Columns: map[string]*ColumnMetadata{ "key": { Name: "key", Type: NativeType{typ: TypeUUID}, Kind: ColumnPartitionKey, }, "key2": { Name: "key2", Type: NativeType{typ: TypeUUID}, Kind: ColumnPartitionKey, }, "column": { Name: "column", Type: NativeType{typ: TypeInt}, Order: ASC, Kind: ColumnClusteringKey, }, "column2": { Name: "column2", Type: NativeType{typ: TypeInt}, Order: ASC, Kind: ColumnClusteringKey, }, "column3": { Name: "column3", Type: NativeType{typ: TypeInt}, Order: ASC, Kind: ColumnClusteringKey, }, "value": { Name: "value", Type: NativeType{typ: TypeBlob}, Kind: ColumnRegular, }, }, }, }, }, ) // V2 test - V2+ protocol is simpler so here are some toy examples to verify that the mapping works keyspace = &KeyspaceMetadata{ Name: "V2Keyspace", } tables = []TableMetadata{ { Keyspace: "V2Keyspace", Name: "Table1", }, { Keyspace: "V2Keyspace", Name: "Table2", }, } columns = []ColumnMetadata{ { Keyspace: "V2Keyspace", Table: "Table1", Name: "KEY1", Kind: ColumnPartitionKey, ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", }, { Keyspace: "V2Keyspace", Table: "Table1", Name: "Key1", Kind: ColumnPartitionKey, ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", }, { Keyspace: "V2Keyspace", Table: "Table2", Name: "Column1", Kind: ColumnPartitionKey, ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", }, { Keyspace: "V2Keyspace", Table: "Table2", Name: "Column2", Kind: ColumnClusteringKey, ComponentIndex: 0, Validator: "org.apache.cassandra.db.marshal.UTF8Type", }, { Keyspace: "V2Keyspace", Table: "Table2", Name: "Column3", Kind: ColumnClusteringKey, ComponentIndex: 1, Validator: "org.apache.cassandra.db.marshal.ReversedType(org.apache.cassandra.db.marshal.UTF8Type)", }, { Keyspace: "V2Keyspace", Table: "Table2", Name: "Column4", Kind: ColumnRegular, Validator: "org.apache.cassandra.db.marshal.UTF8Type", }, } compileMetadata(2, keyspace, tables, columns, nil, nil, nil, nil, log) assertKeyspaceMetadata( t, keyspace, &KeyspaceMetadata{ Name: "V2Keyspace", Tables: map[string]*TableMetadata{ "Table1": { PartitionKey: []*ColumnMetadata{ { Name: "Key1", Type: NativeType{typ: TypeVarchar}, }, }, ClusteringColumns: []*ColumnMetadata{}, Columns: map[string]*ColumnMetadata{ "KEY1": { Name: "KEY1", Type: NativeType{typ: TypeVarchar}, Kind: ColumnPartitionKey, }, "Key1": { Name: "Key1", Type: NativeType{typ: TypeVarchar}, Kind: ColumnPartitionKey, }, }, }, "Table2": { PartitionKey: []*ColumnMetadata{ { Name: "Column1", Type: NativeType{typ: TypeVarchar}, }, }, ClusteringColumns: []*ColumnMetadata{ { Name: "Column2", Type: NativeType{typ: TypeVarchar}, Order: ASC, }, { Name: "Column3", Type: NativeType{typ: TypeVarchar}, Order: DESC, }, }, Columns: map[string]*ColumnMetadata{ "Column1": { Name: "Column1", Type: NativeType{typ: TypeVarchar}, Kind: ColumnPartitionKey, }, "Column2": { Name: "Column2", Type: NativeType{typ: TypeVarchar}, Order: ASC, Kind: ColumnClusteringKey, }, "Column3": { Name: "Column3", Type: NativeType{typ: TypeVarchar}, Order: DESC, Kind: ColumnClusteringKey, }, "Column4": { Name: "Column4", Type: NativeType{typ: TypeVarchar}, Kind: ColumnRegular, }, }, }, }, }, ) } // Helper function for asserting that actual metadata returned was as expected func assertKeyspaceMetadata(t *testing.T, actual, expected *KeyspaceMetadata) { if len(expected.Tables) != len(actual.Tables) { t.Errorf("Expected len(%s.Tables) to be %v but was %v", expected.Name, len(expected.Tables), len(actual.Tables)) } for keyT := range expected.Tables { et := expected.Tables[keyT] at, found := actual.Tables[keyT] if !found { t.Errorf("Expected %s.Tables[%s] but was not found", expected.Name, keyT) } else { if keyT != at.Name { t.Errorf("Expected %s.Tables[%s].Name to be %v but was %v", expected.Name, keyT, keyT, at.Name) } if len(et.PartitionKey) != len(at.PartitionKey) { t.Errorf("Expected len(%s.Tables[%s].PartitionKey) to be %v but was %v", expected.Name, keyT, len(et.PartitionKey), len(at.PartitionKey)) } else { for i := range et.PartitionKey { if et.PartitionKey[i].Name != at.PartitionKey[i].Name { t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Name to be '%v' but was '%v'", expected.Name, keyT, i, et.PartitionKey[i].Name, at.PartitionKey[i].Name) } if expected.Name != at.PartitionKey[i].Keyspace { t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Keyspace to be '%v' but was '%v'", expected.Name, keyT, i, expected.Name, at.PartitionKey[i].Keyspace) } if keyT != at.PartitionKey[i].Table { t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Table to be '%v' but was '%v'", expected.Name, keyT, i, keyT, at.PartitionKey[i].Table) } if et.PartitionKey[i].Type.Type() != at.PartitionKey[i].Type.Type() { t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Type.Type to be %v but was %v", expected.Name, keyT, i, et.PartitionKey[i].Type.Type(), at.PartitionKey[i].Type.Type()) } if i != at.PartitionKey[i].ComponentIndex { t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].ComponentIndex to be %v but was %v", expected.Name, keyT, i, i, at.PartitionKey[i].ComponentIndex) } if ColumnPartitionKey != at.PartitionKey[i].Kind { t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Kind to be '%v' but was '%v'", expected.Name, keyT, i, ColumnPartitionKey, at.PartitionKey[i].Kind) } } } if len(et.ClusteringColumns) != len(at.ClusteringColumns) { t.Errorf("Expected len(%s.Tables[%s].ClusteringColumns) to be %v but was %v", expected.Name, keyT, len(et.ClusteringColumns), len(at.ClusteringColumns)) } else { for i := range et.ClusteringColumns { if at.ClusteringColumns[i] == nil { t.Fatalf("Unexpected nil value: %s.Tables[%s].ClusteringColumns[%d]", expected.Name, keyT, i) } if et.ClusteringColumns[i].Name != at.ClusteringColumns[i].Name { t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Name to be '%v' but was '%v'", expected.Name, keyT, i, et.ClusteringColumns[i].Name, at.ClusteringColumns[i].Name) } if expected.Name != at.ClusteringColumns[i].Keyspace { t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Keyspace to be '%v' but was '%v'", expected.Name, keyT, i, expected.Name, at.ClusteringColumns[i].Keyspace) } if keyT != at.ClusteringColumns[i].Table { t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Table to be '%v' but was '%v'", expected.Name, keyT, i, keyT, at.ClusteringColumns[i].Table) } if et.ClusteringColumns[i].Type.Type() != at.ClusteringColumns[i].Type.Type() { t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Type.Type to be %v but was %v", expected.Name, keyT, i, et.ClusteringColumns[i].Type.Type(), at.ClusteringColumns[i].Type.Type()) } if i != at.ClusteringColumns[i].ComponentIndex { t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].ComponentIndex to be %v but was %v", expected.Name, keyT, i, i, at.ClusteringColumns[i].ComponentIndex) } if et.ClusteringColumns[i].Order != at.ClusteringColumns[i].Order { t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Order to be %v but was %v", expected.Name, keyT, i, et.ClusteringColumns[i].Order, at.ClusteringColumns[i].Order) } if ColumnClusteringKey != at.ClusteringColumns[i].Kind { t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Kind to be '%v' but was '%v'", expected.Name, keyT, i, ColumnClusteringKey, at.ClusteringColumns[i].Kind) } } } if len(et.Columns) != len(at.Columns) { eKeys := make([]string, 0, len(et.Columns)) for key := range et.Columns { eKeys = append(eKeys, key) } aKeys := make([]string, 0, len(at.Columns)) for key := range at.Columns { aKeys = append(aKeys, key) } t.Errorf("Expected len(%s.Tables[%s].Columns) to be %v (keys:%v) but was %v (keys:%v)", expected.Name, keyT, len(et.Columns), eKeys, len(at.Columns), aKeys) } else { for keyC := range et.Columns { ec := et.Columns[keyC] ac, found := at.Columns[keyC] if !found { t.Errorf("Expected %s.Tables[%s].Columns[%s] but was not found", expected.Name, keyT, keyC) } else { if keyC != ac.Name { t.Errorf("Expected %s.Tables[%s].Columns[%s].Name to be '%v' but was '%v'", expected.Name, keyT, keyC, keyC, at.Name) } if expected.Name != ac.Keyspace { t.Errorf("Expected %s.Tables[%s].Columns[%s].Keyspace to be '%v' but was '%v'", expected.Name, keyT, keyC, expected.Name, ac.Keyspace) } if keyT != ac.Table { t.Errorf("Expected %s.Tables[%s].Columns[%s].Table to be '%v' but was '%v'", expected.Name, keyT, keyC, keyT, ac.Table) } if ec.Type.Type() != ac.Type.Type() { t.Errorf("Expected %s.Tables[%s].Columns[%s].Type.Type to be %v but was %v", expected.Name, keyT, keyC, ec.Type.Type(), ac.Type.Type()) } if ec.Order != ac.Order { t.Errorf("Expected %s.Tables[%s].Columns[%s].Order to be %v but was %v", expected.Name, keyT, keyC, ec.Order, ac.Order) } if ec.Kind != ac.Kind { t.Errorf("Expected %s.Tables[%s].Columns[%s].Kind to be '%v' but was '%v'", expected.Name, keyT, keyC, ec.Kind, ac.Kind) } } } } } } } // Tests the cassandra type definition parser func TestTypeParser(t *testing.T) { // native type assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.UTF8Type", assertTypeInfo{Type: TypeVarchar}, ) // reversed assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.ReversedType(org.apache.cassandra.db.marshal.UUIDType)", assertTypeInfo{Type: TypeUUID, Reversed: true}, ) // set assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.Int32Type)", assertTypeInfo{ Type: TypeSet, Elem: &assertTypeInfo{Type: TypeInt}, }, ) // list assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.ListType(org.apache.cassandra.db.marshal.TimeUUIDType)", assertTypeInfo{ Type: TypeList, Elem: &assertTypeInfo{Type: TypeTimeUUID}, }, ) // map assertParseNonCompositeType( t, " org.apache.cassandra.db.marshal.MapType( org.apache.cassandra.db.marshal.UUIDType , org.apache.cassandra.db.marshal.BytesType ) ", assertTypeInfo{ Type: TypeMap, Key: &assertTypeInfo{Type: TypeUUID}, Elem: &assertTypeInfo{Type: TypeBlob}, }, ) // custom assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)", assertTypeInfo{Type: TypeCustom, Custom: "org.apache.cassandra.db.marshal.UserType(sandbox,61646472657373,737472656574:org.apache.cassandra.db.marshal.UTF8Type,63697479:org.apache.cassandra.db.marshal.UTF8Type,7a6970:org.apache.cassandra.db.marshal.Int32Type)"}, ) assertParseNonCompositeType( t, "org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.marshal.LexicalUUIDType)", assertTypeInfo{Type: TypeCustom, Custom: "org.apache.cassandra.db.marshal.DynamicCompositeType(u=>org.apache.cassandra.db.marshal.UUIDType,d=>org.apache.cassandra.db.marshal.DateType,t=>org.apache.cassandra.db.marshal.TimeUUIDType,b=>org.apache.cassandra.db.marshal.BytesType,s=>org.apache.cassandra.db.marshal.UTF8Type,B=>org.apache.cassandra.db.marshal.BooleanType,a=>org.apache.cassandra.db.marshal.AsciiType,l=>org.apache.cassandra.db.marshal.LongType,i=>org.apache.cassandra.db.marshal.IntegerType,x=>org.apache.cassandra.db.marshal.LexicalUUIDType)"}, ) // composite defs assertParseCompositeType( t, "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UTF8Type)", []assertTypeInfo{ {Type: TypeVarchar}, }, nil, ) assertParseCompositeType( t, "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.ReversedType(org.apache.cassandra.db.marshal.DateType),org.apache.cassandra.db.marshal.UTF8Type)", []assertTypeInfo{ {Type: TypeTimestamp, Reversed: true}, {Type: TypeVarchar}, }, nil, ) assertParseCompositeType( t, "org.apache.cassandra.db.marshal.CompositeType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.ColumnToCollectionType(726f77735f6d6572676564:org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.Int32Type,org.apache.cassandra.db.marshal.LongType)))", []assertTypeInfo{ {Type: TypeVarchar}, }, map[string]assertTypeInfo{ "rows_merged": { Type: TypeMap, Key: &assertTypeInfo{Type: TypeInt}, Elem: &assertTypeInfo{Type: TypeBigInt}, }, }, ) } // expected data holder type assertTypeInfo struct { Type Type Reversed bool Elem *assertTypeInfo Key *assertTypeInfo Custom string } // Helper function for asserting that the type parser returns the expected // results for the given definition func assertParseNonCompositeType( t *testing.T, def string, typeExpected assertTypeInfo, ) { log := &defaultLogger{} result := parseType(def, log) if len(result.reversed) != 1 { t.Errorf("%s expected %d reversed values but there were %d", def, 1, len(result.reversed)) } assertParseNonCompositeTypes( t, def, []assertTypeInfo{typeExpected}, result.types, ) // expect no composite part of the result if result.isComposite { t.Errorf("%s: Expected not composite", def) } if result.collections != nil { t.Errorf("%s: Expected nil collections: %v", def, result.collections) } } // Helper function for asserting that the type parser returns the expected // results for the given definition func assertParseCompositeType( t *testing.T, def string, typesExpected []assertTypeInfo, collectionsExpected map[string]assertTypeInfo, ) { log := &defaultLogger{} result := parseType(def, log) if len(result.reversed) != len(typesExpected) { t.Errorf("%s expected %d reversed values but there were %d", def, len(typesExpected), len(result.reversed)) } assertParseNonCompositeTypes( t, def, typesExpected, result.types, ) // expect composite part of the result if !result.isComposite { t.Errorf("%s: Expected composite", def) } if result.collections == nil { t.Errorf("%s: Expected non-nil collections: %v", def, result.collections) } for name, typeExpected := range collectionsExpected { // check for an actual type for this name typeActual, found := result.collections[name] if !found { t.Errorf("%s.tcollections: Expected param named %s but there wasn't", def, name) } else { // remove the actual from the collection so we can detect extras delete(result.collections, name) // check the type assertParseNonCompositeTypes( t, def+"collections["+name+"]", []assertTypeInfo{typeExpected}, []TypeInfo{typeActual}, ) } } if len(result.collections) != 0 { t.Errorf("%s.collections: Expected no more types in collections, but there was %v", def, result.collections) } } // Helper function for asserting that the type parser returns the expected // results for the given definition func assertParseNonCompositeTypes( t *testing.T, context string, typesExpected []assertTypeInfo, typesActual []TypeInfo, ) { if len(typesActual) != len(typesExpected) { t.Errorf("%s: Expected %d types, but there were %d", context, len(typesExpected), len(typesActual)) } for i := range typesExpected { typeExpected := typesExpected[i] typeActual := typesActual[i] // shadow copy the context for local modification context := context if len(typesExpected) > 1 { context = context + "[" + strconv.Itoa(i) + "]" } // check the type if typeActual.Type() != typeExpected.Type { t.Errorf("%s: Expected to parse Type to %s but was %s", context, typeExpected.Type, typeActual.Type()) } // check the custom if typeActual.Custom() != typeExpected.Custom { t.Errorf("%s: Expected to parse Custom %s but was %s", context, typeExpected.Custom, typeActual.Custom()) } collection, _ := typeActual.(CollectionType) // check the elem if typeExpected.Elem != nil { if collection.Elem == nil { t.Errorf("%s: Expected to parse Elem, but was nil ", context) } else { assertParseNonCompositeTypes( t, context+".Elem", []assertTypeInfo{*typeExpected.Elem}, []TypeInfo{collection.Elem}, ) } } else if collection.Elem != nil { t.Errorf("%s: Expected to not parse Elem, but was %+v", context, collection.Elem) } // check the key if typeExpected.Key != nil { if collection.Key == nil { t.Errorf("%s: Expected to parse Key, but was nil ", context) } else { assertParseNonCompositeTypes( t, context+".Key", []assertTypeInfo{*typeExpected.Key}, []TypeInfo{collection.Key}, ) } } else if collection.Key != nil { t.Errorf("%s: Expected to not parse Key, but was %+v", context, collection.Key) } } } cassandra-gocql-driver-1.7.0/policies.go000066400000000000000000000726321467504044300202310ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql //This file will be the future home for more policies import ( "context" "errors" "fmt" "math" "math/rand" "net" "sync" "sync/atomic" "time" "github.com/hailocab/go-hostpool" ) // cowHostList implements a copy on write host list, its equivalent type is []*HostInfo type cowHostList struct { list atomic.Value mu sync.Mutex } func (c *cowHostList) String() string { return fmt.Sprintf("%+v", c.get()) } func (c *cowHostList) get() []*HostInfo { // TODO(zariel): should we replace this with []*HostInfo? l, ok := c.list.Load().(*[]*HostInfo) if !ok { return nil } return *l } // add will add a host if it not already in the list func (c *cowHostList) add(host *HostInfo) bool { c.mu.Lock() l := c.get() if n := len(l); n == 0 { l = []*HostInfo{host} } else { newL := make([]*HostInfo, n+1) for i := 0; i < n; i++ { if host.Equal(l[i]) { c.mu.Unlock() return false } newL[i] = l[i] } newL[n] = host l = newL } c.list.Store(&l) c.mu.Unlock() return true } func (c *cowHostList) remove(ip net.IP) bool { c.mu.Lock() l := c.get() size := len(l) if size == 0 { c.mu.Unlock() return false } found := false newL := make([]*HostInfo, 0, size) for i := 0; i < len(l); i++ { if !l[i].ConnectAddress().Equal(ip) { newL = append(newL, l[i]) } else { found = true } } if !found { c.mu.Unlock() return false } newL = newL[: size-1 : size-1] c.list.Store(&newL) c.mu.Unlock() return true } // RetryableQuery is an interface that represents a query or batch statement that // exposes the correct functions for the retry policy logic to evaluate correctly. type RetryableQuery interface { Attempts() int SetConsistency(c Consistency) GetConsistency() Consistency Context() context.Context } type RetryType uint16 const ( Retry RetryType = 0x00 // retry on same connection RetryNextHost RetryType = 0x01 // retry on another connection Ignore RetryType = 0x02 // ignore error and return result Rethrow RetryType = 0x03 // raise error and stop retrying ) // ErrUnknownRetryType is returned if the retry policy returns a retry type // unknown to the query executor. var ErrUnknownRetryType = errors.New("unknown retry type returned by retry policy") // RetryPolicy interface is used by gocql to determine if a query can be attempted // again after a retryable error has been received. The interface allows gocql // users to implement their own logic to determine if a query can be attempted // again. // // See SimpleRetryPolicy as an example of implementing and using a RetryPolicy // interface. type RetryPolicy interface { Attempt(RetryableQuery) bool GetRetryType(error) RetryType } // SimpleRetryPolicy has simple logic for attempting a query a fixed number of times. // // See below for examples of usage: // // //Assign to the cluster // cluster.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3} // // //Assign to a query // query.RetryPolicy(&gocql.SimpleRetryPolicy{NumRetries: 1}) type SimpleRetryPolicy struct { NumRetries int //Number of times to retry a query } // Attempt tells gocql to attempt the query again based on query.Attempts being less // than the NumRetries defined in the policy. func (s *SimpleRetryPolicy) Attempt(q RetryableQuery) bool { return q.Attempts() <= s.NumRetries } func (s *SimpleRetryPolicy) GetRetryType(err error) RetryType { return RetryNextHost } // ExponentialBackoffRetryPolicy sleeps between attempts type ExponentialBackoffRetryPolicy struct { NumRetries int Min, Max time.Duration } func (e *ExponentialBackoffRetryPolicy) Attempt(q RetryableQuery) bool { if q.Attempts() > e.NumRetries { return false } time.Sleep(e.napTime(q.Attempts())) return true } // used to calculate exponentially growing time func getExponentialTime(min time.Duration, max time.Duration, attempts int) time.Duration { if min <= 0 { min = 100 * time.Millisecond } if max <= 0 { max = 10 * time.Second } minFloat := float64(min) napDuration := minFloat * math.Pow(2, float64(attempts-1)) // add some jitter napDuration += rand.Float64()*minFloat - (minFloat / 2) if napDuration > float64(max) { return time.Duration(max) } return time.Duration(napDuration) } func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType { return RetryNextHost } // DowngradingConsistencyRetryPolicy: Next retry will be with the next consistency level // provided in the slice // // On a read timeout: the operation is retried with the next provided consistency // level. // // On a write timeout: if the operation is an :attr:`~.UNLOGGED_BATCH` // and at least one replica acknowledged the write, the operation is // retried with the next consistency level. Furthermore, for other // write types, if at least one replica acknowledged the write, the // timeout is ignored. // // On an unavailable exception: if at least one replica is alive, the // operation is retried with the next provided consistency level. type DowngradingConsistencyRetryPolicy struct { ConsistencyLevelsToTry []Consistency } func (d *DowngradingConsistencyRetryPolicy) Attempt(q RetryableQuery) bool { currentAttempt := q.Attempts() if currentAttempt > len(d.ConsistencyLevelsToTry) { return false } else if currentAttempt > 0 { q.SetConsistency(d.ConsistencyLevelsToTry[currentAttempt-1]) } return true } func (d *DowngradingConsistencyRetryPolicy) GetRetryType(err error) RetryType { switch t := err.(type) { case *RequestErrUnavailable: if t.Alive > 0 { return Retry } return Rethrow case *RequestErrWriteTimeout: if t.WriteType == "SIMPLE" || t.WriteType == "BATCH" || t.WriteType == "COUNTER" { if t.Received > 0 { return Ignore } return Rethrow } if t.WriteType == "UNLOGGED_BATCH" { return Retry } return Rethrow case *RequestErrReadTimeout: return Retry default: return RetryNextHost } } func (e *ExponentialBackoffRetryPolicy) napTime(attempts int) time.Duration { return getExponentialTime(e.Min, e.Max, attempts) } type HostStateNotifier interface { AddHost(host *HostInfo) RemoveHost(host *HostInfo) HostUp(host *HostInfo) HostDown(host *HostInfo) } type KeyspaceUpdateEvent struct { Keyspace string Change string } type HostTierer interface { // HostTier returns an integer specifying how far a host is from the client. // Tier must start at 0. // The value is used to prioritize closer hosts during host selection. // For example this could be: // 0 - local rack, 1 - local DC, 2 - remote DC // or: // 0 - local DC, 1 - remote DC HostTier(host *HostInfo) uint // This function returns the maximum possible host tier MaxHostTier() uint } // HostSelectionPolicy is an interface for selecting // the most appropriate host to execute a given query. // HostSelectionPolicy instances cannot be shared between sessions. type HostSelectionPolicy interface { HostStateNotifier SetPartitioner KeyspaceChanged(KeyspaceUpdateEvent) Init(*Session) IsLocal(host *HostInfo) bool // Pick returns an iteration function over selected hosts. // Multiple attempts of a single query execution won't call the returned NextHost function concurrently, // so it's safe to have internal state without additional synchronization as long as every call to Pick returns // a different instance of NextHost. Pick(ExecutableQuery) NextHost } // SelectedHost is an interface returned when picking a host from a host // selection policy. type SelectedHost interface { Info() *HostInfo Mark(error) } type selectedHost HostInfo func (host *selectedHost) Info() *HostInfo { return (*HostInfo)(host) } func (host *selectedHost) Mark(err error) {} // NextHost is an iteration function over picked hosts type NextHost func() SelectedHost // RoundRobinHostPolicy is a round-robin load balancing policy, where each host // is tried sequentially for each query. func RoundRobinHostPolicy() HostSelectionPolicy { return &roundRobinHostPolicy{} } type roundRobinHostPolicy struct { hosts cowHostList lastUsedHostIdx uint64 } func (r *roundRobinHostPolicy) IsLocal(*HostInfo) bool { return true } func (r *roundRobinHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} func (r *roundRobinHostPolicy) SetPartitioner(partitioner string) {} func (r *roundRobinHostPolicy) Init(*Session) {} func (r *roundRobinHostPolicy) Pick(qry ExecutableQuery) NextHost { nextStartOffset := atomic.AddUint64(&r.lastUsedHostIdx, 1) return roundRobbin(int(nextStartOffset), r.hosts.get()) } func (r *roundRobinHostPolicy) AddHost(host *HostInfo) { r.hosts.add(host) } func (r *roundRobinHostPolicy) RemoveHost(host *HostInfo) { r.hosts.remove(host.ConnectAddress()) } func (r *roundRobinHostPolicy) HostUp(host *HostInfo) { r.AddHost(host) } func (r *roundRobinHostPolicy) HostDown(host *HostInfo) { r.RemoveHost(host) } func ShuffleReplicas() func(*tokenAwareHostPolicy) { return func(t *tokenAwareHostPolicy) { t.shuffleReplicas = true } } // NonLocalReplicasFallback enables fallback to replicas that are not considered local. // // TokenAwareHostPolicy used with DCAwareHostPolicy fallback first selects replicas by partition key in local DC, then // falls back to other nodes in the local DC. Enabling NonLocalReplicasFallback causes TokenAwareHostPolicy // to first select replicas by partition key in local DC, then replicas by partition key in remote DCs and fall back // to other nodes in local DC. func NonLocalReplicasFallback() func(policy *tokenAwareHostPolicy) { return func(t *tokenAwareHostPolicy) { t.nonLocalReplicasFallback = true } } // TokenAwareHostPolicy is a token aware host selection policy, where hosts are // selected based on the partition key, so queries are sent to the host which // owns the partition. Fallback is used when routing information is not available. func TokenAwareHostPolicy(fallback HostSelectionPolicy, opts ...func(*tokenAwareHostPolicy)) HostSelectionPolicy { p := &tokenAwareHostPolicy{fallback: fallback} for _, opt := range opts { opt(p) } return p } // clusterMeta holds metadata about cluster topology. // It is used inside atomic.Value and shallow copies are used when replacing it, // so fields should not be modified in-place. Instead, to modify a field a copy of the field should be made // and the pointer in clusterMeta updated to point to the new value. type clusterMeta struct { // replicas is map[keyspace]map[token]hosts replicas map[string]tokenRingReplicas tokenRing *tokenRing } type tokenAwareHostPolicy struct { fallback HostSelectionPolicy getKeyspaceMetadata func(keyspace string) (*KeyspaceMetadata, error) getKeyspaceName func() string shuffleReplicas bool nonLocalReplicasFallback bool // mu protects writes to hosts, partitioner, metadata. // reads can be unlocked as long as they are not used for updating state later. mu sync.Mutex hosts cowHostList partitioner string metadata atomic.Value // *clusterMeta logger StdLogger } func (t *tokenAwareHostPolicy) Init(s *Session) { t.mu.Lock() defer t.mu.Unlock() if t.getKeyspaceMetadata != nil { // Init was already called. // See https://github.com/scylladb/gocql/issues/94. panic("sharing token aware host selection policy between sessions is not supported") } t.getKeyspaceMetadata = s.KeyspaceMetadata t.getKeyspaceName = func() string { return s.cfg.Keyspace } t.logger = s.logger } func (t *tokenAwareHostPolicy) IsLocal(host *HostInfo) bool { return t.fallback.IsLocal(host) } func (t *tokenAwareHostPolicy) KeyspaceChanged(update KeyspaceUpdateEvent) { t.mu.Lock() defer t.mu.Unlock() meta := t.getMetadataForUpdate() t.updateReplicas(meta, update.Keyspace) t.metadata.Store(meta) } // updateReplicas updates replicas in clusterMeta. // It must be called with t.mu mutex locked. // meta must not be nil and it's replicas field will be updated. func (t *tokenAwareHostPolicy) updateReplicas(meta *clusterMeta, keyspace string) { newReplicas := make(map[string]tokenRingReplicas, len(meta.replicas)) ks, err := t.getKeyspaceMetadata(keyspace) if err == nil { strat := getStrategy(ks, t.logger) if strat != nil { if meta != nil && meta.tokenRing != nil { newReplicas[keyspace] = strat.replicaMap(meta.tokenRing) } } } for ks, replicas := range meta.replicas { if ks != keyspace { newReplicas[ks] = replicas } } meta.replicas = newReplicas } func (t *tokenAwareHostPolicy) SetPartitioner(partitioner string) { t.mu.Lock() defer t.mu.Unlock() if t.partitioner != partitioner { t.fallback.SetPartitioner(partitioner) t.partitioner = partitioner meta := t.getMetadataForUpdate() meta.resetTokenRing(t.partitioner, t.hosts.get(), t.logger) t.updateReplicas(meta, t.getKeyspaceName()) t.metadata.Store(meta) } } func (t *tokenAwareHostPolicy) AddHost(host *HostInfo) { t.mu.Lock() if t.hosts.add(host) { meta := t.getMetadataForUpdate() meta.resetTokenRing(t.partitioner, t.hosts.get(), t.logger) t.updateReplicas(meta, t.getKeyspaceName()) t.metadata.Store(meta) } t.mu.Unlock() t.fallback.AddHost(host) } func (t *tokenAwareHostPolicy) AddHosts(hosts []*HostInfo) { t.mu.Lock() for _, host := range hosts { t.hosts.add(host) } meta := t.getMetadataForUpdate() meta.resetTokenRing(t.partitioner, t.hosts.get(), t.logger) t.updateReplicas(meta, t.getKeyspaceName()) t.metadata.Store(meta) t.mu.Unlock() for _, host := range hosts { t.fallback.AddHost(host) } } func (t *tokenAwareHostPolicy) RemoveHost(host *HostInfo) { t.mu.Lock() if t.hosts.remove(host.ConnectAddress()) { meta := t.getMetadataForUpdate() meta.resetTokenRing(t.partitioner, t.hosts.get(), t.logger) t.updateReplicas(meta, t.getKeyspaceName()) t.metadata.Store(meta) } t.mu.Unlock() t.fallback.RemoveHost(host) } func (t *tokenAwareHostPolicy) HostUp(host *HostInfo) { t.fallback.HostUp(host) } func (t *tokenAwareHostPolicy) HostDown(host *HostInfo) { t.fallback.HostDown(host) } // getMetadataReadOnly returns current cluster metadata. // Metadata uses copy on write, so the returned value should be only used for reading. // To obtain a copy that could be updated, use getMetadataForUpdate instead. func (t *tokenAwareHostPolicy) getMetadataReadOnly() *clusterMeta { meta, _ := t.metadata.Load().(*clusterMeta) return meta } // getMetadataForUpdate returns clusterMeta suitable for updating. // It is a SHALLOW copy of current metadata in case it was already set or new empty clusterMeta otherwise. // This function should be called with t.mu mutex locked and the mutex should not be released before // storing the new metadata. func (t *tokenAwareHostPolicy) getMetadataForUpdate() *clusterMeta { metaReadOnly := t.getMetadataReadOnly() meta := new(clusterMeta) if metaReadOnly != nil { *meta = *metaReadOnly } return meta } // resetTokenRing creates a new tokenRing. // It must be called with t.mu locked. func (m *clusterMeta) resetTokenRing(partitioner string, hosts []*HostInfo, logger StdLogger) { if partitioner == "" { // partitioner not yet set return } // create a new token ring tokenRing, err := newTokenRing(partitioner, hosts) if err != nil { logger.Printf("Unable to update the token ring due to error: %s", err) return } // replace the token ring m.tokenRing = tokenRing } func (t *tokenAwareHostPolicy) Pick(qry ExecutableQuery) NextHost { if qry == nil { return t.fallback.Pick(qry) } routingKey, err := qry.GetRoutingKey() if err != nil { return t.fallback.Pick(qry) } else if routingKey == nil { return t.fallback.Pick(qry) } meta := t.getMetadataReadOnly() if meta == nil || meta.tokenRing == nil { return t.fallback.Pick(qry) } token := meta.tokenRing.partitioner.Hash(routingKey) ht := meta.replicas[qry.Keyspace()].replicasFor(token) var replicas []*HostInfo if ht == nil { host, _ := meta.tokenRing.GetHostForToken(token) replicas = []*HostInfo{host} } else { replicas = ht.hosts if t.shuffleReplicas { replicas = shuffleHosts(replicas) } } var ( fallbackIter NextHost i, j, k int remote [][]*HostInfo tierer HostTierer tiererOk bool maxTier uint ) if tierer, tiererOk = t.fallback.(HostTierer); tiererOk { maxTier = tierer.MaxHostTier() } else { maxTier = 1 } if t.nonLocalReplicasFallback { remote = make([][]*HostInfo, maxTier) } used := make(map[*HostInfo]bool, len(replicas)) return func() SelectedHost { for i < len(replicas) { h := replicas[i] i++ var tier uint if tiererOk { tier = tierer.HostTier(h) } else if t.fallback.IsLocal(h) { tier = 0 } else { tier = 1 } if tier != 0 { if t.nonLocalReplicasFallback { remote[tier-1] = append(remote[tier-1], h) } continue } if h.IsUp() { used[h] = true return (*selectedHost)(h) } } if t.nonLocalReplicasFallback { for j < len(remote) && k < len(remote[j]) { h := remote[j][k] k++ if k >= len(remote[j]) { j++ k = 0 } if h.IsUp() { used[h] = true return (*selectedHost)(h) } } } if fallbackIter == nil { // fallback fallbackIter = t.fallback.Pick(qry) } // filter the token aware selected hosts from the fallback hosts for fallbackHost := fallbackIter(); fallbackHost != nil; fallbackHost = fallbackIter() { if !used[fallbackHost.Info()] { used[fallbackHost.Info()] = true return fallbackHost } } return nil } } // HostPoolHostPolicy is a host policy which uses the bitly/go-hostpool library // to distribute queries between hosts and prevent sending queries to // unresponsive hosts. When creating the host pool that is passed to the policy // use an empty slice of hosts as the hostpool will be populated later by gocql. // See below for examples of usage: // // // Create host selection policy using a simple host pool // cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy(hostpool.New(nil)) // // // Create host selection policy using an epsilon greedy pool // cluster.PoolConfig.HostSelectionPolicy = HostPoolHostPolicy( // hostpool.NewEpsilonGreedy(nil, 0, &hostpool.LinearEpsilonValueCalculator{}), // ) func HostPoolHostPolicy(hp hostpool.HostPool) HostSelectionPolicy { return &hostPoolHostPolicy{hostMap: map[string]*HostInfo{}, hp: hp} } type hostPoolHostPolicy struct { hp hostpool.HostPool mu sync.RWMutex hostMap map[string]*HostInfo } func (r *hostPoolHostPolicy) Init(*Session) {} func (r *hostPoolHostPolicy) KeyspaceChanged(KeyspaceUpdateEvent) {} func (r *hostPoolHostPolicy) SetPartitioner(string) {} func (r *hostPoolHostPolicy) IsLocal(*HostInfo) bool { return true } func (r *hostPoolHostPolicy) SetHosts(hosts []*HostInfo) { peers := make([]string, len(hosts)) hostMap := make(map[string]*HostInfo, len(hosts)) for i, host := range hosts { ip := host.ConnectAddress().String() peers[i] = ip hostMap[ip] = host } r.mu.Lock() r.hp.SetHosts(peers) r.hostMap = hostMap r.mu.Unlock() } func (r *hostPoolHostPolicy) AddHost(host *HostInfo) { ip := host.ConnectAddress().String() r.mu.Lock() defer r.mu.Unlock() // If the host addr is present and isn't nil return if h, ok := r.hostMap[ip]; ok && h != nil { return } // otherwise, add the host to the map r.hostMap[ip] = host // and construct a new peer list to give to the HostPool hosts := make([]string, 0, len(r.hostMap)) for addr := range r.hostMap { hosts = append(hosts, addr) } r.hp.SetHosts(hosts) } func (r *hostPoolHostPolicy) RemoveHost(host *HostInfo) { ip := host.ConnectAddress().String() r.mu.Lock() defer r.mu.Unlock() if _, ok := r.hostMap[ip]; !ok { return } delete(r.hostMap, ip) hosts := make([]string, 0, len(r.hostMap)) for _, host := range r.hostMap { hosts = append(hosts, host.ConnectAddress().String()) } r.hp.SetHosts(hosts) } func (r *hostPoolHostPolicy) HostUp(host *HostInfo) { r.AddHost(host) } func (r *hostPoolHostPolicy) HostDown(host *HostInfo) { r.RemoveHost(host) } func (r *hostPoolHostPolicy) Pick(qry ExecutableQuery) NextHost { return func() SelectedHost { r.mu.RLock() defer r.mu.RUnlock() if len(r.hostMap) == 0 { return nil } hostR := r.hp.Get() host, ok := r.hostMap[hostR.Host()] if !ok { return nil } return selectedHostPoolHost{ policy: r, info: host, hostR: hostR, } } } // selectedHostPoolHost is a host returned by the hostPoolHostPolicy and // implements the SelectedHost interface type selectedHostPoolHost struct { policy *hostPoolHostPolicy info *HostInfo hostR hostpool.HostPoolResponse } func (host selectedHostPoolHost) Info() *HostInfo { return host.info } func (host selectedHostPoolHost) Mark(err error) { ip := host.info.ConnectAddress().String() host.policy.mu.RLock() defer host.policy.mu.RUnlock() if _, ok := host.policy.hostMap[ip]; !ok { // host was removed between pick and mark return } host.hostR.Mark(err) } type dcAwareRR struct { local string localHosts cowHostList remoteHosts cowHostList lastUsedHostIdx uint64 } // DCAwareRoundRobinPolicy is a host selection policies which will prioritize and // return hosts which are in the local datacentre before returning hosts in all // other datercentres func DCAwareRoundRobinPolicy(localDC string) HostSelectionPolicy { return &dcAwareRR{local: localDC} } func (d *dcAwareRR) Init(*Session) {} func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {} func (d *dcAwareRR) SetPartitioner(p string) {} func (d *dcAwareRR) IsLocal(host *HostInfo) bool { return host.DataCenter() == d.local } func (d *dcAwareRR) AddHost(host *HostInfo) { if d.IsLocal(host) { d.localHosts.add(host) } else { d.remoteHosts.add(host) } } func (d *dcAwareRR) RemoveHost(host *HostInfo) { if d.IsLocal(host) { d.localHosts.remove(host.ConnectAddress()) } else { d.remoteHosts.remove(host.ConnectAddress()) } } func (d *dcAwareRR) HostUp(host *HostInfo) { d.AddHost(host) } func (d *dcAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) } // This function is supposed to be called in a fashion // roundRobbin(offset, hostsPriority1, hostsPriority2, hostsPriority3 ... ) // // E.g. for DC-naive strategy: // roundRobbin(offset, allHosts) // // For tiered and DC-aware strategy: // roundRobbin(offset, localHosts, remoteHosts) func roundRobbin(shift int, hosts ...[]*HostInfo) NextHost { currentLayer := 0 currentlyObserved := 0 return func() SelectedHost { // iterate over layers for { if currentLayer == len(hosts) { return nil } currentLayerSize := len(hosts[currentLayer]) // iterate over hosts within a layer for { currentlyObserved++ if currentlyObserved > currentLayerSize { currentLayer++ currentlyObserved = 0 break } h := hosts[currentLayer][(shift+currentlyObserved)%currentLayerSize] if h.IsUp() { return (*selectedHost)(h) } } } } } func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost { nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get()) } // RackAwareRoundRobinPolicy is a host selection policies which will prioritize and // return hosts which are in the local rack, before hosts in the local datacenter but // a different rack, before hosts in all other datercentres type rackAwareRR struct { // lastUsedHostIdx keeps the index of the last used host. // It is accessed atomically and needs to be aligned to 64 bits, so we // keep it first in the struct. Do not move it or add new struct members // before it. lastUsedHostIdx uint64 localDC string localRack string hosts []cowHostList } func RackAwareRoundRobinPolicy(localDC string, localRack string) HostSelectionPolicy { hosts := make([]cowHostList, 3) return &rackAwareRR{localDC: localDC, localRack: localRack, hosts: hosts} } func (d *rackAwareRR) Init(*Session) {} func (d *rackAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {} func (d *rackAwareRR) SetPartitioner(p string) {} func (d *rackAwareRR) MaxHostTier() uint { return 2 } func (d *rackAwareRR) HostTier(host *HostInfo) uint { if host.DataCenter() == d.localDC { if host.Rack() == d.localRack { return 0 } else { return 1 } } else { return 2 } } func (d *rackAwareRR) IsLocal(host *HostInfo) bool { return d.HostTier(host) == 0 } func (d *rackAwareRR) AddHost(host *HostInfo) { dist := d.HostTier(host) d.hosts[dist].add(host) } func (d *rackAwareRR) RemoveHost(host *HostInfo) { dist := d.HostTier(host) d.hosts[dist].remove(host.ConnectAddress()) } func (d *rackAwareRR) HostUp(host *HostInfo) { d.AddHost(host) } func (d *rackAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) } func (d *rackAwareRR) Pick(q ExecutableQuery) NextHost { nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) return roundRobbin(int(nextStartOffset), d.hosts[0].get(), d.hosts[1].get(), d.hosts[2].get()) } // ReadyPolicy defines a policy for when a HostSelectionPolicy can be used. After // each host connects during session initialization, the Ready method will be // called. If you only need a single Host to be up you can wrap a // HostSelectionPolicy policy with SingleHostReadyPolicy. type ReadyPolicy interface { Ready() bool } // SingleHostReadyPolicy wraps a HostSelectionPolicy and returns Ready after a // single host has been added via HostUp func SingleHostReadyPolicy(p HostSelectionPolicy) *singleHostReadyPolicy { return &singleHostReadyPolicy{ HostSelectionPolicy: p, } } type singleHostReadyPolicy struct { HostSelectionPolicy ready bool readyMux sync.Mutex } func (s *singleHostReadyPolicy) HostUp(host *HostInfo) { s.HostSelectionPolicy.HostUp(host) s.readyMux.Lock() s.ready = true s.readyMux.Unlock() } func (s *singleHostReadyPolicy) Ready() bool { s.readyMux.Lock() ready := s.ready s.readyMux.Unlock() if !ready { return false } // in case the wrapped policy is also a ReadyPolicy, defer to that if rdy, ok := s.HostSelectionPolicy.(ReadyPolicy); ok { return rdy.Ready() } return true } // ConvictionPolicy interface is used by gocql to determine if a host should be // marked as DOWN based on the error and host info type ConvictionPolicy interface { // Implementations should return `true` if the host should be convicted, `false` otherwise. AddFailure(error error, host *HostInfo) bool //Implementations should clear out any convictions or state regarding the host. Reset(host *HostInfo) } // SimpleConvictionPolicy implements a ConvictionPolicy which convicts all hosts // regardless of error type SimpleConvictionPolicy struct { } func (e *SimpleConvictionPolicy) AddFailure(error error, host *HostInfo) bool { return true } func (e *SimpleConvictionPolicy) Reset(host *HostInfo) {} // ReconnectionPolicy interface is used by gocql to determine if reconnection // can be attempted after connection error. The interface allows gocql users // to implement their own logic to determine how to attempt reconnection. type ReconnectionPolicy interface { GetInterval(currentRetry int) time.Duration GetMaxRetries() int } // ConstantReconnectionPolicy has simple logic for returning a fixed reconnection interval. // // Examples of usage: // // cluster.ReconnectionPolicy = &gocql.ConstantReconnectionPolicy{MaxRetries: 10, Interval: 8 * time.Second} type ConstantReconnectionPolicy struct { MaxRetries int Interval time.Duration } func (c *ConstantReconnectionPolicy) GetInterval(currentRetry int) time.Duration { return c.Interval } func (c *ConstantReconnectionPolicy) GetMaxRetries() int { return c.MaxRetries } // ExponentialReconnectionPolicy returns a growing reconnection interval. type ExponentialReconnectionPolicy struct { MaxRetries int InitialInterval time.Duration MaxInterval time.Duration } func (e *ExponentialReconnectionPolicy) GetInterval(currentRetry int) time.Duration { max := e.MaxInterval if max < e.InitialInterval { max = math.MaxInt16 * time.Second } return getExponentialTime(e.InitialInterval, max, currentRetry) } func (e *ExponentialReconnectionPolicy) GetMaxRetries() int { return e.MaxRetries } type SpeculativeExecutionPolicy interface { Attempts() int Delay() time.Duration } type NonSpeculativeExecution struct{} func (sp NonSpeculativeExecution) Attempts() int { return 0 } // No additional attempts func (sp NonSpeculativeExecution) Delay() time.Duration { return 1 } // The delay. Must be positive to be used in a ticker. type SimpleSpeculativeExecution struct { NumAttempts int TimeoutDelay time.Duration } func (sp *SimpleSpeculativeExecution) Attempts() int { return sp.NumAttempts } func (sp *SimpleSpeculativeExecution) Delay() time.Duration { return sp.TimeoutDelay } cassandra-gocql-driver-1.7.0/policies_test.go000066400000000000000000000741311467504044300212640ustar00rootroot00000000000000// Copyright (c) 2015 The gocql Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "errors" "fmt" "net" "sort" "strings" "testing" "time" "github.com/hailocab/go-hostpool" ) // Tests of the round-robin host selection policy implementation func TestRoundRobbin(t *testing.T) { policy := RoundRobinHostPolicy() hosts := [...]*HostInfo{ {hostId: "0", connectAddress: net.IPv4(0, 0, 0, 1)}, {hostId: "1", connectAddress: net.IPv4(0, 0, 0, 2)}, } for _, host := range hosts { policy.AddHost(host) } got := make(map[string]bool) it := policy.Pick(nil) for h := it(); h != nil; h = it() { id := h.Info().hostId if got[id] { t.Fatalf("got duplicate host: %v", id) } got[id] = true } if len(got) != len(hosts) { t.Fatalf("expected %d hosts got %d", len(hosts), len(got)) } } // Tests of the token-aware host selection policy implementation with a // round-robin host selection policy fallback. func TestHostPolicy_TokenAware_SimpleStrategy(t *testing.T) { const keyspace = "myKeyspace" policy := TokenAwareHostPolicy(RoundRobinHostPolicy()) policyInternal := policy.(*tokenAwareHostPolicy) policyInternal.getKeyspaceName = func() string { return keyspace } policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { return nil, errors.New("not initalized") } query := &Query{routingInfo: &queryRoutingInfo{}} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) if iter == nil { t.Fatal("host iterator was nil") } actual := iter() if actual != nil { t.Fatalf("expected nil from iterator, but was %v", actual) } // set the hosts hosts := [...]*HostInfo{ {hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"00"}}, {hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"25"}}, {hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"50"}}, {hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"75"}}, } for _, host := range &hosts { policy.AddHost(host) } policy.SetPartitioner("OrderedPartitioner") policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) { if keyspaceName != keyspace { return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName) } return &KeyspaceMetadata{ Name: keyspace, StrategyClass: "SimpleStrategy", StrategyOptions: map[string]interface{}{ "class": "SimpleStrategy", "replication_factor": 2, }, }, nil } policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace}) // The SimpleStrategy above should generate the following replicas. // It's handy to have as reference here. assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ "myKeyspace": { {orderedToken("00"), []*HostInfo{hosts[0], hosts[1]}}, {orderedToken("25"), []*HostInfo{hosts[1], hosts[2]}}, {orderedToken("50"), []*HostInfo{hosts[2], hosts[3]}}, {orderedToken("75"), []*HostInfo{hosts[3], hosts[0]}}, }, }, policyInternal.getMetadataReadOnly().replicas) // now the token ring is configured query.RoutingKey([]byte("20")) iter = policy.Pick(query) // first token-aware hosts expectHosts(t, "hosts[0]", iter, "1") expectHosts(t, "hosts[1]", iter, "2") // then rest of the hosts expectHosts(t, "rest", iter, "0", "3") expectNoMoreHosts(t, iter) } // Tests of the host pool host selection policy implementation func TestHostPolicy_HostPool(t *testing.T) { policy := HostPoolHostPolicy(hostpool.New(nil)) hosts := []*HostInfo{ {hostId: "0", connectAddress: net.IPv4(10, 0, 0, 0)}, {hostId: "1", connectAddress: net.IPv4(10, 0, 0, 1)}, } // Using set host to control the ordering of the hosts as calling "AddHost" iterates the map // which will result in an unpredictable ordering policy.(*hostPoolHostPolicy).SetHosts(hosts) // the first host selected is actually at [1], but this is ok for RR // interleaved iteration should always increment the host iter := policy.Pick(nil) actualA := iter() if actualA.Info().HostID() != "0" { t.Errorf("Expected hosts[0] but was hosts[%s]", actualA.Info().HostID()) } actualA.Mark(nil) actualB := iter() if actualB.Info().HostID() != "1" { t.Errorf("Expected hosts[1] but was hosts[%s]", actualB.Info().HostID()) } actualB.Mark(fmt.Errorf("error")) actualC := iter() if actualC.Info().HostID() != "0" { t.Errorf("Expected hosts[0] but was hosts[%s]", actualC.Info().HostID()) } actualC.Mark(nil) actualD := iter() if actualD.Info().HostID() != "0" { t.Errorf("Expected hosts[0] but was hosts[%s]", actualD.Info().HostID()) } actualD.Mark(nil) } func TestHostPolicy_RoundRobin_NilHostInfo(t *testing.T) { policy := RoundRobinHostPolicy() host := &HostInfo{hostId: "host-1"} policy.AddHost(host) iter := policy.Pick(nil) next := iter() if next == nil { t.Fatal("got nil host") } else if v := next.Info(); v == nil { t.Fatal("got nil HostInfo") } else if v.HostID() != host.HostID() { t.Fatalf("expected host %v got %v", host, v) } next = iter() if next != nil { t.Errorf("expected to get nil host got %+v", next) if next.Info() == nil { t.Fatalf("HostInfo is nil") } } } func TestHostPolicy_TokenAware_NilHostInfo(t *testing.T) { policy := TokenAwareHostPolicy(RoundRobinHostPolicy()) policyInternal := policy.(*tokenAwareHostPolicy) policyInternal.getKeyspaceName = func() string { return "myKeyspace" } policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { return nil, errors.New("not initialized") } hosts := [...]*HostInfo{ {connectAddress: net.IPv4(10, 0, 0, 0), tokens: []string{"00"}}, {connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"25"}}, {connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"50"}}, {connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"75"}}, } for _, host := range hosts { policy.AddHost(host) } policy.SetPartitioner("OrderedPartitioner") query := &Query{routingInfo: &queryRoutingInfo{}} query.getKeyspace = func() string { return "myKeyspace" } query.RoutingKey([]byte("20")) iter := policy.Pick(query) next := iter() if next == nil { t.Fatal("got nil host") } else if v := next.Info(); v == nil { t.Fatal("got nil HostInfo") } else if !v.ConnectAddress().Equal(hosts[1].ConnectAddress()) { t.Fatalf("expected peer 1 got %v", v.ConnectAddress()) } // Empty the hosts to trigger the panic when using the fallback. for _, host := range hosts { policy.RemoveHost(host) } next = iter() if next != nil { t.Errorf("expected to get nil host got %+v", next) if next.Info() == nil { t.Fatalf("HostInfo is nil") } } } func TestCOWList_Add(t *testing.T) { var cow cowHostList toAdd := [...]net.IP{net.IPv4(10, 0, 0, 1), net.IPv4(10, 0, 0, 2), net.IPv4(10, 0, 0, 3)} for _, addr := range toAdd { if !cow.add(&HostInfo{connectAddress: addr}) { t.Fatal("did not add peer which was not in the set") } } hosts := cow.get() if len(hosts) != len(toAdd) { t.Fatalf("expected to have %d hosts got %d", len(toAdd), len(hosts)) } set := make(map[string]bool) for _, host := range hosts { set[string(host.ConnectAddress())] = true } for _, addr := range toAdd { if !set[string(addr)] { t.Errorf("addr was not in the host list: %q", addr) } } } // TestSimpleRetryPolicy makes sure that we only allow 1 + numRetries attempts func TestSimpleRetryPolicy(t *testing.T) { q := &Query{routingInfo: &queryRoutingInfo{}} // this should allow a total of 3 tries. rt := &SimpleRetryPolicy{NumRetries: 2} cases := []struct { attempts int allow bool }{ {0, true}, {1, true}, {2, true}, {3, false}, {4, false}, {5, false}, } for _, c := range cases { q.metrics = preFilledQueryMetrics(map[string]*hostMetrics{"127.0.0.1": {Attempts: c.attempts}}) if c.allow && !rt.Attempt(q) { t.Fatalf("should allow retry after %d attempts", c.attempts) } if !c.allow && rt.Attempt(q) { t.Fatalf("should not allow retry after %d attempts", c.attempts) } } } func TestExponentialBackoffPolicy(t *testing.T) { // test with defaults sut := &ExponentialBackoffRetryPolicy{NumRetries: 2} cases := []struct { attempts int delay time.Duration }{ {1, 100 * time.Millisecond}, {2, (2) * 100 * time.Millisecond}, {3, (2 * 2) * 100 * time.Millisecond}, {4, (2 * 2 * 2) * 100 * time.Millisecond}, } for _, c := range cases { // test 100 times for each case for i := 0; i < 100; i++ { d := sut.napTime(c.attempts) if d < c.delay-(100*time.Millisecond)/2 { t.Fatalf("Delay %d less than jitter min of %d", d, c.delay-100*time.Millisecond/2) } if d > c.delay+(100*time.Millisecond)/2 { t.Fatalf("Delay %d greater than jitter max of %d", d, c.delay+100*time.Millisecond/2) } } } } func TestDowngradingConsistencyRetryPolicy(t *testing.T) { q := &Query{cons: LocalQuorum, routingInfo: &queryRoutingInfo{}} rewt0 := &RequestErrWriteTimeout{ Received: 0, WriteType: "SIMPLE", } rewt1 := &RequestErrWriteTimeout{ Received: 1, WriteType: "BATCH", } rewt2 := &RequestErrWriteTimeout{ WriteType: "UNLOGGED_BATCH", } rert := &RequestErrReadTimeout{} reu0 := &RequestErrUnavailable{ Alive: 0, } reu1 := &RequestErrUnavailable{ Alive: 1, } // this should allow a total of 3 tries. consistencyLevels := []Consistency{Three, Two, One} rt := &DowngradingConsistencyRetryPolicy{ConsistencyLevelsToTry: consistencyLevels} cases := []struct { attempts int allow bool err error retryType RetryType }{ {0, true, rewt0, Rethrow}, {3, true, rewt1, Ignore}, {1, true, rewt2, Retry}, {2, true, rert, Retry}, {4, false, reu0, Rethrow}, {16, false, reu1, Retry}, } for _, c := range cases { q.metrics = preFilledQueryMetrics(map[string]*hostMetrics{"127.0.0.1": {Attempts: c.attempts}}) if c.retryType != rt.GetRetryType(c.err) { t.Fatalf("retry type should be %v", c.retryType) } if c.allow && !rt.Attempt(q) { t.Fatalf("should allow retry after %d attempts", c.attempts) } if !c.allow && rt.Attempt(q) { t.Fatalf("should not allow retry after %d attempts", c.attempts) } } } // expectHosts makes sure that the next len(hostIDs) returned from iter is a permutation of hostIDs. func expectHosts(t *testing.T, msg string, iter NextHost, hostIDs ...string) { t.Helper() expectedHostIDs := make(map[string]struct{}, len(hostIDs)) for i := range hostIDs { expectedHostIDs[hostIDs[i]] = struct{}{} } expectedStr := func() string { keys := make([]string, 0, len(expectedHostIDs)) for k := range expectedHostIDs { keys = append(keys, k) } sort.Strings(keys) return strings.Join(keys, ", ") } for len(expectedHostIDs) > 0 { host := iter() if host == nil || host.Info() == nil { t.Fatalf("%s: expected hostID one of {%s}, but got nil", msg, expectedStr()) } hostID := host.Info().HostID() if _, ok := expectedHostIDs[hostID]; !ok { t.Fatalf("%s: expected host ID one of {%s}, but got %s", msg, expectedStr(), hostID) } delete(expectedHostIDs, hostID) } } func expectNoMoreHosts(t *testing.T, iter NextHost) { t.Helper() host := iter() if host == nil { // success return } info := host.Info() if info == nil { t.Fatalf("expected no more hosts, but got host with nil Info()") return } t.Fatalf("expected no more hosts, but got %s", info.HostID()) } func TestHostPolicy_DCAwareRR(t *testing.T) { p := DCAwareRoundRobinPolicy("local") hosts := [...]*HostInfo{ {hostId: "0", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "local"}, {hostId: "1", connectAddress: net.ParseIP("10.0.0.2"), dataCenter: "local"}, {hostId: "2", connectAddress: net.ParseIP("10.0.0.3"), dataCenter: "remote"}, {hostId: "3", connectAddress: net.ParseIP("10.0.0.4"), dataCenter: "remote"}, } for _, host := range hosts { p.AddHost(host) } got := make(map[string]bool, len(hosts)) var dcs []string it := p.Pick(nil) for h := it(); h != nil; h = it() { id := h.Info().hostId dc := h.Info().dataCenter if got[id] { t.Fatalf("got duplicate host %s", id) } got[id] = true dcs = append(dcs, dc) } if len(got) != len(hosts) { t.Fatalf("expected %d hosts got %d", len(hosts), len(got)) } var remote bool for _, dc := range dcs { if dc == "local" { if remote { t.Fatalf("got local dc after remote: %v", dcs) } } else { remote = true } } } // Tests of the token-aware host selection policy implementation with a // DC aware round-robin host selection policy fallback // with {"class": "NetworkTopologyStrategy", "a": 1, "b": 1, "c": 1} replication. func TestHostPolicy_TokenAware(t *testing.T) { const keyspace = "myKeyspace" policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local")) policyInternal := policy.(*tokenAwareHostPolicy) policyInternal.getKeyspaceName = func() string { return keyspace } policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { return nil, errors.New("not initialized") } query := &Query{routingInfo: &queryRoutingInfo{}} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) if iter == nil { t.Fatal("host iterator was nil") } actual := iter() if actual != nil { t.Fatalf("expected nil from iterator, but was %v", actual) } // set the hosts hosts := [...]*HostInfo{ {hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"05"}, dataCenter: "remote1"}, {hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"10"}, dataCenter: "local"}, {hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"15"}, dataCenter: "remote2"}, {hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "remote1"}, {hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: "local"}, {hostId: "5", connectAddress: net.IPv4(10, 0, 0, 6), tokens: []string{"30"}, dataCenter: "remote2"}, {hostId: "6", connectAddress: net.IPv4(10, 0, 0, 7), tokens: []string{"35"}, dataCenter: "remote1"}, {hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: "local"}, {hostId: "8", connectAddress: net.IPv4(10, 0, 0, 9), tokens: []string{"45"}, dataCenter: "remote2"}, {hostId: "9", connectAddress: net.IPv4(10, 0, 0, 10), tokens: []string{"50"}, dataCenter: "remote1"}, {hostId: "10", connectAddress: net.IPv4(10, 0, 0, 11), tokens: []string{"55"}, dataCenter: "local"}, {hostId: "11", connectAddress: net.IPv4(10, 0, 0, 12), tokens: []string{"60"}, dataCenter: "remote2"}, } for _, host := range hosts { policy.AddHost(host) } // the token ring is not setup without the partitioner, but the fallback // should work if actual := policy.Pick(nil)(); actual == nil { t.Fatal("expected to get host from fallback got nil") } query.RoutingKey([]byte("30")) if actual := policy.Pick(query)(); actual == nil { t.Fatal("expected to get host from fallback got nil") } policy.SetPartitioner("OrderedPartitioner") policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) { if keyspaceName != keyspace { return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName) } return &KeyspaceMetadata{ Name: keyspace, StrategyClass: "NetworkTopologyStrategy", StrategyOptions: map[string]interface{}{ "class": "NetworkTopologyStrategy", "local": 1, "remote1": 1, "remote2": 1, }, }, nil } policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: "myKeyspace"}) // The NetworkTopologyStrategy above should generate the following replicas. // It's handy to have as reference here. assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ "myKeyspace": { {orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2]}}, {orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3]}}, {orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4]}}, {orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5]}}, {orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6]}}, {orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7]}}, {orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8]}}, {orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9]}}, {orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10]}}, {orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11]}}, {orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0]}}, {orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1]}}, }, }, policyInternal.getMetadataReadOnly().replicas) // now the token ring is configured query.RoutingKey([]byte("23")) iter = policy.Pick(query) // first should be host with matching token from the local DC expectHosts(t, "matching token from local DC", iter, "4") // next are in non-deterministic order expectHosts(t, "rest", iter, "0", "1", "2", "3", "5", "6", "7", "8", "9", "10", "11") expectNoMoreHosts(t, iter) } // Tests of the token-aware host selection policy implementation with a // DC aware round-robin host selection policy fallback // with {"class": "NetworkTopologyStrategy", "a": 2, "b": 2, "c": 2} replication. func TestHostPolicy_TokenAware_NetworkStrategy(t *testing.T) { const keyspace = "myKeyspace" policy := TokenAwareHostPolicy(DCAwareRoundRobinPolicy("local"), NonLocalReplicasFallback()) policyInternal := policy.(*tokenAwareHostPolicy) policyInternal.getKeyspaceName = func() string { return keyspace } policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { return nil, errors.New("not initialized") } query := &Query{routingInfo: &queryRoutingInfo{}} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) if iter == nil { t.Fatal("host iterator was nil") } actual := iter() if actual != nil { t.Fatalf("expected nil from iterator, but was %v", actual) } // set the hosts hosts := [...]*HostInfo{ {hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"05"}, dataCenter: "remote1"}, {hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"10"}, dataCenter: "local"}, {hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"15"}, dataCenter: "remote2"}, {hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "remote1"}, // 1 {hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: "local"}, // 2 {hostId: "5", connectAddress: net.IPv4(10, 0, 0, 6), tokens: []string{"30"}, dataCenter: "remote2"}, // 3 {hostId: "6", connectAddress: net.IPv4(10, 0, 0, 7), tokens: []string{"35"}, dataCenter: "remote1"}, // 4 {hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: "local"}, // 5 {hostId: "8", connectAddress: net.IPv4(10, 0, 0, 9), tokens: []string{"45"}, dataCenter: "remote2"}, // 6 {hostId: "9", connectAddress: net.IPv4(10, 0, 0, 10), tokens: []string{"50"}, dataCenter: "remote1"}, {hostId: "10", connectAddress: net.IPv4(10, 0, 0, 11), tokens: []string{"55"}, dataCenter: "local"}, {hostId: "11", connectAddress: net.IPv4(10, 0, 0, 12), tokens: []string{"60"}, dataCenter: "remote2"}, } for _, host := range hosts { policy.AddHost(host) } policy.SetPartitioner("OrderedPartitioner") policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) { if keyspaceName != keyspace { return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName) } return &KeyspaceMetadata{ Name: keyspace, StrategyClass: "NetworkTopologyStrategy", StrategyOptions: map[string]interface{}{ "class": "NetworkTopologyStrategy", "local": 2, "remote1": 2, "remote2": 2, }, }, nil } policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace}) // The NetworkTopologyStrategy above should generate the following replicas. // It's handy to have as reference here. assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ keyspace: { {orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2], hosts[3], hosts[4], hosts[5]}}, {orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3], hosts[4], hosts[5], hosts[6]}}, {orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4], hosts[5], hosts[6], hosts[7]}}, {orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5], hosts[6], hosts[7], hosts[8]}}, {orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6], hosts[7], hosts[8], hosts[9]}}, {orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7], hosts[8], hosts[9], hosts[10]}}, {orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8], hosts[9], hosts[10], hosts[11]}}, {orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9], hosts[10], hosts[11], hosts[0]}}, {orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10], hosts[11], hosts[0], hosts[1]}}, {orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11], hosts[0], hosts[1], hosts[2]}}, {orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0], hosts[1], hosts[2], hosts[3]}}, {orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1], hosts[2], hosts[3], hosts[4]}}, }, }, policyInternal.getMetadataReadOnly().replicas) // now the token ring is configured query.RoutingKey([]byte("18")) iter = policy.Pick(query) // first should be hosts with matching token from the local DC expectHosts(t, "matching token from local DC", iter, "4", "7") // rest should be hosts with matching token from remote DCs expectHosts(t, "matching token from remote DCs", iter, "3", "5", "6", "8") // followed by other hosts expectHosts(t, "rest", iter, "0", "1", "2", "9", "10", "11") expectNoMoreHosts(t, iter) } func TestHostPolicy_RackAwareRR(t *testing.T) { p := RackAwareRoundRobinPolicy("local", "b") hosts := [...]*HostInfo{ {hostId: "0", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "local", rack: "a"}, {hostId: "1", connectAddress: net.ParseIP("10.0.0.2"), dataCenter: "local", rack: "a"}, {hostId: "2", connectAddress: net.ParseIP("10.0.0.3"), dataCenter: "local", rack: "b"}, {hostId: "3", connectAddress: net.ParseIP("10.0.0.4"), dataCenter: "local", rack: "b"}, {hostId: "4", connectAddress: net.ParseIP("10.0.0.5"), dataCenter: "remote", rack: "a"}, {hostId: "5", connectAddress: net.ParseIP("10.0.0.6"), dataCenter: "remote", rack: "a"}, {hostId: "6", connectAddress: net.ParseIP("10.0.0.7"), dataCenter: "remote", rack: "b"}, {hostId: "7", connectAddress: net.ParseIP("10.0.0.8"), dataCenter: "remote", rack: "b"}, } for _, host := range hosts { p.AddHost(host) } it := p.Pick(nil) // Must start with rack-local hosts expectHosts(t, "rack-local hosts", it, "3", "2") // Then dc-local hosts expectHosts(t, "dc-local hosts", it, "0", "1") // Then the remote hosts expectHosts(t, "remote hosts", it, "4", "5", "6", "7") expectNoMoreHosts(t, it) } // Tests of the token-aware host selection policy implementation with a // DC & Rack aware round-robin host selection policy fallback func TestHostPolicy_TokenAware_RackAware(t *testing.T) { const keyspace = "myKeyspace" policy := TokenAwareHostPolicy(RackAwareRoundRobinPolicy("local", "b")) policyWithFallback := TokenAwareHostPolicy(RackAwareRoundRobinPolicy("local", "b"), NonLocalReplicasFallback()) policyInternal := policy.(*tokenAwareHostPolicy) policyInternal.getKeyspaceName = func() string { return keyspace } policyInternal.getKeyspaceMetadata = func(ks string) (*KeyspaceMetadata, error) { return nil, errors.New("not initialized") } policyWithFallbackInternal := policyWithFallback.(*tokenAwareHostPolicy) policyWithFallbackInternal.getKeyspaceName = policyInternal.getKeyspaceName policyWithFallbackInternal.getKeyspaceMetadata = policyInternal.getKeyspaceMetadata query := &Query{routingInfo: &queryRoutingInfo{}} query.getKeyspace = func() string { return keyspace } iter := policy.Pick(nil) if iter == nil { t.Fatal("host iterator was nil") } actual := iter() if actual != nil { t.Fatalf("expected nil from iterator, but was %v", actual) } // set the hosts hosts := [...]*HostInfo{ {hostId: "0", connectAddress: net.IPv4(10, 0, 0, 1), tokens: []string{"05"}, dataCenter: "remote", rack: "a"}, {hostId: "1", connectAddress: net.IPv4(10, 0, 0, 2), tokens: []string{"10"}, dataCenter: "remote", rack: "b"}, {hostId: "2", connectAddress: net.IPv4(10, 0, 0, 3), tokens: []string{"15"}, dataCenter: "local", rack: "a"}, {hostId: "3", connectAddress: net.IPv4(10, 0, 0, 4), tokens: []string{"20"}, dataCenter: "local", rack: "b"}, {hostId: "4", connectAddress: net.IPv4(10, 0, 0, 5), tokens: []string{"25"}, dataCenter: "remote", rack: "a"}, {hostId: "5", connectAddress: net.IPv4(10, 0, 0, 6), tokens: []string{"30"}, dataCenter: "remote", rack: "b"}, {hostId: "6", connectAddress: net.IPv4(10, 0, 0, 7), tokens: []string{"35"}, dataCenter: "local", rack: "a"}, {hostId: "7", connectAddress: net.IPv4(10, 0, 0, 8), tokens: []string{"40"}, dataCenter: "local", rack: "b"}, {hostId: "8", connectAddress: net.IPv4(10, 0, 0, 9), tokens: []string{"45"}, dataCenter: "remote", rack: "a"}, {hostId: "9", connectAddress: net.IPv4(10, 0, 0, 10), tokens: []string{"50"}, dataCenter: "remote", rack: "b"}, {hostId: "10", connectAddress: net.IPv4(10, 0, 0, 11), tokens: []string{"55"}, dataCenter: "local", rack: "a"}, {hostId: "11", connectAddress: net.IPv4(10, 0, 0, 12), tokens: []string{"60"}, dataCenter: "local", rack: "b"}, } for _, host := range hosts { policy.AddHost(host) policyWithFallback.AddHost(host) } // the token ring is not setup without the partitioner, but the fallback // should work if actual := policy.Pick(nil)(); actual == nil { t.Fatal("expected to get host from fallback got nil") } query.RoutingKey([]byte("30")) if actual := policy.Pick(query)(); actual == nil { t.Fatal("expected to get host from fallback got nil") } policy.SetPartitioner("OrderedPartitioner") policyWithFallback.SetPartitioner("OrderedPartitioner") policyInternal.getKeyspaceMetadata = func(keyspaceName string) (*KeyspaceMetadata, error) { if keyspaceName != keyspace { return nil, fmt.Errorf("unknown keyspace: %s", keyspaceName) } return &KeyspaceMetadata{ Name: keyspace, StrategyClass: "NetworkTopologyStrategy", StrategyOptions: map[string]interface{}{ "class": "NetworkTopologyStrategy", "local": 2, "remote": 2, }, }, nil } policyWithFallbackInternal.getKeyspaceMetadata = policyInternal.getKeyspaceMetadata policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: "myKeyspace"}) policyWithFallback.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: "myKeyspace"}) // The NetworkTopologyStrategy above should generate the following replicas. // It's handy to have as reference here. assertDeepEqual(t, "replicas", map[string]tokenRingReplicas{ "myKeyspace": { {orderedToken("05"), []*HostInfo{hosts[0], hosts[1], hosts[2], hosts[3]}}, {orderedToken("10"), []*HostInfo{hosts[1], hosts[2], hosts[3], hosts[4]}}, {orderedToken("15"), []*HostInfo{hosts[2], hosts[3], hosts[4], hosts[5]}}, {orderedToken("20"), []*HostInfo{hosts[3], hosts[4], hosts[5], hosts[6]}}, {orderedToken("25"), []*HostInfo{hosts[4], hosts[5], hosts[6], hosts[7]}}, {orderedToken("30"), []*HostInfo{hosts[5], hosts[6], hosts[7], hosts[8]}}, {orderedToken("35"), []*HostInfo{hosts[6], hosts[7], hosts[8], hosts[9]}}, {orderedToken("40"), []*HostInfo{hosts[7], hosts[8], hosts[9], hosts[10]}}, {orderedToken("45"), []*HostInfo{hosts[8], hosts[9], hosts[10], hosts[11]}}, {orderedToken("50"), []*HostInfo{hosts[9], hosts[10], hosts[11], hosts[0]}}, {orderedToken("55"), []*HostInfo{hosts[10], hosts[11], hosts[0], hosts[1]}}, {orderedToken("60"), []*HostInfo{hosts[11], hosts[0], hosts[1], hosts[2]}}, }, }, policyInternal.getMetadataReadOnly().replicas) query.RoutingKey([]byte("23")) // now the token ring is configured // Test the policy with fallback iter = policyWithFallback.Pick(query) // first should be host with matching token from the local DC & rack expectHosts(t, "matching token from local DC and local rack", iter, "7") // next should be host with matching token from local DC and other rack expectHosts(t, "matching token from local DC and non-local rack", iter, "6") // next should be hosts with matching token from other DC, in any order expectHosts(t, "matching token from non-local DC", iter, "4", "5") // then the local DC & rack that didn't match the token expectHosts(t, "non-matching token from local DC and local rack", iter, "3", "11") // then the local DC & other rack that didn't match the token expectHosts(t, "non-matching token from local DC and non-local rack", iter, "2", "10") // finally, the other DC that didn't match the token expectHosts(t, "non-matching token from non-local DC", iter, "0", "1", "8", "9") expectNoMoreHosts(t, iter) // Test the policy without fallback iter = policy.Pick(query) // first should be host with matching token from the local DC & Rack expectHosts(t, "matching token from local DC and local rack", iter, "7") // next should be the other two hosts from local DC & rack expectHosts(t, "non-matching token local DC and local rack", iter, "3", "11") // then the three hosts from the local DC but other rack expectHosts(t, "local DC, non-local rack", iter, "2", "6", "10") // then the 6 hosts from the other DC expectHosts(t, "non-local DC", iter, "0", "1", "4", "5", "8", "9") expectNoMoreHosts(t, iter) } cassandra-gocql-driver-1.7.0/prepared_cache.go000066400000000000000000000045161467504044300213430ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "sync" "github.com/gocql/gocql/internal/lru" ) const defaultMaxPreparedStmts = 1000 // preparedLRU is the prepared statement cache type preparedLRU struct { mu sync.Mutex lru *lru.Cache } func (p *preparedLRU) clear() { p.mu.Lock() defer p.mu.Unlock() for p.lru.Len() > 0 { p.lru.RemoveOldest() } } func (p *preparedLRU) add(key string, val *inflightPrepare) { p.mu.Lock() defer p.mu.Unlock() p.lru.Add(key, val) } func (p *preparedLRU) remove(key string) bool { p.mu.Lock() defer p.mu.Unlock() return p.lru.Remove(key) } func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) { p.mu.Lock() defer p.mu.Unlock() val, ok := p.lru.Get(key) if ok { return val.(*inflightPrepare), true } return fn(p.lru), false } func (p *preparedLRU) keyFor(hostID, keyspace, statement string) string { // TODO: we should just use a struct for the key in the map return hostID + keyspace + statement } func (p *preparedLRU) evictPreparedID(key string, id []byte) { p.mu.Lock() defer p.mu.Unlock() val, ok := p.lru.Get(key) if !ok { return } ifp, ok := val.(*inflightPrepare) if !ok { return } select { case <-ifp.done: if bytes.Equal(id, ifp.preparedStatment.id) { p.lru.Remove(key) } default: } } cassandra-gocql-driver-1.7.0/query_executor.go000066400000000000000000000132661467504044300215030ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "context" "sync" "time" ) type ExecutableQuery interface { borrowForExecution() // Used to ensure that the query stays alive for lifetime of a particular execution goroutine. releaseAfterExecution() // Used when a goroutine finishes its execution attempts, either with ok result or an error. execute(ctx context.Context, conn *Conn) *Iter attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) retryPolicy() RetryPolicy speculativeExecutionPolicy() SpeculativeExecutionPolicy GetRoutingKey() ([]byte, error) Keyspace() string Table() string IsIdempotent() bool withContext(context.Context) ExecutableQuery RetryableQuery } type queryExecutor struct { pool *policyConnPool policy HostSelectionPolicy } func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter { start := time.Now() iter := qry.execute(ctx, conn) end := time.Now() qry.attempt(q.pool.keyspace, end, start, iter, conn.host) return iter } func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy, hostIter NextHost, results chan *Iter) *Iter { ticker := time.NewTicker(sp.Delay()) defer ticker.Stop() for i := 0; i < sp.Attempts(); i++ { select { case <-ticker.C: qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). go q.run(ctx, qry, hostIter, results) case <-ctx.Done(): return &Iter{err: ctx.Err()} case iter := <-results: return iter } } return nil } func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { hostIter := q.policy.Pick(qry) // check if the query is not marked as idempotent, if // it is, we force the policy to NonSpeculative sp := qry.speculativeExecutionPolicy() if !qry.IsIdempotent() || sp.Attempts() == 0 { return q.do(qry.Context(), qry, hostIter), nil } // When speculative execution is enabled, we could be accessing the host iterator from multiple goroutines below. // To ensure we don't call it concurrently, we wrap the returned NextHost function here to synchronize access to it. var mu sync.Mutex origHostIter := hostIter hostIter = func() SelectedHost { mu.Lock() defer mu.Unlock() return origHostIter() } ctx, cancel := context.WithCancel(qry.Context()) defer cancel() results := make(chan *Iter, 1) // Launch the main execution qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). go q.run(ctx, qry, hostIter, results) // The speculative executions are launched _in addition_ to the main // execution, on a timer. So Speculation{2} would make 3 executions running // in total. if iter := q.speculate(ctx, qry, sp, hostIter, results); iter != nil { return iter, nil } select { case iter := <-results: return iter, nil case <-ctx.Done(): return &Iter{err: ctx.Err()}, nil } } func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter { selectedHost := hostIter() rt := qry.retryPolicy() var lastErr error var iter *Iter for selectedHost != nil { host := selectedHost.Info() if host == nil || !host.IsUp() { selectedHost = hostIter() continue } pool, ok := q.pool.getPool(host) if !ok { selectedHost = hostIter() continue } conn := pool.Pick() if conn == nil { selectedHost = hostIter() continue } iter = q.attemptQuery(ctx, qry, conn) iter.host = selectedHost.Info() // Update host switch iter.err { case context.Canceled, context.DeadlineExceeded, ErrNotFound: // those errors represents logical errors, they should not count // toward removing a node from the pool selectedHost.Mark(nil) return iter default: selectedHost.Mark(iter.err) } // Exit if the query was successful // or no retry policy defined or retry attempts were reached if iter.err == nil || rt == nil || !rt.Attempt(qry) { return iter } lastErr = iter.err // If query is unsuccessful, check the error with RetryPolicy to retry switch rt.GetRetryType(iter.err) { case Retry: // retry on the same host continue case Rethrow, Ignore: return iter case RetryNextHost: // retry on the next host selectedHost = hostIter() continue default: // Undefined? Return nil and error, this will panic in the requester return &Iter{err: ErrUnknownRetryType} } } if lastErr != nil { return &Iter{err: lastErr} } return &Iter{err: ErrNoConnections} } func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) { select { case results <- q.do(ctx, qry, hostIter): case <-ctx.Done(): } qry.releaseAfterExecution() } cassandra-gocql-driver-1.7.0/ring.go000066400000000000000000000077271467504044300173640ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "fmt" "sync" "sync/atomic" ) type ring struct { // endpoints are the set of endpoints which the driver will attempt to connect // to in the case it can not reach any of its hosts. They are also used to boot // strap the initial connection. endpoints []*HostInfo mu sync.RWMutex // hosts are the set of all hosts in the cassandra ring that we know of. // key of map is host_id. hosts map[string]*HostInfo // hostIPToUUID maps host native address to host_id. hostIPToUUID map[string]string hostList []*HostInfo pos uint32 // TODO: we should store the ring metadata here also. } func (r *ring) rrHost() *HostInfo { r.mu.RLock() defer r.mu.RUnlock() if len(r.hostList) == 0 { return nil } pos := int(atomic.AddUint32(&r.pos, 1) - 1) return r.hostList[pos%len(r.hostList)] } func (r *ring) getHostByIP(ip string) (*HostInfo, bool) { r.mu.RLock() defer r.mu.RUnlock() hi, ok := r.hostIPToUUID[ip] return r.hosts[hi], ok } func (r *ring) getHost(hostID string) *HostInfo { r.mu.RLock() host := r.hosts[hostID] r.mu.RUnlock() return host } func (r *ring) allHosts() []*HostInfo { r.mu.RLock() hosts := make([]*HostInfo, 0, len(r.hosts)) for _, host := range r.hosts { hosts = append(hosts, host) } r.mu.RUnlock() return hosts } func (r *ring) currentHosts() map[string]*HostInfo { r.mu.RLock() hosts := make(map[string]*HostInfo, len(r.hosts)) for k, v := range r.hosts { hosts[k] = v } r.mu.RUnlock() return hosts } func (r *ring) addOrUpdate(host *HostInfo) *HostInfo { if existingHost, ok := r.addHostIfMissing(host); ok { existingHost.update(host) host = existingHost } return host } func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) { if host.invalidConnectAddr() { panic(fmt.Sprintf("invalid host: %v", host)) } hostID := host.HostID() r.mu.Lock() if r.hosts == nil { r.hosts = make(map[string]*HostInfo) } if r.hostIPToUUID == nil { r.hostIPToUUID = make(map[string]string) } existing, ok := r.hosts[hostID] if !ok { r.hosts[hostID] = host r.hostIPToUUID[host.nodeToNodeAddress().String()] = hostID existing = host r.hostList = append(r.hostList, host) } r.mu.Unlock() return existing, ok } func (r *ring) removeHost(hostID string) bool { r.mu.Lock() if r.hosts == nil { r.hosts = make(map[string]*HostInfo) } if r.hostIPToUUID == nil { r.hostIPToUUID = make(map[string]string) } h, ok := r.hosts[hostID] if ok { for i, host := range r.hostList { if host.HostID() == hostID { r.hostList = append(r.hostList[:i], r.hostList[i+1:]...) break } } delete(r.hostIPToUUID, h.nodeToNodeAddress().String()) } delete(r.hosts, hostID) r.mu.Unlock() return ok } type clusterMetadata struct { mu sync.RWMutex partitioner string } func (c *clusterMetadata) setPartitioner(partitioner string) { c.mu.Lock() defer c.mu.Unlock() if c.partitioner != partitioner { // TODO: update other things now c.partitioner = partitioner } } cassandra-gocql-driver-1.7.0/ring_test.go000066400000000000000000000040661467504044300204140ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "net" "testing" ) func TestRing_AddHostIfMissing_Missing(t *testing.T) { ring := &ring{} host := &HostInfo{hostId: MustRandomUUID().String(), connectAddress: net.IPv4(1, 1, 1, 1)} h1, ok := ring.addHostIfMissing(host) if ok { t.Fatal("host was reported as already existing") } else if !h1.Equal(host) { t.Fatalf("hosts not equal that are returned %v != %v", h1, host) } else if h1 != host { t.Fatalf("returned host same pointer: %p != %p", h1, host) } } func TestRing_AddHostIfMissing_Existing(t *testing.T) { ring := &ring{} host := &HostInfo{hostId: MustRandomUUID().String(), connectAddress: net.IPv4(1, 1, 1, 1)} ring.addHostIfMissing(host) h2 := &HostInfo{hostId: host.hostId, connectAddress: net.IPv4(2, 2, 2, 2)} h1, ok := ring.addHostIfMissing(h2) if !ok { t.Fatal("host was not reported as already existing") } else if !h1.Equal(host) { t.Fatalf("hosts not equal that are returned %v != %v", h1, host) } else if h1 != host { t.Fatalf("returned host same pointer: %p != %p", h1, host) } } cassandra-gocql-driver-1.7.0/session.go000066400000000000000000001774111467504044300201060ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "context" "encoding/binary" "errors" "fmt" "io" "net" "strings" "sync" "sync/atomic" "time" "unicode" "github.com/gocql/gocql/internal/lru" ) // Session is the interface used by users to interact with the database. // // It's safe for concurrent use by multiple goroutines and a typical usage // scenario is to have one global session object to interact with the // whole Cassandra cluster. // // This type extends the Node interface by adding a convenient query builder // and automatically sets a default consistency level on all operations // that do not have a consistency level set. type Session struct { cons Consistency pageSize int prefetch float64 routingKeyInfoCache routingKeyInfoLRU schemaDescriber *schemaDescriber trace Tracer queryObserver QueryObserver batchObserver BatchObserver connectObserver ConnectObserver frameObserver FrameHeaderObserver streamObserver StreamObserver hostSource *ringDescriber ringRefresher *refreshDebouncer stmtsLRU *preparedLRU connCfg *ConnConfig executor *queryExecutor pool *policyConnPool policy HostSelectionPolicy ring ring metadata clusterMetadata mu sync.RWMutex control *controlConn // event handlers nodeEvents *eventDebouncer schemaEvents *eventDebouncer // ring metadata useSystemSchema bool hasAggregatesAndFunctions bool cfg ClusterConfig ctx context.Context cancel context.CancelFunc // sessionStateMu protects isClosed and isInitialized. sessionStateMu sync.RWMutex // isClosed is true once Session.Close is finished. isClosed bool // isClosing bool is true once Session.Close is started. isClosing bool // isInitialized is true once Session.init succeeds. // you can use initialized() to read the value. isInitialized bool logger StdLogger } var queryPool = &sync.Pool{ New: func() interface{} { return &Query{routingInfo: &queryRoutingInfo{}, refCount: 1} }, } func addrsToHosts(addrs []string, defaultPort int, logger StdLogger) ([]*HostInfo, error) { var hosts []*HostInfo for _, hostaddr := range addrs { resolvedHosts, err := hostInfo(hostaddr, defaultPort) if err != nil { // Try other hosts if unable to resolve DNS name if _, ok := err.(*net.DNSError); ok { logger.Printf("gocql: dns error: %v\n", err) continue } return nil, err } hosts = append(hosts, resolvedHosts...) } if len(hosts) == 0 { return nil, errors.New("failed to resolve any of the provided hostnames") } return hosts, nil } // NewSession wraps an existing Node. func NewSession(cfg ClusterConfig) (*Session, error) { // Check that hosts in the ClusterConfig is not empty if len(cfg.Hosts) < 1 { return nil, ErrNoHosts } // Check that either Authenticator is set or AuthProvider, not both if cfg.Authenticator != nil && cfg.AuthProvider != nil { return nil, errors.New("Can't use both Authenticator and AuthProvider in cluster config.") } // TODO: we should take a context in here at some point ctx, cancel := context.WithCancel(context.TODO()) s := &Session{ cons: cfg.Consistency, prefetch: 0.25, cfg: cfg, pageSize: cfg.PageSize, stmtsLRU: &preparedLRU{lru: lru.New(cfg.MaxPreparedStmts)}, connectObserver: cfg.ConnectObserver, ctx: ctx, cancel: cancel, logger: cfg.logger(), } s.schemaDescriber = newSchemaDescriber(s) s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent, s.logger) s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent, s.logger) s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo) s.hostSource = &ringDescriber{session: s} s.ringRefresher = newRefreshDebouncer(ringRefreshDebounceTime, func() error { return refreshRing(s.hostSource) }) if cfg.PoolConfig.HostSelectionPolicy == nil { cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy() } s.pool = cfg.PoolConfig.buildPool(s) s.policy = cfg.PoolConfig.HostSelectionPolicy s.policy.Init(s) s.executor = &queryExecutor{ pool: s.pool, policy: cfg.PoolConfig.HostSelectionPolicy, } s.queryObserver = cfg.QueryObserver s.batchObserver = cfg.BatchObserver s.connectObserver = cfg.ConnectObserver s.frameObserver = cfg.FrameHeaderObserver s.streamObserver = cfg.StreamObserver //Check the TLS Config before trying to connect to anything external connCfg, err := connConfig(&s.cfg) if err != nil { //TODO: Return a typed error return nil, fmt.Errorf("gocql: unable to create session: %v", err) } s.connCfg = connCfg if err := s.init(); err != nil { s.Close() if err == ErrNoConnectionsStarted { //This error used to be generated inside NewSession & returned directly //Forward it on up to be backwards compatible return nil, ErrNoConnectionsStarted } else { // TODO(zariel): dont wrap this error in fmt.Errorf, return a typed error return nil, fmt.Errorf("gocql: unable to create session: %v", err) } } return s, nil } func (s *Session) init() error { hosts, err := addrsToHosts(s.cfg.Hosts, s.cfg.Port, s.logger) if err != nil { return err } s.ring.endpoints = hosts if !s.cfg.disableControlConn { s.control = createControlConn(s) if s.cfg.ProtoVersion == 0 { proto, err := s.control.discoverProtocol(hosts) if err != nil { return fmt.Errorf("unable to discover protocol version: %v", err) } else if proto == 0 { return errors.New("unable to discovery protocol version") } // TODO(zariel): we really only need this in 1 place s.cfg.ProtoVersion = proto s.connCfg.ProtoVersion = proto } if err := s.control.connect(hosts); err != nil { return err } if !s.cfg.DisableInitialHostLookup { var partitioner string newHosts, partitioner, err := s.hostSource.GetHosts() if err != nil { return err } s.policy.SetPartitioner(partitioner) filteredHosts := make([]*HostInfo, 0, len(newHosts)) for _, host := range newHosts { if !s.cfg.filterHost(host) { filteredHosts = append(filteredHosts, host) } } hosts = filteredHosts } } for _, host := range hosts { // In case when host lookup is disabled and when we are in unit tests, // host are not discovered, and we are missing host ID information used // by internal logic. // Associate random UUIDs here with all hosts missing this information. if len(host.HostID()) == 0 { host.SetHostID(MustRandomUUID().String()) } } hostMap := make(map[string]*HostInfo, len(hosts)) for _, host := range hosts { hostMap[host.HostID()] = host } hosts = hosts[:0] // each host will increment left and decrement it after connecting and once // there's none left, we'll close hostCh var left int64 // we will receive up to len(hostMap) of messages so create a buffer so we // don't end up stuck in a goroutine if we stopped listening connectedCh := make(chan struct{}, len(hostMap)) // we add one here because we don't want to end up closing hostCh until we're // done looping and the decerement code might be reached before we've looped // again atomic.AddInt64(&left, 1) for _, host := range hostMap { host := s.ring.addOrUpdate(host) if s.cfg.filterHost(host) { continue } atomic.AddInt64(&left, 1) go func() { s.pool.addHost(host) connectedCh <- struct{}{} // if there are no hosts left, then close the hostCh to unblock the loop // below if its still waiting if atomic.AddInt64(&left, -1) == 0 { close(connectedCh) } }() hosts = append(hosts, host) } // once we're done looping we subtract the one we initially added and check // to see if we should close if atomic.AddInt64(&left, -1) == 0 { close(connectedCh) } // before waiting for them to connect, add them all to the policy so we can // utilize efficiencies by calling AddHosts if the policy supports it type bulkAddHosts interface { AddHosts([]*HostInfo) } if v, ok := s.policy.(bulkAddHosts); ok { v.AddHosts(hosts) } else { for _, host := range hosts { s.policy.AddHost(host) } } readyPolicy, _ := s.policy.(ReadyPolicy) // now loop over connectedCh until it's closed (meaning we've connected to all) // or until the policy says we're ready for range connectedCh { if readyPolicy != nil && readyPolicy.Ready() { break } } // TODO(zariel): we probably dont need this any more as we verify that we // can connect to one of the endpoints supplied by using the control conn. // See if there are any connections in the pool if s.cfg.ReconnectInterval > 0 { go s.reconnectDownedHosts(s.cfg.ReconnectInterval) } // If we disable the initial host lookup, we need to still check if the // cluster is using the newer system schema or not... however, if control // connection is disable, we really have no choice, so we just make our // best guess... if !s.cfg.disableControlConn && s.cfg.DisableInitialHostLookup { newer, _ := checkSystemSchema(s.control) s.useSystemSchema = newer } else { version := s.ring.rrHost().Version() s.useSystemSchema = version.AtLeast(3, 0, 0) s.hasAggregatesAndFunctions = version.AtLeast(2, 2, 0) } if s.pool.Size() == 0 { return ErrNoConnectionsStarted } // Invoke KeyspaceChanged to let the policy cache the session keyspace // parameters. This is used by tokenAwareHostPolicy to discover replicas. if !s.cfg.disableControlConn && s.cfg.Keyspace != "" { s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: s.cfg.Keyspace}) } s.sessionStateMu.Lock() s.isInitialized = true s.sessionStateMu.Unlock() return nil } // AwaitSchemaAgreement will wait until schema versions across all nodes in the // cluster are the same (as seen from the point of view of the control connection). // The maximum amount of time this takes is governed // by the MaxWaitSchemaAgreement setting in the configuration (default: 60s). // AwaitSchemaAgreement returns an error in case schema versions are not the same // after the timeout specified in MaxWaitSchemaAgreement elapses. func (s *Session) AwaitSchemaAgreement(ctx context.Context) error { if s.cfg.disableControlConn { return errNoControl } return s.control.withConn(func(conn *Conn) *Iter { return &Iter{err: conn.awaitSchemaAgreement(ctx)} }).err } func (s *Session) reconnectDownedHosts(intv time.Duration) { reconnectTicker := time.NewTicker(intv) defer reconnectTicker.Stop() for { select { case <-reconnectTicker.C: hosts := s.ring.allHosts() // Print session.ring for debug. if gocqlDebug { buf := bytes.NewBufferString("Session.ring:") for _, h := range hosts { buf.WriteString("[" + h.ConnectAddress().String() + ":" + h.State().String() + "]") } s.logger.Println(buf.String()) } for _, h := range hosts { if h.IsUp() { continue } // we let the pool call handleNodeConnected to change the host state s.pool.addHost(h) } case <-s.ctx.Done(): return } } } // SetConsistency sets the default consistency level for this session. This // setting can also be changed on a per-query basis and the default value // is Quorum. func (s *Session) SetConsistency(cons Consistency) { s.mu.Lock() s.cons = cons s.mu.Unlock() } // SetPageSize sets the default page size for this session. A value <= 0 will // disable paging. This setting can also be changed on a per-query basis. func (s *Session) SetPageSize(n int) { s.mu.Lock() s.pageSize = n s.mu.Unlock() } // SetPrefetch sets the default threshold for pre-fetching new pages. If // there are only p*pageSize rows remaining, the next page will be requested // automatically. This value can also be changed on a per-query basis and // the default value is 0.25. func (s *Session) SetPrefetch(p float64) { s.mu.Lock() s.prefetch = p s.mu.Unlock() } // SetTrace sets the default tracer for this session. This setting can also // be changed on a per-query basis. func (s *Session) SetTrace(trace Tracer) { s.mu.Lock() s.trace = trace s.mu.Unlock() } // Query generates a new query object for interacting with the database. // Further details of the query may be tweaked using the resulting query // value before the query is executed. Query is automatically prepared // if it has not previously been executed. func (s *Session) Query(stmt string, values ...interface{}) *Query { qry := queryPool.Get().(*Query) qry.session = s qry.stmt = stmt qry.values = values qry.defaultsFromSession() return qry } type QueryInfo struct { Id []byte Args []ColumnInfo Rval []ColumnInfo PKeyColumns []int } // Bind generates a new query object based on the query statement passed in. // The query is automatically prepared if it has not previously been executed. // The binding callback allows the application to define which query argument // values will be marshalled as part of the query execution. // During execution, the meta data of the prepared query will be routed to the // binding callback, which is responsible for producing the query argument values. func (s *Session) Bind(stmt string, b func(q *QueryInfo) ([]interface{}, error)) *Query { qry := queryPool.Get().(*Query) qry.session = s qry.stmt = stmt qry.binding = b qry.defaultsFromSession() return qry } // Close closes all connections. The session is unusable after this // operation. func (s *Session) Close() { s.sessionStateMu.Lock() if s.isClosing { s.sessionStateMu.Unlock() return } s.isClosing = true s.sessionStateMu.Unlock() if s.pool != nil { s.pool.Close() } if s.control != nil { s.control.close() } if s.nodeEvents != nil { s.nodeEvents.stop() } if s.schemaEvents != nil { s.schemaEvents.stop() } if s.ringRefresher != nil { s.ringRefresher.stop() } if s.cancel != nil { s.cancel() } s.sessionStateMu.Lock() s.isClosed = true s.sessionStateMu.Unlock() } func (s *Session) Closed() bool { s.sessionStateMu.RLock() closed := s.isClosed s.sessionStateMu.RUnlock() return closed } func (s *Session) initialized() bool { s.sessionStateMu.RLock() initialized := s.isInitialized s.sessionStateMu.RUnlock() return initialized } func (s *Session) executeQuery(qry *Query) (it *Iter) { // fail fast if s.Closed() { return &Iter{err: ErrSessionClosed} } iter, err := s.executor.executeQuery(qry) if err != nil { return &Iter{err: err} } if iter == nil { panic("nil iter") } return iter } func (s *Session) removeHost(h *HostInfo) { s.policy.RemoveHost(h) hostID := h.HostID() s.pool.removeHost(hostID) s.ring.removeHost(hostID) } // KeyspaceMetadata returns the schema metadata for the keyspace specified. Returns an error if the keyspace does not exist. func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) { // fail fast if s.Closed() { return nil, ErrSessionClosed } else if keyspace == "" { return nil, ErrNoKeyspace } return s.schemaDescriber.getSchema(keyspace) } func (s *Session) getConn() *Conn { hosts := s.ring.allHosts() for _, host := range hosts { if !host.IsUp() { continue } pool, ok := s.pool.getPool(host) if !ok { continue } else if conn := pool.Pick(); conn != nil { return conn } } return nil } // returns routing key indexes and type info func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyInfo, error) { s.routingKeyInfoCache.mu.Lock() entry, cached := s.routingKeyInfoCache.lru.Get(stmt) if cached { // done accessing the cache s.routingKeyInfoCache.mu.Unlock() // the entry is an inflight struct similar to that used by // Conn to prepare statements inflight := entry.(*inflightCachedEntry) // wait for any inflight work inflight.wg.Wait() if inflight.err != nil { return nil, inflight.err } key, _ := inflight.value.(*routingKeyInfo) return key, nil } // create a new inflight entry while the data is created inflight := new(inflightCachedEntry) inflight.wg.Add(1) defer inflight.wg.Done() s.routingKeyInfoCache.lru.Add(stmt, inflight) s.routingKeyInfoCache.mu.Unlock() var ( info *preparedStatment partitionKey []*ColumnMetadata ) conn := s.getConn() if conn == nil { // TODO: better error? inflight.err = errors.New("gocql: unable to fetch prepared info: no connection available") return nil, inflight.err } // get the query info for the statement info, inflight.err = conn.prepareStatement(ctx, stmt, nil) if inflight.err != nil { // don't cache this error s.routingKeyInfoCache.Remove(stmt) return nil, inflight.err } // TODO: it would be nice to mark hosts here but as we are not using the policies // to fetch hosts we cant if info.request.colCount == 0 { // no arguments, no routing key, and no error return nil, nil } table := info.request.table keyspace := info.request.keyspace if len(info.request.pkeyColumns) > 0 { // proto v4 dont need to calculate primary key columns types := make([]TypeInfo, len(info.request.pkeyColumns)) for i, col := range info.request.pkeyColumns { types[i] = info.request.columns[col].TypeInfo } routingKeyInfo := &routingKeyInfo{ indexes: info.request.pkeyColumns, types: types, keyspace: keyspace, table: table, } inflight.value = routingKeyInfo return routingKeyInfo, nil } var keyspaceMetadata *KeyspaceMetadata keyspaceMetadata, inflight.err = s.KeyspaceMetadata(info.request.columns[0].Keyspace) if inflight.err != nil { // don't cache this error s.routingKeyInfoCache.Remove(stmt) return nil, inflight.err } tableMetadata, found := keyspaceMetadata.Tables[table] if !found { // unlikely that the statement could be prepared and the metadata for // the table couldn't be found, but this may indicate either a bug // in the metadata code, or that the table was just dropped. inflight.err = ErrNoMetadata // don't cache this error s.routingKeyInfoCache.Remove(stmt) return nil, inflight.err } partitionKey = tableMetadata.PartitionKey size := len(partitionKey) routingKeyInfo := &routingKeyInfo{ indexes: make([]int, size), types: make([]TypeInfo, size), keyspace: keyspace, table: table, } for keyIndex, keyColumn := range partitionKey { // set an indicator for checking if the mapping is missing routingKeyInfo.indexes[keyIndex] = -1 // find the column in the query info for argIndex, boundColumn := range info.request.columns { if keyColumn.Name == boundColumn.Name { // there may be many such bound columns, pick the first routingKeyInfo.indexes[keyIndex] = argIndex routingKeyInfo.types[keyIndex] = boundColumn.TypeInfo break } } if routingKeyInfo.indexes[keyIndex] == -1 { // missing a routing key column mapping // no routing key, and no error return nil, nil } } // cache this result inflight.value = routingKeyInfo return routingKeyInfo, nil } func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter { return conn.executeBatch(ctx, b) } func (s *Session) executeBatch(batch *Batch) *Iter { // fail fast if s.Closed() { return &Iter{err: ErrSessionClosed} } // Prevent the execution of the batch if greater than the limit // Currently batches have a limit of 65536 queries. // https://datastax-oss.atlassian.net/browse/JAVA-229 if batch.Size() > BatchSizeMaximum { return &Iter{err: ErrTooManyStmts} } iter, err := s.executor.executeQuery(batch) if err != nil { return &Iter{err: err} } return iter } // ExecuteBatch executes a batch operation and returns nil if successful // otherwise an error is returned describing the failure. func (s *Session) ExecuteBatch(batch *Batch) error { iter := s.executeBatch(batch) return iter.Close() } // ExecuteBatchCAS executes a batch operation and returns true if successful and // an iterator (to scan additional rows if more than one conditional statement) // was sent. // Further scans on the interator must also remember to include // the applied boolean as the first argument to *Iter.Scan func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bool, iter *Iter, err error) { iter = s.executeBatch(batch) if err := iter.checkErrAndNotFound(); err != nil { iter.Close() return false, nil, err } if len(iter.Columns()) > 1 { dest = append([]interface{}{&applied}, dest...) iter.Scan(dest...) } else { iter.Scan(&applied) } return applied, iter, nil } // MapExecuteBatchCAS executes a batch operation much like ExecuteBatchCAS, // however it accepts a map rather than a list of arguments for the initial // scan. func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{}) (applied bool, iter *Iter, err error) { iter = s.executeBatch(batch) if err := iter.checkErrAndNotFound(); err != nil { iter.Close() return false, nil, err } iter.MapScan(dest) applied = dest["[applied]"].(bool) delete(dest, "[applied]") // we usually close here, but instead of closing, just returin an error // if MapScan failed. Although Close just returns err, using Close // here might be confusing as we are not actually closing the iter return applied, iter, iter.err } type hostMetrics struct { // Attempts is count of how many times this query has been attempted for this host. // An attempt is either a retry or fetching next page of results. Attempts int // TotalLatency is the sum of attempt latencies for this host in nanoseconds. TotalLatency int64 } type queryMetrics struct { l sync.RWMutex m map[string]*hostMetrics // totalAttempts is total number of attempts. // Equal to sum of all hostMetrics' Attempts. totalAttempts int } // preFilledQueryMetrics initializes new queryMetrics based on per-host supplied data. func preFilledQueryMetrics(m map[string]*hostMetrics) *queryMetrics { qm := &queryMetrics{m: m} for _, hm := range qm.m { qm.totalAttempts += hm.Attempts } return qm } // hostMetrics returns a snapshot of metrics for given host. // If the metrics for host don't exist, they are created. func (qm *queryMetrics) hostMetrics(host *HostInfo) *hostMetrics { qm.l.Lock() metrics := qm.hostMetricsLocked(host) copied := new(hostMetrics) *copied = *metrics qm.l.Unlock() return copied } // hostMetricsLocked gets or creates host metrics for given host. // It must be called only while holding qm.l lock. func (qm *queryMetrics) hostMetricsLocked(host *HostInfo) *hostMetrics { metrics, exists := qm.m[host.ConnectAddress().String()] if !exists { // if the host is not in the map, it means it's been accessed for the first time metrics = &hostMetrics{} qm.m[host.ConnectAddress().String()] = metrics } return metrics } // attempts returns the number of times the query was executed. func (qm *queryMetrics) attempts() int { qm.l.Lock() attempts := qm.totalAttempts qm.l.Unlock() return attempts } func (qm *queryMetrics) latency() int64 { qm.l.Lock() var ( attempts int latency int64 ) for _, metric := range qm.m { attempts += metric.Attempts latency += metric.TotalLatency } qm.l.Unlock() if attempts > 0 { return latency / int64(attempts) } return 0 } // attempt adds given number of attempts and latency for given host. // It returns previous total attempts. // If needsHostMetrics is true, a copy of updated hostMetrics is returned. func (qm *queryMetrics) attempt(addAttempts int, addLatency time.Duration, host *HostInfo, needsHostMetrics bool) (int, *hostMetrics) { qm.l.Lock() totalAttempts := qm.totalAttempts qm.totalAttempts += addAttempts updateHostMetrics := qm.hostMetricsLocked(host) updateHostMetrics.Attempts += addAttempts updateHostMetrics.TotalLatency += addLatency.Nanoseconds() var hostMetricsCopy *hostMetrics if needsHostMetrics { hostMetricsCopy = new(hostMetrics) *hostMetricsCopy = *updateHostMetrics } qm.l.Unlock() return totalAttempts, hostMetricsCopy } // Query represents a CQL statement that can be executed. type Query struct { stmt string values []interface{} cons Consistency pageSize int routingKey []byte pageState []byte prefetch float64 trace Tracer observer QueryObserver session *Session conn *Conn rt RetryPolicy spec SpeculativeExecutionPolicy binding func(q *QueryInfo) ([]interface{}, error) serialCons SerialConsistency defaultTimestamp bool defaultTimestampValue int64 disableSkipMetadata bool context context.Context idempotent bool customPayload map[string][]byte metrics *queryMetrics refCount uint32 disableAutoPage bool // getKeyspace is field so that it can be overriden in tests getKeyspace func() string // used by control conn queries to prevent triggering a write to systems // tables in AWS MCS see skipPrepare bool // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo } type queryRoutingInfo struct { // mu protects contents of queryRoutingInfo. mu sync.RWMutex keyspace string table string } func (q *Query) defaultsFromSession() { s := q.session s.mu.RLock() q.cons = s.cons q.pageSize = s.pageSize q.trace = s.trace q.observer = s.queryObserver q.prefetch = s.prefetch q.rt = s.cfg.RetryPolicy q.serialCons = s.cfg.SerialConsistency q.defaultTimestamp = s.cfg.DefaultTimestamp q.idempotent = s.cfg.DefaultIdempotence q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} q.spec = &NonSpeculativeExecution{} s.mu.RUnlock() } // Statement returns the statement that was used to generate this query. func (q Query) Statement() string { return q.stmt } // Values returns the values passed in via Bind. // This can be used by a wrapper type that needs to access the bound values. func (q Query) Values() []interface{} { return q.values } // String implements the stringer interface. func (q Query) String() string { return fmt.Sprintf("[query statement=%q values=%+v consistency=%s]", q.stmt, q.values, q.cons) } // Attempts returns the number of times the query was executed. func (q *Query) Attempts() int { return q.metrics.attempts() } func (q *Query) AddAttempts(i int, host *HostInfo) { q.metrics.attempt(i, 0, host, false) } // Latency returns the average amount of nanoseconds per attempt of the query. func (q *Query) Latency() int64 { return q.metrics.latency() } func (q *Query) AddLatency(l int64, host *HostInfo) { q.metrics.attempt(0, time.Duration(l)*time.Nanosecond, host, false) } // Consistency sets the consistency level for this query. If no consistency // level have been set, the default consistency level of the cluster // is used. func (q *Query) Consistency(c Consistency) *Query { q.cons = c return q } // GetConsistency returns the currently configured consistency level for // the query. func (q *Query) GetConsistency() Consistency { return q.cons } // Same as Consistency but without a return value func (q *Query) SetConsistency(c Consistency) { q.cons = c } // CustomPayload sets the custom payload level for this query. func (q *Query) CustomPayload(customPayload map[string][]byte) *Query { q.customPayload = customPayload return q } func (q *Query) Context() context.Context { if q.context == nil { return context.Background() } return q.context } // Trace enables tracing of this query. Look at the documentation of the // Tracer interface to learn more about tracing. func (q *Query) Trace(trace Tracer) *Query { q.trace = trace return q } // Observer enables query-level observer on this query. // The provided observer will be called every time this query is executed. func (q *Query) Observer(observer QueryObserver) *Query { q.observer = observer return q } // PageSize will tell the iterator to fetch the result in pages of size n. // This is useful for iterating over large result sets, but setting the // page size too low might decrease the performance. This feature is only // available in Cassandra 2 and onwards. func (q *Query) PageSize(n int) *Query { q.pageSize = n return q } // DefaultTimestamp will enable the with default timestamp flag on the query. // If enable, this will replace the server side assigned // timestamp as default timestamp. Note that a timestamp in the query itself // will still override this timestamp. This is entirely optional. // // Only available on protocol >= 3 func (q *Query) DefaultTimestamp(enable bool) *Query { q.defaultTimestamp = enable return q } // WithTimestamp will enable the with default timestamp flag on the query // like DefaultTimestamp does. But also allows to define value for timestamp. // It works the same way as USING TIMESTAMP in the query itself, but // should not break prepared query optimization. // // Only available on protocol >= 3 func (q *Query) WithTimestamp(timestamp int64) *Query { q.DefaultTimestamp(true) q.defaultTimestampValue = timestamp return q } // RoutingKey sets the routing key to use when a token aware connection // pool is used to optimize the routing of this query. func (q *Query) RoutingKey(routingKey []byte) *Query { q.routingKey = routingKey return q } func (q *Query) withContext(ctx context.Context) ExecutableQuery { // I really wish go had covariant types return q.WithContext(ctx) } // WithContext returns a shallow copy of q with its context // set to ctx. // // The provided context controls the entire lifetime of executing a // query, queries will be canceled and return once the context is // canceled. func (q *Query) WithContext(ctx context.Context) *Query { q2 := *q q2.context = ctx return &q2 } // Deprecate: does nothing, cancel the context passed to WithContext func (q *Query) Cancel() { // TODO: delete } func (q *Query) execute(ctx context.Context, conn *Conn) *Iter { return conn.executeQuery(ctx, q) } func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { latency := end.Sub(start) attempt, metricsForHost := q.metrics.attempt(1, latency, host, q.observer != nil) if q.observer != nil { q.observer.ObserveQuery(q.Context(), ObservedQuery{ Keyspace: keyspace, Statement: q.stmt, Values: q.values, Start: start, End: end, Rows: iter.numRows, Host: host, Metrics: metricsForHost, Err: iter.err, Attempt: attempt, }) } } func (q *Query) retryPolicy() RetryPolicy { return q.rt } // Keyspace returns the keyspace the query will be executed against. func (q *Query) Keyspace() string { if q.getKeyspace != nil { return q.getKeyspace() } if q.routingInfo.keyspace != "" { return q.routingInfo.keyspace } if q.session == nil { return "" } // TODO(chbannis): this should be parsed from the query or we should let // this be set by users. return q.session.cfg.Keyspace } // Table returns name of the table the query will be executed against. func (q *Query) Table() string { return q.routingInfo.table } // GetRoutingKey gets the routing key to use for routing this query. If // a routing key has not been explicitly set, then the routing key will // be constructed if possible using the keyspace's schema and the query // info for this query statement. If the routing key cannot be determined // then nil will be returned with no error. On any error condition, // an error description will be returned. func (q *Query) GetRoutingKey() ([]byte, error) { if q.routingKey != nil { return q.routingKey, nil } else if q.binding != nil && len(q.values) == 0 { // If this query was created using session.Bind we wont have the query // values yet, so we have to pass down to the next policy. // TODO: Remove this and handle this case return nil, nil } // try to determine the routing key routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), q.stmt) if err != nil { return nil, err } if routingKeyInfo != nil { q.routingInfo.mu.Lock() q.routingInfo.keyspace = routingKeyInfo.keyspace q.routingInfo.table = routingKeyInfo.table q.routingInfo.mu.Unlock() } return createRoutingKey(routingKeyInfo, q.values) } func (q *Query) shouldPrepare() bool { stmt := strings.TrimLeftFunc(strings.TrimRightFunc(q.stmt, func(r rune) bool { return unicode.IsSpace(r) || r == ';' }), unicode.IsSpace) var stmtType string if n := strings.IndexFunc(stmt, unicode.IsSpace); n >= 0 { stmtType = strings.ToLower(stmt[:n]) } if stmtType == "begin" { if n := strings.LastIndexFunc(stmt, unicode.IsSpace); n >= 0 { stmtType = strings.ToLower(stmt[n+1:]) } } switch stmtType { case "select", "insert", "update", "delete", "batch": return true } return false } // SetPrefetch sets the default threshold for pre-fetching new pages. If // there are only p*pageSize rows remaining, the next page will be requested // automatically. func (q *Query) Prefetch(p float64) *Query { q.prefetch = p return q } // RetryPolicy sets the policy to use when retrying the query. func (q *Query) RetryPolicy(r RetryPolicy) *Query { q.rt = r return q } // SetSpeculativeExecutionPolicy sets the execution policy func (q *Query) SetSpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Query { q.spec = sp return q } // speculativeExecutionPolicy fetches the policy func (q *Query) speculativeExecutionPolicy() SpeculativeExecutionPolicy { return q.spec } // IsIdempotent returns whether the query is marked as idempotent. // Non-idempotent query won't be retried. // See "Retries and speculative execution" in package docs for more details. func (q *Query) IsIdempotent() bool { return q.idempotent } // Idempotent marks the query as being idempotent or not depending on // the value. // Non-idempotent query won't be retried. // See "Retries and speculative execution" in package docs for more details. func (q *Query) Idempotent(value bool) *Query { q.idempotent = value return q } // Bind sets query arguments of query. This can also be used to rebind new query arguments // to an existing query instance. func (q *Query) Bind(v ...interface{}) *Query { q.values = v q.pageState = nil return q } // SerialConsistency sets the consistency level for the // serial phase of conditional updates. That consistency can only be // either SERIAL or LOCAL_SERIAL and if not present, it defaults to // SERIAL. This option will be ignored for anything else that a // conditional update/insert. func (q *Query) SerialConsistency(cons SerialConsistency) *Query { q.serialCons = cons return q } // PageState sets the paging state for the query to resume paging from a specific // point in time. Setting this will disable to query paging for this query, and // must be used for all subsequent pages. func (q *Query) PageState(state []byte) *Query { q.pageState = state q.disableAutoPage = true return q } // NoSkipMetadata will override the internal result metadata cache so that the driver does not // send skip_metadata for queries, this means that the result will always contain // the metadata to parse the rows and will not reuse the metadata from the prepared // statement. This should only be used to work around cassandra bugs, such as when using // CAS operations which do not end in Cas. // // See https://issues.apache.org/jira/browse/CASSANDRA-11099 // https://github.com/apache/cassandra-gocql-driver/issues/612 func (q *Query) NoSkipMetadata() *Query { q.disableSkipMetadata = true return q } // Exec executes the query without returning any rows. func (q *Query) Exec() error { return q.Iter().Close() } func isUseStatement(stmt string) bool { if len(stmt) < 3 { return false } return strings.EqualFold(stmt[0:3], "use") } // Iter executes the query and returns an iterator capable of iterating // over all results. func (q *Query) Iter() *Iter { if isUseStatement(q.stmt) { return &Iter{err: ErrUseStmt} } // if the query was specifically run on a connection then re-use that // connection when fetching the next results if q.conn != nil { return q.conn.executeQuery(q.Context(), q) } return q.session.executeQuery(q) } // MapScan executes the query, copies the columns of the first selected // row into the map pointed at by m and discards the rest. If no rows // were selected, ErrNotFound is returned. func (q *Query) MapScan(m map[string]interface{}) error { iter := q.Iter() if err := iter.checkErrAndNotFound(); err != nil { return err } iter.MapScan(m) return iter.Close() } // Scan executes the query, copies the columns of the first selected // row into the values pointed at by dest and discards the rest. If no rows // were selected, ErrNotFound is returned. func (q *Query) Scan(dest ...interface{}) error { iter := q.Iter() if err := iter.checkErrAndNotFound(); err != nil { return err } iter.Scan(dest...) return iter.Close() } // ScanCAS executes a lightweight transaction (i.e. an UPDATE or INSERT // statement containing an IF clause). If the transaction fails because // the existing values did not match, the previous values will be stored // in dest. // // As for INSERT .. IF NOT EXISTS, previous values will be returned as if // SELECT * FROM. So using ScanCAS with INSERT is inherently prone to // column mismatching. Use MapScanCAS to capture them safely. func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) { q.disableSkipMetadata = true iter := q.Iter() if err := iter.checkErrAndNotFound(); err != nil { return false, err } if len(iter.Columns()) > 1 { dest = append([]interface{}{&applied}, dest...) iter.Scan(dest...) } else { iter.Scan(&applied) } return applied, iter.Close() } // MapScanCAS executes a lightweight transaction (i.e. an UPDATE or INSERT // statement containing an IF clause). If the transaction fails because // the existing values did not match, the previous values will be stored // in dest map. // // As for INSERT .. IF NOT EXISTS, previous values will be returned as if // SELECT * FROM. So using ScanCAS with INSERT is inherently prone to // column mismatching. MapScanCAS is added to capture them safely. func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error) { q.disableSkipMetadata = true iter := q.Iter() if err := iter.checkErrAndNotFound(); err != nil { return false, err } iter.MapScan(dest) applied = dest["[applied]"].(bool) delete(dest, "[applied]") return applied, iter.Close() } // Release releases a query back into a pool of queries. Released Queries // cannot be reused. // // Example: // // qry := session.Query("SELECT * FROM my_table") // qry.Exec() // qry.Release() func (q *Query) Release() { q.decRefCount() } // reset zeroes out all fields of a query so that it can be safely pooled. func (q *Query) reset() { *q = Query{routingInfo: &queryRoutingInfo{}, refCount: 1} } func (q *Query) incRefCount() { atomic.AddUint32(&q.refCount, 1) } func (q *Query) decRefCount() { if res := atomic.AddUint32(&q.refCount, ^uint32(0)); res == 0 { // do release q.reset() queryPool.Put(q) } } func (q *Query) borrowForExecution() { q.incRefCount() } func (q *Query) releaseAfterExecution() { q.decRefCount() } // Iter represents an iterator that can be used to iterate over all rows that // were returned by a query. The iterator might send additional queries to the // database during the iteration if paging was enabled. type Iter struct { err error pos int meta resultMetadata numRows int next *nextIter host *HostInfo framer *framer closed int32 } // Host returns the host which the query was sent to. func (iter *Iter) Host() *HostInfo { return iter.host } // Columns returns the name and type of the selected columns. func (iter *Iter) Columns() []ColumnInfo { return iter.meta.columns } type Scanner interface { // Next advances the row pointer to point at the next row, the row is valid until // the next call of Next. It returns true if there is a row which is available to be // scanned into with Scan. // Next must be called before every call to Scan. Next() bool // Scan copies the current row's columns into dest. If the length of dest does not equal // the number of columns returned in the row an error is returned. If an error is encountered // when unmarshalling a column into the value in dest an error is returned and the row is invalidated // until the next call to Next. // Next must be called before calling Scan, if it is not an error is returned. Scan(...interface{}) error // Err returns the if there was one during iteration that resulted in iteration being unable to complete. // Err will also release resources held by the iterator, the Scanner should not used after being called. Err() error } type iterScanner struct { iter *Iter cols [][]byte valid bool } func (is *iterScanner) Next() bool { iter := is.iter if iter.err != nil { return false } if iter.pos >= iter.numRows { if iter.next != nil { is.iter = iter.next.fetch() return is.Next() } return false } for i := 0; i < len(is.cols); i++ { col, err := iter.readColumn() if err != nil { iter.err = err return false } is.cols[i] = col } iter.pos++ is.valid = true return true } func scanColumn(p []byte, col ColumnInfo, dest []interface{}) (int, error) { if dest[0] == nil { return 1, nil } if col.TypeInfo.Type() == TypeTuple { // this will panic, actually a bug, please report tuple := col.TypeInfo.(TupleTypeInfo) count := len(tuple.Elems) // here we pass in a slice of the struct which has the number number of // values as elements in the tuple if err := Unmarshal(col.TypeInfo, p, dest[:count]); err != nil { return 0, err } return count, nil } else { if err := Unmarshal(col.TypeInfo, p, dest[0]); err != nil { return 0, err } return 1, nil } } func (is *iterScanner) Scan(dest ...interface{}) error { if !is.valid { return errors.New("gocql: Scan called without calling Next") } iter := is.iter // currently only support scanning into an expand tuple, such that its the same // as scanning in more values from a single column if len(dest) != iter.meta.actualColCount { return fmt.Errorf("gocql: not enough columns to scan into: have %d want %d", len(dest), iter.meta.actualColCount) } // i is the current position in dest, could posible replace it and just use // slices of dest i := 0 var err error for _, col := range iter.meta.columns { var n int n, err = scanColumn(is.cols[i], col, dest[i:]) if err != nil { break } i += n } is.valid = false return err } func (is *iterScanner) Err() error { iter := is.iter is.iter = nil is.cols = nil is.valid = false return iter.Close() } // Scanner returns a row Scanner which provides an interface to scan rows in a manner which is // similar to database/sql. The iter should NOT be used again after calling this method. func (iter *Iter) Scanner() Scanner { if iter == nil { return nil } return &iterScanner{iter: iter, cols: make([][]byte, len(iter.meta.columns))} } func (iter *Iter) readColumn() ([]byte, error) { return iter.framer.readBytesInternal() } // Scan consumes the next row of the iterator and copies the columns of the // current row into the values pointed at by dest. Use nil as a dest value // to skip the corresponding column. Scan might send additional queries // to the database to retrieve the next set of rows if paging was enabled. // // Scan returns true if the row was successfully unmarshaled or false if the // end of the result set was reached or if an error occurred. Close should // be called afterwards to retrieve any potential errors. func (iter *Iter) Scan(dest ...interface{}) bool { if iter.err != nil { return false } if iter.pos >= iter.numRows { if iter.next != nil { *iter = *iter.next.fetch() return iter.Scan(dest...) } return false } if iter.next != nil && iter.pos >= iter.next.pos { iter.next.fetchAsync() } // currently only support scanning into an expand tuple, such that its the same // as scanning in more values from a single column if len(dest) != iter.meta.actualColCount { iter.err = fmt.Errorf("gocql: not enough columns to scan into: have %d want %d", len(dest), iter.meta.actualColCount) return false } // i is the current position in dest, could posible replace it and just use // slices of dest i := 0 for _, col := range iter.meta.columns { colBytes, err := iter.readColumn() if err != nil { iter.err = err return false } n, err := scanColumn(colBytes, col, dest[i:]) if err != nil { iter.err = err return false } i += n } iter.pos++ return true } // GetCustomPayload returns any parsed custom payload results if given in the // response from Cassandra. Note that the result is not a copy. // // This additional feature of CQL Protocol v4 // allows additional results and query information to be returned by // custom QueryHandlers running in your C* cluster. // See https://datastax.github.io/java-driver/manual/custom_payloads/ func (iter *Iter) GetCustomPayload() map[string][]byte { if iter.framer != nil { return iter.framer.customPayload } return nil } // Warnings returns any warnings generated if given in the response from Cassandra. // // This is only available starting with CQL Protocol v4. func (iter *Iter) Warnings() []string { if iter.framer != nil { return iter.framer.header.warnings } return nil } // Close closes the iterator and returns any errors that happened during // the query or the iteration. func (iter *Iter) Close() error { if atomic.CompareAndSwapInt32(&iter.closed, 0, 1) { if iter.framer != nil { iter.framer = nil } } return iter.err } // WillSwitchPage detects if iterator reached end of current page // and the next page is available. func (iter *Iter) WillSwitchPage() bool { return iter.pos >= iter.numRows && iter.next != nil } // checkErrAndNotFound handle error and NotFound in one method. func (iter *Iter) checkErrAndNotFound() error { if iter.err != nil { return iter.err } else if iter.numRows == 0 { return ErrNotFound } return nil } // PageState return the current paging state for a query which can be used for // subsequent queries to resume paging this point. func (iter *Iter) PageState() []byte { return iter.meta.pagingState } // NumRows returns the number of rows in this pagination, it will update when new // pages are fetched, it is not the value of the total number of rows this iter // will return unless there is only a single page returned. func (iter *Iter) NumRows() int { return iter.numRows } // nextIter holds state for fetching a single page in an iterator. // single page might be attempted multiple times due to retries. type nextIter struct { qry *Query pos int oncea sync.Once once sync.Once next *Iter } func (n *nextIter) fetchAsync() { n.oncea.Do(func() { go n.fetch() }) } func (n *nextIter) fetch() *Iter { n.once.Do(func() { // if the query was specifically run on a connection then re-use that // connection when fetching the next results if n.qry.conn != nil { n.next = n.qry.conn.executeQuery(n.qry.Context(), n.qry) } else { n.next = n.qry.session.executeQuery(n.qry) } }) return n.next } type Batch struct { Type BatchType Entries []BatchEntry Cons Consistency routingKey []byte CustomPayload map[string][]byte rt RetryPolicy spec SpeculativeExecutionPolicy trace Tracer observer BatchObserver session *Session serialCons SerialConsistency defaultTimestamp bool defaultTimestampValue int64 context context.Context cancelBatch func() keyspace string metrics *queryMetrics // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo } // NewBatch creates a new batch operation without defaults from the cluster // // Deprecated: use session.NewBatch instead func NewBatch(typ BatchType) *Batch { return &Batch{ Type: typ, metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, spec: &NonSpeculativeExecution{}, routingInfo: &queryRoutingInfo{}, } } // NewBatch creates a new batch operation using defaults defined in the cluster func (s *Session) NewBatch(typ BatchType) *Batch { s.mu.RLock() batch := &Batch{ Type: typ, rt: s.cfg.RetryPolicy, serialCons: s.cfg.SerialConsistency, trace: s.trace, observer: s.batchObserver, session: s, Cons: s.cons, defaultTimestamp: s.cfg.DefaultTimestamp, keyspace: s.cfg.Keyspace, metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, spec: &NonSpeculativeExecution{}, routingInfo: &queryRoutingInfo{}, } s.mu.RUnlock() return batch } // Trace enables tracing of this batch. Look at the documentation of the // Tracer interface to learn more about tracing. func (b *Batch) Trace(trace Tracer) *Batch { b.trace = trace return b } // Observer enables batch-level observer on this batch. // The provided observer will be called every time this batched query is executed. func (b *Batch) Observer(observer BatchObserver) *Batch { b.observer = observer return b } func (b *Batch) Keyspace() string { return b.keyspace } // Batch has no reasonable eqivalent of Query.Table(). func (b *Batch) Table() string { return b.routingInfo.table } // Attempts returns the number of attempts made to execute the batch. func (b *Batch) Attempts() int { return b.metrics.attempts() } func (b *Batch) AddAttempts(i int, host *HostInfo) { b.metrics.attempt(i, 0, host, false) } // Latency returns the average number of nanoseconds to execute a single attempt of the batch. func (b *Batch) Latency() int64 { return b.metrics.latency() } func (b *Batch) AddLatency(l int64, host *HostInfo) { b.metrics.attempt(0, time.Duration(l)*time.Nanosecond, host, false) } // GetConsistency returns the currently configured consistency level for the batch // operation. func (b *Batch) GetConsistency() Consistency { return b.Cons } // SetConsistency sets the currently configured consistency level for the batch // operation. func (b *Batch) SetConsistency(c Consistency) { b.Cons = c } func (b *Batch) Context() context.Context { if b.context == nil { return context.Background() } return b.context } func (b *Batch) IsIdempotent() bool { for _, entry := range b.Entries { if !entry.Idempotent { return false } } return true } func (b *Batch) speculativeExecutionPolicy() SpeculativeExecutionPolicy { return b.spec } func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch { b.spec = sp return b } // Query adds the query to the batch operation func (b *Batch) Query(stmt string, args ...interface{}) { b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args}) } // Bind adds the query to the batch operation and correlates it with a binding callback // that will be invoked when the batch is executed. The binding callback allows the application // to define which query argument values will be marshalled as part of the batch execution. func (b *Batch) Bind(stmt string, bind func(q *QueryInfo) ([]interface{}, error)) { b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, binding: bind}) } func (b *Batch) retryPolicy() RetryPolicy { return b.rt } // RetryPolicy sets the retry policy to use when executing the batch operation func (b *Batch) RetryPolicy(r RetryPolicy) *Batch { b.rt = r return b } func (b *Batch) withContext(ctx context.Context) ExecutableQuery { return b.WithContext(ctx) } // WithContext returns a shallow copy of b with its context // set to ctx. // // The provided context controls the entire lifetime of executing a // query, queries will be canceled and return once the context is // canceled. func (b *Batch) WithContext(ctx context.Context) *Batch { b2 := *b b2.context = ctx return &b2 } // Deprecate: does nothing, cancel the context passed to WithContext func (*Batch) Cancel() { // TODO: delete } // Size returns the number of batch statements to be executed by the batch operation. func (b *Batch) Size() int { return len(b.Entries) } // SerialConsistency sets the consistency level for the // serial phase of conditional updates. That consistency can only be // either SERIAL or LOCAL_SERIAL and if not present, it defaults to // SERIAL. This option will be ignored for anything else that a // conditional update/insert. // // Only available for protocol 3 and above func (b *Batch) SerialConsistency(cons SerialConsistency) *Batch { b.serialCons = cons return b } // DefaultTimestamp will enable the with default timestamp flag on the query. // If enable, this will replace the server side assigned // timestamp as default timestamp. Note that a timestamp in the query itself // will still override this timestamp. This is entirely optional. // // Only available on protocol >= 3 func (b *Batch) DefaultTimestamp(enable bool) *Batch { b.defaultTimestamp = enable return b } // WithTimestamp will enable the with default timestamp flag on the query // like DefaultTimestamp does. But also allows to define value for timestamp. // It works the same way as USING TIMESTAMP in the query itself, but // should not break prepared query optimization. // // Only available on protocol >= 3 func (b *Batch) WithTimestamp(timestamp int64) *Batch { b.DefaultTimestamp(true) b.defaultTimestampValue = timestamp return b } func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { latency := end.Sub(start) attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil) if b.observer == nil { return } statements := make([]string, len(b.Entries)) values := make([][]interface{}, len(b.Entries)) for i, entry := range b.Entries { statements[i] = entry.Stmt values[i] = entry.Args } b.observer.ObserveBatch(b.Context(), ObservedBatch{ Keyspace: keyspace, Statements: statements, Values: values, Start: start, End: end, // Rows not used in batch observations // TODO - might be able to support it when using BatchCAS Host: host, Metrics: metricsForHost, Err: iter.err, Attempt: attempt, }) } func (b *Batch) GetRoutingKey() ([]byte, error) { if b.routingKey != nil { return b.routingKey, nil } if len(b.Entries) == 0 { return nil, nil } entry := b.Entries[0] if entry.binding != nil { // bindings do not have the values let's skip it like Query does. return nil, nil } // try to determine the routing key routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), entry.Stmt) if err != nil { return nil, err } return createRoutingKey(routingKeyInfo, entry.Args) } func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) ([]byte, error) { if routingKeyInfo == nil { return nil, nil } if len(routingKeyInfo.indexes) == 1 { // single column routing key routingKey, err := Marshal( routingKeyInfo.types[0], values[routingKeyInfo.indexes[0]], ) if err != nil { return nil, err } return routingKey, nil } // composite routing key buf := bytes.NewBuffer(make([]byte, 0, 256)) for i := range routingKeyInfo.indexes { encoded, err := Marshal( routingKeyInfo.types[i], values[routingKeyInfo.indexes[i]], ) if err != nil { return nil, err } lenBuf := []byte{0x00, 0x00} binary.BigEndian.PutUint16(lenBuf, uint16(len(encoded))) buf.Write(lenBuf) buf.Write(encoded) buf.WriteByte(0x00) } routingKey := buf.Bytes() return routingKey, nil } func (b *Batch) borrowForExecution() { // empty, because Batch has no equivalent of Query.Release() // that would race with speculative executions. } func (b *Batch) releaseAfterExecution() { // empty, because Batch has no equivalent of Query.Release() // that would race with speculative executions. } type BatchType byte const ( LoggedBatch BatchType = 0 UnloggedBatch BatchType = 1 CounterBatch BatchType = 2 ) type BatchEntry struct { Stmt string Args []interface{} Idempotent bool binding func(q *QueryInfo) ([]interface{}, error) } type ColumnInfo struct { Keyspace string Table string Name string TypeInfo TypeInfo } func (c ColumnInfo) String() string { return fmt.Sprintf("[column keyspace=%s table=%s name=%s type=%v]", c.Keyspace, c.Table, c.Name, c.TypeInfo) } // routing key indexes LRU cache type routingKeyInfoLRU struct { lru *lru.Cache mu sync.Mutex } type routingKeyInfo struct { indexes []int types []TypeInfo keyspace string table string } func (r *routingKeyInfo) String() string { return fmt.Sprintf("routing key index=%v types=%v", r.indexes, r.types) } func (r *routingKeyInfoLRU) Remove(key string) { r.mu.Lock() r.lru.Remove(key) r.mu.Unlock() } // Max adjusts the maximum size of the cache and cleans up the oldest records if // the new max is lower than the previous value. Not concurrency safe. func (r *routingKeyInfoLRU) Max(max int) { r.mu.Lock() for r.lru.Len() > max { r.lru.RemoveOldest() } r.lru.MaxEntries = max r.mu.Unlock() } type inflightCachedEntry struct { wg sync.WaitGroup err error value interface{} } // Tracer is the interface implemented by query tracers. Tracers have the // ability to obtain a detailed event log of all events that happened during // the execution of a query from Cassandra. Gathering this information might // be essential for debugging and optimizing queries, but this feature should // not be used on production systems with very high load. type Tracer interface { Trace(traceId []byte) } type traceWriter struct { session *Session w io.Writer mu sync.Mutex } // NewTraceWriter returns a simple Tracer implementation that outputs // the event log in a textual format. func NewTraceWriter(session *Session, w io.Writer) Tracer { return &traceWriter{session: session, w: w} } func (t *traceWriter) Trace(traceId []byte) { var ( coordinator string duration int ) iter := t.session.control.query(`SELECT coordinator, duration FROM system_traces.sessions WHERE session_id = ?`, traceId) iter.Scan(&coordinator, &duration) if err := iter.Close(); err != nil { t.mu.Lock() fmt.Fprintln(t.w, "Error:", err) t.mu.Unlock() return } var ( timestamp time.Time activity string source string elapsed int thread string ) t.mu.Lock() defer t.mu.Unlock() fmt.Fprintf(t.w, "Tracing session %016x (coordinator: %s, duration: %v):\n", traceId, coordinator, time.Duration(duration)*time.Microsecond) iter = t.session.control.query(`SELECT event_id, activity, source, source_elapsed, thread FROM system_traces.events WHERE session_id = ?`, traceId) for iter.Scan(×tamp, &activity, &source, &elapsed, &thread) { fmt.Fprintf(t.w, "%s: %s [%s] (source: %s, elapsed: %d)\n", timestamp.Format("2006/01/02 15:04:05.999999"), activity, thread, source, elapsed) } if err := iter.Close(); err != nil { fmt.Fprintln(t.w, "Error:", err) } } type ObservedQuery struct { Keyspace string Statement string // Values holds a slice of bound values for the query. // Do not modify the values here, they are shared with multiple goroutines. Values []interface{} Start time.Time // time immediately before the query was called End time.Time // time immediately after the query returned // Rows is the number of rows in the current iter. // In paginated queries, rows from previous scans are not counted. // Rows is not used in batch queries and remains at the default value Rows int // Host is the informations about the host that performed the query Host *HostInfo // The metrics per this host Metrics *hostMetrics // Err is the error in the query. // It only tracks network errors or errors of bad cassandra syntax, in particular selects with no match return nil error Err error // Attempt is the index of attempt at executing this query. // The first attempt is number zero and any retries have non-zero attempt number. Attempt int } // QueryObserver is the interface implemented by query observers / stat collectors. // // Experimental, this interface and use may change type QueryObserver interface { // ObserveQuery gets called on every query to cassandra, including all queries in an iterator when paging is enabled. // It doesn't get called if there is no query because the session is closed or there are no connections available. // The error reported only shows query errors, i.e. if a SELECT is valid but finds no matches it will be nil. ObserveQuery(context.Context, ObservedQuery) } type ObservedBatch struct { Keyspace string Statements []string // Values holds a slice of bound values for each statement. // Values[i] are bound values passed to Statements[i]. // Do not modify the values here, they are shared with multiple goroutines. Values [][]interface{} Start time.Time // time immediately before the batch query was called End time.Time // time immediately after the batch query returned // Host is the informations about the host that performed the batch Host *HostInfo // Err is the error in the batch query. // It only tracks network errors or errors of bad cassandra syntax, in particular selects with no match return nil error Err error // The metrics per this host Metrics *hostMetrics // Attempt is the index of attempt at executing this query. // The first attempt is number zero and any retries have non-zero attempt number. Attempt int } // BatchObserver is the interface implemented by batch observers / stat collectors. type BatchObserver interface { // ObserveBatch gets called on every batch query to cassandra. // It also gets called once for each query in a batch. // It doesn't get called if there is no query because the session is closed or there are no connections available. // The error reported only shows query errors, i.e. if a SELECT is valid but finds no matches it will be nil. // Unlike QueryObserver.ObserveQuery it does no reporting on rows read. ObserveBatch(context.Context, ObservedBatch) } type ObservedConnect struct { // Host is the information about the host about to connect Host *HostInfo Start time.Time // time immediately before the dial is called End time.Time // time immediately after the dial returned // Err is the connection error (if any) Err error } // ConnectObserver is the interface implemented by connect observers / stat collectors. type ConnectObserver interface { // ObserveConnect gets called when a new connection to cassandra is made. ObserveConnect(ObservedConnect) } type Error struct { Code int Message string } func (e Error) Error() string { return e.Message } var ( ErrNotFound = errors.New("not found") ErrUnavailable = errors.New("unavailable") ErrUnsupported = errors.New("feature not supported") ErrTooManyStmts = errors.New("too many statements") ErrUseStmt = errors.New("use statements aren't supported. Please see https://github.com/apache/cassandra-gocql-driver for explanation.") ErrSessionClosed = errors.New("session has been closed") ErrNoConnections = errors.New("gocql: no hosts available in the pool") ErrNoKeyspace = errors.New("no keyspace provided") ErrKeyspaceDoesNotExist = errors.New("keyspace does not exist") ErrNoMetadata = errors.New("no metadata available") ) type ErrProtocol struct{ error } func NewErrProtocol(format string, args ...interface{}) error { return ErrProtocol{fmt.Errorf(format, args...)} } // BatchSizeMaximum is the maximum number of statements a batch operation can have. // This limit is set by cassandra and could change in the future. const BatchSizeMaximum = 65535 cassandra-gocql-driver-1.7.0/session_connect_test.go000066400000000000000000000043151467504044300226460ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "net" "strconv" "sync" ) type OneConnTestServer struct { Err error Addr net.IP Port int listener net.Listener acceptChan chan struct{} mu sync.Mutex closed bool } func NewOneConnTestServer() (*OneConnTestServer, error) { lstn, err := net.Listen("tcp4", "localhost:0") if err != nil { return nil, err } addr, port := parseAddressPort(lstn.Addr().String()) return &OneConnTestServer{ listener: lstn, acceptChan: make(chan struct{}), Addr: addr, Port: port, }, nil } func (c *OneConnTestServer) Accepted() chan struct{} { return c.acceptChan } func (c *OneConnTestServer) Close() { c.lockedClose() } func (c *OneConnTestServer) Serve() { conn, err := c.listener.Accept() c.Err = err if conn != nil { conn.Close() } c.lockedClose() } func (c *OneConnTestServer) lockedClose() { c.mu.Lock() defer c.mu.Unlock() if !c.closed { close(c.acceptChan) c.listener.Close() c.closed = true } } func parseAddressPort(hostPort string) (net.IP, int) { host, portStr, err := net.SplitHostPort(hostPort) if err != nil { return net.ParseIP(""), 0 } port, _ := strconv.Atoi(portStr) return net.ParseIP(host), port } cassandra-gocql-driver-1.7.0/session_test.go000066400000000000000000000225471467504044300211440ustar00rootroot00000000000000//go:build all || cassandra // +build all cassandra /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "context" "fmt" "net" "testing" ) func TestSessionAPI(t *testing.T) { cfg := &ClusterConfig{} s := &Session{ cfg: *cfg, cons: Quorum, policy: RoundRobinHostPolicy(), logger: cfg.logger(), } s.pool = cfg.PoolConfig.buildPool(s) s.executor = &queryExecutor{ pool: s.pool, policy: s.policy, } defer s.Close() s.SetConsistency(All) if s.cons != All { t.Fatalf("expected consistency 'All', got '%v'", s.cons) } s.SetPageSize(100) if s.pageSize != 100 { t.Fatalf("expected pageSize 100, got %v", s.pageSize) } s.SetPrefetch(0.75) if s.prefetch != 0.75 { t.Fatalf("expceted prefetch 0.75, got %v", s.prefetch) } trace := &traceWriter{} s.SetTrace(trace) if s.trace != trace { t.Fatalf("expected traceWriter '%v',got '%v'", trace, s.trace) } qry := s.Query("test", 1) if v, ok := qry.values[0].(int); !ok { t.Fatalf("expected qry.values[0] to be an int, got %v", qry.values[0]) } else if v != 1 { t.Fatalf("expceted qry.values[0] to be 1, got %v", v) } else if qry.stmt != "test" { t.Fatalf("expected qry.stmt to be 'test', got '%v'", qry.stmt) } boundQry := s.Bind("test", func(q *QueryInfo) ([]interface{}, error) { return nil, nil }) if boundQry.binding == nil { t.Fatal("expected qry.binding to be defined, got nil") } else if boundQry.stmt != "test" { t.Fatalf("expected qry.stmt to be 'test', got '%v'", boundQry.stmt) } itr := s.executeQuery(qry) if itr.err != ErrNoConnections { t.Fatalf("expected itr.err to be '%v', got '%v'", ErrNoConnections, itr.err) } testBatch := s.NewBatch(LoggedBatch) testBatch.Query("test") err := s.ExecuteBatch(testBatch) if err != ErrNoConnections { t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrNoConnections, err) } s.Close() if !s.Closed() { t.Fatal("expected s.Closed() to be true, got false") } //Should just return cleanly s.Close() err = s.ExecuteBatch(testBatch) if err != ErrSessionClosed { t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrSessionClosed, err) } } type funcQueryObserver func(context.Context, ObservedQuery) func (f funcQueryObserver) ObserveQuery(ctx context.Context, o ObservedQuery) { f(ctx, o) } func TestQueryBasicAPI(t *testing.T) { qry := &Query{routingInfo: &queryRoutingInfo{}} // Initiate host ip := "127.0.0.1" qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 0, TotalLatency: 0}}) if qry.Latency() != 0 { t.Fatalf("expected Query.Latency() to return 0, got %v", qry.Latency()) } qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 2, TotalLatency: 4}}) if qry.Attempts() != 2 { t.Fatalf("expected Query.Attempts() to return 2, got %v", qry.Attempts()) } if qry.Latency() != 2 { t.Fatalf("expected Query.Latency() to return 2, got %v", qry.Latency()) } qry.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) if qry.Attempts() != 4 { t.Fatalf("expected Query.Attempts() to return 4, got %v", qry.Attempts()) } qry.Consistency(All) if qry.GetConsistency() != All { t.Fatalf("expected Query.GetConsistency to return 'All', got '%s'", qry.GetConsistency()) } trace := &traceWriter{} qry.Trace(trace) if qry.trace != trace { t.Fatalf("expected Query.Trace to be '%v', got '%v'", trace, qry.trace) } observer := funcQueryObserver(func(context.Context, ObservedQuery) {}) qry.Observer(observer) if qry.observer == nil { // can't compare func to func, checking not nil instead t.Fatal("expected Query.QueryObserver to be set, got nil") } qry.PageSize(10) if qry.pageSize != 10 { t.Fatalf("expected Query.PageSize to be 10, got %v", qry.pageSize) } qry.Prefetch(0.75) if qry.prefetch != 0.75 { t.Fatalf("expected Query.Prefetch to be 0.75, got %v", qry.prefetch) } rt := &SimpleRetryPolicy{NumRetries: 3} if qry.RetryPolicy(rt); qry.rt != rt { t.Fatalf("expected Query.RetryPolicy to be '%v', got '%v'", rt, qry.rt) } qry.Bind(qry) if qry.values[0] != qry { t.Fatalf("expected Query.Values[0] to be '%v', got '%v'", qry, qry.values[0]) } } func TestQueryShouldPrepare(t *testing.T) { toPrepare := []string{"select * ", "INSERT INTO", "update table", "delete from", "begin batch"} cantPrepare := []string{"create table", "USE table", "LIST keyspaces", "alter table", "drop table", "grant user", "revoke user"} q := &Query{routingInfo: &queryRoutingInfo{}} for i := 0; i < len(toPrepare); i++ { q.stmt = toPrepare[i] if !q.shouldPrepare() { t.Fatalf("expected Query.shouldPrepare to return true, got false for statement '%v'", toPrepare[i]) } } for i := 0; i < len(cantPrepare); i++ { q.stmt = cantPrepare[i] if q.shouldPrepare() { t.Fatalf("expected Query.shouldPrepare to return false, got true for statement '%v'", cantPrepare[i]) } } } func TestBatchBasicAPI(t *testing.T) { cfg := &ClusterConfig{RetryPolicy: &SimpleRetryPolicy{NumRetries: 2}} s := &Session{ cfg: *cfg, cons: Quorum, logger: cfg.logger(), } defer s.Close() s.pool = cfg.PoolConfig.buildPool(s) // Test UnloggedBatch b := s.NewBatch(UnloggedBatch) if b.Type != UnloggedBatch { t.Fatalf("expceted batch.Type to be '%v', got '%v'", UnloggedBatch, b.Type) } else if b.rt != cfg.RetryPolicy { t.Fatalf("expceted batch.RetryPolicy to be '%v', got '%v'", cfg.RetryPolicy, b.rt) } // Test LoggedBatch b = s.NewBatch(LoggedBatch) if b.Type != LoggedBatch { t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type) } ip := "127.0.0.1" // Test attempts b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1}}) if b.Attempts() != 1 { t.Fatalf("expected batch.Attempts() to return %v, got %v", 1, b.Attempts()) } b.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) if b.Attempts() != 3 { t.Fatalf("expected batch.Attempts() to return %v, got %v", 3, b.Attempts()) } // Test latency if b.Latency() != 0 { t.Fatalf("expected batch.Latency() to be 0, got %v", b.Latency()) } b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1, TotalLatency: 4}}) if b.Latency() != 4 { t.Fatalf("expected batch.Latency() to return %v, got %v", 4, b.Latency()) } // Test Consistency b.Cons = One if b.GetConsistency() != One { t.Fatalf("expected batch.GetConsistency() to return 'One', got '%s'", b.GetConsistency()) } trace := &traceWriter{} b.Trace(trace) if b.trace != trace { t.Fatalf("expected batch.Trace to be '%v', got '%v'", trace, b.trace) } // Test batch.Query() b.Query("test", 1) if b.Entries[0].Stmt != "test" { t.Fatalf("expected batch.Entries[0].Statement to be 'test', got '%v'", b.Entries[0].Stmt) } else if b.Entries[0].Args[0].(int) != 1 { t.Fatalf("expected batch.Entries[0].Args[0] to be 1, got %v", b.Entries[0].Args[0]) } b.Bind("test2", func(q *QueryInfo) ([]interface{}, error) { return nil, nil }) if b.Entries[1].Stmt != "test2" { t.Fatalf("expected batch.Entries[1].Statement to be 'test2', got '%v'", b.Entries[1].Stmt) } else if b.Entries[1].binding == nil { t.Fatal("expected batch.Entries[1].binding to be defined, got nil") } // Test RetryPolicy r := &SimpleRetryPolicy{NumRetries: 4} b.RetryPolicy(r) if b.rt != r { t.Fatalf("expected batch.RetryPolicy to be '%v', got '%v'", r, b.rt) } if b.Size() != 2 { t.Fatalf("expected batch.Size() to return 2, got %v", b.Size()) } } func TestConsistencyNames(t *testing.T) { names := map[fmt.Stringer]string{ Any: "ANY", One: "ONE", Two: "TWO", Three: "THREE", Quorum: "QUORUM", All: "ALL", LocalQuorum: "LOCAL_QUORUM", EachQuorum: "EACH_QUORUM", Serial: "SERIAL", LocalSerial: "LOCAL_SERIAL", LocalOne: "LOCAL_ONE", } for k, v := range names { if k.String() != v { t.Fatalf("expected '%v', got '%v'", v, k.String()) } } } func TestIsUseStatement(t *testing.T) { testCases := []struct { input string exp bool }{ {"USE foo", true}, {"USe foo", true}, {"UsE foo", true}, {"Use foo", true}, {"uSE foo", true}, {"uSe foo", true}, {"usE foo", true}, {"use foo", true}, {"SELECT ", false}, {"UPDATE ", false}, {"INSERT ", false}, {"", false}, } for _, tc := range testCases { v := isUseStatement(tc.input) if v != tc.exp { t.Fatalf("expected %v but got %v for statement %q", tc.exp, v, tc.input) } } } cassandra-gocql-driver-1.7.0/session_unit_test.go000066400000000000000000000036361467504044300222010ustar00rootroot00000000000000//go:build all || unit // +build all unit /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "context" "testing" ) func TestAsyncSessionInit(t *testing.T) { // Build a 3 node cluster to test host metric mapping var addresses = []string{ "127.0.0.1", "127.0.0.2", "127.0.0.3", } // only build 1 of the servers so that we can test not connecting to the last // one srv := NewTestServerWithAddress(addresses[0]+":0", t, defaultProto, context.Background()) defer srv.Stop() // just choose any port cluster := testCluster(defaultProto, srv.Address, addresses[1]+":9999", addresses[2]+":9999") cluster.PoolConfig.HostSelectionPolicy = SingleHostReadyPolicy(RoundRobinHostPolicy()) db, err := cluster.CreateSession() if err != nil { t.Fatalf("NewCluster: %v", err) } defer db.Close() // make sure the session works if err := db.Query("void").Exec(); err != nil { t.Fatalf("unexpected error from void") } } cassandra-gocql-driver-1.7.0/stress_test.go000066400000000000000000000050111467504044300207670ustar00rootroot00000000000000//go:build all || cassandra // +build all cassandra /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "sync/atomic" "testing" ) func BenchmarkConnStress(b *testing.B) { const workers = 16 cluster := createCluster() cluster.NumConns = 1 session := createSessionFromCluster(cluster, b) defer session.Close() if err := createTable(session, "CREATE TABLE IF NOT EXISTS conn_stress (id int primary key)"); err != nil { b.Fatal(err) } var seed uint64 writer := func(pb *testing.PB) { seed := atomic.AddUint64(&seed, 1) var i uint64 = 0 for pb.Next() { if err := session.Query("insert into conn_stress (id) values (?)", i*seed).Exec(); err != nil { b.Error(err) return } i++ } } b.SetParallelism(workers) b.RunParallel(writer) } func BenchmarkConnRoutingKey(b *testing.B) { const workers = 16 cluster := createCluster() cluster.NumConns = 1 cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy()) session := createSessionFromCluster(cluster, b) defer session.Close() if err := createTable(session, "CREATE TABLE IF NOT EXISTS routing_key_stress (id int primary key)"); err != nil { b.Fatal(err) } var seed uint64 writer := func(pb *testing.PB) { seed := atomic.AddUint64(&seed, 1) var i uint64 = 0 query := session.Query("insert into routing_key_stress (id) values (?)") for pb.Next() { if _, err := query.Bind(i * seed).GetRoutingKey(); err != nil { b.Error(err) return } i++ } } b.SetParallelism(workers) b.RunParallel(writer) } cassandra-gocql-driver-1.7.0/testdata/000077500000000000000000000000001467504044300176725ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/testdata/frames/000077500000000000000000000000001467504044300211475ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/testdata/frames/bench_parse_result.gz000077500000000000000000000023701467504044300253650ustar00rootroot00000000000000U"Vbench_parse_resulto6nj'N۵E uuMV KMM"5J ۷#%;eYrL*a?'x/{,wsF{opC ZhXQ/xޟy2!* 8TÀ8W?{^J٘>HJ2'dN7W8~I>3,jh?lv K!O&5C ȄC SKy,A&ܙH9Z̸s6] 74D8 p?INm~?61h8(naؗpCÂ^!]__2g?OPU͵h? *&;!^XFub2yN5*x֚ |.VEfow%Lol?X1kQaΟM0$Sr3Eӝ@rSeŞ^rz/P>pΖu(E Z`%l0&3ءL# ͎Arʎ%L PuN,7I@[>L{]aG]`B%h5šS!\,_,._tUk#nf衆R^\wl9f# TYPKlGS6#'z0RC HMY97-'LLvQ>;^gx!4";1ϳG46 MeW@Z7&Jt3ߌN,,Ԣۡ\?״Y' ;ͪs Khԕw,i lbyY}w2,1&%̭e-!˱ۘќ0٢Ի4fc#Ua\k.GHHn3/ڪZOa[Nb5ޭś_d/z#>~կml EH_TMwe[c;b;o(V !mxr%Ī>pզx䪌FIC+3 2_gn=纗JvL~[͛<2;!'8Vp!,ë=f~?XKH njԌfJ2e ;cV+kn7 )>-Ƭ-ZM(g)<3n?fatxJgJlH5JcKGYsX}Yk5QG@}}tr;\j)_qJt<^L(cassandra-gocql-driver-1.7.0/testdata/pki/000077500000000000000000000000001467504044300204555ustar00rootroot00000000000000cassandra-gocql-driver-1.7.0/testdata/pki/.keystore000066400000000000000000000075011467504044300223260ustar00rootroot00000000000000 cassandraV 0 0 +* n2+ޥ̧4zDm9$`$A*۝/#SfŔ/0_!aOJaS~Dxo3׼"K+!ׁ)$'g{y++y(m %,u.çj+rNCΩc} 'D@*sp jN6UuypIoyE,m+Y}3I/:eʩ:!V=]Z;4!6zygaΕ`w)RvuXQ:; 꼍A;(kRa%~U)#QlBKuHaD$ٿjJ/*+)ߠX5o̲xt-:jg4QhQAo 1kh{hɻĽ$*?b$Dvm#PA,ȏWHTʷxV<#%cOv̸~il}Vs0Y &Փ kqlޞ\@zTt%ҡ5a $@)+Ս'EMXf/H;B9{\"㐧PNi%B}uI#z* w 2Y45#{ ALHlx*c^Uy+dQuTڗ@dMhy&]AwZ˟!%ٕsСF,ո׽bA5Zw-1X@>j4ӭ*oJ8ٯ*$<<7:RBiB)R/#e yL`2r_NuvO,l뷫}R wha ]V<^^.E-Z&2*k ZBjPۚN4t/8g-U&oFsAc)Ve ?j=Hbj-8Y6VI;z{}lE82(ҾWYK> N#?[]AT@MFIç48I]_U*Hv4Hp'kRʗ.' 25fG1XL{dٰ>Rvd3?=(fPUl-:m}brnU0!J%"Y/w6;MJ!D{M B]N-Z(`B"#:rJ}\ v(wPيrU7Nhm"͚QXf@~J9r[yI@W:@#.e ~?tբjaurIl7@bk7u9oB3iU4rM' =tny$!] ^VG"E.} %R# 7%rD3V1+pdWqRLO? N_ڀhCyNJI\X.509q0m0UNDX6]淽0  *H  0 1 0 U ca0  240829155106Z21240805155106Z010U cassandra0"0  *H 0 ]m͞g%8R8X}3\<|z;Y@pFANАo3t9O?'e~ c’Tp]nUfhU34]?Tǯp q^i;pMXaj Kr^jtImYEڒ駶'qélCwk2'QnT 0,ZAPSހK*o]Ӷ^kү5ZJ A7;6 iRL,bD@}@@DZnb;+i[ROp;dMybfEIC#GKi#<.Ls~Qz5W r)a[k1:H@9>$}zF%`P 9H>V!: 6f,rפXJmge}j6n}ol(/2ґ#AxY:({Ӹygƭ:LqP86]G00 U00 U0^UW0USspiffe://test.cassandra.apache.org/cassandra-gocql-driver/integrationTest/cassandra0UveUkΤxvIK0U#0ͱ\Zr6s[0  *H  ^efWhY}9X/E5"!iS[ea':QdFM$}1x9>1#LCCI[#.uQOZe| hVu{:L4Y00,}0C%;џS>o'WP/J琓C8ȌNzA %cGq),$g%\VlR?ʆ~4rHIz%Fĝ@e~K݆ ? ԅc7r7`&ՙxI ٦ŏoT/A5W?e#e h)H%- GjZ{&dr-lgE?R W Up 7>sA^7ֺ$F", ȿff>1NTuud[Qnio~G8&yc‰4&lfq)AH|){xJRΫ껛z|WtTU9.980283rAcassandra-gocql-driver-1.7.0/testdata/pki/.truststore000066400000000000000000000024451467504044300227210ustar00rootroot00000000000000caOX.50900Πrx,/) i0  *H  0 1 0 U ca0  240829155104Z21240805155104Z0 1 0 U ca0"0  *H 0 ֽb ΄Ds@W Legyzkʦ]PHJBಆ0}߱Vb:bH ƹWnlNNeDTB3iV74R|aDep*?l/vD~zGyØ]Lʵb!IĝI{V2O3Wy4?TZG:^%K:GYFփYȒiF-H/6Q,`НkZti@ACh+[!d-b؀_ƣMD?b)V#pd~WT& o04'FIHA%RO-(VfzNNn}{nqsѪ%0 Pz_7-$d|=Oߤ/ja-m*PҼsL  ,$ǡמqÀdgێT?;e)Zn8d.3p{8ʬA<0:0 U00 U0Uͱ\Zr6s[0  *H  |搯ŇKRo${(#' +OXIHq6:;Eu6zbAmW6r /5|LJaw)( -2DY_DW/1$-a@;j0Sv6LXs'a&)w 2ۮ,2y"68syK.T\TG/`wJ .Os RIm}XQE>aǓʙRFAO$ Sp@v xϡX y׊6Zwp0?pGI0;;J>Ar #ro"y0-1TFpms0-YB19r@:]}ieHVSUum+5Qhrr=d7LM,9$=VDẉ 127 { val.Sub(val, maxHashInt) val.Abs(val) } return (*randomToken)(val) } func (p randomPartitioner) ParseString(str string) token { val := new(big.Int) val.SetString(str, 10) return (*randomToken)(val) } func (r *randomToken) String() string { return (*big.Int)(r).String() } func (r *randomToken) Less(token token) bool { return -1 == (*big.Int)(r).Cmp((*big.Int)(token.(*randomToken))) } type hostToken struct { token token host *HostInfo } func (ht hostToken) String() string { return fmt.Sprintf("{token=%v host=%v}", ht.token, ht.host.HostID()) } // a data structure for organizing the relationship between tokens and hosts type tokenRing struct { partitioner partitioner // tokens map token range to primary replica. // The elements in tokens are sorted by token ascending. // The range for a given item in tokens starts after preceding range and ends with the token specified in // token. The end token is part of the range. // The lowest (i.e. index 0) range wraps around the ring (its preceding range is the one with largest index). tokens []hostToken hosts []*HostInfo } func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) { tokenRing := &tokenRing{ hosts: hosts, } if strings.HasSuffix(partitioner, "Murmur3Partitioner") { tokenRing.partitioner = murmur3Partitioner{} } else if strings.HasSuffix(partitioner, "OrderedPartitioner") { tokenRing.partitioner = orderedPartitioner{} } else if strings.HasSuffix(partitioner, "RandomPartitioner") { tokenRing.partitioner = randomPartitioner{} } else { return nil, fmt.Errorf("unsupported partitioner '%s'", partitioner) } for _, host := range hosts { for _, strToken := range host.Tokens() { token := tokenRing.partitioner.ParseString(strToken) tokenRing.tokens = append(tokenRing.tokens, hostToken{token, host}) } } sort.Sort(tokenRing) return tokenRing, nil } func (t *tokenRing) Len() int { return len(t.tokens) } func (t *tokenRing) Less(i, j int) bool { return t.tokens[i].token.Less(t.tokens[j].token) } func (t *tokenRing) Swap(i, j int) { t.tokens[i], t.tokens[j] = t.tokens[j], t.tokens[i] } func (t *tokenRing) String() string { buf := &bytes.Buffer{} buf.WriteString("TokenRing(") if t.partitioner != nil { buf.WriteString(t.partitioner.Name()) } buf.WriteString("){") sep := "" for i, th := range t.tokens { buf.WriteString(sep) sep = "," buf.WriteString("\n\t[") buf.WriteString(strconv.Itoa(i)) buf.WriteString("]") buf.WriteString(th.token.String()) buf.WriteString(":") buf.WriteString(th.host.ConnectAddress().String()) } buf.WriteString("\n}") return string(buf.Bytes()) } func (t *tokenRing) GetHostForToken(token token) (host *HostInfo, endToken token) { if t == nil || len(t.tokens) == 0 { return nil, nil } // find the primary replica p := sort.Search(len(t.tokens), func(i int) bool { return !t.tokens[i].token.Less(token) }) if p == len(t.tokens) { // wrap around to the first in the ring p = 0 } v := t.tokens[p] return v.host, v.token } cassandra-gocql-driver-1.7.0/token_test.go000066400000000000000000000264431467504044300206000ustar00rootroot00000000000000// Copyright (c) 2015 The gocql Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "fmt" "math/big" "net" "sort" "strconv" "testing" ) // Tests of the murmur3Patitioner func TestMurmur3Partitioner(t *testing.T) { token := murmur3Partitioner{}.ParseString("-1053604476080545076") if "-1053604476080545076" != token.String() { t.Errorf("Expected '-1053604476080545076' but was '%s'", token) } // at least verify that the partitioner // doesn't return nil pk, _ := marshalInt(nil, 1) token = murmur3Partitioner{}.Hash(pk) if token == nil { t.Fatal("token was nil") } } // Tests of the murmur3Token func TestMurmur3Token(t *testing.T) { if murmur3Token(42).Less(murmur3Token(42)) { t.Errorf("Expected Less to return false, but was true") } if !murmur3Token(-42).Less(murmur3Token(42)) { t.Errorf("Expected Less to return true, but was false") } if murmur3Token(42).Less(murmur3Token(-42)) { t.Errorf("Expected Less to return false, but was true") } } // Tests of the orderedPartitioner func TestOrderedPartitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil p := orderedPartitioner{} pk, _ := marshalInt(nil, 1) token := p.Hash(pk) if token == nil { t.Fatal("token was nil") } str := token.String() parsedToken := p.ParseString(str) if !bytes.Equal([]byte(token.(orderedToken)), []byte(parsedToken.(orderedToken))) { t.Errorf("Failed to convert to and from a string %s expected %x but was %x", str, []byte(token.(orderedToken)), []byte(parsedToken.(orderedToken)), ) } } // Tests of the orderedToken func TestOrderedToken(t *testing.T) { if orderedToken([]byte{0, 0, 4, 2}).Less(orderedToken([]byte{0, 0, 4, 2})) { t.Errorf("Expected Less to return false, but was true") } if !orderedToken([]byte{0, 0, 3}).Less(orderedToken([]byte{0, 0, 4, 2})) { t.Errorf("Expected Less to return true, but was false") } if orderedToken([]byte{0, 0, 4, 2}).Less(orderedToken([]byte{0, 0, 3})) { t.Errorf("Expected Less to return false, but was true") } } // Tests of the randomPartitioner func TestRandomPartitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil p := randomPartitioner{} pk, _ := marshalInt(nil, 1) token := p.Hash(pk) if token == nil { t.Fatal("token was nil") } str := token.String() parsedToken := p.ParseString(str) if (*big.Int)(token.(*randomToken)).Cmp((*big.Int)(parsedToken.(*randomToken))) != 0 { t.Errorf("Failed to convert to and from a string %s expected %v but was %v", str, token, parsedToken, ) } } func TestRandomPartitionerMatchesReference(t *testing.T) { // example taken from datastax python driver // >>> from cassandra.metadata import MD5Token // >>> MD5Token.hash_fn("test") // 12707736894140473154801792860916528374L var p randomPartitioner expect := "12707736894140473154801792860916528374" actual := p.Hash([]byte("test")).String() if actual != expect { t.Errorf("expected random partitioner to generate tokens in the same way as the reference"+ " python client. Expected %s, but got %s", expect, actual) } } // Tests of the randomToken func TestRandomToken(t *testing.T) { if ((*randomToken)(big.NewInt(42))).Less((*randomToken)(big.NewInt(42))) { t.Errorf("Expected Less to return false, but was true") } if !((*randomToken)(big.NewInt(41))).Less((*randomToken)(big.NewInt(42))) { t.Errorf("Expected Less to return true, but was false") } if ((*randomToken)(big.NewInt(42))).Less((*randomToken)(big.NewInt(41))) { t.Errorf("Expected Less to return false, but was true") } } type intToken int func (i intToken) String() string { return strconv.Itoa(int(i)) } func (i intToken) Less(token token) bool { return i < token.(intToken) } // Test of the token ring implementation based on example at the start of this // page of documentation: // http://www.datastax.com/docs/0.8/cluster_architecture/partitioning func TestTokenRing_Int(t *testing.T) { host0 := &HostInfo{} host25 := &HostInfo{} host50 := &HostInfo{} host75 := &HostInfo{} ring := &tokenRing{ partitioner: nil, // these tokens and hosts are out of order to test sorting tokens: []hostToken{ {intToken(0), host0}, {intToken(50), host50}, {intToken(75), host75}, {intToken(25), host25}, }, } sort.Sort(ring) if host, endToken := ring.GetHostForToken(intToken(0)); host != host0 || endToken != intToken(0) { t.Error("Expected host 0 for token 0") } if host, endToken := ring.GetHostForToken(intToken(1)); host != host25 || endToken != intToken(25) { t.Error("Expected host 25 for token 1") } if host, endToken := ring.GetHostForToken(intToken(24)); host != host25 || endToken != intToken(25) { t.Error("Expected host 25 for token 24") } if host, endToken := ring.GetHostForToken(intToken(25)); host != host25 || endToken != intToken(25) { t.Error("Expected host 25 for token 25") } if host, endToken := ring.GetHostForToken(intToken(26)); host != host50 || endToken != intToken(50) { t.Error("Expected host 50 for token 26") } if host, endToken := ring.GetHostForToken(intToken(49)); host != host50 || endToken != intToken(50) { t.Error("Expected host 50 for token 49") } if host, endToken := ring.GetHostForToken(intToken(50)); host != host50 || endToken != intToken(50) { t.Error("Expected host 50 for token 50") } if host, endToken := ring.GetHostForToken(intToken(51)); host != host75 || endToken != intToken(75) { t.Error("Expected host 75 for token 51") } if host, endToken := ring.GetHostForToken(intToken(74)); host != host75 || endToken != intToken(75) { t.Error("Expected host 75 for token 74") } if host, endToken := ring.GetHostForToken(intToken(75)); host != host75 || endToken != intToken(75) { t.Error("Expected host 75 for token 75") } if host, endToken := ring.GetHostForToken(intToken(76)); host != host0 || endToken != intToken(0) { t.Error("Expected host 0 for token 76") } if host, endToken := ring.GetHostForToken(intToken(99)); host != host0 || endToken != intToken(0) { t.Error("Expected host 0 for token 99") } if host, endToken := ring.GetHostForToken(intToken(100)); host != host0 || endToken != intToken(0) { t.Error("Expected host 0 for token 100") } } // Test for the behavior of a nil pointer to tokenRing func TestTokenRing_Nil(t *testing.T) { var ring *tokenRing = nil if host, endToken := ring.GetHostForToken(nil); host != nil || endToken != nil { t.Error("Expected nil for nil token ring") } } // Test of the recognition of the partitioner class func TestTokenRing_UnknownPartition(t *testing.T) { _, err := newTokenRing("UnknownPartitioner", nil) if err == nil { t.Error("Expected error for unknown partitioner value, but was nil") } } func hostsForTests(n int) []*HostInfo { hosts := make([]*HostInfo, n) for i := 0; i < n; i++ { host := &HostInfo{ connectAddress: net.IPv4(1, 1, 1, byte(n)), tokens: []string{fmt.Sprintf("%d", n)}, } hosts[i] = host } return hosts } // Test of the tokenRing with the Murmur3Partitioner func TestTokenRing_Murmur3(t *testing.T) { // Note, strings are parsed directly to int64, they are not murmur3 hashed hosts := hostsForTests(4) ring, err := newTokenRing("Murmur3Partitioner", hosts) if err != nil { t.Fatalf("Failed to create token ring due to error: %v", err) } p := murmur3Partitioner{} for _, host := range hosts { actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0])) if !actual.ConnectAddress().Equal(host.ConnectAddress()) { t.Errorf("Expected address %v for token %q, but was %v", host.ConnectAddress(), host.tokens[0], actual.ConnectAddress()) } } actual, _ := ring.GetHostForToken(p.ParseString("12")) if !actual.ConnectAddress().Equal(hosts[1].ConnectAddress()) { t.Errorf("Expected address 1 for token \"12\", but was %s", actual.ConnectAddress()) } actual, _ = ring.GetHostForToken(p.ParseString("24324545443332")) if !actual.ConnectAddress().Equal(hosts[0].ConnectAddress()) { t.Errorf("Expected address 0 for token \"24324545443332\", but was %s", actual.ConnectAddress()) } } // Test of the tokenRing with the OrderedPartitioner func TestTokenRing_Ordered(t *testing.T) { // Tokens here more or less are similar layout to the int tokens above due // to each numeric character translating to a consistently offset byte. hosts := hostsForTests(4) ring, err := newTokenRing("OrderedPartitioner", hosts) if err != nil { t.Fatalf("Failed to create token ring due to error: %v", err) } p := orderedPartitioner{} var actual *HostInfo for _, host := range hosts { actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0])) if !actual.ConnectAddress().Equal(host.ConnectAddress()) { t.Errorf("Expected address %v for token %q, but was %v", host.ConnectAddress(), host.tokens[0], actual.ConnectAddress()) } } actual, _ = ring.GetHostForToken(p.ParseString("12")) if !actual.peer.Equal(hosts[1].peer) { t.Errorf("Expected address 1 for token \"12\", but was %s", actual.ConnectAddress()) } actual, _ = ring.GetHostForToken(p.ParseString("24324545443332")) if !actual.ConnectAddress().Equal(hosts[1].ConnectAddress()) { t.Errorf("Expected address 1 for token \"24324545443332\", but was %s", actual.ConnectAddress()) } } // Test of the tokenRing with the RandomPartitioner func TestTokenRing_Random(t *testing.T) { // String tokens are parsed into big.Int in base 10 hosts := hostsForTests(4) ring, err := newTokenRing("RandomPartitioner", hosts) if err != nil { t.Fatalf("Failed to create token ring due to error: %v", err) } p := randomPartitioner{} var actual *HostInfo for _, host := range hosts { actual, _ := ring.GetHostForToken(p.ParseString(host.tokens[0])) if !actual.ConnectAddress().Equal(host.ConnectAddress()) { t.Errorf("Expected address %v for token %q, but was %v", host.ConnectAddress(), host.tokens[0], actual.ConnectAddress()) } } actual, _ = ring.GetHostForToken(p.ParseString("12")) if !actual.peer.Equal(hosts[1].peer) { t.Errorf("Expected address 1 for token \"12\", but was %s", actual.ConnectAddress()) } actual, _ = ring.GetHostForToken(p.ParseString("24324545443332")) if !actual.ConnectAddress().Equal(hosts[0].ConnectAddress()) { t.Errorf("Expected address 1 for token \"24324545443332\", but was %s", actual.ConnectAddress()) } } cassandra-gocql-driver-1.7.0/topology.go000066400000000000000000000205341467504044300202700ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "fmt" "sort" "strconv" "strings" ) type hostTokens struct { // token is end (inclusive) of token range these hosts belong to token token hosts []*HostInfo } // tokenRingReplicas maps token ranges to list of replicas. // The elements in tokenRingReplicas are sorted by token ascending. // The range for a given item in tokenRingReplicas starts after preceding range and ends with the token specified in // token. The end token is part of the range. // The lowest (i.e. index 0) range wraps around the ring (its preceding range is the one with largest index). type tokenRingReplicas []hostTokens func (h tokenRingReplicas) Less(i, j int) bool { return h[i].token.Less(h[j].token) } func (h tokenRingReplicas) Len() int { return len(h) } func (h tokenRingReplicas) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h tokenRingReplicas) replicasFor(t token) *hostTokens { if len(h) == 0 { return nil } p := sort.Search(len(h), func(i int) bool { return !h[i].token.Less(t) }) if p >= len(h) { // rollover p = 0 } return &h[p] } type placementStrategy interface { replicaMap(tokenRing *tokenRing) tokenRingReplicas replicationFactor(dc string) int } func getReplicationFactorFromOpts(val interface{}) (int, error) { switch v := val.(type) { case int: if v < 0 { return 0, fmt.Errorf("invalid replication_factor %d", v) } return v, nil case string: n, err := strconv.Atoi(v) if err != nil { return 0, fmt.Errorf("invalid replication_factor %q: %v", v, err) } else if n < 0 { return 0, fmt.Errorf("invalid replication_factor %d", n) } return n, nil default: return 0, fmt.Errorf("unknown replication_factor type %T", v) } } func getStrategy(ks *KeyspaceMetadata, logger StdLogger) placementStrategy { switch { case strings.Contains(ks.StrategyClass, "SimpleStrategy"): rf, err := getReplicationFactorFromOpts(ks.StrategyOptions["replication_factor"]) if err != nil { logger.Printf("parse rf for keyspace %q: %v", ks.Name, err) return nil } return &simpleStrategy{rf: rf} case strings.Contains(ks.StrategyClass, "NetworkTopologyStrategy"): dcs := make(map[string]int) for dc, rf := range ks.StrategyOptions { if dc == "class" { continue } rf, err := getReplicationFactorFromOpts(rf) if err != nil { logger.Println("parse rf for keyspace %q, dc %q: %v", err) // skip DC if the rf is invalid/unsupported, so that we can at least work with other working DCs. continue } dcs[dc] = rf } return &networkTopology{dcs: dcs} case strings.Contains(ks.StrategyClass, "LocalStrategy"): return nil default: logger.Printf("parse rf for keyspace %q: unsupported strategy class: %v", ks.StrategyClass) return nil } } type simpleStrategy struct { rf int } func (s *simpleStrategy) replicationFactor(dc string) int { return s.rf } func (s *simpleStrategy) replicaMap(tokenRing *tokenRing) tokenRingReplicas { tokens := tokenRing.tokens ring := make(tokenRingReplicas, len(tokens)) for i, th := range tokens { replicas := make([]*HostInfo, 0, s.rf) seen := make(map[*HostInfo]bool) for j := 0; j < len(tokens) && len(replicas) < s.rf; j++ { h := tokens[(i+j)%len(tokens)] if !seen[h.host] { replicas = append(replicas, h.host) seen[h.host] = true } } ring[i] = hostTokens{th.token, replicas} } sort.Sort(ring) return ring } type networkTopology struct { dcs map[string]int } func (n *networkTopology) replicationFactor(dc string) int { return n.dcs[dc] } func (n *networkTopology) haveRF(replicaCounts map[string]int) bool { if len(replicaCounts) != len(n.dcs) { return false } for dc, rf := range n.dcs { if rf != replicaCounts[dc] { return false } } return true } func (n *networkTopology) replicaMap(tokenRing *tokenRing) tokenRingReplicas { dcRacks := make(map[string]map[string]struct{}, len(n.dcs)) // skipped hosts in a dc skipped := make(map[string][]*HostInfo, len(n.dcs)) // number of replicas per dc replicasInDC := make(map[string]int, len(n.dcs)) // dc -> racks seenDCRacks := make(map[string]map[string]struct{}, len(n.dcs)) for _, h := range tokenRing.hosts { dc := h.DataCenter() rack := h.Rack() racks, ok := dcRacks[dc] if !ok { racks = make(map[string]struct{}) dcRacks[dc] = racks } racks[rack] = struct{}{} } for dc, racks := range dcRacks { replicasInDC[dc] = 0 seenDCRacks[dc] = make(map[string]struct{}, len(racks)) } tokens := tokenRing.tokens replicaRing := make(tokenRingReplicas, 0, len(tokens)) var totalRF int for _, rf := range n.dcs { totalRF += rf } for i, th := range tokenRing.tokens { if rf := n.dcs[th.host.DataCenter()]; rf == 0 { // skip this token since no replica in this datacenter. continue } for k, v := range skipped { skipped[k] = v[:0] } for dc := range n.dcs { replicasInDC[dc] = 0 for rack := range seenDCRacks[dc] { delete(seenDCRacks[dc], rack) } } replicas := make([]*HostInfo, 0, totalRF) for j := 0; j < len(tokens) && (len(replicas) < totalRF && !n.haveRF(replicasInDC)); j++ { // TODO: ensure we dont add the same host twice p := i + j if p >= len(tokens) { p -= len(tokens) } h := tokens[p].host dc := h.DataCenter() rack := h.Rack() rf := n.dcs[dc] if rf == 0 { // skip this DC, dont know about it or replication factor is zero continue } else if replicasInDC[dc] >= rf { if replicasInDC[dc] > rf { panic(fmt.Sprintf("replica overflow. rf=%d have=%d in dc %q", rf, replicasInDC[dc], dc)) } // have enough replicas in this DC continue } else if _, ok := dcRacks[dc][rack]; !ok { // dont know about this rack continue } racks := seenDCRacks[dc] if _, ok := racks[rack]; ok && len(racks) == len(dcRacks[dc]) { // we have been through all the racks and dont have RF yet, add this replicas = append(replicas, h) replicasInDC[dc]++ } else if !ok { if racks == nil { racks = make(map[string]struct{}, 1) seenDCRacks[dc] = racks } // new rack racks[rack] = struct{}{} replicas = append(replicas, h) r := replicasInDC[dc] + 1 if len(racks) == len(dcRacks[dc]) { // if we have been through all the racks, drain the rest of the skipped // hosts until we have RF. The next iteration will skip in the block // above skippedHosts := skipped[dc] var k int for ; k < len(skippedHosts) && r+k < rf; k++ { sh := skippedHosts[k] replicas = append(replicas, sh) } r += k skipped[dc] = skippedHosts[k:] } replicasInDC[dc] = r } else { // already seen this rack, keep hold of this host incase // we dont get enough for rf skipped[dc] = append(skipped[dc], h) } } if len(replicas) == 0 { panic(fmt.Sprintf("no replicas for token: %v", th.token)) } else if !replicas[0].Equal(th.host) { panic(fmt.Sprintf("first replica is not the primary replica for the token: expected %v got %v", replicas[0].ConnectAddress(), th.host.ConnectAddress())) } replicaRing = append(replicaRing, hostTokens{th.token, replicas}) } dcsWithReplicas := 0 for _, dc := range n.dcs { if dc > 0 { dcsWithReplicas++ } } if dcsWithReplicas == len(dcRacks) && len(replicaRing) != len(tokens) { panic(fmt.Sprintf("token map different size to token ring: got %d expected %d", len(replicaRing), len(tokens))) } return replicaRing } cassandra-gocql-driver-1.7.0/topology_test.go000066400000000000000000000135561467504044300213350ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "fmt" "sort" "testing" ) func TestPlacementStrategy_SimpleStrategy(t *testing.T) { host0 := &HostInfo{hostId: "0"} host25 := &HostInfo{hostId: "25"} host50 := &HostInfo{hostId: "50"} host75 := &HostInfo{hostId: "75"} tokens := []hostToken{ {intToken(0), host0}, {intToken(25), host25}, {intToken(50), host50}, {intToken(75), host75}, } hosts := []*HostInfo{host0, host25, host50, host75} strat := &simpleStrategy{rf: 2} tokenReplicas := strat.replicaMap(&tokenRing{hosts: hosts, tokens: tokens}) if len(tokenReplicas) != len(tokens) { t.Fatalf("expected replica map to have %d items but has %d", len(tokens), len(tokenReplicas)) } for _, replicas := range tokenReplicas { if len(replicas.hosts) != strat.rf { t.Errorf("expected to have %d replicas got %d for token=%v", strat.rf, len(replicas.hosts), replicas.token) } } for i, token := range tokens { ht := tokenReplicas.replicasFor(token.token) if ht.token != token.token { t.Errorf("token %v not in replica map: %v", token, ht.hosts) } for j, replica := range ht.hosts { exp := tokens[(i+j)%len(tokens)].host if exp != replica { t.Errorf("expected host %v to be a replica of %v got %v", exp.hostId, token, replica.hostId) } } } } func TestPlacementStrategy_NetworkStrategy(t *testing.T) { const ( totalDCs = 3 racksPerDC = 3 hostsPerDC = 5 ) tests := []struct { name string strat *networkTopology expectedReplicaMapSize int }{ { name: "full", strat: &networkTopology{ dcs: map[string]int{ "dc1": 1, "dc2": 2, "dc3": 3, }, }, expectedReplicaMapSize: hostsPerDC * totalDCs, }, { name: "missing", strat: &networkTopology{ dcs: map[string]int{ "dc2": 2, "dc3": 3, }, }, expectedReplicaMapSize: hostsPerDC * 2, }, { name: "zero", strat: &networkTopology{ dcs: map[string]int{ "dc1": 0, "dc2": 2, "dc3": 3, }, }, expectedReplicaMapSize: hostsPerDC * 2, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { var ( hosts []*HostInfo tokens []hostToken ) dcRing := make(map[string][]hostToken, totalDCs) for i := 0; i < totalDCs; i++ { var dcTokens []hostToken dc := fmt.Sprintf("dc%d", i+1) for j := 0; j < hostsPerDC; j++ { rack := fmt.Sprintf("rack%d", (j%racksPerDC)+1) h := &HostInfo{hostId: fmt.Sprintf("%s:%s:%d", dc, rack, j), dataCenter: dc, rack: rack} token := hostToken{ token: orderedToken([]byte(h.hostId)), host: h, } tokens = append(tokens, token) dcTokens = append(dcTokens, token) hosts = append(hosts, h) } sort.Sort(&tokenRing{tokens: dcTokens}) dcRing[dc] = dcTokens } if len(tokens) != hostsPerDC*totalDCs { t.Fatalf("expected %d tokens in the ring got %d", hostsPerDC*totalDCs, len(tokens)) } sort.Sort(&tokenRing{tokens: tokens}) var expReplicas int for _, rf := range test.strat.dcs { expReplicas += rf } tokenReplicas := test.strat.replicaMap(&tokenRing{hosts: hosts, tokens: tokens}) if len(tokenReplicas) != test.expectedReplicaMapSize { t.Fatalf("expected replica map to have %d items but has %d", test.expectedReplicaMapSize, len(tokenReplicas)) } if !sort.IsSorted(tokenReplicas) { t.Fatal("replica map was not sorted by token") } for token, replicas := range tokenReplicas { if len(replicas.hosts) != expReplicas { t.Fatalf("expected to have %d replicas got %d for token=%v", expReplicas, len(replicas.hosts), token) } } for dc, rf := range test.strat.dcs { if rf == 0 { continue } dcTokens := dcRing[dc] for i, th := range dcTokens { token := th.token allReplicas := tokenReplicas.replicasFor(token) if allReplicas.token != token { t.Fatalf("token %v not in replica map", token) } var replicas []*HostInfo for _, replica := range allReplicas.hosts { if replica.dataCenter == dc { replicas = append(replicas, replica) } } if len(replicas) != rf { t.Fatalf("expected %d replicas in dc %q got %d", rf, dc, len(replicas)) } var lastRack string for j, replica := range replicas { // expected is in the next rack var exp *HostInfo if lastRack == "" { // primary, first replica exp = dcTokens[(i+j)%len(dcTokens)].host } else { for k := 0; k < len(dcTokens); k++ { // walk around the ring from i + j to find the next host the // next rack p := (i + j + k) % len(dcTokens) h := dcTokens[p].host if h.rack != lastRack { exp = h break } } if exp.rack == lastRack { panic("no more racks") } } lastRack = replica.rack } } } }) } } cassandra-gocql-driver-1.7.0/tuple_test.go000066400000000000000000000246561467504044300206150ustar00rootroot00000000000000//go:build all || integration // +build all integration /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "reflect" "testing" ) func TestTupleSimple(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_test( id int, coord frozen>, primary key(id))`) if err != nil { t.Fatal(err) } err = session.Query("INSERT INTO tuple_test(id, coord) VALUES(?, (?, ?))", 1, 100, -100).Exec() if err != nil { t.Fatal(err) } var ( id int coord struct { x int y int } ) iter := session.Query("SELECT id, coord FROM tuple_test WHERE id=?", 1) if err := iter.Scan(&id, &coord.x, &coord.y); err != nil { t.Fatal(err) } if id != 1 { t.Errorf("expected to get id=1 got: %v", id) } else if coord.x != 100 { t.Errorf("expected to get coord.x=100 got: %v", coord.x) } else if coord.y != -100 { t.Errorf("expected to get coord.y=-100 got: %v", coord.y) } } func TestTuple_NullTuple(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_nil_test( id int, coord frozen>, primary key(id))`) if err != nil { t.Fatal(err) } const id = 1 err = session.Query("INSERT INTO tuple_nil_test(id, coord) VALUES(?, (?, ?))", id, nil, nil).Exec() if err != nil { t.Fatal(err) } x := new(int) y := new(int) iter := session.Query("SELECT coord FROM tuple_nil_test WHERE id=?", id) if err := iter.Scan(&x, &y); err != nil { t.Fatal(err) } if x != nil { t.Fatalf("should be nil got %+#v", x) } else if y != nil { t.Fatalf("should be nil got %+#v", y) } } func TestTuple_TupleNotSet(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_not_set_test( id int, coord frozen>, primary key(id))`) if err != nil { t.Fatal(err) } const id = 1 err = session.Query("INSERT INTO tuple_not_set_test(id,coord) VALUES(?, (?,?))", id, 1, 2).Exec() if err != nil { t.Fatal(err) } err = session.Query("INSERT INTO tuple_not_set_test(id) VALUES(?)", id+1).Exec() if err != nil { t.Fatal(err) } x := new(int) y := new(int) iter := session.Query("SELECT coord FROM tuple_not_set_test WHERE id=?", id) if err := iter.Scan(x, y); err != nil { t.Fatal(err) } if x == nil || *x != 1 { t.Fatalf("x should be %d got %+#v, value=%d", 1, x, *x) } if y == nil || *y != 2 { t.Fatalf("y should be %d got %+#v, value=%d", 2, y, *y) } // Check if the supplied targets are reset to nil iter = session.Query("SELECT coord FROM tuple_not_set_test WHERE id=?", id+1) if err := iter.Scan(x, y); err != nil { t.Fatal(err) } if x == nil || *x != 0 { t.Fatalf("x should be %d got %+#v, value=%d", 0, x, *x) } if y == nil || *y != 0 { t.Fatalf("y should be %d got %+#v, value=%d", 0, y, *y) } } func TestTupleMapScan(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_map_scan( id int, val frozen>, primary key(id))`) if err != nil { t.Fatal(err) } if err := session.Query(`INSERT INTO tuple_map_scan (id, val) VALUES (1, (1, 2));`).Exec(); err != nil { t.Fatal(err) } m := make(map[string]interface{}) err = session.Query(`SELECT * FROM tuple_map_scan`).MapScan(m) if err != nil { t.Fatal(err) } if m["val[0]"] != 1 { t.Fatalf("expacted val[0] to be %d but was %d", 1, m["val[0]"]) } if m["val[1]"] != 2 { t.Fatalf("expacted val[1] to be %d but was %d", 2, m["val[1]"]) } } func TestTupleMapScanNil(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_map_scan_nil( id int, val frozen>, primary key(id))`) if err != nil { t.Fatal(err) } if err := session.Query(`INSERT INTO tuple_map_scan_nil (id, val) VALUES (?,(?,?));`, 1, nil, nil).Exec(); err != nil { t.Fatal(err) } m := make(map[string]interface{}) err = session.Query(`SELECT * FROM tuple_map_scan_nil`).MapScan(m) if err != nil { t.Fatal(err) } if m["val[0]"] != 0 { t.Fatalf("expacted val[0] to be %d but was %d", 0, m["val[0]"]) } if m["val[1]"] != 0 { t.Fatalf("expacted val[1] to be %d but was %d", 0, m["val[1]"]) } } func TestTupleMapScanNotSet(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_map_scan_not_set( id int, val frozen>, primary key(id))`) if err != nil { t.Fatal(err) } if err := session.Query(`INSERT INTO tuple_map_scan_not_set (id) VALUES (?);`, 1).Exec(); err != nil { t.Fatal(err) } m := make(map[string]interface{}) err = session.Query(`SELECT * FROM tuple_map_scan_not_set`).MapScan(m) if err != nil { t.Fatal(err) } if m["val[0]"] != 0 { t.Fatalf("expacted val[0] to be %d but was %d", 0, m["val[0]"]) } if m["val[1]"] != 0 { t.Fatalf("expacted val[1] to be %d but was %d", 0, m["val[1]"]) } } func TestTupleLastFieldEmpty(t *testing.T) { // Regression test - empty value used to be treated as NULL value in the last tuple field session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.tuple_last_field_empty( id int, val frozen>, primary key(id))`) if err != nil { t.Fatal(err) } if err := session.Query(`INSERT INTO tuple_last_field_empty (id, val) VALUES (?,(?,?));`, 1, "abc", "").Exec(); err != nil { t.Fatal(err) } var e1, e2 *string if err := session.Query("SELECT val FROM tuple_last_field_empty WHERE id = ?", 1).Scan(&e1, &e2); err != nil { t.Fatal(err) } if e1 == nil { t.Fatal("expected e1 not to be nil") } if *e1 != "abc" { t.Fatalf("expected e1 to be equal to \"abc\", but is %v", *e2) } if e2 == nil { t.Fatal("expected e2 not to be nil") } if *e2 != "" { t.Fatalf("expected e2 to be an empty string, but is %v", *e2) } } func TestTuple_NestedCollection(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.nested_tuples( id int, val list>>, primary key(id))`) if err != nil { t.Fatal(err) } type typ struct { A int B string } tests := []struct { name string val interface{} }{ {name: "slice", val: [][]interface{}{{1, "2"}, {3, "4"}}}, {name: "array", val: [][2]interface{}{{1, "2"}, {3, "4"}}}, {name: "struct", val: []typ{{1, "2"}, {3, "4"}}}, } for i, test := range tests { t.Run(test.name, func(t *testing.T) { if err := session.Query(`INSERT INTO nested_tuples (id, val) VALUES (?, ?);`, i, test.val).Exec(); err != nil { t.Fatal(err) } rv := reflect.ValueOf(test.val) res := reflect.New(rv.Type()).Elem().Addr().Interface() err = session.Query(`SELECT val FROM nested_tuples WHERE id=?`, i).Scan(res) if err != nil { t.Fatal(err) } resVal := reflect.ValueOf(res).Elem().Interface() if !reflect.DeepEqual(test.val, resVal) { t.Fatalf("unmarshaled value not equal to the original value: expected %#v, got %#v", test.val, resVal) } }) } } func TestTuple_NullableNestedCollection(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("tuple types are only available of proto>=3") } err := createTable(session, `CREATE TABLE gocql_test.nested_tuples_with_nulls( id int, val list>>, primary key(id))`) if err != nil { t.Fatal(err) } type typ struct { A *string B *string } ptrStr := func(s string) *string { ret := new(string) *ret = s return ret } tests := []struct { name string val interface{} }{ {name: "slice", val: [][]*string{{ptrStr("1"), nil}, {nil, ptrStr("2")}, {ptrStr("3"), ptrStr("")}}}, {name: "array", val: [][2]*string{{ptrStr("1"), nil}, {nil, ptrStr("2")}, {ptrStr("3"), ptrStr("")}}}, {name: "struct", val: []typ{{ptrStr("1"), nil}, {nil, ptrStr("2")}, {ptrStr("3"), ptrStr("")}}}, } for i, test := range tests { t.Run(test.name, func(t *testing.T) { if err := session.Query(`INSERT INTO nested_tuples_with_nulls (id, val) VALUES (?, ?);`, i, test.val).Exec(); err != nil { t.Fatal(err) } rv := reflect.ValueOf(test.val) res := reflect.New(rv.Type()).Interface() err = session.Query(`SELECT val FROM nested_tuples_with_nulls WHERE id=?`, i).Scan(res) if err != nil { t.Fatal(err) } resVal := reflect.ValueOf(res).Elem().Interface() if !reflect.DeepEqual(test.val, resVal) { t.Fatalf("unmarshaled value not equal to the original value: expected %#v, got %#v", test.val, resVal) } }) } } cassandra-gocql-driver-1.7.0/udt_test.go000066400000000000000000000274651467504044300202610ustar00rootroot00000000000000//go:build all || cassandra // +build all cassandra /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "fmt" "strings" "testing" "time" ) type position struct { Lat int `cql:"lat"` Lon int `cql:"lon"` Padding string `json:"padding"` } // NOTE: due to current implementation details it is not currently possible to use // a pointer receiver type for the UDTMarshaler interface to handle UDT's func (p position) MarshalUDT(name string, info TypeInfo) ([]byte, error) { switch name { case "lat": return Marshal(info, p.Lat) case "lon": return Marshal(info, p.Lon) case "padding": return Marshal(info, p.Padding) default: return nil, fmt.Errorf("unknown column for position: %q", name) } } func (p *position) UnmarshalUDT(name string, info TypeInfo, data []byte) error { switch name { case "lat": return Unmarshal(info, data, &p.Lat) case "lon": return Unmarshal(info, data, &p.Lon) case "padding": return Unmarshal(info, data, &p.Padding) default: return fmt.Errorf("unknown column for position: %q", name) } } func TestUDT_Marshaler(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("UDT are only available on protocol >= 3") } err := createTable(session, `CREATE TYPE gocql_test.position( lat int, lon int, padding text);`) if err != nil { t.Fatal(err) } err = createTable(session, `CREATE TABLE gocql_test.houses( id int, name text, loc frozen, primary key(id) );`) if err != nil { t.Fatal(err) } const ( expLat = -1 expLon = 2 ) pad := strings.Repeat("X", 1000) err = session.Query("INSERT INTO houses(id, name, loc) VALUES(?, ?, ?)", 1, "test", &position{expLat, expLon, pad}).Exec() if err != nil { t.Fatal(err) } pos := &position{} err = session.Query("SELECT loc FROM houses WHERE id = ?", 1).Scan(pos) if err != nil { t.Fatal(err) } if pos.Lat != expLat { t.Errorf("expeceted lat to be be %d got %d", expLat, pos.Lat) } if pos.Lon != expLon { t.Errorf("expeceted lon to be be %d got %d", expLon, pos.Lon) } if pos.Padding != pad { t.Errorf("expected to get padding %q got %q\n", pad, pos.Padding) } } func TestUDT_Reflect(t *testing.T) { // Uses reflection instead of implementing the marshaling type session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("UDT are only available on protocol >= 3") } err := createTable(session, `CREATE TYPE gocql_test.horse( name text, owner text);`) if err != nil { t.Fatal(err) } err = createTable(session, `CREATE TABLE gocql_test.horse_race( position int, horse frozen, primary key(position) );`) if err != nil { t.Fatal(err) } type horse struct { Name string `cql:"name"` Owner string `cql:"owner"` } insertedHorse := &horse{ Name: "pony", Owner: "jim", } err = session.Query("INSERT INTO horse_race(position, horse) VALUES(?, ?)", 1, insertedHorse).Exec() if err != nil { t.Fatal(err) } retrievedHorse := &horse{} err = session.Query("SELECT horse FROM horse_race WHERE position = ?", 1).Scan(retrievedHorse) if err != nil { t.Fatal(err) } if *retrievedHorse != *insertedHorse { t.Fatalf("expected to get %+v got %+v", insertedHorse, retrievedHorse) } } func TestUDT_Proto2error(t *testing.T) { // TODO(zariel): move this to marshal test? _, err := Marshal(NativeType{custom: "org.apache.cassandra.db.marshal.UserType.Type", proto: 2}, 1) if err != ErrorUDTUnavailable { t.Fatalf("expected %v got %v", ErrUnavailable, err) } } func TestUDT_NullObject(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("UDT are only available on protocol >= 3") } err := createTable(session, `CREATE TYPE gocql_test.udt_null_type( name text, owner text);`) if err != nil { t.Fatal(err) } err = createTable(session, `CREATE TABLE gocql_test.udt_null_table( id uuid, udt_col frozen, primary key(id) );`) if err != nil { t.Fatal(err) } type col struct { Name string `cql:"name"` Owner string `cql:"owner"` } id := TimeUUID() err = session.Query("INSERT INTO udt_null_table(id) VALUES(?)", id).Exec() if err != nil { t.Fatal(err) } readCol := &col{ Name: "temp", Owner: "temp", } err = session.Query("SELECT udt_col FROM udt_null_table WHERE id = ?", id).Scan(readCol) if err != nil { t.Fatal(err) } if readCol.Name != "" { t.Errorf("expected empty string to be returned for null udt: got %q", readCol.Name) } if readCol.Owner != "" { t.Errorf("expected empty string to be returned for null udt: got %q", readCol.Owner) } } func TestMapScanUDT(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("UDT are only available on protocol >= 3") } err := createTable(session, `CREATE TYPE gocql_test.log_entry ( created_timestamp timestamp, message text );`) if err != nil { t.Fatal(err) } err = createTable(session, `CREATE TABLE gocql_test.requests_by_id ( id uuid PRIMARY KEY, type int, log_entries list> );`) if err != nil { t.Fatal(err) } entry := []struct { CreatedTimestamp time.Time `cql:"created_timestamp"` Message string `cql:"message"` }{ { CreatedTimestamp: time.Now().Truncate(time.Millisecond), Message: "test time now", }, } id, _ := RandomUUID() const typ = 1 err = session.Query("INSERT INTO requests_by_id(id, type, log_entries) VALUES (?, ?, ?)", id, typ, entry).Exec() if err != nil { t.Fatal(err) } rawResult := map[string]interface{}{} err = session.Query(`SELECT * FROM requests_by_id WHERE id = ?`, id).MapScan(rawResult) if err != nil { t.Fatal(err) } logEntries, ok := rawResult["log_entries"].([]map[string]interface{}) if !ok { t.Fatal("log_entries not in scanned map") } if len(logEntries) != 1 { t.Fatalf("expected to get 1 log_entry got %d", len(logEntries)) } logEntry := logEntries[0] timestamp, ok := logEntry["created_timestamp"] if !ok { t.Error("created_timestamp not unmarshalled into map") } else { if ts, ok := timestamp.(time.Time); ok { if !ts.In(time.UTC).Equal(entry[0].CreatedTimestamp.In(time.UTC)) { t.Errorf("created_timestamp not equal to stored: got %v expected %v", ts.In(time.UTC), entry[0].CreatedTimestamp.In(time.UTC)) } } else { t.Errorf("created_timestamp was not time.Time got: %T", timestamp) } } message, ok := logEntry["message"] if !ok { t.Error("message not unmarshalled into map") } else { if ts, ok := message.(string); ok { if ts != message { t.Errorf("message not equal to stored: got %v expected %v", ts, entry[0].Message) } } else { t.Errorf("message was not string got: %T", message) } } } func TestUDT_MissingField(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("UDT are only available on protocol >= 3") } err := createTable(session, `CREATE TYPE gocql_test.missing_field( name text, owner text);`) if err != nil { t.Fatal(err) } err = createTable(session, `CREATE TABLE gocql_test.missing_field( id uuid, udt_col frozen, primary key(id) );`) if err != nil { t.Fatal(err) } type col struct { Name string `cql:"name"` } writeCol := &col{ Name: "test", } id := TimeUUID() err = session.Query("INSERT INTO missing_field(id, udt_col) VALUES(?, ?)", id, writeCol).Exec() if err != nil { t.Fatal(err) } readCol := &col{} err = session.Query("SELECT udt_col FROM missing_field WHERE id = ?", id).Scan(readCol) if err != nil { t.Fatal(err) } if readCol.Name != writeCol.Name { t.Errorf("expected %q: got %q", writeCol.Name, readCol.Name) } } func TestUDT_EmptyCollections(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("UDT are only available on protocol >= 3") } err := createTable(session, `CREATE TYPE gocql_test.nil_collections( a list, b map, c set );`) if err != nil { t.Fatal(err) } err = createTable(session, `CREATE TABLE gocql_test.nil_collections( id uuid, udt_col frozen, primary key(id) );`) if err != nil { t.Fatal(err) } type udt struct { A []string `cql:"a"` B map[string]string `cql:"b"` C []string `cql:"c"` } id := TimeUUID() err = session.Query("INSERT INTO nil_collections(id, udt_col) VALUES(?, ?)", id, &udt{}).Exec() if err != nil { t.Fatal(err) } var val udt err = session.Query("SELECT udt_col FROM nil_collections WHERE id=?", id).Scan(&val) if err != nil { t.Fatal(err) } if val.A != nil { t.Errorf("expected to get nil got %#+v", val.A) } if val.B != nil { t.Errorf("expected to get nil got %#+v", val.B) } if val.C != nil { t.Errorf("expected to get nil got %#+v", val.C) } } func TestUDT_UpdateField(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("UDT are only available on protocol >= 3") } err := createTable(session, `CREATE TYPE gocql_test.update_field_udt( name text, owner text);`) if err != nil { t.Fatal(err) } err = createTable(session, `CREATE TABLE gocql_test.update_field( id uuid, udt_col frozen, primary key(id) );`) if err != nil { t.Fatal(err) } type col struct { Name string `cql:"name"` Owner string `cql:"owner"` Data string `cql:"data"` } writeCol := &col{ Name: "test-name", Owner: "test-owner", } id := TimeUUID() err = session.Query("INSERT INTO update_field(id, udt_col) VALUES(?, ?)", id, writeCol).Exec() if err != nil { t.Fatal(err) } if err := createTable(session, `ALTER TYPE gocql_test.update_field_udt ADD data text;`); err != nil { t.Fatal(err) } readCol := &col{} err = session.Query("SELECT udt_col FROM update_field WHERE id = ?", id).Scan(readCol) if err != nil { t.Fatal(err) } if *readCol != *writeCol { t.Errorf("expected %+v: got %+v", *writeCol, *readCol) } } func TestUDT_ScanNullUDT(t *testing.T) { session := createSession(t) defer session.Close() if session.cfg.ProtoVersion < protoVersion3 { t.Skip("UDT are only available on protocol >= 3") } err := createTable(session, `CREATE TYPE gocql_test.scan_null_udt_position( lat int, lon int, padding text);`) if err != nil { t.Fatal(err) } err = createTable(session, `CREATE TABLE gocql_test.scan_null_udt_houses( id int, name text, loc frozen, primary key(id) );`) if err != nil { t.Fatal(err) } err = session.Query("INSERT INTO scan_null_udt_houses(id, name) VALUES(?, ?)", 1, "test").Exec() if err != nil { t.Fatal(err) } pos := &position{} err = session.Query("SELECT loc FROM scan_null_udt_houses WHERE id = ?", 1).Scan(pos) if err != nil { t.Fatal(err) } } cassandra-gocql-driver-1.7.0/uuid.go000066400000000000000000000223241467504044300173610ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2012, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql // The uuid package can be used to generate and parse universally unique // identifiers, a standardized format in the form of a 128 bit number. // // http://tools.ietf.org/html/rfc4122 import ( "crypto/rand" "errors" "fmt" "io" "net" "strings" "sync/atomic" "time" ) type UUID [16]byte var hardwareAddr []byte var clockSeq uint32 const ( VariantNCSCompat = 0 VariantIETF = 2 VariantMicrosoft = 6 VariantFuture = 7 ) func init() { if interfaces, err := net.Interfaces(); err == nil { for _, i := range interfaces { if i.Flags&net.FlagLoopback == 0 && len(i.HardwareAddr) > 0 { hardwareAddr = i.HardwareAddr break } } } if hardwareAddr == nil { // If we failed to obtain the MAC address of the current computer, // we will use a randomly generated 6 byte sequence instead and set // the multicast bit as recommended in RFC 4122. hardwareAddr = make([]byte, 6) _, err := io.ReadFull(rand.Reader, hardwareAddr) if err != nil { panic(err) } hardwareAddr[0] = hardwareAddr[0] | 0x01 } // initialize the clock sequence with a random number var clockSeqRand [2]byte io.ReadFull(rand.Reader, clockSeqRand[:]) clockSeq = uint32(clockSeqRand[1])<<8 | uint32(clockSeqRand[0]) } // ParseUUID parses a 32 digit hexadecimal number (that might contain hypens) // representing an UUID. func ParseUUID(input string) (UUID, error) { var u UUID j := 0 for _, r := range input { switch { case r == '-' && j&1 == 0: continue case r >= '0' && r <= '9' && j < 32: u[j/2] |= byte(r-'0') << uint(4-j&1*4) case r >= 'a' && r <= 'f' && j < 32: u[j/2] |= byte(r-'a'+10) << uint(4-j&1*4) case r >= 'A' && r <= 'F' && j < 32: u[j/2] |= byte(r-'A'+10) << uint(4-j&1*4) default: return UUID{}, fmt.Errorf("invalid UUID %q", input) } j += 1 } if j != 32 { return UUID{}, fmt.Errorf("invalid UUID %q", input) } return u, nil } // UUIDFromBytes converts a raw byte slice to an UUID. func UUIDFromBytes(input []byte) (UUID, error) { var u UUID if len(input) != 16 { return u, errors.New("UUIDs must be exactly 16 bytes long") } copy(u[:], input) return u, nil } func MustRandomUUID() UUID { uuid, err := RandomUUID() if err != nil { panic(err) } return uuid } // RandomUUID generates a totally random UUID (version 4) as described in // RFC 4122. func RandomUUID() (UUID, error) { var u UUID _, err := io.ReadFull(rand.Reader, u[:]) if err != nil { return u, err } u[6] &= 0x0F // clear version u[6] |= 0x40 // set version to 4 (random uuid) u[8] &= 0x3F // clear variant u[8] |= 0x80 // set to IETF variant return u, nil } var timeBase = time.Date(1582, time.October, 15, 0, 0, 0, 0, time.UTC).Unix() // getTimestamp converts time to UUID (version 1) timestamp. // It must be an interval of 100-nanoseconds since timeBase. func getTimestamp(t time.Time) int64 { utcTime := t.In(time.UTC) ts := int64(utcTime.Unix()-timeBase)*10000000 + int64(utcTime.Nanosecond()/100) return ts } // TimeUUID generates a new time based UUID (version 1) using the current // time as the timestamp. func TimeUUID() UUID { return UUIDFromTime(time.Now()) } // The min and max clock values for a UUID. // // Cassandra's TimeUUIDType compares the lsb parts as signed byte arrays. // Thus, the min value for each byte is -128 and the max is +127. const ( minClock = 0x8080 maxClock = 0x7f7f ) // The min and max node values for a UUID. // // See explanation about Cassandra's TimeUUIDType comparison logic above. var ( minNode = []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80} maxNode = []byte{0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f} ) // MinTimeUUID generates a "fake" time based UUID (version 1) which will be // the smallest possible UUID generated for the provided timestamp. // // UUIDs generated by this function are not unique and are mostly suitable only // in queries to select a time range of a Cassandra's TimeUUID column. func MinTimeUUID(t time.Time) UUID { return TimeUUIDWith(getTimestamp(t), minClock, minNode) } // MaxTimeUUID generates a "fake" time based UUID (version 1) which will be // the biggest possible UUID generated for the provided timestamp. // // UUIDs generated by this function are not unique and are mostly suitable only // in queries to select a time range of a Cassandra's TimeUUID column. func MaxTimeUUID(t time.Time) UUID { return TimeUUIDWith(getTimestamp(t), maxClock, maxNode) } // UUIDFromTime generates a new time based UUID (version 1) as described in // RFC 4122. This UUID contains the MAC address of the node that generated // the UUID, the given timestamp and a sequence number. func UUIDFromTime(t time.Time) UUID { ts := getTimestamp(t) clock := atomic.AddUint32(&clockSeq, 1) return TimeUUIDWith(ts, clock, hardwareAddr) } // TimeUUIDWith generates a new time based UUID (version 1) as described in // RFC4122 with given parameters. t is the number of 100's of nanoseconds // since 15 Oct 1582 (60bits). clock is the number of clock sequence (14bits). // node is a slice to gurarantee the uniqueness of the UUID (up to 6bytes). // Note: calling this function does not increment the static clock sequence. func TimeUUIDWith(t int64, clock uint32, node []byte) UUID { var u UUID u[0], u[1], u[2], u[3] = byte(t>>24), byte(t>>16), byte(t>>8), byte(t) u[4], u[5] = byte(t>>40), byte(t>>32) u[6], u[7] = byte(t>>56)&0x0F, byte(t>>48) u[8] = byte(clock >> 8) u[9] = byte(clock) copy(u[10:], node) u[6] |= 0x10 // set version to 1 (time based uuid) u[8] &= 0x3F // clear variant u[8] |= 0x80 // set to IETF variant return u } // String returns the UUID in it's canonical form, a 32 digit hexadecimal // number in the form of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx. func (u UUID) String() string { var offsets = [...]int{0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34} const hexString = "0123456789abcdef" r := make([]byte, 36) for i, b := range u { r[offsets[i]] = hexString[b>>4] r[offsets[i]+1] = hexString[b&0xF] } r[8] = '-' r[13] = '-' r[18] = '-' r[23] = '-' return string(r) } // Bytes returns the raw byte slice for this UUID. A UUID is always 128 bits // (16 bytes) long. func (u UUID) Bytes() []byte { return u[:] } // Variant returns the variant of this UUID. This package will only generate // UUIDs in the IETF variant. func (u UUID) Variant() int { x := u[8] if x&0x80 == 0 { return VariantNCSCompat } if x&0x40 == 0 { return VariantIETF } if x&0x20 == 0 { return VariantMicrosoft } return VariantFuture } // Version extracts the version of this UUID variant. The RFC 4122 describes // five kinds of UUIDs. func (u UUID) Version() int { return int(u[6] & 0xF0 >> 4) } // Node extracts the MAC address of the node who generated this UUID. It will // return nil if the UUID is not a time based UUID (version 1). func (u UUID) Node() []byte { if u.Version() != 1 { return nil } return u[10:] } // Clock extracts the clock sequence of this UUID. It will return zero if the // UUID is not a time based UUID (version 1). func (u UUID) Clock() uint32 { if u.Version() != 1 { return 0 } // Clock sequence is the lower 14bits of u[8:10] return uint32(u[8]&0x3F)<<8 | uint32(u[9]) } // Timestamp extracts the timestamp information from a time based UUID // (version 1). func (u UUID) Timestamp() int64 { if u.Version() != 1 { return 0 } return int64(uint64(u[0])<<24|uint64(u[1])<<16| uint64(u[2])<<8|uint64(u[3])) + int64(uint64(u[4])<<40|uint64(u[5])<<32) + int64(uint64(u[6]&0x0F)<<56|uint64(u[7])<<48) } // Time is like Timestamp, except that it returns a time.Time. func (u UUID) Time() time.Time { if u.Version() != 1 { return time.Time{} } t := u.Timestamp() sec := t / 1e7 nsec := (t % 1e7) * 100 return time.Unix(sec+timeBase, nsec).UTC() } // Marshaling for JSON func (u UUID) MarshalJSON() ([]byte, error) { return []byte(`"` + u.String() + `"`), nil } // Unmarshaling for JSON func (u *UUID) UnmarshalJSON(data []byte) error { str := strings.Trim(string(data), `"`) if len(str) > 36 { return fmt.Errorf("invalid JSON UUID %s", str) } parsed, err := ParseUUID(str) if err == nil { copy(u[:], parsed[:]) } return err } func (u UUID) MarshalText() ([]byte, error) { return []byte(u.String()), nil } func (u *UUID) UnmarshalText(text []byte) (err error) { *u, err = ParseUUID(string(text)) return } cassandra-gocql-driver-1.7.0/uuid_test.go000066400000000000000000000223511467504044300204200ustar00rootroot00000000000000//go:build all || unit // +build all unit /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "bytes" "strings" "testing" "time" ) func TestUUIDNil(t *testing.T) { var uuid UUID want, got := "00000000-0000-0000-0000-000000000000", uuid.String() if want != got { t.Fatalf("TestNil: expected %q got %q", want, got) } } var testsUUID = []struct { input string variant int version int }{ {"b4f00409-cef8-4822-802c-deb20704c365", VariantIETF, 4}, {"B4F00409-CEF8-4822-802C-DEB20704C365", VariantIETF, 4}, //Use capital letters {"f81d4fae-7dec-11d0-a765-00a0c91e6bf6", VariantIETF, 1}, {"00000000-7dec-11d0-a765-00a0c91e6bf6", VariantIETF, 1}, {"3051a8d7-aea7-1801-e0bf-bc539dd60cf3", VariantFuture, 1}, {"3051a8d7-aea7-2801-e0bf-bc539dd60cf3", VariantFuture, 2}, {"3051a8d7-aea7-3801-e0bf-bc539dd60cf3", VariantFuture, 3}, {"3051a8d7-aea7-4801-e0bf-bc539dd60cf3", VariantFuture, 4}, {"3051a8d7-aea7-3801-e0bf-bc539dd60cf3", VariantFuture, 5}, {"d0e817e1-e4b1-1801-3fe6-b4b60ccecf9d", VariantNCSCompat, 0}, {"d0e817e1-e4b1-1801-bfe6-b4b60ccecf9d", VariantIETF, 1}, {"d0e817e1-e4b1-1801-dfe6-b4b60ccecf9d", VariantMicrosoft, 0}, {"d0e817e1-e4b1-1801-ffe6-b4b60ccecf9d", VariantFuture, 0}, } func TestPredefinedUUID(t *testing.T) { for i := range testsUUID { uuid, err := ParseUUID(testsUUID[i].input) if err != nil { t.Errorf("ParseUUID #%d: %v", i, err) continue } if str := uuid.String(); str != strings.ToLower(testsUUID[i].input) { t.Errorf("String #%d: expected %q got %q", i, testsUUID[i].input, str) continue } if variant := uuid.Variant(); variant != testsUUID[i].variant { t.Errorf("Variant #%d: expected %d got %d", i, testsUUID[i].variant, variant) } if testsUUID[i].variant == VariantIETF { if version := uuid.Version(); version != testsUUID[i].version { t.Errorf("Version #%d: expected %d got %d", i, testsUUID[i].version, version) } } json, err := uuid.MarshalJSON() if err != nil { t.Errorf("MarshalJSON #%d: %v", i, err) } expectedJson := `"` + strings.ToLower(testsUUID[i].input) + `"` if string(json) != expectedJson { t.Errorf("MarshalJSON #%d: expected %v got %v", i, expectedJson, string(json)) } var unmarshaled UUID err = unmarshaled.UnmarshalJSON(json) if err != nil { t.Errorf("UnmarshalJSON #%d: %v", i, err) } if unmarshaled != uuid { t.Errorf("UnmarshalJSON #%d: expected %v got %v", i, uuid, unmarshaled) } } } func TestInvalidUUIDCharacter(t *testing.T) { _, err := ParseUUID("z4f00409-cef8-4822-802c-deb20704c365") if err == nil || !strings.Contains(err.Error(), "invalid UUID") { t.Fatalf("expected invalid UUID error, got '%v' ", err) } } func TestInvalidUUIDLength(t *testing.T) { _, err := ParseUUID("4f00") if err == nil || !strings.Contains(err.Error(), "invalid UUID") { t.Fatalf("expected invalid UUID error, got '%v' ", err) } _, err = UUIDFromBytes(TimeUUID().Bytes()[:15]) if err == nil || err.Error() != "UUIDs must be exactly 16 bytes long" { t.Fatalf("expected error '%v', got '%v'", "UUIDs must be exactly 16 bytes long", err) } } func TestRandomUUID(t *testing.T) { for i := 0; i < 20; i++ { uuid, err := RandomUUID() if err != nil { t.Errorf("RandomUUID: %v", err) } if variant := uuid.Variant(); variant != VariantIETF { t.Errorf("wrong variant. expected %d got %d", VariantIETF, variant) } if version := uuid.Version(); version != 4 { t.Errorf("wrong version. expected %d got %d", 4, version) } } } func TestRandomUUIDInvalidAPICalls(t *testing.T) { uuid, err := RandomUUID() if err != nil { t.Fatalf("unexpected error %v", err) } if node := uuid.Node(); node != nil { t.Fatalf("expected nil, got %v", node) } if stamp := uuid.Timestamp(); stamp != 0 { t.Fatalf("expceted 0, got %v", stamp) } zeroT := time.Time{} if to := uuid.Time(); to != zeroT { t.Fatalf("expected %v, got %v", zeroT, to) } } func TestUUIDFromTime(t *testing.T) { date := time.Date(1982, 5, 5, 12, 34, 56, 400, time.UTC) uuid := UUIDFromTime(date) if uuid.Time() != date { t.Errorf("embedded time incorrect. Expected %v got %v", date, uuid.Time()) } } func TestTimeUUIDWith(t *testing.T) { utcTime := time.Date(1982, 5, 5, 12, 34, 56, 400, time.UTC) ts := int64(utcTime.Unix()-timeBase)*10000000 + int64(utcTime.Nanosecond()/100) clockSeq := uint32(0x3FFF) // Max number of clock sequence. node := [7]byte{0, 1, 2, 3, 4, 5, 6} // The last element should be ignored. uuid := TimeUUIDWith(ts, clockSeq, node[:]) if got := uuid.Variant(); got != VariantIETF { t.Errorf("wrong variant. expected %d got %d", VariantIETF, got) } if got, want := uuid.Version(), 1; got != want { t.Errorf("wrong version. Expected %v got %v", want, got) } if got := uuid.Timestamp(); got != int64(ts) { t.Errorf("wrong timestamp. Expected %v got %v", ts, got) } if got := uuid.Clock(); uint32(got) != clockSeq { t.Errorf("wrong clock. expected %v got %v", clockSeq, got) } if got, want := uuid.Node(), node[:6]; !bytes.Equal(got, want) { t.Errorf("wrong node. expected %x, bot %x", want, got) } } func TestParseUUID(t *testing.T) { uuid, _ := ParseUUID("486f3a88-775b-11e3-ae07-d231feb1dc81") if uuid.Time() != time.Date(2014, 1, 7, 5, 19, 29, 222516000, time.UTC) { t.Errorf("Expected date of 1/7/2014 at 5:19:29.222516, got %v", uuid.Time()) } } func TestTimeUUID(t *testing.T) { var node []byte timestamp := int64(0) for i := 0; i < 20; i++ { uuid := TimeUUID() if variant := uuid.Variant(); variant != VariantIETF { t.Errorf("wrong variant. expected %d got %d", VariantIETF, variant) } if version := uuid.Version(); version != 1 { t.Errorf("wrong version. expected %d got %d", 1, version) } if n := uuid.Node(); !bytes.Equal(n, node) && i > 0 { t.Errorf("wrong node. expected %x, got %x", node, n) } else if i == 0 { node = n } ts := uuid.Timestamp() if ts < timestamp { t.Errorf("timestamps must grow: timestamp=%v ts=%v", timestamp, ts) } timestamp = ts } } func TestUnmarshalJSON(t *testing.T) { var withHyphens, withoutHypens, tooLong UUID withHyphens.UnmarshalJSON([]byte(`"486f3a88-775b-11e3-ae07-d231feb1dc81"`)) if withHyphens.Time().Truncate(time.Second) != time.Date(2014, 1, 7, 5, 19, 29, 0, time.UTC) { t.Errorf("Expected date of 1/7/2014 at 5:19:29, got %v", withHyphens.Time()) } withoutHypens.UnmarshalJSON([]byte(`"486f3a88775b11e3ae07d231feb1dc81"`)) if withoutHypens.Time().Truncate(time.Second) != time.Date(2014, 1, 7, 5, 19, 29, 0, time.UTC) { t.Errorf("Expected date of 1/7/2014 at 5:19:29, got %v", withoutHypens.Time()) } err := tooLong.UnmarshalJSON([]byte(`"486f3a88-775b-11e3-ae07-d231feb1dc81486f3a88"`)) if err == nil { t.Errorf("no error for invalid JSON UUID") } } func TestMarshalText(t *testing.T) { u, err := ParseUUID("486f3a88-775b-11e3-ae07-d231feb1dc81") if err != nil { t.Fatal(err) } text, err := u.MarshalText() if err != nil { t.Fatal(err) } var u2 UUID if err := u2.UnmarshalText(text); err != nil { t.Fatal(err) } if u != u2 { t.Fatalf("uuids not equal after marshalling: before=%s after=%s", u, u2) } } func TestMinTimeUUID(t *testing.T) { aTime := time.Now() minTimeUUID := MinTimeUUID(aTime) ts := aTime.Unix() tsFromUUID := minTimeUUID.Time().Unix() if ts != tsFromUUID { t.Errorf("timestamps are not equal: expected %d, got %d", ts, tsFromUUID) } clockFromUUID := minTimeUUID.Clock() // clear two most significant bits, as they are used for IETF variant if minClock&0x3FFF != clockFromUUID { t.Errorf("clocks are not equal: expected %08b, got %08b", minClock&0x3FFF, clockFromUUID) } nodeFromUUID := minTimeUUID.Node() if !bytes.Equal(minNode, nodeFromUUID) { t.Errorf("nodes are not equal: expected %08b, got %08b", minNode, nodeFromUUID) } } func TestMaxTimeUUID(t *testing.T) { aTime := time.Now() maxTimeUUID := MaxTimeUUID(aTime) ts := aTime.Unix() tsFromUUID := maxTimeUUID.Time().Unix() if ts != tsFromUUID { t.Errorf("timestamps are not equal: expected %d, got %d", ts, tsFromUUID) } clockFromUUID := maxTimeUUID.Clock() if maxClock&0x3FFF != clockFromUUID { t.Errorf("clocks are not equal: expected %08b, got %08b", maxClock&0x3FFF, clockFromUUID) } nodeFromUUID := maxTimeUUID.Node() if !bytes.Equal(maxNode, nodeFromUUID) { t.Errorf("nodes are not equal: expected %08b, got %08b", maxNode, nodeFromUUID) } } cassandra-gocql-driver-1.7.0/version.go000066400000000000000000000040631467504044300201000ustar00rootroot00000000000000/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import "runtime/debug" const ( defaultDriverName = "github.com/apache/cassandra-gocql-driver" // This string MUST have this value since we explicitly test against the // current main package returned by runtime/debug below. Also note the // package name used here may change in a future (2.x) release; in that case // this constant will be updated as well. mainPackage = "github.com/gocql/gocql" ) var driverName string var driverVersion string func init() { buildInfo, ok := debug.ReadBuildInfo() if ok { for _, d := range buildInfo.Deps { if d.Path == mainPackage { driverName = defaultDriverName driverVersion = d.Version // If there's a replace directive in play for the gocql package // then use that information for path and version instead. This // will allow forks or other local packages to clearly identify // themselves as distinct from mainPackage above. if d.Replace != nil { driverName = d.Replace.Path driverVersion = d.Replace.Version } break } } } } cassandra-gocql-driver-1.7.0/wiki_test.go000066400000000000000000000170611467504044300204170ustar00rootroot00000000000000//go:build all || cassandra // +build all cassandra /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Content before git sha 34fdeebefcbf183ed7f916f931aa0586fdaa1b40 * Copyright (c) 2016, The Gocql authors, * provided under the BSD-3-Clause License. * See the NOTICE file distributed with this work for additional information. */ package gocql import ( "fmt" "reflect" "sort" "testing" "time" "gopkg.in/inf.v0" ) type WikiPage struct { Title string RevId UUID Body string Views int64 Protected bool Modified time.Time Rating *inf.Dec Tags []string Attachments map[string]WikiAttachment } type WikiAttachment []byte var wikiTestData = []*WikiPage{ { Title: "Frontpage", RevId: TimeUUID(), Body: "Welcome to this wiki page!", Rating: inf.NewDec(131, 3), Modified: time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), Tags: []string{"start", "important", "test"}, Attachments: map[string]WikiAttachment{ "logo": WikiAttachment("\x00company logo\x00"), "favicon": WikiAttachment("favicon.ico"), }, }, { Title: "Foobar", RevId: TimeUUID(), Body: "foo::Foo f = new foo::Foo(foo::Foo::INIT);", Modified: time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), }, } type WikiTest struct { session *Session tb testing.TB table string } func CreateSchema(session *Session, tb testing.TB, table string) *WikiTest { table = "wiki_" + table if err := createTable(session, fmt.Sprintf("DROP TABLE IF EXISTS gocql_test.%s", table)); err != nil { tb.Fatal("CreateSchema:", err) } err := createTable(session, fmt.Sprintf(`CREATE TABLE gocql_test.%s ( title varchar, revid timeuuid, body varchar, views bigint, protected boolean, modified timestamp, rating decimal, tags set, attachments map, PRIMARY KEY (title, revid) )`, table)) if err != nil { tb.Fatal("CreateSchema:", err) } return &WikiTest{ session: session, tb: tb, table: table, } } func (w *WikiTest) CreatePages(n int) { var page WikiPage t0 := time.Now() for i := 0; i < n; i++ { page.Title = fmt.Sprintf("generated_%d", (i&16)+1) page.Modified = t0.Add(time.Duration(i-n) * time.Minute) page.RevId = UUIDFromTime(page.Modified) page.Body = fmt.Sprintf("text %d", i) if err := w.InsertPage(&page); err != nil { w.tb.Error("CreatePages:", err) } } } func (w *WikiTest) InsertPage(page *WikiPage) error { return w.session.Query(fmt.Sprintf(`INSERT INTO %s (title, revid, body, views, protected, modified, rating, tags, attachments) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, w.table), page.Title, page.RevId, page.Body, page.Views, page.Protected, page.Modified, page.Rating, page.Tags, page.Attachments).Exec() } func (w *WikiTest) SelectPage(page *WikiPage, title string, revid UUID) error { return w.session.Query(fmt.Sprintf(`SELECT title, revid, body, views, protected, modified,tags, attachments, rating FROM %s WHERE title = ? AND revid = ? LIMIT 1`, w.table), title, revid).Scan(&page.Title, &page.RevId, &page.Body, &page.Views, &page.Protected, &page.Modified, &page.Tags, &page.Attachments, &page.Rating) } func (w *WikiTest) GetPageCount() int { var count int if err := w.session.Query(fmt.Sprintf(`SELECT COUNT(*) FROM %s`, w.table)).Scan(&count); err != nil { w.tb.Error("GetPageCount", err) } return count } func TestWikiCreateSchema(t *testing.T) { session := createSession(t) defer session.Close() CreateSchema(session, t, "create") } func BenchmarkWikiCreateSchema(b *testing.B) { b.StopTimer() session := createSession(b) defer func() { b.StopTimer() session.Close() }() b.StartTimer() for i := 0; i < b.N; i++ { CreateSchema(session, b, "bench_create") } } func TestWikiCreatePages(t *testing.T) { session := createSession(t) defer session.Close() w := CreateSchema(session, t, "create_pages") numPages := 5 w.CreatePages(numPages) if count := w.GetPageCount(); count != numPages { t.Errorf("expected %d pages, got %d pages.", numPages, count) } } func BenchmarkWikiCreatePages(b *testing.B) { b.StopTimer() session := createSession(b) defer func() { b.StopTimer() session.Close() }() w := CreateSchema(session, b, "bench_create_pages") b.StartTimer() w.CreatePages(b.N) } func BenchmarkWikiSelectAllPages(b *testing.B) { b.StopTimer() session := createSession(b) defer func() { b.StopTimer() session.Close() }() w := CreateSchema(session, b, "bench_select_all") w.CreatePages(100) b.StartTimer() var page WikiPage for i := 0; i < b.N; i++ { iter := session.Query(fmt.Sprintf(`SELECT title, revid, body, views, protected, modified, tags, attachments, rating FROM %s`, w.table)).Iter() for iter.Scan(&page.Title, &page.RevId, &page.Body, &page.Views, &page.Protected, &page.Modified, &page.Tags, &page.Attachments, &page.Rating) { // pass } if err := iter.Close(); err != nil { b.Error(err) } } } func BenchmarkWikiSelectSinglePage(b *testing.B) { b.StopTimer() session := createSession(b) defer func() { b.StopTimer() session.Close() }() w := CreateSchema(session, b, "bench_select_single") pages := make([]WikiPage, 100) w.CreatePages(len(pages)) iter := session.Query(fmt.Sprintf(`SELECT title, revid FROM %s`, w.table)).Iter() for i := 0; i < len(pages); i++ { if !iter.Scan(&pages[i].Title, &pages[i].RevId) { pages = pages[:i] break } } if err := iter.Close(); err != nil { b.Error(err) } b.StartTimer() var page WikiPage for i := 0; i < b.N; i++ { p := &pages[i%len(pages)] if err := w.SelectPage(&page, p.Title, p.RevId); err != nil { b.Error(err) } } } func BenchmarkWikiSelectPageCount(b *testing.B) { b.StopTimer() session := createSession(b) defer func() { b.StopTimer() session.Close() }() w := CreateSchema(session, b, "bench_page_count") const numPages = 10 w.CreatePages(numPages) b.StartTimer() for i := 0; i < b.N; i++ { if count := w.GetPageCount(); count != numPages { b.Errorf("expected %d pages, got %d pages.", numPages, count) } } } func TestWikiTypicalCRUD(t *testing.T) { session := createSession(t) defer session.Close() w := CreateSchema(session, t, "crud") for _, page := range wikiTestData { if err := w.InsertPage(page); err != nil { t.Error("InsertPage:", err) } } if count := w.GetPageCount(); count != len(wikiTestData) { t.Errorf("count: expected %d, got %d\n", len(wikiTestData), count) } for _, original := range wikiTestData { page := new(WikiPage) if err := w.SelectPage(page, original.Title, original.RevId); err != nil { t.Error("SelectPage:", err) continue } sort.Sort(sort.StringSlice(page.Tags)) sort.Sort(sort.StringSlice(original.Tags)) if !reflect.DeepEqual(page, original) { t.Errorf("page: expected %#v, got %#v\n", original, page) } } }