cassandra-driver-3.7.1/0000775000175000017500000000000013004144417017612 5ustar aboudreaultaboudreault00000000000000cassandra-driver-3.7.1/PKG-INFO0000664000175000017500000001361113004144417020711 0ustar aboudreaultaboudreault00000000000000Metadata-Version: 1.1 Name: cassandra-driver Version: 3.7.1 Summary: Python driver for Cassandra Home-page: http://github.com/datastax/python-driver Author: Tyler Hobbs Author-email: tyler@datastax.com License: UNKNOWN Description: DataStax Python Driver for Apache Cassandra =========================================== .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master :target: https://travis-ci.org/datastax/python-driver A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (1.2+) and DataStax Enterprise (3.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. The driver supports Python 2.6, 2.7, 3.3, and 3.4. Feedback Requested ------------------ **Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). Features -------- * `Synchronous `_ and `Asynchronous `_ APIs * `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining * `Connection pooling `_ * Automatic node discovery * `Automatic reconnection `_ * Configurable `load balancing `_ and `retry policies `_ * `Concurrent execution utilities `_ * `Object mapper `_ Installation ------------ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the `installation guide `_. Documentation ------------- The documentation can be found online `here `_. A couple of links for getting up to speed: * `Installation `_ * `Getting started guide `_ * `API docs `_ * `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to `documentation here `_. Contributing ------------ See `CONTRIBUTING.md `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ Your two best options for getting help with the driver are the `mailing list `_ and the IRC channel. For IRC, use the #datastax-drivers channel on irc.freenode.net. If you don't have an IRC client, you can use `freenode's web-based client `_. License ------- Copyright 2013-2016 DataStax Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Keywords: cassandra,cql,orm Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Natural Language :: English Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2.6 Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3.3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Software Development :: Libraries :: Python Modules cassandra-driver-3.7.1/LICENSE0000664000175000017500000002367612657142321020642 0ustar aboudreaultaboudreault00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS cassandra-driver-3.7.1/cassandra/0000775000175000017500000000000013004144417021551 5ustar aboudreaultaboudreault00000000000000cassandra-driver-3.7.1/cassandra/metrics.py0000664000175000017500000001462412766043657023622 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from itertools import chain import logging try: from greplin import scales except ImportError: raise ImportError( "The scales library is required for metrics support: " "https://pypi.python.org/pypi/scales") log = logging.getLogger(__name__) class Metrics(object): """ A collection of timers and counters for various performance metrics. """ request_timer = None """ A :class:`greplin.scales.PmfStat` timer for requests. This is a dict-like object with the following keys: * count - number of requests that have been timed * min - min latency * max - max latency * mean - mean latency * stdev - standard deviation for latencies * median - median latency * 75percentile - 75th percentile latencies * 97percentile - 97th percentile latencies * 98percentile - 98th percentile latencies * 99percentile - 99th percentile latencies * 999percentile - 99.9th percentile latencies """ connection_errors = None """ A :class:`greplin.scales.IntStat` count of the number of times that a request to a Cassandra node has failed due to a connection problem. """ write_timeouts = None """ A :class:`greplin.scales.IntStat` count of write requests that resulted in a timeout. """ read_timeouts = None """ A :class:`greplin.scales.IntStat` count of read requests that resulted in a timeout. """ unavailables = None """ A :class:`greplin.scales.IntStat` count of write or read requests that failed due to an insufficient number of replicas being alive to meet the requested :class:`.ConsistencyLevel`. """ other_errors = None """ A :class:`greplin.scales.IntStat` count of all other request failures, including failures caused by invalid requests, bootstrapping nodes, overloaded nodes, etc. """ retries = None """ A :class:`greplin.scales.IntStat` count of the number of times a request was retried based on the :class:`.RetryPolicy` decision. """ ignores = None """ A :class:`greplin.scales.IntStat` count of the number of times a failed request was ignored based on the :class:`.RetryPolicy` decision. """ known_hosts = None """ A :class:`greplin.scales.IntStat` count of the number of nodes in the cluster that the driver is aware of, regardless of whether any connections are opened to those nodes. """ connected_to = None """ A :class:`greplin.scales.IntStat` count of the number of nodes that the driver currently has at least one connection open to. """ open_connections = None """ A :class:`greplin.scales.IntStat` count of the number connections the driver currently has open. """ _stats_counter = 0 def __init__(self, cluster_proxy): log.debug("Starting metric capture") self.stats_name = 'cassandra-{0}'.format(str(self._stats_counter)) Metrics._stats_counter += 1 self.stats = scales.collection(self.stats_name, scales.PmfStat('request_timer'), scales.IntStat('connection_errors'), scales.IntStat('write_timeouts'), scales.IntStat('read_timeouts'), scales.IntStat('unavailables'), scales.IntStat('other_errors'), scales.IntStat('retries'), scales.IntStat('ignores'), # gauges scales.Stat('known_hosts', lambda: len(cluster_proxy.metadata.all_hosts())), scales.Stat('connected_to', lambda: len(set(chain.from_iterable(s._pools.keys() for s in cluster_proxy.sessions)))), scales.Stat('open_connections', lambda: sum(sum(p.open_count for p in s._pools.values()) for s in cluster_proxy.sessions))) # TODO, to be removed in 4.0 # /cassandra contains the metrics of the first cluster registered if 'cassandra' not in scales._Stats.stats: scales._Stats.stats['cassandra'] = scales._Stats.stats[self.stats_name] self.request_timer = self.stats.request_timer self.connection_errors = self.stats.connection_errors self.write_timeouts = self.stats.write_timeouts self.read_timeouts = self.stats.read_timeouts self.unavailables = self.stats.unavailables self.other_errors = self.stats.other_errors self.retries = self.stats.retries self.ignores = self.stats.ignores self.known_hosts = self.stats.known_hosts self.connected_to = self.stats.connected_to self.open_connections = self.stats.open_connections def on_connection_error(self): self.stats.connection_errors += 1 def on_write_timeout(self): self.stats.write_timeouts += 1 def on_read_timeout(self): self.stats.read_timeouts += 1 def on_unavailable(self): self.stats.unavailables += 1 def on_other_error(self): self.stats.other_errors += 1 def on_ignore(self): self.stats.ignores += 1 def on_retry(self): self.stats.retries += 1 def get_stats(self): """ Returns the metrics for the registered cluster instance. """ return scales.getStats()[self.stats_name] def set_stats_name(self, stats_name): """ Set the metrics stats name. The stats_name is a string used to access the metris through scales: scales.getStats()[] Default is 'cassandra-'. """ if self.stats_name == stats_name: return if stats_name in scales._Stats.stats: raise ValueError('"{0}" already exists in stats.'.format(stats_name)) stats = scales._Stats.stats[self.stats_name] del scales._Stats.stats[self.stats_name] self.stats_name = stats_name scales._Stats.stats[self.stats_name] = stats cassandra-driver-3.7.1/cassandra/metadata.py0000664000175000017500000027122013004141114023676 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from binascii import unhexlify from bisect import bisect_right from collections import defaultdict, Mapping from hashlib import md5 from itertools import islice, cycle import json import logging import re import six from six.moves import zip import sys from threading import RLock murmur3 = None try: from cassandra.murmur3 import murmur3 except ImportError as e: pass from cassandra import SignatureDescriptor, ConsistencyLevel, InvalidRequest, Unauthorized import cassandra.cqltypes as types from cassandra.encoder import Encoder from cassandra.marshal import varint_unpack from cassandra.protocol import QueryMessage from cassandra.query import dict_factory, bind_params from cassandra.util import OrderedDict log = logging.getLogger(__name__) cql_keywords = set(( 'add', 'aggregate', 'all', 'allow', 'alter', 'and', 'apply', 'as', 'asc', 'ascii', 'authorize', 'batch', 'begin', 'bigint', 'blob', 'boolean', 'by', 'called', 'clustering', 'columnfamily', 'compact', 'contains', 'count', 'counter', 'create', 'custom', 'date', 'decimal', 'delete', 'desc', 'describe', 'distinct', 'double', 'drop', 'entries', 'execute', 'exists', 'filtering', 'finalfunc', 'float', 'from', 'frozen', 'full', 'function', 'functions', 'grant', 'if', 'in', 'index', 'inet', 'infinity', 'initcond', 'input', 'insert', 'int', 'into', 'is', 'json', 'key', 'keys', 'keyspace', 'keyspaces', 'language', 'limit', 'list', 'login', 'map', 'materialized', 'modify', 'nan', 'nologin', 'norecursive', 'nosuperuser', 'not', 'null', 'of', 'on', 'options', 'or', 'order', 'password', 'permission', 'permissions', 'primary', 'rename', 'replace', 'returns', 'revoke', 'role', 'roles', 'schema', 'select', 'set', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'table', 'text', 'time', 'timestamp', 'timeuuid', 'tinyint', 'to', 'token', 'trigger', 'truncate', 'ttl', 'tuple', 'type', 'unlogged', 'update', 'use', 'user', 'users', 'using', 'uuid', 'values', 'varchar', 'varint', 'view', 'where', 'with', 'writetime' )) """ Set of keywords in CQL. Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g """ cql_keywords_unreserved = set(( 'aggregate', 'all', 'as', 'ascii', 'bigint', 'blob', 'boolean', 'called', 'clustering', 'compact', 'contains', 'count', 'counter', 'custom', 'date', 'decimal', 'distinct', 'double', 'exists', 'filtering', 'finalfunc', 'float', 'frozen', 'function', 'functions', 'inet', 'initcond', 'input', 'int', 'json', 'key', 'keys', 'keyspaces', 'language', 'list', 'login', 'map', 'nologin', 'nosuperuser', 'options', 'password', 'permission', 'permissions', 'returns', 'role', 'roles', 'sfunc', 'smallint', 'static', 'storage', 'stype', 'superuser', 'text', 'time', 'timestamp', 'timeuuid', 'tinyint', 'trigger', 'ttl', 'tuple', 'type', 'user', 'users', 'uuid', 'values', 'varchar', 'varint', 'writetime' )) """ Set of unreserved keywords in CQL. Derived from .../cassandra/src/java/org/apache/cassandra/cql3/Cql.g """ cql_keywords_reserved = cql_keywords - cql_keywords_unreserved """ Set of reserved keywords in CQL. """ _encoder = Encoder() class Metadata(object): """ Holds a representation of the cluster schema and topology. """ cluster_name = None """ The string name of the cluster. """ keyspaces = None """ A map from keyspace names to matching :class:`~.KeyspaceMetadata` instances. """ partitioner = None """ The string name of the partitioner for the cluster. """ token_map = None """ A :class:`~.TokenMap` instance describing the ring topology. """ def __init__(self): self.keyspaces = {} self._hosts = {} self._hosts_lock = RLock() def export_schema_as_string(self): """ Returns a string that can be executed as a query in order to recreate the entire schema. The string is formatted to be human readable. """ return "\n\n".join(ks.export_as_string() for ks in self.keyspaces.values()) def refresh(self, connection, timeout, target_type=None, change_type=None, **kwargs): server_version = self.get_host(connection.host).release_version parser = get_schema_parser(connection, server_version, timeout) if not target_type: self._rebuild_all(parser) return tt_lower = target_type.lower() try: parse_method = getattr(parser, 'get_' + tt_lower) meta = parse_method(self.keyspaces, **kwargs) if meta: update_method = getattr(self, '_update_' + tt_lower) if tt_lower == 'keyspace' and connection.protocol_version < 3: # we didn't have 'type' target in legacy protocol versions, so we need to query those too user_types = parser.get_types_map(self.keyspaces, **kwargs) self._update_keyspace(meta, user_types) else: update_method(meta) else: drop_method = getattr(self, '_drop_' + tt_lower) drop_method(**kwargs) except AttributeError: raise ValueError("Unknown schema target_type: '%s'" % target_type) def _rebuild_all(self, parser): current_keyspaces = set() for keyspace_meta in parser.get_all_keyspaces(): current_keyspaces.add(keyspace_meta.name) old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None) self.keyspaces[keyspace_meta.name] = keyspace_meta if old_keyspace_meta: self._keyspace_updated(keyspace_meta.name) else: self._keyspace_added(keyspace_meta.name) # remove not-just-added keyspaces removed_keyspaces = [name for name in self.keyspaces.keys() if name not in current_keyspaces] self.keyspaces = dict((name, meta) for name, meta in self.keyspaces.items() if name in current_keyspaces) for ksname in removed_keyspaces: self._keyspace_removed(ksname) def _update_keyspace(self, keyspace_meta, new_user_types=None): ks_name = keyspace_meta.name old_keyspace_meta = self.keyspaces.get(ks_name, None) self.keyspaces[ks_name] = keyspace_meta if old_keyspace_meta: keyspace_meta.tables = old_keyspace_meta.tables keyspace_meta.user_types = new_user_types if new_user_types is not None else old_keyspace_meta.user_types keyspace_meta.indexes = old_keyspace_meta.indexes keyspace_meta.functions = old_keyspace_meta.functions keyspace_meta.aggregates = old_keyspace_meta.aggregates keyspace_meta.views = old_keyspace_meta.views if (keyspace_meta.replication_strategy != old_keyspace_meta.replication_strategy): self._keyspace_updated(ks_name) else: self._keyspace_added(ks_name) def _drop_keyspace(self, keyspace): if self.keyspaces.pop(keyspace, None): self._keyspace_removed(keyspace) def _update_table(self, meta): try: keyspace_meta = self.keyspaces[meta.keyspace_name] # this is unfortunate, but protocol v4 does not differentiate # between events for tables and views. .get_table will # return one or the other based on the query results. # Here we deal with that. if isinstance(meta, TableMetadata): keyspace_meta._add_table_metadata(meta) else: keyspace_meta._add_view_metadata(meta) except KeyError: # can happen if keyspace disappears while processing async event pass def _drop_table(self, keyspace, table): try: keyspace_meta = self.keyspaces[keyspace] keyspace_meta._drop_table_metadata(table) # handles either table or view except KeyError: # can happen if keyspace disappears while processing async event pass def _update_type(self, type_meta): try: self.keyspaces[type_meta.keyspace].user_types[type_meta.name] = type_meta except KeyError: # can happen if keyspace disappears while processing async event pass def _drop_type(self, keyspace, type): try: self.keyspaces[keyspace].user_types.pop(type, None) except KeyError: # can happen if keyspace disappears while processing async event pass def _update_function(self, function_meta): try: self.keyspaces[function_meta.keyspace].functions[function_meta.signature] = function_meta except KeyError: # can happen if keyspace disappears while processing async event pass def _drop_function(self, keyspace, function): try: self.keyspaces[keyspace].functions.pop(function.signature, None) except KeyError: pass def _update_aggregate(self, aggregate_meta): try: self.keyspaces[aggregate_meta.keyspace].aggregates[aggregate_meta.signature] = aggregate_meta except KeyError: pass def _drop_aggregate(self, keyspace, aggregate): try: self.keyspaces[keyspace].aggregates.pop(aggregate.signature, None) except KeyError: pass def _keyspace_added(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) def _keyspace_updated(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) def _keyspace_removed(self, ksname): if self.token_map: self.token_map.remove_keyspace(ksname) def rebuild_token_map(self, partitioner, token_map): """ Rebuild our view of the topology from fresh rows from the system topology tables. For internal use only. """ self.partitioner = partitioner if partitioner.endswith('RandomPartitioner'): token_class = MD5Token elif partitioner.endswith('Murmur3Partitioner'): token_class = Murmur3Token elif partitioner.endswith('ByteOrderedPartitioner'): token_class = BytesToken else: self.token_map = None return token_to_host_owner = {} ring = [] for host, token_strings in six.iteritems(token_map): for token_string in token_strings: token = token_class.from_string(token_string) ring.append(token) token_to_host_owner[token] = host all_tokens = sorted(ring) self.token_map = TokenMap( token_class, token_to_host_owner, all_tokens, self) def get_replicas(self, keyspace, key): """ Returns a list of :class:`.Host` instances that are replicas for a given partition key. """ t = self.token_map if not t: return [] try: return t.get_replicas(keyspace, t.token_class.from_key(key)) except NoMurmur3: return [] def can_support_partitioner(self): if self.partitioner.endswith('Murmur3Partitioner') and murmur3 is None: return False else: return True def add_or_return_host(self, host): """ Returns a tuple (host, new), where ``host`` is a Host instance, and ``new`` is a bool indicating whether the host was newly added. """ with self._hosts_lock: try: return self._hosts[host.address], False except KeyError: self._hosts[host.address] = host return host, True def remove_host(self, host): with self._hosts_lock: return bool(self._hosts.pop(host.address, False)) def get_host(self, address): return self._hosts.get(address) def all_hosts(self): """ Returns a list of all known :class:`.Host` instances in the cluster. """ with self._hosts_lock: return list(self._hosts.values()) REPLICATION_STRATEGY_CLASS_PREFIX = "org.apache.cassandra.locator." def trim_if_startswith(s, prefix): if s.startswith(prefix): return s[len(prefix):] return s _replication_strategies = {} class ReplicationStrategyTypeType(type): def __new__(metacls, name, bases, dct): dct.setdefault('name', name) cls = type.__new__(metacls, name, bases, dct) if not name.startswith('_'): _replication_strategies[name] = cls return cls @six.add_metaclass(ReplicationStrategyTypeType) class _ReplicationStrategy(object): options_map = None @classmethod def create(cls, strategy_class, options_map): if not strategy_class: return None strategy_name = trim_if_startswith(strategy_class, REPLICATION_STRATEGY_CLASS_PREFIX) rs_class = _replication_strategies.get(strategy_name, None) if rs_class is None: rs_class = _UnknownStrategyBuilder(strategy_name) _replication_strategies[strategy_name] = rs_class try: rs_instance = rs_class(options_map) except Exception as exc: log.warning("Failed creating %s with options %s: %s", strategy_name, options_map, exc) return None return rs_instance def make_token_replica_map(self, token_to_host_owner, ring): raise NotImplementedError() def export_for_schema(self): raise NotImplementedError() ReplicationStrategy = _ReplicationStrategy class _UnknownStrategyBuilder(object): def __init__(self, name): self.name = name def __call__(self, options_map): strategy_instance = _UnknownStrategy(self.name, options_map) return strategy_instance class _UnknownStrategy(ReplicationStrategy): def __init__(self, name, options_map): self.name = name self.options_map = options_map.copy() if options_map is not None else dict() self.options_map['class'] = self.name def __eq__(self, other): return (isinstance(other, _UnknownStrategy) and self.name == other.name and self.options_map == other.options_map) def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ if self.options_map: return dict((str(key), str(value)) for key, value in self.options_map.items()) return "{'class': '%s'}" % (self.name, ) def make_token_replica_map(self, token_to_host_owner, ring): return {} class SimpleStrategy(ReplicationStrategy): replication_factor = None """ The replication factor for this keyspace. """ def __init__(self, options_map): try: self.replication_factor = int(options_map['replication_factor']) except Exception: raise ValueError("SimpleStrategy requires an integer 'replication_factor' option") def make_token_replica_map(self, token_to_host_owner, ring): replica_map = {} for i in range(len(ring)): j, hosts = 0, list() while len(hosts) < self.replication_factor and j < len(ring): token = ring[(i + j) % len(ring)] host = token_to_host_owner[token] if host not in hosts: hosts.append(host) j += 1 replica_map[ring[i]] = hosts return replica_map def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ return "{'class': 'SimpleStrategy', 'replication_factor': '%d'}" \ % (self.replication_factor,) def __eq__(self, other): if not isinstance(other, SimpleStrategy): return False return self.replication_factor == other.replication_factor class NetworkTopologyStrategy(ReplicationStrategy): dc_replication_factors = None """ A map of datacenter names to the replication factor for that DC. """ def __init__(self, dc_replication_factors): self.dc_replication_factors = dict( (str(k), int(v)) for k, v in dc_replication_factors.items()) def make_token_replica_map(self, token_to_host_owner, ring): dc_rf_map = dict((dc, int(rf)) for dc, rf in self.dc_replication_factors.items() if rf > 0) # build a map of DCs to lists of indexes into `ring` for tokens that # belong to that DC dc_to_token_offset = defaultdict(list) dc_racks = defaultdict(set) hosts_per_dc = defaultdict(set) for i, token in enumerate(ring): host = token_to_host_owner[token] dc_to_token_offset[host.datacenter].append(i) if host.datacenter and host.rack: dc_racks[host.datacenter].add(host.rack) hosts_per_dc[host.datacenter].add(host) # A map of DCs to an index into the dc_to_token_offset value for that dc. # This is how we keep track of advancing around the ring for each DC. dc_to_current_index = defaultdict(int) replica_map = defaultdict(list) for i in range(len(ring)): replicas = replica_map[ring[i]] # go through each DC and find the replicas in that DC for dc in dc_to_token_offset.keys(): if dc not in dc_rf_map: continue # advance our per-DC index until we're up to at least the # current token in the ring token_offsets = dc_to_token_offset[dc] index = dc_to_current_index[dc] num_tokens = len(token_offsets) while index < num_tokens and token_offsets[index] < i: index += 1 dc_to_current_index[dc] = index replicas_remaining = dc_rf_map[dc] replicas_this_dc = 0 skipped_hosts = [] racks_placed = set() racks_this_dc = dc_racks[dc] hosts_this_dc = len(hosts_per_dc[dc]) for token_offset in islice(cycle(token_offsets), index, index + num_tokens): host = token_to_host_owner[ring[token_offset]] if replicas_remaining == 0 or replicas_this_dc == hosts_this_dc: break if host in replicas: continue if host.rack in racks_placed and len(racks_placed) < len(racks_this_dc): skipped_hosts.append(host) continue replicas.append(host) replicas_this_dc += 1 replicas_remaining -= 1 racks_placed.add(host.rack) if len(racks_placed) == len(racks_this_dc): for host in skipped_hosts: if replicas_remaining == 0: break replicas.append(host) replicas_remaining -= 1 del skipped_hosts[:] return replica_map def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ ret = "{'class': 'NetworkTopologyStrategy'" for dc, repl_factor in sorted(self.dc_replication_factors.items()): ret += ", '%s': '%d'" % (dc, repl_factor) return ret + "}" def __eq__(self, other): if not isinstance(other, NetworkTopologyStrategy): return False return self.dc_replication_factors == other.dc_replication_factors class LocalStrategy(ReplicationStrategy): def __init__(self, options_map): pass def make_token_replica_map(self, token_to_host_owner, ring): return {} def export_for_schema(self): """ Returns a string version of these replication options which are suitable for use in a CREATE KEYSPACE statement. """ return "{'class': 'LocalStrategy'}" def __eq__(self, other): return isinstance(other, LocalStrategy) class KeyspaceMetadata(object): """ A representation of the schema for a single keyspace. """ name = None """ The string name of the keyspace. """ durable_writes = True """ A boolean indicating whether durable writes are enabled for this keyspace or not. """ replication_strategy = None """ A :class:`.ReplicationStrategy` subclass object. """ tables = None """ A map from table names to instances of :class:`~.TableMetadata`. """ indexes = None """ A dict mapping index names to :class:`.IndexMetadata` instances. """ user_types = None """ A map from user-defined type names to instances of :class:`~cassandra.metadata.UserType`. .. versionadded:: 2.1.0 """ functions = None """ A map from user-defined function signatures to instances of :class:`~cassandra.metadata.Function`. .. versionadded:: 2.6.0 """ aggregates = None """ A map from user-defined aggregate signatures to instances of :class:`~cassandra.metadata.Aggregate`. .. versionadded:: 2.6.0 """ views = None """ A dict mapping view names to :class:`.MaterializedViewMetadata` instances. """ _exc_info = None """ set if metadata parsing failed """ def __init__(self, name, durable_writes, strategy_class, strategy_options): self.name = name self.durable_writes = durable_writes self.replication_strategy = ReplicationStrategy.create(strategy_class, strategy_options) self.tables = {} self.indexes = {} self.user_types = {} self.functions = {} self.aggregates = {} self.views = {} def export_as_string(self): """ Returns a CQL query string that can be used to recreate the entire keyspace, including user-defined types and tables. """ cql = "\n\n".join([self.as_cql_query() + ';'] + self.user_type_strings() + [f.export_as_string() for f in self.functions.values()] + [a.export_as_string() for a in self.aggregates.values()] + [t.export_as_string() for t in self.tables.values()]) if self._exc_info: import traceback ret = "/*\nWarning: Keyspace %s is incomplete because of an error processing metadata.\n" % \ (self.name) for line in traceback.format_exception(*self._exc_info): ret += line ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % cql return ret return cql def as_cql_query(self): """ Returns a CQL query string that can be used to recreate just this keyspace, not including user-defined types and tables. """ ret = "CREATE KEYSPACE %s WITH replication = %s " % ( protect_name(self.name), self.replication_strategy.export_for_schema()) return ret + (' AND durable_writes = %s' % ("true" if self.durable_writes else "false")) def user_type_strings(self): user_type_strings = [] user_types = self.user_types.copy() keys = sorted(user_types.keys()) for k in keys: if k in user_types: self.resolve_user_types(k, user_types, user_type_strings) return user_type_strings def resolve_user_types(self, key, user_types, user_type_strings): user_type = user_types.pop(key) for type_name in user_type.field_types: for sub_type in types.cql_types_from_string(type_name): if sub_type in user_types: self.resolve_user_types(sub_type, user_types, user_type_strings) user_type_strings.append(user_type.export_as_string()) def _add_table_metadata(self, table_metadata): old_indexes = {} old_meta = self.tables.get(table_metadata.name, None) if old_meta: # views are not queried with table, so they must be transferred to new table_metadata.views = old_meta.views # indexes will be updated with what is on the new metadata old_indexes = old_meta.indexes # note the intentional order of add before remove # this makes sure the maps are never absent something that existed before this update for index_name, index_metadata in six.iteritems(table_metadata.indexes): self.indexes[index_name] = index_metadata for index_name in (n for n in old_indexes if n not in table_metadata.indexes): self.indexes.pop(index_name, None) self.tables[table_metadata.name] = table_metadata def _drop_table_metadata(self, table_name): table_meta = self.tables.pop(table_name, None) if table_meta: for index_name in table_meta.indexes: self.indexes.pop(index_name, None) for view_name in table_meta.views: self.views.pop(view_name, None) return # we can't tell table drops from views, so drop both # (name is unique among them, within a keyspace) view_meta = self.views.pop(table_name, None) if view_meta: try: self.tables[view_meta.base_table_name].views.pop(table_name, None) except KeyError: pass def _add_view_metadata(self, view_metadata): try: self.tables[view_metadata.base_table_name].views[view_metadata.name] = view_metadata self.views[view_metadata.name] = view_metadata except KeyError: pass class UserType(object): """ A user defined type, as created by ``CREATE TYPE`` statements. User-defined types were introduced in Cassandra 2.1. .. versionadded:: 2.1.0 """ keyspace = None """ The string name of the keyspace in which this type is defined. """ name = None """ The name of this type. """ field_names = None """ An ordered list of the names for each field in this user-defined type. """ field_types = None """ An ordered list of the types for each field in this user-defined type. """ def __init__(self, keyspace, name, field_names, field_types): self.keyspace = keyspace self.name = name # non-frozen collections can return None self.field_names = field_names or [] self.field_types = field_types or [] def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this type. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ ret = "CREATE TYPE %s.%s (%s" % ( protect_name(self.keyspace), protect_name(self.name), "\n" if formatted else "") if formatted: field_join = ",\n" padding = " " else: field_join = ", " padding = "" fields = [] for field_name, field_type in zip(self.field_names, self.field_types): fields.append("%s %s" % (protect_name(field_name), field_type)) ret += field_join.join("%s%s" % (padding, field) for field in fields) ret += "\n)" if formatted else ")" return ret def export_as_string(self): return self.as_cql_query(formatted=True) + ';' class Aggregate(object): """ A user defined aggregate function, as created by ``CREATE AGGREGATE`` statements. Aggregate functions were introduced in Cassandra 2.2 .. versionadded:: 2.6.0 """ keyspace = None """ The string name of the keyspace in which this aggregate is defined """ name = None """ The name of this aggregate """ argument_types = None """ An ordered list of the types for each argument to the aggregate """ final_func = None """ Name of a final function """ initial_condition = None """ Initial condition of the aggregate """ return_type = None """ Return type of the aggregate """ state_func = None """ Name of a state function """ state_type = None """ Type of the aggregate state """ def __init__(self, keyspace, name, argument_types, state_func, state_type, final_func, initial_condition, return_type): self.keyspace = keyspace self.name = name self.argument_types = argument_types self.state_func = state_func self.state_type = state_type self.final_func = final_func self.initial_condition = initial_condition self.return_type = return_type def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this aggregate. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace) name = protect_name(self.name) type_list = ', '.join(self.argument_types) state_func = protect_name(self.state_func) state_type = self.state_type ret = "CREATE AGGREGATE %(keyspace)s.%(name)s(%(type_list)s)%(sep)s" \ "SFUNC %(state_func)s%(sep)s" \ "STYPE %(state_type)s" % locals() ret += ''.join((sep, 'FINALFUNC ', protect_name(self.final_func))) if self.final_func else '' ret += ''.join((sep, 'INITCOND ', self.initial_condition)) if self.initial_condition is not None else '' return ret def export_as_string(self): return self.as_cql_query(formatted=True) + ';' @property def signature(self): return SignatureDescriptor.format_signature(self.name, self.argument_types) class Function(object): """ A user defined function, as created by ``CREATE FUNCTION`` statements. User-defined functions were introduced in Cassandra 2.2 .. versionadded:: 2.6.0 """ keyspace = None """ The string name of the keyspace in which this function is defined """ name = None """ The name of this function """ argument_types = None """ An ordered list of the types for each argument to the function """ argument_names = None """ An ordered list of the names of each argument to the function """ return_type = None """ Return type of the function """ language = None """ Language of the function body """ body = None """ Function body string """ called_on_null_input = None """ Flag indicating whether this function should be called for rows with null values (convenience function to avoid handling nulls explicitly if the result will just be null) """ def __init__(self, keyspace, name, argument_types, argument_names, return_type, language, body, called_on_null_input): self.keyspace = keyspace self.name = name self.argument_types = argument_types # argument_types (frozen>) will always be a list # argument_name is not frozen in C* < 3.0 and may return None self.argument_names = argument_names or [] self.return_type = return_type self.language = language self.body = body self.called_on_null_input = called_on_null_input def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this function. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace) name = protect_name(self.name) arg_list = ', '.join(["%s %s" % (protect_name(n), t) for n, t in zip(self.argument_names, self.argument_types)]) typ = self.return_type lang = self.language body = self.body on_null = "CALLED" if self.called_on_null_input else "RETURNS NULL" return "CREATE FUNCTION %(keyspace)s.%(name)s(%(arg_list)s)%(sep)s" \ "%(on_null)s ON NULL INPUT%(sep)s" \ "RETURNS %(typ)s%(sep)s" \ "LANGUAGE %(lang)s%(sep)s" \ "AS $$%(body)s$$" % locals() def export_as_string(self): return self.as_cql_query(formatted=True) + ';' @property def signature(self): return SignatureDescriptor.format_signature(self.name, self.argument_types) class TableMetadata(object): """ A representation of the schema for a single table. """ keyspace_name = None """ String name of this Table's keyspace """ name = None """ The string name of the table. """ partition_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the partition key for this table. This will always hold at least one column. """ clustering_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the clustering key for this table. These are all of the :attr:`.primary_key` columns that are not in the :attr:`.partition_key`. Note that a table may have no clustering keys, in which case this will be an empty list. """ @property def primary_key(self): """ A list of :class:`.ColumnMetadata` representing the components of the primary key for this table. """ return self.partition_key + self.clustering_key columns = None """ A dict mapping column names to :class:`.ColumnMetadata` instances. """ indexes = None """ A dict mapping index names to :class:`.IndexMetadata` instances. """ is_compact_storage = False options = None """ A dict mapping table option names to their specific settings for this table. """ compaction_options = { "min_compaction_threshold": "min_threshold", "max_compaction_threshold": "max_threshold", "compaction_strategy_class": "class"} triggers = None """ A dict mapping trigger names to :class:`.TriggerMetadata` instances. """ views = None """ A dict mapping view names to :class:`.MaterializedViewMetadata` instances. """ _exc_info = None """ set if metadata parsing failed """ @property def is_cql_compatible(self): """ A boolean indicating if this table can be represented as CQL in export """ comparator = getattr(self, 'comparator', None) if comparator: # no compact storage with more than one column beyond PK if there # are clustering columns incompatible = (self.is_compact_storage and len(self.columns) > len(self.primary_key) + 1 and len(self.clustering_key) >= 1) return not incompatible return True def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None): self.keyspace_name = keyspace_name self.name = name self.partition_key = [] if partition_key is None else partition_key self.clustering_key = [] if clustering_key is None else clustering_key self.columns = OrderedDict() if columns is None else columns self.indexes = {} self.options = {} if options is None else options self.comparator = None self.triggers = OrderedDict() if triggers is None else triggers self.views = {} def export_as_string(self): """ Returns a string of CQL queries that can be used to recreate this table along with all indexes on it. The returned string is formatted to be human readable. """ if self._exc_info: import traceback ret = "/*\nWarning: Table %s.%s is incomplete because of an error processing metadata.\n" % \ (self.keyspace_name, self.name) for line in traceback.format_exception(*self._exc_info): ret += line ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() elif not self.is_cql_compatible: # If we can't produce this table with CQL, comment inline ret = "/*\nWarning: Table %s.%s omitted because it has constructs not compatible with CQL (was created via legacy API).\n" % \ (self.keyspace_name, self.name) ret += "\nApproximate structure, for reference:\n(this should not be used to reproduce this schema)\n\n%s\n*/" % self._all_as_cql() else: ret = self._all_as_cql() return ret def _all_as_cql(self): ret = self.as_cql_query(formatted=True) ret += ";" for index in self.indexes.values(): ret += "\n%s;" % index.as_cql_query() for trigger_meta in self.triggers.values(): ret += "\n%s;" % (trigger_meta.as_cql_query(),) for view_meta in self.views.values(): ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),) return ret def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this table (index creations are not included). If `formatted` is set to :const:`True`, extra whitespace will be added to make the query human readable. """ ret = "CREATE TABLE %s.%s (%s" % ( protect_name(self.keyspace_name), protect_name(self.name), "\n" if formatted else "") if formatted: column_join = ",\n" padding = " " else: column_join = ", " padding = "" columns = [] for col in self.columns.values(): columns.append("%s %s%s" % (protect_name(col.name), col.cql_type, ' static' if col.is_static else '')) if len(self.partition_key) == 1 and not self.clustering_key: columns[0] += " PRIMARY KEY" ret += column_join.join("%s%s" % (padding, col) for col in columns) # primary key if len(self.partition_key) > 1 or self.clustering_key: ret += "%s%sPRIMARY KEY (" % (column_join, padding) if len(self.partition_key) > 1: ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key) else: ret += protect_name(self.partition_key[0].name) if self.clustering_key: ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key) ret += ")" # properties ret += "%s) WITH " % ("\n" if formatted else "") ret += self._property_string(formatted, self.clustering_key, self.options, self.is_compact_storage) return ret @classmethod def _property_string(cls, formatted, clustering_key, options_map, is_compact_storage=False): properties = [] if is_compact_storage: properties.append("COMPACT STORAGE") if clustering_key: cluster_str = "CLUSTERING ORDER BY " inner = [] for col in clustering_key: ordering = "DESC" if col.is_reversed else "ASC" inner.append("%s %s" % (protect_name(col.name), ordering)) cluster_str += "(%s)" % ", ".join(inner) properties.append(cluster_str) properties.extend(cls._make_option_strings(options_map)) join_str = "\n AND " if formatted else " AND " return join_str.join(properties) @classmethod def _make_option_strings(cls, options_map): ret = [] options_copy = dict(options_map.items()) actual_options = json.loads(options_copy.pop('compaction_strategy_options', '{}')) value = options_copy.pop("compaction_strategy_class", None) actual_options.setdefault("class", value) compaction_option_strings = ["'%s': '%s'" % (k, v) for k, v in actual_options.items()] ret.append('compaction = {%s}' % ', '.join(compaction_option_strings)) for system_table_name in cls.compaction_options.keys(): options_copy.pop(system_table_name, None) # delete if present options_copy.pop('compaction_strategy_option', None) if not options_copy.get('compression'): params = json.loads(options_copy.pop('compression_parameters', '{}')) param_strings = ["'%s': '%s'" % (k, v) for k, v in params.items()] ret.append('compression = {%s}' % ', '.join(param_strings)) for name, value in options_copy.items(): if value is not None: if name == "comment": value = value or "" ret.append("%s = %s" % (name, protect_value(value))) return list(sorted(ret)) def protect_name(name): return maybe_escape_name(name) def protect_names(names): return [protect_name(n) for n in names] def protect_value(value): if value is None: return 'NULL' if isinstance(value, (int, float, bool)): return str(value).lower() return "'%s'" % value.replace("'", "''") valid_cql3_word_re = re.compile(r'^[a-z][0-9a-z_]*$') def is_valid_name(name): if name is None: return False if name.lower() in cql_keywords_reserved: return False return valid_cql3_word_re.match(name) is not None def maybe_escape_name(name): if is_valid_name(name): return name return escape_name(name) def escape_name(name): return '"%s"' % (name.replace('"', '""'),) class ColumnMetadata(object): """ A representation of a single column in a table. """ table = None """ The :class:`.TableMetadata` this column belongs to. """ name = None """ The string name of this column. """ cql_type = None """ The CQL type for the column. """ is_static = False """ If this column is static (available in Cassandra 2.1+), this will be :const:`True`, otherwise :const:`False`. """ is_reversed = False """ If this column is reversed (DESC) as in clustering order """ _cass_type = None def __init__(self, table_metadata, column_name, cql_type, is_static=False, is_reversed=False): self.table = table_metadata self.name = column_name self.cql_type = cql_type self.is_static = is_static self.is_reversed = is_reversed def __str__(self): return "%s %s" % (self.name, self.cql_type) class IndexMetadata(object): """ A representation of a secondary index on a column. """ keyspace_name = None """ A string name of the keyspace. """ table_name = None """ A string name of the table this index is on. """ name = None """ A string name for the index. """ kind = None """ A string representing the kind of index (COMPOSITE, CUSTOM,...). """ index_options = {} """ A dict of index options. """ def __init__(self, keyspace_name, table_name, index_name, kind, index_options): self.keyspace_name = keyspace_name self.table_name = table_name self.name = index_name self.kind = kind self.index_options = index_options def as_cql_query(self): """ Returns a CQL query that can be used to recreate this index. """ options = dict(self.index_options) index_target = options.pop("target") if self.kind != "CUSTOM": return "CREATE INDEX %s ON %s.%s (%s)" % ( self.name, # Cassandra doesn't like quoted index names for some reason protect_name(self.keyspace_name), protect_name(self.table_name), index_target) else: class_name = options.pop("class_name") ret = "CREATE CUSTOM INDEX %s ON %s.%s (%s) USING '%s'" % ( self.name, # Cassandra doesn't like quoted index names for some reason protect_name(self.keyspace_name), protect_name(self.table_name), index_target, class_name) if options: ret += " WITH OPTIONS = %s" % Encoder().cql_encode_all_types(options) return ret def export_as_string(self): """ Returns a CQL query string that can be used to recreate this index. """ return self.as_cql_query() + ';' class TokenMap(object): """ Information about the layout of the ring. """ token_class = None """ A subclass of :class:`.Token`, depending on what partitioner the cluster uses. """ token_to_host_owner = None """ A map of :class:`.Token` objects to the :class:`.Host` that owns that token. """ tokens_to_hosts_by_ks = None """ A map of keyspace names to a nested map of :class:`.Token` objects to sets of :class:`.Host` objects. """ ring = None """ An ordered list of :class:`.Token` instances in the ring. """ _metadata = None def __init__(self, token_class, token_to_host_owner, all_tokens, metadata): self.token_class = token_class self.ring = all_tokens self.token_to_host_owner = token_to_host_owner self.tokens_to_hosts_by_ks = {} self._metadata = metadata self._rebuild_lock = RLock() def rebuild_keyspace(self, keyspace, build_if_absent=False): with self._rebuild_lock: try: current = self.tokens_to_hosts_by_ks.get(keyspace, None) if (build_if_absent and current is None) or (not build_if_absent and current is not None): ks_meta = self._metadata.keyspaces.get(keyspace) if ks_meta: replica_map = self.replica_map_for_keyspace(self._metadata.keyspaces[keyspace]) self.tokens_to_hosts_by_ks[keyspace] = replica_map except Exception: # should not happen normally, but we don't want to blow up queries because of unexpected meta state # bypass until new map is generated self.tokens_to_hosts_by_ks[keyspace] = {} log.exception("Failed creating a token map for keyspace '%s' with %s. PLEASE REPORT THIS: https://datastax-oss.atlassian.net/projects/PYTHON", keyspace, self.token_to_host_owner) def replica_map_for_keyspace(self, ks_metadata): strategy = ks_metadata.replication_strategy if strategy: return strategy.make_token_replica_map(self.token_to_host_owner, self.ring) else: return None def remove_keyspace(self, keyspace): self.tokens_to_hosts_by_ks.pop(keyspace, None) def get_replicas(self, keyspace, token): """ Get a set of :class:`.Host` instances representing all of the replica nodes for a given :class:`.Token`. """ tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) if tokens_to_hosts is None: self.rebuild_keyspace(keyspace, build_if_absent=True) tokens_to_hosts = self.tokens_to_hosts_by_ks.get(keyspace, None) if tokens_to_hosts: # token range ownership is exclusive on the LHS (the start token), so # we use bisect_right, which, in the case of a tie/exact match, # picks an insertion point to the right of the existing match point = bisect_right(self.ring, token) if point == len(self.ring): return tokens_to_hosts[self.ring[0]] else: return tokens_to_hosts[self.ring[point]] return [] class Token(object): """ Abstract class representing a token. """ def __init__(self, token): self.value = token @classmethod def hash_fn(cls, key): return key @classmethod def from_key(cls, key): return cls(cls.hash_fn(key)) @classmethod def from_string(cls, token_string): raise NotImplementedError() def __cmp__(self, other): if self.value < other.value: return -1 elif self.value == other.value: return 0 else: return 1 def __eq__(self, other): return self.value == other.value def __lt__(self, other): return self.value < other.value def __hash__(self): return hash(self.value) def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self.value) __str__ = __repr__ MIN_LONG = -(2 ** 63) MAX_LONG = (2 ** 63) - 1 class NoMurmur3(Exception): pass class HashToken(Token): @classmethod def from_string(cls, token_string): """ `token_string` should be the string representation from the server. """ # The hash partitioners just store the deciman value return cls(int(token_string)) class Murmur3Token(HashToken): """ A token for ``Murmur3Partitioner``. """ @classmethod def hash_fn(cls, key): if murmur3 is not None: h = int(murmur3(key)) return h if h != MIN_LONG else MAX_LONG else: raise NoMurmur3() def __init__(self, token): """ `token` is an int or string representing the token. """ self.value = int(token) class MD5Token(HashToken): """ A token for ``RandomPartitioner``. """ @classmethod def hash_fn(cls, key): if isinstance(key, six.text_type): key = key.encode('UTF-8') return abs(varint_unpack(md5(key).digest())) class BytesToken(Token): """ A token for ``ByteOrderedPartitioner``. """ @classmethod def from_string(cls, token_string): """ `token_string` should be the string representation from the server. """ # unhexlify works fine with unicode input in everythin but pypy3, where it Raises "TypeError: 'str' does not support the buffer interface" if isinstance(token_string, six.text_type): token_string = token_string.encode('ascii') # The BOP stores a hex string return cls(unhexlify(token_string)) class TriggerMetadata(object): """ A representation of a trigger for a table. """ table = None """ The :class:`.TableMetadata` this trigger belongs to. """ name = None """ The string name of this trigger. """ options = None """ A dict mapping trigger option names to their specific settings for this table. """ def __init__(self, table_metadata, trigger_name, options=None): self.table = table_metadata self.name = trigger_name self.options = options def as_cql_query(self): ret = "CREATE TRIGGER %s ON %s.%s USING %s" % ( protect_name(self.name), protect_name(self.table.keyspace_name), protect_name(self.table.name), protect_value(self.options['class']) ) return ret def export_as_string(self): return self.as_cql_query() + ';' class _SchemaParser(object): def __init__(self, connection, timeout): self.connection = connection self.timeout = timeout def _handle_results(self, success, result): if success: return dict_factory(*result.results) if result else [] else: raise result def _query_build_row(self, query_string, build_func): result = self._query_build_rows(query_string, build_func) return result[0] if result else None def _query_build_rows(self, query_string, build_func): query = QueryMessage(query=query_string, consistency_level=ConsistencyLevel.ONE) responses = self.connection.wait_for_responses((query), timeout=self.timeout, fail_on_error=False) (success, response) = responses[0] if success: result = dict_factory(*response.results) return [build_func(row) for row in result] elif isinstance(response, InvalidRequest): log.debug("user types table not found") return [] else: raise response class SchemaParserV22(_SchemaParser): _SELECT_KEYSPACES = "SELECT * FROM system.schema_keyspaces" _SELECT_COLUMN_FAMILIES = "SELECT * FROM system.schema_columnfamilies" _SELECT_COLUMNS = "SELECT * FROM system.schema_columns" _SELECT_TRIGGERS = "SELECT * FROM system.schema_triggers" _SELECT_TYPES = "SELECT * FROM system.schema_usertypes" _SELECT_FUNCTIONS = "SELECT * FROM system.schema_functions" _SELECT_AGGREGATES = "SELECT * FROM system.schema_aggregates" _table_name_col = 'columnfamily_name' _function_agg_arument_type_col = 'signature' recognized_table_options = ( "comment", "read_repair_chance", "dclocal_read_repair_chance", # kept to be safe, but see _build_table_options() "local_read_repair_chance", "replicate_on_write", "gc_grace_seconds", "bloom_filter_fp_chance", "caching", "compaction_strategy_class", "compaction_strategy_options", "min_compaction_threshold", "max_compaction_threshold", "compression_parameters", "min_index_interval", "max_index_interval", "index_interval", "speculative_retry", "rows_per_partition_to_cache", "memtable_flush_period_in_ms", "populate_io_cache_on_flush", "compression", "default_time_to_live") def __init__(self, connection, timeout): super(SchemaParserV22, self).__init__(connection, timeout) self.keyspaces_result = [] self.tables_result = [] self.columns_result = [] self.triggers_result = [] self.types_result = [] self.functions_result = [] self.aggregates_result = [] self.keyspace_table_rows = defaultdict(list) self.keyspace_table_col_rows = defaultdict(lambda: defaultdict(list)) self.keyspace_type_rows = defaultdict(list) self.keyspace_func_rows = defaultdict(list) self.keyspace_agg_rows = defaultdict(list) self.keyspace_table_trigger_rows = defaultdict(lambda: defaultdict(list)) def get_all_keyspaces(self): self._query_all() for row in self.keyspaces_result: keyspace_meta = self._build_keyspace_metadata(row) try: for table_row in self.keyspace_table_rows.get(keyspace_meta.name, []): table_meta = self._build_table_metadata(table_row) keyspace_meta._add_table_metadata(table_meta) for usertype_row in self.keyspace_type_rows.get(keyspace_meta.name, []): usertype = self._build_user_type(usertype_row) keyspace_meta.user_types[usertype.name] = usertype for fn_row in self.keyspace_func_rows.get(keyspace_meta.name, []): fn = self._build_function(fn_row) keyspace_meta.functions[fn.signature] = fn for agg_row in self.keyspace_agg_rows.get(keyspace_meta.name, []): agg = self._build_aggregate(agg_row) keyspace_meta.aggregates[agg.signature] = agg except Exception: log.exception("Error while parsing metadata for keyspace %s. Metadata model will be incomplete.", keyspace_meta.name) keyspace_meta._exc_info = sys.exc_info() yield keyspace_meta def get_table(self, keyspaces, keyspace, table): cl = ConsistencyLevel.ONE where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col,), (keyspace, table), _encoder) cf_query = QueryMessage(query=self._SELECT_COLUMN_FAMILIES + where_clause, consistency_level=cl) col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl) (cf_success, cf_result), (col_success, col_result), (triggers_success, triggers_result) \ = self.connection.wait_for_responses(cf_query, col_query, triggers_query, timeout=self.timeout, fail_on_error=False) table_result = self._handle_results(cf_success, cf_result) col_result = self._handle_results(col_success, col_result) # handle the triggers table not existing in Cassandra 1.2 if not triggers_success and isinstance(triggers_result, InvalidRequest): triggers_result = [] else: triggers_result = self._handle_results(triggers_success, triggers_result) if table_result: return self._build_table_metadata(table_result[0], col_result, triggers_result) def get_type(self, keyspaces, keyspace, type): where_clause = bind_params(" WHERE keyspace_name = %s AND type_name = %s", (keyspace, type), _encoder) return self._query_build_row(self._SELECT_TYPES + where_clause, self._build_user_type) def get_types_map(self, keyspaces, keyspace): where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) types = self._query_build_rows(self._SELECT_TYPES + where_clause, self._build_user_type) return dict((t.name, t) for t in types) def get_function(self, keyspaces, keyspace, function): where_clause = bind_params(" WHERE keyspace_name = %%s AND function_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), (keyspace, function.name, function.argument_types), _encoder) return self._query_build_row(self._SELECT_FUNCTIONS + where_clause, self._build_function) def get_aggregate(self, keyspaces, keyspace, aggregate): where_clause = bind_params(" WHERE keyspace_name = %%s AND aggregate_name = %%s AND %s = %%s" % (self._function_agg_arument_type_col,), (keyspace, aggregate.name, aggregate.argument_types), _encoder) return self._query_build_row(self._SELECT_AGGREGATES + where_clause, self._build_aggregate) def get_keyspace(self, keyspaces, keyspace): where_clause = bind_params(" WHERE keyspace_name = %s", (keyspace,), _encoder) return self._query_build_row(self._SELECT_KEYSPACES + where_clause, self._build_keyspace_metadata) @classmethod def _build_keyspace_metadata(cls, row): try: ksm = cls._build_keyspace_metadata_internal(row) except Exception: name = row["keyspace_name"] ksm = KeyspaceMetadata(name, False, 'UNKNOWN', {}) ksm._exc_info = sys.exc_info() # capture exc_info before log because nose (test) logging clears it in certain circumstances log.exception("Error while parsing metadata for keyspace %s row(%s)", name, row) return ksm @staticmethod def _build_keyspace_metadata_internal(row): name = row["keyspace_name"] durable_writes = row["durable_writes"] strategy_class = row["strategy_class"] strategy_options = json.loads(row["strategy_options"]) return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) @classmethod def _build_user_type(cls, usertype_row): field_types = list(map(cls._schema_type_to_cql, usertype_row['field_types'])) return UserType(usertype_row['keyspace_name'], usertype_row['type_name'], usertype_row['field_names'], field_types) @classmethod def _build_function(cls, function_row): return_type = cls._schema_type_to_cql(function_row['return_type']) return Function(function_row['keyspace_name'], function_row['function_name'], function_row[cls._function_agg_arument_type_col], function_row['argument_names'], return_type, function_row['language'], function_row['body'], function_row['called_on_null_input']) @classmethod def _build_aggregate(cls, aggregate_row): cass_state_type = types.lookup_casstype(aggregate_row['state_type']) initial_condition = aggregate_row['initcond'] if initial_condition is not None: initial_condition = _encoder.cql_encode_all_types(cass_state_type.deserialize(initial_condition, 3)) state_type = _cql_from_cass_type(cass_state_type) return_type = cls._schema_type_to_cql(aggregate_row['return_type']) return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], aggregate_row['signature'], aggregate_row['state_func'], state_type, aggregate_row['final_func'], initial_condition, return_type) def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): keyspace_name = row["keyspace_name"] cfname = row[self._table_name_col] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][cfname] trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][cfname] if not col_rows: # CASSANDRA-8487 log.warning("Building table metadata with no column meta for %s.%s", keyspace_name, cfname) table_meta = TableMetadata(keyspace_name, cfname) try: comparator = types.lookup_casstype(row["comparator"]) table_meta.comparator = comparator is_dct_comparator = issubclass(comparator, types.DynamicCompositeType) is_composite_comparator = issubclass(comparator, types.CompositeType) column_name_types = comparator.subtypes if is_composite_comparator else (comparator,) num_column_name_components = len(column_name_types) last_col = column_name_types[-1] column_aliases = row.get("column_aliases", None) clustering_rows = [r for r in col_rows if r.get('type', None) == "clustering_key"] if len(clustering_rows) > 1: clustering_rows = sorted(clustering_rows, key=lambda row: row.get('component_index')) if column_aliases is not None: column_aliases = json.loads(column_aliases) if not column_aliases: # json load failed or column_aliases empty PYTHON-562 column_aliases = [r.get('column_name') for r in clustering_rows] if is_composite_comparator: if issubclass(last_col, types.ColumnToCollectionType): # collections is_compact = False has_value = False clustering_size = num_column_name_components - 2 elif (len(column_aliases) == num_column_name_components - 1 and issubclass(last_col, types.UTF8Type)): # aliases? is_compact = False has_value = False clustering_size = num_column_name_components - 1 else: # compact table is_compact = True has_value = column_aliases or not col_rows clustering_size = num_column_name_components # Some thrift tables define names in composite types (see PYTHON-192) if not column_aliases and hasattr(comparator, 'fieldnames'): column_aliases = filter(None, comparator.fieldnames) else: is_compact = True if column_aliases or not col_rows or is_dct_comparator: has_value = True clustering_size = num_column_name_components else: has_value = False clustering_size = 0 # partition key partition_rows = [r for r in col_rows if r.get('type', None) == "partition_key"] if len(partition_rows) > 1: partition_rows = sorted(partition_rows, key=lambda row: row.get('component_index')) key_aliases = row.get("key_aliases") if key_aliases is not None: key_aliases = json.loads(key_aliases) if key_aliases else [] else: # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. key_aliases = [r.get('column_name') for r in partition_rows] key_validator = row.get("key_validator") if key_validator is not None: key_type = types.lookup_casstype(key_validator) key_types = key_type.subtypes if issubclass(key_type, types.CompositeType) else [key_type] else: key_types = [types.lookup_casstype(r.get('validator')) for r in partition_rows] for i, col_type in enumerate(key_types): if len(key_aliases) > i: column_name = key_aliases[i] elif i == 0: column_name = "key" else: column_name = "key%d" % i col = ColumnMetadata(table_meta, column_name, col_type.cql_parameterized_type()) table_meta.columns[column_name] = col table_meta.partition_key.append(col) # clustering key for i in range(clustering_size): if len(column_aliases) > i: column_name = column_aliases[i] else: column_name = "column%d" % (i + 1) data_type = column_name_types[i] cql_type = _cql_from_cass_type(data_type) is_reversed = types.is_reversed_casstype(data_type) col = ColumnMetadata(table_meta, column_name, cql_type, is_reversed=is_reversed) table_meta.columns[column_name] = col table_meta.clustering_key.append(col) # value alias (if present) if has_value: value_alias_rows = [r for r in col_rows if r.get('type', None) == "compact_value"] if not key_aliases: # TODO are we checking the right thing here? value_alias = "value" else: value_alias = row.get("value_alias", None) if value_alias is None and value_alias_rows: # CASSANDRA-8487 # In 2.0+, we can use the 'type' column. In 3.0+, we have to use it. value_alias = value_alias_rows[0].get('column_name') default_validator = row.get("default_validator") if default_validator: validator = types.lookup_casstype(default_validator) else: if value_alias_rows: # CASSANDRA-8487 validator = types.lookup_casstype(value_alias_rows[0].get('validator')) cql_type = _cql_from_cass_type(validator) col = ColumnMetadata(table_meta, value_alias, cql_type) if value_alias: # CASSANDRA-8487 table_meta.columns[value_alias] = col # other normal columns for col_row in col_rows: column_meta = self._build_column_metadata(table_meta, col_row) if column_meta.name: table_meta.columns[column_meta.name] = column_meta index_meta = self._build_index_metadata(column_meta, col_row) if index_meta: table_meta.indexes[index_meta.name] = index_meta for trigger_row in trigger_rows: trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) table_meta.triggers[trigger_meta.name] = trigger_meta table_meta.options = self._build_table_options(row) table_meta.is_compact_storage = is_compact except Exception: table_meta._exc_info = sys.exc_info() log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, cfname, row, col_rows) return table_meta def _build_table_options(self, row): """ Setup the mostly-non-schema table options, like caching settings """ options = dict((o, row.get(o)) for o in self.recognized_table_options if o in row) # the option name when creating tables is "dclocal_read_repair_chance", # but the column name in system.schema_columnfamilies is # "local_read_repair_chance". We'll store this as dclocal_read_repair_chance, # since that's probably what users are expecting (and we need it for the # CREATE TABLE statement anyway). if "local_read_repair_chance" in options: val = options.pop("local_read_repair_chance") options["dclocal_read_repair_chance"] = val return options @classmethod def _build_column_metadata(cls, table_metadata, row): name = row["column_name"] type_string = row["validator"] data_type = types.lookup_casstype(type_string) cql_type = _cql_from_cass_type(data_type) is_static = row.get("type", None) == "static" is_reversed = types.is_reversed_casstype(data_type) column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) column_meta._cass_type = data_type return column_meta @staticmethod def _build_index_metadata(column_metadata, row): index_name = row.get("index_name") kind = row.get("index_type") if index_name or kind: options = row.get("index_options") options = json.loads(options) if options else {} options = options or {} # if the json parsed to None, init empty dict # generate a CQL index identity string target = protect_name(column_metadata.name) if kind != "CUSTOM": if "index_keys" in options: target = 'keys(%s)' % (target,) elif "index_values" in options: # don't use any "function" for collection values pass else: # it might be a "full" index on a frozen collection, but # we need to check the data type to verify that, because # there is no special index option for full-collection # indexes. data_type = column_metadata._cass_type collection_types = ('map', 'set', 'list') if data_type.typename == "frozen" and data_type.subtypes[0].typename in collection_types: # no index option for full-collection index target = 'full(%s)' % (target,) options['target'] = target return IndexMetadata(column_metadata.table.keyspace_name, column_metadata.table.name, index_name, kind, options) @staticmethod def _build_trigger_metadata(table_metadata, row): name = row["trigger_name"] options = row["trigger_options"] trigger_meta = TriggerMetadata(table_metadata, name, options) return trigger_meta def _query_all(self): cl = ConsistencyLevel.ONE queries = [ QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMN_FAMILIES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl) ] responses = self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False) (ks_success, ks_result), (table_success, table_result), \ (col_success, col_result), (types_success, types_result), \ (functions_success, functions_result), \ (aggregates_success, aggregates_result), \ (triggers_success, triggers_result) = responses self.keyspaces_result = self._handle_results(ks_success, ks_result) self.tables_result = self._handle_results(table_success, table_result) self.columns_result = self._handle_results(col_success, col_result) # if we're connected to Cassandra < 2.0, the triggers table will not exist if triggers_success: self.triggers_result = dict_factory(*triggers_result.results) else: if isinstance(triggers_result, InvalidRequest): log.debug("triggers table not found") elif isinstance(triggers_result, Unauthorized): log.warning("this version of Cassandra does not allow access to schema_triggers metadata with authorization enabled (CASSANDRA-7967); " "The driver will operate normally, but will not reflect triggers in the local metadata model, or schema strings.") else: raise triggers_result # if we're connected to Cassandra < 2.1, the usertypes table will not exist if types_success: self.types_result = dict_factory(*types_result.results) else: if isinstance(types_result, InvalidRequest): log.debug("user types table not found") self.types_result = {} else: raise types_result # functions were introduced in Cassandra 2.2 if functions_success: self.functions_result = dict_factory(*functions_result.results) else: if isinstance(functions_result, InvalidRequest): log.debug("user functions table not found") else: raise functions_result # aggregates were introduced in Cassandra 2.2 if aggregates_success: self.aggregates_result = dict_factory(*aggregates_result.results) else: if isinstance(aggregates_result, InvalidRequest): log.debug("user aggregates table not found") else: raise aggregates_result self._aggregate_results() def _aggregate_results(self): m = self.keyspace_table_rows for row in self.tables_result: m[row["keyspace_name"]].append(row) m = self.keyspace_table_col_rows for row in self.columns_result: ksname = row["keyspace_name"] cfname = row[self._table_name_col] m[ksname][cfname].append(row) m = self.keyspace_type_rows for row in self.types_result: m[row["keyspace_name"]].append(row) m = self.keyspace_func_rows for row in self.functions_result: m[row["keyspace_name"]].append(row) m = self.keyspace_agg_rows for row in self.aggregates_result: m[row["keyspace_name"]].append(row) m = self.keyspace_table_trigger_rows for row in self.triggers_result: ksname = row["keyspace_name"] cfname = row[self._table_name_col] m[ksname][cfname].append(row) @staticmethod def _schema_type_to_cql(type_string): cass_type = types.lookup_casstype(type_string) return _cql_from_cass_type(cass_type) class SchemaParserV3(SchemaParserV22): _SELECT_KEYSPACES = "SELECT * FROM system_schema.keyspaces" _SELECT_TABLES = "SELECT * FROM system_schema.tables" _SELECT_COLUMNS = "SELECT * FROM system_schema.columns" _SELECT_INDEXES = "SELECT * FROM system_schema.indexes" _SELECT_TRIGGERS = "SELECT * FROM system_schema.triggers" _SELECT_TYPES = "SELECT * FROM system_schema.types" _SELECT_FUNCTIONS = "SELECT * FROM system_schema.functions" _SELECT_AGGREGATES = "SELECT * FROM system_schema.aggregates" _SELECT_VIEWS = "SELECT * FROM system_schema.views" _table_name_col = 'table_name' _function_agg_arument_type_col = 'argument_types' recognized_table_options = ( 'bloom_filter_fp_chance', 'caching', 'cdc', 'comment', 'compaction', 'compression', 'crc_check_chance', 'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds', 'max_index_interval', 'memtable_flush_period_in_ms', 'min_index_interval', 'read_repair_chance', 'speculative_retry') def __init__(self, connection, timeout): super(SchemaParserV3, self).__init__(connection, timeout) self.indexes_result = [] self.keyspace_table_index_rows = defaultdict(lambda: defaultdict(list)) self.keyspace_view_rows = defaultdict(list) def get_all_keyspaces(self): for keyspace_meta in super(SchemaParserV3, self).get_all_keyspaces(): for row in self.keyspace_view_rows[keyspace_meta.name]: view_meta = self._build_view_metadata(row) keyspace_meta._add_view_metadata(view_meta) yield keyspace_meta def get_table(self, keyspaces, keyspace, table): cl = ConsistencyLevel.ONE where_clause = bind_params(" WHERE keyspace_name = %%s AND %s = %%s" % (self._table_name_col), (keyspace, table), _encoder) cf_query = QueryMessage(query=self._SELECT_TABLES + where_clause, consistency_level=cl) col_query = QueryMessage(query=self._SELECT_COLUMNS + where_clause, consistency_level=cl) indexes_query = QueryMessage(query=self._SELECT_INDEXES + where_clause, consistency_level=cl) triggers_query = QueryMessage(query=self._SELECT_TRIGGERS + where_clause, consistency_level=cl) # in protocol v4 we don't know if this event is a view or a table, so we look for both where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder) view_query = QueryMessage(query=self._SELECT_VIEWS + where_clause, consistency_level=cl) (cf_success, cf_result), (col_success, col_result), (indexes_sucess, indexes_result), \ (triggers_success, triggers_result), (view_success, view_result) \ = self.connection.wait_for_responses(cf_query, col_query, indexes_query, triggers_query, view_query, timeout=self.timeout, fail_on_error=False) table_result = self._handle_results(cf_success, cf_result) col_result = self._handle_results(col_success, col_result) if table_result: indexes_result = self._handle_results(indexes_sucess, indexes_result) triggers_result = self._handle_results(triggers_success, triggers_result) return self._build_table_metadata(table_result[0], col_result, triggers_result, indexes_result) view_result = self._handle_results(view_success, view_result) if view_result: return self._build_view_metadata(view_result[0], col_result) @staticmethod def _build_keyspace_metadata_internal(row): name = row["keyspace_name"] durable_writes = row["durable_writes"] strategy_options = dict(row["replication"]) strategy_class = strategy_options.pop("class") return KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) @staticmethod def _build_aggregate(aggregate_row): return Aggregate(aggregate_row['keyspace_name'], aggregate_row['aggregate_name'], aggregate_row['argument_types'], aggregate_row['state_func'], aggregate_row['state_type'], aggregate_row['final_func'], aggregate_row['initcond'], aggregate_row['return_type']) def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_rows=None): keyspace_name = row["keyspace_name"] table_name = row[self._table_name_col] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][table_name] trigger_rows = trigger_rows or self.keyspace_table_trigger_rows[keyspace_name][table_name] index_rows = index_rows or self.keyspace_table_index_rows[keyspace_name][table_name] table_meta = TableMetadataV3(keyspace_name, table_name) try: table_meta.options = self._build_table_options(row) flags = row.get('flags', set()) if flags: compact_static = False table_meta.is_compact_storage = 'dense' in flags or 'super' in flags or 'compound' not in flags is_dense = 'dense' in flags else: compact_static = True table_meta.is_compact_storage = True is_dense = False self._build_table_columns(table_meta, col_rows, compact_static, is_dense) for trigger_row in trigger_rows: trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) table_meta.triggers[trigger_meta.name] = trigger_meta for index_row in index_rows: index_meta = self._build_index_metadata(table_meta, index_row) if index_meta: table_meta.indexes[index_meta.name] = index_meta except Exception: table_meta._exc_info = sys.exc_info() log.exception("Error while parsing metadata for table %s.%s row(%s) columns(%s)", keyspace_name, table_name, row, col_rows) return table_meta def _build_table_options(self, row): """ Setup the mostly-non-schema table options, like caching settings """ return dict((o, row.get(o)) for o in self.recognized_table_options if o in row) def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=False): # partition key partition_rows = [r for r in col_rows if r.get('kind', None) == "partition_key"] if len(partition_rows) > 1: partition_rows = sorted(partition_rows, key=lambda row: row.get('position')) for r in partition_rows: # we have to add meta here (and not in the later loop) because TableMetadata.columns is an # OrderedDict, and it assumes keys are inserted first, in order, when exporting CQL column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta meta.partition_key.append(meta.columns[r.get('column_name')]) # clustering key if not compact_static: clustering_rows = [r for r in col_rows if r.get('kind', None) == "clustering"] if len(clustering_rows) > 1: clustering_rows = sorted(clustering_rows, key=lambda row: row.get('position')) for r in clustering_rows: column_meta = self._build_column_metadata(meta, r) meta.columns[column_meta.name] = column_meta meta.clustering_key.append(meta.columns[r.get('column_name')]) for col_row in (r for r in col_rows if r.get('kind', None) not in ('partition_key', 'clustering_key')): column_meta = self._build_column_metadata(meta, col_row) if is_dense and column_meta.cql_type == types.cql_empty_type: continue if compact_static and not column_meta.is_static: # for compact static tables, we omit the clustering key and value, and only add the logical columns. # They are marked not static so that it generates appropriate CQL continue if compact_static: column_meta.is_static = False meta.columns[column_meta.name] = column_meta def _build_view_metadata(self, row, col_rows=None): keyspace_name = row["keyspace_name"] view_name = row["view_name"] base_table_name = row["base_table_name"] include_all_columns = row["include_all_columns"] where_clause = row["where_clause"] col_rows = col_rows or self.keyspace_table_col_rows[keyspace_name][view_name] view_meta = MaterializedViewMetadata(keyspace_name, view_name, base_table_name, include_all_columns, where_clause, self._build_table_options(row)) self._build_table_columns(view_meta, col_rows) return view_meta @staticmethod def _build_column_metadata(table_metadata, row): name = row["column_name"] cql_type = row["type"] is_static = row.get("kind", None) == "static" is_reversed = row["clustering_order"].upper() == "DESC" column_meta = ColumnMetadata(table_metadata, name, cql_type, is_static, is_reversed) return column_meta @staticmethod def _build_index_metadata(table_metadata, row): index_name = row.get("index_name") kind = row.get("kind") if index_name or kind: index_options = row.get("options") return IndexMetadata(table_metadata.keyspace_name, table_metadata.name, index_name, kind, index_options) else: return None @staticmethod def _build_trigger_metadata(table_metadata, row): name = row["trigger_name"] options = row["options"] trigger_meta = TriggerMetadata(table_metadata, name, options) return trigger_meta def _query_all(self): cl = ConsistencyLevel.ONE queries = [ QueryMessage(query=self._SELECT_KEYSPACES, consistency_level=cl), QueryMessage(query=self._SELECT_TABLES, consistency_level=cl), QueryMessage(query=self._SELECT_COLUMNS, consistency_level=cl), QueryMessage(query=self._SELECT_TYPES, consistency_level=cl), QueryMessage(query=self._SELECT_FUNCTIONS, consistency_level=cl), QueryMessage(query=self._SELECT_AGGREGATES, consistency_level=cl), QueryMessage(query=self._SELECT_TRIGGERS, consistency_level=cl), QueryMessage(query=self._SELECT_INDEXES, consistency_level=cl), QueryMessage(query=self._SELECT_VIEWS, consistency_level=cl) ] responses = self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False) (ks_success, ks_result), (table_success, table_result), \ (col_success, col_result), (types_success, types_result), \ (functions_success, functions_result), \ (aggregates_success, aggregates_result), \ (triggers_success, triggers_result), \ (indexes_success, indexes_result), \ (views_success, views_result) = responses self.keyspaces_result = self._handle_results(ks_success, ks_result) self.tables_result = self._handle_results(table_success, table_result) self.columns_result = self._handle_results(col_success, col_result) self.triggers_result = self._handle_results(triggers_success, triggers_result) self.types_result = self._handle_results(types_success, types_result) self.functions_result = self._handle_results(functions_success, functions_result) self.aggregates_result = self._handle_results(aggregates_success, aggregates_result) self.indexes_result = self._handle_results(indexes_success, indexes_result) self.views_result = self._handle_results(views_success, views_result) self._aggregate_results() def _aggregate_results(self): super(SchemaParserV3, self)._aggregate_results() m = self.keyspace_table_index_rows for row in self.indexes_result: ksname = row["keyspace_name"] cfname = row[self._table_name_col] m[ksname][cfname].append(row) m = self.keyspace_view_rows for row in self.views_result: m[row["keyspace_name"]].append(row) @staticmethod def _schema_type_to_cql(type_string): return type_string class TableMetadataV3(TableMetadata): compaction_options = {} option_maps = ['compaction', 'compression', 'caching'] @property def is_cql_compatible(self): return True @classmethod def _make_option_strings(cls, options_map): ret = [] options_copy = dict(options_map.items()) for option in cls.option_maps: value = options_copy.get(option) if isinstance(value, Mapping): del options_copy[option] params = ("'%s': '%s'" % (k, v) for k, v in value.items()) ret.append("%s = {%s}" % (option, ', '.join(params))) for name, value in options_copy.items(): if value is not None: if name == "comment": value = value or "" ret.append("%s = %s" % (name, protect_value(value))) return list(sorted(ret)) class MaterializedViewMetadata(object): """ A representation of a materialized view on a table """ keyspace_name = None """ A string name of the view.""" name = None """ A string name of the view.""" base_table_name = None """ A string name of the base table for this view.""" partition_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the partition key for this view. This will always hold at least one column. """ clustering_key = None """ A list of :class:`.ColumnMetadata` instances representing the columns in the clustering key for this view. Note that a table may have no clustering keys, in which case this will be an empty list. """ columns = None """ A dict mapping column names to :class:`.ColumnMetadata` instances. """ include_all_columns = None """ A flag indicating whether the view was created AS SELECT * """ where_clause = None """ String WHERE clause for the view select statement. From server metadata """ options = None """ A dict mapping table option names to their specific settings for this view. """ def __init__(self, keyspace_name, view_name, base_table_name, include_all_columns, where_clause, options): self.keyspace_name = keyspace_name self.name = view_name self.base_table_name = base_table_name self.partition_key = [] self.clustering_key = [] self.columns = OrderedDict() self.include_all_columns = include_all_columns self.where_clause = where_clause self.options = options or {} def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this function. If `formatted` is set to :const:`True`, extra whitespace will be added to make the query more readable. """ sep = '\n ' if formatted else ' ' keyspace = protect_name(self.keyspace_name) name = protect_name(self.name) selected_cols = '*' if self.include_all_columns else ', '.join(protect_name(col.name) for col in self.columns.values()) base_table = protect_name(self.base_table_name) where_clause = self.where_clause part_key = ', '.join(protect_name(col.name) for col in self.partition_key) if len(self.partition_key) > 1: pk = "((%s)" % part_key else: pk = "(%s" % part_key if self.clustering_key: pk += ", %s" % ', '.join(protect_name(col.name) for col in self.clustering_key) pk += ")" properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options) return "CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" \ "SELECT %(selected_cols)s%(sep)s" \ "FROM %(keyspace)s.%(base_table)s%(sep)s" \ "WHERE %(where_clause)s%(sep)s" \ "PRIMARY KEY %(pk)s%(sep)s" \ "WITH %(properties)s" % locals() def export_as_string(self): return self.as_cql_query(formatted=True) + ";" def get_schema_parser(connection, server_version, timeout): if server_version.startswith('3'): return SchemaParserV3(connection, timeout) else: # we could further specialize by version. Right now just refactoring the # multi-version parser we have as of C* 2.2.0rc1. return SchemaParserV22(connection, timeout) def _cql_from_cass_type(cass_type): """ A string representation of the type for this column, such as "varchar" or "map". """ if issubclass(cass_type, types.ReversedType): return cass_type.subtypes[0].cql_parameterized_type() else: return cass_type.cql_parameterized_type() cassandra-driver-3.7.1/cassandra/protocol.py0000664000175000017500000012130313004141114023753 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import # to enable import io from stdlib from collections import namedtuple import logging import socket from uuid import UUID import six from six.moves import range import io from cassandra import type_codes, DriverException from cassandra import (Unavailable, WriteTimeout, ReadTimeout, WriteFailure, ReadFailure, FunctionFailure, AlreadyExists, InvalidRequest, Unauthorized, UnsupportedOperation, UserFunctionDescriptor, UserAggregateDescriptor, SchemaTargetType) from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, int8_pack, int8_unpack, uint64_pack, header_pack, v3_header_pack) from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, CounterColumnType, DateType, DecimalType, DoubleType, FloatType, Int32Type, InetAddressType, IntegerType, ListType, LongType, MapType, SetType, TimeUUIDType, UTF8Type, VarcharType, UUIDType, UserType, TupleType, lookup_casstype, SimpleDateType, TimeType, ByteType, ShortType) from cassandra.policies import WriteType from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY from cassandra import util log = logging.getLogger(__name__) class NotSupportedError(Exception): pass class InternalError(Exception): pass ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type']) MIN_SUPPORTED_VERSION = 1 MAX_SUPPORTED_VERSION = 5 HEADER_DIRECTION_TO_CLIENT = 0x80 HEADER_DIRECTION_MASK = 0x80 COMPRESSED_FLAG = 0x01 TRACING_FLAG = 0x02 CUSTOM_PAYLOAD_FLAG = 0x04 WARNING_FLAG = 0x08 USE_BETA_FLAG = 0x10 USE_BETA_MASK = ~USE_BETA_FLAG _message_types_by_opcode = {} _UNSET_VALUE = object() def register_class(cls): _message_types_by_opcode[cls.opcode] = cls def get_registered_classes(): return _message_types_by_opcode.copy() class _RegisterMessageType(type): def __init__(cls, name, bases, dct): if not name.startswith('_'): register_class(cls) @six.add_metaclass(_RegisterMessageType) class _MessageType(object): tracing = False custom_payload = None warnings = None def update_custom_payload(self, other): if other: if not self.custom_payload: self.custom_payload = {} self.custom_payload.update(other) if len(self.custom_payload) > 65535: raise ValueError("Custom payload map exceeds max count allowed by protocol (65535)") def __repr__(self): return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self))) def _get_params(message_obj): base_attrs = dir(_MessageType) return ( (n, a) for n, a in message_obj.__dict__.items() if n not in base_attrs and not n.startswith('_') and not callable(a) ) error_classes = {} class ErrorMessage(_MessageType, Exception): opcode = 0x00 name = 'ERROR' summary = 'Unknown' def __init__(self, code, message, info): self.code = code self.message = message self.info = info @classmethod def recv_body(cls, f, protocol_version, *args): code = read_int(f) msg = read_string(f) subcls = error_classes.get(code, cls) extra_info = subcls.recv_error_info(f, protocol_version) return subcls(code=code, message=msg, info=extra_info) def summary_msg(self): msg = 'Error from server: code=%04x [%s] message="%s"' \ % (self.code, self.summary, self.message) if six.PY2 and isinstance(msg, six.text_type): msg = msg.encode('utf-8') return msg def __str__(self): return '<%s>' % self.summary_msg() __repr__ = __str__ @staticmethod def recv_error_info(f, protocol_version): pass def to_exception(self): return self class ErrorMessageSubclass(_RegisterMessageType): def __init__(cls, name, bases, dct): if cls.error_code is not None: # Server has an error code of 0. error_classes[cls.error_code] = cls @six.add_metaclass(ErrorMessageSubclass) class ErrorMessageSub(ErrorMessage): error_code = None class RequestExecutionException(ErrorMessageSub): pass class RequestValidationException(ErrorMessageSub): pass class ServerError(ErrorMessageSub): summary = 'Server error' error_code = 0x0000 class ProtocolException(ErrorMessageSub): summary = 'Protocol error' error_code = 0x000A class BadCredentials(ErrorMessageSub): summary = 'Bad credentials' error_code = 0x0100 class UnavailableErrorMessage(RequestExecutionException): summary = 'Unavailable exception' error_code = 0x1000 @staticmethod def recv_error_info(f, protocol_version): return { 'consistency': read_consistency_level(f), 'required_replicas': read_int(f), 'alive_replicas': read_int(f), } def to_exception(self): return Unavailable(self.summary_msg(), **self.info) class OverloadedErrorMessage(RequestExecutionException): summary = 'Coordinator node overloaded' error_code = 0x1001 class IsBootstrappingErrorMessage(RequestExecutionException): summary = 'Coordinator node is bootstrapping' error_code = 0x1002 class TruncateError(RequestExecutionException): summary = 'Error during truncate' error_code = 0x1003 class WriteTimeoutErrorMessage(RequestExecutionException): summary = "Coordinator node timed out waiting for replica nodes' responses" error_code = 0x1100 @staticmethod def recv_error_info(f, protocol_version): return { 'consistency': read_consistency_level(f), 'received_responses': read_int(f), 'required_responses': read_int(f), 'write_type': WriteType.name_to_value[read_string(f)], } def to_exception(self): return WriteTimeout(self.summary_msg(), **self.info) class ReadTimeoutErrorMessage(RequestExecutionException): summary = "Coordinator node timed out waiting for replica nodes' responses" error_code = 0x1200 @staticmethod def recv_error_info(f, protocol_version): return { 'consistency': read_consistency_level(f), 'received_responses': read_int(f), 'required_responses': read_int(f), 'data_retrieved': bool(read_byte(f)), } def to_exception(self): return ReadTimeout(self.summary_msg(), **self.info) class ReadFailureMessage(RequestExecutionException): summary = "Replica(s) failed to execute read" error_code = 0x1300 @staticmethod def recv_error_info(f, protocol_version): consistency = read_consistency_level(f) received_responses = read_int(f) required_responses = read_int(f) if protocol_version >= 5: error_code_map = read_error_code_map(f) failures = len(error_code_map) else: error_code_map = None failures = read_int(f) data_retrieved = bool(read_byte(f)) return { 'consistency': consistency, 'received_responses': received_responses, 'required_responses': required_responses, 'failures': failures, 'error_code_map': error_code_map, 'data_retrieved': data_retrieved } def to_exception(self): return ReadFailure(self.summary_msg(), **self.info) class FunctionFailureMessage(RequestExecutionException): summary = "User Defined Function failure" error_code = 0x1400 @staticmethod def recv_error_info(f, protocol_version): return { 'keyspace': read_string(f), 'function': read_string(f), 'arg_types': [read_string(f) for _ in range(read_short(f))], } def to_exception(self): return FunctionFailure(self.summary_msg(), **self.info) class WriteFailureMessage(RequestExecutionException): summary = "Replica(s) failed to execute write" error_code = 0x1500 @staticmethod def recv_error_info(f, protocol_version): consistency = read_consistency_level(f) received_responses = read_int(f) required_responses = read_int(f) if protocol_version >= 5: error_code_map = read_error_code_map(f) failures = len(error_code_map) else: error_code_map = None failures = read_int(f) write_type = WriteType.name_to_value[read_string(f)] return { 'consistency': consistency, 'received_responses': received_responses, 'required_responses': required_responses, 'failures': failures, 'error_code_map': error_code_map, 'write_type': write_type } def to_exception(self): return WriteFailure(self.summary_msg(), **self.info) class SyntaxException(RequestValidationException): summary = 'Syntax error in CQL query' error_code = 0x2000 class UnauthorizedErrorMessage(RequestValidationException): summary = 'Unauthorized' error_code = 0x2100 def to_exception(self): return Unauthorized(self.summary_msg()) class InvalidRequestException(RequestValidationException): summary = 'Invalid query' error_code = 0x2200 def to_exception(self): return InvalidRequest(self.summary_msg()) class ConfigurationException(RequestValidationException): summary = 'Query invalid because of configuration issue' error_code = 0x2300 class PreparedQueryNotFound(RequestValidationException): summary = 'Matching prepared statement not found on this node' error_code = 0x2500 @staticmethod def recv_error_info(f, protocol_version): # return the query ID return read_binary_string(f) class AlreadyExistsException(ConfigurationException): summary = 'Item already exists' error_code = 0x2400 @staticmethod def recv_error_info(f, protocol_version): return { 'keyspace': read_string(f), 'table': read_string(f), } def to_exception(self): return AlreadyExists(**self.info) class StartupMessage(_MessageType): opcode = 0x01 name = 'STARTUP' KNOWN_OPTION_KEYS = set(( 'CQL_VERSION', 'COMPRESSION', )) def __init__(self, cqlversion, options): self.cqlversion = cqlversion self.options = options def send_body(self, f, protocol_version): optmap = self.options.copy() optmap['CQL_VERSION'] = self.cqlversion write_stringmap(f, optmap) class ReadyMessage(_MessageType): opcode = 0x02 name = 'READY' @classmethod def recv_body(cls, *args): return cls() class AuthenticateMessage(_MessageType): opcode = 0x03 name = 'AUTHENTICATE' def __init__(self, authenticator): self.authenticator = authenticator @classmethod def recv_body(cls, f, *args): authname = read_string(f) return cls(authenticator=authname) class CredentialsMessage(_MessageType): opcode = 0x04 name = 'CREDENTIALS' def __init__(self, creds): self.creds = creds def send_body(self, f, protocol_version): if protocol_version > 1: raise UnsupportedOperation( "Credentials-based authentication is not supported with " "protocol version 2 or higher. Use the SASL authentication " "mechanism instead.") write_short(f, len(self.creds)) for credkey, credval in self.creds.items(): write_string(f, credkey) write_string(f, credval) class AuthChallengeMessage(_MessageType): opcode = 0x0E name = 'AUTH_CHALLENGE' def __init__(self, challenge): self.challenge = challenge @classmethod def recv_body(cls, f, *args): return cls(read_binary_longstring(f)) class AuthResponseMessage(_MessageType): opcode = 0x0F name = 'AUTH_RESPONSE' def __init__(self, response): self.response = response def send_body(self, f, protocol_version): write_longstring(f, self.response) class AuthSuccessMessage(_MessageType): opcode = 0x10 name = 'AUTH_SUCCESS' def __init__(self, token): self.token = token @classmethod def recv_body(cls, f, *args): return cls(read_longstring(f)) class OptionsMessage(_MessageType): opcode = 0x05 name = 'OPTIONS' def send_body(self, f, protocol_version): pass class SupportedMessage(_MessageType): opcode = 0x06 name = 'SUPPORTED' def __init__(self, cql_versions, options): self.cql_versions = cql_versions self.options = options @classmethod def recv_body(cls, f, *args): options = read_stringmultimap(f) cql_versions = options.pop('CQL_VERSION') return cls(cql_versions=cql_versions, options=options) # used for QueryMessage and ExecuteMessage _VALUES_FLAG = 0x01 _SKIP_METADATA_FLAG = 0x02 _PAGE_SIZE_FLAG = 0x04 _WITH_PAGING_STATE_FLAG = 0x08 _WITH_SERIAL_CONSISTENCY_FLAG = 0x10 _PROTOCOL_TIMESTAMP = 0x20 class QueryMessage(_MessageType): opcode = 0x07 name = 'QUERY' def __init__(self, query, consistency_level, serial_consistency_level=None, fetch_size=None, paging_state=None, timestamp=None): self.query = query self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level self.fetch_size = fetch_size self.paging_state = paging_state self.timestamp = timestamp self._query_params = None # only used internally. May be set to a list of native-encoded values to have them sent with the request. def send_body(self, f, protocol_version): write_longstring(f, self.query) write_consistency_level(f, self.consistency_level) flags = 0x00 if self._query_params is not None: flags |= _VALUES_FLAG # also v2+, but we're only setting params internally right now if self.serial_consistency_level: if protocol_version >= 2: flags |= _WITH_SERIAL_CONSISTENCY_FLAG else: raise UnsupportedOperation( "Serial consistency levels require the use of protocol version " "2 or higher. Consider setting Cluster.protocol_version to 2 " "to support serial consistency levels.") if self.fetch_size: if protocol_version >= 2: flags |= _PAGE_SIZE_FLAG else: raise UnsupportedOperation( "Automatic query paging may only be used with protocol version " "2 or higher. Consider setting Cluster.protocol_version to 2.") if self.paging_state: if protocol_version >= 2: flags |= _WITH_PAGING_STATE_FLAG else: raise UnsupportedOperation( "Automatic query paging may only be used with protocol version " "2 or higher. Consider setting Cluster.protocol_version to 2.") if self.timestamp is not None: flags |= _PROTOCOL_TIMESTAMP write_byte(f, flags) if self._query_params is not None: write_short(f, len(self._query_params)) for param in self._query_params: write_value(f, param) if self.fetch_size: write_int(f, self.fetch_size) if self.paging_state: write_longstring(f, self.paging_state) if self.serial_consistency_level: write_consistency_level(f, self.serial_consistency_level) if self.timestamp is not None: write_long(f, self.timestamp) CUSTOM_TYPE = object() RESULT_KIND_VOID = 0x0001 RESULT_KIND_ROWS = 0x0002 RESULT_KIND_SET_KEYSPACE = 0x0003 RESULT_KIND_PREPARED = 0x0004 RESULT_KIND_SCHEMA_CHANGE = 0x0005 class ResultMessage(_MessageType): opcode = 0x08 name = 'RESULT' kind = None results = None paging_state = None # Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE) type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_')) _FLAGS_GLOBAL_TABLES_SPEC = 0x0001 _HAS_MORE_PAGES_FLAG = 0x0002 _NO_METADATA_FLAG = 0x0004 def __init__(self, kind, results, paging_state=None): self.kind = kind self.results = results self.paging_state = paging_state @classmethod def recv_body(cls, f, protocol_version, user_type_map, result_metadata): kind = read_int(f) paging_state = None if kind == RESULT_KIND_VOID: results = None elif kind == RESULT_KIND_ROWS: paging_state, results = cls.recv_results_rows( f, protocol_version, user_type_map, result_metadata) elif kind == RESULT_KIND_SET_KEYSPACE: ksname = read_string(f) results = ksname elif kind == RESULT_KIND_PREPARED: results = cls.recv_results_prepared(f, protocol_version, user_type_map) elif kind == RESULT_KIND_SCHEMA_CHANGE: results = cls.recv_results_schema_change(f, protocol_version) else: raise DriverException("Unknown RESULT kind: %d" % kind) return cls(kind, results, paging_state) @classmethod def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) column_metadata = column_metadata or result_metadata rowcount = read_int(f) rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] colnames = [c[2] for c in column_metadata] coltypes = [c[3] for c in column_metadata] try: parsed_rows = [ tuple(ctype.from_binary(val, protocol_version) for ctype, val in zip(coltypes, row)) for row in rows] except Exception: for i in range(len(row)): try: coltypes[i].from_binary(row[i], protocol_version) except Exception as e: raise DriverException('Failed decoding result column "%s" of type %s: %s' % (colnames[i], coltypes[i].cql_parameterized_type(), e.message)) return paging_state, (colnames, parsed_rows) @classmethod def recv_results_prepared(cls, f, protocol_version, user_type_map): query_id = read_binary_string(f) bind_metadata, pk_indexes, result_metadata = cls.recv_prepared_metadata(f, protocol_version, user_type_map) return query_id, bind_metadata, pk_indexes, result_metadata @classmethod def recv_results_metadata(cls, f, user_type_map): flags = read_int(f) colcount = read_int(f) if flags & cls._HAS_MORE_PAGES_FLAG: paging_state = read_binary_longstring(f) else: paging_state = None no_meta = bool(flags & cls._NO_METADATA_FLAG) if no_meta: return paging_state, [] glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) if glob_tblspec: ksname = read_string(f) cfname = read_string(f) column_metadata = [] for _ in range(colcount): if glob_tblspec: colksname = ksname colcfname = cfname else: colksname = read_string(f) colcfname = read_string(f) colname = read_string(f) coltype = cls.read_type(f, user_type_map) column_metadata.append((colksname, colcfname, colname, coltype)) return paging_state, column_metadata @classmethod def recv_prepared_metadata(cls, f, protocol_version, user_type_map): flags = read_int(f) colcount = read_int(f) pk_indexes = None if protocol_version >= 4: num_pk_indexes = read_int(f) pk_indexes = [read_short(f) for _ in range(num_pk_indexes)] glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) if glob_tblspec: ksname = read_string(f) cfname = read_string(f) bind_metadata = [] for _ in range(colcount): if glob_tblspec: colksname = ksname colcfname = cfname else: colksname = read_string(f) colcfname = read_string(f) colname = read_string(f) coltype = cls.read_type(f, user_type_map) bind_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype)) if protocol_version >= 2: _, result_metadata = cls.recv_results_metadata(f, user_type_map) return bind_metadata, pk_indexes, result_metadata else: return bind_metadata, pk_indexes, None @classmethod def recv_results_schema_change(cls, f, protocol_version): return EventMessage.recv_schema_change(f, protocol_version) @classmethod def read_type(cls, f, user_type_map): optid = read_short(f) try: typeclass = cls.type_codes[optid] except KeyError: raise NotSupportedError("Unknown data type code 0x%04x. Have to skip" " entire result set." % (optid,)) if typeclass in (ListType, SetType): subtype = cls.read_type(f, user_type_map) typeclass = typeclass.apply_parameters((subtype,)) elif typeclass == MapType: keysubtype = cls.read_type(f, user_type_map) valsubtype = cls.read_type(f, user_type_map) typeclass = typeclass.apply_parameters((keysubtype, valsubtype)) elif typeclass == TupleType: num_items = read_short(f) types = tuple(cls.read_type(f, user_type_map) for _ in range(num_items)) typeclass = typeclass.apply_parameters(types) elif typeclass == UserType: ks = read_string(f) udt_name = read_string(f) num_fields = read_short(f) names, types = zip(*((read_string(f), cls.read_type(f, user_type_map)) for _ in range(num_fields))) specialized_type = typeclass.make_udt_class(ks, udt_name, names, types) specialized_type.mapped_class = user_type_map.get(ks, {}).get(udt_name) typeclass = specialized_type elif typeclass == CUSTOM_TYPE: classname = read_string(f) typeclass = lookup_casstype(classname) return typeclass @staticmethod def recv_row(f, colcount): return [read_value(f) for _ in range(colcount)] class PrepareMessage(_MessageType): opcode = 0x09 name = 'PREPARE' def __init__(self, query): self.query = query def send_body(self, f, protocol_version): write_longstring(f, self.query) class ExecuteMessage(_MessageType): opcode = 0x0A name = 'EXECUTE' def __init__(self, query_id, query_params, consistency_level, serial_consistency_level=None, fetch_size=None, paging_state=None, timestamp=None, skip_meta=False): self.query_id = query_id self.query_params = query_params self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level self.fetch_size = fetch_size self.paging_state = paging_state self.timestamp = timestamp self.skip_meta = skip_meta def send_body(self, f, protocol_version): write_string(f, self.query_id) if protocol_version == 1: if self.serial_consistency_level: raise UnsupportedOperation( "Serial consistency levels require the use of protocol version " "2 or higher. Consider setting Cluster.protocol_version to 2 " "to support serial consistency levels.") if self.fetch_size or self.paging_state: raise UnsupportedOperation( "Automatic query paging may only be used with protocol version " "2 or higher. Consider setting Cluster.protocol_version to 2.") write_short(f, len(self.query_params)) for param in self.query_params: write_value(f, param) write_consistency_level(f, self.consistency_level) else: write_consistency_level(f, self.consistency_level) flags = _VALUES_FLAG if self.serial_consistency_level: flags |= _WITH_SERIAL_CONSISTENCY_FLAG if self.fetch_size: flags |= _PAGE_SIZE_FLAG if self.paging_state: flags |= _WITH_PAGING_STATE_FLAG if self.timestamp is not None: if protocol_version >= 3: flags |= _PROTOCOL_TIMESTAMP else: raise UnsupportedOperation( "Protocol-level timestamps may only be used with protocol version " "3 or higher. Consider setting Cluster.protocol_version to 3.") if self.skip_meta: flags |= _SKIP_METADATA_FLAG write_byte(f, flags) write_short(f, len(self.query_params)) for param in self.query_params: write_value(f, param) if self.fetch_size: write_int(f, self.fetch_size) if self.paging_state: write_longstring(f, self.paging_state) if self.serial_consistency_level: write_consistency_level(f, self.serial_consistency_level) if self.timestamp is not None: write_long(f, self.timestamp) class BatchMessage(_MessageType): opcode = 0x0D name = 'BATCH' def __init__(self, batch_type, queries, consistency_level, serial_consistency_level=None, timestamp=None): self.batch_type = batch_type self.queries = queries self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level self.timestamp = timestamp def send_body(self, f, protocol_version): write_byte(f, self.batch_type.value) write_short(f, len(self.queries)) for prepared, string_or_query_id, params in self.queries: if not prepared: write_byte(f, 0) write_longstring(f, string_or_query_id) else: write_byte(f, 1) write_short(f, len(string_or_query_id)) f.write(string_or_query_id) write_short(f, len(params)) for param in params: write_value(f, param) write_consistency_level(f, self.consistency_level) if protocol_version >= 3: flags = 0 if self.serial_consistency_level: flags |= _WITH_SERIAL_CONSISTENCY_FLAG if self.timestamp is not None: flags |= _PROTOCOL_TIMESTAMP write_byte(f, flags) if self.serial_consistency_level: write_consistency_level(f, self.serial_consistency_level) if self.timestamp is not None: write_long(f, self.timestamp) known_event_types = frozenset(( 'TOPOLOGY_CHANGE', 'STATUS_CHANGE', 'SCHEMA_CHANGE' )) class RegisterMessage(_MessageType): opcode = 0x0B name = 'REGISTER' def __init__(self, event_list): self.event_list = event_list def send_body(self, f, protocol_version): write_stringlist(f, self.event_list) class EventMessage(_MessageType): opcode = 0x0C name = 'EVENT' def __init__(self, event_type, event_args): self.event_type = event_type self.event_args = event_args @classmethod def recv_body(cls, f, protocol_version, *args): event_type = read_string(f).upper() if event_type in known_event_types: read_method = getattr(cls, 'recv_' + event_type.lower()) return cls(event_type=event_type, event_args=read_method(f, protocol_version)) raise NotSupportedError('Unknown event type %r' % event_type) @classmethod def recv_topology_change(cls, f, protocol_version): # "NEW_NODE" or "REMOVED_NODE" change_type = read_string(f) address = read_inet(f) return dict(change_type=change_type, address=address) @classmethod def recv_status_change(cls, f, protocol_version): # "UP" or "DOWN" change_type = read_string(f) address = read_inet(f) return dict(change_type=change_type, address=address) @classmethod def recv_schema_change(cls, f, protocol_version): # "CREATED", "DROPPED", or "UPDATED" change_type = read_string(f) if protocol_version >= 3: target = read_string(f) keyspace = read_string(f) event = {'target_type': target, 'change_type': change_type, 'keyspace': keyspace} if target != SchemaTargetType.KEYSPACE: target_name = read_string(f) if target == SchemaTargetType.FUNCTION: event['function'] = UserFunctionDescriptor(target_name, [read_string(f) for _ in range(read_short(f))]) elif target == SchemaTargetType.AGGREGATE: event['aggregate'] = UserAggregateDescriptor(target_name, [read_string(f) for _ in range(read_short(f))]) else: event[target.lower()] = target_name else: keyspace = read_string(f) table = read_string(f) if table: event = {'target_type': SchemaTargetType.TABLE, 'change_type': change_type, 'keyspace': keyspace, 'table': table} else: event = {'target_type': SchemaTargetType.KEYSPACE, 'change_type': change_type, 'keyspace': keyspace} return event class _ProtocolHandler(object): """ _ProtocolHander handles encoding and decoding messages. This class can be specialized to compose Handlers which implement alternative result decoding or type deserialization. Class definitions are passed to :class:`cassandra.cluster.Cluster` on initialization. Contracted class methods are :meth:`_ProtocolHandler.encode_message` and :meth:`_ProtocolHandler.decode_message`. """ message_types_by_opcode = _message_types_by_opcode.copy() """ Default mapping of opcode to Message implementation. The default ``decode_message`` implementation uses this to instantiate a message and populate using ``recv_body``. This mapping can be updated to inject specialized result decoding implementations. """ @classmethod def encode_message(cls, msg, stream_id, protocol_version, compressor, allow_beta_protocol_version): """ Encodes a message using the specified frame parameters, and compressor :param msg: the message, typically of cassandra.protocol._MessageType, generated by the driver :param stream_id: protocol stream id for the frame header :param protocol_version: version for the frame header, and used encoding contents :param compressor: optional compression function to be used on the body """ flags = 0 body = io.BytesIO() if msg.custom_payload: if protocol_version < 4: raise UnsupportedOperation("Custom key/value payloads can only be used with protocol version 4 or higher") flags |= CUSTOM_PAYLOAD_FLAG write_bytesmap(body, msg.custom_payload) msg.send_body(body, protocol_version) body = body.getvalue() if compressor and len(body) > 0: body = compressor(body) flags |= COMPRESSED_FLAG if msg.tracing: flags |= TRACING_FLAG if allow_beta_protocol_version: flags |= USE_BETA_FLAG buff = io.BytesIO() cls._write_header(buff, protocol_version, flags, stream_id, msg.opcode, len(body)) buff.write(body) return buff.getvalue() @staticmethod def _write_header(f, version, flags, stream_id, opcode, length): """ Write a CQL protocol frame header. """ pack = v3_header_pack if version >= 3 else header_pack f.write(pack(version, flags, stream_id, opcode)) write_int(f, length) @classmethod def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body, decompressor, result_metadata): """ Decodes a native protocol message body :param protocol_version: version to use decoding contents :param user_type_map: map[keyspace name] = map[type name] = custom type to instantiate when deserializing this type :param stream_id: native protocol stream id from the frame header :param flags: native protocol flags bitmap from the header :param opcode: native protocol opcode from the header :param body: frame body :param decompressor: optional decompression function to inflate the body :return: a message decoded from the body and frame attributes """ if flags & COMPRESSED_FLAG: if decompressor is None: raise RuntimeError("No de-compressor available for compressed frame!") body = decompressor(body) flags ^= COMPRESSED_FLAG body = io.BytesIO(body) if flags & TRACING_FLAG: trace_id = UUID(bytes=body.read(16)) flags ^= TRACING_FLAG else: trace_id = None if flags & WARNING_FLAG: warnings = read_stringlist(body) flags ^= WARNING_FLAG else: warnings = None if flags & CUSTOM_PAYLOAD_FLAG: custom_payload = read_bytesmap(body) flags ^= CUSTOM_PAYLOAD_FLAG else: custom_payload = None flags &= USE_BETA_MASK # will only be set if we asserted it in connection estabishment if flags: log.warning("Unknown protocol flags set: %02x. May cause problems.", flags) msg_class = cls.message_types_by_opcode[opcode] msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata) msg.stream_id = stream_id msg.trace_id = trace_id msg.custom_payload = custom_payload msg.warnings = warnings if msg.warnings: for w in msg.warnings: log.warning("Server warning: %s", w) return msg def cython_protocol_handler(colparser): """ Given a column parser to deserialize ResultMessages, return a suitable Cython-based protocol handler. There are three Cython-based protocol handlers: - obj_parser.ListParser decodes result messages into a list of tuples - obj_parser.LazyParser decodes result messages lazily by returning an iterator - numpy_parser.NumPyParser decodes result messages into NumPy arrays The default is to use obj_parser.ListParser """ from cassandra.row_parser import make_recv_results_rows class FastResultMessage(ResultMessage): """ Cython version of Result Message that has a faster implementation of recv_results_row. """ # type_codes = ResultMessage.type_codes.copy() code_to_type = dict((v, k) for k, v in ResultMessage.type_codes.items()) recv_results_rows = classmethod(make_recv_results_rows(colparser)) class CythonProtocolHandler(_ProtocolHandler): """ Use FastResultMessage to decode query result message messages. """ my_opcodes = _ProtocolHandler.message_types_by_opcode.copy() my_opcodes[FastResultMessage.opcode] = FastResultMessage message_types_by_opcode = my_opcodes col_parser = colparser return CythonProtocolHandler if HAVE_CYTHON: from cassandra.obj_parser import ListParser, LazyParser ProtocolHandler = cython_protocol_handler(ListParser()) LazyProtocolHandler = cython_protocol_handler(LazyParser()) else: # Use Python-based ProtocolHandler ProtocolHandler = _ProtocolHandler LazyProtocolHandler = None if HAVE_CYTHON and HAVE_NUMPY: from cassandra.numpy_parser import NumpyParser NumpyProtocolHandler = cython_protocol_handler(NumpyParser()) else: NumpyProtocolHandler = None def read_byte(f): return int8_unpack(f.read(1)) def write_byte(f, b): f.write(int8_pack(b)) def read_int(f): return int32_unpack(f.read(4)) def write_int(f, i): f.write(int32_pack(i)) def write_long(f, i): f.write(uint64_pack(i)) def read_short(f): return uint16_unpack(f.read(2)) def write_short(f, s): f.write(uint16_pack(s)) def read_consistency_level(f): return read_short(f) def write_consistency_level(f, cl): write_short(f, cl) def read_string(f): size = read_short(f) contents = f.read(size) return contents.decode('utf8') def read_binary_string(f): size = read_short(f) contents = f.read(size) return contents def write_string(f, s): if isinstance(s, six.text_type): s = s.encode('utf8') write_short(f, len(s)) f.write(s) def read_binary_longstring(f): size = read_int(f) contents = f.read(size) return contents def read_longstring(f): return read_binary_longstring(f).decode('utf8') def write_longstring(f, s): if isinstance(s, six.text_type): s = s.encode('utf8') write_int(f, len(s)) f.write(s) def read_stringlist(f): numstrs = read_short(f) return [read_string(f) for _ in range(numstrs)] def write_stringlist(f, stringlist): write_short(f, len(stringlist)) for s in stringlist: write_string(f, s) def read_stringmap(f): numpairs = read_short(f) strmap = {} for _ in range(numpairs): k = read_string(f) strmap[k] = read_string(f) return strmap def write_stringmap(f, strmap): write_short(f, len(strmap)) for k, v in strmap.items(): write_string(f, k) write_string(f, v) def read_bytesmap(f): numpairs = read_short(f) bytesmap = {} for _ in range(numpairs): k = read_string(f) bytesmap[k] = read_value(f) return bytesmap def write_bytesmap(f, bytesmap): write_short(f, len(bytesmap)) for k, v in bytesmap.items(): write_string(f, k) write_value(f, v) def read_stringmultimap(f): numkeys = read_short(f) strmmap = {} for _ in range(numkeys): k = read_string(f) strmmap[k] = read_stringlist(f) return strmmap def write_stringmultimap(f, strmmap): write_short(f, len(strmmap)) for k, v in strmmap.items(): write_string(f, k) write_stringlist(f, v) def read_error_code_map(f): numpairs = read_int(f) error_code_map = {} for _ in range(numpairs): endpoint = read_inet_addr_only(f) error_code_map[endpoint] = read_short(f) return error_code_map def read_value(f): size = read_int(f) if size < 0: return None return f.read(size) def write_value(f, v): if v is None: write_int(f, -1) elif v is _UNSET_VALUE: write_int(f, -2) else: write_int(f, len(v)) f.write(v) def read_inet_addr_only(f): size = read_byte(f) addrbytes = f.read(size) if size == 4: addrfam = socket.AF_INET elif size == 16: addrfam = socket.AF_INET6 else: raise InternalError("bad inet address: %r" % (addrbytes,)) return util.inet_ntop(addrfam, addrbytes) def read_inet(f): addr = read_inet_addr_only(f) port = read_int(f) return (addr, port) def write_inet(f, addrtuple): addr, port = addrtuple if ':' in addr: addrfam = socket.AF_INET6 else: addrfam = socket.AF_INET addrbytes = util.inet_pton(addrfam, addr) write_byte(f, len(addrbytes)) f.write(addrbytes) write_int(f, port) cassandra-driver-3.7.1/cassandra/pool.py0000664000175000017500000007017112766043721023114 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Connection pooling and host management. """ import logging import socket import time from threading import Lock, RLock, Condition import weakref try: from weakref import WeakSet except ImportError: from cassandra.util import WeakSet # NOQA from cassandra import AuthenticationFailed from cassandra.connection import ConnectionException from cassandra.policies import HostDistance log = logging.getLogger(__name__) class NoConnectionsAvailable(Exception): """ All existing connections to a given host are busy, or there are no open connections. """ pass class Host(object): """ Represents a single Cassandra node. """ address = None """ The IP address of the node. This is the RPC address the driver uses when connecting to the node """ broadcast_address = None """ broadcast address configured for the node, *if available* ('peer' in system.peers table). This is not present in the ``system.local`` table for older versions of Cassandra. It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. """ listen_address = None """ listen address configured for the node, *if available*. This is only available in the ``system.local`` table for newer versions of Cassandra. It is also not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. Usually the same as ``broadcast_address`` unless configured differently in cassandra.yaml. """ conviction_policy = None """ A :class:`~.ConvictionPolicy` instance for determining when this node should be marked up or down. """ is_up = None """ :const:`True` if the node is considered up, :const:`False` if it is considered down, and :const:`None` if it is not known if the node is up or down. """ release_version = None """ release_version as queried from the control connection system tables """ dse_version = None """ dse_version as queried from the control connection system tables. Only populated when connecting to DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. """ dse_workload = None """ DSE workload queried from the control connection system tables. Only populated when connecting to DSE with this property available. Not queried if :attr:`~.Cluster.token_metadata_enabled` is ``False``. """ _datacenter = None _rack = None _reconnection_handler = None lock = None _currently_handling_node_up = False def __init__(self, inet_address, conviction_policy_factory, datacenter=None, rack=None): if inet_address is None: raise ValueError("inet_address may not be None") if conviction_policy_factory is None: raise ValueError("conviction_policy_factory may not be None") self.address = inet_address self.conviction_policy = conviction_policy_factory(self) self.set_location_info(datacenter, rack) self.lock = RLock() @property def datacenter(self): """ The datacenter the node is in. """ return self._datacenter @property def rack(self): """ The rack the node is in. """ return self._rack def set_location_info(self, datacenter, rack): """ Sets the datacenter and rack for this node. Intended for internal use (by the control connection, which periodically checks the ring topology) only. """ self._datacenter = datacenter self._rack = rack def set_up(self): if not self.is_up: log.debug("Host %s is now marked up", self.address) self.conviction_policy.reset() self.is_up = True def set_down(self): self.is_up = False def signal_connection_failure(self, connection_exc): return self.conviction_policy.add_failure(connection_exc) def is_currently_reconnecting(self): return self._reconnection_handler is not None def get_and_set_reconnection_handler(self, new_handler): """ Atomically replaces the reconnection handler for this host. Intended for internal use only. """ with self.lock: old = self._reconnection_handler self._reconnection_handler = new_handler return old def __eq__(self, other): return self.address == other.address def __hash__(self): return hash(self.address) def __lt__(self, other): return self.address < other.address def __str__(self): return str(self.address) def __repr__(self): dc = (" %s" % (self._datacenter,)) if self._datacenter else "" return "<%s: %s%s>" % (self.__class__.__name__, self.address, dc) class _ReconnectionHandler(object): """ Abstract class for attempting reconnections with a given schedule and scheduler. """ _cancelled = False def __init__(self, scheduler, schedule, callback, *callback_args, **callback_kwargs): self.scheduler = scheduler self.schedule = schedule self.callback = callback self.callback_args = callback_args self.callback_kwargs = callback_kwargs def start(self): if self._cancelled: log.debug("Reconnection handler was cancelled before starting") return first_delay = next(self.schedule) self.scheduler.schedule(first_delay, self.run) def run(self): if self._cancelled: return conn = None try: conn = self.try_reconnect() except Exception as exc: try: next_delay = next(self.schedule) except StopIteration: # the schedule has been exhausted next_delay = None # call on_exception for logging purposes even if next_delay is None if self.on_exception(exc, next_delay): if next_delay is None: log.warning( "Will not continue to retry reconnection attempts " "due to an exhausted retry schedule") else: self.scheduler.schedule(next_delay, self.run) else: if not self._cancelled: self.on_reconnection(conn) self.callback(*(self.callback_args), **(self.callback_kwargs)) finally: if conn: conn.close() def cancel(self): self._cancelled = True def try_reconnect(self): """ Subclasses must implement this method. It should attempt to open a new Connection and return it; if a failure occurs, an Exception should be raised. """ raise NotImplementedError() def on_reconnection(self, connection): """ Called when a new Connection is successfully opened. Nothing is done by default. """ pass def on_exception(self, exc, next_delay): """ Called when an Exception is raised when trying to connect. `exc` is the Exception that was raised and `next_delay` is the number of seconds (as a float) that the handler will wait before attempting to connect again. Subclasses should return :const:`False` if no more attempts to connection should be made, :const:`True` otherwise. The default behavior is to always retry unless the error is an :exc:`.AuthenticationFailed` instance. """ if isinstance(exc, AuthenticationFailed): return False else: return True class _HostReconnectionHandler(_ReconnectionHandler): def __init__(self, host, connection_factory, is_host_addition, on_add, on_up, *args, **kwargs): _ReconnectionHandler.__init__(self, *args, **kwargs) self.is_host_addition = is_host_addition self.on_add = on_add self.on_up = on_up self.host = host self.connection_factory = connection_factory def try_reconnect(self): return self.connection_factory() def on_reconnection(self, connection): log.info("Successful reconnection to %s, marking node up if it isn't already", self.host) if self.is_host_addition: self.on_add(self.host) else: self.on_up(self.host) def on_exception(self, exc, next_delay): if isinstance(exc, AuthenticationFailed): return False else: log.warning("Error attempting to reconnect to %s, scheduling retry in %s seconds: %s", self.host, next_delay, exc) log.debug("Reconnection error details", exc_info=True) return True class HostConnection(object): """ When using v3 of the native protocol, this is used instead of a connection pool per host (HostConnectionPool) due to the increased in-flight capacity of individual connections. """ host = None host_distance = None is_shutdown = False _session = None _connection = None _lock = None _keyspace = None def __init__(self, host, host_distance, session): self.host = host self.host_distance = host_distance self._session = weakref.proxy(session) self._lock = Lock() # this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool. self._stream_available_condition = Condition(self._lock) self._is_replacing = False if host_distance == HostDistance.IGNORED: log.debug("Not opening connection to ignored host %s", self.host) return elif host_distance == HostDistance.REMOTE and not session.cluster.connect_to_remote_hosts: log.debug("Not opening connection to remote host %s", self.host) return log.debug("Initializing connection for host %s", self.host) self._connection = session.cluster.connection_factory(host.address) self._keyspace = session.keyspace if self._keyspace: self._connection.set_keyspace_blocking(self._keyspace) log.debug("Finished initializing connection for host %s", self.host) def borrow_connection(self, timeout): if self.is_shutdown: raise ConnectionException( "Pool for %s is shutdown" % (self.host,), self.host) conn = self._connection if not conn: raise NoConnectionsAvailable() start = time.time() remaining = timeout while True: with conn.lock: if conn.in_flight <= conn.max_request_id: conn.in_flight += 1 return conn, conn.get_request_id() if timeout is not None: remaining = timeout - time.time() + start if remaining < 0: break with self._stream_available_condition: self._stream_available_condition.wait(remaining) raise NoConnectionsAvailable("All request IDs are currently in use") def return_connection(self, connection): with connection.lock: connection.in_flight -= 1 with self._stream_available_condition: self._stream_available_condition.notify() if (connection.is_defunct or connection.is_closed) and not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) is_down = self._session.cluster.signal_connection_failure( self.host, connection.last_error, is_host_addition=False) connection.signaled_error = True if is_down: self.shutdown() else: self._connection = None with self._lock: if self._is_replacing: return self._is_replacing = True self._session.submit(self._replace, connection) def _replace(self, connection): log.debug("Replacing connection (%s) to %s", id(connection), self.host) try: conn = self._session.cluster.connection_factory(self.host.address) if self._keyspace: conn.set_keyspace_blocking(self._keyspace) self._connection = conn except Exception: log.warning("Failed reconnecting %s. Retrying." % (self.host.address,)) self._session.submit(self._replace, connection) else: with self._lock: self._is_replacing = False self._stream_available_condition.notify() def shutdown(self): with self._lock: if self.is_shutdown: return else: self.is_shutdown = True self._stream_available_condition.notify_all() if self._connection: self._connection.close() def _set_keyspace_for_all_conns(self, keyspace, callback): if self.is_shutdown or not self._connection: return def connection_finished_setting_keyspace(conn, error): self.return_connection(conn) errors = [] if not error else [error] callback(self, errors) self._keyspace = keyspace self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace) def get_connections(self): c = self._connection return [c] if c else [] def get_state(self): connection = self._connection open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 in_flights = [connection.in_flight] if connection else [] return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights} @property def open_count(self): connection = self._connection return 1 if connection and not (connection.is_closed or connection.is_defunct) else 0 _MAX_SIMULTANEOUS_CREATION = 1 _MIN_TRASH_INTERVAL = 10 class HostConnectionPool(object): """ Used to pool connections to a host for v1 and v2 native protocol. """ host = None host_distance = None is_shutdown = False open_count = 0 _scheduled_for_creation = 0 _next_trash_allowed_at = 0 _keyspace = None def __init__(self, host, host_distance, session): self.host = host self.host_distance = host_distance self._session = weakref.proxy(session) self._lock = RLock() self._conn_available_condition = Condition() log.debug("Initializing new connection pool for host %s", self.host) core_conns = session.cluster.get_core_connections_per_host(host_distance) self._connections = [session.cluster.connection_factory(host.address) for i in range(core_conns)] self._keyspace = session.keyspace if self._keyspace: for conn in self._connections: conn.set_keyspace_blocking(self._keyspace) self._trash = set() self._next_trash_allowed_at = time.time() self.open_count = core_conns log.debug("Finished initializing new connection pool for host %s", self.host) def borrow_connection(self, timeout): if self.is_shutdown: raise ConnectionException( "Pool for %s is shutdown" % (self.host,), self.host) conns = self._connections if not conns: # handled specially just for simpler code log.debug("Detected empty pool, opening core conns to %s", self.host) core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) with self._lock: # we check the length of self._connections again # along with self._scheduled_for_creation while holding the lock # in case multiple threads hit this condition at the same time to_create = core_conns - (len(self._connections) + self._scheduled_for_creation) for i in range(to_create): self._scheduled_for_creation += 1 self._session.submit(self._create_new_connection) # in_flight is incremented by wait_for_conn conn = self._wait_for_conn(timeout) return conn else: # note: it would be nice to push changes to these config settings # to pools instead of doing a new lookup on every # borrow_connection() call max_reqs = self._session.cluster.get_max_requests_per_connection(self.host_distance) max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) least_busy = min(conns, key=lambda c: c.in_flight) request_id = None # to avoid another thread closing this connection while # trashing it (through the return_connection process), hold # the connection lock from this point until we've incremented # its in_flight count need_to_wait = False with least_busy.lock: if least_busy.in_flight < least_busy.max_request_id: least_busy.in_flight += 1 request_id = least_busy.get_request_id() else: # once we release the lock, wait for another connection need_to_wait = True if need_to_wait: # wait_for_conn will increment in_flight on the conn least_busy, request_id = self._wait_for_conn(timeout) # if we have too many requests on this connection but we still # have space to open a new connection against this host, go ahead # and schedule the creation of a new connection if least_busy.in_flight >= max_reqs and len(self._connections) < max_conns: self._maybe_spawn_new_connection() return least_busy, request_id def _maybe_spawn_new_connection(self): with self._lock: if self._scheduled_for_creation >= _MAX_SIMULTANEOUS_CREATION: return if self.open_count >= self._session.cluster.get_max_connections_per_host(self.host_distance): return self._scheduled_for_creation += 1 log.debug("Submitting task for creation of new Connection to %s", self.host) self._session.submit(self._create_new_connection) def _create_new_connection(self): try: self._add_conn_if_under_max() except (ConnectionException, socket.error) as exc: log.warning("Failed to create new connection to %s: %s", self.host, exc) except Exception: log.exception("Unexpectedly failed to create new connection") finally: with self._lock: self._scheduled_for_creation -= 1 def _add_conn_if_under_max(self): max_conns = self._session.cluster.get_max_connections_per_host(self.host_distance) with self._lock: if self.is_shutdown: return True if self.open_count >= max_conns: return True self.open_count += 1 log.debug("Going to open new connection to host %s", self.host) try: conn = self._session.cluster.connection_factory(self.host.address) if self._keyspace: conn.set_keyspace_blocking(self._session.keyspace) self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL with self._lock: new_connections = self._connections[:] + [conn] self._connections = new_connections log.debug("Added new connection (%s) to pool for host %s, signaling availablility", id(conn), self.host) self._signal_available_conn() return True except (ConnectionException, socket.error) as exc: log.warning("Failed to add new connection to pool for host %s: %s", self.host, exc) with self._lock: self.open_count -= 1 if self._session.cluster.signal_connection_failure(self.host, exc, is_host_addition=False): self.shutdown() return False except AuthenticationFailed: with self._lock: self.open_count -= 1 return False def _await_available_conn(self, timeout): with self._conn_available_condition: self._conn_available_condition.wait(timeout) def _signal_available_conn(self): with self._conn_available_condition: self._conn_available_condition.notify() def _signal_all_available_conn(self): with self._conn_available_condition: self._conn_available_condition.notify_all() def _wait_for_conn(self, timeout): start = time.time() remaining = timeout while remaining > 0: # wait on our condition for the possibility that a connection # is useable self._await_available_conn(remaining) # self.shutdown() may trigger the above Condition if self.is_shutdown: raise ConnectionException("Pool is shutdown") conns = self._connections if conns: least_busy = min(conns, key=lambda c: c.in_flight) with least_busy.lock: if least_busy.in_flight < least_busy.max_request_id: least_busy.in_flight += 1 return least_busy, least_busy.get_request_id() remaining = timeout - (time.time() - start) raise NoConnectionsAvailable() def return_connection(self, connection): with connection.lock: connection.in_flight -= 1 in_flight = connection.in_flight if connection.is_defunct or connection.is_closed: if not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) is_down = self._session.cluster.signal_connection_failure( self.host, connection.last_error, is_host_addition=False) connection.signaled_error = True if is_down: self.shutdown() else: self._replace(connection) else: if connection in self._trash: with connection.lock: if connection.in_flight == 0: with self._lock: if connection in self._trash: self._trash.remove(connection) log.debug("Closing trashed connection (%s) to %s", id(connection), self.host) connection.close() return core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) min_reqs = self._session.cluster.get_min_requests_per_connection(self.host_distance) # we can use in_flight here without holding the connection lock # because the fact that in_flight dipped below the min at some # point is enough to start the trashing procedure if len(self._connections) > core_conns and in_flight <= min_reqs and \ time.time() >= self._next_trash_allowed_at: self._maybe_trash_connection(connection) else: self._signal_available_conn() def _maybe_trash_connection(self, connection): core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) did_trash = False with self._lock: if connection not in self._connections: return if self.open_count > core_conns: did_trash = True self.open_count -= 1 new_connections = self._connections[:] new_connections.remove(connection) self._connections = new_connections with connection.lock: if connection.in_flight == 0: log.debug("Skipping trash and closing unused connection (%s) to %s", id(connection), self.host) connection.close() # skip adding it to the trash if we're already closing it return self._trash.add(connection) if did_trash: self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL log.debug("Trashed connection (%s) to %s", id(connection), self.host) def _replace(self, connection): should_replace = False with self._lock: if connection in self._connections: new_connections = self._connections[:] new_connections.remove(connection) self._connections = new_connections self.open_count -= 1 should_replace = True if should_replace: log.debug("Replacing connection (%s) to %s", id(connection), self.host) connection.close() self._session.submit(self._retrying_replace) else: log.debug("Closing connection (%s) to %s", id(connection), self.host) connection.close() def _retrying_replace(self): replaced = False try: replaced = self._add_conn_if_under_max() except Exception: log.exception("Failed replacing connection to %s", self.host) if not replaced: log.debug("Failed replacing connection to %s. Retrying.", self.host) self._session.submit(self._retrying_replace) def shutdown(self): with self._lock: if self.is_shutdown: return else: self.is_shutdown = True self._signal_all_available_conn() for conn in self._connections: conn.close() self.open_count -= 1 for conn in self._trash: conn.close() def ensure_core_connections(self): if self.is_shutdown: return core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance) with self._lock: to_create = core_conns - (len(self._connections) + self._scheduled_for_creation) for i in range(to_create): self._scheduled_for_creation += 1 self._session.submit(self._create_new_connection) def _set_keyspace_for_all_conns(self, keyspace, callback): """ Asynchronously sets the keyspace for all connections. When all connections have been set, `callback` will be called with two arguments: this pool, and a list of any errors that occurred. """ remaining_callbacks = set(self._connections) errors = [] if not remaining_callbacks: callback(self, errors) return def connection_finished_setting_keyspace(conn, error): self.return_connection(conn) remaining_callbacks.remove(conn) if error: errors.append(error) if not remaining_callbacks: callback(self, errors) self._keyspace = keyspace for conn in self._connections: conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace) def get_connections(self): return self._connections def get_state(self): in_flights = [c.in_flight for c in self._connections] return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights} cassandra-driver-3.7.1/cassandra/query.py0000664000175000017500000011233012766043721023302 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module holds classes for working with prepared statements and specifying consistency levels and retry policies for individual queries. """ from collections import namedtuple from datetime import datetime, timedelta import re import struct import time import six from six.moves import range, zip from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.util import unix_time_from_uuid1 from cassandra.encoder import Encoder import cassandra.encoder from cassandra.protocol import _UNSET_VALUE from cassandra.util import OrderedDict, _sanitize_identifiers import logging log = logging.getLogger(__name__) UNSET_VALUE = _UNSET_VALUE """ Specifies an unset value when binding a prepared statement. Unset values are ignored, allowing prepared statements to be used without specify See https://issues.apache.org/jira/browse/CASSANDRA-7304 for further details on semantics. .. versionadded:: 2.6.0 Only valid when using native protocol v4+ """ NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') _clean_name_cache = {} def _clean_column_name(name): try: return _clean_name_cache[name] except KeyError: clean = NON_ALPHA_REGEX.sub("_", START_BADCHAR_REGEX.sub("", END_BADCHAR_REGEX.sub("", name))) _clean_name_cache[name] = clean return clean def tuple_factory(colnames, rows): """ Returns each row as a tuple Example:: >>> from cassandra.query import tuple_factory >>> session = cluster.connect('mykeyspace') >>> session.row_factory = tuple_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> print rows[0] ('Bob', 42) .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ return rows def named_tuple_factory(colnames, rows): """ Returns each row as a `namedtuple `_. This is the default row factory. Example:: >>> from cassandra.query import named_tuple_factory >>> session = cluster.connect('mykeyspace') >>> session.row_factory = named_tuple_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> user = rows[0] >>> # you can access field by their name: >>> print "name: %s, age: %d" % (user.name, user.age) name: Bob, age: 42 >>> # or you can access fields by their position (like a tuple) >>> name, age = user >>> print "name: %s, age: %d" % (name, age) name: Bob, age: 42 >>> name = user[0] >>> age = user[1] >>> print "name: %s, age: %d" % (name, age) name: Bob, age: 42 .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ clean_column_names = map(_clean_column_name, colnames) try: Row = namedtuple('Row', clean_column_names) except Exception: clean_column_names = list(map(_clean_column_name, colnames)) # create list because py3 map object will be consumed by first attempt log.warning("Failed creating named tuple for results with column names %s (cleaned: %s) " "(see Python 'namedtuple' documentation for details on name rules). " "Results will be returned with positional names. " "Avoid this by choosing different names, using SELECT \"\" AS aliases, " "or specifying a different row_factory on your Session" % (colnames, clean_column_names)) Row = namedtuple('Row', _sanitize_identifiers(clean_column_names)) return [Row(*row) for row in rows] def dict_factory(colnames, rows): """ Returns each row as a dict. Example:: >>> from cassandra.query import dict_factory >>> session = cluster.connect('mykeyspace') >>> session.row_factory = dict_factory >>> rows = session.execute("SELECT name, age FROM users LIMIT 1") >>> print rows[0] {u'age': 42, u'name': u'Bob'} .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ return [dict(zip(colnames, row)) for row in rows] def ordered_dict_factory(colnames, rows): """ Like :meth:`~cassandra.query.dict_factory`, but returns each row as an OrderedDict, so the order of the columns is preserved. .. versionchanged:: 2.0.0 moved from ``cassandra.decoder`` to ``cassandra.query`` """ return [OrderedDict(zip(colnames, row)) for row in rows] FETCH_SIZE_UNSET = object() class Statement(object): """ An abstract class representing a single query. There are three subclasses: :class:`.SimpleStatement`, :class:`.BoundStatement`, and :class:`.BatchStatement`. These can be passed to :meth:`.Session.execute()`. """ retry_policy = None """ An instance of a :class:`cassandra.policies.RetryPolicy` or one of its subclasses. This controls when a query will be retried and how it will be retried. """ consistency_level = None """ The :class:`.ConsistencyLevel` to be used for this operation. Defaults to :const:`None`, which means that the default consistency level for the Session this is executed in will be used. """ fetch_size = FETCH_SIZE_UNSET """ How many rows will be fetched at a time. This overrides the default of :attr:`.Session.default_fetch_size` This only takes effect when protocol version 2 or higher is used. See :attr:`.Cluster.protocol_version` for details. .. versionadded:: 2.0.0 """ keyspace = None """ The string name of the keyspace this query acts on. This is used when :class:`~.TokenAwarePolicy` is configured for :attr:`.Cluster.load_balancing_policy` It is set implicitly on :class:`.BoundStatement`, and :class:`.BatchStatement`, but must be set explicitly on :class:`.SimpleStatement`. .. versionadded:: 2.1.3 """ custom_payload = None """ :ref:`custom_payload` to be passed to the server. These are only allowed when using protocol version 4 or higher. .. versionadded:: 2.6.0 """ is_idempotent = False """ Flag indicating whether this statement is safe to run multiple times in speculative execution. """ _serial_consistency_level = None _routing_key = None def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, is_idempotent=False): if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') self.retry_policy = retry_policy if consistency_level is not None: self.consistency_level = consistency_level self._routing_key = routing_key if serial_consistency_level is not None: self.serial_consistency_level = serial_consistency_level if fetch_size is not FETCH_SIZE_UNSET: self.fetch_size = fetch_size if keyspace is not None: self.keyspace = keyspace if custom_payload is not None: self.custom_payload = custom_payload self.is_idempotent = is_idempotent def _key_parts_packed(self, parts): for p in parts: l = len(p) yield struct.pack(">H%dsB" % l, l, p, 0) def _get_routing_key(self): return self._routing_key def _set_routing_key(self, key): if isinstance(key, (list, tuple)): if len(key) == 1: self._routing_key = key[0] else: self._routing_key = b"".join(self._key_parts_packed(key)) else: self._routing_key = key def _del_routing_key(self): self._routing_key = None routing_key = property( _get_routing_key, _set_routing_key, _del_routing_key, """ The :attr:`~.TableMetadata.partition_key` portion of the primary key, which can be used to determine which nodes are replicas for the query. If the partition key is a composite, a list or tuple must be passed in. Each key component should be in its packed (binary) format, so all components should be strings. """) def _get_serial_consistency_level(self): return self._serial_consistency_level def _set_serial_consistency_level(self, serial_consistency_level): acceptable = (None, ConsistencyLevel.SERIAL, ConsistencyLevel.LOCAL_SERIAL) if serial_consistency_level not in acceptable: raise ValueError( "serial_consistency_level must be either ConsistencyLevel.SERIAL " "or ConsistencyLevel.LOCAL_SERIAL") self._serial_consistency_level = serial_consistency_level def _del_serial_consistency_level(self): self._serial_consistency_level = None serial_consistency_level = property( _get_serial_consistency_level, _set_serial_consistency_level, _del_serial_consistency_level, """ The serial consistency level is only used by conditional updates (``INSERT``, ``UPDATE`` and ``DELETE`` with an ``IF`` condition). For those, the ``serial_consistency_level`` defines the consistency level of the serial phase (or "paxos" phase) while the normal :attr:`~.consistency_level` defines the consistency for the "learn" phase, i.e. what type of reads will be guaranteed to see the update right away. For example, if a conditional write has a :attr:`~.consistency_level` of :attr:`~.ConsistencyLevel.QUORUM` (and is successful), then a :attr:`~.ConsistencyLevel.QUORUM` read is guaranteed to see that write. But if the regular :attr:`~.consistency_level` of that write is :attr:`~.ConsistencyLevel.ANY`, then only a read with a :attr:`~.consistency_level` of :attr:`~.ConsistencyLevel.SERIAL` is guaranteed to see it (even a read with consistency :attr:`~.ConsistencyLevel.ALL` is not guaranteed to be enough). The serial consistency can only be one of :attr:`~.ConsistencyLevel.SERIAL` or :attr:`~.ConsistencyLevel.LOCAL_SERIAL`. While ``SERIAL`` guarantees full linearizability (with other ``SERIAL`` updates), ``LOCAL_SERIAL`` only guarantees it in the local data center. The serial consistency level is ignored for any query that is not a conditional update. Serial reads should use the regular :attr:`consistency_level`. Serial consistency levels may only be used against Cassandra 2.0+ and the :attr:`~.Cluster.protocol_version` must be set to 2 or higher. See :doc:`/lwt` for a discussion on how to work with results returned from conditional statements. .. versionadded:: 2.0.0 """) class SimpleStatement(Statement): """ A simple, un-prepared query. """ def __init__(self, query_string, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, is_idempotent=False): """ `query_string` should be a literal CQL statement with the exception of parameter placeholders that will be filled through the `parameters` argument of :meth:`.Session.execute()`. See :class:`Statement` attributes for a description of the other parameters. """ Statement.__init__(self, retry_policy, consistency_level, routing_key, serial_consistency_level, fetch_size, keyspace, custom_payload, is_idempotent) self._query_string = query_string @property def query_string(self): return self._query_string def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.query_string, consistency)) __repr__ = __str__ class PreparedStatement(object): """ A statement that has been prepared against at least one Cassandra node. Instances of this class should not be created directly, but through :meth:`.Session.prepare()`. A :class:`.PreparedStatement` should be prepared only once. Re-preparing a statement may affect performance (as the operation requires a network roundtrip). """ column_metadata = None #TODO: make this bind_metadata in next major consistency_level = None custom_payload = None fetch_size = FETCH_SIZE_UNSET keyspace = None # change to prepared_keyspace in major release protocol_version = None query_id = None query_string = None result_metadata = None routing_key_indexes = None _routing_key_index_set = None serial_consistency_level = None def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspace, protocol_version, result_metadata): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes self.query_string = query self.keyspace = keyspace self.protocol_version = protocol_version self.result_metadata = result_metadata self.is_idempotent = False @classmethod def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version, result_metadata): if not column_metadata: return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version, result_metadata) if pk_indexes: routing_key_indexes = pk_indexes else: routing_key_indexes = None first_col = column_metadata[0] ks_meta = cluster_metadata.keyspaces.get(first_col.keyspace_name) if ks_meta: table_meta = ks_meta.tables.get(first_col.table_name) if table_meta: partition_key_columns = table_meta.partition_key # make a map of {column_name: index} for each column in the statement statement_indexes = dict((c.name, i) for i, c in enumerate(column_metadata)) # a list of which indexes in the statement correspond to partition key items try: routing_key_indexes = [statement_indexes[c.name] for c in partition_key_columns] except KeyError: # we're missing a partition key component in the prepared pass # statement; just leave routing_key_indexes as None return PreparedStatement(column_metadata, query_id, routing_key_indexes, query, prepared_keyspace, protocol_version, result_metadata) def bind(self, values): """ Creates and returns a :class:`BoundStatement` instance using `values`. See :meth:`BoundStatement.bind` for rules on input ``values``. """ return BoundStatement(self).bind(values) def is_routing_key_index(self, i): if self._routing_key_index_set is None: self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set() return i in self._routing_key_index_set def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.query_string, consistency)) __repr__ = __str__ class BoundStatement(Statement): """ A prepared statement that has been bound to a particular set of values. These may be created directly or through :meth:`.PreparedStatement.bind()`. """ prepared_statement = None """ The :class:`PreparedStatement` instance that this was created from. """ values = None """ The sequence of values that were bound to the prepared statement. """ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None): """ `prepared_statement` should be an instance of :class:`PreparedStatement`. See :class:`Statement` attributes for a description of the other parameters. """ self.prepared_statement = prepared_statement self.consistency_level = prepared_statement.consistency_level self.serial_consistency_level = prepared_statement.serial_consistency_level self.fetch_size = prepared_statement.fetch_size self.custom_payload = prepared_statement.custom_payload self.is_idempotent = prepared_statement.is_idempotent self.values = [] meta = prepared_statement.column_metadata if meta: self.keyspace = meta[0].keyspace_name Statement.__init__(self, retry_policy, consistency_level, routing_key, serial_consistency_level, fetch_size, keyspace, custom_payload) def bind(self, values): """ Binds a sequence of values for the prepared statement parameters and returns this instance. Note that `values` *must* be: * a sequence, even if you are only binding one value, or * a dict that relates 1-to-1 between dict keys and columns .. versionchanged:: 2.6.0 :data:`~.UNSET_VALUE` was introduced. These can be bound as positional parameters in a sequence, or by name in a dict. Additionally, when using protocol v4+: * short sequences will be extended to match bind parameters with UNSET_VALUE * names may be omitted from a dict with UNSET_VALUE implied. .. versionchanged:: 3.0.0 method will not throw if extra keys are present in bound dict (PYTHON-178) """ if values is None: values = () proto_version = self.prepared_statement.protocol_version col_meta = self.prepared_statement.column_metadata # special case for binding dicts if isinstance(values, dict): values_dict = values values = [] # sort values accordingly for col in col_meta: try: values.append(values_dict[col.name]) except KeyError: if proto_version >= 4: values.append(UNSET_VALUE) else: raise KeyError( 'Column name `%s` not found in bound dict.' % (col.name)) value_len = len(values) col_meta_len = len(col_meta) if value_len > col_meta_len: raise ValueError( "Too many arguments provided to bind() (got %d, expected %d)" % (len(values), len(col_meta))) # this is fail-fast for clarity pre-v4. When v4 can be assumed, # the error will be better reported when UNSET_VALUE is implicitly added. if proto_version < 4 and self.prepared_statement.routing_key_indexes and \ value_len < len(self.prepared_statement.routing_key_indexes): raise ValueError( "Too few arguments provided to bind() (got %d, required %d for routing key)" % (value_len, len(self.prepared_statement.routing_key_indexes))) self.raw_values = values self.values = [] for value, col_spec in zip(values, col_meta): if value is None: self.values.append(None) elif value is UNSET_VALUE: if proto_version >= 4: self._append_unset_value() else: raise ValueError("Attempt to bind UNSET_VALUE while using unsuitable protocol version (%d < 4)" % proto_version) else: try: self.values.append(col_spec.type.serialize(value, proto_version)) except (TypeError, struct.error) as exc: actual_type = type(value) message = ('Received an argument of invalid type for column "%s". ' 'Expected: %s, Got: %s; (%s)' % (col_spec.name, col_spec.type, actual_type, exc)) raise TypeError(message) if proto_version >= 4: diff = col_meta_len - len(self.values) if diff: for _ in range(diff): self._append_unset_value() return self def _append_unset_value(self): next_index = len(self.values) if self.prepared_statement.is_routing_key_index(next_index): col_meta = self.prepared_statement.column_metadata[next_index] raise ValueError("Cannot bind UNSET_VALUE as a part of the routing key '%s'" % col_meta.name) self.values.append(UNSET_VALUE) @property def routing_key(self): if not self.prepared_statement.routing_key_indexes: return None if self._routing_key is not None: return self._routing_key routing_indexes = self.prepared_statement.routing_key_indexes if len(routing_indexes) == 1: self._routing_key = self.values[routing_indexes[0]] else: self._routing_key = b"".join(self._key_parts_packed(self.values[i] for i in routing_indexes)) return self._routing_key def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.prepared_statement.query_string, self.raw_values, consistency)) __repr__ = __str__ class BatchType(object): """ A BatchType is used with :class:`.BatchStatement` instances to control the atomicity of the batch operation. .. versionadded:: 2.0.0 """ LOGGED = None """ Atomic batch operation. """ UNLOGGED = None """ Non-atomic batch operation. """ COUNTER = None """ Batches of counter operations. """ def __init__(self, name, value): self.name = name self.value = value def __str__(self): return self.name def __repr__(self): return "BatchType.%s" % (self.name, ) BatchType.LOGGED = BatchType("LOGGED", 0) BatchType.UNLOGGED = BatchType("UNLOGGED", 1) BatchType.COUNTER = BatchType("COUNTER", 2) class BatchStatement(Statement): """ A protocol-level batch of operations which are applied atomically by default. .. versionadded:: 2.0.0 """ batch_type = None """ The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. """ serial_consistency_level = None """ The same as :attr:`.Statement.serial_consistency_level`, but is only supported when using protocol version 3 or higher. """ _statements_and_parameters = None _session = None def __init__(self, batch_type=BatchType.LOGGED, retry_policy=None, consistency_level=None, serial_consistency_level=None, session=None, custom_payload=None): """ `batch_type` specifies The :class:`.BatchType` for the batch operation. Defaults to :attr:`.BatchType.LOGGED`. `retry_policy` should be a :class:`~.RetryPolicy` instance for controlling retries on the operation. `consistency_level` should be a :class:`~.ConsistencyLevel` value to be used for all operations in the batch. `custom_payload` is a :ref:`custom_payload` passed to the server. Note: as Statement objects are added to the batch, this map is updated with any values found in their custom payloads. These are only allowed when using protocol version 4 or higher. Example usage: .. code-block:: python insert_user = session.prepare("INSERT INTO users (name, age) VALUES (?, ?)") batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM) for (name, age) in users_to_insert: batch.add(insert_user, (name, age)) session.execute(batch) You can also mix different types of operations within a batch: .. code-block:: python batch = BatchStatement() batch.add(SimpleStatement("INSERT INTO users (name, age) VALUES (%s, %s)"), (name, age)) batch.add(SimpleStatement("DELETE FROM pending_users WHERE name=%s"), (name,)) session.execute(batch) .. versionadded:: 2.0.0 .. versionchanged:: 2.1.0 Added `serial_consistency_level` as a parameter .. versionchanged:: 2.6.0 Added `custom_payload` as a parameter """ self.batch_type = batch_type self._statements_and_parameters = [] self._session = session Statement.__init__(self, retry_policy=retry_policy, consistency_level=consistency_level, serial_consistency_level=serial_consistency_level, custom_payload=custom_payload) def clear(self): """ This is a convenience method to clear a batch statement for reuse. *Note:* it should not be used concurrently with uncompleted execution futures executing the same ``BatchStatement``. """ del self._statements_and_parameters[:] self.keyspace = None self.routing_key = None if self.custom_payload: self.custom_payload.clear() def add(self, statement, parameters=None): """ Adds a :class:`.Statement` and optional sequence of parameters to be used with the statement to the batch. Like with other statements, parameters must be a sequence, even if there is only one item. """ if isinstance(statement, six.string_types): if parameters: encoder = Encoder() if self._session is None else self._session.encoder statement = bind_params(statement, parameters, encoder) self._add_statement_and_params(False, statement, ()) elif isinstance(statement, PreparedStatement): query_id = statement.query_id bound_statement = statement.bind(() if parameters is None else parameters) self._update_state(bound_statement) self._add_statement_and_params(True, query_id, bound_statement.values) elif isinstance(statement, BoundStatement): if parameters: raise ValueError( "Parameters cannot be passed with a BoundStatement " "to BatchStatement.add()") self._update_state(statement) self._add_statement_and_params(True, statement.prepared_statement.query_id, statement.values) else: # it must be a SimpleStatement query_string = statement.query_string if parameters: encoder = Encoder() if self._session is None else self._session.encoder query_string = bind_params(query_string, parameters, encoder) self._update_state(statement) self._add_statement_and_params(False, query_string, ()) return self def add_all(self, statements, parameters): """ Adds a sequence of :class:`.Statement` objects and a matching sequence of parameters to the batch. Statement and parameter sequences must be of equal length or one will be truncated. :const:`None` can be used in the parameters position where are needed. """ for statement, value in zip(statements, parameters): self.add(statement, value) def _add_statement_and_params(self, is_prepared, statement, parameters): if len(self._statements_and_parameters) >= 0xFFFF: raise ValueError("Batch statement cannot contain more than %d statements." % 0xFFFF) self._statements_and_parameters.append((is_prepared, statement, parameters)) def _maybe_set_routing_attributes(self, statement): if self.routing_key is None: if statement.keyspace and statement.routing_key: self.routing_key = statement.routing_key self.keyspace = statement.keyspace def _update_custom_payload(self, statement): if statement.custom_payload: if self.custom_payload is None: self.custom_payload = {} self.custom_payload.update(statement.custom_payload) def _update_state(self, statement): self._maybe_set_routing_attributes(statement) self._update_custom_payload(statement) def __str__(self): consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set') return (u'' % (self.batch_type, len(self._statements_and_parameters), consistency)) __repr__ = __str__ ValueSequence = cassandra.encoder.ValueSequence """ A wrapper class that is used to specify that a sequence of values should be treated as a CQL list of values instead of a single column collection when used as part of the `parameters` argument for :meth:`.Session.execute()`. This is typically needed when supplying a list of keys to select. For example:: >>> my_user_ids = ('alice', 'bob', 'charles') >>> query = "SELECT * FROM users WHERE user_id IN %s" >>> session.execute(query, parameters=[ValueSequence(my_user_ids)]) """ def bind_params(query, params, encoder): if six.PY2 and isinstance(query, six.text_type): query = query.encode('utf-8') if isinstance(params, dict): return query % dict((k, encoder.cql_encode_all_types(v)) for k, v in six.iteritems(params)) else: return query % tuple(encoder.cql_encode_all_types(v) for v in params) class TraceUnavailable(Exception): """ Raised when complete trace details cannot be fetched from Cassandra. """ pass class QueryTrace(object): """ A trace of the duration and events that occurred when executing an operation. """ trace_id = None """ :class:`uuid.UUID` unique identifier for this tracing session. Matches the ``session_id`` column in ``system_traces.sessions`` and ``system_traces.events``. """ request_type = None """ A string that very generally describes the traced operation. """ duration = None """ A :class:`datetime.timedelta` measure of the duration of the query. """ client = None """ The IP address of the client that issued this request This is only available when using Cassandra 2.2+ """ coordinator = None """ The IP address of the host that acted as coordinator for this request. """ parameters = None """ A :class:`dict` of parameters for the traced operation, such as the specific query string. """ started_at = None """ A UTC :class:`datetime.datetime` object describing when the operation was started. """ events = None """ A chronologically sorted list of :class:`.TraceEvent` instances representing the steps the traced operation went through. This corresponds to the rows in ``system_traces.events`` for this tracing session. """ _session = None _SELECT_SESSIONS_FORMAT = "SELECT * FROM system_traces.sessions WHERE session_id = %s" _SELECT_EVENTS_FORMAT = "SELECT * FROM system_traces.events WHERE session_id = %s" _BASE_RETRY_SLEEP = 0.003 def __init__(self, trace_id, session): self.trace_id = trace_id self._session = session def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None): """ Retrieves the actual tracing details from Cassandra and populates the attributes of this instance. Because tracing details are stored asynchronously by Cassandra, this may need to retry the session detail fetch. If the trace is still not available after `max_wait` seconds, :exc:`.TraceUnavailable` will be raised; if `max_wait` is :const:`None`, this will retry forever. `wait_for_complete=False` bypasses the wait for duration to be populated. This can be used to query events from partial sessions. `query_cl` specifies a consistency level to use for polling the trace tables, if it should be different than the session default. """ attempt = 0 start = time.time() while True: time_spent = time.time() - start if max_wait is not None and time_spent >= max_wait: raise TraceUnavailable( "Trace information was not available within %f seconds. Consider raising Session.max_trace_wait." % (max_wait,)) log.debug("Attempting to fetch trace info for trace ID: %s", self.trace_id) session_results = self._execute( SimpleStatement(self._SELECT_SESSIONS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) is_complete = session_results and session_results[0].duration is not None if not session_results or (wait_for_complete and not is_complete): time.sleep(self._BASE_RETRY_SLEEP * (2 ** attempt)) attempt += 1 continue if is_complete: log.debug("Fetched trace info for trace ID: %s", self.trace_id) else: log.debug("Fetching parital trace info for trace ID: %s", self.trace_id) session_row = session_results[0] self.request_type = session_row.request self.duration = timedelta(microseconds=session_row.duration) if is_complete else None self.started_at = session_row.started_at self.coordinator = session_row.coordinator self.parameters = session_row.parameters # since C* 2.2 self.client = getattr(session_row, 'client', None) log.debug("Attempting to fetch trace events for trace ID: %s", self.trace_id) time_spent = time.time() - start event_results = self._execute( SimpleStatement(self._SELECT_EVENTS_FORMAT, consistency_level=query_cl), (self.trace_id,), time_spent, max_wait) log.debug("Fetched trace events for trace ID: %s", self.trace_id) self.events = tuple(TraceEvent(r.activity, r.event_id, r.source, r.source_elapsed, r.thread) for r in event_results) break def _execute(self, query, parameters, time_spent, max_wait): timeout = (max_wait - time_spent) if max_wait is not None else None future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout) # in case the user switched the row factory, set it to namedtuple for this query future.row_factory = named_tuple_factory future.send_request() try: return future.result() except OperationTimedOut: raise TraceUnavailable("Trace information was not available within %f seconds" % (max_wait,)) def __str__(self): return "%s [%s] coordinator: %s, started at: %s, duration: %s, parameters: %s" \ % (self.request_type, self.trace_id, self.coordinator, self.started_at, self.duration, self.parameters) class TraceEvent(object): """ Representation of a single event within a query trace. """ description = None """ A brief description of the event. """ datetime = None """ A UTC :class:`datetime.datetime` marking when the event occurred. """ source = None """ The IP address of the node this event occurred on. """ source_elapsed = None """ A :class:`datetime.timedelta` measuring the amount of time until this event occurred starting from when :attr:`.source` first received the query. """ thread_name = None """ The name of the thread that this event occurred on. """ def __init__(self, description, timeuuid, source, source_elapsed, thread_name): self.description = description self.datetime = datetime.utcfromtimestamp(unix_time_from_uuid1(timeuuid)) self.source = source if source_elapsed is not None: self.source_elapsed = timedelta(microseconds=source_elapsed) else: self.source_elapsed = None self.thread_name = thread_name def __str__(self): return "%s on %s[%s] at %s" % (self.description, self.source, self.thread_name, self.datetime) cassandra-driver-3.7.1/cassandra/bytesio.pyx0000664000175000017500000000325712743410406024003 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. cdef class BytesIOReader: """ This class provides efficient support for reading bytes from a 'bytes' buffer, by returning char * values directly without allocating intermediate objects. """ def __init__(self, bytes buf): self.buf = buf self.size = len(buf) self.buf_ptr = self.buf cdef char *read(self, Py_ssize_t n = -1) except NULL: """Read at most size bytes from the file (less if the read hits EOF before obtaining size bytes). If the size argument is negative or omitted, read all data until EOF is reached. The bytes are returned as a string object. An empty string is returned when EOF is encountered immediately. """ cdef Py_ssize_t newpos = self.pos + n if n < 0: newpos = self.size elif newpos > self.size: # Raise an error here, as we do not want the caller to consume past the # end of the buffer raise EOFError("Cannot read past the end of the file") cdef char *res = self.buf_ptr + self.pos self.pos = newpos return res cassandra-driver-3.7.1/cassandra/auth.py0000664000175000017500000001400712743410406023071 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and try: from puresasl.client import SASLClient except ImportError: SASLClient = None class AuthProvider(object): """ An abstract class that defines the interface that will be used for creating :class:`~.Authenticator` instances when opening new connections to Cassandra. .. versionadded:: 2.0.0 """ def new_authenticator(self, host): """ Implementations of this class should return a new instance of :class:`~.Authenticator` or one of its subclasses. """ raise NotImplementedError() class Authenticator(object): """ An abstract class that handles SASL authentication with Cassandra servers. Each time a new connection is created and the server requires authentication, a new instance of this class will be created by the corresponding :class:`~.AuthProvider` to handler that authentication. The lifecycle of the new :class:`~.Authenticator` will the be: 1) The :meth:`~.initial_response()` method will be called. The return value will be sent to the server to initiate the handshake. 2) The server will respond to each client response by either issuing a challenge or indicating that the authentication is complete (successful or not). If a new challenge is issued, :meth:`~.evaluate_challenge()` will be called to produce a response that will be sent to the server. This challenge/response negotiation will continue until the server responds that authentication is successful (or an :exc:`~.AuthenticationFailed` is raised). 3) When the server indicates that authentication is successful, :meth:`~.on_authentication_success` will be called a token string that that the server may optionally have sent. The exact nature of the negotiation between the client and server is specific to the authentication mechanism configured server-side. .. versionadded:: 2.0.0 """ server_authenticator_class = None """ Set during the connection AUTHENTICATE phase """ def initial_response(self): """ Returns an message to send to the server to initiate the SASL handshake. :const:`None` may be returned to send an empty message. """ return None def evaluate_challenge(self, challenge): """ Called when the server sends a challenge message. Generally, this method should return :const:`None` when authentication is complete from a client perspective. Otherwise, a string should be returned. """ raise NotImplementedError() def on_authentication_success(self, token): """ Called when the server indicates that authentication was successful. Depending on the authentication mechanism, `token` may be :const:`None` or a string. """ pass class PlainTextAuthProvider(AuthProvider): """ An :class:`~.AuthProvider` that works with Cassandra's PasswordAuthenticator. Example usage:: from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider auth_provider = PlainTextAuthProvider( username='cassandra', password='cassandra') cluster = Cluster(auth_provider=auth_provider) .. versionadded:: 2.0.0 """ def __init__(self, username, password): self.username = username self.password = password def new_authenticator(self, host): return PlainTextAuthenticator(self.username, self.password) class PlainTextAuthenticator(Authenticator): """ An :class:`~.Authenticator` that works with Cassandra's PasswordAuthenticator. .. versionadded:: 2.0.0 """ def __init__(self, username, password): self.username = username self.password = password def initial_response(self): return "\x00%s\x00%s" % (self.username, self.password) def evaluate_challenge(self, challenge): return None class SaslAuthProvider(AuthProvider): """ An :class:`~.AuthProvider` supporting general SASL auth mechanisms Suitable for GSSAPI or other SASL mechanisms Example usage:: from cassandra.cluster import Cluster from cassandra.auth import SaslAuthProvider sasl_kwargs = {'service': 'something', 'mechanism': 'GSSAPI', 'qops': 'auth'.split(',')} auth_provider = SaslAuthProvider(**sasl_kwargs) cluster = Cluster(auth_provider=auth_provider) .. versionadded:: 2.1.4 """ def __init__(self, **sasl_kwargs): if SASLClient is None: raise ImportError('The puresasl library has not been installed') if 'host' in sasl_kwargs: raise ValueError("kwargs should not contain 'host' since it is passed dynamically to new_authenticator") self.sasl_kwargs = sasl_kwargs def new_authenticator(self, host): return SaslAuthenticator(host, **self.sasl_kwargs) class SaslAuthenticator(Authenticator): """ A pass-through :class:`~.Authenticator` using the third party package 'pure-sasl' for authentication .. versionadded:: 2.1.4 """ def __init__(self, host, service, mechanism='GSSAPI', **sasl_kwargs): if SASLClient is None: raise ImportError('The puresasl library has not been installed') self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) def initial_response(self): return self.sasl.process() def evaluate_challenge(self, challenge): return self.sasl.process(challenge) cassandra-driver-3.7.1/cassandra/cython_utils.pyx0000664000175000017500000000401612743410406025043 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Duplicate module of util.py, with some accelerated functions used for deserialization. """ from libc.math cimport modf, round, fabs from cpython.datetime cimport ( timedelta_new, # cdef inline object timedelta_new(int days, int seconds, int useconds) # Create timedelta object using DateTime CAPI factory function. # Note, there are no range checks for any of the arguments. import_datetime, # Datetime C API initialization function. # You have to call it before any usage of DateTime CAPI functions. ) import datetime import sys cdef bint is_little_endian from cassandra.util import is_little_endian import_datetime() DEF DAY_IN_SECONDS = 86400 DATETIME_EPOC = datetime.datetime(1970, 1, 1) cdef datetime_from_timestamp(double timestamp): cdef int days = (timestamp / DAY_IN_SECONDS) cdef int64_t days_in_seconds = ( days) * DAY_IN_SECONDS cdef int seconds = (timestamp - days_in_seconds) cdef double tmp cdef double micros_left = modf(timestamp, &tmp) * 1000000. micros_left = modf(micros_left, &tmp) cdef int microseconds = tmp # rounding to emulate fp math in delta_new cdef int x_odd tmp = round(micros_left) if fabs(tmp - micros_left) == 0.5: x_odd = microseconds & 1 tmp = 2.0 * round((micros_left + x_odd) * 0.5) - x_odd microseconds += tmp return DATETIME_EPOC + timedelta_new(days, seconds, microseconds) cassandra-driver-3.7.1/cassandra/numpyFlags.h0000664000175000017500000000006212743410406024050 0ustar aboudreaultaboudreault00000000000000#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION cassandra-driver-3.7.1/cassandra/parsing.pyx0000664000175000017500000000263112743410406023763 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Module containing the definitions and declarations (parsing.pxd) for parsers. """ cdef class ParseDesc: """Description of what structure to parse""" def __init__(self, colnames, coltypes, deserializers, protocol_version): self.colnames = colnames self.coltypes = coltypes self.deserializers = deserializers self.protocol_version = protocol_version self.rowsize = len(colnames) cdef class ColumnParser: """Decode a ResultMessage into a set of columns""" cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc): raise NotImplementedError cdef class RowParser: """Parser for a single row""" cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): """ Unpack a single row of data in a ResultMessage. """ raise NotImplementedError cassandra-driver-3.7.1/cassandra/__init__.py0000664000175000017500000003377413004142045023673 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging class NullHandler(logging.Handler): def emit(self, record): pass logging.getLogger('cassandra').addHandler(NullHandler()) __version_info__ = (3, 7, 1) __version__ = '.'.join(map(str, __version_info__)) class ConsistencyLevel(object): """ Spcifies how many replicas must respond for an operation to be considered a success. By default, ``ONE`` is used for all operations. """ ANY = 0 """ Only requires that one replica receives the write *or* the coordinator stores a hint to replay later. Valid only for writes. """ ONE = 1 """ Only one replica needs to respond to consider the operation a success """ TWO = 2 """ Two replicas must respond to consider the operation a success """ THREE = 3 """ Three replicas must respond to consider the operation a success """ QUORUM = 4 """ ``ceil(RF/2)`` replicas must respond to consider the operation a success """ ALL = 5 """ All replicas must respond to consider the operation a success """ LOCAL_QUORUM = 6 """ Requires a quorum of replicas in the local datacenter """ EACH_QUORUM = 7 """ Requires a quorum of replicas in each datacenter """ SERIAL = 8 """ For conditional inserts/updates that utilize Cassandra's lightweight transactions, this requires consensus among all replicas for the modified data. """ LOCAL_SERIAL = 9 """ Like :attr:`~ConsistencyLevel.SERIAL`, but only requires consensus among replicas in the local datacenter. """ LOCAL_ONE = 10 """ Sends a request only to replicas in the local datacenter and waits for one response. """ ConsistencyLevel.value_to_name = { ConsistencyLevel.ANY: 'ANY', ConsistencyLevel.ONE: 'ONE', ConsistencyLevel.TWO: 'TWO', ConsistencyLevel.THREE: 'THREE', ConsistencyLevel.QUORUM: 'QUORUM', ConsistencyLevel.ALL: 'ALL', ConsistencyLevel.LOCAL_QUORUM: 'LOCAL_QUORUM', ConsistencyLevel.EACH_QUORUM: 'EACH_QUORUM', ConsistencyLevel.SERIAL: 'SERIAL', ConsistencyLevel.LOCAL_SERIAL: 'LOCAL_SERIAL', ConsistencyLevel.LOCAL_ONE: 'LOCAL_ONE' } ConsistencyLevel.name_to_value = { 'ANY': ConsistencyLevel.ANY, 'ONE': ConsistencyLevel.ONE, 'TWO': ConsistencyLevel.TWO, 'THREE': ConsistencyLevel.THREE, 'QUORUM': ConsistencyLevel.QUORUM, 'ALL': ConsistencyLevel.ALL, 'LOCAL_QUORUM': ConsistencyLevel.LOCAL_QUORUM, 'EACH_QUORUM': ConsistencyLevel.EACH_QUORUM, 'SERIAL': ConsistencyLevel.SERIAL, 'LOCAL_SERIAL': ConsistencyLevel.LOCAL_SERIAL, 'LOCAL_ONE': ConsistencyLevel.LOCAL_ONE } def consistency_value_to_name(value): return ConsistencyLevel.value_to_name[value] if value is not None else "Not Set" class SchemaChangeType(object): DROPPED = 'DROPPED' CREATED = 'CREATED' UPDATED = 'UPDATED' class SchemaTargetType(object): KEYSPACE = 'KEYSPACE' TABLE = 'TABLE' TYPE = 'TYPE' FUNCTION = 'FUNCTION' AGGREGATE = 'AGGREGATE' class SignatureDescriptor(object): def __init__(self, name, argument_types): self.name = name self.argument_types = argument_types @property def signature(self): """ function signature string in the form 'name([type0[,type1[...]]])' can be used to uniquely identify overloaded function names within a keyspace """ return self.format_signature(self.name, self.argument_types) @staticmethod def format_signature(name, argument_types): return "%s(%s)" % (name, ','.join(t for t in argument_types)) def __repr__(self): return "%s(%s, %s)" % (self.__class__.__name__, self.name, self.argument_types) class UserFunctionDescriptor(SignatureDescriptor): """ Describes a User function by name and argument signature """ name = None """ name of the function """ argument_types = None """ Ordered list of CQL argument type names comprising the type signature """ class UserAggregateDescriptor(SignatureDescriptor): """ Describes a User aggregate function by name and argument signature """ name = None """ name of the aggregate """ argument_types = None """ Ordered list of CQL argument type names comprising the type signature """ class DriverException(Exception): """ Base for all exceptions explicitly raised by the driver. """ pass class RequestExecutionException(DriverException): """ Base for request execution exceptions returned from the server. """ pass class Unavailable(RequestExecutionException): """ There were not enough live replicas to satisfy the requested consistency level, so the coordinator node immediately failed the request without forwarding it to any replicas. """ consistency = None """ The requested :class:`ConsistencyLevel` """ required_replicas = None """ The number of replicas that needed to be live to complete the operation """ alive_replicas = None """ The number of replicas that were actually alive """ def __init__(self, summary_message, consistency=None, required_replicas=None, alive_replicas=None): self.consistency = consistency self.required_replicas = required_replicas self.alive_replicas = alive_replicas Exception.__init__(self, summary_message + ' info=' + repr({'consistency': consistency_value_to_name(consistency), 'required_replicas': required_replicas, 'alive_replicas': alive_replicas})) class Timeout(RequestExecutionException): """ Replicas failed to respond to the coordinator node before timing out. """ consistency = None """ The requested :class:`ConsistencyLevel` """ required_responses = None """ The number of required replica responses """ received_responses = None """ The number of replicas that responded before the coordinator timed out the operation """ def __init__(self, summary_message, consistency=None, required_responses=None, received_responses=None): self.consistency = consistency self.required_responses = required_responses self.received_responses = received_responses Exception.__init__(self, summary_message + ' info=' + repr({'consistency': consistency_value_to_name(consistency), 'required_responses': required_responses, 'received_responses': received_responses})) class ReadTimeout(Timeout): """ A subclass of :exc:`Timeout` for read operations. This indicates that the replicas failed to respond to the coordinator node before the configured timeout. This timeout is configured in ``cassandra.yaml`` with the ``read_request_timeout_in_ms`` and ``range_request_timeout_in_ms`` options. """ data_retrieved = None """ A boolean indicating whether the requested data was retrieved by the coordinator from any replicas before it timed out the operation """ def __init__(self, message, data_retrieved=None, **kwargs): Timeout.__init__(self, message, **kwargs) self.data_retrieved = data_retrieved class WriteTimeout(Timeout): """ A subclass of :exc:`Timeout` for write operations. This indicates that the replicas failed to respond to the coordinator node before the configured timeout. This timeout is configured in ``cassandra.yaml`` with the ``write_request_timeout_in_ms`` option. """ write_type = None """ The type of write operation, enum on :class:`~cassandra.policies.WriteType` """ def __init__(self, message, write_type=None, **kwargs): Timeout.__init__(self, message, **kwargs) self.write_type = write_type class CoordinationFailure(RequestExecutionException): """ Replicas sent a failure to the coordinator. """ consistency = None """ The requested :class:`ConsistencyLevel` """ required_responses = None """ The number of required replica responses """ received_responses = None """ The number of replicas that responded before the coordinator timed out the operation """ failures = None """ The number of replicas that sent a failure message """ error_code_map = None """ A map of inet addresses to error codes representing replicas that sent a failure message. Only set when `protocol_version` is 5 or higher. """ def __init__(self, summary_message, consistency=None, required_responses=None, received_responses=None, failures=None, error_code_map=None): self.consistency = consistency self.required_responses = required_responses self.received_responses = received_responses self.failures = failures self.error_code_map = error_code_map info_dict = { 'consistency': consistency_value_to_name(consistency), 'required_responses': required_responses, 'received_responses': received_responses, 'failures': failures } if error_code_map is not None: # make error codes look like "0x002a" formatted_map = dict((addr, '0x%04x' % err_code) for (addr, err_code) in error_code_map.items()) info_dict['error_code_map'] = formatted_map Exception.__init__(self, summary_message + ' info=' + repr(info_dict)) class ReadFailure(CoordinationFailure): """ A subclass of :exc:`CoordinationFailure` for read operations. This indicates that the replicas sent a failure message to the coordinator. """ data_retrieved = None """ A boolean indicating whether the requested data was retrieved by the coordinator from any replicas before it timed out the operation """ def __init__(self, message, data_retrieved=None, **kwargs): CoordinationFailure.__init__(self, message, **kwargs) self.data_retrieved = data_retrieved class WriteFailure(CoordinationFailure): """ A subclass of :exc:`CoordinationFailure` for write operations. This indicates that the replicas sent a failure message to the coordinator. """ write_type = None """ The type of write operation, enum on :class:`~cassandra.policies.WriteType` """ def __init__(self, message, write_type=None, **kwargs): CoordinationFailure.__init__(self, message, **kwargs) self.write_type = write_type class FunctionFailure(RequestExecutionException): """ User Defined Function failed during execution """ keyspace = None """ Keyspace of the function """ function = None """ Name of the function """ arg_types = None """ List of argument type names of the function """ def __init__(self, summary_message, keyspace, function, arg_types): self.keyspace = keyspace self.function = function self.arg_types = arg_types Exception.__init__(self, summary_message) class RequestValidationException(DriverException): """ Server request validation failed """ pass class ConfigurationException(RequestValidationException): """ Server indicated request errro due to current configuration """ pass class AlreadyExists(ConfigurationException): """ An attempt was made to create a keyspace or table that already exists. """ keyspace = None """ The name of the keyspace that already exists, or, if an attempt was made to create a new table, the keyspace that the table is in. """ table = None """ The name of the table that already exists, or, if an attempt was make to create a keyspace, :const:`None`. """ def __init__(self, keyspace=None, table=None): if table: message = "Table '%s.%s' already exists" % (keyspace, table) else: message = "Keyspace '%s' already exists" % (keyspace,) Exception.__init__(self, message) self.keyspace = keyspace self.table = table class InvalidRequest(RequestValidationException): """ A query was made that was invalid for some reason, such as trying to set the keyspace for a connection to a nonexistent keyspace. """ pass class Unauthorized(RequestValidationException): """ The current user is not authorized to perform the requested operation. """ pass class AuthenticationFailed(DriverException): """ Failed to authenticate. """ pass class OperationTimedOut(DriverException): """ The operation took longer than the specified (client-side) timeout to complete. This is not an error generated by Cassandra, only the driver. """ errors = None """ A dict of errors keyed by the :class:`~.Host` against which they occurred. """ last_host = None """ The last :class:`~.Host` this operation was attempted against. """ def __init__(self, errors=None, last_host=None): self.errors = errors self.last_host = last_host message = "errors=%s, last_host=%s" % (self.errors, self.last_host) Exception.__init__(self, message) class UnsupportedOperation(DriverException): """ An attempt was made to use a feature that is not supported by the selected protocol version. See :attr:`Cluster.protocol_version` for more details. """ pass cassandra-driver-3.7.1/cassandra/ioutils.pyx0000664000175000017500000000316012743410406024006 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. include 'cython_marshal.pyx' from cassandra.buffer cimport Buffer, from_ptr_and_size from libc.stdint cimport int32_t from cassandra.bytesio cimport BytesIOReader cdef inline int get_buf(BytesIOReader reader, Buffer *buf_out) except -1: """ Get a pointer into the buffer provided by BytesIOReader for the next data item in the stream of values. BEWARE: If the next item has a zero negative size, the pointer will be set to NULL. A negative size happens when the value is NULL in the database, whereas a zero size may happen either for legacy reasons, or for data types such as strings (which may be empty). """ cdef Py_ssize_t raw_val_size = read_int(reader) cdef char *ptr if raw_val_size <= 0: ptr = NULL else: ptr = reader.read(raw_val_size) from_ptr_and_size(ptr, raw_val_size, buf_out) return 0 cdef inline int32_t read_int(BytesIOReader reader) except ?0xDEAD: cdef Buffer buf buf.ptr = reader.read(4) buf.size = 4 return unpack_num[int32_t](&buf) cassandra-driver-3.7.1/cassandra/policies.py0000664000175000017500000010352012777231260023744 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from itertools import islice, cycle, groupby, repeat import logging from random import randint from threading import Lock import socket from cassandra import ConsistencyLevel, OperationTimedOut log = logging.getLogger(__name__) class HostDistance(object): """ A measure of how "distant" a node is from the client, which may influence how the load balancer distributes requests and how many connections are opened to the node. """ IGNORED = -1 """ A node with this distance should never be queried or have connections opened to it. """ LOCAL = 0 """ Nodes with ``LOCAL`` distance will be preferred for operations under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) and will have a greater number of connections opened against them by default. This distance is typically used for nodes within the same datacenter as the client. """ REMOTE = 1 """ Nodes with ``REMOTE`` distance will be treated as a last resort by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) and will have a smaller number of connections opened against them by default. This distance is typically used for nodes outside of the datacenter that the client is running in. """ class HostStateListener(object): def on_up(self, host): """ Called when a node is marked up. """ raise NotImplementedError() def on_down(self, host): """ Called when a node is marked down. """ raise NotImplementedError() def on_add(self, host): """ Called when a node is added to the cluster. The newly added node should be considered up. """ raise NotImplementedError() def on_remove(self, host): """ Called when a node is removed from the cluster. """ raise NotImplementedError() class LoadBalancingPolicy(HostStateListener): """ Load balancing policies are used to decide how to distribute requests among all possible coordinator nodes in the cluster. In particular, they may focus on querying "near" nodes (those in a local datacenter) or on querying nodes who happen to be replicas for the requested data. You may also use subclasses of :class:`.LoadBalancingPolicy` for custom behavior. """ _hosts_lock = None def __init__(self): self._hosts_lock = Lock() def distance(self, host): """ Returns a measure of how remote a :class:`~.pool.Host` is in terms of the :class:`.HostDistance` enums. """ raise NotImplementedError() def populate(self, cluster, hosts): """ This method is called to initialize the load balancing policy with a set of :class:`.Host` instances before its first use. The `cluster` parameter is an instance of :class:`.Cluster`. """ raise NotImplementedError() def make_query_plan(self, working_keyspace=None, query=None): """ Given a :class:`~.query.Statement` instance, return a iterable of :class:`.Host` instances which should be queried in that order. A generator may work well for custom implementations of this method. Note that the `query` argument may be :const:`None` when preparing statements. `working_keyspace` should be the string name of the current keyspace, as set through :meth:`.Session.set_keyspace()` or with a ``USE`` statement. """ raise NotImplementedError() def check_supported(self): """ This will be called after the cluster Metadata has been initialized. If the load balancing policy implementation cannot be supported for some reason (such as a missing C extension), this is the point at which it should raise an exception. """ pass class RoundRobinPolicy(LoadBalancingPolicy): """ A subclass of :class:`.LoadBalancingPolicy` which evenly distributes queries across all nodes in the cluster, regardless of what datacenter the nodes may be in. This load balancing policy is used by default. """ _live_hosts = frozenset(()) _position = 0 def populate(self, cluster, hosts): self._live_hosts = frozenset(hosts) if len(hosts) > 1: self._position = randint(0, len(hosts) - 1) def distance(self, host): return HostDistance.LOCAL def make_query_plan(self, working_keyspace=None, query=None): # not thread-safe, but we don't care much about lost increments # for the purposes of load balancing pos = self._position self._position += 1 hosts = self._live_hosts length = len(hosts) if length: pos %= length return islice(cycle(hosts), pos, pos + length) else: return [] def on_up(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.union((host, )) def on_down(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.difference((host, )) def on_add(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.union((host, )) def on_remove(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.difference((host, )) class DCAwareRoundRobinPolicy(LoadBalancingPolicy): """ Similar to :class:`.RoundRobinPolicy`, but prefers hosts in the local datacenter and only uses nodes in remote datacenters as a last resort. """ local_dc = None used_hosts_per_remote_dc = 0 def __init__(self, local_dc='', used_hosts_per_remote_dc=0): """ The `local_dc` parameter should be the name of the datacenter (such as is reported by ``nodetool ring``) that should be considered local. If not specified, the driver will choose a local_dc based on the first host among :attr:`.Cluster.contact_points` having a valid DC. If relying on this mechanism, all specified contact points should be nodes in a single, local DC. `used_hosts_per_remote_dc` controls how many nodes in each remote datacenter will have connections opened against them. In other words, `used_hosts_per_remote_dc` hosts will be considered :attr:`~.HostDistance.REMOTE` and the rest will be considered :attr:`~.HostDistance.IGNORED`. By default, all remote hosts are ignored. """ self.local_dc = local_dc self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} self._position = 0 self._contact_points = [] LoadBalancingPolicy.__init__(self) def _dc(self, host): return host.datacenter or self.local_dc def populate(self, cluster, hosts): for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): self._dc_live_hosts[dc] = tuple(set(dc_hosts)) if not self.local_dc: self._contact_points = cluster.contact_points_resolved self._position = randint(0, len(hosts) - 1) if hosts else 0 def distance(self, host): dc = self._dc(host) if dc == self.local_dc: return HostDistance.LOCAL if not self.used_hosts_per_remote_dc: return HostDistance.IGNORED else: dc_hosts = self._dc_live_hosts.get(dc) if not dc_hosts: return HostDistance.IGNORED if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]: return HostDistance.REMOTE else: return HostDistance.IGNORED def make_query_plan(self, working_keyspace=None, query=None): # not thread-safe, but we don't care much about lost increments # for the purposes of load balancing pos = self._position self._position += 1 local_live = self._dc_live_hosts.get(self.local_dc, ()) pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host # the dict can change, so get candidate DCs iterating over keys of a copy other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] for dc in other_dcs: remote_live = self._dc_live_hosts.get(dc, ()) for host in remote_live[:self.used_hosts_per_remote_dc]: yield host def on_up(self, host): # not worrying about threads because this will happen during # control connection startup/refresh if not self.local_dc and host.datacenter: if host.address in self._contact_points: self.local_dc = host.datacenter log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " "if incorrect, please specify a local_dc to the constructor, " "or limit contact points to local cluster nodes" % (self.local_dc, host.address)) del self._contact_points dc = self._dc(host) with self._hosts_lock: current_hosts = self._dc_live_hosts.get(dc, ()) if host not in current_hosts: self._dc_live_hosts[dc] = current_hosts + (host, ) def on_down(self, host): dc = self._dc(host) with self._hosts_lock: current_hosts = self._dc_live_hosts.get(dc, ()) if host in current_hosts: hosts = tuple(h for h in current_hosts if h != host) if hosts: self._dc_live_hosts[dc] = hosts else: del self._dc_live_hosts[dc] def on_add(self, host): self.on_up(host) def on_remove(self, host): self.on_down(host) class TokenAwarePolicy(LoadBalancingPolicy): """ A :class:`.LoadBalancingPolicy` wrapper that adds token awareness to a child policy. This alters the child policy's behavior so that it first attempts to send queries to :attr:`~.HostDistance.LOCAL` replicas (as determined by the child policy) based on the :class:`.Statement`'s :attr:`~.Statement.routing_key`. Once those hosts are exhausted, the remaining hosts in the child policy's query plan will be used. If no :attr:`~.Statement.routing_key` is set on the query, the child policy's query plan will be used as is. """ _child_policy = None _cluster_metadata = None def __init__(self, child_policy): self._child_policy = child_policy def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata self._child_policy.populate(cluster, hosts) def check_supported(self): if not self._cluster_metadata.can_support_partitioner(): raise RuntimeError( '%s cannot be used with the cluster partitioner (%s) because ' 'the relevant C extension for this driver was not compiled. ' 'See the installation instructions for details on building ' 'and installing the C extensions.' % (self.__class__.__name__, self._cluster_metadata.partitioner)) def distance(self, *args, **kwargs): return self._child_policy.distance(*args, **kwargs) def make_query_plan(self, working_keyspace=None, query=None): if query and query.keyspace: keyspace = query.keyspace else: keyspace = working_keyspace child = self._child_policy if query is None: for host in child.make_query_plan(keyspace, query): yield host else: routing_key = query.routing_key if routing_key is None or keyspace is None: for host in child.make_query_plan(keyspace, query): yield host else: replicas = self._cluster_metadata.get_replicas(keyspace, routing_key) for replica in replicas: if replica.is_up and \ child.distance(replica) == HostDistance.LOCAL: yield replica for host in child.make_query_plan(keyspace, query): # skip if we've already listed this host if host not in replicas or \ child.distance(host) == HostDistance.REMOTE: yield host def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) def on_down(self, *args, **kwargs): return self._child_policy.on_down(*args, **kwargs) def on_add(self, *args, **kwargs): return self._child_policy.on_add(*args, **kwargs) def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) class WhiteListRoundRobinPolicy(RoundRobinPolicy): """ A subclass of :class:`.RoundRobinPolicy` which evenly distributes queries across all nodes in the cluster, regardless of what datacenter the nodes may be in, but only if that node exists in the list of allowed nodes This policy is addresses the issue described in https://datastax-oss.atlassian.net/browse/JAVA-145 Where connection errors occur when connection attempts are made to private IP addresses remotely """ def __init__(self, hosts): """ The `hosts` parameter should be a sequence of hosts to permit connections to. """ self._allowed_hosts = hosts self._allowed_hosts_resolved = [endpoint[4][0] for a in self._allowed_hosts for endpoint in socket.getaddrinfo(a, None, socket.AF_UNSPEC, socket.SOCK_STREAM)] RoundRobinPolicy.__init__(self) def populate(self, cluster, hosts): self._live_hosts = frozenset(h for h in hosts if h.address in self._allowed_hosts_resolved) if len(hosts) <= 1: self._position = 0 else: self._position = randint(0, len(hosts) - 1) def distance(self, host): if host.address in self._allowed_hosts_resolved: return HostDistance.LOCAL else: return HostDistance.IGNORED def on_up(self, host): if host.address in self._allowed_hosts_resolved: RoundRobinPolicy.on_up(self, host) def on_add(self, host): if host.address in self._allowed_hosts_resolved: RoundRobinPolicy.on_add(self, host) class ConvictionPolicy(object): """ A policy which decides when hosts should be considered down based on the types of failures and the number of failures. If custom behavior is needed, this class may be subclassed. """ def __init__(self, host): """ `host` is an instance of :class:`.Host`. """ self.host = host def add_failure(self, connection_exc): """ Implementations should return :const:`True` if the host should be convicted, :const:`False` otherwise. """ raise NotImplementedError() def reset(self): """ Implementations should clear out any convictions or state regarding the host. """ raise NotImplementedError() class SimpleConvictionPolicy(ConvictionPolicy): """ The default implementation of :class:`ConvictionPolicy`, which simply marks a host as down after the first failure of any kind. """ def add_failure(self, connection_exc): return not isinstance(connection_exc, OperationTimedOut) def reset(self): pass class ReconnectionPolicy(object): """ This class and its subclasses govern how frequently an attempt is made to reconnect to nodes that are marked as dead. If custom behavior is needed, this class may be subclassed. """ def new_schedule(self): """ This should return a finite or infinite iterable of delays (each as a floating point number of seconds) inbetween each failed reconnection attempt. Note that if the iterable is finite, reconnection attempts will cease once the iterable is exhausted. """ raise NotImplementedError() class ConstantReconnectionPolicy(ReconnectionPolicy): """ A :class:`.ReconnectionPolicy` subclass which sleeps for a fixed delay inbetween each reconnection attempt. """ def __init__(self, delay, max_attempts=64): """ `delay` should be a floating point number of seconds to wait inbetween each attempt. `max_attempts` should be a total number of attempts to be made before giving up, or :const:`None` to continue reconnection attempts forever. The default is 64. """ if delay < 0: raise ValueError("delay must not be negative") if max_attempts is not None and max_attempts < 0: raise ValueError("max_attempts must not be negative") self.delay = delay self.max_attempts = max_attempts def new_schedule(self): if self.max_attempts: return repeat(self.delay, self.max_attempts) return repeat(self.delay) class ExponentialReconnectionPolicy(ReconnectionPolicy): """ A :class:`.ReconnectionPolicy` subclass which exponentially increases the length of the delay inbetween each reconnection attempt up to a set maximum delay. """ # TODO: max_attempts is 64 to preserve legacy default behavior # consider changing to None in major release to prevent the policy # giving up forever def __init__(self, base_delay, max_delay, max_attempts=64): """ `base_delay` and `max_delay` should be in floating point units of seconds. `max_attempts` should be a total number of attempts to be made before giving up, or :const:`None` to continue reconnection attempts forever. The default is 64. """ if base_delay < 0 or max_delay < 0: raise ValueError("Delays may not be negative") if max_delay < base_delay: raise ValueError("Max delay must be greater than base delay") if max_attempts is not None and max_attempts < 0: raise ValueError("max_attempts must not be negative") self.base_delay = base_delay self.max_delay = max_delay self.max_attempts = max_attempts def new_schedule(self): i = 0 while self.max_attempts is None or i < self.max_attempts: yield min(self.base_delay * (2 ** i), self.max_delay) i += 1 class WriteType(object): """ For usage with :class:`.RetryPolicy`, this describe a type of write operation. """ SIMPLE = 0 """ A write to a single partition key. Such writes are guaranteed to be atomic and isolated. """ BATCH = 1 """ A write to multiple partition keys that used the distributed batch log to ensure atomicity. """ UNLOGGED_BATCH = 2 """ A write to multiple partition keys that did not use the distributed batch log. Atomicity for such writes is not guaranteed. """ COUNTER = 3 """ A counter write (for one or multiple partition keys). Such writes should not be replayed in order to avoid overcount. """ BATCH_LOG = 4 """ The initial write to the distributed batch log that Cassandra performs internally before a BATCH write. """ CAS = 5 """ A lighweight-transaction write, such as "DELETE ... IF EXISTS". """ WriteType.name_to_value = { 'SIMPLE': WriteType.SIMPLE, 'BATCH': WriteType.BATCH, 'UNLOGGED_BATCH': WriteType.UNLOGGED_BATCH, 'COUNTER': WriteType.COUNTER, 'BATCH_LOG': WriteType.BATCH_LOG, 'CAS': WriteType.CAS } class RetryPolicy(object): """ A policy that describes whether to retry, rethrow, or ignore coordinator timeout and unavailable failures. These are failures reported from the server side. Timeouts are configured by `settings in cassandra.yaml `_. Unavailable failures occur when the coordinator cannot acheive the consistency level for a request. For further information see the method descriptions below. To specify a default retry policy, set the :attr:`.Cluster.default_retry_policy` attribute to an instance of this class or one of its subclasses. To specify a retry policy per query, set the :attr:`.Statement.retry_policy` attribute to an instance of this class or one of its subclasses. If custom behavior is needed for retrying certain operations, this class may be subclassed. """ RETRY = 0 """ This should be returned from the below methods if the operation should be retried on the same connection. """ RETHROW = 1 """ This should be returned from the below methods if the failure should be propagated and no more retries attempted. """ IGNORE = 2 """ This should be returned from the below methods if the failure should be ignored but no more retries should be attempted. """ RETRY_NEXT_HOST = 3 """ This should be returned from the below methods if the operation should be retried on another connection. """ def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): """ This is called when a read operation times out from the coordinator's perspective (i.e. a replica did not respond to the coordinator in time). It should return a tuple with two items: one of the class enums (such as :attr:`.RETRY`) and a :class:`.ConsistencyLevel` to retry the operation at or :const:`None` to keep the same consistency level. `query` is the :class:`.Statement` that timed out. `consistency` is the :class:`.ConsistencyLevel` that the operation was attempted at. The `required_responses` and `received_responses` parameters describe how many replicas needed to respond to meet the requested consistency level and how many actually did respond before the coordinator timed out the request. `data_retrieved` is a boolean indicating whether any of those responses contained data (as opposed to just a digest). `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. By default, operations will be retried at most once, and only if a sufficient number of replicas responded (with data digests). """ if retry_num != 0: return self.RETHROW, None elif received_responses >= required_responses and not data_retrieved: return self.RETRY, consistency else: return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): """ This is called when a write operation times out from the coordinator's perspective (i.e. a replica did not respond to the coordinator in time). `query` is the :class:`.Statement` that timed out. `consistency` is the :class:`.ConsistencyLevel` that the operation was attempted at. `write_type` is one of the :class:`.WriteType` enums describing the type of write operation. The `required_responses` and `received_responses` parameters describe how many replicas needed to acknowledge the write to meet the requested consistency level and how many replicas actually did acknowledge the write before the coordinator timed out the request. `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. By default, failed write operations will retried at most once, and they will only be retried if the `write_type` was :attr:`~.WriteType.BATCH_LOG`. """ if retry_num != 0: return self.RETHROW, None elif write_type == WriteType.BATCH_LOG: return self.RETRY, consistency else: return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): """ This is called when the coordinator node determines that a read or write operation cannot be successful because the number of live replicas are too low to meet the requested :class:`.ConsistencyLevel`. This means that the read or write operation was never forwared to any replicas. `query` is the :class:`.Statement` that failed. `consistency` is the :class:`.ConsistencyLevel` that the operation was attempted at. `required_replicas` is the number of replicas that would have needed to acknowledge the operation to meet the requested consistency level. `alive_replicas` is the number of replicas that the coordinator considered alive at the time of the request. `retry_num` counts how many times the operation has been retried, so the first time this method is called, `retry_num` will be 0. By default, no retries will be attempted and the error will be re-raised. """ return (self.RETRY_NEXT_HOST, consistency) if retry_num == 0 else (self.RETHROW, None) class FallthroughRetryPolicy(RetryPolicy): """ A retry policy that never retries and always propagates failures to the application. """ def on_read_timeout(self, *args, **kwargs): return self.RETHROW, None def on_write_timeout(self, *args, **kwargs): return self.RETHROW, None def on_unavailable(self, *args, **kwargs): return self.RETHROW, None class DowngradingConsistencyRetryPolicy(RetryPolicy): """ A retry policy that sometimes retries with a lower consistency level than the one initially requested. **BEWARE**: This policy may retry queries using a lower consistency level than the one initially requested. By doing so, it may break consistency guarantees. In other words, if you use this retry policy, there are cases (documented below) where a read at :attr:`~.QUORUM` *may not* see a preceding write at :attr:`~.QUORUM`. Do not use this policy unless you have understood the cases where this can happen and are ok with that. It is also recommended to subclass this class so that queries that required a consistency level downgrade can be recorded (so that repairs can be made later, etc). This policy implements the same retries as :class:`.RetryPolicy`, but on top of that, it also retries in the following cases: * On a read timeout: if the number of replicas that responded is greater than one but lower than is required by the requested consistency level, the operation is retried at a lower consistency level. * On a write timeout: if the operation is an :attr:`~.UNLOGGED_BATCH` and at least one replica acknowledged the write, the operation is retried at a lower consistency level. Furthermore, for other write types, if at least one replica acknowledged the write, the timeout is ignored. * On an unavailable exception: if at least one replica is alive, the operation is retried at a lower consistency level. The reasoning behind this retry policy is as follows: if, based on the information the Cassandra coordinator node returns, retrying the operation with the initially requested consistency has a chance to succeed, do it. Otherwise, if based on that information we know the initially requested consistency level cannot be achieved currently, then: * For writes, ignore the exception (thus silently failing the consistency requirement) if we know the write has been persisted on at least one replica. * For reads, try reading at a lower consistency level (thus silently failing the consistency requirement). In other words, this policy implements the idea that if the requested consistency level cannot be achieved, the next best thing for writes is to make sure the data is persisted, and that reading something is better than reading nothing, even if there is a risk of reading stale data. """ def _pick_consistency(self, num_responses): if num_responses >= 3: return self.RETRY, ConsistencyLevel.THREE elif num_responses >= 2: return self.RETRY, ConsistencyLevel.TWO elif num_responses >= 1: return self.RETRY, ConsistencyLevel.ONE else: return self.RETHROW, None def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): if retry_num != 0: return self.RETHROW, None elif received_responses < required_responses: return self._pick_consistency(received_responses) elif not data_retrieved: return self.RETRY, consistency else: return self.RETHROW, None def on_write_timeout(self, query, consistency, write_type, required_responses, received_responses, retry_num): if retry_num != 0: return self.RETHROW, None if write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): if received_responses > 0: # persisted on at least one replica return self.IGNORE, None else: return self.RETHROW, None elif write_type == WriteType.UNLOGGED_BATCH: return self._pick_consistency(received_responses) elif write_type == WriteType.BATCH_LOG: return self.RETRY, consistency return self.RETHROW, None def on_unavailable(self, query, consistency, required_replicas, alive_replicas, retry_num): if retry_num != 0: return self.RETHROW, None else: return self._pick_consistency(alive_replicas) class AddressTranslator(object): """ Interface for translating cluster-defined endpoints. The driver discovers nodes using server metadata and topology change events. Normally, the endpoint defined by the server is the right way to connect to a node. In some environments, these addresses may not be reachable, or not preferred (public vs. private IPs in cloud environments, suboptimal routing, etc). This interface allows for translating from server defined endpoints to preferred addresses for driver connections. *Note:* :attr:`~Cluster.contact_points` provided while creating the :class:`~.Cluster` instance are not translated using this mechanism -- only addresses received from Cassandra nodes are. """ def translate(self, addr): """ Accepts the node ip address, and returns a translated address to be used connecting to this node. """ raise NotImplementedError() class IdentityTranslator(AddressTranslator): """ Returns the endpoint with no translation """ def translate(self, addr): return addr class EC2MultiRegionTranslator(AddressTranslator): """ Resolves private ips of the hosts in the same datacenter as the client, and public ips of hosts in other datacenters. """ def translate(self, addr): """ Reverse DNS the public broadcast_address, then lookup that hostname to get the AWS-resolved IP, which will point to the private IP address within the same datacenter. """ # get family of this address so we translate to the same family = socket.getaddrinfo(addr, 0, socket.AF_UNSPEC, socket.SOCK_STREAM)[0][0] host = socket.getfqdn(addr) for a in socket.getaddrinfo(host, 0, family, socket.SOCK_STREAM): try: return a[4][0] except Exception: pass return addr class SpeculativeExecutionPolicy(object): """ Interface for specifying speculative execution plans """ def new_plan(self, keyspace, statement): """ Returns :param keyspace: :param statement: :return: """ raise NotImplementedError() class SpeculativeExecutionPlan(object): def next_execution(self, host): raise NotImplementedError() class NoSpeculativeExecutionPlan(SpeculativeExecutionPlan): def next_execution(self, host): return -1 class NoSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): def new_plan(self, keyspace, statement): return NoSpeculativeExecutionPlan() class ConstantSpeculativeExecutionPolicy(SpeculativeExecutionPolicy): """ A speculative execution policy that sends a new query every X seconds (**delay**) for a maximum of Y attempts (**max_attempts**). """ def __init__(self, delay, max_attempts): self.delay = delay self.max_attempts = max_attempts class ConstantSpeculativeExecutionPlan(SpeculativeExecutionPlan): def __init__(self, delay, max_attempts): self.delay = delay self.remaining = max_attempts def next_execution(self, host): if self.remaining > 0: self.remaining -= 1 return self.delay else: return -1 def new_plan(self, keyspace, statement): return self.ConstantSpeculativeExecutionPlan(self.delay, self.max_attempts) cassandra-driver-3.7.1/cassandra/tuple.pxd0000664000175000017500000000306712743410406023430 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from cpython.tuple cimport ( PyTuple_New, # Return value: New reference. # Return a new tuple object of size len, or NULL on failure. PyTuple_SET_ITEM, # Like PyTuple_SetItem(), but does no error checking, and should # only be used to fill in brand new tuples. Note: This function # ``steals'' a reference to o. ) from cpython.ref cimport ( Py_INCREF # void Py_INCREF(object o) # Increment the reference count for object o. The object must not # be NULL; if you aren't sure that it isn't NULL, use # Py_XINCREF(). ) cdef inline tuple tuple_new(Py_ssize_t n): """Allocate a new tuple object""" return PyTuple_New(n) cdef inline void tuple_set(tuple tup, Py_ssize_t idx, object item): """Insert new object into tuple. No item must have been set yet.""" # PyTuple_SET_ITEM steals a reference, so we need to INCREF Py_INCREF(item) PyTuple_SET_ITEM(tup, idx, item) cassandra-driver-3.7.1/cassandra/deserializers.pyx0000664000175000017500000004103712766043657025207 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from libc.stdint cimport int32_t, uint16_t include 'cython_marshal.pyx' from cassandra.buffer cimport Buffer, to_bytes, slice_buffer from cassandra.cython_utils cimport datetime_from_timestamp from cython.view cimport array as cython_array from cassandra.tuple cimport tuple_new, tuple_set import socket from decimal import Decimal from uuid import UUID from cassandra import cqltypes from cassandra import util cdef bint PY2 = six.PY2 cdef class Deserializer: """Cython-based deserializer class for a cqltype""" def __init__(self, cqltype): self.cqltype = cqltype self.empty_binary_ok = cqltype.empty_binary_ok cdef deserialize(self, Buffer *buf, int protocol_version): raise NotImplementedError cdef class DesBytesType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): if buf.size == 0: return b"" return to_bytes(buf) # this is to facilitate cqlsh integration, which requires bytearrays for BytesType # It is switched in by simply overwriting DesBytesType: # deserializers.DesBytesType = deserializers.DesBytesTypeByteArray cdef class DesBytesTypeByteArray(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): if buf.size == 0: return bytearray() return bytearray(buf.ptr[:buf.size]) # TODO: Use libmpdec: http://www.bytereef.org/mpdecimal/index.html cdef class DesDecimalType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): cdef Buffer varint_buf slice_buffer(buf, &varint_buf, 4, buf.size - 4) cdef int32_t scale = unpack_num[int32_t](buf) unscaled = varint_unpack(&varint_buf) return Decimal('%de%d' % (unscaled, -scale)) cdef class DesUUIDType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): return UUID(bytes=to_bytes(buf)) cdef class DesBooleanType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): if unpack_num[int8_t](buf): return True return False cdef class DesByteType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): return unpack_num[int8_t](buf) cdef class DesAsciiType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): if buf.size == 0: return "" if PY2: return to_bytes(buf) return to_bytes(buf).decode('ascii') cdef class DesFloatType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): return unpack_num[float](buf) cdef class DesDoubleType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): return unpack_num[double](buf) cdef class DesLongType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): return unpack_num[int64_t](buf) cdef class DesInt32Type(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): return unpack_num[int32_t](buf) cdef class DesIntegerType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): return varint_unpack(buf) cdef class DesInetAddressType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): cdef bytes byts = to_bytes(buf) # TODO: optimize inet_ntop, inet_ntoa if buf.size == 16: return util.inet_ntop(socket.AF_INET6, byts) else: # util.inet_pton could also handle, but this is faster # since we've already determined the AF return socket.inet_ntoa(byts) cdef class DesCounterColumnType(DesLongType): pass cdef class DesDateType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): cdef double timestamp = unpack_num[int64_t](buf) / 1000.0 return datetime_from_timestamp(timestamp) cdef class TimestampType(DesDateType): pass cdef class TimeUUIDType(DesDateType): cdef deserialize(self, Buffer *buf, int protocol_version): return UUID(bytes=to_bytes(buf)) # Values of the 'date'` type are encoded as 32-bit unsigned integers # representing a number of days with epoch (January 1st, 1970) at the center of the # range (2^31). EPOCH_OFFSET_DAYS = 2 ** 31 cdef class DesSimpleDateType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): days = unpack_num[uint32_t](buf) - EPOCH_OFFSET_DAYS return util.Date(days) cdef class DesShortType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): return unpack_num[int16_t](buf) cdef class DesTimeType(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): return util.Time(unpack_num[int64_t](buf)) cdef class DesUTF8Type(Deserializer): cdef deserialize(self, Buffer *buf, int protocol_version): if buf.size == 0: return "" cdef val = to_bytes(buf) return val.decode('utf8') cdef class DesVarcharType(DesUTF8Type): pass cdef class _DesParameterizedType(Deserializer): cdef object subtypes cdef Deserializer[::1] deserializers cdef Py_ssize_t subtypes_len def __init__(self, cqltype): super().__init__(cqltype) self.subtypes = cqltype.subtypes self.deserializers = make_deserializers(cqltype.subtypes) self.subtypes_len = len(self.subtypes) cdef class _DesSingleParamType(_DesParameterizedType): cdef Deserializer deserializer def __init__(self, cqltype): assert cqltype.subtypes and len(cqltype.subtypes) == 1, cqltype.subtypes super().__init__(cqltype) self.deserializer = self.deserializers[0] #-------------------------------------------------------------------------- # List and set deserialization cdef class DesListType(_DesSingleParamType): cdef deserialize(self, Buffer *buf, int protocol_version): cdef uint16_t v2_and_below = 2 cdef int32_t v3_and_above = 3 if protocol_version >= 3: result = _deserialize_list_or_set[int32_t]( v3_and_above, buf, protocol_version, self.deserializer) else: result = _deserialize_list_or_set[uint16_t]( v2_and_below, buf, protocol_version, self.deserializer) return result cdef class DesSetType(DesListType): cdef deserialize(self, Buffer *buf, int protocol_version): return util.sortedset(DesListType.deserialize(self, buf, protocol_version)) ctypedef fused itemlen_t: uint16_t # protocol <= v2 int32_t # protocol >= v3 cdef list _deserialize_list_or_set(itemlen_t dummy_version, Buffer *buf, int protocol_version, Deserializer deserializer): """ Deserialize a list or set. The 'dummy' parameter is needed to make fused types work, so that we can specialize on the protocol version. """ cdef Buffer itemlen_buf cdef Buffer elem_buf cdef itemlen_t numelements cdef int offset cdef list result = [] _unpack_len[itemlen_t](buf, 0, &numelements) offset = sizeof(itemlen_t) protocol_version = max(3, protocol_version) for _ in range(numelements): subelem[itemlen_t](buf, &elem_buf, &offset, dummy_version) result.append(from_binary(deserializer, &elem_buf, protocol_version)) return result cdef inline int subelem( Buffer *buf, Buffer *elem_buf, int* offset, itemlen_t dummy) except -1: """ Read the next element from the buffer: first read the size (in bytes) of the element, then fill elem_buf with a newly sliced buffer of this size (and the right offset). """ cdef itemlen_t elemlen _unpack_len[itemlen_t](buf, offset[0], &elemlen) offset[0] += sizeof(itemlen_t) slice_buffer(buf, elem_buf, offset[0], elemlen) offset[0] += elemlen return 0 cdef int _unpack_len(Buffer *buf, int offset, itemlen_t *output) except -1: cdef Buffer itemlen_buf slice_buffer(buf, &itemlen_buf, offset, sizeof(itemlen_t)) if itemlen_t is uint16_t: output[0] = unpack_num[uint16_t](&itemlen_buf) else: output[0] = unpack_num[int32_t](&itemlen_buf) return 0 #-------------------------------------------------------------------------- # Map deserialization cdef class DesMapType(_DesParameterizedType): cdef Deserializer key_deserializer, val_deserializer def __init__(self, cqltype): super().__init__(cqltype) self.key_deserializer = self.deserializers[0] self.val_deserializer = self.deserializers[1] cdef deserialize(self, Buffer *buf, int protocol_version): cdef uint16_t v2_and_below = 0 cdef int32_t v3_and_above = 0 key_type, val_type = self.cqltype.subtypes if protocol_version >= 3: result = _deserialize_map[int32_t]( v3_and_above, buf, protocol_version, self.key_deserializer, self.val_deserializer, key_type, val_type) else: result = _deserialize_map[uint16_t]( v2_and_below, buf, protocol_version, self.key_deserializer, self.val_deserializer, key_type, val_type) return result cdef _deserialize_map(itemlen_t dummy_version, Buffer *buf, int protocol_version, Deserializer key_deserializer, Deserializer val_deserializer, key_type, val_type): cdef Buffer key_buf, val_buf cdef Buffer itemlen_buf cdef itemlen_t numelements cdef int offset cdef list result = [] _unpack_len[itemlen_t](buf, 0, &numelements) offset = sizeof(itemlen_t) themap = util.OrderedMapSerializedKey(key_type, protocol_version) protocol_version = max(3, protocol_version) for _ in range(numelements): subelem[itemlen_t](buf, &key_buf, &offset, dummy_version) subelem[itemlen_t](buf, &val_buf, &offset, numelements) key = from_binary(key_deserializer, &key_buf, protocol_version) val = from_binary(val_deserializer, &val_buf, protocol_version) themap._insert_unchecked(key, to_bytes(&key_buf), val) return themap #-------------------------------------------------------------------------- cdef class DesTupleType(_DesParameterizedType): # TODO: Use TupleRowParser to parse these tuples cdef deserialize(self, Buffer *buf, int protocol_version): cdef Py_ssize_t i, p cdef int32_t itemlen cdef tuple res = tuple_new(self.subtypes_len) cdef Buffer item_buf cdef Buffer itemlen_buf cdef Deserializer deserializer # collections inside UDTs are always encoded with at least the # version 3 format protocol_version = max(3, protocol_version) p = 0 values = [] for i in range(self.subtypes_len): item = None if p < buf.size: slice_buffer(buf, &itemlen_buf, p, 4) itemlen = unpack_num[int32_t](&itemlen_buf) p += 4 if itemlen >= 0: slice_buffer(buf, &item_buf, p, itemlen) p += itemlen deserializer = self.deserializers[i] item = from_binary(deserializer, &item_buf, protocol_version) tuple_set(res, i, item) return res cdef class DesUserType(DesTupleType): cdef deserialize(self, Buffer *buf, int protocol_version): typ = self.cqltype values = DesTupleType.deserialize(self, buf, protocol_version) if typ.mapped_class: return typ.mapped_class(**dict(zip(typ.fieldnames, values))) elif typ.tuple_type: return typ.tuple_type(*values) else: return tuple(values) cdef class DesCompositeType(_DesParameterizedType): cdef deserialize(self, Buffer *buf, int protocol_version): cdef Py_ssize_t i, idx, start cdef Buffer elem_buf cdef int16_t element_length cdef Deserializer deserializer cdef tuple res = tuple_new(self.subtypes_len) idx = 0 for i in range(self.subtypes_len): if not buf.size: # CompositeType can have missing elements at the end # Fill the tuple with None values and slice it # # (I'm not sure a tuple needs to be fully initialized before # it can be destroyed, so play it safe) for j in range(i, self.subtypes_len): tuple_set(res, j, None) res = res[:i] break element_length = unpack_num[uint16_t](buf) slice_buffer(buf, &elem_buf, 2, element_length) deserializer = self.deserializers[i] item = from_binary(deserializer, &elem_buf, protocol_version) tuple_set(res, i, item) # skip element length, element, and the EOC (one byte) start = 2 + element_length + 1 slice_buffer(buf, buf, start, buf.size - start) return res DesDynamicCompositeType = DesCompositeType cdef class DesReversedType(_DesSingleParamType): cdef deserialize(self, Buffer *buf, int protocol_version): return from_binary(self.deserializer, buf, protocol_version) cdef class DesFrozenType(_DesSingleParamType): cdef deserialize(self, Buffer *buf, int protocol_version): return from_binary(self.deserializer, buf, protocol_version) #-------------------------------------------------------------------------- cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size): """ Decide whether to return None or EMPTY when a buffer size is zero or negative. This is used by from_binary in deserializers.pxd. """ if buf_size < 0: return None elif deserializer.cqltype.support_empty_values: return cqltypes.EMPTY else: return None #-------------------------------------------------------------------------- # Generic deserialization cdef class GenericDeserializer(Deserializer): """ Wrap a generic datatype for deserialization """ cdef deserialize(self, Buffer *buf, int protocol_version): return self.cqltype.deserialize(to_bytes(buf), protocol_version) def __repr__(self): return "GenericDeserializer(%s)" % (self.cqltype,) #-------------------------------------------------------------------------- # Helper utilities def make_deserializers(cqltypes): """Create an array of Deserializers for each given cqltype in cqltypes""" cdef Deserializer[::1] deserializers return obj_array([find_deserializer(ct) for ct in cqltypes]) cdef dict classes = globals() cpdef Deserializer find_deserializer(cqltype): """Find a deserializer for a cqltype""" name = 'Des' + cqltype.__name__ if name in globals(): cls = classes[name] elif issubclass(cqltype, cqltypes.ListType): cls = DesListType elif issubclass(cqltype, cqltypes.SetType): cls = DesSetType elif issubclass(cqltype, cqltypes.MapType): cls = DesMapType elif issubclass(cqltype, cqltypes.UserType): # UserType is a subclass of TupleType, so should precede it cls = DesUserType elif issubclass(cqltype, cqltypes.TupleType): cls = DesTupleType elif issubclass(cqltype, cqltypes.DynamicCompositeType): # DynamicCompositeType is a subclass of CompositeType, so should precede it cls = DesDynamicCompositeType elif issubclass(cqltype, cqltypes.CompositeType): cls = DesCompositeType elif issubclass(cqltype, cqltypes.ReversedType): cls = DesReversedType elif issubclass(cqltype, cqltypes.FrozenType): cls = DesFrozenType else: cls = GenericDeserializer return cls(cqltype) def obj_array(list objs): """Create a (Cython) array of objects given a list of objects""" cdef object[:] arr cdef Py_ssize_t i arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O") # arr[:] = objs # This does not work (segmentation faults) for i, obj in enumerate(objs): arr[i] = obj return arr cassandra-driver-3.7.1/cassandra/type_codes.py0000664000175000017500000000334312766043657024306 0ustar aboudreaultaboudreault00000000000000""" Module with constants for Cassandra type codes. These constants are useful for a) mapping messages to cqltypes (cassandra/cqltypes.py) b) optimized dispatching for (de)serialization (cassandra/encoding.py) Type codes are repeated here from the Cassandra binary protocol specification: 0x0000 Custom: the value is a [string], see above. 0x0001 Ascii 0x0002 Bigint 0x0003 Blob 0x0004 Boolean 0x0005 Counter 0x0006 Decimal 0x0007 Double 0x0008 Float 0x0009 Int 0x000A Text 0x000B Timestamp 0x000C Uuid 0x000D Varchar 0x000E Varint 0x000F Timeuuid 0x0010 Inet 0x0020 List: the value is an [option], representing the type of the elements of the list. 0x0021 Map: the value is two [option], representing the types of the keys and values of the map 0x0022 Set: the value is an [option], representing the type of the elements of the set """ CUSTOM_TYPE = 0x0000 AsciiType = 0x0001 LongType = 0x0002 BytesType = 0x0003 BooleanType = 0x0004 CounterColumnType = 0x0005 DecimalType = 0x0006 DoubleType = 0x0007 FloatType = 0x0008 Int32Type = 0x0009 UTF8Type = 0x000A DateType = 0x000B UUIDType = 0x000C VarcharType = 0x000D IntegerType = 0x000E TimeUUIDType = 0x000F InetAddressType = 0x0010 SimpleDateType = 0x0011 TimeType = 0x0012 ShortType = 0x0013 ByteType = 0x0014 ListType = 0x0020 MapType = 0x0021 SetType = 0x0022 UserType = 0x0030 TupleType = 0x0031 cassandra-driver-3.7.1/cassandra/bytesio.pxd0000664000175000017500000000136112743410406023750 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. cdef class BytesIOReader: cdef bytes buf cdef char *buf_ptr cdef Py_ssize_t pos cdef Py_ssize_t size cdef char *read(self, Py_ssize_t n = ?) except NULL cassandra-driver-3.7.1/cassandra/cqlengine/0000775000175000017500000000000013004144417023516 5ustar aboudreaultaboudreault00000000000000cassandra-driver-3.7.1/cassandra/cqlengine/functions.py0000664000175000017500000000753612777231260026124 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import division from datetime import datetime from cassandra.cqlengine import UnicodeMixin, ValidationError import sys if sys.version_info >= (2, 7): def get_total_seconds(td): return td.total_seconds() else: def get_total_seconds(td): # integer division used here to emulate built-in total_seconds return ((86400 * td.days + td.seconds) * 10 ** 6 + td.microseconds) / 10 ** 6 class QueryValue(UnicodeMixin): """ Base class for query filter values. Subclasses of these classes can be passed into .filter() keyword args """ format_string = '%({0})s' def __init__(self, value): self.value = value self.context_id = None def __unicode__(self): return self.format_string.format(self.context_id) def set_context_id(self, ctx_id): self.context_id = ctx_id def get_context_size(self): return 1 def update_context(self, ctx): ctx[str(self.context_id)] = self.value class BaseQueryFunction(QueryValue): """ Base class for filtering functions. Subclasses of these classes can be passed into .filter() and will be translated into CQL functions in the resulting query """ pass class TimeUUIDQueryFunction(BaseQueryFunction): def __init__(self, value): """ :param value: the time to create bounding time uuid from :type value: datetime """ if not isinstance(value, datetime): raise ValidationError('datetime instance is required') super(TimeUUIDQueryFunction, self).__init__(value) def to_database(self, val): epoch = datetime(1970, 1, 1, tzinfo=val.tzinfo) offset = get_total_seconds(epoch.tzinfo.utcoffset(epoch)) if epoch.tzinfo else 0 return int((get_total_seconds(val - epoch) - offset) * 1000) def update_context(self, ctx): ctx[str(self.context_id)] = self.to_database(self.value) class MinTimeUUID(TimeUUIDQueryFunction): """ return a fake timeuuid corresponding to the smallest possible timeuuid for the given timestamp http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun """ format_string = 'MinTimeUUID(%({0})s)' class MaxTimeUUID(TimeUUIDQueryFunction): """ return a fake timeuuid corresponding to the largest possible timeuuid for the given timestamp http://cassandra.apache.org/doc/cql3/CQL-3.0.html#timeuuidFun """ format_string = 'MaxTimeUUID(%({0})s)' class Token(BaseQueryFunction): """ compute the token for a given partition key http://cassandra.apache.org/doc/cql3/CQL-3.0.html#tokenFun """ def __init__(self, *values): if len(values) == 1 and isinstance(values[0], (list, tuple)): values = values[0] super(Token, self).__init__(values) self._columns = None def set_columns(self, columns): self._columns = columns def get_context_size(self): return len(self.value) def __unicode__(self): token_args = ', '.join('%({0})s'.format(self.context_id + i) for i in range(self.get_context_size())) return "token({0})".format(token_args) def update_context(self, ctx): for i, (col, val) in enumerate(zip(self._columns, self.value)): ctx[str(self.context_id + i)] = col.to_database(val) cassandra-driver-3.7.1/cassandra/cqlengine/usertype.py0000664000175000017500000001531612766043721025770 0ustar aboudreaultaboudreault00000000000000import re import six from cassandra.util import OrderedDict from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine import columns from cassandra.cqlengine import connection as conn from cassandra.cqlengine import models class UserTypeException(CQLEngineException): pass class UserTypeDefinitionException(UserTypeException): pass class BaseUserType(object): """ The base type class; don't inherit from this, inherit from UserType, defined below """ __type_name__ = None _fields = None _db_map = None def __init__(self, **values): self._values = {} if self._db_map: values = dict((self._db_map.get(k, k), v) for k, v in values.items()) for name, field in self._fields.items(): field_default = field.get_default() if field.has_default else None value = values.get(name, field_default) if value is not None or isinstance(field, columns.BaseContainerColumn): value = field.to_python(value) value_mngr = field.value_manager(self, field, value) value_mngr.explicit = name in values self._values[name] = value_mngr def __eq__(self, other): if self.__class__ != other.__class__: return False keys = set(self._fields.keys()) other_keys = set(other._fields.keys()) if keys != other_keys: return False for key in other_keys: if getattr(self, key, None) != getattr(other, key, None): return False return True def __ne__(self, other): return not self.__eq__(other) def __str__(self): return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in six.iteritems(self._values))) def has_changed_fields(self): return any(v.changed for v in self._values.values()) def reset_changed_fields(self): for v in self._values.values(): v.reset_previous_value() def __iter__(self): for field in self._fields.keys(): yield field def __getattr__(self, attr): # provides the mapping from db_field to fields try: return getattr(self, self._db_map[attr]) except KeyError: raise AttributeError(attr) def __getitem__(self, key): if not isinstance(key, six.string_types): raise TypeError if key not in self._fields.keys(): raise KeyError return getattr(self, key) def __setitem__(self, key, val): if not isinstance(key, six.string_types): raise TypeError if key not in self._fields.keys(): raise KeyError return setattr(self, key, val) def __len__(self): try: return self._len except: self._len = len(self._fields.keys()) return self._len def keys(self): """ Returns a list of column IDs. """ return [k for k in self] def values(self): """ Returns list of column values. """ return [self[k] for k in self] def items(self): """ Returns a list of column ID/value tuples. """ return [(k, self[k]) for k in self] @classmethod def register_for_keyspace(cls, keyspace, connection=None): conn.register_udt(keyspace, cls.type_name(), cls, connection=connection) @classmethod def type_name(cls): """ Returns the type name if it's been defined otherwise, it creates it from the class name """ if cls.__type_name__: type_name = cls.__type_name__.lower() else: camelcase = re.compile(r'([a-z])([A-Z])') ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2)), s) type_name = ccase(cls.__name__) # trim to less than 48 characters or cassandra will complain type_name = type_name[-48:] type_name = type_name.lower() type_name = re.sub(r'^_+', '', type_name) cls.__type_name__ = type_name return type_name def validate(self): """ Cleans and validates the field values """ for name, field in self._fields.items(): v = getattr(self, name) if v is None and not self._values[name].explicit and field.has_default: v = field.get_default() val = field.validate(v) setattr(self, name, val) class UserTypeMetaClass(type): def __new__(cls, name, bases, attrs): field_dict = OrderedDict() field_defs = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] field_defs = sorted(field_defs, key=lambda x: x[1].position) def _transform_column(field_name, field_obj): field_dict[field_name] = field_obj field_obj.set_column_name(field_name) attrs[field_name] = models.ColumnDescriptor(field_obj) # transform field definitions for k, v in field_defs: # don't allow a field with the same name as a built-in attribute or method if k in BaseUserType.__dict__: raise UserTypeDefinitionException("field '{0}' conflicts with built-in attribute/method".format(k)) _transform_column(k, v) attrs['_fields'] = field_dict db_map = {} for field_name, field in field_dict.items(): db_field = field.db_field_name if db_field != field_name: if db_field in field_dict: raise UserTypeDefinitionException("db_field '{0}' for field '{1}' conflicts with another attribute name".format(db_field, field_name)) db_map[db_field] = field_name attrs['_db_map'] = db_map klass = super(UserTypeMetaClass, cls).__new__(cls, name, bases, attrs) return klass @six.add_metaclass(UserTypeMetaClass) class UserType(BaseUserType): """ This class is used to model User Defined Types. To define a type, declare a class inheriting from this, and assign field types as class attributes: .. code-block:: python # connect with default keyspace ... from cassandra.cqlengine.columns import Text, Integer from cassandra.cqlengine.usertype import UserType class address(UserType): street = Text() zipcode = Integer() from cassandra.cqlengine import management management.sync_type(address) Please see :ref:`user_types` for a complete example and discussion. """ __type_name__ = None """ *Optional.* Sets the name of the CQL type for this type. If not specified, the type name will be the name of the class, with it's module name as it's prefix. """ cassandra-driver-3.7.1/cassandra/cqlengine/query.py0000664000175000017500000015110712777231260025253 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from datetime import datetime, timedelta from functools import partial import time import six from warnings import warn from cassandra.query import SimpleStatement from cassandra.cqlengine import columns, CQLEngineException, ValidationError, UnicodeMixin from cassandra.cqlengine import connection as conn from cassandra.cqlengine.functions import Token, BaseQueryFunction, QueryValue from cassandra.cqlengine.operators import (InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator, LessThanOperator, LessThanOrEqualOperator, ContainsOperator, BaseWhereOperator) from cassandra.cqlengine.statements import (WhereClause, SelectStatement, DeleteStatement, UpdateStatement, InsertStatement, BaseCQLStatement, MapDeleteClause, ConditionalClause) class QueryException(CQLEngineException): pass class IfNotExistsWithCounterColumn(CQLEngineException): pass class IfExistsWithCounterColumn(CQLEngineException): pass class LWTException(CQLEngineException): """Lightweight conditional exception. This exception will be raised when a write using an `IF` clause could not be applied due to existing data violating the condition. The existing data is available through the ``existing`` attribute. :param existing: The current state of the data which prevented the write. """ def __init__(self, existing): super(LWTException, self).__init__("LWT Query was not applied") self.existing = existing class DoesNotExist(QueryException): pass class MultipleObjectsReturned(QueryException): pass def check_applied(result): """ Raises LWTException if it looks like a failed LWT request. """ try: applied = result.was_applied except Exception: applied = True # result was not LWT form if not applied: raise LWTException(result[0]) class AbstractQueryableColumn(UnicodeMixin): """ exposes cql query operators through pythons builtin comparator symbols """ def _get_column(self): raise NotImplementedError def __unicode__(self): raise NotImplementedError def _to_database(self, val): if isinstance(val, QueryValue): return val else: return self._get_column().to_database(val) def in_(self, item): """ Returns an in operator used where you'd typically want to use python's `in` operator """ return WhereClause(six.text_type(self), InOperator(), item) def contains_(self, item): """ Returns a CONTAINS operator """ return WhereClause(six.text_type(self), ContainsOperator(), item) def __eq__(self, other): return WhereClause(six.text_type(self), EqualsOperator(), self._to_database(other)) def __gt__(self, other): return WhereClause(six.text_type(self), GreaterThanOperator(), self._to_database(other)) def __ge__(self, other): return WhereClause(six.text_type(self), GreaterThanOrEqualOperator(), self._to_database(other)) def __lt__(self, other): return WhereClause(six.text_type(self), LessThanOperator(), self._to_database(other)) def __le__(self, other): return WhereClause(six.text_type(self), LessThanOrEqualOperator(), self._to_database(other)) class BatchType(object): Unlogged = 'UNLOGGED' Counter = 'COUNTER' class BatchQuery(object): """ Handles the batching of queries http://docs.datastax.com/en/cql/3.0/cql/cql_reference/batch_r.html See :doc:`/cqlengine/batches` for more details. """ warn_multiple_exec = True _consistency = None _connection = None _connection_explicit = False def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on_exception=False, timeout=conn.NOT_SET, connection=None): """ :param batch_type: (optional) One of batch type values available through BatchType enum :type batch_type: str or None :param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied to the batch conditional. :type timestamp: datetime or timedelta or None :param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc) :type consistency: The :class:`.ConsistencyLevel` to be used for the batch query, or None. :param execute_on_exception: (Defaults to False) Indicates that when the BatchQuery instance is used as a context manager the queries accumulated within the context must be executed despite encountering an error within the context. By default, any exception raised from within the context scope will cause the batched queries not to be executed. :type execute_on_exception: bool :param timeout: (optional) Timeout for the entire batch (in seconds), if not specified fallback to default session timeout :type timeout: float or None :param str connection: Connection name to use for the batch execution """ self.queries = [] self.batch_type = batch_type if timestamp is not None and not isinstance(timestamp, (datetime, timedelta)): raise CQLEngineException('timestamp object must be an instance of datetime') self.timestamp = timestamp self._consistency = consistency self._execute_on_exception = execute_on_exception self._timeout = timeout self._callbacks = [] self._executed = False self._context_entered = False self._connection = connection if connection: self._connection_explicit = True def add_query(self, query): if not isinstance(query, BaseCQLStatement): raise CQLEngineException('only BaseCQLStatements can be added to a batch query') self.queries.append(query) def consistency(self, consistency): self._consistency = consistency def _execute_callbacks(self): for callback, args, kwargs in self._callbacks: callback(*args, **kwargs) def add_callback(self, fn, *args, **kwargs): """Add a function and arguments to be passed to it to be executed after the batch executes. A batch can support multiple callbacks. Note, that if the batch does not execute, the callbacks are not executed. A callback, thus, is an "on batch success" handler. :param fn: Callable object :type fn: callable :param \*args: Positional arguments to be passed to the callback at the time of execution :param \*\*kwargs: Named arguments to be passed to the callback at the time of execution """ if not callable(fn): raise ValueError("Value for argument 'fn' is {0} and is not a callable object.".format(type(fn))) self._callbacks.append((fn, args, kwargs)) def execute(self): if self._executed and self.warn_multiple_exec: msg = "Batch executed multiple times." if self._context_entered: msg += " If using the batch as a context manager, there is no need to call execute directly." warn(msg) self._executed = True if len(self.queries) == 0: # Empty batch is a no-op # except for callbacks self._execute_callbacks() return opener = 'BEGIN ' + (self.batch_type + ' ' if self.batch_type else '') + ' BATCH' if self.timestamp: if isinstance(self.timestamp, six.integer_types): ts = self.timestamp elif isinstance(self.timestamp, (datetime, timedelta)): ts = self.timestamp if isinstance(self.timestamp, timedelta): ts += datetime.now() # Apply timedelta ts = int(time.mktime(ts.timetuple()) * 1e+6 + ts.microsecond) else: raise ValueError("Batch expects a long, a timedelta, or a datetime") opener += ' USING TIMESTAMP {0}'.format(ts) query_list = [opener] parameters = {} ctx_counter = 0 for query in self.queries: query.update_context_id(ctx_counter) ctx = query.get_context() ctx_counter += len(ctx) query_list.append(' ' + str(query)) parameters.update(ctx) query_list.append('APPLY BATCH;') tmp = conn.execute('\n'.join(query_list), parameters, self._consistency, self._timeout, connection=self._connection) check_applied(tmp) self.queries = [] self._execute_callbacks() def __enter__(self): self._context_entered = True return self def __exit__(self, exc_type, exc_val, exc_tb): # don't execute if there was an exception by default if exc_type is not None and not self._execute_on_exception: return self.execute() class ContextQuery(object): """ A Context manager to allow a Model to switch context easily. Presently, the context only specifies a keyspace for model IO. :param *args: One or more models. A model should be a class type, not an instance. :param **kwargs: (optional) Context parameters: can be *keyspace* or *connection* For example: .. code-block:: python with ContextQuery(Automobile, keyspace='test2') as A: A.objects.create(manufacturer='honda', year=2008, model='civic') print len(A.objects.all()) # 1 result with ContextQuery(Automobile, keyspace='test4') as A: print len(A.objects.all()) # 0 result # Multiple models with ContextQuery(Automobile, Automobile2, connection='cluster2') as (A, A2): print len(A.objects.all()) print len(A2.objects.all()) """ def __init__(self, *args, **kwargs): from cassandra.cqlengine import models self.models = [] if len(args) < 1: raise ValueError("No model provided.") keyspace = kwargs.pop('keyspace', None) connection = kwargs.pop('connection', None) if kwargs: raise ValueError("Unknown keyword argument(s): {0}".format( ','.join(kwargs.keys()))) for model in args: try: issubclass(model, models.Model) except TypeError: raise ValueError("Models must be derived from base Model.") m = models._clone_model_class(model, {}) if keyspace: m.__keyspace__ = keyspace if connection: m.__connection__ = connection self.models.append(m) def __enter__(self): if len(self.models) > 1: return tuple(self.models) return self.models[0] def __exit__(self, exc_type, exc_val, exc_tb): return class AbstractQuerySet(object): def __init__(self, model): super(AbstractQuerySet, self).__init__() self.model = model # Where clause filters self._where = [] # Conditional clause filters self._conditional = [] # ordering arguments self._order = [] self._allow_filtering = False # CQL has a default limit of 10000, it's defined here # because explicit is better than implicit self._limit = 10000 # see the defer and only methods self._defer_fields = set() self._deferred_values = {} self._only_fields = [] self._values_list = False self._flat_values_list = False # results cache self._result_cache = None self._result_idx = None self._result_generator = None self._materialize_results = True self._distinct_fields = None self._count = None self._batch = None self._ttl = None self._consistency = None self._timestamp = None self._if_not_exists = False self._timeout = conn.NOT_SET self._if_exists = False self._fetch_size = None self._connection = None @property def column_family_name(self): return self.model.column_family_name() def _execute(self, statement): if self._batch: return self._batch.add_query(statement) else: connection = self._connection or self.model._get_connection() result = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) if self._if_not_exists or self._if_exists or self._conditional: check_applied(result) return result def __unicode__(self): return six.text_type(self._select_query()) def __str__(self): return str(self.__unicode__()) def __call__(self, *args, **kwargs): return self.filter(*args, **kwargs) def __deepcopy__(self, memo): clone = self.__class__(self.model) for k, v in self.__dict__.items(): if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator', '_construct_result']: # don't clone these, which are per-request-execution clone.__dict__[k] = None elif k == '_batch': # we need to keep the same batch instance across # all queryset clones, otherwise the batched queries # fly off into other batch instances which are never # executed, thx @dokai clone.__dict__[k] = self._batch elif k == '_timeout': clone.__dict__[k] = self._timeout else: clone.__dict__[k] = copy.deepcopy(v, memo) return clone def __len__(self): self._execute_query() return self.count() # ----query generation / execution---- def _select_fields(self): """ returns the fields to select """ return [] def _validate_select_where(self): """ put select query validation here """ def _select_query(self): """ Returns a select clause based on the given filter args """ if self._where: self._validate_select_where() return SelectStatement( self.column_family_name, fields=self._select_fields(), where=self._where, order_by=self._order, limit=self._limit, allow_filtering=self._allow_filtering, distinct_fields=self._distinct_fields, fetch_size=self._fetch_size ) # ----Reads------ def _execute_query(self): if self._batch: raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") if self._result_cache is None: self._result_generator = (i for i in self._execute(self._select_query())) self._result_cache = [] self._construct_result = self._maybe_inject_deferred(self._get_result_constructor()) # "DISTINCT COUNT()" is not supported in C* < 2.2, so we need to materialize all results to get # len() and count() working with DISTINCT queries if self._materialize_results or self._distinct_fields: self._fill_result_cache() def _fill_result_cache(self): """ Fill the result cache with all results. """ idx = 0 try: while True: idx += 1000 self._fill_result_cache_to_idx(idx) except StopIteration: pass self._count = len(self._result_cache) def _fill_result_cache_to_idx(self, idx): self._execute_query() if self._result_idx is None: self._result_idx = -1 qty = idx - self._result_idx if qty < 1: return else: for idx in range(qty): self._result_idx += 1 while True: try: self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx]) break except IndexError: self._result_cache.append(next(self._result_generator)) def __iter__(self): self._execute_query() idx = 0 while True: if len(self._result_cache) <= idx: try: self._result_cache.append(next(self._result_generator)) except StopIteration: break instance = self._result_cache[idx] if isinstance(instance, dict): self._fill_result_cache_to_idx(idx) yield self._result_cache[idx] idx += 1 def __getitem__(self, s): self._execute_query() if isinstance(s, slice): start = s.start if s.start else 0 # calculate the amount of results that need to be loaded end = s.stop if start < 0 or s.stop is None or s.stop < 0: end = self.count() try: self._fill_result_cache_to_idx(end) except StopIteration: pass return self._result_cache[start:s.stop:s.step] else: try: s = int(s) except (ValueError, TypeError): raise TypeError('QuerySet indices must be integers') # Using negative indexing is costly since we have to execute a count() if s < 0: num_results = self.count() s += num_results try: self._fill_result_cache_to_idx(s) except StopIteration: raise IndexError return self._result_cache[s] def _get_result_constructor(self): """ Returns a function that will be used to instantiate query results """ raise NotImplementedError @staticmethod def _construct_with_deferred(f, deferred, row): row.update(deferred) return f(row) def _maybe_inject_deferred(self, constructor): return partial(self._construct_with_deferred, constructor, self._deferred_values)\ if self._deferred_values else constructor def batch(self, batch_obj): """ Set a batch object to run the query on. Note: running a select query with a batch object will raise an exception """ if self._connection: raise CQLEngineException("Cannot specify the connection on model in batch mode.") if batch_obj is not None and not isinstance(batch_obj, BatchQuery): raise CQLEngineException('batch_obj must be a BatchQuery instance or None') clone = copy.deepcopy(self) clone._batch = batch_obj return clone def first(self): try: return six.next(iter(self)) except StopIteration: return None def all(self): """ Returns a queryset matching all rows .. code-block:: python for user in User.objects().all(): print(user) """ return copy.deepcopy(self) def consistency(self, consistency): """ Sets the consistency level for the operation. See :class:`.ConsistencyLevel`. .. code-block:: python for user in User.objects(id=3).consistency(CL.ONE): print(user) """ clone = copy.deepcopy(self) clone._consistency = consistency return clone def _parse_filter_arg(self, arg): """ Parses a filter arg in the format: __ :returns: colname, op tuple """ statement = arg.rsplit('__', 1) if len(statement) == 1: return arg, None elif len(statement) == 2: return (statement[0], statement[1]) if arg != 'pk__token' else (arg, None) else: raise QueryException("Can't parse '{0}'".format(arg)) def iff(self, *args, **kwargs): """Adds IF statements to queryset""" if len([x for x in kwargs.values() if x is None]): raise CQLEngineException("None values on iff are not allowed") clone = copy.deepcopy(self) for operator in args: if not isinstance(operator, ConditionalClause): raise QueryException('{0} is not a valid query operator'.format(operator)) clone._conditional.append(operator) for arg, val in kwargs.items(): if isinstance(val, Token): raise QueryException("Token() values are not valid in conditionals") col_name, col_op = self._parse_filter_arg(arg) try: column = self.model._get_column(col_name) except KeyError: raise QueryException("Can't resolve column name: '{0}'".format(col_name)) if isinstance(val, BaseQueryFunction): query_val = val else: query_val = column.to_database(val) operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') operator = operator_class() clone._conditional.append(WhereClause(column.db_field_name, operator, query_val)) return clone def filter(self, *args, **kwargs): """ Adds WHERE arguments to the queryset, returning a new queryset See :ref:`retrieving-objects-with-filters` Returns a QuerySet filtered on the keyword arguments """ # add arguments to the where clause filters if len([x for x in kwargs.values() if x is None]): raise CQLEngineException("None values on filter are not allowed") clone = copy.deepcopy(self) for operator in args: if not isinstance(operator, WhereClause): raise QueryException('{0} is not a valid query operator'.format(operator)) clone._where.append(operator) for arg, val in kwargs.items(): col_name, col_op = self._parse_filter_arg(arg) quote_field = True if not isinstance(val, Token): try: column = self.model._get_column(col_name) except KeyError: raise QueryException("Can't resolve column name: '{0}'".format(col_name)) else: if col_name != 'pk__token': raise QueryException("Token() values may only be compared to the 'pk__token' virtual column") column = columns._PartitionKeysToken(self.model) quote_field = False partition_columns = column.partition_columns if len(partition_columns) != len(val.value): raise QueryException( 'Token() received {0} arguments but model has {1} partition keys'.format( len(val.value), len(partition_columns))) val.set_columns(partition_columns) # get query operator, or use equals if not supplied operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') operator = operator_class() if isinstance(operator, InOperator): if not isinstance(val, (list, tuple)): raise QueryException('IN queries must use a list/tuple value') query_val = [column.to_database(v) for v in val] elif isinstance(val, BaseQueryFunction): query_val = val elif (isinstance(operator, ContainsOperator) and isinstance(column, (columns.List, columns.Set, columns.Map))): # For ContainsOperator and collections, we query using the value, not the container query_val = val else: query_val = column.to_database(val) if not col_op: # only equal values should be deferred clone._defer_fields.add(col_name) clone._deferred_values[column.db_field_name] = val # map by db field name for substitution in results clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field)) return clone def get(self, *args, **kwargs): """ Returns a single instance matching this query, optionally with additional filter kwargs. See :ref:`retrieving-objects-with-filters` Returns a single object matching the QuerySet. .. code-block:: python user = User.get(id=1) If no objects are matched, a :class:`~.DoesNotExist` exception is raised. If more than one object is found, a :class:`~.MultipleObjectsReturned` exception is raised. """ if args or kwargs: return self.filter(*args, **kwargs).get() self._execute_query() # Check that the resultset only contains one element, avoiding sending a COUNT query try: self[1] raise self.model.MultipleObjectsReturned('Multiple objects found') except IndexError: pass try: obj = self[0] except IndexError: raise self.model.DoesNotExist return obj def _get_ordering_condition(self, colname): order_type = 'DESC' if colname.startswith('-') else 'ASC' colname = colname.replace('-', '') return colname, order_type def order_by(self, *colnames): """ Sets the column(s) to be used for ordering Default order is ascending, prepend a '-' to any column name for descending *Note: column names must be a clustering key* .. code-block:: python from uuid import uuid1,uuid4 class Comment(Model): photo_id = UUID(primary_key=True) comment_id = TimeUUID(primary_key=True, default=uuid1) # second primary key component is a clustering key comment = Text() sync_table(Comment) u = uuid4() for x in range(5): Comment.create(photo_id=u, comment="test %d" % x) print("Normal") for comment in Comment.objects(photo_id=u): print comment.comment_id print("Reversed") for comment in Comment.objects(photo_id=u).order_by("-comment_id"): print comment.comment_id """ if len(colnames) == 0: clone = copy.deepcopy(self) clone._order = [] return clone conditions = [] for colname in colnames: conditions.append('"{0}" {1}'.format(*self._get_ordering_condition(colname))) clone = copy.deepcopy(self) clone._order.extend(conditions) return clone def count(self): """ Returns the number of rows matched by this query. *Note: This function executes a SELECT COUNT() and has a performance cost on large datasets* """ if self._batch: raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") if self._count is None: query = self._select_query() query.count = True result = self._execute(query) count_row = result[0].popitem() self._count = count_row[1] return self._count def distinct(self, distinct_fields=None): """ Returns the DISTINCT rows matched by this query. distinct_fields default to the partition key fields if not specified. *Note: distinct_fields must be a partition key or a static column* .. code-block:: python class Automobile(Model): manufacturer = columns.Text(partition_key=True) year = columns.Integer(primary_key=True) model = columns.Text(primary_key=True) price = columns.Decimal() sync_table(Automobile) # create rows Automobile.objects.distinct() # or Automobile.objects.distinct(['manufacturer']) """ clone = copy.deepcopy(self) if distinct_fields: clone._distinct_fields = distinct_fields else: clone._distinct_fields = [x.column_name for x in self.model._partition_keys.values()] return clone def limit(self, v): """ Limits the number of results returned by Cassandra. Use *0* or *None* to disable. *Note that CQL's default limit is 10,000, so all queries without a limit set explicitly will have an implicit limit of 10,000* .. code-block:: python # Fetch 100 users for user in User.objects().limit(100): print(user) # Fetch all users for user in User.objects().limit(None): print(user) """ if v is None: v = 0 if not isinstance(v, six.integer_types): raise TypeError if v == self._limit: return self if v < 0: raise QueryException("Negative limit is not allowed") clone = copy.deepcopy(self) clone._limit = v return clone def fetch_size(self, v): """ Sets the number of rows that are fetched at a time. *Note that driver's default fetch size is 5000.* .. code-block:: python for user in User.objects().fetch_size(500): print(user) """ if not isinstance(v, six.integer_types): raise TypeError if v == self._fetch_size: return self if v < 1: raise QueryException("fetch size less than 1 is not allowed") clone = copy.deepcopy(self) clone._fetch_size = v return clone def allow_filtering(self): """ Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key """ clone = copy.deepcopy(self) clone._allow_filtering = True return clone def _only_or_defer(self, action, fields): if action == 'only' and self._only_fields: raise QueryException("QuerySet already has 'only' fields defined") clone = copy.deepcopy(self) # check for strange fields missing_fields = [f for f in fields if f not in self.model._columns.keys()] if missing_fields: raise QueryException( "Can't resolve fields {0} in {1}".format( ', '.join(missing_fields), self.model.__name__)) if action == 'defer': clone._defer_fields.update(fields) elif action == 'only': clone._only_fields = fields else: raise ValueError return clone def only(self, fields): """ Load only these fields for the returned query """ return self._only_or_defer('only', fields) def defer(self, fields): """ Don't load these fields for the returned query """ return self._only_or_defer('defer', fields) def create(self, **kwargs): return self.model(**kwargs) \ .batch(self._batch) \ .ttl(self._ttl) \ .consistency(self._consistency) \ .if_not_exists(self._if_not_exists) \ .timestamp(self._timestamp) \ .if_exists(self._if_exists) \ .using(connection=self._connection) \ .save() def delete(self): """ Deletes the contents of a query """ # validate where clause partition_keys = set(x.db_field_name for x in self.model._partition_keys.values()) if partition_keys - set(c.field for c in self._where): raise QueryException("The partition key must be defined on delete queries") dq = DeleteStatement( self.column_family_name, where=self._where, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists ) self._execute(dq) def __eq__(self, q): if len(self._where) == len(q._where): return all([w in q._where for w in self._where]) return False def __ne__(self, q): return not (self != q) def timeout(self, timeout): """ :param timeout: Timeout for the query (in seconds) :type timeout: float or None """ clone = copy.deepcopy(self) clone._timeout = timeout return clone def using(self, keyspace=None, connection=None): """ Change the context on-the-fly of the Model class (keyspace, connection) """ if connection and self._batch: raise CQLEngineException("Cannot specify a connection on model in batch mode.") clone = copy.deepcopy(self) if keyspace: from cassandra.cqlengine.models import _clone_model_class clone.model = _clone_model_class(self.model, {'__keyspace__': keyspace}) if connection: clone._connection = connection return clone class ResultObject(dict): """ adds attribute access to a dictionary """ def __getattr__(self, item): try: return self[item] except KeyError: raise AttributeError class SimpleQuerySet(AbstractQuerySet): """ Overrides _get_result_constructor for querysets that do not define a model (e.g. NamedTable queries) """ def _get_result_constructor(self): """ Returns a function that will be used to instantiate query results """ return ResultObject class ModelQuerySet(AbstractQuerySet): """ """ def _validate_select_where(self): """ Checks that a filterset will not create invalid select statement """ # check that there's either a =, a IN or a CONTAINS (collection) relationship with a primary key or indexed field equal_ops = [self.model._get_column_by_db_name(w.field) \ for w in self._where if isinstance(w.operator, EqualsOperator) and not isinstance(w.value, Token)] token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) if not any(w.primary_key or w.index for w in equal_ops) and not token_comparison and not self._allow_filtering: raise QueryException(('Where clauses require either =, a IN or a CONTAINS (collection) ' 'comparison with either a primary key or indexed field')) if not self._allow_filtering: # if the query is not on an indexed field if not any(w.index for w in equal_ops): if not any([w.partition_key for w in equal_ops]) and not token_comparison: raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset') def _select_fields(self): if self._defer_fields or self._only_fields: fields = self.model._columns.keys() if self._defer_fields: fields = [f for f in fields if f not in self._defer_fields] # select the partition keys if all model fields are set defer if not fields: fields = self.model._partition_keys if self._only_fields: fields = [f for f in fields if f in self._only_fields] if not fields: raise QueryException('No fields in select query. Only fields: "{0}", defer fields: "{1}"'.format( ','.join(self._only_fields), ','.join(self._defer_fields))) return [self.model._columns[f].db_field_name for f in fields] return super(ModelQuerySet, self)._select_fields() def _get_result_constructor(self): """ Returns a function that will be used to instantiate query results """ if not self._values_list: # we want models return self.model._construct_instance elif self._flat_values_list: # the user has requested flattened list (1 value per row) key = self._only_fields[0] return lambda row: row[key] else: return lambda row: [row[f] for f in self._only_fields] def _get_ordering_condition(self, colname): colname, order_type = super(ModelQuerySet, self)._get_ordering_condition(colname) column = self.model._columns.get(colname) if column is None: raise QueryException("Can't resolve the column name: '{0}'".format(colname)) # validate the column selection if not column.primary_key: raise QueryException( "Can't order on '{0}', can only order on (clustered) primary keys".format(colname)) pks = [v for k, v in self.model._columns.items() if v.primary_key] if column == pks[0]: raise QueryException( "Can't order by the first primary key (partition key), clustering (secondary) keys only") return column.db_field_name, order_type def values_list(self, *fields, **kwargs): """ Instructs the query set to return tuples, not model instance """ flat = kwargs.pop('flat', False) if kwargs: raise TypeError('Unexpected keyword arguments to values_list: %s' % (kwargs.keys(),)) if flat and len(fields) > 1: raise TypeError("'flat' is not valid when values_list is called with more than one field.") clone = self.only(fields) clone._values_list = True clone._flat_values_list = flat return clone def ttl(self, ttl): """ Sets the ttl (in seconds) for modified data. *Note that running a select query with a ttl value will raise an exception* """ clone = copy.deepcopy(self) clone._ttl = ttl return clone def timestamp(self, timestamp): """ Allows for custom timestamps to be saved with the record. """ clone = copy.deepcopy(self) clone._timestamp = timestamp return clone def if_not_exists(self): """ Check the existence of an object before insertion. If the insertion isn't applied, a LWTException is raised. """ if self.model._has_counter: raise IfNotExistsWithCounterColumn('if_not_exists cannot be used with tables containing counter columns') clone = copy.deepcopy(self) clone._if_not_exists = True return clone def if_exists(self): """ Check the existence of an object before an update or delete. If the update or delete isn't applied, a LWTException is raised. """ if self.model._has_counter: raise IfExistsWithCounterColumn('if_exists cannot be used with tables containing counter columns') clone = copy.deepcopy(self) clone._if_exists = True return clone def update(self, **values): """ Performs an update on the row selected by the queryset. Include values to update in the update like so: .. code-block:: python Model.objects(key=n).update(value='x') Passing in updates for columns which are not part of the model will raise a ValidationError. Per column validation will be performed, but instance level validation will not (i.e., `Model.validate` is not called). This is sometimes referred to as a blind update. For example: .. code-block:: python class User(Model): id = Integer(primary_key=True) name = Text() setup(["localhost"], "test") sync_table(User) u = User.create(id=1, name="jon") User.objects(id=1).update(name="Steve") # sets name to null User.objects(id=1).update(name=None) Also supported is blindly adding and removing elements from container columns, without loading a model instance from Cassandra. Using the syntax `.update(column_name={x, y, z})` will overwrite the contents of the container, like updating a non container column. However, adding `__` to the end of the keyword arg, makes the update call add or remove items from the collection, without overwriting then entire column. Given the model below, here are the operations that can be performed on the different container columns: .. code-block:: python class Row(Model): row_id = columns.Integer(primary_key=True) set_column = columns.Set(Integer) list_column = columns.List(Integer) map_column = columns.Map(Integer, Integer) :class:`~cqlengine.columns.Set` - `add`: adds the elements of the given set to the column - `remove`: removes the elements of the given set to the column .. code-block:: python # add elements to a set Row.objects(row_id=5).update(set_column__add={6}) # remove elements to a set Row.objects(row_id=5).update(set_column__remove={4}) :class:`~cqlengine.columns.List` - `append`: appends the elements of the given list to the end of the column - `prepend`: prepends the elements of the given list to the beginning of the column .. code-block:: python # append items to a list Row.objects(row_id=5).update(list_column__append=[6, 7]) # prepend items to a list Row.objects(row_id=5).update(list_column__prepend=[1, 2]) :class:`~cqlengine.columns.Map` - `update`: adds the given keys/values to the columns, creating new entries if they didn't exist, and overwriting old ones if they did .. code-block:: python # add items to a map Row.objects(row_id=5).update(map_column__update={1: 2, 3: 4}) """ if not values: return nulled_columns = set() updated_columns = set() us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, val in values.items(): col_name, col_op = self._parse_filter_arg(name) col = self.model._columns.get(col_name) # check for nonexistant columns if col is None: raise ValidationError("{0}.{1} has no column named: {2}".format(self.__module__, self.model.__name__, col_name)) # check for primary key update attempts if col.is_primary_key: raise ValidationError("Cannot apply update to primary key '{0}' for {1}.{2}".format(col_name, self.__module__, self.model.__name__)) # we should not provide default values in this use case. val = col.validate(val) if val is None: nulled_columns.add(col_name) continue us.add_update(col, val, operation=col_op) updated_columns.add(col_name) if us.assignments: self._execute(us) if nulled_columns: delete_conditional = [condition for condition in self._conditional if condition.field not in updated_columns] if self._conditional else None ds = DeleteStatement(self.column_family_name, fields=nulled_columns, where=self._where, conditionals=delete_conditional, if_exists=self._if_exists) self._execute(ds) class DMLQuery(object): """ A query object used for queries performing inserts, updates, or deletes this is usually instantiated by the model instance to be modified unlike the read query object, this is mutable """ _ttl = None _consistency = None _timestamp = None _if_not_exists = False _if_exists = False def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None, if_not_exists=False, conditional=None, timeout=conn.NOT_SET, if_exists=False): self.model = model self.column_family_name = self.model.column_family_name() self.instance = instance self._batch = batch self._ttl = ttl self._consistency = consistency self._timestamp = timestamp self._if_not_exists = if_not_exists self._if_exists = if_exists self._conditional = conditional self._timeout = timeout def _execute(self, statement): connection = self.instance._get_connection() if self.instance else self.model._get_connection() if self._batch: if self._batch._connection: if not self._batch._connection_explicit and connection and \ connection != self._batch._connection: raise CQLEngineException('BatchQuery queries must be executed on the same connection') else: # set the BatchQuery connection from the model self._batch._connection = connection return self._batch.add_query(statement) else: results = _execute_statement(self.model, statement, self._consistency, self._timeout, connection=connection) if self._if_not_exists or self._if_exists or self._conditional: check_applied(results) return results def batch(self, batch_obj): if batch_obj is not None and not isinstance(batch_obj, BatchQuery): raise CQLEngineException('batch_obj must be a BatchQuery instance or None') self._batch = batch_obj return self def _delete_null_columns(self, conditionals=None): """ executes a delete query to remove columns that have changed to null """ ds = DeleteStatement(self.column_family_name, conditionals=conditionals, if_exists=self._if_exists) deleted_fields = False static_only = True for _, v in self.instance._values.items(): col = v.column if v.deleted: ds.add_field(col.db_field_name) deleted_fields = True static_only &= col.static elif isinstance(col, columns.Map): uc = MapDeleteClause(col.db_field_name, v.value, v.previous_value) if uc.get_context_size() > 0: ds.add_field(uc) deleted_fields = True static_only |= col.static if deleted_fields: keys = self.model._partition_keys if static_only else self.model._primary_keys for name, col in keys.items(): ds.add_where(col, EqualsOperator(), getattr(self.instance, name)) self._execute(ds) def update(self): """ updates a row. This is a blind update call. All validation and cleaning needs to happen prior to calling this. """ if self.instance is None: raise CQLEngineException("DML Query intance attribute is None") assert type(self.instance) == self.model null_clustering_key = False if len(self.instance._clustering_keys) == 0 else True static_changed_only = True statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.instance._clustering_keys.items(): null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) updated_columns = set() # get defined fields and their column names for name, col in self.model._columns.items(): # if clustering key is null, don't include non static columns if null_clustering_key and not col.static and not col.partition_key: continue if not col.is_primary_key: val = getattr(self.instance, name, None) val_mgr = self.instance._values[name] if val is None: continue if not val_mgr.changed and not isinstance(col, columns.Counter): continue static_changed_only = static_changed_only and col.static statement.add_update(col, val, previous=val_mgr.previous_value) updated_columns.add(col.db_field_name) if statement.assignments: for name, col in self.model._primary_keys.items(): # only include clustering key if clustering key is not null, and non static columns are changed to avoid cql error if (null_clustering_key or static_changed_only) and (not col.partition_key): continue statement.add_where(col, EqualsOperator(), getattr(self.instance, name)) self._execute(statement) if not null_clustering_key: # remove conditions on fields that have been updated delete_conditionals = [condition for condition in self._conditional if condition.field not in updated_columns] if self._conditional else None self._delete_null_columns(delete_conditionals) def save(self): """ Creates / updates a row. This is a blind insert call. All validation and cleaning needs to happen prior to calling this. """ if self.instance is None: raise CQLEngineException("DML Query intance attribute is None") assert type(self.instance) == self.model nulled_fields = set() if self.instance._has_counter or self.instance._can_update(): if self.instance._has_counter: warn("'create' and 'save' actions on Counters are deprecated. A future version will disallow this. Use the 'update' mechanism instead.") return self.update() else: insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp, if_not_exists=self._if_not_exists) static_save_only = False if len(self.instance._clustering_keys) == 0 else True for name, col in self.instance._clustering_keys.items(): static_save_only = static_save_only and col._val_is_null(getattr(self.instance, name, None)) for name, col in self.instance._columns.items(): if static_save_only and not col.static and not col.partition_key: continue val = getattr(self.instance, name, None) if col._val_is_null(val): if self.instance._values[name].changed: nulled_fields.add(col.db_field_name) continue insert.add_assignment(col, getattr(self.instance, name, None)) # skip query execution if it's empty # caused by pointless update queries if not insert.is_empty: self._execute(insert) # delete any nulled columns if not static_save_only: self._delete_null_columns() def delete(self): """ Deletes one instance """ if self.instance is None: raise CQLEngineException("DML Query instance attribute is None") ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.model._primary_keys.items(): val = getattr(self.instance, name) if val is None and not col.partition_key: continue ds.add_where(col, EqualsOperator(), val) self._execute(ds) def _execute_statement(model, statement, consistency_level, timeout, connection=None): params = statement.get_context() s = SimpleStatement(str(statement), consistency_level=consistency_level, fetch_size=statement.fetch_size) if model._partition_key_index: key_values = statement.partition_key_values(model._partition_key_index) if not any(v is None for v in key_values): parts = model._routing_key_from_values(key_values, conn.get_cluster(connection).protocol_version) s.routing_key = parts s.keyspace = model._get_keyspace() connection = connection or model._get_connection() return conn.execute(s, params, timeout=timeout, connection=connection) cassandra-driver-3.7.1/cassandra/cqlengine/statements.py0000664000175000017500000006757512766043657026325 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from datetime import datetime, timedelta import time import six from six.moves import filter from cassandra.query import FETCH_SIZE_UNSET from cassandra.cqlengine import columns from cassandra.cqlengine import UnicodeMixin from cassandra.cqlengine.functions import QueryValue from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator class StatementException(Exception): pass class ValueQuoter(UnicodeMixin): def __init__(self, value): self.value = value def __unicode__(self): from cassandra.encoder import cql_quote if isinstance(self.value, (list, tuple)): return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' elif isinstance(self.value, dict): return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' elif isinstance(self.value, set): return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}' return cql_quote(self.value) def __eq__(self, other): if isinstance(other, self.__class__): return self.value == other.value return False class InQuoter(ValueQuoter): def __unicode__(self): from cassandra.encoder import cql_quote return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')' class BaseClause(UnicodeMixin): def __init__(self, field, value): self.field = field self.value = value self.context_id = None def __unicode__(self): raise NotImplementedError def __hash__(self): return hash(self.field) ^ hash(self.value) def __eq__(self, other): if isinstance(other, self.__class__): return self.field == other.field and self.value == other.value return False def __ne__(self, other): return not self.__eq__(other) def get_context_size(self): """ returns the number of entries this clause will add to the query context """ return 1 def set_context_id(self, i): """ sets the value placeholder that will be used in the query """ self.context_id = i def update_context(self, ctx): """ updates the query context with this clauses values """ assert isinstance(ctx, dict) ctx[str(self.context_id)] = self.value class WhereClause(BaseClause): """ a single where statement used in queries """ def __init__(self, field, operator, value, quote_field=True): """ :param field: :param operator: :param value: :param quote_field: hack to get the token function rendering properly :return: """ if not isinstance(operator, BaseWhereOperator): raise StatementException( "operator must be of type {0}, got {1}".format(BaseWhereOperator, type(operator)) ) super(WhereClause, self).__init__(field, value) self.operator = operator self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) self.quote_field = quote_field def __unicode__(self): field = ('"{0}"' if self.quote_field else '{0}').format(self.field) return u'{0} {1} {2}'.format(field, self.operator, six.text_type(self.query_value)) def __hash__(self): return super(WhereClause, self).__hash__() ^ hash(self.operator) def __eq__(self, other): if super(WhereClause, self).__eq__(other): return self.operator.__class__ == other.operator.__class__ return False def get_context_size(self): return self.query_value.get_context_size() def set_context_id(self, i): super(WhereClause, self).set_context_id(i) self.query_value.set_context_id(i) def update_context(self, ctx): if isinstance(self.operator, InOperator): ctx[str(self.context_id)] = InQuoter(self.value) else: self.query_value.update_context(ctx) class AssignmentClause(BaseClause): """ a single variable st statement """ def __unicode__(self): return u'"{0}" = %({1})s'.format(self.field, self.context_id) def insert_tuple(self): return self.field, self.context_id class ConditionalClause(BaseClause): """ A single variable iff statement """ def __unicode__(self): return u'"{0}" = %({1})s'.format(self.field, self.context_id) def insert_tuple(self): return self.field, self.context_id class ContainerUpdateTypeMapMeta(type): def __init__(cls, name, bases, dct): if not hasattr(cls, 'type_map'): cls.type_map = {} else: cls.type_map[cls.col_type] = cls super(ContainerUpdateTypeMapMeta, cls).__init__(name, bases, dct) @six.add_metaclass(ContainerUpdateTypeMapMeta) class ContainerUpdateClause(AssignmentClause): def __init__(self, field, value, operation=None, previous=None): super(ContainerUpdateClause, self).__init__(field, value) self.previous = previous self._assignments = None self._operation = operation self._analyzed = False def _analyze(self): raise NotImplementedError def get_context_size(self): raise NotImplementedError def update_context(self, ctx): raise NotImplementedError class SetUpdateClause(ContainerUpdateClause): """ updates a set collection """ col_type = columns.Set _additions = None _removals = None def __unicode__(self): qs = [] ctx_id = self.context_id if (self.previous is None and self._assignments is None and self._additions is None and self._removals is None): qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] if self._assignments is not None: qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] ctx_id += 1 if self._additions is not None: qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] ctx_id += 1 if self._removals is not None: qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] return ', '.join(qs) def _analyze(self): """ works out the updates to be performed """ if self.value is None or self.value == self.previous: pass elif self._operation == "add": self._additions = self.value elif self._operation == "remove": self._removals = self.value elif self.previous is None: self._assignments = self.value else: # partial update time self._additions = (self.value - self.previous) or None self._removals = (self.previous - self.value) or None self._analyzed = True def get_context_size(self): if not self._analyzed: self._analyze() if (self.previous is None and not self._assignments and self._additions is None and self._removals is None): return 1 return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals)) def update_context(self, ctx): if not self._analyzed: self._analyze() ctx_id = self.context_id if (self.previous is None and self._assignments is None and self._additions is None and self._removals is None): ctx[str(ctx_id)] = set() if self._assignments is not None: ctx[str(ctx_id)] = self._assignments ctx_id += 1 if self._additions is not None: ctx[str(ctx_id)] = self._additions ctx_id += 1 if self._removals is not None: ctx[str(ctx_id)] = self._removals class ListUpdateClause(ContainerUpdateClause): """ updates a list collection """ col_type = columns.List _append = None _prepend = None def __unicode__(self): if not self._analyzed: self._analyze() qs = [] ctx_id = self.context_id if self._assignments is not None: qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] ctx_id += 1 if self._prepend is not None: qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)] ctx_id += 1 if self._append is not None: qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] return ', '.join(qs) def get_context_size(self): if not self._analyzed: self._analyze() return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend)) def update_context(self, ctx): if not self._analyzed: self._analyze() ctx_id = self.context_id if self._assignments is not None: ctx[str(ctx_id)] = self._assignments ctx_id += 1 if self._prepend is not None: ctx[str(ctx_id)] = self._prepend ctx_id += 1 if self._append is not None: ctx[str(ctx_id)] = self._append def _analyze(self): """ works out the updates to be performed """ if self.value is None or self.value == self.previous: pass elif self._operation == "append": self._append = self.value elif self._operation == "prepend": self._prepend = self.value elif self.previous is None: self._assignments = self.value elif len(self.value) < len(self.previous): # if elements have been removed, # rewrite the whole list self._assignments = self.value elif len(self.previous) == 0: # if we're updating from an empty # list, do a complete insert self._assignments = self.value else: # the max start idx we want to compare search_space = len(self.value) - max(0, len(self.previous) - 1) # the size of the sub lists we want to look at search_size = len(self.previous) for i in range(search_space): # slice boundary j = i + search_size sub = self.value[i:j] idx_cmp = lambda idx: self.previous[idx] == sub[idx] if idx_cmp(0) and idx_cmp(-1) and self.previous == sub: self._prepend = self.value[:i] or None self._append = self.value[j:] or None break # if both append and prepend are still None after looking # at both lists, an insert statement will be created if self._prepend is self._append is None: self._assignments = self.value self._analyzed = True class MapUpdateClause(ContainerUpdateClause): """ updates a map collection """ col_type = columns.Map _updates = None def _analyze(self): if self._operation == "update": self._updates = self.value.keys() else: if self.previous is None: self._updates = sorted([k for k, v in self.value.items()]) else: self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None self._analyzed = True def get_context_size(self): if self.is_assignment: return 1 return len(self._updates or []) * 2 def update_context(self, ctx): ctx_id = self.context_id if self.is_assignment: ctx[str(ctx_id)] = {} else: for key in self._updates or []: val = self.value.get(key) ctx[str(ctx_id)] = key ctx[str(ctx_id + 1)] = val ctx_id += 2 @property def is_assignment(self): if not self._analyzed: self._analyze() return self.previous is None and not self._updates def __unicode__(self): qs = [] ctx_id = self.context_id if self.is_assignment: qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] else: for _ in self._updates or []: qs += ['"{0}"[%({1})s] = %({2})s'.format(self.field, ctx_id, ctx_id + 1)] ctx_id += 2 return ', '.join(qs) class CounterUpdateClause(AssignmentClause): col_type = columns.Counter def __init__(self, field, value, previous=None): super(CounterUpdateClause, self).__init__(field, value) self.previous = previous or 0 def get_context_size(self): return 1 def update_context(self, ctx): ctx[str(self.context_id)] = abs(self.value - self.previous) def __unicode__(self): delta = self.value - self.previous sign = '-' if delta < 0 else '+' return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id) class BaseDeleteClause(BaseClause): pass class FieldDeleteClause(BaseDeleteClause): """ deletes a field from a row """ def __init__(self, field): super(FieldDeleteClause, self).__init__(field, None) def __unicode__(self): return '"{0}"'.format(self.field) def update_context(self, ctx): pass def get_context_size(self): return 0 class MapDeleteClause(BaseDeleteClause): """ removes keys from a map """ def __init__(self, field, value, previous=None): super(MapDeleteClause, self).__init__(field, value) self.value = self.value or {} self.previous = previous or {} self._analyzed = False self._removals = None def _analyze(self): self._removals = sorted([k for k in self.previous if k not in self.value]) self._analyzed = True def update_context(self, ctx): if not self._analyzed: self._analyze() for idx, key in enumerate(self._removals): ctx[str(self.context_id + idx)] = key def get_context_size(self): if not self._analyzed: self._analyze() return len(self._removals) def __unicode__(self): if not self._analyzed: self._analyze() return ', '.join(['"{0}"[%({1})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))]) class BaseCQLStatement(UnicodeMixin): """ The base cql statement class """ def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None): super(BaseCQLStatement, self).__init__() self.table = table self.context_id = 0 self.context_counter = self.context_id self.timestamp = timestamp self.fetch_size = fetch_size if fetch_size else FETCH_SIZE_UNSET self.where_clauses = [] for clause in where or []: self._add_where_clause(clause) self.conditionals = [] for conditional in conditionals or []: self.add_conditional_clause(conditional) def _update_part_key_values(self, field_index_map, clauses, parts): for clause in filter(lambda c: c.field in field_index_map, clauses): parts[field_index_map[clause.field]] = clause.value def partition_key_values(self, field_index_map): parts = [None] * len(field_index_map) self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts) return parts def add_where(self, column, operator, value, quote_field=True): value = column.to_database(value) clause = WhereClause(column.db_field_name, operator, value, quote_field) self._add_where_clause(clause) def _add_where_clause(self, clause): clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() self.where_clauses.append(clause) def get_context(self): """ returns the context dict for this statement :rtype: dict """ ctx = {} for clause in self.where_clauses or []: clause.update_context(ctx) return ctx def add_conditional_clause(self, clause): """ Adds a iff clause to this statement :param clause: The clause that will be added to the iff statement :type clause: ConditionalClause """ clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() self.conditionals.append(clause) def _get_conditionals(self): return 'IF {0}'.format(' AND '.join([six.text_type(c) for c in self.conditionals])) def get_context_size(self): return len(self.get_context()) def update_context_id(self, i): self.context_id = i self.context_counter = self.context_id for clause in self.where_clauses: clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() @property def timestamp_normalized(self): """ we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta :return: """ if not self.timestamp: return None if isinstance(self.timestamp, six.integer_types): return self.timestamp if isinstance(self.timestamp, timedelta): tmp = datetime.now() + self.timestamp else: tmp = self.timestamp return int(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond) def __unicode__(self): raise NotImplementedError def __repr__(self): return self.__unicode__() @property def _where(self): return 'WHERE {0}'.format(' AND '.join([six.text_type(c) for c in self.where_clauses])) class SelectStatement(BaseCQLStatement): """ a cql select statement """ def __init__(self, table, fields=None, count=False, where=None, order_by=None, limit=None, allow_filtering=False, distinct_fields=None, fetch_size=None): """ :param where :type where list of cqlengine.statements.WhereClause """ super(SelectStatement, self).__init__( table, where=where, fetch_size=fetch_size ) self.fields = [fields] if isinstance(fields, six.string_types) else (fields or []) self.distinct_fields = distinct_fields self.count = count self.order_by = [order_by] if isinstance(order_by, six.string_types) else order_by self.limit = limit self.allow_filtering = allow_filtering def __unicode__(self): qs = ['SELECT'] if self.distinct_fields: if self.count: qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] else: qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] elif self.count: qs += ['COUNT(*)'] else: qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*'] qs += ['FROM', self.table] if self.where_clauses: qs += [self._where] if self.order_by and not self.count: qs += ['ORDER BY {0}'.format(', '.join(six.text_type(o) for o in self.order_by))] if self.limit: qs += ['LIMIT {0}'.format(self.limit)] if self.allow_filtering: qs += ['ALLOW FILTERING'] return ' '.join(qs) class AssignmentStatement(BaseCQLStatement): """ value assignment statements """ def __init__(self, table, assignments=None, where=None, ttl=None, timestamp=None, conditionals=None): super(AssignmentStatement, self).__init__( table, where=where, conditionals=conditionals ) self.ttl = ttl self.timestamp = timestamp # add assignments self.assignments = [] for assignment in assignments or []: self._add_assignment_clause(assignment) def update_context_id(self, i): super(AssignmentStatement, self).update_context_id(i) for assignment in self.assignments: assignment.set_context_id(self.context_counter) self.context_counter += assignment.get_context_size() def partition_key_values(self, field_index_map): parts = super(AssignmentStatement, self).partition_key_values(field_index_map) self._update_part_key_values(field_index_map, self.assignments, parts) return parts def add_assignment(self, column, value): value = column.to_database(value) clause = AssignmentClause(column.db_field_name, value) self._add_assignment_clause(clause) def _add_assignment_clause(self, clause): clause.set_context_id(self.context_counter) self.context_counter += clause.get_context_size() self.assignments.append(clause) @property def is_empty(self): return len(self.assignments) == 0 def get_context(self): ctx = super(AssignmentStatement, self).get_context() for clause in self.assignments: clause.update_context(ctx) return ctx class InsertStatement(AssignmentStatement): """ an cql insert statement """ def __init__(self, table, assignments=None, where=None, ttl=None, timestamp=None, if_not_exists=False): super(InsertStatement, self).__init__(table, assignments=assignments, where=where, ttl=ttl, timestamp=timestamp) self.if_not_exists = if_not_exists def __unicode__(self): qs = ['INSERT INTO {0}'.format(self.table)] # get column names and context placeholders fields = [a.insert_tuple() for a in self.assignments] columns, values = zip(*fields) qs += ["({0})".format(', '.join(['"{0}"'.format(c) for c in columns]))] qs += ['VALUES'] qs += ["({0})".format(', '.join(['%({0})s'.format(v) for v in values]))] if self.if_not_exists: qs += ["IF NOT EXISTS"] if self.ttl: qs += ["USING TTL {0}".format(self.ttl)] if self.timestamp: qs += ["USING TIMESTAMP {0}".format(self.timestamp_normalized)] return ' '.join(qs) class UpdateStatement(AssignmentStatement): """ an cql update select statement """ def __init__(self, table, assignments=None, where=None, ttl=None, timestamp=None, conditionals=None, if_exists=False): super(UpdateStatement, self). __init__(table, assignments=assignments, where=where, ttl=ttl, timestamp=timestamp, conditionals=conditionals) self.if_exists = if_exists def __unicode__(self): qs = ['UPDATE', self.table] using_options = [] if self.ttl: using_options += ["TTL {0}".format(self.ttl)] if self.timestamp: using_options += ["TIMESTAMP {0}".format(self.timestamp_normalized)] if using_options: qs += ["USING {0}".format(" AND ".join(using_options))] qs += ['SET'] qs += [', '.join([six.text_type(c) for c in self.assignments])] if self.where_clauses: qs += [self._where] if len(self.conditionals) > 0: qs += [self._get_conditionals()] if self.if_exists: qs += ["IF EXISTS"] return ' '.join(qs) def get_context(self): ctx = super(UpdateStatement, self).get_context() for clause in self.conditionals: clause.update_context(ctx) return ctx def update_context_id(self, i): super(UpdateStatement, self).update_context_id(i) for conditional in self.conditionals: conditional.set_context_id(self.context_counter) self.context_counter += conditional.get_context_size() def add_update(self, column, value, operation=None, previous=None): value = column.to_database(value) col_type = type(column) container_update_type = ContainerUpdateClause.type_map.get(col_type) if container_update_type: previous = column.to_database(previous) clause = container_update_type(column.db_field_name, value, operation, previous) elif col_type == columns.Counter: clause = CounterUpdateClause(column.db_field_name, value, previous) else: clause = AssignmentClause(column.db_field_name, value) if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates self._add_assignment_clause(clause) class DeleteStatement(BaseCQLStatement): """ a cql delete statement """ def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False): super(DeleteStatement, self).__init__( table, where=where, timestamp=timestamp, conditionals=conditionals ) self.fields = [] if isinstance(fields, six.string_types): fields = [fields] for field in fields or []: self.add_field(field) self.if_exists = if_exists def update_context_id(self, i): super(DeleteStatement, self).update_context_id(i) for field in self.fields: field.set_context_id(self.context_counter) self.context_counter += field.get_context_size() for t in self.conditionals: t.set_context_id(self.context_counter) self.context_counter += t.get_context_size() def get_context(self): ctx = super(DeleteStatement, self).get_context() for field in self.fields: field.update_context(ctx) for clause in self.conditionals: clause.update_context(ctx) return ctx def add_field(self, field): if isinstance(field, six.string_types): field = FieldDeleteClause(field) if not isinstance(field, BaseClause): raise StatementException("only instances of AssignmentClause can be added to statements") field.set_context_id(self.context_counter) self.context_counter += field.get_context_size() self.fields.append(field) def __unicode__(self): qs = ['DELETE'] if self.fields: qs += [', '.join(['{0}'.format(f) for f in self.fields])] qs += ['FROM', self.table] delete_option = [] if self.timestamp: delete_option += ["TIMESTAMP {0}".format(self.timestamp_normalized)] if delete_option: qs += [" USING {0} ".format(" AND ".join(delete_option))] if self.where_clauses: qs += [self._where] if self.conditionals: qs += [self._get_conditionals()] if self.if_exists: qs += ["IF EXISTS"] return ' '.join(qs) cassandra-driver-3.7.1/cassandra/cqlengine/models.py0000664000175000017500000011203712766043721025371 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import re import six from warnings import warn from cassandra.cqlengine import CQLEngineException, ValidationError from cassandra.cqlengine import columns from cassandra.cqlengine import connection from cassandra.cqlengine import query from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned from cassandra.metadata import protect_name from cassandra.util import OrderedDict log = logging.getLogger(__name__) def _clone_model_class(model, attrs): new_type = type(model.__name__, (model,), attrs) try: new_type.__abstract__ = model.__abstract__ new_type.__discriminator_value__ = model.__discriminator_value__ new_type.__default_ttl__ = model.__default_ttl__ except AttributeError: pass return new_type class ModelException(CQLEngineException): pass class ModelDefinitionException(ModelException): pass class PolymorphicModelException(ModelException): pass class UndefinedKeyspaceWarning(Warning): pass DEFAULT_KEYSPACE = None class hybrid_classmethod(object): """ Allows a method to behave as both a class method and normal instance method depending on how it's called """ def __init__(self, clsmethod, instmethod): self.clsmethod = clsmethod self.instmethod = instmethod def __get__(self, instance, owner): if instance is None: return self.clsmethod.__get__(owner, owner) else: return self.instmethod.__get__(instance, owner) def __call__(self, *args, **kwargs): """ Just a hint to IDEs that it's ok to call this """ raise NotImplementedError class QuerySetDescriptor(object): """ returns a fresh queryset for the given model it's declared on everytime it's accessed """ def __get__(self, obj, model): """ :rtype: ModelQuerySet """ if model.__abstract__: raise CQLEngineException('cannot execute queries against abstract models') queryset = model.__queryset__(model) # if this is a concrete polymorphic model, and the discriminator # key is an indexed column, add a filter clause to only return # logical rows of the proper type if model._is_polymorphic and not model._is_polymorphic_base: name, column = model._discriminator_column_name, model._discriminator_column if column.partition_key or column.index: # look for existing poly types return queryset.filter(**{name: model.__discriminator_value__}) return queryset def __call__(self, *args, **kwargs): """ Just a hint to IDEs that it's ok to call this :rtype: ModelQuerySet """ raise NotImplementedError class ConditionalDescriptor(object): """ returns a query set descriptor """ def __get__(self, instance, model): if instance: def conditional_setter(*prepared_conditional, **unprepared_conditionals): if len(prepared_conditional) > 0: conditionals = prepared_conditional[0] else: conditionals = instance.objects.iff(**unprepared_conditionals)._conditional instance._conditional = conditionals return instance return conditional_setter qs = model.__queryset__(model) def conditional_setter(**unprepared_conditionals): conditionals = model.objects.iff(**unprepared_conditionals)._conditional qs._conditional = conditionals return qs return conditional_setter def __call__(self, *args, **kwargs): raise NotImplementedError class TTLDescriptor(object): """ returns a query set descriptor """ def __get__(self, instance, model): if instance: # instance = copy.deepcopy(instance) # instance method def ttl_setter(ts): instance._ttl = ts return instance return ttl_setter qs = model.__queryset__(model) def ttl_setter(ts): qs._ttl = ts return qs return ttl_setter def __call__(self, *args, **kwargs): raise NotImplementedError class TimestampDescriptor(object): """ returns a query set descriptor with a timestamp specified """ def __get__(self, instance, model): if instance: # instance method def timestamp_setter(ts): instance._timestamp = ts return instance return timestamp_setter return model.objects.timestamp def __call__(self, *args, **kwargs): raise NotImplementedError class IfNotExistsDescriptor(object): """ return a query set descriptor with a if_not_exists flag specified """ def __get__(self, instance, model): if instance: # instance method def ifnotexists_setter(ife=True): instance._if_not_exists = ife return instance return ifnotexists_setter return model.objects.if_not_exists def __call__(self, *args, **kwargs): raise NotImplementedError class IfExistsDescriptor(object): """ return a query set descriptor with a if_exists flag specified """ def __get__(self, instance, model): if instance: # instance method def ifexists_setter(ife=True): instance._if_exists = ife return instance return ifexists_setter return model.objects.if_exists def __call__(self, *args, **kwargs): raise NotImplementedError class ConsistencyDescriptor(object): """ returns a query set descriptor if called on Class, instance if it was an instance call """ def __get__(self, instance, model): if instance: # instance = copy.deepcopy(instance) def consistency_setter(consistency): instance.__consistency__ = consistency return instance return consistency_setter qs = model.__queryset__(model) def consistency_setter(consistency): qs._consistency = consistency return qs return consistency_setter def __call__(self, *args, **kwargs): raise NotImplementedError class UsingDescriptor(object): """ return a query set descriptor with a connection context specified """ def __get__(self, instance, model): if instance: # instance method def using_setter(connection=None): if connection: instance._connection = connection return instance return using_setter return model.objects.using def __call__(self, *args, **kwargs): raise NotImplementedError class ColumnQueryEvaluator(query.AbstractQueryableColumn): """ Wraps a column and allows it to be used in comparator expressions, returning query operators ie: Model.column == 5 """ def __init__(self, column): self.column = column def __unicode__(self): return self.column.db_field_name def _get_column(self): return self.column class ColumnDescriptor(object): """ Handles the reading and writing of column values to and from a model instance's value manager, as well as creating comparator queries """ def __init__(self, column): """ :param column: :type column: columns.Column :return: """ self.column = column self.query_evaluator = ColumnQueryEvaluator(self.column) def __get__(self, instance, owner): """ Returns either the value or column, depending on if an instance is provided or not :param instance: the model instance :type instance: Model """ try: return instance._values[self.column.column_name].getval() except AttributeError: return self.query_evaluator def __set__(self, instance, value): """ Sets the value on an instance, raises an exception with classes TODO: use None instance to create update statements """ if instance: return instance._values[self.column.column_name].setval(value) else: raise AttributeError('cannot reassign column values') def __delete__(self, instance): """ Sets the column value to None, if possible """ if instance: if self.column.can_delete: instance._values[self.column.column_name].delval() else: raise AttributeError('cannot delete {0} columns'.format(self.column.column_name)) class BaseModel(object): """ The base model class, don't inherit from this, inherit from Model, defined below """ class DoesNotExist(_DoesNotExist): pass class MultipleObjectsReturned(_MultipleObjectsReturned): pass objects = QuerySetDescriptor() ttl = TTLDescriptor() consistency = ConsistencyDescriptor() iff = ConditionalDescriptor() # custom timestamps, see USING TIMESTAMP X timestamp = TimestampDescriptor() if_not_exists = IfNotExistsDescriptor() if_exists = IfExistsDescriptor() using = UsingDescriptor() # _len is lazily created by __len__ __table_name__ = None __table_name_case_sensitive__ = False __keyspace__ = None __connection__ = None __discriminator_value__ = None __options__ = None __compute_routing_key__ = True # the queryset class used for this class __queryset__ = query.ModelQuerySet __dmlquery__ = query.DMLQuery __consistency__ = None # can be set per query _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP) _if_not_exists = False # optional if_not_exists flag to check existence before insertion _if_exists = False # optional if_exists flag to check existence before update _table_name = None # used internally to cache a derived table name _connection = None def __init__(self, **values): self._ttl = None self._timestamp = None self._conditional = None self._batch = None self._timeout = connection.NOT_SET self._is_persisted = False self._connection = None self._values = {} for name, column in self._columns.items(): # Set default values on instantiation. Thanks to this, we don't have # to wait anylonger for a call to validate() to have CQLengine set # default columns values. column_default = column.get_default() if column.has_default else None value = values.get(name, column_default) if value is not None or isinstance(column, columns.BaseContainerColumn): value = column.to_python(value) value_mngr = column.value_manager(self, column, value) value_mngr.explicit = name in values self._values[name] = value_mngr def __repr__(self): return '{0}({1})'.format(self.__class__.__name__, ', '.join('{0}={1!r}'.format(k, getattr(self, k)) for k in self._defined_columns.keys() if k != self._discriminator_column_name)) def __str__(self): """ Pretty printing of models by their primary key """ return '{0} <{1}>'.format(self.__class__.__name__, ', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys())) @classmethod def _routing_key_from_values(cls, pk_values, protocol_version): return cls._key_serializer(pk_values, protocol_version) @classmethod def _discover_polymorphic_submodels(cls): if not cls._is_polymorphic_base: raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes') def _discover(klass): if not klass._is_polymorphic_base and klass.__discriminator_value__ is not None: cls._discriminator_map[klass.__discriminator_value__] = klass for subklass in klass.__subclasses__(): _discover(subklass) _discover(cls) @classmethod def _get_model_by_discriminator_value(cls, key): if not cls._is_polymorphic_base: raise ModelException('_get_model_by_discriminator_value can only be called on polymorphic base classes') return cls._discriminator_map.get(key) @classmethod def _construct_instance(cls, values): """ method used to construct instances from query results this is where polymorphic deserialization occurs """ # we're going to take the values, which is from the DB as a dict # and translate that into our local fields # the db_map is a db_field -> model field map if cls._db_map: values = dict((cls._db_map.get(k, k), v) for k, v in values.items()) if cls._is_polymorphic: disc_key = values.get(cls._discriminator_column_name) if disc_key is None: raise PolymorphicModelException('discriminator value was not found in values') poly_base = cls if cls._is_polymorphic_base else cls._polymorphic_base klass = poly_base._get_model_by_discriminator_value(disc_key) if klass is None: poly_base._discover_polymorphic_submodels() klass = poly_base._get_model_by_discriminator_value(disc_key) if klass is None: raise PolymorphicModelException( 'unrecognized discriminator column {0} for class {1}'.format(disc_key, poly_base.__name__) ) if not issubclass(klass, cls): raise PolymorphicModelException( '{0} is not a subclass of {1}'.format(klass.__name__, cls.__name__) ) values = dict((k, v) for k, v in values.items() if k in klass._columns.keys()) else: klass = cls instance = klass(**values) instance._set_persisted() return instance def _set_persisted(self): for v in self._values.values(): v.reset_previous_value() self._is_persisted = True def _can_update(self): """ Called by the save function to check if this should be persisted with update or insert :return: """ if not self._is_persisted: return False return all([not self._values[k].changed for k in self._primary_keys]) @classmethod def _get_keyspace(cls): """ Returns the manual keyspace, if set, otherwise the default keyspace """ return cls.__keyspace__ or DEFAULT_KEYSPACE @classmethod def _get_column(cls, name): """ Returns the column matching the given name, raising a key error if it doesn't exist :param name: the name of the column to return :rtype: Column """ return cls._columns[name] @classmethod def _get_column_by_db_name(cls, name): """ Returns the column, mapped by db_field name """ return cls._columns.get(cls._db_map.get(name, name)) def __eq__(self, other): if self.__class__ != other.__class__: return False # check attribute keys keys = set(self._columns.keys()) other_keys = set(other._columns.keys()) if keys != other_keys: return False return all(getattr(self, key, None) == getattr(other, key, None) for key in other_keys) def __ne__(self, other): return not self.__eq__(other) @classmethod def column_family_name(cls, include_keyspace=True): """ Returns the column family name if it's been defined otherwise, it creates it from the module and class name """ cf_name = protect_name(cls._raw_column_family_name()) if include_keyspace: keyspace = cls._get_keyspace() if not keyspace: raise CQLEngineException("Model keyspace is not set and no default is available. Set model keyspace or setup connection before attempting to generate a query.") return '{0}.{1}'.format(protect_name(keyspace), cf_name) return cf_name @classmethod def _raw_column_family_name(cls): if not cls._table_name: if cls.__table_name__: if cls.__table_name_case_sensitive__: cls._table_name = cls.__table_name__ else: table_name = cls.__table_name__.lower() if cls.__table_name__ != table_name: warn(("Model __table_name__ will be case sensitive by default in 4.0. " "You should fix the __table_name__ value of the '{0}' model.").format(cls.__name__)) cls._table_name = table_name else: if cls._is_polymorphic and not cls._is_polymorphic_base: cls._table_name = cls._polymorphic_base._raw_column_family_name() else: camelcase = re.compile(r'([a-z])([A-Z])') ccase = lambda s: camelcase.sub(lambda v: '{0}_{1}'.format(v.group(1), v.group(2).lower()), s) cf_name = ccase(cls.__name__) # trim to less than 48 characters or cassandra will complain cf_name = cf_name[-48:] cf_name = cf_name.lower() cf_name = re.sub(r'^_+', '', cf_name) cls._table_name = cf_name return cls._table_name def validate(self): """ Cleans and validates the field values """ for name, col in self._columns.items(): v = getattr(self, name) if v is None and not self._values[name].explicit and col.has_default: v = col.get_default() val = col.validate(v) setattr(self, name, val) # Let an instance be used like a dict of its columns keys/values def __iter__(self): """ Iterate over column ids. """ for column_id in self._columns.keys(): yield column_id def __getitem__(self, key): """ Returns column's value. """ if not isinstance(key, six.string_types): raise TypeError if key not in self._columns.keys(): raise KeyError return getattr(self, key) def __setitem__(self, key, val): """ Sets a column's value. """ if not isinstance(key, six.string_types): raise TypeError if key not in self._columns.keys(): raise KeyError return setattr(self, key, val) def __len__(self): """ Returns the number of columns defined on that model. """ try: return self._len except: self._len = len(self._columns.keys()) return self._len def keys(self): """ Returns a list of column IDs. """ return [k for k in self] def values(self): """ Returns list of column values. """ return [self[k] for k in self] def items(self): """ Returns a list of column ID/value tuples. """ return [(k, self[k]) for k in self] def _as_dict(self): """ Returns a map of column names to cleaned values """ values = self._dynamic_columns or {} for name, col in self._columns.items(): values[name] = col.to_database(getattr(self, name, None)) return values @classmethod def create(cls, **kwargs): """ Create an instance of this model in the database. Takes the model column values as keyword arguments. Returns the instance. """ extra_columns = set(kwargs.keys()) - set(cls._columns.keys()) if extra_columns: raise ValidationError("Incorrect columns passed: {0}".format(extra_columns)) return cls.objects.create(**kwargs) @classmethod def all(cls): """ Returns a queryset representing all stored objects This is a pass-through to the model objects().all() """ return cls.objects.all() @classmethod def filter(cls, *args, **kwargs): """ Returns a queryset based on filter parameters. This is a pass-through to the model objects().:method:`~cqlengine.queries.filter`. """ return cls.objects.filter(*args, **kwargs) @classmethod def get(cls, *args, **kwargs): """ Returns a single object based on the passed filter constraints. This is a pass-through to the model objects().:method:`~cqlengine.queries.get`. """ return cls.objects.get(*args, **kwargs) def timeout(self, timeout): """ Sets a timeout for use in :meth:`~.save`, :meth:`~.update`, and :meth:`~.delete` operations """ assert self._batch is None, 'Setting both timeout and batch is not supported' self._timeout = timeout return self def save(self): """ Saves an object to the database. .. code-block:: python #create a person instance person = Person(first_name='Kimberly', last_name='Eggleston') #saves it to Cassandra person.save() """ # handle polymorphic models if self._is_polymorphic: if self._is_polymorphic_base: raise PolymorphicModelException('cannot save polymorphic base model') else: setattr(self, self._discriminator_column_name, self.__discriminator_value__) self.validate() self.__dmlquery__(self.__class__, self, batch=self._batch, ttl=self._ttl, timestamp=self._timestamp, consistency=self.__consistency__, if_not_exists=self._if_not_exists, conditional=self._conditional, timeout=self._timeout, if_exists=self._if_exists).save() self._set_persisted() self._timestamp = None return self def update(self, **values): """ Performs an update on the model instance. You can pass in values to set on the model for updating, or you can call without values to execute an update against any modified fields. If no fields on the model have been modified since loading, no query will be performed. Model validation is performed normally. It is possible to do a blind update, that is, to update a field without having first selected the object out of the database. See :ref:`Blind Updates ` """ for k, v in values.items(): col = self._columns.get(k) # check for nonexistant columns if col is None: raise ValidationError("{0}.{1} has no column named: {2}".format(self.__module__, self.__class__.__name__, k)) # check for primary key update attempts if col.is_primary_key: raise ValidationError("Cannot apply update to primary key '{0}' for {1}.{2}".format(k, self.__module__, self.__class__.__name__)) setattr(self, k, v) # handle polymorphic models if self._is_polymorphic: if self._is_polymorphic_base: raise PolymorphicModelException('cannot update polymorphic base model') else: setattr(self, self._discriminator_column_name, self.__discriminator_value__) self.validate() self.__dmlquery__(self.__class__, self, batch=self._batch, ttl=self._ttl, timestamp=self._timestamp, consistency=self.__consistency__, conditional=self._conditional, timeout=self._timeout, if_exists=self._if_exists).update() self._set_persisted() self._timestamp = None return self def delete(self): """ Deletes the object from the database """ self.__dmlquery__(self.__class__, self, batch=self._batch, timestamp=self._timestamp, consistency=self.__consistency__, timeout=self._timeout, conditional=self._conditional, if_exists=self._if_exists).delete() def get_changed_columns(self): """ Returns a list of the columns that have been updated since instantiation or save """ return [k for k, v in self._values.items() if v.changed] @classmethod def _class_batch(cls, batch): return cls.objects.batch(batch) def _inst_batch(self, batch): assert self._timeout is connection.NOT_SET, 'Setting both timeout and batch is not supported' if self._connection: raise CQLEngineException("Cannot specify a connection on model in batch mode.") self._batch = batch return self batch = hybrid_classmethod(_class_batch, _inst_batch) @classmethod def _class_get_connection(cls): return cls.__connection__ def _inst_get_connection(self): return self._connection or self.__connection__ _get_connection = hybrid_classmethod(_class_get_connection, _inst_get_connection) class ModelMetaClass(type): def __new__(cls, name, bases, attrs): # move column definitions into columns dict # and set default column names column_dict = OrderedDict() primary_keys = OrderedDict() pk_name = None # get inherited properties inherited_columns = OrderedDict() for base in bases: for k, v in getattr(base, '_defined_columns', {}).items(): inherited_columns.setdefault(k, v) # short circuit __abstract__ inheritance is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False) # short circuit __discriminator_value__ inheritance attrs['__discriminator_value__'] = attrs.get('__discriminator_value__') # TODO __default__ttl__ should be removed in the next major release options = attrs.get('__options__') or {} attrs['__default_ttl__'] = options.get('default_time_to_live') column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] column_definitions = sorted(column_definitions, key=lambda x: x[1].position) is_polymorphic_base = any([c[1].discriminator_column for c in column_definitions]) column_definitions = [x for x in inherited_columns.items()] + column_definitions discriminator_columns = [c for c in column_definitions if c[1].discriminator_column] is_polymorphic = len(discriminator_columns) > 0 if len(discriminator_columns) > 1: raise ModelDefinitionException('only one discriminator_column can be defined in a model, {0} found'.format(len(discriminator_columns))) if attrs['__discriminator_value__'] and not is_polymorphic: raise ModelDefinitionException('__discriminator_value__ specified, but no base columns defined with discriminator_column=True') discriminator_column_name, discriminator_column = discriminator_columns[0] if discriminator_columns else (None, None) if isinstance(discriminator_column, (columns.BaseContainerColumn, columns.Counter)): raise ModelDefinitionException('counter and container columns cannot be used as discriminator columns') # find polymorphic base class polymorphic_base = None if is_polymorphic and not is_polymorphic_base: def _get_polymorphic_base(bases): for base in bases: if getattr(base, '_is_polymorphic_base', False): return base klass = _get_polymorphic_base(base.__bases__) if klass: return klass polymorphic_base = _get_polymorphic_base(bases) defined_columns = OrderedDict(column_definitions) # check for primary key if not is_abstract and not any([v.primary_key for k, v in column_definitions]): raise ModelDefinitionException("At least 1 primary key is required.") counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)] data_columns = [c for c in defined_columns.values() if not c.primary_key and not isinstance(c, columns.Counter)] if counter_columns and data_columns: raise ModelDefinitionException('counter models may not have data columns') has_partition_keys = any(v.partition_key for (k, v) in column_definitions) def _transform_column(col_name, col_obj): column_dict[col_name] = col_obj if col_obj.primary_key: primary_keys[col_name] = col_obj col_obj.set_column_name(col_name) # set properties attrs[col_name] = ColumnDescriptor(col_obj) partition_key_index = 0 # transform column definitions for k, v in column_definitions: # don't allow a column with the same name as a built-in attribute or method if k in BaseModel.__dict__: raise ModelDefinitionException("column '{0}' conflicts with built-in attribute/method".format(k)) # counter column primary keys are not allowed if (v.primary_key or v.partition_key) and isinstance(v, columns.Counter): raise ModelDefinitionException('counter columns cannot be used as primary keys') # this will mark the first primary key column as a partition # key, if one hasn't been set already if not has_partition_keys and v.primary_key: v.partition_key = True has_partition_keys = True if v.partition_key: v._partition_key_index = partition_key_index partition_key_index += 1 overriding = column_dict.get(k) if overriding: v.position = overriding.position v.partition_key = overriding.partition_key v._partition_key_index = overriding._partition_key_index _transform_column(k, v) partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) if attrs.get('__compute_routing_key__', True): key_cols = [c for c in partition_keys.values()] partition_key_index = dict((col.db_field_name, col._partition_key_index) for col in key_cols) key_cql_types = [c.cql_type for c in key_cols] key_serializer = staticmethod(lambda parts, proto_version: [t.to_binary(p, proto_version) for t, p in zip(key_cql_types, parts)]) else: partition_key_index = {} key_serializer = staticmethod(lambda parts, proto_version: None) # setup partition key shortcut if len(partition_keys) == 0: if not is_abstract: raise ModelException("at least one partition key must be defined") if len(partition_keys) == 1: pk_name = [x for x in partition_keys.keys()][0] attrs['pk'] = attrs[pk_name] else: # composite partition key case, get/set a tuple of values _get = lambda self: tuple(self._values[c].getval() for c in partition_keys.keys()) _set = lambda self, val: tuple(self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val)) attrs['pk'] = property(_get, _set) # some validation col_names = set() for v in column_dict.values(): # check for duplicate column names if v.db_field_name in col_names: raise ModelException("{0} defines the column '{1}' more than once".format(name, v.db_field_name)) if v.clustering_order and not (v.primary_key and not v.partition_key): raise ModelException("clustering_order may be specified only for clustering primary keys") if v.clustering_order and v.clustering_order.lower() not in ('asc', 'desc'): raise ModelException("invalid clustering order '{0}' for column '{1}'".format(repr(v.clustering_order), v.db_field_name)) col_names.add(v.db_field_name) # create db_name -> model name map for loading db_map = {} for col_name, field in column_dict.items(): db_field = field.db_field_name if db_field != col_name: db_map[db_field] = col_name # add management members to the class attrs['_columns'] = column_dict attrs['_primary_keys'] = primary_keys attrs['_defined_columns'] = defined_columns # maps the database field to the models key attrs['_db_map'] = db_map attrs['_pk_name'] = pk_name attrs['_dynamic_columns'] = {} attrs['_partition_keys'] = partition_keys attrs['_partition_key_index'] = partition_key_index attrs['_key_serializer'] = key_serializer attrs['_clustering_keys'] = clustering_keys attrs['_has_counter'] = len(counter_columns) > 0 # add polymorphic management attributes attrs['_is_polymorphic_base'] = is_polymorphic_base attrs['_is_polymorphic'] = is_polymorphic attrs['_polymorphic_base'] = polymorphic_base attrs['_discriminator_column'] = discriminator_column attrs['_discriminator_column_name'] = discriminator_column_name attrs['_discriminator_map'] = {} if is_polymorphic_base else None # setup class exceptions DoesNotExistBase = None for base in bases: DoesNotExistBase = getattr(base, 'DoesNotExist', None) if DoesNotExistBase is not None: break DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist) attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {}) MultipleObjectsReturnedBase = None for base in bases: MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None) if MultipleObjectsReturnedBase is not None: break MultipleObjectsReturnedBase = MultipleObjectsReturnedBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned) attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {}) # create the class and add a QuerySet to it klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs) udts = [] for col in column_dict.values(): columns.resolve_udts(col, udts) for user_type in set(udts): user_type.register_for_keyspace(klass._get_keyspace()) return klass @six.add_metaclass(ModelMetaClass) class Model(BaseModel): __abstract__ = True """ *Optional.* Indicates that this model is only intended to be used as a base class for other models. You can't create tables for abstract models, but checks around schema validity are skipped during class construction. """ __table_name__ = None """ *Optional.* Sets the name of the CQL table for this model. If left blank, the table name will be the name of the model, with it's module name as it's prefix. Manually defined table names are not inherited. """ __table_name_case_sensitive__ = False """ *Optional.* By default, __table_name__ is case insensitive. Set this to True if you want to preserve the case sensitivity. """ __keyspace__ = None """ Sets the name of the keyspace used by this model. """ __connection__ = None """ Sets the name of the default connection used by this model. """ __options__ = None """ *Optional* Table options applied with this model (e.g. compaction, default ttl, cache settings, tec.) """ __discriminator_value__ = None """ *Optional* Specifies a value for the discriminator column when using model inheritance. """ __compute_routing_key__ = True """ *Optional* Setting False disables computing the routing key for TokenAwareRouting """ cassandra-driver-3.7.1/cassandra/cqlengine/named.py0000664000175000017500000001116712766043721025174 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from cassandra.util import OrderedDict from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine.columns import Column from cassandra.cqlengine.connection import get_cluster from cassandra.cqlengine.models import UsingDescriptor, BaseModel from cassandra.cqlengine.query import AbstractQueryableColumn, SimpleQuerySet from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned class QuerySetDescriptor(object): """ returns a fresh queryset for the given model it's declared on everytime it's accessed """ def __get__(self, obj, model): """ :rtype: ModelQuerySet """ if model.__abstract__: raise CQLEngineException('cannot execute queries against abstract models') return SimpleQuerySet(obj) def __call__(self, *args, **kwargs): """ Just a hint to IDEs that it's ok to call this :rtype: ModelQuerySet """ raise NotImplementedError class NamedColumn(AbstractQueryableColumn): """ A column that is not coupled to a model class, or type """ def __init__(self, name): self.name = name def __unicode__(self): return self.name def _get_column(self): """ :rtype: NamedColumn """ return self @property def db_field_name(self): return self.name @property def cql(self): return self.get_cql() def get_cql(self): return '"{0}"'.format(self.name) def to_database(self, val): return val class NamedTable(object): """ A Table that is not coupled to a model class """ __abstract__ = False objects = QuerySetDescriptor() __partition_keys = None _partition_key_index = None __connection__ = None _connection = None using = UsingDescriptor() _get_connection = BaseModel._get_connection class DoesNotExist(_DoesNotExist): pass class MultipleObjectsReturned(_MultipleObjectsReturned): pass def __init__(self, keyspace, name): self.keyspace = keyspace self.name = name self._connection = None @property def _partition_keys(self): if not self.__partition_keys: self._get_partition_keys() return self.__partition_keys def _get_partition_keys(self): try: table_meta = get_cluster(self._get_connection()).metadata.keyspaces[self.keyspace].tables[self.name] self.__partition_keys = OrderedDict((pk.name, Column(primary_key=True, partition_key=True, db_field=pk.name)) for pk in table_meta.partition_key) except Exception as e: raise CQLEngineException("Failed inspecting partition keys for {0}." "Ensure cqlengine is connected before attempting this with NamedTable.".format(self.column_family_name())) def column(self, name): return NamedColumn(name) def column_family_name(self, include_keyspace=True): """ Returns the column family name if it's been defined otherwise, it creates it from the module and class name """ if include_keyspace: return '{0}.{1}'.format(self.keyspace, self.name) else: return self.name def _get_column(self, name): """ Returns the column matching the given name :rtype: Column """ return self.column(name) # def create(self, **kwargs): # return self.objects.create(**kwargs) def all(self): return self.objects.all() def filter(self, *args, **kwargs): return self.objects.filter(*args, **kwargs) def get(self, *args, **kwargs): return self.objects.get(*args, **kwargs) class NamedKeyspace(object): """ A keyspace """ def __init__(self, name): self.name = name def table(self, name): """ returns a table descriptor with the given name that belongs to this keyspace """ return NamedTable(self.name, name) cassandra-driver-3.7.1/cassandra/cqlengine/__init__.py0000664000175000017500000000172712743410406025641 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import six # Caching constants. CACHING_ALL = "ALL" CACHING_KEYS_ONLY = "KEYS_ONLY" CACHING_ROWS_ONLY = "ROWS_ONLY" CACHING_NONE = "NONE" class CQLEngineException(Exception): pass class ValidationError(CQLEngineException): pass class UnicodeMixin(object): if six.PY3: __str__ = lambda x: x.__unicode__() else: __str__ = lambda x: six.text_type(x).encode('utf-8') cassandra-driver-3.7.1/cassandra/cqlengine/connection.py0000664000175000017500000002511712777231260026246 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict import logging import six import threading from cassandra.cluster import Cluster, _NOT_SET, NoHostAvailable, UserTypeDoesNotExist from cassandra.query import SimpleStatement, dict_factory from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine.statements import BaseCQLStatement log = logging.getLogger(__name__) NOT_SET = _NOT_SET # required for passing timeout to Session.execute cluster = None session = None # connections registry DEFAULT_CONNECTION = object() _connections = {} # Because type models may be registered before a connection is present, # and because sessions may be replaced, we must register UDTs here, in order # to have them registered when a new session is established. udt_by_keyspace = defaultdict(dict) def format_log_context(msg, connection=None, keyspace=None): """Format log message to add keyspace and connection context""" connection_info = connection or 'DEFAULT_CONNECTION' if keyspace: msg = '[Connection: {0}, Keyspace: {1}] {2}'.format(connection_info, keyspace, msg) else: msg = '[Connection: {0}] {1}'.format(connection_info, msg) return msg class UndefinedKeyspaceException(CQLEngineException): pass class Connection(object): """CQLEngine Connection""" name = None hosts = None consistency = None retry_connect = False lazy_connect = False lazy_connect_lock = None cluster_options = None cluster = None session = None def __init__(self, name, hosts, consistency=None, lazy_connect=False, retry_connect=False, cluster_options=None): self.hosts = hosts self.name = name self.consistency = consistency self.lazy_connect = lazy_connect self.retry_connect = retry_connect self.cluster_options = cluster_options if cluster_options else {} self.lazy_connect_lock = threading.RLock() def setup(self): """Setup the connection""" global cluster, session if 'username' in self.cluster_options or 'password' in self.cluster_options: raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider") if self.lazy_connect: return self.cluster = Cluster(self.hosts, **self.cluster_options) try: self.session = self.cluster.connect() log.debug(format_log_context("connection initialized with internally created session", connection=self.name)) except NoHostAvailable: if self.retry_connect: log.warning(format_log_context("connect failed, setting up for re-attempt on first use", connection=self.name)) self.lazy_connect = True raise if self.consistency is not None: self.session.default_consistency_level = self.consistency if DEFAULT_CONNECTION in _connections and _connections[DEFAULT_CONNECTION] == self: cluster = _connections[DEFAULT_CONNECTION].cluster session = _connections[DEFAULT_CONNECTION].session self.setup_session() def setup_session(self): self.session.row_factory = dict_factory enc = self.session.encoder enc.mapping[tuple] = enc.cql_encode_tuple _register_known_types(self.session.cluster) def handle_lazy_connect(self): # if lazy_connect is False, it means the cluster is setup and ready # No need to acquire the lock if not self.lazy_connect: return with self.lazy_connect_lock: # lazy_connect might have been set to False by another thread while waiting the lock # In this case, do nothing. if self.lazy_connect: log.debug(format_log_context("Lazy connect enabled", connection=self.name)) self.lazy_connect = False self.setup() def register_connection(name, hosts, consistency=None, lazy_connect=False, retry_connect=False, cluster_options=None, default=False): if name in _connections: log.warning("Registering connection '{0}' when it already exists.".format(name)) conn = Connection(name, hosts, consistency=consistency,lazy_connect=lazy_connect, retry_connect=retry_connect, cluster_options=cluster_options) _connections[name] = conn if default: set_default_connection(name) conn.setup() return conn def unregister_connection(name): global cluster, session if name not in _connections: return if DEFAULT_CONNECTION in _connections and _connections[name] == _connections[DEFAULT_CONNECTION]: del _connections[DEFAULT_CONNECTION] cluster = None session = None log.warning("Unregistering default connection '{0}'. Use set_default_connection to set a new one.".format(name)) conn = _connections[name] if conn.cluster: conn.cluster.shutdown() del _connections[name] log.debug("Connection '{0}' has been removed from the registry.".format(name)) def set_default_connection(name): global cluster, session if name not in _connections: raise CQLEngineException("Connection '{0}' doesn't exist.".format(name)) log.debug("Connection '{0}' has been set as default.".format(name)) _connections[DEFAULT_CONNECTION] = _connections[name] cluster = _connections[name].cluster session = _connections[name].session def get_connection(name=None): if not name: name = DEFAULT_CONNECTION if name not in _connections: raise CQLEngineException("Connection name '{0}' doesn't exist in the registry.".format(name)) conn = _connections[name] conn.handle_lazy_connect() return conn def default(): """ Configures the default connection to localhost, using the driver defaults (except for row_factory) """ try: conn = get_connection() if conn.session: log.warning("configuring new default connection for cqlengine when one was already set") except: pass conn = register_connection('default', hosts=None, default=True) conn.setup() log.debug("cqlengine connection initialized with default session to localhost") def set_session(s): """ Configures the default connection with a preexisting :class:`cassandra.cluster.Session` Note: the mapper presently requires a Session :attr:`~.row_factory` set to ``dict_factory``. This may be relaxed in the future """ conn = get_connection() if conn.session: log.warning("configuring new default connection for cqlengine when one was already set") if s.row_factory is not dict_factory: raise CQLEngineException("Failed to initialize: 'Session.row_factory' must be 'dict_factory'.") conn.session = s conn.cluster = s.cluster # Set default keyspace from given session's keyspace if conn.session.keyspace: from cassandra.cqlengine import models models.DEFAULT_KEYSPACE = conn.session.keyspace conn.setup_session() log.debug("cqlengine default connection initialized with %s", s) def setup( hosts, default_keyspace, consistency=None, lazy_connect=False, retry_connect=False, **kwargs): """ Setup a the driver connection used by the mapper :param list hosts: list of hosts, (``contact_points`` for :class:`cassandra.cluster.Cluster`) :param str default_keyspace: The default keyspace to use :param int consistency: The global default :class:`~.ConsistencyLevel` - default is the same as :attr:`.Session.default_consistency_level` :param bool lazy_connect: True if should not connect until first use :param bool retry_connect: True if we should retry to connect even if there was a connection failure initially :param \*\*kwargs: Pass-through keyword arguments for :class:`cassandra.cluster.Cluster` """ from cassandra.cqlengine import models models.DEFAULT_KEYSPACE = default_keyspace register_connection('default', hosts=hosts, consistency=consistency, lazy_connect=lazy_connect, retry_connect=retry_connect, cluster_options=kwargs, default=True) def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connection=None): conn = get_connection(connection) if not conn.session: raise CQLEngineException("It is required to setup() cqlengine before executing queries") if isinstance(query, SimpleStatement): pass # elif isinstance(query, BaseCQLStatement): params = query.get_context() query = SimpleStatement(str(query), consistency_level=consistency_level, fetch_size=query.fetch_size) elif isinstance(query, six.string_types): query = SimpleStatement(query, consistency_level=consistency_level) log.debug(format_log_context(query.query_string, connection=connection)) result = conn.session.execute(query, params, timeout=timeout) return result def get_session(connection=None): conn = get_connection(connection) return conn.session def get_cluster(connection=None): conn = get_connection(connection) if not conn.cluster: raise CQLEngineException("%s.cluster is not configured. Call one of the setup or default functions first." % __name__) return conn.cluster def register_udt(keyspace, type_name, klass, connection=None): udt_by_keyspace[keyspace][type_name] = klass cluster = get_cluster(connection) if cluster: try: cluster.register_user_type(keyspace, type_name, klass) except UserTypeDoesNotExist: pass # new types are covered in management sync functions def _register_known_types(cluster): from cassandra.cqlengine import models for ks_name, name_type_map in udt_by_keyspace.items(): for type_name, klass in name_type_map.items(): try: cluster.register_user_type(ks_name or models.DEFAULT_KEYSPACE, type_name, klass) except UserTypeDoesNotExist: pass # new types are covered in management sync functions cassandra-driver-3.7.1/cassandra/cqlengine/management.py0000664000175000017500000005447212766043721026232 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple import json import logging import os import six import warnings from itertools import product from cassandra import metadata from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine import columns, query from cassandra.cqlengine.connection import execute, get_cluster, format_log_context from cassandra.cqlengine.models import Model from cassandra.cqlengine.named import NamedTable from cassandra.cqlengine.usertype import UserType CQLENG_ALLOW_SCHEMA_MANAGEMENT = 'CQLENG_ALLOW_SCHEMA_MANAGEMENT' Field = namedtuple('Field', ['name', 'type']) log = logging.getLogger(__name__) # system keyspaces schema_columnfamilies = NamedTable('system', 'schema_columnfamilies') def _get_context(keyspaces, connections): """Return all the execution contexts""" if keyspaces: if not isinstance(keyspaces, (list, tuple)): raise ValueError('keyspaces must be a list or a tuple.') if connections: if not isinstance(connections, (list, tuple)): raise ValueError('connections must be a list or a tuple.') keyspaces = keyspaces if keyspaces else [None] connections = connections if connections else [None] return product(connections, keyspaces) def create_keyspace_simple(name, replication_factor, durable_writes=True, connections=None): """ Creates a keyspace with SimpleStrategy for replica placement If the keyspace already exists, it will not be modified. **This function should be used with caution, especially in production environments. Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** *There are plans to guard schema-modifying functions with an environment-driven conditional.* :param str name: name of keyspace to create :param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy` :param bool durable_writes: Write log is bypassed if set to False :param list connections: List of connection names """ _create_keyspace(name, durable_writes, 'SimpleStrategy', {'replication_factor': replication_factor}, connections=connections) def create_keyspace_network_topology(name, dc_replication_map, durable_writes=True, connections=None): """ Creates a keyspace with NetworkTopologyStrategy for replica placement If the keyspace already exists, it will not be modified. **This function should be used with caution, especially in production environments. Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** *There are plans to guard schema-modifying functions with an environment-driven conditional.* :param str name: name of keyspace to create :param dict dc_replication_map: map of dc_names: replication_factor :param bool durable_writes: Write log is bypassed if set to False :param list connections: List of connection names """ _create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', dc_replication_map, connections=connections) def _create_keyspace(name, durable_writes, strategy_class, strategy_options, connections=None): if not _allow_schema_modification(): return if connections: if not isinstance(connections, (list, tuple)): raise ValueError('Connections must be a list or a tuple.') def __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=None): cluster = get_cluster(connection) if name not in cluster.metadata.keyspaces: log.info(format_log_context("Creating keyspace %s", connection=connection), name) ks_meta = metadata.KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options) execute(ks_meta.as_cql_query(), connection=connection) else: log.info(format_log_context("Not creating keyspace %s because it already exists", connection=connection), name) if connections: for connection in connections: __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=connection) else: __create_keyspace(name, durable_writes, strategy_class, strategy_options) def drop_keyspace(name, connections=None): """ Drops a keyspace, if it exists. *There are plans to guard schema-modifying functions with an environment-driven conditional.* **This function should be used with caution, especially in production environments. Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** :param str name: name of keyspace to drop :param list connections: List of connection names """ if not _allow_schema_modification(): return if connections: if not isinstance(connections, (list, tuple)): raise ValueError('Connections must be a list or a tuple.') def _drop_keyspace(name, connection=None): cluster = get_cluster(connection) if name in cluster.metadata.keyspaces: execute("DROP KEYSPACE {0}".format(metadata.protect_name(name)), connection=connection) if connections: for connection in connections: _drop_keyspace(name, connection) else: _drop_keyspace(name) def _get_index_name_by_column(table, column_name): """ Find the index name for a given table and column. """ protected_name = metadata.protect_name(column_name) possible_index_values = [protected_name, "values(%s)" % protected_name] for index_metadata in table.indexes.values(): options = dict(index_metadata.index_options) if options.get('target') in possible_index_values: return index_metadata.name def sync_table(model, keyspaces=None, connections=None): """ Inspects the model and creates / updates the corresponding table and columns. If `keyspaces` is specified, the table will be synched for all specified keyspaces. Note that the `Model.__keyspace__` is ignored in that case. If `connections` is specified, the table will be synched for all specified connections. Note that the `Model.__connection__` is ignored in that case. If not specified, it will try to get the connection from the Model. Any User Defined Types used in the table are implicitly synchronized. This function can only add fields that are not part of the primary key. Note that the attributes removed from the model are not deleted on the database. They become effectively ignored by (will not show up on) the model. **This function should be used with caution, especially in production environments. Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** *There are plans to guard schema-modifying functions with an environment-driven conditional.* """ context = _get_context(keyspaces, connections) for connection, keyspace in context: with query.ContextQuery(model, keyspace=keyspace) as m: _sync_table(m, connection=connection) def _sync_table(model, connection=None): if not _allow_schema_modification(): return if not issubclass(model, Model): raise CQLEngineException("Models must be derived from base Model.") if model.__abstract__: raise CQLEngineException("cannot create table from abstract model") cf_name = model.column_family_name() raw_cf_name = model._raw_column_family_name() ks_name = model._get_keyspace() connection = connection or model._get_connection() cluster = get_cluster(connection) try: keyspace = cluster.metadata.keyspaces[ks_name] except KeyError: msg = format_log_context("Keyspace '{0}' for model {1} does not exist.", connection=connection) raise CQLEngineException(msg.format(ks_name, model)) tables = keyspace.tables syncd_types = set() for col in model._columns.values(): udts = [] columns.resolve_udts(col, udts) for udt in [u for u in udts if u not in syncd_types]: _sync_type(ks_name, udt, syncd_types, connection=connection) if raw_cf_name not in tables: log.debug(format_log_context("sync_table creating new table %s", keyspace=ks_name, connection=connection), cf_name) qs = _get_create_table(model) try: execute(qs, connection=connection) except CQLEngineException as ex: # 1.2 doesn't return cf names, so we have to examine the exception # and ignore if it says the column family already exists if "Cannot add already existing column family" not in unicode(ex): raise else: log.debug(format_log_context("sync_table checking existing table %s", keyspace=ks_name, connection=connection), cf_name) table_meta = tables[raw_cf_name] _validate_pk(model, table_meta) table_columns = table_meta.columns model_fields = set() for model_name, col in model._columns.items(): db_name = col.db_field_name model_fields.add(db_name) if db_name in table_columns: col_meta = table_columns[db_name] if col_meta.cql_type != col.db_type: msg = format_log_context('Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).' ' Model should be updated.', keyspace=ks_name, connection=connection) msg = msg.format(cf_name, db_name, col_meta.cql_type, col.db_type) warnings.warn(msg) log.warning(msg) continue if col.primary_key or col.primary_key: msg = format_log_context("Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}", keyspace=ks_name, connection=connection) raise CQLEngineException(msg.format(model_name, db_name, cf_name)) query = "ALTER TABLE {0} add {1}".format(cf_name, col.get_column_def()) execute(query, connection=connection) db_fields_not_in_model = model_fields.symmetric_difference(table_columns) if db_fields_not_in_model: msg = format_log_context("Table {0} has fields not referenced by model: {1}", keyspace=ks_name, connection=connection) log.info(msg.format(cf_name, db_fields_not_in_model)) _update_options(model, connection=connection) table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name] indexes = [c for n, c in model._columns.items() if c.index] # TODO: support multiple indexes in C* 3.0+ for column in indexes: index_name = _get_index_name_by_column(table, column.db_field_name) if index_name: continue qs = ['CREATE INDEX'] qs += ['ON {0}'.format(cf_name)] qs += ['("{0}")'.format(column.db_field_name)] qs = ' '.join(qs) execute(qs, connection=connection) def _validate_pk(model, table_meta): model_partition = [c.db_field_name for c in model._partition_keys.values()] meta_partition = [c.name for c in table_meta.partition_key] model_clustering = [c.db_field_name for c in model._clustering_keys.values()] meta_clustering = [c.name for c in table_meta.clustering_key] if model_partition != meta_partition or model_clustering != meta_clustering: def _pk_string(partition, clustering): return "PRIMARY KEY (({0}){1})".format(', '.join(partition), ', ' + ', '.join(clustering) if clustering else '') raise CQLEngineException("Model {0} PRIMARY KEY composition does not match existing table {1}. " "Model: {2}; Table: {3}. " "Update model or drop the table.".format(model, model.column_family_name(), _pk_string(model_partition, model_clustering), _pk_string(meta_partition, meta_clustering))) def sync_type(ks_name, type_model, connection=None): """ Inspects the type_model and creates / updates the corresponding type. Note that the attributes removed from the type_model are not deleted on the database (this operation is not supported). They become effectively ignored by (will not show up on) the type_model. **This function should be used with caution, especially in production environments. Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** *There are plans to guard schema-modifying functions with an environment-driven conditional.* """ if not _allow_schema_modification(): return if not issubclass(type_model, UserType): raise CQLEngineException("Types must be derived from base UserType.") _sync_type(ks_name, type_model, connection=connection) def _sync_type(ks_name, type_model, omit_subtypes=None, connection=None): syncd_sub_types = omit_subtypes or set() for field in type_model._fields.values(): udts = [] columns.resolve_udts(field, udts) for udt in [u for u in udts if u not in syncd_sub_types]: _sync_type(ks_name, udt, syncd_sub_types, connection=connection) syncd_sub_types.add(udt) type_name = type_model.type_name() type_name_qualified = "%s.%s" % (ks_name, type_name) cluster = get_cluster(connection) keyspace = cluster.metadata.keyspaces[ks_name] defined_types = keyspace.user_types if type_name not in defined_types: log.debug(format_log_context("sync_type creating new type %s", keyspace=ks_name, connection=connection), type_name_qualified) cql = get_create_type(type_model, ks_name) execute(cql, connection=connection) cluster.refresh_user_type_metadata(ks_name, type_name) type_model.register_for_keyspace(ks_name, connection=connection) else: type_meta = defined_types[type_name] defined_fields = type_meta.field_names model_fields = set() for field in type_model._fields.values(): model_fields.add(field.db_field_name) if field.db_field_name not in defined_fields: execute("ALTER TYPE {0} ADD {1}".format(type_name_qualified, field.get_column_def()), connection=connection) else: field_type = type_meta.field_types[defined_fields.index(field.db_field_name)] if field_type != field.db_type: msg = format_log_context('Existing user type {0} has field "{1}" with a type ({2}) differing from the model user type ({3}).' ' UserType should be updated.', keyspace=ks_name, connection=connection) msg = msg.format(type_name_qualified, field.db_field_name, field_type, field.db_type) warnings.warn(msg) log.warning(msg) type_model.register_for_keyspace(ks_name, connection=connection) if len(defined_fields) == len(model_fields): log.info(format_log_context("Type %s did not require synchronization", keyspace=ks_name, connection=connection), type_name_qualified) return db_fields_not_in_model = model_fields.symmetric_difference(defined_fields) if db_fields_not_in_model: msg = format_log_context("Type %s has fields not referenced by model: %s", keyspace=ks_name, connection=connection) log.info(msg, type_name_qualified, db_fields_not_in_model) def get_create_type(type_model, keyspace): type_meta = metadata.UserType(keyspace, type_model.type_name(), (f.db_field_name for f in type_model._fields.values()), (v.db_type for v in type_model._fields.values())) return type_meta.as_cql_query() def _get_create_table(model): ks_table_name = model.column_family_name() query_strings = ['CREATE TABLE {0}'.format(ks_table_name)] # add column types pkeys = [] # primary keys ckeys = [] # clustering keys qtypes = [] # field types def add_column(col): s = col.get_column_def() if col.primary_key: keys = (pkeys if col.partition_key else ckeys) keys.append('"{0}"'.format(col.db_field_name)) qtypes.append(s) for name, col in model._columns.items(): add_column(col) qtypes.append('PRIMARY KEY (({0}){1})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or '')) query_strings += ['({0})'.format(', '.join(qtypes))] property_strings = [] _order = ['"{0}" {1}'.format(c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()] if _order: property_strings.append('CLUSTERING ORDER BY ({0})'.format(', '.join(_order))) # options strings use the V3 format, which matches CQL more closely and does not require mapping property_strings += metadata.TableMetadataV3._make_option_strings(model.__options__ or {}) if property_strings: query_strings += ['WITH {0}'.format(' AND '.join(property_strings))] return ' '.join(query_strings) def _get_table_metadata(model, connection=None): # returns the table as provided by the native driver for a given model cluster = get_cluster(connection) ks = model._get_keyspace() table = model._raw_column_family_name() table = cluster.metadata.keyspaces[ks].tables[table] return table def _options_map_from_strings(option_strings): # converts options strings to a mapping to strings or dict options = {} for option in option_strings: name, value = option.split('=') i = value.find('{') if i >= 0: value = value[i:value.rfind('}') + 1].replace("'", '"') # from cql single quotes to json double; not aware of any values that would be escaped right now value = json.loads(value) else: value = value.strip() options[name.strip()] = value return options def _update_options(model, connection=None): """Updates the table options for the given model if necessary. :param model: The model to update. :param connection: Name of the connection to use :return: `True`, if the options were modified in Cassandra, `False` otherwise. :rtype: bool """ ks_name = model._get_keyspace() msg = format_log_context("Checking %s for option differences", keyspace=ks_name, connection=connection) log.debug(msg, model) model_options = model.__options__ or {} table_meta = _get_table_metadata(model, connection=connection) # go to CQL string first to normalize meta from different versions existing_option_strings = set(table_meta._make_option_strings(table_meta.options)) existing_options = _options_map_from_strings(existing_option_strings) model_option_strings = metadata.TableMetadataV3._make_option_strings(model_options) model_options = _options_map_from_strings(model_option_strings) update_options = {} for name, value in model_options.items(): try: existing_value = existing_options[name] except KeyError: msg = format_log_context("Invalid table option: '%s'; known options: %s", keyspace=ks_name, connection=connection) raise KeyError(msg % (name, existing_options.keys())) if isinstance(existing_value, six.string_types): if value != existing_value: update_options[name] = value else: try: for k, v in value.items(): if existing_value[k] != v: update_options[name] = value break except KeyError: update_options[name] = value if update_options: options = ' AND '.join(metadata.TableMetadataV3._make_option_strings(update_options)) query = "ALTER TABLE {0} WITH {1}".format(model.column_family_name(), options) execute(query, connection=connection) return True return False def drop_table(model, keyspaces=None, connections=None): """ Drops the table indicated by the model, if it exists. If `keyspaces` is specified, the table will be dropped for all specified keyspaces. Note that the `Model.__keyspace__` is ignored in that case. If `connections` is specified, the table will be synched for all specified connections. Note that the `Model.__connection__` is ignored in that case. If not specified, it will try to get the connection from the Model. **This function should be used with caution, especially in production environments. Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** *There are plans to guard schema-modifying functions with an environment-driven conditional.* """ context = _get_context(keyspaces, connections) for connection, keyspace in context: with query.ContextQuery(model, keyspace=keyspace) as m: _drop_table(m, connection=connection) def _drop_table(model, connection=None): if not _allow_schema_modification(): return connection = connection or model._get_connection() # don't try to delete non existant tables meta = get_cluster(connection).metadata ks_name = model._get_keyspace() raw_cf_name = model._raw_column_family_name() try: meta.keyspaces[ks_name].tables[raw_cf_name] execute('DROP TABLE {0};'.format(model.column_family_name()), connection=connection) except KeyError: pass def _allow_schema_modification(): if not os.getenv(CQLENG_ALLOW_SCHEMA_MANAGEMENT): msg = CQLENG_ALLOW_SCHEMA_MANAGEMENT + " environment variable is not set. Future versions of this package will require this variable to enable management functions." warnings.warn(msg) log.warning(msg) return True cassandra-driver-3.7.1/cassandra/cqlengine/operators.py0000664000175000017500000000447212743410406026120 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import six from cassandra.cqlengine import UnicodeMixin class QueryOperatorException(Exception): pass class BaseQueryOperator(UnicodeMixin): # The symbol that identifies this operator in kwargs # ie: colname__ symbol = None # The comparator symbol this operator uses in cql cql_symbol = None def __unicode__(self): if self.cql_symbol is None: raise QueryOperatorException("cql symbol is None") return self.cql_symbol class OpMapMeta(type): def __init__(cls, name, bases, dct): if not hasattr(cls, 'opmap'): cls.opmap = {} else: cls.opmap[cls.symbol] = cls super(OpMapMeta, cls).__init__(name, bases, dct) @six.add_metaclass(OpMapMeta) class BaseWhereOperator(BaseQueryOperator): """ base operator used for where clauses """ @classmethod def get_operator(cls, symbol): try: return cls.opmap[symbol.upper()] except KeyError: raise QueryOperatorException("{0} doesn't map to a QueryOperator".format(symbol)) class EqualsOperator(BaseWhereOperator): symbol = 'EQ' cql_symbol = '=' class NotEqualsOperator(BaseWhereOperator): symbol = 'NE' cql_symbol = '!=' class InOperator(EqualsOperator): symbol = 'IN' cql_symbol = 'IN' class GreaterThanOperator(BaseWhereOperator): symbol = "GT" cql_symbol = '>' class GreaterThanOrEqualOperator(BaseWhereOperator): symbol = "GTE" cql_symbol = '>=' class LessThanOperator(BaseWhereOperator): symbol = "LT" cql_symbol = '<' class LessThanOrEqualOperator(BaseWhereOperator): symbol = "LTE" cql_symbol = '<=' class ContainsOperator(EqualsOperator): symbol = "CONTAINS" cql_symbol = 'CONTAINS' cassandra-driver-3.7.1/cassandra/cqlengine/columns.py0000664000175000017500000007300012777231260025561 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy, copy from datetime import date, datetime, timedelta import logging import six from uuid import UUID as _UUID from cassandra import util from cassandra.cqltypes import SimpleDateType, _cqltypes, UserType from cassandra.cqlengine import ValidationError from cassandra.cqlengine.functions import get_total_seconds log = logging.getLogger(__name__) class BaseValueManager(object): def __init__(self, instance, column, value): self.instance = instance self.column = column self.value = value self.previous_value = None self.explicit = False @property def deleted(self): return self.column._val_is_null(self.value) and (self.explicit or self.previous_value is not None) @property def changed(self): """ Indicates whether or not this value has changed. :rtype: boolean """ return self.value != self.previous_value def reset_previous_value(self): self.previous_value = deepcopy(self.value) def getval(self): return self.value def setval(self, val): self.value = val def delval(self): self.value = None def get_property(self): _get = lambda slf: self.getval() _set = lambda slf, val: self.setval(val) _del = lambda slf: self.delval() if self.column.can_delete: return property(_get, _set, _del) else: return property(_get, _set) class Column(object): # the cassandra type this column maps to db_type = None value_manager = BaseValueManager instance_counter = 0 _python_type_hashable = True primary_key = False """ bool flag, indicates this column is a primary key. The first primary key defined on a model is the partition key (unless partition keys are set), all others are cluster keys """ partition_key = False """ indicates that this column should be the partition key, defining more than one partition key column creates a compound partition key """ index = False """ bool flag, indicates an index should be created for this column """ db_field = None """ the fieldname this field will map to in the database """ default = None """ the default value, can be a value or a callable (no args) """ required = False """ boolean, is the field required? Model validation will raise and exception if required is set to True and there is a None value assigned """ clustering_order = None """ only applicable on clustering keys (primary keys that are not partition keys) determines the order that the clustering keys are sorted on disk """ discriminator_column = False """ boolean, if set to True, this column will be used for discriminating records of inherited models. Should only be set on a column of an abstract model being used for inheritance. There may only be one discriminator column per model. See :attr:`~.__discriminator_value__` for how to specify the value of this column on specialized models. """ static = False """ boolean, if set to True, this is a static column, with a single value per partition """ def __init__(self, primary_key=False, partition_key=False, index=False, db_field=None, default=None, required=False, clustering_order=None, discriminator_column=False, static=False): self.partition_key = partition_key self.primary_key = partition_key or primary_key self.index = index self.db_field = db_field self.default = default self.required = required self.clustering_order = clustering_order self.discriminator_column = discriminator_column # the column name in the model definition self.column_name = None self._partition_key_index = None self.static = static self.value = None # keep track of instantiation order self.position = Column.instance_counter Column.instance_counter += 1 def __ne__(self, other): if isinstance(other, Column): return self.position != other.position return NotImplemented def __eq__(self, other): if isinstance(other, Column): return self.position == other.position return NotImplemented def __lt__(self, other): if isinstance(other, Column): return self.position < other.position return NotImplemented def __le__(self, other): if isinstance(other, Column): return self.position <= other.position return NotImplemented def __gt__(self, other): if isinstance(other, Column): return self.position > other.position return NotImplemented def __ge__(self, other): if isinstance(other, Column): return self.position >= other.position return NotImplemented def __hash__(self): return id(self) def validate(self, value): """ Returns a cleaned and validated value. Raises a ValidationError if there's a problem """ if value is None: if self.required: raise ValidationError('{0} - None values are not allowed'.format(self.column_name or self.db_field)) return value def to_python(self, value): """ Converts data from the database into python values raises a ValidationError if the value can't be converted """ return value def to_database(self, value): """ Converts python value into database value """ if value is None and self.has_default: return self.get_default() return value @property def has_default(self): return self.default is not None @property def is_primary_key(self): return self.primary_key @property def can_delete(self): return not self.primary_key def get_default(self): if self.has_default: if callable(self.default): return self.default() else: return self.default def get_column_def(self): """ Returns a column definition for CQL table definition """ static = "static" if self.static else "" return '{0} {1} {2}'.format(self.cql, self.db_type, static) # TODO: make columns use cqltypes under the hood # until then, this bridges the gap in using types along with cassandra.metadata for CQL generation def cql_parameterized_type(self): return self.db_type def set_column_name(self, name): """ Sets the column name during document class construction This value will be ignored if db_field is set in __init__ """ self.column_name = name @property def db_field_name(self): """ Returns the name of the cql name of this column """ return self.db_field or self.column_name @property def db_index_name(self): """ Returns the name of the cql index """ return 'index_{0}'.format(self.db_field_name) @property def cql(self): return self.get_cql() def get_cql(self): return '"{0}"'.format(self.db_field_name) def _val_is_null(self, val): """ determines if the given value equates to a null value for the given column type """ return val is None @property def sub_types(self): return [] @property def cql_type(self): return _cqltypes[self.db_type] class Blob(Column): """ Stores a raw binary value """ db_type = 'blob' def to_database(self, value): if not isinstance(value, (six.binary_type, bytearray)): raise Exception("expecting a binary, got a %s" % type(value)) val = super(Bytes, self).to_database(value) return bytearray(val) Bytes = Blob class Inet(Column): """ Stores an IP address in IPv4 or IPv6 format """ db_type = 'inet' class Text(Column): """ Stores a UTF-8 encoded string """ db_type = 'text' def __init__(self, min_length=None, max_length=None, **kwargs): """ :param int min_length: Sets the minimum length of this string, for validation purposes. Defaults to 1 if this is a ``required`` column. Otherwise, None. :param int max_length: Sets the maximum length of this string, for validation purposes. """ self.min_length = ( 1 if not min_length and kwargs.get('required', False) else min_length) self.max_length = max_length if self.min_length is not None: if self.min_length < 0: raise ValueError( 'Minimum length is not allowed to be negative.') if self.max_length is not None: if self.max_length < 0: raise ValueError( 'Maximum length is not allowed to be negative.') if self.min_length is not None and self.max_length is not None: if self.max_length < self.min_length: raise ValueError( 'Maximum length must be greater or equal ' 'to minimum length.') super(Text, self).__init__(**kwargs) def validate(self, value): value = super(Text, self).validate(value) if not isinstance(value, (six.string_types, bytearray)) and value is not None: raise ValidationError('{0} {1} is not a string'.format(self.column_name, type(value))) if self.max_length is not None: if value and len(value) > self.max_length: raise ValidationError('{0} is longer than {1} characters'.format(self.column_name, self.max_length)) if self.min_length: if (self.min_length and not value) or len(value) < self.min_length: raise ValidationError('{0} is shorter than {1} characters'.format(self.column_name, self.min_length)) return value class Ascii(Text): """ Stores a US-ASCII character string """ db_type = 'ascii' def validate(self, value): """ Only allow ASCII and None values. Check against US-ASCII, a.k.a. 7-bit ASCII, a.k.a. ISO646-US, a.k.a. the Basic Latin block of the Unicode character set. Source: https://github.com/apache/cassandra/blob /3dcbe90e02440e6ee534f643c7603d50ca08482b/src/java/org/apache/cassandra /serializers/AsciiSerializer.java#L29 """ value = super(Ascii, self).validate(value) if value: charset = value if isinstance( value, (bytearray, )) else map(ord, value) if not set(range(128)).issuperset(charset): raise ValidationError( '{!r} is not an ASCII string.'.format(value)) return value class Integer(Column): """ Stores a 32-bit signed integer value """ db_type = 'int' def validate(self, value): val = super(Integer, self).validate(value) if val is None: return try: return int(val) except (TypeError, ValueError): raise ValidationError("{0} {1} can't be converted to integral value".format(self.column_name, value)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class TinyInt(Integer): """ Stores an 8-bit signed integer value .. versionadded:: 2.6.0 requires C* 2.2+ and protocol v4+ """ db_type = 'tinyint' class SmallInt(Integer): """ Stores a 16-bit signed integer value .. versionadded:: 2.6.0 requires C* 2.2+ and protocol v4+ """ db_type = 'smallint' class BigInt(Integer): """ Stores a 64-bit signed integer value """ db_type = 'bigint' class VarInt(Column): """ Stores an arbitrary-precision integer """ db_type = 'varint' def validate(self, value): val = super(VarInt, self).validate(value) if val is None: return try: return int(val) except (TypeError, ValueError): raise ValidationError( "{0} {1} can't be converted to integral value".format(self.column_name, value)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class CounterValueManager(BaseValueManager): def __init__(self, instance, column, value): super(CounterValueManager, self).__init__(instance, column, value) self.value = self.value or 0 self.previous_value = self.previous_value or 0 class Counter(Integer): """ Stores a counter that can be inremented and decremented """ db_type = 'counter' value_manager = CounterValueManager def __init__(self, index=False, db_field=None, required=False): super(Counter, self).__init__( primary_key=False, partition_key=False, index=index, db_field=db_field, default=0, required=required, ) class DateTime(Column): """ Stores a datetime value """ db_type = 'timestamp' truncate_microseconds = False """ Set this ``True`` to have model instances truncate the date, quantizing it in the same way it will be in the database. This allows equality comparison between assigned values and values read back from the database:: DateTime.truncate_microseconds = True assert Model.create(id=0, d=datetime.utcnow()) == Model.objects(id=0).first() Defaults to ``False`` to preserve legacy behavior. May change in the future. """ def to_python(self, value): if value is None: return if isinstance(value, datetime): if DateTime.truncate_microseconds: us = value.microsecond truncated_us = us // 1000 * 1000 return value - timedelta(microseconds=us - truncated_us) else: return value elif isinstance(value, date): return datetime(*(value.timetuple()[:6])) return datetime.utcfromtimestamp(value) def to_database(self, value): value = super(DateTime, self).to_database(value) if value is None: return if not isinstance(value, datetime): if isinstance(value, date): value = datetime(value.year, value.month, value.day) else: raise ValidationError("{0} '{1}' is not a datetime object".format(self.column_name, value)) epoch = datetime(1970, 1, 1, tzinfo=value.tzinfo) offset = get_total_seconds(epoch.tzinfo.utcoffset(epoch)) if epoch.tzinfo else 0 return int((get_total_seconds(value - epoch) - offset) * 1000) class Date(Column): """ Stores a simple date, with no time-of-day .. versionchanged:: 2.6.0 removed overload of Date and DateTime. DateTime is a drop-in replacement for legacy models requires C* 2.2+ and protocol v4+ """ db_type = 'date' def to_database(self, value): value = super(Date, self).to_database(value) if value is None: return # need to translate to int version because some dates are not representable in # string form (datetime limitation) d = value if isinstance(value, util.Date) else util.Date(value) return d.days_from_epoch + SimpleDateType.EPOCH_OFFSET_DAYS class Time(Column): """ Stores a timezone-naive time-of-day, with nanosecond precision .. versionadded:: 2.6.0 requires C* 2.2+ and protocol v4+ """ db_type = 'time' def to_database(self, value): value = super(Time, self).to_database(value) if value is None: return # str(util.Time) yields desired CQL encoding return value if isinstance(value, util.Time) else util.Time(value) class UUID(Column): """ Stores a type 1 or 4 UUID """ db_type = 'uuid' def validate(self, value): val = super(UUID, self).validate(value) if val is None: return if isinstance(val, _UUID): return val if isinstance(val, six.string_types): try: return _UUID(val) except ValueError: # fall-through to error pass raise ValidationError("{0} {1} is not a valid uuid".format( self.column_name, value)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class TimeUUID(UUID): """ UUID containing timestamp """ db_type = 'timeuuid' class Boolean(Column): """ Stores a boolean True or False value """ db_type = 'boolean' def validate(self, value): """ Always returns a Python boolean. """ value = super(Boolean, self).validate(value) if value is not None: value = bool(value) return value def to_python(self, value): return self.validate(value) class BaseFloat(Column): def validate(self, value): value = super(BaseFloat, self).validate(value) if value is None: return try: return float(value) except (TypeError, ValueError): raise ValidationError("{0} {1} is not a valid float".format(self.column_name, value)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class Float(BaseFloat): """ Stores a single-precision floating-point value """ db_type = 'float' class Double(BaseFloat): """ Stores a double-precision floating-point value """ db_type = 'double' class Decimal(Column): """ Stores a variable precision decimal value """ db_type = 'decimal' def validate(self, value): from decimal import Decimal as _Decimal from decimal import InvalidOperation val = super(Decimal, self).validate(value) if val is None: return try: return _Decimal(repr(val)) if isinstance(val, float) else _Decimal(val) except InvalidOperation: raise ValidationError("{0} '{1}' can't be coerced to decimal".format(self.column_name, val)) def to_python(self, value): return self.validate(value) def to_database(self, value): return self.validate(value) class BaseCollectionColumn(Column): """ Base Container type for collection-like columns. http://cassandra.apache.org/doc/cql3/CQL-3.0.html#collections """ def __init__(self, types, **kwargs): """ :param types: a sequence of sub types in this collection """ instances = [] for t in types: inheritance_comparator = issubclass if isinstance(t, type) else isinstance if not inheritance_comparator(t, Column): raise ValidationError("%s is not a column class" % (t,)) if t.db_type is None: raise ValidationError("%s is an abstract type" % (t,)) inst = t() if isinstance(t, type) else t if isinstance(t, BaseCollectionColumn): inst._freeze_db_type() instances.append(inst) self.types = instances super(BaseCollectionColumn, self).__init__(**kwargs) def validate(self, value): value = super(BaseCollectionColumn, self).validate(value) # It is dangerous to let collections have more than 65535. # See: https://issues.apache.org/jira/browse/CASSANDRA-5428 if value is not None and len(value) > 65535: raise ValidationError("{0} Collection can't have more than 65535 elements.".format(self.column_name)) return value def _val_is_null(self, val): return not val def _freeze_db_type(self): if not self.db_type.startswith('frozen'): self.db_type = "frozen<%s>" % (self.db_type,) @property def sub_types(self): return self.types @property def cql_type(self): return _cqltypes[self.__class__.__name__.lower()].apply_parameters([c.cql_type for c in self.types]) class Tuple(BaseCollectionColumn): """ Stores a fixed-length set of positional values http://docs.datastax.com/en/cql/3.1/cql/cql_reference/tupleType.html """ def __init__(self, *args, **kwargs): """ :param args: column types representing tuple composition """ if not args: raise ValueError("Tuple must specify at least one inner type") super(Tuple, self).__init__(args, **kwargs) self.db_type = 'tuple<{0}>'.format(', '.join(typ.db_type for typ in self.types)) def validate(self, value): val = super(Tuple, self).validate(value) if val is None: return if len(val) > len(self.types): raise ValidationError("Value %r has more fields than tuple definition (%s)" % (val, ', '.join(t for t in self.types))) return tuple(t.validate(v) for t, v in zip(self.types, val)) def to_python(self, value): if value is None: return tuple() return tuple(t.to_python(v) for t, v in zip(self.types, value)) def to_database(self, value): if value is None: return return tuple(t.to_database(v) for t, v in zip(self.types, value)) class BaseContainerColumn(BaseCollectionColumn): pass class Set(BaseContainerColumn): """ Stores a set of unordered, unique values http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_set_t.html """ _python_type_hashable = False def __init__(self, value_type, strict=True, default=set, **kwargs): """ :param value_type: a column class indicating the types of the value :param strict: sets whether non set values will be coerced to set type on validation, or raise a validation error, defaults to True """ self.strict = strict super(Set, self).__init__((value_type,), default=default, **kwargs) self.value_col = self.types[0] if not self.value_col._python_type_hashable: raise ValidationError("Cannot create a Set with unhashable value type (see PYTHON-494)") self.db_type = 'set<{0}>'.format(self.value_col.db_type) def validate(self, value): val = super(Set, self).validate(value) if val is None: return types = (set, util.SortedSet) if self.strict else (set, util.SortedSet, list, tuple) if not isinstance(val, types): if self.strict: raise ValidationError('{0} {1} is not a set object'.format(self.column_name, val)) else: raise ValidationError('{0} {1} cannot be coerced to a set object'.format(self.column_name, val)) if None in val: raise ValidationError("{0} None not allowed in a set".format(self.column_name)) # TODO: stop doing this conversion because it doesn't support non-hashable collections as keys (cassandra does) # will need to start using the cassandra.util types in the next major rev (PYTHON-494) return set(self.value_col.validate(v) for v in val) def to_python(self, value): if value is None: return set() return set(self.value_col.to_python(v) for v in value) def to_database(self, value): if value is None: return None return set(self.value_col.to_database(v) for v in value) class List(BaseContainerColumn): """ Stores a list of ordered values http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_list_t.html """ _python_type_hashable = False def __init__(self, value_type, default=list, **kwargs): """ :param value_type: a column class indicating the types of the value """ super(List, self).__init__((value_type,), default=default, **kwargs) self.value_col = self.types[0] self.db_type = 'list<{0}>'.format(self.value_col.db_type) def validate(self, value): val = super(List, self).validate(value) if val is None: return if not isinstance(val, (set, list, tuple)): raise ValidationError('{0} {1} is not a list object'.format(self.column_name, val)) if None in val: raise ValidationError("{0} None is not allowed in a list".format(self.column_name)) return [self.value_col.validate(v) for v in val] def to_python(self, value): if value is None: return [] return [self.value_col.to_python(v) for v in value] def to_database(self, value): if value is None: return None return [self.value_col.to_database(v) for v in value] class Map(BaseContainerColumn): """ Stores a key -> value map (dictionary) http://www.datastax.com/documentation/cql/3.1/cql/cql_using/use_map_t.html """ _python_type_hashable = False def __init__(self, key_type, value_type, default=dict, **kwargs): """ :param key_type: a column class indicating the types of the key :param value_type: a column class indicating the types of the value """ super(Map, self).__init__((key_type, value_type), default=default, **kwargs) self.key_col = self.types[0] self.value_col = self.types[1] if not self.key_col._python_type_hashable: raise ValidationError("Cannot create a Map with unhashable key type (see PYTHON-494)") self.db_type = 'map<{0}, {1}>'.format(self.key_col.db_type, self.value_col.db_type) def validate(self, value): val = super(Map, self).validate(value) if val is None: return if not isinstance(val, (dict, util.OrderedMap)): raise ValidationError('{0} {1} is not a dict object'.format(self.column_name, val)) if None in val: raise ValidationError("{0} None is not allowed in a map".format(self.column_name)) # TODO: stop doing this conversion because it doesn't support non-hashable collections as keys (cassandra does) # will need to start using the cassandra.util types in the next major rev (PYTHON-494) return dict((self.key_col.validate(k), self.value_col.validate(v)) for k, v in val.items()) def to_python(self, value): if value is None: return {} if value is not None: return dict((self.key_col.to_python(k), self.value_col.to_python(v)) for k, v in value.items()) def to_database(self, value): if value is None: return None return dict((self.key_col.to_database(k), self.value_col.to_database(v)) for k, v in value.items()) class UDTValueManager(BaseValueManager): @property def changed(self): return self.value != self.previous_value or (self.value is not None and self.value.has_changed_fields()) def reset_previous_value(self): if self.value is not None: self.value.reset_changed_fields() self.previous_value = copy(self.value) class UserDefinedType(Column): """ User Defined Type column http://www.datastax.com/documentation/cql/3.1/cql/cql_using/cqlUseUDT.html These columns are represented by a specialization of :class:`cassandra.cqlengine.usertype.UserType`. Please see :ref:`user_types` for examples and discussion. """ value_manager = UDTValueManager def __init__(self, user_type, **kwargs): """ :param type user_type: specifies the :class:`~.cqlengine.usertype.UserType` model of the column """ self.user_type = user_type self.db_type = "frozen<%s>" % user_type.type_name() super(UserDefinedType, self).__init__(**kwargs) @property def sub_types(self): return list(self.user_type._fields.values()) @property def cql_type(self): return UserType.make_udt_class(keyspace='', udt_name=self.user_type.type_name(), field_names=[c.db_field_name for c in self.user_type._fields.values()], field_types=[c.cql_type for c in self.user_type._fields.values()]) def resolve_udts(col_def, out_list): for col in col_def.sub_types: resolve_udts(col, out_list) if isinstance(col_def, UserDefinedType): out_list.append(col_def.user_type) class _PartitionKeysToken(Column): """ virtual column representing token of partition columns. Used by filter(pk__token=Token(...)) filters """ def __init__(self, model): self.partition_columns = model._partition_keys.values() super(_PartitionKeysToken, self).__init__(partition_key=True) @property def db_field_name(self): return 'token({0})'.format(', '.join(['"{0}"'.format(c.db_field_name) for c in self.partition_columns])) cassandra-driver-3.7.1/cassandra/connection.py0000664000175000017500000012205712766043721024303 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import # to enable import io from stdlib from collections import defaultdict, deque import errno from functools import wraps, partial from heapq import heappush, heappop import io import logging import six from six.moves import range import socket import struct import sys from threading import Thread, Event, RLock import time try: import ssl except ImportError: ssl = None # NOQA if 'gevent.monkey' in sys.modules: from gevent.queue import Queue, Empty else: from six.moves.queue import Queue, Empty # noqa from cassandra import ConsistencyLevel, AuthenticationFailed, OperationTimedOut from cassandra.marshal import int32_pack from cassandra.protocol import (ReadyMessage, AuthenticateMessage, OptionsMessage, StartupMessage, ErrorMessage, CredentialsMessage, QueryMessage, ResultMessage, ProtocolHandler, InvalidRequestException, SupportedMessage, AuthResponseMessage, AuthChallengeMessage, AuthSuccessMessage, ProtocolException, MAX_SUPPORTED_VERSION, RegisterMessage) from cassandra.util import OrderedDict log = logging.getLogger(__name__) # We use an ordered dictionary and specifically add lz4 before # snappy so that lz4 will be preferred. Changing the order of this # will change the compression preferences for the driver. locally_supported_compressions = OrderedDict() try: import lz4 except ImportError: pass else: # Cassandra writes the uncompressed message length in big endian order, # but the lz4 lib requires little endian order, so we wrap these # functions to handle that def lz4_compress(byts): # write length in big-endian instead of little-endian return int32_pack(len(byts)) + lz4.compress(byts)[4:] def lz4_decompress(byts): # flip from big-endian to little-endian return lz4.decompress(byts[3::-1] + byts[4:]) locally_supported_compressions['lz4'] = (lz4_compress, lz4_decompress) try: import snappy except ImportError: pass else: # work around apparently buggy snappy decompress def decompress(byts): if byts == '\x00': return '' return snappy.decompress(byts) locally_supported_compressions['snappy'] = (snappy.compress, decompress) PROTOCOL_VERSION_MASK = 0x7f HEADER_DIRECTION_FROM_CLIENT = 0x00 HEADER_DIRECTION_TO_CLIENT = 0x80 HEADER_DIRECTION_MASK = 0x80 frame_header_v1_v2 = struct.Struct('>BbBi') frame_header_v3 = struct.Struct('>BhBi') class _Frame(object): def __init__(self, version, flags, stream, opcode, body_offset, end_pos): self.version = version self.flags = flags self.stream = stream self.opcode = opcode self.body_offset = body_offset self.end_pos = end_pos def __eq__(self, other): # facilitates testing if isinstance(other, _Frame): return (self.version == other.version and self.flags == other.flags and self.stream == other.stream and self.opcode == other.opcode and self.body_offset == other.body_offset and self.end_pos == other.end_pos) return NotImplemented def __str__(self): return "ver({0}); flags({1:04b}); stream({2}); op({3}); offset({4}); len({5})".format(self.version, self.flags, self.stream, self.opcode, self.body_offset, self.end_pos - self.body_offset) NONBLOCKING = (errno.EAGAIN, errno.EWOULDBLOCK) class ConnectionException(Exception): """ An unrecoverable error was hit when attempting to use a connection, or the connection was already closed or defunct. """ def __init__(self, message, host=None): Exception.__init__(self, message) self.host = host class ConnectionShutdown(ConnectionException): """ Raised when a connection has been marked as defunct or has been closed. """ pass class ProtocolVersionUnsupported(ConnectionException): """ Server rejected startup message due to unsupported protocol version """ def __init__(self, host, startup_version): msg = "Unsupported protocol version on %s: %d" % (host, startup_version) super(ProtocolVersionUnsupported, self).__init__(msg, host) self.startup_version = startup_version class ConnectionBusy(Exception): """ An attempt was made to send a message through a :class:`.Connection` that was already at the max number of in-flight operations. """ pass class ProtocolError(Exception): """ Communication did not match the protocol that this driver expects. """ pass def defunct_on_error(f): @wraps(f) def wrapper(self, *args, **kwargs): try: return f(self, *args, **kwargs) except Exception as exc: self.defunct(exc) return wrapper DEFAULT_CQL_VERSION = '3.0.0' if six.PY3: def int_from_buf_item(i): return i else: int_from_buf_item = ord class Connection(object): CALLBACK_ERR_THREAD_THRESHOLD = 100 in_buffer_size = 4096 out_buffer_size = 4096 cql_version = None protocol_version = MAX_SUPPORTED_VERSION keyspace = None compression = True compressor = None decompressor = None ssl_options = None last_error = None # The current number of operations that are in flight. More precisely, # the number of request IDs that are currently in use. in_flight = 0 # Max concurrent requests allowed per connection. This is set optimistically high, allowing # all request ids to be used in protocol version 3+. Normally concurrency would be controlled # at a higher level by the application or concurrent.execute_concurrent. This attribute # is for lower-level integrations that want some upper bound without reimplementing. max_in_flight = 2 ** 15 # A set of available request IDs. When using the v3 protocol or higher, # this will not initially include all request IDs in order to save memory, # but the set will grow if it is exhausted. request_ids = None # Tracks the highest used request ID in order to help with growing the # request_ids set highest_request_id = 0 is_defunct = False is_closed = False lock = None user_type_map = None msg_received = False is_unsupported_proto_version = False is_control_connection = False signaled_error = False # used for flagging at the pool level allow_beta_protocol_version = False _iobuf = None _current_frame = None _socket = None _socket_impl = socket _ssl_impl = ssl _check_hostname = False def __init__(self, host='127.0.0.1', port=9042, authenticator=None, ssl_options=None, sockopts=None, compression=True, cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False, user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False): self.host = host self.port = port self.authenticator = authenticator self.ssl_options = ssl_options.copy() if ssl_options else None self.sockopts = sockopts self.compression = compression self.cql_version = cql_version self.protocol_version = protocol_version self.is_control_connection = is_control_connection self.user_type_map = user_type_map self.connect_timeout = connect_timeout self.allow_beta_protocol_version = allow_beta_protocol_version self._push_watchers = defaultdict(set) self._requests = {} self._iobuf = io.BytesIO() if ssl_options: self._check_hostname = bool(self.ssl_options.pop('check_hostname', False)) if self._check_hostname: if not getattr(ssl, 'match_hostname', None): raise RuntimeError("ssl_options specify 'check_hostname', but ssl.match_hostname is not provided. " "Patch or upgrade Python to use this option.") if protocol_version >= 3: self.max_request_id = min(self.max_in_flight - 1, (2 ** 15) - 1) # Don't fill the deque with 2**15 items right away. Start with some and add # more if needed. initial_size = min(300, self.max_in_flight) self.request_ids = deque(range(initial_size)) self.highest_request_id = initial_size - 1 else: self.max_request_id = min(self.max_in_flight, (2 ** 7) - 1) self.request_ids = deque(range(self.max_request_id + 1)) self.highest_request_id = self.max_request_id self.lock = RLock() self.connected_event = Event() @classmethod def initialize_reactor(cls): """ Called once by Cluster.connect(). This should be used by implementations to set up any resources that will be shared across connections. """ pass @classmethod def handle_fork(cls): """ Called after a forking. This should cleanup any remaining reactor state from the parent process. """ pass @classmethod def create_timer(cls, timeout, callback): raise NotImplementedError() @classmethod def factory(cls, host, timeout, *args, **kwargs): """ A factory function which returns connections which have succeeded in connecting and are ready for service (or raises an exception otherwise). """ start = time.time() kwargs['connect_timeout'] = timeout conn = cls(host, *args, **kwargs) elapsed = time.time() - start conn.connected_event.wait(timeout - elapsed) if conn.last_error: if conn.is_unsupported_proto_version: raise ProtocolVersionUnsupported(host, conn.protocol_version) raise conn.last_error elif not conn.connected_event.is_set(): conn.close() raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout) else: return conn def _connect_socket(self): sockerr = None addresses = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) if not addresses: raise ConnectionException("getaddrinfo returned empty list for %s" % (self.host,)) for (af, socktype, proto, canonname, sockaddr) in addresses: try: self._socket = self._socket_impl.socket(af, socktype, proto) if self.ssl_options: if not self._ssl_impl: raise RuntimeError("This version of Python was not compiled with SSL support") self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options) self._socket.settimeout(self.connect_timeout) self._socket.connect(sockaddr) self._socket.settimeout(None) if self._check_hostname: ssl.match_hostname(self._socket.getpeercert(), self.host) sockerr = None break except socket.error as err: if self._socket: self._socket.close() self._socket = None sockerr = err if sockerr: raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % ([a[4] for a in addresses], sockerr.strerror or sockerr)) if self.sockopts: for args in self.sockopts: self._socket.setsockopt(*args) def close(self): raise NotImplementedError() def defunct(self, exc): with self.lock: if self.is_defunct or self.is_closed: return self.is_defunct = True exc_info = sys.exc_info() # if we are not handling an exception, just use the passed exception, and don't try to format exc_info with the message if any(exc_info): log.debug("Defuncting connection (%s) to %s:", id(self), self.host, exc_info=exc_info) else: log.debug("Defuncting connection (%s) to %s: %s", id(self), self.host, exc) self.last_error = exc self.close() self.error_all_requests(exc) self.connected_event.set() return exc def error_all_requests(self, exc): with self.lock: requests = self._requests self._requests = {} if not requests: return new_exc = ConnectionShutdown(str(exc)) def try_callback(cb): try: cb(new_exc) except Exception: log.warning("Ignoring unhandled exception while erroring requests for a " "failed connection (%s) to host %s:", id(self), self.host, exc_info=True) # run first callback from this thread to ensure pool state before leaving cb, _, _ = requests.popitem()[1] try_callback(cb) if not requests: return # additional requests are optionally errored from a separate thread # The default callback and retry logic is fairly expensive -- we don't # want to tie up the event thread when there are many requests def err_all_callbacks(): for cb, _, _ in requests.values(): try_callback(cb) if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD: err_all_callbacks() else: # daemon thread here because we want to stay decoupled from the cluster TPE # TODO: would it make sense to just have a driver-global TPE? t = Thread(target=err_all_callbacks) t.daemon = True t.start() def get_request_id(self): """ This must be called while self.lock is held. """ try: return self.request_ids.popleft() except IndexError: self.highest_request_id += 1 # in_flight checks should guarantee this assert self.highest_request_id <= self.max_request_id return self.highest_request_id def handle_pushed(self, response): log.debug("Message pushed from server: %r", response) for cb in self._push_watchers.get(response.event_type, []): try: cb(response.event_args) except Exception: log.exception("Pushed event handler errored, ignoring:") def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None): if self.is_defunct: raise ConnectionShutdown("Connection to %s is defunct" % self.host) elif self.is_closed: raise ConnectionShutdown("Connection to %s is closed" % self.host) # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages self._requests[request_id] = (cb, decoder, result_metadata) msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, allow_beta_protocol_version=self.allow_beta_protocol_version) self.push(msg) return len(msg) def wait_for_response(self, msg, timeout=None): return self.wait_for_responses(msg, timeout=timeout)[0] def wait_for_responses(self, *msgs, **kwargs): """ Returns a list of (success, response) tuples. If success is False, response will be an Exception. Otherwise, response will be the normal query response. If fail_on_error was left as True and one of the requests failed, the corresponding Exception will be raised. """ if self.is_closed or self.is_defunct: raise ConnectionShutdown("Connection %s is already closed" % (self, )) timeout = kwargs.get('timeout') fail_on_error = kwargs.get('fail_on_error', True) waiter = ResponseWaiter(self, len(msgs), fail_on_error) # busy wait for sufficient space on the connection messages_sent = 0 while True: needed = len(msgs) - messages_sent with self.lock: available = min(needed, self.max_request_id - self.in_flight + 1) request_ids = [self.get_request_id() for _ in range(available)] self.in_flight += available for i, request_id in enumerate(request_ids): self.send_msg(msgs[messages_sent + i], request_id, partial(waiter.got_response, index=messages_sent + i)) messages_sent += available if messages_sent == len(msgs): break else: if timeout is not None: timeout -= 0.01 if timeout <= 0.0: raise OperationTimedOut() time.sleep(0.01) try: return waiter.deliver(timeout) except OperationTimedOut: raise except Exception as exc: self.defunct(exc) raise def register_watcher(self, event_type, callback, register_timeout=None): """ Register a callback for a given event type. """ self._push_watchers[event_type].add(callback) self.wait_for_response( RegisterMessage(event_list=[event_type]), timeout=register_timeout) def register_watchers(self, type_callback_dict, register_timeout=None): """ Register multiple callback/event type pairs, expressed as a dict. """ for event_type, callback in type_callback_dict.items(): self._push_watchers[event_type].add(callback) self.wait_for_response( RegisterMessage(event_list=type_callback_dict.keys()), timeout=register_timeout) def control_conn_disposed(self): self.is_control_connection = False self._push_watchers = {} @defunct_on_error def _read_frame_header(self): buf = self._iobuf.getvalue() pos = len(buf) if pos: version = int_from_buf_item(buf[0]) & PROTOCOL_VERSION_MASK if version > MAX_SUPPORTED_VERSION: raise ProtocolError("This version of the driver does not support protocol version %d" % version) frame_header = frame_header_v3 if version >= 3 else frame_header_v1_v2 # this frame header struct is everything after the version byte header_size = frame_header.size + 1 if pos >= header_size: flags, stream, op, body_len = frame_header.unpack_from(buf, 1) if body_len < 0: raise ProtocolError("Received negative body length: %r" % body_len) self._current_frame = _Frame(version, flags, stream, op, header_size, body_len + header_size) return pos def _reset_frame(self): self._iobuf = io.BytesIO(self._iobuf.read()) self._iobuf.seek(0, 2) # io.SEEK_END == 2 (constant not present in 2.6) self._current_frame = None def process_io_buffer(self): while True: if not self._current_frame: pos = self._read_frame_header() else: pos = self._iobuf.tell() if not self._current_frame or pos < self._current_frame.end_pos: # we don't have a complete header yet or we # already saw a header, but we don't have a # complete message yet return else: frame = self._current_frame self._iobuf.seek(frame.body_offset) msg = self._iobuf.read(frame.end_pos - frame.body_offset) self.process_msg(frame, msg) self._reset_frame() @defunct_on_error def process_msg(self, header, body): stream_id = header.stream if stream_id < 0: callback = None decoder = ProtocolHandler.decode_message result_metadata = None else: callback, decoder, result_metadata = self._requests.pop(stream_id) with self.lock: self.request_ids.append(stream_id) self.msg_received = True try: response = decoder(header.version, self.user_type_map, stream_id, header.flags, header.opcode, body, self.decompressor, result_metadata) except Exception as exc: log.exception("Error decoding response from Cassandra. " "%s; buffer: %r", header, self._iobuf.getvalue()) if callback is not None: callback(exc) self.defunct(exc) return try: if stream_id >= 0: if isinstance(response, ProtocolException): if 'unsupported protocol version' in response.message: self.is_unsupported_proto_version = True else: log.error("Closing connection %s due to protocol error: %s", self, response.summary_msg()) self.defunct(response) if callback is not None: callback(response) else: self.handle_pushed(response) except Exception: log.exception("Callback handler errored, ignoring:") @defunct_on_error def _send_options_message(self): if self.cql_version is None and (not self.compression or not locally_supported_compressions): log.debug("Not sending options message for new connection(%s) to %s " "because compression is disabled and a cql version was not " "specified", id(self), self.host) self._compressor = None self.cql_version = DEFAULT_CQL_VERSION self._send_startup_message() else: log.debug("Sending initial options message for new connection (%s) to %s", id(self), self.host) self.send_msg(OptionsMessage(), self.get_request_id(), self._handle_options_response) @defunct_on_error def _handle_options_response(self, options_response): if self.is_defunct: return if not isinstance(options_response, SupportedMessage): if isinstance(options_response, ConnectionException): raise options_response else: log.error("Did not get expected SupportedMessage response; " "instead, got: %s", options_response) raise ConnectionException("Did not get expected SupportedMessage " "response; instead, got: %s" % (options_response,)) log.debug("Received options response on new connection (%s) from %s", id(self), self.host) supported_cql_versions = options_response.cql_versions remote_supported_compressions = options_response.options['COMPRESSION'] if self.cql_version: if self.cql_version not in supported_cql_versions: raise ProtocolError( "cql_version %r is not supported by remote (w/ native " "protocol). Supported versions: %r" % (self.cql_version, supported_cql_versions)) else: self.cql_version = supported_cql_versions[0] self._compressor = None compression_type = None if self.compression: overlap = (set(locally_supported_compressions.keys()) & set(remote_supported_compressions)) if len(overlap) == 0: log.debug("No available compression types supported on both ends." " locally supported: %r. remotely supported: %r", locally_supported_compressions.keys(), remote_supported_compressions) else: compression_type = None if isinstance(self.compression, six.string_types): # the user picked a specific compression type ('snappy' or 'lz4') if self.compression not in remote_supported_compressions: raise ProtocolError( "The requested compression type (%s) is not supported by the Cassandra server at %s" % (self.compression, self.host)) compression_type = self.compression else: # our locally supported compressions are ordered to prefer # lz4, if available for k in locally_supported_compressions.keys(): if k in overlap: compression_type = k break # set the decompressor here, but set the compressor only after # a successful Ready message self._compressor, self.decompressor = \ locally_supported_compressions[compression_type] self._send_startup_message(compression_type) @defunct_on_error def _send_startup_message(self, compression=None): log.debug("Sending StartupMessage on %s", self) opts = {} if compression: opts['COMPRESSION'] = compression sm = StartupMessage(cqlversion=self.cql_version, options=opts) self.send_msg(sm, self.get_request_id(), cb=self._handle_startup_response) log.debug("Sent StartupMessage on %s", self) @defunct_on_error def _handle_startup_response(self, startup_response, did_authenticate=False): if self.is_defunct: return if isinstance(startup_response, ReadyMessage): log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.host) if self._compressor: self.compressor = self._compressor self.connected_event.set() elif isinstance(startup_response, AuthenticateMessage): log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s", id(self), self.host, startup_response.authenticator) if self.authenticator is None: raise AuthenticationFailed('Remote end requires authentication.') if isinstance(self.authenticator, dict): log.debug("Sending credentials-based auth response on %s", self) cm = CredentialsMessage(creds=self.authenticator) callback = partial(self._handle_startup_response, did_authenticate=True) self.send_msg(cm, self.get_request_id(), cb=callback) else: log.debug("Sending SASL-based auth response on %s", self) self.authenticator.server_authenticator_class = startup_response.authenticator initial_response = self.authenticator.initial_response() initial_response = "" if initial_response is None else initial_response self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), self._handle_auth_response) elif isinstance(startup_response, ErrorMessage): log.debug("Received ErrorMessage on new connection (%s) from %s: %s", id(self), self.host, startup_response.summary_msg()) if did_authenticate: raise AuthenticationFailed( "Failed to authenticate to %s: %s" % (self.host, startup_response.summary_msg())) else: raise ConnectionException( "Failed to initialize new connection to %s: %s" % (self.host, startup_response.summary_msg())) elif isinstance(startup_response, ConnectionShutdown): log.debug("Connection to %s was closed during the startup handshake", (self.host)) raise startup_response else: msg = "Unexpected response during Connection setup: %r" log.error(msg, startup_response) raise ProtocolError(msg % (startup_response,)) @defunct_on_error def _handle_auth_response(self, auth_response): if self.is_defunct: return if isinstance(auth_response, AuthSuccessMessage): log.debug("Connection %s successfully authenticated", self) self.authenticator.on_authentication_success(auth_response.token) if self._compressor: self.compressor = self._compressor self.connected_event.set() elif isinstance(auth_response, AuthChallengeMessage): response = self.authenticator.evaluate_challenge(auth_response.challenge) msg = AuthResponseMessage("" if response is None else response) log.debug("Responding to auth challenge on %s", self) self.send_msg(msg, self.get_request_id(), self._handle_auth_response) elif isinstance(auth_response, ErrorMessage): log.debug("Received ErrorMessage on new connection (%s) from %s: %s", id(self), self.host, auth_response.summary_msg()) raise AuthenticationFailed( "Failed to authenticate to %s: %s" % (self.host, auth_response.summary_msg())) elif isinstance(auth_response, ConnectionShutdown): log.debug("Connection to %s was closed during the authentication process", self.host) raise auth_response else: msg = "Unexpected response during Connection authentication to %s: %r" log.error(msg, self.host, auth_response) raise ProtocolError(msg % (self.host, auth_response)) def set_keyspace_blocking(self, keyspace): if not keyspace or keyspace == self.keyspace: return query = QueryMessage(query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE) try: result = self.wait_for_response(query) except InvalidRequestException as ire: # the keyspace probably doesn't exist raise ire.to_exception() except Exception as exc: conn_exc = ConnectionException( "Problem while setting keyspace: %r" % (exc,), self.host) self.defunct(conn_exc) raise conn_exc if isinstance(result, ResultMessage): self.keyspace = keyspace else: conn_exc = ConnectionException( "Problem while setting keyspace: %r" % (result,), self.host) self.defunct(conn_exc) raise conn_exc def set_keyspace_async(self, keyspace, callback): """ Use this in order to avoid deadlocking the event loop thread. When the operation completes, `callback` will be called with two arguments: this connection and an Exception if an error occurred, otherwise :const:`None`. """ if not keyspace or keyspace == self.keyspace: callback(self, None) return query = QueryMessage(query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE) def process_result(result): if isinstance(result, ResultMessage): self.keyspace = keyspace callback(self, None) elif isinstance(result, InvalidRequestException): callback(self, result.to_exception()) else: callback(self, self.defunct(ConnectionException( "Problem while setting keyspace: %r" % (result,), self.host))) request_id = None # we use a busy wait on the lock here because: # - we'll only spin if the connection is at max capacity, which is very # unlikely for a set_keyspace call # - it allows us to avoid signaling a condition every time a request completes while True: with self.lock: if self.in_flight < self.max_request_id: request_id = self.get_request_id() self.in_flight += 1 break time.sleep(0.001) self.send_msg(query, request_id, process_result) @property def is_idle(self): return not self.msg_received def reset_idle(self): self.msg_received = False def __str__(self): status = "" if self.is_defunct: status = " (defunct)" elif self.is_closed: status = " (closed)" return "<%s(%r) %s:%d%s>" % (self.__class__.__name__, id(self), self.host, self.port, status) __repr__ = __str__ class ResponseWaiter(object): def __init__(self, connection, num_responses, fail_on_error): self.connection = connection self.pending = num_responses self.fail_on_error = fail_on_error self.error = None self.responses = [None] * num_responses self.event = Event() def got_response(self, response, index): with self.connection.lock: self.connection.in_flight -= 1 if isinstance(response, Exception): if hasattr(response, 'to_exception'): response = response.to_exception() if self.fail_on_error: self.error = response self.event.set() else: self.responses[index] = (False, response) else: if not self.fail_on_error: self.responses[index] = (True, response) else: self.responses[index] = response self.pending -= 1 if not self.pending: self.event.set() def deliver(self, timeout=None): """ If fail_on_error was set to False, a list of (success, response) tuples will be returned. If success is False, response will be an Exception. Otherwise, response will be the normal query response. If fail_on_error was left as True and one of the requests failed, the corresponding Exception will be raised. Otherwise, the normal response will be returned. """ self.event.wait(timeout) if self.error: raise self.error elif not self.event.is_set(): raise OperationTimedOut() else: return self.responses class HeartbeatFuture(object): def __init__(self, connection, owner): self._exception = None self._event = Event() self.connection = connection self.owner = owner log.debug("Sending options message heartbeat on idle connection (%s) %s", id(connection), connection.host) with connection.lock: if connection.in_flight <= connection.max_request_id: connection.in_flight += 1 connection.send_msg(OptionsMessage(), connection.get_request_id(), self._options_callback) else: self._exception = Exception("Failed to send heartbeat because connection 'in_flight' exceeds threshold") self._event.set() def wait(self, timeout): self._event.wait(timeout) if self._event.is_set(): if self._exception: raise self._exception else: raise OperationTimedOut("Connection heartbeat timeout after %s seconds" % (timeout,), self.connection.host) def _options_callback(self, response): if isinstance(response, SupportedMessage): log.debug("Received options response on connection (%s) from %s", id(self.connection), self.connection.host) else: if isinstance(response, ConnectionException): self._exception = response else: self._exception = ConnectionException("Received unexpected response to OptionsMessage: %s" % (response,)) self._event.set() class ConnectionHeartbeat(Thread): def __init__(self, interval_sec, get_connection_holders): Thread.__init__(self, name="Connection heartbeat") self._interval = interval_sec self._get_connection_holders = get_connection_holders self._shutdown_event = Event() self.daemon = True self.start() class ShutdownException(Exception): pass def run(self): self._shutdown_event.wait(self._interval) while not self._shutdown_event.is_set(): start_time = time.time() futures = [] failed_connections = [] try: for connections, owner in [(o.get_connections(), o) for o in self._get_connection_holders()]: for connection in connections: self._raise_if_stopped() if not (connection.is_defunct or connection.is_closed): if connection.is_idle: try: futures.append(HeartbeatFuture(connection, owner)) except Exception as e: log.warning("Failed sending heartbeat message on connection (%s) to %s", id(connection), connection.host) failed_connections.append((connection, owner, e)) else: connection.reset_idle() else: # make sure the owner sees this defunt/closed connection owner.return_connection(connection) self._raise_if_stopped() for f in futures: self._raise_if_stopped() connection = f.connection try: f.wait(self._interval) # TODO: move this, along with connection locks in pool, down into Connection with connection.lock: connection.in_flight -= 1 connection.reset_idle() except Exception as e: log.warning("Heartbeat failed for connection (%s) to %s", id(connection), connection.host) failed_connections.append((f.connection, f.owner, e)) for connection, owner, exc in failed_connections: self._raise_if_stopped() connection.defunct(exc) owner.return_connection(connection) except self.ShutdownException: pass except Exception: log.error("Failed connection heartbeat", exc_info=True) elapsed = time.time() - start_time self._shutdown_event.wait(max(self._interval - elapsed, 0.01)) def stop(self): self._shutdown_event.set() self.join() def _raise_if_stopped(self): if self._shutdown_event.is_set(): raise self.ShutdownException() class Timer(object): canceled = False def __init__(self, timeout, callback): self.end = time.time() + timeout self.callback = callback if timeout < 0: self.callback() def __lt__(self, other): return self.end < other.end def cancel(self): self.canceled = True def finish(self, time_now): if self.canceled: return True if time_now >= self.end: self.callback() return True return False class TimerManager(object): def __init__(self): self._queue = [] self._new_timers = [] def add_timer(self, timer): """ called from client thread with a Timer object """ self._new_timers.append((timer.end, timer)) def service_timeouts(self): """ run callbacks on all expired timers Called from the event thread :return: next end time, or None """ queue = self._queue if self._new_timers: new_timers = self._new_timers while new_timers: heappush(queue, new_timers.pop()) if queue: now = time.time() while queue: try: timer = queue[0][1] if timer.finish(now): heappop(queue) else: return timer.end except Exception: log.exception("Exception while servicing timeout callback: ") @property def next_timeout(self): try: return self._queue[0][0] except IndexError: pass cassandra-driver-3.7.1/cassandra/obj_parser.pyx0000664000175000017500000000477013001676145024456 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. include "ioutils.pyx" from cassandra.bytesio cimport BytesIOReader from cassandra.deserializers cimport Deserializer, from_binary from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser from cassandra.tuple cimport tuple_new, tuple_set cdef class ListParser(ColumnParser): """Decode a ResultMessage into a list of tuples (or other objects)""" cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() return [rowparser.unpack_row(reader, desc) for i in range(rowcount)] cdef class LazyParser(ColumnParser): """Decode a ResultMessage lazily using a generator""" cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc): # Use a little helper function as closures (generators) are not # supported in cpdef methods return parse_rows_lazy(reader, desc) def parse_rows_lazy(BytesIOReader reader, ParseDesc desc): cdef Py_ssize_t i, rowcount rowcount = read_int(reader) cdef RowParser rowparser = TupleRowParser() return (rowparser.unpack_row(reader, desc) for i in range(rowcount)) cdef class TupleRowParser(RowParser): """ Parse a single returned row into a tuple of objects: (obj1, ..., objN) """ cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc): assert desc.rowsize >= 0 cdef Buffer buf cdef Py_ssize_t i, rowsize = desc.rowsize cdef Deserializer deserializer cdef tuple res = tuple_new(desc.rowsize) for i in range(rowsize): # Read the next few bytes get_buf(reader, &buf) # Deserialize bytes to python object deserializer = desc.deserializers[i] val = from_binary(deserializer, &buf, desc.protocol_version) # Insert new object into tuple tuple_set(res, i, val) return res cassandra-driver-3.7.1/cassandra/marshal.py0000664000175000017500000000524012743410406023556 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import six import struct def _make_packer(format_string): packer = struct.Struct(format_string) pack = packer.pack unpack = lambda s: packer.unpack(s)[0] return pack, unpack int64_pack, int64_unpack = _make_packer('>q') int32_pack, int32_unpack = _make_packer('>i') int16_pack, int16_unpack = _make_packer('>h') int8_pack, int8_unpack = _make_packer('>b') uint64_pack, uint64_unpack = _make_packer('>Q') uint32_pack, uint32_unpack = _make_packer('>I') uint16_pack, uint16_unpack = _make_packer('>H') uint8_pack, uint8_unpack = _make_packer('>B') float_pack, float_unpack = _make_packer('>f') double_pack, double_unpack = _make_packer('>d') # Special case for cassandra header header_struct = struct.Struct('>BBbB') header_pack = header_struct.pack header_unpack = header_struct.unpack # in protocol version 3 and higher, the stream ID is two bytes v3_header_struct = struct.Struct('>BBhB') v3_header_pack = v3_header_struct.pack v3_header_unpack = v3_header_struct.unpack if six.PY3: def varint_unpack(term): val = int(''.join("%02x" % i for i in term), 16) if (term[0] & 128) != 0: len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code val -= 1 << (len_term * 8) return val else: def varint_unpack(term): # noqa val = int(term.encode('hex'), 16) if (ord(term[0]) & 128) != 0: len_term = len(term) # pulling this out of the expression to avoid overflow in cython optimized code val = val - (1 << (len_term * 8)) return val def bitlength(n): bitlen = 0 while n > 0: n >>= 1 bitlen += 1 return bitlen def varint_pack(big): pos = True if big == 0: return b'\x00' if big < 0: bytelength = bitlength(abs(big) - 1) // 8 + 1 big = (1 << bytelength * 8) + big pos = False revbytes = bytearray() while big > 0: revbytes.append(big & 0xff) big >>= 8 if pos and revbytes[-1] & 0x80: revbytes.append(0) revbytes.reverse() return six.binary_type(revbytes) cassandra-driver-3.7.1/cassandra/encoder.py0000664000175000017500000001732112766043657023570 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ These functions are used to convert Python objects into CQL strings. When non-prepared statements are executed, these encoder functions are called on each query parameter. """ import logging log = logging.getLogger(__name__) from binascii import hexlify import calendar import datetime import math import sys import types from uuid import UUID import six from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey, sortedset, Time, Date) if six.PY3: long = int def cql_quote(term): # The ordering of this method is important for the result of this method to # be a native str type (for both Python 2 and 3) if isinstance(term, str): return "'%s'" % str(term).replace("'", "''") # This branch of the if statement will only be used by Python 2 to catch # unicode strings, text_type is used to prevent type errors with Python 3. elif isinstance(term, six.text_type): return "'%s'" % term.encode('utf8').replace("'", "''") else: return str(term) class ValueSequence(list): pass class Encoder(object): """ A container for mapping python types to CQL string literals when working with non-prepared statements. The type :attr:`~.Encoder.mapping` can be directly customized by users. """ mapping = None """ A map of python types to encoder functions. """ def __init__(self): self.mapping = { float: self.cql_encode_float, bytearray: self.cql_encode_bytes, str: self.cql_encode_str, int: self.cql_encode_object, UUID: self.cql_encode_object, datetime.datetime: self.cql_encode_datetime, datetime.date: self.cql_encode_date, datetime.time: self.cql_encode_time, Date: self.cql_encode_date_ext, Time: self.cql_encode_time, dict: self.cql_encode_map_collection, OrderedDict: self.cql_encode_map_collection, OrderedMap: self.cql_encode_map_collection, OrderedMapSerializedKey: self.cql_encode_map_collection, list: self.cql_encode_list_collection, tuple: self.cql_encode_list_collection, # TODO: change to tuple in next major set: self.cql_encode_set_collection, sortedset: self.cql_encode_set_collection, frozenset: self.cql_encode_set_collection, types.GeneratorType: self.cql_encode_list_collection, ValueSequence: self.cql_encode_sequence } if six.PY2: self.mapping.update({ unicode: self.cql_encode_unicode, buffer: self.cql_encode_bytes, long: self.cql_encode_object, types.NoneType: self.cql_encode_none, }) else: self.mapping.update({ memoryview: self.cql_encode_bytes, bytes: self.cql_encode_bytes, type(None): self.cql_encode_none, }) def cql_encode_none(self, val): """ Converts :const:`None` to the string 'NULL'. """ return 'NULL' def cql_encode_unicode(self, val): """ Converts :class:`unicode` objects to UTF-8 encoded strings with quote escaping. """ return cql_quote(val.encode('utf-8')) def cql_encode_str(self, val): """ Escapes quotes in :class:`str` objects. """ return cql_quote(val) if six.PY3: def cql_encode_bytes(self, val): return (b'0x' + hexlify(val)).decode('utf-8') elif sys.version_info >= (2, 7): def cql_encode_bytes(self, val): # noqa return b'0x' + hexlify(val) else: # python 2.6 requires string or read-only buffer for hexlify def cql_encode_bytes(self, val): # noqa return b'0x' + hexlify(buffer(val)) def cql_encode_object(self, val): """ Default encoder for all objects that do not have a specific encoder function registered. This function simply calls :meth:`str()` on the object. """ return str(val) def cql_encode_float(self, val): """ Encode floats using repr to preserve precision """ if math.isinf(val): return 'Infinity' if val > 0 else '-Infinity' elif math.isnan(val): return 'NaN' else: return repr(val) def cql_encode_datetime(self, val): """ Converts a :class:`datetime.datetime` object to a (string) integer timestamp with millisecond precision. """ timestamp = calendar.timegm(val.utctimetuple()) return str(long(timestamp * 1e3 + getattr(val, 'microsecond', 0) / 1e3)) def cql_encode_date(self, val): """ Converts a :class:`datetime.date` object to a string with format ``YYYY-MM-DD``. """ return "'%s'" % val.strftime('%Y-%m-%d') def cql_encode_time(self, val): """ Converts a :class:`cassandra.util.Time` object to a string with format ``HH:MM:SS.mmmuuunnn``. """ return "'%s'" % val def cql_encode_date_ext(self, val): """ Encodes a :class:`cassandra.util.Date` object as an integer """ # using the int form in case the Date exceeds datetime.[MIN|MAX]YEAR return str(val.days_from_epoch + 2 ** 31) def cql_encode_sequence(self, val): """ Converts a sequence to a string of the form ``(item1, item2, ...)``. This is suitable for ``IN`` value lists. """ return '(%s)' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) cql_encode_tuple = cql_encode_sequence """ Converts a sequence to a string of the form ``(item1, item2, ...)``. This is suitable for ``tuple`` type columns. """ def cql_encode_map_collection(self, val): """ Converts a dict into a string of the form ``{key1: val1, key2: val2, ...}``. This is suitable for ``map`` type columns. """ return '{%s}' % ', '.join('%s: %s' % ( self.mapping.get(type(k), self.cql_encode_object)(k), self.mapping.get(type(v), self.cql_encode_object)(v) ) for k, v in six.iteritems(val)) def cql_encode_list_collection(self, val): """ Converts a sequence to a string of the form ``[item1, item2, ...]``. This is suitable for ``list`` type columns. """ return '[%s]' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) def cql_encode_set_collection(self, val): """ Converts a sequence to a string of the form ``{item1, item2, ...}``. This is suitable for ``set`` type columns. """ return '{%s}' % ', '.join(self.mapping.get(type(v), self.cql_encode_object)(v) for v in val) def cql_encode_all_types(self, val): """ Converts any type into a CQL string, defaulting to ``cql_encode_object`` if :attr:`~Encoder.mapping` does not contain an entry for the type. """ return self.mapping.get(type(val), self.cql_encode_object)(val) cassandra-driver-3.7.1/cassandra/util.py0000664000175000017500000011064212766043657023126 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import with_statement import calendar import datetime import random import six import uuid import sys DATETIME_EPOC = datetime.datetime(1970, 1, 1) assert sys.byteorder in ('little', 'big') is_little_endian = sys.byteorder == 'little' def datetime_from_timestamp(timestamp): """ Creates a timezone-agnostic datetime from timestamp (in seconds) in a consistent manner. Works around a Windows issue with large negative timestamps (PYTHON-119), and rounding differences in Python 3.4 (PYTHON-340). :param timestamp: a unix timestamp, in seconds """ dt = DATETIME_EPOC + datetime.timedelta(seconds=timestamp) return dt def unix_time_from_uuid1(uuid_arg): """ Converts a version 1 :class:`uuid.UUID` to a timestamp with the same precision as :meth:`time.time()` returns. This is useful for examining the results of queries returning a v1 :class:`~uuid.UUID`. :param uuid_arg: a version 1 :class:`~uuid.UUID` """ return (uuid_arg.time - 0x01B21DD213814000) / 1e7 def datetime_from_uuid1(uuid_arg): """ Creates a timezone-agnostic datetime from the timestamp in the specified type-1 UUID. :param uuid_arg: a version 1 :class:`~uuid.UUID` """ return datetime_from_timestamp(unix_time_from_uuid1(uuid_arg)) def min_uuid_from_time(timestamp): """ Generates the minimum TimeUUID (type 1) for a given timestamp, as compared by Cassandra. See :func:`uuid_from_time` for argument and return types. """ return uuid_from_time(timestamp, 0x808080808080, 0x80) # Cassandra does byte-wise comparison; fill with min signed bytes (0x80 = -128) def max_uuid_from_time(timestamp): """ Generates the maximum TimeUUID (type 1) for a given timestamp, as compared by Cassandra. See :func:`uuid_from_time` for argument and return types. """ return uuid_from_time(timestamp, 0x7f7f7f7f7f7f, 0x3f7f) # Max signed bytes (0x7f = 127) def uuid_from_time(time_arg, node=None, clock_seq=None): """ Converts a datetime or timestamp to a type 1 :class:`uuid.UUID`. :param time_arg: The time to use for the timestamp portion of the UUID. This can either be a :class:`datetime` object or a timestamp in seconds (as returned from :meth:`time.time()`). :type datetime: :class:`datetime` or timestamp :param node: None integer for the UUID (up to 48 bits). If not specified, this field is randomized. :type node: long :param clock_seq: Clock sequence field for the UUID (up to 14 bits). If not specified, a random sequence is generated. :type clock_seq: int :rtype: :class:`uuid.UUID` """ if hasattr(time_arg, 'utctimetuple'): seconds = int(calendar.timegm(time_arg.utctimetuple())) microseconds = (seconds * 1e6) + time_arg.time().microsecond else: microseconds = int(time_arg * 1e6) # 0x01b21dd213814000 is the number of 100-ns intervals between the # UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00. intervals = int(microseconds * 10) + 0x01b21dd213814000 time_low = intervals & 0xffffffff time_mid = (intervals >> 32) & 0xffff time_hi_version = (intervals >> 48) & 0x0fff if clock_seq is None: clock_seq = random.getrandbits(14) else: if clock_seq > 0x3fff: raise ValueError('clock_seq is out of range (need a 14-bit value)') clock_seq_low = clock_seq & 0xff clock_seq_hi_variant = 0x80 | ((clock_seq >> 8) & 0x3f) if node is None: node = random.getrandbits(48) return uuid.UUID(fields=(time_low, time_mid, time_hi_version, clock_seq_hi_variant, clock_seq_low, node), version=1) LOWEST_TIME_UUID = uuid.UUID('00000000-0000-1000-8080-808080808080') """ The lowest possible TimeUUID, as sorted by Cassandra. """ HIGHEST_TIME_UUID = uuid.UUID('ffffffff-ffff-1fff-bf7f-7f7f7f7f7f7f') """ The highest possible TimeUUID, as sorted by Cassandra. """ try: from collections import OrderedDict except ImportError: # OrderedDict from Python 2.7+ # Copyright (c) 2009 Raymond Hettinger # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation files # (the "Software"), to deal in the Software without restriction, # including without limitation the rights to use, copy, modify, merge, # publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES # OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR # OTHER DEALINGS IN THE SOFTWARE. from UserDict import DictMixin class OrderedDict(dict, DictMixin): # noqa """ A dictionary which maintains the insertion order of keys. """ def __init__(self, *args, **kwds): """ A dictionary which maintains the insertion order of keys. """ if len(args) > 1: raise TypeError('expected at most 1 arguments, got %d' % len(args)) try: self.__end except AttributeError: self.clear() self.update(*args, **kwds) def clear(self): self.__end = end = [] end += [None, end, end] # sentinel node for doubly linked list self.__map = {} # key --> [key, prev, next] dict.clear(self) def __setitem__(self, key, value): if key not in self: end = self.__end curr = end[1] curr[2] = end[1] = self.__map[key] = [key, curr, end] dict.__setitem__(self, key, value) def __delitem__(self, key): dict.__delitem__(self, key) key, prev, next = self.__map.pop(key) prev[2] = next next[1] = prev def __iter__(self): end = self.__end curr = end[2] while curr is not end: yield curr[0] curr = curr[2] def __reversed__(self): end = self.__end curr = end[1] while curr is not end: yield curr[0] curr = curr[1] def popitem(self, last=True): if not self: raise KeyError('dictionary is empty') if last: key = next(reversed(self)) else: key = next(iter(self)) value = self.pop(key) return key, value def __reduce__(self): items = [[k, self[k]] for k in self] tmp = self.__map, self.__end del self.__map, self.__end inst_dict = vars(self).copy() self.__map, self.__end = tmp if inst_dict: return (self.__class__, (items,), inst_dict) return self.__class__, (items,) def keys(self): return list(self) setdefault = DictMixin.setdefault update = DictMixin.update pop = DictMixin.pop values = DictMixin.values items = DictMixin.items iterkeys = DictMixin.iterkeys itervalues = DictMixin.itervalues iteritems = DictMixin.iteritems def __repr__(self): if not self: return '%s()' % (self.__class__.__name__,) return '%s(%r)' % (self.__class__.__name__, self.items()) def copy(self): return self.__class__(self) @classmethod def fromkeys(cls, iterable, value=None): d = cls() for key in iterable: d[key] = value return d def __eq__(self, other): if isinstance(other, OrderedDict): if len(self) != len(other): return False for p, q in zip(self.items(), other.items()): if p != q: return False return True return dict.__eq__(self, other) def __ne__(self, other): return not self == other # WeakSet from Python 2.7+ (https://code.google.com/p/weakrefset) from _weakref import ref class _IterationGuard(object): # This context manager registers itself in the current iterators of the # weak container, such as to delay all removals until the context manager # exits. # This technique should be relatively thread-safe (since sets are). def __init__(self, weakcontainer): # Don't create cycles self.weakcontainer = ref(weakcontainer) def __enter__(self): w = self.weakcontainer() if w is not None: w._iterating.add(self) return self def __exit__(self, e, t, b): w = self.weakcontainer() if w is not None: s = w._iterating s.remove(self) if not s: w._commit_removals() class WeakSet(object): def __init__(self, data=None): self.data = set() def _remove(item, selfref=ref(self)): self = selfref() if self is not None: if self._iterating: self._pending_removals.append(item) else: self.data.discard(item) self._remove = _remove # A list of keys to be removed self._pending_removals = [] self._iterating = set() if data is not None: self.update(data) def _commit_removals(self): l = self._pending_removals discard = self.data.discard while l: discard(l.pop()) def __iter__(self): with _IterationGuard(self): for itemref in self.data: item = itemref() if item is not None: yield item def __len__(self): return sum(x() is not None for x in self.data) def __contains__(self, item): return ref(item) in self.data def __reduce__(self): return (self.__class__, (list(self),), getattr(self, '__dict__', None)) __hash__ = None def add(self, item): if self._pending_removals: self._commit_removals() self.data.add(ref(item, self._remove)) def clear(self): if self._pending_removals: self._commit_removals() self.data.clear() def copy(self): return self.__class__(self) def pop(self): if self._pending_removals: self._commit_removals() while True: try: itemref = self.data.pop() except KeyError: raise KeyError('pop from empty WeakSet') item = itemref() if item is not None: return item def remove(self, item): if self._pending_removals: self._commit_removals() self.data.remove(ref(item)) def discard(self, item): if self._pending_removals: self._commit_removals() self.data.discard(ref(item)) def update(self, other): if self._pending_removals: self._commit_removals() if isinstance(other, self.__class__): self.data.update(other.data) else: for element in other: self.add(element) def __ior__(self, other): self.update(other) return self # Helper functions for simple delegating methods. def _apply(self, other, method): if not isinstance(other, self.__class__): other = self.__class__(other) newdata = method(other.data) newset = self.__class__() newset.data = newdata return newset def difference(self, other): return self._apply(other, self.data.difference) __sub__ = difference def difference_update(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: self.data.difference_update(ref(item) for item in other) def __isub__(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: self.data.difference_update(ref(item) for item in other) return self def intersection(self, other): return self._apply(other, self.data.intersection) __and__ = intersection def intersection_update(self, other): if self._pending_removals: self._commit_removals() self.data.intersection_update(ref(item) for item in other) def __iand__(self, other): if self._pending_removals: self._commit_removals() self.data.intersection_update(ref(item) for item in other) return self def issubset(self, other): return self.data.issubset(ref(item) for item in other) __lt__ = issubset def __le__(self, other): return self.data <= set(ref(item) for item in other) def issuperset(self, other): return self.data.issuperset(ref(item) for item in other) __gt__ = issuperset def __ge__(self, other): return self.data >= set(ref(item) for item in other) def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented return self.data == set(ref(item) for item in other) def symmetric_difference(self, other): return self._apply(other, self.data.symmetric_difference) __xor__ = symmetric_difference def symmetric_difference_update(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: self.data.symmetric_difference_update(ref(item) for item in other) def __ixor__(self, other): if self._pending_removals: self._commit_removals() if self is other: self.data.clear() else: self.data.symmetric_difference_update(ref(item) for item in other) return self def union(self, other): return self._apply(other, self.data.union) __or__ = union def isdisjoint(self, other): return len(self.intersection(other)) == 0 from bisect import bisect_left class SortedSet(object): ''' A sorted set based on sorted list A sorted set implementation is used in this case because it does not require its elements to be immutable/hashable. #Not implemented: update functions, inplace operators ''' def __init__(self, iterable=()): self._items = [] self.update(iterable) def __len__(self): return len(self._items) def __getitem__(self, i): return self._items[i] def __iter__(self): return iter(self._items) def __reversed__(self): return reversed(self._items) def __repr__(self): return '%s(%r)' % ( self.__class__.__name__, self._items) def __reduce__(self): return self.__class__, (self._items,) def __eq__(self, other): if isinstance(other, self.__class__): return self._items == other._items else: try: return len(other) == len(self._items) and all(item in self for item in other) except TypeError: return NotImplemented def __ne__(self, other): if isinstance(other, self.__class__): return self._items != other._items else: try: return len(other) != len(self._items) or any(item not in self for item in other) except TypeError: return NotImplemented def __le__(self, other): return self.issubset(other) def __lt__(self, other): return len(other) > len(self._items) and self.issubset(other) def __ge__(self, other): return self.issuperset(other) def __gt__(self, other): return len(self._items) > len(other) and self.issuperset(other) def __and__(self, other): return self._intersect(other) __rand__ = __and__ def __iand__(self, other): isect = self._intersect(other) self._items = isect._items return self def __or__(self, other): return self.union(other) __ror__ = __or__ def __ior__(self, other): union = self.union(other) self._items = union._items return self def __sub__(self, other): return self._diff(other) def __rsub__(self, other): return sortedset(other) - self def __isub__(self, other): diff = self._diff(other) self._items = diff._items return self def __xor__(self, other): return self.symmetric_difference(other) __rxor__ = __xor__ def __ixor__(self, other): sym_diff = self.symmetric_difference(other) self._items = sym_diff._items return self def __contains__(self, item): i = bisect_left(self._items, item) return i < len(self._items) and self._items[i] == item def __delitem__(self, i): del self._items[i] def __delslice__(self, i, j): del self._items[i:j] def add(self, item): i = bisect_left(self._items, item) if i < len(self._items): if self._items[i] != item: self._items.insert(i, item) else: self._items.append(item) def update(self, iterable): for i in iterable: self.add(i) def clear(self): del self._items[:] def copy(self): new = sortedset() new._items = list(self._items) return new def isdisjoint(self, other): return len(self._intersect(other)) == 0 def issubset(self, other): return len(self._intersect(other)) == len(self._items) def issuperset(self, other): return len(self._intersect(other)) == len(other) def pop(self): if not self._items: raise KeyError("pop from empty set") return self._items.pop() def remove(self, item): i = bisect_left(self._items, item) if i < len(self._items): if self._items[i] == item: self._items.pop(i) return raise KeyError('%r' % item) def union(self, *others): union = sortedset() union._items = list(self._items) for other in others: if isinstance(other, self.__class__): i = 0 for item in other._items: i = bisect_left(union._items, item, i) if i < len(union._items): if item != union._items[i]: union._items.insert(i, item) else: union._items.append(item) else: for item in other: union.add(item) return union def intersection(self, *others): isect = self.copy() for other in others: isect = isect._intersect(other) if not isect: break return isect def difference(self, *others): diff = self.copy() for other in others: diff = diff._diff(other) if not diff: break return diff def symmetric_difference(self, other): diff_self_other = self._diff(other) diff_other_self = other.difference(self) return diff_self_other.union(diff_other_self) def _diff(self, other): diff = sortedset() if isinstance(other, self.__class__): i = 0 for item in self._items: i = bisect_left(other._items, item, i) if i < len(other._items): if item != other._items[i]: diff._items.append(item) else: diff._items.append(item) else: for item in self._items: if item not in other: diff.add(item) return diff def _intersect(self, other): isect = sortedset() if isinstance(other, self.__class__): i = 0 for item in self._items: i = bisect_left(other._items, item, i) if i < len(other._items): if item == other._items[i]: isect._items.append(item) else: break else: for item in self._items: if item in other: isect.add(item) return isect sortedset = SortedSet # backwards-compatibility from collections import Mapping from six.moves import cPickle class OrderedMap(Mapping): ''' An ordered map that accepts non-hashable types for keys. It also maintains the insertion order of items, behaving as OrderedDict in that regard. These maps are constructed and read just as normal mapping types, exept that they may contain arbitrary collections and other non-hashable items as keys:: >>> od = OrderedMap([({'one': 1, 'two': 2}, 'value'), ... ({'three': 3, 'four': 4}, 'value2')]) >>> list(od.keys()) [{'two': 2, 'one': 1}, {'three': 3, 'four': 4}] >>> list(od.values()) ['value', 'value2'] These constructs are needed to support nested collections in Cassandra 2.1.3+, where frozen collections can be specified as parameters to others\*:: CREATE TABLE example ( ... value map>, double> ... ) This class derives from the (immutable) Mapping API. Objects in these maps are not intended be modified. \* Note: Because of the way Cassandra encodes nested types, when using the driver with nested collections, :attr:`~.Cluster.protocol_version` must be 3 or higher. ''' def __init__(self, *args, **kwargs): if len(args) > 1: raise TypeError('expected at most 1 arguments, got %d' % len(args)) self._items = [] self._index = {} if args: e = args[0] if callable(getattr(e, 'keys', None)): for k in e.keys(): self._insert(k, e[k]) else: for k, v in e: self._insert(k, v) for k, v in six.iteritems(kwargs): self._insert(k, v) def _insert(self, key, value): flat_key = self._serialize_key(key) i = self._index.get(flat_key, -1) if i >= 0: self._items[i] = (key, value) else: self._items.append((key, value)) self._index[flat_key] = len(self._items) - 1 __setitem__ = _insert def __getitem__(self, key): try: index = self._index[self._serialize_key(key)] return self._items[index][1] except KeyError: raise KeyError(str(key)) def __delitem__(self, key): # not efficient -- for convenience only try: index = self._index.pop(self._serialize_key(key)) self._index = dict((k, i if i < index else i - 1) for k, i in self._index.items()) self._items.pop(index) except KeyError: raise KeyError(str(key)) def __iter__(self): for i in self._items: yield i[0] def __len__(self): return len(self._items) def __eq__(self, other): if isinstance(other, OrderedMap): return self._items == other._items try: d = dict(other) return len(d) == len(self._items) and all(i[1] == d[i[0]] for i in self._items) except KeyError: return False except TypeError: pass return NotImplemented def __repr__(self): return '%s([%s])' % ( self.__class__.__name__, ', '.join("(%r, %r)" % (k, v) for k, v in self._items)) def __str__(self): return '{%s}' % ', '.join("%r: %r" % (k, v) for k, v in self._items) def popitem(self): try: kv = self._items.pop() del self._index[self._serialize_key(kv[0])] return kv except IndexError: raise KeyError() def _serialize_key(self, key): return cPickle.dumps(key) class OrderedMapSerializedKey(OrderedMap): def __init__(self, cass_type, protocol_version): super(OrderedMapSerializedKey, self).__init__() self.cass_key_type = cass_type self.protocol_version = protocol_version def _insert_unchecked(self, key, flat_key, value): self._items.append((key, value)) self._index[flat_key] = len(self._items) - 1 def _serialize_key(self, key): return self.cass_key_type.serialize(key, self.protocol_version) import datetime import time if six.PY3: long = int class Time(object): ''' Idealized time, independent of day. Up to nanosecond resolution ''' MICRO = 1000 MILLI = 1000 * MICRO SECOND = 1000 * MILLI MINUTE = 60 * SECOND HOUR = 60 * MINUTE DAY = 24 * HOUR nanosecond_time = 0 def __init__(self, value): """ Initializer value can be: - integer_type: absolute nanoseconds in the day - datetime.time: built-in time - string_type: a string time of the form "HH:MM:SS[.mmmuuunnn]" """ if isinstance(value, six.integer_types): self._from_timestamp(value) elif isinstance(value, datetime.time): self._from_time(value) elif isinstance(value, six.string_types): self._from_timestring(value) else: raise TypeError('Time arguments must be a whole number, datetime.time, or string') @property def hour(self): """ The hour component of this time (0-23) """ return self.nanosecond_time // Time.HOUR @property def minute(self): """ The minute component of this time (0-59) """ minutes = self.nanosecond_time // Time.MINUTE return minutes % 60 @property def second(self): """ The second component of this time (0-59) """ seconds = self.nanosecond_time // Time.SECOND return seconds % 60 @property def nanosecond(self): """ The fractional seconds component of the time, in nanoseconds """ return self.nanosecond_time % Time.SECOND def time(self): """ Return a built-in datetime.time (nanosecond precision truncated to micros). """ return datetime.time(hour=self.hour, minute=self.minute, second=self.second, microsecond=self.nanosecond // Time.MICRO) def _from_timestamp(self, t): if t >= Time.DAY: raise ValueError("value must be less than number of nanoseconds in a day (%d)" % Time.DAY) self.nanosecond_time = t def _from_timestring(self, s): try: parts = s.split('.') base_time = time.strptime(parts[0], "%H:%M:%S") self.nanosecond_time = (base_time.tm_hour * Time.HOUR + base_time.tm_min * Time.MINUTE + base_time.tm_sec * Time.SECOND) if len(parts) > 1: # right pad to 9 digits nano_time_str = parts[1] + "0" * (9 - len(parts[1])) self.nanosecond_time += int(nano_time_str) except ValueError: raise ValueError("can't interpret %r as a time" % (s,)) def _from_time(self, t): self.nanosecond_time = (t.hour * Time.HOUR + t.minute * Time.MINUTE + t.second * Time.SECOND + t.microsecond * Time.MICRO) def __hash__(self): return self.nanosecond_time def __eq__(self, other): if isinstance(other, Time): return self.nanosecond_time == other.nanosecond_time if isinstance(other, six.integer_types): return self.nanosecond_time == other return self.nanosecond_time % Time.MICRO == 0 and \ datetime.time(hour=self.hour, minute=self.minute, second=self.second, microsecond=self.nanosecond // Time.MICRO) == other def __lt__(self, other): if not isinstance(other, Time): return NotImplemented return self.nanosecond_time < other.nanosecond_time def __repr__(self): return "Time(%s)" % self.nanosecond_time def __str__(self): return "%02d:%02d:%02d.%09d" % (self.hour, self.minute, self.second, self.nanosecond) class Date(object): ''' Idealized date: year, month, day Offers wider year range than datetime.date. For Dates that cannot be represented as a datetime.date (because datetime.MINYEAR, datetime.MAXYEAR), this type falls back to printing days_from_epoch offset. ''' MINUTE = 60 HOUR = 60 * MINUTE DAY = 24 * HOUR date_format = "%Y-%m-%d" days_from_epoch = 0 def __init__(self, value): """ Initializer value can be: - integer_type: absolute days from epoch (1970, 1, 1). Can be negative. - datetime.date: built-in date - string_type: a string time of the form "yyyy-mm-dd" """ if isinstance(value, six.integer_types): self.days_from_epoch = value elif isinstance(value, (datetime.date, datetime.datetime)): self._from_timetuple(value.timetuple()) elif isinstance(value, six.string_types): self._from_datestring(value) else: raise TypeError('Date arguments must be a whole number, datetime.date, or string') @property def seconds(self): """ Absolute seconds from epoch (can be negative) """ return self.days_from_epoch * Date.DAY def date(self): """ Return a built-in datetime.date for Dates falling in the years [datetime.MINYEAR, datetime.MAXYEAR] ValueError is raised for Dates outside this range. """ try: dt = datetime_from_timestamp(self.seconds) return datetime.date(dt.year, dt.month, dt.day) except Exception: raise ValueError("%r exceeds ranges for built-in datetime.date" % self) def _from_timetuple(self, t): self.days_from_epoch = calendar.timegm(t) // Date.DAY def _from_datestring(self, s): if s[0] == '+': s = s[1:] dt = datetime.datetime.strptime(s, self.date_format) self._from_timetuple(dt.timetuple()) def __hash__(self): return self.days_from_epoch def __eq__(self, other): if isinstance(other, Date): return self.days_from_epoch == other.days_from_epoch if isinstance(other, six.integer_types): return self.days_from_epoch == other try: return self.date() == other except Exception: return False def __lt__(self, other): if not isinstance(other, Date): return NotImplemented return self.days_from_epoch < other.days_from_epoch def __repr__(self): return "Date(%s)" % self.days_from_epoch def __str__(self): try: dt = datetime_from_timestamp(self.seconds) return "%04d-%02d-%02d" % (dt.year, dt.month, dt.day) except: # If we overflow datetime.[MIN|MAX] return str(self.days_from_epoch) import socket if hasattr(socket, 'inet_pton'): inet_pton = socket.inet_pton inet_ntop = socket.inet_ntop else: """ Windows doesn't have socket.inet_pton and socket.inet_ntop until Python 3.4 This is an alternative impl using ctypes, based on this win_inet_pton project: https://github.com/hickeroar/win_inet_pton """ import ctypes class sockaddr(ctypes.Structure): """ Shared struct for ipv4 and ipv6. https://msdn.microsoft.com/en-us/library/windows/desktop/ms740496(v=vs.85).aspx ``__pad1`` always covers the port. When being used for ``sockaddr_in6``, ``ipv4_addr`` actually covers ``sin6_flowinfo``, resulting in proper alignment for ``ipv6_addr``. """ _fields_ = [("sa_family", ctypes.c_short), ("__pad1", ctypes.c_ushort), ("ipv4_addr", ctypes.c_byte * 4), ("ipv6_addr", ctypes.c_byte * 16), ("__pad2", ctypes.c_ulong)] if hasattr(ctypes, 'windll'): WSAStringToAddressA = ctypes.windll.ws2_32.WSAStringToAddressA WSAAddressToStringA = ctypes.windll.ws2_32.WSAAddressToStringA else: def not_windows(*args): raise OSError("IPv6 addresses cannot be handled on Windows. " "Missing ctypes.windll") WSAStringToAddressA = not_windows WSAAddressToStringA = not_windows def inet_pton(address_family, ip_string): if address_family == socket.AF_INET: return socket.inet_aton(ip_string) addr = sockaddr() addr.sa_family = address_family addr_size = ctypes.c_int(ctypes.sizeof(addr)) if WSAStringToAddressA( ip_string, address_family, None, ctypes.byref(addr), ctypes.byref(addr_size) ) != 0: raise socket.error(ctypes.FormatError()) if address_family == socket.AF_INET6: return ctypes.string_at(addr.ipv6_addr, 16) raise socket.error('unknown address family') def inet_ntop(address_family, packed_ip): if address_family == socket.AF_INET: return socket.inet_ntoa(packed_ip) addr = sockaddr() addr.sa_family = address_family addr_size = ctypes.c_int(ctypes.sizeof(addr)) ip_string = ctypes.create_string_buffer(128) ip_string_size = ctypes.c_int(ctypes.sizeof(ip_string)) if address_family == socket.AF_INET6: if len(packed_ip) != ctypes.sizeof(addr.ipv6_addr): raise socket.error('packed IP wrong length for inet_ntoa') ctypes.memmove(addr.ipv6_addr, packed_ip, 16) else: raise socket.error('unknown address family') if WSAAddressToStringA( ctypes.byref(addr), addr_size, None, ip_string, ctypes.byref(ip_string_size) ) != 0: raise socket.error(ctypes.FormatError()) return ip_string[:ip_string_size.value - 1] import keyword # similar to collections.namedtuple, reproduced here because Python 2.6 did not have the rename logic def _positional_rename_invalid_identifiers(field_names): names_out = list(field_names) for index, name in enumerate(field_names): if (not all(c.isalnum() or c == '_' for c in name) or keyword.iskeyword(name) or not name or name[0].isdigit() or name.startswith('_')): names_out[index] = 'field_%d_' % index return names_out def _sanitize_identifiers(field_names): names_out = _positional_rename_invalid_identifiers(field_names) if len(names_out) != len(set(names_out)): observed_names = set() for index, name in enumerate(names_out): while names_out[index] in observed_names: names_out[index] = "%s_" % (names_out[index],) observed_names.add(names_out[index]) return names_out cassandra-driver-3.7.1/cassandra/deserializers.pxd0000664000175000017500000000323512743410406025141 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from cassandra.buffer cimport Buffer cdef class Deserializer: # The cqltypes._CassandraType corresponding to this deserializer cdef object cqltype # String may be empty, whereas other values may not be. # Other values may be NULL, in which case the integer length # of the binary data is negative. However, non-string types # may also return a zero length for legacy reasons # (see http://code.metager.de/source/xref/apache/cassandra/doc/native_protocol_v3.spec # paragraph 6) cdef bint empty_binary_ok cdef deserialize(self, Buffer *buf, int protocol_version) # cdef deserialize(self, CString byts, protocol_version) cdef inline object from_binary(Deserializer deserializer, Buffer *buf, int protocol_version): if buf.size < 0: return None elif buf.size == 0 and not deserializer.empty_binary_ok: return _ret_empty(deserializer, buf.size) else: return deserializer.deserialize(buf, protocol_version) cdef _ret_empty(Deserializer deserializer, Py_ssize_t buf_size) cassandra-driver-3.7.1/cassandra/cython_deps.py0000664000175000017500000000033412743410406024445 0ustar aboudreaultaboudreault00000000000000try: from cassandra.row_parser import make_recv_results_rows HAVE_CYTHON = True except ImportError: HAVE_CYTHON = False try: import numpy HAVE_NUMPY = True except ImportError: HAVE_NUMPY = False cassandra-driver-3.7.1/cassandra/cqltypes.py0000664000175000017500000010341213004141114023757 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Representation of Cassandra data types. These classes should make it simple for the library (and caller software) to deal with Cassandra-style Java class type names and CQL type specifiers, and convert between them cleanly. Parameterized types are fully supported in both flavors. Once you have the right Type object for the type you want, you can use it to serialize, deserialize, or retrieve the corresponding CQL or Cassandra type strings. """ # NOTE: # If/when the need arises for interpret types from CQL string literals in # different ways (for https://issues.apache.org/jira/browse/CASSANDRA-3799, # for example), these classes would be a good place to tack on # .from_cql_literal() and .as_cql_literal() classmethods (or whatever). from __future__ import absolute_import # to enable import io from stdlib from binascii import unhexlify import calendar from collections import namedtuple from decimal import Decimal import io import logging import re import socket import time import six from six.moves import range import sys from uuid import UUID import warnings from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack, uint16_pack, uint16_unpack, uint32_pack, uint32_unpack, int32_pack, int32_unpack, int64_pack, int64_unpack, float_pack, float_unpack, double_pack, double_unpack, varint_pack, varint_unpack) from cassandra import util apache_cassandra_type_prefix = 'org.apache.cassandra.db.marshal.' cassandra_empty_type = 'org.apache.cassandra.db.marshal.EmptyType' cql_empty_type = 'empty' log = logging.getLogger(__name__) if six.PY3: _number_types = frozenset((int, float)) long = int def _name_from_hex_string(encoded_name): bin_str = unhexlify(encoded_name) return bin_str.decode('ascii') else: _number_types = frozenset((int, long, float)) _name_from_hex_string = unhexlify def trim_if_startswith(s, prefix): if s.startswith(prefix): return s[len(prefix):] return s _casstypes = {} _cqltypes = {} cql_type_scanner = re.Scanner(( ('frozen', None), (r'[a-zA-Z0-9_]+', lambda s, t: t), (r'[\s,<>]', None), )) def cql_types_from_string(cql_type): return cql_type_scanner.scan(cql_type)[0] class CassandraTypeType(type): """ The CassandraType objects in this module will normally be used directly, rather than through instances of those types. They can be instantiated, of course, but the type information is what this driver mainly needs. This metaclass registers CassandraType classes in the global by-cassandra-typename and by-cql-typename registries, unless their class name starts with an underscore. """ def __new__(metacls, name, bases, dct): dct.setdefault('cassname', name) cls = type.__new__(metacls, name, bases, dct) if not name.startswith('_'): _casstypes[name] = cls if not cls.typename.startswith(apache_cassandra_type_prefix): _cqltypes[cls.typename] = cls return cls casstype_scanner = re.Scanner(( (r'[()]', lambda s, t: t), (r'[a-zA-Z0-9_.:=>]+', lambda s, t: t), (r'[\s,]', None), )) def lookup_casstype_simple(casstype): """ Given a Cassandra type name (either fully distinguished or not), hand back the CassandraType class responsible for it. If a name is not recognized, a custom _UnrecognizedType subclass will be created for it. This function does not handle complex types (so no type parameters-- nothing with parentheses). Use lookup_casstype() instead if you might need that. """ shortname = trim_if_startswith(casstype, apache_cassandra_type_prefix) try: typeclass = _casstypes[shortname] except KeyError: typeclass = mkUnrecognizedType(casstype) return typeclass def parse_casstype_args(typestring): tokens, remainder = casstype_scanner.scan(typestring) if remainder: raise ValueError("weird characters %r at end" % remainder) # use a stack of (types, names) lists args = [([], [])] for tok in tokens: if tok == '(': args.append(([], [])) elif tok == ')': types, names = args.pop() prev_types, prev_names = args[-1] prev_types[-1] = prev_types[-1].apply_parameters(types, names) else: types, names = args[-1] parts = re.split(':|=>', tok) tok = parts.pop() if parts: names.append(parts[0]) else: names.append(None) ctype = lookup_casstype_simple(tok) types.append(ctype) # return the first (outer) type, which will have all parameters applied return args[0][0][0] def lookup_casstype(casstype): """ Given a Cassandra type as a string (possibly including parameters), hand back the CassandraType class responsible for it. If a name is not recognized, a custom _UnrecognizedType subclass will be created for it. Example: >>> lookup_casstype('org.apache.cassandra.db.marshal.MapType(org.apache.cassandra.db.marshal.UTF8Type,org.apache.cassandra.db.marshal.Int32Type)') """ if isinstance(casstype, (CassandraType, CassandraTypeType)): return casstype try: return parse_casstype_args(casstype) except (ValueError, AssertionError, IndexError) as e: raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e)) def is_reversed_casstype(data_type): return issubclass(data_type, ReversedType) class EmptyValue(object): """ See _CassandraType.support_empty_values """ def __str__(self): return "EMPTY" __repr__ = __str__ EMPTY = EmptyValue() @six.add_metaclass(CassandraTypeType) class _CassandraType(object): subtypes = () num_subtypes = 0 empty_binary_ok = False support_empty_values = False """ Back in the Thrift days, empty strings were used for "null" values of all types, including non-string types. For most users, an empty string value in an int column is the same as being null/not present, so the driver normally returns None in this case. (For string-like types, it *will* return an empty string by default instead of None.) To avoid this behavior, set this to :const:`True`. Instead of returning None for empty string values, the EMPTY singleton (an instance of EmptyValue) will be returned. """ def __repr__(self): return '<%s( %r )>' % (self.cql_parameterized_type(), self.val) @classmethod def from_binary(cls, byts, protocol_version): """ Deserialize a bytestring into a value. See the deserialize() method for more information. This method differs in that if None or the empty string is passed in, None may be returned. """ if byts is None: return None elif len(byts) == 0 and not cls.empty_binary_ok: return EMPTY if cls.support_empty_values else None return cls.deserialize(byts, protocol_version) @classmethod def to_binary(cls, val, protocol_version): """ Serialize a value into a bytestring. See the serialize() method for more information. This method differs in that if None is passed in, the result is the empty string. """ return b'' if val is None else cls.serialize(val, protocol_version) @staticmethod def deserialize(byts, protocol_version): """ Given a bytestring, deserialize into a value according to the protocol for this type. Note that this does not create a new instance of this class; it merely gives back a value that would be appropriate to go inside an instance of this class. """ return byts @staticmethod def serialize(val, protocol_version): """ Given a value appropriate for this class, serialize it according to the protocol for this type and return the corresponding bytestring. """ return val @classmethod def cass_parameterized_type_with(cls, subtypes, full=False): """ Return the name of this type as it would be expressed by Cassandra, optionally fully qualified. If subtypes is not None, it is expected to be a list of other CassandraType subclasses, and the output string includes the Cassandra names for those subclasses as well, as parameters to this one. Example: >>> LongType.cass_parameterized_type_with(()) 'LongType' >>> LongType.cass_parameterized_type_with((), full=True) 'org.apache.cassandra.db.marshal.LongType' >>> SetType.cass_parameterized_type_with([DecimalType], full=True) 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)' """ cname = cls.cassname if full and '.' not in cname: cname = apache_cassandra_type_prefix + cname if not subtypes: return cname sublist = ', '.join(styp.cass_parameterized_type(full=full) for styp in subtypes) return '%s(%s)' % (cname, sublist) @classmethod def apply_parameters(cls, subtypes, names=None): """ Given a set of other CassandraTypes, create a new subtype of this type using them as parameters. This is how composite types are constructed. >>> MapType.apply_parameters([DateType, BooleanType]) `subtypes` will be a sequence of CassandraTypes. If provided, `names` will be an equally long sequence of column names or Nones. """ if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes: raise ValueError("%s types require %d subtypes (%d given)" % (cls.typename, cls.num_subtypes, len(subtypes))) newname = cls.cass_parameterized_type_with(subtypes) if six.PY2 and isinstance(newname, unicode): newname = newname.encode('utf-8') return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names}) @classmethod def cql_parameterized_type(cls): """ Return a CQL type specifier for this type. If this type has parameters, they are included in standard CQL <> notation. """ if not cls.subtypes: return cls.typename return '%s<%s>' % (cls.typename, ', '.join(styp.cql_parameterized_type() for styp in cls.subtypes)) @classmethod def cass_parameterized_type(cls, full=False): """ Return a Cassandra type specifier for this type. If this type has parameters, they are included in the standard () notation. """ return cls.cass_parameterized_type_with(cls.subtypes, full=full) # it's initially named with a _ to avoid registering it as a real type, but # client programs may want to use the name still for isinstance(), etc CassandraType = _CassandraType class _UnrecognizedType(_CassandraType): num_subtypes = 'UNKNOWN' if six.PY3: def mkUnrecognizedType(casstypename): return CassandraTypeType(casstypename, (_UnrecognizedType,), {'typename': "'%s'" % casstypename}) else: def mkUnrecognizedType(casstypename): # noqa return CassandraTypeType(casstypename.encode('utf8'), (_UnrecognizedType,), {'typename': "'%s'" % casstypename}) class BytesType(_CassandraType): typename = 'blob' empty_binary_ok = True @staticmethod def serialize(val, protocol_version): return six.binary_type(val) class DecimalType(_CassandraType): typename = 'decimal' @staticmethod def deserialize(byts, protocol_version): scale = int32_unpack(byts[:4]) unscaled = varint_unpack(byts[4:]) return Decimal('%de%d' % (unscaled, -scale)) @staticmethod def serialize(dec, protocol_version): try: sign, digits, exponent = dec.as_tuple() except AttributeError: try: sign, digits, exponent = Decimal(dec).as_tuple() except Exception: raise TypeError("Invalid type for Decimal value: %r", dec) unscaled = int(''.join([str(digit) for digit in digits])) if sign: unscaled *= -1 scale = int32_pack(-exponent) unscaled = varint_pack(unscaled) return scale + unscaled class UUIDType(_CassandraType): typename = 'uuid' @staticmethod def deserialize(byts, protocol_version): return UUID(bytes=byts) @staticmethod def serialize(uuid, protocol_version): try: return uuid.bytes except AttributeError: raise TypeError("Got a non-UUID object for a UUID value") class BooleanType(_CassandraType): typename = 'boolean' @staticmethod def deserialize(byts, protocol_version): return bool(int8_unpack(byts)) @staticmethod def serialize(truth, protocol_version): return int8_pack(truth) class ByteType(_CassandraType): typename = 'tinyint' @staticmethod def deserialize(byts, protocol_version): return int8_unpack(byts) @staticmethod def serialize(byts, protocol_version): return int8_pack(byts) if six.PY2: class AsciiType(_CassandraType): typename = 'ascii' empty_binary_ok = True else: class AsciiType(_CassandraType): typename = 'ascii' empty_binary_ok = True @staticmethod def deserialize(byts, protocol_version): return byts.decode('ascii') @staticmethod def serialize(var, protocol_version): try: return var.encode('ascii') except UnicodeDecodeError: return var class FloatType(_CassandraType): typename = 'float' @staticmethod def deserialize(byts, protocol_version): return float_unpack(byts) @staticmethod def serialize(byts, protocol_version): return float_pack(byts) class DoubleType(_CassandraType): typename = 'double' @staticmethod def deserialize(byts, protocol_version): return double_unpack(byts) @staticmethod def serialize(byts, protocol_version): return double_pack(byts) class LongType(_CassandraType): typename = 'bigint' @staticmethod def deserialize(byts, protocol_version): return int64_unpack(byts) @staticmethod def serialize(byts, protocol_version): return int64_pack(byts) class Int32Type(_CassandraType): typename = 'int' @staticmethod def deserialize(byts, protocol_version): return int32_unpack(byts) @staticmethod def serialize(byts, protocol_version): return int32_pack(byts) class IntegerType(_CassandraType): typename = 'varint' @staticmethod def deserialize(byts, protocol_version): return varint_unpack(byts) @staticmethod def serialize(byts, protocol_version): return varint_pack(byts) class InetAddressType(_CassandraType): typename = 'inet' @staticmethod def deserialize(byts, protocol_version): if len(byts) == 16: return util.inet_ntop(socket.AF_INET6, byts) else: # util.inet_pton could also handle, but this is faster # since we've already determined the AF return socket.inet_ntoa(byts) @staticmethod def serialize(addr, protocol_version): if ':' in addr: return util.inet_pton(socket.AF_INET6, addr) else: # util.inet_pton could also handle, but this is faster # since we've already determined the AF return socket.inet_aton(addr) class CounterColumnType(LongType): typename = 'counter' cql_timestamp_formats = ( '%Y-%m-%d %H:%M', '%Y-%m-%d %H:%M:%S', '%Y-%m-%dT%H:%M', '%Y-%m-%dT%H:%M:%S', '%Y-%m-%d' ) _have_warned_about_timestamps = False class DateType(_CassandraType): typename = 'timestamp' @staticmethod def interpret_datestring(val): if val[-5] in ('+', '-'): offset = (int(val[-4:-2]) * 3600 + int(val[-2:]) * 60) * int(val[-5] + '1') val = val[:-5] else: offset = -time.timezone for tformat in cql_timestamp_formats: try: tval = time.strptime(val, tformat) except ValueError: continue # scale seconds to millis for the raw value return (calendar.timegm(tval) + offset) * 1e3 else: raise ValueError("can't interpret %r as a date" % (val,)) @staticmethod def deserialize(byts, protocol_version): timestamp = int64_unpack(byts) / 1000.0 return util.datetime_from_timestamp(timestamp) @staticmethod def serialize(v, protocol_version): try: # v is datetime timestamp_seconds = calendar.timegm(v.utctimetuple()) timestamp = timestamp_seconds * 1e3 + getattr(v, 'microsecond', 0) / 1e3 except AttributeError: try: timestamp = calendar.timegm(v.timetuple()) * 1e3 except AttributeError: # Ints and floats are valid timestamps too if type(v) not in _number_types: raise TypeError('DateType arguments must be a datetime, date, or timestamp') timestamp = v return int64_pack(long(timestamp)) class TimestampType(DateType): pass class TimeUUIDType(DateType): typename = 'timeuuid' def my_timestamp(self): return util.unix_time_from_uuid1(self.val) @staticmethod def deserialize(byts, protocol_version): return UUID(bytes=byts) @staticmethod def serialize(timeuuid, protocol_version): try: return timeuuid.bytes except AttributeError: raise TypeError("Got a non-UUID object for a UUID value") class SimpleDateType(_CassandraType): typename = 'date' date_format = "%Y-%m-%d" # Values of the 'date'` type are encoded as 32-bit unsigned integers # representing a number of days with epoch (January 1st, 1970) at the center of the # range (2^31). EPOCH_OFFSET_DAYS = 2 ** 31 @staticmethod def deserialize(byts, protocol_version): days = uint32_unpack(byts) - SimpleDateType.EPOCH_OFFSET_DAYS return util.Date(days) @staticmethod def serialize(val, protocol_version): try: days = val.days_from_epoch except AttributeError: if isinstance(val, six.integer_types): # the DB wants offset int values, but util.Date init takes days from epoch # here we assume int values are offset, as they would appear in CQL # short circuit to avoid subtracting just to add offset return uint32_pack(val) days = util.Date(val).days_from_epoch return uint32_pack(days + SimpleDateType.EPOCH_OFFSET_DAYS) class ShortType(_CassandraType): typename = 'smallint' @staticmethod def deserialize(byts, protocol_version): return int16_unpack(byts) @staticmethod def serialize(byts, protocol_version): return int16_pack(byts) class TimeType(_CassandraType): typename = 'time' @staticmethod def deserialize(byts, protocol_version): return util.Time(int64_unpack(byts)) @staticmethod def serialize(val, protocol_version): try: nano = val.nanosecond_time except AttributeError: nano = util.Time(val).nanosecond_time return int64_pack(nano) class UTF8Type(_CassandraType): typename = 'text' empty_binary_ok = True @staticmethod def deserialize(byts, protocol_version): return byts.decode('utf8') @staticmethod def serialize(ustr, protocol_version): try: return ustr.encode('utf-8') except UnicodeDecodeError: # already utf-8 return ustr class VarcharType(UTF8Type): typename = 'varchar' class _ParameterizedType(_CassandraType): num_subtypes = 'UNKNOWN' @classmethod def deserialize(cls, byts, protocol_version): if not cls.subtypes: raise NotImplementedError("can't deserialize unparameterized %s" % cls.typename) return cls.deserialize_safe(byts, protocol_version) @classmethod def serialize(cls, val, protocol_version): if not cls.subtypes: raise NotImplementedError("can't serialize unparameterized %s" % cls.typename) return cls.serialize_safe(val, protocol_version) class _SimpleParameterizedType(_ParameterizedType): @classmethod def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes if protocol_version >= 3: unpack = int32_unpack length = 4 else: unpack = uint16_unpack length = 2 numelements = unpack(byts[:length]) p = length result = [] inner_proto = max(3, protocol_version) for _ in range(numelements): itemlen = unpack(byts[p:p + length]) p += length item = byts[p:p + itemlen] p += itemlen result.append(subtype.from_binary(item, inner_proto)) return cls.adapter(result) @classmethod def serialize_safe(cls, items, protocol_version): if isinstance(items, six.string_types): raise TypeError("Received a string for a type that expects a sequence") subtype, = cls.subtypes pack = int32_pack if protocol_version >= 3 else uint16_pack buf = io.BytesIO() buf.write(pack(len(items))) inner_proto = max(3, protocol_version) for item in items: itembytes = subtype.to_binary(item, inner_proto) buf.write(pack(len(itembytes))) buf.write(itembytes) return buf.getvalue() class ListType(_SimpleParameterizedType): typename = 'list' num_subtypes = 1 adapter = list class SetType(_SimpleParameterizedType): typename = 'set' num_subtypes = 1 adapter = util.sortedset class MapType(_ParameterizedType): typename = 'map' num_subtypes = 2 @classmethod def deserialize_safe(cls, byts, protocol_version): key_type, value_type = cls.subtypes if protocol_version >= 3: unpack = int32_unpack length = 4 else: unpack = uint16_unpack length = 2 numelements = unpack(byts[:length]) p = length themap = util.OrderedMapSerializedKey(key_type, protocol_version) inner_proto = max(3, protocol_version) for _ in range(numelements): key_len = unpack(byts[p:p + length]) p += length keybytes = byts[p:p + key_len] p += key_len val_len = unpack(byts[p:p + length]) p += length valbytes = byts[p:p + val_len] p += val_len key = key_type.from_binary(keybytes, inner_proto) val = value_type.from_binary(valbytes, inner_proto) themap._insert_unchecked(key, keybytes, val) return themap @classmethod def serialize_safe(cls, themap, protocol_version): key_type, value_type = cls.subtypes pack = int32_pack if protocol_version >= 3 else uint16_pack buf = io.BytesIO() buf.write(pack(len(themap))) try: items = six.iteritems(themap) except AttributeError: raise TypeError("Got a non-map object for a map value") inner_proto = max(3, protocol_version) for key, val in items: keybytes = key_type.to_binary(key, inner_proto) valbytes = value_type.to_binary(val, inner_proto) buf.write(pack(len(keybytes))) buf.write(keybytes) buf.write(pack(len(valbytes))) buf.write(valbytes) return buf.getvalue() class TupleType(_ParameterizedType): typename = 'tuple' @classmethod def deserialize_safe(cls, byts, protocol_version): proto_version = max(3, protocol_version) p = 0 values = [] for col_type in cls.subtypes: if p == len(byts): break itemlen = int32_unpack(byts[p:p + 4]) p += 4 if itemlen >= 0: item = byts[p:p + itemlen] p += itemlen else: item = None # collections inside UDTs are always encoded with at least the # version 3 format values.append(col_type.from_binary(item, proto_version)) if len(values) < len(cls.subtypes): nones = [None] * (len(cls.subtypes) - len(values)) values = values + nones return tuple(values) @classmethod def serialize_safe(cls, val, protocol_version): if len(val) > len(cls.subtypes): raise ValueError("Expected %d items in a tuple, but got %d: %s" % (len(cls.subtypes), len(val), val)) proto_version = max(3, protocol_version) buf = io.BytesIO() for item, subtype in zip(val, cls.subtypes): if item is not None: packed_item = subtype.to_binary(item, proto_version) buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: buf.write(int32_pack(-1)) return buf.getvalue() @classmethod def cql_parameterized_type(cls): subtypes_string = ', '.join(sub.cql_parameterized_type() for sub in cls.subtypes) return 'frozen>' % (subtypes_string,) class UserType(TupleType): typename = "org.apache.cassandra.db.marshal.UserType" _cache = {} _module = sys.modules[__name__] @classmethod def make_udt_class(cls, keyspace, udt_name, field_names, field_types): assert len(field_names) == len(field_types) if six.PY2 and isinstance(udt_name, unicode): udt_name = udt_name.encode('utf-8') instance = cls._cache.get((keyspace, udt_name)) if not instance or instance.fieldnames != field_names or instance.subtypes != field_types: instance = type(udt_name, (cls,), {'subtypes': field_types, 'cassname': cls.cassname, 'typename': udt_name, 'fieldnames': field_names, 'keyspace': keyspace, 'mapped_class': None, 'tuple_type': cls._make_registered_udt_namedtuple(keyspace, udt_name, field_names)}) cls._cache[(keyspace, udt_name)] = instance return instance @classmethod def evict_udt_class(cls, keyspace, udt_name): if six.PY2 and isinstance(udt_name, unicode): udt_name = udt_name.encode('utf-8') try: del cls._cache[(keyspace, udt_name)] except KeyError: pass @classmethod def apply_parameters(cls, subtypes, names): keyspace = subtypes[0].cass_parameterized_type() # when parsed from cassandra type, the keyspace is created as an unrecognized cass type; This gets the name back udt_name = _name_from_hex_string(subtypes[1].cassname) field_names = tuple(_name_from_hex_string(encoded_name) for encoded_name in names[2:]) # using tuple here to match what comes into make_udt_class from other sources (for caching equality test) return cls.make_udt_class(keyspace, udt_name, field_names, tuple(subtypes[2:])) @classmethod def cql_parameterized_type(cls): return "frozen<%s>" % (cls.typename,) @classmethod def deserialize_safe(cls, byts, protocol_version): values = super(UserType, cls).deserialize_safe(byts, protocol_version) if cls.mapped_class: return cls.mapped_class(**dict(zip(cls.fieldnames, values))) elif cls.tuple_type: return cls.tuple_type(*values) else: return tuple(values) @classmethod def serialize_safe(cls, val, protocol_version): proto_version = max(3, protocol_version) buf = io.BytesIO() for i, (fieldname, subtype) in enumerate(zip(cls.fieldnames, cls.subtypes)): # first treat as a tuple, else by custom type try: item = val[i] except TypeError: item = getattr(val, fieldname) if item is not None: packed_item = subtype.to_binary(item, proto_version) buf.write(int32_pack(len(packed_item))) buf.write(packed_item) else: buf.write(int32_pack(-1)) return buf.getvalue() @classmethod def _make_registered_udt_namedtuple(cls, keyspace, name, field_names): # this is required to make the type resolvable via this module... # required when unregistered udts are pickled for use as keys in # util.OrderedMap t = cls._make_udt_tuple_type(name, field_names) if t: qualified_name = "%s_%s" % (keyspace, name) setattr(cls._module, qualified_name, t) return t @classmethod def _make_udt_tuple_type(cls, name, field_names): # fallback to positional named, then unnamed tuples # for CQL identifiers that aren't valid in Python, try: t = namedtuple(name, field_names) except ValueError: try: t = namedtuple(name, util._positional_rename_invalid_identifiers(field_names)) log.warn("could not create a namedtuple for '%s' because one or more field names are not valid Python identifiers (%s); " \ "returning positionally-named fields" % (name, field_names)) except ValueError: t = None log.warn("could not create a namedtuple for '%s' because the name is not a valid Python identifier; " \ "will return tuples in its place" % (name,)) return t class CompositeType(_ParameterizedType): typename = "org.apache.cassandra.db.marshal.CompositeType" @classmethod def cql_parameterized_type(cls): """ There is no CQL notation for Composites, so we override this. """ typestring = cls.cass_parameterized_type(full=True) return "'%s'" % (typestring,) @classmethod def deserialize_safe(cls, byts, protocol_version): result = [] for subtype in cls.subtypes: if not byts: # CompositeType can have missing elements at the end break element_length = uint16_unpack(byts[:2]) element = byts[2:2 + element_length] # skip element length, element, and the EOC (one byte) byts = byts[2 + element_length + 1:] result.append(subtype.from_binary(element, protocol_version)) return tuple(result) class DynamicCompositeType(_ParameterizedType): typename = "org.apache.cassandra.db.marshal.DynamicCompositeType" @classmethod def cql_parameterized_type(cls): sublist = ', '.join('%s=>%s' % (alias, typ.cass_parameterized_type(full=True)) for alias, typ in zip(cls.fieldnames, cls.subtypes)) return "'%s(%s)'" % (cls.typename, sublist) class ColumnToCollectionType(_ParameterizedType): """ This class only really exists so that we can cleanly evaluate types when Cassandra includes this. We don't actually need or want the extra information. """ typename = "org.apache.cassandra.db.marshal.ColumnToCollectionType" class ReversedType(_ParameterizedType): typename = "org.apache.cassandra.db.marshal.ReversedType" num_subtypes = 1 @classmethod def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes return subtype.from_binary(byts) @classmethod def serialize_safe(cls, val, protocol_version): subtype, = cls.subtypes return subtype.to_binary(val, protocol_version) class FrozenType(_ParameterizedType): typename = "frozen" num_subtypes = 1 @classmethod def deserialize_safe(cls, byts, protocol_version): subtype, = cls.subtypes return subtype.from_binary(byts) @classmethod def serialize_safe(cls, val, protocol_version): subtype, = cls.subtypes return subtype.to_binary(val, protocol_version) def is_counter_type(t): if isinstance(t, six.string_types): t = lookup_casstype(t) return issubclass(t, CounterColumnType) def cql_typename(casstypename): """ Translate a Cassandra-style type specifier (optionally-fully-distinguished Java class names for data types, along with optional parameters) into a CQL-style type specifier. >>> cql_typename('DateType') 'timestamp' >>> cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)') 'list' """ return lookup_casstype(casstypename).cql_parameterized_type() cassandra-driver-3.7.1/cassandra/io/0000775000175000017500000000000013004144417022160 5ustar aboudreaultaboudreault00000000000000cassandra-driver-3.7.1/cassandra/io/libevreactor.py0000664000175000017500000002721012766043657025237 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import atexit from collections import deque from functools import partial import logging import os import socket import ssl from threading import Lock, Thread import time import weakref from six.moves import range from cassandra.connection import (Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager) try: import cassandra.io.libevwrapper as libev except ImportError: raise ImportError( "The C extension needed to use libev was not found. This " "probably means that you didn't have the required build dependencies " "when installing the driver. See " "http://datastax.github.io/python-driver/installation.html#c-extensions " "for instructions on installing build dependencies and building " "the C extension.") log = logging.getLogger(__name__) def _cleanup(loop_weakref): try: loop = loop_weakref() except ReferenceError: return loop._cleanup() class LibevLoop(object): def __init__(self): self._pid = os.getpid() self._loop = libev.Loop() self._notifier = libev.Async(self._loop) self._notifier.start() # prevent _notifier from keeping the loop from returning self._loop.unref() self._started = False self._shutdown = False self._lock = Lock() self._thread = None # set of all connections; only replaced with a new copy # while holding _conn_set_lock, never modified in place self._live_conns = set() # newly created connections that need their write/read watcher started self._new_conns = set() # recently closed connections that need their write/read watcher stopped self._closed_conns = set() self._conn_set_lock = Lock() self._preparer = libev.Prepare(self._loop, self._loop_will_run) # prevent _preparer from keeping the loop from returning self._loop.unref() self._preparer.start() self._timers = TimerManager() self._loop_timer = libev.Timer(self._loop, self._on_loop_timer) atexit.register(partial(_cleanup, weakref.ref(self))) def maybe_start(self): should_start = False with self._lock: if not self._started: log.debug("Starting libev event loop") self._started = True should_start = True if should_start: self._thread = Thread(target=self._run_loop, name="event_loop") self._thread.daemon = True self._thread.start() self._notifier.send() def _run_loop(self): while True: self._loop.start() # there are still active watchers, no deadlock with self._lock: if not self._shutdown and self._live_conns: log.debug("Restarting event loop") continue else: # all Connections have been closed, no active watchers log.debug("All Connections currently closed, event loop ended") self._started = False break def _cleanup(self): self._shutdown = True if not self._thread: return for conn in self._live_conns | self._new_conns | self._closed_conns: conn.close() map(lambda w: w.stop(), (w for w in (conn._write_watcher, conn._read_watcher) if w)) self.notify() # wake the timer watcher log.debug("Waiting for event loop thread to join...") self._thread.join(timeout=1.0) if self._thread.is_alive(): log.warning( "Event loop thread could not be joined, so shutdown may not be clean. " "Please call Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") def add_timer(self, timer): self._timers.add_timer(timer) self._notifier.send() # wake up in case this timer is earlier def _update_timer(self): if not self._shutdown: next_end = self._timers.service_timeouts() if next_end: self._loop_timer.start(next_end - time.time()) # timer handles negative values else: self._loop_timer.stop() def _on_loop_timer(self): self._timers.service_timeouts() def notify(self): self._notifier.send() def connection_created(self, conn): with self._conn_set_lock: new_live_conns = self._live_conns.copy() new_live_conns.add(conn) self._live_conns = new_live_conns new_new_conns = self._new_conns.copy() new_new_conns.add(conn) self._new_conns = new_new_conns def connection_destroyed(self, conn): with self._conn_set_lock: new_live_conns = self._live_conns.copy() new_live_conns.discard(conn) self._live_conns = new_live_conns new_closed_conns = self._closed_conns.copy() new_closed_conns.add(conn) self._closed_conns = new_closed_conns self._notifier.send() def _loop_will_run(self, prepare): changed = False for conn in self._live_conns: if not conn.deque and conn._write_watcher_is_active: if conn._write_watcher: conn._write_watcher.stop() conn._write_watcher_is_active = False changed = True elif conn.deque and not conn._write_watcher_is_active: conn._write_watcher.start() conn._write_watcher_is_active = True changed = True if self._new_conns: with self._conn_set_lock: to_start = self._new_conns self._new_conns = set() for conn in to_start: conn._read_watcher.start() changed = True if self._closed_conns: with self._conn_set_lock: to_stop = self._closed_conns self._closed_conns = set() for conn in to_stop: if conn._write_watcher: conn._write_watcher.stop() # clear reference cycles from IO callback del conn._write_watcher if conn._read_watcher: conn._read_watcher.stop() # clear reference cycles from IO callback del conn._read_watcher changed = True # TODO: update to do connection management, timer updates through dedicated async 'notifier' callbacks self._update_timer() if changed: self._notifier.send() class LibevConnection(Connection): """ An implementation of :class:`.Connection` that uses libev for its event loop. """ _libevloop = None _write_watcher_is_active = False _read_watcher = None _write_watcher = None _socket = None @classmethod def initialize_reactor(cls): if not cls._libevloop: cls._libevloop = LibevLoop() else: if cls._libevloop._pid != os.getpid(): log.debug("Detected fork, clearing and reinitializing reactor state") cls.handle_fork() cls._libevloop = LibevLoop() @classmethod def handle_fork(cls): if cls._libevloop: cls._libevloop._cleanup() cls._libevloop = None @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) cls._libevloop.add_timer(timer) return timer def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self.deque = deque() self._deque_lock = Lock() self._connect_socket() self._socket.setblocking(0) with self._libevloop._lock: self._read_watcher = libev.IO(self._socket.fileno(), libev.EV_READ, self._libevloop._loop, self.handle_read) self._write_watcher = libev.IO(self._socket.fileno(), libev.EV_WRITE, self._libevloop._loop, self.handle_write) self._send_options_message() self._libevloop.connection_created(self) # start the global event loop if needed self._libevloop.maybe_start() def close(self): with self.lock: if self.is_closed: return self.is_closed = True log.debug("Closing connection (%s) to %s", id(self), self.host) self._libevloop.connection_destroyed(self) self._socket.close() log.debug("Closed socket to %s", self.host) # don't leave in-progress operations hanging if not self.is_defunct: self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) def handle_write(self, watcher, revents, errno=None): if revents & libev.EV_ERROR: if errno: exc = IOError(errno, os.strerror(errno)) else: exc = Exception("libev reported an error") self.defunct(exc) return while True: try: with self._deque_lock: next_msg = self.deque.popleft() except IndexError: return try: sent = self._socket.send(next_msg) except socket.error as err: if (err.args[0] in NONBLOCKING): with self._deque_lock: self.deque.appendleft(next_msg) else: self.defunct(err) return else: if sent < len(next_msg): with self._deque_lock: self.deque.appendleft(next_msg[sent:]) def handle_read(self, watcher, revents, errno=None): if revents & libev.EV_ERROR: if errno: exc = IOError(errno, os.strerror(errno)) else: exc = Exception("libev reported an error") self.defunct(exc) return try: while True: buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) if len(buf) < self.in_buffer_size: break except socket.error as err: if ssl and isinstance(err, ssl.SSLError): if err.args[0] not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): self.defunct(err) return elif err.args[0] not in NONBLOCKING: self.defunct(err) return if self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) self.close() def push(self, data): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] for i in range(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] with self._deque_lock: self.deque.extend(chunks) self._libevloop.notify() cassandra-driver-3.7.1/cassandra/io/libevwrapper.c0000664000175000017500000005351412743410406025041 0ustar aboudreaultaboudreault00000000000000#include #include typedef struct libevwrapper_Loop { PyObject_HEAD struct ev_loop *loop; } libevwrapper_Loop; static void Loop_dealloc(libevwrapper_Loop *self) { ev_loop_destroy(self->loop); Py_TYPE(self)->tp_free((PyObject *)self); }; static PyObject* Loop_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { libevwrapper_Loop *self; self = (libevwrapper_Loop *)type->tp_alloc(type, 0); if (self != NULL) { self->loop = ev_loop_new(EVBACKEND_SELECT); if (!self->loop) { PyErr_SetString(PyExc_Exception, "Error getting new ev loop"); Py_DECREF(self); return NULL; } } return (PyObject *)self; }; static int Loop_init(libevwrapper_Loop *self, PyObject *args, PyObject *kwds) { if (!PyArg_ParseTuple(args, "")) { PyErr_SetString(PyExc_TypeError, "Loop.__init__() takes no arguments"); return -1; } return 0; }; static PyObject * Loop_start(libevwrapper_Loop *self, PyObject *args) { Py_BEGIN_ALLOW_THREADS ev_run(self->loop, 0); Py_END_ALLOW_THREADS Py_RETURN_NONE; }; static PyObject * Loop_unref(libevwrapper_Loop *self, PyObject *args) { ev_unref(self->loop); Py_RETURN_NONE; } static PyMethodDef Loop_methods[] = { {"start", (PyCFunction)Loop_start, METH_NOARGS, "Start the event loop"}, {"unref", (PyCFunction)Loop_unref, METH_NOARGS, "Unrefrence the event loop"}, {NULL} /* Sentinel */ }; static PyTypeObject libevwrapper_LoopType = { PyVarObject_HEAD_INIT(NULL, 0) "cassandra.io.libevwrapper.Loop",/*tp_name*/ sizeof(libevwrapper_Loop), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)Loop_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash */ 0, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ "Loop objects", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ Loop_methods, /* tp_methods */ 0, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)Loop_init, /* tp_init */ 0, /* tp_alloc */ Loop_new, /* tp_new */ }; typedef struct libevwrapper_IO { PyObject_HEAD struct ev_io io; struct libevwrapper_Loop *loop; PyObject *callback; } libevwrapper_IO; static void IO_dealloc(libevwrapper_IO *self) { Py_XDECREF(self->loop); Py_XDECREF(self->callback); Py_TYPE(self)->tp_free((PyObject *)self); }; static void io_callback(struct ev_loop *loop, ev_io *watcher, int revents) { libevwrapper_IO *self = watcher->data; PyObject *result; PyGILState_STATE gstate = PyGILState_Ensure(); if (revents & EV_ERROR && errno) { result = PyObject_CallFunction(self->callback, "Obi", self, revents, errno); } else { result = PyObject_CallFunction(self->callback, "Ob", self, revents); } if (!result) { PyErr_WriteUnraisable(self->callback); } Py_XDECREF(result); PyGILState_Release(gstate); }; static int IO_init(libevwrapper_IO *self, PyObject *args, PyObject *kwds) { PyObject *socket; PyObject *callback; PyObject *loop; int io_flags = 0, fd = -1; struct ev_io *io = NULL; if (!PyArg_ParseTuple(args, "OiOO", &socket, &io_flags, &loop, &callback)) { return -1; } if (loop) { Py_INCREF(loop); self->loop = (libevwrapper_Loop *)loop; } if (callback) { if (!PyCallable_Check(callback)) { PyErr_SetString(PyExc_TypeError, "callback parameter must be callable"); Py_XDECREF(loop); return -1; } Py_INCREF(callback); self->callback = callback; } fd = PyObject_AsFileDescriptor(socket); if (fd == -1) { PyErr_SetString(PyExc_TypeError, "unable to get file descriptor from socket"); Py_XDECREF(callback); Py_XDECREF(loop); return -1; } io = &(self->io); ev_io_init(io, io_callback, fd, io_flags); self->io.data = self; return 0; } static PyObject* IO_start(libevwrapper_IO *self, PyObject *args) { ev_io_start(self->loop->loop, &self->io); Py_RETURN_NONE; } static PyObject* IO_stop(libevwrapper_IO *self, PyObject *args) { ev_io_stop(self->loop->loop, &self->io); Py_RETURN_NONE; } static PyObject* IO_is_active(libevwrapper_IO *self, PyObject *args) { struct ev_io *io = &(self->io); return PyBool_FromLong(ev_is_active(io)); } static PyObject* IO_is_pending(libevwrapper_IO *self, PyObject *args) { struct ev_io *io = &(self->io); return PyBool_FromLong(ev_is_pending(io)); } static PyMethodDef IO_methods[] = { {"start", (PyCFunction)IO_start, METH_NOARGS, "Start the watcher"}, {"stop", (PyCFunction)IO_stop, METH_NOARGS, "Stop the watcher"}, {"is_active", (PyCFunction)IO_is_active, METH_NOARGS, "Is the watcher active?"}, {"is_pending", (PyCFunction)IO_is_pending, METH_NOARGS, "Is the watcher pending?"}, {NULL} /* Sentinal */ }; static PyTypeObject libevwrapper_IOType = { PyVarObject_HEAD_INIT(NULL, 0) "cassandra.io.libevwrapper.IO", /*tp_name*/ sizeof(libevwrapper_IO), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)IO_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash */ 0, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ "IO objects", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ IO_methods, /* tp_methods */ 0, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)IO_init, /* tp_init */ }; typedef struct libevwrapper_Async { PyObject_HEAD struct ev_async async; struct libevwrapper_Loop *loop; } libevwrapper_Async; static void Async_dealloc(libevwrapper_Async *self) { Py_XDECREF(self->loop); Py_TYPE(self)->tp_free((PyObject *)self); }; static void async_callback(EV_P_ ev_async *watcher, int revents) {}; static int Async_init(libevwrapper_Async *self, PyObject *args, PyObject *kwds) { PyObject *loop; static char *kwlist[] = {"loop", NULL}; struct ev_async *async = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &loop)) { PyErr_SetString(PyExc_TypeError, "unable to get file descriptor from socket"); return -1; } if (loop) { Py_INCREF(loop); self->loop = (libevwrapper_Loop *)loop; } else { return -1; } async = &(self->async); ev_async_init(async, async_callback); return 0; }; static PyObject * Async_start(libevwrapper_Async *self, PyObject *args) { ev_async_start(self->loop->loop, &self->async); Py_RETURN_NONE; } static PyObject * Async_send(libevwrapper_Async *self, PyObject *args) { ev_async_send(self->loop->loop, &self->async); Py_RETURN_NONE; }; static PyMethodDef Async_methods[] = { {"start", (PyCFunction)Async_start, METH_NOARGS, "Start the watcher"}, {"send", (PyCFunction)Async_send, METH_NOARGS, "Notify the event loop"}, {NULL} /* Sentinel */ }; static PyTypeObject libevwrapper_AsyncType = { PyVarObject_HEAD_INIT(NULL, 0) "cassandra.io.libevwrapper.Async", /*tp_name*/ sizeof(libevwrapper_Async), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)Async_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash */ 0, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ "Async objects", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ Async_methods, /* tp_methods */ 0, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)Async_init, /* tp_init */ }; typedef struct libevwrapper_Prepare { PyObject_HEAD struct ev_prepare prepare; struct libevwrapper_Loop *loop; PyObject *callback; } libevwrapper_Prepare; static void Prepare_dealloc(libevwrapper_Prepare *self) { Py_XDECREF(self->loop); Py_XDECREF(self->callback); Py_TYPE(self)->tp_free((PyObject *)self); } static void prepare_callback(struct ev_loop *loop, ev_prepare *watcher, int revents) { libevwrapper_Prepare *self = watcher->data; PyObject *result = NULL; PyGILState_STATE gstate; gstate = PyGILState_Ensure(); result = PyObject_CallFunction(self->callback, "O", self); if (!result) { PyErr_WriteUnraisable(self->callback); } Py_XDECREF(result); PyGILState_Release(gstate); } static int Prepare_init(libevwrapper_Prepare *self, PyObject *args, PyObject *kwds) { PyObject *callback; PyObject *loop; struct ev_prepare *prepare = NULL; if (!PyArg_ParseTuple(args, "OO", &loop, &callback)) { return -1; } if (loop) { Py_INCREF(loop); self->loop = (libevwrapper_Loop *)loop; } else { return -1; } if (callback) { if (!PyCallable_Check(callback)) { PyErr_SetString(PyExc_TypeError, "callback parameter must be callable"); Py_XDECREF(loop); return -1; } Py_INCREF(callback); self->callback = callback; } prepare = &(self->prepare); ev_prepare_init(prepare, prepare_callback); self->prepare.data = self; return 0; } static PyObject * Prepare_start(libevwrapper_Prepare *self, PyObject *args) { ev_prepare_start(self->loop->loop, &self->prepare); Py_RETURN_NONE; } static PyObject * Prepare_stop(libevwrapper_Prepare *self, PyObject *args) { ev_prepare_stop(self->loop->loop, &self->prepare); Py_RETURN_NONE; } static PyMethodDef Prepare_methods[] = { {"start", (PyCFunction)Prepare_start, METH_NOARGS, "Start the Prepare watcher"}, {"stop", (PyCFunction)Prepare_stop, METH_NOARGS, "Stop the Prepare watcher"}, {NULL} /* Sentinal */ }; static PyTypeObject libevwrapper_PrepareType = { PyVarObject_HEAD_INIT(NULL, 0) "cassandra.io.libevwrapper.Prepare", /*tp_name*/ sizeof(libevwrapper_Prepare), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)Prepare_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash */ 0, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ "Prepare objects", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ Prepare_methods, /* tp_methods */ 0, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)Prepare_init, /* tp_init */ }; typedef struct libevwrapper_Timer { PyObject_HEAD struct ev_timer timer; struct libevwrapper_Loop *loop; PyObject *callback; } libevwrapper_Timer; static void Timer_dealloc(libevwrapper_Timer *self) { Py_XDECREF(self->loop); Py_XDECREF(self->callback); Py_TYPE(self)->tp_free((PyObject *)self); } static void timer_callback(struct ev_loop *loop, ev_timer *watcher, int revents) { libevwrapper_Timer *self = watcher->data; PyObject *result = NULL; PyGILState_STATE gstate; gstate = PyGILState_Ensure(); result = PyObject_CallFunction(self->callback, NULL); if (!result) { PyErr_WriteUnraisable(self->callback); } Py_XDECREF(result); PyGILState_Release(gstate); } static int Timer_init(libevwrapper_Timer *self, PyObject *args, PyObject *kwds) { PyObject *callback; PyObject *loop; if (!PyArg_ParseTuple(args, "OO", &loop, &callback)) { return -1; } if (loop) { Py_INCREF(loop); self->loop = (libevwrapper_Loop *)loop; } else { return -1; } if (callback) { if (!PyCallable_Check(callback)) { PyErr_SetString(PyExc_TypeError, "callback parameter must be callable"); Py_XDECREF(loop); return -1; } Py_INCREF(callback); self->callback = callback; } ev_init(&self->timer, timer_callback); self->timer.data = self; return 0; } static PyObject * Timer_start(libevwrapper_Timer *self, PyObject *args) { double timeout; if (!PyArg_ParseTuple(args, "d", &timeout)) { return NULL; } /* some tiny non-zero number to avoid zero, and make it run immediately for negative timeouts */ self->timer.repeat = fmax(timeout, 0.000000001); ev_timer_again(self->loop->loop, &self->timer); Py_RETURN_NONE; } static PyObject * Timer_stop(libevwrapper_Timer *self, PyObject *args) { ev_timer_stop(self->loop->loop, &self->timer); Py_RETURN_NONE; } static PyMethodDef Timer_methods[] = { {"start", (PyCFunction)Timer_start, METH_VARARGS, "Start the Timer watcher"}, {"stop", (PyCFunction)Timer_stop, METH_NOARGS, "Stop the Timer watcher"}, {NULL} /* Sentinal */ }; static PyTypeObject libevwrapper_TimerType = { PyVarObject_HEAD_INIT(NULL, 0) "cassandra.io.libevwrapper.Timer", /*tp_name*/ sizeof(libevwrapper_Timer), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)Timer_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash */ 0, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ "Timer objects", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ Timer_methods, /* tp_methods */ 0, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc)Timer_init, /* tp_init */ }; static PyMethodDef module_methods[] = { {NULL} /* Sentinal */ }; PyDoc_STRVAR(module_doc, "libev wrapper methods"); #if PY_MAJOR_VERSION >= 3 static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "libevwrapper", module_doc, -1, module_methods, NULL, NULL, NULL, NULL }; #define INITERROR return NULL PyObject * PyInit_libevwrapper(void) # else # define INITERROR return void initlibevwrapper(void) #endif { PyObject *module = NULL; if (PyType_Ready(&libevwrapper_LoopType) < 0) INITERROR; libevwrapper_IOType.tp_new = PyType_GenericNew; if (PyType_Ready(&libevwrapper_IOType) < 0) INITERROR; libevwrapper_PrepareType.tp_new = PyType_GenericNew; if (PyType_Ready(&libevwrapper_PrepareType) < 0) INITERROR; libevwrapper_AsyncType.tp_new = PyType_GenericNew; if (PyType_Ready(&libevwrapper_AsyncType) < 0) INITERROR; libevwrapper_TimerType.tp_new = PyType_GenericNew; if (PyType_Ready(&libevwrapper_TimerType) < 0) INITERROR; # if PY_MAJOR_VERSION >= 3 module = PyModule_Create(&moduledef); # else module = Py_InitModule3("libevwrapper", module_methods, module_doc); # endif if (module == NULL) INITERROR; if (PyModule_AddIntConstant(module, "EV_READ", EV_READ) == -1) INITERROR; if (PyModule_AddIntConstant(module, "EV_WRITE", EV_WRITE) == -1) INITERROR; if (PyModule_AddIntConstant(module, "EV_ERROR", EV_ERROR) == -1) INITERROR; Py_INCREF(&libevwrapper_LoopType); if (PyModule_AddObject(module, "Loop", (PyObject *)&libevwrapper_LoopType) == -1) INITERROR; Py_INCREF(&libevwrapper_IOType); if (PyModule_AddObject(module, "IO", (PyObject *)&libevwrapper_IOType) == -1) INITERROR; Py_INCREF(&libevwrapper_PrepareType); if (PyModule_AddObject(module, "Prepare", (PyObject *)&libevwrapper_PrepareType) == -1) INITERROR; Py_INCREF(&libevwrapper_AsyncType); if (PyModule_AddObject(module, "Async", (PyObject *)&libevwrapper_AsyncType) == -1) INITERROR; Py_INCREF(&libevwrapper_TimerType); if (PyModule_AddObject(module, "Timer", (PyObject *)&libevwrapper_TimerType) == -1) INITERROR; if (!PyEval_ThreadsInitialized()) { PyEval_InitThreads(); } #if PY_MAJOR_VERSION >= 3 return module; #endif } cassandra-driver-3.7.1/cassandra/io/eventletreactor.py0000664000175000017500000001157612766043657025774 0ustar aboudreaultaboudreault00000000000000# Copyright 2014 Symantec Corporation # Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Originally derived from MagnetoDB source: # https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py import eventlet from eventlet.green import socket from eventlet.queue import Queue import logging from threading import Event import time from six.moves import xrange from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) class EventletConnection(Connection): """ An implementation of :class:`.Connection` that utilizes ``eventlet``. This implementation assumes all eventlet monkey patching is active. It is not tested with partial patching. """ _read_watcher = None _write_watcher = None _socket_impl = eventlet.green.socket _ssl_impl = eventlet.green.ssl _timers = None _timeout_watcher = None _new_timer = None @classmethod def initialize_reactor(cls): eventlet.monkey_patch() if not cls._timers: cls._timers = TimerManager() cls._timeout_watcher = eventlet.spawn(cls.service_timeouts) cls._new_timer = Event() @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) cls._timers.add_timer(timer) cls._new_timer.set() return timer @classmethod def service_timeouts(cls): """ cls._timeout_watcher runs in this loop forever. It is usually waiting for the next timeout on the cls._new_timer Event. When new timers are added, that event is set so that the watcher can wake up and possibly set an earlier timeout. """ timer_manager = cls._timers while True: next_end = timer_manager.service_timeouts() sleep_time = max(next_end - time.time(), 0) if next_end else 10000 cls._new_timer.wait(sleep_time) cls._new_timer.clear() def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self._write_queue = Queue() self._connect_socket() self._read_watcher = eventlet.spawn(lambda: self.handle_read()) self._write_watcher = eventlet.spawn(lambda: self.handle_write()) self._send_options_message() def close(self): with self.lock: if self.is_closed: return self.is_closed = True log.debug("Closing connection (%s) to %s" % (id(self), self.host)) cur_gthread = eventlet.getcurrent() if self._read_watcher and self._read_watcher != cur_gthread: self._read_watcher.kill() if self._write_watcher and self._write_watcher != cur_gthread: self._write_watcher.kill() if self._socket: self._socket.close() log.debug("Closed socket to %s" % (self.host,)) if not self.is_defunct: self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) # don't leave in-progress operations hanging self.connected_event.set() def handle_close(self): log.debug("connection closed by server") self.close() def handle_write(self): while True: try: next_msg = self._write_queue.get() self._socket.sendall(next_msg) except socket.error as err: log.debug("Exception during socket send for %s: %s", self, err) self.defunct(err) return # Leave the write loop def handle_read(self): while True: try: buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) except socket.error as err: log.debug("Exception during socket recv for %s: %s", self, err) self.defunct(err) return # leave the read loop if self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) self.close() return def push(self, data): chunk_size = self.out_buffer_size for i in xrange(0, len(data), chunk_size): self._write_queue.put(data[i:i + chunk_size]) cassandra-driver-3.7.1/cassandra/io/__init__.py0000664000175000017500000000110412743410406024270 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. cassandra-driver-3.7.1/cassandra/io/twistedreactor.py0000664000175000017500000001747412743410406025615 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Module that implements an event loop based on twisted ( https://twistedmatrix.com ). """ import atexit from functools import partial import logging from threading import Thread, Lock import time from twisted.internet import reactor, protocol import weakref from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) def _cleanup(cleanup_weakref): try: cleanup_weakref()._cleanup() except ReferenceError: return class TwistedConnectionProtocol(protocol.Protocol): """ Twisted Protocol class for handling data received and connection made events. """ def dataReceived(self, data): """ Callback function that is called when data has been received on the connection. Reaches back to the Connection object and queues the data for processing. """ self.transport.connector.factory.conn._iobuf.write(data) self.transport.connector.factory.conn.handle_read() def connectionMade(self): """ Callback function that is called when a connection has succeeded. Reaches back to the Connection object and confirms that the connection is ready. """ self.transport.connector.factory.conn.client_connection_made() def connectionLost(self, reason): # reason is a Failure instance self.transport.connector.factory.conn.defunct(reason.value) class TwistedConnectionClientFactory(protocol.ClientFactory): def __init__(self, connection): # ClientFactory does not define __init__() in parent classes # and does not inherit from object. self.conn = connection def buildProtocol(self, addr): """ Twisted function that defines which kind of protocol to use in the ClientFactory. """ return TwistedConnectionProtocol() def clientConnectionFailed(self, connector, reason): """ Overridden twisted callback which is called when the connection attempt fails. """ log.debug("Connect failed: %s", reason) self.conn.defunct(reason.value) def clientConnectionLost(self, connector, reason): """ Overridden twisted callback which is called when the connection goes away (cleanly or otherwise). It should be safe to call defunct() here instead of just close, because we can assume that if the connection was closed cleanly, there are no requests to error out. If this assumption turns out to be false, we can call close() instead of defunct() when "reason" is an appropriate type. """ log.debug("Connect lost: %s", reason) self.conn.defunct(reason.value) class TwistedLoop(object): _lock = None _thread = None _timeout_task = None _timeout = None def __init__(self): self._lock = Lock() self._timers = TimerManager() def maybe_start(self): with self._lock: if not reactor.running: self._thread = Thread(target=reactor.run, name="cassandra_driver_event_loop", kwargs={'installSignalHandlers': False}) self._thread.daemon = True self._thread.start() atexit.register(partial(_cleanup, weakref.ref(self))) def _cleanup(self): if self._thread: reactor.callFromThread(reactor.stop) self._thread.join(timeout=1.0) if self._thread.is_alive(): log.warning("Event loop thread could not be joined, so " "shutdown may not be clean. Please call " "Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") def add_timer(self, timer): self._timers.add_timer(timer) # callFromThread to schedule from the loop thread, where # the timeout task can safely be modified reactor.callFromThread(self._schedule_timeout, timer.end) def _schedule_timeout(self, next_timeout): if next_timeout: delay = max(next_timeout - time.time(), 0) if self._timeout_task and self._timeout_task.active(): if next_timeout < self._timeout: self._timeout_task.reset(delay) self._timeout = next_timeout else: self._timeout_task = reactor.callLater(delay, self._on_loop_timer) self._timeout = next_timeout def _on_loop_timer(self): self._timers.service_timeouts() self._schedule_timeout(self._timers.next_timeout) class TwistedConnection(Connection): """ An implementation of :class:`.Connection` that utilizes the Twisted event loop. """ _loop = None @classmethod def initialize_reactor(cls): if not cls._loop: cls._loop = TwistedLoop() @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) cls._loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): """ Initialization method. Note that we can't call reactor methods directly here because it's not thread-safe, so we schedule the reactor/connection stuff to be run from the event loop thread when it gets the chance. """ Connection.__init__(self, *args, **kwargs) self.is_closed = True self.connector = None reactor.callFromThread(self.add_connection) self._loop.maybe_start() def add_connection(self): """ Convenience function to connect and store the resulting connector. """ self.connector = reactor.connectTCP( host=self.host, port=self.port, factory=TwistedConnectionClientFactory(self), timeout=self.connect_timeout) def client_connection_made(self): """ Called by twisted protocol when a connection attempt has succeeded. """ with self.lock: self.is_closed = False self._send_options_message() def close(self): """ Disconnect and error-out all requests. """ with self.lock: if self.is_closed: return self.is_closed = True log.debug("Closing connection (%s) to %s", id(self), self.host) self.connector.disconnect() log.debug("Closed socket to %s", self.host) if not self.is_defunct: self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) # don't leave in-progress operations hanging self.connected_event.set() def handle_read(self): """ Process the incoming data buffer. """ self.process_io_buffer() def push(self, data): """ This function is called when outgoing data should be queued for sending. Note that we can't call transport.write() directly because it is not thread-safe, so we schedule it to run from within the event loop when it gets the chance. """ reactor.callFromThread(self.connector.transport.write, data) cassandra-driver-3.7.1/cassandra/io/geventreactor.py0000664000175000017500000001030712766043657025425 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gevent import gevent.event from gevent.queue import Queue from gevent import socket import gevent.ssl import logging import time from six.moves import range from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) class GeventConnection(Connection): """ An implementation of :class:`.Connection` that utilizes ``gevent``. This implementation assumes all gevent monkey patching is active. It is not tested with partial patching. """ _read_watcher = None _write_watcher = None _socket_impl = gevent.socket _ssl_impl = gevent.ssl _timers = None _timeout_watcher = None _new_timer = None @classmethod def initialize_reactor(cls): if not cls._timers: cls._timers = TimerManager() cls._timeout_watcher = gevent.spawn(cls.service_timeouts) cls._new_timer = gevent.event.Event() @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) cls._timers.add_timer(timer) cls._new_timer.set() return timer @classmethod def service_timeouts(cls): timer_manager = cls._timers timer_event = cls._new_timer while True: next_end = timer_manager.service_timeouts() sleep_time = max(next_end - time.time(), 0) if next_end else 10000 timer_event.wait(sleep_time) timer_event.clear() def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self._write_queue = Queue() self._connect_socket() self._read_watcher = gevent.spawn(self.handle_read) self._write_watcher = gevent.spawn(self.handle_write) self._send_options_message() def close(self): with self.lock: if self.is_closed: return self.is_closed = True log.debug("Closing connection (%s) to %s" % (id(self), self.host)) if self._read_watcher: self._read_watcher.kill(block=False) if self._write_watcher: self._write_watcher.kill(block=False) if self._socket: self._socket.close() log.debug("Closed socket to %s" % (self.host,)) if not self.is_defunct: self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) # don't leave in-progress operations hanging self.connected_event.set() def handle_close(self): log.debug("connection closed by server") self.close() def handle_write(self): while True: try: next_msg = self._write_queue.get() self._socket.sendall(next_msg) except socket.error as err: log.debug("Exception in send for %s: %s", self, err) self.defunct(err) return def handle_read(self): while True: try: buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) except socket.error as err: log.debug("Exception in read for %s: %s", self, err) self.defunct(err) return # leave the read loop if self._iobuf.tell(): self.process_io_buffer() else: log.debug("Connection %s closed by server", self) self.close() return def push(self, data): chunk_size = self.out_buffer_size for i in range(0, len(data), chunk_size): self._write_queue.put(data[i:i + chunk_size]) cassandra-driver-3.7.1/cassandra/io/asyncorereactor.py0000664000175000017500000002713412766043657025766 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import atexit from collections import deque from functools import partial import logging import os import socket import sys from threading import Lock, Thread import time import weakref from six.moves import range try: from weakref import WeakSet except ImportError: from cassandra.util import WeakSet # noqa import asyncore try: import ssl except ImportError: ssl = None # NOQA from cassandra.connection import Connection, ConnectionShutdown, NONBLOCKING, Timer, TimerManager log = logging.getLogger(__name__) _dispatcher_map = {} def _cleanup(loop_weakref): try: loop = loop_weakref() except ReferenceError: return loop._cleanup() class _PipeWrapper(object): def __init__(self, fd): self.fd = fd def fileno(self): return self.fd def close(self): os.close(self.fd) def getsockopt(self, level, optname, buflen=None): # act like an unerrored socket for the asyncore error handling if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: return 0 raise NotImplementedError() class _AsyncoreDispatcher(asyncore.dispatcher): def __init__(self, socket): asyncore.dispatcher.__init__(self, map=_dispatcher_map) # inject after to avoid base class validation self.set_socket(socket) self._notified = False def writable(self): return False def validate(self): assert not self._notified self.notify_loop() assert self._notified self.loop(0.1) assert not self._notified def loop(self, timeout): asyncore.loop(timeout=timeout, use_poll=True, map=_dispatcher_map, count=1) class _AsyncorePipeDispatcher(_AsyncoreDispatcher): def __init__(self): self.read_fd, self.write_fd = os.pipe() _AsyncoreDispatcher.__init__(self, _PipeWrapper(self.read_fd)) def writable(self): return False def handle_read(self): while len(os.read(self.read_fd, 4096)) == 4096: pass self._notified = False def notify_loop(self): if not self._notified: self._notified = True os.write(self.write_fd, b'x') class _AsyncoreUDPDispatcher(_AsyncoreDispatcher): """ Experimental alternate dispatcher for avoiding busy wait in the asyncore loop. It is not used by default because it relies on local port binding. Port scanning is not implemented, so multiple clients on one host will collide. This address would need to be set per instance, or this could be specialized to scan until an address is found. To use:: from cassandra.io.asyncorereactor import _AsyncoreUDPDispatcher, AsyncoreLoop AsyncoreLoop._loop_dispatch_class = _AsyncoreUDPDispatcher """ bind_address = ('localhost', 10000) def __init__(self): self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._socket.bind(self.bind_address) self._socket.setblocking(0) _AsyncoreDispatcher.__init__(self, self._socket) def handle_read(self): try: d = self._socket.recvfrom(1) while d and d[1]: d = self._socket.recvfrom(1) except socket.error as e: pass self._notified = False def notify_loop(self): if not self._notified: self._notified = True self._socket.sendto(b'', self.bind_address) def loop(self, timeout): asyncore.loop(timeout=timeout, use_poll=False, map=_dispatcher_map, count=1) class _BusyWaitDispatcher(object): max_write_latency = 0.001 """ Timeout pushed down to asyncore select/poll. Dictates the amount of time it will sleep before coming back to check if anything is writable. """ def notify_loop(self): pass def loop(self, timeout): if not _dispatcher_map: time.sleep(0.005) count = timeout // self.max_write_latency asyncore.loop(timeout=self.max_write_latency, use_poll=True, map=_dispatcher_map, count=count) def validate(self): pass def close(self): pass class AsyncoreLoop(object): timer_resolution = 0.1 # used as the max interval to be in the io loop before returning to service timeouts _loop_dispatch_class = _AsyncorePipeDispatcher if os.name != 'nt' else _BusyWaitDispatcher def __init__(self): self._pid = os.getpid() self._loop_lock = Lock() self._started = False self._shutdown = False self._thread = None self._timers = TimerManager() try: dispatcher = self._loop_dispatch_class() dispatcher.validate() log.debug("Validated loop dispatch with %s", self._loop_dispatch_class) except Exception: log.exception("Failed validating loop dispatch with %s. Using busy wait execution instead.", self._loop_dispatch_class) dispatcher.close() dispatcher = _BusyWaitDispatcher() self._loop_dispatcher = dispatcher atexit.register(partial(_cleanup, weakref.ref(self))) def maybe_start(self): should_start = False did_acquire = False try: did_acquire = self._loop_lock.acquire(False) if did_acquire and not self._started: self._started = True should_start = True finally: if did_acquire: self._loop_lock.release() if should_start: self._thread = Thread(target=self._run_loop, name="cassandra_driver_event_loop") self._thread.daemon = True self._thread.start() def wake_loop(self): self._loop_dispatcher.notify_loop() def _run_loop(self): log.debug("Starting asyncore event loop") with self._loop_lock: while not self._shutdown: try: self._loop_dispatcher.loop(self.timer_resolution) self._timers.service_timeouts() except Exception: log.debug("Asyncore event loop stopped unexepectedly", exc_info=True) break self._started = False log.debug("Asyncore event loop ended") def add_timer(self, timer): self._timers.add_timer(timer) def _cleanup(self): self._shutdown = True if not self._thread: return log.debug("Waiting for event loop thread to join...") self._thread.join(timeout=1.0) if self._thread.is_alive(): log.warning( "Event loop thread could not be joined, so shutdown may not be clean. " "Please call Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") class AsyncoreConnection(Connection, asyncore.dispatcher): """ An implementation of :class:`.Connection` that uses the ``asyncore`` module in the Python standard library for its event loop. """ _loop = None _writable = False _readable = False @classmethod def initialize_reactor(cls): if not cls._loop: cls._loop = AsyncoreLoop() else: current_pid = os.getpid() if cls._loop._pid != current_pid: log.debug("Detected fork, clearing and reinitializing reactor state") cls.handle_fork() cls._loop = AsyncoreLoop() @classmethod def handle_fork(cls): global _dispatcher_map _dispatcher_map = {} if cls._loop: cls._loop._cleanup() cls._loop = None @classmethod def create_timer(cls, timeout, callback): timer = Timer(timeout, callback) cls._loop.add_timer(timer) return timer def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self.deque = deque() self.deque_lock = Lock() self._connect_socket() asyncore.dispatcher.__init__(self, self._socket, _dispatcher_map) self._writable = True self._readable = True self._send_options_message() # start the event loop if needed self._loop.maybe_start() def close(self): with self.lock: if self.is_closed: return self.is_closed = True log.debug("Closing connection (%s) to %s", id(self), self.host) self._writable = False self._readable = False asyncore.dispatcher.close(self) log.debug("Closed socket to %s", self.host) if not self.is_defunct: self.error_all_requests( ConnectionShutdown("Connection to %s was closed" % self.host)) # don't leave in-progress operations hanging self.connected_event.set() def handle_error(self): self.defunct(sys.exc_info()[1]) def handle_close(self): log.debug("Connection %s closed by server", self) self.close() def handle_write(self): while True: with self.deque_lock: try: next_msg = self.deque.popleft() except IndexError: self._writable = False return try: sent = self.send(next_msg) self._readable = True except socket.error as err: if (err.args[0] in NONBLOCKING): with self.deque_lock: self.deque.appendleft(next_msg) else: self.defunct(err) return else: if sent < len(next_msg): with self.deque_lock: self.deque.appendleft(next_msg[sent:]) if sent == 0: return def handle_read(self): try: while True: buf = self.recv(self.in_buffer_size) self._iobuf.write(buf) if len(buf) < self.in_buffer_size: break except socket.error as err: if ssl and isinstance(err, ssl.SSLError): if err.args[0] not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): self.defunct(err) return elif err.args[0] not in NONBLOCKING: self.defunct(err) return if self._iobuf.tell(): self.process_io_buffer() if not self._requests and not self.is_control_connection: self._readable = False def push(self, data): sabs = self.out_buffer_size if len(data) > sabs: chunks = [] for i in range(0, len(data), sabs): chunks.append(data[i:i + sabs]) else: chunks = [data] with self.deque_lock: self.deque.extend(chunks) self._writable = True self._loop.wake_loop() def writable(self): return self._writable def readable(self): return self._readable or (self.is_control_connection and not (self.is_defunct or self.is_closed)) cassandra-driver-3.7.1/cassandra/murmur3.py0000664000175000017500000000452312743410406023544 0ustar aboudreaultaboudreault00000000000000from six.moves import range import struct def body_and_tail(data): l = len(data) nblocks = l // 16 tail = l % 16 if nblocks: return struct.unpack_from('qq' * nblocks, data), struct.unpack_from('b' * tail, data, -tail), l else: return tuple(), struct.unpack_from('b' * tail, data, -tail), l def rotl64(x, r): # note: not a general-purpose function because it leaves the high-order bits intact # suitable for this use case without wasting cycles mask = 2 ** r - 1 rotated = (x << r) | ((x >> 64 - r) & mask) return rotated def fmix(k): # masking off the 31s bits that would be leftover after >> 33 a 64-bit number k ^= (k >> 33) & 0x7fffffff k *= 0xff51afd7ed558ccd k ^= (k >> 33) & 0x7fffffff k *= 0xc4ceb9fe1a85ec53 k ^= (k >> 33) & 0x7fffffff return k INT64_MAX = int(2 ** 63 - 1) INT64_MIN = -INT64_MAX - 1 INT64_OVF_OFFSET = INT64_MAX + 1 INT64_OVF_DIV = 2 * INT64_OVF_OFFSET def truncate_int64(x): if not INT64_MIN <= x <= INT64_MAX: x = (x + INT64_OVF_OFFSET) % INT64_OVF_DIV - INT64_OVF_OFFSET return x def _murmur3(data): h1 = h2 = 0 c1 = -8663945395140668459 # 0x87c37b91114253d5 c2 = 0x4cf5ad432745937f body, tail, total_len = body_and_tail(data) # body for i in range(0, len(body), 2): k1 = body[i] k2 = body[i + 1] k1 *= c1 k1 = rotl64(k1, 31) k1 *= c2 h1 ^= k1 h1 = rotl64(h1, 27) h1 += h2 h1 = h1 * 5 + 0x52dce729 k2 *= c2 k2 = rotl64(k2, 33) k2 *= c1 h2 ^= k2 h2 = rotl64(h2, 31) h2 += h1 h2 = h2 * 5 + 0x38495ab5 # tail k1 = k2 = 0 len_tail = len(tail) if len_tail > 8: for i in range(len_tail - 1, 7, -1): k2 ^= tail[i] << (i - 8) * 8 k2 *= c2 k2 = rotl64(k2, 33) k2 *= c1 h2 ^= k2 if len_tail: for i in range(min(7, len_tail - 1), -1, -1): k1 ^= tail[i] << i * 8 k1 *= c1 k1 = rotl64(k1, 31) k1 *= c2 h1 ^= k1 # finalization h1 ^= total_len h2 ^= total_len h1 += h2 h2 += h1 h1 = fmix(h1) h2 = fmix(h2) h1 += h2 return truncate_int64(h1) try: from cassandra.cmurmur3 import murmur3 except ImportError: murmur3 = _murmur3 cassandra-driver-3.7.1/cassandra/cmurmur3.c0000664000175000017500000001457612755630214023516 0ustar aboudreaultaboudreault00000000000000/* * The majority of this code was taken from the python-smhasher library, * which can be found here: https://github.com/phensley/python-smhasher * * That library is under the MIT license with the following copyright: * * Copyright (c) 2011 Austin Appleby (Murmur3 routine) * Copyright (c) 2011 Patrick Hensley (Python wrapper, packaging) * Copyright 2013-2016 DataStax (Minor modifications to match Cassandra's MM3 hashes) * */ #define PY_SSIZE_T_CLEAN 1 #include #include #if PY_VERSION_HEX < 0x02050000 typedef int Py_ssize_t; #define PY_SSIZE_T_MAX INT_MAX #define PY_SSIZE_T_MIN INT_MIN #endif #ifdef PYPY_VERSION #define COMPILING_IN_PYPY 1 #define COMPILING_IN_CPYTHON 0 #else #define COMPILING_IN_PYPY 0 #define COMPILING_IN_CPYTHON 1 #endif //----------------------------------------------------------------------------- // Platform-specific functions and macros // Microsoft Visual Studio #if defined(_MSC_VER) typedef unsigned char uint8_t; typedef unsigned long uint32_t; typedef unsigned __int64 uint64_t; typedef char int8_t; typedef long int32_t; typedef __int64 int64_t; #define FORCE_INLINE __forceinline #include #define ROTL32(x,y) _rotl(x,y) #define ROTL64(x,y) _rotl64(x,y) #define BIG_CONSTANT(x) (x) // Other compilers #else // defined(_MSC_VER) #include #define FORCE_INLINE inline __attribute__((always_inline)) inline uint32_t rotl32 ( int32_t x, int8_t r ) { // cast to unsigned for logical right bitshift (to match C* MM3 implementation) return (x << r) | ((int32_t) (((uint32_t) x) >> (32 - r))); } inline int64_t rotl64 ( int64_t x, int8_t r ) { // cast to unsigned for logical right bitshift (to match C* MM3 implementation) return (x << r) | ((int64_t) (((uint64_t) x) >> (64 - r))); } #define ROTL32(x,y) rotl32(x,y) #define ROTL64(x,y) rotl64(x,y) #define BIG_CONSTANT(x) (x##LL) #endif // !defined(_MSC_VER) //----------------------------------------------------------------------------- // Block read - if your platform needs to do endian-swapping or can only // handle aligned reads, do the conversion here // TODO 32bit? FORCE_INLINE int64_t getblock ( const int64_t * p, int i ) { return p[i]; } //----------------------------------------------------------------------------- // Finalization mix - force all bits of a hash block to avalanche FORCE_INLINE int64_t fmix ( int64_t k ) { // cast to unsigned for logical right bitshift (to match C* MM3 implementation) k ^= ((uint64_t) k) >> 33; k *= BIG_CONSTANT(0xff51afd7ed558ccd); k ^= ((uint64_t) k) >> 33; k *= BIG_CONSTANT(0xc4ceb9fe1a85ec53); k ^= ((uint64_t) k) >> 33; return k; } int64_t MurmurHash3_x64_128 (const void * key, const int len, const uint32_t seed) { const int8_t * data = (const int8_t*)key; const int nblocks = len / 16; int64_t h1 = seed; int64_t h2 = seed; int64_t c1 = BIG_CONSTANT(0x87c37b91114253d5); int64_t c2 = BIG_CONSTANT(0x4cf5ad432745937f); int64_t k1 = 0; int64_t k2 = 0; const int64_t * blocks = (const int64_t *)(data); const int8_t * tail = (const int8_t*)(data + nblocks*16); //---------- // body int i; for(i = 0; i < nblocks; i++) { int64_t k1 = getblock(blocks,i*2+0); int64_t k2 = getblock(blocks,i*2+1); k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; h1 = ROTL64(h1,27); h1 += h2; h1 = h1*5+0x52dce729; k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; h2 = ROTL64(h2,31); h2 += h1; h2 = h2*5+0x38495ab5; } //---------- // tail switch(len & 15) { case 15: k2 ^= ((int64_t) (tail[14])) << 48; case 14: k2 ^= ((int64_t) (tail[13])) << 40; case 13: k2 ^= ((int64_t) (tail[12])) << 32; case 12: k2 ^= ((int64_t) (tail[11])) << 24; case 11: k2 ^= ((int64_t) (tail[10])) << 16; case 10: k2 ^= ((int64_t) (tail[ 9])) << 8; case 9: k2 ^= ((int64_t) (tail[ 8])) << 0; k2 *= c2; k2 = ROTL64(k2,33); k2 *= c1; h2 ^= k2; case 8: k1 ^= ((int64_t) (tail[ 7])) << 56; case 7: k1 ^= ((int64_t) (tail[ 6])) << 48; case 6: k1 ^= ((int64_t) (tail[ 5])) << 40; case 5: k1 ^= ((int64_t) (tail[ 4])) << 32; case 4: k1 ^= ((int64_t) (tail[ 3])) << 24; case 3: k1 ^= ((int64_t) (tail[ 2])) << 16; case 2: k1 ^= ((int64_t) (tail[ 1])) << 8; case 1: k1 ^= ((int64_t) (tail[ 0])) << 0; k1 *= c1; k1 = ROTL64(k1,31); k1 *= c2; h1 ^= k1; }; //---------- // finalization h1 ^= len; h2 ^= len; h1 += h2; h2 += h1; h1 = fmix(h1); h2 = fmix(h2); h1 += h2; h2 += h1; return h1; } struct module_state { PyObject *error; }; // pypy3 doesn't have GetState yet. #if COMPILING_IN_CPYTHON && PY_MAJOR_VERSION >= 3 #define GETSTATE(m) ((struct module_state*)PyModule_GetState(m)) #else #define GETSTATE(m) (&_state) static struct module_state _state; #endif static PyObject * murmur3(PyObject *self, PyObject *args) { const char *key; Py_ssize_t len; uint32_t seed = 0; int64_t result = 0; if (!PyArg_ParseTuple(args, "s#|I", &key, &len, &seed)) { return NULL; } // TODO handle x86 version? result = MurmurHash3_x64_128((void *)key, len, seed); return (PyObject *) PyLong_FromLongLong(result); } static PyMethodDef cmurmur3_methods[] = { {"murmur3", murmur3, METH_VARARGS, "Make an x64 murmur3 64-bit hash value"}, {NULL, NULL, 0, NULL} }; #if PY_MAJOR_VERSION >= 3 static int cmurmur3_traverse(PyObject *m, visitproc visit, void *arg) { Py_VISIT(GETSTATE(m)->error); return 0; } static int cmurmur3_clear(PyObject *m) { Py_CLEAR(GETSTATE(m)->error); return 0; } static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "cmurmur3", NULL, sizeof(struct module_state), cmurmur3_methods, NULL, cmurmur3_traverse, cmurmur3_clear, NULL }; #define INITERROR return NULL PyObject * PyInit_cmurmur3(void) #else #define INITERROR return void initcmurmur3(void) #endif { #if PY_MAJOR_VERSION >= 3 PyObject *module = PyModule_Create(&moduledef); #else PyObject *module = Py_InitModule("cmurmur3", cmurmur3_methods); #endif struct module_state *st = NULL; if (module == NULL) INITERROR; st = GETSTATE(module); st->error = PyErr_NewException("cmurmur3.Error", NULL, NULL); if (st->error == NULL) { Py_DECREF(module); INITERROR; } #if PY_MAJOR_VERSION >= 3 return module; #endif } cassandra-driver-3.7.1/cassandra/parsing.pxd0000664000175000017500000000203012743410406023727 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from cassandra.bytesio cimport BytesIOReader from cassandra.deserializers cimport Deserializer cdef class ParseDesc: cdef public object colnames cdef public object coltypes cdef Deserializer[::1] deserializers cdef public int protocol_version cdef Py_ssize_t rowsize cdef class ColumnParser: cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc) cdef class RowParser: cpdef unpack_row(self, BytesIOReader reader, ParseDesc desc) cassandra-driver-3.7.1/cassandra/row_parser.pyx0000664000175000017500000000310013004141114024457 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from cassandra.parsing cimport ParseDesc, ColumnParser from cassandra.deserializers import make_deserializers include "ioutils.pyx" def make_recv_results_rows(ColumnParser colparser): def recv_results_rows(cls, f, int protocol_version, user_type_map, result_metadata): """ Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples) This is used as the recv_results_rows method of (Fast)ResultMessage """ paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) column_metadata = column_metadata or result_metadata colnames = [c[2] for c in column_metadata] coltypes = [c[3] for c in column_metadata] desc = ParseDesc(colnames, coltypes, make_deserializers(coltypes), protocol_version) reader = BytesIOReader(f.read()) parsed_rows = colparser.parse_rows(reader, desc) return (paging_state, (colnames, parsed_rows)) return recv_results_rows cassandra-driver-3.7.1/cassandra/concurrent.py0000664000175000017500000002233512766043657024334 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple from heapq import heappush, heappop from itertools import cycle import six from six.moves import xrange, zip from threading import Condition import sys from cassandra.cluster import ResultSet import logging log = logging.getLogger(__name__) ExecutionResult = namedtuple('ExecutionResult', ['success', 'result_or_exc']) def execute_concurrent(session, statements_and_parameters, concurrency=100, raise_on_first_error=True, results_generator=False): """ Executes a sequence of (statement, parameters) tuples concurrently. Each ``parameters`` item must be a sequence or :const:`None`. The `concurrency` parameter controls how many statements will be executed concurrently. When :attr:`.Cluster.protocol_version` is set to 1 or 2, it is recommended that this be kept below 100 times the number of core connections per host times the number of connected hosts (see :meth:`.Cluster.set_core_connections_per_host`). If that amount is exceeded, the event loop thread may attempt to block on new connection creation, substantially impacting throughput. If :attr:`~.Cluster.protocol_version` is 3 or higher, you can safely experiment with higher levels of concurrency. If `raise_on_first_error` is left as :const:`True`, execution will stop after the first failed statement and the corresponding exception will be raised. `results_generator` controls how the results are returned. If :const:`False`, the results are returned only after all requests have completed. If :const:`True`, a generator expression is returned. Using a generator results in a constrained memory footprint when the results set will be large -- results are yielded as they return instead of materializing the entire list at once. The trade for lower memory footprint is marginal CPU overhead (more thread coordination and sorting out-of-order results on-the-fly). A sequence of ``ExecutionResult(success, result_or_exc)`` namedtuples is returned in the same order that the statements were passed in. If ``success`` is :const:`False`, there was an error executing the statement, and ``result_or_exc`` will be an :class:`Exception`. If ``success`` is :const:`True`, ``result_or_exc`` will be the query result. Example usage:: select_statement = session.prepare("SELECT * FROM users WHERE id=?") statements_and_params = [] for user_id in user_ids: params = (user_id, ) statements_and_params.append((select_statement, params)) results = execute_concurrent( session, statements_and_params, raise_on_first_error=False) for (success, result) in results: if not success: handle_error(result) # result will be an Exception else: process_user(result[0]) # result will be a list of rows """ if concurrency <= 0: raise ValueError("concurrency must be greater than 0") if not statements_and_parameters: return [] executor = ConcurrentExecutorGenResults(session, statements_and_parameters) if results_generator else ConcurrentExecutorListResults(session, statements_and_parameters) return executor.execute(concurrency, raise_on_first_error) class _ConcurrentExecutor(object): max_error_recursion = 100 def __init__(self, session, statements_and_params): self.session = session self._enum_statements = enumerate(iter(statements_and_params)) self._condition = Condition() self._fail_fast = False self._results_queue = [] self._current = 0 self._exec_count = 0 self._exec_depth = 0 def execute(self, concurrency, fail_fast): self._fail_fast = fail_fast self._results_queue = [] self._current = 0 self._exec_count = 0 with self._condition: for n in xrange(concurrency): if not self._execute_next(): break return self._results() def _execute_next(self): # lock must be held try: (idx, (statement, params)) = next(self._enum_statements) self._exec_count += 1 self._execute(idx, statement, params) return True except StopIteration: pass def _execute(self, idx, statement, params): self._exec_depth += 1 try: future = self.session.execute_async(statement, params, timeout=None) args = (future, idx) future.add_callbacks( callback=self._on_success, callback_args=args, errback=self._on_error, errback_args=args) except Exception as exc: # exc_info with fail_fast to preserve stack trace info when raising on the client thread # (matches previous behavior -- not sure why we wouldn't want stack trace in the other case) e = sys.exc_info() if self._fail_fast and six.PY2 else exc # If we're not failing fast and all executions are raising, there is a chance of recursing # here as subsequent requests are attempted. If we hit this threshold, schedule this result/retry # and let the event loop thread return. if self._exec_depth < self.max_error_recursion: self._put_result(e, idx, False) else: self.session.submit(self._put_result, e, idx, False) self._exec_depth -= 1 def _on_success(self, result, future, idx): future.clear_callbacks() self._put_result(ResultSet(future, result), idx, True) def _on_error(self, result, future, idx): self._put_result(result, idx, False) @staticmethod def _raise(exc): if six.PY2 and isinstance(exc, tuple): (exc_type, value, traceback) = exc six.reraise(exc_type, value, traceback) else: raise exc class ConcurrentExecutorGenResults(_ConcurrentExecutor): def _put_result(self, result, idx, success): with self._condition: heappush(self._results_queue, (idx, ExecutionResult(success, result))) self._execute_next() self._condition.notify() def _results(self): with self._condition: while self._current < self._exec_count: while not self._results_queue or self._results_queue[0][0] != self._current: self._condition.wait() while self._results_queue and self._results_queue[0][0] == self._current: _, res = heappop(self._results_queue) try: self._condition.release() if self._fail_fast and not res[0]: self._raise(res[1]) yield res finally: self._condition.acquire() self._current += 1 class ConcurrentExecutorListResults(_ConcurrentExecutor): _exception = None def execute(self, concurrency, fail_fast): self._exception = None return super(ConcurrentExecutorListResults, self).execute(concurrency, fail_fast) def _put_result(self, result, idx, success): self._results_queue.append((idx, ExecutionResult(success, result))) with self._condition: self._current += 1 if not success and self._fail_fast: if not self._exception: self._exception = result self._condition.notify() elif not self._execute_next() and self._current == self._exec_count: self._condition.notify() def _results(self): with self._condition: while self._current < self._exec_count: self._condition.wait() if self._exception and self._fail_fast: self._raise(self._exception) if self._exception and self._fail_fast: # raise the exception even if there was no wait self._raise(self._exception) return [r[1] for r in sorted(self._results_queue)] def execute_concurrent_with_args(session, statement, parameters, *args, **kwargs): """ Like :meth:`~cassandra.concurrent.execute_concurrent()`, but takes a single statement and a sequence of parameters. Each item in ``parameters`` should be a sequence or :const:`None`. Example usage:: statement = session.prepare("INSERT INTO mytable (a, b) VALUES (1, ?)") parameters = [(x,) for x in range(1000)] execute_concurrent_with_args(session, statement, parameters, concurrency=50) """ return execute_concurrent(session, zip(cycle((statement,)), parameters), *args, **kwargs) cassandra-driver-3.7.1/cassandra/cluster.py0000664000175000017500000051436413004141114023610 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module houses the main classes you will interact with, :class:`.Cluster` and :class:`.Session`. """ from __future__ import absolute_import import atexit from collections import defaultdict, Mapping from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures from copy import copy from functools import partial, wraps from itertools import groupby, count import logging from random import random import six from six.moves import filter, range, queue as Queue import socket import sys import time from threading import Lock, RLock, Thread, Event import weakref from weakref import WeakValueDictionary try: from weakref import WeakSet except ImportError: from cassandra.util import WeakSet # NOQA from cassandra import (ConsistencyLevel, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, SchemaTargetType, DriverException) from cassandra.connection import (ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported) from cassandra.cqltypes import UserType from cassandra.encoder import Encoder from cassandra.protocol import (QueryMessage, ResultMessage, ErrorMessage, ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, UnavailableErrorMessage, OverloadedErrorMessage, PrepareMessage, ExecuteMessage, PreparedQueryNotFound, IsBootstrappingErrorMessage, BatchMessage, RESULT_KIND_PREPARED, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SCHEMA_CHANGE, MIN_SUPPORTED_VERSION, ProtocolHandler) from cassandra.metadata import Metadata, protect_name, murmur3 from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, RetryPolicy, IdentityTranslator, NoSpeculativeExecutionPlan, NoSpeculativeExecutionPolicy) from cassandra.pool import (Host, _ReconnectionHandler, _HostReconnectionHandler, HostConnectionPool, HostConnection, NoConnectionsAvailable) from cassandra.query import (SimpleStatement, PreparedStatement, BoundStatement, BatchStatement, bind_params, QueryTrace, named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET) def _is_eventlet_monkey_patched(): if 'eventlet.patcher' not in sys.modules: return False import eventlet.patcher return eventlet.patcher.is_monkey_patched('socket') def _is_gevent_monkey_patched(): if 'gevent.monkey' not in sys.modules: return False import gevent.socket return socket.socket is gevent.socket.socket # default to gevent when we are monkey patched with gevent, eventlet when # monkey patched with eventlet, otherwise if libev is available, use that as # the default because it's fastest. Otherwise, use asyncore. if _is_gevent_monkey_patched(): from cassandra.io.geventreactor import GeventConnection as DefaultConnection elif _is_eventlet_monkey_patched(): from cassandra.io.eventletreactor import EventletConnection as DefaultConnection else: try: from cassandra.io.libevreactor import LibevConnection as DefaultConnection # NOQA except ImportError: from cassandra.io.asyncorereactor import AsyncoreConnection as DefaultConnection # NOQA # Forces load of utf8 encoding module to avoid deadlock that occurs # if code that is being imported tries to import the module in a seperate # thread. # See http://bugs.python.org/issue10923 "".encode('utf8') log = logging.getLogger(__name__) DEFAULT_MIN_REQUESTS = 5 DEFAULT_MAX_REQUESTS = 100 DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST = 2 DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST = 8 DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST = 1 DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST = 2 _NOT_SET = object() class NoHostAvailable(Exception): """ Raised when an operation is attempted but all connections are busy, defunct, closed, or resulted in errors when used. """ errors = None """ A map of the form ``{ip: exception}`` which details the particular Exception that was caught for each host the operation was attempted against. """ def __init__(self, message, errors): Exception.__init__(self, message, errors) self.errors = errors def _future_completed(future): """ Helper for run_in_executor() """ exc = future.exception() if exc: log.debug("Failed to run task on executor", exc_info=exc) def run_in_executor(f): """ A decorator to run the given method in the ThreadPoolExecutor. """ @wraps(f) def new_f(self, *args, **kwargs): if self.is_shutdown: return try: future = self.executor.submit(f, self, *args, **kwargs) future.add_done_callback(_future_completed) except Exception: log.exception("Failed to submit task to executor") return new_f _clusters_for_shutdown = set() def _register_cluster_shutdown(cluster): _clusters_for_shutdown.add(cluster) def _discard_cluster_shutdown(cluster): _clusters_for_shutdown.discard(cluster) def _shutdown_clusters(): clusters = _clusters_for_shutdown.copy() # copy because shutdown modifies the global set "discard" for cluster in clusters: cluster.shutdown() atexit.register(_shutdown_clusters) def default_lbp_factory(): if murmur3 is not None: return TokenAwarePolicy(DCAwareRoundRobinPolicy()) return DCAwareRoundRobinPolicy() class ExecutionProfile(object): load_balancing_policy = None """ An instance of :class:`.policies.LoadBalancingPolicy` or one of its subclasses. Used in determining host distance for establishing connections, and routing requests. Defaults to ``TokenAwarePolicy(DCAwareRoundRobinPolicy())`` if not specified """ retry_policy = None """ An instance of :class:`.policies.RetryPolicy` instance used when :class:`.Statement` objects do not have a :attr:`~.Statement.retry_policy` explicitly set. Defaults to :class:`.RetryPolicy` if not specified """ consistency_level = ConsistencyLevel.LOCAL_ONE """ :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement`. """ serial_consistency_level = None """ Serial :class:`.ConsistencyLevel` used when not specified on a :class:`.Statement` (for LWT conditional statements). """ request_timeout = 10.0 """ Request timeout used when not overridden in :meth:`.Session.execute` """ row_factory = staticmethod(tuple_factory) """ A callable to format results, accepting ``(colnames, rows)`` where ``colnames`` is a list of column names, and ``rows`` is a list of tuples, with each tuple representing a row of parsed values. Some example implementations: - :func:`cassandra.query.tuple_factory` - return a result row as a tuple - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple - :func:`cassandra.query.dict_factory` - return a result row as a dict - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict """ speculative_execution_policy = None """ An instance of :class:`.policies.SpeculativeExecutionPolicy` Defaults to :class:`.NoSpeculativeExecutionPolicy` if not specified """ def __init__(self, load_balancing_policy=None, retry_policy=None, consistency_level=ConsistencyLevel.LOCAL_ONE, serial_consistency_level=None, request_timeout=10.0, row_factory=named_tuple_factory, speculative_execution_policy=None): self.load_balancing_policy = load_balancing_policy or default_lbp_factory() self.retry_policy = retry_policy or RetryPolicy() self.consistency_level = consistency_level self.serial_consistency_level = serial_consistency_level self.request_timeout = request_timeout self.row_factory = row_factory self.speculative_execution_policy = speculative_execution_policy or NoSpeculativeExecutionPolicy() class ProfileManager(object): def __init__(self): self.profiles = dict() def distance(self, host): distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) return HostDistance.LOCAL if HostDistance.LOCAL in distances else \ HostDistance.REMOTE if HostDistance.REMOTE in distances else \ HostDistance.IGNORED def populate(self, cluster, hosts): for p in self.profiles.values(): p.load_balancing_policy.populate(cluster, hosts) def check_supported(self): for p in self.profiles.values(): p.load_balancing_policy.check_supported() def on_up(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_up(host) def on_down(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_down(host) def on_add(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_add(host) def on_remove(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_remove(host) @property def default(self): """ internal-only; no checks are done because this entry is populated on cluster init """ return self.profiles[EXEC_PROFILE_DEFAULT] EXEC_PROFILE_DEFAULT = object() """ Key for the ``Cluster`` default execution profile, used when no other profile is selected in ``Session.execute(execution_profile)``. Use this as the key in ``Cluster(execution_profiles)`` to override the default profile. """ class _ConfigMode(object): UNCOMMITTED = 0 LEGACY = 1 PROFILES = 2 class Cluster(object): """ The main class to use when interacting with a Cassandra cluster. Typically, one instance of this class will be created for each separate Cassandra cluster that your application interacts with. Example usage:: >>> from cassandra.cluster import Cluster >>> cluster = Cluster(['192.168.1.1', '192.168.1.2']) >>> session = cluster.connect() >>> session.execute("CREATE KEYSPACE ...") >>> ... >>> cluster.shutdown() ``Cluster`` and ``Session`` also provide context management functions which implicitly handle shutdown when leaving scope. """ contact_points = ['127.0.0.1'] """ The list of contact points to try connecting for cluster discovery. Defaults to loopback interface. Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit local_dc set (as is the default), the DC is chosen from an arbitrary host in contact_points. In this case, contact_points should contain only nodes from a single, local DC. """ port = 9042 """ The server-side port to open connections to. Defaults to 9042. """ cql_version = None """ If a specific version of CQL should be used, this may be set to that string version. Otherwise, the highest CQL version supported by the server will be automatically used. """ protocol_version = 4 """ The maximum version of the native protocol to use. If not set in the constructor, the driver will automatically downgrade version based on a negotiation with the server, but it is most efficient to set this to the maximum supported by your version of Cassandra. Setting this will also prevent conflicting versions negotiated if your cluster is upgraded. Version 2 of the native protocol adds support for lightweight transactions, batch operations, and automatic query paging. The v2 protocol is supported by Cassandra 2.0+. Version 3 of the native protocol adds support for protocol-level client-side timestamps (see :attr:`.Session.use_client_timestamp`), serial consistency levels for :class:`~.BatchStatement`, and an improved connection pool. Version 4 of the native protocol adds a number of new types, server warnings, new failure messages, and custom payloads. Details in the `project docs `_ The following table describes the native protocol versions that are supported by each version of Cassandra: +-------------------+-------------------+ | Cassandra Version | Protocol Versions | +===================+===================+ | 1.2 | 1 | +-------------------+-------------------+ | 2.0 | 1, 2 | +-------------------+-------------------+ | 2.1 | 1, 2, 3 | +-------------------+-------------------+ | 2.2 | 1, 2, 3, 4 | +-------------------+-------------------+ | 3.x | 3, 4 | +-------------------+-------------------+ """ allow_beta_protocol_version = False """ Setting true injects a flag in all messages that makes the server accept and use "beta" protocol version. Used for testing new protocol features incrementally before the new version is complete. """ compression = True """ Controls compression for communications between the driver and Cassandra. If left as the default of :const:`True`, either lz4 or snappy compression may be used, depending on what is supported by both the driver and Cassandra. If both are fully supported, lz4 will be preferred. You may also set this to 'snappy' or 'lz4' to request that specific compression type. Setting this to :const:`False` disables compression. """ _auth_provider = None _auth_provider_callable = None @property def auth_provider(self): """ When :attr:`~.Cluster.protocol_version` is 2 or higher, this should be an instance of a subclass of :class:`~cassandra.auth.AuthProvider`, such as :class:`~.PlainTextAuthProvider`. When :attr:`~.Cluster.protocol_version` is 1, this should be a function that accepts one argument, the IP address of a node, and returns a dict of credentials for that node. When not using authentication, this should be left as :const:`None`. """ return self._auth_provider @auth_provider.setter # noqa def auth_provider(self, value): if not value: self._auth_provider = value return try: self._auth_provider_callable = value.new_authenticator except AttributeError: if self.protocol_version > 1: raise TypeError("auth_provider must implement the cassandra.auth.AuthProvider " "interface when protocol_version >= 2") elif not callable(value): raise TypeError("auth_provider must be callable when protocol_version == 1") self._auth_provider_callable = value self._auth_provider = value _load_balancing_policy = None @property def load_balancing_policy(self): """ An instance of :class:`.policies.LoadBalancingPolicy` or one of its subclasses. .. versionchanged:: 2.6.0 Defaults to :class:`~.TokenAwarePolicy` (:class:`~.DCAwareRoundRobinPolicy`). when using CPython (where the murmur3 extension is available). :class:`~.DCAwareRoundRobinPolicy` otherwise. Default local DC will be chosen from contact points. **Please see** :class:`~.DCAwareRoundRobinPolicy` **for a discussion on default behavior with respect to DC locality and remote nodes.** """ return self._load_balancing_policy @load_balancing_policy.setter def load_balancing_policy(self, lbp): if self._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Cluster.load_balancing_policy while using Configuration Profiles. Set this in a profile instead.") self._load_balancing_policy = lbp self._config_mode = _ConfigMode.LEGACY @property def _default_load_balancing_policy(self): return self.profile_manager.default.load_balancing_policy reconnection_policy = ExponentialReconnectionPolicy(1.0, 600.0) """ An instance of :class:`.policies.ReconnectionPolicy`. Defaults to an instance of :class:`.ExponentialReconnectionPolicy` with a base delay of one second and a max delay of ten minutes. """ _default_retry_policy = RetryPolicy() @property def default_retry_policy(self): """ A default :class:`.policies.RetryPolicy` instance to use for all :class:`.Statement` objects which do not have a :attr:`~.Statement.retry_policy` explicitly set. """ return self._default_retry_policy @default_retry_policy.setter def default_retry_policy(self, policy): if self._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Cluster.default_retry_policy while using Configuration Profiles. Set this in a profile instead.") self._default_retry_policy = policy self._config_mode = _ConfigMode.LEGACY conviction_policy_factory = SimpleConvictionPolicy """ A factory function which creates instances of :class:`.policies.ConvictionPolicy`. Defaults to :class:`.policies.SimpleConvictionPolicy`. """ address_translator = IdentityTranslator() """ :class:`.policies.AddressTranslator` instance to be used in translating server node addresses to driver connection addresses. """ connect_to_remote_hosts = True """ If left as :const:`True`, hosts that are considered :attr:`~.HostDistance.REMOTE` by the :attr:`~.Cluster.load_balancing_policy` will have a connection opened to them. Otherwise, they will not have a connection opened to them. Note that the default load balancing policy ignores remote hosts by default. .. versionadded:: 2.1.0 """ metrics_enabled = False """ Whether or not metric collection is enabled. If enabled, :attr:`.metrics` will be an instance of :class:`~cassandra.metrics.Metrics`. """ metrics = None """ An instance of :class:`cassandra.metrics.Metrics` if :attr:`.metrics_enabled` is :const:`True`, else :const:`None`. """ ssl_options = None """ A optional dict which will be used as kwargs for ``ssl.wrap_socket()`` when new sockets are created. This should be used when client encryption is enabled in Cassandra. By default, a ``ca_certs`` value should be supplied (the value should be a string pointing to the location of the CA certs file), and you probably want to specify ``ssl_version`` as ``ssl.PROTOCOL_TLSv1`` to match Cassandra's default protocol. .. versionchanged:: 3.3.0 In addition to ``wrap_socket`` kwargs, clients may also specify ``'check_hostname': True`` to verify the cert hostname as outlined in RFC 2818 and RFC 6125. Note that this requires the certificate to be transferred, so should almost always require the option ``'cert_reqs': ssl.CERT_REQUIRED``. Note also that this functionality was not built into Python standard library until (2.7.9, 3.2). To enable this mechanism in earlier versions, patch ``ssl.match_hostname`` with a custom or `back-ported function `_. """ sockopts = None """ An optional list of tuples which will be used as arguments to ``socket.setsockopt()`` for all created sockets. Note: some drivers find setting TCPNODELAY beneficial in the context of their execution model. It was not found generally beneficial for this driver. To try with your own workload, set ``sockopts = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]`` """ max_schema_agreement_wait = 10 """ The maximum duration (in seconds) that the driver will wait for schema agreement across the cluster. Defaults to ten seconds. If set <= 0, the driver will bypass schema agreement waits altogether. """ metadata = None """ An instance of :class:`cassandra.metadata.Metadata`. """ connection_class = DefaultConnection """ This determines what event loop system will be used for managing I/O with Cassandra. These are the current options: * :class:`cassandra.io.asyncorereactor.AsyncoreConnection` * :class:`cassandra.io.libevreactor.LibevConnection` * :class:`cassandra.io.eventletreactor.EventletConnection` (requires monkey-patching - see doc for details) * :class:`cassandra.io.geventreactor.GeventConnection` (requires monkey-patching - see doc for details) * :class:`cassandra.io.twistedreactor.TwistedConnection` By default, ``AsyncoreConnection`` will be used, which uses the ``asyncore`` module in the Python standard library. If ``libev`` is installed, ``LibevConnection`` will be used instead. If ``gevent`` or ``eventlet`` monkey-patching is detected, the corresponding connection class will be used automatically. """ control_connection_timeout = 2.0 """ A timeout, in seconds, for queries made by the control connection, such as querying the current schema and information about nodes in the cluster. If set to :const:`None`, there will be no timeout for these queries. """ idle_heartbeat_interval = 30 """ Interval, in seconds, on which to heartbeat idle connections. This helps keep connections open through network devices that expire idle connections. It also helps discover bad connections early in low-traffic scenarios. Setting to zero disables heartbeats. """ schema_event_refresh_window = 2 """ Window, in seconds, within which a schema component will be refreshed after receiving a schema_change event. The driver delays a random amount of time in the range [0.0, window) before executing the refresh. This serves two purposes: 1.) Spread the refresh for deployments with large fanout from C* to client tier, preventing a 'thundering herd' problem with many clients refreshing simultaneously. 2.) Remove redundant refreshes. Redundant events arriving within the delay period are discarded, and only one refresh is executed. Setting this to zero will execute refreshes immediately. Setting this negative will disable schema refreshes in response to push events (refreshes will still occur in response to schema change responses to DDL statements executed by Sessions of this Cluster). """ topology_event_refresh_window = 10 """ Window, in seconds, within which the node and token list will be refreshed after receiving a topology_change event. Setting this to zero will execute refreshes immediately. Setting this negative will disable node refreshes in response to push events. See :attr:`.schema_event_refresh_window` for discussion of rationale """ status_event_refresh_window = 2 """ Window, in seconds, within which the driver will start the reconnect after receiving a status_change event. Setting this to zero will connect immediately. This is primarily used to avoid 'thundering herd' in deployments with large fanout from cluster to clients. When nodes come up, clients attempt to reprepare prepared statements (depending on :attr:`.reprepare_on_up`), and establish connection pools. This can cause a rush of connections and queries if not mitigated with this factor. """ prepare_on_all_hosts = True """ Specifies whether statements should be prepared on all hosts, or just one. This can reasonably be disabled on long-running applications with numerous clients preparing statements on startup, where a randomized initial condition of the load balancing policy can be expected to distribute prepares from different clients across the cluster. """ reprepare_on_up = True """ Specifies whether all known prepared statements should be prepared on a node when it comes up. May be used to avoid overwhelming a node on return, or if it is supposed that the node was only marked down due to network. If statements are not reprepared, they are prepared on the first execution, causing an extra roundtrip for one or more client requests. """ connect_timeout = 5 """ Timeout, in seconds, for creating new connections. This timeout covers the entire connection negotiation, including TCP establishment, options passing, and authentication. """ @property def schema_metadata_enabled(self): """ Flag indicating whether internal schema metadata is updated. When disabled, the driver does not populate Cluster.metadata.keyspaces on connect, or on schema change events. This can be used to speed initial connection, and reduce load on client and server during operation. Turning this off gives away token aware request routing, and programmatic inspection of the metadata model. """ return self.control_connection._schema_meta_enabled @schema_metadata_enabled.setter def schema_metadata_enabled(self, enabled): self.control_connection._schema_meta_enabled = bool(enabled) @property def token_metadata_enabled(self): """ Flag indicating whether internal token metadata is updated. When disabled, the driver does not query node token information on connect, or on topology change events. This can be used to speed initial connection, and reduce load on client and server during operation. It is most useful in large clusters using vnodes, where the token map can be expensive to compute. Turning this off gives away token aware request routing, and programmatic inspection of the token ring. """ return self.control_connection._token_meta_enabled @token_metadata_enabled.setter def token_metadata_enabled(self, enabled): self.control_connection._token_meta_enabled = bool(enabled) profile_manager = None _config_mode = _ConfigMode.UNCOMMITTED sessions = None control_connection = None scheduler = None executor = None is_shutdown = False _is_setup = False _prepared_statements = None _prepared_statement_lock = None _idle_heartbeat = None _protocol_version_explicit = False _discount_down_events = True _user_types = None """ A map of {keyspace: {type_name: UserType}} """ _listeners = None _listener_lock = None def __init__(self, contact_points=["127.0.0.1"], port=9042, compression=True, auth_provider=None, load_balancing_policy=None, reconnection_policy=None, default_retry_policy=None, conviction_policy_factory=None, metrics_enabled=False, connection_class=None, ssl_options=None, sockopts=None, cql_version=None, protocol_version=_NOT_SET, executor_threads=2, max_schema_agreement_wait=10, control_connection_timeout=2.0, idle_heartbeat_interval=30, schema_event_refresh_window=2, topology_event_refresh_window=10, connect_timeout=5, schema_metadata_enabled=True, token_metadata_enabled=True, address_translator=None, status_event_refresh_window=2, prepare_on_all_hosts=True, reprepare_on_up=True, execution_profiles=None, allow_beta_protocol_version=False): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as extablishing connection pools or refreshing metadata. Any of the mutable Cluster attributes may be set as keyword arguments to the constructor. """ if contact_points is not None: if isinstance(contact_points, six.string_types): raise TypeError("contact_points should not be a string, it should be a sequence (e.g. list) of strings") if None in contact_points: raise ValueError("contact_points should not contain None (it can resolve to localhost)") self.contact_points = contact_points self.port = port self.contact_points_resolved = [endpoint[4][0] for a in self.contact_points for endpoint in socket.getaddrinfo(a, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM)] self.compression = compression if protocol_version is not _NOT_SET: self.protocol_version = protocol_version self._protocol_version_explicit = True self.allow_beta_protocol_version = allow_beta_protocol_version self.auth_provider = auth_provider if load_balancing_policy is not None: if isinstance(load_balancing_policy, type): raise TypeError("load_balancing_policy should not be a class, it should be an instance of that class") self.load_balancing_policy = load_balancing_policy else: self._load_balancing_policy = default_lbp_factory() # set internal attribute to avoid committing to legacy config mode if reconnection_policy is not None: if isinstance(reconnection_policy, type): raise TypeError("reconnection_policy should not be a class, it should be an instance of that class") self.reconnection_policy = reconnection_policy if default_retry_policy is not None: if isinstance(default_retry_policy, type): raise TypeError("default_retry_policy should not be a class, it should be an instance of that class") self.default_retry_policy = default_retry_policy if conviction_policy_factory is not None: if not callable(conviction_policy_factory): raise ValueError("conviction_policy_factory must be callable") self.conviction_policy_factory = conviction_policy_factory if address_translator is not None: if isinstance(address_translator, type): raise TypeError("address_translator should not be a class, it should be an instance of that class") self.address_translator = address_translator if connection_class is not None: self.connection_class = connection_class self.profile_manager = ProfileManager() self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile(self.load_balancing_policy, self.default_retry_policy, Session._default_consistency_level, Session._default_serial_consistency_level, Session._default_timeout, Session._row_factory) # legacy mode if either of these is not default if load_balancing_policy or default_retry_policy: if execution_profiles: raise ValueError("Clusters constructed with execution_profiles should not specify legacy parameters " "load_balancing_policy or default_retry_policy. Configure this in a profile instead.") self._config_mode = _ConfigMode.LEGACY else: if execution_profiles: self.profile_manager.profiles.update(execution_profiles) self._config_mode = _ConfigMode.PROFILES self.metrics_enabled = metrics_enabled self.ssl_options = ssl_options self.sockopts = sockopts self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait self.control_connection_timeout = control_connection_timeout self.idle_heartbeat_interval = idle_heartbeat_interval self.schema_event_refresh_window = schema_event_refresh_window self.topology_event_refresh_window = topology_event_refresh_window self.status_event_refresh_window = status_event_refresh_window self.connect_timeout = connect_timeout self.prepare_on_all_hosts = prepare_on_all_hosts self.reprepare_on_up = reprepare_on_up self._listeners = set() self._listener_lock = Lock() # let Session objects be GC'ed (and shutdown) when the user no longer # holds a reference. self.sessions = WeakSet() self.metadata = Metadata() self.control_connection = None self._prepared_statements = WeakValueDictionary() self._prepared_statement_lock = Lock() self._user_types = defaultdict(dict) self._min_requests_per_connection = { HostDistance.LOCAL: DEFAULT_MIN_REQUESTS, HostDistance.REMOTE: DEFAULT_MIN_REQUESTS } self._max_requests_per_connection = { HostDistance.LOCAL: DEFAULT_MAX_REQUESTS, HostDistance.REMOTE: DEFAULT_MAX_REQUESTS } self._core_connections_per_host = { HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST } self._max_connections_per_host = { HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST } self.executor = ThreadPoolExecutor(max_workers=executor_threads) self.scheduler = _Scheduler(self.executor) self._lock = RLock() if self.metrics_enabled: from cassandra.metrics import Metrics self.metrics = Metrics(weakref.proxy(self)) self.control_connection = ControlConnection( self, self.control_connection_timeout, self.schema_event_refresh_window, self.topology_event_refresh_window, self.status_event_refresh_window, schema_metadata_enabled, token_metadata_enabled) def register_user_type(self, keyspace, user_type, klass): """ Registers a class to use to represent a particular user-defined type. Query parameters for this user-defined type will be assumed to be instances of `klass`. Result sets for this user-defined type will be instances of `klass`. If no class is registered for a user-defined type, a namedtuple will be used for result sets, and non-prepared statements may not encode parameters for this type correctly. `keyspace` is the name of the keyspace that the UDT is defined in. `user_type` is the string name of the UDT to register the mapping for. `klass` should be a class with attributes whose names match the fields of the user-defined type. The constructor must accepts kwargs for each of the fields in the UDT. This method should only be called after the type has been created within Cassandra. Example:: cluster = Cluster(protocol_version=3) session = cluster.connect() session.set_keyspace('mykeyspace') session.execute("CREATE TYPE address (street text, zipcode int)") session.execute("CREATE TABLE users (id int PRIMARY KEY, location address)") # create a class to map to the "address" UDT class Address(object): def __init__(self, street, zipcode): self.street = street self.zipcode = zipcode cluster.register_user_type('mykeyspace', 'address', Address) # insert a row using an instance of Address session.execute("INSERT INTO users (id, location) VALUES (%s, %s)", (0, Address("123 Main St.", 78723))) # results will include Address instances results = session.execute("SELECT * FROM users") row = results[0] print row.id, row.location.street, row.location.zipcode """ if self.protocol_version < 3: log.warning("User Type serialization is only supported in native protocol version 3+ (%d in use). " "CQL encoding for simple statements will still work, but named tuples will " "be returned when reading type %s.%s.", self.protocol_version, keyspace, user_type) self._user_types[keyspace][user_type] = klass for session in self.sessions: session.user_type_registered(keyspace, user_type, klass) UserType.evict_udt_class(keyspace, user_type) def add_execution_profile(self, name, profile, pool_wait_timeout=5): """ Adds an :class:`.ExecutionProfile` to the cluster. This makes it available for use by ``name`` in :meth:`.Session.execute` and :meth:`.Session.execute_async`. This method will raise if the profile already exists. Normally profiles will be injected at cluster initialization via ``Cluster(execution_profiles)``. This method provides a way of adding them dynamically. Adding a new profile updates the connection pools according to the specified ``load_balancing_policy``. By default, this method will wait up to five seconds for the pool creation to complete, so the profile can be used immediately upon return. This behavior can be controlled using ``pool_wait_timeout`` (see `concurrent.futures.wait `_ for timeout semantics). """ if not isinstance(profile, ExecutionProfile): raise TypeError("profile must be an instance of ExecutionProfile") if self._config_mode == _ConfigMode.LEGACY: raise ValueError("Cannot add execution profiles when legacy parameters are set explicitly. TODO: link to doc") if name in self.profile_manager.profiles: raise ValueError("Profile %s already exists") self.profile_manager.profiles[name] = profile profile.load_balancing_policy.populate(self, self.metadata.all_hosts()) # on_up after populate allows things like DCA LBP to choose default local dc for host in filter(lambda h: h.is_up, self.metadata.all_hosts()): profile.load_balancing_policy.on_up(host) futures = set() for session in self.sessions: futures.update(session.update_created_pools()) _, not_done = wait_futures(futures, pool_wait_timeout) if not_done: raise OperationTimedOut("Failed to create all new connection pools in the %ss timeout.") def get_min_requests_per_connection(self, host_distance): return self._min_requests_per_connection[host_distance] def set_min_requests_per_connection(self, host_distance, min_requests): """ Sets a threshold for concurrent requests per connection, below which connections will be considered for disposal (down to core connections; see :meth:`~Cluster.set_core_connections_per_host`). Pertains to connection pool management in protocol versions {1,2}. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_min_requests_per_connection() only has an effect " "when using protocol_version 1 or 2.") if min_requests < 0 or min_requests > 126 or \ min_requests >= self._max_requests_per_connection[host_distance]: raise ValueError("min_requests must be 0-126 and less than the max_requests for this host_distance (%d)" % (self._min_requests_per_connection[host_distance],)) self._min_requests_per_connection[host_distance] = min_requests def get_max_requests_per_connection(self, host_distance): return self._max_requests_per_connection[host_distance] def set_max_requests_per_connection(self, host_distance, max_requests): """ Sets a threshold for concurrent requests per connection, above which new connections will be created to a host (up to max connections; see :meth:`~Cluster.set_max_connections_per_host`). Pertains to connection pool management in protocol versions {1,2}. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_max_requests_per_connection() only has an effect " "when using protocol_version 1 or 2.") if max_requests < 1 or max_requests > 127 or \ max_requests <= self._min_requests_per_connection[host_distance]: raise ValueError("max_requests must be 1-127 and greater than the min_requests for this host_distance (%d)" % (self._min_requests_per_connection[host_distance],)) self._max_requests_per_connection[host_distance] = max_requests def get_core_connections_per_host(self, host_distance): """ Gets the minimum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. This property is ignored if :attr:`~.Cluster.protocol_version` is 3 or higher. """ return self._core_connections_per_host[host_distance] def set_core_connections_per_host(self, host_distance, core_connections): """ Sets the minimum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. Protocol version 1 and 2 are limited in the number of concurrent requests they can send per connection. The driver implements connection pooling to support higher levels of concurrency. If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) and using this will result in an :exc:`~.UnsupporteOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_core_connections_per_host() only has an effect " "when using protocol_version 1 or 2.") old = self._core_connections_per_host[host_distance] self._core_connections_per_host[host_distance] = core_connections if old < core_connections: self._ensure_core_connections() def get_max_connections_per_host(self, host_distance): """ Gets the maximum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 8 for :attr:`~HostDistance.LOCAL` and 2 for :attr:`~HostDistance.REMOTE`. This property is ignored if :attr:`~.Cluster.protocol_version` is 3 or higher. """ return self._max_connections_per_host[host_distance] def set_max_connections_per_host(self, host_distance, max_connections): """ Sets the maximum number of connections per Session that will be opened for each host with :class:`~.HostDistance` equal to `host_distance`. The default is 2 for :attr:`~HostDistance.LOCAL` and 1 for :attr:`~HostDistance.REMOTE`. If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) and using this will result in an :exc:`~.UnsupporteOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( "Cluster.set_max_connections_per_host() only has an effect " "when using protocol_version 1 or 2.") self._max_connections_per_host[host_distance] = max_connections def connection_factory(self, address, *args, **kwargs): """ Called to create a new connection with proper configuration. Intended for internal use only. """ kwargs = self._make_connection_kwargs(address, kwargs) return self.connection_class.factory(address, self.connect_timeout, *args, **kwargs) def _make_connection_factory(self, host, *args, **kwargs): kwargs = self._make_connection_kwargs(host.address, kwargs) return partial(self.connection_class.factory, host.address, self.connect_timeout, *args, **kwargs) def _make_connection_kwargs(self, address, kwargs_dict): if self._auth_provider_callable: kwargs_dict.setdefault('authenticator', self._auth_provider_callable(address)) kwargs_dict.setdefault('port', self.port) kwargs_dict.setdefault('compression', self.compression) kwargs_dict.setdefault('sockopts', self.sockopts) kwargs_dict.setdefault('ssl_options', self.ssl_options) kwargs_dict.setdefault('cql_version', self.cql_version) kwargs_dict.setdefault('protocol_version', self.protocol_version) kwargs_dict.setdefault('user_type_map', self._user_types) kwargs_dict.setdefault('allow_beta_protocol_version', self.allow_beta_protocol_version) return kwargs_dict def protocol_downgrade(self, host_addr, previous_version): if self._protocol_version_explicit: raise DriverException("ProtocolError returned from server while using explicitly set client protocol_version %d" % (previous_version,)) new_version = previous_version - 1 if new_version < self.protocol_version: if new_version >= MIN_SUPPORTED_VERSION: log.warning("Downgrading core protocol version from %d to %d for %s. " "To avoid this, it is best practice to explicitly set Cluster(protocol_version) to the version supported by your cluster. " "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_addr) self.protocol_version = new_version else: raise DriverException("Cannot downgrade protocol version (%d) below minimum supported version: %d" % (new_version, MIN_SUPPORTED_VERSION)) def connect(self, keyspace=None, wait_for_all_pools=False): """ Creates and returns a new :class:`~.Session` object. If `keyspace` is specified, that keyspace will be the default keyspace for operations on the ``Session``. """ with self._lock: if self.is_shutdown: raise DriverException("Cluster is already shut down") if not self._is_setup: log.debug("Connecting to cluster, contact points: %s; protocol version: %s", self.contact_points, self.protocol_version) self.connection_class.initialize_reactor() _register_cluster_shutdown(self) for address in self.contact_points_resolved: host, new = self.add_host(address, signal=False) if new: host.set_up() for listener in self.listeners: listener.on_add(host) self.profile_manager.populate( weakref.proxy(self), self.metadata.all_hosts()) try: self.control_connection.connect() # we set all contact points up for connecting, but we won't infer state after this for address in self.contact_points_resolved: h = self.metadata.get_host(address) if h and self.profile_manager.distance(h) == HostDistance.IGNORED: h.is_up = None log.debug("Control connection created") except Exception: log.exception("Control connection failed to connect, " "shutting down Cluster:") self.shutdown() raise self.profile_manager.check_supported() # todo: rename this method if self.idle_heartbeat_interval: self._idle_heartbeat = ConnectionHeartbeat(self.idle_heartbeat_interval, self.get_connection_holders) self._is_setup = True session = self._new_session(keyspace) if wait_for_all_pools: wait_futures(session._initial_connect_futures) return session def get_connection_holders(self): holders = [] for s in self.sessions: holders.extend(s.get_pools()) holders.append(self.control_connection) return holders def shutdown(self): """ Closes all sessions and connection associated with this Cluster. To ensure all connections are properly closed, **you should always call shutdown() on a Cluster instance when you are done with it**. Once shutdown, a Cluster should not be used for any purpose. """ with self._lock: if self.is_shutdown: return else: self.is_shutdown = True if self._idle_heartbeat: self._idle_heartbeat.stop() self.scheduler.shutdown() self.control_connection.shutdown() for session in self.sessions: session.shutdown() self.executor.shutdown() _discard_cluster_shutdown(self) def __enter__(self): return self def __exit__(self, *args): self.shutdown() def _new_session(self, keyspace): session = Session(self, self.metadata.all_hosts(), keyspace) self._session_register_user_types(session) self.sessions.add(session) return session def _session_register_user_types(self, session): for keyspace, type_map in six.iteritems(self._user_types): for udt_name, klass in six.iteritems(type_map): session.user_type_registered(keyspace, udt_name, klass) def _cleanup_failed_on_up_handling(self, host): self.profile_manager.on_down(host) self.control_connection.on_down(host) for session in self.sessions: session.remove_pool(host) self._start_reconnector(host, is_host_addition=False) def _on_up_future_completed(self, host, futures, results, lock, finished_future): with lock: futures.discard(finished_future) try: results.append(finished_future.result()) except Exception as exc: results.append(exc) if futures: return try: # all futures have completed at this point for exc in [f for f in results if isinstance(f, Exception)]: log.error("Unexpected failure while marking node %s up:", host, exc_info=exc) self._cleanup_failed_on_up_handling(host) return if not all(results): log.debug("Connection pool could not be created, not marking node %s up", host) self._cleanup_failed_on_up_handling(host) return log.info("Connection pools established for node %s", host) # mark the host as up and notify all listeners host.set_up() for listener in self.listeners: listener.on_up(host) finally: with host.lock: host._currently_handling_node_up = False # see if there are any pools to add or remove now that the host is marked up for session in self.sessions: session.update_created_pools() def on_up(self, host): """ Intended for internal use only. """ if self.is_shutdown: return log.debug("Waiting to acquire lock for handling up status of node %s", host) with host.lock: if host._currently_handling_node_up: log.debug("Another thread is already handling up status of node %s", host) return if host.is_up: log.debug("Host %s was already marked up", host) return host._currently_handling_node_up = True log.debug("Starting to handle up status of node %s", host) have_future = False futures = set() try: log.info("Host %s may be up; will prepare queries and open connection pool", host) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: log.debug("Now that host %s is up, cancelling the reconnection handler", host) reconnector.cancel() self._prepare_all_queries(host) log.debug("Done preparing all queries for host %s, ", host) for session in self.sessions: session.remove_pool(host) log.debug("Signalling to load balancing policies that host %s is up", host) self.profile_manager.on_up(host) log.debug("Signalling to control connection that host %s is up", host) self.control_connection.on_up(host) log.debug("Attempting to open new connection pools for host %s", host) futures_lock = Lock() futures_results = [] callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) for session in self.sessions: future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: have_future = True future.add_done_callback(callback) futures.add(future) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: future.cancel() self._cleanup_failed_on_up_handling(host) with host.lock: host._currently_handling_node_up = False raise else: if not have_future: with host.lock: host.set_up() host._currently_handling_node_up = False # for testing purposes return futures def _start_reconnector(self, host, is_host_addition): if self.profile_manager.distance(host) == HostDistance.IGNORED: return schedule = self.reconnection_policy.new_schedule() # in order to not hold references to this Cluster open and prevent # proper shutdown when the program ends, we'll just make a closure # of the current Cluster attributes to create new Connections with conn_factory = self._make_connection_factory(host) reconnector = _HostReconnectionHandler( host, conn_factory, is_host_addition, self.on_add, self.on_up, self.scheduler, schedule, host.get_and_set_reconnection_handler, new_handler=None) old_reconnector = host.get_and_set_reconnection_handler(reconnector) if old_reconnector: log.debug("Old host reconnector found for %s, cancelling", host) old_reconnector.cancel() log.debug("Starting reconnector for host %s", host) reconnector.start() @run_in_executor def on_down(self, host, is_host_addition, expect_host_to_be_down=False): """ Intended for internal use only. """ if self.is_shutdown: return with host.lock: was_up = host.is_up # ignore down signals if we have open pools to the host # this is to avoid closing pools when a control connection host became isolated if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: connected = False for session in self.sessions: pool_states = session.get_pool_state() pool_state = pool_states.get(host) if pool_state: connected |= pool_state['open_count'] > 0 if connected: return host.set_down() if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): return log.warning("Host %s has been marked down", host) self.profile_manager.on_down(host) self.control_connection.on_down(host) for session in self.sessions: session.on_down(host) for listener in self.listeners: listener.on_down(host) self._start_reconnector(host, is_host_addition) def on_add(self, host, refresh_nodes=True): if self.is_shutdown: return log.debug("Handling new host %r and notifying listeners", host) distance = self.profile_manager.distance(host) if distance != HostDistance.IGNORED: self._prepare_all_queries(host) log.debug("Done preparing queries for new host %r", host) self.profile_manager.on_add(host) self.control_connection.on_add(host, refresh_nodes) if distance == HostDistance.IGNORED: log.debug("Not adding connection pool for new host %r because the " "load balancing policy has marked it as IGNORED", host) self._finalize_add(host, set_up=False) return futures_lock = Lock() futures_results = [] futures = set() def future_completed(future): with futures_lock: futures.discard(future) try: futures_results.append(future.result()) except Exception as exc: futures_results.append(exc) if futures: return log.debug('All futures have completed for added host %s', host) for exc in [f for f in futures_results if isinstance(f, Exception)]: log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc) return if not all(futures_results): log.warning("Connection pool could not be created, not marking node %s up", host) return self._finalize_add(host) have_future = False for session in self.sessions: future = session.add_or_renew_pool(host, is_host_addition=True) if future is not None: have_future = True futures.add(future) future.add_done_callback(future_completed) if not have_future: self._finalize_add(host) def _finalize_add(self, host, set_up=True): if set_up: host.set_up() for listener in self.listeners: listener.on_add(host) # see if there are any pools to add or remove now that the host is marked up for session in self.sessions: session.update_created_pools() def on_remove(self, host): if self.is_shutdown: return log.debug("Removing host %s", host) host.set_down() self.profile_manager.on_remove(host) for session in self.sessions: session.on_remove(host) for listener in self.listeners: listener.on_remove(host) self.control_connection.on_remove(host) def signal_connection_failure(self, host, connection_exc, is_host_addition, expect_host_to_be_down=False): is_down = host.signal_connection_failure(connection_exc) if is_down: self.on_down(host, is_host_addition, expect_host_to_be_down) return is_down def add_host(self, address, datacenter=None, rack=None, signal=True, refresh_nodes=True): """ Called when adding initial contact points and when the control connection subsequently discovers a new node. Returns a Host instance, and a flag indicating whether it was new in the metadata. Intended for internal use only. """ host, new = self.metadata.add_or_return_host(Host(address, self.conviction_policy_factory, datacenter, rack)) if new and signal: log.info("New Cassandra host %r discovered", host) self.on_add(host, refresh_nodes) return host, new def remove_host(self, host): """ Called when the control connection observes that a node has left the ring. Intended for internal use only. """ if host and self.metadata.remove_host(host): log.info("Cassandra host %s removed", host) self.on_remove(host) def register_listener(self, listener): """ Adds a :class:`cassandra.policies.HostStateListener` subclass instance to the list of listeners to be notified when a host is added, removed, marked up, or marked down. """ with self._listener_lock: self._listeners.add(listener) def unregister_listener(self, listener): """ Removes a registered listener. """ with self._listener_lock: self._listeners.remove(listener) @property def listeners(self): with self._listener_lock: return self._listeners.copy() def _ensure_core_connections(self): """ If any host has fewer than the configured number of core connections open, attempt to open connections until that number is met. """ for session in self.sessions: for pool in session._pools.values(): pool.ensure_core_connections() @staticmethod def _validate_refresh_schema(keyspace, table, usertype, function, aggregate): if any((table, usertype, function, aggregate)): if not keyspace: raise ValueError("keyspace is required to refresh specific sub-entity {table, usertype, function, aggregate}") if sum(1 for e in (table, usertype, function) if e) > 1: raise ValueError("{table, usertype, function, aggregate} are mutually exclusive") @staticmethod def _target_type_from_refresh_args(keyspace, table, usertype, function, aggregate): if aggregate: return SchemaTargetType.AGGREGATE elif function: return SchemaTargetType.FUNCTION elif usertype: return SchemaTargetType.TYPE elif table: return SchemaTargetType.TABLE elif keyspace: return SchemaTargetType.KEYSPACE return None def get_control_connection_host(self): """ Returns the control connection host metadata. """ connection = self.control_connection._connection host = connection.host if connection else None return self.metadata.get_host(host) if host else None def refresh_schema_metadata(self, max_schema_agreement_wait=None): """ Synchronously refresh all schema metadata. By default, the timeout for this operation is governed by :attr:`~.Cluster.max_schema_agreement_wait` and :attr:`~.Cluster.control_connection_timeout`. Passing max_schema_agreement_wait here overrides :attr:`~.Cluster.max_schema_agreement_wait`. Setting max_schema_agreement_wait <= 0 will bypass schema agreement and refresh schema immediately. An Exception is raised if schema refresh fails for any reason. """ if not self.control_connection.refresh_schema(schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("Schema metadata was not refreshed. See log for details.") def refresh_keyspace_metadata(self, keyspace, max_schema_agreement_wait=None): """ Synchronously refresh keyspace metadata. This applies to keyspace-level information such as replication and durability settings. It does not refresh tables, types, etc. contained in the keyspace. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.KEYSPACE, keyspace=keyspace, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("Keyspace metadata was not refreshed. See log for details.") def refresh_table_metadata(self, keyspace, table, max_schema_agreement_wait=None): """ Synchronously refresh table metadata. This applies to a table, and any triggers or indexes attached to the table. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=table, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("Table metadata was not refreshed. See log for details.") def refresh_materialized_view_metadata(self, keyspace, view, max_schema_agreement_wait=None): """ Synchronously refresh materialized view metadata. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TABLE, keyspace=keyspace, table=view, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("View metadata was not refreshed. See log for details.") def refresh_user_type_metadata(self, keyspace, user_type, max_schema_agreement_wait=None): """ Synchronously refresh user defined type metadata. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.TYPE, keyspace=keyspace, type=user_type, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("User Type metadata was not refreshed. See log for details.") def refresh_user_function_metadata(self, keyspace, function, max_schema_agreement_wait=None): """ Synchronously refresh user defined function metadata. ``function`` is a :class:`cassandra.UserFunctionDescriptor`. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.FUNCTION, keyspace=keyspace, function=function, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("User Function metadata was not refreshed. See log for details.") def refresh_user_aggregate_metadata(self, keyspace, aggregate, max_schema_agreement_wait=None): """ Synchronously refresh user defined aggregate metadata. ``aggregate`` is a :class:`cassandra.UserAggregateDescriptor`. See :meth:`~.Cluster.refresh_schema_metadata` for description of ``max_schema_agreement_wait`` behavior """ if not self.control_connection.refresh_schema(target_type=SchemaTargetType.AGGREGATE, keyspace=keyspace, aggregate=aggregate, schema_agreement_wait=max_schema_agreement_wait, force=True): raise DriverException("User Aggregate metadata was not refreshed. See log for details.") def refresh_nodes(self, force_token_rebuild=False): """ Synchronously refresh the node list and token metadata `force_token_rebuild` can be used to rebuild the token map metadata, even if no new nodes are discovered. An Exception is raised if node refresh fails for any reason. """ if not self.control_connection.refresh_node_list_and_token_map(force_token_rebuild): raise DriverException("Node list was not refreshed. See log for details.") def set_meta_refresh_enabled(self, enabled): """ *Deprecated:* set :attr:`~.Cluster.schema_metadata_enabled` :attr:`~.Cluster.token_metadata_enabled` instead Sets a flag to enable (True) or disable (False) all metadata refresh queries. This applies to both schema and node topology. Disabling this is useful to minimize refreshes during multiple changes. Meta refresh must be enabled for the driver to become aware of any cluster topology changes or schema updates. """ self.schema_metadata_enabled = enabled self.token_metadata_enabled = enabled def _prepare_all_queries(self, host): if not self._prepared_statements or not self.reprepare_on_up: return log.debug("Preparing all known prepared statements against host %s", host) connection = None try: connection = self.connection_factory(host.address) statements = self._prepared_statements.values() for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): if keyspace is not None: connection.set_keyspace_blocking(keyspace) # prepare 10 statements at a time ks_statements = list(ks_statements) chunks = [] for i in range(0, len(ks_statements), 10): chunks.append(ks_statements[i:i + 10]) for ks_chunk in chunks: messages = [PrepareMessage(query=s.query_string) for s in ks_chunk] # TODO: make this timeout configurable somehow? responses = connection.wait_for_responses(*messages, timeout=5.0, fail_on_error=False) for success, response in responses: if not success: log.debug("Got unexpected response when preparing " "statement on host %s: %r", host, response) log.debug("Done preparing all known prepared statements against host %s", host) except OperationTimedOut as timeout: log.warning("Timed out trying to prepare all statements on host %s: %s", host, timeout) except (ConnectionException, socket.error) as exc: log.warning("Error trying to prepare all statements on host %s: %r", host, exc) except Exception: log.exception("Error trying to prepare all statements on host %s", host) finally: if connection: connection.close() def add_prepared(self, query_id, prepared_statement): with self._prepared_statement_lock: self._prepared_statements[query_id] = prepared_statement class Session(object): """ A collection of connection pools for each host in the cluster. Instances of this class should not be created directly, only using :meth:`.Cluster.connect()`. Queries and statements can be executed through ``Session`` instances using the :meth:`~.Session.execute()` and :meth:`~.Session.execute_async()` methods. Example usage:: >>> session = cluster.connect() >>> session.set_keyspace("mykeyspace") >>> session.execute("SELECT * FROM mycf") """ cluster = None hosts = None keyspace = None is_shutdown = False _row_factory = staticmethod(named_tuple_factory) @property def row_factory(self): """ The format to return row results in. By default, each returned row will be a named tuple. You can alternatively use any of the following: - :func:`cassandra.query.tuple_factory` - return a result row as a tuple - :func:`cassandra.query.named_tuple_factory` - return a result row as a named tuple - :func:`cassandra.query.dict_factory` - return a result row as a dict - :func:`cassandra.query.ordered_dict_factory` - return a result row as an OrderedDict """ return self._row_factory @row_factory.setter def row_factory(self, rf): self._validate_set_legacy_config('row_factory', rf) _default_timeout = 10.0 @property def default_timeout(self): """ A default timeout, measured in seconds, for queries executed through :meth:`.execute()` or :meth:`.execute_async()`. This default may be overridden with the `timeout` parameter for either of those methods. Setting this to :const:`None` will cause no timeouts to be set by default. Please see :meth:`.ResponseFuture.result` for details on the scope and effect of this timeout. .. versionadded:: 2.0.0 """ return self._default_timeout @default_timeout.setter def default_timeout(self, timeout): self._validate_set_legacy_config('default_timeout', timeout) _default_consistency_level = ConsistencyLevel.LOCAL_ONE @property def default_consistency_level(self): """ The default :class:`~ConsistencyLevel` for operations executed through this session. This default may be overridden by setting the :attr:`~.Statement.consistency_level` on individual statements. .. versionadded:: 1.2.0 .. versionchanged:: 3.0.0 default changed from ONE to LOCAL_ONE """ return self._default_consistency_level @default_consistency_level.setter def default_consistency_level(self, cl): self._validate_set_legacy_config('default_consistency_level', cl) _default_serial_consistency_level = None @property def default_serial_consistency_level(self): """ The default :class:`~ConsistencyLevel` for serial phase of conditional updates executed through this session. This default may be overridden by setting the :attr:`~.Statement.serial_consistency_level` on individual statements. Only valid for ``protocol_version >= 2``. """ return self._default_serial_consistency_level @default_serial_consistency_level.setter def default_serial_consistency_level(self, cl): self._validate_set_legacy_config('default_serial_consistency_level', cl) max_trace_wait = 2.0 """ The maximum amount of time (in seconds) the driver will wait for trace details to be populated server-side for a query before giving up. If the `trace` parameter for :meth:`~.execute()` or :meth:`~.execute_async()` is :const:`True`, the driver will repeatedly attempt to fetch trace details for the query (using exponential backoff) until this limit is hit. If the limit is passed, an error will be logged and the :attr:`.Statement.trace` will be left as :const:`None`. """ default_fetch_size = 5000 """ By default, this many rows will be fetched at a time. Setting this to :const:`None` will disable automatic paging for large query results. The fetch size can be also specified per-query through :attr:`.Statement.fetch_size`. This only takes effect when protocol version 2 or higher is used. See :attr:`.Cluster.protocol_version` for details. .. versionadded:: 2.0.0 """ use_client_timestamp = True """ When using protocol version 3 or higher, write timestamps may be supplied client-side at the protocol level. (Normally they are generated server-side by the coordinator node.) Note that timestamps specified within a CQL query will override this timestamp. .. versionadded:: 2.1.0 """ encoder = None """ A :class:`~cassandra.encoder.Encoder` instance that will be used when formatting query parameters for non-prepared statements. This is not used for prepared statements (because prepared statements give the driver more information about what CQL types are expected, allowing it to accept a wider range of python types). The encoder uses a mapping from python types to encoder methods (for specific CQL types). This mapping can be be modified by users as they see fit. Methods of :class:`~cassandra.encoder.Encoder` should be used for mapping values if possible, because they take precautions to avoid injections and properly sanitize data. Example:: cluster = Cluster() session = cluster.connect("mykeyspace") session.encoder.mapping[tuple] = session.encoder.cql_encode_tuple session.execute("CREATE TABLE mytable (k int PRIMARY KEY, col tuple)") session.execute("INSERT INTO mytable (k, col) VALUES (%s, %s)", [0, (123, 'abc')]) .. versionadded:: 2.1.0 """ client_protocol_handler = ProtocolHandler """ Specifies a protocol handler that will be used for client-initiated requests (i.e. no internal driver requests). This can be used to override or extend features such as message or type ser/des. The default pure python implementation is :class:`cassandra.protocol.ProtocolHandler`. When compiled with Cython, there are also built-in faster alternatives. See :ref:`faster_deser` """ _lock = None _pools = None _profile_manager = None _metrics = None _request_init_callbacks = None def __init__(self, cluster, hosts, keyspace=None): self.cluster = cluster self.hosts = hosts self.keyspace = keyspace self._lock = RLock() self._pools = {} self._profile_manager = cluster.profile_manager self._metrics = cluster.metrics self._request_init_callbacks = [] self._protocol_version = self.cluster.protocol_version self.encoder = Encoder() # create connection pools in parallel self._initial_connect_futures = set() for host in hosts: future = self.add_or_renew_pool(host, is_host_addition=False) if future: self._initial_connect_futures.add(future) futures = wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) while futures.not_done and not any(f.result() for f in futures.done): futures = wait_futures(futures.not_done, return_when=FIRST_COMPLETED) def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None): """ Execute the given query and synchronously wait for the response. If an error is encountered while executing the query, an Exception will be raised. `query` may be a query string or an instance of :class:`cassandra.query.Statement`. `parameters` may be a sequence or dict of parameters to bind. If a sequence is used, ``%s`` should be used the placeholder for each argument. If a dict is used, ``%(name)s`` style placeholders must be used. `timeout` should specify a floating-point timeout (in seconds) after which an :exc:`.OperationTimedOut` exception will be raised if the query has not completed. If not set, the timeout defaults to :attr:`~.Session.default_timeout`. If set to :const:`None`, there is no timeout. Please see :meth:`.ResponseFuture.result` for details on the scope and effect of this timeout. If `trace` is set to :const:`True`, the query will be sent with tracing enabled. The trace details can be obtained using the returned :class:`.ResultSet` object. `custom_payload` is a :ref:`custom_payload` dict to be passed to the server. If `query` is a Statement with its own custom_payload. The message payload will be a union of the two, with the values specified here taking precedence. `execution_profile` is the execution profile to use for this request. It can be a key to a profile configured via :meth:`Cluster.add_execution_profile` or an instance (from :meth:`Session.execution_profile_clone_update`, for example `paging_state` is an optional paging state, reused from a previous :class:`ResultSet`. """ return self.execute_async(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state).result() def execute_async(self, query, parameters=None, trace=False, custom_payload=None, timeout=_NOT_SET, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None): """ Execute the given query and return a :class:`~.ResponseFuture` object which callbacks may be attached to for asynchronous response delivery. You may also call :meth:`~.ResponseFuture.result()` on the :class:`.ResponseFuture` to synchronously block for results at any time. See :meth:`Session.execute` for parameter definitions. Example usage:: >>> session = cluster.connect() >>> future = session.execute_async("SELECT * FROM mycf") >>> def log_results(results): ... for row in results: ... log.info("Results: %s", row) >>> def log_error(exc): >>> log.error("Operation failed: %s", exc) >>> future.add_callbacks(log_results, log_error) Async execution with blocking wait for results:: >>> future = session.execute_async("SELECT * FROM mycf") >>> # do other stuff... >>> try: ... results = future.result() ... except Exception: ... log.exception("Operation failed:") """ future = self._create_response_future(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state) future._protocol_handler = self.client_protocol_handler self._on_request(future) future.send_request() return future def _create_response_future(self, query, parameters, trace, custom_payload, timeout, execution_profile=EXEC_PROFILE_DEFAULT, paging_state=None): """ Returns the ResponseFuture before calling send_request() on it """ prepared_statement = None if isinstance(query, six.string_types): query = SimpleStatement(query) elif isinstance(query, PreparedStatement): query = query.bind(parameters) if self.cluster._config_mode == _ConfigMode.LEGACY: if execution_profile is not EXEC_PROFILE_DEFAULT: raise ValueError("Cannot specify execution_profile while using legacy parameters.") if timeout is _NOT_SET: timeout = self.default_timeout cl = query.consistency_level if query.consistency_level is not None else self.default_consistency_level serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else self.default_serial_consistency_level retry_policy = query.retry_policy or self.cluster.default_retry_policy row_factory = self.row_factory load_balancing_policy = self.cluster.load_balancing_policy spec_exec_policy = None else: execution_profile = self._get_execution_profile(execution_profile) if timeout is _NOT_SET: timeout = execution_profile.request_timeout cl = query.consistency_level if query.consistency_level is not None else execution_profile.consistency_level serial_cl = query.serial_consistency_level if query.serial_consistency_level is not None else execution_profile.serial_consistency_level retry_policy = query.retry_policy or execution_profile.retry_policy row_factory = execution_profile.row_factory load_balancing_policy = execution_profile.load_balancing_policy spec_exec_policy = execution_profile.speculative_execution_policy fetch_size = query.fetch_size if fetch_size is FETCH_SIZE_UNSET and self._protocol_version >= 2: fetch_size = self.default_fetch_size elif self._protocol_version == 1: fetch_size = None start_time = time.time() if self._protocol_version >= 3 and self.use_client_timestamp: timestamp = int(start_time * 1e6) else: timestamp = None if isinstance(query, SimpleStatement): query_string = query.query_string if parameters: query_string = bind_params(query_string, parameters, self.encoder) message = QueryMessage( query_string, cl, serial_cl, fetch_size, timestamp=timestamp) elif isinstance(query, BoundStatement): prepared_statement = query.prepared_statement message = ExecuteMessage( prepared_statement.query_id, query.values, cl, serial_cl, fetch_size, timestamp=timestamp, skip_meta=bool(prepared_statement.result_metadata)) elif isinstance(query, BatchStatement): if self._protocol_version < 2: raise UnsupportedOperation( "BatchStatement execution is only supported with protocol version " "2 or higher (supported in Cassandra 2.0 and higher). Consider " "setting Cluster.protocol_version to 2 to support this operation.") message = BatchMessage( query.batch_type, query._statements_and_parameters, cl, serial_cl, timestamp) message.tracing = trace message.update_custom_payload(query.custom_payload) message.update_custom_payload(custom_payload) message.allow_beta_protocol_version = self.cluster.allow_beta_protocol_version message.paging_state = paging_state spec_exec_plan = spec_exec_policy.new_plan(query.keyspace or self.keyspace, query) if query.is_idempotent and spec_exec_policy else None return ResponseFuture( self, message, query, timeout, metrics=self._metrics, prepared_statement=prepared_statement, retry_policy=retry_policy, row_factory=row_factory, load_balancer=load_balancing_policy, start_time=start_time, speculative_execution_plan=spec_exec_plan) def _get_execution_profile(self, ep): profiles = self.cluster.profile_manager.profiles try: return ep if isinstance(ep, ExecutionProfile) else profiles[ep] except KeyError: raise ValueError("Invalid execution_profile: '%s'; valid profiles are %s" % (ep, profiles.keys())) def execution_profile_clone_update(self, ep, **kwargs): """ Returns a clone of the ``ep`` profile. ``kwargs`` can be specified to update attributes of the returned profile. This is a shollow clone, so any objects referenced by the profile are shared. This means Load Balancing Policy is maintained by inclusion in the active profiles. It also means updating any other rich objects will be seen by the active profile. In cases where this is not desirable, be sure to replace the instance instead of manipulating the shared object. """ clone = copy(self._get_execution_profile(ep)) for attr, value in kwargs.items(): setattr(clone, attr, value) return clone def add_request_init_listener(self, fn, *args, **kwargs): """ Adds a callback with arguments to be called when any request is created. It will be invoked as `fn(response_future, *args, **kwargs)` after each client request is created, and before the request is sent\*. This can be used to create extensions by adding result callbacks to the response future. \* where `response_future` is the :class:`.ResponseFuture` for the request. Note that the init callback is done on the client thread creating the request, so you may need to consider synchronization if you have multiple threads. Any callbacks added to the response future will be executed on the event loop thread, so the normal advice about minimizing cycles and avoiding blocking apply (see Note in :meth:`.ResponseFuture.add_callbacks`. See `this example `_ in the source tree for an example. """ self._request_init_callbacks.append((fn, args, kwargs)) def remove_request_init_listener(self, fn, *args, **kwargs): """ Removes a callback and arguments from the list. See :meth:`.Session.add_request_init_listener`. """ self._request_init_callbacks.remove((fn, args, kwargs)) def _on_request(self, response_future): for fn, args, kwargs in self._request_init_callbacks: fn(response_future, *args, **kwargs) def prepare(self, query, custom_payload=None): """ Prepares a query string, returning a :class:`~cassandra.query.PreparedStatement` instance which can be used as follows:: >>> session = cluster.connect("mykeyspace") >>> query = "INSERT INTO users (id, name, age) VALUES (?, ?, ?)" >>> prepared = session.prepare(query) >>> session.execute(prepared, (user.id, user.name, user.age)) Or you may bind values to the prepared statement ahead of time:: >>> prepared = session.prepare(query) >>> bound_stmt = prepared.bind((user.id, user.name, user.age)) >>> session.execute(bound_stmt) Of course, prepared statements may (and should) be reused:: >>> prepared = session.prepare(query) >>> for user in users: ... bound = prepared.bind((user.id, user.name, user.age)) ... session.execute(bound) **Important**: PreparedStatements should be prepared only once. Preparing the same query more than once will likely affect performance. `custom_payload` is a key value map to be passed along with the prepare message. See :ref:`custom_payload`. """ message = PrepareMessage(query=query) future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) try: future.send_request() query_id, bind_metadata, pk_indexes, result_metadata = future.result() except Exception: log.exception("Error preparing query:") raise prepared_statement = PreparedStatement.from_message( query_id, bind_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace, self._protocol_version, result_metadata) prepared_statement.custom_payload = future.custom_payload self.cluster.add_prepared(query_id, prepared_statement) if self.cluster.prepare_on_all_hosts: host = future._current_host try: self.prepare_on_all_hosts(prepared_statement.query_string, host) except Exception: log.exception("Error preparing query on all hosts:") return prepared_statement def prepare_on_all_hosts(self, query, excluded_host): """ Prepare the given query on all hosts, excluding ``excluded_host``. Intended for internal use only. """ futures = [] for host in self._pools.keys(): if host != excluded_host and host.is_up: future = ResponseFuture(self, PrepareMessage(query=query), None, self.default_timeout) # we don't care about errors preparing against specific hosts, # since we can always prepare them as needed when the prepared # statement is used. Just log errors and continue on. try: request_id = future._query(host) except Exception: log.exception("Error preparing query for host %s:", host) continue if request_id is None: # the error has already been logged by ResponsFuture log.debug("Failed to prepare query for host %s: %r", host, future._errors.get(host)) continue futures.append((host, future)) for host, future in futures: try: future.result() except Exception: log.exception("Error preparing query for host %s:", host) def shutdown(self): """ Close all connections. ``Session`` instances should not be used for any purpose after being shutdown. """ with self._lock: if self.is_shutdown: return else: self.is_shutdown = True for pool in list(self._pools.values()): pool.shutdown() def __enter__(self): return self def __exit__(self, *args): self.shutdown() def add_or_renew_pool(self, host, is_host_addition): """ For internal use only. """ distance = self._profile_manager.distance(host) if distance == HostDistance.IGNORED: return None def run_add_or_renew_pool(): try: if self._protocol_version >= 3: new_pool = HostConnection(host, distance, self) else: new_pool = HostConnectionPool(host, distance, self) except AuthenticationFailed as auth_exc: conn_exc = ConnectionException(str(auth_exc), host=host) self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) return False except Exception as conn_exc: log.warning("Failed to create connection pool for new host %s:", host, exc_info=conn_exc) # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created self.cluster.signal_connection_failure( host, conn_exc, is_host_addition, expect_host_to_be_down=True) return False previous = self._pools.get(host) with self._lock: while new_pool._keyspace != self.keyspace: self._lock.release() set_keyspace_event = Event() errors_returned = [] def callback(pool, errors): errors_returned.extend(errors) set_keyspace_event.set() new_pool._set_keyspace_for_all_conns(self.keyspace, callback) set_keyspace_event.wait(self.cluster.connect_timeout) if not set_keyspace_event.is_set() or errors_returned: log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) self.cluster.on_down(host, is_host_addition) new_pool.shutdown() self._lock.acquire() return False self._lock.acquire() self._pools[host] = new_pool log.debug("Added pool for host %s to session", host) if previous: previous.shutdown() return True return self.submit(run_add_or_renew_pool) def remove_pool(self, host): pool = self._pools.pop(host, None) if pool: log.debug("Removed connection pool for %r", host) return self.submit(pool.shutdown) else: return None def update_created_pools(self): """ When the set of live nodes change, the loadbalancer will change its mind on host distances. It might change it on the node that came/left but also on other nodes (for instance, if a node dies, another previously ignored node may be now considered). This method ensures that all hosts for which a pool should exist have one, and hosts that shouldn't don't. For internal use only. """ futures = set() for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) pool = self._pools.get(host) future = None if not pool or pool.is_shutdown: # we don't eagerly set is_up on previously ignored hosts. None is included here # to allow us to attempt connections to hosts that have gone from ignored to something # else. if distance != HostDistance.IGNORED and host.is_up in (True, None): future = self.add_or_renew_pool(host, False) elif distance != pool.host_distance: # the distance has changed if distance == HostDistance.IGNORED: future = self.remove_pool(host) else: pool.host_distance = distance if future: futures.add(future) return futures def on_down(self, host): """ Called by the parent Cluster instance when a node is marked down. Only intended for internal use. """ future = self.remove_pool(host) if future: future.add_done_callback(lambda f: self.update_created_pools()) def on_remove(self, host): """ Internal """ self.on_down(host) def set_keyspace(self, keyspace): """ Set the default keyspace for all queries made through this Session. This operation blocks until complete. """ self.execute('USE %s' % (protect_name(keyspace),)) def _set_keyspace_for_all_pools(self, keyspace, callback): """ Asynchronously sets the keyspace on all pools. When all pools have set all of their connections, `callback` will be called with a dictionary of all errors that occurred, keyed by the `Host` that they occurred against. """ with self._lock: self.keyspace = keyspace remaining_callbacks = set(self._pools.values()) errors = {} if not remaining_callbacks: callback(errors) return def pool_finished_setting_keyspace(pool, host_errors): remaining_callbacks.remove(pool) if host_errors: errors[pool.host] = host_errors if not remaining_callbacks: callback(host_errors) for pool in self._pools.values(): pool._set_keyspace_for_all_conns(keyspace, pool_finished_setting_keyspace) def user_type_registered(self, keyspace, user_type, klass): """ Called by the parent Cluster instance when the user registers a new mapping from a user-defined type to a class. Intended for internal use only. """ try: ks_meta = self.cluster.metadata.keyspaces[keyspace] except KeyError: raise UserTypeDoesNotExist( 'Keyspace %s does not exist or has not been discovered by the driver' % (keyspace,)) try: type_meta = ks_meta.user_types[user_type] except KeyError: raise UserTypeDoesNotExist( 'User type %s does not exist in keyspace %s' % (user_type, keyspace)) field_names = type_meta.field_names if six.PY2: # go from unicode to string to avoid decode errors from implicit # decode when formatting non-ascii values field_names = [fn.encode('utf-8') for fn in field_names] def encode(val): return '{ %s }' % ' , '.join('%s : %s' % ( field_name, self.encoder.cql_encode_all_types(getattr(val, field_name, None)) ) for field_name in field_names) self.encoder.mapping[klass] = encode def submit(self, fn, *args, **kwargs): """ Internal """ if not self.is_shutdown: return self.cluster.executor.submit(fn, *args, **kwargs) def get_pool_state(self): return dict((host, pool.get_state()) for host, pool in self._pools.items()) def get_pools(self): return self._pools.values() def _validate_set_legacy_config(self, attr_name, value): if self.cluster._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,)) setattr(self, '_' + attr_name, value) self.cluster._config_mode = _ConfigMode.LEGACY class UserTypeDoesNotExist(Exception): """ An attempt was made to use a user-defined type that does not exist. .. versionadded:: 2.1.0 """ pass class _ControlReconnectionHandler(_ReconnectionHandler): """ Internal """ def __init__(self, control_connection, *args, **kwargs): _ReconnectionHandler.__init__(self, *args, **kwargs) self.control_connection = weakref.proxy(control_connection) def try_reconnect(self): return self.control_connection._reconnect_internal() def on_reconnection(self, connection): self.control_connection._set_new_connection(connection) def on_exception(self, exc, next_delay): # TODO only overridden to add logging, so add logging if isinstance(exc, AuthenticationFailed): return False else: log.debug("Error trying to reconnect control connection: %r", exc) return True def _watch_callback(obj_weakref, method_name, *args, **kwargs): """ A callback handler for the ControlConnection that tolerates weak references. """ obj = obj_weakref() if obj is None: return getattr(obj, method_name)(*args, **kwargs) def _clear_watcher(conn, expiring_weakref): """ Called when the ControlConnection object is about to be finalized. This clears watchers on the underlying Connection object. """ try: conn.control_conn_disposed() except ReferenceError: pass class ControlConnection(object): """ Internal """ _SELECT_PEERS = "SELECT * FROM system.peers" _SELECT_PEERS_NO_TOKENS = "SELECT peer, data_center, rack, rpc_address, release_version, schema_version FROM system.peers" _SELECT_LOCAL = "SELECT * FROM system.local WHERE key='local'" _SELECT_LOCAL_NO_TOKENS = "SELECT cluster_name, data_center, rack, partitioner, release_version, schema_version FROM system.local WHERE key='local'" _SELECT_SCHEMA_PEERS = "SELECT peer, rpc_address, schema_version FROM system.peers" _SELECT_SCHEMA_LOCAL = "SELECT schema_version FROM system.local WHERE key='local'" _is_shutdown = False _timeout = None _protocol_version = None _schema_event_refresh_window = None _topology_event_refresh_window = None _status_event_refresh_window = None _schema_meta_enabled = True _token_meta_enabled = True # for testing purposes _time = time def __init__(self, cluster, timeout, schema_event_refresh_window, topology_event_refresh_window, status_event_refresh_window, schema_meta_enabled=True, token_meta_enabled=True): # use a weak reference to allow the Cluster instance to be GC'ed (and # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) self._connection = None self._timeout = timeout self._schema_event_refresh_window = schema_event_refresh_window self._topology_event_refresh_window = topology_event_refresh_window self._status_event_refresh_window = status_event_refresh_window self._schema_meta_enabled = schema_meta_enabled self._token_meta_enabled = token_meta_enabled self._lock = RLock() self._schema_agreement_lock = Lock() self._reconnection_handler = None self._reconnection_lock = RLock() self._event_schedule_times = {} def connect(self): if self._is_shutdown: return self._protocol_version = self._cluster.protocol_version self._set_new_connection(self._reconnect_internal()) def _set_new_connection(self, conn): """ Replace existing connection (if there is one) and close it. """ with self._lock: old = self._connection self._connection = conn if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() def _reconnect_internal(self): """ Tries to connect to each host in the query plan until one succeeds or every attempt fails. If successful, a new Connection will be returned. Otherwise, :exc:`NoHostAvailable` will be raised with an "errors" arg that is a dict mapping host addresses to the exception that was raised when an attempt was made to open a connection to that host. """ errors = {} for host in self._cluster._default_load_balancing_policy.make_query_plan(): try: return self._try_connect(host) except ConnectionException as exc: errors[host.address] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) self._cluster.signal_connection_failure(host, exc, is_host_addition=False) except Exception as exc: errors[host.address] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) if self._is_shutdown: raise DriverException("[control connection] Reconnection in progress during shutdown") raise NoHostAvailable("Unable to connect to any servers", errors) def _try_connect(self, host): """ Creates a new Connection, registers for pushed events, and refreshes node/token and schema metadata. """ log.debug("[control connection] Opening new connection to %s", host) while True: try: connection = self._cluster.connection_factory(host.address, is_control_connection=True) if self._is_shutdown: connection.close() raise DriverException("Reconnecting during shutdown") break except ProtocolVersionUnsupported as e: self._cluster.protocol_downgrade(host.address, e.startup_version) log.debug("[control connection] Established new connection %r, " "registering watchers and refreshing schema and topology", connection) # use weak references in both directions # _clear_watcher will be called when this ControlConnection is about to be finalized # _watch_callback will get the actual callback from the Connection and relay it to # this object (after a dereferencing a weakref) self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: connection.register_watchers({ "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') }, register_timeout=self._timeout) sel_peers = self._SELECT_PEERS if self._token_meta_enabled else self._SELECT_PEERS_NO_TOKENS sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS peers_query = QueryMessage(query=sel_peers, consistency_level=ConsistencyLevel.ONE) local_query = QueryMessage(query=sel_local, consistency_level=ConsistencyLevel.ONE) shared_results = connection.wait_for_responses( peers_query, local_query, timeout=self._timeout) self._refresh_node_list_and_token_map(connection, preloaded_results=shared_results) self._refresh_schema(connection, preloaded_results=shared_results, schema_agreement_wait=-1) except Exception: connection.close() raise return connection def reconnect(self): if self._is_shutdown: return self._submit(self._reconnect) def _reconnect(self): log.debug("[control connection] Attempting to reconnect") try: self._set_new_connection(self._reconnect_internal()) except NoHostAvailable: # make a retry schedule (which includes backoff) schedule = self.cluster.reconnection_policy.new_schedule() with self._reconnection_lock: # cancel existing reconnection attempts if self._reconnection_handler: self._reconnection_handler.cancel() # when a connection is successfully made, _set_new_connection # will be called with the new connection and then our # _reconnection_handler will be cleared out self._reconnection_handler = _ControlReconnectionHandler( self, self._cluster.scheduler, schedule, self._get_and_set_reconnection_handler, new_handler=None) self._reconnection_handler.start() except Exception: log.debug("[control connection] error reconnecting", exc_info=True) raise def _get_and_set_reconnection_handler(self, new_handler): """ Called by the _ControlReconnectionHandler when a new connection is successfully created. Clears out the _reconnection_handler on this ControlConnection. """ with self._reconnection_lock: old = self._reconnection_handler self._reconnection_handler = new_handler return old def _submit(self, *args, **kwargs): try: if not self._cluster.is_shutdown: return self._cluster.executor.submit(*args, **kwargs) except ReferenceError: pass return None def shutdown(self): # stop trying to reconnect (if we are) with self._reconnection_lock: if self._reconnection_handler: self._reconnection_handler.cancel() with self._lock: if self._is_shutdown: return else: self._is_shutdown = True log.debug("Shutting down control connection") if self._connection: self._connection.close() self._connection = None def refresh_schema(self, force=False, **kwargs): try: if self._connection: return self._refresh_schema(self._connection, force=force, **kwargs) except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: log.debug("[control connection] Error refreshing schema", exc_info=True) self._signal_error() return False def _refresh_schema(self, connection, preloaded_results=None, schema_agreement_wait=None, force=False, **kwargs): if self._cluster.is_shutdown: return False agreed = self.wait_for_schema_agreement(connection, preloaded_results=preloaded_results, wait_time=schema_agreement_wait) if not self._schema_meta_enabled and not force: log.debug("[control connection] Skipping schema refresh because schema metadata is disabled") return False if not agreed: log.debug("Skipping schema refresh due to lack of schema agreement") return False self._cluster.metadata.refresh(connection, self._timeout, **kwargs) return True def refresh_node_list_and_token_map(self, force_token_rebuild=False): try: if self._connection: self._refresh_node_list_and_token_map(self._connection, force_token_rebuild=force_token_rebuild) return True except ReferenceError: pass # our weak reference to the Cluster is no good except Exception: log.debug("[control connection] Error refreshing node list and token map", exc_info=True) self._signal_error() return False def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, force_token_rebuild=False): if preloaded_results: log.debug("[control connection] Refreshing node list and token map using preloaded results") peers_result = preloaded_results[0] local_result = preloaded_results[1] else: cl = ConsistencyLevel.ONE if not self._token_meta_enabled: log.debug("[control connection] Refreshing node list without token map") sel_peers = self._SELECT_PEERS_NO_TOKENS sel_local = self._SELECT_LOCAL_NO_TOKENS else: log.debug("[control connection] Refreshing node list and token map") sel_peers = self._SELECT_PEERS sel_local = self._SELECT_LOCAL peers_query = QueryMessage(query=sel_peers, consistency_level=cl) local_query = QueryMessage(query=sel_local, consistency_level=cl) peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=self._timeout) peers_result = dict_factory(*peers_result.results) partitioner = None token_map = {} found_hosts = set() if local_result.results: found_hosts.add(connection.host) local_rows = dict_factory(*(local_result.results)) local_row = local_rows[0] cluster_name = local_row["cluster_name"] self._cluster.metadata.cluster_name = cluster_name partitioner = local_row.get("partitioner") tokens = local_row.get("tokens") host = self._cluster.metadata.get_host(connection.host) if host: datacenter = local_row.get("data_center") rack = local_row.get("rack") self._update_location_info(host, datacenter, rack) host.listen_address = local_row.get("listen_address") host.broadcast_address = local_row.get("broadcast_address") host.release_version = local_row.get("release_version") host.dse_version = local_row.get("dse_version") host.dse_workload = local_row.get("workload") if partitioner and tokens: token_map[host] = tokens # Check metadata.partitioner to see if we haven't built anything yet. If # every node in the cluster was in the contact points, we won't discover # any new nodes, so we need this additional check. (See PYTHON-90) should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None for row in peers_result: addr = self._rpc_from_peer_row(row) tokens = row.get("tokens", None) if 'tokens' in row and not tokens: # it was selected, but empty log.warning("Excluding host (%s) with no tokens in system.peers table of %s." % (addr, connection.host)) continue if addr in found_hosts: log.warning("Found multiple hosts with the same rpc_address (%s). Excluding peer %s", addr, row.get("peer")) continue found_hosts.add(addr) host = self._cluster.metadata.get_host(addr) datacenter = row.get("data_center") rack = row.get("rack") if host is None: log.debug("[control connection] Found new host to connect to: %s", addr) host, _ = self._cluster.add_host(addr, datacenter, rack, signal=True, refresh_nodes=False) should_rebuild_token_map = True else: should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) host.broadcast_address = row.get("peer") host.release_version = row.get("release_version") host.dse_version = row.get("dse_version") host.dse_workload = row.get("workload") if partitioner and tokens: token_map[host] = tokens for old_host in self._cluster.metadata.all_hosts(): if old_host.address != connection.host and old_host.address not in found_hosts: should_rebuild_token_map = True log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) self._cluster.remove_host(old_host) log.debug("[control connection] Finished fetching ring info") if partitioner and should_rebuild_token_map: log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) def _update_location_info(self, host, datacenter, rack): if host.datacenter == datacenter and host.rack == rack: return False # If the dc/rack information changes, we need to update the load balancing policy. # For that, we remove and re-add the node against the policy. Not the most elegant, and assumes # that the policy will update correctly, but in practice this should work. self._cluster.profile_manager.on_down(host) host.set_location_info(datacenter, rack) self._cluster.profile_manager.on_up(host) return True def _delay_for_event_type(self, event_type, delay_window): # this serves to order processing correlated events (received within the window) # the window and randomization still have the desired effect of skew across client instances next_time = self._event_schedule_times.get(event_type, 0) now = self._time.time() if now <= next_time: this_time = next_time + 0.01 delay = this_time - now else: delay = random() * delay_window this_time = now + delay self._event_schedule_times[event_type] = this_time return delay def _refresh_nodes_if_not_up(self, addr): """ Used to mitigate refreshes for nodes that are already known. Some versions of the server send superfluous NEW_NODE messages in addition to UP events. """ host = self._cluster.metadata.get_host(addr) if not host or not host.is_up: self.refresh_node_list_and_token_map() def _handle_topology_change(self, event): change_type = event["change_type"] addr = self._translate_address(event["address"][0]) if change_type == "NEW_NODE" or change_type == "MOVED_NODE": if self._topology_event_refresh_window >= 0: delay = self._delay_for_event_type('topology_change', self._topology_event_refresh_window) self._cluster.scheduler.schedule_unique(delay, self._refresh_nodes_if_not_up, addr) elif change_type == "REMOVED_NODE": host = self._cluster.metadata.get_host(addr) self._cluster.scheduler.schedule_unique(0, self._cluster.remove_host, host) def _handle_status_change(self, event): change_type = event["change_type"] addr = self._translate_address(event["address"][0]) host = self._cluster.metadata.get_host(addr) if change_type == "UP": delay = self._delay_for_event_type('status_change', self._status_event_refresh_window) if host is None: # this is the first time we've seen the node self._cluster.scheduler.schedule_unique(delay, self.refresh_node_list_and_token_map) else: self._cluster.scheduler.schedule_unique(delay, self._cluster.on_up, host) elif change_type == "DOWN": # Note that there is a slight risk we can receive the event late and thus # mark the host down even though we already had reconnected successfully. # But it is unlikely, and don't have too much consequence since we'll try reconnecting # right away, so we favor the detection to make the Host.is_up more accurate. if host is not None: # this will be run by the scheduler self._cluster.on_down(host, is_host_addition=False) def _translate_address(self, addr): return self._cluster.address_translator.translate(addr) def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return delay = self._delay_for_event_type('schema_change', self._schema_event_refresh_window) self._cluster.scheduler.schedule_unique(delay, self.refresh_schema, **event) def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wait_time=None): total_timeout = wait_time if wait_time is not None else self._cluster.max_schema_agreement_wait if total_timeout <= 0: return True # Each schema change typically generates two schema refreshes, one # from the response type and one from the pushed notification. Holding # a lock is just a simple way to cut down on the number of schema queries # we'll make. with self._schema_agreement_lock: if self._is_shutdown: return if not connection: connection = self._connection if preloaded_results: log.debug("[control connection] Attempting to use preloaded results for schema agreement") peers_result = preloaded_results[0] local_result = preloaded_results[1] schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) if schema_mismatches is None: return True log.debug("[control connection] Waiting for schema agreement") start = self._time.time() elapsed = 0 cl = ConsistencyLevel.ONE schema_mismatches = None while elapsed < total_timeout: peers_query = QueryMessage(query=self._SELECT_SCHEMA_PEERS, consistency_level=cl) local_query = QueryMessage(query=self._SELECT_SCHEMA_LOCAL, consistency_level=cl) try: timeout = min(self._timeout, total_timeout - elapsed) peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=timeout) except OperationTimedOut as timeout: log.debug("[control connection] Timed out waiting for " "response during schema agreement check: %s", timeout) elapsed = self._time.time() - start continue except ConnectionShutdown: if self._is_shutdown: log.debug("[control connection] Aborting wait for schema match due to shutdown") return None else: raise schema_mismatches = self._get_schema_mismatches(peers_result, local_result, connection.host) if schema_mismatches is None: return True log.debug("[control connection] Schemas mismatched, trying again") self._time.sleep(0.2) elapsed = self._time.time() - start log.warning("Node %s is reporting a schema disagreement: %s", connection.host, schema_mismatches) return False def _get_schema_mismatches(self, peers_result, local_result, local_address): peers_result = dict_factory(*peers_result.results) versions = defaultdict(set) if local_result.results: local_row = dict_factory(*local_result.results)[0] if local_row.get("schema_version"): versions[local_row.get("schema_version")].add(local_address) for row in peers_result: schema_ver = row.get('schema_version') if not schema_ver: continue addr = self._rpc_from_peer_row(row) peer = self._cluster.metadata.get_host(addr) if peer and peer.is_up is not False: versions[schema_ver].add(addr) if len(versions) == 1: log.debug("[control connection] Schemas match") return None return dict((version, list(nodes)) for version, nodes in six.iteritems(versions)) def _rpc_from_peer_row(self, row): addr = row.get("rpc_address") if not addr or addr in ["0.0.0.0", "::"]: addr = row.get("peer") return self._translate_address(addr) def _signal_error(self): with self._lock: if self._is_shutdown: return # try just signaling the cluster, as this will trigger a reconnect # as part of marking the host down if self._connection and self._connection.is_defunct: host = self._cluster.metadata.get_host(self._connection.host) # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: self._cluster.signal_connection_failure( host, self._connection.last_error, is_host_addition=False) return # if the connection is not defunct or the host already left, reconnect # manually self.reconnect() def on_up(self, host): pass def on_down(self, host): conn = self._connection if conn and conn.host == host.address and \ self._reconnection_handler is None: log.debug("[control connection] Control connection host (%s) is " "considered down, starting reconnection", host) # this will result in a task being submitted to the executor to reconnect self.reconnect() def on_add(self, host, refresh_nodes=True): if refresh_nodes: self.refresh_node_list_and_token_map(force_token_rebuild=True) def on_remove(self, host): c = self._connection if c and c.host == host.address: log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host) # refresh will be done on reconnect self.reconnect() else: self.refresh_node_list_and_token_map(force_token_rebuild=True) def get_connections(self): c = getattr(self, '_connection', None) return [c] if c else [] def return_connection(self, connection): if connection is self._connection and (connection.is_defunct or connection.is_closed): self.reconnect() def _stop_scheduler(scheduler, thread): try: if not scheduler.is_shutdown: scheduler.shutdown() except ReferenceError: pass thread.join() class _Scheduler(Thread): _queue = None _scheduled_tasks = None _executor = None is_shutdown = False def __init__(self, executor): self._queue = Queue.PriorityQueue() self._scheduled_tasks = set() self._count = count() self._executor = executor Thread.__init__(self, name="Task Scheduler") self.daemon = True self.start() def shutdown(self): try: log.debug("Shutting down Cluster Scheduler") except AttributeError: # this can happen on interpreter shutdown pass self.is_shutdown = True self._queue.put_nowait((0, 0, None)) self.join() def schedule(self, delay, fn, *args, **kwargs): self._insert_task(delay, (fn, args, tuple(kwargs.items()))) def schedule_unique(self, delay, fn, *args, **kwargs): task = (fn, args, tuple(kwargs.items())) if task not in self._scheduled_tasks: self._insert_task(delay, task) else: log.debug("Ignoring schedule_unique for already-scheduled task: %r", task) def _insert_task(self, delay, task): if not self.is_shutdown: run_at = time.time() + delay self._scheduled_tasks.add(task) self._queue.put_nowait((run_at, next(self._count), task)) else: log.debug("Ignoring scheduled task after shutdown: %r", task) def run(self): while True: if self.is_shutdown: return try: while True: run_at, i, task = self._queue.get(block=True, timeout=None) if self.is_shutdown: if task: log.debug("Not executing scheduled task due to Scheduler shutdown") return if run_at <= time.time(): self._scheduled_tasks.discard(task) fn, args, kwargs = task kwargs = dict(kwargs) future = self._executor.submit(fn, *args, **kwargs) future.add_done_callback(self._log_if_failed) else: self._queue.put_nowait((run_at, i, task)) break except Queue.Empty: pass time.sleep(0.1) def _log_if_failed(self, future): exc = future.exception() if exc: log.warning( "An internally scheduled tasked failed with an unhandled exception:", exc_info=exc) def refresh_schema_and_set_result(control_conn, response_future, connection, **kwargs): try: log.debug("Refreshing schema in response to schema change. " "%s", kwargs) response_future.is_schema_agreed = control_conn._refresh_schema(connection, **kwargs) except Exception: log.exception("Exception refreshing schema in response to schema change:") response_future.session.submit(control_conn.refresh_schema, **kwargs) finally: response_future._set_final_result(None) class ResponseFuture(object): """ An asynchronous response delivery mechanism that is returned from calls to :meth:`.Session.execute_async()`. There are two ways for results to be delivered: - Synchronously, by calling :meth:`.result()` - Asynchronously, by attaching callback and errback functions via :meth:`.add_callback()`, :meth:`.add_errback()`, and :meth:`.add_callbacks()`. """ query = None """ The :class:`~.Statement` instance that is being executed through this :class:`.ResponseFuture`. """ is_schema_agreed = True """ For DDL requests, this may be set ``False`` if the schema agreement poll after the response fails. Always ``True`` for non-DDL requests. """ request_encoded_size = None """ Size of the request message sent """ coordinator_host = None """ The host from which we recieved a response """ attempted_hosts = None """ A list of hosts tried, including all speculative executions, retries, and pages """ session = None row_factory = None message = None default_timeout = None _retry_policy = None _profile_manager = None _req_id = None _final_result = _NOT_SET _col_names = None _final_exception = None _query_traces = None _callbacks = None _errbacks = None _current_host = None _connection = None _query_retries = 0 _start_time = None _metrics = None _paging_state = None _custom_payload = None _warnings = None _timer = None _protocol_handler = ProtocolHandler _spec_execution_plan = NoSpeculativeExecutionPlan() _warned_timeout = False def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None, retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, speculative_execution_plan=None): self.session = session # TODO: normalize handling of retry policy and row factory self.row_factory = row_factory or session.row_factory self._load_balancer = load_balancer or session.cluster._default_load_balancing_policy self.message = message self.query = query self.timeout = timeout self._time_remaining = timeout self._retry_policy = retry_policy self._metrics = metrics self.prepared_statement = prepared_statement self._callback_lock = Lock() self._start_time = start_time or time.time() self._make_query_plan() self._event = Event() self._errors = {} self._callbacks = [] self._errbacks = [] self._spec_execution_plan = speculative_execution_plan or self._spec_execution_plan self.attempted_hosts = [] def _start_timer(self): if self._timer is None: spec_delay = self._spec_execution_plan.next_execution(self._current_host) if spec_delay >= 0: if self._time_remaining is None or self._time_remaining > spec_delay: self._timer = self.session.cluster.connection_class.create_timer(spec_delay, self._on_speculative_execute) return if self._time_remaining is not None: self._timer = self.session.cluster.connection_class.create_timer(self._time_remaining, self._on_timeout) def _cancel_timer(self): if self._timer: self._timer.cancel() def _on_timeout(self): errors = self._errors if not errors: if self.is_schema_agreed: errors = {self._current_host.address: "Client request timeout. See Session.execute[_async](timeout)"} else: connection = self.session.cluster.control_connection._connection host = connection.host if connection else 'unknown' errors = {host: "Request timed out while waiting for schema agreement. See Session.execute[_async](timeout) and Cluster.max_schema_agreement_wait."} self._set_final_exception(OperationTimedOut(errors, self._current_host)) def _on_speculative_execute(self): self._timer = None if not self._event.is_set(): if self._time_remaining is not None: elapsed = time.time() - self._start_time self._time_remaining -= elapsed if self._time_remaining <= 0: self._on_timeout() return if not self.send_request(error_no_hosts=False): self._start_timer() def _make_query_plan(self): # convert the list/generator/etc to an iterator so that subsequent # calls to send_request (which retries may do) will resume where # they last left off self.query_plan = iter(self._load_balancer.make_query_plan(self.session.keyspace, self.query)) def send_request(self, error_no_hosts=True): """ Internal """ # query_plan is an iterator, so this will resume where we last left # off if send_request() is called multiple times for host in self.query_plan: req_id = self._query(host) if req_id is not None: self._req_id = req_id # timer is only started here, after we have at least one message queued # this is done to avoid overrun of timers with unfettered client requests # in the case of full disconnect, where no hosts will be available self._start_timer() return True if self.timeout is not None and time.time() - self._start_time > self.timeout: self._on_timeout() return True if error_no_hosts: self._set_final_exception(NoHostAvailable( "Unable to complete the operation against any hosts", self._errors)) return False def _query(self, host, message=None, cb=None): if message is None: message = self.message pool = self.session._pools.get(host) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") return None elif pool.is_shutdown: self._errors[host] = ConnectionException("Pool is shutdown") return None self._current_host = host connection = None try: # TODO get connectTimeout from cluster settings connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] if cb is None: cb = partial(self._set_result, host, connection, pool) self.request_encoded_size = connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message, result_metadata=result_meta) self.attempted_hosts.append(host) return request_id except NoConnectionsAvailable as exc: log.debug("All connections for host %s are at capacity, moving to the next host", host) self._errors[host] = exc return None except Exception as exc: log.debug("Error querying host %s", host, exc_info=True) self._errors[host] = exc if self._metrics is not None: self._metrics.on_connection_error() if connection: pool.return_connection(connection) return None @property def has_more_pages(self): """ Returns :const:`True` if there are more pages left in the query results, :const:`False` otherwise. This should only be checked after the first page has been returned. .. versionadded:: 2.0.0 """ return self._paging_state is not None @property def warnings(self): """ Warnings returned from the server, if any. This will only be set for protocol_version 4+. Warnings may be returned for such things as oversized batches, or too many tombstones in slice queries. Ensure the future is complete before trying to access this property (call :meth:`.result()`, or after callback is invoked). Otherwise it may throw if the response has not been received. """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): raise DriverException("warnings cannot be retrieved before ResponseFuture is finalized") return self._warnings @property def custom_payload(self): """ The custom payload returned from the server, if any. This will only be set by Cassandra servers implementing a custom QueryHandler, and only for protocol_version 4+. Ensure the future is complete before trying to access this property (call :meth:`.result()`, or after callback is invoked). Otherwise it may throw if the response has not been received. :return: :ref:`custom_payload`. """ # TODO: When timers are introduced, just make this wait if not self._event.is_set(): raise DriverException("custom_payload cannot be retrieved before ResponseFuture is finalized") return self._custom_payload def start_fetching_next_page(self): """ If there are more pages left in the query result, this asynchronously starts fetching the next page. If there are no pages left, :exc:`.QueryExhausted` is raised. Also see :attr:`.has_more_pages`. This should only be called after the first page has been returned. .. versionadded:: 2.0.0 """ if not self._paging_state: raise QueryExhausted() self._make_query_plan() self.message.paging_state = self._paging_state self._event.clear() self._final_result = _NOT_SET self._final_exception = None self._timer = None # clear cancelled timer; new one will be set when request is queued self.send_request() def _reprepare(self, prepare_message, host, connection, pool): cb = partial(self.session.submit, self._execute_after_prepare, host, connection, pool) request_id = self._query(host, prepare_message, cb=cb) if request_id is None: # try to submit the original prepared statement on some other host self.send_request() def _set_result(self, host, connection, pool, response): try: self.coordinator_host = host if pool: pool.return_connection(connection) trace_id = getattr(response, 'trace_id', None) if trace_id: if not self._query_traces: self._query_traces = [] self._query_traces.append(QueryTrace(trace_id, self.session)) self._warnings = getattr(response, 'warnings', None) self._custom_payload = getattr(response, 'custom_payload', None) if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) # since we're running on the event loop thread, we need to # use a non-blocking method for setting the keyspace on # all connections in this session, otherwise the event # loop thread will deadlock waiting for keyspaces to be # set. This uses a callback chain which ends with # self._set_keyspace_completed() being called in the # event loop thread. if session: session._set_keyspace_for_all_pools( response.results, self._set_keyspace_completed) elif response.kind == RESULT_KIND_SCHEMA_CHANGE: # refresh the schema before responding, but do it in another # thread instead of the event loop thread self.is_schema_agreed = False self.session.submit( refresh_schema_and_set_result, self.session.cluster.control_connection, self, connection, **response.results) else: results = getattr(response, 'results', None) if results is not None and response.kind == RESULT_KIND_ROWS: self._paging_state = response.paging_state self._col_names = results[0] results = self.row_factory(*results) self._set_final_result(results) elif isinstance(response, ErrorMessage): retry_policy = self._retry_policy if isinstance(response, ReadTimeoutErrorMessage): if self._metrics is not None: self._metrics.on_read_timeout() retry = retry_policy.on_read_timeout( self.query, retry_num=self._query_retries, **response.info) elif isinstance(response, WriteTimeoutErrorMessage): if self._metrics is not None: self._metrics.on_write_timeout() retry = retry_policy.on_write_timeout( self.query, retry_num=self._query_retries, **response.info) elif isinstance(response, UnavailableErrorMessage): if self._metrics is not None: self._metrics.on_unavailable() retry = retry_policy.on_unavailable( self.query, retry_num=self._query_retries, **response.info) elif isinstance(response, OverloadedErrorMessage): if self._metrics is not None: self._metrics.on_other_error() # need to retry against a different host here log.warning("Host %s is overloaded, retrying against a different " "host", host) self._retry(reuse_connection=False, consistency_level=None, host=host) return elif isinstance(response, IsBootstrappingErrorMessage): if self._metrics is not None: self._metrics.on_other_error() # need to retry against a different host here self._retry(reuse_connection=False, consistency_level=None, host=host) return elif isinstance(response, PreparedQueryNotFound): if self.prepared_statement: query_id = self.prepared_statement.query_id assert query_id == response.info, \ "Got different query ID in server response (%s) than we " \ "had before (%s)" % (response.info, query_id) else: query_id = response.info try: prepared_statement = self.session.cluster._prepared_statements[query_id] except KeyError: if not self.prepared_statement: log.error("Tried to execute unknown prepared statement: id=%s", query_id.encode('hex')) self._set_final_exception(response) return else: prepared_statement = self.prepared_statement self.session.cluster._prepared_statements[query_id] = prepared_statement current_keyspace = self._connection.keyspace prepared_keyspace = prepared_statement.keyspace if prepared_keyspace and current_keyspace != prepared_keyspace: self._set_final_exception( ValueError("The Session's current keyspace (%s) does " "not match the keyspace the statement was " "prepared with (%s)" % (current_keyspace, prepared_keyspace))) return log.debug("Re-preparing unrecognized prepared statement against host %s: %s", host, prepared_statement.query_string) prepare_message = PrepareMessage(query=prepared_statement.query_string) # since this might block, run on the executor to avoid hanging # the event loop thread self.session.submit(self._reprepare, prepare_message, host, connection, pool) return else: if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) return retry_type, consistency = retry if retry_type in (RetryPolicy.RETRY, RetryPolicy.RETRY_NEXT_HOST): self._query_retries += 1 reuse = retry_type == RetryPolicy.RETRY self._retry(reuse, consistency, host) elif retry_type is RetryPolicy.RETHROW: self._set_final_exception(response.to_exception()) else: # IGNORE if self._metrics is not None: self._metrics.on_ignore() self._set_final_result(None) self._errors[host] = response.to_exception() elif isinstance(response, ConnectionException): if self._metrics is not None: self._metrics.on_connection_error() if not isinstance(response, ConnectionShutdown): self._connection.defunct(response) self._retry(reuse_connection=False, consistency_level=None, host=host) elif isinstance(response, Exception): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) else: # we got some other kind of response message msg = "Got unexpected message: %r" % (response,) exc = ConnectionException(msg, host) self._connection.defunct(exc) self._set_final_exception(exc) except Exception as exc: # almost certainly caused by a bug, but we need to set something here log.exception("Unexpected exception while handling result in ResponseFuture:") self._set_final_exception(exc) def _set_keyspace_completed(self, errors): if not errors: self._set_final_result(None) else: self._set_final_exception(ConnectionException( "Failed to set keyspace on all hosts: %s" % (errors,))) def _execute_after_prepare(self, host, connection, pool, response): """ Handle the response to our attempt to prepare a statement. If it succeeded, run the original query again against the same host. """ if pool: pool.return_connection(connection) if self._final_exception: return if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_PREPARED: # result metadata is the only thing that could have changed from an alter _, _, _, result_metadata = response.results self.prepared_statement.result_metadata = result_metadata # use self._query to re-use the same host and # at the same time properly borrow the connection request_id = self._query(host) if request_id is None: # this host errored out, move on to the next self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response when preparing statement " "on host %s: %s" % (host, response))) elif isinstance(response, ErrorMessage): if hasattr(response, 'to_exception'): self._set_final_exception(response.to_exception()) else: self._set_final_exception(response) elif isinstance(response, ConnectionException): log.debug("Connection error when preparing statement on host %s: %s", host, response) # try again on a different host, preparing again if necessary self._errors[host] = response self.send_request() else: self._set_final_exception(ConnectionException( "Got unexpected response type when preparing " "statement on host %s: %s" % (host, response))) def _set_final_result(self, response): self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) with self._callback_lock: self._final_result = response self._event.set() # apply each callback for callback in self._callbacks: fn, args, kwargs = callback fn(response, *args, **kwargs) def _set_final_exception(self, response): self._cancel_timer() if self._metrics is not None: self._metrics.request_timer.addValue(time.time() - self._start_time) with self._callback_lock: self._final_exception = response self._event.set() for errback in self._errbacks: fn, args, kwargs = errback fn(response, *args, **kwargs) def _retry(self, reuse_connection, consistency_level, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation return if self._metrics is not None: self._metrics.on_retry() if consistency_level is not None: self.message.consistency_level = consistency_level # don't retry on the event loop thread self.session.submit(self._retry_task, reuse_connection, host) def _retry_task(self, reuse_connection, host): if self._final_exception: # the connection probably broke while we were waiting # to retry the operation return if reuse_connection and self._query(host) is not None: return # otherwise, move onto another host self.send_request() def result(self): """ Return the final result or raise an Exception if errors were encountered. If the final result or error has not been set yet, this method will block until it is set, or the timeout set for the request expires. Timeout is specified in the Session request execution functions. If the timeout is exceeded, an :exc:`cassandra.OperationTimedOut` will be raised. This is a client-side timeout. For more information about server-side coordinator timeouts, see :class:`.policies.RetryPolicy`. Example usage:: >>> future = session.execute_async("SELECT * FROM mycf") >>> # do other stuff... >>> try: ... rows = future.result() ... for row in rows: ... ... # process results ... except Exception: ... log.exception("Operation failed:") """ self._event.wait() if self._final_result is not _NOT_SET: return ResultSet(self, self._final_result) else: raise self._final_exception def get_query_trace_ids(self): """ Returns the trace session ids for this future, if tracing was enabled (does not fetch trace data). """ return [trace.trace_id for trace in self._query_traces] def get_query_trace(self, max_wait=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query trace of the last response, or `None` if tracing was not enabled. Note that this may raise an exception if there are problems retrieving the trace details from Cassandra. If the trace is not available after `max_wait`, :exc:`cassandra.query.TraceUnavailable` will be raised. `query_cl` is the consistency level used to poll the trace tables. """ if self._query_traces: return self._get_query_trace(len(self._query_traces) - 1, max_wait, query_cl) def get_all_query_traces(self, max_wait_per=None, query_cl=ConsistencyLevel.LOCAL_ONE): """ Fetches and returns the query traces for all query pages, if tracing was enabled. See note in :meth:`~.get_query_trace` regarding possible exceptions. """ if self._query_traces: return [self._get_query_trace(i, max_wait_per, query_cl) for i in range(len(self._query_traces))] return [] def _get_query_trace(self, i, max_wait, query_cl): trace = self._query_traces[i] if not trace.events: trace.populate(max_wait=max_wait, query_cl=query_cl) return trace def add_callback(self, fn, *args, **kwargs): """ Attaches a callback function to be called when the final results arrive. By default, `fn` will be called with the results as the first and only argument. If `*args` or `**kwargs` are supplied, they will be passed through as additional positional or keyword arguments to `fn`. If an error is hit while executing the operation, a callback attached here will not be called. Use :meth:`.add_errback()` or :meth:`add_callbacks()` if you wish to handle that case. If the final result has already been seen when this method is called, the callback will be called immediately (before this method returns). Note: in the case that the result is not available when the callback is added, the callback is executed by IO event thread. This means that the callback should not block or attempt further synchronous requests, because no further IO will be processed until the callback returns. **Important**: if the callback you attach results in an exception being raised, **the exception will be ignored**, so please ensure your callback handles all error cases that you care about. Usage example:: >>> session = cluster.connect("mykeyspace") >>> def handle_results(rows, start_time, should_log=False): ... if should_log: ... log.info("Total time: %f", time.time() - start_time) ... ... >>> future = session.execute_async("SELECT * FROM users") >>> future.add_callback(handle_results, time.time(), should_log=True) """ run_now = False with self._callback_lock: if self._final_result is not _NOT_SET: run_now = True else: self._callbacks.append((fn, args, kwargs)) if run_now: fn(self._final_result, *args, **kwargs) return self def add_errback(self, fn, *args, **kwargs): """ Like :meth:`.add_callback()`, but handles error cases. An Exception instance will be passed as the first positional argument to `fn`. """ run_now = False with self._callback_lock: if self._final_exception: run_now = True else: self._errbacks.append((fn, args, kwargs)) if run_now: fn(self._final_exception, *args, **kwargs) return self def add_callbacks(self, callback, errback, callback_args=(), callback_kwargs=None, errback_args=(), errback_kwargs=None): """ A convenient combination of :meth:`.add_callback()` and :meth:`.add_errback()`. Example usage:: >>> session = cluster.connect() >>> query = "SELECT * FROM mycf" >>> future = session.execute_async(query) >>> def log_results(results, level='debug'): ... for row in results: ... log.log(level, "Result: %s", row) >>> def log_error(exc, query): ... log.error("Query '%s' failed: %s", query, exc) >>> future.add_callbacks( ... callback=log_results, callback_kwargs={'level': 'info'}, ... errback=log_error, errback_args=(query,)) """ self.add_callback(callback, *callback_args, **(callback_kwargs or {})) self.add_errback(errback, *errback_args, **(errback_kwargs or {})) def clear_callbacks(self): with self._callback_lock: self._callbacks = [] self._errbacks = [] def __str__(self): result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result return "" \ % (self.query, self._req_id, result, self._final_exception, self.coordinator_host) __repr__ = __str__ class QueryExhausted(Exception): """ Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and there are no more pages. You can check :attr:`.ResponseFuture.has_more_pages` before calling to avoid this. .. versionadded:: 2.0.0 """ pass class ResultSet(object): """ An iterator over the rows from a query result. Also supplies basic equality and indexing methods for backward-compatability. These methods materialize the entire result set (loading all pages), and should only be used if the total result size is understood. Warnings are emitted when paged results are materialized in this fashion. You can treat this as a normal iterator over rows:: >>> from cassandra.query import SimpleStatement >>> statement = SimpleStatement("SELECT * FROM users", fetch_size=10) >>> for user_row in session.execute(statement): ... process_user(user_row) Whenever there are no more rows in the current page, the next page will be fetched transparently. However, note that it *is* possible for an :class:`Exception` to be raised while fetching the next page, just like you might see on a normal call to ``session.execute()``. """ def __init__(self, response_future, initial_response): self.response_future = response_future self.column_names = response_future._col_names self._set_current_rows(initial_response) self._page_iter = None self._list_mode = False @property def has_more_pages(self): """ True if the last response indicated more pages; False otherwise """ return self.response_future.has_more_pages @property def current_rows(self): """ The list of current page rows. May be empty if the result was empty, or this is the last page. """ return self._current_rows or [] def __iter__(self): if self._list_mode: return iter(self._current_rows) self._page_iter = iter(self._current_rows) return self def next(self): try: return next(self._page_iter) except StopIteration: if not self.response_future.has_more_pages: if not self._list_mode: self._current_rows = [] raise self.fetch_next_page() self._page_iter = iter(self._current_rows) return next(self._page_iter) __next__ = next def fetch_next_page(self): """ Manually, synchronously fetch the next page. Supplied for manually retrieving pages and inspecting :meth:`~.current_page`. It is not necessary to call this when iterating through results; paging happens implicitly in iteration. """ if self.response_future.has_more_pages: self.response_future.start_fetching_next_page() result = self.response_future.result() self._current_rows = result._current_rows # ResultSet has already _set_current_rows to the appropriate form else: self._current_rows = [] def _set_current_rows(self, result): if isinstance(result, Mapping): self._current_rows = [result] if result else [] return try: iter(result) # can't check directly for generator types because cython generators are different self._current_rows = result except TypeError: self._current_rows = [result] if result else [] def _fetch_all(self): self._current_rows = list(self) self._page_iter = None def _enter_list_mode(self, operator): if self._list_mode: return if self._page_iter: raise RuntimeError("Cannot use %s when results have been iterated." % operator) if self.response_future.has_more_pages: log.warning("Using %s on paged results causes entire result set to be materialized.", operator) self._fetch_all() # done regardless of paging status in case the row factory produces a generator self._list_mode = True def __eq__(self, other): self._enter_list_mode("equality operator") return self._current_rows == other def __getitem__(self, i): self._enter_list_mode("index operator") return self._current_rows[i] def __nonzero__(self): return bool(self._current_rows) __bool__ = __nonzero__ def get_query_trace(self, max_wait_sec=None): """ Gets the last query trace from the associated future. See :meth:`.ResponseFuture.get_query_trace` for details. """ return self.response_future.get_query_trace(max_wait_sec) def get_all_query_traces(self, max_wait_sec_per=None): """ Gets all query traces from the associated future. See :meth:`.ResponseFuture.get_all_query_traces` for details. """ return self.response_future.get_all_query_traces(max_wait_sec_per) @property def was_applied(self): """ For LWT results, returns whether the transaction was applied. Result is indeterminate if called on a result that was not an LWT request. Only valid when one of tne of the internal row factories is in use. """ if self.response_future.row_factory not in (named_tuple_factory, dict_factory, tuple_factory): raise RuntimeError("Cannot determine LWT result with row factory %s" % (self.response_future.row_factsory,)) if len(self.current_rows) != 1: raise RuntimeError("LWT result should have exactly one row. This has %d." % (len(self.current_rows))) row = self.current_rows[0] if isinstance(row, tuple): return row[0] else: return row['[applied]'] @property def paging_state(self): """ Server paging state of the query. Can be `None` if the query was not paged. The driver treats paging state as opaque, but it may contain primary key data, so applications may want to avoid sending this to untrusted parties. """ return self.response_future._paging_state cassandra-driver-3.7.1/cassandra/cython_marshal.pyx0000664000175000017500000000462212743410406025335 0ustar aboudreaultaboudreault00000000000000# -- cython: profile=True # # Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import six from libc.stdint cimport (int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t) from libc.string cimport memcpy from cassandra.buffer cimport Buffer, buf_read, to_bytes cdef bint is_little_endian from cassandra.util import is_little_endian cdef bint PY3 = six.PY3 ctypedef fused num_t: int64_t int32_t int16_t int8_t uint64_t uint32_t uint16_t uint8_t double float cdef inline num_t unpack_num(Buffer *buf, num_t *dummy=NULL): # dummy pointer because cython wants the fused type as an arg """ Copy to aligned destination, conditionally swapping to native byte order """ cdef Py_ssize_t start, end, i cdef char *src = buf_read(buf, sizeof(num_t)) cdef num_t ret = 0 cdef char *out = &ret if is_little_endian: for i in range(sizeof(num_t)): out[sizeof(num_t) - i - 1] = src[i] else: memcpy(out, src, sizeof(num_t)) return ret cdef varint_unpack(Buffer *term): """Unpack a variable-sized integer""" if PY3: return varint_unpack_py3(to_bytes(term)) else: return varint_unpack_py2(to_bytes(term)) # TODO: Optimize these two functions cdef varint_unpack_py3(bytes term): val = int(''.join(["%02x" % i for i in term]), 16) if (term[0] & 128) != 0: shift = len(term) * 8 # * Note below val -= 1 << shift return val cdef varint_unpack_py2(bytes term): # noqa val = int(term.encode('hex'), 16) if (ord(term[0]) & 128) != 0: shift = len(term) * 8 # * Note below val = val - (1 << shift) return val # * Note * # '1 << (len(term) * 8)' Cython tries to do native # integer shifts, which overflows. We need this to # emulate Python shifting, which will expand the long # to accommodate cassandra-driver-3.7.1/cassandra/numpy_parser.pyx0000664000175000017500000001335112766043657025064 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module provides an optional protocol parser that returns NumPy arrays. ============================================================================= This module should not be imported by any of the main python-driver modules, as numpy is an optional dependency. ============================================================================= """ include "ioutils.pyx" cimport cython from libc.stdint cimport uint64_t, uint8_t from cpython.ref cimport Py_INCREF, PyObject from cassandra.bytesio cimport BytesIOReader from cassandra.deserializers cimport Deserializer, from_binary from cassandra.parsing cimport ParseDesc, ColumnParser, RowParser from cassandra import cqltypes from cassandra.util import is_little_endian import numpy as np cdef extern from "numpyFlags.h": # Include 'numpyFlags.h' into the generated C code to disable the # deprecated NumPy API pass cdef extern from "Python.h": # An integer type large enough to hold a pointer ctypedef uint64_t Py_uintptr_t # Simple array descriptor, useful to parse rows into a NumPy array ctypedef struct ArrDesc: Py_uintptr_t buf_ptr int stride # should be large enough as we allocate contiguous arrays int is_object Py_uintptr_t mask_ptr arrDescDtype = np.dtype( [ ('buf_ptr', np.uintp) , ('stride', np.dtype('i')) , ('is_object', np.dtype('i')) , ('mask_ptr', np.uintp) ], align=True) _cqltype_to_numpy = { cqltypes.LongType: np.dtype('>i8'), cqltypes.CounterColumnType: np.dtype('>i8'), cqltypes.Int32Type: np.dtype('>i4'), cqltypes.ShortType: np.dtype('>i2'), cqltypes.FloatType: np.dtype('>f4'), cqltypes.DoubleType: np.dtype('>f8'), } obj_dtype = np.dtype('O') cdef uint8_t mask_true = 0x01 cdef class NumpyParser(ColumnParser): """Decode a ResultMessage into a bunch of NumPy arrays""" cpdef parse_rows(self, BytesIOReader reader, ParseDesc desc): cdef Py_ssize_t rowcount cdef ArrDesc[::1] array_descs cdef ArrDesc *arrs rowcount = read_int(reader) array_descs, arrays = make_arrays(desc, rowcount) arrs = &array_descs[0] _parse_rows(reader, desc, arrs, rowcount) arrays = [make_native_byteorder(arr) for arr in arrays] result = dict(zip(desc.colnames, arrays)) return result cdef _parse_rows(BytesIOReader reader, ParseDesc desc, ArrDesc *arrs, Py_ssize_t rowcount): cdef Py_ssize_t i for i in range(rowcount): unpack_row(reader, desc, arrs) ### Helper functions to create NumPy arrays and array descriptors def make_arrays(ParseDesc desc, array_size): """ Allocate arrays for each result column. returns a tuple of (array_descs, arrays), where 'array_descs' describe the arrays for NativeRowParser and 'arrays' is a dict mapping column names to arrays (e.g. this can be fed into pandas.DataFrame) """ array_descs = np.empty((desc.rowsize,), arrDescDtype) arrays = [] for i, coltype in enumerate(desc.coltypes): arr = make_array(coltype, array_size) array_descs[i]['buf_ptr'] = arr.ctypes.data array_descs[i]['stride'] = arr.strides[0] array_descs[i]['is_object'] = arr.dtype is obj_dtype try: array_descs[i]['mask_ptr'] = arr.mask.ctypes.data except AttributeError: array_descs[i]['mask_ptr'] = 0 arrays.append(arr) return array_descs, arrays def make_array(coltype, array_size): """ Allocate a new NumPy array of the given column type and size. """ try: a = np.ma.empty((array_size,), dtype=_cqltype_to_numpy[coltype]) a.mask = np.zeros((array_size,), dtype=np.bool) except KeyError: a = np.empty((array_size,), dtype=obj_dtype) return a #### Parse rows into NumPy arrays @cython.boundscheck(False) @cython.wraparound(False) cdef inline int unpack_row( BytesIOReader reader, ParseDesc desc, ArrDesc *arrays) except -1: cdef Buffer buf cdef Py_ssize_t i, rowsize = desc.rowsize cdef ArrDesc arr cdef Deserializer deserializer for i in range(rowsize): get_buf(reader, &buf) arr = arrays[i] if arr.is_object: deserializer = desc.deserializers[i] val = from_binary(deserializer, &buf, desc.protocol_version) Py_INCREF(val) ( arr.buf_ptr)[0] = val elif buf.size >= 0: memcpy( arr.buf_ptr, buf.ptr, buf.size) else: memcpy(arr.mask_ptr, &mask_true, 1) # Update the pointer into the array for the next time arrays[i].buf_ptr += arr.stride arrays[i].mask_ptr += 1 return 0 def make_native_byteorder(arr): """ Make sure all values have a native endian in the NumPy arrays. """ if is_little_endian and not arr.dtype.kind == 'O': # We have arrays in big-endian order. First swap the bytes # into little endian order, and then update the numpy dtype # accordingly (e.g. from '>i8' to ' buf.size: raise IndexError("Requested more than length of buffer") return buf.ptr cdef inline int slice_buffer(Buffer *buf, Buffer *out, Py_ssize_t start, Py_ssize_t size) except -1: if size < 0: raise ValueError("Length must be positive") if start + size > buf.size: raise IndexError("Buffer slice out of bounds") out.ptr = buf.ptr + start out.size = size return 0 cdef inline void from_ptr_and_size(char *ptr, Py_ssize_t size, Buffer *out): out.ptr = ptr out.size = size cassandra-driver-3.7.1/cassandra/cython_utils.pxd0000664000175000017500000000012012743410406025006 0ustar aboudreaultaboudreault00000000000000from libc.stdint cimport int64_t cdef datetime_from_timestamp(double timestamp) cassandra-driver-3.7.1/cassandra_driver.egg-info/0000775000175000017500000000000013004144417024616 5ustar aboudreaultaboudreault00000000000000cassandra-driver-3.7.1/cassandra_driver.egg-info/PKG-INFO0000664000175000017500000001361113004144415025713 0ustar aboudreaultaboudreault00000000000000Metadata-Version: 1.1 Name: cassandra-driver Version: 3.7.1 Summary: Python driver for Cassandra Home-page: http://github.com/datastax/python-driver Author: Tyler Hobbs Author-email: tyler@datastax.com License: UNKNOWN Description: DataStax Python Driver for Apache Cassandra =========================================== .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master :target: https://travis-ci.org/datastax/python-driver A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (1.2+) and DataStax Enterprise (3.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. The driver supports Python 2.6, 2.7, 3.3, and 3.4. Feedback Requested ------------------ **Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). Features -------- * `Synchronous `_ and `Asynchronous `_ APIs * `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining * `Connection pooling `_ * Automatic node discovery * `Automatic reconnection `_ * Configurable `load balancing `_ and `retry policies `_ * `Concurrent execution utilities `_ * `Object mapper `_ Installation ------------ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the `installation guide `_. Documentation ------------- The documentation can be found online `here `_. A couple of links for getting up to speed: * `Installation `_ * `Getting started guide `_ * `API docs `_ * `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to `documentation here `_. Contributing ------------ See `CONTRIBUTING.md `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ Your two best options for getting help with the driver are the `mailing list `_ and the IRC channel. For IRC, use the #datastax-drivers channel on irc.freenode.net. If you don't have an IRC client, you can use `freenode's web-based client `_. License ------- Copyright 2013-2016 DataStax Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Keywords: cassandra,cql,orm Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Natural Language :: English Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2.6 Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3.3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Software Development :: Libraries :: Python Modules cassandra-driver-3.7.1/cassandra_driver.egg-info/requires.txt0000664000175000017500000000002213004144415027206 0ustar aboudreaultaboudreault00000000000000six >=1.6 futures cassandra-driver-3.7.1/cassandra_driver.egg-info/dependency_links.txt0000664000175000017500000000000113004144415030662 0ustar aboudreaultaboudreault00000000000000 cassandra-driver-3.7.1/cassandra_driver.egg-info/SOURCES.txt0000664000175000017500000000313513004144417026504 0ustar aboudreaultaboudreault00000000000000LICENSE MANIFEST.in README.rst ez_setup.py setup.py cassandra/__init__.py cassandra/auth.py cassandra/buffer.pxd cassandra/bytesio.pxd cassandra/bytesio.pyx cassandra/cluster.py cassandra/cmurmur3.c cassandra/concurrent.py cassandra/connection.py cassandra/cqltypes.py cassandra/cython_deps.py cassandra/cython_marshal.pyx cassandra/cython_utils.pxd cassandra/cython_utils.pyx cassandra/deserializers.pxd cassandra/deserializers.pyx cassandra/encoder.py cassandra/ioutils.pyx cassandra/marshal.py cassandra/metadata.py cassandra/metrics.py cassandra/murmur3.py cassandra/numpyFlags.h cassandra/numpy_parser.pyx cassandra/obj_parser.pyx cassandra/parsing.pxd cassandra/parsing.pyx cassandra/policies.py cassandra/pool.py cassandra/protocol.py cassandra/query.py cassandra/row_parser.pyx cassandra/tuple.pxd cassandra/type_codes.pxd cassandra/type_codes.py cassandra/util.py cassandra/cqlengine/__init__.py cassandra/cqlengine/columns.py cassandra/cqlengine/connection.py cassandra/cqlengine/functions.py cassandra/cqlengine/management.py cassandra/cqlengine/models.py cassandra/cqlengine/named.py cassandra/cqlengine/operators.py cassandra/cqlengine/query.py cassandra/cqlengine/statements.py cassandra/cqlengine/usertype.py cassandra/io/__init__.py cassandra/io/asyncorereactor.py cassandra/io/eventletreactor.py cassandra/io/geventreactor.py cassandra/io/libevreactor.py cassandra/io/libevwrapper.c cassandra/io/twistedreactor.py cassandra_driver.egg-info/PKG-INFO cassandra_driver.egg-info/SOURCES.txt cassandra_driver.egg-info/dependency_links.txt cassandra_driver.egg-info/requires.txt cassandra_driver.egg-info/top_level.txtcassandra-driver-3.7.1/cassandra_driver.egg-info/top_level.txt0000664000175000017500000000002013004144415027336 0ustar aboudreaultaboudreault00000000000000DUMMY cassandra cassandra-driver-3.7.1/ez_setup.py0000664000175000017500000002062412657142321022033 0ustar aboudreaultaboudreault00000000000000#!python """Bootstrap setuptools installation If you want to use setuptools in your package's setup.py, just include this file in the same directory with it, and add this to the top of your setup.py:: from ez_setup import use_setuptools use_setuptools() If you want to require a specific version of setuptools, set a download mirror, or use an alternate download directory, you can do so by supplying the appropriate options to ``use_setuptools()``. This file can also be run as a script to install or upgrade setuptools. """ import os import shutil import sys import tempfile import tarfile import optparse import subprocess from distutils import log try: from site import USER_SITE except ImportError: USER_SITE = None DEFAULT_VERSION = "0.9.6" DEFAULT_URL = "https://pypi.python.org/packages/source/s/setuptools/" def _python_cmd(*args): args = (sys.executable,) + args return subprocess.call(args) == 0 def _install(tarball, install_args=()): # extracting the tarball tmpdir = tempfile.mkdtemp() log.warn('Extracting in %s', tmpdir) old_wd = os.getcwd() try: os.chdir(tmpdir) tar = tarfile.open(tarball) _extractall(tar) tar.close() # going in the directory subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0]) os.chdir(subdir) log.warn('Now working in %s', subdir) # installing log.warn('Installing Setuptools') if not _python_cmd('setup.py', 'install', *install_args): log.warn('Something went wrong during the installation.') log.warn('See the error message above.') # exitcode will be 2 return 2 finally: os.chdir(old_wd) shutil.rmtree(tmpdir) def _build_egg(egg, tarball, to_dir): # extracting the tarball tmpdir = tempfile.mkdtemp() log.warn('Extracting in %s', tmpdir) old_wd = os.getcwd() try: os.chdir(tmpdir) tar = tarfile.open(tarball) _extractall(tar) tar.close() # going in the directory subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0]) os.chdir(subdir) log.warn('Now working in %s', subdir) # building an egg log.warn('Building a Setuptools egg in %s', to_dir) _python_cmd('setup.py', '-q', 'bdist_egg', '--dist-dir', to_dir) finally: os.chdir(old_wd) shutil.rmtree(tmpdir) # returning the result log.warn(egg) if not os.path.exists(egg): raise IOError('Could not build the egg.') def _do_download(version, download_base, to_dir, download_delay): egg = os.path.join(to_dir, 'setuptools-%s-py%d.%d.egg' % (version, sys.version_info[0], sys.version_info[1])) if not os.path.exists(egg): tarball = download_setuptools(version, download_base, to_dir, download_delay) _build_egg(egg, tarball, to_dir) sys.path.insert(0, egg) import setuptools setuptools.bootstrap_install_from = egg def use_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, download_delay=15): # making sure we use the absolute path to_dir = os.path.abspath(to_dir) was_imported = 'pkg_resources' in sys.modules or \ 'setuptools' in sys.modules try: import pkg_resources except ImportError: return _do_download(version, download_base, to_dir, download_delay) try: pkg_resources.require("setuptools>=" + version) return except pkg_resources.VersionConflict: e = sys.exc_info()[1] if was_imported: sys.stderr.write( "The required version of setuptools (>=%s) is not available,\n" "and can't be installed while this script is running. Please\n" "install a more recent version first, using\n" "'easy_install -U setuptools'." "\n\n(Currently using %r)\n" % (version, e.args[0])) sys.exit(2) else: del pkg_resources, sys.modules['pkg_resources'] # reload ok return _do_download(version, download_base, to_dir, download_delay) except pkg_resources.DistributionNotFound: return _do_download(version, download_base, to_dir, download_delay) def download_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, delay=15): """Download setuptools from a specified location and return its filename `version` should be a valid setuptools version number that is available as an egg for download under the `download_base` URL (which should end with a '/'). `to_dir` is the directory where the egg will be downloaded. `delay` is the number of seconds to pause before an actual download attempt. """ # making sure we use the absolute path to_dir = os.path.abspath(to_dir) try: from urllib.request import urlopen except ImportError: from urllib2 import urlopen tgz_name = "setuptools-%s.tar.gz" % version url = download_base + tgz_name saveto = os.path.join(to_dir, tgz_name) src = dst = None if not os.path.exists(saveto): # Avoid repeated downloads try: log.warn("Downloading %s", url) src = urlopen(url) # Read/write all in one block, so we don't create a corrupt file # if the download is interrupted. data = src.read() dst = open(saveto, "wb") dst.write(data) finally: if src: src.close() if dst: dst.close() return os.path.realpath(saveto) def _extractall(self, path=".", members=None): """Extract all members from the archive to the current working directory and set owner, modification time and permissions on directories afterwards. `path' specifies a different directory to extract to. `members' is optional and must be a subset of the list returned by getmembers(). """ import copy import operator from tarfile import ExtractError directories = [] if members is None: members = self for tarinfo in members: if tarinfo.isdir(): # Extract directories with a safe mode. directories.append(tarinfo) tarinfo = copy.copy(tarinfo) tarinfo.mode = 448 # decimal for oct 0700 self.extract(tarinfo, path) # Reverse sort directories. if sys.version_info < (2, 4): def sorter(dir1, dir2): return cmp(dir1.name, dir2.name) directories.sort(sorter) directories.reverse() else: directories.sort(key=operator.attrgetter('name'), reverse=True) # Set correct owner, mtime and filemode on directories. for tarinfo in directories: dirpath = os.path.join(path, tarinfo.name) try: self.chown(tarinfo, dirpath) self.utime(tarinfo, dirpath) self.chmod(tarinfo, dirpath) except ExtractError: e = sys.exc_info()[1] if self.errorlevel > 1: raise else: self._dbg(1, "tarfile: %s" % e) def _build_install_args(options): """ Build the arguments to 'python setup.py install' on the setuptools package """ install_args = [] if options.user_install: if sys.version_info < (2, 6): log.warn("--user requires Python 2.6 or later") raise SystemExit(1) install_args.append('--user') return install_args def _parse_args(): """ Parse the command line for options """ parser = optparse.OptionParser() parser.add_option( '--user', dest='user_install', action='store_true', default=False, help='install in user site package (requires Python 2.6 or later)') parser.add_option( '--download-base', dest='download_base', metavar="URL", default=DEFAULT_URL, help='alternative URL from where to download the setuptools package') options, args = parser.parse_args() # positional arguments are ignored return options def main(version=DEFAULT_VERSION): """Install or upgrade setuptools and EasyInstall""" options = _parse_args() tarball = download_setuptools(download_base=options.download_base) return _install(tarball, _build_install_args(options)) if __name__ == '__main__': sys.exit(main()) cassandra-driver-3.7.1/setup.cfg0000664000175000017500000000007313004144417021433 0ustar aboudreaultaboudreault00000000000000[egg_info] tag_build = tag_date = 0 tag_svn_revision = 0 cassandra-driver-3.7.1/README.rst0000664000175000017500000001036712766043657021332 0ustar aboudreaultaboudreault00000000000000DataStax Python Driver for Apache Cassandra =========================================== .. image:: https://travis-ci.org/datastax/python-driver.png?branch=master :target: https://travis-ci.org/datastax/python-driver A modern, `feature-rich `_ and highly-tunable Python client library for Apache Cassandra (1.2+) and DataStax Enterprise (3.1+) using exclusively Cassandra's binary protocol and Cassandra Query Language v3. The driver supports Python 2.6, 2.7, 3.3, and 3.4. Feedback Requested ------------------ **Help us focus our efforts!** Provide your input on the `Platform and Runtime Survey `_ (we kept it short). Features -------- * `Synchronous `_ and `Asynchronous `_ APIs * `Simple, Prepared, and Batch statements `_ * Asynchronous IO, parallel execution, request pipelining * `Connection pooling `_ * Automatic node discovery * `Automatic reconnection `_ * Configurable `load balancing `_ and `retry policies `_ * `Concurrent execution utilities `_ * `Object mapper `_ Installation ------------ Installation through pip is recommended:: $ pip install cassandra-driver For more complete installation instructions, see the `installation guide `_. Documentation ------------- The documentation can be found online `here `_. A couple of links for getting up to speed: * `Installation `_ * `Getting started guide `_ * `API docs `_ * `Performance tips `_ Object Mapper ------------- cqlengine (originally developed by Blake Eggleston and Jon Haddad, with contributions from the community) is now maintained as an integral part of this package. Refer to `documentation here `_. Contributing ------------ See `CONTRIBUTING.md `_. Reporting Problems ------------------ Please report any bugs and make any feature requests on the `JIRA `_ issue tracker. If you would like to contribute, please feel free to open a pull request. Getting Help ------------ Your two best options for getting help with the driver are the `mailing list `_ and the IRC channel. For IRC, use the #datastax-drivers channel on irc.freenode.net. If you don't have an IRC client, you can use `freenode's web-based client `_. License ------- Copyright 2013-2016 DataStax Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. cassandra-driver-3.7.1/MANIFEST.in0000664000175000017500000000030312743410406021347 0ustar aboudreaultaboudreault00000000000000include setup.py README.rst MANIFEST.in LICENSE ez_setup.py include cassandra/cmurmur3.c include cassandra/io/libevwrapper.c include cassandra/*.pyx include cassandra/*.pxd include cassandra/*.h cassandra-driver-3.7.1/setup.py0000664000175000017500000003755413004141751021340 0ustar aboudreaultaboudreault00000000000000# Copyright 2013-2016 DataStax, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import os import sys import warnings if __name__ == '__main__' and sys.argv[1] == "gevent_nosetests": print("Running gevent tests") from gevent.monkey import patch_all patch_all() if __name__ == '__main__' and sys.argv[1] == "eventlet_nosetests": print("Running eventlet tests") from eventlet import monkey_patch monkey_patch() import ez_setup ez_setup.use_setuptools() from setuptools import setup from distutils.command.build_ext import build_ext from distutils.core import Extension from distutils.errors import (CCompilerError, DistutilsPlatformError, DistutilsExecError) from distutils.cmd import Command PY3 = sys.version_info[0] == 3 try: import subprocess has_subprocess = True except ImportError: has_subprocess = False from cassandra import __version__ long_description = "" with open("README.rst") as f: long_description = f.read() try: from nose.commands import nosetests except ImportError: gevent_nosetests = None eventlet_nosetests = None else: class gevent_nosetests(nosetests): description = "run nosetests with gevent monkey patching" class eventlet_nosetests(nosetests): description = "run nosetests with eventlet monkey patching" has_cqlengine = False if __name__ == '__main__' and sys.argv[1] == "install": try: import cqlengine has_cqlengine = True except ImportError: pass PROFILING = False class DocCommand(Command): description = "generate or test documentation" user_options = [("test", "t", "run doctests instead of generating documentation")] boolean_options = ["test"] def initialize_options(self): self.test = False def finalize_options(self): pass def run(self): if self.test: path = "docs/_build/doctest" mode = "doctest" else: path = "docs/_build/%s" % __version__ mode = "html" try: os.makedirs(path) except: pass if has_subprocess: # Prevent run with in-place extensions because cython-generated objects do not carry docstrings # http://docs.cython.org/src/userguide/special_methods.html#docstrings import glob for f in glob.glob("cassandra/*.so"): print("Removing '%s' to allow docs to run on pure python modules." %(f,)) os.unlink(f) # Build io extension to make import and docstrings work try: output = subprocess.check_output( ["python", "setup.py", "build_ext", "--inplace", "--force", "--no-murmur3", "--no-cython"], stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exc: raise RuntimeError("Documentation step '%s' failed: %s: %s" % ("build_ext", exc, exc.output)) else: print(output) try: output = subprocess.check_output( ["sphinx-build", "-b", mode, "docs", path], stderr=subprocess.STDOUT) except subprocess.CalledProcessError as exc: raise RuntimeError("Documentation step '%s' failed: %s: %s" % (mode, exc, exc.output)) else: print(output) print("") print("Documentation step '%s' performed, results here:" % mode) print(" file://%s/%s/index.html" % (os.path.dirname(os.path.realpath(__file__)), path)) class BuildFailed(Exception): def __init__(self, ext): self.ext = ext murmur3_ext = Extension('cassandra.cmurmur3', sources=['cassandra/cmurmur3.c']) libev_ext = Extension('cassandra.io.libevwrapper', sources=['cassandra/io/libevwrapper.c'], include_dirs=['/usr/include/libev', '/usr/local/include', '/opt/local/include'], libraries=['ev'], library_dirs=['/usr/local/lib', '/opt/local/lib']) platform_unsupported_msg = \ """ =============================================================================== The optional C extensions are not supported on this platform. =============================================================================== """ arch_unsupported_msg = \ """ =============================================================================== The optional C extensions are not supported on big-endian systems. =============================================================================== """ pypy_unsupported_msg = \ """ ================================================================================= Some optional C extensions are not supported in PyPy. Only murmur3 will be built. ================================================================================= """ is_windows = os.name == 'nt' is_pypy = "PyPy" in sys.version if is_pypy: sys.stderr.write(pypy_unsupported_msg) is_supported_platform = sys.platform != "cli" and not sys.platform.startswith("java") is_supported_arch = sys.byteorder != "big" if not is_supported_platform: sys.stderr.write(platform_unsupported_msg) elif not is_supported_arch: sys.stderr.write(arch_unsupported_msg) try_extensions = "--no-extensions" not in sys.argv and is_supported_platform and is_supported_arch and not os.environ.get('CASS_DRIVER_NO_EXTENSIONS') try_murmur3 = try_extensions and "--no-murmur3" not in sys.argv try_libev = try_extensions and "--no-libev" not in sys.argv and not is_pypy and not is_windows try_cython = try_extensions and "--no-cython" not in sys.argv and not is_pypy and not os.environ.get('CASS_DRIVER_NO_CYTHON') try_cython &= 'egg_info' not in sys.argv # bypass setup_requires for pip egg_info calls, which will never have --install-option"--no-cython" coming fomr pip sys.argv = [a for a in sys.argv if a not in ("--no-murmur3", "--no-libev", "--no-cython", "--no-extensions")] build_concurrency = int(os.environ.get('CASS_DRIVER_BUILD_CONCURRENCY', '0')) class NoPatchExtension(Extension): # Older versions of setuptools.extension has a static flag which is set False before our # setup_requires lands Cython. It causes our *.pyx sources to be renamed to *.c in # the initializer. # The other workaround would be to manually generate sources, but that bypasses a lot # of the niceness cythonize embodies (setup build dir, conditional build, etc). # Newer setuptools does not have this problem because it checks for cython dynamically. # https://bitbucket.org/pypa/setuptools/commits/714c3144e08fd01a9f61d1c88411e76d2538b2e4 def __init__(self, *args, **kwargs): # bypass the patched init if possible if Extension.__bases__: base, = Extension.__bases__ base.__init__(self, *args, **kwargs) else: Extension.__init__(self, *args, **kwargs) class build_extensions(build_ext): error_message = """ =============================================================================== WARNING: could not compile %s. The C extensions are not required for the driver to run, but they add support for token-aware routing with the Murmur3Partitioner. On Windows, make sure Visual Studio or an SDK is installed, and your environment is configured to build for the appropriate architecture (matching your Python runtime). This is often a matter of using vcvarsall.bat from your install directory, or running from a command prompt in the Visual Studio Tools Start Menu. =============================================================================== """ if is_windows else """ =============================================================================== WARNING: could not compile %s. The C extensions are not required for the driver to run, but they add support for libev and token-aware routing with the Murmur3Partitioner. Linux users should ensure that GCC and the Python headers are available. On Ubuntu and Debian, this can be accomplished by running: $ sudo apt-get install build-essential python-dev On RedHat and RedHat-based systems like CentOS and Fedora: $ sudo yum install gcc python-devel On OSX, homebrew installations of Python should provide the necessary headers. libev Support ------------- For libev support, you will also need to install libev and its headers. On Debian/Ubuntu: $ sudo apt-get install libev4 libev-dev On RHEL/CentOS/Fedora: $ sudo yum install libev libev-devel On OSX, via homebrew: $ brew install libev =============================================================================== """ def run(self): try: self._setup_extensions() build_ext.run(self) except DistutilsPlatformError as exc: sys.stderr.write('%s\n' % str(exc)) warnings.warn(self.error_message % "C extensions.") def build_extensions(self): if build_concurrency > 1: self.check_extensions_list(self.extensions) import multiprocessing.pool multiprocessing.pool.ThreadPool(processes=build_concurrency).map(self.build_extension, self.extensions) else: build_ext.build_extensions(self) def build_extension(self, ext): try: build_ext.build_extension(self, ext) except (CCompilerError, DistutilsExecError, DistutilsPlatformError, IOError) as exc: sys.stderr.write('%s\n' % str(exc)) name = "The %s extension" % (ext.name,) warnings.warn(self.error_message % (name,)) def _setup_extensions(self): # We defer extension setup until this command to leveraage 'setup_requires' pulling in Cython before we # attempt to import anything self.extensions = [] if try_murmur3: self.extensions.append(murmur3_ext) if try_libev: self.extensions.append(libev_ext) if try_cython: try: from Cython.Build import cythonize cython_candidates = ['cluster', 'concurrent', 'connection', 'cqltypes', 'metadata', 'pool', 'protocol', 'query', 'util'] compile_args = [] if is_windows else ['-Wno-unused-function'] self.extensions.extend(cythonize( [Extension('cassandra.%s' % m, ['cassandra/%s.py' % m], extra_compile_args=compile_args) for m in cython_candidates], nthreads=build_concurrency, exclude_failures=True)) self.extensions.extend(cythonize(NoPatchExtension("*", ["cassandra/*.pyx"], extra_compile_args=compile_args), nthreads=build_concurrency)) except Exception: sys.stderr.write("Failed to cythonize one or more modules. These will not be compiled as extensions (optional).\n") def pre_build_check(): """ Try to verify build tools """ if os.environ.get('CASS_DRIVER_NO_PRE_BUILD_CHECK'): return True try: from distutils.ccompiler import new_compiler from distutils.sysconfig import customize_compiler from distutils.dist import Distribution # base build_ext just to emulate compiler option setup be = build_ext(Distribution()) be.initialize_options() be.finalize_options() # First, make sure we have a Python include directory have_python_include = any(os.path.isfile(os.path.join(p, 'Python.h')) for p in be.include_dirs) if not have_python_include: sys.stderr.write("Did not find 'Python.h' in %s.\n" % (be.include_dirs,)) return False compiler = new_compiler(compiler=be.compiler) customize_compiler(compiler) executables = [] if compiler.compiler_type in ('unix', 'cygwin'): executables = [compiler.executables[exe][0] for exe in ('compiler_so', 'linker_so')] elif compiler.compiler_type == 'nt': executables = [getattr(compiler, exe) for exe in ('cc', 'linker')] if executables: from distutils.spawn import find_executable for exe in executables: if not find_executable(exe): sys.stderr.write("Failed to find %s for compiler type %s.\n" % (exe, compiler.compiler_type)) return False except Exception as exc: sys.stderr.write('%s\n' % str(exc)) sys.stderr.write("Failed pre-build check. Attempting anyway.\n") # if we are unable to positively id the compiler type, or one of these assumptions fails, # just proceed as we would have without the check return True def run_setup(extensions): kw = {'cmdclass': {'doc': DocCommand}} if gevent_nosetests is not None: kw['cmdclass']['gevent_nosetests'] = gevent_nosetests if eventlet_nosetests is not None: kw['cmdclass']['eventlet_nosetests'] = eventlet_nosetests kw['cmdclass']['build_ext'] = build_extensions kw['ext_modules'] = [Extension('DUMMY', [])] # dummy extension makes sure build_ext is called for install if try_cython: # precheck compiler before adding to setup_requires # we don't actually negate try_cython because: # 1.) build_ext eats errors at compile time, letting the install complete while producing useful feedback # 2.) there could be a case where the python environment has cython installed but the system doesn't have build tools if pre_build_check(): kw['setup_requires'] = ['Cython>=0.20,<0.25'] else: sys.stderr.write("Bypassing Cython setup requirement\n") dependencies = ['six >=1.6'] if not PY3: dependencies.append('futures') setup( name='cassandra-driver', version=__version__, description='Python driver for Cassandra', long_description=long_description, url='http://github.com/datastax/python-driver', author='Tyler Hobbs', author_email='tyler@datastax.com', packages=['cassandra', 'cassandra.io', 'cassandra.cqlengine'], keywords='cassandra,cql,orm', include_package_data=True, install_requires=dependencies, tests_require=['nose', 'mock<=1.0.1', 'PyYAML', 'pytz', 'sure'], classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'License :: OSI Approved :: Apache Software License', 'Natural Language :: English', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Software Development :: Libraries :: Python Modules' ], **kw) run_setup(None) if has_cqlengine: warnings.warn("\n#######\n'cqlengine' package is present on path: %s\n" "cqlengine is now an integrated sub-package of this driver.\n" "It is recommended to remove this package to reduce the chance for conflicting usage" % cqlengine.__file__)