pax_global_header00006660000000000000000000000064120337277320014520gustar00rootroot0000000000000052 comment=82e9740c7023b4c34f74cd95e6599f5fa0fe9d14 fe-1.1.0/000077500000000000000000000000001203372773200121115ustar00rootroot00000000000000fe-1.1.0/.gitignore000066400000000000000000000000211203372773200140720ustar00rootroot00000000000000build dist *.pyc fe-1.1.0/AUTHORS000066400000000000000000000012131203372773200131560ustar00rootroot00000000000000Contributors: James William Pye [faults are mostly mine] Elvis Pranskevichus William Grzybowski [subjective paramstyle] Barry Grussling [inet/cidr support] Matthew Grant [inet/cidr support] Support by Donation: AppCove Network Further Credits =============== When licenses match, people win. Code is occasionally imported from other projects to enhance py-postgresql and to allow users to enjoy dependency free installation. DB-API 2.0 Test Case -------------------- postgresql/test/test_dbapi20.py: Stuart Bishop fcrypt ------ postgresql/resolved/crypt.py: Carey Evans fe-1.1.0/LICENSE000066400000000000000000000031371203372773200131220ustar00rootroot00000000000000BSD Licensed Software Unless stated otherwise, the contained software is copyright 2004-2009, James Williame Pye. For more information: http://python.projects.postgresql.org Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. Neither the name of the James William Pye nor the names of [its] contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. fe-1.1.0/MANIFEST.in000066400000000000000000000003711203372773200136500ustar00rootroot00000000000000include AUTHORS include LICENSE recursive-include postgresql *.c recursive-include postgresql *.sql recursive-include postgresql *.txt recursive-include postgresql/documentation/sphinx *.rst conf.py recursive-include postgresql/documentation/html * fe-1.1.0/README000066400000000000000000000021501203372773200127670ustar00rootroot00000000000000About ===== py-postgresql is a Python 3 package providing modules to work with PostgreSQL. This includes a high-level driver, and many other tools that support a developer working with PostgreSQL databases. Installation ------------ Installation *should* be as simple as:: $ python3 ./setup.py install More information about installation is available via:: python -m postgresql.documentation.admin Basic Driver Usage ------------------ Using PG-API:: >>> import postgresql >>> db = postgresql.open('pq://user:password@host:port/database') >>> get_table = db.prepare("select * from information_schema.tables where table_name = $1") >>> for x in get_table("tables"): >>> print(x) >>> print(get_table.first("tables")) However, a DB-API 2.0 driver is provided as well: `postgresql.driver.dbapi20`. Further Information ------------------- Online documentation can be retrieved from: http://python.projects.postgresql.org Or, you can read them in your pager: python -m postgresql.documentation.index For information about PostgreSQL: http://postgresql.org For information about Python: http://python.org fe-1.1.0/postgresql/000077500000000000000000000000001203372773200143145ustar00rootroot00000000000000fe-1.1.0/postgresql/__init__.py000066400000000000000000000052221203372773200164260ustar00rootroot00000000000000## # py-postgresql root package # http://python.projects.postgresql.org ## """ py-postgresql is a Python package for using PostgreSQL. This includes low-level protocol tools, a driver(PG-API and DB-API), and cluster management tools. If it's not documented in the narratives, `postgresql.documentation.index`, then the stability of the APIs should *not* be trusted. See for more information about PostgreSQL. """ __all__ = [ '__author__', '__date__', '__version__', '__docformat__', 'version', 'version_info', 'open', ] #: The version string of py-postgresql. version = '' # overridden by subsequent import from .project. #: The version triple of py-postgresql: (major, minor, patch). version_info = () # overridden by subsequent import from .project. # Optional. try: from .project import version_info, version, \ author as __author__, date as __date__ __version__ = version except ImportError: pass # Avoid importing these until requested. _pg_iri = _pg_driver = _pg_param = None def open(iri = None, prompt_title = None, **kw): """ Create a `postgresql.api.Connection` to the server referenced by the given `iri`:: >>> import postgresql # General Format: >>> db = postgresql.open('pq://user:password@host:port/database') # Connect to 'postgres' at localhost. >>> db = postgresql.open('localhost/postgres') Connection keywords can also be used with `open`. See the narratives for more information. The `prompt_title` keyword is ignored. `open` will never prompt for the password unless it is explicitly instructed to do so. (Note: "pq" is the name of the protocol used to communicate with PostgreSQL) """ global _pg_iri, _pg_driver, _pg_param if _pg_iri is None: from . import iri as _pg_iri from . import driver as _pg_driver from . import clientparameters as _pg_param return_connector = False if iri is not None: if iri.startswith('&'): return_connector = True iri = iri[1:] iri_params = _pg_iri.parse(iri) iri_params.pop('path', None) else: iri_params = {} std_params = _pg_param.collect(prompt_title = None) # If unix is specified, it's going to conflict with any standard # settings, so remove them right here. if 'unix' in kw or 'unix' in iri_params: std_params.pop('host', None) std_params.pop('port', None) params = _pg_param.normalize( list(_pg_param.denormalize_parameters(std_params)) + \ list(_pg_param.denormalize_parameters(iri_params)) + \ list(_pg_param.denormalize_parameters(kw)) ) _pg_param.resolve_password(params) C = _pg_driver.default.fit(**params) if return_connector is True: return C else: c = C() c.connect() return c __docformat__ = 'reStructuredText' fe-1.1.0/postgresql/alock.py000066400000000000000000000104551203372773200157640ustar00rootroot00000000000000## # .alock - Advisory Locks ## """ Tools for Advisory Locks """ from abc import abstractmethod, abstractproperty from .python.element import Element __all__ = [ 'ALock', 'ExclusiveLock', 'ShareLock', ] class ALock(Element): """ Advisory Lock class for managing the acquisition and release of a sequence of PostgreSQL advisory locks. ALock()'s are fairly consistent with threading.RLock()'s. They can be acquired multiple times, and they must be released the same number of times for the lock to actually be released. A notably difference is that ALock's manage a sequence of lock identifiers. This means that a given ALock() may represent multiple advisory locks. """ _e_factors = ('database', 'identifiers',) _e_label = 'ALOCK' def _e_metas(self, headfmt = "{1} [{0}]".format ): yield None, headfmt(self.state, self.mode) @abstractproperty def mode(self): """ The mode of the lock class. """ @abstractproperty def __select_statements__(self): """ Implemented by subclasses to return the statements to try, acquire, and release the advisory lock. Returns a triple of callables where each callable takes two arguments, the lock-id pairs, and then the int8 lock-ids. ``(try, acquire, release)``. """ @staticmethod def _split_lock_identifiers(idseq): # lame O(2) id_pairs = [ list(x) if x.__class__ is not int else [None,None] for x in idseq ] ids = [ x if x.__class__ is int else None for x in idseq ] return (id_pairs, ids) def acquire(self, blocking = True, len = len): """ Acquire the locks using the configured identifiers. """ if self._count == 0: # _count is zero, so the locks need to be acquired. wait = bool(blocking) if wait: self._acquire(self._id_pairs, self._ids) else: # grab the success of each lock id. if some were # unsuccessful, then the ones that were successful need to be # released. r = self._try(self._id_pairs, self._ids) # accumulate the identifiers that *did* lock release_seq = [ id for didlock, id in zip(r, self.identifiers) if didlock[0] ] if len(release_seq) != len(self.identifiers): # some failed, so release the acquired and return False # # reverse in case there is another waiting for all. # that is, release last-to-first so that if another is waiting # on the same seq that it should be able to acquire all of # them once the contended lock is released. release_seq.reverse() self._release(*self._split_lock_identifiers(release_seq)) # unable to acquire all. return False self._count = self._count + 1 return True def __enter__(self): self.acquire() return self def release(self): """ Release the locks using the configured identifiers. """ if self._count < 1: raise RuntimeError("cannot release un-acquired lock") if not self.database.closed and self._count > 0: # if the database has been closed, or the count will # remain non-zero, there is no need to release. self._release(reversed(self._id_pairs), reversed(self._ids)) # decrement the count nonetheless. self._count = self._count - 1 def __exit__(self, typ, val, tb): self.release() def locked(self): """ Whether the locks have been acquired. This method is sensitive to the connection's state. If the connection is closed, it will return False. """ return (self._count > 0) and (not self.database.closed) @property def state(self): return 'locked' if self.locked() else 'unlocked' def __init__(self, database, *identifiers): """ Initialize the lock object to manage a sequence of advisory locks for use with the given database. """ self._count = 0 self.connection = self.database = database self.identifiers = identifiers self._id_pairs, self._ids = self._split_lock_identifiers(identifiers) self._try, self._acquire, self._release = self.__select_statements__() class ShareLock(ALock): mode = 'share' def __select_statements__(self): return ( self.database.sys.try_advisory_shared, self.database.sys.acquire_advisory_shared, self.database.sys.release_advisory_shared, ) class ExclusiveLock(ALock): mode = 'exclusive' def __select_statements__(self): return ( self.database.sys.try_advisory_exclusive, self.database.sys.acquire_advisory_exclusive, self.database.sys.release_advisory_exclusive, ) fe-1.1.0/postgresql/api.py000066400000000000000000001072561203372773200154520ustar00rootroot00000000000000## # .api - ABCs for database interface elements ## """ Application Programmer Interfaces for PostgreSQL. ``postgresql.api`` is a collection of Python APIs for the PostgreSQL DBMS. It is designed to take full advantage of PostgreSQL's features to provide the Python programmer with substantial convenience. This module is used to define "PG-API". It creates a set of ABCs that makes up the basic interfaces used to work with a PostgreSQL server. """ import collections import abc from .python.element import Element __all__ = [ 'Message', 'Statement', 'Chunks', 'Cursor', 'Connector', 'Category', 'Database', 'TypeIO', 'Connection', 'Transaction', 'Settings', 'StoredProcedure', 'Driver', 'Installation', 'Cluster', ] class Message(Element): """ A message emitted by PostgreSQL. A message being a NOTICE, WARNING, INFO, etc. """ _e_label = 'MESSAGE' severities = ( 'DEBUG', 'INFO', 'NOTICE', 'WARNING', 'ERROR', 'FATAL', 'PANIC', ) sources = ( 'SERVER', 'CLIENT', ) @property @abc.abstractmethod def source(self) -> str: """ Where the message originated from. Normally, 'SERVER', but sometimes 'CLIENT'. """ @property @abc.abstractmethod def code(self) -> str: """ The SQL state code of the message. """ @property @abc.abstractmethod def message(self) -> str: """ The primary message string. """ @property @abc.abstractmethod def details(self) -> dict: """ The additional details given with the message. Common keys *should* be the following: * 'severity' * 'context' * 'detail' * 'hint' * 'file' * 'line' * 'function' * 'position' * 'internal_position' * 'internal_query' """ @abc.abstractmethod def isconsistent(self, other) -> bool: """ Whether the fields of the `other` Message object is consistent with the fields of `self`. This *must* return the result of the comparison of code, source, message, and details. This method is provided as the alternative to overriding equality; often, pointer equality is the desirable means for comparison, but equality of the fields is also necessary. """ class Result(Element): """ A result is an object managing the results of a prepared statement. These objects represent a binding of parameters to a given statement object. For results that were constructed on the server and a reference passed back to the client, statement and parameters may be None. """ _e_label = 'RESULT' _e_factors = ('statement', 'parameters', 'cursor_id') @abc.abstractmethod def close(self) -> None: """ Close the Result handle. """ @property @abc.abstractmethod def cursor_id(self) -> str: """ The cursor's identifier. """ @property @abc.abstractmethod def sql_column_types(self) -> [str]: """ The type of the columns produced by the cursor. A sequence of `str` objects stating the SQL type name:: ['INTEGER', 'CHARACTER VARYING', 'INTERVAL'] """ @property @abc.abstractmethod def pg_column_types(self) -> [int]: """ The type Oids of the columns produced by the cursor. A sequence of `int` objects stating the SQL type name:: [27, 28] """ @property @abc.abstractmethod def column_names(self) -> [str]: """ The attribute names of the columns produced by the cursor. A sequence of `str` objects stating the column name:: ['column1', 'column2', 'emp_name'] """ @property @abc.abstractmethod def column_types(self) -> [str]: """ The Python types of the columns produced by the cursor. A sequence of type objects:: [, ] """ @property @abc.abstractmethod def parameters(self) -> (tuple, None): """ The parameters bound to the cursor. `None`, if unknown and an empty tuple `()`, if no parameters were given. These should be the *original* parameters given to the invoked statement. This should only be `None` when the cursor is created from an identifier, `postgresql.api.Database.cursor_from_id`. """ @property @abc.abstractmethod def statement(self) -> ("Statement", None): """ The query object used to create the cursor. `None`, if unknown. This should only be `None` when the cursor is created from an identifier, `postgresql.api.Database.cursor_from_id`. """ class Chunks( Result, collections.Iterator, collections.Iterable, ): pass class Cursor( Result, collections.Iterator, collections.Iterable, ): """ A `Cursor` object is an interface to a sequence of tuples(rows). A result set. Cursors publish a file-like interface for reading tuples from a cursor declared on the database. `Cursor` objects are created by invoking the `Statement.declare` method or by opening a cursor using an identifier via the `Database.cursor_from_id` method. """ _e_label = 'CURSOR' _seek_whence_map = { 0 : 'ABSOLUTE', 1 : 'RELATIVE', 2 : 'FROM_END', 3 : 'FORWARD', 4 : 'BACKWARD' } _direction_map = { True : 'FORWARD', False : 'BACKWARD', } @abc.abstractmethod def clone(self) -> "Cursor": """ Create a new cursor using the same factors as `self`. """ def __iter__(self): return self @property @abc.abstractmethod def direction(self) -> bool: """ The default `direction` argument for read(). When `True` reads are FORWARD. When `False` reads are BACKWARD. Cursor operation option. """ @abc.abstractmethod def read(self, quantity : "Number of rows to read" = None, direction : "Direction to fetch in, defaults to `self.direction`" = None, ) -> ["Row"]: """ Read, fetch, the specified number of rows and return them in a list. If quantity is `None`, all records will be fetched. `direction` can be used to override the default configured direction. This alters the cursor's position. Read does not directly correlate to FETCH. If zero is given as the quantity, an empty sequence *must* be returned. """ @abc.abstractmethod def __next__(self) -> "Row": """ Get the next tuple in the cursor. Advances the cursor position by one. """ @abc.abstractmethod def seek(self, offset, whence = 'ABSOLUTE'): """ Set the cursor's position to the given offset with respect to the whence parameter and the configured direction. Whence values: ``0`` or ``"ABSOLUTE"`` Absolute. ``1`` or ``"RELATIVE"`` Relative. ``2`` or ``"FROM_END"`` Absolute from end. ``3`` or ``"FORWARD"`` Relative forward. ``4`` or ``"BACKWARD"`` Relative backward. Direction effects whence. If direction is BACKWARD, ABSOLUTE positioning will effectively be FROM_END, RELATIVE's position will be negated, and FROM_END will effectively be ABSOLUTE. """ class Execution(metaclass = abc.ABCMeta): """ The abstract class of execution methods. """ @abc.abstractmethod def __call__(self, *parameters : "Positional Parameters") -> ["Row"]: """ Execute the prepared statement with the given arguments as parameters. Usage: >>> p=db.prepare("SELECT column FROM ttable WHERE key = $1") >>> p('identifier') [...] """ @abc.abstractmethod def column(self, *parameters) -> collections.Iterable: """ Return an iterator producing the values of first column of the rows produced by the cursor created from the statement bound with the given parameters. Column iterators are never scrollable. Supporting cursors will be WITH HOLD when outside of a transaction to allow cross-transaction access. `column` is designed for the situations involving large data sets. Each iteration returns a single value. column expressed in sibling terms:: return map(operator.itemgetter(0), self.rows(*parameters)) """ @abc.abstractmethod def chunks(self, *parameters) -> collections.Iterable: """ Return an iterator producing sequences of rows produced by the cursor created from the statement bound with the given parameters. Chunking iterators are *never* scrollable. Supporting cursors will be WITH HOLD when outside of a transaction. `chunks` is designed for moving large data sets efficiently. Each iteration returns sequences of rows *normally* of length(seq) == chunksize. If chunksize is unspecified, a default, positive integer will be filled in. The rows contained in the sequences are only required to support the basic `collections.Sequence` interfaces; simple and quick sequence types should be used. """ @abc.abstractmethod def rows(self, *parameters) -> collections.Iterable: """ Return an iterator producing rows produced by the cursor created from the statement bound with the given parameters. Row iterators are never scrollable. Supporting cursors will be WITH HOLD when outside of a transaction to allow cross-transaction access. `rows` is designed for the situations involving large data sets. Each iteration returns a single row. Arguably, best implemented:: return itertools.chain.from_iterable(self.chunks(*parameters)) """ @abc.abstractmethod def column(self, *parameters) -> collections.Iterable: """ Return an iterator producing the values of the first column in the cursor created from the statement bound with the given parameters. Column iterators are never scrollable. Supporting cursors will be WITH HOLD when outside of a transaction to allow cross-transaction access. `column` is designed for the situations involving large data sets. Each iteration returns a single value. `column` is equivalent to:: return map(operator.itemgetter(0), self.rows(*parameters)) """ @abc.abstractmethod def declare(self, *parameters) -> Cursor: """ Return a scrollable cursor with hold using the statement bound with the given parameters. """ @abc.abstractmethod def first(self, *parameters) -> "'First' object that is returned by the query": """ Execute the prepared statement with the given arguments as parameters. If the statement returns rows with multiple columns, return the first row. If the statement returns rows with a single column, return the first column in the first row. If the query does not return rows at all, return the count or `None` if no count exists in the completion message. Usage: >>> db.prepare("SELECT * FROM ttable WHERE key = $1").first("somekey") ('somekey', 'somevalue') >>> db.prepare("SELECT 'foo'").first() 'foo' >>> db.prepare("INSERT INTO atable (col) VALUES (1)").first() 1 """ @abc.abstractmethod def load_rows(self, iterable : "A iterable of tuples to execute the statement with" ): """ Given an iterable, `iterable`, feed the produced parameters to the query. This is a bulk-loading interface for parameterized queries. Effectively, it is equivalent to: >>> q = db.prepare(sql) >>> for i in iterable: ... q(*i) Its purpose is to allow the implementation to take advantage of the knowledge that a series of parameters are to be loaded so that the operation can be optimized. """ @abc.abstractmethod def load_chunks(self, iterable : "A iterable of chunks of tuples to execute the statement with" ): """ Given an iterable, `iterable`, feed the produced parameters of the chunks produced by the iterable to the query. This is a bulk-loading interface for parameterized queries. Effectively, it is equivalent to: >>> ps = db.prepare(...) >>> for c in iterable: ... for i in c: ... q(*i) Its purpose is to allow the implementation to take advantage of the knowledge that a series of chunks of parameters are to be loaded so that the operation can be optimized. """ class Statement( Element, collections.Callable, collections.Iterable, ): """ Instances of `Statement` are returned by the `prepare` method of `Database` instances. A Statement is an Iterable as well as Callable. The Iterable interface is supported for queries that take no arguments at all. It allows the syntax:: >>> for x in db.prepare('select * FROM table'): ... pass """ _e_label = 'STATEMENT' _e_factors = ('database', 'statement_id', 'string',) @property @abc.abstractmethod def statement_id(self) -> str: """ The statment's identifier. """ @property @abc.abstractmethod def string(self) -> object: """ The SQL string of the prepared statement. `None` if not available. This can happen in cases where a statement is prepared on the server and a reference to the statement is sent to the client which subsequently uses the statement via the `Database`'s `statement` constructor. """ @property @abc.abstractmethod def sql_parameter_types(self) -> [str]: """ The type of the parameters required by the statement. A sequence of `str` objects stating the SQL type name:: ['INTEGER', 'VARCHAR', 'INTERVAL'] """ @property @abc.abstractmethod def sql_column_types(self) -> [str]: """ The type of the columns produced by the statement. A sequence of `str` objects stating the SQL type name:: ['INTEGER', 'VARCHAR', 'INTERVAL'] """ @property @abc.abstractmethod def pg_parameter_types(self) -> [int]: """ The type Oids of the parameters required by the statement. A sequence of `int` objects stating the PostgreSQL type Oid:: [27, 28] """ @property @abc.abstractmethod def pg_column_types(self) -> [int]: """ The type Oids of the columns produced by the statement. A sequence of `int` objects stating the SQL type name:: [27, 28] """ @property @abc.abstractmethod def column_names(self) -> [str]: """ The attribute names of the columns produced by the statement. A sequence of `str` objects stating the column name:: ['column1', 'column2', 'emp_name'] """ @property @abc.abstractmethod def column_types(self) -> [type]: """ The Python types of the columns produced by the statement. A sequence of type objects:: [, ] """ @property @abc.abstractmethod def parameter_types(self) -> [type]: """ The Python types expected of parameters given to the statement. A sequence of type objects:: [, ] """ @abc.abstractmethod def clone(self) -> "Statement": """ Create a new statement object using the same factors as `self`. When used for refreshing plans, the new clone should replace references to the original. """ @abc.abstractmethod def close(self) -> None: """ Close the prepared statement releasing resources associated with it. """ Execution.register(Statement) PreparedStatement = Statement class StoredProcedure( Element, collections.Callable, ): """ A function stored on the database. """ _e_label = 'FUNCTION' _e_factors = ('database',) @abc.abstractmethod def __call__(self, *args, **kw) -> (object, Cursor, collections.Iterable): """ Execute the procedure with the given arguments. If keyword arguments are passed they must be mapped to the argument whose name matches the key. If any positional arguments are given, they must fill in gaps created by the stated keyword arguments. If too few or too many arguments are given, a TypeError must be raised. If a keyword argument is passed where the procedure does not have a corresponding argument name, then, likewise, a TypeError must be raised. In the case where the `StoredProcedure` references a set returning function(SRF), the result *must* be an iterable. SRFs that return single columns *must* return an iterable of that column; not row data. If the SRF returns a composite(OUT parameters), it *should* return a `Cursor`. """ ## # Arguably, it would be wiser to isolate blocks, and savepoints, but the utility # of the separation is not significant. It's really # more interesting as a formality that the user may explicitly state the # type of the transaction. However, this capability is not completely absent # from the current interface as the configuration parameters, or lack thereof, # help imply the expectations. class Transaction(Element): """ A `Tranaction` is an element that represents a transaction in the session. Once created, it's ready to be started, and subsequently committed or rolled back. Read-only transaction: >>> with db.xact(mode = 'read only'): ... ... Read committed isolation: >>> with db.xact(isolation = 'READ COMMITTED'): ... ... Savepoints are created if inside a transaction block: >>> with db.xact(): ... with db.xact(): ... ... """ _e_label = 'XACT' _e_factors = ('database',) @property @abc.abstractmethod def mode(self) -> (None, str): """ The mode of the transaction block: START TRANSACTION [ISOLATION] ; The `mode` property is a string and will be directly interpolated into the START TRANSACTION statement. """ @property @abc.abstractmethod def isolation(self) -> (None, str): """ The isolation level of the transaction block: START TRANSACTION [MODE]; The `isolation` property is a string and will be directly interpolated into the START TRANSACTION statement. """ @abc.abstractmethod def start(self) -> None: """ Start the transaction. If the database is in a transaction block, the transaction should be configured as a savepoint. If any transaction block configuration was applied to the transaction, raise a `postgresql.exceptions.OperationError`. If the database is not in a transaction block, start one using the configuration where: `self.isolation` specifies the ``ISOLATION LEVEL``. Normally, ``READ COMMITTED``, ``SERIALIZABLE``, or ``READ UNCOMMITTED``. `self.mode` specifies the mode of the transaction. Normally, ``READ ONLY`` or ``READ WRITE``. If the transaction is already open, do nothing. If the transaction has been committed or aborted, raise an `postgresql.exceptions.OperationError`. """ begin = start @abc.abstractmethod def commit(self) -> None: """ Commit the transaction. If the transaction is a block, issue a COMMIT statement. If the transaction was started inside a transaction block, it should be identified as a savepoint, and the savepoint should be released. If the transaction has already been committed, do nothing. """ @abc.abstractmethod def rollback(self) -> None: """ Abort the transaction. If the transaction is a savepoint, ROLLBACK TO the savepoint identifier. If the transaction is a transaction block, issue an ABORT. If the transaction has already been aborted, do nothing. """ abort = rollback @abc.abstractmethod def __enter__(self): """ Run the `start` method and return self. """ @abc.abstractmethod def __exit__(self, typ, obj, tb): """ If an exception is indicated by the parameters, run the transaction's `rollback` method iff the database is still available(not closed), and return a `False` value. If an exception is not indicated, but the database's transaction state is in error, run the transaction's `rollback` method and raise a `postgresql.exceptions.InFailedTransactionError`. If the database is unavailable, the `rollback` method should cause a `postgresql.exceptions.ConnectionDoesNotExistError` exception to occur. Otherwise, run the transaction's `commit` method. When the `commit` is ultimately unsuccessful or not ran at all, the purpose of __exit__ is to resolve the error state of the database iff the database is available(not closed) so that more commands can be after the block's exit. """ class Settings( Element, collections.MutableMapping ): """ A mapping interface to the session's settings. This provides a direct interface to ``SHOW`` or ``SET`` commands. Identifiers and values need not be quoted specially as the implementation must do that work for the user. """ _e_label = 'SETTINGS' @abc.abstractmethod def __getitem__(self, key): """ Return the setting corresponding to the given key. The result should be consistent with what the ``SHOW`` command returns. If the key does not exist, raise a KeyError. """ @abc.abstractmethod def __setitem__(self, key, value): """ Set the setting with the given key to the given value. The action should be consistent with the effect of the ``SET`` command. """ @abc.abstractmethod def __call__(self, **kw): """ Create a context manager applying the given settings on __enter__ and restoring the old values on __exit__. >>> with db.settings(search_path = 'local,public'): ... ... """ @abc.abstractmethod def get(self, key, default = None): """ Get the setting with the corresponding key. If the setting does not exist, return the `default`. """ @abc.abstractmethod def getset(self, keys): """ Return a dictionary containing the key-value pairs of the requested settings. If *any* of the keys do not exist, a `KeyError` must be raised with the set of keys that did not exist. """ @abc.abstractmethod def update(self, mapping): """ For each key-value pair, incur the effect of the `__setitem__` method. """ @abc.abstractmethod def keys(self): """ Return an iterator to all of the settings' keys. """ @abc.abstractmethod def values(self): """ Return an iterator to all of the settings' values. """ @abc.abstractmethod def items(self): """ Return an iterator to all of the setting value pairs. """ class Database(Element): """ The interface to an individual database. `Connection` objects inherit from this """ _e_label = 'DATABASE' @property @abc.abstractmethod def backend_id(self) -> (int, None): """ The backend's process identifier. """ @property @abc.abstractmethod def version_info(self) -> tuple: """ A version tuple of the database software similar Python's `sys.version_info`. >>> db.version_info (8, 1, 3, '', 0) """ @property @abc.abstractmethod def client_address(self) -> (str, None): """ The client address that the server sees. This is obtainable by querying the ``pg_catalog.pg_stat_activity`` relation. `None` if unavailable. """ @property @abc.abstractmethod def client_port(self) -> (int, None): """ The client port that the server sees. This is obtainable by querying the ``pg_catalog.pg_stat_activity`` relation. `None` if unavailable. """ @property @abc.abstractmethod def xact(self, isolation : "ISOLATION LEVEL to use with the transaction" = None, mode : "Mode of the transaction, READ ONLY or READ WRITE" = None, ) -> Transaction: """ Create a `Transaction` object using the given keyword arguments as its configuration. """ @property @abc.abstractmethod def settings(self) -> Settings: """ A `Settings` instance bound to the `Database`. """ @abc.abstractmethod def do(language, source) -> None: """ Execute a DO statement using the given language and source. Always returns `None`. Likely to be a function of Connection.execute. """ @abc.abstractmethod def execute(sql) -> None: """ Execute an arbitrary block of SQL. Always returns `None` and raise an exception on error. """ @abc.abstractmethod def prepare(self, sql : str) -> Statement: """ Create a new `Statement` instance bound to the connection using the given SQL. >>> s = db.prepare("SELECT 1") >>> c = s() >>> c.next() (1,) """ @abc.abstractmethod def statement_from_id(self, statement_id : "The statement's identification string.", ) -> Statement: """ Create a `Statement` object that was already prepared on the server. The distinction between this and a regular query is that it must be explicitly closed if it is no longer desired, and it is instantiated using the statement identifier as opposed to the SQL statement itself. """ @abc.abstractmethod def cursor_from_id(self, cursor_id : "The cursor's identification string." ) -> Cursor: """ Create a `Cursor` object from the given `cursor_id` that was already declared on the server. `Cursor` objects created this way must *not* be closed when the object is garbage collected. Rather, the user must explicitly close it for the server resources to be released. This is in contrast to `Cursor` objects that are created by invoking a `Statement` or a SRF `StoredProcedure`. """ @abc.abstractmethod def proc(self, procedure_id : \ "The procedure identifier; a valid ``regprocedure`` or Oid." ) -> StoredProcedure: """ Create a `StoredProcedure` instance using the given identifier. The `proc_id` given can be either an ``Oid``, or a ``regprocedure`` that identifies the stored procedure to create the interface for. >>> p = db.proc('version()') >>> p() 'PostgreSQL 8.3.0' >>> qstr = "select oid from pg_proc where proname = 'generate_series'" >>> db.prepare(qstr).first() 1069 >>> generate_series = db.proc(1069) >>> list(generate_series(1,5)) [1, 2, 3, 4, 5] """ @abc.abstractmethod def reset(self) -> None: """ Reset the connection into it's original state. Issues a ``RESET ALL`` to the database. If the database supports removing temporary tables created in the session, then remove them. Reapply initial configuration settings such as path. The purpose behind this method is to provide a soft-reconnect method that re-initializes the connection into its original state. One obvious use of this would be in a connection pool where the connection is being recycled. """ @abc.abstractmethod def notify(self, *channels, **channel_and_payload) -> int: """ NOTIFY the channels with the given payload. Equivalent to issuing "NOTIFY " or "NOTIFY , " for each item in `channels` and `channel_and_payload`. All NOTIFYs issued *must* occur in the same transaction. The items in `channels` can either be a string or a tuple. If a string, no payload is given, but if an item is a `builtins.tuple`, the second item will be given as the payload. `channels` offers a means to issue NOTIFYs in guaranteed order. The items in `channel_and_payload` are all payloaded NOTIFYs where the keys are the channels and the values are the payloads. Order is undefined. """ @abc.abstractmethod def listen(self, *channels) -> None: """ Start listening to the given channels. Equivalent to issuing "LISTEN " for x in channels. """ @abc.abstractmethod def unlisten(self, *channels) -> None: """ Stop listening to the given channels. Equivalent to issuing "UNLISTEN " for x in channels. """ @abc.abstractmethod def listening_channels(self) -> ["channel name", ...]: """ Return an *iterator* to all the channels currently being listened to. """ @abc.abstractmethod def iternotifies(self, timeout = None) -> collections.Iterator: """ Return an iterator to the notifications received by the connection. The iterator *must* produce triples in the form ``(channel, payload, pid)``. If timeout is not `None`, `None` *must* be emitted at the specified timeout interval. If the timeout is zero, all the pending notifications *must* be yielded by the iterator and then `StopIteration` *must* be raised. If the connection is closed for any reason, the iterator *must* silently stop by raising `StopIteration`. Further error control is then the responsibility of the user. """ class TypeIO(Element): _e_label = 'TYPIO' def _e_metas(self): return () class SocketFactory(object): @property @abc.abstractmethod def fatal_exception(self) -> Exception: """ The exception that is raised by sockets that indicate a fatal error. The exception can be a base exception as the `fatal_error_message` will indicate if that particular exception is actually fatal. """ @property @abc.abstractmethod def timeout_exception(self) -> Exception: """ The exception raised by the socket when an operation could not be completed due to a configured time constraint. """ @property @abc.abstractmethod def tryagain_exception(self) -> Exception: """ The exception raised by the socket when an operation was interrupted, but should be tried again. """ @property @abc.abstractmethod def tryagain(self, err : Exception) -> bool: """ Whether or not `err` suggests the operation should be tried again. """ @abc.abstractmethod def fatal_exception_message(self, err : Exception) -> (str, None): """ A function returning a string describing the failure, this string will be given to the `postgresql.exceptions.ConnectionFailure` instance that will subsequently be raised by the `Connection` object. Returns `None` when `err` is not actually fatal. """ @abc.abstractmethod def socket_secure(self, socket : "socket object") -> "secured socket": """ Return a reference to the secured socket using the given parameters. If securing the socket for the connector is impossible, the user should never be able to instantiate the connector with parameters requesting security. """ @abc.abstractmethod def socket_factory_sequence(self) -> [collections.Callable]: """ Return a sequence of `SocketCreator`s that `Connection` objects will use to create the socket object. """ class Category(Element): """ A category is an object that initializes the subject connection for a specific purpose. Arguably, a runtime class for use with connections. """ _e_label = 'CATEGORY' _e_factors = () @abc.abstractmethod def __call__(self, connection): """ Initialize the given connection in order to conform to the category. """ class Connector(Element): """ A connector is an object providing the necessary information to establish a connection. This includes credentials, database settings, and many times addressing information. """ _e_label = 'CONNECTOR' _e_factors = ('driver', 'category') def __call__(self, *args, **kw): """ Create and connect. Arguments will be given to the `Connection` instance's `connect` method. """ return self.driver.connection(self, *args, **kw) def __init__(self, user : "required keyword specifying the user name(str)" = None, password : str = None, database : str = None, settings : (dict, [(str,str)]) = None, category : Category = None, ): if user is None: # sure, it's a "required" keyword, makes for better documentation raise TypeError("'user' is a required keyword") self.user = user self.password = password self.database = database self.settings = settings self.category = category if category is not None and not isinstance(category, Category): raise TypeError("'category' must a be `None` or `postgresql.api.Category`") class Connection(Database): """ The interface to a connection to a PostgreSQL database. This is a `Database` interface with the additional connection management tools that are particular to using a remote database. """ _e_label = 'CONNECTION' _e_factors = ('connector',) @property @abc.abstractmethod def connector(self) -> Connector: """ The :py:class:`Connector` instance facilitating the `Connection` object's communication and initialization. """ @property @abc.abstractmethod def query(self) -> Execution: """ The :py:class:`Execution` instance providing a one-shot query interface:: connection.query.(sql, *parameters) == connection.prepare(sql).(*parameters) """ @property @abc.abstractmethod def closed(self) -> bool: """ `True` if the `Connection` is closed, `False` if the `Connection` is open. >>> db.closed True """ @abc.abstractmethod def clone(self) -> "Connection": """ Create another connection using the same factors as `self`. The returned object should be open and ready for use. """ @abc.abstractmethod def connect(self) -> None: """ Establish the connection to the server and initialize the category. Does nothing if the connection is already established. """ cat = self.connector.category if cat is not None: cat(self) @abc.abstractmethod def close(self) -> None: """ Close the connection. Does nothing if the connection is already closed. """ @abc.abstractmethod def __enter__(self): """ Establish the connection and return self. """ @abc.abstractmethod def __exit__(self, typ, obj, tb): """ Closes the connection and returns `False` when an exception is passed in, `True` when `None`. """ class Driver(Element): """ The `Driver` element provides the `Connector` and other information pertaining to the implementation of the driver. Information about what the driver supports is available in instances. """ _e_label = "DRIVER" _e_factors = () @abc.abstractmethod def connect(**kw): """ Create a connection using the given parameters for the Connector. """ class Installation(Element): """ Interface to a PostgreSQL installation. Instances would provide various information about an installation of PostgreSQL accessible by the Python """ _e_label = "INSTALLATION" _e_factors = () @property @abc.abstractmethod def version(self): """ A version string consistent with what `SELECT version()` would output. """ @property @abc.abstractmethod def version_info(self): """ A tuple specifying the version in a form similar to Python's sys.version_info. (8, 3, 3, 'final', 0) See `postgresql.versionstring`. """ @property @abc.abstractmethod def type(self): """ The "type" of PostgreSQL. Normally, the first component of the string returned by pg_config. """ @property @abc.abstractmethod def ssl(self) -> bool: """ Whether the installation supports SSL. """ class Cluster(Element): """ Interface to a PostgreSQL cluster--a data directory. An implementation of this provides a means to control a server. """ _e_label = 'CLUSTER' _e_factors = ('installation', 'data_directory') @property @abc.abstractmethod def installation(self) -> Installation: """ The installation used by the cluster. """ @property @abc.abstractmethod def data_directory(self) -> str: """ The path to the data directory of the cluster. """ @abc.abstractmethod def init(self, initdb : "path to the initdb to use" = None, user : "name of the cluster's superuser" = None, password : "superuser's password" = None, encoding : "the encoding to use for the cluster" = None, locale : "the locale to use for the cluster" = None, collate : "the collation to use for the cluster" = None, ctype : "the ctype to use for the cluster" = None, monetary : "the monetary to use for the cluster" = None, numeric : "the numeric to use for the cluster" = None, time : "the time to use for the cluster" = None, text_search_config : "default text search configuration" = None, xlogdir : "location for the transaction log directory" = None, ): """ Create the cluster at the `data_directory` associated with the Cluster instance. """ @abc.abstractmethod def drop(self): """ Kill the server and completely remove the data directory. """ @abc.abstractmethod def start(self): """ Start the cluster. """ @abc.abstractmethod def stop(self): """ Signal the server to shutdown. """ @abc.abstractmethod def kill(self): """ Kill the server. """ @abc.abstractmethod def restart(self): """ Restart the cluster. """ @abc.abstractmethod def wait_until_started(self, timeout : "maximum time to wait" = 10 ): """ After the start() method is ran, the database may not be ready for use. This method provides a mechanism to block until the cluster is ready for use. If the `timeout` is reached, the method *must* throw a `postgresql.exceptions.ClusterTimeoutError`. """ @abc.abstractmethod def wait_until_stopped(self, timeout : "maximum time to wait" = 10 ): """ After the stop() method is ran, the database may still be running. This method provides a mechanism to block until the cluster is completely shutdown. If the `timeout` is reached, the method *must* throw a `postgresql.exceptions.ClusterTimeoutError`. """ @property @abc.abstractmethod def settings(self): """ A `Settings` interface to the ``postgresql.conf`` file associated with the cluster. """ @abc.abstractmethod def __enter__(self): """ Start the cluster if it's not already running, and wait for it to be readied. """ @abc.abstractmethod def __exit__(self, exc, val, tb): """ Stop the cluster and wait for it to shutdown *iff* it was started by the corresponding enter. """ __docformat__ = 'reStructuredText' if __name__ == '__main__': help(__package__ + '.api') ## # vim: ts=3:sw=3:noet: fe-1.1.0/postgresql/bin/000077500000000000000000000000001203372773200150645ustar00rootroot00000000000000fe-1.1.0/postgresql/bin/__init__.py000066400000000000000000000002661203372773200172010ustar00rootroot00000000000000""" Console-script collection package. Contents: pg_python Python console with a PostgreSQL connection bound to `db`. pg_dotconf Modify a PostgreSQL configuration file. """ fe-1.1.0/postgresql/bin/pg_dotconf.py000066400000000000000000000033531203372773200175640ustar00rootroot00000000000000#!/usr/bin/env python import sys import os from optparse import OptionParser from .. import configfile from .. import __version__ __all__ = ['command'] def command(args): """ pg_dotconf script entry point. """ op = OptionParser( "%prog [--stdout] [-f settings] postgresql.conf ([param=val]|[param])*", version = __version__ ) op.add_option( '-f', '--file', dest = 'settings', help = 'A file of settings to *apply* to the given "postgresql.conf"', default = [], action = 'append', ) op.add_option( '--stdout', dest = 'stdout', help = 'Redirect the product to standard output instead of writing back to the "postgresql.conf" file', action = 'store_true', default = False ) co, ca = op.parse_args(args[1:]) if not ca: return 0 settings = {} for sfp in co.settings: with open(sfp) as sf: for line in sf: pl = configfile.parse_line(line) if pl is not None: if comment not in line[pl[0].start]: settings[line[pl[0]]] = unquote(line[pl[1]]) prev = None for p in ca[1:]: if '=' not in p: k = p v = None else: k, v = p.split('=', 1) k = k.strip() if not k: sys.stderr.write("ERROR: invalid setting, %r after %r%s" %( p, prev, os.linesep )) sys.stderr.write( "HINT: Settings must take the form 'setting=value' " \ "or 'setting_name_to_comment'. Settings must also be received " \ "as a single argument." + os.linesep ) sys.exit(1) prev = p settings[k] = v fp = ca[0] with open(fp, 'r') as fr: lines = configfile.alter_config(settings, fr) if co.stdout or fp == '/dev/stdin': for l in lines: sys.stdout.write(l) else: with open(fp, 'w') as fw: for l in lines: fw.write(l) return 0 if __name__ == '__main__': sys.exit(command(sys.argv)) fe-1.1.0/postgresql/bin/pg_python.py000066400000000000000000000066741203372773200174620ustar00rootroot00000000000000## # .bin.pg_python - Python console with a connection. ## """ Python command with a PG-API connection(``db``). """ import os import sys import re import code import optparse import contextlib from .. import clientparameters from ..python import command as pycmd from .. import project from ..driver import default as pg_driver from .. import exceptions as pg_exc from .. import sys as pg_sys from .. import lib as pg_lib pq_trace = optparse.make_option( '--pq-trace', dest = 'pq_trace', help = 'trace PQ protocol transmissions', default = None, ) default_options = [ pq_trace, clientparameters.option_lib, clientparameters.option_libpath, ] + pycmd.default_optparse_options def command(argv = sys.argv): p = clientparameters.DefaultParser( "%prog [connection options] [script] ...", version = project.version, option_list = default_options ) p.disable_interspersed_args() co, ca = p.parse_args(argv[1:]) rv = 1 # Resolve the category. pg_sys.libpath.insert(0, os.path.curdir) pg_sys.libpath.extend(co.libpath or []) if co.lib: cat = pg_lib.Category(*map(pg_lib.load, co.lib)) else: cat = None trace_file = None if co.pq_trace is not None: trace_file = open(co.pq_trace, 'a') try: need_prompt = False cond = None connector = None connection = None while connection is None: try: cond = clientparameters.collect(parsed_options = co, prompt_title = None) if need_prompt: # authspec error thrown last time, so force prompt. cond['prompt_password'] = True try: clientparameters.resolve_password(cond, prompt_title = 'pg_python') except EOFError: raise SystemExit(1) connector = pg_driver.fit(category = cat, **cond) connection = connector() if trace_file is not None: connection.tracer = trace_file.write connection.connect() except pg_exc.ClientCannotConnectError as err: for att in connection.failures: exc = att.error if isinstance(exc, pg_exc.AuthenticationSpecificationError): sys.stderr.write(os.linesep + exc.message + (os.linesep*2)) # keep prompting the user need_prompt = True connection = None break else: # no invalid password failures.. raise pythonexec = pycmd.Execution(ca, context = getattr(co, 'python_context', None), loader = getattr(co, 'python_main', None), ) builtin_overload = { # New built-ins 'connector' : connector, 'db' : connection, 'do' : connection.do, 'prepare' : connection.prepare, 'sqlexec' : connection.execute, 'settings' : connection.settings, 'proc' : connection.proc, 'xact' : connection.xact, } if not isinstance(__builtins__, dict): builtins_d = __builtins__.__dict__ else: builtins_d = __builtins__ restore = {k : builtins_d.get(k) for k in builtin_overload} builtins_d.update(builtin_overload) try: with connection: rv = pythonexec( context = pycmd.postmortem(os.environ.get('PYTHON_POSTMORTEM')) ) exc = getattr(sys, 'last_type', None) if rv and exc and not issubclass(exc, Exception): # Don't try to close it if wasn't an Exception. del connection.pq.socket finally: # restore __builtins__ builtins_d.update(restore) for k, v in builtin_overload.items(): if v is None: del builtins_d[x] if trace_file is not None: trace_file.close() except: pg_sys.libpath.remove(os.path.curdir) raise return rv if __name__ == '__main__': sys.exit(command(sys.argv)) ## # vim: ts=3:sw=3:noet: fe-1.1.0/postgresql/clientparameters.py000066400000000000000000000413471203372773200202410ustar00rootroot00000000000000## # .clientparameters ## """ Collect client connection parameters from various sources. This module provides functions for collecting client parameters from various sources such as user relative defaults, environment variables, and even command line options. There are two primary data-structures that this module deals with: normalized parameters and denormalized parameters. Normalized parameters is a proper mapping object, dictionary, consisting of the parameters used to apply to a connection creation interface. The high-level interface, ``collect`` returns normalized parameters. Denormalized parameters is a sequence or iterable of key-value pairs. However, the key is always a tuple whose components make up the "key-path". This is used to support sub-dictionaries like settings:: >>> normal_params = { 'user' : 'jwp', 'host' : 'localhost', 'settings' : {'default_statistics_target' : 200, 'search_path' : 'home,public'} } Denormalized parameters are used to simplify the overriding of past parameters. For this to work with dictionaries in a general fashion, dictionary objects would need a "deep update" method. """ import sys import os import configparser import optparse from itertools import chain from functools import partial from . import iri as pg_iri from . import pgpassfile as pg_pass from . exceptions import Error class ClientParameterError(Error): code = '-*000' source = '.clientparameters' class ServiceDoesNotExistError(ClientParameterError): code = '-*srv' try: from getpass import getuser, getpass except ImportError: getpass = raw_input def getuser(): return 'postgres' default_host = 'localhost' default_port = 5432 pg_service_envvar = 'PGSERVICE' pg_service_file_envvar = 'PGSERVICEFILE' pg_sysconfdir_envvar = 'PGSYSCONFDIR' pg_service_filename = 'pg_service.conf' pg_service_user_filename = '.pg_service.conf' # posix pg_home_passfile = '.pgpass' pg_home_directory = '.postgresql' # win32 pg_appdata_directory = 'postgresql' pg_appdata_passfile = 'pgpass.conf' # In order to support pg_service.conf, it is # necessary to identify driver parameters, so # that database configuration parameters can # be placed in settings. pg_service_driver_parameters = set([ 'user', 'host', 'database', 'port', 'password', 'sslcrtfile', 'sslkeyfile', 'sslrootcrtfile', 'sslrootkeyfile', 'sslmode', 'server_encoding', 'connect_timeout', ]) # environment variables that will be in the parameters' "settings" dictionary. default_envvar_settings_map = { 'TZ' : 'timezone', 'DATESTYLE' : 'datestyle', 'CLIENTENCODING' : 'client_encoding', 'GEQO' : 'geqo', 'OPTIONS' : 'options', } # Environment variables that require no transformation. default_envvar_map = { 'USER' : 'user', 'DATABASE' : 'database', 'HOST' : 'host', 'PORT' : 'port', 'PASSWORD' : 'password', 'SSLMODE' : 'sslmode', 'SSLKEY' : 'sslkey', 'CONNECT_TIMEOUT' : 'connect_timeout', 'REALM' : 'kerberos4_realm', 'KRBSRVNAME' : 'kerberos5_service', # Extensions #'ROLE' : 'role', # SET ROLE $PGROLE # This keyword *should* never make it to a connect() function # as `resolve_password` should be called to fill in the # parameter accordingly. 'PASSFILE' : 'pgpassfile', } def defaults(environ = os.environ): """ Produce the defaults based on the existing configuration. """ user = getuser() or 'postgres' userdir = os.path.expanduser('~' + user) or '/dev/null' pgdata = os.path.join(userdir, pg_home_directory) yield ('user',), getuser() yield ('host',), default_host yield ('port',), default_port # If appdata is available, override the pgdata and pgpassfile # configuration settings. if sys.platform == 'win32': appdata = environ.get('APPDATA') if appdata: pgdata = os.path.join(appdata, pg_appdata_directory) pgpassfile = os.path.join(pgdata, pg_appdata_passfile) else: pgpassfile = os.path.join(userdir, pg_home_passfile) for k, v in ( ('sslcrtfile', os.path.join(pgdata, 'postgresql.crt')), ('sslkeyfile', os.path.join(pgdata, 'postgresql.key')), ('sslrootcrtfile', os.path.join(pgdata, 'root.crt')), ('sslrootcrlfile', os.path.join(pgdata, 'root.crl')), ('pgpassfile', pgpassfile), ): if os.path.exists(v): yield (k,), v def envvars(environ = os.environ, modifier : "environment variable key modifier" = 'PG'.__add__): """ Create a clientparams dictionary from the given environment variables. PGUSER -> user PGDATABASE -> database PGHOST -> host PGHOSTADDR -> host (overrides PGHOST) PGPORT -> port PGPASSWORD -> password PGPASSFILE -> pgpassfile PGSSLMODE -> sslmode PGREQUIRESSL gets rewritten into "sslmode = 'require'". PGREALM -> kerberos4_realm PGKRBSVRNAME -> kerberos5_service PGSSLKEY -> sslkey PGTZ -> settings['timezone'] PGDATESTYLE -> settings['datestyle'] PGCLIENTENCODING -> settings['client_encoding'] PGGEQO -> settings['geqo'] The 'PG' prefix can be customized via the `modifier` argument. However, PGSYSCONFDIR will not respect any such change as it's not a client parameter itself. """ hostaddr = modifier('HOSTADDR') reqssl = modifier('REQUIRESSL') if reqssl in environ: if environ[reqssl].strip() == '1': yield ('sslmode',), ('require', reqssl + '=1') for k, v in default_envvar_map.items(): k = modifier(k) if k in environ: yield ((v,), environ[k]) if hostaddr in environ: yield (('host',), environ[hostaddr]) envvar_settings_map = (( (modifier(k), v) for k,v in default_envvar_settings_map.items() )) settings = [ (('settings', v,), environ[k]) for k, v in envvar_settings_map if k in environ ] # PGSYSCONFDIR based if pg_sysconfdir_envvar in environ: yield ('config-pg_sysconfdir', environ[pg_sysconfdir_envvar]) # PGSERVICEFILE based if pg_service_file_envvar in environ: yield ('config-pg_service_file', environ[pg_service_file_envvar]) service = modifier('SERVICE') if service in environ: yield ('pg_service', environ[service]) ## # optparse options ## option_datadir = optparse.make_option('-D', '--datadir', help = 'location of the database storage area', default = None, dest = 'datadir', ) option_in_xact = optparse.make_option('-1', '--with-transaction', dest = 'in_xact', action = 'store_true', help = 'run operation with a transaction block', ) def append_db_client_parameters(option, opt_str, value, parser): # for options without arguments, None is passed in. value = True if value is None else value parser.values.db_client_parameters.append( ((option.dest,), value) ) make_option = partial( optparse.make_option, action = 'callback', callback = append_db_client_parameters ) option_user = make_option('-U', '--username', dest = 'user', type = 'str', help = 'user name to connect as', ) option_database = make_option('-d', '--database', type = 'str', help = "database's name", dest = 'database', ) option_password = make_option('-W', '--password', dest = 'prompt_password', help = 'prompt for password', ) option_host = make_option('-h', '--host', help = 'database server host', type = 'str', dest = 'host', ) option_port = make_option('-p', '--port', help = 'database server port', type = 'str', dest = 'port', ) option_unix = make_option('--unix', help = 'path to filesystem socket', type = 'str', dest = 'unix', ) def append_settings(option, opt_str, value, parser): 'split the string into a (key,value) pair tuple' kv = value.split('=', 1) if len(kv) != 2: raise OptionValueError("invalid setting argument, %r" %(value,)) parser.values.db_client_parameters.append( ((option.dest, kv[0]), kv[1]) ) option_settings = make_option('-s', '--setting', dest = 'settings', help = 'run-time parameters to set upon connecting', callback = append_settings, type = 'str', ) option_sslmode = make_option('--ssl-mode', dest = 'sslmode', help = 'SSL requirement for connectivity: require, prefer, allow, disable', choices = ('require','prefer','allow','disable'), type = 'choice', ) def append_db_client_x_parameters(option, opt_str, value, parser): parser.values.db_client_parameters.append((option.dest, value)) make_x_option = partial(make_option, callback = append_db_client_x_parameters) option_iri = make_x_option('-I', '--iri', help = 'database locator string [pq://user:password@host:port/database?[driver_param]=value&setting=value]', type = 'str', dest = 'pq_iri', ) option_lib = optparse.make_option('-l', help = 'bind the library found in postgresql.sys.libpath to the connection', type = 'str', dest = 'lib', action = 'append' ) option_libpath = optparse.make_option('-L', help = 'append the library path', type = 'str', dest = 'libpath', action = 'append' ) # PostgreSQL Standard Options standard_optparse_options = ( option_host, option_port, option_user, option_password, option_database, ) class StandardParser(optparse.OptionParser): """ Option parser limited to the basic -U, -h, -p, -W, and -D options. This parser subclass is necessary for two reasons: 1. _add_help_option override to not conflict with -h 2. Initialize the db_client_parameters on the parser's values. See the DefaultParser for more fun. """ standard_option_list = standard_optparse_options def get_default_values(self, *args, **kw): v = super().get_default_values(*args, **kw) v.db_client_parameters = [] return v def _add_help_option(self): # Only allow long --help so that it will not conflict with -h(host) self.add_option("--help", action = "help", help = "show this help message and exit", ) # Extended Options default_optparse_options = [ option_unix, option_sslmode, option_settings, # Complex Options option_iri, ] default_optparse_options.extend(standard_optparse_options) class DefaultParser(StandardParser): """ Parser that includes a variety of connectivity options. (IRI, sslmode, settings) """ standard_option_list = default_optparse_options def resolve_password( parameters : "a fully normalized set of client parameters(dict)", getpass = getpass, prompt_title = '', ): """ Given a parameters dictionary, resolve the 'password' key. If `prompt_password` is `True`. If sys.stdin is a TTY, use `getpass` to prompt the user. Otherwise, read a single line from sys.stdin. delete 'prompt_password' from the dictionary. Otherwise. If the 'password' key is `None`, attempt to resolve the password using the 'pgpassfile' key. Finally, remove the pgpassfile key as the password has been resolved for the given parameters. """ prompt_for_password = parameters.pop('prompt_password', False) pgpassfile = parameters.pop('pgpassfile', None) prompt_title = parameters.pop('prompt_title', None) if prompt_for_password is True: # it's a prompt if sys.stdin.isatty(): prompt = prompt_title or parameters.pop('prompt_title', '') prompt += '[' + pg_iri.serialize(parameters, obscure_password = True) + ']' parameters['password'] = getpass("Password for " + prompt +": ") else: # getpass will throw an exception if it's not a tty, # so just take the next line. pw = sys.stdin.readline() # try to clean it up.. if pw.endswith(os.linesep): pw = pw[:len(pw)-len(os.linesep)] parameters['password'] = pw else: if parameters.get('password') is None: # No password? Look in the pgpassfile. if pgpassfile is not None: parameters['password'] = pg_pass.lookup_pgpass(parameters, pgpassfile) # Don't need the pgpassfile parameter anymore as the password # has been resolved. def x_settings(sdict, config): d=dict(sdict) for (k,v) in d.items(): yield (('settings', k), v) def denormalize_parameters(p): """ Given a fully normalized parameters dictionary: {'host': 'localhost', 'settings' : {'timezone':'utc'}} Denormalize it: [(('host',), 'localhost'), (('settings','timezone'), 'utc')] """ for k,v in p.items(): if k == 'settings': for sk, sv in dict(v).items(): yield (('settings', sk), sv) else: yield ((k,), v) def x_pq_iri(iri, config): return denormalize_parameters(pg_iri.parse(iri)) # Lookup service data using the `service_name` # Be sure to map 'dbname' to 'database'. def x_pg_service(service_name, config): service_files = [] f = config.get('pg_service_file') if f is not None: # service file override service_files.append(f) else: # override is not specified, use the user service file home = os.path.expanduser('~' + getuser()) service_files.append(os.path.join(home, pg_service_user_filename)) # global service file is checked next. sysconfdir = config.get('pg_sysconfdir') if sysconfdir: sf = config.get('pg_service_filename', pg_service_filename) f = os.path.join(sysconfdir, sf) # existence will be checked later. service_files.append(f) for sf in service_files: if not os.path.exists(sf): continue cp = configparser.RawConfigParser() cp.read(sf) try: s = cp.items(service_name) except configparser.NoSectionError: continue for (k, v) in s: k = k.lower() if k == 'ldap': yield ('pg_ldap', ':'.join((k, v))) elif k == 'pg_service': # ignore pass elif k == 'hostaddr': # XXX: should yield ipv as well? yield (('host',), v) elif k == 'dbname': yield (('database',), v) elif k not in pg_service_driver_parameters: # it's a GUC. yield (('settings', k), v) else: yield ((k,), v) else: break else: # iterator exhausted; service not found if sum([os.path.exists(x) for x in service_files]): details = { 'context': ', '.join(service_files), } else: details = { 'hint': "No service files could be found." } raise ServiceDoesNotExistError( 'cannot find service named "{0}"'.format(service_name), details = details ) def x_pg_ldap(ldap_url, config): raise NotImplementedError("cannot resolve ldap URLs: " + str(ldap_url)) default_x_callbacks = { 'settings' : x_settings, 'pq_iri' : x_pq_iri, 'pg_service' : x_pg_service, 'pg_ldap' : x_pg_ldap, } def extrapolate(iter, config = None, callbacks = default_x_callbacks): """ Given an iterable of standardized settings, [((path0, path1, ..., pathN), value)] Process any callbacks. """ config = config or {} for item in iter: k = item[0] if isinstance(k, str): if k.startswith('config-'): config[k[len('config-'):]] = item[1] else: cb = callbacks.get(k) if cb: for x in extrapolate( cb(item[1], config), config = config, callbacks = callbacks ): yield x else: pass else: yield item def normalize_parameter(kv): """ Translate a parameter into standard form. """ (k, v) = kv if k[0] == 'requiressl' and v in ('1', True): k[0] = 'sslmode' v = 'require' elif k[0] == 'dbname': k[0] = 'database' elif k[0] == 'sslmode': v = v.lower() return (tuple(k),v) def normalize(iter): """ Normally takes the output of `extrapolate` and makes a dictionary suitable for applying to a connector. """ rd = {} for (k, v) in iter: sd = rd for sk in k[:len(k)-1]: sd = sd.setdefault(sk, {}) sd[k[-1]] = v return rd def resolve_pg_service_file( environ = os.environ, default_pg_sysconfdir = None, default_pg_service_filename = pg_service_filename ): sysconfdir = environ.get(pg_sysconfdir_envvar, default_pg_sysconfdir) if sysconfdir: return os.path.join(sysconfdir, default_pg_service_filename) return None def collect( parsed_options : "options parsed using the `DefaultParser`" = None, no_defaults : "Don't build-out defaults like 'user' from getpass.getuser()" = False, environ : "environment variables to use, `None` to disable" = os.environ, environ_prefix : "prefix to use for collecting environment variables" = 'PG', default_pg_sysconfdir : "default 'PGSYSCONFDIR' to use" = None, pg_service_file : "the pg-service file to actually use" = None, prompt_title : "additional title to use if a prompt request is made" = '', parameters : "base-client parameters to use(applied after defaults)" = (), ): """ Build a normalized client parameters dictionary for use with a connection construction interface. """ d_parameters = [] d_parameters.append([('config-environ', environ)]) if default_pg_sysconfdir is not None: d_parameters.append([ ('config-pg_sysconfdir', default_pg_sysconfdir) ]) if pg_service_file is not None: d_parameters.append([ ('config-pg_service_file', pg_service_file) ]) if not no_defaults: d_parameters.append(defaults(environ = environ)) if parameters: d_parameters.append(denormalize_parameters(dict(parameters))) if environ is not None: d_parameters.append(envvars( environ = environ, modifier = environ_prefix.__add__ )) cop = getattr(parsed_options, 'db_client_parameters', None) if cop: d_parameters.append(cop) cpd = normalize(extrapolate(chain(*d_parameters))) if prompt_title is not None: resolve_password(cpd, prompt_title = prompt_title) return cpd if __name__ == '__main__': import pprint p = DefaultParser( description = "print the clientparams dictionary for the environment" ) (co, ca) = p.parse_args() r = collect(parsed_options = co, prompt_title = 'custom_prompt_title') pprint.pprint(r) fe-1.1.0/postgresql/cluster.py000066400000000000000000000405151203372773200163540ustar00rootroot00000000000000## # .cluster - PostgreSQL cluster management ## """ Create, control, and destroy PostgreSQL clusters. postgresql.cluster provides a programmer's interface to controlling a PostgreSQL cluster. It provides direct access to proper signalling interfaces. """ import sys import os import errno import time import subprocess as sp from tempfile import NamedTemporaryFile from . import api as pg_api from . import configfile from . import installation as pg_inn from . import exceptions as pg_exc from . import driver as pg_driver from .encodings.aliases import get_python_name from .python.os import close_fds if sys.platform in ('win32', 'win64'): from .port import signal1_msw as signal pg_kill = signal.kill def namedtemp(encoding): return NamedTemporaryFile(delete = False, mode = 'w', encoding=encoding) else: import signal pg_kill = os.kill def namedtemp(encoding): return NamedTemporaryFile(mode = 'w', encoding=encoding) class ClusterError(pg_exc.Error): """ General cluster error. """ code = '-C000' source = 'CLUSTER' class ClusterInitializationError(ClusterError): "General cluster initialization failure" code = '-Cini' class InitDBError(ClusterInitializationError): "A non-zero result was returned by the initdb command" code = '-Cidb' class ClusterStartupError(ClusterError): "Cluster startup failed" code = '-Cbot' class ClusterNotRunningError(ClusterError): "Cluster is not running" code = '-Cdwn' class ClusterTimeoutError(ClusterError): "Cluster operation timed out" code = '-Cout' class ClusterWarning(pg_exc.Warning): "Warning issued by cluster operations" code = '-Cwrn' source = 'CLUSTER' DEFAULT_CLUSTER_ENCODING = 'utf-8' DEFAULT_CONFIG_FILENAME = 'postgresql.conf' DEFAULT_HBA_FILENAME = 'pg_hba.conf' DEFAULT_PID_FILENAME = 'postmaster.pid' initdb_option_map = { 'encoding' : '-E', 'authentication' : '-A', 'user' : '-U', # pwprompt is not supported. # interactive use should be implemented by the application # calling Cluster.init() } class Cluster(pg_api.Cluster): """ Interface to a PostgreSQL cluster. Provides mechanisms to start, stop, restart, kill, drop, and initalize a cluster(data directory). Cluster does not strive to be consistent with ``pg_ctl``. This is considered to be a base class for managing a cluster, and is intended to be extended to accommodate for a particular purpose. """ driver = pg_driver.default installation = None data_directory = None DEFAULT_CLUSTER_ENCODING = DEFAULT_CLUSTER_ENCODING DEFAULT_CONFIG_FILENAME = DEFAULT_CONFIG_FILENAME DEFAULT_PID_FILENAME = DEFAULT_PID_FILENAME DEFAULT_HBA_FILENAME = DEFAULT_HBA_FILENAME @property def state(self): if self.running(): return 'running' if not os.path.exists(self.data_directory): return 'void' return 'stopped' def _e_metas(self): state = self.state yield (None, '[' + state + ']') if state == 'running': yield ('pid', self.state) @property def daemon_path(self): """ Path to the executable to use to startup the cluster. """ return self.installation.postmaster or self.installation.postgres def get_pid_from_file(self): """ The current pid from the postmaster.pid file. """ try: path = os.path.join(self.data_directory, self.DEFAULT_PID_FILENAME) with open(path) as f: return int(f.readline()) except IOError as e: if e.errno in (errno.EIO, errno.ENOENT): return None @property def pid(self): """ If we have the subprocess, use the pid on the object. """ pid = self.get_pid_from_file() if pid is None: d = self.daemon_process if d is not None: return d.pid return pid @property def settings(self): if not hasattr(self, '_settings'): self._settings = configfile.ConfigFile(self.pgsql_dot_conf) return self._settings @property def hba_file(self, join = os.path.join): """ The path to the HBA file of the cluster. """ return self.settings.get( 'hba_file', join(self.data_directory, self.DEFAULT_HBA_FILENAME) ) def __init__(self, installation : "installation object", data_directory : "path to the data directory", ): self.installation = installation self.data_directory = os.path.abspath(data_directory) self.pgsql_dot_conf = os.path.join( self.data_directory, self.DEFAULT_CONFIG_FILENAME ) self.daemon_process = None self.daemon_command = None def __repr__(self, format = "{mod}.{name}({ins!r}, {dir!r})".format): return format( type(self).__module__, type(self).__name__, self.installation, self.data_directory, ) def __enter__(self): """ Start the cluster and wait for it to startup. """ self.start() self.wait_until_started() return self def __exit__(self, typ, val, tb): """ Stop the cluster and wait for it to shutdown. """ self.stop() self.wait_until_stopped() def init(self, password : \ "Password to assign to the " \ "cluster's superuser(`user` keyword)." = None, **kw ): """ Create the cluster at the given `data_directory` using the provided keyword parameters as options to the command. `command_option_map` provides the mapping of keyword arguments to command options. """ initdb = self.installation.initdb if initdb is None: initdb = (self.installation.pg_ctl, 'initdb',) else: initdb = (initdb,) if None in initdb: raise ClusterInitializationError( "unable to find executable for cluster initialization", details = { 'detail' : "The installation does not have 'initdb' or 'pg_ctl'.", }, creator = self ) # Transform keyword options into command options for the executable. # A default is used rather than looking at the environment to, well, # avoid looking at the environment. kw.setdefault('encoding', self.DEFAULT_CLUSTER_ENCODING) opts = [] for x in kw: if x in ('logfile', 'extra_arguments'): continue if x not in initdb_option_map: raise TypeError("got an unexpected keyword argument %r" %(x,)) opts.append(initdb_option_map[x]) opts.append(kw[x]) logfile = kw.get('logfile') or sp.PIPE extra_args = tuple([ str(x) for x in kw.get('extra_arguments', ()) ]) supw_file = () supw_tmp = None p = None try: if password is not None: # got a superuserpass, store it in a tempfile for initdb supw_tmp = namedtemp(encoding = get_python_name(kw['encoding'])) supw_tmp.write(password) supw_tmp.flush() supw_file = ('--pwfile=' + supw_tmp.name,) cmd = initdb + ('-D', self.data_directory) \ + tuple(opts) \ + supw_file \ + extra_args p = sp.Popen( cmd, close_fds = close_fds, bufsize = 1024 * 5, # not expecting this to ever be filled. stdin = sp.PIPE, stdout = logfile, # stderr is used to identify a reasonable error message. stderr = sp.PIPE, ) # stdin is not used; it is not desirable for initdb to be attached. p.stdin.close() while True: try: rc = p.wait() break except OSError as e: if e.errno != errno.EINTR: raise finally: if p.stdout is not None: p.stdout.close() if rc != 0: # initdb returned non-zero, pickup stderr and attach to exception. r = p.stderr.read().strip() try: msg = r.decode('utf-8') except UnicodeDecodeError: # split up the lines, and use rep. msg = os.linesep.join([ repr(x)[2:-1] for x in r.splitlines() ]) raise InitDBError( "initdb exited with non-zero status", details = { 'command': cmd, 'stderr': msg, 'stdout': msg, }, creator = self ) finally: if p is not None: for x in (p.stderr, p.stdin, p.stdout): if x is not None: x.close() if supw_tmp is not None: n = supw_tmp.name supw_tmp.close() # XXX: win32 compensation. if os.path.exists(n): os.unlink(n) def drop(self): """ Stop the cluster and remove it from the filesystem """ if self.running(): self.shutdown() try: self.wait_until_stopped() except ClusterTimeoutError: self.kill() try: self.wait_until_stopped() except ClusterTimeoutError: ClusterWarning( 'cluster failed to shutdown after kill', details = {'hint' : 'Shared memory may have been leaked.'}, creator = self ).emit() # Really, using rm -rf would be the best, but use this for portability. for root, dirs, files in os.walk(self.data_directory, topdown = False): for name in files: os.remove(os.path.join(root, name)) for name in dirs: os.rmdir(os.path.join(root, name)) os.rmdir(self.data_directory) def start(self, logfile : "Where to send stderr" = None, settings : "Mapping of runtime parameters" = None ): """ Start the cluster. """ if self.running(): return cmd = (self.daemon_path, '-D', self.data_directory) if settings is not None: for k,v in dict(settings).items(): cmd.append('--{k}={v}'.format(k=k,v=v)) p = sp.Popen( cmd, close_fds = close_fds, bufsize = 1024, # send everything to logfile stdout = sp.PIPE if logfile is None else logfile, stderr = sp.STDOUT, stdin = sp.PIPE, ) if logfile is None: p.stdout.close() p.stdin.close() self.daemon_process = p self.daemon_command = cmd def restart(self, logfile = None, settings = None, timeout = 10): """ Restart the cluster gracefully. This provides a higher level interface to stopping then starting the cluster. It will perform the wait operations and block until the restart is complete. If waiting is not desired, .start() and .stop() should be used directly. """ if self.running(): self.stop() self.wait_until_stopped(timeout = timeout) if self.running(): raise ClusterError( "failed to shutdown cluster", creator = self ) self.start(logfile = logfile, settings = settings) self.wait_until_started(timeout = timeout) def reload(self): """ Signal the cluster to reload its configuration file. """ pid = self.pid if pid is not None: try: pg_kill(pid, signal.SIGHUP) except OSError as e: if e.errno != errno.ESRCH: raise def stop(self): """ Stop the cluster gracefully waiting for clients to disconnect(SIGTERM). """ pid = self.pid if pid is not None: try: pg_kill(pid, signal.SIGTERM) except OSError as e: if e.errno != errno.ESRCH: raise def shutdown(self): """ Shutdown the cluster as soon as possible, disconnecting clients. """ pid = self.pid if pid is not None: try: pg_kill(pid, signal.SIGINT) except OSError as e: if e.errno != errno.ESRCH: raise def kill(self): """ Stop the cluster immediately(SIGKILL). Does *not* wait for shutdown. """ pid = self.pid if pid is not None: try: pg_kill(pid, signal.SIGKILL) except OSError as e: if e.errno != errno.ESRCH: raise # already dead, so it would seem. def initialized(self): """ Whether or not the data directory *appears* to be a valid cluster. """ if os.path.isdir(self.data_directory) and \ os.path.exists(self.pgsql_dot_conf) and \ os.path.isdir(os.path.join(self.data_directory, 'base')): return True return False def running(self): """ Whether or not the postmaster is running. This does *not* mean the cluster is accepting connections. """ if self.daemon_process is not None: r = self.daemon_process.poll() if r is not None: pid = self.get_pid_from_file() if pid is not None: # daemon process does not exist, but there's a pidfile. self.daemon_process = None return self.running() return False else: return True else: pid = self.get_pid_from_file() if pid is None: return False try: pg_kill(pid, signal.SIG_DFL) except OSError as e: if e.errno != errno.ESRCH: raise return False return True def connector(self, **kw): """ Create a postgresql.driver connector based on the given keywords and listen_addresses and port configuration in settings. """ host, port = self.address() return self.driver.fit( host = host or 'localhost', port = port or 5432, **kw ) def connection(self, **kw): """ Create a connection object to the cluster, but do not connect. """ return self.connector(**kw)() def connect(self, **kw): """ Create an established connection from the connector. Cluster must be running. """ if not self.running(): raise ClusterNotRunningError( "cannot connect if cluster is not running", creator = self ) x = self.connection(**kw) x.connect() return x def address(self): """ Get the host-port pair from the configuration. """ d = self.settings.getset(( 'listen_addresses', 'port', )) if d.get('listen_addresses') is not None: # Prefer localhost over other addresses. # More likely to get a successful connection. addrs = d.get('listen_addresses').lower().split(',') if 'localhost' in addrs or '*' in addrs: host = 'localhost' elif '127.0.0.1' in addrs: host = '127.0.0.1' elif '::1' in addrs: host = '::1' else: host = addrs[0] else: host = None return (host, d.get('port')) def ready_for_connections(self): """ If the daemon is running, and is not in startup mode. This only works for clusters configured for TCP/IP connections. """ if not self.running(): return False e = None host, port = self.address() connection = self.driver.fit( user = ' -*- ping -*- ', host = host, port = port, database = 'template1', sslmode = 'disable', )() try: connection.connect() except pg_exc.ClientCannotConnectError as err: for attempt in err.database.failures: x = attempt.error if self.installation.version_info[:2] < (8,1): if isinstance(x, ( pg_exc.UndefinedObjectError, pg_exc.AuthenticationSpecificationError, )): # undefined user.. whatever... return True else: if isinstance(x, pg_exc.AuthenticationSpecificationError): return True # configuration file error. ya, that's probably not going to change. if isinstance(x, (pg_exc.CFError, pg_exc.ProtocolError)): raise x if isinstance(x, pg_exc.ServerNotReadyError): e = x break else: e = err # the else true means we successfully connected with those # credentials... strange, but true.. return e if e is not None else True def wait_until_started(self, timeout : "how long to wait before throwing a timeout exception" = 10, delay : "how long to sleep before re-testing" = 0.05, ): """ After the `start` method is used, this can be ran in order to block until the cluster is ready for use. This method loops until `ready_for_connections` returns `True` in order to make sure that the cluster is actually up. """ start = time.time() checkpoint = start while True: if not self.running(): if self.daemon_process is not None: r = self.daemon_process.returncode if r is not None: raise ClusterStartupError( "postgres daemon terminated", details = { 'RESULT' : r, 'COMMAND' : self.daemon_command, }, creator = self ) else: raise ClusterNotRunningError( "postgres daemon has not been started", creator = self ) r = self.ready_for_connections() checkpoint = time.time() if r is True: break if checkpoint - start >= timeout: # timeout was reached, but raise ServerNotReadyError # to signal to the user that it was *not* due to some unknown # condition, rather it's *still* starting up. if r is not None and isinstance(r, pg_exc.ServerNotReadyError): raise r e = ClusterTimeoutError( 'timeout on startup', creator = self ) if r not in (True,False): raise e from r raise e time.sleep(delay) def wait_until_stopped(self, timeout : "how long to wait before throwing a timeout exception" = 10, delay : "how long to sleep before re-testing" = 0.05 ): """ After the `stop` method is used, this can be ran in order to block until the cluster is shutdown. Additionally, catching `ClusterTimeoutError` exceptions would be a starting point for making decisions about whether or not to issue a kill to the daemon. """ start = time.time() while self.running() is True: # pickup the exit code. if self.daemon_process is not None: self.last_exit_code = self.daemon_process.poll() else: self.last_exit_code = pg_kill(self.get_pid_from_file(), 0) if time.time() - start >= timeout: raise ClusterTimeoutError( 'timeout on shutdown', creator = self, ) time.sleep(delay) ## # vim: ts=3:sw=3:noet: fe-1.1.0/postgresql/configfile.py000066400000000000000000000174001203372773200167750ustar00rootroot00000000000000## # .configfile ## 'PostgreSQL configuration file parser and editor functions.' import sys import os from . import string as pg_str from . import api as pg_api quote = "'" comment = '#' def parse_line(line, equality = '=', comment = comment, quote = quote): keyval = line.split(equality, 1) if len(keyval) == 2: key, val = keyval prekey_len = 0 for c in key: if not c.isspace() and c not in comment: break prekey_len += 1 key_len = 0 for c in key[prekey_len:]: if not (c.isalpha() or c.isdigit() or c in '_'): break key_len += 1 # If non-whitespace exists after the key, # it's a complex comment, so just bail out. if key[prekey_len + key_len:].strip(): return preval_len = 0 for c in val: if not c.isspace() or c in '\n\r': break preval_len += 1 inquotes = False escaped = False val_len = 0 for i in range(preval_len, len(val)): c = val[i] if c == quote: if inquotes is False: inquotes = True else: if escaped is False: # Peek ahead to see if it's escaped with another quote escaped = (len(val) > i+1 and val[i+1] == quote) if escaped is False: inquotes = False elif escaped is True: # It *was* an escaped quote. escaped = False elif inquotes is False and (c.isspace() or c in comment): break val_len += 1 return ( # The key slice slice(prekey_len, key_len + prekey_len), # The value slice slice(len(key) + 1 + preval_len, len(key) + 1 + preval_len + val_len) ) def unquote(s, quote = quote): """ Unquote the string `s` if quoted. """ s = s.strip() if not s.startswith(quote): return s return s[1:-1].replace(quote*2, quote) def write_config(map, writer, keys = None): 'A configuration writer that will trample & merely write the settings' if keys is None: keys = map for k in keys: writer('='.join((k, map[k])) + os.linesep) def alter_config( map : "the configuration changes to make", fo : "file object containing configuration lines(Iterable)", keys : "the keys to change; defaults to map.keys()" = None ): 'Alters a configuration file without trampling on the existing structure' if keys is None: keys = list(map.keys()) # Normalize keys and map them back to pkeys = { k.lower().strip() : keys.index(k) for k in keys } lines = [] candidates = {} i = -1 # Process lines in fo for l in fo: i += 1 lines.append(l) pl = parse_line(l) # "Bad" line? fuh-get-duh-bowt-it. if pl is None: continue sk, sv = pl k = l[sk].lower() v = l[sv] # It's a candidate? if k in pkeys: c = candidates.get(k) if c is None: candidates[k] = c = [] c.append((i, sk, sv)) # Simply insert the data somewhere for unfound keys. for k in pkeys: if k not in candidates: key = keys[pkeys[k]] val = map[key] # Can't comment without an uncommented candidate. if val is not None: if not lines[-1].endswith(os.linesep): lines[-1] = lines[-1] + os.linesep lines.append("%s = '%s'" %(key, val.replace("'", "''"))) # Multiple lines may have the key, so make a decision based on the value. for ck in candidates.keys(): to_set_key = keys[pkeys[ck]] to_set_val = map[keys[pkeys[ck]]] if to_set_val is None: # Comment uncommented occurrences. for cl in candidates[ck]: line_num, sk, sv = cl if comment not in lines[line_num][:sk.start]: lines[line_num] = '#' + lines[line_num] else: # Manage occurrences. # w_ is for winner. # Now, a winner is elected for alteration. The winner # is decided based on a two factors: commenting and value. w_score = -1 w_commented = None w_val = None w_cl = None for cl in candidates[ck]: line_num, sk, sv = cl l = lines[line_num] lkey = l[sk] lval = l[sv] commented = (comment in l[:sk.start]) score = \ (not commented and 1 or 0) + \ (unquote(lval) == to_set_val and 2 or 0) # So, if a line is not commented, and has equal # values, then that's the winner. If a line is commented, # and has a has equal values, it will succeed over a mere # uncommented value. if score > w_score: if w_commented is False: # It's now a loser, so comment it out if necessary. lines[w_cl[0]] = '#' + lines[w_cl[0]] w_score = score w_commented = commented w_val = lval w_cl = cl elif commented is False: # Loser needs to be commented. lines[line_num] = '#' + l line_num, sk, sv = w_cl l = lines[line_num] if w_commented: bol = '' else: bol = l[:sk.start] post_val = l[sv.stop:] # If there is post-value data, validate that it's commented. if post_val and not post_val.isspace(): stripped_post_val = post_val.lstrip() if not stripped_post_val.startswith(comment): post_val = '%s%s%s' %( # The whitespace before the uncommented visibles post_val[0:len(post_val) - len(stripped_post_val)], # A comment followed by the uncommented visibles comment, stripped_post_val ) # Everything is set as quoted as it's the only safe # way to set something without delving into setting types. lines[line_num] = \ bol + l[sk.start:sv.start] + \ "'%s'" %(to_set_val.replace("'", "''"),) + post_val return lines def read_config(iter, d = None, selector = None): if d is None: d = {} for line in iter: kv = parse_line(line) if kv: key = line[kv[0]] if comment not in line[:kv[0].start] and \ (selector is None or selector(key)): d[key] = unquote(line[kv[1]]) return d class ConfigFile(pg_api.Settings): """ Provides a mapping interface to a configuration file. Every action will cause the file to be wholly read, so using `update` to make multiple changes is desirable. """ _e_factors = ('path',) _e_label = 'CONFIGFILE' def _e_metas(self): yield (None, len(self.keys())) def __init__(self, path, open = open): self.path = path self._open = open self._store = [] self._restore = {} def __repr__(self): return "%s.%s(%r)" %( type(self).__module__, type(self).__name__, self.path ) def _save(self, lines : [str]): with self._open(self.path, 'w') as cf: for l in lines: cf.write(l) def __delitem__(self, k): with self._open(self.path) as cf: lines = alter_config({k : None}, cf) self._save() def __getitem__(self, k): with self._open(self.path) as cfo: return read_config( cfo, selector = k.__eq__ )[k] def __setitem__(self, k, v): self.update({k : v}) def __call__(self, **kw): self._store.insert(0, kw) def __context__(self): return self def __iter__(self): return self.keys() def __len__(self): return len(list(self.keys())) def __enter__(self): res = self.getset(self._store[0].keys()) self.update(self._store[0]) del self._store[0] self._restore.append(res) def __exit__(self, exc, val, tb): self._restored.update(self._restore[-1]) del self._restore[-1] self.update(self._restored) self._restored.clear() return exc is None def get(self, k, alt = None): with self._open(self.path) as cf: return read_config(cf, selector = k.__eq__).get(k, alt) def keys(self): return read_config(self._open(self.path)).keys() def values(self): return read_config(self._open(self.path)).values() def items(self): return read_config(self._open(self.path)).items() def update(self, keyvals): """ Given a dictionary of settings, apply them to the cluster's postgresql.conf. """ with self._open(self.path) as cf: lines = alter_config(keyvals, cf) self._save(lines) def getset(self, keys): """ Get all the settings in the list of keys. Returns a dictionary of those keys. """ keys = set(keys) with self._open(self.path) as cfo: cfg = read_config( cfo, selector = keys.__contains__ ) for x in (keys - set(cfg.keys())): cfg[x] = None return cfg ## # vim: ts=3:sw=3:noet: fe-1.1.0/postgresql/copyman.py000066400000000000000000000537411203372773200163460ustar00rootroot00000000000000## # .copyman - COPY manager ## """ Manage complex COPY operations; one-to-many COPY streaming. Primarily this module houses the `CopyManager` class, and the `transfer` function for a high-level interface to using the `CopyManager`. """ import sys from abc import abstractmethod, abstractproperty from collections import Iterator from .python.element import Element, ElementSet from .python.structlib import ulong_unpack, ulong_pack from .protocol.buffer import pq_message_stream from .protocol.element3 import CopyData, CopyDone, Complete, cat_messages from .protocol.xact3 import Complete as xactComplete #: 10KB buffer for COPY messages by default. default_buffer_size = 1024 * 10 class Fault(Exception): pass class ProducerFault(Fault): """ Exception raised when the Producer caused an exception. Normally, Producer faults are fatal. """ def __init__(self, manager): self.manager = manager def __str__(self): return "producer raised exception" class ReceiverFault(Fault): """ Exception raised when Receivers cause an exception. Faults should be trapped if recovery from an exception is possible, or if the failed receiver is optional to the succes of the operation. The 'manager' attribute is the CopyManager that raised the fault. The 'faults' attribute is a dictionary mapping the receiver to the exception instance raised. """ def __init__(self, manager, faults): self.manager = manager self.faults = faults def __str__(self): return "{0} faults occurred".format(len(self.faults)) class CopyFail(Exception): """ Exception thrown by the CopyManager when the COPY operation failed. The 'manager' attribute the CopyManager that raised the CopyFail. The 'reason' attribute is a string indicating why it failed. The 'receiver_faults' attribute is a mapping of receivers to exceptions that were raised on exit. The 'producer_fault' attribute specifies if the producer raise an exception on exit. """ def __init__(self, manager, reason = None, receiver_faults = None, producer_fault = None, ): self.manager = manager self.reason = reason self.receiver_faults = receiver_faults or {} self.producer_fault = producer_fault def __str__(self): return self.reason or 'copy aborted' # The identifier for PQv3 copy data. PROTOCOL_PQv3 = "PQv3" # The identifier for iterables of copy data sequences. # iter([[row1, row2], [row3, row4]]) PROTOCOL_CHUNKS = "CHUNKS" # The protocol identifier for NULL producers and receivers. PROTOCOL_NULL = None class ChunkProtocol(object): __slots__ = ('buffer',) def __init__(self): self.buffer = pq_message_stream() def __call__(self, data): self.buffer.write(bytes(data)) return [x[1] for x in self.buffer.read()] # Null protocol mapping. def EmptyView(arg): return memoryview(b'') def EmptyList(arg): return [] def ReturnNone(arg): return None # Zero-Transformation def NoTransformation(arg): return arg # Copy protocols being at the Python level; *not* wire/serialization format. copy_protocol_mappings = { # PQv3 -> Chunks (PROTOCOL_PQv3, PROTOCOL_CHUNKS) : ChunkProtocol, # Chunks -> PQv3 (PROTOCOL_CHUNKS, PROTOCOL_PQv3) : lambda: cat_messages, # Null Producers and Receivers (PROTOCOL_NULL, PROTOCOL_PQv3) : lambda: EmptyView, (PROTOCOL_NULL, PROTOCOL_CHUNKS) : lambda: EmptyList, (PROTOCOL_PQv3, PROTOCOL_NULL) : lambda: ReturnNone, (PROTOCOL_CHUNKS, PROTOCOL_NULL) : lambda: ReturnNone, # Zero Transformations (PROTOCOL_NULL, PROTOCOL_NULL) : lambda: NoTransformation, (PROTOCOL_CHUNKS, PROTOCOL_CHUNKS) : lambda: NoTransformation, (PROTOCOL_PQv3, PROTOCOL_PQv3) : lambda: NoTransformation, } # Used to manage the conversions of COPY data. # Notably, chunks -> PQv3 or PQv3 -> chunks. class CopyTransformer(object): __slots__ = ('current', 'transformers', 'get') def __init__(self, source_protocol, target_protocols): self.current = {} self.transformers = { x : copy_protocol_mappings[(source_protocol, x)]() for x in set(target_protocols) } self.get = self.current.__getitem__ def __call__(self, data): for protocol, transformer in self.transformers.items(): self.current[protocol] = transformer(data) ## # This is the object that does the magic. # It tracks the state of the wire. # It ends when non-COPY data is found. class WireState(object): """ Manages the state of the wire. This class manages three possible positions: 1. Between wire messages 2. Inside message header 3. Inside message (with complete header) The wire state will become unusable when the configured condition is True. """ __slots__ = ('remaining_bytes', 'size_fragment', 'final_view', 'condition',) def update(self, view, getlen = ulong_unpack, len = len): """ Given the state of the COPY and new data, advance the position on the COPY stream. """ # Only usable until the terminating condition. if self.final_view is not None: raise RuntimeError("wire state encountered exceptional condition") nmessages = 0 # State carried over from prior run. remaining_bytes = self.remaining_bytes size_fragment = self.size_fragment # Terminating condition. CONDITION = self.condition # Is it a continuation of a message header? if remaining_bytes == -1: ## # Inside message header; after message type. # Continue adding to the 'size_fragment' # until there are four bytes to unpack. ## o = len(size_fragment) size_fragment += bytes(view[:4-o]) if len(size_fragment) == 4: # The size fragment is completed; only part # of the fragment remains to be consumed. remaining_bytes = getlen(size_fragment) - o size_fragment = b'' else: assert len(size_fragment) < 4 # size_fragment got updated.. if remaining_bytes >= 0: vlen = len(view) while True: if remaining_bytes: ## # Inside message body. Message length has been unpacked. ## view = view[remaining_bytes:] # How much is remaining now? rb = remaining_bytes - vlen if rb <= 0: # Finished it. vlen = -rb remaining_bytes = 0 nmessages += 1 else: vlen = 0 remaining_bytes = rb ## # In between protocol messages. ## if not view: # no more data to analyze break # There is at least one byte in the view. if CONDITION(view[0]): # State is dead now. # User needs to handle unexpected message, then continue. self.final_view = view assert remaining_bytes == 0 break if vlen < 5: # Header continuation. remaining_bytes = -1 view = view[1:] size_fragment += bytes(view) # Not enough left for the header of the next message? break # Update remaining_bytes to include the header, and start over. remaining_bytes = getlen(view[1:5]) + 1 # Update the state for the next update. self.remaining_bytes, self.size_fragment = ( remaining_bytes, size_fragment, ) # Emit the number of messages "consumed" this round. return nmessages def __init__(self, condition = (CopyData.type[0].__ne__ if isinstance(memoryview(b'f')[0], int) else CopyData.type.__ne__)): self.remaining_bytes = 0 self.size_fragment = b'' self.final_view = None self.condition = condition class Fitting(Element): _e_label = 'FITTING' def _e_metas(self): yield None, '[' + self.state + ']' @abstractproperty def protocol(self): """ The COPY data format produced or consumed. """ # Used to setup the Receiver/Producer def __enter__(self): pass # Used to tear down the Receiver/Producer def __exit__(self, typ, val, tb): pass class Producer(Fitting, Iterator): _e_label = 'PRODUCER' def _e_metas(self): for x in super()._e_metas(): yield x yield 'data', str(self.total_bytes / (1024**2)) + 'MB' yield 'messages', self.total_messages yield 'average size', (self.total_bytes / self.total_messages) def __init__(self): self.total_messages = 0 self.total_bytes = 0 @abstractmethod def realign(self): """ Method implemented by producers that emit COPY data that is not guaranteed to be aligned. This is only necessary in failure cases where receivers still need more data to complete the message. """ @abstractmethod def __next__(self): """ Produce the next set of data. """ class Receiver(Fitting): _e_label = 'RECEIVER' @abstractmethod def transmit(self): """ Finish the reception of the accepted data. """ @abstractmethod def accept(self, data): """ Take the data object to be processed. """ class NullProducer(Producer): """ Produces no copy data. """ _e_factors = () protocol = PROTOCOL_NULL def realign(self): # Never needs to realigned. pass def __next__(self): raise StopIteration class IteratorProducer(Producer): _e_factors = ('iterator',) protocol = PROTOCOL_CHUNKS def __init__(self, iterator): self.iterator = iter(iterator) self.__next__ = self.iterator.__next__ super().__init__() def realign(self): # Never needs to realign; data is emitted on message boundaries. pass def __next__(self, next = next): n = next(self.iterator) self.total_messages += len(n) self.total_bytes += sum(map(len, n)) return n class ProtocolProducer(Producer): """ Producer using a PQv3 data stream. Normally, this class needs to be subclassed as it assumes that the given recv_into function will write COPY messages. """ protocol = PROTOCOL_PQv3 @abstractmethod def recover(self, view): """ Given a view containing data read from the wire, recover the controller's state. This needs to be implemented by subclasses in order for the ProtocolReceiver to pass control back to the original state machine. """ ## # When a COPY is interrupted, this can be used to accommodate # the original state machine to identify the message boundaries. def realign(self): s = self._state if s is None: # It's already aligned. self.nextchunk = iter(()).__next__ return if s.final_view: # It was at the end or non-COPY. for_producer = bytes(s.final_view) for_receivers = b'' elif s.remaining_bytes == -1: # In the middle of a message header. for_producer = CopyData.type + s.size_fragment # receivers: header = (self._state.size_fragment.ljust(3, b'\x00') + b'\x04') # Don't include the already sent parts. buf = header[len(self._state.size_fragment):] bodylen = ulong_unpack(header) - 4 # This will often cause an invalid copy data error, # but it doesn't matter much because we will issue a copy fail. buf += b'\x00' * bodylen for_receivers = buf elif s.remaining_bytes > 0: # In the middle of a message. for_producer = CopyData.type + ulong_pack(s.remaining_bytes + 4) for_receivers = b'\x00' * self._state.remaining_bytes else: for_producer = for_receivers = b'' self.recover(for_producer) if for_receivers: self.nextchunk = iter((for_receivers,)).__next__ else: self.nextchunk = iter(()).__next__ def process_copy_data(self, view): self.total_messages += self._state.update(view) if self._state.final_view is not None: # It's not COPY data. fv = self._state.final_view # Only publish up to the final_view. if fv: view = view[:-len(fv)] # The next next() will handle the async, error, or completion. self.recover(fv) self._state = None self.total_bytes += len(view) return view # Given a view, begin tracking the state of the wire. def track_state(self, view): self._state = WireState() self.nextchunk = self.recv_view return self.process_copy_data(view) # The usual method for receiving more data. def recv_view(self): view = self.buffer_view[:self.recv_into(self.buffer, self.buffer_size)] if not view: # Zero read; let the subclass handle the situation. self.recover(memoryview(b'')) return self.nextchunk() view = self.process_copy_data(view) return view def nextchunk(self): raise RuntimeError("producer not properly initialized") def __next__(self): return self.nextchunk() def __init__(self, recv_into : "callable taking writable buffer and size", buffer_size = default_buffer_size ): super().__init__() self.recv_into = recv_into self.buffer_size = buffer_size self.buffer = bytearray(buffer_size) self.buffer_view = memoryview(self.buffer) self._state = None class StatementProducer(ProtocolProducer): _e_factors = ('statement', 'parameters',) def _e_metas(self): for x in super()._e_metas(): yield x @property def state(self): if self._chunks is None: return 'created' return 'producing' def count(self): return self._chunks.count() def command(self): return self._chunks.command() def __init__(self, statement, *args, **kw): super().__init__(statement.database.pq.socket.recv_into, **kw) self.statement = statement self.parameters = args self._chunks = None ## # Take any data held by the statement's chunks and connection. def confiscate(self, next = next): current = [] try: while not current: current.extend(next(self._chunks)) except StopIteration: if not current: # End of COPY. raise pq = self._chunks.database.pq buffer = cat_messages(current) + pq.message_buffer.getvalue() + (pq.read_data or b'') view = memoryview(buffer) pq.read_data = None pq.message_buffer.truncate() # Reconstruct the buffer from the already parsed lines. r = self.track_state(view) # XXX: Better way? Probably shouldn't do the full track_state if complete.. if self._chunks._xact.state is xactComplete: # It's over, don't hand off to recv_view. self.nextchunk = self.confiscate assert self._state.final_view is None return r def recover(self, view): # Method used when non-COPY data is found. self._chunks.database.pq.message_buffer.write(bytes(view)) self.nextchunk = self.confiscate def __enter__(self): super().__enter__() if self._chunks is not None: raise RuntimeError("receiver already used") self._chunks = self.statement.chunks(*self.parameters) # Start by confiscating the connection state. self.nextchunk = self.confiscate def __exit__(self, typ, val, tb): if typ is None or issubclass(typ, Exception): db = self.statement.database if not db.closed and self._chunks._xact is not None: # The COPY transaction is still happening, # force an interrupt if the connection still exists. db.interrupt() if db.pq.xact: # Raise, CopyManager should trap. db._pq_complete() super().__exit__(typ, val, tb) class NullReceiver(Receiver): _e_factors = () protocol = PROTOCOL_NULL state = 'null' def transmit(self): # Nothing to do. pass def accept(self, data): pass class ProtocolReceiver(Receiver): protocol = PROTOCOL_PQv3 __slots__ = ('send', 'view') def __init__(self, send): super().__init__() self.send = send self.view = memoryview(b'') def accept(self, data): self.view = data def transmit(self): while self.view: self.view = self.view[self.send(self.view):] def __enter__(self): return self def __exit__(self, typ, val, tb): pass class StatementReceiver(ProtocolReceiver): _e_factors = ('statement', 'parameters',) __slots__ = ProtocolReceiver.__slots__ + _e_factors + ('xact',) def _e_metas(self): yield None, '[' + self.state + ']' def __init__(self, statement, *parameters): self.statement = statement self.parameters = parameters self.xact = None super().__init__(statement.database.pq.socket.send,) # XXX: A bit of a hack... # This is actually a good indication that statements need a .copy() # execution method for producing a "CopyCursor" that reads or writes. class WireReady(BaseException): pass def raise_wire_ready(self): raise self.WireReady() yield None def __enter__(self, iter = iter): super().__enter__() # Get the connection in the COPY state. try: self.statement.load_chunks( iter(self.raise_wire_ready()), *self.parameters ) except self.WireReady: # It's a BaseException; nothing should trap it. # Note the transaction object; we'll use it on exit. self.xact = self.statement.database.pq.xact def __exit__(self, typ, val, tb): if self.xact is None: # Nothing to do. return super().__exit__(typ, val, tb) if self.view: # The realigned producer emitted the necessary # data for message boundary alignment. # # In this case, we unconditionally fail. pq = self.statement.database.pq # There shouldn't be any message_data, atm. pq.message_data = bytes(self.view) self.statement.database._pq_complete() # It is possible for a non-alignment view to exist in cases of # faults. However, exit should *not* be called in those cases. ## elif typ is None: # Success? self.xact.messages = self.xact.CopyDoneSequence # If not, this will blow up. self.statement.database._pq_complete() # Find the complete message for command and count. for x in self.xact.messages_received(): if getattr(x, 'type', None) == Complete.type: self._complete_message = x elif issubclass(typ, Exception): # Likely raises. CopyManager should trap. self.statement.database._pq_complete() return super().__exit__(typ, val, tb) def count(self): return self._complete_message.extract_count() def command(self): return self._complete_message.extract_command().decode('ascii') class CallReceiver(Receiver): """ Call the given object with a list of COPY lines. """ _e_factors = ('callable',) protocol = PROTOCOL_CHUNKS def __init__(self, callable): self.callable = callable self.lines = None super().__init__() def transmit(self): if self.lines is not None: self.callable(self.lines) self.lines = None def accept(self, lines): self.lines = lines class CopyManager(Element, Iterator): """ A class for managing COPY operations. Connects the producer to the receivers. """ _e_label = 'COPY' _e_factors = ('producer', 'receivers',) def _e_metas(self): yield None, '[' + self.state + ']' @property def state(self): if self.transformer is None: return 'initialized' return str(self.producer.total_messages) + ' messages transferred' def __init__(self, producer, *receivers): self.producer = producer self.transformer = None self.receivers = ElementSet(receivers) self._seen_stop_iteration = False rp = set() add = rp.add for x in self.receivers: add(x.protocol) self.protocols = rp def __enter__(self): if self.transformer: raise RuntimeError("copy already started") self._stats = (0, 0) self.transformer = CopyTransformer(self.producer.protocol, self.protocols) self.producer.__enter__() try: for x in self.receivers: x.__enter__() except Exception: self.__exit__(*sys.exc_info()) return self def __exit__(self, typ, val, tb): ## # Exiting the CopyManager is a fairly complex operation. # # In cases of failure, re-alignment may need to happen # for when the receivers are not on a message boundary. ## if typ is not None and not issubclass(typ, Exception): # Don't bother, it's an interrupt or sufficient resources. return profail = None try: # Does nothing if the COPY was successful. self.producer.realign() try: ## # If the producer is not aligned to a message boundary, # it can emit completion data that will put the receivers # back on track. # This last service call will move that data onto the receivers. self._service_producer() ## # The receivers need to handle any new data in their __exit__. except StopIteration: # No re-alignment needed. pass self.producer.__exit__(typ, val, tb) except Exception as x: # reference profail later. profail = x # No receivers? It wasn't a success. if not self.receivers: raise CopyFail(self, "no receivers", producer_fault = profail) exit_faults = {} for x in self.receivers: try: x.__exit__(typ, val, tb) except Exception as e: exit_faults[x] = e if typ or exit_faults or profail or not self._seen_stop_iteration: raise CopyFail(self, "could not complete the COPY operation", receiver_faults = exit_faults or None, producer_fault = profail ) def reconcile(self, r): """ Reconcile a receiver that faulted. This method should be used to add back a receiver that failed to complete its write operation, but is capable of completing the operation at this time. """ if r.protocol not in self.protocols: raise RuntimeError("cannot add new receivers to copy operations") r.transmit() # Okay, add it back. self.receivers.add(r) def _service_producer(self): # Setup current data. if not self.receivers: # No receivers to take the data. raise StopIteration try: nextdata = next(self.producer) except StopIteration: # Should be over. self._seen_stop_iteration = True raise except Exception: raise ProducerFault(self) self.transformer(nextdata) # Distribute data to receivers. for x in self.receivers: x.accept(self.transformer.get(x.protocol)) def _service_receivers(self): faults = {} for x in self.receivers: # Process all the receivers. try: x.transmit() except Exception as e: faults[x] = e if faults: # The CopyManager is eager to continue the operation. for x in faults: self.receivers.discard(x) raise ReceiverFault(self, faults) # Run the COPY to completion. def run(self): with self: try: while True: self._service_producer() self._service_receivers() except StopIteration: # It's done. pass def __iter__(self): return self def __next__(self): messages = self.producer.total_messages bytes = self.producer.total_bytes self._service_producer() # Record the progress in case a receiver faults. self._stats = ( self._stats[0] + (self.producer.total_messages - messages), self._stats[1] + (self.producer.total_bytes - bytes), ) self._service_receivers() # Return the progress. current_stats = self._stats self._stats = (0, 0) return current_stats def transfer(producer, *receivers): """ Perform a COPY operation using the given statements:: >>> import copyman >>> copyman.transfer(src.prepare("COPY table TO STDOUT"), dst.prepare("COPY table FROM STDIN")) """ cm = CopyManager( StatementProducer(producer), *[x if isinstance(x, Receiver) else StatementReceiver(x) for x in receivers] ) cm.run() return (cm.producer.total_messages, cm.producer.total_bytes) fe-1.1.0/postgresql/documentation/000077500000000000000000000000001203372773200171655ustar00rootroot00000000000000fe-1.1.0/postgresql/documentation/__init__.py000066400000000000000000000001511203372773200212730ustar00rootroot00000000000000## # .documentation ## r""" See: `postgresql.documentation.index` """ __docformat__ = 'reStructuredText' fe-1.1.0/postgresql/documentation/admin.rst000066400000000000000000000022121203372773200210040ustar00rootroot00000000000000Administration ============== This chapter covers the administration of py-postgresql. This includes installation and other aspects of working with py-postgresql such as environment variables and configuration files. Installation ------------ py-postgresql uses Python's distutils package to manage the build and installation process of the package. The normal entry point for this is the ``setup.py`` script contained in the root project directory. After extracting the archive and changing the into the project's directory, installation is normally as simple as:: $ python3 ./setup.py install However, if you need to install for use with a particular version of python, just use the path of the executable that should be used:: $ /usr/opt/bin/python3 ./setup.py install Environment ----------- These environment variables effect the operation of the package: ============== =============================================================================== PGINSTALLATION The path to the ``pg_config`` executable of the installation to use by default. ============== =============================================================================== fe-1.1.0/postgresql/documentation/alock.rst000066400000000000000000000077171203372773200210240ustar00rootroot00000000000000.. _alock: ************** Advisory Locks ************** .. warning:: `postgresql.alock` is a new feature in v1.0. `Explicit Locking in PostgreSQL `_. PostgreSQL's advisory locks offer a cooperative synchronization primitive. These are used in cases where an application needs access to a resource, but using table locks may cause interference with other operations that can be safely performed alongside the application-level, exclusive operation. Advisory locks can be used by directly executing the stored procedures in the database or by using the :class:`postgresql.alock.ALock` subclasses, which provides a context manager that uses those stored procedures. Currently, only two subclasses exist. Each represents the lock mode supported by PostgreSQL's advisory locks: * :class:`postgresql.alock.ShareLock` * :class:`postgresql.alock.ExclusiveLock` Acquiring ALocks ================ An ALock instance represents a sequence of advisory locks. A single ALock can acquire and release multiple advisory locks by creating the instance with multiple lock identifiers:: >>> from postgresql import alock >>> table1_oid = 192842 >>> table2_oid = 192849 >>> l = alock.ExclusiveLock(db, (table1_oid, 0), (table2_oid, 0)) >>> l.acquire() >>> ... >>> l.release() :class:`postgresql.alock.ALock` is similar to :class:`threading.RLock`; in order for an ALock to be released, it must be released the number of times it has been acquired. ALocks are associated with and survived by their session. Much like how RLocks are associated with the thread they are acquired in: acquiring an ALock again will merely increment its count. PostgreSQL allows advisory locks to be identified using a pair of `int4` or a single `int8`. ALock instances represent a *sequence* of those identifiers:: >>> from postgresql import alock >>> ids = [(0,0), 0, 1] >>> with alock.ShareLock(db, *ids): ... ... Both types of identifiers may be used within the same ALock, and, regardless of their type, will be aquired in the order that they were given to the class' constructor. In the above example, ``(0,0)`` is acquired first, then ``0``, and lastly ``1``. ALocks ====== `postgresql.alock.ALock` is abstract; it defines the interface and some common functionality. The lock mode is selected by choosing the appropriate subclass. There are two: ``postgresql.alock.ExclusiveLock(database, *identifiers)`` Instantiate an ALock object representing the `identifiers` for use with the `database`. Exclusive locks will conflict with other exclusive locks and share locks. ``postgresql.alock.ShareLock(database, *identifiers)`` Instantiate an ALock object representing the `identifiers` for use with the `database`. Share locks can be acquired when a share lock with the same identifier has been acquired by another backend. However, an exclusive lock with the same identifier will conflict. ALock Interface Points ---------------------- Methods and properties available on :class:`postgresql.alock.ALock` instances: ``alock.acquire(blocking = True)`` Acquire the advisory locks represented by the ``alock`` object. If blocking is `True`, the default, the method will block until locks on *all* the identifiers have been acquired. If blocking is `False`, acquisition may not block, and success will be indicated by the returned object: `True` if *all* lock identifiers were acquired and `False` if any of the lock identifiers could not be acquired. ``alock.release()`` Release the advisory locks represented by the ``alock`` object. If the lock has not been acquired, a `RuntimeError` will be raised. ``alock.locked()`` Returns a boolean describing whether the locks are held or not. This will return `False` if the lock connection has been closed. ``alock.__enter__()`` Alias to ``acquire``; context manager protocol. Always blocking. ``alock.__exit__(typ, val, tb)`` Alias to ``release``; context manager protocol. fe-1.1.0/postgresql/documentation/bin.rst000066400000000000000000000121561203372773200204740ustar00rootroot00000000000000Commands ******** This chapter discusses the usage of the available console scripts. postgresql.bin.pg_python ======================== The ``pg_python`` command provides a simple way to write Python scripts against a single target database. It acts like the regular Python console command, but takes standard PostgreSQL options as well to specify the client parameters to make establish connection with. The Python environment is then augmented with the following built-ins: ``db`` The PG-API connection object. ``xact`` ``db.xact``, the transaction creator. ``settings`` ``db.settings`` ``prepare`` ``db.prepare``, the statement creator. ``proc`` ``db.proc`` ``do`` ``db.do``, execute a single DO statement. ``sqlexec`` ``db.execute``, execute multiple SQL statements (``None`` is always returned) pg_python Usage --------------- Usage: postgresql.bin.pg_python [connection options] [script] ... Options: --unix=UNIX path to filesystem socket --ssl-mode=SSLMODE SSL requirement for connectivity: require, prefer, allow, disable -s SETTINGS, --setting=SETTINGS run-time parameters to set upon connecting -I PQ_IRI, --iri=PQ_IRI database locator string [pq://user:password@host:port/database?setting=value] -h HOST, --host=HOST database server host -p PORT, --port=PORT database server port -U USER, --username=USER user name to connect as -W, --password prompt for password -d DATABASE, --database=DATABASE database's name --pq-trace=PQ_TRACE trace PQ protocol transmissions -C PYTHON_CONTEXT, --context=PYTHON_CONTEXT Python context code to run[file://,module:,] -m PYTHON_MAIN Python module to run as script(__main__) -c PYTHON_MAIN Python expression to run(__main__) --version show program's version number and exit --help show this help message and exit Interactive Console Backslash Commands -------------------------------------- Inspired by ``psql``:: >>> \? Backslash Commands: \? Show this help message. \E Edit a file or a temporary script. \e Edit and Execute the file directly in the context. \i Execute a Python script within the interpreter's context. \set Configure environment variables. \set without arguments to show all \x Execute the Python command within this process. pg_python Examples ------------------ Module execution taking advantage of the new built-ins:: $ python3 -m postgresql.bin.pg_python -h localhost -W -m timeit "prepare('SELECT 1').first()" Password for pg_python[pq://jwp@localhost:5432]: 1000 loops, best of 3: 1.35 msec per loop $ python3 -m postgresql.bin.pg_python -h localhost -W -m timeit -s "ps=prepare('SELECT 1')" "ps.first()" Password for pg_python[pq://jwp@localhost:5432]: 1000 loops, best of 3: 442 usec per loop Simple interactive usage:: $ python3 -m postgresql.bin.pg_python -h localhost -W Password for pg_python[pq://jwp@localhost:5432]: >>> ps = prepare('select 1') >>> ps.first() 1 >>> c = ps() >>> c.read() [(1,)] >>> ps.close() >>> import sys >>> sys.exit(0) postgresql.bin.pg_dotconf ========================= pg_dotconf is used to modify a PostgreSQL cluster's configuration file. It provides a means to apply settings specified from the command line and from a file referenced using the ``-f`` option. .. warning:: ``include`` directives in configuration files are *completely* ignored. If modification of an included file is desired, the command must be applied to that specific file. pg_dotconf Usage ---------------- Usage: postgresql.bin.pg_dotconf [--stdout] [-f filepath] postgresql.conf ([param=val]|[param])* Options: --version show program's version number and exit -h, --help show this help message and exit -f SETTINGS, --file=SETTINGS A file of settings to *apply* to the given "postgresql.conf" --stdout Redirect the product to standard output instead of writing back to the "postgresql.conf" file Examples -------- Modifying a simple configuration file:: $ echo "setting = value" >pg.conf # change 'setting' $ python3 -m postgresql.bin.pg_dotconf pg.conf setting=newvalue $ cat pg.conf setting = 'newvalue' # new settings are appended to the file $ python3 -m postgresql.bin.pg_dotconf pg.conf another_setting=value $ cat pg.conf setting = 'newvalue' another_setting = 'value' # comment a setting $ python3 -m postgresql.bin.pg_dotconf pg.conf another_setting $ cat pg.conf setting = 'newvalue' #another_setting = 'value' When a setting is given on the command line, it must been seen as one argument to the command, so it's *very* important to avoid invocations like:: $ python3 -m postgresql.bin.pg_dotconf pg.conf setting = value ERROR: invalid setting, '=' after 'setting' HINT: Settings must take the form 'setting=value' or 'setting_name_to_comment'. Settings must also be received as a single argument. fe-1.1.0/postgresql/documentation/changes-v1.0.rst000066400000000000000000000071631203372773200220200ustar00rootroot00000000000000Changes in v1.0 =============== 1.0.4 in development -------------------- * Alter how changes are represented in documentation to simplify merging. 1.0.3 released on 2011-09-24 ---------------------------- * Use raise x from y to generalize exceptions. (Elvis Pranskevichus) * Alter postgresql.string.quote_ident to always quote. (Elvis Pranskevichus) * Add postgresql.string.quote_ident_if_necessary (Modification of Elvis Pranskevichus' patch) * Many postgresql.string bug fixes (Elvis Pranskevichus) * Correct ResourceWarnings improving Python 3.2 support. (jwp) * Add test command to setup.py (Elvis Pranskevichus) 1.0.2 released on 2010-09-18 ---------------------------- * Add support for DOMAINs in registered composites. (Elvis Pranskevichus) * Properly raise StopIteration in Cursor.__next__. (Elvis Pranskevichus) * Add Cluster Management documentation. * Release savepoints after rolling them back. * Fix Startup() usage for Python 3.2. * Emit deprecation warning when 'gid' is given to xact(). * Compensate for Python3.2's ElementTree API changes. 1.0.1 released on 2010-04-24 ---------------------------- * Fix unpacking of array NULLs. (Elvis Pranskevichus) * Fix .first()'s handling of counts and commands. Bad logic caused zero-counts to return the command tag. * Don't interrupt and close a temporal connection if it's not open. * Use the Driver's typio attribute for TypeIO overrides. (Elvis Pranskevichus) 1.0 released on 2010-03-27 -------------------------- * **DEPRECATION**: Removed 2PC support documentation. * **DEPRECATION**: Removed pg_python and pg_dotconf 'scripts'. They are still accessible by python3 -m postgresql.bin.pg_* * Add support for binary hstore. * Add support for user service files. * Implement a Copy manager for direct connection-to-connection COPY operations. * Added db.do() method for DO-statement support(convenience method). * Set the default client_min_messages level to WARNING. NOTICEs are often not desired by programmers, and py-postgresql's high verbosity further irritates that case. * Added postgresql.project module to provide project information. Project name, author, version, etc. * Increased default recvsize and chunksize for improved performance. * 'D' messages are special cased as builtins.tuples instead of protocol.element3.Tuple * Alter Statement.chunks() to return chunks of builtins.tuple. Being an interface intended for speed, types.Row() impedes its performance. * Fix handling of infinity values with timestamptz, timestamp, and date. [Bug reported by Axel Rau.] * Correct representation of PostgreSQL ARRAYs by properly recording lowerbounds and upperbounds. Internally, sub-ARRAYs have their own element lists. * Implement a NotificationManager for managing the NOTIFYs received by a connection. The class can manage NOTIFYs from multiple connections, whereas the db.wait() method is tailored for single targets. * Implement an ALock class for managing advisory locks using the threading.Lock APIs. [Feedback from Valentine Gogichashvili] * Implement reference symbols. Allow libraries to define symbols that are used to create queries that inherit the original symbol's type and execution method. ``db.prepare(db.prepare(...).first())`` * Fix handling of unix domain sockets by pg.open and driver.connect. [Reported by twitter.com/rintavarustus] * Fix typo/dropped parts of a raise LoadError in .lib. [Reported by Vlad Pranskevichus] * Fix db.tracer and pg_python's --pq-trace= * Fix count return from .first() method. Failed to provide an empty tuple for the rformats of the bind statement. [Reported by dou dou] fe-1.1.0/postgresql/documentation/changes-v1.1.rst000066400000000000000000000021321203372773200220100ustar00rootroot00000000000000Changes in v1.1 =============== 1.1.0 ----- * Remove two-phase commit interfaces per deprecation in v1.0. For proper two phase commit use, a lock manager must be employed that the implementation did nothing to accommodate for. * Add support for unpacking anonymous records (Elvis) * Support PostgreSQL 9.2 (Elvis) * Python 3.3 Support (Elvis) * Add column execution method. (jwp) * Add one-shot statement interface. Connection.query.* (jwp) * Modify the inet/cidr support by relying on the ipaddress module introduced in Python 3.3 (Google's ipaddr project) The existing implementation relied on simple str() representation supported by the socket module. Unfortunately, MS Windows' socket library does not appear to support the necessary functionality, or Python's socket module does not expose it. ipaddress fixes the problem. .. note:: The `ipaddress` module is now required for local inet and cidr. While it is of "preliminary" status, the ipaddr project has been around for some time and well supported. ipaddress appears to be the safest way forward for native network types. fe-1.1.0/postgresql/documentation/clientparameters.rst000066400000000000000000000243731203372773200232720ustar00rootroot00000000000000Client Parameters ***************** .. warning:: **The interfaces dealing with optparse are subject to change in 1.0**. There are various sources of parameters used by PostgreSQL client applications. The `postgresql.clientparameters` module provides a means for collecting and managing those parameters. Connection creation interfaces in `postgresql.driver` are purposefully simple. All parameters taken by those interfaces are keywords, and are taken literally; if a parameter is not given, it will effectively be `None`. libpq-based drivers tend differ as they inherit some default client parameters from the environment. Doing this by default is undesirable as it can cause trivial failures due to unexpected parameter inheritance. However, using these parameters from the environment and other sources are simply expected in *some* cases: `postgresql.open`, `postgresql.bin.pg_python`, and other high-level utilities. The `postgresql.clientparameters` module provides a means to collect them into one dictionary object for subsequent application to a connection creation interface. `postgresql.clientparameters` is primarily useful to script authors that want to provide an interface consistent with PostgreSQL commands like ``psql``. Collecting Parameters ===================== The primary entry points in `postgresql.clientparameters` are `postgresql.clientparameters.collect` and `postgresql.clientparameters.resolve_password`. For most purposes, ``collect`` will suffice. By default, it will prompt for the password if instructed to(``-W``). Therefore, ``resolve_password`` need not be used in most cases:: >>> import sys >>> import postgresql.clientparameters as pg_param >>> p = pg_param.DefaultParser() >>> co, ca = p.parse_args(sys.argv[1:]) >>> params = pg_param.collect(parsed_options = co) The `postgresql.clientparameters` module is executable, so you can see the results of the above snippet by:: $ python -m postgresql.clientparameters -h localhost -U a_db_user -ssearch_path=public {'host': 'localhost', 'password': None, 'port': 5432, 'settings': {'search_path': 'public'}, 'user': 'a_db_user'} `postgresql.clientparameters.collect` -------------------------------------- Build a client parameter dictionary from the environment and parsed command line options. The following is a list of keyword arguments that ``collect`` will accept: ``parsed_options`` Options parsed by `postgresql.clientparameters.StandardParser` or `postgresql.clientparameters.DefaultParser` instances. ``no_defaults`` When `True`, don't include defaults like ``pgpassfile`` and ``user``. Defaults to `False`. ``environ`` Environment variables to extract client parameter variables from. Defaults to `os.environ` and expects a `collections.Mapping` interface. ``environ_prefix`` Environment variable prefix to use. Defaults to "PG". This allows the collection of non-standard environment variables whose keys are partially consistent with the standard variants. e.g. "PG_SRC_USER", "PG_SRC_HOST", etc. ``default_pg_sysconfdir`` The location of the pg_service.conf file. The ``PGSYSCONFDIR`` environment variable will override this. When a default installation is present, ``PGINSTALLATION``, it should be set to this. ``pg_service_file`` Explicit location of the service file. This will override the "sysconfdir" based path. ``prompt_title`` Descriptive title to use if a password prompt is needed. `None` to disable password resolution entirely. Setting this to `None` will also disable pgpassfile lookups, so it is necessary that further processing occurs when this is `None`. ``parameters`` Base client parameters to use. These are set after the *defaults* are collected. (The defaults that can be disabled by ``no_defaults``). If ``prompt_title`` is not set to `None`, it will prompt for the password when instructed to do by the ``prompt_password`` key in the parameters:: >>> import postgresql.clientparameters as pg_param >>> p = pg_param.collect(prompt_title = 'my_prompt!', parameters = {'prompt_password':True}) Password for my_prompt![pq://jwp@localhost:5432]: >>> p {'host': 'localhost', 'user': 'jwp', 'password': 'secret', 'port': 5432} If `None`, it will leave the necessary password resolution information in the parameters dictionary for ``resolve_password``:: >>> p = pg_param.collect(prompt_title = None, parameters = {'prompt_password':True}) >>> p {'pgpassfile': '/Users/jwp/.pgpass', 'prompt_password': True, 'host': 'localhost', 'user': 'jwp', 'port': 5432} Of course, ``'prompt_password'`` is normally specified when ``parsed_options`` received a ``-W`` option from the command line:: >>> op = pg_param.DefaultParser() >>> co, ca = op.parse_args(['-W']) >>> p = pg_param.collect(parsed_options = co) >>> p=pg_param.collect(parsed_options = co) Password for [pq://jwp@localhost:5432]: >>> p {'host': 'localhost', 'user': 'jwp', 'password': 'secret', 'port': 5432} >>> `postgresql.clientparameters.resolve_password` ---------------------------------------------- Resolve the password for the given client parameters dictionary returned by ``collect``. By default, this function need not be used as ``collect`` will resolve the password by default. `resolve_password` accepts the following arguments: ``parameters`` First positional argument. Normalized client parameters dictionary to update in-place with the resolved password. If the 'prompt_password' key is in ``parameters``, it will prompt regardless(normally comes from ``-W``). ``getpass`` Function to call to prompt for the password. Defaults to `getpass.getpass`. ``prompt_title`` Additional title to use if a prompt is requested. This can also be specified in the ``parameters`` as the ``prompt_title`` key. This *augments* the IRI display on the prompt. Defaults to an empty string, ``''``. The resolution process is effected by the contents of the given ``parameters``. Notable keywords: ``prompt_password`` If present in the given parameters, the user will be prompted for the using the given ``getpass`` function. This disables the password file lookup process. ``prompt_title`` This states a default prompt title to use. If the ``prompt_title`` keyword argument is given to ``resolve_password``, this will not be used. ``pgpassfile`` The PostgreSQL password file to lookup the password in. If the ``password`` parameter is present, this will not be used. When resolution occurs, the ``prompt_password``, ``prompt_title``, and ``pgpassfile`` keys are *removed* from the given parameters dictionary:: >>> p=pg_param.collect(prompt_title = None) >>> p {'pgpassfile': '/Users/jwp/.pgpass', 'host': 'localhost', 'user': 'jwp', 'port': 5432} >>> pg_param.resolve_password(p) >>> p {'host': 'localhost', 'password': 'secret', 'user': 'jwp', 'port': 5432} Defaults ======== The following is a list of default parameters provided by ``collect`` and the sources of their values: ==================== =================================================================== Key Value ==================== =================================================================== ``'user'`` `getpass.getuser()` or ``'postgres'`` ``'host'`` `postgresql.clientparameters.default_host` (``'localhost'``) ``'port'`` `postgresql.clientparameters.default_port` (``5432``) ``'pgpassfile'`` ``"$HOME/.pgpassfile"`` or ``[PGDATA]`` + ``'pgpass.conf'`` (Win32) ``'sslcrtfile'`` ``[PGDATA]`` + ``'postgresql.crt'`` ``'sslkeyfile'`` ``[PGDATA]`` + ``'postgresql.key'`` ``'sslrootcrtfile'`` ``[PGDATA]`` + ``'root.crt'`` ``'sslrootcrlfile'`` ``[PGDATA]`` + ``'root.crl'`` ==================== =================================================================== ``[PGDATA]`` referenced in the above table is a directory whose path is platform dependent. On most systems, it is ``"$HOME/.postgresql"``, but on Windows based systems it is ``"%APPDATA%\postgresql"`` .. note:: [PGDATA] is *not* an environment variable. .. _pg_envvars: PostgreSQL Environment Variables ================================ The following is a list of environment variables that will be collected by the `postgresql.clientparameter.collect` function using "PG" as the ``environ_prefix`` and the keyword that it will be mapped to: ===================== ====================================== Environment Variable Keyword ===================== ====================================== ``PGUSER`` ``'user'`` ``PGDATABASE`` ``'database'`` ``PGHOST`` ``'host'`` ``PGPORT`` ``'port'`` ``PGPASSWORD`` ``'password'`` ``PGSSLMODE`` ``'sslmode'`` ``PGSSLKEY`` ``'sslkey'`` ``PGCONNECT_TIMEOUT`` ``'connect_timeout'`` ``PGREALM`` ``'kerberos4_realm'`` ``PGKRBSRVNAME`` ``'kerberos5_service'`` ``PGPASSFILE`` ``'pgpassfile'`` ``PGTZ`` ``'settings' = {'timezone': }`` ``PGDATESTYLE`` ``'settings' = {'datestyle': }`` ``PGCLIENTENCODING`` ``'settings' = {'client_encoding': }`` ``PGGEQO`` ``'settings' = {'geqo': }`` ===================== ====================================== .. _pg_passfile: PostgreSQL Password File ======================== The password file is a simple newline separated list of ``:`` separated fields. It is located at ``$HOME/.pgpass`` for most systems and at ``%APPDATA%\postgresql\pgpass.conf`` for Windows based systems. However, the ``PGPASSFILE`` environment variable may be used to override that location. The lines in the file must be in the following form:: hostname:port:database:username:password A single asterisk, ``*``, may be used to indicate that any value will match the field. However, this only effects fields other than ``password``. See http://www.postgresql.org/docs/current/static/libpq-pgpass.html for more details. Client parameters produced by ``collect`` that have not been processed by ``resolve_password`` will include a ``'pgpassfile'`` key. This is the value that ``resolve_password`` will use to locate the pgpassfile to interrogate if a password key is not present and it is not instructed to prompt for a password. .. warning:: Connection creation interfaces will *not* resolve ``'pgpassfile'``, so it is important that the parameters produced by ``collect()`` are properly processed before an attempt is made to establish a connection. fe-1.1.0/postgresql/documentation/cluster.rst000066400000000000000000000367541203372773200214170ustar00rootroot00000000000000.. _cluster_management: ****************** Cluster Management ****************** py-postgresql provides cluster management tools in order to give the user fine-grained control over a PostgreSQL cluster and access to information about an installation of PostgreSQL. .. _installation: Installations ============= `postgresql.installation.Installation` objects are primarily used to access PostgreSQL installation information. Normally, they are created using a dictionary constructed from the output of the pg_config_ executable:: from postgresql.installation import Installation, pg_config_dictionary pg_install = Installation(pg_config_dictionary('/usr/local/pgsql/bin/pg_config')) The extraction of pg_config_ information is isolated from Installation instantiation in order to allow Installations to be created from arbitrary dictionaries. This can be useful in cases where the installation layout is inconsistent with the standard PostgreSQL installation layout, or if a faux Installation needs to be created for testing purposes. Installation Interface Points ----------------------------- ``Installation(info)`` Instantiate an Installation using the given information. Normally, this information is extracted from a pg_config_ executable using `postgresql.installation.pg_config_dictionary`:: info = pg_config_dictionary('/usr/local/pgsql/bin/pg_config') pg_install = Installation(info) ``Installation.version`` The installation's version string:: pg_install.version 'PostgreSQL 9.0devel' ``Installation.version_info`` A tuple containing the version's ``(major, minor, patch, state, level)``. Where ``major``, ``minor``, ``patch``, and ``level`` are `int` objects, and ``state`` is a `str` object:: pg_install.version_info (9, 0, 0, 'devel', 0) ``Installation.ssl`` A `bool` indicating whether or not the installation has SSL support. ``Installation.configure_options`` The options given to the ``configure`` script that built the installation. The options are represented using a dictionary object whose keys are normalized long option names, and whose values are the option's argument. If the option takes no argument, `True` will be used as the value. The normalization of the long option names consists of removing the preceding dashes, lowering the string, and replacing any dashes with underscores. For instance, ``--enable-debug`` will be ``enable_debug``:: pg_install.configure_options {'enable_debug': True, 'with_libxml': True, 'enable_cassert': True, 'with_libedit_preferred': True, 'prefix': '/src/build/pg90', 'with_openssl': True, 'enable_integer_datetimes': True, 'enable_depend': True} ``Installation.paths`` The paths of the installation as a dictionary where the keys are the path identifiers and the values are the absolute file system paths. For instance, ``'bindir'`` is associated with ``$PREFIX/bin``, ``'libdir'`` is associated with ``$PREFIX/lib``, etc. The paths included in this dictionary are listed on the class' attributes: `Installation.pg_directories` and `Installation.pg_executables`. The keys that point to installation directories are: ``bindir``, ``docdir``, ``includedir``, ``pkgincludedir``, ``includedir_server``, ``libdir``, ``pkglibdir``, ``localedir``, ``mandir``, ``sharedir``, and ``sysconfdir``. The keys that point to installation executables are: ``pg_config``, ``psql``, ``initdb``, ``pg_resetxlog``, ``pg_controldata``, ``clusterdb``, ``pg_ctl``, ``pg_dump``, ``pg_dumpall``, ``postgres``, ``postmaster``, ``reindexdb``, ``vacuumdb``, ``ipcclean``, ``createdb``, ``ecpg``, ``createuser``, ``createlang``, ``droplang``, ``dropuser``, and ``pg_restore``. .. note:: If the executable does not exist, the value will be `None` instead of an absoluate path. To get the path to the psql_ executable:: from postgresql.installation import Installation pg_install = Installation('/usr/local/pgsql/bin/pg_config') psql_path = pg_install.paths['psql'] Clusters ======== `postgresql.cluster.Cluster` is the class used to manage a PostgreSQL cluster--a data directory created by initdb_. A Cluster represents a data directory with respect to a given installation of PostgreSQL, so creating a `postgresql.cluster.Cluster` object requires a `postgresql.installation.Installation`, and a file system path to the data directory. In part, a `postgresql.cluster.Cluster` is the Python programmer's variant of the pg_ctl_ command. However, it goes beyond the basic process control functionality and extends into initialization and configuration as well. A Cluster manages the server process using the `subprocess` module and signals. The `subprocess.Popen` object, ``Cluster.daemon_process``, is retained when the Cluster starts the server process itself. This gives the Cluster access to the result code of server process when it exits, and the ability to redirect stderr and stdout to a parameterized file object using subprocess features. Despite its use of `subprocess`, Clusters can control a server process that was *not* started by the Cluster's ``start`` method. Initializing Clusters --------------------- `postgresql.cluster.Cluster` provides a method for initializing a `Cluster`'s data directory, ``init``. This method provides a Python interface to the PostgreSQL initdb_ command. ``init`` is a regular method and accepts a few keyword parameters. Normally, parameters are directly mapped to initdb_ command options. However, ``password`` makes use of initdb's capability to read the superuser's password from a file. To do this, a temporary file is allocated internally by the method:: from postgresql.installation import Installation, pg_config_dictionary from postgresql.cluster import Cluster pg_install = Installation(pg_config_dictionary('/usr/local/pgsql/bin/pg_config')) pg_cluster = Cluster(pg_install, 'pg_data') pg_cluster.init(user = 'pg', password = 'secret', encoding = 'utf-8') The init method will block until the initdb command is complete. Once initialized, the Cluster may be configured. Configuring Clusters -------------------- A Cluster's `configuration file`_ can be manipulated using the `Cluster.settings` mapping. The mapping's methods will always access the configuration file, so it may be desirable to cache repeat reads. Also, if multiple settings are being applied, using the ``update()`` method may be important to avoid writing the entire file multiple times:: pg_cluster.settings.update({'listen_addresses' : 'localhost', 'port' : '6543'}) Similarly, to avoid opening and reading the entire file multiple times, `Cluster.settings.getset` should be used to retrieve multiple settings:: d = pg_cluster.settings.getset(set(('listen_addresses', 'port'))) d {'listen_addresses' : 'localhost', 'port' : '6543'} Values contained in ``settings`` are always Python strings:: assert pg_cluster.settings['max_connections'].__class__ is str The ``postgresql.conf`` file is only one part of the server configuration. Structured access and manipulation of the pg_hba_ file is not supported. Clusters only provide the file path to the pg_hba_ file:: hba = open(pg_cluster.hba_file) If the configuration of the Cluster is altered while the server process is running, it may be necessary to signal the process that configuration changes have been made. This signal can be sent using the ``Cluster.reload()`` method. ``Cluster.reload()`` will send a SIGHUP signal to the server process. However, not all changes to configuration settings can go into effect after calling ``Cluster.reload()``. In those cases, the server process will need to be shutdown and started again. Controlling Clusters -------------------- The server process of a Cluster object can be controlled with the ``start()``, ``stop()``, ``shutdown()``, ``kill()``, and ``restart()`` methods. These methods start the server process, signal the server process, or, in the case of restart, a combination of the two. When a Cluster starts the server process, it's ran as a subprocess. Therefore, if the current process exits, the server process will exit as well. ``start()`` does *not* automatically daemonize the server process. .. note:: Under Microsoft Windows, above does not hold true. The server process will continue running despite the exit of the parent process. To terminate a server process, one of these three methods should be called: ``stop``, ``shutdown``, or ``kill``. ``stop`` is a graceful shutdown and will *wait for all clients to disconnect* before shutting down. ``shutdown`` will close any open connections and safely shutdown the server process. ``kill`` will immediately terminate the server process leading to recovery upon starting the server process again. .. note:: Using ``kill`` may cause shared memory to be leaked. Normally, `Cluster.shutdown` is the appropriate way to terminate a server process. Cluster Interface Points ------------------------ Methods and properties available on `postgresql.cluster.Cluster` instances: ``Cluster(installation, data_directory)`` Create a `postgresql.cluster.Cluster` object for the specified `postgresql.installation.Installation`, and ``data_directory``. The ``data_directory`` must be an absoluate file system path. The directory does *not* need to exist. The ``init()`` method may later be used to create the cluster. ``Cluster.installation`` The Cluster's `postgresql.installation.Installation` instance. ``Cluster.data_directory`` The absolute path to the PostgreSQL data directory. This directory may not exist. ``Cluster.init([encoding = None[, user = None[, password = None]]])`` Run the `initdb`_ executable of the configured installation to initialize the cluster at the configured data directory, `Cluster.data_directory`. ``encoding`` is mapped to ``-E``, the default database encoding. By default, the encoding is determined from the environment's locale. ``user`` is mapped to ``-U``, the database superuser name. By default, the current user's name. ``password`` is ultimately mapped to ``--pwfile``. The argument given to the long option is actually a path to the temporary file that holds the given password. Raises `postgresql.cluster.InitDBError` when initdb_ returns a non-zero result code. Raises `postgresql.cluster.ClusterInitializationError` when there is no initdb_ in the Installation. ``Cluster.initialized()`` Whether or not the data directory exists, *and* if it looks like a PostgreSQL data directory. Meaning, the directory must contain a ``postgresql.conf`` file and a ``base`` directory. ``Cluster.drop()`` Shutdown the Cluster's server process and completely remove the `Cluster.data_directory` from the file system. ``Cluster.pid()`` The server's process identifier as a Python `int`. `None` if there is no server process running. This is a method rather than a property as it may read the PID from a file in cases where the server process was not started by the Cluster. ``Cluster.start([logfile = None[, settings = None]])`` Start the PostgreSQL server process for the Cluster if it is not already running. This will execute postgres_ as a subprocess. If ``logfile``, an opened and writable file object, is given, stderr and stdout will be redirected to that file. By default, both stderr and stdout are closed. If ``settings`` is given, the mapping or sequence of pairs will be used as long options to the subprocess. For each item, ``--{key}={value}`` will be given as an argument to the subprocess. ``Cluster.running()`` Whether or not the cluster's server process is running. Returns `True` or `False`. Even if `True` is returned, it does *not* mean that the server process is ready to accept connections. ``Cluster.ready_for_connections()`` Whether or not the Cluster is ready to accept connections. Usually called after `Cluster.start`. Returns `True` when the Cluster can accept connections, `False` when it cannot, and `None` if the Cluster's server process is not running at all. ``Cluster.wait_until_started([timeout = 10[, delay = 0.05]])`` Blocks the process until the cluster is identified as being ready for connections. Usually called after ``Cluster.start()``. Raises `postgresql.cluster.ClusterNotRunningError` if the server process is not running at all. Raises `postgresql.cluster.ClusterTimeoutError` if `Cluster.ready_for_connections()` does not return `True` within the given `timeout` period. Raises `postgresql.cluster.ClusterStartupError` if the server process terminates while polling for readiness. ``timeout`` and ``delay`` are both in seconds. Where ``timeout`` is the maximum time to wait for the Cluster to be ready for connections, and ``delay`` is the time to sleep between calls to `Cluster.ready_for_connections()`. ``Cluster.stop()`` Signal the cluster to shutdown when possible. The *server* will wait for all clients to disconnect before shutting down. ``Cluster.shutdown()`` Signal the cluster to shutdown immediately. Any open client connections will be closed. ``Cluster.kill()`` Signal the absolute destruction of the server process(SIGKILL). *This will require recovery when the cluster is started again.* *Shared memory may be leaked.* ``Cluster.wait_until_stopped([timeout = 10[, delay = 0.05]])`` Blocks the process until the cluster is identified as being shutdown. Usually called after `Cluster.stop` or `Cluster.shutdown`. Raises `postgresql.cluster.ClusterTimeoutError` if `Cluster.ready_for_connections` does not return `None` within the given `timeout` period. ``Cluster.reload()`` Signal the server that it should reload its configuration files(SIGHUP). Usually called after manipulating `Cluster.settings` or modifying the contents of `Cluster.hba_file`. ``Cluster.restart([logfile = None[, settings = None[, timeout = 10]]])`` Stop the server process, wait until it is stopped, start the server process, and wait until it has started. .. note:: This calls ``Cluster.stop()``, so it will wait until clients disconnect before starting up again. The ``logfile`` and ``settings`` parameters will be given to `Cluster.start`. ``timeout`` will be given to `Cluster.wait_until_stopped` and `Cluster.wait_until_started`. ``Cluster.settings`` A `collections.Mapping` interface to the ``postgresql.conf`` file of the cluster. A notable extension to the mapping interface is the ``getset`` method. This method will return a dictionary object containing the settings whose names were contained in the `set` object given to the method. This method should be used when multiple settings need to be retrieved from the configuration file. ``Cluster.hba_file`` The path to the cluster's pg_hba_ file. This property respects the HBA file location setting in ``postgresql.conf``. Usually, ``$PGDATA/pg_hba.conf``. ``Cluster.daemon_path`` The path to the executable to use to start the server process. ``Cluster.daemon_process`` The `subprocess.Popen` instance of the server process. `None` if the server process was not started or was not started using the Cluster object. .. _pg_hba: http://www.postgresql.org/docs/current/static/auth-pg-hba-conf.html .. _pg_config: http://www.postgresql.org/docs/current/static/app-pgconfig.html .. _initdb: http://www.postgresql.org/docs/current/static/app-initdb.html .. _psql: http://www.postgresql.org/docs/current/static/app-psql.html .. _postgres: http://www.postgresql.org/docs/current/static/app-postgres.html .. _pg_ctl: http://www.postgresql.org/docs/current/static/app-pg-ctl.html .. _configuration file: http://www.postgresql.org/docs/current/static/runtime-config.html fe-1.1.0/postgresql/documentation/copyman.rst000066400000000000000000000270071203372773200213730ustar00rootroot00000000000000.. _pg_copyman: *************** Copy Management *************** The `postgresql.copyman` module provides a way to quickly move COPY data coming from one connection to many connections. Alternatively, it can be sourced by arbitrary iterators and target arbitrary callables. Statement execution methods offer a way for running COPY operations with iterators, but the cost of allocating objects for each row is too significant for transferring gigabytes of COPY data from one connection to another. The interfaces available on statement objects are primarily intended to be used when transferring COPY data to and from arbitrary Python objects. Direct connection-to-connection COPY operations can be performed using the high-level `postgresql.copyman.transfer` function:: >>> from postgresql import copyman >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") >>> total_rows, total_bytes = copyman.transfer(send_stmt, receive_stmt) However, if more control is needed, the `postgresql.copyman.CopyManager` class should be used directly. Copy Managers ============= The `postgresql.copyman.CopyManager` class manages the Producer and the Receivers involved in a COPY operation. Normally, `postgresql.copyman.StatementProducer` and `postgresql.copyman.StatementReceiver` instances. Naturally, a Producer is the object that produces the COPY data to be given to the Manager's Receivers. Using a Manager directly means that there is a need for more control over the operation. The Manager is both a context manager and an iterator. The context manager interfaces handle initialization and finalization of the COPY state, and the iterator provides an event loop emitting information about the amount of COPY data transferred this cycle. Normal usage takes the form:: >>> from postgresql import copyman >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") >>> producer = copyman.StatementProducer(send_stmt) >>> receiver = copyman.StatementReceiver(receive_stmt) >>> >>> with source.xact(), destination.xact(): ... with copyman.CopyManager(producer, receiver) as copy: ... for num_messages, num_bytes in copy: ... update_rate(num_bytes) As an alternative to a for-loop inside a with-statement block, the `run` method can be called to perform the operation:: >>> with source.xact(), destination.xact(): ... copyman.CopyManager(producer, receiver).run() However, there is little benefit beyond using the high-level `postgresql.copyman.transfer` function. Manager Interface Points ------------------------ Primarily, the `postgresql.copyman.CopyManager` provides a context manager and an iterator for controlling the COPY operation. ``CopyManager.run()`` Perform the entire COPY operation. ``CopyManager.__enter__()`` Start the COPY operation. Connections taking part in the COPY should **not** be used until ``__exit__`` is ran. ``CopyManager.__exit__(typ, val, tb)`` Finish the COPY operation. Fails in the case of an incomplete COPY, or an untrapped exception. Either returns `None` or raises the generalized exception, `postgresql.copyman.CopyFail`. ``CopyManager.__iter__()`` Returns the CopyManager instance. ``CopyManager.__next__()`` Transfer the next chunk of COPY data to the receivers. Yields a tuple consisting of the number of messages and bytes transferred, ``(num_messages, num_bytes)``. Raises `StopIteration` when complete. Raises `postgresql.copyman.ReceiverFault` when a Receiver raises an exception. Raises `postgresql.copyman.ProducerFault` when the Producer raises an exception. The original exception is available via the exception's ``__context__`` attribute. ``CopyManager.reconcile(faulted_receiver)`` Reconcile a faulted receiver. When a receiver faults, it will no longer be in the set of Receivers. This method is used to signal to the manager that the problem has been corrected, and the receiver is again ready to receive. ``CopyManager.receivers`` The `builtins.set` of Receivers involved in the COPY operation. ``CopyManager.producer`` The Producer emitting the data to be given to the Receivers. Faults ====== The CopyManager generalizes any exceptions that occur during transfer. While inside the context manager, `postgresql.copyman.Fault` may be raised if a Receiver or a Producer raises an exception. A `postgresql.copyman.ProducerFault` in the case of the Producer, and `postgresql.copyman.ReceiverFault` in the case of the Receivers. .. note:: Faults are only raised by `postgresql.copyman.CopyManager.__next__`. The ``run()`` method will only raise `postgresql.copyman.CopyFail`. Receiver Faults --------------- The Manager assumes the Fault is fatal to a Receiver, and immediately removes it from the set of target receivers. Additionally, if the Fault exception goes untrapped, the copy will ultimately fail. The Fault exception references the Manager that raised the exception, and the actual exceptions that occurred associated with the Receiver that caused them. In order to identify the exception that caused a Fault, the ``faults`` attribute on the `postgresql.copyman.ReceiverFault` must be referenced:: >>> from postgresql import copyman >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") >>> producer = copyman.StatementProducer(send_stmt) >>> receiver = copyman.StatementReceiver(receive_stmt) >>> >>> with source.xact(), destination.xact(): ... with copyman.CopyManager(producer, receiver) as copy: ... while copy.receivers: ... try: ... for num_messages, num_bytes in copy: ... update_rate(num_bytes) ... break ... except copyman.ReceiverFault as cf: ... # Access the original exception using the receiver as the key. ... original_exception = cf.faults[receiver] ... if unknown_failure(original_exception): ... ... ... raise ReceiverFault Properties ~~~~~~~~~~~~~~~~~~~~~~~~ The following attributes exist on `postgresql.copyman.ReceiverFault` instances: ``ReceiverFault.manager`` The subject `postgresql.copyman.CopyManager` instance. ``ReceiverFault.faults`` A dictionary mapping the Receiver to the exception raised by that Receiver. Reconciliation ~~~~~~~~~~~~~~ When a `postgresql.copyman.ReceiverFault` is raised, the Manager immediately removes the Receiver so that the COPY operation can continue. Continuation of the COPY can occur by trapping the exception and continuing the iteration of the Manager. However, if the fault is recoverable, the `postgresql.copyman.CopyManager.reconcile` method must be used to reintroduce the Receiver into the Manager's set. Faults must be trapped from within the Manager's context:: >>> import socket >>> from postgresql import copyman >>> send_stmt = source.prepare("COPY (SELECT i FROM generate_series(1, 1000000) AS g(i)) TO STDOUT") >>> destination.execute("CREATE TEMP TABLE loading_table (i int8)") >>> receive_stmt = destination.prepare("COPY loading_table FROM STDIN") >>> producer = copyman.StatementProducer(send_stmt) >>> receiver = copyman.StatementReceiver(receive_stmt) >>> >>> with source.xact(), destination.xact(): ... with copyman.CopyManager(producer, receiver) as copy: ... while copy.receivers: ... try: ... for num_messages, num_bytes in copy: ... update_rate(num_bytes) ... except copyman.ReceiverFault as cf: ... if isinstance(cf.faults[receiver], socket.timeout): ... copy.reconcile(receiver) ... else: ... raise Recovering from Faults does add significant complexity to a COPY operation, so, often, it's best to avoid conditions in which reconciliable Faults may occur. Producer Faults --------------- Producer faults are normally fatal to the COPY operation and should rarely be trapped. However, the Manager makes no state changes when a Producer faults, so, unlike Receiver Faults, no reconciliation process is necessary; rather, if it's safe to continue, the Manager's iterator should continue to be processed. ProducerFault Properties ~~~~~~~~~~~~~~~~~~~~~~~~ The following attributes exist on `postgresql.copyman.ProducerFault` instances: ``ReceiverFault.manager`` The subject `postgresql.copyman.CopyManager`. ``ReceiverFault.__context__`` The original exception raised by the Producer. Failures ======== When a COPY operation is aborted, either by an exception or by the iterator being broken, a `postgresql.copyman.CopyFail` exception will be raised by the Manager's ``__exit__()`` method. The `postgresql.copyman.CopyFail` exception offers to record any exceptions that occur during the exit of the context managers of the Producer and the Receivers. CopyFail Properties ------------------- The following properties exist on `postgresql.copyman.CopyFail` exceptions: ``CopyFail.manager`` The Manager whose COPY operation failed. ``CopyFail.receiver_faults`` A dictionary mapping a `postgresql.copyman.Receiver` to the exception raised by that Receiver's ``__exit__``. `None` if no exceptions were raised by the Receivers. ``CopyFail.producer_fault`` The exception Raised by the `postgresql.copyman.Producer`. `None` if none. Producers ========= The following Producers are available: ``postgresql.copyman.StatementProducer(postgresql.api.Statement)`` Given a Statement producing COPY data, construct a Producer. ``postgresql.copyman.IteratorProducer(collections.Iterator)`` Given an Iterator producing *chunks* of COPY lines, construct a Producer to manage the data coming from the iterator. Receivers ========= ``postgresql.copyman.StatementReceiver(postgresql.api.Statement)`` Given a Statement producing COPY data, construct a Producer. ``postgresql.copyman.CallReceiver(callable)`` Given a callable, construct a Receiver that will transmit COPY data in chunks of lines. That is, the callable will be given a list of COPY lines for each transfer cycle. Terminology =========== The following terms are regularly used to describe the implementation and processes of the `postgresql.copyman` module: Manager The object used to manage data coming from a Producer and being given to the Receivers. It also manages the necessary initialization and finalization steps required by those factors. Producer The object used to produce the COPY data to be given to the Receivers. The source. Receiver An object that consumes COPY data. A target. Fault Specifically, `postgresql.copyman.Fault` exceptions. A Fault is raised when a Receiver or a Producer raises an exception during the COPY operation. Reconciliation Generally, the steps performed by the "reconcile" method on `postgresql.copyman.CopyManager` instances. More precisely, the necessary steps for a Receiver's reintroduction into the COPY operation after a Fault. Failed Copy A failed copy is an aborted COPY operation. This occurs in situations of untrapped exceptions or an incomplete COPY. Specifically, the COPY will be noted as failed in cases where the Manager's iterator is *not* ran until exhaustion. Realignment The process of providing compensating data to the Receivers so that the connection will be on a message boundary. Occurs when the COPY operation is aborted. fe-1.1.0/postgresql/documentation/driver.rst000066400000000000000000002015561203372773200212230ustar00rootroot00000000000000.. _db_interface: ****** Driver ****** `postgresql.driver` provides a PG-API, `postgresql.api`, interface to a PostgreSQL server using PQ version 3.0 to facilitate communication. It makes use of the protocol's extended features to provide binary datatype transmission and protocol level prepared statements for strongly typed parameters. `postgresql.driver` currently supports PostgreSQL servers as far back as 8.0. Prior versions are not tested. While any version of PostgreSQL supporting version 3.0 of the PQ protocol *should* work, many features may not work due to absent functionality in the remote end. For DB-API 2.0 users, the driver module is located at `postgresql.driver.dbapi20`. The DB-API 2.0 interface extends PG-API. All of the features discussed in this chapter are available on DB-API connections. .. warning:: PostgreSQL versions 8.1 and earlier do not support standard conforming strings. In order to avoid subjective escape methods on connections, `postgresql.driver.pq3` enables the ``standard_conforming_strings`` setting by default. Greater care must be taken when working versions that do not support standard strings. **The majority of issues surrounding the interpolation of properly quoted literals can be easily avoided by using parameterized statements**. The following identifiers are regularly used as shorthands for significant interface elements: ``db`` `postgresql.api.Connection`, a database connection. `Connections`_ ``ps`` `postgresql.api.Statement`, a prepared statement. `Prepared Statements`_ ``c`` `postgresql.api.Cursor`, a cursor; the results of a prepared statement. `Cursors`_ ``C`` `postgresql.api.Connector`, a connector. `Connectors`_ Establishing a Connection ========================= There are many ways to establish a `postgresql.api.Connection` to a PostgreSQL server using `postgresql.driver`. This section discusses those, connection creation, interfaces. `postgresql.open` ----------------- In the root package module, the ``open()`` function is provided for accessing databases using a locator string and optional connection keywords. The string taken by `postgresql.open` is a URL whose components make up the client parameters:: >>> db = postgresql.open("pq://localhost/postgres") This will connect to the host, ``localhost`` and to the database named ``postgres`` via the ``pq`` protocol. open will inherit client parameters from the environment, so the user name given to the server will come from ``$PGUSER``, or if that is unset, the result of `getpass.getuser`--the username of the user running the process. The user's "pgpassfile" will even be referenced if no password is given:: >>> db = postgresql.open("pq://username:password@localhost/postgres") In this case, the password *is* given, so ``~/.pgpass`` would never be referenced. The ``user`` client parameter is also given, ``username``, so ``$PGUSER`` or `getpass.getuser` will not be given to the server. Settings can also be provided by the query portion of the URL:: >>> db = postgresql.open("pq://user@localhost/postgres?search_path=public&timezone=mst") The above syntax ultimately passes the query as settings(see the description of the ``settings`` keyword in `Connection Keywords`). Driver parameters require a distinction. This distinction is made when the setting's name is wrapped in square-brackets, '[' and ']': >>> db = postgresql.open("pq://user@localhost/postgres?[sslmode]=require&[connect_timeout]=5") ``sslmode`` and ``connect_timeout`` are driver parameters. These are never sent to the server, but if they were not in square-brackets, they would be, and the driver would never identify them as driver parameters. The general structure of a PQ-locator is:: protocol://user:password@host:port/database?[driver_setting]=value&server_setting=value Optionally, connection keyword arguments can be used to override anything given in the locator:: >>> db = postgresql.open("pq://user:secret@host", password = "thE_real_sekrat") Or, if the locator is not desired, individual keywords can be used exclusively:: >>> db = postgresql.open(user = 'user', host = 'localhost', port = 6543) In fact, all arguments to `postgresql.open` are optional as all arguments are keywords; ``iri`` is merely the first keyword argument taken by `postgresql.open`. If the environment has all the necessary parameters for a successful connection, there is no need to pass anything to open:: >>> db = postgresql.open() For a complete list of keywords that `postgresql.open` can accept, see `Connection Keywords`_. For more information about the environment variables, see :ref:`pg_envvars`. For more information about the ``pgpassfile``, see :ref:`pg_passfile`. `postgresql.driver.connect` --------------------------- `postgresql.open` is a high-level interface to connection creation. It provides password resolution services and client parameter inheritance. For some applications, this is undesirable as such implicit inheritance may lead to failures due to unanticipated parameters being used. For those applications, use of `postgresql.open` is not recommended. Rather, `postgresql.driver.connect` should be used when explicit parameterization is desired by an application: >>> import postgresql.driver as pg_driver >>> db = pg_driver.connect( ... user = 'usename', ... password = 'secret', ... host = 'localhost', ... port = 5432 ... ) This will create a connection to the server listening on port ``5432`` on the host ``localhost`` as the user ``usename`` with the password ``secret``. .. note:: `connect` will *not* inherit parameters from the environment as libpq-based drivers do. See `Connection Keywords`_ for a full list of acceptable keyword parameters and their meaning. Connectors ---------- Connectors are the supporting objects used to instantiate a connection. They exist for the purpose of providing connections with the necessary abstractions for facilitating the client's communication with the server, *and to act as a container for the client parameters*. The latter purpose is of primary interest to this section. Each connection object is associated with its connector by the ``connector`` attribute on the connection. This provides the user with access to the parameters used to establish the connection in the first place, and the means to create another connection to the same server. The attributes on the connector should *not* be altered. If parameter changes are needed, a new connector should be created. The attributes available on a connector are consistent with the names of the connection parameters described in `Connection Keywords`_, so that list can be used as a reference to identify the information available on the connector. Connectors fit into the category of "connection creation interfaces", so connector instantiation normally takes the same parameters that the `postgresql.driver.connect` function takes. .. note:: Connector implementations are specific to the transport, so keyword arguments like ``host`` and ``port`` aren't supported by the ``Unix`` connector. The driver, `postgresql.driver.default` provides a set of connectors for making a connection: ``postgresql.driver.default.host(...)`` Provides a ``getaddrinfo()`` abstraction for establishing a connection. ``postgresql.driver.default.ip4(...)`` Connect to a single IPv4 addressed host. ``postgresql.driver.default.ip6(...)`` Connect to a single IPv6 addressed host. ``postgresql.driver.default.unix(...)`` Connect to a single unix domain socket. Requires the ``unix`` keyword which must be an absolute path to the unix domain socket to connect to. ``host`` is the usual connector used to establish a connection:: >>> C = postgresql.driver.default.host( ... user = 'auser', ... host = 'foo.com', ... port = 5432) >>> # create >>> db = C() >>> # establish >>> db.connect() If a constant internet address is used, ``ip4`` or ``ip6`` can be used:: >>> C = postgresql.driver.default.ip4(user='auser', host='127.0.0.1', port=5432) >>> db = C() >>> db.connect() Additionally, ``db.connect()`` on ``db.__enter__()`` for with-statement support: >>> with C() as db: ... ... Connectors are constant. They have no knowledge of PostgreSQL service files, environment variables or LDAP services, so changes made to those facilities will *not* be reflected in a connector's configuration. If the latest information from any of these sources is needed, a new connector needs to be created as the credentials have changed. .. note:: ``host`` connectors use ``getaddrinfo()``, so if DNS changes are made, new connections *will* use the latest information. Connection Keywords ------------------- The following is a list of keywords accepted by connection creation interfaces: ``user`` The user to connect as. ``password`` The user's password. ``database`` The name of the database to connect to. (PostgreSQL defaults it to `user`) ``host`` The hostname or IP address to connect to. ``port`` The port on the host to connect to. ``unix`` The unix domain socket to connect to. Exclusive with ``host`` and ``port``. Expects a string containing the *absolute path* to the unix domain socket to connect to. ``settings`` A dictionary or key-value pair sequence stating the parameters to give to the database. These settings are included in the startup packet, and should be used carefully as when an invalid setting is given, it will cause the connection to fail. ``connect_timeout`` Amount of time to wait for a connection to be made. (in seconds) ``server_encoding`` Hint given to the driver to properly encode password data and some information in the startup packet. This should only be used in cases where connections cannot be made due to authentication failures that occur while using known-correct credentials. ``sslmode`` ``'disable'`` Don't allow SSL connections. ``'allow'`` Try without SSL first, but if that doesn't work, try with. ``'prefer'`` Try SSL first, then without. ``'require'`` Require an SSL connection. ``sslcrtfile`` Certificate file path given to `ssl.wrap_socket`. ``sslkeyfile`` Key file path given to `ssl.wrap_socket`. ``sslrootcrtfile`` Root certificate file path given to `ssl.wrap_socket` ``sslrootcrlfile`` Revocation list file path. [Currently not checked.] Connections =========== `postgresql.open` and `postgresql.driver.connect` provide the means to establish a connection. Connections provide a `postgresql.api.Database` interface to a PostgreSQL server; specifically, a `postgresql.api.Connection`. Connections are one-time objects. Once, it is closed or lost, it can longer be used to interact with the database provided by the server. If further use of the server is desired, a new connection *must* be established. .. note:: Cannot connect failures, exceptions raised on ``connect()``, are also terminal. In cases where operations are performed on a closed connection, a `postgresql.exceptions.ConnectionDoesNotExistError` will be raised. Database Interface Points ------------------------- After a connection is established:: >>> import postgresql >>> db = postgresql.open(...) The methods and properties on the connection object are ready for use: ``Connection.prepare(sql_statement_string)`` Create a `postgresql.api.Statement` object for querying the database. This provides an "SQL statement template" that can be executed multiple times. See `Prepared Statements`_ for more information. ``Connection.proc(procedure_id)`` Create a `postgresql.api.StoredProcedure` object referring to a stored procedure on the database. The returned object will provide a `collections.Callable` interface to the stored procedure on the server. See `Stored Procedures`_ for more information. ``Connection.statement_from_id(statement_id)`` Create a `postgresql.api.Statement` object from an existing statement identifier. This is used in cases where the statement was prepared on the server. See `Prepared Statements`_ for more information. ``Connection.cursor_from_id(cursor_id)`` Create a `postgresql.api.Cursor` object from an existing cursor identifier. This is used in cases where the cursor was declared on the server. See `Cursors`_ for more information. ``Connection.do(language, source)`` Execute a DO statement on the server using the specified language. *DO statements are available on PostgreSQL 9.0 and greater.* *Executing this method on servers that do not support DO statements will* *likely cause a SyntaxError*. ``Connection.execute(sql_statements_string)`` Run a block of SQL on the server. This method returns `None` unless an error occurs. If errors occur, the processing of the statements will stop and the error will be raised. ``Connection.xact(isolation = None, mode = None)`` The `postgresql.api.Transaction` constructor for creating transactions. This method creates a transaction reference. The transaction will not be started until it's instructed to do so. See `Transactions`_ for more information. ``Connection.settings`` A property providing a `collections.MutableMapping` interface to the database's SQL settings. See `Settings`_ for more information. ``Connection.clone()`` Create a new connection object based on the same factors that were used to create ``db``. The new connection returned will already be connected. ``Connection.msghook(msg)`` By default, the `msghook` attribute does not exist. If set to a callable, any message that occurs during an operation of the database or an operation of a database derived object will be given to the callable. See the `Database Messages`_ section for more information. ``Connection.listen(*channels)`` Start listening for asynchronous notifications in the specified channels. Sends a batch of ``LISTEN`` statements to the server. ``Connection.unlisten(*channels)`` Stop listening for asynchronous notifications in the specified channels. Sends a batch of ``UNLISTEN`` statements to the server. ``Connection.listening_channels()`` Return an iterator producing the channel names that are currently being listened to. ``Connection.notify(*channels, **channel_and_payload)`` NOTIFY the channels with the given payload. Sends a batch of ``NOTIFY`` statements to the server. Equivalent to issuing "NOTIFY " or "NOTIFY , " for each item in `channels` and `channel_and_payload`. All NOTIFYs issued will occur in the same transaction, regardless of auto-commit. The items in `channels` can either be a string or a tuple. If a string, no payload is given, but if an item is a `builtins.tuple`, the second item in the pair will be given as the payload, and the first as the channel. `channels` offers a means to issue NOTIFYs in guaranteed order:: >>> db.notify('channel1', ('different_channel', 'payload')) In the above, ``NOTIFY "channel1";`` will be issued first, followed by ``NOTIFY "different_channel", 'payload';``. The items in `channel_and_payload` are all payloaded NOTIFYs where the keys are the channels and the values are the payloads. Order is undefined:: >>> db.notify(channel_name = 'payload_data') `channels` and `channels_and_payload` can be used together. In such cases all NOTIFY statements generated from `channels_and_payload` will follow those in `channels`. ``Connection.iternotifies(timeout = None)`` Return an iterator to the NOTIFYs received on the connection. The iterator will yield notification triples consisting of ``(channel, payload, pid)``. While iterating, the connection should *not* be used in other threads. The optional timeout can be used to enable "idle" events in which `None` objects will be yielded by the iterator. See :ref:`notifyman` for details. When a connection is established, certain pieces of information are collected from the backend. The following are the attributes set on the connection object after the connection is made: ``Connection.version`` The version string of the *server*; the result of ``SELECT version()``. ``Connection.version_info`` A ``sys.version_info`` form of the ``server_version`` setting. eg. ``(8, 1, 2, 'final', 0)``. ``Connection.security`` `None` if no security. ``'ssl'`` if SSL is enabled. ``Connection.backend_id`` The process-id of the backend process. ``Connection.backend_start`` When backend was started. ``datetime.datetime`` instance. ``Connection.client_address`` The address of the client that the backend is communicating with. ``Connection.client_port`` The port of the client that the backend is communicating with. ``Connection.fileno()`` Method to get the file descriptor number of the connection's socket. This method will return `None` if the socket object does not have a ``fileno``. Under normal circumstances, it will return an `int`. The ``backend_start``, ``client_address``, and ``client_port`` are collected from pg_stat_activity. If this information is unavailable, the attributes will be `None`. Prepared Statements =================== Prepared statements are the primary entry point for initiating an operation on the database. Prepared statement objects represent a request that will, likely, be sent to the database at some point in the future. A statement is a single SQL command. The ``prepare`` entry point on the connection provides the standard method for creating a `postgersql.api.Statement` instance bound to the connection(``db``) from an SQL statement string:: >>> ps = db.prepare("SELECT 1") >>> ps() [(1,)] Statement objects may also be created from a statement identifier using the ``statement_from_id`` method on the connection. When this method is used, the statement must have already been prepared or an error will be raised. >>> db.execute("PREPARE a_statement_id AS SELECT 1;") >>> ps = db.statement_from_id('a_statement_id') >>> ps() [(1,)] When a statement is executed, it binds any given parameters to a *new* cursor and the entire result-set is returned. Statements created using ``prepare()`` will leverage garbage collection in order to automatically close statements that are no longer referenced. However, statements created from pre-existing identifiers, ``statement_from_id``, must be explicitly closed if the statement is to be discarded. Statement objects are one-time objects. Once closed, they can no longer be used. Statement Interface Points -------------------------- Prepared statements can be executed just like functions: >>> ps = db.prepare("SELECT 'hello, world!'") >>> ps() [('hello, world!',)] The default execution method, ``__call__``, produces the entire result set. It is the simplest form of statement execution. Statement objects can be executed in different ways to accommodate for the larger results or random access(scrollable cursors). Prepared statement objects have a few execution methods: ``Statement(*parameters)`` As shown before, statement objects can be invoked like a function to get the statement's results. ``Statement.rows(*parameters)`` Return a iterator to all the rows produced by the statement. This method will stream rows on demand, so it is ideal for situations where each individual row in a large result-set must be processed. ``iter(Statement)`` Convenience interface that executes the ``rows()`` method without arguments. This enables the following syntax: >>> for table_name, in db.prepare("SELECT table_name FROM information_schema.tables"): ... print(table_name) ``Statement.column(*parameters)`` Return a iterator to the first column produced by the statement. This method will stream values on demand, and *should* only be used with statements that have a single column; otherwise, bandwidth will ultimately be wasted as the other columns will be dropped. *This execution method cannot be used with COPY statements.* ``Statement.first(*parameters)`` For simple statements, cursor objects are unnecessary. Consider the data contained in ``c`` from above, 'hello world!'. To get at this data directly from the ``__call__(...)`` method, it looks something like:: >>> ps = db.prepare("SELECT 'hello, world!'") >>> ps()[0][0] 'hello, world!' To simplify access to simple data, the ``first`` method will simply return the "first" of the result set:: >>> ps.first() 'hello, world!' The first value. When the result set consists of a single column, ``first()`` will return that column in the first row. The first row. When the result set consists of multiple columns, ``first()`` will return that first row. The first, and only, row count. When DML--for instance, an INSERT-statement--is executed, ``first()`` will return the row count returned by the statement as an integer. .. note:: DML that returns row data, RETURNING, will *not* return a row count. The result set created by the statement determines what is actually returned. Naturally, a statement used with ``first()`` should be crafted with these rules in mind. ``Statement.chunks(*parameters)`` This access point is designed for situations where rows are being streamed out quickly. It is a method that returns a ``collections.Iterator`` that produces *sequences* of rows. This is the most efficient way to get rows from the database. The rows in the sequences are ``builtins.tuple`` objects. ``Statement.declare(*parameters)`` Create a scrollable cursor with hold. This returns a `postgresql.api.Cursor` ready for accessing random rows in the result-set. Applications that use the database to support paging can use this method to manage the view. ``Statement.close()`` Close the statement inhibiting further use. ``Statement.load_rows(collections.Iterable(parameters))`` Given an iterable producing parameters, execute the statement for each iteration. Always returns `None`. ``Statement.load_chunks(collections.Iterable(collections.Iterable(parameters)))`` Given an iterable of iterables producing parameters, execute the statement for each parameter produced. However, send the all execution commands with the corresponding parameters of each chunk before reading any results. Always returns `None`. This access point is designed to be used in conjunction with ``Statement.chunks()`` for transferring rows from one connection to another with great efficiency:: >>> dst.prepare(...).load_chunks(src.prepare(...).chunks()) ``Statement.clone()`` Create a new statement object based on the same factors that were used to create ``ps``. ``Statement.msghook(msg)`` By default, the `msghook` attribute does not exist. If set to a callable, any message that occurs during an operation of the statement or an operation of a statement derived object will be given to the callable. See the `Database Messages`_ section for more information. In order to provide the appropriate type transformations, the driver must acquire metadata about the statement's parameters and results. This data is published via the following properties on the statement object: ``Statement.sql_parameter_types`` A sequence of SQL type names specifying the types of the parameters used in the statement. ``Statement.sql_column_types`` A sequence of SQL type names specifying the types of the columns produced by the statement. `None` if the statement does not return row-data. ``Statement.pg_parameter_types`` A sequence of PostgreSQL type Oid's specifying the types of the parameters used in the statement. ``Statement.pg_column_types`` A sequence of PostgreSQL type Oid's specifying the types of the columns produced by the statement. `None` if the statement does not return row-data. ``Statement.parameter_types`` A sequence of Python types that the statement expects. ``Statement.column_types`` A sequence of Python types that the statement will produce. ``Statement.column_names`` A sequence of `str` objects specifying the names of the columns produced by the statement. `None` if the statement does not return row-data. The indexes of the parameter sequences correspond to the parameter's identifier, N+1: ``sql_parameter_types[0]`` -> ``'$1'``. >>> ps = db.prepare("SELECT $1::integer AS intname, $2::varchar AS chardata") >>> ps.sql_parameter_types ('INTEGER','VARCHAR') >>> ps.sql_column_types ('INTEGER','VARCHAR') >>> ps.column_names ('intname','chardata') >>> ps.column_types (, ) Parameterized Statements ------------------------ Statements can take parameters. Using statement parameters is the recommended way to interrogate the database when variable information is needed to formulate a complete request. In order to do this, the statement must be defined using PostgreSQL's positional parameter notation. ``$1``, ``$2``, ``$3``, etc:: >>> ps = db.prepare("SELECT $1") >>> ps('hello, world!')[0][0] 'hello, world!' PostgreSQL determines the type of the parameter based on the context of the parameter's identifier:: >>> ps = db.prepare( ... "SELECT * FROM information_schema.tables WHERE table_name = $1 LIMIT $2" ... ) >>> ps("tables", 1) [('postgres', 'information_schema', 'tables', 'VIEW', None, None, None, None, None, 'NO', 'NO', None)] Parameter ``$1`` in the above statement will take on the type of the ``table_name`` column and ``$2`` will take on the type required by the LIMIT clause(text and int8). However, parameters can be forced to a specific type using explicit casts: >>> ps = db.prepare("SELECT $1::integer") >>> ps.first(-400) -400 Parameters are typed. PostgreSQL servers provide the driver with the type information about a positional parameter, and the serialization routine will raise an exception if the given object is inappropriate. The Python types expected by the driver for a given SQL-or-PostgreSQL type are listed in `Type Support`_. This usage of types is not always convenient. Notably, the `datetime` module does not provide a friendly way for a user to express intervals, dates, or times. There is a likely inclination to forego these parameter type requirements. In such cases, explicit casts can be made to work-around the type requirements:: >>> ps = db.prepare("SELECT $1::text::date") >>> ps.first('yesterday') datetime.date(2009, 3, 11) The parameter, ``$1``, is given to the database as a string, which is then promptly cast into a date. Of course, without the explicit cast as text, the outcome would be different:: >>> ps = db.prepare("SELECT $1::date") >>> ps.first('yesterday') Traceback: ... postgresql.exceptions.ParameterError The function that processes the parameter expects a `datetime.date` object, and the given `str` object does not provide the necessary interfaces for the conversion, so the driver raises a `postgresql.exceptions.ParameterError` from the original conversion exception. Inserting and DML ----------------- Loading data into the database is facilitated by prepared statements. In these examples, a table definition is necessary for a complete illustration:: >>> db.execute( ... """ ... CREATE TABLE employee ( ... employee_name text, ... employee_salary numeric, ... employee_dob date, ... employee_hire_date date ... ); ... """ ... ) Create an INSERT statement using ``prepare``:: >>> mkemp = db.prepare("INSERT INTO employee VALUES ($1, $2, $3, $4)") And add "Mr. Johnson" to the table:: >>> import datetime >>> r = mkemp( ... "John Johnson", ... "92000", ... datetime.date(1950, 12, 10), ... datetime.date(1998, 4, 23) ... ) >>> print(r[0]) INSERT >>> print(r[1]) 1 The execution of DML will return a tuple. This tuple contains the completed command name and the associated row count. Using the call interface is fine for making a single insert, but when multiple records need to be inserted, it's not the most efficient means to load data. For multiple records, the ``ps.load_rows([...])`` provides an efficient way to load large quantities of structured data:: >>> from datetime import date >>> mkemp.load_rows([ ... ("Jack Johnson", "85000", date(1962, 11, 23), date(1990, 3, 5)), ... ("Debra McGuffer", "52000", date(1973, 3, 4), date(2002, 1, 14)), ... ("Barbara Smith", "86000", date(1965, 2, 24), date(2005, 7, 19)), ... ]) While small, the above illustrates the ``ps.load_rows()`` method taking an iterable of tuples that provides parameters for the each execution of the statement. ``load_rows`` is also used to support ``COPY ... FROM STDIN`` statements:: >>> copy_emps_in = db.prepare("COPY employee FROM STDIN") >>> copy_emps_in.load_rows([ ... b'Emp Name1\t72000\t1970-2-01\t1980-10-22\n', ... b'Emp Name2\t62000\t1968-9-11\t1985-11-1\n', ... b'Emp Name3\t62000\t1968-9-11\t1985-11-1\n', ... ]) Copy data goes in as bytes and come out as bytes regardless of the type of COPY taking place. It is the user's obligation to make sure the row-data is in the appropriate encoding. COPY Statements --------------- `postgresql.driver` transparently supports PostgreSQL's COPY command. To the user, COPY will act exactly like other statements that produce tuples; COPY tuples, however, are `bytes` objects. The only distinction in usability is that the COPY *should* be completed before other actions take place on the connection--this is important when a COPY is invoked via ``rows()`` or ``chunks()``. In situations where other actions are invoked during a ``COPY TO STDOUT``, the entire result set of the COPY will be read. However, no error will be raised so long as there is enough memory available, so it is *very* desirable to avoid doing other actions on the connection while a COPY is active. In situations where other actions are invoked during a ``COPY FROM STDIN``, a COPY failure error will occur. The driver manages the connection state in such a way that will purposefully cause the error as the COPY was inappropriately interrupted. This not usually a problem as ``load_rows(...)`` and ``load_chunks(...)`` methods must complete the COPY command before returning. Copy data is always transferred using ``bytes`` objects. Even in cases where the COPY is not in ``BINARY`` mode. Any needed encoding transformations *must* be made the caller. This is done to avoid any unnecessary overhead by default:: >>> ps = db.prepare("COPY (SELECT i FROM generate_series(0, 99) AS g(i)) TO STDOUT") >>> r = ps() >>> len(r) 100 >>> r[0] b'0\n' >>> r[-1] b'99\n' Of course, invoking a statement that way will read the entire result-set into memory, which is not usually desirable for COPY. Using the ``chunks(...)`` iterator is the *fastest* way to move data:: >>> ci = ps.chunks() >>> import sys >>> for rowset in ps.chunks(): ... sys.stdout.buffer.writelines(rowset) ... ``COPY FROM STDIN`` commands are supported via `postgresql.api.Statement.load_rows`. Each invocation to ``load_rows`` is a single invocation of COPY. ``load_rows`` takes an iterable of COPY lines to send to the server:: >>> db.execute(""" ... CREATE TABLE sample_copy ( ... sc_number int, ... sc_text text ... ); ... """) >>> copyin = db.prepare('COPY sample_copy FROM STDIN') >>> copyin.load_rows([ ... b'123\tone twenty three\n', ... b'350\ttree fitty\n', ... ]) For direct connection-to-connection COPY, use of ``load_chunks(...)`` is recommended as it will provide the most efficient transfer method:: >>> copyout = src.prepare('COPY atable TO STDOUT') >>> copyin = dst.prepare('COPY atable FROM STDIN') >>> copyin.load_chunks(copyout.chunks()) Specifically, each chunk of row data produced by ``chunks()`` will be written in full by ``load_chunks()`` before getting another chunk to write. Cursors ======= When a prepared statement is declared, ``ps.declare(...)``, a `postgresql.api.Cursor` is created and returned for random access to the rows in the result set. Direct use of cursors is primarily useful for applications that need to implement paging. For situations that need to iterate over the result set, the ``ps.rows(...)`` or ``ps.chunks(...)`` execution methods should be used. Cursors can also be created directly from ``cursor_id``'s using the ``cursor_from_id`` method on connection objects:: >>> db.execute('DECLARE the_cursor_id CURSOR WITH HOLD FOR SELECT 1;') >>> c = db.cursor_from_id('the_cursor_id') >>> c.read() [(1,)] >>> c.close() .. hint:: If the cursor that needs to be opened is going to be treated as an iterator, then a FETCH-statement should be prepared instead using ``cursor_from_id``. Like statements created from an identifier, cursors created from an identifier must be explicitly closed in order to destroy the object on the server. Likewise, cursors created from statement invocations will be automatically released when they are no longer referenced. .. note:: PG-API cursors are a direct interface to single result-set SQL cursors. This is in contrast with DB-API cursors, which have interfaces for dealing with multiple result-sets. There is no execute method on PG-API cursors. Cursor Interface Points ----------------------- For cursors that return row data, these interfaces are provided for accessing those results: ``Cursor.read(quantity = None, direction = None)`` This method name is borrowed from `file` objects, and are semantically similar. However, this being a cursor, rows are returned instead of bytes or characters. When the number of rows returned is less then the quantity requested, it means that the cursor has been exhausted in the configured direction. The ``direction`` argument can be either ``'FORWARD'`` or `True` to FETCH FORWARD, or ``'BACKWARD'`` or `False` to FETCH BACKWARD. Like, ``seek()``, the ``direction`` *property* on the cursor object effects this method. ``Cursor.seek(position[, whence = 0])`` When the cursor is scrollable, this seek interface can be used to move the position of the cursor. See `Scrollable Cursors`_ for more information. ``next(Cursor)`` This fetches the next row in the cursor object. Cursors support the iterator protocol. While equivalent to ``cursor.read(1)[0]``, `StopIteration` is raised if the returned sequence is empty. (``__next__()``) ``Cursor.close()`` For cursors opened using ``cursor_from_id()``, this method must be called in order to ``CLOSE`` the cursor. For cursors created by invoking a prepared statement, this is not necessary as the garbage collection interface will take the appropriate steps. ``Cursor.clone()`` Create a new cursor object based on the same factors that were used to create ``c``. ``Cursor.msghook(msg)`` By default, the `msghook` attribute does not exist. If set to a callable, any message that occurs during an operation of the cursor will be given to the callable. See the `Database Messages`_ section for more information. Cursors have some additional configuration properties that may be modified during the use of the cursor: ``Cursor.direction`` A value of `True`, the default, will cause read to fetch forwards, whereas a value of `False` will cause it to fetch backwards. ``'BACKWARD'`` and ``'FORWARD'`` can be used instead of `False` and `True`. Cursors normally share metadata with the statements that create them, so it is usually unnecessary for referencing the cursor's column descriptions directly. However, when a cursor is opened from an identifier, the cursor interface must collect the metadata itself. These attributes provide the metadata in absence of a statement object: ``Cursor.sql_column_types`` A sequence of SQL type names specifying the types of the columns produced by the cursor. `None` if the cursor does not return row-data. ``Cursor.pg_column_types`` A sequence of PostgreSQL type Oid's specifying the types of the columns produced by the cursor. `None` if the cursor does not return row-data. ``Cursor.column_types`` A sequence of Python types that the cursor will produce. ``Cursor.column_names`` A sequence of `str` objects specifying the names of the columns produced by the cursor. `None` if the cursor does not return row-data. ``Cursor.statement`` The statement that was executed that created the cursor. `None` if unknown--``db.cursor_from_id()``. Scrollable Cursors ------------------ Scrollable cursors are supported for applications that need to implement paging. When statements are invoked via the ``declare(...)`` method, the returned cursor is scrollable. .. note:: Scrollable cursors never pre-fetch in order to provide guaranteed positioning. The cursor interface supports scrolling using the ``seek`` method. Like ``read``, it is semantically similar to a file object's ``seek()``. ``seek`` takes two arguments: ``position`` and ``whence``: ``position`` The position to scroll to. The meaning of this is determined by ``whence``. ``whence`` How to use the position: absolute, relative, or absolute from end: absolute: ``'ABSOLUTE'`` or ``0`` (default) seek to the absolute position in the cursor relative to the beginning of the cursor. relative: ``'RELATIVE'`` or ``1`` seek to the relative position. Negative ``position``'s will cause a MOVE backwards, while positive ``position``'s will MOVE forwards. from end: ``'FROM_END'`` or ``2`` seek to the end of the cursor and then MOVE backwards by the given ``position``. The ``whence`` keyword argument allows for either numeric and textual specifications. Scrolling through employees:: >>> emps_by_age = db.prepare(""" ... SELECT ... employee_name, employee_salary, employee_dob, employee_hire_date, ... EXTRACT(years FROM AGE(employee_dob)) AS age ... ORDER BY age ASC ... """) >>> c = emps_by_age.declare() >>> # seek to the end, ``2`` works as well. >>> c.seek(0, 'FROM_END') >>> # scroll back one, ``1`` works as well. >>> c.seek(-1, 'RELATIVE') >>> # and back to the beginning again >>> c.seek(0) Additionally, scrollable cursors support backward fetches by specifying the direction keyword argument:: >>> c.seek(0, 2) >>> c.read(1, 'BACKWARD') Cursor Direction ---------------- The ``direction`` property on the cursor states the default direction for read and seek operations. Normally, the direction is `True`, ``'FORWARD'``. When the property is set to ``'BACKWARD'`` or `False`, the read method will fetch backward by default, and seek operations will be inverted to simulate a reversely ordered cursor. The following example illustrates the effect:: >>> reverse_c = db.prepare('SELECT i FROM generate_series(99, 0, -1) AS g(i)').declare() >>> c = db.prepare('SELECT i FROM generate_series(0, 99) AS g(i)').declare() >>> reverse_c.direction = 'BACKWARD' >>> reverse_c.seek(0) >>> c.read() == reverse_c.read() Furthermore, when the cursor is configured to read backwards, specifying ``'BACKWARD'`` for read's ``direction`` argument will ultimately cause a forward fetch. This potentially confusing facet of direction configuration is implemented in order to create an appropriate symmetry in functionality. The cursors in the above example contain the same rows, but are ultimately in reverse order. The backward direction property is designed so that the effect of any read or seek operation on those cursors is the same:: >>> reverse_c.seek(50) >>> c.seek(50) >>> c.read(10) == reverse_c.read(10) >>> c.read(10, 'BACKWARD') == reverse_c.read(10, 'BACKWARD') And for relative seeks:: >>> c.seek(-10, 1) >>> reverse_c.seek(-10, 1) >>> c.read(10, 'BACKWARD') == reverse_c.read(10, 'BACKWARD') Rows ==== Rows received from PostgreSQL are instantiated into `postgresql.types.Row` objects. Rows are both a sequence and a mapping. Items accessed with an `int` are seen as indexes and other objects are seen as keys:: >>> row = db.prepare("SELECT 't'::text AS col0, 2::int4 AS col1").first() >>> row ('t', 2) >>> row[0] 't' >>> row["col0"] 't' However, this extra functionality is not free. The cost of instantiating `postgresql.types.Row` objects is quite measurable, so the `chunks()` execution method will produce `builtins.tuple` objects for cases where performance is critical. .. note:: Attributes aren't used to provide access to values due to potential conflicts with existing method and property names. Row Interface Points -------------------- Rows implement the `collections.Mapping` and `collections.Sequence` interfaces. ``Row.keys()`` An iterable producing the column names. Order is not guaranteed. See the ``column_names`` property to get an ordered sequence. ``Row.values()`` Iterable to the values in the row. ``Row.get(key_or_index[, default=None])`` Get the item in the row. If the key doesn't exist or the index is out of range, return the default. ``Row.items()`` Iterable of key-value pairs. Ordered by index. ``iter(Row)`` Iterable to the values in index order. ``value in Row`` Whether or not the value exists in the row. (__contains__) ``Row[key_or_index]`` If ``key_or_index`` is an integer, return the value at that index. If the index is out of range, raise an `IndexError`. Otherwise, return the value associated with column name. If the given key, ``key_or_index``, does not exist, raise a `KeyError`. ``Row.index_from_key(key)`` Return the index associated with the given key. ``Row.key_from_index(index)`` Return the key associated with the given index. ``Row.transform(*args, **kw)`` Create a new row object of the same length, with the same keys, but with new values produced by applying the given callables to the corresponding items. Callables given as ``args`` will be associated with values by their index and callables given as keywords will be associated with values by their key, column name. While the mapping interfaces will provide most of the needed information, some additional properties are provided for consistency with statement and cursor objects. ``Row.column_names`` Property providing an ordered sequence of column names. The index corresponds to the row value-index that the name refers to. >>> row[row.column_names[i]] == row[i] Row Transformations ------------------- After a row is returned, sometimes the data in the row is not in the desired format. Further processing is needed if the row object is to going to be given to another piece of code which requires an object of differring consistency. The ``transform`` method on row objects provides a means to create a new row object consisting of the old row's items, but with certain columns transformed using the given callables:: >>> row = db.prepare(""" ... SELECT ... 'XX9301423'::text AS product_code, ... 2::int4 AS quantity, ... '4.92'::numeric AS total ... """).first() >>> row ('XX9301423', 2, Decimal("4.92")) >>> row.transform(quantity = str) ('XX9301423', '2', Decimal("4.92")) ``transform`` supports both positional and keyword arguments in order to assign the callable for a column's transformation:: >>> from operator import methodcaller >>> row.transform(methodcaller('strip', 'XX')) ('9301423', 2, Decimal("4.92")) Of course, more than one column can be transformed:: >>> stripxx = methodcaller('strip', 'XX') >>> row.transform(stripxx, str, str) ('9301423', '2', '4.92') `None` can also be used to indicate no transformation:: >>> row.transform(None, str, str) ('XX9301423', '2', '4.92') More advanced usage can make use of lambdas for compound transformations in a single pass of the row:: >>> strip_and_int = lambda x: int(stripxx(x)) >>> row.transform(strip_and_int) (9301423, 2, Decimal("4.92")) Transformations will be, more often than not, applied against *rows* as opposed to *a* row. Using `operator.methodcaller` with `map` provides the necessary functionality to create simple iterables producing transformed row sequences:: >>> import decimal >>> apply_tax = lambda x: (x * decimal.Decimal("0.1")) + x >>> transform_row = methodcaller('transform', strip_and_int, None, apply_tax) >>> r = map(transform_row, [row]) >>> list(r) [(9301423, 2, Decimal('5.412'))] And finally, `functools.partial` can be used to create a simple callable:: >>> from functools import partial >>> transform_rows = partial(map, transform_row) >>> list(transform_rows([row])) [(9301423, 2, Decimal('5.412'))] Queries ======= Queries in `py-postgresql` are single use prepared statements. They exist primarily for syntactic convenience, but they also allow the driver to recognize the short lifetime of the statement. Single use statements are supported using the ``query`` property on connection objects, :py:class:`postgresql.api.Connection.query`. The statement object is not available when using queries as the results, or handle to the results, are directly returned. Queries have access to all execution methods: * ``Connection.query(sql, *parameters)`` * ``Connection.query.rows(sql, *parameters)`` * ``Connection.query.column(sql, *parameters)`` * ``Connection.query.first(sql, *parameters)`` * ``Connection.query.chunks(sql, *parameters)`` * ``Connection.query.declare(sql, *parameters)`` * ``Connection.query.load_rows(sql, collections.Iterable(parameters))`` * ``Connection.query.load_chunks(collections.Iterable(collections.Iterable(parameters)))`` In cases where a sequence of one-shot queries needs to be performed, it may be important to avoid unnecessary repeat attribute resolution from the connection object as the ``query`` property is an interface object created on access. Caching the target execution methods is recommended:: qrows = db.query.rows l = [] for x in my_queries: l.append(qrows(x)) The characteristic of Each execution method is discussed in the prior `Prepared Statements`_ section. Stored Procedures ================= The ``proc`` method on `postgresql.api.Database` objects provides a means to create a reference to a stored procedure on the remote database. `postgresql.api.StoredProcedure` objects are used to represent the referenced SQL routine. This provides a direct interface to functions stored on the database. It leverages knowledge of the parameters and results of the function in order to provide the user with a natural interface to the procedure:: >>> func = db.proc('version()') >>> func() 'PostgreSQL 8.3.6 on ...' Stored Procedure Interface Points --------------------------------- It's more-or-less a function, so there's only one interface point: ``func(*args, **kw)`` (``__call__``) Stored procedure objects are callable, executing a procedure will return an object of suitable representation for a given procedure's type signature. If it returns a single object, it will return the single object produced by the procedure. If it's a set returning function, it will return an *iterable* to the values produced by the procedure. In cases of set returning function with multiple OUT-parameters, a cursor will be returned. Stored Procedure Type Support ----------------------------- Stored procedures support most types of functions. "Function Types" being set returning functions, multiple-OUT parameters, and simple single-object returns. Set-returning functions, SRFs return a sequence:: >>> generate_series = db.proc('generate_series(int,int)') >>> gs = generate_series(1, 20) >>> gs > >>> next(gs) 1 >>> list(gs) [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] For functions like ``generate_series()``, the driver is able to identify that the return is a sequence of *solitary* integer objects, so the result of the function is just that, a sequence of integers. Functions returning composite types are recognized, and return row objects:: >>> db.execute(""" ... CREATE FUNCTION composite(OUT i int, OUT t text) ... LANGUAGE SQL AS ... $body$ ... SELECT 900::int AS i, 'sample text'::text AS t; ... $body$; ... """) >>> composite = db.proc('composite()') >>> r = composite() >>> r (900, 'sample text') >>> r['i'] 900 >>> r['t'] 'sample text' Functions returning a set of composites are recognized, and the result is a `postgresql.api.Cursor` object whose column names are consistent with the names of the OUT parameters:: >>> db.execute(""" ... CREATE FUNCTION srfcomposite(out i int, out t text) ... RETURNS SETOF RECORD ... LANGUAGE SQL AS ... $body$ ... SELECT 900::int AS i, 'sample text'::text AS t ... UNION ALL ... SELECT 450::int AS i, 'more sample text'::text AS t ... $body$; ... """) >>> srfcomposite = db.proc('srfcomposite()') >>> r = srfcomposite() >>> next(r) (900, 'sample text') >>> v = next(r) >>> v['i'], v['t'] (450, 'more sample text') Transactions ============ Transactions are managed by creating an object corresponding to a transaction started on the server. A transaction is a transaction block, a savepoint, or a prepared transaction. The ``xact(...)`` method on the connection object provides the standard method for creating a `postgresql.api.Transaction` object to manage a transaction on the connection. The creation of a transaction object does not start the transaction. Rather, the transaction must be explicitly started using the ``start()`` method on the transaction object. Usually, transactions *should* be managed with the context manager interfaces:: >>> with db.xact(): ... ... The transaction in the above example is opened, started, by the ``__enter__`` method invoked by the with-statement's usage. It will be subsequently committed or rolled-back depending on the exception state and the error state of the connection when ``__exit__`` is called. **Using the with-statement syntax for managing transactions is strongly recommended.** By using the transaction's context manager, it allows for Python exceptions to be properly treated as fatal to the transaction as when an uncaught exception of any kind occurs within the block, it is unlikely that the state of the transaction can be trusted. Additionally, the ``__exit__`` method provides a safe-guard against invalid commits. This can occur if a database error is inappropriately caught within a block without being raised. The context manager interfaces are higher level interfaces to the explicit instruction methods provided by `postgresql.api.Transaction` objects. Transaction Configuration ------------------------- Keyword arguments given to ``xact()`` provide the means for configuring the properties of the transaction. Only three points of configuration are available: ``isolation`` The isolation level of the transaction. This must be a string. It will be interpolated directly into the START TRANSACTION statement. Normally, 'SERIALIZABLE' or 'READ COMMITTED': >>> with db.xact('SERIALIZABLE'): ... ... ``mode`` A string, 'READ ONLY' or 'READ WRITE'. States the mutability of stored information in the database. Like ``isolation``, this is interpolated directly into the START TRANSACTION string. The specification of any of these transaction properties imply that the transaction is a block. Savepoints do not take configuration, so if a transaction identified as a block is started while another block is running, an exception will be raised. Transaction Interface Points ---------------------------- The methods available on transaction objects manage the state of the transaction and relay any necessary instructions to the remote server in order to reflect that change of state. >>> x = db.xact(...) ``x.start()`` Start the transaction. ``x.commit()`` Commit the transaction. ``x.rollback()`` Abort the transaction. These methods are primarily provided for applications that manage transactions in a way that cannot be formed around single, sequential blocks of code. Generally, using these methods require additional work to be performed by the code that is managing the transaction. If usage of these direct, instructional methods is necessary, it is important to note that if the database is in an error state when a *transaction block's* commit() is executed, an implicit rollback will occur. The transaction object will simply follow instructions and issue the ``COMMIT`` statement, and it will succeed without exception. Error Control ------------- Handling *database* errors inside transaction CMs is generally discouraged as any database operation that occurs within a failed transaction is an error itself. It is important to trap any recoverable database errors *outside* of the scope of the transaction's context manager: >>> try: ... with db.xact(): ... ... ... except postgresql.exceptions.UniqueError: ... pass In cases where the database is in an error state, but the context exits without an exception, a `postgresql.exceptions.InFailedTransactionError` is raised by the driver: >>> with db.xact(): ... try: ... ... ... except postgresql.exceptions.UniqueError: ... pass ... Traceback (most recent call last): ... postgresql.exceptions.InFailedTransactionError: invalid block exit detected CODE: 25P02 SEVERITY: ERROR Normally, if a ``COMMIT`` is issued on a failed transaction, the command implies a ``ROLLBACK`` without error. This is a very undesirable result for the CM's exit as it may allow for code to be ran that presumes the transaction was committed. The driver intervenes here and raises the `postgresql.exceptions.InFailedTransactionError` to safe-guard against such cases. This effect is consistent with savepoint releases that occur during an error state. The distinction between the two cases is made using the ``source`` property on the raised exception. Settings ======== SQL's SHOW and SET provides a means to configure runtime parameters on the database("GUC"s). In order to save the user some grief, a `collections.MutableMapping` interface is provided to simplify configuration. The ``settings`` attribute on the connection provides the interface extension. The standard dictionary interface is supported: >>> db.settings['search_path'] = "$user,public" And ``update(...)`` is better performing for multiple sets: >>> db.settings.update({ ... 'search_path' : "$user,public", ... 'default_statistics_target' : "1000" ... }) .. note:: The ``transaction_isolation`` setting cannot be set using the ``settings`` mapping. Internally, ``settings`` uses ``set_config``, which cannot adjust that particular setting. Settings Interface Points ------------------------- Manipulation and interrogation of the connection's settings is achieved by using the standard `collections.MutableMapping` interfaces. ``Connection.settings[k]`` Get the value of a single setting. ``Connection.settings[k] = v`` Set the value of a single setting. ``Connection.settings.update([(k1,v2), (k2,v2), ..., (kn,vn)])`` Set multiple settings using a sequence of key-value pairs. ``Connection.settings.update({k1 : v1, k2 : v2, ..., kn : vn})`` Set multiple settings using a dictionary or mapping object. ``Connection.settings.getset([k1, k2, ..., kn])`` Get a set of a settings. This is the most efficient way to get multiple settings as it uses a single request. ``Connection.settings.keys()`` Get all available setting names. ``Connection.settings.values()`` Get all setting values. ``Connection.settings.items()`` Get a sequence of key-value pairs corresponding to all settings on the database. Settings Management ------------------- `postgresql.api.Settings` objects can create context managers when called. This gives the user with the ability to specify sections of code that are to be ran with certain settings. The settings' context manager takes full advantage of keyword arguments in order to configure the context manager: >>> with db.settings(search_path = 'local,public', timezone = 'mst'): ... ... `postgresql.api.Settings` objects are callable; the return is a context manager configured with the given keyword arguments representing the settings to use for the block of code that is about to be executed. When the block exits, the settings will be restored to the values that they had before the block entered. Type Support ============ The driver supports a large number of PostgreSQL types at the binary level. Most types are converted to standard Python types. The remaining types are usually PostgreSQL specific types that are converted into objects whose class is defined in `postgresql.types`. When a conversion function is not available for a particular type, the driver will use the string format of the type and instantiate a `str` object for the data. It will also expect `str` data when parameter of a type without a conversion function is bound. .. note:: Generally, these standard types are provided for convenience. If conversions into these datatypes are not desired, it is recommended that explicit casts into ``text`` are made in statement string. .. table:: Python types used to represent PostgreSQL types. ================================= ================================== =========== PostgreSQL Types Python Types SQL Types ================================= ================================== =========== `postgresql.types.INT2OID` `int` smallint `postgresql.types.INT4OID` `int` integer `postgresql.types.INT8OID` `int` bigint `postgresql.types.FLOAT4OID` `float` float `postgresql.types.FLOAT8OID` `float` double `postgresql.types.VARCHAROID` `str` varchar `postgresql.types.BPCHAROID` `str` char `postgresql.types.XMLOID` `xml.etree` (cElementTree) xml `postgresql.types.DATEOID` `datetime.date` date `postgresql.types.TIMESTAMPOID` `datetime.datetime` timestamp `postgresql.types.TIMESTAMPTZOID` `datetime.datetime` (tzinfo) timestamptz `postgresql.types.TIMEOID` `datetime.time` time `postgresql.types.TIMETZOID` `datetime.time` timetz `postgresql.types.INTERVALOID` `datetime.timedelta` interval `postgresql.types.NUMERICOID` `decimal.Decimal` numeric `postgresql.types.BYTEAOID` `bytes` bytea `postgresql.types.TEXTOID` `str` text `dict` hstore ================================= ================================== =========== The mapping in the above table *normally* goes both ways. So when a parameter is passed to a statement, the type *should* be consistent with the corresponding Python type. However, many times, for convenience, the object will be passed through the type's constructor, so it is not always necessary. Arrays ------ Arrays of PostgreSQL types are supported with near transparency. For simple arrays, arbitrary iterables can just be given as a statement's parameter and the array's constructor will consume the objects produced by the iterator into a `postgresql.types.Array` instance. However, in situations where the array has multiple dimensions, `list` objects are used to delimit the boundaries of the array. >>> ps = db.prepare("select $1::int[]") >>> ps.first([(1,2), (2,3)]) Traceback: ... postgresql.exceptions.ParameterError In the above case, it is apparent that this array is supposed to have two dimensions. However, this is not the case for other types: >>> ps = db.prepare("select $1::point[]") >>> ps.first([(1,2), (2,3)]) postgresql.types.Array([postgresql.types.point((1.0, 2.0)), postgresql.types.point((2.0, 3.0))]) Lists are used to provide the necessary boundary information: >>> ps = db.prepare("select $1::int[]") >>> ps.first([[1,2],[2,3]]) postgresql.types.Array([[1,2],[2,3]]) The above is the appropriate way to define the array from the original example. .. hint:: The root-iterable object given as an array parameter does not need to be a list-type as it's assumed to be made up of elements. Composites ---------- Composites are supported using `postgresql.types.Row` objects to represent the data. When a composite is referenced for the first time, the driver queries the database for information about the columns that make up the type. This information is then used to create the necessary I/O routines for packing and unpacking the parameters and columns of that type:: >>> db.execute("CREATE TYPE ctest AS (i int, t text, n numeric);") >>> ps = db.prepare("SELECT $1::ctest") >>> i = (100, 'text', "100.02013") >>> r = ps.first(i) >>> r["t"] 'text' >>> r["n"] Decimal("100.02013") Or if use of a dictionary is desired:: >>> r = ps.first({'t' : 'just-the-text'}) >>> r (None, 'just-the-text', None) When a dictionary is given to construct the row, absent values are filled with `None`. .. _db_messages: Database Messages ================= By default, py-postgresql gives detailed reports of messages emitted by the database. Often, the verbosity is excessive due to single target processes or existing application infrastructure for tracing the sources of various events. Normally, this verbosity is not a significant problem as the driver defaults the ``client_min_messages`` setting to ``'WARNING'`` by default. However, if ``NOTICE`` or ``INFO`` messages are needed, finer grained control over message propagation may be desired, py-postgresql's object relationship model provides a common protocol for controlling message propagation and, ultimately, display. The ``msghook`` attribute on elements--for instance, Statements, Connections, and Connectors--is absent by default. However, when present on an object that contributed the cause of a message event, it will be invoked with the Message, `postgresql.message.Message`, object as its sole parameter. The attribute of the object that is closest to the event is checked first, if present it will be called. If the ``msghook()`` call returns a `True` value(specficially, ``bool(x) is True``), the message will *not* be propagated any further. However, if a `False` value--notably, `None`--is returned, the next element is checked until the list is exhausted and the message is given to `postgresql.sys.msghook`. The normal list of elements is as follows:: Output → Statement → Connection → Connector → Driver → postgresql.sys Where ``Output`` can be a `postgresql.api.Cursor` object produced by ``declare(...)`` or an implicit output management object used *internally* by ``Statement.__call__()`` and other statement execution methods. Setting the ``msghook`` attribute on `postgresql.api.Statement` gives very fine control over raised messages. Consider filtering the notice message on create table statements that implicitly create indexes:: >>> db = postgresql.open(...) >>> db.settings['client_min_messages'] = 'NOTICE' >>> ct_this = db.prepare('CREATE TEMP TABLE "this" (i int PRIMARY KEY)') >>> ct_that = db.prepare('CREATE TEMP TABLE "that" (i int PRIMARY KEY)') >>> def filter_notices(msg): ... if msg.details['severity'] == 'NOTICE': ... return True ... >>> ct_that() NOTICE: CREATE TABLE / PRIMARY KEY will create implicit index "that_pkey" for table "that" ... ('CREATE TABLE', None) >>> ct_this.msghook = filter_notices >>> ct_this() ('CREATE TABLE', None) >>> The above illustrates the quality of an installed ``msghook`` that simply inhibits further propagation of messages with a severity of 'NOTICE'--but, only notices coming from objects derived from the ``ct_this`` `postgresql.api.Statement` object. Subsequently, if the filter is installed on the connection's ``msghook``:: >>> db = postgresql.open(...) >>> db.settings['client_min_messages'] = 'NOTICE' >>> ct_this = db.prepare('CREATE TEMP TABLE "this" (i int PRIMARY KEY)') >>> ct_that = db.prepare('CREATE TEMP TABLE "that" (i int PRIMARY KEY)') >>> def filter_notices(msg): ... if msg.details['severity'] == 'NOTICE': ... return True ... >>> db.msghook = filter_notices >>> ct_that() ('CREATE TABLE', None) >>> ct_this() ('CREATE TABLE', None) >>> Any message with ``'NOTICE'`` severity coming from the connection, ``db``, will be suffocated by the ``filter_notices`` function. However, if a ``msghook`` is installed on either of those statements, it would be possible for display to occur depending on the implementation of the hook installed on the statement objects. Message Metadata ---------------- PostgreSQL messages, `postgresql.message.Message`, are primarily described in three parts: the SQL-state code, the main message string, and a mapping containing the details. The follow attributes are available on message objects: ``Message.message`` The primary message string. ``Message.code`` The SQL-state code associated with a given message. ``Message.source`` The origins of the message. Normally, ``'SERVER'`` or ``'CLIENT'``. ``Message.location`` A terse, textual representation of ``'file'``, ``'line'``, and ``'function'`` provided by the associated ``details``. ``Message.details`` A mapping providing extended information about a message. This mapping object **can** contain the following keys: ``'severity'`` Any of ``'DEBUG'``, ``'INFO'``, ``'NOTICE'``, ``'WARNING'``, ``'ERROR'``, ``'FATAL'``, or ``'PANIC'``; the latter three are usually associated with a `postgresql.exceptions.Error` instance. ``'context'`` The CONTEXT portion of the message. ``'detail'`` The DETAIL portion of the message. ``'hint'`` The HINT portion of the message. ``'position'`` A number identifying the position in the statement string that caused a parse error. ``'file'`` The name of the file that emitted the message. (*normally* server information) ``'function'`` The name of the function that emitted the message. (*normally* server information) ``'line'`` The line of the file that emitted the message. (*normally* server information) fe-1.1.0/postgresql/documentation/gotchas.rst000066400000000000000000000125571203372773200213610ustar00rootroot00000000000000Gotchas ======= It is recognized that decisions were made that may not always be ideal for a given user. In order to highlight those potential issues and hopefully bring some sense into a confusing situation, this document was drawn. Non-English Locales ------------------- Many non-english locales are not supported due to the localization of the severity field in messages and errors sent to the client. Internally, py-postgresql uses this to allow client side filtering of messages and to identify FATAL connection errors that allow the client to recognize that it should be expecting the connection to terminate. Thread Safety ------------- py-postgresql connection operations are not thread safe. `client_encoding` setting should be altered carefully ----------------------------------------------------- `postgresql.driver`'s streaming cursor implementation reads a fixed set of rows when it queries the server for more. In order to optimize some situations, the driver will send a request for more data, but makes no attempt to wait and process the data as it is not yet needed. When the user comes back to read more data from the cursor, it will then look at this new data. The problem being, if `client_encoding` was switched, it may use the wrong codec to transform the wire data into higher level Python objects(str). To avoid this problem from ever happening, set the `client_encoding` early. Furthermore, it is probably best to never change the `client_encoding` as the driver automatically makes the necessary transformation to Python strings. The user and password is correct, but it does not work when using `postgresql.driver` ------------------------------------------------------------------------------------- This issue likely comes from the possibility that the information sent to the server early in the negotiation phase may not be in an encoding that is consistent with the server's encoding. One problem is that PostgreSQL does not provide the client with the server encoding early enough in the negotiation phase, and, therefore, is unable to process the password data in a way that is consistent with the server's expectations. Another problem is that PostgreSQL takes much of the data in the startup message as-is, so a decision about the best way to encode parameters is difficult. The easy way to avoid *most* issues with this problem is to initialize the database in the `utf-8` encoding. The driver defaults the expected server encoding to `utf-8`. However, this can be overridden by creating the `Connector` with a `server_encoding` parameter. Setting `server_encoding` to the proper value of the target server will allow the driver to properly encode *some* of the parameters. Also, any GUC parameters passed via the `settings` parameter should use typed objects when possible to hint that the server encoding should not be used on that parameter(`bytes`, for instance). Backslash characters are being treated literally ------------------------------------------------ The driver enables standard compliant strings. Stop using non-standard features. ;) If support for non-standard strings was provided it would require to the driver to provide subjective quote interfaces(eg, db.quote_literal). Doing so is not desirable as it introduces difficulties for the driver *and* the user. Types without binary support in the driver are unsupported in arrays and records -------------------------------------------------------------------------------- When an array or composite type is identified, `postgresql.protocol.typio` ultimately chooses the binary format for the transfer of the column or parameter. When this is done, PostgreSQL will pack or expect *all* the values in binary format as well. If that binary format is not supported and the type is not an string, it will fail to unpack the row or pack the appropriate data for the element or attribute. In most cases issues related to this can be avoided with explicit casts to text. NOTICEs, WARNINGs, and other messages are too verbose ----------------------------------------------------- For many situations, the information provided with database messages is far too verbose. However, considering that py-postgresql is a programmer's library, the default of high verbosity is taken with the express purpose of allowing the programmer to "adjust the volume" until appropriate. By default, py-postgresql adjusts the ``client_min_messages`` to only emit messages at the WARNING level or higher--ERRORs, FATALs, and PANICs. This reduces the number of messages generated by most connections dramatically. If further customization is needed, the :ref:`db_messages` section has information on overriding the default action taken with database messages. Strange TypeError using load_rows() or load_chunks() ---------------------------------------------------- When a prepared statement is directly executed using ``__call__()``, it can easily validate that the appropriate number of parameters are given to the function. When ``load_rows()`` or ``load_chunks()`` is used, any tuple in the the entire sequence can cause this TypeError during the loading process:: TypeError: inconsistent items, N processors and M items in row This exception is raised by a generic processing routine whose functionality is abstract in nature, so the message is abstract as well. It essentially means that a tuple in the sequence given to the loading method had too many or too few items. fe-1.1.0/postgresql/documentation/index.rst000066400000000000000000000033471203372773200210350ustar00rootroot00000000000000py-postgresql ============= py-postgresql is a project dedicated to improving the Python client interfaces to PostgreSQL. At its core, py-postgresql provides a PG-API, `postgresql.api`, and DB-API 2.0 interface for using a PostgreSQL database. Contents -------- .. toctree:: :maxdepth: 2 admin driver copyman notifyman alock cluster lib clientparameters gotchas Reference --------- .. toctree:: :maxdepth: 2 bin reference Changes ------- .. toctree:: :maxdepth: 1 changes-v1.1 changes-v1.0 Sample Code ----------- Using `postgresql.driver`:: >>> import postgresql >>> db = postgresql.open("pq://user:password@host/name_of_database") >>> db.execute("CREATE TABLE emp (emp_name text PRIMARY KEY, emp_salary numeric)") >>> >>> # Create the statements. >>> make_emp = db.prepare("INSERT INTO emp VALUES ($1, $2)") >>> raise_emp = db.prepare("UPDATE emp SET emp_salary = emp_salary + $2 WHERE emp_name = $1") >>> get_emp_with_salary_lt = db.prepare("SELECT emp_name FROM emp WHERE emp_salay < $1") >>> >>> # Create some employees, but do it in a transaction--all or nothing. >>> with db.xact(): ... make_emp("John Doe", "150,000") ... make_emp("Jane Doe", "150,000") ... make_emp("Andrew Doe", "55,000") ... make_emp("Susan Doe", "60,000") >>> >>> # Give some raises >>> with db.xact(): ... for row in get_emp_with_salary_lt("125,000"): ... print(row["emp_name"]) ... raise_emp(row["emp_name"], "10,000") Of course, if DB-API 2.0 is desired, the module is located at `postgresql.driver.dbapi20`. DB-API extends PG-API, so the features illustrated above are available on DB-API connections. See :ref:`db_interface` for more information. fe-1.1.0/postgresql/documentation/lib.rst000066400000000000000000000453651203372773200205020ustar00rootroot00000000000000Categories and Libraries ************************ This chapter discusses the usage and implementation of connection categories and libraries. .. note:: First-time users are encouraged to read the `Audience and Motivation`_ section first. Libraries are a collection of SQL statements that can be bound to a connection. Libraries are *normally* bound directly to the connection object as an attribute using a name specified by the library. Libraries provide a common way for SQL statements to be managed outside of the code that uses them. When using ILFs, this increases the portability of the SQL by keeping the statements isolated from the Python code in an accessible format that can be easily used by other languages or systems --- An ILF parser can be implemented within a few dozen lines using basic text tools. SQL statements defined by a Library are identified by their Symbol. These symbols are named and annotated in order to allow the user to define how a statement is to be used. The user may state the default execution method of the statement object, or whether the symbol is to be preloaded at bind time--these properties are Symbol Annotations. The purpose of libraries are to provide a means to manage statements on disk and at runtime. ILFs provide a means to reference a collection of statements on disk, and, when loaded, the symbol bindings provides means to reference a statement already prepared for use on a given connection. The `postgresql.lib` package-module provides fundamental classes for supporting categories and libraries. Writing Libraries ================= ILF files are the recommended way to build a library. These files use the naming convention "lib{NAME}.sql". The prefix and suffix are used describe the purpose of the file and to provide a hint to editors that SQL highlighting should be used. The format of an ILF takes the form:: [name:type:method] ... Where multiple symbols may be defined. The Preface that comes before the first symbol is an arbitrary block of text that should be used to describe the library. This block is free-form, and should be considered a good place for some general documentation. Symbols are named and described using the contents of section markers: ``('[' ... ']')``. Section markers have three components: the symbol name, the symbol type and the symbol method. Each of these components are separated using a single colon, ``:``. All components are optional except the Symbol name. For example:: [get_user_info] SELECT * FROM user WHERE user_id = $1 [get_user_info_v2::] SELECT * FROM user WHERE user_id = $1 In the above example, ``get_user_info`` and ``get_user_info_v2`` are identical. Empty components indicate the default effect. The second component in the section identifier is the symbol type. All Symbol types are listed in `Symbol Types`_. This can be used to specify what the section's contents are or when to bind the symbol:: [get_user_info:preload] SELECT * FROM user WHERE user_id = $1 This provides the Binding with the knowledge that the statement should be prepared when the Library is bound. Therefore, when this Symbol's statement is used for the first time, it will have already been prepared. Another type is the ``const`` Symbol type. This defines a data Symbol whose *statement results* will be resolved when the Library is bound:: [user_type_ids:const] SELECT user_type_id, user_type FROM user_types; Constant Symbols cannot take parameters as they are data properties. The *result* of the above query is set to the Bindings' ``user_type_ids`` attribute:: >>> db.lib.user_type_ids Where ``lib`` in the above is a Binding of the Library containing the ``user_type_ids`` Symbol. Finally, procedures can be bound as symbols using the ``proc`` type:: [remove_user:proc] remove_user(bigint) All procedures symbols are loaded when the Library is bound. Procedure symbols are special because the execution method is effectively specified by the procedure itself. The third component is the symbol ``method``. This defines the execution method of the statement and ultimately what is returned when the Symbol is called at runtime. All the execution methods are listed in `Symbol Execution Methods`_. The default execution method is the default execution method of `postgresql.api.PreparedStatement` objects; return the entire result set in a list object:: [get_numbers] SELECT i FROM generate_series(0, 100-1) AS g(i); When bound:: >>> db.lib.get_numbers() == [(x,) for x in range(100)] True The transformation of range in the above is necessary as statements return a sequence of row objects by default. For large result-sets, fetching all the rows would be taxing on a system's memory. The ``rows`` and ``chunks`` methods provide an iterator to rows produced by a statement using a stream:: [get_some_rows::rows] SELECT i FROM generate_series(0, 1000) AS g(i); [get_some_chunks::chunks] SELECT i FROM generate_series(0, 1000) AS g(i); ``rows`` means that the Symbol will return an iterator producing individual rows of the result, and ``chunks`` means that the Symbol will return an iterator producing sequences of rows of the result. When bound:: >>> from itertools import chain >>> list(db.lib.get_some_rows()) == list(chain.from_iterable(db.lib.get_some_chunks())) True Other methods include ``column`` and ``first``. The column method provides a means to designate that the symbol should return an iterator of the values in the first column instead of an iterator to the rows:: [another_generate_series_example::column] SELECT i FROM generate_series(0, $1::int) AS g(i) In use:: >>> list(db.lib.another_generate_series_example(100-1)) == list(range(100)) True >>> list(db.lib.another_generate_series_example(10-1)) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] The ``first`` method provides direct access to simple results. Specifically, the first column of the first row when there is only one column. When there are multiple columns the first row is returned:: [get_one::first] SELECT 1 [get_one_twice::first] SELECT 1, 1 In use:: >>> db.lib.get_one() == 1 True >>> db.lib.get_one_twice() == (1,1) True .. note:: ``first`` should be used with care. When the result returns no rows, `None` will be returned. Using Libraries =============== After a library is created, it must be loaded before it can be bound using programmer interfaces. The `postgresql.lib.load` interface provides the primary entry point for loading libraries. When ``load`` is given a string, it identifies if a directory separator is in the string, if there is it will treat the string as a *path* to the ILF to be loaded. If no separator is found, it will treat the string as the library name fragment and look for "lib{NAME}.sql" in the directories listed in `postgresql.sys.libpath`. Once a `postgresql.lib.Library` instance has been acquired, it can then be bound to a connection for use. `postgresql.lib.Binding` is used to create an object that provides and manages the Bound Symbols:: >>> import postgresql.lib as pg_lib >>> lib = pg_lib.load(...) >>> B = pg_lib.Binding(db, lib) The ``B`` object in the above example provides the Library's Symbols as attributes which can be called to in order to execute the Symbol's statement:: >>> B.symbol(param) ... While it is sometimes necessary, manual creation of a Binding is discouraged. Rather, `postgresql.lib.Category` objects should be used to manage the set of Libraries to be bound to a connection. Categories ---------- Libraries provide access to a collection of symbols; Bindings provide an interface to the symbols with respect to a subject database. When a connection is established, multiple Bindings may need to be created in order to fulfill the requirements of the programmer. When a Binding is created, it exists in isolation; this can be an inconvenience when access to both the Binding and the Connection is necessary. Categories exist to provide a formal method for defining the interface extensions on a `postgresql.api.Database` instance(connection). A Category is essentially a runtime-class for connections. It provides a formal initialization procedure for connection objects at runtime. However, the connection resource must be connected prior to category initialization. Categories are sets of Libraries to be bound to a connection with optional name substitutions. In order to create one directly, pass the Library instances to `postgresql.lib.Category`:: >>> import postgresql.lib as pg_lib >>> cat = pg_lib.Category(lib1, lib2, libN) Where ``lib1``, ``lib2``, ``libN`` are `postgresql.lib.Library` instances; usually created by `postgresql.lib.load`. Once created, categories can then used by passing the ``category`` keyword to connection creation interfaces:: >>> import postgresql >>> db = postgresql.open(category = cat) The ``db`` object will now have Bindings for ``lib1``, ``lib2``, ..., and ``libN``. Categories can alter the access point(attribute name) of Bindings. This is done by instantiating the Category using keyword parameters:: >>> cat = pg_lib.Category(lib1, lib2, libname = libN) At this point, when a connection is established as the category ``cat``, ``libN`` will be bound to the connection object on the attribute ``libname`` instead of the name defined by the library. And a final illustration of Category usage:: >>> db = postgresql.open(category = pg_lib.Category(pg_lib.load('name'))) >>> db.name Symbol Types ============ The symbol type determines how a symbol is going to be treated by the Binding. For instance, ``const`` symbols are resolved when the Library is bound and the statement object is immediately discarded. Here is a list of symbol types that can be used in ILF libraries: ```` (Empty component) The symbol's statement will never change. This allows the Bound Symbol to hold onto the `postgresql.api.PreparedStatement` object. When the symbol is used again, it will refer to the existing prepared statement object. ``preload`` Like the default type, the Symbol is a simple statement, but it should be loaded when the library is bound to the connection. ``const`` The statement takes no parameters and only needs to be executed once. This will cause the statement to be executed when the library is bound and the results of the statement will be set to the Binding using the symbol name so that it may be used as a property by the user. ``proc`` The contents of the section is a procedure identifier. When this type is used the symbol method *should not* be specified as the method annotation will be automatically resolved based on the procedure's signature. ``transient`` The Symbol is a statement that should *not* be retained. Specifically, it is a statement object that will be discarded when the user discard the referenced Symbol. Used in cases where the statement is used once or very infrequently. Symbol Execution Methods ======================== The Symbol Execution Method provides a way to specify how a statement is going to be used. Specifically, which `postgresql.api.PreparedStatement` method should be executed when a Bound Symbol is called. The following is a list of the symbol execution methods and the effect it will have when invoked: ```` (Empty component) Returns the entire result set in a single list object. If the statement does not return rows, a ``(command, count)`` pair will be returned. ``rows`` Returns an iterator producing each row in the result set. ``chunks`` Returns an iterator producing "chunks" of rows in the result set. ``first`` Returns the first column of the first row if there is one column in the result set. If there are multiple columns in the result set, the first row is returned. If query is non-RETURNING DML--insert, update, or delete, the row count is returned. ``column`` Returns an iterator to values in the first column. (Equivalent to executing a statement as ``map(operator.itemgetter(0), ps.rows())``.) ``declare`` Returns a scrollable cursor, `postgresql.api.Cursor`, to the result set. ``load_chunks`` Takes an iterable row-chunks to be given to the statement. Returns `None`. If the statement is a ``COPY ... FROM STDIN``, the iterable must produce chunks of COPY lines. ``load_rows`` Takes an iterable rows to be given as parameters. If the statement is a ``COPY ... FROM STDIN``, the iterable must produce COPY lines. Reference Symbols ================= Reference Symbols provide a way to construct a Bound Symbol using the Symbol's query. When invoked, A Reference Symbol's query is executed in order to produce an SQL statement to be used as a Bound Symbol. In ILF files, a reference is identified by its symbol name being prefixed with an ampersand:: [&refsym::first] SELECT 'SELECT 1::int4'::text Then executed:: >>> # Runs the 'refsym' SQL, and creates a Bound Symbol using the results. >>> sym = lib.refsym() >>> assert sym() == 1 The Reference Symbol's type and execution method are inherited by the created Bound Symbol. With one exception, ``const`` reference symbols are special in that they immediately resolved into the target Bound Symbol. A Reference Symbol's source query *must* produce rows of text columns. Multiple columns and multiple rows may be produced by the query, but they must be character types as the results are promptly joined together with whitespace so that the target statement may be prepared. Reference Symbols are most likely to be used in dynamic DDL and DML situations, or, somewhat more specifically, any query whose definition depends on a generated column list. Distributing and Usage ====================== For applications, distribution and management can easily be a custom process. The application designates the library directory; the entry point adds the path to the `postgresql.sys.libpath` list; a category is built; and, a connection is made using the category. For mere Python extensions, however, ``distutils`` has a feature that can aid in ILF distribution. The ``package_data`` setup keyword can be used to include ILF files alongside the Python modules that make up a project. See http://docs.python.org/3.1/distutils/setupscript.html#installing-package-data for more detailed information on the keyword parameter. The recommended way to manage libraries for extending projects is to create a package to contain them. For instance, consider the following layout:: project/ setup.py pkg/ __init__.py lib/ __init__.py libthis.sql libthat.sql The project's SQL libraries are organized into a single package directory, ``lib``, so ``package_data`` would be configured:: package_data = {'pkg.lib': ['*.sql']} Subsequently, the ``lib`` package initialization script can then be used to load the libraries, and create any categories(``project/pkg/lib/__init__.py``):: import os.path import postgresql.lib as pg_lib import postgresql.sys as pg_sys libdir = os.path.dirname(__file__) pg_sys.libpath.append(libdir) libthis = pg_lib.load('this') libthat = pg_lib.load('that') stdcat = pg_lib.Category(libthis, libthat) However, it can be undesirable to add the package directory to the global `postgresql.sys.libpath` search paths. Direct path loading can be used in those cases:: import os.path import postgresql.lib as pg_lib libdir = os.path.dirname(__file__) libthis = pg_lib.load(os.path.join(libdir, 'libthis.sql')) libthat = pg_lib.load(os.path.join(libdir, 'libthat.sql')) stdcat = pg_lib.Category(libthis, libthat) Using the established project context, a connection would then be created as:: from pkg.lib import stdcat import postgresql as pg db = pg.open(..., category = stdcat) # And execute some fictitious symbols. db.this.sym_from_libthis() db.that.sym_from_libthat(...) Audience and Motivation ======================= This chapter covers advanced material. It is **not** recommended that categories and libraries be used for trivial applications or introductory projects. .. note:: Libraries and categories are not likely to be of interest to ORM or DB-API users. With exception to ORMs or other similar abstractions, the most common pattern for managing connections and statements is delegation:: class MyAppDB(object): def __init__(self, connection): self.connection = connection def my_operation(self, op_arg1, op_arg2): return self.connection.prepare( "SELECT my_operation_proc($1,$2)", )(op_arg1, op_arg2) ... The straightforward nature is likeable, but the usage does not take advantage of prepared statements. In order to do that an extra condition is necessary to see if the statement has already been prepared:: ... def my_operation(self, op_arg1, op_arg2): if self.hasattr(self, '_my_operation'): ps = self._my_operation else: ps = self._my_operation = self.connection.prepare( "SELECT my_operation_proc($1, $2)", ) return ps(op_arg1, op_arg2) ... There are many variations that can implement the above. It works and it's simple, but it will be exhausting if repeated and error prone if the initialization condition is not factored out. Additionally, if access to statement metadata is needed, the above example is still lacking as it would require execution of the statement and further protocol expectations to be established. This is the province of libraries: direct database interface management. Categories and Libraries are used to factor out and simplify the above functionality so re-implementation is unnecessary. For example, an ILF library containing the symbol:: [my_operation] SELECT my_operation_proc($1, $2) [] ... Will provide the same functionality as the ``my_operation`` method in the latter Python implementation. Terminology =========== The following terms are used throughout this chapter: Annotations The information of about a Symbol describing what it is and how it should be used. Binding An interface to the Symbols provided by a Library for use with a given connection. Bound Symbol An interface to an individual Symbol ready for execution against the subject database. Bound Reference An interface to an individual Reference Symbol that will produce a Bound Symbol when executed. ILF INI-style Library Format. "lib{NAME}.sql" files. Library A collection of Symbols--mapping of names to SQL statements. Local Symbol A relative term used to denote a symbol that exists in the same library as the subject symbol. Preface The block of text that comes before the first symbol in an ILF file. Symbol An named database operation provided by a Library. Usually, an SQL statement with Annotations. Reference Symbol A Symbol whose SQL statement *produces* the source for a Bound Symbol. Category An object supporting a classification for connectors that provides database initialization facilities for produced connections. For libraries, `postgresql.lib.Category` objects are a set of Libraries, `postgresql.lib.Library`. fe-1.1.0/postgresql/documentation/notifyman.rst000066400000000000000000000223641203372773200217320ustar00rootroot00000000000000.. _notifyman: *********************** Notification Management *********************** Relevant SQL commands: `NOTIFY `_, `LISTEN `_, `UNLISTEN `_. Asynchronous notifications offer a means for PostgreSQL to signal application code. Often these notifications are used to signal cache invalidation. In 9.0 and greater, notifications may include a "payload" in which arbitrary data may be delivered on a channel being listened to. By default, received notifications will merely be appended to an internal list on the connection object. This list will remain empty for the duration of a connection *unless* the connection begins listening to a channel that receives notifications. The `postgresql.notifyman.NotificationManager` class is used to wait for messages to come in on a set of connections, pick up the messages, and deliver the messages to the object's user via the `collections.Iterator` protocol. Listening on a Single Connection ================================ The ``db.iternotifies()`` method is a simplification of the notification manager. It returns an iterator to the notifications received on the subject connection. The iterator yields triples consisting of the ``channel`` being notified, the ``payload`` sent with the notification, and the ``pid`` of the backend that caused the notification:: >>> db.listen('for_rabbits') >>> db.notify('for_rabbits') >>> for x in db.iternotifies(): ... channel, payload, pid = x ... break >>> assert channel == 'for_rabbits' True >>> assert payload == '' True >>> assert pid == db.backend_id True The iterator, by default, will continue listening forever unless the connection is terminated--thus the immediate ``break`` statement in the above loop. In cases where some additional activity is necessary, a timeout parameter may be given to the ``iternotifies`` method in order to allow "idle" events to occur at the designated frequency:: >>> for x in db.iternotifies(0.5): ... if x is None: ... break The above example illustrates that idle events are represented using `None` objects. Idle events are guaranteed to occur *approximately* at the specified interval--the ``timeout`` keyword parameter. In addition to providing a means to do other processing or polling, they also offer a safe break point for the loop. Internally, the iterator produced by the ``iternotifies`` method *is* a `NotificationManager`, which will localize the notifications prior to emitting them via the iterator. *It's not safe to break out of the loop, unless an idle event is being handled.* If the loop is broken while a regular event is being processed, some events may remain in the iterator. In order to consume those events, the iterator *must* be accessible. The iterator will be exhausted when the connection is closed, but if the connection is closed during the loop, any remaining notifications *will* be emitted prior to the loop ending, so it is important to be prepared to handle exceptions or check for a closed connection. In situations where multiple connections need to be watched, direct use of the `NotificationManager` is necessary. Listening on Multiple Connections ================================= The `postgresql.notifyman.NotificationManager` class is used to manage *connections* that are expecting to receive notifications. Instances are iterators that yield the connection object and notifications received on the connection or `None` in the case of an idle event. The manager emits events as a pair; the connection object that received notifications, and *all* the notifications picked up on that connection:: >>> from postgresql.notifyman import NotificationManager >>> # Using ``nm`` to reference the manager from here on. >>> nm = NotificationManager(db1, db2, ..., dbN) >>> nm.settimeout(2) >>> for x in nm: ... if x is None: ... # idle ... break ... ... db, notifies = x ... for channel, payload, pid in notifies: ... ... The manager will continue to wait for and emit events so long as there are good connections available in the set; it is possible for connections to be added and removed at any time. Although, in rare circumstances, discarded connections may still have pending events if it not removed during an idle event. The ``connections`` attribute on `NotificationManager` objects is a set object that may be used directly in order to add and remove connections from the manager:: >>> y = [] >>> for x in nm: ... if x is None: ... if y: ... nm.connections.add(y[0]) ... del y[0] ... The notification manager is resilient; if a connection dies, it will discard the connection from the set, and add it to the set of bad connections, the ``garbage`` attribute. In these cases, the idle event *should* be leveraged to check for these failures if that's a concern. It is the user's responsibility to explicitly handle the failure cases, and remove the bad connections from the ``garbage`` set:: >>> for x in nm: ... if x is None: ... if nm.garbage: ... recovered = take_out_trash(nm.garbage) ... nm.connections.update(recovered) ... nm.garbage.clear() ... db, notifies = x ... for channel, payload, pid in notifies: ... ... Explicitly removing connections from the set can also be a means to gracefully terminate the event loop:: >>> for x in nm: ... if x in None: ... if done_listening is True: ... nm.connections.clear() However, doing so inside the loop is not a requirement; it is safe to remove a connection from the set at any point. Notification Managers ===================== The `postgresql.notifyman.NotificationManager` is an event loop that services multiple connections. In cases where only one connection needs to be serviced, the `postgresql.api.Database.iternotifies` method can be used to simplify the process. Notification Manager Constructors --------------------------------- ``NotificationManager(*connections, timeout = None)`` Create a NotificationManager instance that manages the notifications coming from the given set of connections. The ``timeout`` keyword is optional and can be configured using the ``settimeout`` method as well. Notification Manager Interface Points ------------------------------------- ``NotificationManager.__iter__()`` Returns the instance; it is an iterator. ``NotificationManager.__next__()`` Normally, yield the pair, connection and notifications list, when the next event is received. If a timeout is configured, `None` may be yielded to signal an idle event. The notifications list is a list of triples: ``(channel, payload, pid)``. ``NotificationManager.settimeout(timeout : int)`` Set the amount of time to wait before the manager yields an idle event. If zero, the manager will never wait and only yield notifications that are immediately available. If `None`, the manager will never emit idle events. ``NotificationManager.gettimeout() -> [int, None]`` Get the configured timeout; returns either `None`, or an `int`. ``NotificationManager.connections`` The set of connections that the manager is actively watching for notifications. Connections may be added or removed from the set at any time. ``NotificationManager.garbage`` The set of connections that failed. Normally empty, but when a connection gets an exceptional condition or explicitly raises an exception, it is removed from the ``connections`` set, and placed in ``garbage``. Zero Timeout ------------ When a timeout of zero, ``0``, is configured, the notification manager will terminate early. Specifically, each connection will be polled for any pending notifications, and once all of the collected notifications have been emitted by the iterator, `StopIteration` will be raised. Notably, *no* idle events will occur when the timeout is configured to zero. Zero timeouts offer a means for the notification "queue" to be polled. Often, this is the appropriate way to collect pending notifications on active connections where using the connection exclusively for waiting is not practical:: >>> notifies = list(db.iternotifies(0)) Or with a NotificationManager instance:: >>> nm.settimeout(0) >>> db_notifies = list(nm) In both cases of zero timeout, the iterator may be promptly discarded without losing any events. Summary of Characteristics -------------------------- * The iterator will continue until the connections die. * Objects yielded by the iterator are either `None`, an "idle event", or an individual notification triple if using ``db.iternotifies()``, or a ``(db, notifies)`` pair if using the base `NotificationManager`. * When a connection dies or raises an exception, it will be removed from the ``nm.connections`` set and added to the ``nm.garbage`` set. * The NotificationManager instance will *not* hold any notifications during an idle event. Idle events offer a break point in which the manager may be discarded. * A timeout of zero will cause the iterator to only yield the events that are pending right now, and promptly end. However, the same manager object may be used again. * A notification triple is a tuple consisting of ``(channel, payload, pid)``. * Connections may be added and removed from the ``nm.connections`` set at any time. fe-1.1.0/postgresql/documentation/reference.rst000066400000000000000000000024511203372773200216570ustar00rootroot00000000000000Reference ========= :mod:`postgresql` ----------------- .. automodule:: postgresql .. autodata:: version .. autodata:: version_info .. autofunction:: open :mod:`postgresql.api` --------------------- .. automodule:: postgresql.api :members: :show-inheritance: :mod:`postgresql.sys` --------------------- .. automodule:: postgresql.sys :members: :show-inheritance: :mod:`postgresql.string` ------------------------ .. automodule:: postgresql.string :members: :show-inheritance: :mod:`postgresql.exceptions` ---------------------------- .. automodule:: postgresql.exceptions :members: :show-inheritance: :mod:`postgresql.temporal` -------------------------- .. automodule:: postgresql.temporal :members: :show-inheritance: :mod:`postgresql.installation` ------------------------------ .. automodule:: postgresql.installation :members: :show-inheritance: :mod:`postgresql.cluster` ------------------------- .. automodule:: postgresql.cluster :members: :show-inheritance: :mod:`postgresql.copyman` ------------------------- .. automodule:: postgresql.copyman :members: :show-inheritance: :mod:`postgresql.alock` ----------------------- .. automodule:: postgresql.alock :members: :show-inheritance: fe-1.1.0/postgresql/documentation/sphinx/000077500000000000000000000000001203372773200204765ustar00rootroot00000000000000fe-1.1.0/postgresql/documentation/sphinx/.gitignore000066400000000000000000000000061203372773200224620ustar00rootroot00000000000000*.txt fe-1.1.0/postgresql/documentation/sphinx/build.sh000077500000000000000000000014221203372773200221330ustar00rootroot00000000000000#!/bin/sh cd "$(dirname $0)" # distutils doesn't make it straighforward to include an arbitrary # directory in the package data, so manage .static and .templates here. mkdir -p .static .templates cat >.static/unsuck.css_t <.templates/layout.html < {% endblock %} EOF mkdir -p ../html/doctrees sphinx-build -c "$(pwd)" -E -b html -d ../html/doctrees .. ../html cd ../html && pwd fe-1.1.0/postgresql/documentation/sphinx/conf.py000066400000000000000000000104541203372773200220010ustar00rootroot00000000000000import sys, os sys.path.insert(0, os.path.abspath('../../..')) # needed for autodoc. sys.dont_write_bytecode = True # read the project info from the PKG.project module. mod = {} with open(os.path.abspath('../../project.py')) as f: exec(f.read(), mod, mod) rst_prolog = "" rst_epilog = "" # General configuration # --------------------- # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] # Add any paths that contain templates here, relative to this directory. templates_path = ['.templates'] # The suffix of source filenames. source_suffix = '.rst' # The master toctree document. master_doc = 'index' # General substitutions. copyright = mod['meaculpa'] # The default replacements for |version| and |release|, also used in various # other places throughout the built documents. # # The short X.Y version. version = '.'.join(map(str, mod['version_info'][:2])) # The full version, including alpha/beta/rc tags. release = mod['version'] project = mod['name'] # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. today_fmt = '%B %d, %Y' # List of documents that shouldn't be included in the build. #unused_docs = [] # List of directories, relative to source directories, that shouldn't be searched # for source files. #exclude_dirs = [] # The reST default role (used for this markup: `text`) to use for all documents. #default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # Options for HTML output # ----------------------- # The style sheet to use for HTML and HTML Help pages. A file of that name # must exist either in Sphinx' static/ path, or in one of the custom paths # given in html_static_path. html_style = 'default.css' # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (within the static path) to place at the top of # the sidebar. #html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['.static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. #html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If false, no module index is generated. #html_use_modindex = True # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, the reST sources are included in the HTML build as _sources/. #html_copy_source = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = '' # Output file base name for HTML help builder. htmlhelp_basename = project fe-1.1.0/postgresql/driver/000077500000000000000000000000001203372773200156075ustar00rootroot00000000000000fe-1.1.0/postgresql/driver/__init__.py000066400000000000000000000004451203372773200177230ustar00rootroot00000000000000## # .driver package ## """ Driver package for providing an interface to a PostgreSQL database. """ __all__ = ['connect', 'default'] from .pq3 import Driver default = Driver() def connect(*args, **kw): 'Establish a connection using the default driver.' return default.connect(*args, **kw) fe-1.1.0/postgresql/driver/dbapi20.py000066400000000000000000000232771203372773200174150ustar00rootroot00000000000000## # .driver.dbapi20 - DB-API 2.0 Implementation ## """ DB-API 2.0 conforming interface using postgresql.driver. """ threadsafety = 1 paramstyle = 'pyformat' apilevel = '2.0' from operator import itemgetter from functools import partial import datetime import time import re from .. import clientparameters as pg_param from .. import driver as pg_driver from .. import types as pg_type from .. import string as pg_str from .pq3 import Connection ## # Basically, is it a mapping, or is it a sequence? # If findall()'s first index is 's', it's a sequence. # If it starts with '(', it's mapping. # The pain here is due to a need to recognize any %% escapes. parameters_re = re.compile( r'(?:%%)+|%(s|[(][^)]*[)]s)' ) def percent_parameters(sql): # filter any %% matches(empty strings). return [x for x in parameters_re.findall(sql) if x] def convert_keywords(keys, mapping): return [mapping[k] for k in keys] from postgresql.exceptions import \ Error, DataError, InternalError, \ ICVError as IntegrityError, \ SEARVError as ProgrammingError, \ IRError as OperationalError, \ DriverError as InterfaceError, \ Warning DatabaseError = Error class NotSupportedError(DatabaseError): pass STRING = str BINARY = bytes NUMBER = int DATETIME = datetime.datetime ROWID = int Binary = BINARY Date = datetime.date Time = datetime.time Timestamp = datetime.datetime DateFromTicks = lambda x: Date(*time.localtime(x)[:3]) TimeFromTicks = lambda x: Time(*time.localtime(x)[3:6]) TimestampFromTicks = lambda x: Timestamp(*time.localtime(x)[:7]) def dbapi_type(typid): if typid in ( pg_type.TEXTOID, pg_type.CHAROID, pg_type.VARCHAROID, pg_type.NAMEOID, pg_type.CSTRINGOID, ): return STRING elif typid == pg_type.BYTEAOID: return BINARY elif typid in (pg_type.INT8OID, pg_type.INT2OID, pg_type.INT4OID): return NUMBER elif typid in (pg_type.TIMESTAMPOID, pg_type.TIMESTAMPTZOID): return DATETIME elif typid == pg_type.OIDOID: return ROWID class Portal(object): """ Manages read() interfaces to a chunks iterator. """ def __init__(self, chunks): self.chunks = chunks self.buf = [] self.pos = 0 def __next__(self): try: r = self.buf[self.pos] self.pos += 1 return r except IndexError: # Any alledged infinite recursion will stop on the StopIteration # thrown by this next(). Recursion is unlikely to occur more than # once; specifically, empty chunks would need to be returned # by this invocation of next(). self.buf = next(self.chunks) self.pos = 0 return self.__next__() def readall(self): self.buf = self.buf[self.pos:] self.pos = 0 for x in self.chunks: self.buf.extend(x) r = self.buf self.buf = [] return r def read(self, amount): try: while (len(self.buf) - self.pos) < amount: self.buf.extend(next(self.chunks)) end = self.pos + amount except StopIteration: # end of cursor end = len(self.buf) r = self.buf[self.pos:end] del self.buf[:end] self.pos = 0 return r class Cursor(object): rowcount = -1 arraysize = 1 description = None def __init__(self, C): self.database = self.connection = C self.description = () self.__portals = [] # Describe the "real" cursor as a "portal". # This should keep ambiguous terminology out of the adaptation. def _portal(): def fget(self): if self.__portals is None: raise Error("cursor is closed", source = 'CLIENT', creator = self.database ) try: p = self.__portals[0] except IndexError: raise InterfaceError("no portal on stack") return p def fdel(self): if self.__portals is None: raise Error("cursor is closed", source = 'CLIENT', creator = self.database ) try: del self.__portals[0] except IndexError: raise InterfaceError("no portal on stack") return locals() _portal = property(**_portal()) def setinputsizes(self, sizes): if self.__portals is None: raise Error("cursor is closed", source = 'CLIENT', creator = self.database) def setoutputsize(self, sizes, columns = None): if self.__portals is None: raise Error("cursor is closed", source = 'CLIENT', creator = self.database) def callproc(self, proname, args): if self.__portals is None: raise Error("cursor is closed", source = 'CLIENT', creator = self.database) p = self.database.prepare("SELECT %s(%s)" %( proname, ','.join([ '$' + str(x) for x in range(1, len(args) + 1) ]) )) self.__portals.insert(0, Portal(p.chunks(*args))) return args def fetchone(self): try: return next(self._portal) except StopIteration: return None def __next__(self): return next(self._portal) next = __next__ def __iter__(self): return self def fetchmany(self, arraysize = None): return self._portal.read(arraysize or self.arraysize or 1) def fetchall(self): return self._portal.readall() def nextset(self): del self._portal return len(self.__portals) or None def fileno(self): return self.database.fileno() def _convert_query(self, string): parts = list(pg_str.split(string)) style = None count = 0 keys = [] kmap = {} transformer = tuple rparts = [] for part in parts: if part.__class__ is ().__class__: # skip quoted portions rparts.append(part) else: r = percent_parameters(part) pcount = 0 for x in r: if x == 's': pcount += 1 else: x = x[1:-2] if x not in keys: kmap[x] = '$' + str(len(keys) + 1) keys.append(x) if r: if pcount: # format params = tuple([ '$' + str(i+1) for i in range(count, count + pcount) ]) count += pcount rparts.append(part % params) else: # pyformat rparts.append(part % kmap) else: # no parameters identified in string rparts.append(part) if keys: if count: raise TypeError( "keyword parameters and positional parameters used in query" ) transformer = partial(convert_keywords, keys) count = len(keys) return (pg_str.unsplit(rparts) if rparts else string, transformer, count) def execute(self, statement, parameters = ()): if self.__portals is None: raise Error("cursor is closed", source = 'CLIENT', creator = self.database) sql, pxf, nparams = self._convert_query(statement) if nparams != -1 and len(parameters) != nparams: raise TypeError( "statement require %d parameters, given %d" %( nparams, len(parameters) ) ) ps = self.database.prepare(sql) c = ps.chunks(*pxf(parameters)) if ps._output is not None and len(ps._output) > 0: # name, relationId, columnNumber, typeId, typlen, typmod, format self.rowcount = -1 self.description = tuple([ (self.database.typio.decode(x[0]), dbapi_type(x[3]), None, None, None, None, None) for x in ps._output ]) self.__portals.insert(0, Portal(c)) else: self.rowcount = c.count() if self.rowcount is None: self.rowcount = -1 self.description = None # execute bumps any current portal if self.__portals: del self._portal return self def executemany(self, statement, parameters): if self.__portals is None: raise Error("cursor is closed", source = 'CLIENT', creator = self.database) sql, pxf, nparams = self._convert_query(statement) ps = self.database.prepare(sql) if ps._input is not None: ps.load_rows(map(pxf, parameters)) else: ps.load_rows(parameters) self.rowcount = -1 return self def close(self): if self.__portals is None: raise Error("cursor is closed", source = 'CLIENT', creator = self.database) self.description = None self.__portals = None class Connection(Connection): """ DB-API 2.0 connection implementation for PG-API connection objects. """ from postgresql.exceptions import \ Error, DataError, InternalError, \ ICVError as IntegrityError, \ SEARVError as ProgrammingError, \ IRError as OperationalError, \ DriverError as InterfaceError, \ Warning DatabaseError = DatabaseError NotSupportedError = NotSupportedError def autocommit_set(self, val): if val: # already in autocommit mode. if self._xact is None: return self._xact.rollback() self._xact = None else: if self._xact is not None: return self._xact = self.xact() self._xact.start() def autocommit_get(self): return self._xact is None def autocommit_del(self): self.autocommit = False autocommit = property( fget = autocommit_get, fset = autocommit_set, fdel = autocommit_del, ) del autocommit_set, autocommit_get, autocommit_del def connect(self, *args, **kw): super().connect(*args, **kw) self._xact = self.xact() self._xact.start() def close(self): if self.closed: raise Error( "connection already closed", source = 'CLIENT', creator = self ) super().close() def cursor(self): return Cursor(self) def commit(self): if self._xact is None: raise InterfaceError( "commit on connection in autocommit mode", source = 'CLIENT', details = { 'hint': 'The "autocommit" property on the connection was set to True.' }, creator = self ) self._xact.commit() self._xact = self.xact() self._xact.start() def rollback(self): if self._xact is None: raise InterfaceError( "rollback on connection in autocommit mode", source = 'DRIVER', details = { 'hint': 'The "autocommit" property on the connection was set to True.' }, creator = self ) self._xact.rollback() self._xact = self.xact() self._xact.start() driver = pg_driver.Driver(connection = Connection) def connect(**kw): """ Create a DB-API connection using the given parameters. """ std_params = pg_param.collect(prompt_title = None) params = pg_param.normalize( list(pg_param.denormalize_parameters(std_params)) + \ list(pg_param.denormalize_parameters(kw)) ) pg_param.resolve_password(params) return driver.connect(**params) fe-1.1.0/postgresql/driver/pq3.py000066400000000000000000002505011203372773200166670ustar00rootroot00000000000000## # .driver.pq3 - interface to PostgreSQL using PQ v3.0. ## """ PG-API interface for PostgreSQL using PQ version 3.0. """ import os import weakref import socket from traceback import format_exception from itertools import repeat, chain, count from functools import partial from abc import abstractmethod from codecs import lookup as lookup_codecs from operator import itemgetter get0 = itemgetter(0) get1 = itemgetter(1) from .. import lib as pg_lib from .. import versionstring as pg_version from .. import iri as pg_iri from .. import exceptions as pg_exc from .. import string as pg_str from .. import api as pg_api from .. import message as pg_msg from ..encodings.aliases import get_python_name from ..string import quote_ident from ..python.itertools import interlace, chunk from ..python.socket import SocketFactory from ..python.functools import process_tuple, process_chunk from ..python.functools import Composition as compose from ..protocol import xact3 as xact from ..protocol import element3 as element from ..protocol import client3 as client from ..protocol.message_types import message_types from ..notifyman import NotificationManager from .. import types as pg_types from ..types import io as pg_types_io from ..types.io import lib as io_lib import warnings # Map element3.Notice field identifiers # to names used by message.Message. notice_field_to_name = { message_types[b'S'[0]] : 'severity', message_types[b'C'[0]] : 'code', message_types[b'M'[0]] : 'message', message_types[b'D'[0]] : 'detail', message_types[b'H'[0]] : 'hint', message_types[b'W'[0]] : 'context', message_types[b'P'[0]] : 'position', message_types[b'p'[0]] : 'internal_position', message_types[b'q'[0]] : 'internal_query', message_types[b'F'[0]] : 'file', message_types[b'L'[0]] : 'line', message_types[b'R'[0]] : 'function', } del message_types notice_field_from_name = dict( (v, k) for (k, v) in notice_field_to_name.items() ) could_not_connect = element.ClientError(( (b'S', 'FATAL'), (b'C', '08001'), (b'M', "could not establish connection to server"), )) # generate an id for a client statement or cursor def ID(s, title = None, IDNS = 'py:'): return IDNS + hex(id(s)) def declare_statement_string( cursor_id, statement_string, insensitive = True, scroll = True, hold = True ): s = 'DECLARE ' + cursor_id if insensitive is True: s += ' INSENSITIVE' if scroll is True: s += ' SCROLL' s += ' CURSOR' if hold is True: s += ' WITH HOLD' else: s += ' WITHOUT HOLD' return s + ' FOR ' + statement_string def direction_str_to_bool(str): s = str.upper() if s == 'FORWARD': return True elif s == 'BACKWARD': return False else: raise ValueError("invalid direction " + repr(str)) def direction_to_bool(v): if isinstance(v, str): return direction_str_to_bool(v) elif v is not True and v is not False: raise TypeError("invalid direction " + repr(v)) else: return v class TypeIO(pg_api.TypeIO): """ A class that manages I/O for a given configuration. Normally, a connection would create an instance, and configure it based upon the version and configuration of PostgreSQL that it is connected to. """ _e_factors = ('database',) strio = (None, None, str) def __init__(self, database): self.database = database self.encoding = None strio = self.strio self._cache = { # Encoded character strings pg_types.ACLITEMOID : strio, # No binary functions. pg_types.NAMEOID : strio, pg_types.BPCHAROID : strio, pg_types.VARCHAROID : strio, pg_types.CSTRINGOID : strio, pg_types.TEXTOID : strio, pg_types.REGTYPEOID : strio, pg_types.REGPROCOID : strio, pg_types.REGPROCEDUREOID : strio, pg_types.REGOPEROID : strio, pg_types.REGOPERATOROID : strio, pg_types.REGCLASSOID : strio, } self.typinfo = {} super().__init__() def lookup_type_info(self, typid): return self.database.sys.lookup_type(typid) def lookup_composite_type_info(self, typid): return self.database.sys.lookup_composite(typid) def lookup_domain_basetype(self, typid): if self.database.version_info[:2] >= (8, 4): return self.lookup_domain_basetype_84(typid) while typid: r = self.database.sys.lookup_basetype(typid) if not r[0][0]: return typid else: typid = r[0][0] def lookup_domain_basetype_84(self, typid): r = self.database.sys.lookup_basetype_recursive(typid) return r[0][0] def set_encoding(self, value): """ Set a new client encoding. """ self.encoding = value.lower().strip() enc = get_python_name(self.encoding) ci = lookup_codecs(enc or self.encoding) self._encode, self._decode, *_ = ci def encode(self, string_data): return self._encode(string_data)[0] def decode(self, bytes_data): return self._decode(bytes_data)[0] def encodes(self, iter, get0 = get0): """ Encode the items in the iterable in the configured encoding. """ return map(compose((self._encode, get0)), iter) def decodes(self, iter, get0 = get0): """ Decode the items in the iterable from the configured encoding. """ return map(compose((self._decode, get0)), iter) def resolve_pack(self, typid): return self.resolve(typid)[0] or self.encode def resolve_unpack(self, typid): return self.resolve(typid)[1] or self.decode def attribute_map(self, pq_descriptor): return zip(self.decodes(pq_descriptor.keys()), count()) def sql_type_from_oid(self, oid, qi = quote_ident): if oid in pg_types.oid_to_sql_name: return pg_types.oid_to_sql_name[oid] if oid in self.typinfo: nsp, name, *_ = self.typinfo[oid] return qi(nsp) + '.' + qi(name) name = pg_types.oid_to_name.get(oid) if name: return 'pg_catalog.%s' % name else: return None def type_from_oid(self, oid): if oid in self._cache: typ = self._cache[oid][2] return typ def resolve_descriptor(self, desc, index): 'create a sequence of I/O routines from a pq descriptor' return [ (self.resolve(x[3]) or (None, None))[index] for x in desc ] # lookup a type's IO routines from a given typid def resolve(self, typid : "The Oid of the type to resolve pack and unpack routines for.", from_resolution_of : \ "Sequence of typid's used to identify infinite recursion" = (), builtins : "types.io.resolve" = pg_types_io.resolve, quote_ident = quote_ident ): if from_resolution_of and typid in from_resolution_of: raise TypeError( "type, %d, is already being looked up: %r" %( typid, from_resolution_of ) ) typid = int(typid) typio = None if typid in self._cache: typio = self._cache[typid] else: typio = builtins(typid) if typio is not None: # If typio is a tuple, it's a constant pair: (pack, unpack) # otherwise, it's an I/O pair constructor. if typio.__class__ is not tuple: typio = typio(typid, self) self._cache[typid] = typio if typio is None: # Lookup the type information for the typid as it's not cached. ## ti = self.lookup_type_info(typid) if ti is not None: typnamespace, typname, typtype, typlen, typelem, typrelid, \ ae_typid, ae_hasbin_input, ae_hasbin_output = ti self.typinfo[typid] = ( typnamespace, typname, typrelid, int(typelem) if ae_typid else None ) if typrelid: # Row type # # The attribute name map, # column I/O, # column type Oids # are needed to build the packing pair. attmap = {} cio = [] typids = [] attnames = [] i = 0 for x in self.lookup_composite_type_info(typrelid): attmap[x[1]] = i attnames.append(x[1]) if x[2]: # This is a domain fieldtypid = self.lookup_domain_basetype(x[0]) else: fieldtypid = x[0] typids.append(x[0]) te = self.resolve( fieldtypid, list(from_resolution_of) + [typid] ) cio.append((te[0] or self.encode, te[1] or self.decode)) i += 1 self._cache[typid] = typio = self.record_io_factory( cio, typids, attmap, list( map(self.sql_type_from_oid, typids) ), attnames, typrelid, quote_ident(typnamespace) + '.' + \ quote_ident(typname), ) elif ae_typid is not None: # resolve the element type and I/O pair te = self.resolve( int(typelem), from_resolution_of = list(from_resolution_of) + [typid] ) or (None, None) typio = self.array_io_factory( te[0] or self.encode, te[1] or self.decode, typelem, ae_hasbin_input, ae_hasbin_output ) self._cache[typid] = typio else: typio = None if typtype == b'd': basetype = self.lookup_domain_basetype(typid) typio = self.resolve( basetype, from_resolution_of = list(from_resolution_of) + [typid] ) elif typtype == b'p' and typnamespace == 'pg_catalog' and typname == 'record': # anonymous record type typio = self.anon_record_io_factory() if not typio: typio = self.strio self._cache[typid] = typio else: # Throw warning about type without entry in pg_type? typio = self.strio return typio def identify(self, **identity_mappings): """ Explicitly designate the I/O handler for the specified type. Primarily used in cases involving UDTs. """ # get them ordered; we process separately, then recombine. id = list(identity_mappings.items()) ios = [pg_types_io.resolve(x[0]) for x in id] oids = list(self.database.sys.regtypes([x[1] for x in id])) self._cache.update([ (oid, io if io.__class__ is tuple else io(oid, self)) for oid, io in zip(oids, ios) ]) def array_parts(self, array, ArrayType = pg_types.Array): if array.__class__ is not ArrayType: # Assume the data is a nested list. array = ArrayType(array) return ( array.elements(), array.dimensions, array.lowerbounds ) def array_from_parts(self, parts, ArrayType = pg_types.Array): elements, dimensions, lowerbounds = parts return ArrayType.from_elements( elements, lowerbounds = lowerbounds, upperbounds = [x + lb - 1 for x, lb in zip(dimensions, lowerbounds)] ) ## # array_io_factory - build I/O pair for ARRAYs ## def array_io_factory( self, pack_element, unpack_element, typoid, # array element id hasbin_input, hasbin_output, array_pack = io_lib.array_pack, array_unpack = io_lib.array_unpack, ): packed_typoid = io_lib.ulong_pack(typoid) if hasbin_input: def pack_an_array(data, get_parts = self.array_parts): elements, dimensions, lowerbounds = get_parts(data) return array_pack(( 0, # unused flags typoid, dimensions, lowerbounds, (x if x is None else pack_element(x) for x in elements), )) else: # signals string formatting pack_an_array = None if hasbin_output: def unpack_an_array(data, array_from_parts = self.array_from_parts): flags, typoid, dims, lbs, elements = array_unpack(data) return array_from_parts(((x if x is None else unpack_element(x) for x in elements), dims, lbs)) else: # signals string formatting unpack_an_array = None return (pack_an_array, unpack_an_array, pg_types.Array) def RowTypeFactory(self, attribute_map = {}, _Row = pg_types.Row.from_sequence, composite_relid = None): return partial(_Row, attribute_map) ## # record_io_factory - Build an I/O pair for RECORDs ## def record_io_factory(self, column_io : "sequence (pack,unpack) tuples corresponding to the columns", typids : "sequence of type Oids; index must correspond to the composite's", attmap : "mapping of column name to index number", typnames : "sequence of sql type names in order", attnames : "sequence of attribute names in order", composite_relid : "oid of the composite relation", composite_name : "the name of the composite type", get0 = get0, get1 = get1, fmt_errmsg = "failed to {0} attribute {1}, {2}::{3}, of composite {4} from wire data".format ): fpack = tuple(map(get0, column_io)) funpack = tuple(map(get1, column_io)) row_constructor = self.RowTypeFactory(attribute_map = attmap, composite_relid = composite_relid) def raise_pack_tuple_error(cause, procs, tup, itemnum): data = repr(tup[itemnum]) if len(data) > 80: # Be sure not to fill screen with noise. data = data[:75] + ' ...' self.raise_client_error(element.ClientError(( (b'C', '--cIO',), (b'S', 'ERROR',), (b'M', fmt_errmsg('pack', itemnum, attnames[itemnum], typnames[itemnum], composite_name),), (b'W', data,), (b'P', str(itemnum),) )), cause = cause) def raise_unpack_tuple_error(cause, procs, tup, itemnum): data = repr(tup[itemnum]) if len(data) > 80: # Be sure not to fill screen with noise. data = data[:75] + ' ...' self.raise_client_error(element.ClientError(( (b'C', '--cIO',), (b'S', 'ERROR',), (b'M', fmt_errmsg('unpack', itemnum, attnames[itemnum], typnames[itemnum], composite_name),), (b'W', data,), (b'P', str(itemnum),), )), cause = cause) def unpack_a_record(data, unpack = io_lib.record_unpack, process_tuple = process_tuple, row_constructor = row_constructor ): data = tuple([x[1] for x in unpack(data)]) return row_constructor(process_tuple(funpack, data, raise_unpack_tuple_error)) sorted_atts = sorted(attmap.items(), key = get1) def pack_a_record(data, pack = io_lib.record_pack, process_tuple = process_tuple, ): if isinstance(data, dict): data = [data.get(k) for k,_ in sorted_atts] return pack( tuple(zip( typids, process_tuple(fpack, tuple(data), raise_pack_tuple_error) )) ) return (pack_a_record, unpack_a_record, tuple) def anon_record_io_factory(self): def raise_unpack_tuple_error(cause, procs, tup, itemnum): data = repr(tup[itemnum]) if len(data) > 80: # Be sure not to fill screen with noise. data = data[:75] + ' ...' self.raise_client_error(element.ClientError(( (b'C', '--cIO',), (b'S', 'ERROR',), (b'M', 'Could not unpack element {} from anonymous record'.format(itemnum)), (b'W', data,), (b'P', str(itemnum),) )), cause = cause) def _unpack_record(data, unpack = io_lib.record_unpack, process_tuple = process_tuple): record = list(unpack(data)) coloids = tuple(x[0] for x in record) colio = map(self.resolve, coloids) column_unpack = tuple(c[1] or self.decode for c in colio) data = tuple(x[1] for x in record) return process_tuple(column_unpack, data, raise_unpack_tuple_error) return (None, _unpack_record) def raise_client_error(self, error_message, cause = None, creator = None): m = { notice_field_to_name[k] : v for k, v in error_message.items() # don't include unknown messages in this list. if k in notice_field_to_name } c = m.pop('code') ms = m.pop('message') client_error = self.lookup_exception(c) client_error = client_error(ms, code = c, details = m, source = 'CLIENT', creator = creator or self.database) client_error.database = self.database if cause is not None: raise client_error from cause else: raise client_error def lookup_exception(self, code, errorlookup = pg_exc.ErrorLookup,): return errorlookup(code) def lookup_warning(self, code, warninglookup = pg_exc.WarningLookup,): return warninglookup(code) def raise_server_error(self, error_message, cause = None, creator = None): m = dict(self.decode_notice(error_message)) c = m.pop('code') ms = m.pop('message') server_error = self.lookup_exception(c) server_error = server_error(ms, code = c, details = m, source = 'SERVER', creator = creator or self.database) server_error.database = self.database if cause is not None: raise server_error from cause else: raise server_error def raise_error(self, error_message, ClientError = element.ClientError, **kw): if 'creator' not in kw: kw['creator'] = getattr(self.database, '_controller', self.database) or self.database if error_message.__class__ is ClientError: self.raise_client_error(error_message, **kw) else: self.raise_server_error(error_message, **kw) ## # Used by decode_notice() def _decode_failsafe(self, data): decode = self._decode i = iter(data) for x in i: try: # prematurely optimized for your viewing displeasure. v = x[1] yield (x[0], decode(v)[0]) for x in i: v = x[1] yield (x[0], decode(v)[0]) except UnicodeDecodeError: # Fallback to the bytes representation. # This should be sufficiently informative in most cases, # and in the cases where it isn't, an element traceback should # ultimately yield the pertinent information yield (x[0], repr(x[1])[2:-1]) def decode_notice(self, notice): notice = self._decode_failsafe(notice.items()) return { notice_field_to_name[k] : v for k, v in notice # don't include unknown messages in this list. if k in notice_field_to_name } def emit_server_message(self, message, creator = None, MessageType = pg_msg.Message ): fields = self.decode_notice(message) m = fields.pop('message') c = fields.pop('code') if fields['severity'].upper() == 'WARNING': MessageType = self.lookup_warning(c) message = MessageType(m, code = c, details = fields, creator = creator, source = 'SERVER') message.database = self.database message.emit() return message def emit_client_message(self, message, creator = None, MessageType = pg_msg.Message ): fields = { notice_field_to_name[k] : v for k, v in message.items() # don't include unknown messages in this list. if k in notice_field_to_name } m = fields.pop('message') c = fields.pop('code') if fields['severity'].upper() == 'WARNING': MessageType = self.lookup_warning(c) message = MessageType(m, code = c, details = fields, creator = creator, source = 'CLIENT') message.database = self.database message.emit() return message def emit_message(self, message, ClientNotice = element.ClientNotice, **kw): if message.__class__ is ClientNotice: return self.emit_client_message(message, **kw) else: return self.emit_server_message(message, **kw) ## # This class manages all the functionality used to get # rows from a PostgreSQL portal/cursor. class Output(object): _output = None _output_io = None _output_formats = None _output_attmap = None closed = False cursor_id = None statement = None parameters = None _complete_message = None @abstractmethod def _init(self): """ Bind a cursor based on the configured parameters. """ # The local initialization for the specific cursor. def __init__(self, cursor_id, wref = weakref.ref, ID = ID): self.cursor_id = cursor_id if self.statement is not None: stmt = self.statement self._output = stmt._output self._output_io = stmt._output_io self._row_constructor = stmt._row_constructor self._output_formats = stmt._output_formats or () self._output_attmap = stmt._output_attmap self._pq_cursor_id = self.database.typio.encode(cursor_id) # If the cursor's id was generated, it should be garbage collected. if cursor_id == ID(self): self.database.pq.register_cursor(self, self._pq_cursor_id) self._quoted_cursor_id = '"' + cursor_id.replace('"', '""') + '"' self._init() def __iter__(self): return self def close(self): if self.closed is False: self.database.pq.trash_cursor(self._pq_cursor_id) self.closed = True def _ins(self, *args): return xact.Instruction(*args, asynchook = self.database._receive_async) def _pq_xp_describe(self): return (element.DescribePortal(self._pq_cursor_id),) def _pq_xp_bind(self): return ( element.Bind( self._pq_cursor_id, self.statement._pq_statement_id, self.statement._input_formats, self.statement._pq_parameters(self.parameters), self._output_formats, ), ) def _pq_xp_fetchall(self): return ( element.Bind( b'', self.statement._pq_statement_id, self.statement._input_formats, self.statement._pq_parameters(self.parameters), self._output_formats, ), element.Execute(b'', 0xFFFFFFFF), ) def _pq_xp_declare(self): return ( element.Parse(b'', self.database.typio.encode( declare_statement_string( str(self._quoted_cursor_id), str(self.statement.string) ) ), () ), element.Bind( b'', b'', self.statement._input_formats, self.statement._pq_parameters(self.parameters), () ), element.Execute(b'', 1), ) def _pq_xp_execute(self, quantity): return ( element.Execute(self._pq_cursor_id, quantity), ) def _pq_xp_fetch(self, direction, quantity): ## # It's an SQL declared cursor, manually construct the fetch commands. qstr = "FETCH " + ("FORWARD " if direction else "BACKWARD ") if quantity is None: qstr = qstr + "ALL IN " + self._quoted_cursor_id else: qstr = qstr \ + str(quantity) + " IN " + self._quoted_cursor_id return ( element.Parse(b'', self.database.typio.encode(qstr), ()), element.Bind(b'', b'', (), (), self._output_formats), # The "limit" is defined in the fetch query. element.Execute(b'', 0xFFFFFFFF), ) def _pq_xp_move(self, position, whence): return ( element.Parse(b'', b'MOVE ' + whence + b' ' + position + b' IN ' + \ self.database.typio.encode(self._quoted_cursor_id), () ), element.Bind(b'', b'', (), (), ()), element.Execute(b'', 1), ) def _process_copy_chunk(self, x): if x: if x[0].__class__ is not bytes or x[-1].__class__ is not bytes: return [ y for y in x if y.__class__ is bytes ] return x # Process the element.Tuple message in x for column() def _process_tuple_chunk_Column(self, x, range = range): unpack = self._output_io[0] # get the raw data for the first column l = [y[0] for y in x] # iterate over the range to keep track # of which item we're processing. r = range(len(l)) try: return [unpack(l[i]) for i in r] except Exception: cause = sys.exc_info()[1] try: i = next(r) except StopIteration: i = len(l) self._raise_column_tuple_error(cause, self._output_io, (l[i],), 0) # Process the element.Tuple message in x for rows() def _process_tuple_chunk_Row(self, x, proc = process_chunk, ): rc = self._row_constructor return [ rc(y) for y in proc(self._output_io, x, self._raise_column_tuple_error) ] # Process the elemnt.Tuple messages in `x` for chunks() def _process_tuple_chunk(self, x, proc = process_chunk): return proc(self._output_io, x, self._raise_column_tuple_error) def _raise_column_tuple_error(self, cause, procs, tup, itemnum): 'for column processing' # The element traceback will include the full list of parameters. data = repr(tup[itemnum]) if len(data) > 80: # Be sure not to fill screen with noise. data = data[:75] + ' ...' em = element.ClientError(( (b'S', 'ERROR'), (b'C', "--CIO"), (b'M', "failed to unpack column %r, %s::%s, from wire data" %( itemnum, self.column_names[itemnum], self.database.typio.sql_type_from_oid( self.statement.pg_column_types[itemnum] ) or '', ) ), (b'D', data), (b'H', "Try casting the column to 'text'."), (b'P', str(itemnum)), )) self.database.typio.raise_client_error(em, creator = self, cause = cause) @property def state(self): if self.closed: return 'closed' else: return 'open' @property def column_names(self): if self._output is not None: return list(self.database.typio.decodes(self._output.keys())) # `None` if _output does not exist; not row data @property def column_types(self): if self._output is not None: return [self.database.typio.type_from_oid(x[3]) for x in self._output] # `None` if _output does not exist; not row data @property def pg_column_types(self): if self._output is not None: return [x[3] for x in self._output] # `None` if _output does not exist; not row data @property def sql_column_types(self): return [ self.database.typio.sql_type_from_oid(x) for x in self.pg_column_types ] def command(self): "The completion message's command identifier" if self._complete_message is not None: return self._complete_message.extract_command().decode('ascii') def count(self): "The completion message's count number" if self._complete_message is not None: return self._complete_message.extract_count() class Chunks(Output, pg_api.Chunks): pass ## # FetchAll - A Chunks cursor that gets *all* the records in the cursor. # # It has added complexity over other variants as in order to stream results, # chunks have to be removed from the protocol transaction's received messages. # If this wasn't done, the entire result set would be fully buffered prior # to processing. class FetchAll(Chunks): _e_factors = ('statement', 'parameters',) def _e_metas(self): yield ('type', type(self).__name__) def __init__(self, statement, parameters): self.statement = statement self.parameters = parameters self.database = statement.database Output.__init__(self, '') def _init(self, null = element.Null.type, complete = element.Complete.type, bindcomplete = element.BindComplete.type, parsecomplete = element.ParseComplete.type, ): expect = self._expect self._xact = self._ins( self._pq_xp_fetchall() + (element.SynchronizeMessage,) ) self.database._pq_push(self._xact, self) # Get more messages until the first Tuple is seen. STEP = self.database._pq_step while self._xact.state != xact.Complete: STEP() for x in self._xact.messages_received(): if x.__class__ is tuple or expect == x.type: # No need to step anymore once this is seen. return elif x.type == null: # The protocol transaction is going to be complete.. self.database._pq_complete() self._xact = None return elif x.type == complete: self._complete_message = x self.database._pq_complete() # If this was a select/copy cursor, # the data messages would have caused an earlier # return. It's empty. self._xact = None return elif x.type in (bindcomplete, parsecomplete): # Noise. pass else: # This should have been caught by the protocol transaction. # "Can't happen". self.database._pq_complete() if self._xact.fatal is None: self._xact.fatal = False self._xact.error_message = element.ClientError(( (b'S', 'ERROR'), (b'C', "--000"), (b'M', "unexpected message type " + repr(x.type)) )) self.database.typio.raise_client_error(self._xact.error_message, creator = self) return def __next__(self, data_types = (tuple,bytes), complete = element.Complete.type, ): x = self._xact # self._xact = None; means that the cursor has been exhausted. if x is None: raise StopIteration # Finish the protocol transaction. STEP = self.database._pq_step while x.state is not xact.Complete and not x.completed: STEP() # fatal is None == no error # fatal is True == dead connection # fatal is False == dead transaction if x.fatal is not None: self.database.typio.raise_error(x.error_message, creator = getattr(self, '_controller', self) or self) # no messages to process? if not x.completed: # Transaction has been cleaned out of completed? iterator is done. self._xact = None self.close() raise StopIteration # Get the chunk to be processed. chunk = [ y for y in x.completed[0][1] if y.__class__ in data_types ] r = self._process_chunk(chunk) # Scan for _complete_message. # Arguably, this can fail, but it would be a case # where multiple sync messages were issued. Something that's # not naturally occurring. for y in x.completed[0][1][-3:]: if getattr(y, 'type', None) == complete: self._complete_message = y # Remove it, it's been processed. del x.completed[0] return r class SingleXactCopy(FetchAll): _expect = element.CopyToBegin.type _process_chunk = FetchAll._process_copy_chunk class SingleXactFetch(FetchAll): _expect = element.Tuple.type class MultiXactStream(Chunks): chunksize = 1024 * 4 # only tuple streams _process_chunk = Output._process_tuple_chunk def _e_metas(self): yield ('chunksize', self.chunksize) yield ('type', self.__class__.__name__) def __init__(self, statement, parameters, cursor_id): self.statement = statement self.parameters = parameters self.database = statement.database Output.__init__(self, cursor_id or ID(self)) @abstractmethod def _bind(self): """ Generate the commands needed to bind the cursor. """ @abstractmethod def _fetch(self): """ Generate the commands needed to bind the cursor. """ def _init(self): self._command = self._fetch() self._xact = self._ins(self._bind() + self._command) self.database._pq_push(self._xact, self) def __next__(self, tuple_type = tuple): x = self._xact if x is None: raise StopIteration if self.database.pq.xact is x: self.database._pq_complete() # get all the element.Tuple messages chunk = [ y for y in x.messages_received() if y.__class__ is tuple_type ] if len(chunk) == self.chunksize: # there may be more, dispatch the request for the next chunk self._xact = self._ins(self._command) self.database._pq_push(self._xact, self) else: # it's done. self._xact = None self.close() if not chunk: # chunk is empty, it's done *right* now. raise StopIteration chunk = self._process_chunk(chunk) return chunk ## # The cursor is streamed to the client on demand *inside* # a single SQL transaction block. class MultiXactInsideBlock(MultiXactStream): _bind = MultiXactStream._pq_xp_bind def _fetch(self): ## # Use the extended protocol's execute to fetch more. return self._pq_xp_execute(self.chunksize) + \ (element.SynchronizeMessage,) ## # The cursor is streamed to the client on demand *outside* of # a single SQL transaction block. [DECLARE ... WITH HOLD] class MultiXactOutsideBlock(MultiXactStream): _bind = MultiXactStream._pq_xp_declare def _fetch(self): ## # Use the extended protocol's execute to fetch more *against* # an SQL FETCH statement yielding the data in the proper format. # # MultiXactOutsideBlock uses DECLARE to create the cursor WITH HOLD. # When this is done, the cursor is configured to use StringFormat with # all columns. It's necessary to use FETCH to adjust the formatting. return self._pq_xp_fetch(True, self.chunksize) + \ (element.SynchronizeMessage,) ## # Cursor is used to manage scrollable cursors. class Cursor(Output, pg_api.Cursor): _process_tuple = Output._process_tuple_chunk_Row def _e_metas(self): yield ('direction', 'FORWARD' if self.direction else 'BACKWORD') yield ('type', 'Cursor') def clone(self): return type(self)(self.statement, self.parameters, self.database, None) def __init__(self, statement, parameters, database, cursor_id): self.database = database or statement.database self.statement = statement self.parameters = parameters self.__dict__['direction'] = True if self.statement is None: self._e_factors = ('database', 'cursor_id') Output.__init__(self, cursor_id or ID(self)) def get_direction(self): return self.__dict__['direction'] def set_direction(self, value): self.__dict__['direction'] = direction_to_bool(value) direction = property( fget = get_direction, fset = set_direction, ) del get_direction, set_direction def _which_way(self, direction): if direction is not None: direction = direction_to_bool(direction) # -1 * -1 = 1, -1 * 1 = -1, 1 * 1 = 1 return not ((not self.direction) ^ (not direction)) else: return self.direction def _init(self, tupledesc = element.TupleDescriptor.type, ): """ Based on the cursor parameters and the current transaction state, select a cursor strategy for managing the response from the server. """ if self.statement is not None: x = self._ins(self._pq_xp_declare() + (element.SynchronizeMessage,)) self.database._pq_push(x, self) self.database._pq_complete() else: x = self._ins(self._pq_xp_describe() + (element.SynchronizeMessage,)) self.database._pq_push(x, self) self.database._pq_complete() for m in x.messages_received(): if m.type == tupledesc: typio = self.database.typio self._output = m self._output_attmap = typio.attribute_map(self._output) self._row_constructor = typio.RowTypeFactory(self._output_attmap) # tuple output self._output_io = typio.resolve_descriptor( self._output, 1 # (input, output)[1] ) self._output_formats = [ element.StringFormat if x is None else element.BinaryFormat for x in self._output_io ] self._output_io = tuple([ x or typio.decode for x in self._output_io ]) def __next__(self): result = self._fetch(self.direction, 1) if not result: raise StopIteration else: return result[0] def read(self, quantity = None, direction = None): if quantity == 0: return [] dir = self._which_way(direction) return self._fetch(dir, quantity) def _fetch(self, direction, quantity): x = self._ins( self._pq_xp_fetch(direction, quantity) + \ (element.SynchronizeMessage,) ) self.database._pq_push(x, self) self.database._pq_complete() return self._process_tuple(( y for y in x.messages_received() if y.__class__ is tuple )) def seek(self, offset, whence = 'ABSOLUTE'): rwhence = self._seek_whence_map.get(whence, whence) if rwhence is None or rwhence.upper() not in \ self._seek_whence_map.values(): raise TypeError( "unknown whence parameter, %r" %(whence,) ) rwhence = rwhence.upper() if offset == 'ALL': if rwhence not in ('BACKWARD', 'FORWARD'): rwhence = 'BACKWARD' if self.direction is False else 'FORWARD' else: if offset < 0 and rwhence == 'BACKWARD': offset = -offset rwhence = 'FORWARD' if self.direction is False: if offset == 'ALL' and rwhence != 'FORWARD': rwhence = 'BACKWARD' else: if rwhence == 'RELATIVE': offset = -offset elif rwhence == 'ABSOLUTE': rwhence = 'FROM_END' else: rwhence = 'ABSOLUTE' if rwhence in ('RELATIVE', 'BACKWARD', 'FORWARD'): if offset == 'ALL': cmd = self._pq_xp_move( str(offset).encode('ascii'), str(rwhence).encode('ascii') ) else: if offset < 0: cmd = self._pq_xp_move( str(-offset).encode('ascii'), b'BACKWARD' ) else: cmd = self._pq_xp_move( str(offset).encode('ascii'), str(rwhence).encode('ascii') ) elif rwhence == 'ABSOLUTE': cmd = self._pq_xp_move(str(offset).encode('ascii'), b'ABSOLUTE') else: # move to last record, then consume it to put the position at # the very end of the cursor. cmd = self._pq_xp_move(b'', b'LAST') + \ self._pq_xp_move(b'', b'NEXT') + \ self._pq_xp_move(str(offset).encode('ascii'), b'BACKWARD') x = self._ins(cmd + (element.SynchronizeMessage,),) self.database._pq_push(x, self) self.database._pq_complete() count = None complete = element.Complete.type for cm in x.messages_received(): if getattr(cm, 'type', None) == complete: count = cm.extract_count() break # XXX: Raise if count is None? return count class SingleExecution(pg_api.Execution): database = None def __init__(self, database): self._prepare = database.prepare def load_rows(self, query, *parameters): return self._prepare(query).load_rows(*parameters) def load_chunks(self, query, *parameters): return self._prepare(query).load_chunks(*parameters) def __call__(self, query, *parameters): return self._prepare(query)(*parameters) def declare(self, query, *parameters): return self._prepare(query).declare(*parameters) def rows(self, query, *parameters): return self._prepare(query).rows(*parameters) def chunks(self, query, *parameters): return self._prepare(query).chunks(*parameters) def column(self, query, *parameters): return self._prepare(query).column(*parameters) def first(self, query, *parameters): return self._prepare(query).first(*parameters) class Statement(pg_api.Statement): string = None database = None statement_id = None _input = None _output = None _output_io = None _output_formats = None _output_attmap = None def _e_metas(self): yield (None, '[' + self.state + ']') if hasattr(self._xact, 'error_message'): # be very careful not to trigger an exception. # even in the cases of effective protocol errors, # it is important not to bomb out. pos = self._xact.error_message.get(b'P') if pos is not None and pos.isdigit(): try: pos = int(pos) # get the statement source q = str(self.string) # normalize position.. pos = len('\n'.join(q[:pos].splitlines())) # normalize newlines q = '\n'.join(q.splitlines()) line_no = q.count('\n', 0, pos) + 1 # replace tabs with spaces because there is no way to identify # the tab size of the final display. (ie, marker will be wrong) q = q.replace('\t', ' ') # grab the relevant part of the query string. # the full source will be printed elsewhere. # beginning of string or the newline before the position bov = q.rfind('\n', 0, pos) + 1 # end of string or the newline after the position eov = q.find('\n', pos) if eov == -1: eov = len(q) view = q[bov:eov] # position relative to the beginning of the view pos = pos-bov # analyze lines prior to position dlines = view.splitlines() marker = ((pos-1) * ' ') + '^' + ( ' [line %d, character %d] ' %(line_no, pos) ) # insert marker dlines.append(marker) yield ('LINE', os.linesep.join(dlines)) except: import traceback yield ('LINE', traceback.format_exc(chain=False)) spt = self.sql_parameter_types if spt is not None: yield ('sql_parameter_types', spt) cn = self.column_names ct = self.sql_column_types if cn is not None: if ct is not None: yield ( 'results', '(' + ', '.join([ '{!r} {!r}'.format(n, t) for n,t in zip(cn,ct) ]) + ')' ) else: yield ('sql_column_names', cn) elif ct is not None: yield ('sql_column_types', ct) def clone(self): ps = self.__class__(self.database, None, self.string) ps._init() ps._fini() return ps def __init__(self, database, statement_id, string, wref = weakref.ref ): self.database = database self.string = string self.statement_id = statement_id or ID(self) self._xact = None self.closed = None self._pq_statement_id = database.typio._encode(self.statement_id)[0] if not statement_id: # Register statement on a connection to close it automatically on db end database.pq.register_statement(self, self._pq_statement_id) def __repr__(self): return '<{mod}.{name}[{ci}] {state}>'.format( mod = self.__class__.__module__, name = self.__class__.__name__, ci = self.database.connector._pq_iri, state = self.state, ) def _pq_parameters(self, parameters, proc = process_tuple): return proc( self._input_io, parameters, self._raise_parameter_tuple_error ) ## # process_tuple failed(exception). The parameters could not be packed. # This function is called with the given information in the context # of the original exception(to allow chaining). def _raise_parameter_tuple_error(self, cause, procs, tup, itemnum): # Find the SQL type name. This should *not* hit the server. typ = self.database.typio.sql_type_from_oid( self.pg_parameter_types[itemnum] ) or '' # Representation of the bad parameter. bad_data = repr(tup[itemnum]) if len(bad_data) > 80: # Be sure not to fill screen with noise. bad_data = bad_data[:75] + ' ...' em = element.ClientError(( (b'S', 'ERROR'), (b'C', '--PIO'), (b'M', "could not pack parameter %s::%s for transfer" %( ('$' + str(itemnum + 1)), typ, ) ), (b'D', bad_data), (b'H', "Try casting the parameter to 'text', then to the target type."), (b'P', str(itemnum)) )) self.database.typio.raise_client_error(em, creator = self, cause = cause) ## # Similar to the parameter variant. def _raise_column_tuple_error(self, cause, procs, tup, itemnum): # Find the SQL type name. This should *not* hit the server. typ = self.database.typio.sql_type_from_oid( self.pg_column_types[itemnum] ) or '' # Representation of the bad column. data = repr(tup[itemnum]) if len(data) > 80: # Be sure not to fill screen with noise. data = data[:75] + ' ...' em = element.ClientError(( (b'S', 'ERROR'), (b'C', '--CIO'), (b'M', "could not unpack column %r, %s::%s, from wire data" %( itemnum, self.column_names[itemnum], typ ) ), (b'D', data), (b'H', "Try casting the column to 'text'."), (b'P', str(itemnum)), )) self.database.typio.raise_client_error(em, creator = self, cause = cause) @property def state(self) -> str: if self.closed: if self._xact is not None: if self.string is not None: return 'parsing' else: return 'describing' return 'closed' return 'prepared' @property def column_names(self): if self.closed is None: self._fini() if self._output is not None: return list(self.database.typio.decodes(self._output.keys())) @property def parameter_types(self): if self.closed is None: self._fini() if self._input is not None: return [self.database.typio.type_from_oid(x) for x in self._input] @property def column_types(self): if self.closed is None: self._fini() if self._output is not None: return [ self.database.typio.type_from_oid(x[3]) for x in self._output ] @property def pg_parameter_types(self): if self.closed is None: self._fini() return self._input @property def pg_column_types(self): if self.closed is None: self._fini() if self._output is not None: return [x[3] for x in self._output] @property def sql_column_types(self): if self.closed is None: self._fini() if self._output is not None: return [ self.database.typio.sql_type_from_oid(x) for x in self.pg_column_types ] @property def sql_parameter_types(self): if self.closed is None: self._fini() if self._input is not None: return [ self.database.typio.sql_type_from_oid(x) for x in self.pg_parameter_types ] def close(self): if self.closed is False: self.database.pq.trash_statement(self._pq_statement_id) self.closed = True def _init(self): """ Push initialization messages to the server, but don't wait for the return as there may be things that can be done while waiting for the return. Use the _fini() to complete. """ if self.string is not None: q = self.database.typio._encode(str(self.string))[0] cmd = [ element.CloseStatement(self._pq_statement_id), element.Parse(self._pq_statement_id, q, ()), ] else: cmd = [] cmd.extend( ( element.DescribeStatement(self._pq_statement_id), element.SynchronizeMessage, ) ) self._xact = xact.Instruction(cmd, asynchook = self.database._receive_async) self.database._pq_push(self._xact, self) def _fini(self, strfmt = element.StringFormat, binfmt = element.BinaryFormat): """ Complete initialization that the _init() method started. """ # assume that the transaction has been primed. if self._xact is None: raise RuntimeError("_fini called prior to _init; invalid state") if self._xact is self.database.pq.xact: try: self.database._pq_complete() except Exception: self.closed = True raise (*head, argtypes, tupdesc, last) = self._xact.messages_received() typio = self.database.typio if tupdesc is None or tupdesc is element.NoDataMessage: # Not typed output. self._output = None self._output_attmap = None self._output_io = None self._output_formats = None self._row_constructor = None else: self._output = tupdesc self._output_attmap = dict( typio.attribute_map(tupdesc) ) self._row_constructor = self.database.typio.RowTypeFactory(self._output_attmap) # tuple output self._output_io = typio.resolve_descriptor(tupdesc, 1) self._output_formats = [ strfmt if x is None else binfmt for x in self._output_io ] self._output_io = tuple([ x or typio.decode for x in self._output_io ]) self._input = argtypes packs = [] formats = [] for x in argtypes: pack = (typio.resolve(x) or (None,None))[0] packs.append(pack or typio.encode) formats.append( strfmt if x is None else binfmt ) self._input_io = tuple(packs) self._input_formats = formats self.closed = False self._xact = None def __call__(self, *parameters): if self._input is not None: if len(parameters) != len(self._input): raise TypeError("statement requires %d parameters, given %d" %( len(self._input), len(parameters) )) ## # get em' all! if self._output is None: # might be a copy. c = SingleXactCopy(self, parameters) else: c = SingleXactFetch(self, parameters) c._process_chunk = c._process_tuple_chunk_Row # iff output is None, it's not a tuple returning query. # however, if it's a copy, detect that fact by SingleXactCopy's # immediate return after finding the copy begin message(no complete). if self._output is None: cmd = c.command() if cmd is not None: return (cmd, c.count()) # Returns rows, accumulate in a list. r = [] for x in c: r.extend(x) return r def declare(self, *parameters): if self.closed is None: self._fini() if self._input is not None: if len(parameters) != len(self._input): raise TypeError("statement requires %d parameters, given %d" %( len(self._input), len(parameters) )) return Cursor(self, parameters, self.database, None) def rows(self, *parameters, **kw): chunks = self.chunks(*parameters, **kw) if chunks._output_io: chunks._process_chunk = chunks._process_tuple_chunk_Row return chain.from_iterable(chunks) __iter__ = rows def chunks(self, *parameters): if self.closed is None: self._fini() if self._input is not None: if len(parameters) != len(self._input): raise TypeError("statement requires %d parameters, given %d" %( len(self._input), len(parameters) )) if self._output is None: # It's *probably* a COPY. return SingleXactCopy(self, parameters) if self.database.pq.state == b'I': # Currently, *not* in a Transaction block, so # DECLARE the statement WITH HOLD in order to allow # access across transactions. if self.string is not None: return MultiXactOutsideBlock(self, parameters, None) else: ## # Statement source unknown, so it can't be DECLARE'd. # This happens when statement_from_id is used. return SingleXactFetch(self, parameters) else: # Likely, the best possible case. It gets to use Execute messages. return MultiXactInsideBlock(self, parameters, None) def column(self, *parameters, **kw): chunks = self.chunks(*parameters, **kw) chunks._process_chunk = chunks._process_tuple_chunk_Column return chain.from_iterable(chunks) def first(self, *parameters): if self.closed is None: # Not fully initialized; assume interrupted. self._fini() if self._input is not None: # Use a regular TypeError. if len(parameters) != len(self._input): raise TypeError("statement requires %d parameters, given %d" %( len(self._input), len(parameters) )) # Parameters? Build em'. db = self.database if self._input_io: params = process_tuple( self._input_io, parameters, self._raise_parameter_tuple_error ) else: params = () # Run the statement x = xact.Instruction(( element.Bind( b'', self._pq_statement_id, self._input_formats, params, self._output_formats or (), ), # Get all element.Execute(b'', 0xFFFFFFFF), element.ClosePortal(b''), element.SynchronizeMessage ), asynchook = db._receive_async ) # Push and complete protocol transaction. db._pq_push(x, self) db._pq_complete() if self._output_io: ## # It returned rows, look for the first tuple. tuple_type = element.Tuple.type for xt in x.messages_received(): if xt.__class__ is tuple: break else: return None if len(self._output_io) > 1: # Multiple columns, return a Row. return self._row_constructor( process_tuple( self._output_io, xt, self._raise_column_tuple_error ) ) else: # Single column output. if xt[0] is None: return None io = self._output_io[0] or self.database.typio.decode return io(xt[0]) else: ## # It doesn't return rows, so return a count. ## # This loop searches through the received messages # for the Complete message which contains the count. complete = element.Complete.type for cm in x.messages_received(): # Use getattr because COPY doesn't produce # element.Message instances. if getattr(cm, 'type', None) == complete: break else: # Probably a Null command. return None count = cm.extract_count() if count is None: command = cm.extract_command() if command is not None: return command.decode('ascii') return count def _load_copy_chunks(self, chunks, *parameters): """ Given an chunks of COPY lines, execute the COPY ... FROM STDIN statement and send the copy lines produced by the iterable to the remote end. """ x = xact.Instruction(( element.Bind( b'', self._pq_statement_id, (), (), (), ), element.Execute(b'', 1), element.SynchronizeMessage, ), asynchook = self.database._receive_async ) self.database._pq_push(x, self) # localize step = self.database._pq_step # Get the COPY started. while x.state is not xact.Complete: step() if hasattr(x, 'CopyFailSequence') and x.messages is x.CopyFailSequence: # The protocol transaction has noticed that its a COPY. break else: # Oh, it's not a COPY at all. x.fatal = x.fatal or False x.error_message = element.ClientError(( (b'S', 'ERROR'), # OperationError (b'C', '--OPE'), (b'M', "_load_copy_chunks() used on a non-COPY FROM STDIN query"), )) self.database.typio.raise_client_error(x.error_message, creator = self) for chunk in chunks: x.messages = list(chunk) while x.messages is not x.CopyFailSequence: # Continue stepping until the transaction # sets the CopyFailSequence again. That's # the signal that the transaction has sent # all the previously set messages. step() x.messages = x.CopyDoneSequence self.database._pq_complete() self.database.pq.synchronize() def _load_tuple_chunks(self, chunks): pte = self._raise_parameter_tuple_error last = (element.SynchronizeMessage,) try: for chunk in chunks: bindings = [ ( element.Bind( b'', self._pq_statement_id, self._input_formats, process_tuple( self._input_io, tuple(t), pte ), (), ), element.Execute(b'', 1), ) for t in chunk ] bindings.append(last) self.database._pq_push( xact.Instruction( chain.from_iterable(bindings), asynchook = self.database._receive_async ), self ) self.database._pq_complete() except: ## # In cases where row packing errors or occur, # synchronize, finishing any pending transaction, # and raise the error. ## # If the data sent to the remote end is invalid, # _complete will raise the exception and the current # exception being marked as the cause, so there should # be no [exception] information loss. ## self.database.pq.synchronize() raise def load_chunks(self, chunks, *parameters): """ Execute the query for each row-parameter set in `iterable`. In cases of ``COPY ... FROM STDIN``, iterable must be an iterable of sequences of `bytes`. """ if self.closed is None: self._fini() if not self._input or parameters: return self._load_copy_chunks(chunks) else: return self._load_tuple_chunks(chunks) def load_rows(self, rows, chunksize = 256): return self.load_chunks(chunk(rows, chunksize)) PreparedStatement = Statement class StoredProcedure(pg_api.StoredProcedure): _e_factors = ('database', 'procedure_id') procedure_id = None def _e_metas(self): yield ('oid', self.oid) def __repr__(self): return '<%s:%s>' %( self.procedure_id, self.statement.string ) def __call__(self, *args, **kw): if kw: input = [] argiter = iter(args) try: word_idx = [(kw[k], self._input_attmap[k]) for k in kw] except KeyError as k: raise TypeError("%s got unexpected keyword argument %r" %( self.name, k.message ) ) word_idx.sort(key = get1) current_word = word_idx.pop(0) for x in range(argc): if x == current_word[1]: input.append(current_word[0]) current_word = word_idx.pop(0) else: input.append(argiter.next()) else: input = args if self.srf is True: if self.composite is True: return self.statement.rows(*input) else: # A generator expression is very appropriate here # as SRFs returning large number of rows would require # substantial amounts of memory. return map(get0, self.statement.rows(*input)) else: if self.composite is True: return self.statement(*input)[0] else: return self.statement(*input)[0][0] def __init__(self, ident, database, description = ()): # Lookup pg_proc on database. if isinstance(ident, int): proctup = database.sys.lookup_procedure_oid(int(ident)) else: proctup = database.sys.lookup_procedure_rp(str(ident)) if proctup is None: raise LookupError("no function with identifier %s" %(str(ident),)) self.procedure_id = ident self.oid = proctup[0] self.name = proctup["proname"] self._input_attmap = {} argnames = proctup.get('proargnames') or () for x in range(len(argnames)): an = argnames[x] if an is not None: self._input_attmap[an] = x proargs = proctup['proargtypes'] for x in proargs: # get metadata filled out. database.typio.resolve(x) self.statement = database.prepare( "SELECT * FROM %s(%s) AS func%s" %( proctup['_proid'], # ($1::type, $2::type, ... $n::type) ', '.join([ '$%d::%s' %(x + 1, database.typio.sql_type_from_oid(proargs[x])) for x in range(len(proargs)) ]), # Description for anonymous record returns (description and \ '(' + ','.join(description) + ')' or '') ) ) self.srf = bool(proctup.get("proretset")) self.composite = proctup["composite"] class SettingsCM(object): def __init__(self, database, settings_to_set): self.database = database self.settings_to_set = settings_to_set def __enter__(self): if hasattr(self, 'stored_settings'): raise RuntimeError("cannot re-use setting CMs") self.stored_settings = self.database.settings.getset( self.settings_to_set.keys() ) self.database.settings.update(self.settings_to_set) def __exit__(self, typ, val, tb): self.database.settings.update(self.stored_settings) class Settings(pg_api.Settings): _e_factors = ('database',) def __init__(self, database): self.database = database self.cache = {} def _e_metas(self): yield (None, str(len(self.cache))) def _clear_cache(self): self.cache.clear() def __getitem__(self, i): v = self.cache.get(i) if v is None: r = self.database.sys.setting_get(i) if r: v = r[0][0] else: raise KeyError(i) return v def __setitem__(self, i, v): cv = self.cache.get(i) if cv == v: return setas = self.database.sys.setting_set(i, v) self.cache[i] = setas def __delitem__(self, k): self.database.execute( 'RESET "' + k.replace('"', '""') + '"' ) self.cache.pop(k, None) def __len__(self): return self.database.sys.setting_len() def __call__(self, **settings): return SettingsCM(self.database, settings) def path(): def fget(self): return pg_str.split_ident(self["search_path"]) def fset(self, value): self['search_path'] = ','.join([ '"%s"' %(x.replace('"', '""'),) for x in value ]) def fdel(self): if self.database.connector.path is not None: self.path = self.database.connector.path else: self.database.execute("RESET search_path") doc = 'structured search_path interface' return locals() path = property(**path()) def get(self, k, alt = None): if k in self.cache: return self.cache[k] db = self.database r = self.database.sys.setting_get(k) if r: v = r[0][0] self.cache[k] = v else: v = alt return v def getset(self, keys): setmap = {} rkeys = [] for k in keys: v = self.cache.get(k) if v is not None: setmap[k] = v else: rkeys.append(k) if rkeys: r = self.database.sys.setting_mget(rkeys) self.cache.update(r) setmap.update(r) rem = set(rkeys) - set([x['name'] for x in r]) if rem: raise KeyError(rem) return setmap def keys(self): return map(get0, self.database.sys.setting_keys()) __iter__ = keys def values(self): return map(get0, self.database.sys.setting_values()) def items(self): return self.database.sys.setting_items() def update(self, d): kvl = [list(x) for x in dict(d).items()] self.cache.update(self.database.sys.setting_update(kvl)) def _notify(self, msg): subs = getattr(self, '_subscriptions', {}) d = self.database.typio._decode key = d(msg.name)[0] val = d(msg.value)[0] for x in subs.get(key, ()): x(self.database, key, val) if None in subs: for x in subs[None]: x(self.database, key, val) self.cache[key] = val def subscribe(self, key, callback): """ Subscribe to changes of the setting using the callback. When the setting is changed, the callback will be invoked with the connection, the key, and the new value. If the old value is locally cached, its value will still be available for inspection, but there is no guarantee. If `None` is passed as the key, the callback will be called whenever any setting is remotely changed. >>> def watch(connection, key, newval): ... >>> db.settings.subscribe('TimeZone', watch) """ subs = self._subscriptions = getattr(self, '_subscriptions', {}) callbacks = subs.setdefault(key, []) if callback not in callbacks: callbacks.append(callback) def unsubscribe(self, key, callback): """ Stop listening for changes to a setting. The setting name(`key`), and the callback used to subscribe must be given again for successful termination of the subscription. >>> db.settings.unsubscribe('TimeZone', watch) """ subs = getattr(self, '_subscriptions', {}) callbacks = subs.get(key, ()) if callback in callbacks: callbacks.remove(callback) class Transaction(pg_api.Transaction): database = None mode = None isolation = None _e_factors = ('database', 'isolation', 'mode') def _e_metas(self): yield (None, self.state) def __init__(self, database, isolation = None, mode = None): self.database = database self.isolation = isolation self.mode = mode self.state = 'initialized' self.type = None def __enter__(self): self.start() return self def __exit__(self, typ, value, tb): if typ is None: # No exception, but in a failed transaction? if self.database.pq.state == b'E': if not self.database.closed: self.rollback() # pg_exc.InFailedTransactionError em = element.ClientError(( (b'S', 'ERROR'), (b'C', '25P02'), (b'M', 'invalid transaction block exit detected'), (b'H', "Database was in an error-state, but no exception was raised.") )) self.database.typio.raise_client_error(em, creator = self) else: # No exception, and no error state. Everything is good. try: self.commit() # If an error occurs, clean up the transaction state # and raise as needed. except pg_exc.ActiveTransactionError as err: if not self.database.closed: # adjust the state so rollback will do the right thing and abort. self.state = 'open' self.rollback() raise elif issubclass(typ, Exception): # There's an exception, so only rollback if the connection # exists. If the rollback() was called here, it would just # contribute noise to the error. if not self.database.closed: self.rollback() @staticmethod def _start_xact_string(isolation = None, mode = None): q = 'START TRANSACTION' if isolation is not None: if ';' in isolation: raise ValueError("invalid transaction isolation " + repr(mode)) q += ' ISOLATION LEVEL ' + isolation if mode is not None: if ';' in mode: raise ValueError("invalid transaction mode " + repr(isolation)) q += ' ' + mode return q + ';' @staticmethod def _savepoint_xact_string(id): return 'SAVEPOINT "xact(' + id.replace('"', '""') + ')";' def start(self): if self.state == 'open': return if self.state != 'initialized': em = element.ClientError(( (b'S', 'ERROR'), (b'C', '--OPE'), (b'M', "transactions cannot be restarted"), (b'H', 'Create a new transaction object instead of re-using an old one.') )) self.database.typio.raise_client_error(em, creator = self) if self.database.pq.state == b'I': self.type = 'block' q = self._start_xact_string( isolation = self.isolation, mode = self.mode, ) else: self.type = 'savepoint' if (self.isolation, self.mode) != (None,None): em = element.ClientError(( (b'S', 'ERROR'), (b'C', '--OPE'), (b'M', "configured transaction used inside a transaction block"), (b'H', 'A transaction block was already started.'), )) self.database.typio.raise_client_error(em, creator = self) q = self._savepoint_xact_string(hex(id(self))) self.database.execute(q) self.state = 'open' begin = start @staticmethod def _release_string(id): 'release "";' return 'RELEASE "xact(' + id.replace('"', '""') + ')";' def commit(self): if self.state == 'committed': return if self.state != 'open': em = element.ClientError(( (b'S', 'ERROR'), (b'C', '--OPE'), (b'M', "commit attempted on transaction with unexpected state, " + repr(self.state)), )) self.database.typio.raise_client_error(em, creator = self) if self.type == 'block': q = 'COMMIT' else: q = self._release_string(hex(id(self))) self.database.execute(q) self.state = 'committed' @staticmethod def _rollback_to_string(id, fmt = 'ROLLBACK TO "xact({0})"; RELEASE "xact({0})";'.format): return fmt(id.replace('"', '""')) def rollback(self): if self.state == 'aborted': return if self.state not in ('prepared', 'open'): em = element.ClientError(( (b'S', 'ERROR'), (b'C', '--OPE'), (b'M', "ABORT attempted on transaction with unexpected state, " + repr(self.state)), )) self.database.typio.raise_client_error(em, creator = self) if self.type == 'block': q = 'ABORT;' elif self.type == 'savepoint': q = self._rollback_to_string(hex(id(self))) else: raise RuntimeError("unknown transaction type " + repr(self.type)) self.database.execute(q) self.state = 'aborted' abort = rollback class Connection(pg_api.Connection): connector = None type = None version_info = None version = None security = None backend_id = None client_address = None client_port = None # Replaced with instances on connection instantiation. settings = Settings def _e_metas(self): yield (None, '[' + self.state + ']') if self.client_address is not None: yield ('client_address', self.client_address) if self.client_port is not None: yield ('client_port', self.client_port) if self.version is not None: yield ('version', self.version) att = getattr(self, 'failures', None) if att: count = 0 for x in att: # Format each failure without their traceback. errstr = ''.join(format_exception(type(x.error), x.error, None)) factinfo = str(x.socket_factory) if hasattr(x, 'ssl_negotiation'): if x.ssl_negotiation is True: factinfo = 'SSL ' + factinfo else: factinfo = 'NOSSL ' + factinfo yield ( 'failures[' + str(count) + ']', factinfo + os.linesep + errstr ) count += 1 def __repr__(self): return '<%s.%s[%s] %s>' %( type(self).__module__, type(self).__name__, self.connector._pq_iri, self.closed and 'closed' or '%s' %(self.pq.state,) ) def __exit__(self, type, value, tb): # Don't bother closing unless it's a normal exception. if type is None or issubclass(type, Exception): self.close() def interrupt(self, timeout = None): self.pq.interrupt(timeout = timeout) def execute(self, query : str) -> None: q = xact.Instruction(( element.Query(self.typio._encode(query)[0]), ), asynchook = self._receive_async ) self._pq_push(q, self) self._pq_complete() def do(self, language : str, source : str, qlit = pg_str.quote_literal, qid = pg_str.quote_ident, ) -> None: sql = "DO " + qlit(source) + " LANGUAGE " + qid(language) + ";" self.execute(sql) def xact(self, isolation = None, mode = None): x = Transaction(self, isolation = isolation, mode = mode) return x def prepare(self, sql_statement_string : str, statement_id = None, Class = Statement ) -> Statement: ps = Class(self, statement_id, sql_statement_string) ps._init() ps._fini() return ps @property def query(self, Class = SingleExecution): return Class(self) def statement_from_id(self, statement_id : str) -> Statement: ps = Statement(self, statement_id, None) ps._init() ps._fini() return ps def proc(self, proc_id : (str, int)) -> StoredProcedure: sp = StoredProcedure(proc_id, self) return sp def cursor_from_id(self, cursor_id : str) -> Cursor: c = Cursor(None, None, self, cursor_id) c._init() return c @property def closed(self) -> bool: if getattr(self, 'pq', None) is None: return True if hasattr(self.pq, 'socket') and self.pq.xact is not None: return self.pq.xact.fatal is True return False def close(self, getattr = getattr): # Write out the disconnect message if the socket is around. # If the connection is known to be lost, don't bother. It will # generate an extra exception. if getattr(self, 'pq', None) is None or getattr(self.pq, 'socket', None) is None: # No action to take. return x = getattr(self.pq, 'xact', None) if x is not None and x.fatal is not True: # finish the existing pq transaction iff it's not Closing. self.pq.complete() if self.pq.xact is None: # It completed the existing transaction. self.pq.push(xact.Closing()) self.pq.complete() if self.pq.socket: self.pq.complete() # Close the socket if there is one. if self.pq.socket: self.pq.socket.close() self.pq.socket = None @property def state(self) -> str: if not hasattr(self, 'pq'): return 'initialized' if hasattr(self, 'failures'): return 'failed' if self.closed: return 'closed' if isinstance(self.pq.xact, xact.Negotiation): return 'negotiating' if self.pq.xact is None: if self.pq.state == b'E': return 'failed block' return 'idle' + (' in block' if self.pq.state != b'I' else '') else: return 'busy' def reset(self): """ restore original settings, reset the transaction, drop temporary objects. """ self.execute("ABORT; RESET ALL;") def __enter__(self): self.connect() return self def connect(self): 'Establish the connection to the server' if self.closed is False: # already connected? just return. return if hasattr(self, 'pq'): # It's closed, *but* there's a PQ connection.. x = self.pq.xact self.typio.raise_error(x.error_message, cause = getattr(x, 'exception', None), creator = self) # It's closed. try: self._establish() except Exception: # Close it up on failure. self.close() raise def _establish(self): # guts of connect() self.pq = None # if any exception occurs past this point, the connection # will not be usable. timeout = self.connector.connect_timeout sslmode = self.connector.sslmode or 'prefer' failures = [] exc = None try: # get the list of sockets to try socket_factories = self.connector.socket_factory_sequence() except Exception as e: socket_factories = () exc = e # When ssl is None: SSL negotiation will not occur. # When ssl is True: SSL negotiation will occur *and* it must succeed. # When ssl is False: SSL negotiation will occur but it may fail(NOSSL). if sslmode == 'allow': # without ssl, then with. :) socket_factories = interlace( zip(repeat(None, len(socket_factories)), socket_factories), zip(repeat(True, len(socket_factories)), socket_factories) ) elif sslmode == 'prefer': # with ssl, then without. [maybe] :) socket_factories = interlace( zip(repeat(False, len(socket_factories)), socket_factories), zip(repeat(None, len(socket_factories)), socket_factories) ) # prefer is special, because it *may* be possible to # skip the subsequent "without" in situations where SSL is off. elif sslmode == 'require': socket_factories = zip(repeat(True, len(socket_factories)), socket_factories) elif sslmode == 'disable': # None = Do Not Attempt SSL negotiation. socket_factories = zip(repeat(None, len(socket_factories)), socket_factories) else: raise ValueError("invalid sslmode: " + repr(sslmode)) # can_skip is used when 'prefer' or 'allow' is the sslmode. # if the ssl negotiation returns 'N' (nossl), then # ssl "failed", but the socket is still usable for nossl. # in these cases, can_skip is set to True so that the # subsequent non-ssl attempt is skipped if it failed with the 'N' response. can_skip = False startup = self.connector._startup_parameters password = self.connector._password Connection3 = client.Connection for (ssl, sf) in socket_factories: if can_skip is True: # the last attempt failed and knows this attempt will fail too. can_skip = False continue pq = Connection3(sf, startup, password = password,) if hasattr(self, 'tracer'): pq.tracer = self.tracer # Grab the negotiation transaction before # connecting as it will be needed later if successful. neg = pq.xact pq.connect(ssl = ssl, timeout = timeout) didssl = getattr(pq, 'ssl_negotiation', -1) # It successfully connected if pq.xact is None; # The startup/negotiation xact completed. if pq.xact is None: self.pq = pq if hasattr(self.pq.socket, 'fileno'): self.fileno = self.pq.socket.fileno self.security = 'ssl' if didssl is True else None showoption_type = element.ShowOption.type for x in neg.asyncs: if x.type == showoption_type: self._receive_async(x) # success! break elif pq.socket is not None: # In this case, an application/protocol error occurred. # Close out the sockets ourselves. pq.socket.close() # Identify whether or not we can skip the attempt. # Whether or not we can skip depends entirely on the SSL parameter. if sslmode == 'prefer' and ssl is False and didssl is False: # In this case, the server doesn't support SSL or it's # turned off. Therefore, the "without_ssl" attempt need # *not* be ran because it has already been noted to be # a failure. can_skip = True elif hasattr(pq.xact, 'exception'): # If a Python exception occurred, chances are that it is # going to fail again iff it is going to hit the same host. if sslmode == 'prefer' and ssl is False: # when 'prefer', the first attempt # is marked with ssl is "False" can_skip = True elif sslmode == 'allow' and ssl is None: # when 'allow', the first attempt # is marked with dossl is "None" can_skip = True try: self.typio.raise_error(pq.xact.error_message) except Exception as error: pq.error = error # Otherwise, infinite recursion in the element traceback. error.creator = None # The tracebacks of the specific failures aren't particularly useful.. error.__traceback__ = None if getattr(pq.xact, 'exception', None) is not None: pq.error.__cause__ = pq.xact.exception failures.append(pq) else: # No servers available. (see the break-statement in the for-loop) self.failures = failures or () # it's over. self.typio.raise_client_error(could_not_connect, creator = self, cause = exc) ## # connected, now initialize connection information. self.backend_id = self.pq.backend_id sv = self.settings.cache.get("server_version", "0.0") self.version_info = pg_version.normalize(pg_version.split(sv)) # manual binding self.sys = pg_lib.Binding(self, pg_lib.sys) vi = self.version_info[:2] if vi <= (8,1): sd = self.sys.startup_data_only_version() elif vi >= (9,2): sd = self.sys.startup_data_92() else: sd = self.sys.startup_data() # connection info self.version, self.backend_start, \ self.client_address, self.client_port = sd # First word from the version string. self.type = self.version.split()[0] ## # Set standard_conforming_strings scstr = self.settings.get('standard_conforming_strings') if scstr is None or vi == (8,1): # There used to be a warning emitted here. # It was noisy, and had little added value # over a nice WARNING at the top of the driver documentation. pass elif scstr.lower() not in ('on','true','yes'): self.settings['standard_conforming_strings'] = 'on' super().connect() def _pq_push(self, xact, controller = None): x = self.pq.xact if x is not None: self.pq.complete() if x.fatal is not None: self.typio.raise_error(x.error_message) if controller is not None: self._controller = controller self.pq.push(xact) # Complete the current protocol transaction. def _pq_complete(self): pq = self.pq x = pq.xact if x is not None: # There is a running transaction, finish it. pq.complete() # Raise an error *iff* one occurred. if x.fatal is not None: self.typio.raise_error(x.error_message, cause = getattr(x, 'exception', None)) del self._controller # Process the next message. def _pq_step(self, complete_state = globals()['xact'].Complete): pq = self.pq x = pq.xact if x is not None: pq.step() # If the protocol transaction was completed by # the last step, raise the error *iff* one occurred. if x.state is complete_state: if x.fatal is not None: self.typio.raise_error(x.error_message, cause = getattr(x, 'exception', None)) del self._controller def _receive_async(self, msg, controller = None, showoption = element.ShowOption.type, notice = element.Notice.type, notify = element.Notify.type, ): c = controller or getattr(self, '_controller', self) typ = msg.type if typ == showoption: if msg.name == b'client_encoding': self.typio.set_encoding(msg.value.decode('ascii')) self.settings._notify(msg) elif typ == notice: m = self.typio.emit_message(msg, creator = c) elif typ == notify: self._notifies.append(msg) else: self.typio.emit_client_message( element.ClientNotice(( (b'C', '-1000'), (b'S', 'WARNING'), (b'M', 'cannot process unrecognized asynchronous message'), (b'D', repr(msg)), )), creator = c ) def clone(self, *args, **kw): c = self.__class__(self.connector, *args, **kw) c.connect() return c def notify(self, *channels, **channel_and_payload): notifies = "" if channels: notifies += ';'.join(( 'NOTIFY "' + x.replace('"', '""') + '"' # str() case if x.__class__ is not tuple else ( # tuple() case 'NOTIFY "' + x[0].replace('"', '""') + """",'""" + \ x[1].replace("'", "''") + "'" ) for x in channels )) notifies += ';' if channel_and_payload: notifies += ';'.join(( 'NOTIFY "' + channel.replace('"', '""') + """",'""" + \ payload.replace("'", "''") + "'" for channel, payload in channel_and_payload.items() )) notifies += ';' return self.execute(notifies) def listening_channels(self): if self.version_info[:2] > (8,4): return self.sys.listening_channels() else: return self.sys.listening_relations() def listen(self, *channels, len = len): qstr = '' for x in channels: # XXX: hardcoded identifier length? if len(x) > 63: raise ValueError("channel name too long: " + x) qstr += '; LISTEN ' + x.replace('"', '""') return self.execute(qstr) def unlisten(self, *channels, len = len): qstr = '' for x in channels: # XXX: hardcoded identifier length? if len(x) > 63: raise ValueError("channel name too long: " + x) qstr += '; UNLISTEN ' + x.replace('"', '""') return self.execute(qstr) def iternotifies(self, timeout = None): nm = NotificationManager(self, timeout = timeout) for x in nm: if x is None: yield None else: for y in x[1]: yield y def __init__(self, connector, *args, **kw): """ Create a connection based on the given connector. """ self.connector = connector # raw notify messages self._notifies = [] self.fileno = -1 self.typio = self.connector.driver.typio(self) self.typio.set_encoding('ascii') self.settings = Settings(self) # class Connection class Connector(pg_api.Connector): """ All arguments to Connector are keywords. At the very least, user, and socket, may be provided. If socket, unix, or process is not provided, host and port must be. """ @property def _pq_iri(self): return pg_iri.serialize( { k : v for k,v in self.__dict__.items() if v is not None and not k.startswith('_') and k not in ( 'driver', 'category' ) }, obscure_password = True ) def _e_metas(self): yield (None, '[' + self.__class__.__name__ + '] ' + self._pq_iri) def __repr__(self): keywords = (',' + os.linesep + ' ').join([ '%s = %r' %(k, getattr(self, k, None)) for k in self.__dict__ if not k.startswith('_') and getattr(self, k, None) is not None ]) return '{mod}.{name}({keywords})'.format( mod = type(self).__module__, name = type(self).__name__, keywords = os.linesep + ' ' + keywords if keywords else '' ) @abstractmethod def socket_factory_sequence(self): """ Generate a list of callables that will be used to attempt to make the connection to the server. It is assumed that each factory will produce an object with a socket interface that is ready for reading and writing data. The callables in the sequence must take a timeout parameter. """ def __init__(self, connect_timeout : int = None, server_encoding : "server encoding hint for driver" = None, sslmode : ('allow', 'prefer', 'require', 'disable') = None, sslcrtfile : "filepath" = None, sslkeyfile : "filepath" = None, sslrootcrtfile : "filepath" = None, sslrootcrlfile : "filepath" = None, driver = None, **kw ): super().__init__(**kw) self.driver = driver self.server_encoding = server_encoding self.connect_timeout = connect_timeout self.sslmode = sslmode self.sslkeyfile = sslkeyfile self.sslcrtfile = sslcrtfile self.sslrootcrtfile = sslrootcrtfile self.sslrootcrlfile = sslrootcrlfile if self.sslrootcrlfile is not None: pg_exc.IgnoredClientParameterWarning( "certificate revocation lists are *not* checked", creator = self, ).emit() # Startup message parameters. tnkw = { 'client_min_messages' : 'WARNING', } if self.settings: s = dict(self.settings) if 'search_path' in self.settings: sp = s.get('search_path') if sp is None: self.settings.pop('search_path') elif not isinstance(sp, str): s['search_path'] = ','.join( pg_str.quote_ident(x) for x in sp ) tnkw.update(s) tnkw['user'] = self.user if self.database is not None: tnkw['database'] = self.database se = self.server_encoding or 'utf-8' ## # Attempt to accommodate for literal treatment of startup data. ## self._startup_parameters = tuple([ # All keys go in utf-8. However, ascii would probably be good enough. ( k.encode('utf-8'), # If it's a str(), encode in the hinted server_encoding. # Otherwise, convert the object(int, float, bool, etc) into a string # and treat it as utf-8. v.encode(se) if type(v) is str else str(v).encode('utf-8') ) for k, v in tnkw.items() ]) self._password = (self.password or '').encode(se) self._socket_secure = { 'keyfile' : self.sslkeyfile, 'certfile' : self.sslcrtfile, 'ca_certs' : self.sslrootcrtfile, } # class Connector class SocketConnector(Connector): 'abstract connector for using `socket` and `ssl`' @abstractmethod def socket_factory_sequence(self): """ Return a sequence of `SocketFactory`s for a connection to use to connect to the target host. """ def create_socket_factory(self, **params): return SocketFactory(**params) class IPConnector(SocketConnector): def socket_factory_sequence(self): return self._socketcreators def socket_factory_params(self, host, port, ipv, **kw): if ipv != self.ipv: raise TypeError("'ipv' keyword must be '%d'" % self.ipv) if host is None: raise TypeError("'host' is a required keyword and cannot be 'None'") if port is None: raise TypeError("'port' is a required keyword and cannot be 'None'") return {'socket_create': (self.address_family, socket.SOCK_STREAM), 'socket_connect': (host, int(port))} def __init__(self, host, port, ipv, **kw): params = self.socket_factory_params(host, port, ipv, **kw) self.host, self.port = params['socket_connect'] # constant socket connector self._socketcreator = self.create_socket_factory(**params) self._socketcreators = (self._socketcreator,) super().__init__(**kw) class IP4(IPConnector): 'Connector for establishing IPv4 connections' ipv = 4 address_family = socket.AF_INET def __init__(self, host : "IPv4 Address (str)" = None, port : int = None, ipv = 4, **kw ): super().__init__(host, port, ipv, **kw) class IP6(IPConnector): 'Connector for establishing IPv6 connections' ipv = 6 address_family = socket.AF_INET6 def __init__(self, host : "IPv6 Address (str)" = None, port : int = None, ipv = 6, **kw ): super().__init__(host, port, ipv, **kw) class Unix(SocketConnector): 'Connector for establishing unix domain socket connections' def socket_factory_sequence(self): return self._socketcreators def socket_factory_params(self, unix): if unix is None: raise TypeError("'unix' is a required keyword and cannot be 'None'") return {'socket_create': (socket.AF_UNIX, socket.SOCK_STREAM), 'socket_connect': unix} def __init__(self, unix = None, **kw): params = self.socket_factory_params(unix) self.unix = params['socket_connect'] # constant socket connector self._socketcreator = self.create_socket_factory(**params) self._socketcreators = (self._socketcreator,) super().__init__(**kw) class Host(SocketConnector): """ Connector for establishing hostname based connections. This connector exercises socket.getaddrinfo. """ def socket_factory_sequence(self): """ Return a list of `SocketCreator`s based on the results of `socket.getaddrinfo`. """ return [ # (AF, socktype, proto), (IP, Port) self.create_socket_factory(**(self.socket_factory_params(x[0:3], x[4][:2], self._socket_secure))) for x in socket.getaddrinfo( self.host, self.port, self._address_family, socket.SOCK_STREAM ) ] def socket_factory_params(self, socktype, address, sslparams): return {'socket_create': socktype, 'socket_connect': address, 'socket_secure': sslparams} def __init__(self, host : str = None, port : (str, int) = None, ipv : int = None, address_family : "address family to use(AF_INET,AF_INET6)" = None, **kw ): if host is None: raise TypeError("'host' is a required keyword") if port is None: raise TypeError("'port' is a required keyword") if address_family is not None and ipv is not None: raise TypeError("'ipv' and 'address_family' on mutually exclusive") if ipv is None: self._address_family = address_family or socket.AF_UNSPEC elif ipv == 4: self._address_family = socket.AF_INET elif ipv == 6: self._address_family = socket.AF_INET6 else: raise TypeError("unknown IP version selected: 'ipv' = " + repr(ipv)) self.host = host self.port = port super().__init__(**kw) class Driver(pg_api.Driver): def _e_metas(self): yield (None, type(self).__module__ + '.' + type(self).__name__) def ip4(self, **kw): return IP4(driver = self, **kw) def ip6(self, **kw): return IP6(driver = self, **kw) def host(self, **kw): return Host(driver = self, **kw) def unix(self, **kw): return Unix(driver = self, **kw) def fit(self, unix = None, host = None, port = None, **kw ) -> Connector: """ Create the appropriate `postgresql.api.Connector` based on the parameters. This also protects against mutually exclusive parameters. """ if unix is not None: if host is not None: raise TypeError("'unix' and 'host' keywords are exclusive") if port is not None: raise TypeError("'unix' and 'port' keywords are exclusive") return self.unix(unix = unix, **kw) else: if host is None or port is None: raise TypeError("'host' and 'port', or 'unix' must be supplied") # We have a host and a port. # If it's an IP address, IP4 or IP6 should be selected. if ':' in host: # There's a ':' in host, good chance that it's IPv6. try: socket.inet_pton(socket.AF_INET6, host) return self.ip6(host = host, port = port, **kw) except (socket.error, NameError): pass # Not IPv6, maybe IPv4... try: socket.inet_aton(host) # It's IP4 return self.ip4(host = host, port = port, **kw) except socket.error: pass # neither host, nor port are None, probably a hostname. return self.host(host = host, port = port, **kw) def connect(self, **kw) -> Connection: """ For information on acceptable keywords, see: `postgresql.documentation.driver`:Connection Keywords """ c = self.fit(**kw)() c.connect() return c def __init__(self, connection = Connection, typio = TypeIO): self.connection = connection self.typio = typio fe-1.1.0/postgresql/encodings/000077500000000000000000000000001203372773200162655ustar00rootroot00000000000000fe-1.1.0/postgresql/encodings/__init__.py000066400000000000000000000000231203372773200203710ustar00rootroot00000000000000## # .encodings ## fe-1.1.0/postgresql/encodings/aliases.py000066400000000000000000000031611203372773200202610ustar00rootroot00000000000000## # .encodings.aliases ## """ Module for mapping PostgreSQL encoding names to Python encoding names. These are **not** installed in Python's aliases. Rather, `get_python_name` should be used directly. URLs of interest: * http://docs.python.org/library/codecs.html * http://git.postgresql.org/gitweb?p=postgresql.git;a=blob;f=src/backend/utils/mb/encnames.c """ ## #: Dictionary of Postgres encoding names to Python encoding names. #: This mapping only contains those encoding names that do not intersect. postgres_to_python = { 'unicode' : 'utf_8', 'sql_ascii' : 'ascii', 'euc_jp' : 'eucjp', 'euc_cn' : 'euccn', 'euc_kr' : 'euckr', 'shift_jis_2004' : 'euc_jis_2004', 'sjis' : 'shift_jis', 'alt' : 'cp866', # IBM866 'abc' : 'cp1258', 'vscii' : 'cp1258', 'koi8r' : 'koi8_r', 'koi8u' : 'koi8_u', 'tcvn' : 'cp1258', 'tcvn5712' : 'cp1258', # 'euc_tw' : None, # N/A # 'mule_internal' : None, # N/A } def get_python_name(encname): """ Lookup the name in the `postgres_to_python` dictionary. If no match is found, check for a 'win' or 'windows-' name and convert that to a 'cp###' name. Returns `None` if there is no alias for `encname`. The win[0-9]+ and windows-[0-9]+ entries are handled functionally. """ # check the dictionary first localname = postgres_to_python.get(encname) if localname is not None: return localname # no explicit mapping, check for functional transformation if encname.startswith('win'): # handle win#### and windows-#### # remove the trailing CP number bare = encname.rstrip('0123456789') if bare.strip('_-') in ('win', 'windows'): return 'cp' + encname[len(bare):] return encname fe-1.1.0/postgresql/encodings/bytea.py000066400000000000000000000035121203372773200177440ustar00rootroot00000000000000## # .encodings.bytea ## 'PostgreSQL bytea encoding and decoding functions' import codecs import struct import sys ord_to_seq = { i : \ "\\" + oct(i)[2:].rjust(3, '0') \ if not (32 < i < 126) else r'\\' \ if i == 92 else chr(i) for i in range(256) } if sys.version_info[:2] >= (3, 3): # Subscripting memory in 3.3 returns byte as an integer, not as a bytestring def decode(data): return ''.join(map(ord_to_seq.__getitem__, (data[x] for x in range(len(data))))) else: def decode(data): return ''.join(map(ord_to_seq.__getitem__, (data[x][0] for x in range(len(data))))) def encode(data): diter = ((data[i] for i in range(len(data)))) output = [] next = diter.__next__ for x in diter: if x == "\\": try: y = next() except StopIteration: raise ValueError("incomplete backslash sequence") if y == "\\": # It's a backslash, so let x(\) be appended. x = ord(x) elif y.isdigit(): try: os = ''.join((y, next(), next())) except StopIteration: # requires three digits raise ValueError("incomplete backslash sequence") try: x = int(os, base = 8) except ValueError: raise ValueError("invalid bytea octal sequence '%s'" %(os,)) else: raise ValueError("invalid backslash follow '%s'" %(y,)) else: x = ord(x) output.append(x) return struct.pack(str(len(output)) + 'B', *output) class Codec(codecs.Codec): 'bytea codec' def encode(data, errors = 'strict'): return (encode(data), len(data)) encode = staticmethod(encode) def decode(data, errors = 'strict'): return (decode(data), len(data)) decode = staticmethod(decode) class StreamWriter(Codec, codecs.StreamWriter): pass class StreamReader(Codec, codecs.StreamReader): pass bytea_codec = (Codec.encode, Codec.decode, StreamReader, StreamWriter) codecs.register(lambda x: x == 'bytea' and bytea_codec or None) fe-1.1.0/postgresql/exceptions.py000066400000000000000000000422421203372773200170530ustar00rootroot00000000000000## # .exceptions - Exception hierarchy for PostgreSQL database ERRORs. ## """ PostgreSQL exceptions and warnings with associated state codes. The primary entry points of this module is the `ErrorLookup` function and the `WarningLookup` function. Given an SQL state code, they give back the most appropriate Error or Warning subclass. For more information on error codes see: http://www.postgresql.org/docs/current/static/errcodes-appendix.html This module is executable via -m: python -m postgresql.exceptions. It provides a convenient way to look up the exception object mapped to by the given error code:: $ python -m postgresql.exceptions XX000 postgresql.exceptions.InternalError [XX000] If the exact error code is not found, it will try to find the error class's exception(The first two characters of the error code make up the class identity):: $ python -m postgresql.exceptions XX400 postgresql.exceptions.InternalError [XX000] If that fails, it will return `postgresql.exceptions.Error` """ import sys import os from functools import partial from operator import attrgetter from .message import Message from . import sys as pg_sys PythonException = Exception class Exception(Exception): 'Base PostgreSQL exception class' pass class LoadError(Exception): 'Failed to load a library' class Disconnection(Exception): 'Exception identifying errors that result in disconnection' class Warning(Message): code = '01000' _e_label = property(attrgetter('__class__.__name__')) class DriverWarning(Warning): code = '01-00' source = 'CLIENT' class IgnoredClientParameterWarning(DriverWarning): 'Warn the user of a valid, but ignored parameter.' code = '01-CP' class TypeConversionWarning(DriverWarning): 'Report a potential issue with a conversion.' code = '01-TP' class DeprecationWarning(Warning): code = '01P01' class DynamicResultSetsReturnedWarning(Warning): code = '0100C' class ImplicitZeroBitPaddingWarning(Warning): code = '01008' class NullValueEliminatedInSetFunctionWarning(Warning): code = '01003' class PrivilegeNotGrantedWarning(Warning): code = '01007' class PrivilegeNotRevokedWarning(Warning): code = '01006' class StringDataRightTruncationWarning(Warning): code = '01004' class NoDataWarning(Warning): code = '02000' class NoMoreSetsReturned(NoDataWarning): code = '02001' class Error(Message, Exception): 'A PostgreSQL Error' _e_label = 'ERROR' code = '' def __str__(self): 'Call .sys.errformat(self)' return pg_sys.errformat(self) @property def fatal(self): f = self.details.get('severity') return None if f is None else f in ('PANIC', 'FATAL') class DriverError(Error): "Errors originating in the driver's implementation." source = 'CLIENT' code = '--000' class AuthenticationMethodError(DriverError, Disconnection): """ Server requested an authentication method that is not supported by the driver. """ code = '--AUT' class InsecurityError(DriverError, Disconnection): """ Error signifying a secure channel to a server cannot be established. """ code = '--SEC' class ConnectTimeoutError(DriverError, Disconnection): 'Client was unable to esablish a connection in the given time' code = '--TOE' class TypeIOError(DriverError): """ Driver failed to pack or unpack a tuple. """ code = '--TIO' class ParameterError(TypeIOError): code = '--PIO' class ColumnError(TypeIOError): code = '--CIO' class CompositeError(TypeIOError): code = '--cIO' class OperationError(DriverError): """ An invalid operation on an interface element. """ code = '--OPE' class TransactionError(Error): pass class SQLNotYetCompleteError(Error): code = '03000' class ConnectionError(Error, Disconnection): code = '08000' class ConnectionDoesNotExistError(ConnectionError): """ The connection is closed or was never connected. """ code = '08003' class ConnectionFailureError(ConnectionError): 'Raised when a connection is dropped' code = '08006' class ClientCannotConnectError(ConnectionError): """ Client was unable to establish a connection to the server. """ code = '08001' class ConnectionRejectionError(ConnectionError): code = '08004' class TransactionResolutionUnknownError(ConnectionError): code = '08007' class ProtocolError(ConnectionError): code = '08P01' class TriggeredActionError(Error): code = '09000' class FeatureError(Error): "Unsupported feature" code = '0A000' class TransactionInitiationError(TransactionError): code = '0B000' class LocatorError(Error): code = '0F000' class LocatorSpecificationError(LocatorError): code = '0F001' class GrantorError(Error): code = '0L000' class GrantorOperationError(GrantorError): code = '0LP01' class RoleSpecificationError(Error): code = '0P000' class CaseNotFoundError(Error): code = '20000' class CardinalityError(Error): "Wrong number of rows returned" code = '21000' class TriggeredDataChangeViolation(Error): code = '27000' class AuthenticationSpecificationError(Error, Disconnection): code = '28000' class DPDSEError(Error): "Dependent Privilege Descriptors Still Exist" code = '2B000' class DPDSEObjectError(DPDSEError): code = '2BP01' class SREError(Error): "SQL Routine Exception" code = '2F000' class FunctionExecutedNoReturnStatementError(SREError): code = '2F005' class DataModificationProhibitedError(SREError): code = '2F002' class StatementProhibitedError(SREError): code = '2F003' class ReadingDataProhibitedError(SREError): code = '2F004' class EREError(Error): "External Routine Exception" code = '38000' class ContainingSQLNotPermittedError(EREError): code = '38001' class ModifyingSQLDataNotPermittedError(EREError): code = '38002' class ProhibitedSQLStatementError(EREError): code = '38003' class ReadingSQLDataNotPermittedError(EREError): code = '38004' class ERIEError(Error): "External Routine Invocation Exception" code = '39000' class InvalidSQLState(ERIEError): code = '39001' class NullValueNotAllowed(ERIEError): code = '39004' class TriggerProtocolError(ERIEError): code = '39P01' class SRFProtocolError(ERIEError): code = '39P02' class TRError(TransactionError): "Transaction Rollback" code = '40000' class DeadlockError(TRError): code = '40P01' class IntegrityConstraintViolationError(TRError): code = '40002' class SerializationError(TRError): code = '40001' class StatementCompletionUnknownError(TRError): code = '40003' class ITSError(TransactionError): "Invalid Transaction State" code = '25000' class ActiveTransactionError(ITSError): code = '25001' class BranchAlreadyActiveError(ITSError): code = '25002' class BadAccessModeForBranchError(ITSError): code = '25003' class BadIsolationForBranchError(ITSError): code = '25004' class NoActiveTransactionForBranchError(ITSError): code = '25005' class ReadOnlyTransactionError(ITSError): "Occurs when an alteration occurs in a read-only transaction." code = '25006' class SchemaAndDataStatementsError(ITSError): "Mixed schema and data statements not allowed." code = '25007' class InconsistentCursorIsolationError(ITSError): "The held cursor requires the same isolation." code = '25008' class NoActiveTransactionError(ITSError): code = '25P01' class InFailedTransactionError(ITSError): "Occurs when an action occurs in a failed transaction." code = '25P02' class SavepointError(TransactionError): "Classification error designating errors that relate to savepoints." code = '3B000' class InvalidSavepointSpecificationError(SavepointError): code = '3B001' class TransactionTerminationError(TransactionError): code = '2D000' class IRError(Error): "Insufficient Resource Error" code = '53000' class MemoryError(IRError, MemoryError): code = '53200' class DiskFullError(IRError): code = '53100' class TooManyConnectionsError(IRError): code = '53300' class PLEError(OverflowError): "Program Limit Exceeded" code = '54000' class ComplexityOverflowError(PLEError): code = '54001' class ColumnOverflowError(PLEError): code = '54011' class ArgumentOverflowError(PLEError): code = '54023' class ONIPSError(Error): "Object Not In Prerequisite State" code = '55000' class ObjectInUseError(ONIPSError): code = '55006' class ImmutableRuntimeParameterError(ONIPSError): code = '55P02' class UnavailableLockError(ONIPSError): code = '55P03' class SEARVError(Error): "Syntax Error or Access Rule Violation" code = '42000' class SEARVNameError(SEARVError): code = '42602' class NameTooLongError(SEARVError): code = '42622' class ReservedNameError(SEARVError): code = '42939' class ForeignKeyCreationError(SEARVError): code = '42830' class InsufficientPrivilegeError(SEARVError): code = '42501' class GroupingError(SEARVError): code = '42803' class RecursionError(SEARVError): code = '42P19' class WindowError(SEARVError): code = '42P20' class SyntaxError(SEARVError): code = '42601' class TypeError(SEARVError): pass class CoercionError(TypeError): code = '42846' class TypeMismatchError(TypeError): code = '42804' class IndeterminateTypeError(TypeError): code = '42P18' class WrongObjectTypeError(TypeError): code = '42809' class UndefinedError(SEARVError): pass class UndefinedColumnError(UndefinedError): code = '42703' class UndefinedFunctionError(UndefinedError): code = '42883' class UndefinedTableError(UndefinedError): code = '42P01' class UndefinedParameterError(UndefinedError): code = '42P02' class UndefinedObjectError(UndefinedError): code = '42704' class DuplicateError(SEARVError): pass class DuplicateColumnError(DuplicateError): code = '42701' class DuplicateCursorError(DuplicateError): code = '42P03' class DuplicateDatabaseError(DuplicateError): code = '42P04' class DuplicateFunctionError(DuplicateError): code = '42723' class DuplicatePreparedStatementError(DuplicateError): code = '42P05' class DuplicateSchemaError(DuplicateError): code = '42P06' class DuplicateTableError(DuplicateError): code = '42P07' class DuplicateAliasError(DuplicateError): code = '42712' class DuplicateObjectError(DuplicateError): code = '42710' class AmbiguityError(SEARVError): pass class AmbiguousColumnError(AmbiguityError): code = '42702' class AmbiguousFunctionError(AmbiguityError): code = '42725' class AmbiguousParameterError(AmbiguityError): code = '42P08' class AmbiguousAliasError(AmbiguityError): code = '42P09' class ColumnReferenceError(SEARVError): code = '42P10' class DefinitionError(SEARVError): pass class ColumnDefinitionError(DefinitionError): code = '42611' class CursorDefinitionError(DefinitionError): code = '42P11' class DatabaseDefinitionError(DefinitionError): code = '42P12' class FunctionDefinitionError(DefinitionError): code = '42P13' class PreparedStatementDefinitionError(DefinitionError): code = '42P14' class SchemaDefinitionError(DefinitionError): code = '42P15' class TableDefinitionError(DefinitionError): code = '42P16' class ObjectDefinitionError(DefinitionError): code = '42P17' class CursorStateError(Error): code = '24000' class WithCheckOptionError(Error): code = '44000' class NameError(Error): pass class CatalogNameError(NameError): code = '3D000' class CursorNameError(NameError): code = '34000' class StatementNameError(NameError): code = '26000' class SchemaNameError(NameError): code = '3F000' class ICVError(Error): "Integrity Contraint Violation" code = '23000' class RestrictError(ICVError): code = '23001' class NotNullError(ICVError): code = '23502' class ForeignKeyError(ICVError): code = '23503' class UniqueError(ICVError): code = '23505' class CheckError(ICVError): code = '23514' class DataError(Error): code = '22000' class StringRightTruncationError(DataError): code = '22001' class StringDataLengthError(DataError): code = '22026' class ZeroLengthString(DataError): code = '2200F' class EncodingError(DataError): code = '22021' class ArrayElementError(DataError): code = '2202E' class SpecificTypeMismatch(DataError): code = '2200G' class NullValueNotAllowedError(DataError): code = '22004' class NullValueNoIndicatorParameter(DataError): code = '22002' class ZeroDivisionError(DataError): code = '22012' class FloatingPointError(DataError): code = '22P01' class AssignmentError(DataError): code = '22005' class IndicatorOverflowError(DataError): code = '22022' class BadCopyError(DataError): code = '22P04' class TextRepresentationError(DataError): code = '22P02' class BinaryRepresentationError(DataError): code = '22P03' class UntranslatableCharacterError(DataError): code = '22P05' class NonstandardUseOfEscapeCharacterError(DataError): code = '22P06' class NotXMLError(DataError): code = '2200L' class XMLDocumentError(DataError): code = '2200M' class XMLContentError(DataError): code = '2200N' class XMLCommentError(DataError): code = '2200S' class XMLProcessingInstructionError(DataError): code = '2200T' class DateTimeFormatError(DataError): code = '22007' class TimeZoneDisplacementValueError(DataError): code = '22009' class DateTimeFieldOverflowError(DataError): code = '22008' class IntervalFieldOverflowError(DataError): code = '22015' class LogArgumentError(DataError): code = '2201E' class PowerFunctionArgumentError(DataError): code = '2201F' class WidthBucketFunctionArgumentError(DataError): code = '2201G' class CastCharacterValueError(DataError): code = '22018' class EscapeCharacterError(DataError): code = '22019' class EscapeOctetError(DataError): code = '2200D' class EscapeSequenceError(DataError): code = '22025' class EscapeCharacterConflictError(DataError): code = '2200B' class EscapeCharacterError(DataError): "Invalid escape character" code = '2200C' class SubstringError(DataError): code = '22011' class TrimError(DataError): code = '22027' class IndicatorParameterValueError(DataError): code = '22010' class LimitValueError(DataError): code = '2201W' pg_code = '22020' class OffsetValueError(DataError): code = '2201X' class ParameterValueError(DataError): code = '22023' class RegularExpressionError(DataError): code = '2201B' class NumericRangeError(DataError): code = '22003' class UnterminatedCStringError(DataError): code = '22024' class InternalError(Error): code = 'XX000' class DataCorruptedError(InternalError): code = 'XX001' class IndexCorruptedError(InternalError): code = 'XX002' class SIOError(Error): "System I/O" code = '58000' class UndefinedFileError(SIOError): code = '58P01' class DuplicateFileError(SIOError): code = '58P02' class CFError(Error): "Configuration File Error" code = 'F0000' class LockFileExistsError(CFError): code = 'F0001' class OIError(Error): "Operator Intervention" code = '57000' class QueryCanceledError(OIError): code = '57014' class AdminShutdownError(OIError, Disconnection): code = '57P01' class CrashShutdownError(OIError, Disconnection): code = '57P02' class ServerNotReadyError(OIError, Disconnection): 'Thrown when a connection is established to a server that is still starting up.' code = '57P03' class PLPGSQLError(Error): "Error raised by a PL/PgSQL procedural function" code = 'P0000' class PLPGSQLRaiseError(PLPGSQLError): "Error raised by a PL/PgSQL RAISE statement." code = 'P0001' class PLPGSQLNoDataFoundError(PLPGSQLError): code = 'P0002' class PLPGSQLTooManyRowsError(PLPGSQLError): code = 'P0003' # Setup mapping to provide code based exception lookup. code_to_error = {} code_to_warning = {} def map_errors_and_warnings( objs : "A iterable of `Warning`s and `Error`'s", error_container : "apply the code to error association to this object" = code_to_error, warning_container : "apply the code to warning association to this object" = code_to_warning, ): """ Construct the code-to-error and code-to-warning associations. """ for obj in objs: if not issubclass(type(obj), (type(Warning), type(Error))): # It's not object of interest. continue code = getattr(obj, 'code', None) if code is None: # It has no code attribute, or the code was set to None. # If it's code is None, we don't map it as it's a "container". continue if issubclass(obj, Error): base = Error container = error_container elif issubclass(obj, Warning): base = Warning container = warning_container else: continue cur_obj = container.get(code) if cur_obj is None or issubclass(cur_obj, obj): # There is no object yet, or the object at the code # is not the most general class. # The latter condition comes into play when # there are sub-Class types that share the Class code # with the most general type. (See TypeError) container[code] = obj if hasattr(obj, 'pg_code'): # If there's a PostgreSQL version of the code, # map it as well for older servers. container[obj.pg_code] = obj def code_lookup( default : "The object to return when no code or class is found", container : "where to look for the object associated with the code", code : "the code to find the exception for" ): obj = container.get(code) if obj is None: obj = container.get(code[:2] + "000", default) return obj map_errors_and_warnings(sys.modules[__name__].__dict__.values()) ErrorLookup = partial(code_lookup, Error, code_to_error) WarningLookup = partial(code_lookup, Warning, code_to_warning) if __name__ == '__main__': for x in sys.argv[1:]: if x.startswith('01'): e = WarningLookup(x) else: e = ErrorLookup(x) sys.stdout.write('postgresql.exceptions.%s [%s]%s%s' %( e.__name__, e.code, os.linesep, ( e.__doc__ is not None and os.linesep.join([ ' ' + x for x in (e.__doc__).split('\n') ]) + os.linesep or '' ) ) ) ## # vim: ts=3:sw=3:noet: fe-1.1.0/postgresql/installation.py000066400000000000000000000151101203372773200173650ustar00rootroot00000000000000## # .installation ## """ Collect and access PostgreSQL installation information. """ import sys import os import os.path import subprocess import io import errno from itertools import cycle, chain from operator import itemgetter from .python.os import find_executable, close_fds, platform_exe from . import versionstring from . import api as pg_api from . import string as pg_str # Get the output from the given command. # *args are transformed into "long options", '--' + x def get_command_output(exe, *args): pa = list(exe) + [ '--' + x.strip() for x in args if x is not None ] p = subprocess.Popen(pa, close_fds = close_fds, stdout = subprocess.PIPE, stderr = subprocess.PIPE, stdin = subprocess.PIPE, shell = False ) p.stdin.close() p.stderr.close() while True: try: rv = p.wait() break except OSError as e: if e.errno != errno.EINTR: raise if rv != 0: return None with p.stdout, io.TextIOWrapper(p.stdout) as txt: return txt.read() def pg_config_dictionary(*pg_config_path): """ Create a dictionary of the information available in the given pg_config_path. This provides a one-shot solution to fetching information from the pg_config binary. Returns a dictionary object. """ default_output = get_command_output(pg_config_path) if default_output is not None: d = {} for x in default_output.splitlines(): if not x or x.isspace() or x.find('=') == -1: continue k, v = x.split('=', 1) # keep it semi-consistent with instance d[k.lower().strip()] = v.strip() return d # Support for 8.0 pg_config and earlier. # This requires three invocations of pg_config: # First --help, to get the -- options available, # Second, all the -- options except version. # Third, --version as it appears to be exclusive in some cases. opt = [] for l in get_command_output(pg_config_path, 'help').splitlines(): dash_pos = l.find('--') if dash_pos == -1: continue sp_pos = l.find(' ', dash_pos) # the dashes are added by the call command opt.append(l[dash_pos+2:sp_pos]) if 'help' in opt: opt.remove('help') if 'version' in opt: opt.remove('version') d=dict(zip(opt, get_command_output(pg_config_path, *opt).splitlines())) d['version'] = get_command_output(pg_config_path, 'version').strip() return d ## # Build a key-value pair list of the configure options. # If the item is quoted, mind the quotes. def parse_configure_options(confopt, quotes = '\'"', dash_and_quotes = '-\'"'): # This is not a robust solution, but it will usually work. # Chances are that there is a quote at the beginning of this string. # However, in the windows pg_config.exe, this appears to be absent. if confopt[0:1] in quotes: # quote at the beginning. assume it's used consistently. quote = confopt[0:1] elif confopt[-1:] in quotes: # quote at the end? quote = confopt[-1] else: # fallback to something. :( quote = "'" ## # This is using the wrong kind of split, but the pg_config # output has been consistent enough for this to work. parts = pg_str.split_using(confopt, quote, sep = ' ') qq = quote * 2 for x in parts: if qq in x: # singularize the quotes x = x.replace(qq, quote) # remove the quotes around '--' from option. # if it splits once, the '1' index will # be `True`, indicating that the flag was given, but # was not given a value. kv = x.strip(dash_and_quotes).split('=', 1) + [True] key = kv[0].replace('-','_') # Ignore empty keys. if key: yield (key, kv[1]) def default_pg_config(execname = 'pg_config', envkey = 'PGINSTALLATION'): """ Get the default `pg_config` executable on the system. If 'PGINSTALLATION' is in the environment, use it. Otherwise, look through the system's PATH environment. """ pg_config_path = os.environ.get(envkey) if pg_config_path: # Trust PGINSTALLATION. return platform_exe(pg_config_path) return find_executable(execname) class Installation(pg_api.Installation): """ Class providing a Python interface to PostgreSQL installation information. """ version = None version_info = None type = None configure_options = None #: The pg_config information dictionary. info = None pg_executables = ( 'pg_config', 'psql', 'initdb', 'pg_resetxlog', 'pg_controldata', 'clusterdb', 'pg_ctl', 'pg_dump', 'pg_dumpall', 'postgres', 'postmaster', 'reindexdb', 'vacuumdb', 'ipcclean', 'createdb', 'ecpg', 'createuser', 'createlang', 'droplang', 'dropuser', 'pg_restore', ) pg_libraries = ( 'libpq', 'libecpg', 'libpgtypes', 'libecpg_compat', ) pg_directories = ( 'bindir', 'docdir', 'includedir', 'pkgincludedir', 'includedir_server', 'libdir', 'pkglibdir', 'localedir', 'mandir', 'sharedir', 'sysconfdir', ) def _e_metas(self): l = list(self.configure_options.items()) l.sort(key = itemgetter(0)) yield ('version', self.version) if l: yield ('configure_options', (os.linesep).join(( k if v is True else k + '=' + v for k,v in l )) ) def __repr__(self, format = "{mod}.{name}({info!r})".format): return format( mod = type(self).__module__, name = type(self).__name__, info = self.info ) def __init__(self, info : dict): """ Initialize the Installation using the given information dictionary. """ self.info = info self.version = self.info["version"] self.type, vs = self.version.split() self.version_info = versionstring.normalize(versionstring.split(vs)) self.configure_options = dict( parse_configure_options(self.info.get('configure', '')) ) # collect the paths in a dictionary first self.paths = dict() exists = os.path.exists join = os.path.join for k in self.pg_directories: self.paths[k] = self.info.get(k) # find all the PG executables that exist for the installation. bindir_path = self.info.get('bindir') if bindir_path is None: self.paths.update(zip(self.pg_executables, cycle((None,)))) else: for k in self.pg_executables: path = platform_exe(join(bindir_path, k)) if exists(path): self.paths[k] = path else: self.paths[k] = None self.__dict__.update(self.paths) @property def ssl(self): """ Whether the installation was compiled with SSL support. """ return 'with_openssl' in self.configure_options def default(typ = Installation): """ Get the default Installation. Uses default_pg_config() to identify the executable. """ path = default_pg_config() if path is None: return None return typ(pg_config_dictionary(path)) if __name__ == '__main__': if sys.argv[1:]: d = pg_config_dictionary(sys.argv[1]) i = Installation(d) else: i = default() from .python.element import format_element print(format_element(i)) fe-1.1.0/postgresql/iri.py000066400000000000000000000114261203372773200154550ustar00rootroot00000000000000## # .iri ## """ Parse and serialize PQ IRIs. PQ IRIs take the form:: pq://user:pass@host:port/database?setting=value&setting2=value2#public,othernamespace IPv6 is supported via the standard representation:: pq://[::1]:5432/database Driver Parameters: pq://user@host/?[driver_param]=value&[other_param]=value?setting=val """ from .resolved import riparse as ri from .string import split_ident from operator import itemgetter get0 = itemgetter(0) del itemgetter import re escape_path_re = re.compile('[%s]' %(re.escape(ri.unescaped + ','),)) def structure(d, fieldproc = ri.unescape): 'Create a clientparams dictionary from a parsed RI' if d.get('scheme', 'pq').lower() != 'pq': raise ValueError("PQ-IRI scheme is not 'pq'") cpd = { k : fieldproc(v) for k, v in d.items() if k not in ('path', 'fragment', 'query', 'host', 'scheme') } path = d.get('path') frag = d.get('fragment') query = d.get('query') host = d.get('host') if host is not None: if host.startswith('[') and host.endswith(']'): host = host[1:-1] if host.startswith('unix:'): cpd['unix'] = host[len('unix:'):].replace(':','/') else: cpd['host'] = host[1:-1] else: cpd['host'] = fieldproc(host) if path: # Only state the database field's existence if the first path is non-empty. if path[0]: cpd['database'] = path[0] path = path[1:] if path: cpd['path'] = path settings = {} if query: if hasattr(query, 'items'): qiter = query.items() else: qiter = query for k, v in qiter: if k.startswith('[') and k.endswith(']'): k = k[1:-1] if k != 'settings' and k not in cpd: cpd[fieldproc(k)] = fieldproc(v) elif k: settings[fieldproc(k)] = fieldproc(v) # else: ignore empty query keys if frag: settings['search_path'] = [ fieldproc(x) for x in frag.split(',') ] if settings: cpd['settings'] = settings return cpd def construct_path(x, re = escape_path_re): """ Join a path sequence using ',' and escaping ',' in the pieces. """ return ','.join((re.sub(ri.re_pct_encode, y) for y in x)) def construct(x, obscure_password = False): 'Construct a RI dictionary from a clientparams dictionary' # the rather exhaustive settings choreography is due to # a desire to allow the search_path to be appended in the fragment settings = x.get('settings') no_path_settings = None search_path = None if settings: if isinstance(settings, dict): siter = settings.items() search_path = settings.get('search_path') else: siter = list(settings) search_path = [(k,v) for k,v in siter if k == 'search_path'] search_path.append((None,None)) search_path = search_path[-1][1] no_path_settings = [(k,v) for k,v in siter if k != 'search_path'] if not no_path_settings: no_path_settings = None # It could be a string search_path, split if it is. if search_path is not None and isinstance(search_path, str): search_path = split_ident(search_path, sep = ',') port = None if 'unix' in x: host = '[unix:' + x['unix'].replace('/',':') + ']' # ignore port.. it's a mis-config. elif 'host' in x: host = x['host'] if ':' in host: host = '[' + host + ']' port = x.get('port') else: host = None port = x.get('port') path = [] if 'database' in x: path.append(x['database']) if 'path' in x: path.extend(x['path'] or ()) password = x.get('password') if obscure_password and password is not None: password = '***' driver_params = list({ '[' + k + ']' : str(v) for k,v in x.items() if k not in ( 'user', 'password', 'port', 'database', 'ssl', 'path', 'host', 'unix', 'ipv','settings' ) }.items()) driver_params.sort(key=get0) return ( 'pqs' if x.get('ssl', False) is True else 'pq', # netloc: user:pass@host[:port] ri.unsplit_netloc(( x.get('user'), password, host, None if 'port' not in x else str(x['port']) )), None if not path else '/'.join([ ri.escape_path_re.sub(path_comp, '/') for path_comp in path ]), (ri.construct_query(driver_params) if driver_params else None) if no_path_settings is None else ( ri.construct_query( driver_params + no_path_settings ) ), None if search_path is None else construct_path(search_path), ) def parse(s, fieldproc = ri.unescape): 'Parse a Postgres IRI into a dictionary object' return structure( # In ri.parse, don't unescape the parsed values as our sub-structure # uses the escape mechanism in IRIs to specify literal separator # characters. ri.parse(s, fieldproc = str), fieldproc = fieldproc ) def serialize(x, obscure_password = False): 'Return a Postgres IRI from a dictionary object.' return ri.unsplit(construct(x, obscure_password = obscure_password)) if __name__ == '__main__': import sys for x in sys.argv[1:]: print("{src} -> {parsed!r} -> {serial}".format( src = x, parsed = parse(x), serial = serialize(parse(x)) )) fe-1.1.0/postgresql/lib/000077500000000000000000000000001203372773200150625ustar00rootroot00000000000000fe-1.1.0/postgresql/lib/__init__.py000066400000000000000000000264511203372773200172030ustar00rootroot00000000000000## # .lib - libraries; manage SQL outside of Python. ## """ PostgreSQL statement and object libraries. The purpose of a library is provide a means to manage a mapping of symbols to database operations or objects. These operations can be simple statements, procedures, or something more complex. Libraries are intended to allow the programmer to isolate and manage SQL outside of a system's code-flow. It provides a means to construct the basic Python interfaces to a PostgreSQL-based application. """ import io import os.path from types import ModuleType from abc import abstractmethod, abstractproperty from ..python.element import Element, ElementSet from .. import api as pg_api from .. import sys as pg_sys from .. import exceptions as pg_exc from ..python.itertools import find from itertools import chain try: libdir = os.path.abspath(os.path.dirname(__file__)) except NameError: pass else: if os.path.exists(libdir): pg_sys.libpath.insert(0, libdir) del libdir __all__ = [ 'Library', 'SymbolCollection', 'ILF', 'Symbol', 'Binding', 'BoundSymbol', 'find_libsql', 'load', ] class Symbol(Element): """ An annotated SQL statement string. The annotations describe how the statement should be used. """ __slots__ = ( 'library', 'source', 'name', 'method', 'type', 'parameters', ) _e_label = 'SYMBOL' _e_factors = ('library', 'source',) # The statement execution methods; symbols allow this to be specified # in order for a default method to be selected. execution_methods = { 'first', 'rows', 'chunks', 'declare', 'load_chunks', 'load_rows', 'column', } def _e_metas(self): yield (None, self.name) def __init__(self, library, source, name = None, method = None, type = None, parameters = None, reference = False, ): self.library = library self.source = source self.name = name if method in (None, '', 'all'): method = None elif method not in self.execution_methods: raise ValueError("unknown execution method: " + repr(method)) self.method = method self.type = type self.parameters = parameters self.reference = reference def __str__(self): """ Provide the source of the query's symbol. """ # Explicitly run str() on source as it is expected that a # given symbol's source may be generated. return str(self.source) class Library(Element): """ A library is mapping of symbol names to `postgresql.lib.Symbol` instances. """ _e_label = 'LIBRARY' _e_factors = () @abstractproperty def address(self) -> str: """ A string indicating the source of the symbols. """ @abstractproperty def name(self) -> str: """ The name to bind the library as. Should be an identifier. """ @abstractproperty def preload(self) -> {str,}: """ A set of symbols that should prepared when the library is bound. """ @abstractmethod def symbols(self) -> [str]: """ Iterable of symbol names provides by the library. """ @abstractmethod def get_symbol(self, name) -> (Symbol, [Symbol]): """ Return the symbol with the given name. """ class SymbolCollection(Library): """ Explicitly composed library. (Symbols passed into __init__) """ preload = None symtypes = ( 'static', 'preload', 'const', 'proc', 'transient', ) def __init__(self, symbols, preface = None): """ Given an iterable of (symtype, symexe, doc, sql) tuples, create a symbol collection. """ self.preface = preface self._address = None self._name = None s = self.symbolsd = {} self.preload = set() for name, (isref, typ, exe, doc, query) in symbols: if typ and typ not in self.symtypes: raise ValueError( "symbol %r has an invalid type: %r" %(name, typ) ) if typ == 'preload': self.preload.add(name) typ = None elif typ == 'proc': pass SYM = Symbol(self, query, name = name, method = exe, type = typ, reference = isref ) s[name] = SYM class ILF(SymbolCollection): 'INI Library Format' def _e_metas(self): yield (None, self._address or 'ILF') def __repr__(self): return self.__class__.__module__ + '.' + self.__class__.__name__ + '.open(' + repr(self.address) + ')' @property def name(self): return self._name @property def address(self): return self._address def get_symbol(self, name): return self.symbolsd.get(name) def symbols(self): return self.symbolsd.keys() @classmethod def from_lines(typ, lines): """ Create an anonymous ILF library from a sequence of lines. """ prev = '' curid = None curblock = [] blocks = [] for line in lines: l = line.strip() if l.startswith('[') and l.endswith(']'): blocks.append((curid, curblock)) curid = line curblock = [] elif line.startswith('*[') and ']' in line: ref, rest = line.split(']', 1) # strip the leading '*[' ref = ref[2:] # dereferencing will take place later. curblock.append((ref, rest)) else: curblock.append(line) blocks.append((curid, curblock)) preface = ''.join(blocks.pop(0)[1]) syms = [] for symdesc, block in blocks: # symbol name # symbol type # how to execute symbol name, styp, exe, *_ = (tuple( symdesc.strip().strip('[]').split(':') ) + (None, None)) doc = '' endofcomment = 0 # resolve any symbol references; only one per line. block = [ x if x.__class__ is not tuple else ( find(reversed(syms), lambda y: y[0] == x[0])[1][-1] + x[1] ) for x in block ] for x in block: if x.startswith('-- '): doc += x[3:] else: break endofcomment += 1 query = ''.join(block[endofcomment:]) if styp == 'proc': query = query.strip() if name.startswith('&'): name = name[1:] isref = True else: isref = False syms.append((name, (isref, styp, exe, doc, query))) return typ(syms, preface = preface) @classmethod def open(typ, filepath, *args, **kw): """ Create a named ILF library from a file path. """ with io.open(filepath, *args, **kw) as fp: r = typ.from_lines(fp) r._address = os.path.abspath(filepath) bn = os.path.basename(filepath) if bn.startswith('lib') and bn.endswith('.sql'): r._name = bn[3:-4] or None return r class BoundSymbol(object): """ A symbol bound to a database(connection). """ def __init__(self, symbol, database): if symbol.type == 'proc': proc = database.proc(symbol) self.method = proc.__call__ self.object = proc else: ps = database.prepare(symbol) m = symbol.method if m is None: self.method = ps.__call__ else: self.method = getattr(ps, m) self.object = ps def __call__(self, *args, **kw): return self.method(*args, **kw) class BoundReference(object): """ A symbol bound to a database whose results make up the source of a symbol that will be created upon the execution of this symbol. A reference to a symbol. """ def __init__(self, symbol, database): self.symbol = symbol self.database = database self.method = database.prepare(symbol).chunks def __call__(self, *args, **kw): chunks = chain.from_iterable(self.method(*args, **kw)) # Join columns with a space, and rows with a newline. src = '\n'.join([' '.join(row) for row in chunks]) return BoundSymbol( Symbol( self.symbol.library, src, name = self.symbol.name, method = self.symbol.method, type = self.symbol.type, parameters = self.symbol.parameters, reference = False, ), self.database, ) class Binding(object): """ Library bound to a database(connection). """ def __init__(self, database, library): self.__dict__.update({ '__database__' : database, '__symbol_library__' : library, '__symbol_cache__' : {}, }) for x in library.preload: # cache all preloaded symbols. getattr(self, x) def __repr__(self): return '' %( self.__symbol_library__.name, self.__database__ ) def __dir__(self): return dir(super()) + list(self.__symbol_library__.symbols()) def __getattr__(self, name): """ Return a BoundSymbol against the Binding's database with the symbol named ``name`` in the Binding's library. """ d = self.__dict__ s = d['__symbol_cache__'] db = d['__database__'] lib = d['__symbol_library__'] bs = s.get(name) if bs is None: # No symbol cached with that name. # Everything is crammed in here because # we do *not* want methods on this object. # The namespace is primarily reserved for symbols. sym = lib.get_symbol(name) if sym is None: raise AttributeError( "symbol %r does not exist in library %r" %( name, lib.address ) ) if sym.reference: # Reference. bs = BoundReference(sym, db) if sym.type == 'const': # Constant Reference means a BoundSymbol. bs = bs() if sym.type != 'transient': s[name] = bs else: if not isinstance(sym, Symbol): # subjective symbol... sym = sym(db) if not isinstance(sym, Symbol): raise TypeError( "callable symbol, %r, did not produce " \ "Symbol instance" %(name,) ) if sym.type == 'const': r = BoundSymbol(sym, db)() if sym.method in ('chunks', 'rows', 'column'): # resolve the iterator r = list(r) bs = s[name] = r else: bs = BoundSymbol(sym, db) if sym.type != 'transient': s[name] = bs return bs class Category(pg_api.Category): """ Library-based Category. """ _e_factors = ('libraries',) def _e_metas(self): yield ('aliases', {k.name: v for k, v in self.aliases.items()}) def __init__(self, *libs, **named_libs): sl = set(libs) nl = set(named_libs.values()) self._direct = sl self.libraries = ElementSet(sl | nl) self.aliases = {} # lib -> [alias-1, alias-2, ..., alias-n] for k, v in named_libs.items(): d = self.aliases.setdefault(v, []) d.append(k) def __call__(self, database): for l in self.libraries: names = list(self.aliases.get(l, ())) if l in self._direct: names.append(l.name) B = Binding(database, l) for n in names: if hasattr(database, n): raise AttributeError("attribute already exists: " + name) setattr(database, n, B) def find_libsql(libname, paths, prefix = 'lib', suffix = '.sql'): """ Given the base library name, `libname`, look for a file named "" in each directory(`paths`). All finds will be yielded out. """ lib = prefix + libname + suffix for p in paths: p = os.path.join(p, lib) if os.path.exists(p): yield p def load(libref): """ Given a reference to a symbol library, instantiate the Library instance. Currently this function accepts: * `str` objects as absolute paths or relative to sys.libpath. * Module objects. """ if isinstance(libref, ModuleType): if hasattr(libref, '__lib'): lib = getattr(libref, '__lib') else: lib = ModuleLibrary(libref) setattr(libref, '__lib', lib) elif isinstance(libref, str): try: if os.path.sep in libref: # sep in libref? it's treated as a path. lib = ILF.open(libref) else: # first one wins. for x in find_libsql(libref, pg_sys.libpath): break else: raise pg_exc.LoadError("library %r not in postgresql.sys.libpath" % (libref,)) lib = ILF.open(x) except pg_exc.LoadError: raise except Exception: # any exception is a load error. raise pg_exc.LoadError("failed load ILF, " + repr(libref)) else: raise TypeError("load takes a module or str, given " + type(libref).__name__) return lib sys = load('sys') __docformat__ = 'reStructuredText' fe-1.1.0/postgresql/lib/libsys.sql000066400000000000000000000212531203372773200171130ustar00rootroot00000000000000## # libsys.sql - SQL to support driver features ## -- Queries for dealing with the PostgreSQL catalogs for supporting the driver. [lookup_type::first] SELECT ns.nspname as namespace, bt.typname, bt.typtype, bt.typlen, bt.typelem, bt.typrelid, ae.oid AS ae_typid, ae.typreceive::oid != 0 AS ae_hasbin_input, ae.typsend::oid != 0 AS ae_hasbin_output FROM pg_catalog.pg_type bt LEFT JOIN pg_type ae ON ( bt.typlen = -1 AND bt.typelem != 0 AND bt.typelem = ae.oid ) LEFT JOIN pg_catalog.pg_namespace ns ON (ns.oid = bt.typnamespace) WHERE bt.oid = $1 [lookup_composite] -- Get the type Oid and name of the attributes in `attnum` order. SELECT CAST(atttypid AS oid) AS atttypid, CAST(attname AS text) AS attname, tt.typtype = 'd' AS is_domain FROM pg_catalog.pg_type t LEFT JOIN pg_catalog.pg_attribute a ON (t.typrelid = a.attrelid) LEFT JOIN pg_type tt ON (a.atttypid = tt.oid) WHERE attrelid = $1 AND NOT attisdropped AND attnum > 0 ORDER BY attnum ASC [lookup_basetype_recursive] SELECT (CASE WHEN tt.typtype = 'd' THEN (WITH RECURSIVE typehierarchy(typid, depth) AS ( SELECT t2.typbasetype, 0 FROM pg_type t2 WHERE t2.oid = tt.oid UNION ALL SELECT t2.typbasetype, th.depth + 1 FROM pg_type t2, typehierarchy th WHERE th.typid = t2.oid AND t2.typbasetype != 0 ) SELECT typid FROM typehierarchy ORDER BY depth DESC LIMIT 1) ELSE NULL END) AS basetypid FROM pg_catalog.pg_type tt WHERE tt.oid = $1 [lookup_basetype] SELECT tt.typbasetype FROM pg_catalog.pg_type tt WHERE tt.oid = $1 [lookup_procedures] SELECT pg_proc.oid, pg_proc.*, pg_proc.oid::regproc AS _proid, pg_proc.oid::regprocedure as procedure_id, COALESCE(string_to_array(trim(replace(textin(oidvectorout(proargtypes)), ',', ' '), '{}'), ' ')::oid[], '{}'::oid[]) AS proargtypes, (pg_type.oid = 'record'::regtype or pg_type.typtype = 'c') AS composite FROM pg_catalog.pg_proc LEFT JOIN pg_catalog.pg_type ON ( pg_proc.prorettype = pg_type.oid ) [lookup_procedure_oid::first] *[lookup_procedures] WHERE pg_proc.oid = $1 [lookup_procedure_rp::first] *[lookup_procedures] WHERE pg_proc.oid = regprocedurein($1) [lookup_prepared_xacts::first] SELECT COALESCE(ARRAY( SELECT gid::text FROM pg_catalog.pg_prepared_xacts WHERE database = current_database() AND ( owner = $1::text OR ( (SELECT rolsuper FROM pg_roles WHERE rolname = $1::text) ) ) ORDER BY prepared ASC ), ('{}'::text[])) [regtypes::column] SELECT pg_catalog.regtypein(pg_catalog.textout(($1::text[])[i]))::oid AS typoid FROM pg_catalog.generate_series(1, array_upper($1::text[], 1)) AS g(i) [xact_is_prepared::first] SELECT TRUE FROM pg_catalog.pg_prepared_xacts WHERE gid::text = $1 [get_statement_source::first] SELECT statement FROM pg_catalog.pg_prepared_statements WHERE name = $1 [setting_get] SELECT setting FROM pg_catalog.pg_settings WHERE name = $1 [setting_set::first] SELECT pg_catalog.set_config($1, $2, false) [setting_len::first] SELECT count(*) FROM pg_catalog.pg_settings [setting_item] SELECT name, setting FROM pg_catalog.pg_settings WHERE name = $1 [setting_mget] SELECT name, setting FROM pg_catalog.pg_settings WHERE name = ANY ($1) [setting_keys] SELECT name FROM pg_catalog.pg_settings ORDER BY name [setting_values] SELECT setting FROM pg_catalog.pg_settings ORDER BY name [setting_items] SELECT name, setting FROM pg_catalog.pg_settings ORDER BY name [setting_update] SELECT ($1::text[][])[i][1] AS key, pg_catalog.set_config(($1::text[][])[i][1], $1[i][2], false) AS value FROM pg_catalog.generate_series(1, array_upper(($1::text[][]), 1)) g(i) [startup_data:transient:first] -- 8.2 and greater SELECT pg_catalog.version()::text AS version, backend_start::text, client_addr::text, client_port::int FROM pg_catalog.pg_stat_activity WHERE procpid = pg_catalog.pg_backend_pid() UNION ALL SELECT pg_catalog.version()::text AS version, NULL::text AS backend_start, NULL::text AS client_addr, NULL::int AS client_port LIMIT 1; [startup_data_92:transient:first] -- 9.2 and greater SELECT pg_catalog.version()::text AS version, backend_start::text, client_addr::text, client_port::int FROM pg_catalog.pg_stat_activity WHERE pid = pg_catalog.pg_backend_pid() UNION ALL SELECT pg_catalog.version()::text AS version, NULL::text AS backend_start, NULL::text AS client_addr, NULL::int AS client_port LIMIT 1; [startup_data_no_start:transient:first] -- 8.1 only, but is unused as often the backend's activity row is not -- immediately present. SELECT pg_catalog.version()::text AS version, NULL::text AS backend_start, client_addr::text, client_port::int FROM pg_catalog.pg_stat_activity WHERE procpid = pg_catalog.pg_backend_pid(); [startup_data_only_version:transient:first] -- In 8.0, there's nothing there. SELECT pg_catalog.version()::text AS version, NULL::text AS backend_start, NULL::text AS client_addr, NULL::int AS client_port; [terminate_backends:transient:column] -- Terminate all except mine. SELECT procpid, pg_catalog.pg_terminate_backend(procpid) FROM pg_catalog.pg_stat_activity WHERE procpid != pg_catalog.pg_backend_pid() [terminate_backends_92:transient:column] -- Terminate all except mine. 9.2 and later SELECT pid, pg_catalog.pg_terminate_backend(pid) FROM pg_catalog.pg_stat_activity WHERE pid != pg_catalog.pg_backend_pid() [cancel_backends:transient:column] -- Cancel all except mine. SELECT procpid, pg_catalog.pg_cancel_backend(procpid) FROM pg_catalog.pg_stat_activity WHERE procpid != pg_catalog.pg_backend_pid() [cancel_backends_92:transient:column] -- Cancel all except mine. 9.2 and later SELECT pid, pg_catalog.pg_cancel_backend(pid) FROM pg_catalog.pg_stat_activity WHERE pid != pg_catalog.pg_backend_pid() [sizeof_db:transient:first] SELECT pg_catalog.pg_database_size(current_database())::bigint [sizeof_cluster:transient:first] SELECT SUM(pg_catalog.pg_database_size(datname))::bigint FROM pg_database [sizeof_relation::first] SELECT pg_catalog.pg_relation_size($1::text)::bigint [pg_reload_conf:transient:] SELECT pg_reload_conf() [languages:transient:column] SELECT lanname FROM pg_catalog.pg_language [listening_channels:transient:column] SELECT channel FROM pg_catalog.pg_listening_channels() AS x(channel) [listening_relations:transient:column] -- listening_relations: old version of listening_channels. SELECT relname as channel FROM pg_catalog.pg_listener WHERE listenerpid = pg_catalog.pg_backend_pid(); [notify::first] -- 9.0 and greater SELECT COUNT(pg_catalog.pg_notify(($1::text[])[i][1], $1[i][2]) IS NULL) FROM pg_catalog.generate_series(1, array_upper($1, 1)) AS g(i) [release_advisory_shared] SELECT CASE WHEN ($2::int8[])[i] IS NULL THEN pg_catalog.pg_advisory_unlock_shared(($1::int4[])[i][1], $1[i][2]) ELSE pg_catalog.pg_advisory_unlock_shared($2[i]) END AS released FROM pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) [acquire_advisory_shared] SELECT COUNT(( CASE WHEN ($2::int8[])[i] IS NULL THEN pg_catalog.pg_advisory_lock_shared(($1::int4[])[i][1], $1[i][2]) ELSE pg_catalog.pg_advisory_lock_shared($2[i]) END ) IS NULL) AS acquired FROM pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) [try_advisory_shared] SELECT CASE WHEN ($2::int8[])[i] IS NULL THEN pg_catalog.pg_try_advisory_lock_shared(($1::int4[])[i][1], $1[i][2]) ELSE pg_catalog.pg_try_advisory_lock_shared($2[i]) END AS acquired FROM pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) [release_advisory_exclusive] SELECT CASE WHEN ($2::int8[])[i] IS NULL THEN pg_catalog.pg_advisory_unlock(($1::int4[])[i][1], $1[i][2]) ELSE pg_catalog.pg_advisory_unlock($2[i]) END AS released FROM pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) [acquire_advisory_exclusive] SELECT COUNT(( CASE WHEN ($2::int8[])[i] IS NULL THEN pg_catalog.pg_advisory_lock(($1::int4[])[i][1], $1[i][2]) ELSE pg_catalog.pg_advisory_lock($2[i]) END ) IS NULL) AS acquired -- Guaranteed to be acquired once complete. FROM pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) [try_advisory_exclusive] SELECT CASE WHEN ($2::int8[])[i] IS NULL THEN pg_catalog.pg_try_advisory_lock(($1::int4[])[i][1], $1[i][2]) ELSE pg_catalog.pg_try_advisory_lock($2[i]) END AS acquired FROM pg_catalog.generate_series(1, COALESCE(array_upper($2::int8[], 1), array_upper($1::int4[], 1))) AS g(i) fe-1.1.0/postgresql/message.py000066400000000000000000000074161203372773200163220ustar00rootroot00000000000000## # .message - PostgreSQL message representation ## from operator import itemgetter from .python.element import prime_factor # Final msghook called exists at .sys.msghook from . import sys as pg_sys from .api import Message class Message(Message): """ A message emitted by PostgreSQL. This element is universal, so `postgresql.api.Message` is a complete implementation for representing a message. Any interface should produce these objects. """ _e_label = property(lambda x: getattr(x, 'details').get('severity', 'MESSAGE')) _e_factors = ('creator',) def _e_metas(self, get0 = itemgetter(0)): yield (None, self.message) if self.code and self.code != "00000": yield ('CODE', self.code) locstr = self.location_string if locstr: yield ('LOCATION', locstr + ' from ' + self.source) else: yield ('LOCATION', self.source) for k, v in sorted(self.details.items(), key = get0): if k not in self.standard_detail_coverage: yield (k.upper(), str(v)) source = 'SERVER' code = '00000' message = None details = None severities = ( 'DEBUG', 'INFO', 'NOTICE', 'WARNING', 'ERROR', 'FATAL', 'PANIC', ) sources = ( 'SERVER', 'CLIENT', ) def isconsistent(self, other): """ Return `True` if the all the fields of the message in `self` are equivalent to the fields in `other`. """ if not isinstance(other, self.__class__): return False # creator is contextual information return ( self.code == other.code and \ self.message == other.message and \ self.details == other.details and \ self.source == other.source ) def __init__(self, message : "The primary information of the message", code : "Message code to attach (SQL state)" = None, details : "additional information associated with the message" = {}, source : "Which side generated the message(SERVER, CLIENT)" = None, creator : "The interface element that called for instantiation" = None, ): self.message = message self.details = details self.creator = creator if code is not None and self.code != code: self.code = code if source is not None and self.source != source: self.source = source def __repr__(self): return "{mod}.{typname}({message!r}{code}{details}{source}{creator})".format( mod = self.__module__, typname = self.__class__.__name__, message = self.message, code = ( "" if self.code == type(self).code else ", code = " + repr(self.code) ), details = ( "" if not self.details else ", details = " + repr(self.details) ), source = ( "" if self.source is None else ", source = " + repr(self.source) ), creator = ( "" if self.creator is None else ", creator = " + repr(self.creator) ) ) @property def location_string(self): """ A single line representation of the 'file', 'line', and 'function' keys in the `details` dictionary. """ details = self.details loc = [ details.get(k, '?') for k in ('file', 'line', 'function') ] return ( "" if loc == ['?', '?', '?'] else "File {0!r}, "\ "line {1!s}, in {2!s}".format(*loc) ) # keys to filter in .details standard_detail_coverage = frozenset(['message', 'severity', 'file', 'function', 'line',]) def emit(self, starting_point = None): """ Take the given message object and hand it to all the primary factors(creator) with a msghook callable. """ if starting_point is not None: f = starting_point else: f = self.creator while f is not None: if getattr(f, 'msghook', None) is not None: if f.msghook(self): # the trap returned a nonzero value, # so don't continue raising. (like with's __exit__) return f f = prime_factor(f) if f: f = f[1] # if the next primary factor is without a raise or does not exist, # send the message to postgresql.sys.msghook pg_sys.msghook(self) fe-1.1.0/postgresql/notifyman.py000066400000000000000000000152701203372773200166770ustar00rootroot00000000000000## # .notifyman - Receive and manage NOTIFY events. ## """ Notification Management Tools Primarily this module houses the `NotificationManager` class which provides an iterator for a NOTIFY event loop against a set of connections. >>> import postgresql >>> db = postgresql.open(...) >>> from postgresql.notifyman import NotificationManager >>> nm = NotificationManager(db, timeout = 10) # idle events every 10 seconds >>> for x in nm: ... if x is None: ... # idle event ... ... ... db, notifies = x ... for channel, payload, pid in notifies: ... ... """ from time import time from select import select from itertools import chain class NotificationManager(object): """ A class for managing the asynchronous notifications received by a set of connections. Instances provide the iterator for an event loop that responds to NOTIFYs received by the connections being watched. There is no thread safety, so when a connection is being managed, it should not be used concurrently in other threads while being managed. """ __slots__ = ( 'connections', 'garbage', 'incoming', 'timeout', '_last_time', '_pulled', ) def __init__(self, *connections, timeout = None): self.settimeout(timeout) self.connections = set(connections) # Connections that failed. self.garbage = set() # Used to store NOTIFYs consumed from the connections self.incoming = None self._last_time = None # connection -> sequence of NOTIFYs self._pulled = dict() # Check the wire *and* wait for new messages. def _wait_on_wires(self, time = time, select = select): if self.timeout == 0: # We're polling. max_duration = 0 else: # If timeout is None, we don't issue idle events, but # we still cycle in case the timeout is changed. if self._last_time is not None: max_duration = (self.timeout or 10) - (time() - self._last_time) if max_duration < 0: max_duration = 0 else: self._last_time = time() max_duration = self.timeout or 10 # Connections already marked as "bad" should not be checked. check = self.connections - self.garbage for db in check: if db.closed: self.connections.remove(db) self.garbage.add(db) check = self.connections - self.garbage r, w, x = select(check, (), check, max_duration) # Make sure the connection's _notifies get filled. for db in r: # Collect any pending events. try: # Even if db is in a failed transaction, this # 'null' command will succeed. # (only connection failures blow up) db.execute('') except Exception: # failed to collect notifies; put in exception list. # It is very unlikely that this is *not* a FATAL error. x.append(db) self.trash(x) def trash(self, connections): """ Remove the given connections from the set of good connections, and add them to the `garbage` set. This method can be overridden by subclasses to take a callback approach to connection failures. """ # Identify the bad connections. self.garbage.update(connections) self.connections.difference_update(connections) def queue(self, db, notifies): """ Queue the notifies for the specified connection. Upon success, the This method can be overridden by subclasses to take a callback approach to notification management. """ l = self._pulled.setdefault(db, list()) l.extend(notifies) # Check the connection's _notifies list; just scan everything. def _pull_from_connections(self): for db in self.connections: if not db._notifies: # nothing queued up, look at the next connection continue # Pull notifies into the NotificationManager decode = db.typio.decode notifies = [ (decode(x.channel), decode(x.payload), x.pid) for x in db._notifies ] self.queue(db, notifies) del db._notifies[:len(notifies)] # "Append" the pulled NOTIFYs to the 'incoming' iterator. def _queue_next(self): new_seqs = [] for db in self._pulled: decode = db.typio.decode new_seqs.append((db, self._pulled[db])) if new_seqs: if self.incoming: # Already have incoming; not an expected condition, # but let's compensate. self.incoming, self._pulled = chain(self.incoming, iter(new_seqs)), {} else: self.incoming, self._pulled = iter(new_seqs), {} elif self.incoming is None: # Use this to trigger the StopIteration case of zero-timeout. self.incoming, self._pulled = iter(()), {} def _timedout(self, time = time): # Idles are guaranteed to occur, but make sure that # __next__ has a chance to check the connections and the wires. now = time() if self._last_time is None: self._last_time = now elif self.timeout and now >= (self._last_time + self.timeout): # Set last_time to None in case the timeout is so low # that this condition keeps NOTIFYs from being seen. self._last_time = None # Signal timeout. return True else: # toggle back to None. self._last_time = None return False def settimeout(self, seconds): """ Set the maximum duration, in seconds, for waiting for NOTIFYs on the set of managed connections. The given `seconds` argument can be a number or `None`. A timeout of `None` means no timeout, and "idle" events will never occur. A timeout of `0` means to never wait for NOTIFYs. This has the effect of a StopIteration being raised by `__next__` when there are no more Notifications available for any of the connections in the set. "Idle" events will never occur in this situation as well. A timeout greater than zero means to emit `None` as "idle" events into the loop at the specified interval. Idle events are guaranteed to occur. """ if seconds is not None and seconds < 0: raise ValueError("cannot set timeout less than zero") self.timeout = seconds def gettimeout(self): 'Get the timeout.' return self.timeout def __iter__(self): return self def __next__(self, time = time): checked_wire = True # Loop until NOTIFY received or timeout. while True: if self.incoming is not None: try: return next(self.incoming) except StopIteration: # Nothing more in this incoming. self.incoming = None # Allow a zero timeout to be used to indicate # that there are no NOTIFYs to be read. # This can be used to poll a set of # connections instead of listening. if self.timeout == 0 or not self.connections: raise # timeout happened? yield the "idle" event. # This check **must** happen after .incoming is checked. # Never emit idle when there are real events. if self._timedout(): return None if not checked_wire and self.connections: # Nothing queued up, check connections if any. self._wait_on_wires() checked_wire = True else: checked_wire = False self._pull_from_connections() self._queue_next() fe-1.1.0/postgresql/pgpassfile.py000066400000000000000000000036001203372773200170220ustar00rootroot00000000000000## # .pgpassfile - parse and lookup passwords in a pgpassfile ## 'Parse pgpass files and subsequently lookup a password.' from os.path import exists def split(line, len = len): line = line.strip() if not line: return None r = [] continuation = False for x in line.split(':'): if continuation: # The colon was preceded by a backslash, it's part # of the last field. Substitute the trailing backslash # with the colon and append the next value. r[-1] = r[-1][:-1] + ':' + x.replace('\\\\', '\\') continuation = False else: # Even number of backslashes preceded the split. # Normal field. r.append(x.replace('\\\\', '\\')) # Determine if the next field is a continuation of this one. if (len(x) - len(x.rstrip('\\'))) % 2 == 1: continuation = True if len(r) != 5: # Too few or too many fields. return None return r def parse(data): 'produce a list of [(word, (host,port,dbname,user))] from a pgpass file object' return [ (x[-1], x[0:4]) for x in [split(line) for line in data] if x ] def lookup_password(words, uhpd): """ lookup_password(words, (user, host, port, database)) -> password Where 'words' is the output from pgpass.parse() """ user, host, port, database = uhpd for word, (w_host, w_port, w_database, w_user) in words: if (w_user == '*' or w_user == user) and \ (w_host == '*' or w_host == host) and \ (w_port == '*' or w_port == port) and \ (w_database == '*' or w_database == database): return word def lookup_password_file(path, t): 'like lookup_password, but takes a file path' with open(path) as f: return lookup_password(parse(f), t) def lookup_pgpass(d, passfile, exists = exists): # If the password file exists, lookup the password # using the config's criteria. if exists(passfile): return lookup_password_file(passfile, ( str(d['user']), str(d['host']), str(d['port']), str(d.get('database', d['user'])) )) fe-1.1.0/postgresql/port/000077500000000000000000000000001203372773200153005ustar00rootroot00000000000000fe-1.1.0/postgresql/port/__init__.py000066400000000000000000000003751203372773200174160ustar00rootroot00000000000000## # .port ## """ Platform specific modules. The subject of each module should be the feature and the target platform. This is done to keep modules small and descriptive. These modules are for internal use only. """ __docformat__ = 'reStructuredText' fe-1.1.0/postgresql/port/_optimized/000077500000000000000000000000001203372773200174435ustar00rootroot00000000000000fe-1.1.0/postgresql/port/_optimized/README000066400000000000000000000001151203372773200203200ustar00rootroot00000000000000This is the C ports of the more performance critical parts of py-postgresql. fe-1.1.0/postgresql/port/_optimized/buffer.c000066400000000000000000000275031203372773200210670ustar00rootroot00000000000000/* * .port.optimized.pq_message_buffer - PQ message stream */ /* * PQ messages normally take the form {type, (size), data} */ #define include_buffer_types \ mTYPE(pq_message_stream) struct p_list { PyObject *data; /* PyBytes pushed onto the buffer */ struct p_list *next; }; struct p_place { struct p_list *list; uint32_t offset; }; struct p_buffer { PyObject_HEAD struct p_place position; struct p_list *last; /* for quick appends */ }; /* * Free the list until the given stop */ static void pl_truncate(struct p_list *pl, struct p_list *stop) { while (pl != stop) { struct p_list *next = pl->next; Py_DECREF(pl->data); free(pl); pl = next; } } /* * Reset the buffer */ static void pb_truncate(struct p_buffer *pb) { struct p_list *pl = pb->position.list; pb->position.offset = 0; pb->position.list = NULL; pb->last = NULL; pl_truncate(pl, NULL); } /* * p_truncate - truncate the buffer */ static PyObject * p_truncate(PyObject *self) { pb_truncate((struct p_buffer *) self); Py_INCREF(Py_None); return(Py_None); } static void p_dealloc(PyObject *self) { struct p_buffer *pb = ((struct p_buffer *) self); pb_truncate(pb); self->ob_type->tp_free(self); } static PyObject * p_new(PyTypeObject *subtype, PyObject *args, PyObject *kw) { static char *kwlist[] = {NULL}; struct p_buffer *pb; PyObject *rob; if (!PyArg_ParseTupleAndKeywords(args, kw, "", kwlist)) return(NULL); rob = subtype->tp_alloc(subtype, 0); pb = ((struct p_buffer *) rob); pb->last = pb->position.list = NULL; pb->position.offset = 0; return(rob); } /* * p_at_least - whether the position has at least given number of bytes. */ static char p_at_least(struct p_place *p, uint32_t amount) { int32_t current = 0; struct p_list *pl; pl = p->list; if (pl) current += PyBytes_GET_SIZE(pl->data) - p->offset; if (current >= amount) return((char) 1); if (pl) { for (pl = pl->next; pl != NULL; pl = pl->next) { current += PyBytes_GET_SIZE(pl->data); if (current >= amount) return((char) 1); } } return((char) 0); } static uint32_t p_seek(struct p_place *p, uint32_t amount) { uint32_t amount_left = amount; Py_ssize_t chunk_size; /* Can't seek after the end. */ if (!p->list || p->offset == PyBytes_GET_SIZE(p->list->data)) return(0); chunk_size = PyBytes_GET_SIZE(p->list->data) - p->offset; while (amount_left > 0) { /* * The current list item has the position. * Set the offset and break out. */ if (amount_left < chunk_size) { p->offset += amount_left; amount_left = 0; break; } amount_left -= chunk_size; p->list = p->list->next; p->offset = 0; if (p->list == NULL) break; chunk_size = PyBytes_GET_SIZE(p->list->data); } return(amount - amount_left); } static uint32_t p_memcpy(char *dst, struct p_place *p, uint32_t amount) { struct p_list *pl = p->list; uint32_t offset = p->offset; uint32_t amount_left = amount; char *src; Py_ssize_t chunk_size; /* Nothing to read */ if (pl == NULL) return(0); src = (PyBytes_AS_STRING(pl->data) + offset); chunk_size = PyBytes_GET_SIZE(pl->data) - offset; while (amount_left > 0) { uint32_t this_read = chunk_size < amount_left ? chunk_size : amount_left; memcpy(dst, src, this_read); dst = dst + this_read; amount_left = amount_left - this_read; pl = pl->next; if (pl == NULL) break; src = PyBytes_AS_STRING(pl->data); chunk_size = PyBytes_GET_SIZE(pl->data); } return(amount - amount_left); } static Py_ssize_t p_length(PyObject *self) { char header[5]; long msg_count = 0; uint32_t msg_length; uint32_t copy_amount = 0; struct p_buffer *pb; struct p_place p; pb = ((struct p_buffer *) self); p.list = pb->position.list; p.offset = pb->position.offset; while (p.list != NULL) { copy_amount = p_memcpy(header, &p, 5); if (copy_amount < 5) break; p_seek(&p, copy_amount); memcpy(&msg_length, header + 1, 4); msg_length = local_ntohl(msg_length); if (msg_length < 4) { PyErr_Format(PyExc_ValueError, "invalid message size '%d'", msg_length); return(-1); } msg_length -= 4; if (p_seek(&p, msg_length) < msg_length) break; ++msg_count; } return(msg_count); } static PySequenceMethods pq_ms_as_sequence = { (lenfunc) p_length, 0 }; /* * Build a tuple from the given place. */ static PyObject * p_build_tuple(struct p_place *p) { char header[5]; uint32_t msg_length; PyObject *tuple; PyObject *mt, *md; char *body = NULL; uint32_t copy_amount = 0; copy_amount = p_memcpy(header, p, 5); if (copy_amount < 5) return(NULL); p_seek(p, copy_amount); memcpy(&msg_length, header + 1, 4); msg_length = local_ntohl(msg_length); if (msg_length < 4) { PyErr_Format(PyExc_ValueError, "invalid message size '%d'", msg_length); return(NULL); } msg_length -= 4; if (!p_at_least(p, msg_length)) return(NULL); /* * Copy out the message body if we need to. */ if (msg_length > 0) { body = malloc(msg_length); if (body == NULL) { PyErr_SetString(PyExc_MemoryError, "could not allocate memory for message data"); return(NULL); } copy_amount = p_memcpy(body, p, msg_length); if (copy_amount != msg_length) { free(body); return(NULL); } p_seek(p, copy_amount); } mt = PyTuple_GET_ITEM(message_types, (int) header[0]); if (mt == NULL) { /* * With message_types, this is nearly a can't happen. */ if (body != NULL) free(body); return(NULL); } Py_INCREF(mt); md = PyBytes_FromStringAndSize(body, (Py_ssize_t) msg_length); if (body != NULL) free(body); if (md == NULL) { Py_DECREF(mt); return(NULL); } tuple = PyTuple_New(2); if (tuple == NULL) { Py_DECREF(mt); Py_DECREF(md); } else { PyTuple_SET_ITEM(tuple, 0, mt); PyTuple_SET_ITEM(tuple, 1, md); } return(tuple); } static PyObject * p_write(PyObject *self, PyObject *data) { struct p_buffer *pb; if (!PyBytes_Check(data)) { PyErr_SetString(PyExc_TypeError, "pq buffer.write() method requires a bytes object"); return(NULL); } pb = ((struct p_buffer *) self); if (PyBytes_GET_SIZE(data) > 0) { struct p_list *pl; pl = malloc(sizeof(struct p_list)); if (pl == NULL) { PyErr_SetString(PyExc_MemoryError, "could not allocate memory for pq message stream data"); return(NULL); } pl->data = data; Py_INCREF(data); pl->next = NULL; if (pb->last == NULL) { /* * First and last. */ pb->position.list = pb->last = pl; } else { pb->last->next = pl; pb->last = pl; } } Py_INCREF(Py_None); return(Py_None); } static PyObject * p_next(PyObject *self) { struct p_buffer *pb = ((struct p_buffer *) self); struct p_place p; PyObject *rob; p.offset = pb->position.offset; p.list = pb->position.list; rob = p_build_tuple(&p); if (rob != NULL) { pl_truncate(pb->position.list, p.list); pb->position.list = p.list; pb->position.offset = p.offset; if (p.list == NULL) pb->last = NULL; } return(rob); } static PyObject * p_read(PyObject *self, PyObject *args) { int cur_msg, msg_count = -1, msg_in = 0; struct p_place p; struct p_buffer *pb; PyObject *rob = NULL; if (!PyArg_ParseTuple(args, "|i", &msg_count)) return(NULL); pb = (struct p_buffer *) self; p.list = pb->position.list; p.offset = pb->position.offset; msg_in = p_length(self); msg_count = msg_count < msg_in && msg_count != -1 ? msg_count : msg_in; rob = PyTuple_New(msg_count); for (cur_msg = 0; cur_msg < msg_count; ++cur_msg) { PyObject *msg_tup = NULL; msg_tup = p_build_tuple(&p); if (msg_tup == NULL) { if (PyErr_Occurred()) { Py_DECREF(rob); return(NULL); } break; } PyTuple_SET_ITEM(rob, cur_msg, msg_tup); } pl_truncate(pb->position.list, p.list); pb->position.list = p.list; pb->position.offset = p.offset; if (p.list == NULL) pb->last = NULL; return(rob); } static PyObject * p_has_message(PyObject *self) { char header[5]; uint32_t msg_length; uint32_t copy_amount = 0; struct p_buffer *pb; struct p_place p; PyObject *rob; pb = ((struct p_buffer *) self); p.list = pb->position.list; p.offset = pb->position.offset; copy_amount = p_memcpy(header, &p, 5); if (copy_amount < 5) { Py_INCREF(Py_False); return(Py_False); } p_seek(&p, copy_amount); memcpy(&msg_length, header + 1, 4); msg_length = local_ntohl(msg_length); if (msg_length < 4) { PyErr_Format(PyExc_ValueError, "invalid message size '%d'", msg_length); return(NULL); } msg_length -= 4; rob = p_at_least(&p, msg_length) ? Py_True : Py_False; Py_INCREF(rob); return(rob); } static PyObject * p_next_message(PyObject *self) { struct p_buffer *pb = ((struct p_buffer *) self); struct p_place p; PyObject *rob; p.offset = pb->position.offset; p.list = pb->position.list; rob = p_build_tuple(&p); if (rob == NULL) { if (!PyErr_Occurred()) { rob = Py_None; Py_INCREF(rob); } } else { pl_truncate(pb->position.list, p.list); pb->position.list = p.list; pb->position.offset = p.offset; if (p.list == NULL) pb->last = NULL; } return(rob); } /* * p_getvalue - get the unconsumed data in the buffer * * Normally used in conjunction with truncate to transfer * control of the wire to another state machine. */ static PyObject * p_getvalue(PyObject *self) { struct p_buffer *pb = ((struct p_buffer *) self); struct p_list *l; uint32_t initial_offset; PyObject *rob; /* * Don't include data from already read() messages. */ initial_offset = pb->position.offset; l = pb->position.list; if (l == NULL) { /* * Empty list. */ return(PyBytes_FromString("")); } /* * Get the first chunk. */ rob = PyBytes_FromStringAndSize( (PyBytes_AS_STRING(l->data) + initial_offset), PyBytes_GET_SIZE(l->data) - initial_offset ); if (rob == NULL) return(NULL); l = l->next; while (l != NULL) { PyBytes_Concat(&rob, l->data); if (rob == NULL) break; l = l->next; } return(rob); } static PyMethodDef p_methods[] = { {"write", p_write, METH_O, PyDoc_STR("write the string to the buffer"),}, {"read", p_read, METH_VARARGS, PyDoc_STR("read the number of messages from the buffer")}, {"truncate", (PyCFunction) p_truncate, METH_NOARGS, PyDoc_STR("remove the contents of the buffer"),}, {"has_message", (PyCFunction) p_has_message, METH_NOARGS, PyDoc_STR("whether the buffer has a message ready"),}, {"next_message", (PyCFunction) p_next_message, METH_NOARGS, PyDoc_STR("get and remove the next message--None if none."),}, {"getvalue", (PyCFunction) p_getvalue, METH_NOARGS, PyDoc_STR("get the unprocessed data in the buffer")}, {NULL} }; PyTypeObject pq_message_stream_Type = { PyVarObject_HEAD_INIT(NULL, 0) "postgresql.port.optimized.pq_message_stream", /* tp_name */ sizeof(struct p_buffer), /* tp_basicsize */ 0, /* tp_itemsize */ p_dealloc, /* tp_dealloc */ NULL, /* tp_print */ NULL, /* tp_getattr */ NULL, /* tp_setattr */ NULL, /* tp_compare */ NULL, /* tp_repr */ NULL, /* tp_as_number */ &pq_ms_as_sequence, /* tp_as_sequence */ NULL, /* tp_as_mapping */ NULL, /* tp_hash */ NULL, /* tp_call */ NULL, /* tp_str */ NULL, /* tp_getattro */ NULL, /* tp_setattro */ NULL, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ PyDoc_STR( "Buffer data on write, return messages on read" ), /* tp_doc */ NULL, /* tp_traverse */ NULL, /* tp_clear */ NULL, /* tp_richcompare */ 0, /* tp_weaklistoffset */ NULL, /* tp_iter */ p_next, /* tp_iternext */ p_methods, /* tp_methods */ NULL, /* tp_members */ NULL, /* tp_getset */ NULL, /* tp_base */ NULL, /* tp_dict */ NULL, /* tp_descr_get */ NULL, /* tp_descr_set */ 0, /* tp_dictoffset */ NULL, /* tp_init */ NULL, /* tp_alloc */ p_new, /* tp_new */ NULL, /* tp_free */ }; /* * vim: ts=3:sw=3:noet: */ fe-1.1.0/postgresql/port/_optimized/element3.c000066400000000000000000000372021203372773200213270ustar00rootroot00000000000000/* * .port.optimized - .protocol.element3 optimizations */ #define include_element3_functions \ mFUNC(cat_messages, METH_O, "cat the serialized form of the messages in the given list") \ mFUNC(parse_tuple_message, METH_O, "parse the given tuple data into a tuple of raw data") \ mFUNC(pack_tuple_data, METH_O, "serialize the give tuple message[tuple of bytes()]") \ mFUNC(consume_tuple_messages, METH_O, "create a list of parsed tuple data tuples") \ /* * Given a tuple of bytes and None objects, join them into a * a single bytes object with sizes. */ static PyObject * _pack_tuple_data(PyObject *tup) { PyObject *rob; Py_ssize_t natts; Py_ssize_t catt; char *buf = NULL; char *bufpos = NULL; Py_ssize_t bufsize = 0; if (!PyTuple_Check(tup)) { PyErr_Format( PyExc_TypeError, "pack_tuple_data requires a tuple, given %s", PyObject_TypeName(tup) ); return(NULL); } natts = PyTuple_GET_SIZE(tup); if (natts == 0) return(PyBytes_FromString("")); /* discover buffer size and valid att types */ for (catt = 0; catt < natts; ++catt) { PyObject *ob; ob = PyTuple_GET_ITEM(tup, catt); if (ob == Py_None) { bufsize = bufsize + 4; } else if (PyBytes_CheckExact(ob)) { bufsize = bufsize + PyBytes_GET_SIZE(ob) + 4; } else { PyErr_Format( PyExc_TypeError, "cannot serialize attribute %d, expected bytes() or None, got %s", (int) catt, PyObject_TypeName(ob) ); return(NULL); } } buf = malloc(bufsize); if (buf == NULL) { PyErr_Format( PyExc_MemoryError, "failed to allocate %d bytes of memory for packing tuple data", bufsize ); return(NULL); } bufpos = buf; for (catt = 0; catt < natts; ++catt) { PyObject *ob; ob = PyTuple_GET_ITEM(tup, catt); if (ob == Py_None) { uint32_t attsize = 0xFFFFFFFFL; /* Indicates NULL */ Py_MEMCPY(bufpos, &attsize, 4); bufpos = bufpos + 4; } else { Py_ssize_t size = PyBytes_GET_SIZE(ob); uint32_t msg_size; if (size > 0xFFFFFFFE) { PyErr_Format(PyExc_OverflowError, "data size of %d is greater than attribute capacity", catt ); } msg_size = local_ntohl((uint32_t) size); Py_MEMCPY(bufpos, &msg_size, 4); bufpos = bufpos + 4; Py_MEMCPY(bufpos, PyBytes_AS_STRING(ob), PyBytes_GET_SIZE(ob)); bufpos = bufpos + PyBytes_GET_SIZE(ob); } } rob = PyBytes_FromStringAndSize(buf, bufsize); free(buf); return(rob); } /* * dst must be of PyTuple_Type with at least natts items slots. */ static int _unpack_tuple_data(PyObject *dst, register uint16_t natts, register const char *data, Py_ssize_t data_len) { static const unsigned char null_sequence[4] = {0xFF, 0xFF, 0xFF, 0xFF}; register PyObject *ob; register uint16_t cnatt = 0; register uint32_t attsize; register const char *next; register const char *eod = data + data_len; char attsize_buf[4]; while (cnatt < natts) { /* * Need enough data for the attribute size. */ next = data + 4; if (next > eod) { PyErr_Format(PyExc_ValueError, "not enough data available for attribute %d's size header: " "needed %d bytes, but only %lu remain at position %lu", cnatt, 4, eod - data, data_len - (eod - data) ); return(-1); } Py_MEMCPY(attsize_buf, data, 4); data = next; if ((*((uint32_t *) attsize_buf)) == (*((uint32_t *) null_sequence))) { /* * NULL. */ Py_INCREF(Py_None); PyTuple_SET_ITEM(dst, cnatt, Py_None); } else { attsize = local_ntohl(*((uint32_t *) attsize_buf)); next = data + attsize; if (next > eod || next < data) { /* * Increment caused wrap... */ PyErr_Format(PyExc_ValueError, "attribute %d has invalid size %lu", cnatt, attsize ); return(-1); } ob = PyBytes_FromStringAndSize(data, attsize); if (ob == NULL) { /* * Probably an OOM error. */ return(-1); } PyTuple_SET_ITEM(dst, cnatt, ob); data = next; } cnatt++; } if (data != eod) { PyErr_Format(PyExc_ValueError, "invalid tuple(D) message, %lu remaining " "bytes after processing %d attributes", (unsigned long) (eod - data), cnatt ); return(-1); } return(0); } static PyObject * parse_tuple_message(PyObject *self, PyObject *arg) { PyObject *rob; const char *data; Py_ssize_t dlen = 0; uint16_t natts = 0; if (PyObject_AsReadBuffer(arg, (const void **) &data, &dlen)) return(NULL); if (dlen < 2) { PyErr_Format(PyExc_ValueError, "invalid tuple message: %d bytes is too small", dlen); return(NULL); } Py_MEMCPY(&natts, data, 2); natts = local_ntohs(natts); rob = PyTuple_New(natts); if (rob == NULL) return(NULL); if (_unpack_tuple_data(rob, natts, data+2, dlen-2) < 0) { Py_DECREF(rob); return(NULL); } return(rob); } static PyObject * consume_tuple_messages(PyObject *self, PyObject *list) { Py_ssize_t i; PyObject *rob; /* builtins.list */ if (!PyTuple_Check(list)) { PyErr_SetString(PyExc_TypeError, "consume_tuple_messages requires a tuple"); return(NULL); } rob = PyList_New(PyTuple_GET_SIZE(list)); if (rob == NULL) return(NULL); for (i = 0; i < PyTuple_GET_SIZE(list); ++i) { register PyObject *data; PyObject *msg, *typ, *ptm; msg = PyTuple_GET_ITEM(list, i); if (!PyTuple_CheckExact(msg) || PyTuple_GET_SIZE(msg) != 2) { Py_DECREF(rob); PyErr_SetString(PyExc_TypeError, "consume_tuple_messages requires tuples items to be tuples (pairs)"); return(NULL); } typ = PyTuple_GET_ITEM(msg, 0); if (!PyBytes_CheckExact(typ) || PyBytes_GET_SIZE(typ) != 1) { Py_DECREF(rob); PyErr_SetString(PyExc_TypeError, "consume_tuple_messages requires pairs to consist of bytes"); return(NULL); } /* * End of tuple messages. */ if (*(PyBytes_AS_STRING(typ)) != 'D') break; data = PyTuple_GET_ITEM(msg, 1); ptm = parse_tuple_message(NULL, data); if (ptm == NULL) { Py_DECREF(rob); return(NULL); } PyList_SET_ITEM(rob, i, ptm); } if (i < PyTuple_GET_SIZE(list)) { PyObject *newrob; newrob = PyList_GetSlice(rob, 0, i); Py_DECREF(rob); rob = newrob; } return(rob); } static PyObject * pack_tuple_data(PyObject *self, PyObject *tup) { return(_pack_tuple_data(tup)); } /* * Check for overflow before incrementing the buffer size for cat_messages. */ #define INCSIZET(XVAR, AMT) do { \ size_t _amt_ = AMT; \ size_t _newsize_ = XVAR + _amt_; \ if (_newsize_ >= XVAR) XVAR = _newsize_; else { \ PyErr_Format(PyExc_OverflowError, \ "buffer size overflowed, was %zd bytes, but could not add %d more", XVAR, _amt_); \ goto fail; } \ } while(0) #define INCMSGSIZE(XVAR, AMT) do { \ uint32_t _amt_ = AMT; \ uint32_t _newsize_ = XVAR + _amt_; \ if (_newsize_ >= XVAR) XVAR = _newsize_; else { \ PyErr_Format(PyExc_OverflowError, \ "message size too large, was %u bytes, but could not add %u more", XVAR, _amt_); \ goto fail; } \ } while(0) /* * cat_messages - cat the serialized form of the messages in the given list * * This offers a fast way to construct the final bytes() object to be sent to * the wire. It avoids re-creating bytes() objects by calculating the serialized * size of contiguous, homogenous messages, allocating or extending the buffer * to accommodate for the needed size, and finally, copying the data into the * newly available space. */ static PyObject * cat_messages(PyObject *self, PyObject *messages_in) { const static char null_attribute[4] = {0xff,0xff,0xff,0xff}; PyObject *msgs = NULL; Py_ssize_t nmsgs = 0; Py_ssize_t cmsg = 0; /* * Buffer holding the messages' serialized form. */ char *buf = NULL; char *nbuf = NULL; size_t bufsize = 0; size_t bufpos = 0; /* * Get a List object for faster rescanning when dealing with copy data. */ msgs = PyObject_CallFunctionObjArgs((PyObject *) &PyList_Type, messages_in, NULL); if (msgs == NULL) return(NULL); nmsgs = PyList_GET_SIZE(msgs); while (cmsg < nmsgs) { PyObject *ob; ob = PyList_GET_ITEM(msgs, cmsg); /* * Choose the path, lots of copy data or more singles to serialize? */ if (PyBytes_CheckExact(ob)) { Py_ssize_t eofc = cmsg; size_t xsize = 0; /* find the last of the copy data (eofc) */ do { ++eofc; /* increase in size to allocate for the adjacent copy messages */ INCSIZET(xsize, PyBytes_GET_SIZE(ob)); if (eofc >= nmsgs) break; /* end of messages in the list? */ /* Grab the next message. */ ob = PyList_GET_ITEM(msgs, eofc); } while(PyBytes_CheckExact(ob)); /* * Either the end of the list or `ob` is not a data object meaning * that it's the end of the copy data. */ /* realloc the buf for the new copy data */ INCSIZET(xsize, (5 * (eofc - cmsg))); INCSIZET(bufsize, xsize); nbuf = realloc(buf, bufsize); if (nbuf == NULL) { PyErr_Format( PyExc_MemoryError, "failed to allocate %lu bytes of memory for out-going messages", (unsigned long) bufsize ); goto fail; } else { buf = nbuf; nbuf = NULL; } /* * Make the final pass through the copy lines memcpy'ing the data from * the bytes() objects. */ while (cmsg < eofc) { uint32_t msg_length = 0; char *localbuf = buf + bufpos + 1; buf[bufpos] = 'd'; /* COPY data message type */ ob = PyList_GET_ITEM(msgs, cmsg); INCMSGSIZE(msg_length, (uint32_t) PyBytes_GET_SIZE(ob) + 4); INCSIZET(bufpos, 1 + msg_length); msg_length = local_ntohl(msg_length); Py_MEMCPY(localbuf, &msg_length, 4); Py_MEMCPY(localbuf + 4, PyBytes_AS_STRING(ob), PyBytes_GET_SIZE(ob)); ++cmsg; } } else if (PyTuple_CheckExact(ob)) { /* * Handle 'D' tuple data from a raw Python tuple. */ Py_ssize_t eofc = cmsg; size_t xsize = 0; /* find the last of the tuple data (eofc) */ do { Py_ssize_t current_item, nitems; nitems = PyTuple_GET_SIZE(ob); if (nitems > 0xFFFF) { PyErr_SetString(PyExc_OverflowError, "too many attributes in tuple message"); goto fail; } /* * The items take *at least* 4 bytes each. * (The attribute count is considered later) */ INCSIZET(xsize, (nitems * 4)); for (current_item = 0; current_item < nitems; ++current_item) { PyObject *att = PyTuple_GET_ITEM(ob, current_item); /* * Attributes *must* be bytes() or None. */ if (PyBytes_CheckExact(att)) INCSIZET(xsize, PyBytes_GET_SIZE(att)); else if (att != Py_None) { PyErr_Format(PyExc_TypeError, "cannot serialize tuple message attribute of type '%s'", Py_TYPE(att)->tp_name); goto fail; } /* * else it's Py_None and the size will be included later. */ } ++eofc; if (eofc >= nmsgs) break; /* end of messages in the list? */ /* Grab the next message. */ ob = PyList_GET_ITEM(msgs, eofc); } while(PyTuple_CheckExact(ob)); /* * Either the end of the list or `ob` is not a data object meaning * that it's the end of the copy data. */ /* * realloc the buf for the new tuple data * * Each D message consumes at least 1 + 4 + 2 bytes: * 1 for the message type * 4 for the message size * 2 for the attribute count */ INCSIZET(xsize, (7 * (eofc - cmsg))); INCSIZET(bufsize, xsize); nbuf = realloc(buf, bufsize); if (nbuf == NULL) { PyErr_Format( PyExc_MemoryError, "failed to allocate %zd bytes of memory for out-going messages", bufsize ); goto fail; } else { buf = nbuf; nbuf = NULL; } /* * Make the final pass through the tuple data memcpy'ing the data from * the bytes() objects. * * No type checks are done here as they should have been done while * gathering the sizes for the realloc(). */ while (cmsg < eofc) { Py_ssize_t current_item, nitems; uint32_t msg_length, out_msg_len; uint16_t natts; char *localbuf = (buf + bufpos) + 5; /* skipping the header for now */ buf[bufpos] = 'D'; /* Tuple data message type */ ob = PyList_GET_ITEM(msgs, cmsg); nitems = PyTuple_GET_SIZE(ob); /* * 4 bytes for the message length, * 2 bytes for the attribute count and * 4 bytes for each item in 'ob'. */ msg_length = 4 + 2 + (nitems * 4); /* * Set number of attributes. */ natts = local_ntohs((uint16_t) nitems); Py_MEMCPY(localbuf, &natts, 2); localbuf = localbuf + 2; for (current_item = 0; current_item < nitems; ++current_item) { PyObject *att = PyTuple_GET_ITEM(ob, current_item); if (att == Py_None) { Py_MEMCPY(localbuf, &null_attribute, 4); localbuf = localbuf + 4; } else { Py_ssize_t attsize = PyBytes_GET_SIZE(att); uint32_t n_attsize; n_attsize = local_ntohl((uint32_t) attsize); Py_MEMCPY(localbuf, &n_attsize, 4); localbuf = localbuf + 4; Py_MEMCPY(localbuf, PyBytes_AS_STRING(att), attsize); localbuf = localbuf + attsize; INCSIZET(msg_length, attsize); } } /* * Summed up the message size while copying the attributes. */ out_msg_len = local_ntohl(msg_length); Py_MEMCPY(buf + bufpos + 1, &out_msg_len, 4); /* * Filled in the data while summing the message size, so * adjust the buffer position for the next message. */ INCSIZET(bufpos, 1 + msg_length); ++cmsg; } } else { PyObject *serialized; PyObject *msg_type; int msg_type_size; uint32_t msg_length; /* * Call the serialize() method on the element object. * Do this instead of the normal bytes() method to avoid * the type and size packing overhead. */ serialized = PyObject_CallMethodObjArgs(ob, serialize_strob, NULL); if (serialized == NULL) goto fail; if (!PyBytes_CheckExact(serialized)) { PyErr_Format( PyExc_TypeError, "%s.serialize() returned object of type %s, expected bytes", PyObject_TypeName(ob), PyObject_TypeName(serialized) ); goto fail; } msg_type = PyObject_GetAttr(ob, msgtype_strob); if (msg_type == NULL) { Py_DECREF(serialized); goto fail; } if (!PyBytes_CheckExact(msg_type)) { Py_DECREF(serialized); Py_DECREF(msg_type); PyErr_Format( PyExc_TypeError, "message's 'type' attribute was %s, expected bytes", PyObject_TypeName(ob) ); goto fail; } /* * Some elements have empty message types--Startup for instance. * It is important to get the actual size rather than assuming one. */ msg_type_size = PyBytes_GET_SIZE(msg_type); /* realloc the buf for the new copy data */ INCSIZET(bufsize, 4 + msg_type_size); INCSIZET(bufsize, PyBytes_GET_SIZE(serialized)); nbuf = realloc(buf, bufsize); if (nbuf == NULL) { Py_DECREF(serialized); Py_DECREF(msg_type); PyErr_Format( PyExc_MemoryError, "failed to allocate %d bytes of memory for out-going messages", bufsize ); goto fail; } else { buf = nbuf; nbuf = NULL; } /* * All necessary information acquired, so fill in the message's data. */ buf[bufpos] = *(PyBytes_AS_STRING(msg_type)); msg_length = PyBytes_GET_SIZE(serialized); INCMSGSIZE(msg_length, 4); msg_length = local_ntohl(msg_length); Py_MEMCPY(buf + bufpos + msg_type_size, &msg_length, 4); Py_MEMCPY( buf + bufpos + 4 + msg_type_size, PyBytes_AS_STRING(serialized), PyBytes_GET_SIZE(serialized) ); bufpos = bufsize; Py_DECREF(serialized); Py_DECREF(msg_type); ++cmsg; } } Py_DECREF(msgs); if (buf == NULL) /* no messages, no data */ return(PyBytes_FromString("")); else { PyObject *rob; rob = PyBytes_FromStringAndSize(buf, bufsize); free(buf); return(rob); } fail: /* pyerr is expected to be set */ Py_DECREF(msgs); if (buf != NULL) free(buf); return(NULL); } fe-1.1.0/postgresql/port/_optimized/functools.c000066400000000000000000000164051203372773200216310ustar00rootroot00000000000000/* * .port.optimized - functools.c * *//* * optimizations for postgresql.python package modules. */ /* * process the tuple with the associated callables while * calling the third object in cases of failure to generalize the exception. */ #define include_functools_functions \ mFUNC(rsetattr, METH_VARARGS, "rsetattr(attr, val, ob) set the attribute to the value *and* return `ob`.") \ mFUNC(compose, METH_VARARGS, "given a sequence of callables, and an argument for the first call, compose the result.") \ mFUNC(process_tuple, METH_VARARGS, "process the items in the second argument with the corresponding items in the first argument.") \ mFUNC(process_chunk, METH_VARARGS, "process the items of the chunk given as the second argument with the corresponding items in the first argument.") static PyObject * _process_tuple(PyObject *procs, PyObject *tup, PyObject *fail) { PyObject *rob; Py_ssize_t len, i; if (!PyTuple_CheckExact(procs)) { PyErr_SetString( PyExc_TypeError, "process_tuple requires an exact tuple as its first argument" ); return(NULL); } if (!PyTuple_Check(tup)) { PyErr_SetString( PyExc_TypeError, "process_tuple requires a tuple as its second argument" ); return(NULL); } len = PyTuple_GET_SIZE(tup); if (len != PyTuple_GET_SIZE(procs)) { PyErr_Format( PyExc_TypeError, "inconsistent items, %d processors and %d items in row", len, PyTuple_GET_SIZE(procs) ); return(NULL); } /* types check out; consistent sizes */ rob = PyTuple_New(len); for (i = 0; i < len; ++i) { PyObject *p, *o, *ot, *r; /* p = processor, * o = source object, * ot = o's tuple (temp for application to p), * r = transformed * output */ /* * If it's Py_None, that means it's NULL. No processing necessary. */ o = PyTuple_GET_ITEM(tup, i); if (o == Py_None) { Py_INCREF(Py_None); PyTuple_SET_ITEM(rob, i, Py_None); /* mmmm, cake! */ continue; } p = PyTuple_GET_ITEM(procs, i); /* * Temp tuple for applying *args to p. */ ot = PyTuple_New(1); PyTuple_SET_ITEM(ot, 0, o); Py_INCREF(o); r = PyObject_CallObject(p, ot); Py_DECREF(ot); if (r != NULL) { /* good, set it and move on. */ PyTuple_SET_ITEM(rob, i, r); } else { /* * Exception caused by >>> p(*ot) * * In this case, the failure callback needs to be called * in order to properly generalize the failure. There are numerous, * and (sometimes) inconsistent reasons why a tuple cannot be * processed and therefore a generalized exception raised in the * context of the original is *very* useful. */ Py_DECREF(rob); rob = NULL; /* * Don't trap BaseException's. */ if (PyErr_ExceptionMatches(PyExc_Exception)) { PyObject *cause, *failargs, *failedat; PyObject *exc, *tb; /* Store exception to set context after handler. */ PyErr_Fetch(&exc, &cause, &tb); PyErr_NormalizeException(&exc, &cause, &tb); Py_XDECREF(exc); Py_XDECREF(tb); failedat = PyLong_FromSsize_t(i); if (failedat != NULL) { failargs = PyTuple_New(4); if (failargs != NULL) { /* args for the exception "generalizer" */ PyTuple_SET_ITEM(failargs, 0, cause); PyTuple_SET_ITEM(failargs, 1, procs); Py_INCREF(procs); PyTuple_SET_ITEM(failargs, 2, tup); Py_INCREF(tup); PyTuple_SET_ITEM(failargs, 3, failedat); r = PyObject_CallObject(fail, failargs); Py_DECREF(failargs); if (r != NULL) { PyErr_SetString(PyExc_RuntimeError, "process_tuple exception handler failed to raise" ); Py_DECREF(r); } } else { Py_DECREF(failedat); } } } /* * Break out of loop to return(NULL); */ break; } } return(rob); } /* * process the tuple with the associated callables while * calling the third object in cases of failure to generalize the exception. */ static PyObject * process_tuple(PyObject *self, PyObject *args) { PyObject *tup, *procs, *fail; if (!PyArg_ParseTuple(args, "OOO", &procs, &tup, &fail)) return(NULL); return(_process_tuple(procs, tup, fail)); } static PyObject * _process_chunk_new_list(PyObject *procs, PyObject *tupc, PyObject *fail) { PyObject *rob; Py_ssize_t i, len; /* * Turn the iterable into a new list. */ rob = PyObject_CallFunctionObjArgs((PyObject *) &PyList_Type, tupc, NULL); if (rob == NULL) return(NULL); len = PyList_GET_SIZE(rob); for (i = 0; i < len; ++i) { PyObject *tup, *r; /* * If it's Py_None, that means it's NULL. No processing necessary. */ tup = PyList_GetItem(rob, i); /* borrowed ref from list */ r = _process_tuple(procs, tup, fail); if (r == NULL) { /* process_tuple failed. assume PyErr_Occurred() */ Py_DECREF(rob); return(NULL); } PyList_SetItem(rob, i, r); } return(rob); } static PyObject * _process_chunk_from_list(PyObject *procs, PyObject *tupc, PyObject *fail) { PyObject *rob; Py_ssize_t i, len; len = PyList_GET_SIZE(tupc); rob = PyList_New(len); if (rob == NULL) return(NULL); for (i = 0; i < len; ++i) { PyObject *tup, *r; /* * If it's Py_None, that means it's NULL. No processing necessary. */ tup = PyList_GET_ITEM(tupc, i); r = _process_tuple(procs, tup, fail); if (r == NULL) { Py_DECREF(rob); return(NULL); } PyList_SET_ITEM(rob, i, r); } return(rob); } /* * process the chunk of tuples with the associated callables while * calling the third object in cases of failure to generalize the exception. */ static PyObject * process_chunk(PyObject *self, PyObject *args) { PyObject *tupc, *procs, *fail; if (!PyArg_ParseTuple(args, "OOO", &procs, &tupc, &fail)) return(NULL); if (PyList_Check(tupc)) { return(_process_chunk_from_list(procs, tupc, fail)); } else { return(_process_chunk_new_list(procs, tupc, fail)); } } static PyObject * rsetattr(PyObject *self, PyObject *args) { PyObject *ob, *attr, *val; if (!PyArg_ParseTuple(args, "OOO", &attr, &val, &ob)) return(NULL); if (PyObject_SetAttr(ob, attr, val) < 0) return(NULL); Py_INCREF(ob); return(ob); } /* * Override the functools.Composition __call__. */ static PyObject * compose(PyObject *self, PyObject *args) { Py_ssize_t i, len; PyObject *rob, *argt, *seq, *x; if (!PyArg_ParseTuple(args, "OO", &seq, &rob)) return(NULL); Py_INCREF(rob); if (PyObject_IsInstance(seq, (PyObject *) &PyTuple_Type)) { len = PyTuple_GET_SIZE(seq); for (i = 0; i < len; ++i) { x = PyTuple_GET_ITEM(seq, i); argt = PyTuple_New(1); PyTuple_SET_ITEM(argt, 0, rob); rob = PyObject_CallObject(x, argt); Py_DECREF(argt); if (rob == NULL) break; } } else if (PyObject_IsInstance(seq, (PyObject *) &PyList_Type)) { len = PyList_GET_SIZE(seq); for (i = 0; i < len; ++i) { x = PyList_GET_ITEM(seq, i); argt = PyTuple_New(1); PyTuple_SET_ITEM(argt, 0, rob); rob = PyObject_CallObject(x, argt); Py_DECREF(argt); if (rob == NULL) break; } } else { /* * Arbitrary sequence. */ len = PySequence_Length(seq); for (i = 0; i < len; ++i) { x = PySequence_GetItem(seq, i); argt = PyTuple_New(1); PyTuple_SET_ITEM(argt, 0, rob); rob = PyObject_CallObject(x, argt); Py_DECREF(x); Py_DECREF(argt); if (rob == NULL) break; } } return(rob); } /* * vim: ts=3:sw=3:noet: */ fe-1.1.0/postgresql/port/_optimized/module.c000066400000000000000000000067201203372773200211010ustar00rootroot00000000000000/* * module.c - optimizations for various parts of py-postgresql * * This module.c file ties together other classified C source. * Each filename describing the part of the protocol package that it * covers. It merely uses CPP includes to bring them into this * file and then uses some CPP macros to expand the definitions * in each file. */ #include #include /* * If Python didn't find it, it won't include it. * However, it's quite necessary. */ #ifndef HAVE_STDINT_H #include #endif #define USHORT_MAX ((1<<16)-1) #define SHORT_MAX ((1<<15)-1) #define SHORT_MIN (-(1<<15)) #define PyObject_TypeName(ob) \ (((PyTypeObject *) (ob->ob_type))->tp_name) /* * buffer.c needs the message_types object from .protocol.message_types. * Initialized in PyInit_optimized. */ static PyObject *message_types = NULL; static PyObject *serialize_strob = NULL; static PyObject *msgtype_strob = NULL; static int32_t (*local_ntohl)(int32_t) = NULL; static short (*local_ntohs)(short) = NULL; /* * optimized module contents */ #include "structlib.c" #include "functools.c" #include "buffer.c" #include "wirestate.c" #include "element3.c" /* cpp abuse, read up on X-Macros if you don't understand */ #define mFUNC(name, typ, doc) \ {#name, (PyCFunction) name, typ, PyDoc_STR(doc)}, static PyMethodDef optimized_methods[] = { include_element3_functions include_structlib_functions include_functools_functions {NULL} }; #undef mFUNC static struct PyModuleDef optimized_module = { PyModuleDef_HEAD_INIT, "optimized", /* name of module */ NULL, /* module documentation, may be NULL */ -1, /* size of per-interpreter state of the module, or -1 if the module keeps state in global variables. */ optimized_methods, }; PyMODINIT_FUNC PyInit_optimized(void) { PyObject *mod; PyObject *msgtypes; PyObject *fromlist, *fromstr; long l; /* make some constants */ if (serialize_strob == NULL) { serialize_strob = PyUnicode_FromString("serialize"); if (serialize_strob == NULL) return(NULL); } if (msgtype_strob == NULL) { msgtype_strob = PyUnicode_FromString("type"); if (msgtype_strob == NULL) return(NULL); } mod = PyModule_Create(&optimized_module); if (mod == NULL) return(NULL); /* cpp abuse; ready types */ #define mTYPE(name) \ if (PyType_Ready(&name##_Type) < 0) \ goto cleanup; \ if (PyModule_AddObject(mod, #name, \ (PyObject *) &name##_Type) < 0) \ goto cleanup; /* buffer.c */ include_buffer_types /* wirestate.c */ include_wirestate_types #undef mTYPE l = 1; if (((char *) &l)[0] == 1) { /* little */ local_ntohl = swap_int4; local_ntohs = swap_short; } else { /* big */ local_ntohl = return_int4; local_ntohs = return_short; } /* * Get the message_types tuple to type "instantiation". */ fromlist = PyList_New(1); fromstr = PyUnicode_FromString("message_types"); PyList_SetItem(fromlist, 0, fromstr); msgtypes = PyImport_ImportModuleLevel( "protocol.message_types", PyModule_GetDict(mod), PyModule_GetDict(mod), fromlist, 2 ); Py_DECREF(fromlist); if (msgtypes == NULL) goto cleanup; message_types = PyObject_GetAttrString(msgtypes, "message_types"); Py_DECREF(msgtypes); if (!PyObject_IsInstance(message_types, (PyObject *) (&PyTuple_Type))) { PyErr_SetString(PyExc_RuntimeError, "local protocol.message_types.message_types is not a tuple object"); goto cleanup; } return(mod); cleanup: Py_DECREF(mod); return(NULL); } /* * vim: ts=3:sw=3:noet: */ fe-1.1.0/postgresql/port/_optimized/structlib.c000066400000000000000000000272331203372773200216310ustar00rootroot00000000000000/* * .port.optimized - pack and unpack int2, int4, and int8. */ /* * Define the swap functionality for those endians. */ #define swap2(CP) do{register char c; \ c=CP[1];CP[1]=CP[0];CP[0]=c;\ }while(0) #define swap4(P) do{register char c; \ c=P[3];P[3]=P[0];P[0]=c;\ c=P[2];P[2]=P[1];P[1]=c;\ }while(0) #define swap8(P) do{register char c; \ c=P[7];P[7]=P[0];P[0]=c;\ c=P[6];P[6]=P[1];P[1]=c;\ c=P[5];P[5]=P[2];P[2]=c;\ c=P[4];P[4]=P[3];P[3]=c;\ }while(0) #define long_funcs \ mFUNC(int2_pack, METH_O, "PyInt to serialized, int2") \ mFUNC(int2_unpack, METH_O, "PyInt from serialized, int2") \ mFUNC(int4_pack, METH_O, "PyInt to serialized, int4") \ mFUNC(int4_unpack, METH_O, "PyInt from serialized, int4") \ mFUNC(swap_int2_pack, METH_O, "PyInt to swapped serialized, int2") \ mFUNC(swap_int2_unpack, METH_O, "PyInt from swapped serialized, int2") \ mFUNC(swap_int4_pack, METH_O, "PyInt to swapped serialized, int4") \ mFUNC(swap_int4_unpack, METH_O, "PyInt from swapped serialized, int4") \ mFUNC(uint2_pack, METH_O, "PyInt to serialized, uint2") \ mFUNC(uint2_unpack, METH_O, "PyInt from serialized, uint2") \ mFUNC(uint4_pack, METH_O, "PyInt to serialized, uint4") \ mFUNC(uint4_unpack, METH_O, "PyInt from serialized, uint4") \ mFUNC(swap_uint2_pack, METH_O, "PyInt to swapped serialized, uint2") \ mFUNC(swap_uint2_unpack, METH_O, "PyInt from swapped serialized, uint2") \ mFUNC(swap_uint4_pack, METH_O, "PyInt to swapped serialized, uint4") \ mFUNC(swap_uint4_unpack, METH_O, "PyInt from swapped serialized, uint4") \ #ifdef HAVE_LONG_LONG #if SIZEOF_LONG_LONG == 8 /* * If the configuration is not consistent with the expectations, * just use the slower struct.Struct versions. */ #define longlong_funcs \ mFUNC(int8_pack, METH_O, "PyInt to serialized, int8") \ mFUNC(int8_unpack, METH_O, "PyInt from serialized, int8") \ mFUNC(swap_int8_pack, METH_O, "PyInt to swapped serialized, int8") \ mFUNC(swap_int8_unpack, METH_O, "PyInt from swapped serialized, int8") \ mFUNC(uint8_pack, METH_O, "PyInt to serialized, uint8") \ mFUNC(uint8_unpack, METH_O, "PyInt from serialized, uint8") \ mFUNC(swap_uint8_pack, METH_O, "PyInt to swapped serialized, uint8") \ mFUNC(swap_uint8_unpack, METH_O, "PyInt from swapped serialized, uint8") \ #define include_structlib_functions \ long_funcs \ longlong_funcs #if 0 Currently not used, so exclude. static PY_LONG_LONG return_long_long(PY_LONG_LONG i) { return(i); } static PY_LONG_LONG swap_long_long(PY_LONG_LONG i) { swap8(((char *) &i)); return(i); } #endif #endif #endif #ifndef include_structlib_functions #define include_structlib_functions \ long_funcs #endif static short swap_short(short s) { swap2(((char *) &s)); return(s); } static short return_short(short s) { return(s); } static int32_t swap_int4(int32_t i) { swap4(((char *) &i)); return(i); } static int32_t return_int4(int32_t i) { return(i); } static PyObject * int2_pack(PyObject *self, PyObject *arg) { long l; short s; l = PyLong_AsLong(arg); if (PyErr_Occurred()) return(NULL); if (l > SHORT_MAX || l < SHORT_MIN) { PyErr_Format(PyExc_OverflowError, "long '%d' overflows int2", l ); return(NULL); } s = (short) l; return(PyBytes_FromStringAndSize((const char *) &s, 2)); } static PyObject * swap_int2_pack(PyObject *self, PyObject *arg) { long l; short s; l = PyLong_AsLong(arg); if (PyErr_Occurred()) return(NULL); if (l > SHORT_MAX || l < SHORT_MIN) { PyErr_SetString(PyExc_OverflowError, "long too big or small for int2"); return(NULL); } s = (short) l; swap2(((char *) &s)); return(PyBytes_FromStringAndSize((const char *) &s, 2)); } static PyObject * int2_unpack(PyObject *self, PyObject *arg) { char *c; short *i; long l; Py_ssize_t len; PyObject *rob; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 2) { PyErr_SetString(PyExc_ValueError, "not enough data for int2_unpack"); return(NULL); } i = (short *) c; l = (long) *i; rob = PyLong_FromLong(l); return(rob); } static PyObject * swap_int2_unpack(PyObject *self, PyObject *arg) { char *c; short s; long l; Py_ssize_t len; PyObject *rob; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 2) { PyErr_SetString(PyExc_ValueError, "not enough data for swap_int2_unpack"); return(NULL); } s = *((short *) c); swap2(((char *) &s)); l = (long) s; rob = PyLong_FromLong(l); return(rob); } static PyObject * int4_pack(PyObject *self, PyObject *arg) { long l; int32_t i; l = PyLong_AsLong(arg); if (PyErr_Occurred()) return(NULL); if (!(l <= (long) 0x7FFFFFFFL && l >= (long) (-0x80000000L))) { PyErr_Format(PyExc_OverflowError, "long '%ld' overflows int4", l ); return(NULL); } i = (int32_t) l; return(PyBytes_FromStringAndSize((const char *) &i, 4)); } static PyObject * swap_int4_pack(PyObject *self, PyObject *arg) { long l; int32_t i; l = PyLong_AsLong(arg); if (PyErr_Occurred()) return(NULL); if (!(l <= (long) 0x7FFFFFFFL && l >= (long) (-0x80000000L))) { PyErr_Format(PyExc_OverflowError, "long '%ld' overflows int4", l ); return(NULL); } i = (int32_t) l; swap4(((char *) &i)); return(PyBytes_FromStringAndSize((const char *) &i, 4)); } static PyObject * int4_unpack(PyObject *self, PyObject *arg) { char *c; int32_t i; Py_ssize_t len; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 4) { PyErr_SetString(PyExc_ValueError, "not enough data for int4_unpack"); return(NULL); } i = *((int32_t *) c); return(PyLong_FromLong((long) i)); } static PyObject * swap_int4_unpack(PyObject *self, PyObject *arg) { char *c; int32_t i; Py_ssize_t len; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 4) { PyErr_SetString(PyExc_ValueError, "not enough data for swap_int4_unpack"); return(NULL); } i = *((int32_t *) c); swap4(((char *) &i)); return(PyLong_FromLong((long) i)); } static PyObject * uint2_pack(PyObject *self, PyObject *arg) { long l; unsigned short s; l = PyLong_AsLong(arg); if (PyErr_Occurred()) return(NULL); if (l > USHORT_MAX || l < 0) { PyErr_Format(PyExc_OverflowError, "long '%ld' overflows uint2", l ); return(NULL); } s = (unsigned short) l; return(PyBytes_FromStringAndSize((const char *) &s, 2)); } static PyObject * swap_uint2_pack(PyObject *self, PyObject *arg) { long l; unsigned short s; l = PyLong_AsLong(arg); if (PyErr_Occurred()) return(NULL); if (l > USHORT_MAX || l < 0) { PyErr_Format(PyExc_OverflowError, "long '%ld' overflows uint2", l ); return(NULL); } s = (unsigned short) l; swap2(((char *) &s)); return(PyBytes_FromStringAndSize((const char *) &s, 2)); } static PyObject * uint2_unpack(PyObject *self, PyObject *arg) { char *c; unsigned short *i; long l; Py_ssize_t len; PyObject *rob; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 2) { PyErr_SetString(PyExc_ValueError, "not enough data for uint2_unpack"); return(NULL); } i = (unsigned short *) c; l = (long) *i; rob = PyLong_FromLong(l); return(rob); } static PyObject * swap_uint2_unpack(PyObject *self, PyObject *arg) { char *c; unsigned short s; long l; Py_ssize_t len; PyObject *rob; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 2) { PyErr_SetString(PyExc_ValueError, "not enough data for swap_uint2_unpack"); return(NULL); } s = *((short *) c); swap2(((char *) &s)); l = (long) s; rob = PyLong_FromLong(l); return(rob); } static PyObject * uint4_pack(PyObject *self, PyObject *arg) { uint32_t i; unsigned long l; l = PyLong_AsUnsignedLong(arg); if (PyErr_Occurred()) return(NULL); if (l > 0xFFFFFFFFL) { PyErr_Format(PyExc_OverflowError, "long '%lu' overflows uint4", l ); return(NULL); } i = (uint32_t) l; return(PyBytes_FromStringAndSize((const char *) &i, 4)); } static PyObject * swap_uint4_pack(PyObject *self, PyObject *arg) { uint32_t i; unsigned long l; l = PyLong_AsUnsignedLong(arg); if (PyErr_Occurred()) return(NULL); if (l > 0xFFFFFFFFL) { PyErr_Format(PyExc_OverflowError, "long '%lu' overflows uint4", l ); return(NULL); } i = (uint32_t) l; swap4(((char *) &i)); return(PyBytes_FromStringAndSize((const char *) &i, 4)); } static PyObject * uint4_unpack(PyObject *self, PyObject *arg) { char *c; uint32_t i; Py_ssize_t len; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 4) { PyErr_SetString(PyExc_ValueError, "not enough data for uint4_unpack"); return(NULL); } i = *((uint32_t *) c); return(PyLong_FromUnsignedLong((unsigned long) i)); } static PyObject * swap_uint4_unpack(PyObject *self, PyObject *arg) { char *c; uint32_t i; Py_ssize_t len; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 4) { PyErr_SetString(PyExc_ValueError, "not enough data for swap_uint4_unpack"); return(NULL); } i = *((uint32_t *) c); swap4(((char *) &i)); return(PyLong_FromUnsignedLong((unsigned long) i)); } #ifdef longlong_funcs /* * int8 and "uint8" I/O */ static PyObject * int8_pack(PyObject *self, PyObject *arg) { PY_LONG_LONG l; l = PyLong_AsLongLong(arg); if (l == (PY_LONG_LONG) -1 && PyErr_Occurred()) return(NULL); return(PyBytes_FromStringAndSize((const char *) &l, 8)); } static PyObject * swap_int8_pack(PyObject *self, PyObject *arg) { PY_LONG_LONG l; l = PyLong_AsLongLong(arg); if (l == (PY_LONG_LONG) -1 && PyErr_Occurred()) return(NULL); swap8(((char *) &l)); return(PyBytes_FromStringAndSize((const char *) &l, 8)); } static PyObject * uint8_pack(PyObject *self, PyObject *arg) { unsigned PY_LONG_LONG l; l = PyLong_AsUnsignedLongLong(arg); if (l == (unsigned PY_LONG_LONG) -1 && PyErr_Occurred()) return(NULL); return(PyBytes_FromStringAndSize((const char *) &l, 8)); } static PyObject * swap_uint8_pack(PyObject *self, PyObject *arg) { unsigned PY_LONG_LONG l; l = PyLong_AsUnsignedLongLong(arg); if (l == (unsigned PY_LONG_LONG) -1 && PyErr_Occurred()) return(NULL); swap8(((char *) &l)); return(PyBytes_FromStringAndSize((const char *) &l, 8)); } static PyObject * uint8_unpack(PyObject *self, PyObject *arg) { char *c; Py_ssize_t len; unsigned PY_LONG_LONG i; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 8) { PyErr_SetString(PyExc_ValueError, "not enough data for uint8_unpack"); return(NULL); } i = *((unsigned PY_LONG_LONG *) c); return(PyLong_FromUnsignedLongLong(i)); } static PyObject * swap_uint8_unpack(PyObject *self, PyObject *arg) { char *c; Py_ssize_t len; unsigned PY_LONG_LONG i; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 8) { PyErr_SetString(PyExc_ValueError, "not enough data for swap_uint8_unpack"); return(NULL); } i = *((unsigned PY_LONG_LONG *) c); swap8(((char *) &i)); return(PyLong_FromUnsignedLongLong(i)); } static PyObject * int8_unpack(PyObject *self, PyObject *arg) { char *c; Py_ssize_t len; PY_LONG_LONG i; if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 8) { PyErr_SetString(PyExc_ValueError, "not enough data for int8_unpack"); return(NULL); } i = *((PY_LONG_LONG *) c); return(PyLong_FromLongLong((PY_LONG_LONG) i)); } static PyObject * swap_int8_unpack(PyObject *self, PyObject *arg) { char *c; Py_ssize_t len; PY_LONG_LONG i; c = PyBytes_AsString(arg); if (PyErr_Occurred()) return(NULL); if (PyObject_AsReadBuffer(arg, (const void **) &c, &len)) return(NULL); if (len < 8) { PyErr_SetString(PyExc_ValueError, "not enough data for swap_int8_unpack"); return(NULL); } i = *((PY_LONG_LONG *) c); swap8(((char *) &i)); return(PyLong_FromLongLong(i)); } #endif /* longlong_funcs */ fe-1.1.0/postgresql/port/_optimized/wirestate.c000066400000000000000000000151321203372773200216200ustar00rootroot00000000000000/* * .port.optimized.WireState - PQ wire state for COPY. */ #define include_wirestate_types \ mTYPE(WireState) struct wirestate { PyObject_HEAD char size_fragment[4]; /* the header fragment; continuation specifies bytes read so far. */ PyObject *final_view; /* Py_None unless we reach an unknown message */ Py_ssize_t remaining_bytes; /* Bytes remaining in message */ short continuation; /* >= 0 when continuing a fragment */ }; static void ws_dealloc(PyObject *self) { struct wirestate *ws = ((struct wirestate *) self); Py_XDECREF(ws->final_view); Py_TYPE(self)->tp_free(self); } static PyObject * ws_new(PyTypeObject *subtype, PyObject *args, PyObject *kw) { static char *kwlist[] = {"condition", NULL}; struct wirestate *ws; PyObject *rob; if (!PyArg_ParseTupleAndKeywords(args, kw, "|O", kwlist, &rob)) return(NULL); rob = subtype->tp_alloc(subtype, 0); ws = ((struct wirestate *) rob); ws->continuation = -1; ws->remaining_bytes = 0; ws->final_view = NULL; return(rob); } #define CONDITION(MSGTYPE) (MSGTYPE != 'd') static PyObject * ws_update(PyObject *self, PyObject *view) { struct wirestate *ws; uint32_t remaining_bytes, nmessages = 0; unsigned char *buf, msgtype; char size_fragment[4]; short continuation; Py_ssize_t position = 0, len; PyObject *rob, *final_view = NULL; if (PyObject_AsReadBuffer(view, (const void **) &buf, &len)) return(NULL); if (len == 0) { /* * Nothing changed. */ return(PyLong_FromUnsignedLong(0)); } ws = (struct wirestate *) self; if (ws->final_view) { PyErr_SetString(PyExc_RuntimeError, "wire state has been terminated"); return(NULL); } remaining_bytes = ws->remaining_bytes; continuation = ws->continuation; if (continuation >= 0) { short sf_len = continuation, added; /* * Continuation of message header. */ added = 4 - sf_len; /* * If the buffer's length does not provide, limit to len. */ if (len < added) added = len; Py_MEMCPY(size_fragment, ws->size_fragment, 4); Py_MEMCPY(size_fragment + sf_len, buf, added); continuation = continuation + added; if (continuation == 4) { /* * Completed the size part of the header. */ Py_MEMCPY(&remaining_bytes, size_fragment, 4); remaining_bytes = (local_ntohl((int32_t) remaining_bytes)); if (remaining_bytes < 4) goto invalid_message_header; remaining_bytes = remaining_bytes - sf_len; if (remaining_bytes == 0) ++nmessages; continuation = -1; } else { /* * Consumed more of the header, but more is still needed. * Jump past the main loop. */ goto return_nmessages; } } do { if (remaining_bytes > 0) { position = position + remaining_bytes; if (position > len) { remaining_bytes = position - len; position = len; } else { remaining_bytes = 0; ++nmessages; } } /* * Done with view. */ if (position >= len) break; /* * Validate message type. */ msgtype = *(buf + position); if (CONDITION(msgtype)) { final_view = PySequence_GetSlice(view, position, len); break; } /* * Have enough for a complete header? */ if (len - position < 5) { /* * Start a continuation. Message type has been verified. */ continuation = (len - position) - 1; Py_MEMCPY(size_fragment, buf + position + 1, (Py_ssize_t) continuation); break; } /* * +1 to include the message type. */ Py_MEMCPY(&remaining_bytes, buf + position + 1, 4); remaining_bytes = local_ntohl((int32_t) remaining_bytes) + 1; if (remaining_bytes < 5) goto invalid_message_header; } while(1); return_nmessages: rob = PyLong_FromUnsignedLong(nmessages); if (rob == NULL) { Py_XDECREF(final_view); return(NULL); } /* Commit new state */ ws->remaining_bytes = remaining_bytes; ws->final_view = final_view; ws->continuation = continuation; Py_MEMCPY(ws->size_fragment, size_fragment, 4); return(rob); invalid_message_header: PyErr_SetString(PyExc_ValueError, "message header contained an invalid size"); return(NULL); } static PyMethodDef ws_methods[] = { {"update", ws_update, METH_O, PyDoc_STR("update the state of the wire using the given buffer object"),}, {NULL} }; PyObject * ws_size_fragment(PyObject *self, void *closure) { struct wirestate *ws; ws = (struct wirestate *) self; return(PyBytes_FromStringAndSize(ws->size_fragment, ws->continuation <= 0 ? 0 : ws->continuation)); } PyObject * ws_remaining_bytes(PyObject *self, void *closure) { struct wirestate *ws; ws = (struct wirestate *) self; return(PyLong_FromLong( ws->continuation == -1 ? ws->remaining_bytes : -1 )); } PyObject * ws_final_view(PyObject *self, void *closure) { struct wirestate *ws; PyObject *rob; ws = (struct wirestate *) self; rob = ws->final_view ? ws->final_view : Py_None; Py_INCREF(rob); return(rob); } static PyGetSetDef ws_getset[] = { {"size_fragment", ws_size_fragment, NULL, PyDoc_STR("The data acculumated for the continuation."), NULL,}, {"remaining_bytes", ws_remaining_bytes, NULL, PyDoc_STR("Number bytes necessary to complete the current message."), NULL,}, {"final_view", ws_final_view, NULL, PyDoc_STR("A memoryview of the data that triggered the CONDITION()."), NULL,}, {NULL} }; PyTypeObject WireState_Type = { PyVarObject_HEAD_INIT(NULL, 0) "postgresql.port.optimized.WireState", /* tp_name */ sizeof(struct wirestate), /* tp_basicsize */ 0, /* tp_itemsize */ ws_dealloc, /* tp_dealloc */ NULL, /* tp_print */ NULL, /* tp_getattr */ NULL, /* tp_setattr */ NULL, /* tp_compare */ NULL, /* tp_repr */ NULL, /* tp_as_number */ NULL, /* tp_as_sequence */ NULL, /* tp_as_mapping */ NULL, /* tp_hash */ NULL, /* tp_call */ NULL, /* tp_str */ NULL, /* tp_getattro */ NULL, /* tp_setattro */ NULL, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ PyDoc_STR("Track the state of the wire."), /* tp_doc */ NULL, /* tp_traverse */ NULL, /* tp_clear */ NULL, /* tp_richcompare */ 0, /* tp_weaklistoffset */ NULL, /* tp_iter */ NULL, /* tp_iternext */ ws_methods, /* tp_methods */ NULL, /* tp_members */ ws_getset, /* tp_getset */ NULL, /* tp_base */ NULL, /* tp_dict */ NULL, /* tp_descr_get */ NULL, /* tp_descr_set */ 0, /* tp_dictoffset */ NULL, /* tp_init */ NULL, /* tp_alloc */ ws_new, /* tp_new */ NULL, /* tp_free */ }; /* * vim: ts=3:sw=3:noet: */ fe-1.1.0/postgresql/port/signal1_msw.py000066400000000000000000000042511203372773200201000ustar00rootroot00000000000000## # .port.signal1_msw ## """ Support for PG signals on Windows platforms. This implementation supports all known versions of PostgreSQL. (2010) CallNamedPipe: http://msdn.microsoft.com/en-us/library/aa365144%28VS.85%29.aspx """ import errno from ctypes import windll, wintypes, pointer # CallNamedPipe from kernel32. CallNamedPipeA = windll.kernel32.CallNamedPipeA CallNamedPipeA.restype = wintypes.BOOL CallNamedPipeA.argtypes = ( wintypes.LPCSTR, # in namedpipename wintypes.LPVOID, # in inbuffer (for signal number) wintypes.DWORD, # in inbuffersize (always 1) wintypes.LPVOID, # OutBuffer (signal return validation) wintypes.DWORD, # in OutBufferSize (always 1) wintypes.LPVOID, # out bytes read, really LPDWORD. wintypes.DWORD, # in timeout ) from signal import SIGTERM, SIGINT, SIG_DFL # SYNC: Values taken from the port/win32.h file. SIG_DFL=0 SIGHUP=1 SIGQUIT=3 SIGTRAP=5 SIGABRT=22 # /* Set to match W32 value -- not UNIX value */ SIGKILL=9 SIGPIPE=13 SIGALRM=14 SIGSTOP=17 SIGTSTP=18 SIGCONT=19 SIGCHLD=20 SIGTTIN=21 SIGTTOU=22 # /* Same as SIGABRT -- no problem, I hope */ SIGWINCH=28 SIGUSR1=30 SIGUSR2=31 # SYNC: port.h PG_SIGNAL_COUNT = 32 # In the situation of another variant, another module should be constructed. def kill(pid : int, signal : int, timeout = 1000, dword1 = wintypes.DWORD(1)): """ Re-implementation of pg_kill for win32 using ctypes. """ if pid <= 0: raise OSError(errno.EINVAL, "process group not supported") if signal < 0 or signal >= PG_SIGNAL_COUNT: raise OSError(errno.EINVAL, "unsupported signal number") inbuffer = pointer(wintypes.BYTE(signal)) outbuffer = pointer(wintypes.BYTE(0)) outbytes = pointer(wintypes.DWORD(0)) pidpipe = br'\\.\pipe\pgsignal_' + str(pid).encode('ascii') timeout = wintypes.DWORD(timeout) r = CallNamedPipeA( pidpipe, inbuffer, dword1, outbuffer, dword1, outbytes, timeout ) if r: if outbuffer.contents.value == signal: if outbytes.contents.value == 1: # success return # Don't bother emulating the other failure cases/abstractions. # CallNamedPipeA should raise a WindowsError on those failures. raise OSError(errno.ESRCH, "unexpected output from CallNamedPipeA") __docformat__ = 'reStructuredText' fe-1.1.0/postgresql/project.py000066400000000000000000000005221203372773200163330ustar00rootroot00000000000000'project information' #: project name name = 'py-postgresql' #: IRI based project identity identity = 'http://python.projects.postgresql.org/' meaculpa = 'Python+Postgres' contact = 'python-general@pgfoundry.org' abstract = 'Driver and tools library for PostgreSQL' version_info = (1, 1, 0) version = '.'.join(map(str, version_info)) fe-1.1.0/postgresql/protocol/000077500000000000000000000000001203372773200161555ustar00rootroot00000000000000fe-1.1.0/postgresql/protocol/__init__.py000066400000000000000000000000611203372773200202630ustar00rootroot00000000000000## # .protocol ## """ PQ protocol facilities """ fe-1.1.0/postgresql/protocol/buffer.py000066400000000000000000000010121203372773200177720ustar00rootroot00000000000000## # .protocol.buffer ## """ This is an abstraction module that provides the working buffer implementation. If a C compiler is not available on the system that built the package, the slower `postgresql.protocol.pbuffer` module can be used in `postgresql.port.optimized.buffer`'s absence. This provides a convenient place to import the necessary module without concerning the local code with the details. """ try: from ..port.optimized import pq_message_stream except ImportError: from .pbuffer import pq_message_stream fe-1.1.0/postgresql/protocol/client3.py000066400000000000000000000363471203372773200201050ustar00rootroot00000000000000## # .protocol.client3 ## """ Protocol version 3.0 client and tools. """ import os import weakref from .buffer import pq_message_stream from . import element3 as element from . import xact3 as xact __all__ = ('Connection',) client_detected_protocol_error = element.ClientError(( (b'S', 'FATAL'), (b'C', '08P01'), (b'M', "wire-data caused exception in protocol transaction"), (b'H', "Protocol error detected."), )) client_connect_timeout = element.ClientError(( (b'S', 'FATAL'), (b'C', '--TOE'), (b'M', "connect timed out"), )) not_pq_error = element.ClientError(( # ProtocolError (b'S', 'FATAL'), (b'C', '08P01'), (b'M', 'server did not support SSL negotiation'), (b'H', 'The server is probably not PostgreSQL.'), )) no_ssl_error = element.ClientError(( (b'S', 'FATAL'), # InsecurityError (b'C', '--SEC'), (b'M', 'SSL was required, and the server could not accommodate'), )) # Details in __context__ ssl_failed_error = element.ClientError(( (b'S', 'FATAL'), # InsecurityError (b'C', '--SEC'), (b'M', 'SSL negotiation caused exception'), )) # failed to complete the connection, but no error set. # indicates a programmer error. partial_connection_error = element.ClientError(( (b'S', 'FATAL'), (b'C', '--XXX'), (b'M', "failed to complete negotiation"), (b'H', "Negotiation failed to completed, but no " \ "error was attributed on the connection."), )) eof_error = element.ClientError(( (b'S', 'FATAL'), (b'C', '08006'), (b'M', 'unexpected EOF from server'), (b'D', "Zero-length read from the connection's socket."), )) class Connection(object): """ A PQv3 connection. Operations are designed to not raise exceptions. The user of the connection must check for failures. This is done to encourage users to use their own Exception hierarchy. """ _tracer = None def tracer(): def fget(self): return self._tracer def fset(self, value): self._tracer = value self.write_messages = self.traced_write_messages self.read_messages = self.traced_read_messages def fdel(self): del self._tracer self.write_messages = self.standard_write_messages self.read_messages = self.standard_read_messages doc = 'Callable object to pass protocol trace strings to. '\ '(Normally a write method.)' return locals() tracer = property(**tracer()) def synchronize(self): """ Explicitly send a Synchronize message to the backend. Useful for forcing the completion of lazily processed transactions. NOTE: This will not cause trash to be taken out. """ if self.xact is not None: self.complete() x = xact.Instruction((element.SynchronizeMessage,)) self.xact = x self.complete() def interrupt(self, timeout = None): cq = element.CancelRequest(self.backend_id, self.key).bytes() s = self.socket_factory(timeout = timeout) try: s.sendall(cq) finally: s.close() def connect(self, ssl = None, timeout = None): """ Establish the connection to the server. If `ssl` is None, the socket will not be secured. If `ssl` is True, the socket will be secured, but it will close the connection and return if SSL is not available. If `ssl` is False, the socket will attempt to be secured, but will continue even in the event of a server that does not support SSL. `timeout` will be passed directly to the configured `socket_factory`. """ if hasattr(self, 'socket'): # If there's a socket attribute it normally means # that the connection has already been connected. # Successfully or not; doesn't matter. return # The existence of the socket attribute indicates an attempt was made. self.socket = None try: self.socket = self.socket_factory(timeout = timeout) except ( self.socket_factory.timeout_exception, self.socket_factory.fatal_exception ) as err: self.xact.state = xact.Complete self.xact.fatal = True self.xact.exception = err if self.socket_factory.timed_out(err): self.xact.error_message = client_connect_timeout else: errmsg = self.socket_factory.fatal_exception_message(err) # It's an error that occurred during socket creation/connection. # Even if there isn't a known fatal message, # identify it as fatal and set an ambiguous message. self.xact.error_message = element.ClientError(( (b'S', 'FATAL'), # ConnectionRejectionError (b'C', '08004'), (b'M', errmsg or "could not connect"), )) return if ssl is not None: # if ssl is True, ssl is *required* # if ssl is False, ssl will be tried, but not required # if ssl is None, no SSL negotiation will happen self.ssl_negotiation = supported = self.negotiate_ssl() # b'S' or b'N' was *not* received. if supported is None: # probably not PQv3.. self.xact.fatal = True self.xact.error_message = not_pq_error self.xact.state = xact.Complete return # b'N' was received, but ssl is required. if not supported and ssl is True: # ssl is required.. self.xact.fatal = True self.xact.error_message = no_ssl_error self.xact.state = xact.Complete return if supported: # Make an SSL connection. try: self.socket = self.socket_factory.secure(self.socket) except Exception as err: # Any exception marks a failure. self.xact.exception = err self.xact.fatal = True self.xact.state = xact.Complete self.xact.error_message = ssl_failed_error return # time to negotiate negxact = self.xact self.complete() if negxact.state is xact.Complete and negxact.fatal is None: self.key = negxact.killinfo.key self.backend_id = negxact.killinfo.pid elif not hasattr(self.xact, 'error_message'): # if it's not complete, something strange happened. # make sure to clean up... self.xact.fatal = True self.xact.state = xact.Complete self.xact.error_message = partial_connection_error def negotiate_ssl(self) -> (bool, None): """ Negotiate SSL If SSL is available--received b'S'--return True. If SSL is unavailable--received b'N'--return False. Otherwise, return None. Indicates non-PQv3 endpoint. """ r = element.NegotiateSSLMessage.bytes() while r: r = r[self.socket.send(r):] status = self.socket.recv(1) if status == b'S': return True elif status == b'N': return False # probably not postgresql. return None def read_into(self, Complete = xact.Complete): """ read data from the wire and write it into the message buffer. """ BUFFER_HAS_MSG = self.message_buffer.has_message BUFFER_WRITE_MSG = self.message_buffer.write RECV_DATA = self.socket.recv RECV_BYTES = self.recvsize XACT = self.xact while not BUFFER_HAS_MSG(): if self.read_data is not None: BUFFER_WRITE_MSG(self.read_data) self.read_data = None # If the read_data satisfied a message, # no more data should be read. continue try: self.read_data = RECV_DATA(RECV_BYTES) except self.socket_factory.fatal_exception as e: msg = self.socket_factory.fatal_exception_message(e) if msg is not None: XACT.state = Complete XACT.fatal = True XACT.exception = e XACT.error_message = element.ClientError(( (b'S', 'FATAL'), (b'C', '08006'), (b'M', msg), )) return False else: # It's probably a non-fatal error, # timeout or try again.. raise ## # nothing read from a blocking socket? it's over. if self.read_data == b'': XACT.state = Complete XACT.fatal = True XACT.error_message = eof_error return False # Got data. Put it in the buffer and clear read_data. self.read_data = BUFFER_WRITE_MSG(self.read_data) return True def standard_read_messages(self): 'read more messages into self.read when self.read is empty' r = True if not self.read: # get more data from the wire and # write it into the message buffer. r = self.read_into() self.read = self.message_buffer.read() return r read_messages = standard_read_messages def send_message_data(self): """ send all `message_data`. If an exception occurs, it will check if the exception is fatal or not. """ SEND_DATA = self.socket.send try: while self.message_data: # Send data while there is data to send. self.message_data = self.message_data[ SEND_DATA(self.message_data): ] except self.socket_factory.fatal_exception as e: msg = self.socket_factory.fatal_exception_message(e) if msg is not None: # it's fatal. self.xact.state = xact.Complete self.xact.fatal = True self.xact.exception = e self.xact.error_message = element.ClientError(( (b'S', 'FATAL'), (b'C', '08006'), (b'M', msg), )) return False else: # It wasn't fatal, so just raise raise return True def standard_write_messages(self, messages, cat_messages = element.cat_messages ): 'protocol message writer' if self.writing is not self.written: self.message_data += cat_messages(self.writing) self.written = self.writing if messages is not self.writing: self.writing = messages self.message_data += cat_messages(self.writing) self.written = self.writing return self.send_message_data() write_messages = standard_write_messages def traced_write_messages(self, messages): 'message_writer used when tracing' for msg in messages: t = getattr(msg, 'type', None) if t is not None: data_out = msg.bytes() self._tracer('↑ {type}({lend}): {data}{nl}'.format( type = repr(t)[2:-1], lend = len(data_out), data = repr(data_out), nl = os.linesep )) else: # It's not a message instance, so assume raw data. self._tracer('↑__(%d): %r%s' %( len(msg), msg, os.linesep )) return self.standard_write_messages(messages) def traced_read_messages(self): 'message_reader used when tracing' r = self.standard_read_messages() for msg in self.read: self._tracer('↓ %r(%d): %r%s' %( msg[0], len(msg[1]), msg[1], os.linesep) ) return r def take_out_trash(self): """ close cursors and statements slated for closure. """ xm = [] cursors = 0 for x in self.garbage_cursors: xm.append(element.ClosePortal(x)) cursors += 1 statements = 0 for x in self.garbage_statements: xm.append(element.CloseStatement(x)) statements += 1 xm.append(element.SynchronizeMessage) x = xact.Instruction(xm) self.xact = x del self.garbage_cursors[:cursors] del self.garbage_statements[:statements] self.complete() def push(self, x): """ setup the given transaction to be processed. """ # Push any queued closures onto the transaction or a new transaction. if x.state is xact.Complete: # It's already complete. return if self.xact is not None: self.complete() if self.xact is None: if self.garbage_statements or self.garbage_cursors: # This *has* to be done before a new transaction # is pushed. self.take_out_trash() if self.xact is None: # set it as the current transaction and begin self.xact = x # start it up self.step() def step(self): """ Make a single transition on the transaction. This should be used during COPY TO STDOUT or large result sets to stream information out. """ x = self.xact try: dir, op = x.state if dir is xact.Sending: self.write_messages(x.messages) # The "op" callable will either switch the state, or # set the 'messages' attribute with a new sequence # of message objects for more writing. op() elif dir is xact.Receiving: self.read_messages() self.read = self.read[op(self.read):] self.state = getattr(x, 'last_ready', self.state) else: raise RuntimeError( "unexpected PQ transaction state: " + repr(dir) ) except self.socket_factory.try_again_exception as e: # Unlike _complete, this catches at the outermost level # as there is no loop here for more transitioning. if self.socket_factory.try_again(e): # Can't read or write, ATM? Consider it a transition. :( return else: raise if x.state is xact.Complete and \ getattr(self.xact, 'fatal', None) is not True: # only remove the transaction if it's *not* fatal self.xact = None def complete(self): 'complete the current transaction' # Continue to transition until all transactions have been # completed, or an exception occurs that does not signal retry. x = self.xact R = xact.Receiving S = xact.Sending C = xact.Complete READ_MORE = self.read_messages WRITE_MESSAGES = self.write_messages while x.state is not C: try: while x.state[0] is R: if READ_MORE(): self.read = self.read[x.state[1](self.read):] # push() always takes one step, so it is likely that # the transaction is done sending out data by the time # complete() is called. while x.state[0] is S: if WRITE_MESSAGES(x.messages): x.state[1]() # Multiple calls to get() without signaling # completion *should* yield the same set over # and over again. except self.socket_factory.try_again_exception as e: if not self.socket_factory.try_again(e): raise except Exception as proto_exc: # If an exception is raised here, it's a protocol or a programming error. # XXX: It may be useful to have this closer to the actual # message so that a more informative message can be given. x.fatal = True x.state = xact.Complete x.exception = proto_exc x.error_message = client_detected_protocol_error self.state = b'' return self.state = getattr(x, 'last_ready', self.state) if getattr(x, 'fatal', None) is not True: # only remove the transaction if it's *not* fatal self.xact = None def register_cursor(self, cursor, pq_cursor_id): trash = self.trash_cursor self.cursors[pq_cursor_id] = weakref.ref(cursor, lambda ref: trash(pq_cursor_id)) def trash_cursor(self, pq_cursor_id): try: del self.cursors[pq_cursor_id] except KeyError: pass self.garbage_cursors.append(pq_cursor_id) def register_statement(self, statement, pq_statement_id): trash = self.trash_statement self.statements[pq_statement_id] = weakref.ref(statement, lambda ref: trash(pq_statement_id)) def trash_statement(self, pq_statement_id): try: del self.statements[pq_statement_id] except KeyError: pass self.garbage_statements.append(pq_statement_id) def __str__(self): if hasattr(self, 'ssl_negotiation'): if self.ssl_negotiation is True: ssl = 'SSL' elif self.ssl_negotiation is False: ssl = 'NOSSL after SSL' else: ssl = 'NOSSL' excstr = ''.join(self.exception_string(type(self.exception), self.exception)) return str(self.socket_factory) \ + ' -> (' + ssl + ')' \ + os.linesep + excstr.strip() def __init__(self, socket_factory, startup, password = b'',): """ Create a connection. This does not establish the connection, it only initializes it. """ self.key = None self.backend_id = None self.socket_factory = socket_factory self.xact = xact.Negotiation( element.Startup(startup), password ) self.cursors = {} self.statements = {} self.garbage_statements = [] self.garbage_cursors = [] self.message_buffer = pq_message_stream() self.recvsize = 8192 self.read = () # bytes received. self.read_data = None # serialized message data to be written self.message_data = b'' # messages to be written. self.writing = None # messages that have already been transformed into bytes. # (used to detect whether messages have already been written) self.written = None self.state = 'INITIALIZED' fe-1.1.0/postgresql/protocol/element3.py000066400000000000000000000516201203372773200202470ustar00rootroot00000000000000## # .protocol.element3 ## 'PQ version 3.0 elements' import sys import os import pprint from struct import unpack, Struct from .message_types import message_types from ..python.structlib import ushort_pack, ushort_unpack, ulong_pack, ulong_unpack try: from ..port.optimized import parse_tuple_message, pack_tuple_data except ImportError: def pack_tuple_data(atts, none = None, ulong_pack = ulong_pack, blen = bytes.__len__ ): return b''.join([ b'\xff\xff\xff\xff' if x is none else (ulong_pack(blen(x)) + x) for x in atts ]) try: from ..port.optimized import cat_messages except ImportError: from ..python.structlib import lH_pack, long_pack # Special case tuple()'s def _pack_tuple(t, blen = bytes.__len__, tlen = tuple.__len__, pack_head = lH_pack, ulong_pack = ulong_pack, ptd = pack_tuple_data, ): # NOTE: duplicated from above r = b''.join([ b'\xff\xff\xff\xff' if x is None else (ulong_pack(blen(x)) + x) for x in t ]) return pack_head((blen(r) + 6, tlen(t))) + r def cat_messages(messages, lpack = long_pack, blen = bytes.__len__, tuple = tuple, pack_tuple = _pack_tuple ): return b''.join([ (x.bytes() if x.__class__ is not bytes else ( b'd' + lpack(blen(x) + 4) + x )) if x.__class__ is not tuple else ( b'D' + pack_tuple(x) ) for x in messages ]) del _pack_tuple, lH_pack, long_pack StringFormat = b'\x00\x00' BinaryFormat = b'\x00\x01' class Message(object): bytes_struct = Struct("!cL") __slots__ = () def __repr__(self): return '%s.%s(%s)' %( type(self).__module__, type(self).__name__, ', '.join([repr(getattr(self, x)) for x in self.__slots__]) ) def __eq__(self, ob): return isinstance(ob, type(self)) and self.type == ob.type and \ not False in ( getattr(self, x) == getattr(ob, x) for x in self.__slots__ ) def bytes(self): data = self.serialize() return self.bytes_struct.pack(self.type, len(data) + 4) + data @classmethod def parse(typ, data): return typ(data) class StringMessage(Message): """ A message based on a single string component. """ type = b'' __slots__ = ('data',) def __repr__(self): return '%s.%s(%s)' %( type(self).__module__, type(self).__name__, repr(self.data), ) def __getitem__(self, i): return self.data.__getitem__(i) def __init__(self, data): self.data = data def serialize(self): return bytes(self.data) + b'\x00' @classmethod def parse(typ, data): if not data.endswith(b'\x00'): raise ValueError("string message not NUL-terminated") return typ(data[:-1]) class TupleMessage(tuple, Message): """ A message who's data is based on a tuple structure. """ type = b'' __slots__ = () def __repr__(self): return '%s.%s(%s)' %( type(self).__module__, type(self).__name__, tuple.__repr__(self) ) class Void(Message): """ An absolutely empty message. When serialized, it always yields an empty string. """ type = b'' __slots__ = () def bytes(self): return b'' def serialize(self): return b'' def __new__(typ, *args, **kw): return VoidMessage VoidMessage = Message.__new__(Void) def dict_message_repr(self): return '%s.%s(**%s)' %( type(self).__module__, type(self).__name__, pprint.pformat(dict(self)) ) class WireMessage(Message): def __init__(self, typ_data): self.type = message_types[typ_data[0][0]] self.data = typ_data[1] def serialize(self): return self[1] @classmethod def parse(typ, data): if ulong_unpack(data[1:5]) != len(data) - 1: raise ValueError( "invalid wire message where data is %d bytes and " \ "internal size stamp is %d bytes" %( len(data), ulong_unpack(data[1:5]) + 1 ) ) return typ((data[0:1], data[5:])) class EmptyMessage(Message): 'An abstract message that is always empty' __slots__ = () type = b'' def __new__(typ): return typ.SingleInstance def serialize(self): return b'' @classmethod def parse(typ, data): if data != b'': raise ValueError("empty message(%r) had data" %(typ.type,)) return typ.SingleInstance class Notify(Message): 'Asynchronous notification message' type = message_types[b'A'[0]] __slots__ = ('pid', 'channel', 'payload',) def __init__(self, pid, channel, payload = b''): self.pid = pid self.channel = channel self.payload = payload def serialize(self): return ulong_pack(self.pid) + \ self.channel + b'\x00' + \ self.payload + b'\x00' @classmethod def parse(typ, data): pid = ulong_unpack(data) channel, payload, _ = data[4:].split(b'\x00', 2) return typ(pid, channel, payload) class ShowOption(Message): """ShowOption(name, value) GUC variable information from backend""" type = message_types[b'S'[0]] __slots__ = ('name', 'value') def __init__(self, name, value): self.name = name self.value = value def serialize(self): return self.name + b'\x00' + self.value + b'\x00' @classmethod def parse(typ, data): return typ(*(data.split(b'\x00', 2)[0:2])) class Complete(StringMessage): 'Command completion message.' type = message_types[b'C'[0]] __slots__ = () @classmethod def parse(typ, data): return typ(data.rstrip(b'\x00')) def extract_count(self): """ Extract the last set of digits as an integer. """ # Find the last sequence of digits. # If there are no fields consisting only of digits, there is no count. for x in reversed(self.data.split()): if x.isdigit(): return int(x) return None def extract_command(self): """ Strip all the *surrounding* digits and spaces from the command tag, and return that string. """ return self.data.strip(b'\c\n\t 0123456789') or None class Null(EmptyMessage): 'Null command' type = message_types[b'I'[0]] __slots__ = () NullMessage = Message.__new__(Null) Null.SingleInstance = NullMessage class NoData(EmptyMessage): 'Null command' type = message_types[b'n'[0]] __slots__ = () NoDataMessage = Message.__new__(NoData) NoData.SingleInstance = NoDataMessage class ParseComplete(EmptyMessage): 'Parse reaction' type = message_types[b'1'[0]] __slots__ = () ParseCompleteMessage = Message.__new__(ParseComplete) ParseComplete.SingleInstance = ParseCompleteMessage class BindComplete(EmptyMessage): 'Bind reaction' type = message_types[b'2'[0]] __slots__ = () BindCompleteMessage = Message.__new__(BindComplete) BindComplete.SingleInstance = BindCompleteMessage class CloseComplete(EmptyMessage): 'Close statement or Portal' type = message_types[b'3'[0]] __slots__ = () CloseCompleteMessage = Message.__new__(CloseComplete) CloseComplete.SingleInstance = CloseCompleteMessage class Suspension(EmptyMessage): 'Portal was suspended, more tuples for reading' type = message_types[b's'[0]] __slots__ = () SuspensionMessage = Message.__new__(Suspension) Suspension.SingleInstance = SuspensionMessage class Ready(Message): 'Ready for new query' type = message_types[b'Z'[0]] possible_states = ( message_types[b'I'[0]], message_types[b'E'[0]], message_types[b'T'[0]], ) __slots__ = ('xact_state',) def __init__(self, data): if data not in self.possible_states: raise ValueError("invalid state for Ready message: " + repr(data)) self.xact_state = data def serialize(self): return self.xact_state class Notice(Message, dict): """ Notification message Used by PQ to emit INFO, NOTICE, and WARNING messages among other severities. """ type = message_types[b'N'[0]] __slots__ = () __repr__ = dict_message_repr def serialize(self): return b'\x00'.join([ k + v for k, v in self.items() if k and v is not None ]) + b'\x00' @classmethod def parse(typ, data, msgtypes = message_types): return typ([ (msgtypes[x[0]], x[1:]) # "if x" reduce empty fields for x in data.split(b'\x00') if x ]) class ClientNotice(Notice): __slots__ = () def serialize(self): raise RuntimeError("cannot serialize ClientNotice") @classmethod def parse(self): raise RuntimeError("cannot parse ClientNotice") class Error(Notice): """Incoming error""" type = message_types[b'E'[0]] __slots__ = () class ClientError(Error): __slots__ = () def serialize(self): raise RuntimeError("cannot serialize ClientError") @classmethod def parse(self): raise RuntimeError("cannot serialize ClientError") class FunctionResult(Message): """Function result value""" type = message_types[b'V'[0]] __slots__ = ('result',) def __init__(self, datum): self.result = datum def serialize(self): return self.result is None and b'\xff\xff\xff\xff' or \ ulong_pack(len(self.result)) + self.result @classmethod def parse(typ, data): if data == b'\xff\xff\xff\xff': return typ(None) size = ulong_unpack(data[0:4]) data = data[4:] if size != len(data): raise ValueError( "data length(%d) is not equal to the specified message size(%d)" %( len(data), size ) ) return typ(data) class AttributeTypes(TupleMessage): """Tuple attribute types""" type = message_types[b't'[0]] __slots__ = () def serialize(self): return ushort_pack(len(self)) + b''.join([ulong_pack(x) for x in self]) @classmethod def parse(typ, data): ac = ushort_unpack(data[0:2]) args = data[2:] if len(args) != ac * 4: raise ValueError("invalid argument type data size") return typ(unpack('!%dL'%(ac,), args)) class TupleDescriptor(TupleMessage): """Tuple description""" type = message_types[b'T'[0]] struct = Struct("!LhLhlh") __slots__ = () def keys(self): return [x[0] for x in self] def serialize(self): return ushort_pack(len(self)) + b''.join([ x[0] + b'\x00' + self.struct.pack(*x[1:]) for x in self ]) @classmethod def parse(typ, data): ac = ushort_unpack(data[0:2]) atts = [] data = data[2:] ca = 0 while ca < ac: # End Of Attribute Name eoan = data.index(b'\x00') name = data[0:eoan] data = data[eoan+1:] # name, relationId, columnNumber, typeId, typlen, typmod, format atts.append((name,) + typ.struct.unpack(data[0:18])) data = data[18:] ca += 1 return typ(atts) class Tuple(TupleMessage): """Incoming tuple""" type = message_types[b'D'[0]] __slots__ = () def serialize(self): return ushort_pack(len(self)) + pack_tuple_data(self) @classmethod def parse(typ, data, T = tuple, ulong_unpack = ulong_unpack, len = len ): natts = ushort_unpack(data[0:2]) atts = [] offset = 2 add = atts.append while natts > 0: alo = offset offset += 4 size = data[alo:offset] if size == b'\xff\xff\xff\xff': att = None else: al = ulong_unpack(size) ao = offset offset = ao + al att = data[ao:offset] add(att) natts -= 1 return T(atts) try: parse = parse_tuple_message except NameError: # This is an override when port.optimized is available. pass class KillInformation(Message): 'Backend cancellation information' type = message_types[b'K'[0]] struct = Struct("!LL") __slots__ = ('pid', 'key') def __init__(self, pid, key): self.pid = pid self.key = key def serialize(self): return self.struct.pack(self.pid, self.key) @classmethod def parse(typ, data): return typ(*typ.struct.unpack(data)) class CancelRequest(KillInformation): 'Abort the query in the specified backend' type = b'' from .version import CancelRequestCode as version packed_version = version.bytes() __slots__ = ('pid', 'key') def serialize(self): return self.packed_version + self.struct.pack( self.pid, self.key ) def bytes(self): data = self.serialize() return ulong_pack(len(data) + 4) + self.serialize() @classmethod def parse(typ, data): if data[0:4] != typ.packed_version: raise ValueError("invalid cancel query code") return typ(*typ.struct.unpack(data[4:])) class NegotiateSSL(Message): "Discover backend's SSL support" type = b'' from .version import NegotiateSSLCode as version packed_version = version.bytes() __slots__ = () def __new__(typ): return NegotiateSSLMessage def bytes(self): data = self.serialize() return ulong_pack(len(data) + 4) + data def serialize(self): return self.packed_version @classmethod def parse(typ, data): if data != typ.packed_version: raise ValueError("invalid SSL Negotiation code") return NegotiateSSLMessage NegotiateSSLMessage = Message.__new__(NegotiateSSL) class Startup(Message, dict): """ Initiate a connection using the given keywords. """ type = b'' from postgresql.protocol.version import V3_0 as version packed_version = version.bytes() __slots__ = () __repr__ = dict_message_repr def serialize(self): return self.packed_version + b''.join([ k + b'\x00' + v + b'\x00' for k, v in self.items() if v is not None ]) + b'\x00' def bytes(self): data = self.serialize() return ulong_pack(len(data) + 4) + data @classmethod def parse(typ, data): if data[0:4] != typ.packed_version: raise ValueError("invalid version code {1}".format(repr(data[0:4]))) kw = dict() key = None for value in data[4:].split(b'\x00')[:-2]: if key is None: key = value continue kw[key] = value key = None return typ(kw) AuthRequest_OK = 0 AuthRequest_Cleartext = 3 AuthRequest_Password = AuthRequest_Cleartext AuthRequest_Crypt = 4 AuthRequest_MD5 = 5 # Unsupported by pg_protocol. AuthRequest_KRB4 = 1 AuthRequest_KRB5 = 2 AuthRequest_SCMC = 6 AuthRequest_SSPI = 9 AuthRequest_GSS = 7 AuthRequest_GSSContinue = 8 AuthNameMap = { AuthRequest_Password : 'Cleartext', AuthRequest_Crypt : 'Crypt', AuthRequest_MD5 : 'MD5', AuthRequest_KRB4 : 'Kerberos4', AuthRequest_KRB5 : 'Kerberos5', AuthRequest_SCMC : 'SCM Credential', AuthRequest_SSPI : 'SSPI', AuthRequest_GSS : 'GSS', AuthRequest_GSSContinue : 'GSSContinue', } class Authentication(Message): """Authentication(request, salt)""" type = message_types[b'R'[0]] __slots__ = ('request', 'salt') def __init__(self, request, salt): self.request = request self.salt = salt def serialize(self): return ulong_pack(self.request) + self.salt @classmethod def parse(typ, data): return typ(ulong_unpack(data[0:4]), data[4:]) class Password(StringMessage): 'Password supplement' type = message_types[b'p'[0]] __slots__ = ('data',) class Disconnect(EmptyMessage): 'Close the connection' type = message_types[b'X'[0]] __slots__ = () DisconnectMessage = Message.__new__(Disconnect) Disconnect.SingleInstance = DisconnectMessage class Flush(EmptyMessage): 'Flush' type = message_types[b'H'[0]] __slots__ = () FlushMessage = Message.__new__(Flush) Flush.SingleInstance = FlushMessage class Synchronize(EmptyMessage): 'Synchronize' type = message_types[b'S'[0]] __slots__ = () SynchronizeMessage = Message.__new__(Synchronize) Synchronize.SingleInstance = SynchronizeMessage class Query(StringMessage): """Execute the query with the given arguments""" type = message_types[b'Q'[0]] __slots__ = ('data',) class Parse(Message): """Parse a query with the specified argument types""" type = message_types[b'P'[0]] __slots__ = ('name', 'statement', 'argtypes') def __init__(self, name, statement, argtypes): self.name = name self.statement = statement self.argtypes = argtypes @classmethod def parse(typ, data): name, statement, args = data.split(b'\x00', 2) ac = ushort_unpack(args[0:2]) args = args[2:] if len(args) != ac * 4: raise ValueError("invalid argument type data") at = unpack('!%dL'%(ac,), args) return typ(name, statement, at) def serialize(self): ac = ushort_pack(len(self.argtypes)) return self.name + b'\x00' + self.statement + b'\x00' + ac + b''.join([ ulong_pack(x) for x in self.argtypes ]) class Bind(Message): """ Bind a parsed statement with the given arguments to a Portal Bind( name, # Portal/Cursor identifier statement, # Prepared Statement name/identifier aformats, # Argument formats; Sequence of BinaryFormat or StringFormat. arguments, # Argument data; Sequence of None or argument data(str). rformats, # Result formats; Sequence of BinaryFormat or StringFormat. ) """ type = message_types[b'B'[0]] __slots__ = ('name', 'statement', 'aformats', 'arguments', 'rformats') def __init__(self, name, statement, aformats, arguments, rformats): self.name = name self.statement = statement self.aformats = aformats self.arguments = arguments self.rformats = rformats def serialize(self, len = len): args = self.arguments ac = ushort_pack(len(args)) ad = pack_tuple_data(tuple(args)) return \ self.name + b'\x00' + self.statement + b'\x00' + \ ac + b''.join(self.aformats) + ac + ad + \ ushort_pack(len(self.rformats)) + b''.join(self.rformats) @classmethod def parse(typ, message_data): name, statement, data = message_data.split(b'\x00', 2) ac = ushort_unpack(data[:2]) offset = 2 + (2 * ac) aformats = unpack(("2s" * ac), data[2:offset]) natts = ushort_unpack(data[offset:offset+2]) args = list() offset += 2 while natts > 0: alo = offset offset += 4 size = data[alo:offset] if size == b'\xff\xff\xff\xff': att = None else: al = ulong_unpack(size) ao = offset offset = ao + al att = data[ao:offset] args.append(att) natts -= 1 rfc = ushort_unpack(data[offset:offset+2]) ao = offset + 2 offset = ao + (2 * rfc) rformats = unpack(("2s" * rfc), data[ao:offset]) return typ(name, statement, aformats, args, rformats) class Execute(Message): """Fetch results from the specified Portal""" type = message_types[b'E'[0]] __slots__ = ('name', 'max') def __init__(self, name, max = 0): self.name = name self.max = max def serialize(self): return self.name + b'\x00' + ulong_pack(self.max) @classmethod def parse(typ, data): name, max = data.split(b'\x00', 1) return typ(name, ulong_unpack(max)) class Describe(StringMessage): """Describe a Portal or Prepared Statement""" type = message_types[b'D'[0]] __slots__ = ('data',) def serialize(self): return self.subtype + self.data + b'\x00' @classmethod def parse(typ, data): if data[0:1] != typ.subtype: raise ValueError( "invalid Describe message subtype, %r; expected %r" %( typ.subtype, data[0:1] ) ) return super().parse(data[1:]) class DescribeStatement(Describe): subtype = message_types[b'S'[0]] __slots__ = ('data',) class DescribePortal(Describe): subtype = message_types[b'P'[0]] __slots__ = ('data',) class Close(StringMessage): """Generic Close""" type = message_types[b'C'[0]] __slots__ = () def serialize(self): return self.subtype + self.data + b'\x00' @classmethod def parse(typ, data): if data[0:1] != typ.subtype: raise ValueError( "invalid Close message subtype, %r; expected %r" %( typ.subtype, data[0:1] ) ) return super().parse(data[1:]) class CloseStatement(Close): """Close the specified Statement""" subtype = message_types[b'S'[0]] __slots__ = () class ClosePortal(Close): """Close the specified Portal""" subtype = message_types[b'P'[0]] __slots__ = () class Function(Message): """Execute the specified function with the given arguments""" type = message_types[b'F'[0]] __slots__ = ('oid', 'aformats', 'arguments', 'rformat') def __init__(self, oid, aformats, args, rformat): self.oid = oid self.aformats = aformats self.arguments = args self.rformat = rformat def serialize(self): ac = ushort_pack(len(self.arguments)) return ulong_pack(self.oid) + \ ac + b''.join(self.aformats) + \ ac + pack_tuple_data(tuple(self.arguments)) + self.rformat @classmethod def parse(typ, data): oid = ulong_unpack(data[0:4]) ac = ushort_unpack(data[4:6]) offset = 6 + (2 * ac) aformats = unpack(("2s" * ac), data[6:offset]) natts = ushort_unpack(data[offset:offset+2]) args = list() offset += 2 while natts > 0: alo = offset offset += 4 size = data[alo:offset] if size == b'\xff\xff\xff\xff': att = None else: al = ulong_unpack(size) ao = offset offset = ao + al att = data[ao:offset] args.append(att) natts -= 1 return typ(oid, aformats, args, data[offset:]) class CopyBegin(Message): type = None struct = Struct("!BH") __slots__ = ('format', 'formats') def __init__(self, format, formats): self.format = format self.formats = formats def serialize(self): return self.struct.pack(self.format, len(self.formats)) + b''.join([ ushort_pack(x) for x in self.formats ]) @classmethod def parse(typ, data): format, natts = typ.struct.unpack(data[:3]) formats_str = data[3:] if len(formats_str) != natts * 2: raise ValueError("number of formats and data do not match up") return typ(format, [ ushort_unpack(formats_str[x:x+2]) for x in range(0, natts * 2, 2) ]) class CopyToBegin(CopyBegin): """Begin copying to""" type = message_types[b'H'[0]] __slots__ = ('format', 'formats') class CopyFromBegin(CopyBegin): """Begin copying from""" type = message_types[b'G'[0]] __slots__ = ('format', 'formats') class CopyData(Message): type = message_types[b'd'[0]] __slots__ = ('data',) def __init__(self, data): self.data = bytes(data) def serialize(self): return self.data @classmethod def parse(typ, data): return typ(data) class CopyFail(StringMessage): type = message_types[b'f'[0]] __slots__ = ('data',) class CopyDone(EmptyMessage): type = message_types[b'c'[0]] __slots__ = ('data',) CopyDoneMessage = Message.__new__(CopyDone) CopyDone.SingleInstance = CopyDoneMessage fe-1.1.0/postgresql/protocol/message_types.py000066400000000000000000000011021203372773200213710ustar00rootroot00000000000000## # .protocol.message_types ## """ Data module providing a sequence of bytes objects whose value corresponds to its index in the sequence. This provides resource for buffer objects to use common message type objects. WARNING: It's tempting to use the 'is' operator and in some circumstances that may be okay. However, it's possible (sys.modules.clear()) for the extension modules' copy of this to become inconsistent with what protocol.element3 and protocol.xact3 are using, so it's important to **not** use 'is'. """ message_types = tuple([bytes((x,)) for x in range(256)]) fe-1.1.0/postgresql/protocol/pbuffer.py000066400000000000000000000072711203372773200201670ustar00rootroot00000000000000## # .protocol.pbuffer ## """ Pure Python message buffer implementation. Given data read from the wire, buffer the data until a complete message has been received. """ __all__ = ['pq_message_stream'] from io import BytesIO import struct from .message_types import message_types xl_unpack = struct.Struct('!xL').unpack_from class pq_message_stream(object): 'provide a message stream from a data stream' _block = 512 _limit = _block * 4 def __init__(self): self._strio = BytesIO() self._start = 0 def truncate(self): "remove all data in the buffer" self._strio.truncate(0) self._start = 0 def _rtruncate(self, amt = None): "[internal] remove the given amount of data" strio = self._strio if amt is None: amt = self._strio.tell() strio.seek(0, 2) size = strio.tell() # if the total size is equal to the amt, # then the whole thing is going to be truncated. if size == amt: strio.truncate(0) return copyto_pos = 0 copyfrom_pos = amt while True: strio.seek(copyfrom_pos) data = strio.read(self._block) # Next copyfrom copyfrom_pos = strio.tell() strio.seek(copyto_pos) strio.write(data) if len(data) != self._block: break # Next copyto copyto_pos = strio.tell() strio.truncate(size - amt) def has_message(self, xl_unpack = xl_unpack, len = len): "if the buffer has a message available" strio = self._strio strio.seek(self._start) header = strio.read(5) if len(header) < 5: return False length, = xl_unpack(header) if length < 4: raise ValueError("invalid message size '%d'" %(length,)) strio.seek(0, 2) return (strio.tell() - self._start) >= length + 1 def __len__(self, xl_unpack = xl_unpack, len = len): "number of messages in buffer" count = 0 rpos = self._start strio = self._strio strio.seek(self._start) while True: # get the message metadata header = strio.read(5) rpos += 5 if len(header) < 5: # not enough data for another message break # unpack the length from the header length, = xl_unpack(header) rpos += length - 4 if length < 4: raise ValueError("invalid message size '%d'" %(length,)) strio.seek(length - 4 - 1, 1) if len(strio.read(1)) != 1: break count += 1 return count def _get_message(self, mtypes = message_types, len = len, xl_unpack = xl_unpack, ): strio = self._strio header = strio.read(5) if len(header) < 5: return length, = xl_unpack(header) typ = mtypes[header[0]] if length < 4: raise ValueError("invalid message size '%d'" %(length,)) length -= 4 body = strio.read(length) if len(body) < length: # Not enough data for message. return return (typ, body) def next_message(self): if self._start > self._limit: self._rtruncate(self._start) self._start = 0 self._strio.seek(self._start) msg = self._get_message() if msg is not None: self._start = self._strio.tell() return msg def __next__(self): if self._start > self._limit: self._rtruncate(self._start) self._start = 0 self._strio.seek(self._start) msg = self._get_message() if msg is None: raise StopIteration self._start = self._strio.tell() return msg def read(self, num = 0xFFFFFFFF, len = len): if self._start > self._limit: self._rtruncate(self._start) self._start = 0 new_start = self._start self._strio.seek(new_start) l = [] while len(l) < num: msg = self._get_message() if msg is None: break l.append(msg) new_start += (5 + len(msg[1])) self._start = new_start return l def write(self, data): # Always append data; it's a stream, damnit.. self._strio.seek(0, 2) self._strio.write(data) def getvalue(self): self._strio.seek(self._start) return self._strio.read() fe-1.1.0/postgresql/protocol/version.py000066400000000000000000000020421203372773200202120ustar00rootroot00000000000000## # .protocol.version ## 'PQ version class' from struct import Struct version_struct = Struct('!HH') class Version(tuple): """Version((major, minor)) -> Version Version serializer and parser. """ major = property(fget = lambda s: s[0]) minor = property(fget = lambda s: s[1]) def __new__(subtype, major_minor : '(major, minor)'): (major, minor) = major_minor major = int(major) minor = int(minor) # If it can't be packed like this, it's not a valid version. try: version_struct.pack(major, minor) except Exception as e: raise ValueError("unpackable major and minor") from e return tuple.__new__(subtype, (major, minor)) def __int__(self): return (self[0] << 16) | self[1] def bytes(self): return version_struct.pack(self[0], self[1]) def __repr__(self): return '%d.%d' %(self[0], self[1]) def parse(self, data): return self(version_struct.unpack(data)) parse = classmethod(parse) CancelRequestCode = Version((1234, 5678)) NegotiateSSLCode = Version((1234, 5679)) V2_0 = Version((2, 0)) V3_0 = Version((3, 0)) fe-1.1.0/postgresql/protocol/xact3.py000066400000000000000000000510161203372773200175540ustar00rootroot00000000000000## # .protocol.xact3 - protocol state machine ## 'PQ version 3.0 client transactions' import sys import os import pprint from abc import ABCMeta, abstractmethod from itertools import chain from operator import itemgetter get0 = itemgetter(0) get1 = itemgetter(1) from ..python.functools import Composition as compose from . import element3 as element from hashlib import md5 from ..resolved.crypt import crypt try: from ..port.optimized import consume_tuple_messages except ImportError: pass Receiving = True Sending = False Complete = (None, None) AsynchronousMap = { element.Notice.type : element.Notice.parse, element.Notify.type : element.Notify.parse, element.ShowOption.type : element.ShowOption.parse, } def return_arg(x): return x message_expectation = \ "expected message of types {expected}, " \ "but received {received} instead".format class Transaction(object, metaclass = ABCMeta): """ If the fatal attribute is not None, an error occurred, and the `error_message` attribute should be set to a element3.Error instance. """ fatal = None @abstractmethod def messages_received(self): """ Return an iterable to the messages received that have been processed. """ class Closing(Transaction): """ Send the disconnect message and mark the connection as closed. """ error_message = element.ClientError(( (b'S', 'FATAL'), # pg_exc.ConnectionDoesNotExistError.code (b'C', '08003'), (b'M', 'operation on closed connection'), (b'H', "A new connection needs to be "\ "created in order to query the server."), )) def messages_received(self): return () def sent(self): """ Empty messages and mark complete. """ self.messages = () self.fatal = True self.state = Complete def __init__(self): self.messages = (element.DisconnectMessage,) self.state = (Sending, self.sent) class Negotiation(Transaction): """ Negotiation is a protocol transaction used to manage the initial stage of a connection to PostgreSQL. This transaction revolves around the `state_machine` method which is a generator that takes individual messages and progresses the state of the connection negotiation. This was chosen over the route taken by `Transaction`, seen later, as it's not terribly performance intensive and there are many conditions which make a generator ideal for managing the state. """ state = None def __init__(self, startup_message : "startup message to send", password : "password source data(encoded password bytes)", ): self.startup_message = startup_message self.password = password self.received = [()] self.asyncs = [] self.authtype = None self.killinfo = None self.authok = None self.last_ready = None self.machine = self.state_machine() self.messages = next(self.machine) self.state = (Sending, self.sent) def __repr__(self): s = type(self).__module__ + "." + type(self).__name__ s += pprint.pformat((self.startup_message, self.password)).lstrip() return s def messages_received(self): return self.processed def sent(self): """ Empty messages and switch state to receiving. This is called by the user after the `messages` have been sent to the remote end. That is, this merely finalizes the "Sending" state. """ self.messages = () self.state = (Receiving, self.put_messages) def put_messages(self, messages): # Record everything received. out_messages = () if messages is not self.received[-1]: self.received.append(messages) else: raise RuntimeError("negotiation was interrupted") # if an Error message was found, complete and leave. count = 0 try: for x in messages: count += 1 if x[0] == element.Error.type: if self.fatal is None: self.error_message = element.Error.parse(x[1]) self.fatal = True self.state = Complete return count elif x[0] in AsynchronousMap: self.asyncs.append( AsynchronousMap[x[0]](x[1]) ) else: out_messages = self.machine.send(x) if out_messages: break except StopIteration: # generator is complete, negotiation is complete.. self.state = Complete return count if out_messages: self.messages = out_messages self.state = (Sending, self.sent) return count def unsupported_auth_request(self, req): self.fatal = True self.error_message = element.ClientError(( (b'S', "FATAL"), (b'C', "--AUT"), (b'M', "unsupported authentication request %r(%d)" %( element.AuthNameMap.get(req, ''), req, )), (b'H', "'postgresql.protocol' only supports: MD5, crypt, plaintext, and trust."), )) self.state = Complete def state_machine(self): """ Generator keeping the state of the connection negotiation process. """ x = (yield (self.startup_message,)) if x[0] != element.Authentication.type: self.fatal = True self.error_message = element.ClientError(( (b'S', 'FATAL'), (b'C', '08P01'), (b'M', message_expectation( expected = element.Authentication.type, received = x[0], )), )) return self.authtype = element.Authentication.parse(x[1]) req = self.authtype.request if req != element.AuthRequest_OK: if req == element.AuthRequest_Cleartext: pw = self.password elif req == element.AuthRequest_Crypt: pw = crypt(self.password, self.authtype.salt) elif req == element.AuthRequest_MD5: pw = md5(self.password + self.startup_message[b'user']).hexdigest().encode('ascii') pw = b'md5' + md5(pw + self.authtype.salt).hexdigest().encode('ascii') else: ## # Not going to work. Sorry :( # The many authentication types supported by PostgreSQL are not # easy to implement, especially when implementations for the # type don't exist for Python. self.unsupported_auth_request(req) return x = (yield (element.Password(pw),)) self.authok = element.Authentication.parse(x[1]) if self.authok.request != element.AuthRequest_OK: self.fatal = True self.error_message = element.ClientError(( (b'S', 'FATAL'), (b'C', "08P01"), (b'M', "expected OK from the authentication " \ "message, but received %s(%s) instead" %( repr(element.AuthNameMap.get( self.authok.request, '' )), repr(self.authok.request), ), ) )) return else: self.authok = self.authtype # Done authenticating, pick up the killinfo and the ready message. x = (yield None) if x[0] != element.KillInformation.type: self.fatal = True self.error_message = element.ClientError(( (b'S', 'FATAL'), (b'C', '08P01'), (b'M', message_expectation( expected = element.KillInformation.type, received = repr(x[0]), )), )) return self.killinfo = element.KillInformation.parse(x[1]) x = (yield None) if x[0] != element.Ready.type: self.fatal = True self.error_message = element.ClientError(( (b'S', "FATAL"), (b'C', "08P01"), (b'M', message_expectation( expected = repr(element.Ready.type), received = repr(x[0]), )) )) return self.last_ready = element.Ready.parse(x[1]) class Instruction(Transaction): """ Manage the state of a sequence of request messages to be sent to the server. It provides the messages to be sent and takes the response messages for order and integrity validation: Instruction([.element3.Message(), ..]) A message must be one of: * `.element3.Query` * `.element3.Function` * `.element3.Parse` * `.element3.Bind` * `.element3.Describe` * `.element3.Close` * `.element3.Execute` * `.element3.Synchronize` * `.element3.Flush` """ state = None CopyFailMessage = element.CopyFail(b"invalid termination") # The hook is the dictionary that provides the path for the # current working message. The received messages ultimately come # through here and get parsed using the associated callable. # Messages that complete a command are paired with None. hook = { element.Query.type : ( # 0: Start. { element.TupleDescriptor.type : (element.TupleDescriptor.parse, 3), element.Null.type : (element.Null.parse, 0), element.Complete.type : (element.Complete.parse, 0), element.CopyToBegin.type : (element.CopyToBegin.parse, 2), element.CopyFromBegin.type : (element.CopyFromBegin.parse, 1), element.Ready.type : (element.Ready.parse, None), }, # 1: Complete. { element.Complete.type : (element.Complete.parse, 0), }, # 2: Copy Data. # CopyData until CopyDone. # Complete comes next. { element.CopyData.type : (return_arg, 2), element.CopyDone.type : (element.CopyDone.parse, 1), }, # 3: Row Data. { element.Tuple.type : (element.Tuple.parse, 3), element.Complete.type : (element.Complete.parse, 0), element.Ready.type : (element.Ready.parse, None), }, ), element.Function.type : ( {element.FunctionResult.type : (element.FunctionResult.parse, 1)}, {element.Ready.type : (element.Ready.parse, None)}, ), # Extended Protocol element.Parse.type : ( {element.ParseComplete.type : (element.ParseComplete.parse, None)}, ), element.Bind.type : ( {element.BindComplete.type : (element.BindComplete.parse, None)}, ), element.Describe.type : ( # Still needs the descriptor. { element.AttributeTypes.type : (element.AttributeTypes.parse, 1), element.TupleDescriptor.type : ( element.TupleDescriptor.parse, None ), }, # NoData or TupleDescriptor { element.NoData.type : (element.NoData.parse, None), element.TupleDescriptor.type : ( element.TupleDescriptor.parse, None ), }, ), element.Close.type : ( {element.CloseComplete.type : (element.CloseComplete.parse, None)}, ), element.Execute.type : ( # 0: Start. { element.Tuple.type : (element.Tuple.parse, 1), element.CopyToBegin.type : (element.CopyToBegin.parse, 2), element.CopyFromBegin.type : (element.CopyFromBegin.parse, 3), element.Null.type : (element.Null.parse, None), element.Complete.type : (element.Complete.parse, None), }, # 1: Row Data. { element.Tuple.type : (element.Tuple.parse, 1), element.Suspension.type : (element.Suspension.parse, None), element.Complete.type : (element.Complete.parse, None), }, # 2: Copy Data. { element.CopyData.type : (return_arg, 2), element.CopyDone.type : (element.CopyDone.parse, 3), }, # 3: Complete. { element.Complete.type : (element.Complete.parse, None), }, ), element.Synchronize.type : ( {element.Ready.type : (element.Ready.parse, None)}, ), element.Flush.type : None, } initial_state = ( (), # last messages, (0, 0), # request position, response position (0, 0), # last request position, last response position ) def __init__(self, commands, asynchook = return_arg): """ Initialize an `Instruction` instance using the given commands. Commands are `postgresql.protocol.element3.Message` instances: * `.element3.Query` * `.element3.Function` * `.element3.Parse` * `.element3.Bind` * `.element3.Describe` * `.element3.Close` * `.element3.Execute` * `.element3.Synchronize` * `.element3.Flush` """ # Commands are accessed by index. self.commands = tuple(commands) self.asynchook = asynchook self.completed = [] self.last = self.initial_state self.messages = list(self.commands) self.state = (Sending, self.standard_sent) self.fatal = None for cmd in self.commands: if cmd.type not in self.hook: raise TypeError( "unknown message type for PQ 3.0 protocol", cmd.type ) def __repr__(self, format = '{mod}.{name}({nl}{args})'.format): return format( mod = type(self).__module__, name = type(self).__name__, nl = os.linesep, args = pprint.pformat(self.commands) ) def messages_received(self): 'Received and validate messages' return chain.from_iterable(map(get1, self.completed)) def reverse(self, chaining = chain.from_iterable, map = map, transform = compose((get1, reversed)), reversed = reversed ): """ A iterator that producing the completed messages in reverse order. Last in, first out. """ return chaining(map(transform, reversed(self.completed))) def standard_put(self, messages, SWITCH_TYPES = element.Execute.type + element.Query.type, ERROR_TYPE = element.Error.type, READY_TYPE = element.Ready.type, ERROR_PARSE = element.Error.parse, len = len, ): """ Attempt to forward the state of the transaction using the given messages. "put" messages into the transaction for processing. If an invalid command is initialized on the `Transaction` object, an `IndexError` will be thrown. """ COMMANDS = self.commands NCOMMANDS = len(COMMANDS) HOOK = self.hook # We processed it, but apparently something went wrong, # so go ahead and reprocess it. if messages is self.last[0]: offset, current_step = self.last[1] # don't clear the asyncs. they have already been process by the hook. else: offset, current_step = self.last[2] # it's a new set, so we can clear the asyncs record. self._asyncs = [] cmd = COMMANDS[offset] paths = HOOK[cmd.type] processed = [] count = 0 for x in messages: count += 1 # For the current message, get the path for the message # and whether it signals the end of the current command path, next_step = paths[current_step].get(x[0], (None, None)) if path is None: # No path for message type, could be a protocol error. if x[0] == ERROR_TYPE: em = ERROR_PARSE(x[1]) # Is it fatal? self.fatal = fatal = em[b'S'].upper() != b'ERROR' self.error_message = em if fatal is True: # Can't sync up if the session is closed. self.state = Complete return count # Error occurred, so sync up with backend if # the current command is not 'Q' or 'F' as they # imply a sync message. if cmd.type not in ( element.Function.type, element.Query.type ): # Adjust the offset forward until the Sync message is found. for offset in range(offset, NCOMMANDS): if COMMANDS[offset] is element.SynchronizeMessage: break else: ## # It's done. self.state = Complete return count ## # Not quite done, the state(Ready) message still # needs to be received. cmd = COMMANDS[offset] paths = HOOK[cmd.type] # On a new command, setup the new step. current_step = 0 continue elif x[0] in AsynchronousMap: if x not in self._asyncs: msg = AsynchronousMap[x[0]](x[1]) try: self.asynchook(msg) except Exception as err: # exception thrown by async message handler? # notify the user, but continue... sys.excepthook(*sys.exc_info()) # it's been processed, so don't process it again. self._asyncs.append(x) else: ## # Procotol violation. self.fatal = True self.error_message = element.ClientError(( (b'S', 'FATAL'), (b'C', '08P01'), (b'M', message_expectation( expected = tuple(paths[current_step].keys()), received = x[0] )), )) self.state = Complete return count else: # Process a valid message. r = path(x[1]) processed.append(r) if next_step is not None: current_step = next_step else: current_step = 0 if r.type == READY_TYPE: self.last_ready = r.xact_state # Done with the current command. Increment the offset, and # try to process the new command with the remaining data. paths = None while paths is None: # Increment the offset past any commands # whose hook is None (FlushMessage) offset += 1 # If the offset is the length, # the transaction is complete. if offset == NCOMMANDS: # Done with transaction. break cmd = COMMANDS[offset] paths = HOOK[cmd.type] else: # More commands to process in this transaction. continue # The while loop was broken offset == len(self.commands) # So, that's all there is to this transaction. break # Push the messages onto the completed list if they # have not been put there already. if not self.completed or self.completed[-1][0] != id(messages): self.completed.append((id(messages), processed)) # Store the state for the next transition. self.last = (messages, self.last[2], (offset, current_step),) if offset == NCOMMANDS: # transaction complete. self.state = Complete elif cmd.type in SWITCH_TYPES and processed: # Check the context to identify if the state should be # switched to an optimized processor. last = processed[-1] if last.__class__ is bytes: # Fast path for COPY data, 'd' messages. self.state = (Receiving, self.put_copydata) elif last.__class__ is tuple: # Fast path for Tuples, 'D' messages. self.state = (Receiving, self.put_tupledata) elif last.type == element.CopyFromBegin.type: # In this case, the commands that were sent past # message starting the COPY, need to be re-issued # once the COPY is complete. PG cleared its buffer. self.CopyFailSequence = (self.CopyFailMessage,) + \ self.commands[offset+1:] self.CopyDoneSequence = (element.CopyDoneMessage,) + \ self.commands[offset+1:] self.state = (Sending, self.sent_from_stdin) elif last.type == element.CopyToBegin.type: # Should be seeing COPY data soon. self.state = (Receiving, self.put_copydata) return count def put_copydata(self, messages): """ In the context of a copy, `put_copydata` is used as a fast path for storing `element.CopyData` messages. When a non-`element.CopyData.type` message is received, it reverts the ``state`` attribute back to `standard_put` to process the message-sequence. """ copydata = element.CopyData.type # "Fail" quickly if the last message is not copy data. if messages[-1][0] != copydata: self.state = (Receiving, self.standard_put) return self.standard_put(messages) lines = [x[1] for x in messages if x[0] == copydata] if len(lines) != len(messages): self.state = (Receiving, self.standard_put) return self.standard_put(messages) if not self.completed or self.completed[-1][0] != id(messages): self.completed.append((id(messages), lines)) self.last = (messages, self.last[2], self.last[2],) return len(messages) try: def put_tupledata(self, messages, consume = consume_tuple_messages, ): tuplemessages = consume(messages) if not tuplemessages: # bad handler switch? self.state = (Receiving, self.standard_put) return self.standard_put(messages) if not self.completed or self.completed[-1][0] != id(messages): self.completed.append(((id(messages), tuplemessages))) self.last = (messages, self.last[2], self.last[2],) return len(tuplemessages) except NameError: ## # No consume_tuple_messages function. def put_tupledata(self, messages, p = element.Tuple.parse, t = element.Tuple.type, ): """ Fast path used when inside an Execute command. As soon as tuple data is seen. """ # Fallback to `standard_put` quickly if the last # message is not tuple data. if messages[-1][0] is not t: self.state = (Receiving, self.standard_put) return self.standard_put(messages) tuplemessages = [p(x[1]) for x in messages if x[0] == t] if len(tuplemessages) != len(messages): self.state = (Receiving, self.standard_put) return self.standard_put(messages) if not self.completed or self.completed[-1][0] != id(messages): self.completed.append(((id(messages), tuplemessages))) self.last = (messages, self.last[2], self.last[2],) return len(messages) def standard_sent(self): """ Empty messages and switch state to receiving. This is called by the user after the `messages` have been sent to the remote end. That is, this merely finalizes the "Sending" state. """ self.messages = () self.state = (Receiving, self.standard_put) sent = standard_sent def sent_from_stdin(self): """ The state method for sending copy data. After each call to `sent_from_stdin`, the `messages` attribute is set to a `CopyFailSequence`. This sequence of messages assures that the COPY will be properly terminated. If new copy data is not provided, or `messages` is *not* set to `CopyDoneSequence`, the transaction will instruct the remote end to cause the COPY to fail. """ if self.messages is self.CopyDoneSequence or \ self.messages is self.CopyFailSequence: # If the last sent `messages` is CopyDone or CopyFail, finish out the # transaction. ## self.messages = () self.state = (Receiving, self.standard_put) else: ## # Initialize to CopyFail, if the messages attribute is not # set properly before each invocation, the transaction is # being misused and will be terminated. self.messages = self.CopyFailSequence fe-1.1.0/postgresql/python/000077500000000000000000000000001203372773200156355ustar00rootroot00000000000000fe-1.1.0/postgresql/python/__init__.py000066400000000000000000000001131203372773200177410ustar00rootroot00000000000000""" Python tools package. Various extensions to the standard library. """ fe-1.1.0/postgresql/python/command.py000066400000000000000000000374631203372773200176420ustar00rootroot00000000000000## # .python.command - Python command emulation module. ## """ Create and Execute Python Commands ================================== The purpose of this module is to simplify the creation of a Python command interface. Normally, one would want to do this if there is a *common* need for a certain Python environment that may be, at least, partially initialized via command line options. A notable case would be a Python environment with a database connection whose connection parameters came from the command line. That is, Python + command line driven configuration. The module also provides an extended interactive console that provides backslash commands for editing and executing temporary files. Use ``python -m pythoncommand`` to try it out. Simple usage:: import sys import os import optparse import pythoncommand as pycmd op = optparse.OptionParser( "%prog [options] [script] [script arguments]", version = '1.0', ) op.disable_interspersed_args() # Basically, the standard -m and -c. (Some additional ones for fun) op.add_options(pycmd.default_optparse_options) co, ca = op.parse_args(args[1:]) # This initializes an execution instance which gathers all the information # about the code to be ran when ``pyexe`` is called. pyexe = pycmd.Execution(ca, context = getattr(co, 'python_context', ()), loader = getattr(co, 'python_main', None), ) # And run it. Any exceptions will be printed via print_exception. rv = pyexe() sys.exit(rv) """ import os import sys import re import code import types import optparse import subprocess import contextlib from gettext import gettext as _ from traceback import print_exception from pkgutil import get_loader as module_loader class single_loader(object): """ used for "loading" string modules(think -c) """ def __init__(self, source): self.source = source def get_filename(self, fullpath): if fullpath == self.source: return '' def get_code(self, fullpath): if fullpath == self.source: return compile(self.source, '', 'exec') def get_source(self, fullpath): if fullpath == self.source: return self.source class file_loader(object): """ used for "loading" scripts """ def __init__(self, filepath, fileobj = None): self.filepath = filepath if fileobj is not None: self._source = fileobj.read() def get_filename(self, fullpath): if fullpath == self.filepath: return self.filepath def get_source(self, fullpath): if fullpath == self.filepath: return self._read() def _read(self): if hasattr(self, '_source'): return self._source f = open(self.filepath) try: return f.read() finally: f.close() def get_code(self, fullpath): if fullpath != self.filepath: return return compile(self._read(), self.filepath, 'exec') def extract_filepath(x): if x.startswith('file://'): return x[7:] return None def extract_module(x): if x.startswith('module:'): return x[7:] return None module_loader_descriptor = ( 'Python module', module_loader, extract_module ) file_loader_descriptor = ( 'Python script', file_loader, extract_filepath ) single_loader_descriptor = ( 'Python command', single_loader, lambda x: x ) _directory = ( module_loader_descriptor, file_loader_descriptor, ) directory = list(_directory) def find_loader(ident, dir = directory): for x in dir: xid = x[2](ident) if xid is not None: return x ## # optparse options ## def append_context(option, opt_str, value, parser): """ Add some context to the execution of the Python code using loader module's directory list of loader descriptions. If no loader can be found, assume it's a Python command. """ pc = getattr(parser.values, option.dest, None) or [] if not pc: setattr(parser.values, option.dest, pc) ldesc = find_loader(value) if ldesc is None: ldesc = single_loader_descriptor pc.append((value, ldesc)) def set_python_main(option, opt_str, value, parser): """ Set the main Python code; after contexts are initialized, main is ran. """ main = (value, option.python_loader) setattr(parser.values, option.dest, main) # only terminate parsing if not interspersing arguments if not parser.allow_interspersed_args: parser.rargs.insert(0, '--') context = optparse.make_option( '-C', '--context', help = _('Python context code to run[file://,module:,]'), dest = 'python_context', action = 'callback', callback = append_context, type = 'str' ) module = optparse.make_option( '-m', help = _('Python module to run as script(__main__)'), dest = 'python_main', action = 'callback', callback = set_python_main, type = 'str' ) module.python_loader = module_loader_descriptor command = optparse.make_option( '-c', help = _('Python expression to run(__main__)'), dest = 'python_main', action = 'callback', callback = set_python_main, type = 'str' ) command.python_loader = single_loader_descriptor default_optparse_options = [ context, module, command, ] class ExtendedConsole(code.InteractiveConsole): """ Console subclass providing some convenient backslash commands. """ def __init__(self, *args, **kw): import tempfile self.mktemp = tempfile.mktemp import shlex self.split = shlex.split code.InteractiveConsole.__init__(self, *args, **kw) self.bsc_map = {} self.temp_files = {} self.past_buffers = [] self.register_backslash(r'\?', self.showhelp, "Show this help message.") self.register_backslash(r'\set', self.bs_set, "Configure environment variables. \set without arguments to show all") self.register_backslash(r'\E', self.bs_E, "Edit a file or a temporary script.") self.register_backslash(r'\i', self.bs_i, "Execute a Python script within the interpreter's context.") self.register_backslash(r'\e', self.bs_e, "Edit and Execute the file directly in the context.") self.register_backslash(r'\x', self.bs_x, "Execute the Python command within this process.") def interact(self, *args, **kw): self.showhelp(None, None) return super().interact(*args,**kw) def showtraceback(self): e, v, tb = sys.exc_info() sys.last_type, sys.last_value, sys.last_traceback = e, v, tb print_exception(e, v, tb.tb_next or tb) def register_backslash(self, bscmd, meth, doc): self.bsc_map[bscmd] = (meth, doc) def execslash(self, line): """ If push() gets a line that starts with a backslash, execute the command that the backslash sequence corresponds to. """ cmd = line.split(None, 1) cmd.append('') bsc = self.bsc_map.get(cmd[0]) if bsc is None: self.write("ERROR: unknown backslash command: %s%s"%(cmd, os.linesep)) else: return bsc[0](cmd[0], cmd[1]) def showhelp(self, cmd, arg): i = list(self.bsc_map.items()) i.sort(key = lambda x: x[0]) helplines = os.linesep.join([ ' %s%s%s' %( x[0], ' ' * (8 - len(x[0])), x[1][1] ) for x in i ]) self.write("Backslash Commands:%s%s%s" %( os.linesep*2, helplines, os.linesep*2 )) def bs_set(self, cmd, arg): """ Set a value in the interpreter's environment. """ if arg: for x in self.split(arg): if '=' in x: k, v = x.split('=', 1) os.environ[k] = v self.write("%s=%s%s" %(k, v, os.linesep)) elif x: self.write("%s=%s%s" %(x, os.environ.get(x, ''), os.linesep)) else: for k,v in os.environ.items(): self.write("%s=%s%s" %(k, v, os.linesep)) def resolve_path(self, path, dont_create = False): """ Get the path of the given string; if the path is not absolute and does not contain path separators, identify it as a temporary file. """ if not os.path.isabs(path) and not os.path.sep in path: # clean it up to avoid typos path = path.strip().lower() tmppath = self.temp_files.get(path) if tmppath is None: if dont_create is False: tmppath = self.mktemp( suffix = '.py', prefix = '_console_%s_' %(path,) ) self.temp_files[path] = tmppath else: return path return tmppath return path def execfile(self, filepath): src = open(filepath) try: try: co = compile(src.read(), filepath, 'exec') except SyntaxError: co = None print_exception(*sys.exc_info()) finally: src.close() if co is not None: try: exec(co, self.locals, self.locals) except: e, v, tb = sys.exc_info() print_exception(e, v, tb.tb_next or tb) def editfiles(self, filepaths): sp = list(filepaths) # ;) sp.insert(0, os.environ.get('EDITOR', 'vi')) return subprocess.call(sp) def bs_i(self, cmd, arg): 'execute the files' for x in self.split(arg) or ('',): p = self.resolve_path(x, dont_create = True) self.execfile(p) def bs_E(self, cmd, arg): 'edit the files, but *only* edit them' self.editfiles([self.resolve_path(x) for x in self.split(arg) or ('',)]) def bs_e(self, cmd, arg): 'edit *and* execute the files' filepaths = [self.resolve_path(x) for x in self.split(arg) or ('',)] self.editfiles(filepaths) for x in filepaths: self.execfile(x) def bs_x(self, cmd, arg): rv = -1 if len(cmd) > 1: a = self.split(arg) a.insert(0, '\\x') try: rv = command(argv = a) except SystemExit as se: rv = se.code self.write("[Return Value: %d]%s" %(rv, os.linesep)) def push(self, line): # Has to be a ps1 context. if not self.buffer and line.startswith('\\'): try: self.execslash(line) except: # print the exception, but don't raise. e, v, tb = sys.exc_info() print_exception(e, v, tb.tb_next or tb) else: return code.InteractiveConsole.push(self, line) @contextlib.contextmanager def postmortem(funcpath): if not funcpath: yield None else: pm = funcpath.split('.') attr = pm.pop(-1) modpath = '.'.join(pm) try: m = __import__(modpath, fromlist = modpath) pmobject = getattr(m, attr, None) except ValueError: pmobject = None sys.stderr.write( "%sERROR: no object at %r for postmortem%s"%( os.linesep, funcpath, os.linesep ) ) try: yield None except: try: sys.last_type, sys.last_value, sys.last_traceback = sys.exc_info() pmobject() except: sys.stderr.write( "[Exception raised by Postmortem]" + os.linesep ) print_exception(*sys.exc_info()) raise class Execution(object): """ Given argv and context make an execution instance that, when called, will execute the configured Python code. This class provides the ability to identify what the main part of the execution of the configured Python code. For instance, shall it execute a console, the file that the first argument points to, a -m option module appended to the python_context option value, or the code given within -c? """ def __init__(self, args, context = (), main = None, loader = None, stdin = sys.stdin ): """ args The arguments passed to the script; usually sys.argv after being processed by optparse(ca). context A list of loader descriptors that will be used to establish the context of __main__ module. main Overload to explicitly state what main is. None will cause the class to attempt to fill in the attribute using 'args' and other system objects like sys.stdin. """ self.args = args self.context = context and list(context) or () if main is not None: self.main = main elif loader is not None: # Main explicitly stated, resolve the path and the loader path, ldesc = loader ltitle, rloader, xpath = ldesc l = rloader(path) if l is None: raise ImportError( "%s %r does not exist or cannot be read" %( ltitle, path ) ) self.main = (path, l) # If there are args, but no main, run the first arg. elif args: fp = self.args[0] f = open(fp) try: l = file_loader(fp, fileobj = f) finally: f.close() self.main = (self.args[0], l) self.args = self.args[1:] # There is no main, no loader, and no args. # If stdin is not a tty, use stdin as the main file. elif not stdin.isatty(): l = file_loader('', fileobj = stdin) self.main = ('', l) # tty and no "main". else: # console self.main = (None, None) self.reset_module__main__() def reset_module__main__(self): mod = types.ModuleType('__main__') mod.__builtins__ = __builtins__ mod.__package__ = None self.module__main__ = mod path = getattr(self.main[1], 'fullname', None) if path is not None: mod.__package__ = '.'.join(path.split('.')[:-1]) def _call(self, console = ExtendedConsole, context = None ): """ Initialize the context and run main in the given locals (Note: tramples on sys.argv, __main__ in sys.modules) (Use __call__ instead) """ sys.modules['__main__'] = self.module__main__ md = self.module__main__.__dict__ # Establish execution context in the locals; # iterate over all the loaders in self.context and for path, ldesc in self.context: ltitle, loader, xpath = ldesc rpath = xpath(path) li = loader(rpath) if li is None: sys.stderr.write( "%s %r does not exist or cannot be read%s" %( ltitle, rpath, os.linesep ) ) return 1 try: code = li.get_code(rpath) except: print_exception(*sys.exc_info()) return 1 self.module__main__.__file__ = getattr( li, 'get_filename', lambda x: x )(rpath) self.module__main__.__loader__ = li try: exec(code, md, md) except: e, v, tb = sys.exc_info() print_exception(e, v, tb.tb_next or tb) return 1 if self.main == (None, None): # It's interactive. sys.argv = self.args or [''] # Use readline if available try: import readline except ImportError: pass ic = console(locals = md) try: ic.interact() except SystemExit as e: return e.code return 0 else: # It's ultimately a code object. path, loader = self.main self.module__main__.__file__ = getattr( loader, 'get_filename', lambda x: x )(path) sys.argv = list(self.args) sys.argv.insert(0, self.module__main__.__file__) try: code = loader.get_code(path) except: print_exception(*sys.exc_info()) return 1 rv = 0 exe_exception = False try: if context is not None: with context: try: exec(code, md, md) except: exe_exception = True raise else: try: exec(code, md, md) except: exe_exception = True raise except SystemExit as e: # Assume it's an exe_exception as anything ran in `context` # shouldn't cause an exception. rv = e.code e, v, tb = sys.exc_info() sys.last_type = e sys.last_value = v sys.last_traceback = (tb.tb_next or tb) except: if exe_exception is False: raise rv = 1 e, v, tb = sys.exc_info() print_exception(e, v, tb.tb_next or tb) sys.last_type = e sys.last_value = v sys.last_traceback = (tb.tb_next or tb) return rv def __call__(self, *args, **kw): storage = ( sys.modules.get('__context__'), sys.modules.get('__main__'), sys.argv, os.environ.copy(), ) try: return self._call(*args, **kw) finally: sys.modules['__context__'], \ sys.modules['__main__'], \ sys.argv, os.environ = storage def get_main_source(self): """ Get the execution's "__main__" source. Useful for configuring environmental options derived from "magic" lines. """ path, loader = self.main if path is not None: return loader.get_source(path) def command_execution(argv = sys.argv): 'create an execution using the given argv' # The pwd should be in the path for python commands. # setuptools' console_scripts appear to strip this out. if '' not in sys.path: sys.path.insert(0, '') op = optparse.OptionParser( "%prog [options] [script] [script arguments]", version = '1.0', ) op.disable_interspersed_args() op.add_options(default_optparse_options) co, ca = op.parse_args(argv[1:]) return Execution(ca, context = getattr(co, 'python_context', ()), loader = getattr(co, 'python_main', None), ) def command(argv = sys.argv): return command_execution(argv = argv)( context = postmortem(os.environ.get('PYTHON_POSTMORTEM')) ) if __name__ == '__main__': sys.exit(command()) ## # vim: ts=3:sw=3:noet: fe-1.1.0/postgresql/python/datetime.py000066400000000000000000000023661203372773200200120ustar00rootroot00000000000000## # python.datetime - parts needed to use stdlib.datetime ## import datetime ## # stdlib.datetime representation of PostgreSQL 'infinity' and '-infinity'. infinity_datetime = datetime.datetime(datetime.MAXYEAR, 12, 31, 23, 59, 59, 999999) negative_infinity_datetime = datetime.datetime(datetime.MINYEAR, 1, 1, 0, 0, 0, 0) infinity_date = datetime.date(datetime.MAXYEAR, 12, 31) negative_infinity_date = datetime.date(datetime.MINYEAR, 1, 1) class FixedOffset(datetime.tzinfo): def __init__(self, offset_in_seconds, tzname = None): self._tzname = tzname self._offset = offset_in_seconds self._offset_in_mins = offset_in_seconds // 60 self._td_offset = datetime.timedelta(0, self._offset_in_mins * 60) self._dst = datetime.timedelta(0) def utcoffset(self, offset_from): return self._td_offset def tzname(self, dt): return self._tzname def dst(self, arg): return self._dst def __repr__(self): return "{path}.{name}({off}{tzname})".format( path = type(self).__module__, name = type(self).__name__, off = repr(self._td_offset.days * 24 * 60 * 60 + self._td_offset.seconds), tzname = ( ", tzname = {tzname!r}".format(tzname = self._tzname) \ if self._tzname is not None else "" ) ) UTC = FixedOffset(0, tzname = 'UTC') fe-1.1.0/postgresql/python/decorlib.py000066400000000000000000000022551203372773200177760ustar00rootroot00000000000000## # .python.decorlib ## """ common decorators """ import os import types def propertydoc(ap): """ Helper function for extracting an `abstractproperty`'s real documentation. """ doc = "" rstr = "" if ap.fget: ret = ap.fget.__annotations__.get('return') if ret is not None: rstr = " -> " + repr(ret) if ap.fget.__doc__: doc += os.linesep*2 + "GET::" + (os.linesep + ' '*4) + (os.linesep + ' '*4).join( [x.strip() for x in ap.fget.__doc__.strip().split(os.linesep)] ) if ap.fset and ap.fset.__doc__: doc += os.linesep*2 + "SET::" + (os.linesep + ' '*4) + (os.linesep + ' '*4).join( [x.strip() for x in ap.fset.__doc__.strip().split(os.linesep)] ) if ap.fdel and ap.fdel.__doc__: doc += os.linesep*2 + "DELETE::" + (os.linesep + ' '*4) + (os.linesep + ' '*4).join( [x.strip() for x in ap.fdel.__doc__.strip().split(os.linesep)] ) ap.__doc__ = "" if not doc else ( "Abstract Property" + rstr + doc ) return ap class method(object): __slots__ = ('callable',) def __init__(self, callable): self.callable = callable def __get__(self, val, typ): if val is None: return self.callable return types.MethodType(self.callable, val) fe-1.1.0/postgresql/python/doc.py000066400000000000000000000005711203372773200167570ustar00rootroot00000000000000## # .python.doc ## """ Documentation Tools. """ from operator import attrgetter class Doc(object): """ Simple object that sets the __doc__ attribute to the first parameter and initializes __annotations__ using keyword arguments. """ def __init__(self, doc, **annotations): self.__doc__ = str(doc) self.__annotations__ = annotations __str__ = attrgetter('__doc__') fe-1.1.0/postgresql/python/element.py000066400000000000000000000121541203372773200176430ustar00rootroot00000000000000## # .python.element ## import os from abc import ABCMeta, abstractproperty, abstractmethod from .string import indent from .decorlib import propertydoc class RecursiveFactor(Exception): 'Raised when a factor is ultimately composed of itself' pass class Element(object, metaclass = ABCMeta): """ The purpose of an element is to provide a general mechanism for specifying the factors that composed an object. Factors are designated using an ordered set of strings referencing those significant attributes on the object. Factors are important for PG-API as it provides the foundation for collecting the information about the state of the interface that ultimately led up to an error. Traceback: ... postgresql.exceptions.*: CODE: XX000 CURSOR: parameters: (p1, p2, ...) STATEMENT: ... string: SYMBOL: get_types LIBRARY: catalog ... CONNECTION: CONNECTOR: [Host] IRI: pq://user@localhost:5432/database DRIVER: postgresql.driver.pq3 """ @propertydoc @abstractproperty def _e_label(self) -> str: """ Single-word string describing the kind of element. For instance, `postgresql.api.Statement`'s _e_label is 'STATEMENT'. Usually, this is set directly on the class itself, and is a shorter version of the class's name. """ @propertydoc @abstractproperty def _e_factors(self) -> (): """ The attribute names of the objects that contributed to the creation of this object. The ordering is significant. The first factor is the prime factor. """ @abstractmethod def _e_metas(self) -> [(str, object)]: """ Return an iterable to key-value pairs that provide useful descriptive information about an attribute. Factors on metas are not checked. They are expected to be primitives. If there are no metas, the str() of the object will be used to represent it. """ class ElementSet(Element, set): """ An ElementSet is a set of Elements that can be used as an individual factor. In situations where a single factor is composed of multiple elements where each has no significance over the other, this Element can be used represent that fact. Importantly, it provides the set metadata so that the appropriate information will be produced in element tracebacks. """ _e_label = 'SET' _e_factors = () __slots__ = () def _e_metas(self): yield (None, len(self)) for x in self: yield (None, '--') yield (None, format_element(x)) def prime_factor(obj): 'get the primary factor on the `obj`, returns None if none.' f = getattr(obj, '_e_factors', None) if f: return f[0], getattr(obj, f[0], None) def prime_factors(obj): """ Yield out the sequence of primary factors of the given object. """ visited = set((obj,)) ef = getattr(obj, '_e_factors', None) if not ef: return fn = ef[0] e = getattr(obj, fn, None) if e in visited: raise RecursiveFactor(obj, e) visited.add(e) yield fn, e while e is not None: ef = getattr(obj, '_e_factors', None) fn = ef[0] e = getattr(e, fn, None) if e in visited: raise RecursiveFactor(obj, e) visited.add(e) yield fn, e def format_element(obj, coverage = ()): 'format the given element with its factors and metadata into a readable string' # if it's not an Element, all there is to return is str(obj) if obj in coverage: raise RecursiveFactor(coverage) coverage = coverage + (obj,) if not isinstance(obj, Element): if obj is None: return 'None' return str(obj) # The description of `obj` is built first. # formal element, get metas first. nolead = False metas = [] for key, val in obj._e_metas(): m = "" if val is None: sval = 'None' else: sval = str(val) pre = ' ' if key is not None: m += key + ':' if (len(sval) > 70 or os.linesep in sval): pre = os.linesep sval = indent(sval) else: # if the key is None, it is intended to be inlined. nolead = True pre = '' m += pre + sval.rstrip() metas.append(m) factors = [] for att in obj._e_factors[1:]: m = "" f = getattr(obj, att) # if the object has a label, use the label m += att + ':' sval = format_element(f, coverage = coverage) if len(sval) > 70 or os.linesep in sval: m += os.linesep + indent(sval) else: m += ' ' + sval factors.append(m) mtxt = os.linesep.join(metas) ftxt = os.linesep.join(factors) if mtxt: mtxt = indent(mtxt) if ftxt: ftxt = indent(ftxt) s = mtxt + ftxt if nolead is True: # metas started with a `None` key. s = ' ' + s.lstrip() else: s = os.linesep + s s = obj._e_label + ':' + s.rstrip() # and resolve the next prime pf = prime_factor(obj) if pf is not None: factor_name, prime = pf factor = format_element(prime, coverage = coverage) if getattr(prime, '_e_label', None) is not None: # if the factor has a label, then it will be # included in the format_element output, and # thus factor_name is not needed. factor_name = '' else: factor_name += ':' if len(factor) > 70 or os.linesep in factor: factor = os.linesep + indent(factor) else: factor_name += ' ' s += os.linesep + factor_name + factor return s fe-1.1.0/postgresql/python/functools.py000066400000000000000000000030231203372773200202210ustar00rootroot00000000000000## # python.functools ## import sys from .decorlib import method def rsetattr(attr, val, ob): """ setattr() and return `ob`. Different order used to allow easier partial usage. """ setattr(ob, attr, val) return ob try: from ..port.optimized import rsetattr except ImportError: pass class Composition(tuple): def __call__(self, r): for x in self: r = x(r) return r try: from ..port.optimized import compose __call__ = method(compose) del compose except ImportError: pass try: # C implementation of the tuple processors. from ..port.optimized import process_tuple, process_chunk except ImportError: def process_tuple(procs, tup, exception_handler, len = len, tuple = tuple, cause = None): """ Call each item in `procs` with the corresponding item in `tup` returning the result as `type`. If an item in `tup` is `None`, don't process it. If a give transformation failes, call the given exception_handler. """ i = len(procs) if len(tup) != i: raise TypeError( "inconsistent items, %d processors and %d items in row" %( i, len(tup) ) ) r = [None] * i try: for i in range(i): ob = tup[i] if ob is None: continue r[i] = procs[i](ob) except Exception: cause = sys.exc_info()[1] if cause is not None: exception_handler(cause, procs, tup, i) raise RuntimeError("process_tuple exception handler failed to raise") return tuple(r) def process_chunk(procs, tupc, fail, process_tuple = process_tuple): return [process_tuple(procs, x, fail) for x in tupc] fe-1.1.0/postgresql/python/itertools.py000066400000000000000000000016341203372773200202370ustar00rootroot00000000000000## # .python.itertools ## """ itertools extensions """ import collections from itertools import cycle, islice def interlace(*iters, next = next) -> collections.Iterable: """ interlace(i1, i2, ..., in) -> ( i1-0, i2-0, ..., in-0, i1-1, i2-1, ..., in-1, . . . i1-n, i2-n, ..., in-n, ) """ return map(next, cycle([iter(x) for x in iters])) def chunk(iterable, chunksize = 256): """ Given an iterable, return an iterable producing chunks of the objects produced by the given iterable. chunks([o1,o2,o3,o4], chunksize = 2) -> [ [o1,o2], [o3,o4], ] """ iterable = iter(iterable) last = () lastsize = chunksize while lastsize == chunksize: last = list(islice(iterable, chunksize)) lastsize = len(last) yield last def find(iterable, selector): """ Return the first item in the `iterable` that causes the `selector` to return `True`. """ for x in iterable: if selector(x): return x fe-1.1.0/postgresql/python/msw.py000066400000000000000000000004071203372773200170160ustar00rootroot00000000000000## # .python.msw ## """ Additional Microsoft Windows tools. """ # for Popen(), not supported on windows close_fds = False def platform_exe(name): """ Append '.exe' if it's not already there. """ if name.endswith('.exe'): return name return name + '.exe' fe-1.1.0/postgresql/python/os.py000066400000000000000000000014051203372773200166300ustar00rootroot00000000000000## # .python.os ## """ General OS abstractions and information. """ import sys import os #: By default, close the FDs on subprocess.Popen(). close_fds = True #: By default, there is no modification for executable references. platform_exe = str def find_file(basename, paths, join = os.path.join, exists = os.path.exists, ): """ Find the file in the given paths. Return the first path that exists. """ for x in paths: path = join(x, basename) if exists(path): return path if sys.platform in ('win32','win64'): # replace variants for windows from .msw import close_fds, platform_exe def find_executable(basename, pathsep = os.pathsep, platexe = platform_exe): paths = os.environ.get('PATH', '').split(pathsep) return find_file(platexe(basename), paths) fe-1.1.0/postgresql/python/socket.py000066400000000000000000000057371203372773200175130ustar00rootroot00000000000000## # .python.socket - additional tools for working with sockets ## import sys import os import random import socket import math import errno import ssl __all__ = ['find_available_port', 'SocketFactory'] class SocketFactory(object): """ Object used to create a socket and connect it. This is, more or less, a specialized partial() for socket creation. Additionally, it provides methods and attributes for abstracting exception management on socket operation. """ timeout_exception = socket.timeout fatal_exception = socket.error try_again_exception = socket.error def timed_out(self, err) -> bool: return err.__class__ is self.timeout_exception @staticmethod def try_again(err, codes = (errno.EAGAIN, errno.EINTR, errno.EWOULDBLOCK, errno.ETIMEDOUT)) -> bool: """ Does the error indicate that the operation should be tried again? More importantly, the connection is *not* dead. """ errno = getattr(err, 'errno', None) if errno is None: return False return errno in codes @classmethod def fatal_exception_message(typ, err) -> (str, None): """ If the exception was fatal to the connection, what message should be given to the user? """ if typ.try_again(err): return None return getattr(err, 'strerror', '') def secure(self, socket : socket.socket) -> ssl.SSLSocket: "secure a socket with SSL" if self.socket_secure is not None: return ssl.wrap_socket(socket, **self.socket_secure) else: return ssl.wrap_socket(socket) def __call__(self, timeout = None): s = socket.socket(*self.socket_create) try: s.settimeout(float(timeout) if timeout is not None else None) s.connect(self.socket_connect) s.settimeout(None) except Exception: s.close() raise return s def __init__(self, socket_create : "positional parameters given to socket.socket()", socket_connect : "parameter given to socket.connect()", socket_secure : "keywords given to ssl.wrap_socket" = None, ): self.socket_create = socket_create self.socket_connect = socket_connect self.socket_secure = socket_secure def __str__(self): return 'socket' + repr(self.socket_connect) def find_available_port( interface : "attempt to bind to interface" = 'localhost', address_family : "address family to use (default: AF_INET)" = socket.AF_INET, limit : "Number tries to make before giving up" = 1024, port_range = (6600, 56600) ) -> (int, None): """ Find an available port on the given interface for the given address family. Returns a port number that was successfully bound to or `None` if the attempt limit was reached. """ i = 0 while i < limit: i += 1 port = ( math.floor( random.random() * (port_range[1] - port_range[0]) ) + port_range[0] ) s = socket.socket(address_family, socket.SOCK_STREAM,) try: s.bind(('localhost', port)) s.close() except socket.error as e: s.close() if e.errno in (errno.EACCES, errno.EADDRINUSE, errno.EINTR): # try again continue break else: port = None return port fe-1.1.0/postgresql/python/string.py000066400000000000000000000002611203372773200175140ustar00rootroot00000000000000## # .python.string ## import os def indent(s, level = 2, char = ' '): ind = char * level r = "" for x in s.splitlines(): r += ((ind + x).rstrip() + os.linesep) return r fe-1.1.0/postgresql/python/structlib.py000066400000000000000000000055771203372773200202400ustar00rootroot00000000000000## # .python.structlib - module for extracting serialized data ## import struct from .functools import Composition as compose null_sequence = b'\xff\xff\xff\xff' # Always to and from network order. # Create a pair, (pack, unpack) for the given `struct` format.' def mk_pack(x): s = struct.Struct('!' + x) if len(x) > 1: def pack(y, p = s.pack): return p(*y) return (pack, s.unpack_from) else: def unpack(y, p = s.unpack_from): return p(y)[0] return (s.pack, unpack) byte_pack, byte_unpack = lambda x: bytes((x,)), lambda x: x[0] double_pack, double_unpack = mk_pack("d") float_pack, float_unpack = mk_pack("f") dd_pack, dd_unpack = mk_pack("dd") ddd_pack, ddd_unpack = mk_pack("ddd") dddd_pack, dddd_unpack = mk_pack("dddd") LH_pack, LH_unpack = mk_pack("LH") lH_pack, lH_unpack = mk_pack("lH") llL_pack, llL_unpack = mk_pack("llL") qll_pack, qll_unpack = mk_pack("qll") dll_pack, dll_unpack = mk_pack("dll") dl_pack, dl_unpack = mk_pack("dl") ql_pack, ql_unpack = mk_pack("ql") hhhh_pack, hhhh_unpack = mk_pack("hhhh") longlong_pack, longlong_unpack = mk_pack("q") ulonglong_pack, ulonglong_unpack = mk_pack("Q") # Optimizations for int2, int4, and int8. try: from ..port import optimized as opt from sys import byteorder as bo if bo == 'little': short_unpack = opt.swap_int2_unpack short_pack = opt.swap_int2_pack ushort_unpack = opt.swap_uint2_unpack ushort_pack = opt.swap_uint2_pack long_unpack = opt.swap_int4_unpack long_pack = opt.swap_int4_pack ulong_unpack = opt.swap_uint4_unpack ulong_pack = opt.swap_uint4_pack if hasattr(opt, 'uint8_pack'): longlong_unpack = opt.swap_int8_unpack longlong_pack = opt.swap_int8_pack ulonglong_unpack = opt.swap_uint8_unpack ulonglong_pack = opt.swap_uint8_pack elif bo == 'big': short_unpack = opt.int2_unpack short_pack = opt.int2_pack ushort_unpack = opt.uint2_unpack ushort_pack = opt.uint2_pack long_unpack = opt.int4_unpack long_pack = opt.int4_pack ulong_unpack = opt.uint4_unpack ulong_pack = opt.uint4_pack if hasattr(opt, 'uint8_pack'): longlong_unpack = opt.int8_unpack longlong_pack = opt.int8_pack ulonglong_unpack = opt.uint8_unpack ulonglong_pack = opt.uint8_pack del bo, opt except ImportError: short_pack, short_unpack = mk_pack("h") ushort_pack, ushort_unpack = mk_pack("H") long_pack, long_unpack = mk_pack("l") ulong_pack, ulong_unpack = mk_pack("L") def split_sized_data( data, ulong_unpack = ulong_unpack, null_field = 0xFFFFFFFF, len = len, errmsg = "insufficient data in field {0}, required {1} bytes, {2} remaining".format ): """ Given serialized record data, return a tuple of tuples of type Oids and attributes. """ v = memoryview(data) f = 1 while v: l = ulong_unpack(v) if l == null_field: v = v[4:] yield None continue l += 4 d = v[4:l].tobytes() if len(d) < l-4: raise ValueError(errmsg(f, l - 4, len(d))) v = v[l:] f += 1 yield d fe-1.1.0/postgresql/release/000077500000000000000000000000001203372773200157345ustar00rootroot00000000000000fe-1.1.0/postgresql/release/__init__.py000066400000000000000000000001101203372773200200350ustar00rootroot00000000000000## # .release ## """ Release management code and project meta-data. """ fe-1.1.0/postgresql/release/distutils.py000066400000000000000000000125141203372773200203350ustar00rootroot00000000000000## # .release.distutils - distutils data ## """ Python distutils data provisions module. For sub-packagers, the `prefixed_packages` and `prefixed_extensions` functions should be of particular interest. If the distribution including ``py-postgresql`` uses the standard layout, chances are that `prefixed_extensions` and `prefixed_packages` will supply the appropriate information by default as they use `default_prefix` which is derived from the module's `__package__`. """ import sys import os from ..project import version, name, identity as url from distutils.core import Extension, Command LONG_DESCRIPTION = """ py-postgresql is a set of Python modules providing interfaces to various parts of PostgreSQL. Notably, it provides a pure-Python driver + C optimizations for querying a PostgreSQL database. http://python.projects.postgresql.org Features: * Prepared Statement driven interfaces. * Cluster tools for creating and controlling a cluster. * Support for most PostgreSQL types: composites, arrays, numeric, lots more. * COPY support. Sample PG-API Code:: >>> import postgresql >>> db = postgresql.open('pq://user:password@host:port/database') >>> db.execute("CREATE TABLE emp (emp_first_name text, emp_last_name text, emp_salary numeric)") >>> make_emp = db.prepare("INSERT INTO emp VALUES ($1, $2, $3)") >>> make_emp("John", "Doe", "75,322") >>> with db.xact(): ... make_emp("Jane", "Doe", "75,322") ... make_emp("Edward", "Johnson", "82,744") ... There is a DB-API 2.0 module as well:: postgresql.driver.dbapi20 However, PG-API is recommended as it provides greater utility. Once installed, try out the ``pg_python`` console script:: $ python3 -m postgresql.bin.pg_python -h localhost -p port -U theuser -d database_name If a successful connection is made to the remote host, it will provide a Python console with the database connection bound to the `db` name. History ------- py-postgresql is not yet another PostgreSQL driver, it's been in development for years. py-postgresql is the Python 3 port of the ``pg_proboscis`` driver and integration of the other ``pg/python`` projects. """ CLASSIFIERS = [ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', 'License :: OSI Approved :: MIT License', 'License :: OSI Approved :: Attribution Assurance License', 'License :: OSI Approved :: Python Software Foundation License', 'Natural Language :: English', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 3', 'Topic :: Database', ] subpackages = [ 'bin', 'encodings', 'lib', 'protocol', 'driver', 'test', 'documentation', 'python', 'port', 'release', # Modules imported from other packages. 'resolved', 'types', 'types.io', ] extensions_data = { 'port.optimized' : { 'sources' : [os.path.join('port', '_optimized', 'module.c')], }, } subpackage_data = { 'lib' : ['*.sql'], 'documentation' : ['*.txt'] } try: # :) if __package__ is not None: default_prefix = __package__.split('.')[:-1] else: default_prefix = __name__.split('.')[:-2] except NameError: default_prefix = ['postgresql'] def prefixed_extensions( prefix : "prefix to prepend to paths" = default_prefix, extensions_data : "`extensions_data`" = extensions_data, ) -> [Extension]: """ Generator producing the `distutils` `Extension` objects. """ pkg_prefix = '.'.join(prefix) + '.' path_prefix = os.path.sep.join(prefix) for mod, data in extensions_data.items(): yield Extension( pkg_prefix + mod, [os.path.join(path_prefix, src) for src in data['sources']], libraries = data.get('libraries', ()), optional = True, ) def prefixed_packages( prefix : "prefix to prepend to source paths" = default_prefix, packages = subpackages, ): """ Generator producing the standard `package` list prefixed with `prefix`. """ prefix = '.'.join(prefix) yield prefix prefix = prefix + '.' for pkg in packages: yield prefix + pkg def prefixed_package_data( prefix : "prefix to prepend to dictionary keys paths" = default_prefix, package_data = subpackage_data, ): """ Generator producing the standard `package` list prefixed with `prefix`. """ prefix = '.'.join(prefix) prefix = prefix + '.' for pkg, data in package_data.items(): yield prefix + pkg, data def standard_setup_keywords(build_extensions = True, prefix = default_prefix): """ Used by the py-postgresql distribution. """ d = { 'name' : name, 'version' : version, 'description' : 'PostgreSQL driver and tools library.', 'long_description' : LONG_DESCRIPTION, 'author' : 'James William Pye', 'author_email' : 'x@jwp.name', 'maintainer' : 'James William Pye', 'maintainer_email' : 'python-general@pgfoundry.org', 'url' : url, 'classifiers' : CLASSIFIERS, 'packages' : list(prefixed_packages(prefix = prefix)), 'package_data' : dict(prefixed_package_data(prefix = prefix)), 'cmdclass': dict(test=TestCommand), } if build_extensions: d['ext_modules'] = list(prefixed_extensions(prefix = prefix)) return d class TestCommand(Command): description = "run tests" # List of option tuples: long name, short name (None if no short # name), and help string. user_options = [] def initialize_options(self): pass def finalize_options(self): pass def run(self): import unittest unittest.main(module='postgresql.test.testall', argv=('setup.py',)) fe-1.1.0/postgresql/resolved/000077500000000000000000000000001203372773200161375ustar00rootroot00000000000000fe-1.1.0/postgresql/resolved/__init__.py000066400000000000000000000001131203372773200202430ustar00rootroot00000000000000""" Modules and packages resolved to avoid user dependency resolution. """ fe-1.1.0/postgresql/resolved/crypt.py000066400000000000000000000640741203372773200176650ustar00rootroot00000000000000# fcrypt.py """Unix crypt(3) password hash algorithm. This is a port to Python of the standard Unix password crypt function. It's a single self-contained source file that works with any version of Python from version 1.5 or higher. The code is based on Eric Young's optimised crypt in C. Python fcrypt is intended for users whose Python installation has not had the crypt module enabled, or whose C library doesn't include the crypt function. See the documentation for the Python crypt module for more information: http://www.python.org/doc/current/lib/module-crypt.html An alternative Python crypt module that uses the MD5 algorithm and is more secure than fcrypt is available from michal j wallace at: http://www.sabren.net/code/python/crypt/index.php3 The crypt() function is a one-way hash function, intended to hide a password such that the only way to find out the original password is to guess values until you get a match. If you need to encrypt and decrypt data, this is not the module for you. There are at least two packages providing Python cryptography support: M2Crypto at , and amkCrypto at . Functions: crypt() -- return hashed password """ __author__ = 'Carey Evans ' __version__ = '1.3.1' __date__ = '21 February 2004' __credits__ = '''michal j wallace for inspiring me to write this. Eric Young for the C code this module was copied from.''' __all__ = ['crypt'] # Copyright (C) 2000, 2001, 2004 Carey Evans # # Permission to use, copy, modify, and distribute this software and # its documentation for any purpose and without fee is hereby granted, # provided that the above copyright notice appear in all copies and # that both that copyright notice and this permission notice appear in # supporting documentation. # # CAREY EVANS DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, # INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO # EVENT SHALL CAREY EVANS BE LIABLE FOR ANY SPECIAL, INDIRECT OR # CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF # USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR # OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR # PERFORMANCE OF THIS SOFTWARE. # Based on C code by Eric Young (eay@mincom.oz.au), which has the # following copyright. Especially note condition 3, which imposes # extra restrictions on top of the standard Python license used above. # # The fcrypt.c source is available from: # ftp://ftp.psy.uq.oz.au/pub/Crypto/DES/ # ----- BEGIN fcrypt.c LICENSE ----- # # This library is free for commercial and non-commercial use as long as # the following conditions are aheared to. The following conditions # apply to all code found in this distribution, be it the RC4, RSA, # lhash, DES, etc., code; not just the SSL code. The SSL documentation # included with this distribution is covered by the same copyright terms # except that the holder is Tim Hudson (tjh@mincom.oz.au). # # Copyright remains Eric Young's, and as such any Copyright notices in # the code are not to be removed. # If this package is used in a product, Eric Young should be given attribution # as the author of the parts of the library used. # This can be in the form of a textual message at program startup or # in documentation (online or textual) provided with the package. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # 1. Redistributions of source code must retain the copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # 3. All advertising materials mentioning features or use of this software # must display the following acknowledgement: # "This product includes cryptographic software written by # Eric Young (eay@mincom.oz.au)" # The word 'cryptographic' can be left out if the rouines from the library # being used are not cryptographic related :-). # 4. If you include any Windows specific code (or a derivative thereof) from # the apps directory (application code) you must include an acknowledgement: # "This product includes software written by Tim Hudson (tjh@mincom.oz.au)" # # THIS SOFTWARE IS PROVIDED BY ERIC YOUNG ``AS IS'' AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF # SUCH DAMAGE. # # The licence and distribution terms for any publically available version or # derivative of this code cannot be changed. i.e. this code cannot simply be # copied and put under another distribution licence # [including the GNU Public Licence.] # # ----- END fcrypt.c LICENSE ----- import string, struct _ITERATIONS = 16 _SPtrans = ( # nibble 0 ( 0x00820200, 0x00020000, 0x80800000, 0x80820200, 0x00800000, 0x80020200, 0x80020000, 0x80800000, 0x80020200, 0x00820200, 0x00820000, 0x80000200, 0x80800200, 0x00800000, 0x00000000, 0x80020000, 0x00020000, 0x80000000, 0x00800200, 0x00020200, 0x80820200, 0x00820000, 0x80000200, 0x00800200, 0x80000000, 0x00000200, 0x00020200, 0x80820000, 0x00000200, 0x80800200, 0x80820000, 0x00000000, 0x00000000, 0x80820200, 0x00800200, 0x80020000, 0x00820200, 0x00020000, 0x80000200, 0x00800200, 0x80820000, 0x00000200, 0x00020200, 0x80800000, 0x80020200, 0x80000000, 0x80800000, 0x00820000, 0x80820200, 0x00020200, 0x00820000, 0x80800200, 0x00800000, 0x80000200, 0x80020000, 0x00000000, 0x00020000, 0x00800000, 0x80800200, 0x00820200, 0x80000000, 0x80820000, 0x00000200, 0x80020200 ), # nibble 1 ( 0x10042004, 0x00000000, 0x00042000, 0x10040000, 0x10000004, 0x00002004, 0x10002000, 0x00042000, 0x00002000, 0x10040004, 0x00000004, 0x10002000, 0x00040004, 0x10042000, 0x10040000, 0x00000004, 0x00040000, 0x10002004, 0x10040004, 0x00002000, 0x00042004, 0x10000000, 0x00000000, 0x00040004, 0x10002004, 0x00042004, 0x10042000, 0x10000004, 0x10000000, 0x00040000, 0x00002004, 0x10042004, 0x00040004, 0x10042000, 0x10002000, 0x00042004, 0x10042004, 0x00040004, 0x10000004, 0x00000000, 0x10000000, 0x00002004, 0x00040000, 0x10040004, 0x00002000, 0x10000000, 0x00042004, 0x10002004, 0x10042000, 0x00002000, 0x00000000, 0x10000004, 0x00000004, 0x10042004, 0x00042000, 0x10040000, 0x10040004, 0x00040000, 0x00002004, 0x10002000, 0x10002004, 0x00000004, 0x10040000, 0x00042000 ), # nibble 2 ( 0x41000000, 0x01010040, 0x00000040, 0x41000040, 0x40010000, 0x01000000, 0x41000040, 0x00010040, 0x01000040, 0x00010000, 0x01010000, 0x40000000, 0x41010040, 0x40000040, 0x40000000, 0x41010000, 0x00000000, 0x40010000, 0x01010040, 0x00000040, 0x40000040, 0x41010040, 0x00010000, 0x41000000, 0x41010000, 0x01000040, 0x40010040, 0x01010000, 0x00010040, 0x00000000, 0x01000000, 0x40010040, 0x01010040, 0x00000040, 0x40000000, 0x00010000, 0x40000040, 0x40010000, 0x01010000, 0x41000040, 0x00000000, 0x01010040, 0x00010040, 0x41010000, 0x40010000, 0x01000000, 0x41010040, 0x40000000, 0x40010040, 0x41000000, 0x01000000, 0x41010040, 0x00010000, 0x01000040, 0x41000040, 0x00010040, 0x01000040, 0x00000000, 0x41010000, 0x40000040, 0x41000000, 0x40010040, 0x00000040, 0x01010000 ), # nibble 3 ( 0x00100402, 0x04000400, 0x00000002, 0x04100402, 0x00000000, 0x04100000, 0x04000402, 0x00100002, 0x04100400, 0x04000002, 0x04000000, 0x00000402, 0x04000002, 0x00100402, 0x00100000, 0x04000000, 0x04100002, 0x00100400, 0x00000400, 0x00000002, 0x00100400, 0x04000402, 0x04100000, 0x00000400, 0x00000402, 0x00000000, 0x00100002, 0x04100400, 0x04000400, 0x04100002, 0x04100402, 0x00100000, 0x04100002, 0x00000402, 0x00100000, 0x04000002, 0x00100400, 0x04000400, 0x00000002, 0x04100000, 0x04000402, 0x00000000, 0x00000400, 0x00100002, 0x00000000, 0x04100002, 0x04100400, 0x00000400, 0x04000000, 0x04100402, 0x00100402, 0x00100000, 0x04100402, 0x00000002, 0x04000400, 0x00100402, 0x00100002, 0x00100400, 0x04100000, 0x04000402, 0x00000402, 0x04000000, 0x04000002, 0x04100400 ), # nibble 4 ( 0x02000000, 0x00004000, 0x00000100, 0x02004108, 0x02004008, 0x02000100, 0x00004108, 0x02004000, 0x00004000, 0x00000008, 0x02000008, 0x00004100, 0x02000108, 0x02004008, 0x02004100, 0x00000000, 0x00004100, 0x02000000, 0x00004008, 0x00000108, 0x02000100, 0x00004108, 0x00000000, 0x02000008, 0x00000008, 0x02000108, 0x02004108, 0x00004008, 0x02004000, 0x00000100, 0x00000108, 0x02004100, 0x02004100, 0x02000108, 0x00004008, 0x02004000, 0x00004000, 0x00000008, 0x02000008, 0x02000100, 0x02000000, 0x00004100, 0x02004108, 0x00000000, 0x00004108, 0x02000000, 0x00000100, 0x00004008, 0x02000108, 0x00000100, 0x00000000, 0x02004108, 0x02004008, 0x02004100, 0x00000108, 0x00004000, 0x00004100, 0x02004008, 0x02000100, 0x00000108, 0x00000008, 0x00004108, 0x02004000, 0x02000008 ), # nibble 5 ( 0x20000010, 0x00080010, 0x00000000, 0x20080800, 0x00080010, 0x00000800, 0x20000810, 0x00080000, 0x00000810, 0x20080810, 0x00080800, 0x20000000, 0x20000800, 0x20000010, 0x20080000, 0x00080810, 0x00080000, 0x20000810, 0x20080010, 0x00000000, 0x00000800, 0x00000010, 0x20080800, 0x20080010, 0x20080810, 0x20080000, 0x20000000, 0x00000810, 0x00000010, 0x00080800, 0x00080810, 0x20000800, 0x00000810, 0x20000000, 0x20000800, 0x00080810, 0x20080800, 0x00080010, 0x00000000, 0x20000800, 0x20000000, 0x00000800, 0x20080010, 0x00080000, 0x00080010, 0x20080810, 0x00080800, 0x00000010, 0x20080810, 0x00080800, 0x00080000, 0x20000810, 0x20000010, 0x20080000, 0x00080810, 0x00000000, 0x00000800, 0x20000010, 0x20000810, 0x20080800, 0x20080000, 0x00000810, 0x00000010, 0x20080010 ), # nibble 6 ( 0x00001000, 0x00000080, 0x00400080, 0x00400001, 0x00401081, 0x00001001, 0x00001080, 0x00000000, 0x00400000, 0x00400081, 0x00000081, 0x00401000, 0x00000001, 0x00401080, 0x00401000, 0x00000081, 0x00400081, 0x00001000, 0x00001001, 0x00401081, 0x00000000, 0x00400080, 0x00400001, 0x00001080, 0x00401001, 0x00001081, 0x00401080, 0x00000001, 0x00001081, 0x00401001, 0x00000080, 0x00400000, 0x00001081, 0x00401000, 0x00401001, 0x00000081, 0x00001000, 0x00000080, 0x00400000, 0x00401001, 0x00400081, 0x00001081, 0x00001080, 0x00000000, 0x00000080, 0x00400001, 0x00000001, 0x00400080, 0x00000000, 0x00400081, 0x00400080, 0x00001080, 0x00000081, 0x00001000, 0x00401081, 0x00400000, 0x00401080, 0x00000001, 0x00001001, 0x00401081, 0x00400001, 0x00401080, 0x00401000, 0x00001001 ), # nibble 7 ( 0x08200020, 0x08208000, 0x00008020, 0x00000000, 0x08008000, 0x00200020, 0x08200000, 0x08208020, 0x00000020, 0x08000000, 0x00208000, 0x00008020, 0x00208020, 0x08008020, 0x08000020, 0x08200000, 0x00008000, 0x00208020, 0x00200020, 0x08008000, 0x08208020, 0x08000020, 0x00000000, 0x00208000, 0x08000000, 0x00200000, 0x08008020, 0x08200020, 0x00200000, 0x00008000, 0x08208000, 0x00000020, 0x00200000, 0x00008000, 0x08000020, 0x08208020, 0x00008020, 0x08000000, 0x00000000, 0x00208000, 0x08200020, 0x08008020, 0x08008000, 0x00200020, 0x08208000, 0x00000020, 0x00200020, 0x08008000, 0x08208020, 0x00200000, 0x08200000, 0x08000020, 0x00208000, 0x00008020, 0x08008020, 0x08200000, 0x00000020, 0x08208000, 0x00208020, 0x00000000, 0x08000000, 0x08200020, 0x00008000, 0x00208020 ), ) _skb = ( # for C bits (numbered as per FIPS 46) 1 2 3 4 5 6 ( 0x00000000, 0x00000010, 0x20000000, 0x20000010, 0x00010000, 0x00010010, 0x20010000, 0x20010010, 0x00000800, 0x00000810, 0x20000800, 0x20000810, 0x00010800, 0x00010810, 0x20010800, 0x20010810, 0x00000020, 0x00000030, 0x20000020, 0x20000030, 0x00010020, 0x00010030, 0x20010020, 0x20010030, 0x00000820, 0x00000830, 0x20000820, 0x20000830, 0x00010820, 0x00010830, 0x20010820, 0x20010830, 0x00080000, 0x00080010, 0x20080000, 0x20080010, 0x00090000, 0x00090010, 0x20090000, 0x20090010, 0x00080800, 0x00080810, 0x20080800, 0x20080810, 0x00090800, 0x00090810, 0x20090800, 0x20090810, 0x00080020, 0x00080030, 0x20080020, 0x20080030, 0x00090020, 0x00090030, 0x20090020, 0x20090030, 0x00080820, 0x00080830, 0x20080820, 0x20080830, 0x00090820, 0x00090830, 0x20090820, 0x20090830 ), # for C bits (numbered as per FIPS 46) 7 8 10 11 12 13 ( 0x00000000, 0x02000000, 0x00002000, 0x02002000, 0x00200000, 0x02200000, 0x00202000, 0x02202000, 0x00000004, 0x02000004, 0x00002004, 0x02002004, 0x00200004, 0x02200004, 0x00202004, 0x02202004, 0x00000400, 0x02000400, 0x00002400, 0x02002400, 0x00200400, 0x02200400, 0x00202400, 0x02202400, 0x00000404, 0x02000404, 0x00002404, 0x02002404, 0x00200404, 0x02200404, 0x00202404, 0x02202404, 0x10000000, 0x12000000, 0x10002000, 0x12002000, 0x10200000, 0x12200000, 0x10202000, 0x12202000, 0x10000004, 0x12000004, 0x10002004, 0x12002004, 0x10200004, 0x12200004, 0x10202004, 0x12202004, 0x10000400, 0x12000400, 0x10002400, 0x12002400, 0x10200400, 0x12200400, 0x10202400, 0x12202400, 0x10000404, 0x12000404, 0x10002404, 0x12002404, 0x10200404, 0x12200404, 0x10202404, 0x12202404 ), # for C bits (numbered as per FIPS 46) 14 15 16 17 19 20 ( 0x00000000, 0x00000001, 0x00040000, 0x00040001, 0x01000000, 0x01000001, 0x01040000, 0x01040001, 0x00000002, 0x00000003, 0x00040002, 0x00040003, 0x01000002, 0x01000003, 0x01040002, 0x01040003, 0x00000200, 0x00000201, 0x00040200, 0x00040201, 0x01000200, 0x01000201, 0x01040200, 0x01040201, 0x00000202, 0x00000203, 0x00040202, 0x00040203, 0x01000202, 0x01000203, 0x01040202, 0x01040203, 0x08000000, 0x08000001, 0x08040000, 0x08040001, 0x09000000, 0x09000001, 0x09040000, 0x09040001, 0x08000002, 0x08000003, 0x08040002, 0x08040003, 0x09000002, 0x09000003, 0x09040002, 0x09040003, 0x08000200, 0x08000201, 0x08040200, 0x08040201, 0x09000200, 0x09000201, 0x09040200, 0x09040201, 0x08000202, 0x08000203, 0x08040202, 0x08040203, 0x09000202, 0x09000203, 0x09040202, 0x09040203 ), # for C bits (numbered as per FIPS 46) 21 23 24 26 27 28 ( 0x00000000, 0x00100000, 0x00000100, 0x00100100, 0x00000008, 0x00100008, 0x00000108, 0x00100108, 0x00001000, 0x00101000, 0x00001100, 0x00101100, 0x00001008, 0x00101008, 0x00001108, 0x00101108, 0x04000000, 0x04100000, 0x04000100, 0x04100100, 0x04000008, 0x04100008, 0x04000108, 0x04100108, 0x04001000, 0x04101000, 0x04001100, 0x04101100, 0x04001008, 0x04101008, 0x04001108, 0x04101108, 0x00020000, 0x00120000, 0x00020100, 0x00120100, 0x00020008, 0x00120008, 0x00020108, 0x00120108, 0x00021000, 0x00121000, 0x00021100, 0x00121100, 0x00021008, 0x00121008, 0x00021108, 0x00121108, 0x04020000, 0x04120000, 0x04020100, 0x04120100, 0x04020008, 0x04120008, 0x04020108, 0x04120108, 0x04021000, 0x04121000, 0x04021100, 0x04121100, 0x04021008, 0x04121008, 0x04021108, 0x04121108 ), # for D bits (numbered as per FIPS 46) 1 2 3 4 5 6 ( 0x00000000, 0x10000000, 0x00010000, 0x10010000, 0x00000004, 0x10000004, 0x00010004, 0x10010004, 0x20000000, 0x30000000, 0x20010000, 0x30010000, 0x20000004, 0x30000004, 0x20010004, 0x30010004, 0x00100000, 0x10100000, 0x00110000, 0x10110000, 0x00100004, 0x10100004, 0x00110004, 0x10110004, 0x20100000, 0x30100000, 0x20110000, 0x30110000, 0x20100004, 0x30100004, 0x20110004, 0x30110004, 0x00001000, 0x10001000, 0x00011000, 0x10011000, 0x00001004, 0x10001004, 0x00011004, 0x10011004, 0x20001000, 0x30001000, 0x20011000, 0x30011000, 0x20001004, 0x30001004, 0x20011004, 0x30011004, 0x00101000, 0x10101000, 0x00111000, 0x10111000, 0x00101004, 0x10101004, 0x00111004, 0x10111004, 0x20101000, 0x30101000, 0x20111000, 0x30111000, 0x20101004, 0x30101004, 0x20111004, 0x30111004 ), # for D bits (numbered as per FIPS 46) 8 9 11 12 13 14 ( 0x00000000, 0x08000000, 0x00000008, 0x08000008, 0x00000400, 0x08000400, 0x00000408, 0x08000408, 0x00020000, 0x08020000, 0x00020008, 0x08020008, 0x00020400, 0x08020400, 0x00020408, 0x08020408, 0x00000001, 0x08000001, 0x00000009, 0x08000009, 0x00000401, 0x08000401, 0x00000409, 0x08000409, 0x00020001, 0x08020001, 0x00020009, 0x08020009, 0x00020401, 0x08020401, 0x00020409, 0x08020409, 0x02000000, 0x0A000000, 0x02000008, 0x0A000008, 0x02000400, 0x0A000400, 0x02000408, 0x0A000408, 0x02020000, 0x0A020000, 0x02020008, 0x0A020008, 0x02020400, 0x0A020400, 0x02020408, 0x0A020408, 0x02000001, 0x0A000001, 0x02000009, 0x0A000009, 0x02000401, 0x0A000401, 0x02000409, 0x0A000409, 0x02020001, 0x0A020001, 0x02020009, 0x0A020009, 0x02020401, 0x0A020401, 0x02020409, 0x0A020409 ), # for D bits (numbered as per FIPS 46) 16 17 18 19 20 21 ( 0x00000000, 0x00000100, 0x00080000, 0x00080100, 0x01000000, 0x01000100, 0x01080000, 0x01080100, 0x00000010, 0x00000110, 0x00080010, 0x00080110, 0x01000010, 0x01000110, 0x01080010, 0x01080110, 0x00200000, 0x00200100, 0x00280000, 0x00280100, 0x01200000, 0x01200100, 0x01280000, 0x01280100, 0x00200010, 0x00200110, 0x00280010, 0x00280110, 0x01200010, 0x01200110, 0x01280010, 0x01280110, 0x00000200, 0x00000300, 0x00080200, 0x00080300, 0x01000200, 0x01000300, 0x01080200, 0x01080300, 0x00000210, 0x00000310, 0x00080210, 0x00080310, 0x01000210, 0x01000310, 0x01080210, 0x01080310, 0x00200200, 0x00200300, 0x00280200, 0x00280300, 0x01200200, 0x01200300, 0x01280200, 0x01280300, 0x00200210, 0x00200310, 0x00280210, 0x00280310, 0x01200210, 0x01200310, 0x01280210, 0x01280310 ), # for D bits (numbered as per FIPS 46) 22 23 24 25 27 28 ( 0x00000000, 0x04000000, 0x00040000, 0x04040000, 0x00000002, 0x04000002, 0x00040002, 0x04040002, 0x00002000, 0x04002000, 0x00042000, 0x04042000, 0x00002002, 0x04002002, 0x00042002, 0x04042002, 0x00000020, 0x04000020, 0x00040020, 0x04040020, 0x00000022, 0x04000022, 0x00040022, 0x04040022, 0x00002020, 0x04002020, 0x00042020, 0x04042020, 0x00002022, 0x04002022, 0x00042022, 0x04042022, 0x00000800, 0x04000800, 0x00040800, 0x04040800, 0x00000802, 0x04000802, 0x00040802, 0x04040802, 0x00002800, 0x04002800, 0x00042800, 0x04042800, 0x00002802, 0x04002802, 0x00042802, 0x04042802, 0x00000820, 0x04000820, 0x00040820, 0x04040820, 0x00000822, 0x04000822, 0x00040822, 0x04040822, 0x00002820, 0x04002820, 0x00042820, 0x04042820, 0x00002822, 0x04002822, 0x00042822, 0x04042822 ) ) _shifts2 = (0,0,1,1,1,1,1,1,0,1,1,1,1,1,1,0) _con_salt = ( 0xD2,0xD3,0xD4,0xD5,0xD6,0xD7,0xD8,0xD9, 0xDA,0xDB,0xDC,0xDD,0xDE,0xDF,0xE0,0xE1, 0xE2,0xE3,0xE4,0xE5,0xE6,0xE7,0xE8,0xE9, 0xEA,0xEB,0xEC,0xED,0xEE,0xEF,0xF0,0xF1, 0xF2,0xF3,0xF4,0xF5,0xF6,0xF7,0xF8,0xF9, 0xFA,0xFB,0xFC,0xFD,0xFE,0xFF,0x00,0x01, 0x02,0x03,0x04,0x05,0x06,0x07,0x08,0x09, 0x0A,0x0B,0x05,0x06,0x07,0x08,0x09,0x0A, 0x0B,0x0C,0x0D,0x0E,0x0F,0x10,0x11,0x12, 0x13,0x14,0x15,0x16,0x17,0x18,0x19,0x1A, 0x1B,0x1C,0x1D,0x1E,0x1F,0x20,0x21,0x22, 0x23,0x24,0x25,0x20,0x21,0x22,0x23,0x24, 0x25,0x26,0x27,0x28,0x29,0x2A,0x2B,0x2C, 0x2D,0x2E,0x2F,0x30,0x31,0x32,0x33,0x34, 0x35,0x36,0x37,0x38,0x39,0x3A,0x3B,0x3C, 0x3D,0x3E,0x3F,0x40,0x41,0x42,0x43,0x44 ) _cov_2char = b'./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' def _HPERM_OP(a): """Clever bit manipulation.""" t = ((a << 18) ^ a) & 0xcccc0000 return a ^ t ^ ((t >> 18) & 0x3fff) def _PERM_OP(a,b,n,m): """Cleverer bit manipulation.""" t = ((a >> n) ^ b) & m b = b ^ t a = a ^ (t << n) return a,b ii_struct = struct.Struct('> 16) | ((c >> 4) & 0x0f000000)) c = c & 0x0fffffff # Copy globals into local variables for loop. shifts2 = _shifts2 skbc0, skbc1, skbc2, skbc3, skbd0, skbd1, skbd2, skbd3 = _skb k = [0] * (_ITERATIONS * 2) for i in range(_ITERATIONS): # Only operates on top 28 bits. if shifts2[i]: c = (c >> 2) | (c << 26) d = (d >> 2) | (d << 26) else: c = (c >> 1) | (c << 27) d = (d >> 1) | (d << 27) c = c & 0x0fffffff d = d & 0x0fffffff s = ( skbc0[ c & 0x3f ] | skbc1[((c>> 6) & 0x03) | ((c>> 7) & 0x3c)] | skbc2[((c>>13) & 0x0f) | ((c>>14) & 0x30)] | skbc3[((c>>20) & 0x01) | ((c>>21) & 0x06) | ((c>>22) & 0x38)] ) t = ( skbd0[ d & 0x3f ] | skbd1[((d>> 7) & 0x03) | ((d>> 8) & 0x3c)] | skbd2[((d>>15) & 0x3f) ] | skbd3[((d>>21) & 0x0f) | ((d>>22) & 0x30)] ) k[2*i] = ((t << 16) | (s & 0x0000ffff)) & 0xffffffff s = (s >> 16) | (t & 0xffff0000) # Top bit of s may be 1. s = (s << 4) | ((s >> 28) & 0x0f) k[2*i + 1] = s & 0xffffffff return k def _body(ks, E0, E1): """Use the key schedule ks and salt E0, E1 to create the password hash.""" # Copy global variable into locals for loop. SP0, SP1, SP2, SP3, SP4, SP5, SP6, SP7 = _SPtrans inner = range(0, _ITERATIONS*2, 2) l = r = 0 for j in range(25): l,r = r,l for i in inner: t = r ^ ((r >> 16) & 0xffff) u = t & E0 t = t & E1 u = u ^ (u << 16) ^ r ^ ks[i] t = t ^ (t << 16) ^ r ^ ks[i+1] t = ((t >> 4) & 0x0fffffff) | (t << 28) l,r = r,(SP1[(t ) & 0x3f] ^ SP3[(t>> 8) & 0x3f] ^ SP5[(t>>16) & 0x3f] ^ SP7[(t>>24) & 0x3f] ^ SP0[(u ) & 0x3f] ^ SP2[(u>> 8) & 0x3f] ^ SP4[(u>>16) & 0x3f] ^ SP6[(u>>24) & 0x3f] ^ l) l = ((l >> 1) & 0x7fffffff) | ((l & 0x1) << 31) r = ((r >> 1) & 0x7fffffff) | ((r & 0x1) << 31) r,l = _PERM_OP(r, l, 1, 0x55555555) l,r = _PERM_OP(l, r, 8, 0x00ff00ff) r,l = _PERM_OP(r, l, 2, 0x33333333) l,r = _PERM_OP(l, r, 16, 0x0000ffff) r,l = _PERM_OP(r, l, 4, 0x0f0f0f0f) return l,r def crypt(password, salt): """Generate an encrypted hash from the passed password. If the password is longer than eight characters, only the first eight will be used. The first two characters of the salt are used to modify the encryption algorithm used to generate in the hash in one of 4096 different ways. The characters for the salt should be upper- and lower-case letters A to Z, digits 0 to 9, '.' and '/'. The returned hash begins with the two characters of the salt, and should be passed as the salt to verify the password. Example: >>> from fcrypt import crypt >>> password = 'AlOtBsOl' >>> salt = 'cE' >>> hash = crypt(password, salt) >>> hash 'cEpWz5IUCShqM' >>> crypt(password, hash) == hash 1 >>> crypt('IaLaIoK', hash) == hash 0 In practice, you would read the password using something like the getpass module, and generate the salt randomly: >>> import random, string >>> saltchars = string.letters + string.digits + './' >>> salt = random.choice(saltchars) + random.choice(saltchars) Note that other ASCII characters are accepted in the salt, but the results may not be the same as other versions of crypt. In particular, '_', '$1' and '$2' do not select alternative hash algorithms such as the extended passwords, MD5 crypt and Blowfish crypt supported by the OpenBSD C library. """ # Extract the salt. if len(salt) == 0: salt = b'AA' elif len(salt) == 1: salt = salt + b'A' Eswap0 = _con_salt[salt[0] & 0x7f] Eswap1 = _con_salt[salt[1] & 0x7f] << 4 # Generate the key and use it to apply the encryption. ks = _set_key((password + b'\x00\x00\x00\x00\x00\x00\x00\x00')[:8]) o1, o2 = _body(ks, Eswap0, Eswap1) # Extract 24-bit subsets of result with bytes reversed. t1 = (o1 << 16 & 0xff0000) | (o1 & 0xff00) | (o1 >> 16 & 0xff) t2 = (o1 >> 8 & 0xff0000) | (o2 << 8 & 0xff00) | (o2 >> 8 & 0xff) t3 = (o2 & 0xff0000) | (o2 >> 16 & 0xff00) # Extract 6-bit subsets. r = [ t1 >> 18 & 0x3f, t1 >> 12 & 0x3f, t1 >> 6 & 0x3f, t1 & 0x3f, t2 >> 18 & 0x3f, t2 >> 12 & 0x3f, t2 >> 6 & 0x3f, t2 & 0x3f, t3 >> 18 & 0x3f, t3 >> 12 & 0x3f, t3 >> 6 & 0x3f ] # Convert to characters. for i in range(len(r)): r[i] = _cov_2char[r[i]:r[i]+1] return salt[:2] + b''.join(r) def _test(): """Run doctest on fcrypt module.""" import doctest, fcrypt return doctest.testmod(fcrypt) if __name__ == '__main__': _test() fe-1.1.0/postgresql/resolved/riparse.py000066400000000000000000000221241203372773200201570ustar00rootroot00000000000000# -*- encoding: utf-8 -*- ## # copyright 2008, James William Pye. http://jwp.name ## """ Split, unsplit, parse, serialize, construct and structure resource indicators. Resource indicators take the form:: [scheme:[//]][user[:pass]@]host[:port][/[path[/path]*][?param-n1=value[¶m-n=value-n]*][#fragment]] It might be an URL, URI, or IRI. It tries not to care. Notably, it only percent-encodes chr(0-33) as some RIs support character values greater than 127. Usually, it's best to make a second pass on the string in order to target a specific format, URI or IRI. If a specific format is being targeted, URL or URI or URI-represention of an IRI, a second pass *must* be made on the string. # Future versions may include subsequent transformation routines for targeting. Overview -------- Where ``x`` is a text RI(ie, ``http://foo.com/path``):: unsplit(split(x)) == x serialize(parse(x)) == x parse(x) == structure(split(x)) construct(parse(x)) == split(x) Substructure ------------ In some cases, an RI may have additional structure that needs to be extracted. To do this, the ``fieldproc`` keyword is used on `split_netloc`, `structure`, and `parse` functions. The ``fieldproc`` keyword is a callable that takes a single argument and returns the processed field. By default, ``fieldproc`` is the `unescape` function which will decode percent escapes. This is not desirable when substructure exists within an RI's component as it can create ambiguity about a token when a percent encoded variant is decoded. """ import re pct_encode = '%%%0.2X'.__mod__ unescaped = '%' + ''.join([chr(x) for x in range(0, 33)]) percent_escapes_re = re.compile('(%[0-9a-fA-F]{2,2})+') escape_re = re.compile('[%s]' %(re.escape(unescaped),)) escape_user_re = re.compile('[%s]' %(re.escape(unescaped + ':@/?#'),)) escape_password_re = re.compile('[%s]' %(re.escape(unescaped + '@/?#'),)) escape_host_re = re.compile('[%s]' %(re.escape(unescaped + '/?#'),)) escape_port_re = re.compile('[%s]' %(re.escape(unescaped + '/?#'),)) escape_path_re = re.compile('[%s]' %(re.escape(unescaped + '/?#'),)) escape_query_key_re = re.compile('[%s]' %(re.escape(unescaped + '&=#'),)) escape_query_value_re = re.compile('[%s]' %(re.escape(unescaped + '&#'),)) percent_escapes = {} for x in range(256): k = '%0.2X'.__mod__(x) percent_escapes[k] = x percent_escapes[k.lower()] = x percent_escapes[k[0].lower() + k[1]] = x percent_escapes[k[0] + k[1].lower()] = x scheme_chars = '-.+0123456789' del x def unescape(x, mkval = chr): 'Substitute percent escapes with literal characters' nstr = type(x)('') if isinstance(x, str): mkval = chr pos = 0 end = len(x) while pos != end: newpos = x.find('%', pos) if newpos == -1: nstr += x[pos:] break else: nstr += x[pos:newpos] val = percent_escapes.get(x[newpos+1:newpos+3]) if val is not None: nstr += mkval(val) pos = newpos + 3 else: nstr += '%' pos = newpos + 1 return nstr def re_pct_encode(m): return pct_encode(ord(m.group(0))) indexes = { 'scheme' : 0, 'netloc' : 1, 'path' : 2, 'query' : 3, 'fragment' : 4 } def split(s): """ Split an IRI into its base components based on the markers:: ://, /, ?, # Return a 5-tuple: (scheme, netloc, path, query, fragment) """ scheme = None netloc = None path = None query = None fragment = None end = len(s) pos = 0 # Non-iauthority RI's should be special cased by the user. scheme_pos = s.find('://') if scheme_pos != -1: pos = scheme_pos + 3 scheme = s[:scheme_pos] for x in scheme: if not (x in scheme_chars) and \ not ('A' <= x <= 'Z') and not ('a' <= x <= 'z'): pos = 0 scheme = None break end_of_netloc = end path_pos = s.find('/', pos) if path_pos == -1: path_pos = None else: end_of_netloc = path_pos query_pos = s.find('?', pos) if query_pos == -1: query_pos = None elif path_pos is None or query_pos < path_pos: path_pos = None end_of_netloc = query_pos fragment_pos = s.find('#', pos) if fragment_pos == -1: fragment_pos = None else: if query_pos is not None and fragment_pos < query_pos: query_pos = None if path_pos is not None and fragment_pos < path_pos: path_pos = None end_of_netloc = fragment_pos if query_pos is None and path_pos is None: end_of_netloc = fragment_pos if end_of_netloc != pos: netloc = s[pos:end_of_netloc] if path_pos is not None: path = s[path_pos+1:query_pos or fragment_pos or end] if query_pos is not None: query = s[query_pos+1:fragment_pos or end] if fragment_pos is not None: fragment = s[fragment_pos+1:end] return (scheme, netloc, path, query, fragment) def unsplit_path(p, re = escape_path_re): """ Join a list of paths(strings) on "/" *after* escaping them. """ if not p: return None return '/'.join([re.sub(re_pct_encode, x) for x in p]) def split_path(p, fieldproc = unescape): """ Return a list of unescaped strings split on "/". Set `fieldproc` to `str` if the components' percent escapes should not be decoded. """ if p is None: return [] return [fieldproc(x) for x in p.split('/')] def unsplit(t): 'Make a RI from a split RI(5-tuple)' s = '' if t[0] is not None: s += t[0] s += '://' if t[1] is not None: s += t[1] if t[2] is not None: s += '/' s += t[2] if t[3] is not None: s += '?' s += t[3] if t[4] is not None: s += '#' s += t[4] return s def split_netloc(netloc, fieldproc = unescape): """ Split a net location into a 4-tuple, (user, password, host, port). Set `fieldproc` to `str` if the components' percent escapes should not be decoded. """ pos = netloc.find('@') if pos == -1: # No user information pos = 0 user = None password = None else: s = netloc[:pos] userpw = s.split(':', 1) if len(userpw) == 2: user, password = userpw user = fieldproc(user) password = fieldproc(password) else: user = fieldproc(userpw[0]) password = None pos += 1 if pos >= len(netloc): return (user, password, None, None) pos_chr = netloc[pos] if pos_chr == '[': # IPvN addr next_pos = netloc.find(']', pos) if next_pos == -1: # unterminated IPvN block next_pos = len(netloc) - 1 addr = netloc[pos:next_pos+1] pos = next_pos + 1 next_pos = netloc.find(':', pos) if next_pos == -1: port = None else: port = fieldproc(netloc[next_pos+1:]) else: next_pos = netloc.find(':', pos) if next_pos == -1: addr = fieldproc(netloc[pos:]) port = None else: addr = fieldproc(netloc[pos:next_pos]) port = fieldproc(netloc[next_pos+1:]) return (user, password, addr, port) def unsplit_netloc(t): 'Create a netloc fragment from the given tuple(user,password,host,port)' if t[0] is None and t[2] is None: return None s = '' if t[0] is not None: s += escape_user_re.sub(re_pct_encode, t[0]) if t[1] is not None: s += ':' s += escape_password_re.sub(re_pct_encode, t[1]) s += '@' if t[2] is not None: s += escape_host_re.sub(re_pct_encode, t[2]) if t[3] is not None: s += ':' s += escape_port_re.sub(re_pct_encode, t[3]) return s def structure(t, fieldproc = unescape): """ Create a dictionary from a split RI(5-tuple). Set `fieldproc` to `str` if the components' percent escapes should not be decoded. """ d = {} if t[0] is not None: d['scheme'] = t[0] if t[1] is not None: uphp = split_netloc(t[1], fieldproc = fieldproc) if uphp[0] is not None: d['user'] = uphp[0] if uphp[1] is not None: d['password'] = uphp[1] if uphp[2] is not None: d['host'] = uphp[2] if uphp[3] is not None: d['port'] = uphp[3] if t[2] is not None: if t[2]: d['path'] = list(map(fieldproc, t[2].split('/'))) else: d['path'] = [] if t[3] is not None: if t[3]: d['query'] = [tuple((list(map(fieldproc, x.split('=', 1))) + [None])[:2]) for x in t[3].split('&')] else: # no characters followed the '?' d['query'] = [] if t[4] is not None: d['fragment'] = fieldproc(t[4]) return d def construct_query(x, key_re = escape_query_key_re, value_re = escape_query_value_re, ): 'Given a sequence of (key, value) pairs, construct' return '&'.join([ v is not None and \ '%s=%s' %( key_re.sub(re_pct_encode, k), value_re.sub(re_pct_encode, v), ) or \ key_re.sub(re_pct_encode, k) for k, v in x ]) def construct(x): 'Construct a RI tuple(5-tuple) from a dictionary object' p = x.get('path') if p is not None: p = '/'.join([escape_path_re.sub(re_pct_encode, y) for y in p]) q = x.get('query') if q is not None: q = construct_query(q) f = x.get('fragment') if f is not None: f = escape_re.sub(re_pct_encode, f) u = x.get('user') pw = x.get('password') h = x.get('host') port = x.get('port') return ( x.get('scheme'), # netloc: [user[:pass]@]host[:port] unsplit_netloc(( x.get('user'), x.get('password'), x.get('host'), x.get('port'), )), p, q, f ) def parse(s, fieldproc = unescape): """ Parse an RI into a dictionary object. Synonym for ``structure(split(x))``. Set `fieldproc` to `str` if the components' percent escapes should not be decoded. """ return structure(split(s), fieldproc = fieldproc) def serialize(x): 'Return an RI from a dictionary object. Synonym for ``unsplit(construct(x))``' return unsplit(construct(x)) __docformat__ = 'reStructuredText' fe-1.1.0/postgresql/string.py000066400000000000000000000161601203372773200162000ustar00rootroot00000000000000## # .string ## """ String split and join operations for dealing with literals and identifiers. Notably, the functions in this module are intended to be used for simple use-cases. It attempts to stay away from "real" parsing and simply provides functions for common needs, like the ability to identify unquoted portions of a query string so that logic or transformations can be applied to only unquoted portions. Scanning for statement terminators, or safely interpolating identifiers. All functions deal with strict quoting rules. """ import re def escape_literal(text): "Replace every instance of ' with ''" return text.replace("'", "''") def quote_literal(text): "Escape the literal and wrap it in [single] quotations" return "'" + text.replace("'", "''") + "'" def escape_ident(text): 'Replace every instance of " with ""' return text.replace('"', '""') def needs_quoting(text): return not (text and not text[0].isdecimal() and text.replace('_', 'a').isalnum()) def quote_ident(text): "Replace every instance of '"' with '""' *and* place '"' on each end" return '"' + text.replace('"', '""') + '"' def quote_ident_if_needed(text): """ If needed, replace every instance of '"' with '""' *and* place '"' on each end. Otherwise, just return the text. """ return quote_ident(text) if needs_quoting(text) else text quote_re = re.compile(r"""(?xu) E'(?:''|\\.|[^'])*(?:'|$) (?# Backslash escapes E'str') | '(?:''|[^'])*(?:'|$) (?# Regular literals 'str') | "(?:""|[^"])*(?:"|$) (?# Identifiers "str") | (\$(?:[^0-9$]\w*)?\$).*?(?:\1|$) (?# Dollar quotes $$str$$) """) def split(text): """ split the string up by into non-quoted and quoted portions. Zero and even numbered indexes are unquoted portions, while odd indexes are quoted portions. Unquoted portions are regular strings, whereas quoted portions are pair-tuples specifying the quotation mechanism and the content thereof. >>> list(split("select $$foobar$$")) ['select ', ('$$', 'foobar'), ''] If the split ends on a quoted section, it means the string's quote was not terminated. Subsequently, there will be an even number of objects in the list. Quotation errors are detected, but never raised. Rather it's up to the user to identify the best course of action for the given split. """ lastend = 0 re = quote_re scan = re.scanner(text) match = scan.search() while match is not None: # text preceding the quotation yield text[lastend:match.start()] # the dollar quote, if any dq = match.groups()[0] if dq is not None: endoff = len(dq) quote = dq end = quote else: endoff = 1 q = text[match.start()] if q == 'E': quote = "E'" end = "'" else: end = quote = q # If the end is not the expected quote, it consumed # the end. Be sure to check that the match's end - end offset # is *not* the start, ie an empty quotation at the end of the string. if text[match.end()-endoff:match.end()] != end \ or match.end() - endoff == match.start(): yield (quote, text[match.start()+len(quote):]) break else: yield (quote, text[match.start()+len(quote):match.end()-endoff]) lastend = match.end() match = scan.search() else: # balanced quotes, yield the rest yield text[lastend:] def unsplit(splitted_iter): """ catenate a split string. This is needed to handle the special cases created by pg.string.split(). (Run-away quotations, primarily) """ s = '' quoted = False i = iter(splitted_iter) endq = '' for x in i: s += endq + x try: q, qtext = next(i) s += q + qtext if q == "E'": endq = "'" else: endq = q except StopIteration: break return s def split_using(text, quote, sep = '.', maxsplit = -1): """ split the string on the seperator ignoring the separator in quoted areas. This is only useful for simple quoted strings. Dollar quotes, and backslash escapes are not supported. """ escape = quote * 2 esclen = len(escape) offset = 0 tl = len(text) end = tl # Fast path: No quotes? Do a simple split. if quote not in text: return text.split(sep, maxsplit) l = [] while len(l) != maxsplit: # Look for the separator first nextsep = text.find(sep, offset) if nextsep == -1: # it's over. there are no more seps break else: # There's a sep ahead, but is there a quoted section before it? nextquote = text.find(quote, offset, nextsep) while nextquote != -1: # Yep, there's a quote before the sep; # need to eat the escaped portion. nextquote = text.find(quote, nextquote + 1,) while nextquote != -1: if text.find(escape, nextquote, nextquote+esclen) != nextquote: # Not an escape, so it's the end. break # Look for another quote past the escape quote. nextquote = text.find(quote, nextquote + 2) else: # the sep was located in the escape, and # the escape consumed the rest of the string. nextsep = -1 break nextsep = text.find(sep, nextquote + 1) if nextsep == -1: # it's over. there are no more seps # [likely they were consumed by the escape] break nextquote = text.find(quote, nextquote + 1, nextsep) if nextsep == -1: break l.append(text[offset:nextsep]) offset = nextsep + 1 l.append(text[offset:]) return l def split_ident(text, sep = ',', quote = '"', maxsplit = -1): """ Split a series of identifiers using the specified separator. """ nr = [] for x in split_using(text, quote, sep = sep, maxsplit = maxsplit): x = x.strip() if x.startswith('"'): if not x.endswith('"'): raise ValueError( "unterminated identifier quotation", x ) else: nr.append(x[1:-1].replace('""', '"')) elif needs_quoting(x): raise ValueError( "non-ident characters in unquoted identifier", x ) else: # postgres implies a lower, so to stay consistent # with it on qname joins, lower the unquoted identifier now. nr.append(x.lower()) return nr def split_qname(text, maxsplit = -1): """ Call to .split_ident() with a '.' sep parameter. """ return split_ident(text, maxsplit = maxsplit, sep = '.') def qname(*args): "Quote the identifiers and join them using '.'" return '.'.join([quote_ident(x) for x in args]) def qname_if_needed(*args): return '.'.join([quote_ident_if_needed(x) for x in args]) def split_sql(sql, sep = ';'): """ Given SQL, safely split using the given separator. Notably, this yields fully split text. This should be used instead of split_sql_str() when quoted sections need be still be isolated. >>> list(split_sql('select $$1$$ AS "foo;"; select 2;')) [['select ', ('$$', '1'), ' AS ', ('"', 'foo;'), ''], (' select 2',), ['']] """ i = iter(split(sql)) cur = [] for part in i: sections = part.split(sep) if len(sections) < 2: cur.append(part) else: cur.append(sections[0]) yield cur for x in sections[1:-1]: yield (x,) cur = [sections[-1]] try: cur.append(next(i)) except StopIteration: break if cur: yield cur def split_sql_str(sql, sep = ';'): """ Identical to split_sql but yields unsplit text. >>> list(split_sql_str('select $$1$$ AS "foo;"; select 2;')) ['select $$1$$ AS "foo;"', ' select 2', ''] """ for x in split_sql(sql, sep = sep): yield unsplit(x) fe-1.1.0/postgresql/sys.py000066400000000000000000000041151203372773200155050ustar00rootroot00000000000000## # .sys ## """ py-postgresql system functions and data. Data ---- ``libpath`` The local file system paths that contain query libraries. Overridable Functions --------------------- excformat Information that makes up an exception's displayed "body". Effectively, the implementation of `postgresql.exception.Error.__str__` msghook Display a message. """ import sys import os import traceback from .python.element import format_element from .python.string import indent libpath = [] def default_errformat(val): """ Built-in error formatter. DON'T TOUCH! """ it = val._e_metas() if val.creator is not None: # Protect against element traceback failures. try: after = os.linesep + format_element(val.creator) except Exception: after = 'Element Traceback of %r caused exception:%s' %( type(val.creator).__name__, os.linesep ) after += indent(traceback.format_exc()) after = os.linesep + indent(after).rstrip() else: after = '' return next(it)[1] \ + os.linesep + ' ' \ + (os.linesep + ' ').join( k + ': ' + v for k, v in it ) + after def default_msghook(msg, format_message = format_element): """ Built-in message hook. DON'T TOUCH! """ if sys.stderr and not sys.stderr.closed: try: sys.stderr.write(format_message(msg) + os.linesep) except Exception: try: sys.excepthook(*sys.exc_info()) except Exception: # gasp. pass def errformat(*args, **kw): """ Raised Database Error formatter pointing to default_excformat. Override if you like. All postgresql.exceptions.Error's are formatted using this function. """ return default_errformat(*args, **kw) def msghook(*args, **kw): """ Message hook pointing to default_msghook. Override if you like. All untrapped messages raised by driver connections come here to be printed to stderr. """ return default_msghook(*args, **kw) def reset_errformat(with_func = errformat): 'restore the original excformat function' global errformat errformat = with_func def reset_msghook(with_func = msghook): 'restore the original msghook function' global msghook msghook = with_func fe-1.1.0/postgresql/temporal.py000066400000000000000000000152161203372773200165160ustar00rootroot00000000000000## # .temporal - manage the temporary cluster ## """ Temporary PostgreSQL cluster for the process. """ import os import atexit from collections import deque from .cluster import Cluster, ClusterError from . import installation from .python.socket import find_available_port class Temporal(object): """ Manages a temporary cluster for the duration of the process. Instances of this class reference a distinct cluster. These clusters are transient; they will only exist until the process exits. Usage:: >>> from postgresql.temporal import pg_tmp >>> with pg_tmp: ... ps = db.prepare('SELECT 1') ... assert ps.first() == 1 Or `pg_tmp` can decorate a method or function. """ #: Format the cluster directory name. cluster_dirname = 'pg_tmp_{0}_{1}'.format cluster = None _init_pid_ = None _local_id_ = 0 builtins_keys = { 'connector', 'db', 'do', 'xact', 'proc', 'settings', 'prepare', 'sqlexec', 'newdb', } def __init__(self): self.builtins_stack = deque() self.sandbox_id = 0 # identifier for keeping temporary instances unique. self.__class__._local_id_ = self.local_id = (self.__class__._local_id_ + 1) def __call__(self, callable): def in_pg_temporal_context(*args, **kw): with self: return callable(*args, **kw) n = getattr(callable, '__name__', None) if n: in_pg_temporal_context.__name__ = n return in_pg_temporal_context def destroy(self): # Don't destroy if it's not the initializing process. if os.getpid() == self._init_pid_: # Kill all the open connections. try: c = cluster.connection(user = 'test', database = 'template1',) with c: if c.version_info[:2] <= (9,1): c.sys.terminate_backends() else: c.sys.terminate_backends_92() except Exception: # Doesn't matter much if it fails. pass cluster = self.cluster self.cluster = None self._init_pid_ = None if cluster is not None: cluster.stop() cluster.wait_until_stopped(timeout = 5) cluster.drop() def init(self, installation_factory = installation.default, inshint = { 'hint' : "Try setting the PGINSTALLATION " \ "environment variable to the `pg_config` path" } ): if self.cluster is not None: return ## # Hasn't been created yet, but doesn't matter. # On exit, obliterate the cluster directory. self._init_pid_ = os.getpid() atexit.register(self.destroy) # [$HOME|.]/.pg_tmpdb_{pid} self.cluster_path = os.path.join( os.environ.get('HOME', os.getcwd()), self.cluster_dirname(self._init_pid_, self.local_id) ) self.logfile = os.path.join(self.cluster_path, 'logfile') installation = installation_factory() if installation is None: raise ClusterError( 'could not find the default pg_config', details = inshint ) cluster = Cluster(installation, self.cluster_path,) # If it exists already, destroy it. if cluster.initialized(): cluster.drop() cluster.encoding = 'utf-8' cluster.init( user = 'test', # Consistent username. encoding = cluster.encoding, logfile = None, ) # Configure self.cluster_port = find_available_port() if self.cluster_port is None: raise ClusterError( 'could not find a port for the test cluster on localhost', creator = cluster ) cluster.settings.update(dict( port = str(self.cluster_port), max_connections = '20', shared_buffers = '200', listen_addresses = 'localhost', log_destination = 'stderr', log_min_messages = 'FATAL', unix_socket_directory = cluster.data_directory, )) cluster.settings.update(dict( max_prepared_transactions = '10', )) # Start it up. with open(self.logfile, 'w') as lfo: cluster.start(logfile = lfo) cluster.wait_until_started() # Initialize template1 and the test user database. c = cluster.connection(user = 'test', database = 'template1',) with c: c.execute('create database test') # It's ready. self.cluster = cluster def push(self): c = self.cluster.connection(user = 'test') c.connect() extras = [] def new_pg_tmp_connection(l = extras, c = c, sbid = 'sandbox' + str(self.sandbox_id + 1)): # Used to create a new connection that will be closed # when the context stack is popped along with 'db'. l.append(c.clone()) l[-1].settings['search_path'] = str(sbid) + ',' + l[-1].settings['search_path'] return l[-1] # The new builtins. builtins = { 'db' : c, 'prepare' : c.prepare, 'xact' : c.xact, 'sqlexec' : c.execute, 'do' : c.do, 'settings' : c.settings, 'proc' : c.proc, 'connector' : c.connector, 'new' : new_pg_tmp_connection, } if not self.builtins_stack: # Store any of those set or not set. current = { k : __builtins__[k] for k in self.builtins_keys if k in __builtins__ } self.builtins_stack.append((current, [])) # Store and push. self.builtins_stack.append((builtins, extras)) __builtins__.update(builtins) self.sandbox_id += 1 def pop(self, exc, drop_schema = 'DROP SCHEMA sandbox{0} CASCADE'.format): builtins, extras = self.builtins_stack.pop() self.sandbox_id -= 1 # restore __builtins__ if len(self.builtins_stack) > 1: __builtins__.update(self.builtins_stack[-1][0]) else: previous = self.builtins_stack.popleft() for x in self.builtins_keys: if x in previous: __builtins__[x] = previous[x] else: # Wasn't set before. __builtins__.pop(x, None) # close popped connection, but only if we're not in an interrupt. # However, temporal will always terminate all backends atexit. if exc is None or isinstance(exc, Exception): # Interrupt then close. Just in case something is lingering. for xdb in [builtins['db']] + list(extras): if xdb.closed is False: # In order for a clean close of the connection, # interrupt before closing. It is still # possible for the close to block, but less likely. xdb.interrupt() xdb.close() # Interrupted and closed all the other connections at this level; # now remove the sandbox schema. c = self.cluster.connection(user = 'test') with c: # Use a new connection so that the state of # the context connection will not have to be # contended with. c.execute(drop_schema(self.sandbox_id+1)) else: # interrupt pass def __enter__(self): if self.cluster is None: self.init() self.push() try: db.connect() db.execute('CREATE SCHEMA sandbox' + str(self.sandbox_id)) db.settings['search_path'] = 'sandbox' + str(self.sandbox_id) + ',' + db.settings['search_path'] except Exception as e: # failed to initialize sandbox schema; pop it. self.pop(e) raise def __exit__(self, exc, val, tb): if self.cluster is not None: self.pop(val) #: The process' temporary cluster. pg_tmp = Temporal() fe-1.1.0/postgresql/test/000077500000000000000000000000001203372773200152735ustar00rootroot00000000000000fe-1.1.0/postgresql/test/__init__.py000066400000000000000000000000001203372773200173720ustar00rootroot00000000000000fe-1.1.0/postgresql/test/cursor_integrity.py000066400000000000000000000053231203372773200212630ustar00rootroot00000000000000## # .test.cursor_integrity ## import os import unittest import random import itertools iot = '_dst' getq = "SELECT i FROM generate_series(0, %d) AS g(i)" copy = "COPY (%s) TO STDOUT" def random_read(curs, remaining_rows): """ Read from one of the three methods using a random amount if sized. - 50% chance of curs.read(random()) - 40% chance of next() - 10% chance of read() # no count """ if random.random() > 0.5: rrows = random.randrange(0, remaining_rows) return curs.read(rrows), rrows elif random.random() < 0.1: return curs.read(), -1 else: try: return [next(curs)], 1 except StopIteration: return [], 1 def random_select_get(limit): return prepare(getq %(limit - 1,)) def random_copy_get(limit): return prepare(copy %(getq %(limit - 1,),)) class test_integrity(unittest.TestCase): """ test the integrity of the get and put interfaces on queries and result handles. """ def test_select(self): total = 0 while total < 10000: limit = random.randrange(500000) read = 0 total += limit p = random_select_get(limit)() last = ([(-1,)], 1) completed = [last[0]] while True: next = random_read(p, (limit - read) or 10) thisread = len(next[0]) read += thisread completed.append(next[0]) if thisread: self.failUnlessEqual( last[0][-1][0], next[0][0][0] - 1, "first row(-1) of next failed to match the last row of the previous" ) last = next elif next[1] != 0: # done break self.failUnlessEqual(read, limit) self.failUnlessEqual(list(range(-1, limit)), [ x[0] for x in itertools.chain(*completed) ]) def test_insert(self): pass if 'db' in dir(__builtins__) and pg.version_info >= (8,2,0): def test_copy_out(self): total = 0 while total < 10000000: limit = random.randrange(500000) read = 0 total += limit p = random_copy_get(limit)() last = ([-1], 1) completed = [last[0]] while True: next = random_read(p, (limit - read) or 10) next = ([int(x) for x in next[0]], next[1]) thisread = len(next[0]) read += thisread completed.append(next[0]) if thisread: self.failUnlessEqual( last[0][-1], next[0][0] - 1, "first row(-1) of next failed to match the last row of the previous" ) last = next elif next[1] != 0: # done break self.failUnlessEqual(read, limit) self.failUnlessEqual( list(range(-1, limit)), list(itertools.chain(*completed)) ) def test_copy_in(self): pass def main(): global copyin, loadin execute("CREATE TEMP TABLE _dst (i bigint)") copyin = prepare("COPY _dst FROM STDIN") loadin = prepare("INSERT INTO _dst VALUES ($1)") unittest.main() if __name__ == '__main__': main() fe-1.1.0/postgresql/test/perf_copy_io.py000066400000000000000000000036551203372773200203330ustar00rootroot00000000000000## # test.perf_copy_io - Copy I/O: To and From performance ## import os, sys, random, time if __name__ == '__main__': with open('/usr/share/dict/words', mode='brU') as wordfile: Words = wordfile.readlines() else: Words = [b'/usr/share/dict/words', b'is', b'read', b'in', b'__main__'] wordcount = len(Words) random.seed() def getWord(): "extract a random word from ``Words``" return Words[random.randrange(0, wordcount)].strip() def testSpeed(tuples = 50000 * 3): sqlexec("CREATE TEMP TABLE _copy " "(i int, t text, mt text, ts text, ty text, tx text);") try: Q = prepare("COPY _copy FROM STDIN") size = 0 def incsize(data): 'count of bytes' nonlocal size size += len(data) return data sys.stderr.write("preparing data(%d tuples)...\n" %(tuples,)) # Use an LC to avoid the Python overhead involved with a GE data = [incsize(b'\t'.join(( str(x).encode('ascii'), getWord(), getWord(), getWord(), getWord(), getWord() )))+b'\n' for x in range(tuples)] sys.stderr.write("starting copy...\n") start = time.time() copied_in = Q.load_rows(data) duration = time.time() - start sys.stderr.write( "COPY FROM STDIN Summary,\n " \ "copied tuples: %d\n " \ "copied bytes: %d\n " \ "duration: %f\n " \ "average tuple size(bytes): %f\n " \ "average KB per second: %f\n " \ "average tuples per second: %f\n" %( tuples, size, duration, size / tuples, size / 1024 / duration, tuples / duration, ) ) Q = prepare("COPY _copy TO STDOUT") start = time.time() c = 0 for rows in Q.chunks(): c += len(rows) duration = time.time() - start sys.stderr.write( "COPY TO STDOUT Summary,\n " \ "copied tuples: %d\n " \ "duration: %f\n " \ "average KB per second: %f\n " \ "average tuples per second: %f\n " %( c, duration, size / 1024 / duration, tuples / duration, ) ) finally: sqlexec("DROP TABLE _copy") if __name__ == '__main__': testSpeed() fe-1.1.0/postgresql/test/perf_query_io.py000066400000000000000000000032731203372773200205220ustar00rootroot00000000000000#!/usr/bin/env python ## # .test.perf_query_io ## # Statement I/O: Mass insert and select performance ## import os import time import sys import decimal import datetime def insertSamples(count, insert_records): recs = [ ( -3, 123, 0xfffffea023, decimal.Decimal("90900023123.40031"), decimal.Decimal("432.40031"), 'some_óäæ_thing', 'varying', 'æ', datetime.datetime(1982, 5, 18, 12, 0, 0, 100232) ) for x in range(count) ] gen = time.time() insert_records.load_rows(recs) fin = time.time() xacttime = fin - gen ats = count / xacttime sys.stderr.write( "INSERT Summary,\n " \ "inserted tuples: %d\n " \ "total time: %f\n " \ "average tuples per second: %f\n\n" %( count, xacttime, ats, ) ) def timeTupleRead(ps): loops = 0 tuples = 0 genesis = time.time() for x in ps.chunks(): loops += 1 tuples += len(x) finalis = time.time() looptime = finalis - genesis ats = tuples / looptime sys.stderr.write( "SELECT Summary,\n " \ "looped: {looped}\n " \ "looptime: {looptime}\n " \ "tuples: {ntuples}\n " \ "average tuples per second: {tps}\n ".format( looped = loops, looptime = looptime, ntuples = tuples, tps = ats ) ) def main(count): sqlexec('CREATE TEMP TABLE samples ' '(i2 int2, i4 int4, i8 int8, n numeric, n2 numeric, t text, v varchar, c char(2), ts timestamp)') insert_records = prepare( "INSERT INTO samples VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" ) select_records = prepare("SELECT * FROM samples") try: insertSamples(count, insert_records) timeTupleRead(select_records) finally: sqlexec("DROP TABLE samples") def command(args): main(int((args + [25000])[1])) if __name__ == '__main__': command(sys.argv) fe-1.1.0/postgresql/test/support.py000066400000000000000000000006301203372773200173600ustar00rootroot00000000000000## # .test.support ## """ Executable module used by test_* modules to mimic a command. """ import sys def pg_config(*args): data = """FOO=BaR FEH=YEAH version=NAY """ sys.stdout.write(data) if __name__ == '__main__': if sys.argv[1:]: cmd = sys.argv[1] if cmd in globals(): cmd = globals()[cmd] cmd(sys.argv[2:]) sys.exit(0) sys.stderr.write("no valid entry point referenced") sys.exit(1) fe-1.1.0/postgresql/test/test_alock.py000066400000000000000000000073031203372773200200000ustar00rootroot00000000000000## # .test.test_alock - test .alock ## import unittest import threading import time from ..temporal import pg_tmp from .. import alock n_alocks = "select count(*) FROM pg_locks WHERE locktype = 'advisory'" class test_alock(unittest.TestCase): @pg_tmp def testALockWait(self): # sadly, this is primarily used to exercise the code paths.. ad = prepare(n_alocks).first self.assertEqual(ad(), 0) state = [False, False, False] alt = new() first = alock.ExclusiveLock(db, (0,0)) second = alock.ExclusiveLock(db, 1) def concurrent_lock(): try: with alock.ExclusiveLock(alt, 1): with alock.ExclusiveLock(alt, (0,0)): # start it state[0] = True while not state[1]: pass time.sleep(0.01) while not state[2]: time.sleep(0.01) except Exception: # Avoid dead lock in cases where advisory is not available. state[0] = state[1] = state[2] = True t = threading.Thread(target = concurrent_lock) t.start() while not state[0]: time.sleep(0.01) self.assertEqual(ad(), 2) state[1] = True with first: self.assertEqual(ad(), 2) state[2] = True with second: self.assertEqual(ad(), 2) t.join(timeout = 1) @pg_tmp def testALockNoWait(self): alt = new() ad = prepare(n_alocks).first self.assertEqual(ad(), 0) with alock.ExclusiveLock(db, (0,0)): l=alock.ExclusiveLock(alt, (0,0)) # should fail to acquire self.assertEqual(l.acquire(blocking=False), False) # no alocks should exist now self.assertEqual(ad(), 0) @pg_tmp def testALock(self): ad = prepare(n_alocks).first self.assertEqual(ad(), 0) # test a variety.. lockids = [ (1,4), -32532, 0, 2, (7, -1232), 4, 5, 232142423, (18,7), 2, (1,4) ] alt = new() xal1 = alock.ExclusiveLock(db, *lockids) xal2 = alock.ExclusiveLock(db, *lockids) sal1 = alock.ShareLock(db, *lockids) with sal1: with xal1, xal2: self.assertTrue(ad() > 0) for x in lockids: xl = alock.ExclusiveLock(alt, x) self.assertEqual(xl.acquire(blocking=False), False) # main has exclusives on these, so this should fail. xl = alock.ShareLock(alt, *lockids) self.assertEqual(xl.acquire(blocking=False), False) for x in lockids: # sal1 still holds xl = alock.ExclusiveLock(alt, x) self.assertEqual(xl.acquire(blocking=False), False) # sal1 still holds, but we want a share lock too. xl = alock.ShareLock(alt, x) self.assertEqual(xl.acquire(blocking=False), True) xl.release() # no alocks should exist now self.assertEqual(ad(), 0) @pg_tmp def testPartialALock(self): # Validates that release is properly cleaning up ad = prepare(n_alocks).first self.assertEqual(ad(), 0) held = (0,-1234) wanted = [0, 324, -1232948, 7, held, 1, (2,4), (834,1)] alt = new() with alock.ExclusiveLock(db, held): l=alock.ExclusiveLock(alt, *wanted) # should fail to acquire, db has held self.assertEqual(l.acquire(blocking=False), False) # No alocks should exist now. # This *MUST* occur prior to alt being closed. # Otherwise, we won't be testing for the recovery # of a failed non-blocking acquire(). self.assertEqual(ad(), 0) @pg_tmp def testALockParameterErrors(self): self.assertRaises(TypeError, alock.ALock) l = alock.ExclusiveLock(db) self.assertRaises(RuntimeError, l.release) @pg_tmp def testALockOnClosed(self): ad = prepare(n_alocks).first self.assertEqual(ad(), 0) held = (0,-1234) alt = new() # __exit__ should only touch the count. with alock.ExclusiveLock(alt, held) as l: self.assertEqual(ad(), 1) self.assertEqual(l.locked(), True) alt.close() time.sleep(0.005) self.assertEqual(ad(), 0) self.assertEqual(l.locked(), False) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_bytea_codec.py000066400000000000000000000025261203372773200211520ustar00rootroot00000000000000## # .test.test_bytea_codec ## import unittest import struct from ..encodings import bytea byte = struct.Struct('B') class test_bytea_codec(unittest.TestCase): def testDecoding(self): for x in range(255): c = byte.pack(x) b = c.decode('bytea') # normalize into octal escapes if c == b'\\' and b == "\\\\": b = "\\" + oct(b'\\'[0])[2:] elif not b.startswith("\\"): b = "\\" + oct(ord(b))[2:] if int(b[1:], 8) != x: self.fail( "bytea encoding failed at %d; encoded %r to %r" %(x, c, b,) ) def testEncoding(self): self.assertEqual('bytea'.encode('bytea'), b'bytea') self.assertEqual('\\\\'.encode('bytea'), b'\\') self.assertRaises(ValueError, '\\'.encode, 'bytea') self.assertRaises(ValueError, 'foo\\'.encode, 'bytea') self.assertRaises(ValueError, r'foo\0'.encode, 'bytea') self.assertRaises(ValueError, r'foo\00'.encode, 'bytea') self.assertRaises(ValueError, r'\f'.encode, 'bytea') self.assertRaises(ValueError, r'\800'.encode, 'bytea') self.assertRaises(ValueError, r'\7f0'.encode, 'bytea') for x in range(255): seq = ('\\' + oct(x)[2:].lstrip('0').rjust(3, '0')) dx = ord(seq.encode('bytea')) if dx != x: self.fail( "generated sequence failed to map back; current is %d, " \ "rendered %r, transformed to %d" %(x, seq, dx) ) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_cluster.py000066400000000000000000000042131203372773200203650ustar00rootroot00000000000000## # .test.test_cluster ## import sys import os import time import unittest import tempfile from .. import installation from ..cluster import Cluster, ClusterStartupError default_install = installation.default() if default_install is None: sys.stderr.write("ERROR: cannot find 'default' pg_config\n") sys.stderr.write("HINT: set the PGINSTALLATION environment variable to the `pg_config` path\n") sys.exit(1) class test_cluster(unittest.TestCase): def setUp(self): self.cluster = Cluster(default_install, 'test_cluster',) def tearDown(self): self.cluster.drop() self.cluster = None def start_cluster(self, logfile = None): self.cluster.start(logfile = logfile) self.cluster.wait_until_started(timeout = 10) def init(self, *args, **kw): self.cluster.init(*args, **kw) self.cluster.settings.update({ 'max_connections' : '8', 'listen_addresses' : 'localhost', 'port' : '6543', 'unix_socket_directory' : self.cluster.data_directory, }) def testSilentMode(self): self.init() self.cluster.settings['silent_mode'] = 'on' # if it fails to start(ClusterError), silent_mode is not working properly. try: self.start_cluster(logfile = sys.stdout) except ClusterStartupError: # silent_mode is not supported on windows by PG. if sys.platform in ('win32','win64'): pass elif self.cluster.installation.version_info[:2] >= (9, 2): pass else: raise else: if sys.platform in ('win32','win64'): self.fail("silent_mode unexpectedly supported on windows") elif self.cluster.installation.version_info[:2] >= (9, 2): self.fail("silent_mode unexpectedly supported on PostgreSQL >=9.2") def testSuperPassword(self): self.init( user = 'test', password = 'secret', logfile = sys.stdout, ) self.start_cluster() c = self.cluster.connection( user='test', password='secret', database='template1', ) with c: self.assertEqual(c.prepare('select 1').first(), 1) def testNoParameters(self): 'simple init and drop' self.init() self.start_cluster() if __name__ == '__main__': from types import ModuleType this = ModuleType("this") this.__dict__.update(globals()) unittest.main(this) fe-1.1.0/postgresql/test/test_configfile.py000066400000000000000000000150771203372773200210230ustar00rootroot00000000000000## # .test.test_configfile ## import os import unittest from io import StringIO from .. import configfile sample_config_Aroma = \ """ ## # A sample config file. ## # This provides a good = test for alter_config. #shared_buffers = 4500 search_path = window,$user,public shared_buffers = 2500 port = 5234 listen_addresses = 'localhost' listen_addresses = '*' """ ## # Wining cases are alteration cases that provide # source and expectations from an alteration. # # The first string is the source, the second the # alterations to make, the and the third, the expectation. ## winning_cases = [ ( # Two top contenders; the first should be altered, second commented. "foo = bar"+os.linesep+"foo = bar", {'foo' : 'newbar'}, "foo = 'newbar'"+os.linesep+"#foo = bar" ), ( # Two top contenders, first one stays commented "#foo = bar"+os.linesep+"foo = bar", {'foo' : 'newbar'}, "#foo = bar"+os.linesep+"foo = 'newbar'" ), ( # Two top contenders, second one stays commented "foo = bar"+os.linesep+"#foo = bar", {'foo' : 'newbar'}, "foo = 'newbar'"+os.linesep+"#foo = bar" ), ( # Two candidates "foo = bar"+os.linesep+"foo = none", {'foo' : 'bar'}, "foo = 'bar'"+os.linesep+"#foo = none" ), ( # Two candidates, winner should be the first, second gets comment "#foo = none"+os.linesep+"foo = bar", {'foo' : 'none'}, "foo = 'none'"+os.linesep+"#foo = bar" ), ( # Two commented candidates "#foo = none"+os.linesep+"#foo = some", {'foo' : 'bar'}, "foo = 'bar'"+os.linesep+"#foo = some" ), ( # Two commented candidates, the latter a top contender "#foo = none"+os.linesep+"#foo = bar", {'foo' : 'bar'}, "#foo = none"+os.linesep+"foo = 'bar'" ), ( # Replace empty value "foo = "+os.linesep, {'foo' : 'feh'}, "foo = 'feh'" ), ( # Comment value "foo = bar", {'foo' : None}, "#foo = bar" ), ( # Commenting after value "foo = val this should be commented", {'foo' : 'newval'}, "foo = 'newval' #this should be commented" ), ( # Commenting after value "#foo = val this should be commented", {'foo' : 'newval'}, "foo = 'newval' #this should be commented" ), ( # Commenting after quoted value "#foo = 'val'foo this should be commented", {'foo' : 'newval'}, "foo = 'newval' #this should be commented" ), ( # Adjacent post-value comment "#foo = 'val'#foo this should be commented", {'foo' : 'newval'}, "foo = 'newval'#foo this should be commented" ), ( # New setting in empty string "", {'bar' : 'newvar'}, "bar = 'newvar'", ), ( # New setting "foo = 'bar'", {'bar' : 'newvar'}, "foo = 'bar'"+os.linesep+"bar = 'newvar'", ), ( # New setting with quote escape "foo = 'bar'", {'bar' : "new'var"}, "foo = 'bar'"+os.linesep+"bar = 'new''var'", ), ] class test_configfile(unittest.TestCase): def parseNone(self, line): sl = configfile.parse_line(line) if sl is not None: self.fail( "With line %r, parsed out to %r, %r, and %r, %r, " \ "but expected None to be returned by parse function." %( line, line[sl[0]], sl[0], line[sl[0]], sl[0] ) ) def parseExpect(self, line, key, val): line = line %(key, val) sl = configfile.parse_line(line) if sl is None: self.fail( "expecting %r and %r from line %r, " \ "but got None(syntax error) instead." %( key, val, line ) ) k, v = sl if line[k] != key: self.fail( "expecting key %r for line %r, " \ "but got %r from %r instead." %( key, line, line[k], k ) ) if line[v] != val: self.fail( "expecting value %r for line %r, " \ "but got %r from %r instead." %( val, line, line[v], v ) ) def testParser(self): self.parseExpect("#%s = %s", 'foo', 'none') self.parseExpect("#%s=%s"+os.linesep, 'foo', 'bar') self.parseExpect(" #%s=%s"+os.linesep, 'foo', 'bar') self.parseExpect('%s =%s'+os.linesep, 'foo', 'bar') self.parseExpect(' %s=%s '+os.linesep, 'foo', 'Bar') self.parseExpect(' %s = %s '+os.linesep, 'foo', 'Bar') self.parseExpect('# %s = %s '+os.linesep, 'foo', 'Bar') self.parseExpect('\t # %s = %s '+os.linesep, 'foo', 'Bar') self.parseExpect(' # %s = %s '+os.linesep, 'foo', 'Bar') self.parseExpect(" # %s = %s"+os.linesep, 'foo', "' Bar '") self.parseExpect("%s = %s# comment"+os.linesep, 'foo', '') self.parseExpect(" # %s = %s # A # comment"+os.linesep, 'foo', "' B''a#r '") # No equality or equality in complex comment self.parseNone(' #i # foo = Bar '+os.linesep) self.parseNone('#bar') self.parseNone('bar') def testConfigRead(self): sample = "foo = bar"+os.linesep+"# A comment, yes."+os.linesep+" bar = foo # yet?"+os.linesep d = configfile.read_config(sample.split(os.linesep)) self.assertTrue(d['foo'] == 'bar') self.assertTrue(d['bar'] == 'foo') def testConfigWriteRead(self): strio = StringIO() d = { '' : "'foo bar'" } configfile.write_config(d, strio.write) strio.seek(0) def testWinningCases(self): i = 0 for before, alters, after in winning_cases: befg = (x + os.linesep for x in before.split(os.linesep)) became = ''.join(configfile.alter_config(alters, befg)) self.assertTrue( became.strip() == after, 'On %d, before, %r, did not become after, %r; got %r using %r' %( i, before, after, became, alters ) ) i += 1 def testSimpleConfigAlter(self): # Simple set and uncomment and set test. strio = StringIO() strio.write("foo = bar"+os.linesep+" # bleh = unset"+os.linesep+" # grr = 'oh yeah''s'") strio.seek(0) lines = configfile.alter_config({'foo' : 'yes', 'bleh' : 'feh'}, strio) d = configfile.read_config(lines) self.assertTrue(d['foo'] == 'yes') self.assertTrue(d['bleh'] == 'feh') self.assertTrue(''.join(lines).count('bleh') == 1) def testAroma(self): lines = configfile.alter_config({ 'shared_buffers' : '800', 'port' : None }, (x + os.linesep for x in sample_config_Aroma.split('\n')) ) d = configfile.read_config(lines) self.assertTrue(d['shared_buffers'] == '800') self.assertTrue(d.get('port') is None) nlines = configfile.alter_config({'port' : '1'}, lines) d2 = configfile.read_config(nlines) self.assertTrue(d2.get('port') == '1') self.assertTrue( nlines[:4] == lines[:4] ) def testSelection(self): # Sanity red = configfile.read_config(['foo = bar'+os.linesep, 'bar = foo']) self.assertTrue(len(red.keys()) == 2) # Test a simple selector red = configfile.read_config(['foo = bar'+os.linesep, 'bar = foo'], selector = lambda x: x == 'bar') rkeys = list(red.keys()) self.assertTrue(len(rkeys) == 1) self.assertTrue(rkeys[0] == 'bar') self.assertTrue(red['bar'] == 'foo') if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_connect.py000066400000000000000000000316211203372773200203400ustar00rootroot00000000000000## # .test.test_connect ## import sys import os import unittest import atexit import socket import errno from ..python.socket import find_available_port from .. import installation from .. import cluster as pg_cluster from .. import exceptions as pg_exc from ..driver import dbapi20 as dbapi20 from .. import driver as pg_driver from .. import open as pg_open def check_for_ipv6(): result = False if socket.has_ipv6: try: socket.socket(socket.AF_INET6, socket.SOCK_STREAM) result = True except socket.error as e: errs = [errno.EAFNOSUPPORT] WSAEAFNOSUPPORT = getattr(errno, 'WSAEAFNOSUPPORT', None) if WSAEAFNOSUPPORT is not None: errs.append(WSAEAFNOSUPPORT) if e.errno not in errs: raise return result msw = sys.platform in ('win32', 'win64') # win32 binaries don't appear to be built with ipv6 has_ipv6 = check_for_ipv6() and not msw has_unix_sock = not msw class TestCaseWithCluster(unittest.TestCase): """ postgresql.driver *interface* tests. """ def __init__(self, *args, **kw): super().__init__(*args, **kw) self.installation = installation.default() self.cluster_path = \ 'py_unittest_pg_cluster_' \ + str(os.getpid()) + getattr(self, 'cluster_path_suffix', '') if self.installation is None: sys.stderr.write("ERROR: cannot find 'default' pg_config\n") sys.stderr.write( "HINT: set the PGINSTALLATION environment variable to the `pg_config` path\n" ) sys.exit(1) self.cluster = pg_cluster.Cluster( self.installation, self.cluster_path, ) if self.cluster.initialized(): self.cluster.drop() def configure_cluster(self): self.cluster_port = find_available_port() if self.cluster_port is None: pg_exc.ClusterError( 'failed to find a port for the test cluster on localhost', creator = self.cluster ).raise_exception() listen_addresses = '127.0.0.1' if has_ipv6: listen_addresses += ',::1' self.cluster.settings.update(dict( port = str(self.cluster_port), max_connections = '6', shared_buffers = '24', listen_addresses = listen_addresses, log_destination = 'stderr', log_min_messages = 'FATAL', unix_socket_directory = self.cluster.data_directory, )) # 8.4 turns prepared transactions off by default. if self.cluster.installation.version_info >= (8,1): self.cluster.settings.update(dict( max_prepared_transactions = '3', )) def initialize_database(self): c = self.cluster.connection( user = 'test', database = 'template1', ) with c: if c.prepare( "select true from pg_catalog.pg_database " \ "where datname = 'test'" ).first() is None: c.execute('create database test') def connection(self, *args, **kw): return self.cluster.connection(*args, user = 'test', **kw) def run(self, *args, **kw): if not self.cluster.initialized(): self.cluster.encoding = 'utf-8' self.cluster.init( user = 'test', encoding = self.cluster.encoding, logfile = None, ) sys.stderr.write('*') try: atexit.register(self.cluster.drop) self.configure_cluster() self.cluster.start(logfile = sys.stdout) self.cluster.wait_until_started() self.initialize_database() except Exception: self.cluster.drop() atexit.unregister(self.cluster.drop) raise if not self.cluster.running(): self.cluster.start() self.cluster.wait_until_started() db = self.connection() with db: self.db = db return super().run(*args, **kw) self.db = None class test_connect(TestCaseWithCluster): """ postgresql.driver connectivity tests """ ip6 = '::1' ip4 = '127.0.0.1' host = 'localhost' params = {} cluster_path_suffix = '_test_connect' def __init__(self, *args, **kw): super().__init__(*args,**kw) # 8.4 nixed this. self.do_crypt = self.cluster.installation.version_info < (8,4) def configure_cluster(self): super().configure_cluster() self.cluster.settings.update({ 'log_min_messages' : 'log', }) # Configure the hba file with the supported methods. with open(self.cluster.hba_file, 'w') as hba: hosts = ['0.0.0.0/0',] if has_ipv6: hosts.append('0::0/0') methods = ['md5', 'password'] + (['crypt'] if self.do_crypt else []) for h in hosts: for m in methods: # user and method are the same name. hba.writelines(['host test {m} {h} {m}\n'.format( h = h, m = m )]) # trusted hba.writelines(["local all all trust\n"]) hba.writelines(["host test trusted 0.0.0.0/0 trust\n"]) if has_ipv6: hba.writelines(["host test trusted 0::0/0 trust\n"]) # admin lines hba.writelines(["host all test 0.0.0.0/0 trust\n"]) if has_ipv6: hba.writelines(["host all test 0::0/0 trust\n"]) def initialize_database(self): super().initialize_database() with self.cluster.connection(user = 'test') as db: db.execute( """ CREATE USER md5 WITH ENCRYPTED PASSWORD 'md5_password' ; -- crypt doesn't work with encrypted passwords: -- http://www.postgresql.org/docs/8.2/interactive/auth-methods.html#AUTH-PASSWORD CREATE USER crypt WITH UNENCRYPTED PASSWORD 'crypt_password' ; CREATE USER password WITH ENCRYPTED PASSWORD 'password_password' ; CREATE USER trusted; """ ) def test_pg_open_SQL_ASCII(self): # postgresql.open host, port = self.cluster.address() # test simple locators.. with pg_open( 'pq://' + 'md5:' + 'md5_password@' + host + ':' + str(port) \ + '/test?client_encoding=SQL_ASCII' ) as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) self.assertEqual(db.settings['client_encoding'], 'SQL_ASCII') self.assertTrue(db.closed) def test_pg_open_keywords(self): host, port = self.cluster.address() # straight test, no IRI with pg_open( user = 'md5', password = 'md5_password', host = host, port = port, database = 'test' ) as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) self.assertTrue(db.closed) # composite test with pg_open( "pq://md5:md5_password@", host = host, port = port, database = 'test' ) as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) # override test with pg_open( "pq://md5:foobar@", password = 'md5_password', host = host, port = port, database = 'test' ) as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) # and, one with some settings with pg_open( "pq://md5:foobar@?search_path=ieeee", password = 'md5_password', host = host, port = port, database = 'test', settings = {'search_path' : 'public'} ) as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) self.assertEqual(db.settings['search_path'], 'public') def test_pg_open(self): # postgresql.open host, port = self.cluster.address() # test simple locators.. with pg_open( 'pq://' + 'md5:' + 'md5_password@' + host + ':' + str(port) \ + '/test' ) as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) self.assertTrue(db.closed) with pg_open( 'pq://' + 'password:' + 'password_password@' + host + ':' + str(port) \ + '/test' ) as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) self.assertTrue(db.closed) with pg_open( 'pq://' + 'trusted@' + host + ':' + str(port) + '/test' ) as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) self.assertTrue(db.closed) # test environment collection pgenv = ('PGUSER', 'PGPORT', 'PGHOST', 'PGSERVICE', 'PGPASSWORD', 'PGDATABASE') stored = list(map(os.environ.get, pgenv)) try: os.environ.pop('PGSERVICE', None) os.environ['PGUSER'] = 'md5' os.environ['PGPASSWORD'] = 'md5_password' os.environ['PGHOST'] = host os.environ['PGPORT'] = str(port) os.environ['PGDATABASE'] = 'test' # No arguments, the environment provided everything. with pg_open() as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) self.assertEqual(db.prepare('select current_user').first(), 'md5') self.assertTrue(db.closed) finally: i = 0 for x in stored: env = pgenv[i] if x is None: os.environ.pop(env, None) else: os.environ[env] = x oldservice = os.environ.get('PGSERVICE') oldsysconfdir = os.environ.get('PGSYSCONFDIR') try: with open('pg_service.conf', 'w') as sf: sf.write(''' [myserv] user = password password = password_password host = {host} port = {port} dbname = test search_path = public '''.format(host = host, port = port)) sf.flush() try: os.environ['PGSERVICE'] = 'myserv' os.environ['PGSYSCONFDIR'] = os.getcwd() with pg_open() as db: self.assertEqual(db.prepare('select 1')(), [(1,)]) self.assertEqual(db.prepare('select current_user').first(), 'password') self.assertEqual(db.settings['search_path'], 'public') finally: if oldservice is None: os.environ.pop('PGSERVICE', None) else: os.environ['PGSERVICE'] = oldservice if oldsysconfdir is None: os.environ.pop('PGSYSCONFDIR', None) else: os.environ['PGSYSCONFDIR'] = oldsysconfdir finally: if os.path.exists('pg_service.conf'): os.remove('pg_service.conf') def test_dbapi_connect(self): host, port = self.cluster.address() MD5 = dbapi20.connect( user = 'md5', database = 'test', password = 'md5_password', host = host, port = port, **self.params ) self.assertEqual(MD5.cursor().execute('select 1').fetchone()[0], 1) MD5.close() self.assertRaises(pg_exc.ConnectionDoesNotExistError, MD5.cursor().execute, 'select 1' ) if self.do_crypt: CRYPT = dbapi20.connect( user = 'crypt', database = 'test', password = 'crypt_password', host = host, port = port, **self.params ) self.assertEqual(CRYPT.cursor().execute('select 1').fetchone()[0], 1) CRYPT.close() self.assertRaises(pg_exc.ConnectionDoesNotExistError, CRYPT.cursor().execute, 'select 1' ) PASSWORD = dbapi20.connect( user = 'password', database = 'test', password = 'password_password', host = host, port = port, **self.params ) self.assertEqual(PASSWORD.cursor().execute('select 1').fetchone()[0], 1) PASSWORD.close() self.assertRaises(pg_exc.ConnectionDoesNotExistError, PASSWORD.cursor().execute, 'select 1' ) TRUST = dbapi20.connect( user = 'trusted', database = 'test', password = '', host = host, port = port, **self.params ) self.assertEqual(TRUST.cursor().execute('select 1').fetchone()[0], 1) TRUST.close() self.assertRaises(pg_exc.ConnectionDoesNotExistError, TRUST.cursor().execute, 'select 1' ) def test_IP4_connect(self): C = pg_driver.default.ip4( user = 'test', host = '127.0.0.1', database = 'test', port = self.cluster.address()[1], **self.params ) with C() as c: self.assertEqual(c.prepare('select 1').first(), 1) if has_ipv6: def test_IP6_connect(self): C = pg_driver.default.ip6( user = 'test', host = '::1', database = 'test', port = self.cluster.address()[1], **self.params ) with C() as c: self.assertEqual(c.prepare('select 1').first(), 1) def test_Host_connect(self): C = pg_driver.default.host( user = 'test', host = 'localhost', database = 'test', port = self.cluster.address()[1], **self.params ) with C() as c: self.assertEqual(c.prepare('select 1').first(), 1) def test_md5_connect(self): c = self.cluster.connection( user = 'md5', password = 'md5_password', database = 'test', **self.params ) with c: self.assertEqual(c.prepare('select current_user').first(), 'md5') def test_crypt_connect(self): if self.do_crypt: c = self.cluster.connection( user = 'crypt', password = 'crypt_password', database = 'test', **self.params ) with c: self.assertEqual(c.prepare('select current_user').first(), 'crypt') def test_password_connect(self): c = self.cluster.connection( user = 'password', password = 'password_password', database = 'test', ) with c: self.assertEqual(c.prepare('select current_user').first(), 'password') def test_trusted_connect(self): c = self.cluster.connection( user = 'trusted', password = '', database = 'test', **self.params ) with c: self.assertEqual(c.prepare('select current_user').first(), 'trusted') def test_Unix_connect(self): if not has_unix_sock: return unix_domain_socket = os.path.join( self.cluster.data_directory, '.s.PGSQL.' + self.cluster.settings['port'] ) C = pg_driver.default.unix( user = 'test', unix = unix_domain_socket, ) with C() as c: self.assertEqual(c.prepare('select 1').first(), 1) self.assertEqual(c.client_address, None) def test_pg_open_unix(self): if not has_unix_sock: return unix_domain_socket = os.path.join( self.cluster.data_directory, '.s.PGSQL.' + self.cluster.settings['port'] ) with pg_open(unix = unix_domain_socket, user = 'test') as c: self.assertEqual(c.prepare('select 1').first(), 1) self.assertEqual(c.client_address, None) with pg_open('pq://test@[unix:' + unix_domain_socket.replace('/',':') + ']') as c: self.assertEqual(c.prepare('select 1').first(), 1) self.assertEqual(c.client_address, None) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_copyman.py000066400000000000000000000446451203372773200203670ustar00rootroot00000000000000## # .test.test_copyman - test .copyman ## import unittest from itertools import islice from .. import copyman from ..temporal import pg_tmp # The asyncs, and alternative termination. from ..protocol.element3 import Notice, Notify, Error, cat_messages from .. import exceptions as pg_exc # state manager can handle empty data messages, right? =) emptysource = """ CREATE TEMP TABLE emptysource (); -- 10 INSERT INTO emptysource DEFAULT VALUES; INSERT INTO emptysource DEFAULT VALUES; INSERT INTO emptysource DEFAULT VALUES; INSERT INTO emptysource DEFAULT VALUES; INSERT INTO emptysource DEFAULT VALUES; INSERT INTO emptysource DEFAULT VALUES; INSERT INTO emptysource DEFAULT VALUES; INSERT INTO emptysource DEFAULT VALUES; INSERT INTO emptysource DEFAULT VALUES; INSERT INTO emptysource DEFAULT VALUES; """ emptydst = "CREATE TEMP TABLE empty ();" # The usual subjects. stdrowcount = 10000 stdsource = """ CREATE TEMP TABLE source (i int, t text); INSERT INTO source SELECT i, i::text AS t FROM generate_series(1, {0}) AS g(i); """.format(stdrowcount) stditer = [ b'\t'.join((x, x)) + b'\n' for x in ( str(i).encode('ascii') for i in range(1, 10001) ) ] stditer_tuples = [ (x, str(x)) for x in range(1, 10001) ] stddst = "CREATE TEMP TABLE destination (i int, t text)" srcsql = "COPY source TO STDOUT" dstsql = "COPY destination FROM STDIN" binary_srcsql = "COPY source TO STDOUT WITH BINARY" binary_dstsql = "COPY destination FROM STDIN WITH BINARY" dstcount = "SELECT COUNT(*) FROM destination" grabdst = "SELECT * FROM destination ORDER BY i ASC" grabsrc = "SELECT * FROM source ORDER BY i ASC" ## # This subclass is used to append some arbitrary data # after the initial data. This is used to exercise async/notice support. class Injector(copyman.StatementProducer): def __init__(self, appended_messages, *args, **kw): super().__init__(*args, **kw) self._appended_messages = appended_messages def confiscate(self): pq = self.statement.database.pq mb = pq.message_buffer b = mb.getvalue() mb.truncate() mb.write(cat_messages(self._appended_messages)) mb.write(b) return super().confiscate() class test_copyman(unittest.TestCase): def testNull(self): # Test some of the basic machinery. sp = copyman.NullProducer() sr = copyman.NullReceiver() copyman.CopyManager(sp, sr).run() self.assertEqual(sp.total_messages, 0) self.assertEqual(sp.total_bytes, 0) @pg_tmp def testNullProducer(self): sqlexec(stddst) np = copyman.NullProducer() sr = copyman.StatementReceiver(prepare(dstsql)) copyman.CopyManager(np, sr).run() self.assertEqual(np.total_messages, 0) self.assertEqual(np.total_bytes, 0) self.assertEqual(prepare(dstcount).first(), 0) self.assertEqual(prepare(grabdst)(), []) @pg_tmp def testNullReceiver(self): sqlexec(stdsource) sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 128) sr = copyman.NullReceiver() with copyman.CopyManager(sp, sr) as copy: for x in copy: pass self.assertEqual(sp.total_messages, stdrowcount) self.assertEqual(sp.total_bytes > 0, True) def testIteratorToCall(self): tmp = iter(stditer) # segment stditer into chunks consisting of twenty rows each sp = copyman.IteratorProducer([ list(islice(tmp, 20)) for x in range(len(stditer) // 20) ]) dest = [] sr = copyman.CallReceiver(dest.extend) recomputed_bytes = 0 recomputed_messages = 0 with copyman.CopyManager(sp, sr) as copy: for msg, bytes in copy: recomputed_messages += msg recomputed_bytes += bytes self.assertEqual(stdrowcount, recomputed_messages) self.assertEqual(recomputed_bytes, sp.total_bytes) self.assertEqual(len(dest), stdrowcount) self.assertEqual(dest, stditer) @pg_tmp def testDirectStatements(self): sqlexec(stdsource) dst = new() dst.execute(stddst) sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 512) sr = copyman.StatementReceiver(dst.prepare(dstsql)) with copyman.CopyManager(sp, sr) as copy: for x in copy: pass self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) self.assertEqual(dst.prepare(grabdst)(), prepare(grabsrc)()) @pg_tmp def testIteratorProducer(self): sqlexec(stddst) sp = copyman.IteratorProducer([stditer]) sr = copyman.StatementReceiver(prepare(dstsql)) recomputed_bytes = 0 recomputed_messages = 0 with copyman.CopyManager(sp, sr) as copy: for msg, bytes in copy: recomputed_messages += msg recomputed_bytes += bytes self.assertEqual(stdrowcount, recomputed_messages) self.assertEqual(recomputed_bytes, sp.total_bytes) self.assertEqual(prepare(dstcount).first(), stdrowcount) self.assertEqual(prepare(grabdst)(), stditer_tuples) def multiple_destinations(self, count = 3, binary = False, buffer_size = 129): if binary: src = binary_srcsql dst = binary_dstsql # accommodate for the binary header. count_offset = 1 else: src = srcsql dst = dstsql count_offset = 0 sqlexec(stdsource) dests = [new() for x in range(count)] receivers = [] for x in dests: x.execute(stddst) receivers.append(copyman.StatementReceiver(x.prepare(dst))) sp = copyman.StatementProducer(prepare(src), buffer_size = buffer_size) recomputed_bytes = 0 recomputed_messages = 0 with copyman.CopyManager(sp, *receivers) as copy: for msg, bytes in copy: recomputed_messages += msg recomputed_bytes += bytes src_snap = prepare(grabsrc)() for x in dests: self.assertEqual(x.prepare(dstcount).first(), stdrowcount) self.assertEqual(x.prepare(grabdst)(), src_snap) self.assertEqual(stdrowcount + count_offset, recomputed_messages) self.assertEqual(recomputed_bytes, sp.total_bytes) @pg_tmp def testMultipleStatements(self): self.multiple_destinations() @pg_tmp def testMultipleStatementsBinary(self): self.multiple_destinations(binary = True) @pg_tmp def testMultipleStatementsSmallBuffer(self): self.multiple_destinations(buffer_size = 11) @pg_tmp def testNotices(self): # Inject a Notices directly into the stream to emulate # cases of asynchronous messages received during COPY. notices = [ Notice(( (b'S', b'NOTICE'), (b'C', b'00000'), (b'M', b'It\'s a beautiful day.'), )), Notice(( (b'S', b'WARNING'), (b'C', b'01X1X1'), (b'M', b'FAILURE IS CERTAIN'), )) ] sqlexec(stdsource) dst = new() dst.execute(stddst) # hook for notices.. rmessages = [] def hook(msg): rmessages.append(msg) # suppress return True stmt = prepare(srcsql) stmt.msghook = hook sp = Injector(notices, stmt, buffer_size = 133) sr = copyman.StatementReceiver(dst.prepare(dstsql)) seen_in_loop = 0 with copyman.CopyManager(sp, sr) as copy: for x in copy: if rmessages: # Should get hooked before the COPY is over. seen_in_loop += 1 self.assertTrue(seen_in_loop > 0) self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) self.assertEqual(dst.prepare(grabdst)(), prepare(grabsrc)()) # The injector adds then everytime the wire data is confiscated # from the protocol connection. notice, warning = rmessages[:2] self.assertEqual(notice.code, "00000") self.assertEqual(warning.code, "01X1X1") self.assertEqual(warning.details['severity'], "WARNING") self.assertEqual(notice.message, "It's a beautiful day.") self.assertEqual(warning.message, "FAILURE IS CERTAIN") self.assertEqual(notice.details['severity'], "NOTICE") @pg_tmp def testAsyncNotify(self): # Inject a NOTIFY directly into the stream to emulate # cases of asynchronous messages received during COPY. notify = [Notify(1234, b'channel', b'payload')] sqlexec(stdsource) dst = new() dst.execute(stddst) sp = Injector(notify, prepare(srcsql), buffer_size = 32) sr = copyman.StatementReceiver(dst.prepare(dstsql)) seen_in_loop = 0 r = [] with copyman.CopyManager(sp, sr) as copy: for x in copy: r += list(db.iternotifies(0)) # Got the injected NOTIFY's, right? self.assertTrue(r) # it may have happened multiple times, so adjust accordingly. self.assertEqual(r, [('channel', 'payload', 1234)]*len(r)) @pg_tmp def testUnfinishedCopy(self): sqlexec(stdsource) dst = new() dst.execute(stddst) sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 32) sr = copyman.StatementReceiver(dst.prepare(dstsql)) try: with copyman.CopyManager(sp, sr) as copy: for x in copy: break self.fail("did not raise CopyFail") except copyman.CopyFail: pass @pg_tmp def testRaiseInCopy(self): sqlexec(stdsource) dst = new() dst.execute(stddst) sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 128) sr = copyman.StatementReceiver(dst.prepare(dstsql)) i = 0 class ThisError(Exception): pass try: with copyman.CopyManager(sp, sr) as copy: for x in copy: # Note, the state of the receiver has changed. # We may not be on a message boundary, so this test # exercises cases where an interrupt occurs where # re-alignment *may* need to occur. raise ThisError() except copyman.CopyFail as cf: # It's a copy failure, but due to ThisError. self.assertTrue(isinstance(cf.__context__, ThisError)) else: self.fail("didn't raise CopyFail") # Connections should be usable. self.assertEqual(prepare('select 1').first(), 1) self.assertEqual(dst.prepare('select 1').first(), 1) @pg_tmp def testRaiseInCopyOnEnter(self): sqlexec(stdsource) dst = new() dst.execute(stddst) sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 128) sr = copyman.StatementReceiver(dst.prepare(dstsql)) i = 0 class ThatError(Exception): pass try: with copyman.CopyManager(sp, sr) as copy: raise ThatError() except copyman.CopyFail as cf: # yeah; error on incomplete COPY self.assertTrue(isinstance(cf.__context__, ThatError)) else: self.fail("didn't raise CopyFail") @pg_tmp def testCopyWithFailure(self): sqlexec(stdsource) dst = new() dst2 = new() dst.execute(stddst) dst2.execute(stddst) sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 128) sr1 = copyman.StatementReceiver(dst.prepare(dstsql)) sr2 = copyman.StatementReceiver(dst2.prepare(dstsql)) done = False with copyman.CopyManager(sp, sr1, sr2) as copy: while True: try: for x in copy: if not done: done = True dst2.pq.socket.close() else: # Done with copy. break except copyman.ReceiverFault as cf: if sr2 not in cf.faults: raise self.assertTrue(done) self.assertRaises(Exception, dst2.execute, 'select 1') self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) self.assertEqual(dst.prepare(grabdst)(), prepare(grabsrc)()) @pg_tmp def testEmptyRows(self): sqlexec(emptysource) dst = new() dst.execute(emptydst) sp = copyman.StatementProducer(prepare("COPY emptysource TO STDOUT"), buffer_size = 127) sr = copyman.StatementReceiver(dst.prepare("COPY empty FROM STDIN")) m = 0 b = 0 with copyman.CopyManager(sp, sr) as copy: for x in copy: nmsg, nbytes = x m += nmsg b += nbytes self.assertEqual(m, 10) self.assertEqual(prepare("SELECT COUNT(*) FROM emptysource").first(), 10) self.assertEqual(dst.prepare("SELECT COUNT(*) FROM empty").first(), 10) self.assertEqual(sr.count(), 10) self.assertEqual(sp.count(), 10) @pg_tmp def testCopyOne(self): from io import BytesIO b = BytesIO() copyman.transfer( prepare('COPY (SELECT 1) TO STDOUT'), copyman.CallReceiver(b.writelines) ) b.seek(0) self.assertEqual(b.read(), b'1\n') @pg_tmp def testCopyNone(self): from io import BytesIO b = BytesIO() copyman.transfer( prepare('COPY (SELECT 1 LIMIT 0) TO STDOUT'), copyman.CallReceiver(b.writelines) ) b.seek(0) self.assertEqual(b.read(), b'') @pg_tmp def testNoReceivers(self): sqlexec(stdsource) dst = new() dst.execute(stddst) sp = copyman.StatementProducer(prepare(srcsql)) sr1 = copyman.StatementReceiver(dst.prepare(dstsql)) done = False try: with copyman.CopyManager(sp, sr1) as copy: while not done: try: for x in copy: if not done: done = True dst.pq.socket.close() else: self.fail("failed to detect dead socket") except copyman.ReceiverFault as cf: self.assertTrue(sr1 in cf.faults) # Don't reconcile. Let the manager drop the receiver. except copyman.CopyFail: self.assertTrue(not bool(copy.receivers)) # Success. else: self.fail("did not raise expected error") # Let the exception cause a failure. self.assertTrue(done) @pg_tmp def testReconciliation(self): # cm.reconcile() test. sqlexec(stdsource) dst = new() dst.execute(stddst) sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 201) sr = copyman.StatementReceiver(dst.prepare(dstsql)) original_call = sr.send class RecoverableError(Exception): pass def failed_write(*args): sr.send = original_call raise RecoverableError() sr.send = failed_write done = False recomputed_messages = 0 recomputed_bytes = 0 with copyman.CopyManager(sp, sr) as copy: while copy.receivers: try: for nmsg, nbytes in copy: recomputed_messages += nmsg recomputed_bytes += nbytes else: # Done with COPY, break out of while copy.receivers. break except copyman.ReceiverFault as cf: if isinstance(cf.faults[sr], RecoverableError): if done is True: self.fail("failed_write was called twice?") done = True self.assertEqual(len(copy.receivers), 0) copy.reconcile(sr) self.assertEqual(len(copy.receivers), 1) self.assertEqual(done, True) # Connections should be usable. self.assertEqual(prepare('select 1').first(), 1) self.assertEqual(dst.prepare('select 1').first(), 1) # validate completion self.assertEqual(stdrowcount, recomputed_messages) self.assertEqual(recomputed_bytes, sp.total_bytes) self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) @pg_tmp def testDroppedConnection(self): # no cm.reconcile() test. sqlexec(stdsource) dst = new() dst2 = new() dst2.execute(stddst) dst.execute(stddst) sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 201) sr1 = copyman.StatementReceiver(dst.prepare(dstsql)) sr2 = copyman.StatementReceiver(dst2.prepare(dstsql)) class TheCause(Exception): pass def failed_write(*args): raise TheCause() sr2.send = failed_write done = False recomputed_messages = 0 recomputed_bytes = 0 with copyman.CopyManager(sp, sr1, sr2) as copy: while copy.receivers: try: for nmsg, nbytes in copy: recomputed_messages += nmsg recomputed_bytes += nbytes else: # Done with COPY, break out of while copy.receivers. break except copyman.ReceiverFault as cf: self.assertTrue(isinstance(cf.faults[sr2], TheCause)) if done is True: self.fail("failed_write was called twice?") done = True self.assertEqual(len(copy.receivers), 1) dst2.pq.socket.close() # We don't reconcile, so the manager only has one target now. self.assertEqual(done, True) # May not be aligned; really, we're expecting the connection to # have died. self.assertRaises(Exception, dst2.execute, "SELECT 1") # Connections should be usable. self.assertEqual(prepare('select 1').first(), 1) self.assertEqual(dst.prepare('select 1').first(), 1) # validate completion self.assertEqual(stdrowcount, recomputed_messages) self.assertEqual(recomputed_bytes, sp.total_bytes) self.assertEqual(dst.prepare(dstcount).first(), stdrowcount) self.assertEqual(sp.count(), stdrowcount) self.assertEqual(sp.command(), "COPY") @pg_tmp def testProducerFailure(self): sqlexec(stdsource) dst = new() dst.execute(stddst) sp = copyman.StatementProducer(prepare(srcsql)) sr = copyman.StatementReceiver(dst.prepare(dstsql)) done = False try: with copyman.CopyManager(sp, sr) as copy: try: for x in copy: if not done: done = True db.pq.socket.close() except copyman.ProducerFault as pf: self.assertTrue(pf.__context__ is not None) self.fail('expected CopyManager to raise CopyFail') except copyman.CopyFail as cf: # Expecting to see CopyFail self.assertTrue(True) self.assertTrue(isinstance(cf.producer_fault, pg_exc.ConnectionFailureError)) self.assertTrue(done) self.assertRaises(Exception, sqlexec, 'select 1') self.assertEqual(dst.prepare(dstcount).first(), 0) from ..copyman import WireState class test_WireState(unittest.TestCase): def testNormal(self): WS=WireState() messages = WS.update(memoryview(b'd\x00\x00\x00\x04')) self.assertEqual(messages, 1) self.assertEqual(WS.remaining_bytes, 0) self.assertEqual(WS.size_fragment, b'') self.assertEqual(WS.final_view, None) def testIncomplete(self): WS=WireState() messages = WS.update(memoryview(b'd\x00\x00\x00\x05')) self.assertEqual(messages, 0) self.assertEqual(WS.remaining_bytes, 1) self.assertEqual(WS.size_fragment, b'') self.assertEqual(WS.final_view, None) messages = WS.update(memoryview(b'x')) self.assertEqual(messages, 1) self.assertEqual(WS.remaining_bytes, 0) self.assertEqual(WS.size_fragment, b'') self.assertEqual(WS.final_view, None) def testIncompleteHeader_0size(self): WS=WireState() messages = WS.update(memoryview(b'd')) self.assertEqual(messages, 0) self.assertEqual(WS.remaining_bytes, -1) self.assertEqual(WS.size_fragment, b'') self.assertEqual(WS.final_view, None) messages = WS.update(b'\x00\x00\x00\x04') self.assertEqual(messages, 1) def testIncompleteHeader_1size(self): WS=WireState() messages = WS.update(memoryview(b'd\x00')) self.assertEqual(messages, 0) self.assertEqual(WS.size_fragment, b'\x00') self.assertEqual(WS.final_view, None) self.assertEqual(WS.remaining_bytes, -1) messages = WS.update(memoryview(b'\x00\x00\x04')) self.assertEqual(messages, 1) self.assertEqual(WS.remaining_bytes, 0) def testIncompleteHeader_2size(self): WS=WireState() messages = WS.update(memoryview(b'd\x00\x00')) self.assertEqual(messages, 0) self.assertEqual(WS.remaining_bytes, -1) self.assertEqual(WS.size_fragment, b'\x00\x00') self.assertEqual(WS.final_view, None) messages = WS.update(b'\x00\x04') self.assertEqual(messages, 1) self.assertEqual(WS.remaining_bytes, 0) def testIncompleteHeader_3size(self): WS=WireState() messages = WS.update(memoryview(b'd\x00\x00\x00')) self.assertEqual(messages, 0) self.assertEqual(WS.remaining_bytes, -1) self.assertEqual(WS.size_fragment, b'\x00\x00\x00') self.assertEqual(WS.final_view, None) messages = WS.update(memoryview(b'\x04')) self.assertEqual(messages, 1) self.assertEqual(WS.remaining_bytes, 0) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_dbapi20.py000066400000000000000000000625331203372773200201360ustar00rootroot00000000000000## # .test.test_dbapi20 - test .driver.dbapi20 ## import unittest import time from ..temporal import pg_tmp ## # Various Adjustments for .driver.dbapi20 # # Log: dbapi20.py # Revision 1.10 2003/10/09 03:14:14 zenzen # Add test for DB API 2.0 optional extension, where database exceptions # are exposed as attributes on the Connection object. # # Revision 1.9 2003/08/13 01:16:36 zenzen # Minor tweak from Stefan Fleiter # # Revision 1.8 2003/04/10 00:13:25 zenzen # Changes, as per suggestions by M.-A. Lemburg # - Add a table prefix, to ensure namespace collisions can always be avoided # # Revision 1.7 2003/02/26 23:33:37 zenzen # Break out DDL into helper functions, as per request by David Rushby # # Revision 1.6 2003/02/21 03:04:33 zenzen # Stuff from Henrik Ekelund: # added test_None # added test_nextset & hooks # # Revision 1.5 2003/02/17 22:08:43 zenzen # Implement suggestions and code from Henrik Eklund - test that cursor.arraysize # defaults to 1 & generic cursor.callproc test added # # Revision 1.4 2003/02/15 00:16:33 zenzen # Changes, as per suggestions and bug reports by M.-A. Lemburg, # Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar # - Class renamed # - Now a subclass of TestCase, to avoid requiring the driver stub # to use multiple inheritance # - Reversed the polarity of buggy test in test_description # - Test exception heirarchy correctly # - self.populate is now self._populate(), so if a driver stub # overrides self.ddl1 this change propogates # - VARCHAR columns now have a width, which will hopefully make the # DDL even more portible (this will be reversed if it causes more problems) # - cursor.rowcount being checked after various execute and fetchXXX methods # - Check for fetchall and fetchmany returning empty lists after results # are exhausted (already checking for empty lists if select retrieved # nothing # - Fix bugs in test_setoutputsize_basic and test_setinputsizes # class test_dbapi20(unittest.TestCase): """ Test a database self.driver for DB API 2.0 compatibility. This implementation tests Gadfly, but the TestCase is structured so that other self.drivers can subclass this test case to ensure compiliance with the DB-API. It is expected that this TestCase may be expanded in the future if ambiguities or edge conditions are discovered. The 'Optional Extensions' are not yet being tested. self.drivers should subclass this test, overriding setUp, tearDown, self.driver, connect_args and connect_kw_args. Class specification should be as follows: import dbapi20 class mytest(dbapi20.DatabaseAPI20Test): [...] __rcs_id__ = 'Id: dbapi20.py,v 1.10 2003/10/09 03:14:14 zenzen Exp' __version__ = 'Revision: 1.10' __author__ = 'Stuart Bishop ' """ import postgresql.driver.dbapi20 as driver table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables booze_name = table_prefix + 'booze' ddl1 = 'create temp table %s (name varchar(20))' % booze_name ddl2 = 'create temp table %sbarflys (name varchar(20))' % table_prefix xddl1 = 'drop table %sbooze' % table_prefix xddl2 = 'drop table %sbarflys' % table_prefix lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. def executeDDL1(self,cursor): cursor.execute(self.ddl1) def executeDDL2(self,cursor): cursor.execute(self.ddl2) def tearDown(self): con = self._connect() try: cur = con.cursor() for ddl in (self.xddl1, self.xddl2): try: cur.execute(ddl) con.commit() except self.driver.Error: # Assume table didn't exist. Other tests will check if # execute is busted. pass finally: con.close() def _connect(self): pg_tmp.init() host, port = pg_tmp.cluster.address() return self.driver.connect( user = 'test', host = host, port = port, ) def test_connect(self): con = self._connect() con.close() def test_apilevel(self): try: # Must exist apilevel = self.driver.apilevel # Must equal 2.0 self.assertEqual(apilevel,'2.0') except AttributeError: self.fail("Driver doesn't define apilevel") def test_threadsafety(self): try: # Must exist threadsafety = self.driver.threadsafety # Must be a valid value self.assertTrue(threadsafety in (0,1,2,3)) except AttributeError: self.fail("Driver doesn't define threadsafety") def test_paramstyle(self): try: # Must exist paramstyle = self.driver.paramstyle # Must be a valid value self.assertTrue(paramstyle in ( 'qmark','numeric','named','format','pyformat' )) except AttributeError: self.fail("Driver doesn't define paramstyle") def test_Exceptions(self): # Make sure required exceptions exist, and are in the # defined heirarchy. self.assertTrue(issubclass(self.driver.InterfaceError,self.driver.Error)) self.assertTrue(issubclass(self.driver.DatabaseError,self.driver.Error)) self.assertTrue(issubclass(self.driver.OperationalError,self.driver.Error)) self.assertTrue(issubclass(self.driver.IntegrityError,self.driver.Error)) self.assertTrue(issubclass(self.driver.InternalError,self.driver.Error)) self.assertTrue(issubclass(self.driver.ProgrammingError,self.driver.Error)) self.assertTrue(issubclass(self.driver.NotSupportedError,self.driver.Error)) def test_ExceptionsAsConnectionAttributes(self): # OPTIONAL EXTENSION # Test for the optional DB API 2.0 extension, where the exceptions # are exposed as attributes on the Connection object # I figure this optional extension will be implemented by any # driver author who is using this test suite, so it is enabled # by default. con = self._connect() try: drv = self.driver self.assertTrue(con.Warning is drv.Warning) self.assertTrue(con.Error is drv.Error) self.assertTrue(con.InterfaceError is drv.InterfaceError) self.assertTrue(con.DatabaseError is drv.DatabaseError) self.assertTrue(con.OperationalError is drv.OperationalError) self.assertTrue(con.IntegrityError is drv.IntegrityError) self.assertTrue(con.InternalError is drv.InternalError) self.assertTrue(con.ProgrammingError is drv.ProgrammingError) self.assertTrue(con.NotSupportedError is drv.NotSupportedError) finally: con.close() def test_commit(self): con = self._connect() try: # Commit must work, even if it doesn't do anything con.commit() finally: con.close() def test_rollback(self): con = self._connect() # If rollback is defined, it should either work or throw # the documented exception try: if hasattr(con,'rollback'): try: con.rollback() except self.driver.NotSupportedError: pass finally: con.close() def test_cursor(self): con = self._connect() try: cur = con.cursor() finally: con.close() def test_cursor_isolation(self): con = self._connect() try: # Make sure cursors created from the same connection have # the documented transaction isolation level cur1 = con.cursor() cur2 = con.cursor() self.executeDDL1(cur1) cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( self.table_prefix )) cur2.execute("select name from %sbooze" % self.table_prefix) booze = cur2.fetchall() self.assertEqual(len(booze),1) self.assertEqual(len(booze[0]),1) self.assertEqual(booze[0][0],'Victoria Bitter') finally: con.close() def test_description(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) self.assertEqual(cur.description,None, 'cursor.description should be none after executing a ' 'statement that can return no rows (such as DDL)' ) cur.execute('select name from %sbooze' % self.table_prefix) self.assertEqual(len(cur.description),1, 'cursor.description describes too many columns' ) self.assertEqual(len(cur.description[0]),7, 'cursor.description[x] tuples must have 7 elements' ) self.assertEqual(cur.description[0][0].lower(),'name', 'cursor.description[x][0] must return column name' ) self.assertEqual(cur.description[0][1],self.driver.STRING, 'cursor.description[x][1] must return column type. Got %r' % cur.description[0][1] ) # Make sure self.description gets reset self.executeDDL2(cur) self.assertEqual(cur.description,None, 'cursor.description not being set to None when executing ' 'no-result statements (eg. DDL)' ) finally: con.close() def test_rowcount(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) self.assertEqual(cur.rowcount, -1, 'cursor.rowcount should be -1 after executing no-result ' 'statements' ) cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( self.table_prefix )) self.assertEqual(cur.rowcount, 1, 'cursor.rowcount should == number or rows inserted, or ' 'set to -1 after executing an insert statement' ) cur.execute("insert into %sbooze select 'Victoria Bitter' WHERE FALSE" % ( self.table_prefix )) self.assertEqual(cur.rowcount, 0) cur.execute("insert into %sbooze select 'First' UNION ALL select 'second'" % ( self.table_prefix )) self.assertEqual(cur.rowcount, 2) cur.execute("select name from %sbooze" % self.table_prefix) self.assertEqual(cur.rowcount, -1, 'cursor.rowcount should == number of rows returned, or ' 'set to -1 after executing a select statement' ) self.executeDDL2(cur) self.assertEqual(cur.rowcount, -1, 'cursor.rowcount not being reset to -1 after executing ' 'no-result statements' ) finally: con.close() lower_func = 'lower' def test_callproc(self): con = self._connect() try: cur = con.cursor() if self.lower_func and hasattr(cur,'callproc'): r = cur.callproc(self.lower_func,('FOO',)) self.assertEqual(len(r),1) self.assertEqual(r[0],'FOO') r = cur.fetchall() self.assertEqual(len(r),1,'callproc produced no result set') self.assertEqual(len(r[0]),1, 'callproc produced invalid result set' ) self.assertEqual(r[0][0],'foo', 'callproc produced invalid results' ) finally: con.close() def test_close(self): con = self._connect() try: cur = con.cursor() finally: con.close() # cursor.execute should raise an Error if called after connection # closed self.assertRaises(self.driver.Error,self.executeDDL1,cur) # connection.commit should raise an Error if called after connection' # closed.' self.assertRaises(self.driver.Error,con.commit) # connection.close should raise an Error if called more than once self.assertRaises(self.driver.Error,con.close) def test_cursor_close(self): con = self._connect() try: cur = con.cursor() cur.close() # cursor.execute should raise an Error if called after cursor.close # closed self.assertRaises(self.driver.Error,self.executeDDL1,cur) # cursor.executemany should raise an Error if called after connection' # closed.' self.assertRaises(self.driver.Error,cur.executemany,'foo', []) self.assertRaises(self.driver.Error,cur.callproc,'generate_series', [1, 10]) # cursor.close should raise an Error if called more than once self.assertRaises(self.driver.Error,cur.close) finally: con.close() def test_execute(self): con = self._connect() try: cur = con.cursor() self._paraminsert(cur) finally: con.close() def test_format_execute(self): self.driver.paramstyle = 'format' try: self.test_execute() finally: self.driver.paramstyle = 'pyformat' def _paraminsert(self,cur): self.executeDDL1(cur) cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( self.table_prefix )) self.assertTrue(cur.rowcount in (-1,1)) if self.driver.paramstyle == 'qmark': cur.execute( 'insert into %sbooze values (?)' % self.table_prefix, ("Cooper's",) ) elif self.driver.paramstyle == 'numeric': cur.execute( 'insert into %sbooze values (:1)' % self.table_prefix, ("Cooper's",) ) elif self.driver.paramstyle == 'named': cur.execute( 'insert into %sbooze values (:beer)' % self.table_prefix, {'beer':"Cooper's"} ) elif self.driver.paramstyle == 'format': cur.execute( 'insert into %sbooze values (%%s)' % self.table_prefix, ("Cooper's",) ) elif self.driver.paramstyle == 'pyformat': cur.execute( 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, {'beer':"Cooper's"} ) else: self.fail('Invalid paramstyle') self.assertTrue(cur.rowcount in (-1,1)) cur.execute('select name from %sbooze' % self.table_prefix) res = cur.fetchall() self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') beers = [res[0][0],res[1][0]] beers.sort() self.assertEqual(beers[0],"Cooper's", 'cursor.fetchall retrieved incorrect data, or data inserted ' 'incorrectly' ) self.assertEqual(beers[1],"Victoria Bitter", 'cursor.fetchall retrieved incorrect data, or data inserted ' 'incorrectly' ) def test_executemany(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) largs = [ ("Cooper's",) , ("Boag's",) ] margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] if self.driver.paramstyle == 'qmark': cur.executemany( 'insert into %sbooze values (?)' % self.table_prefix, largs ) elif self.driver.paramstyle == 'numeric': cur.executemany( 'insert into %sbooze values (:1)' % self.table_prefix, largs ) elif self.driver.paramstyle == 'named': cur.executemany( 'insert into %sbooze values (:beer)' % self.table_prefix, margs ) elif self.driver.paramstyle == 'format': cur.executemany( 'insert into %sbooze values (%%s)' % self.table_prefix, largs ) elif self.driver.paramstyle == 'pyformat': cur.executemany( 'insert into %sbooze values (%%(beer)s)' % ( self.table_prefix ), margs ) else: self.fail('Unknown paramstyle') self.assertTrue(cur.rowcount in (-1,2), 'insert using cursor.executemany set cursor.rowcount to ' 'incorrect value %r' % cur.rowcount ) cur.execute('select name from %sbooze' % self.table_prefix) res = cur.fetchall() self.assertEqual(len(res),2, 'cursor.fetchall retrieved incorrect number of rows' ) beers = [res[0][0],res[1][0]] beers.sort() self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') finally: con.close() def test_format_executemany(self): self.driver.paramstyle = 'format' try: self.test_executemany() finally: self.driver.paramstyle = 'pyformat' def test_fetchone(self): con = self._connect() try: cur = con.cursor() # cursor.fetchone should raise an Error if called before # executing a select-type query self.assertRaises(self.driver.Error,cur.fetchone) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows self.executeDDL1(cur) self.assertRaises(self.driver.Error,cur.fetchone) cur.execute('select name from %sbooze' % self.table_prefix) self.assertEqual(cur.fetchone(),None, 'cursor.fetchone should return None if a query retrieves ' 'no rows' ) self.assertTrue(cur.rowcount in (-1,0)) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( self.table_prefix )) self.assertRaises(self.driver.Error,cur.fetchone) cur.execute('select name from %sbooze' % self.table_prefix) r = cur.fetchone() self.assertEqual(len(r),1, 'cursor.fetchone should have retrieved a single row' ) self.assertEqual(r[0],'Victoria Bitter', 'cursor.fetchone retrieved incorrect data' ) self.assertEqual(cur.fetchone(),None, 'cursor.fetchone should return None if no more rows available' ) self.assertTrue(cur.rowcount in (-1,1)) finally: con.close() samples = [ 'Carlton Cold', 'Carlton Draft', 'Mountain Goat', 'Redback', 'Victoria Bitter', 'XXXX' ] def _populate(self): ''' Return a list of sql commands to setup the DB for the fetch tests. ''' populate = [ "insert into %sbooze values ('%s')" % (self.table_prefix,s) for s in self.samples ] return populate def test_fetchmany(self): con = self._connect() try: cur = con.cursor() # cursor.fetchmany should raise an Error if called without #issuing a query self.assertRaises(self.driver.Error,cur.fetchmany,4) self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) cur.execute('select name from %sbooze' % self.table_prefix) r = cur.fetchmany() self.assertEqual(len(r),1, 'cursor.fetchmany retrieved incorrect number of rows, ' 'default of arraysize is one.' ) cur.arraysize=10 r = cur.fetchmany(3) # Should get 3 rows self.assertEqual(len(r),3, 'cursor.fetchmany retrieved incorrect number of rows' ) r = cur.fetchmany(4) # Should get 2 more self.assertEqual(len(r),2, 'cursor.fetchmany retrieved incorrect number of rows' ) r = cur.fetchmany(4) # Should be an empty sequence self.assertEqual(len(r),0, 'cursor.fetchmany should return an empty sequence after ' 'results are exhausted' ) self.assertTrue(cur.rowcount in (-1,6)) # Same as above, using cursor.arraysize cur.arraysize=4 cur.execute('select name from %sbooze' % self.table_prefix) r = cur.fetchmany() # Should get 4 rows self.assertEqual(len(r),4, 'cursor.arraysize not being honoured by fetchmany' ) r = cur.fetchmany() # Should get 2 more self.assertEqual(len(r),2) r = cur.fetchmany() # Should be an empty sequence self.assertEqual(len(r),0) self.assertTrue(cur.rowcount in (-1,6)) cur.arraysize=6 cur.execute('select name from %sbooze' % self.table_prefix) rows = cur.fetchmany() # Should get all rows self.assertTrue(cur.rowcount in (-1,6)) self.assertEqual(len(rows),6) self.assertEqual(len(rows),6) rows = [r[0] for r in rows] rows.sort() # Make sure we get the right data back out for i in range(0,6): self.assertEqual(rows[i],self.samples[i], 'incorrect data retrieved by cursor.fetchmany' ) rows = cur.fetchmany() # Should return an empty list self.assertEqual(len(rows),0, 'cursor.fetchmany should return an empty sequence if ' 'called after the whole result set has been fetched' ) self.assertTrue(cur.rowcount in (-1,6)) self.executeDDL2(cur) cur.execute('select name from %sbarflys' % self.table_prefix) r = cur.fetchmany() # Should get empty sequence self.assertEqual(len(r),0, 'cursor.fetchmany should return an empty sequence if ' 'query retrieved no rows' ) self.assertTrue(cur.rowcount in (-1,0)) finally: con.close() def test_fetchall(self): con = self._connect() try: cur = con.cursor() # cursor.fetchall should raise an Error if called # without executing a query that may return rows (such # as a select) self.assertRaises(self.driver.Error, cur.fetchall) self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) # cursor.fetchall should raise an Error if called # after executing a a statement that cannot return rows self.assertRaises(self.driver.Error,cur.fetchall) cur.execute('select name from %sbooze' % self.table_prefix) rows = cur.fetchall() self.assertTrue(cur.rowcount in (-1,len(self.samples))) self.assertEqual(len(rows),len(self.samples), 'cursor.fetchall did not retrieve all rows' ) rows = [r[0] for r in rows] rows.sort() for i in range(0,len(self.samples)): self.assertEqual(rows[i],self.samples[i], 'cursor.fetchall retrieved incorrect rows' ) rows = cur.fetchall() self.assertEqual( len(rows),0, 'cursor.fetchall should return an empty list if called ' 'after the whole result set has been fetched' ) self.assertTrue(cur.rowcount in (-1,len(self.samples))) self.executeDDL2(cur) cur.execute('select name from %sbarflys' % self.table_prefix) rows = cur.fetchall() self.assertTrue(cur.rowcount in (-1,0)) self.assertEqual(len(rows),0, 'cursor.fetchall should return an empty list if ' 'a select query returns no rows' ) finally: con.close() def test_mixedfetch(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) cur.execute('select name from %sbooze' % self.table_prefix) rows1 = cur.fetchone() rows23 = cur.fetchmany(2) rows4 = cur.fetchone() rows56 = cur.fetchall() self.assertTrue(cur.rowcount in (-1,6)) self.assertEqual(len(rows23),2, 'fetchmany returned incorrect number of rows' ) self.assertEqual(len(rows56),2, 'fetchall returned incorrect number of rows' ) rows = [rows1[0]] rows.extend([rows23[0][0],rows23[1][0]]) rows.append(rows4[0]) rows.extend([rows56[0][0],rows56[1][0]]) rows.sort() for i in range(0,len(self.samples)): self.assertEqual(rows[i],self.samples[i], 'incorrect data retrieved or inserted' ) finally: con.close() def help_nextset_setUp(self,cur): ''' Should create a procedure called deleteme that returns two result sets, first the number of rows in booze then "name from booze" ''' cur.execute('select name from ' + self.booze_name) cur.execute('select count(*) from ' + self.booze_name) def help_nextset_tearDown(self,cur): 'If cleaning up is needed after nextSetTest' pass def test_nextset(self): con = self._connect() try: cur = con.cursor() if not hasattr(cur,'nextset'): return try: self.executeDDL1(cur) sql=self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) numberofrows=cur.fetchone() assert numberofrows[0]== len(self.samples) assert cur.nextset() names=cur.fetchall() assert len(names) == len(self.samples) s=cur.nextset() assert s == None,'No more return sets, should return None' finally: self.help_nextset_tearDown(cur) finally: con.close() def test_arraysize(self): # Not much here - rest of the tests for this are in test_fetchmany con = self._connect() try: cur = con.cursor() self.assertTrue(hasattr(cur,'arraysize'), 'cursor.arraysize must be defined' ) finally: con.close() def test_setinputsizes(self): con = self._connect() try: cur = con.cursor() cur.setinputsizes( (25,) ) self._paraminsert(cur) # Make sure cursor still works finally: con.close() def test_setoutputsize_basic(self): # Basic test is to make sure setoutputsize doesn't blow up con = self._connect() try: cur = con.cursor() cur.setoutputsize(1000) cur.setoutputsize(2000,0) self._paraminsert(cur) # Make sure the cursor still works finally: con.close() def test_setoutputsize(self): # Real test for setoutputsize is driver dependant pass def test_autocommit(self): con = self._connect() con2 = self._connect() try: con.autocommit = True # autocommit mode on, commit/abort on inappropriate. self.assertRaises( con.InterfaceError, con.commit ) self.assertRaises( con.InterfaceError, con.rollback ) c = con.cursor() c.execute("create table some_committed_table(i int)") # if this fails, autocommit had no effect on `con` con2.cursor().execute("drop table some_committed_table") con2.commit() finally: con.close() con2.close() def test_None(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) cur.execute('select name from %sbooze' % self.table_prefix) r = cur.fetchall() self.assertEqual(len(r),1) self.assertEqual(len(r[0]),1) self.assertEqual(r[0][0],None,'NULL value not returned as None') finally: con.close() def test_Date(self): d1 = self.driver.Date(2002,12,25) d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) # Can we assume this? API doesn't specify, but it seems implied self.assertEqual(str(d1),str(d2)) def test_Time(self): t1 = self.driver.Time(13,45,30) t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) # Can we assume this? API doesn't specify, but it seems implied self.assertEqual(str(t1),str(t2)) def test_Timestamp(self): t1 = self.driver.Timestamp(2002,12,25,13,45,30) t2 = self.driver.TimestampFromTicks( time.mktime((2002,12,25,13,45,30,0,0,0)) ) # Can we assume this? API doesn't specify, but it seems implied #self.assertEqual(str(t1),str(t2)) def test_Binary(self): b = self.driver.Binary(b'Something') b = self.driver.Binary(b'') def test_STRING(self): self.assertTrue(hasattr(self.driver,'STRING'), 'module.STRING must be defined' ) def test_BINARY(self): self.assertTrue(hasattr(self.driver,'BINARY'), 'module.BINARY must be defined.' ) def test_NUMBER(self): self.assertTrue(hasattr(self.driver,'NUMBER'), 'module.NUMBER must be defined.' ) def test_DATETIME(self): self.assertTrue(hasattr(self.driver,'DATETIME'), 'module.DATETIME must be defined.' ) def test_ROWID(self): self.assertTrue(hasattr(self.driver,'ROWID'), 'module.ROWID must be defined.' ) def test_placeholder_escape(self): con = self._connect() try: c = con.cursor() c.execute("SELECT 100 %% %s", (99,)) self.assertEqual(1, c.fetchone()[0]) c.execute("SELECT 100 %% %(foo)s", {'foo': 99}) self.assertEqual(1, c.fetchone()[0]) finally: con.close() if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_driver.py000066400000000000000000001467761203372773200202240ustar00rootroot00000000000000## # .test.test_driver ## import sys import unittest import gc import threading import time import datetime import decimal import uuid from itertools import chain, islice from operator import itemgetter from ..python.datetime import FixedOffset, \ negative_infinity_datetime, infinity_datetime, \ negative_infinity_date, infinity_date from .. import types as pg_types from ..types.io.stdlib_xml_etree import etree from .. import exceptions as pg_exc from ..types.bitwise import Bit, Varbit from ..temporal import pg_tmp type_samples = [ ('smallint', ( ((1 << 16) // 2) - 1, - ((1 << 16) // 2), -1, 0, 1, ), ), ('int', ( ((1 << 32) // 2) - 1, - ((1 << 32) // 2), -1, 0, 1, ), ), ('bigint', ( ((1 << 64) // 2) - 1, - ((1 << 64) // 2), -1, 0, 1, ), ), ('numeric', ( -(2**64), 2**64, -(2**128), 2**128, -1, 0, 1, decimal.Decimal("0.00000000000000"), decimal.Decimal("1.00000000000000"), decimal.Decimal("-1.00000000000000"), decimal.Decimal("-2.00000000000000"), decimal.Decimal("1000000000000000.00000000000000"), decimal.Decimal("-0.00000000000000"), decimal.Decimal(1234), decimal.Decimal(-1234), decimal.Decimal("1234000000.00088883231"), decimal.Decimal(str(1234.00088883231)), decimal.Decimal("3123.23111"), decimal.Decimal("-3123000000.23111"), decimal.Decimal("3123.2311100000"), decimal.Decimal("-03123.0023111"), decimal.Decimal("3123.23111"), decimal.Decimal("3123.23111"), decimal.Decimal("10000.23111"), decimal.Decimal("100000.23111"), decimal.Decimal("1000000.23111"), decimal.Decimal("10000000.23111"), decimal.Decimal("100000000.23111"), decimal.Decimal("1000000000.23111"), decimal.Decimal("1000000000.3111"), decimal.Decimal("1000000000.111"), decimal.Decimal("1000000000.11"), decimal.Decimal("100000000.0"), decimal.Decimal("10000000.0"), decimal.Decimal("1000000.0"), decimal.Decimal("100000.0"), decimal.Decimal("10000.0"), decimal.Decimal("1000.0"), decimal.Decimal("100.0"), decimal.Decimal("100"), decimal.Decimal("100.1"), decimal.Decimal("100.12"), decimal.Decimal("100.123"), decimal.Decimal("100.1234"), decimal.Decimal("100.12345"), decimal.Decimal("100.123456"), decimal.Decimal("100.1234567"), decimal.Decimal("100.12345679"), decimal.Decimal("100.123456790"), decimal.Decimal("100.123456790000000000000000"), decimal.Decimal("1.0"), decimal.Decimal("0.0"), decimal.Decimal("-1.0"), decimal.Decimal("1.0E-1000"), decimal.Decimal("1.0E1000"), decimal.Decimal("1.0E10000"), decimal.Decimal("1.0E-10000"), decimal.Decimal("1.0E15000"), decimal.Decimal("1.0E-15000"), decimal.Decimal("1.0E-16382"), decimal.Decimal("1.0E32767"), decimal.Decimal("0.000000000000000000000000001"), decimal.Decimal("0.000000000000010000000000001"), decimal.Decimal("0.00000000000000000000000001"), decimal.Decimal("0.00000000100000000000000001"), decimal.Decimal("0.0000000000000000000000001"), decimal.Decimal("0.000000000000000000000001"), decimal.Decimal("0.00000000000000000000001"), decimal.Decimal("0.0000000000000000000001"), decimal.Decimal("0.000000000000000000001"), decimal.Decimal("0.00000000000000000001"), decimal.Decimal("0.0000000000000000001"), decimal.Decimal("0.000000000000000001"), decimal.Decimal("0.00000000000000001"), decimal.Decimal("0.0000000000000001"), decimal.Decimal("0.000000000000001"), decimal.Decimal("0.00000000000001"), decimal.Decimal("0.0000000000001"), decimal.Decimal("0.000000000001"), decimal.Decimal("0.00000000001"), decimal.Decimal("0.0000000001"), decimal.Decimal("0.000000001"), decimal.Decimal("0.00000001"), decimal.Decimal("0.0000001"), decimal.Decimal("0.000001"), decimal.Decimal("0.00001"), decimal.Decimal("0.0001"), decimal.Decimal("0.001"), decimal.Decimal("0.01"), decimal.Decimal("0.1"), # these require some weight transfer ), ), ('bytea', ( bytes(range(256)), bytes(range(255, -1, -1)), b'\x00\x00', b'foo', ), ), ('smallint[]', ( [123,321,-123,-321], [], ), ), ('int[]', [ [123,321,-123,-321], [[1],[2]], [], ], ), ('bigint[]', [ [ 0, 1, -1, 0xFFFFFFFFFFFF, -0xFFFFFFFFFFFF, ((1 << 64) // 2) - 1, - ((1 << 64) // 2), ], [], ], ), ('varchar[]', [ ["foo", "bar",], ["foo", "bar",], [], ], ), ('timestamp', [ datetime.datetime(3000,5,20,5,30,10), datetime.datetime(2000,1,1,5,25,10), datetime.datetime(500,1,1,5,25,10), datetime.datetime(250,1,1,5,25,10), infinity_datetime, negative_infinity_datetime, ], ), ('date', [ datetime.date(3000,5,20), datetime.date(2000,1,1), datetime.date(500,1,1), datetime.date(1,1,1), ], ), ('time', [ datetime.time(12,15,20), datetime.time(0,1,1), datetime.time(23,59,59), ], ), ('timestamptz', [ # It's converted to UTC. When it comes back out, it will be in UTC # again. The datetime comparison will take the tzinfo into account. datetime.datetime(1990,5,12,10,10,0, tzinfo=FixedOffset(4000)), datetime.datetime(1982,5,18,10,10,0, tzinfo=FixedOffset(6000)), datetime.datetime(1950,1,1,10,10,0, tzinfo=FixedOffset(7000)), datetime.datetime(1800,1,1,10,10,0, tzinfo=FixedOffset(2000)), datetime.datetime(2400,1,1,10,10,0, tzinfo=FixedOffset(2000)), infinity_datetime, negative_infinity_datetime, ], ), ('timetz', [ # timetz retains the offset datetime.time(10,10,0, tzinfo=FixedOffset(4000)), datetime.time(10,10,0, tzinfo=FixedOffset(6000)), datetime.time(10,10,0, tzinfo=FixedOffset(7000)), datetime.time(10,10,0, tzinfo=FixedOffset(2000)), datetime.time(22,30,0, tzinfo=FixedOffset(0)), ], ), ('interval', [ # no months :( datetime.timedelta(40, 10, 1234), datetime.timedelta(0, 0, 4321), datetime.timedelta(0, 0), datetime.timedelta(-100, 0), datetime.timedelta(-100, -400), ], ), ('point', [ (10, 1234), (-1, -1), (0, 0), (1, 1), (-100, 0), (-100, -400), (-100.02314, -400.930425), (0xFFFF, 1.3124243), ], ), ('lseg', [ ((0,0),(0,0)), ((10,5),(18,293)), ((55,5),(10,293)), ((-1,-1),(-1,-1)), ((-100,0.00231),(50,45.42132)), ((0.123,0.00231),(50,45.42132)), ], ), ('circle', [ ((0,0),0), ((0,0),1), ((0,0),1.0011), ((1,1),1.0011), ((-1,-1),1.0011), ((1,-1),1.0011), ((-1,1),1.0011), ], ), ('box', [ ((0,0),(0,0)), ((-1,-1),(-1,-1)), ((1,1),(-1,-1)), ((10,1),(-1,-1)), ((100.2312,45.1232),(-123.023,-1423.82342)), ], ), ('bit', [ Bit('1'), Bit('0'), None, ], ), ('varbit', [ Varbit('1'), Varbit('0'), Varbit('10'), Varbit('11'), Varbit('00'), Varbit('001'), Varbit('101'), Varbit('111'), Varbit('0010'), Varbit('1010'), Varbit('1010'), Varbit('01010101011111011010110101010101111'), Varbit('010111101111'), ], ), ('macaddr[]', [ ['00:00:00:00:00:00', 'ff:ff:ff:ff:ff:ff'], ['00:00:00:00:00:01', '00:00:00:00:00:00', 'ff:ff:ff:ff:ff:ff'], ['00:00:00:00:00:01', '00:00:00:00:00:00', 'ff:ff:ff:ff:ff:ff', '10:00:00:00:00:00'], ], ), ] try: import ipaddress type_samples.extend([ ('inet', [ ipaddress.IPv4Address('255.255.255.255'), ipaddress.IPv4Address('127.0.0.1'), ipaddress.IPv4Address('10.0.0.1'), ipaddress.IPv4Address('0.0.0.0'), ipaddress.IPv6Address('::1'), ipaddress.IPv6Address('ffff' + ':ffff'*7), ipaddress.IPv6Address('fe80::1'), ipaddress.IPv6Address('fe80::1'), ipaddress.IPv6Address('::'), # 0::0 ], ), ('cidr', [ ipaddress.IPv4Network('255.255.255.255/32'), ipaddress.IPv4Network('127.0.0.0/8'), ipaddress.IPv4Network('127.1.0.0/16'), ipaddress.IPv4Network('10.0.0.0/32'), ipaddress.IPv4Network('0.0.0.0/0'), ipaddress.IPv6Network('ffff' + ':ffff'*7 + '/128'), ipaddress.IPv6Network('::1/128'), ipaddress.IPv6Network('fe80::1/128'), ipaddress.IPv6Network('fe80::/64'), ipaddress.IPv6Network('fe80::/16'), ipaddress.IPv6Network('::/0'), ], ), ('inet[]', [ [ipaddress.IPv4Address('127.0.0.1'), ipaddress.IPv6Address('::1')], [ipaddress.IPv4Address('10.0.0.1'), ipaddress.IPv6Address('fe80::1')], ], ), ('cidr[]', [ [ipaddress.IPv4Network('127.0.0.0/8'), ipaddress.IPv6Network('::/0')], [ipaddress.IPv4Network('10.0.0.0/16'), ipaddress.IPv6Network('fe80::/64')], [ipaddress.IPv4Network('10.102.0.0/16'), ipaddress.IPv6Network('fe80::/64')], ], ), ]) except ImportError: pass class test_driver(unittest.TestCase): @pg_tmp def testInterrupt(self): def pg_sleep(l): try: db.execute("SELECT pg_sleep(5)") except Exception: l.append(sys.exc_info()) else: l.append(None) return rl = [] t = threading.Thread(target = pg_sleep, args = (rl,)) t.start() time.sleep(0.2) while t.is_alive(): db.interrupt() time.sleep(0.1) def raise_exc(l): if l[0] is not None: e, v, tb = rl[0] raise v self.assertRaises(pg_exc.QueryCanceledError, raise_exc, rl) @pg_tmp def testClones(self): db.execute('create table _can_clone_see_this (i int);') try: with db.clone() as db2: self.assertEqual(db2.prepare('select 1').first(), 1) self.assertEqual(db2.prepare( "select count(*) FROM information_schema.tables " \ "where table_name = '_can_clone_see_this'" ).first(), 1 ) finally: db.execute('drop table _can_clone_see_this') # check already open db3 = db.clone() try: self.assertEqual(db3.prepare('select 1').first(), 1) finally: db3.close() ps = db.prepare('select 1') ps2 = ps.clone() self.assertEqual(ps2.first(), ps.first()) ps2.close() c = ps.declare() c2 = c.clone() self.assertEqual(c.read(), c2.read()) @pg_tmp def testItsClosed(self): ps = db.prepare("SELECT 1") # If scroll is False it will pre-fetch, and no error will be thrown. c = ps.declare() # c.close() self.assertRaises(pg_exc.CursorNameError, c.read) self.assertEqual(ps.first(), 1) # ps.close() self.assertRaises(pg_exc.StatementNameError, ps.first) # db.close() self.assertRaises( pg_exc.ConnectionDoesNotExistError, db.execute, "foo" ) # No errors, it's already closed. ps.close() c.close() db.close() @pg_tmp def testGarbage(self): ps = db.prepare('select 1') sid = ps.statement_id ci = ps.chunks() ci_id = ci.cursor_id c = ps.declare() cid = c.cursor_id # make sure there are no remaining xact references.. db._pq_complete() # ci and c both hold references to ps, so they must # be removed before we can observe the effects __del__ del c gc.collect() self.assertTrue(db.typio.encode(cid) in db.pq.garbage_cursors) del ci gc.collect() self.assertTrue(db.typio.encode(ci_id) in db.pq.garbage_cursors) del ps gc.collect() self.assertTrue(db.typio.encode(sid) in db.pq.garbage_statements) @pg_tmp def testStatementCall(self): ps = db.prepare("SELECT 1") r = ps() self.assertTrue(isinstance(r, list)) self.assertEqual(ps(), [(1,)]) ps = db.prepare("SELECT 1, 2") self.assertEqual(ps(), [(1,2)]) ps = db.prepare("SELECT 1, 2 UNION ALL SELECT 3, 4") self.assertEqual(ps(), [(1,2),(3,4)]) @pg_tmp def testStatementFirstDML(self): cmd = prepare("CREATE TEMP TABLE first (i int)").first() self.assertEqual(cmd, 'CREATE TABLE') fins = db.prepare("INSERT INTO first VALUES (123)").first fupd = db.prepare("UPDATE first SET i = 321 WHERE i = 123").first fdel = db.prepare("DELETE FROM first").first self.assertEqual(fins(), 1) self.assertEqual(fdel(), 1) self.assertEqual(fins(), 1) self.assertEqual(fupd(), 1) self.assertEqual(fins(), 1) self.assertEqual(fins(), 1) self.assertEqual(fupd(), 2) self.assertEqual(fdel(), 3) self.assertEqual(fdel(), 0) @pg_tmp def testStatementRowsPersistence(self): # validate that rows' cursor will persist beyond a transaction. ps = db.prepare("SELECT i FROM generate_series($1::int, $2::int) AS g(i)") # create the iterator inside the transaction rows = ps.rows(0, 10000-1) ps(0,1) # validate the first half. self.assertEqual( list(islice(map(itemgetter(0), rows), 5000)), list(range(5000)) ) ps(0,1) # and the second half. self.assertEqual( list(map(itemgetter(0), rows)), list(range(5000, 10000)) ) @pg_tmp def testStatementParameters(self): # too few and takes one ps = db.prepare("select $1::integer") self.assertRaises(TypeError, ps) # too many and takes one self.assertRaises(TypeError, ps, 1, 2) # too many and takes none ps = db.prepare("select 1") self.assertRaises(TypeError, ps, 1) # too many and takes some ps = db.prepare("select $1::int, $2::text") self.assertRaises(TypeError, ps, 1, "foo", "bar") @pg_tmp def testStatementAndCursorMetadata(self): ps = db.prepare("SELECT $1::integer AS my_int_column") self.assertEqual(tuple(ps.column_names), ('my_int_column',)) self.assertEqual(tuple(ps.sql_column_types), ('INTEGER',)) self.assertEqual(tuple(ps.sql_parameter_types), ('INTEGER',)) self.assertEqual(tuple(ps.pg_parameter_types), (pg_types.INT4OID,)) self.assertEqual(tuple(ps.parameter_types), (int,)) self.assertEqual(tuple(ps.column_types), (int,)) c = ps.declare(15) self.assertEqual(tuple(c.column_names), ('my_int_column',)) self.assertEqual(tuple(c.sql_column_types), ('INTEGER',)) self.assertEqual(tuple(c.column_types), (int,)) ps = db.prepare("SELECT $1::text AS my_text_column") self.assertEqual(tuple(ps.column_names), ('my_text_column',)) self.assertEqual(tuple(ps.sql_column_types), ('pg_catalog.text',)) self.assertEqual(tuple(ps.sql_parameter_types), ('pg_catalog.text',)) self.assertEqual(tuple(ps.pg_parameter_types), (pg_types.TEXTOID,)) self.assertEqual(tuple(ps.column_types), (str,)) self.assertEqual(tuple(ps.parameter_types), (str,)) c = ps.declare('textdata') self.assertEqual(tuple(c.column_names), ('my_text_column',)) self.assertEqual(tuple(c.sql_column_types), ('pg_catalog.text',)) self.assertEqual(tuple(c.pg_column_types), (pg_types.TEXTOID,)) self.assertEqual(tuple(c.column_types), (str,)) ps = db.prepare("SELECT $1::text AS my_column1, $2::varchar AS my_column2") self.assertEqual(tuple(ps.column_names), ('my_column1','my_column2')) self.assertEqual(tuple(ps.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING')) self.assertEqual(tuple(ps.sql_parameter_types), ('pg_catalog.text', 'CHARACTER VARYING')) self.assertEqual(tuple(ps.pg_parameter_types), (pg_types.TEXTOID, pg_types.VARCHAROID)) self.assertEqual(tuple(ps.pg_column_types), (pg_types.TEXTOID, pg_types.VARCHAROID)) self.assertEqual(tuple(ps.parameter_types), (str,str)) self.assertEqual(tuple(ps.column_types), (str,str)) c = ps.declare('textdata', 'varchardata') self.assertEqual(tuple(c.column_names), ('my_column1','my_column2')) self.assertEqual(tuple(c.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING')) self.assertEqual(tuple(c.pg_column_types), (pg_types.TEXTOID, pg_types.VARCHAROID)) self.assertEqual(tuple(c.column_types), (str,str)) db.execute("CREATE TYPE public.myudt AS (i int)") myudt_oid = db.prepare("select oid from pg_type WHERE typname='myudt'").first() ps = db.prepare("SELECT $1::text AS my_column1, $2::varchar AS my_column2, $3::public.myudt AS my_column3") self.assertEqual(tuple(ps.column_names), ('my_column1','my_column2', 'my_column3')) self.assertEqual(tuple(ps.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING', '"public"."myudt"')) self.assertEqual(tuple(ps.sql_parameter_types), ('pg_catalog.text', 'CHARACTER VARYING', '"public"."myudt"')) self.assertEqual(tuple(ps.pg_column_types), ( pg_types.TEXTOID, pg_types.VARCHAROID, myudt_oid) ) self.assertEqual(tuple(ps.pg_parameter_types), ( pg_types.TEXTOID, pg_types.VARCHAROID, myudt_oid) ) self.assertEqual(tuple(ps.parameter_types), (str,str,tuple)) self.assertEqual(tuple(ps.column_types), (str,str,tuple)) c = ps.declare('textdata', 'varchardata', (123,)) self.assertEqual(tuple(c.column_names), ('my_column1','my_column2', 'my_column3')) self.assertEqual(tuple(c.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING', '"public"."myudt"')) self.assertEqual(tuple(c.pg_column_types), ( pg_types.TEXTOID, pg_types.VARCHAROID, myudt_oid )) self.assertEqual(tuple(c.column_types), (str,str,tuple)) @pg_tmp def testRowInterface(self): data = (1, '0', decimal.Decimal('0.00'), datetime.datetime(1982,5,18,12,30,0)) ps = db.prepare( "SELECT 1::int2 AS col0, " \ "'0'::text AS col1, 0.00::numeric as col2, " \ "'1982-05-18 12:30:00'::timestamp as col3;" ) row = ps.first() self.assertEqual(tuple(row), data) self.assertTrue(1 in row) self.assertTrue('0' in row) self.assertTrue(decimal.Decimal('0.00') in row) self.assertTrue(datetime.datetime(1982,5,18,12,30,0) in row) self.assertEqual( tuple(row.column_names), tuple(['col' + str(i) for i in range(4)]) ) self.assertEqual( (row["col0"], row["col1"], row["col2"], row["col3"]), (row[0], row[1], row[2], row[3]), ) self.assertEqual( (row["col0"], row["col1"], row["col2"], row["col3"]), (row[0], row[1], row[2], row[3]), ) keys = list(row.keys()) cnames = list(ps.column_names) cnames.sort() keys.sort() self.assertEqual(keys, cnames) self.assertEqual(list(row.values()), list(data)) self.assertEqual(list(row.items()), list(zip(ps.column_names, data))) row_d = dict(row) for x in ps.column_names: self.assertEqual(row_d[x], row[x]) for x in row_d.keys(): self.assertEqual(row.get(x), row_d[x]) row_t = tuple(row) self.assertEqual(row_t, row) # transform crow = row.transform(col0 = str) self.assertEqual(type(crow[0]), str) crow = row.transform(str) self.assertEqual(type(crow[0]), str) crow = row.transform(str, int) self.assertEqual(type(crow[0]), str) self.assertEqual(type(crow[1]), int) # None = no transformation crow = row.transform(None, int) self.assertEqual(type(crow[0]), int) self.assertEqual(type(crow[1]), int) # and a combination crow = row.transform(str, col1 = int, col3 = str) self.assertEqual(type(crow[0]), str) self.assertEqual(type(crow[1]), int) self.assertEqual(type(crow[3]), str) for i in range(4): self.assertEqual(i, row.index_from_key('col' + str(i))) self.assertEqual('col' + str(i), row.key_from_index(i)) def column_test(self): g_i = db.prepare('SELECT i FROM generate_series(1,10) as g(i)').column # ignore the second column. g_ii = db.prepare('SELECT i, i+10 as i2 FROM generate_series(1,10) as g(i)').column self.assertEqual(tuple(g_i()), tuple(g_ii())) self.assertEqual(tuple(g_i()), (1,2,3,4,5,6,7,8,9,10)) @pg_tmp def testColumn(self): self.column_test() @pg_tmp def testColumnInXact(self): with db.xact(): self.column_test() @pg_tmp def testStatementFromId(self): db.execute("PREPARE foo AS SELECT 1 AS colname;") ps = db.statement_from_id('foo') self.assertEqual(ps.first(), 1) self.assertEqual(ps(), [(1,)]) self.assertEqual(list(ps), [(1,)]) self.assertEqual(tuple(ps.column_names), ('colname',)) @pg_tmp def testCursorFromId(self): db.execute("DECLARE foo CURSOR WITH HOLD FOR SELECT 1") c = db.cursor_from_id('foo') self.assertEqual(c.read(), [(1,)]) db.execute( "DECLARE bar SCROLL CURSOR WITH HOLD FOR SELECT i FROM generate_series(0, 99) AS g(i)" ) c = db.cursor_from_id('bar') c.seek(50) self.assertEqual([x for x, in c.read(10)], list(range(50,60))) c.seek(0,2) self.assertEqual(c.read(), []) c.seek(0) self.assertEqual([x for x, in c.read()], list(range(100))) @pg_tmp def testCopyToSTDOUT(self): with db.xact(): db.execute("CREATE TABLE foo (i int)") foo = db.prepare('insert into foo values ($1)') foo.load_rows(((x,) for x in range(500))) copy_foo = db.prepare('copy foo to stdout') foo_content = set(copy_foo) expected = set((str(i).encode('ascii') + b'\n' for i in range(500))) self.assertEqual(expected, foo_content) self.assertEqual(expected, set(copy_foo())) self.assertEqual(expected, set(chain.from_iterable(copy_foo.chunks()))) self.assertEqual(expected, set(copy_foo.rows())) db.execute("DROP TABLE foo") @pg_tmp def testCopyFromSTDIN(self): with db.xact(): db.execute("CREATE TABLE foo (i int)") foo = db.prepare('copy foo from stdin') foo.load_rows((str(i).encode('ascii') + b'\n' for i in range(200))) foo_content = list(( x for (x,) in db.prepare('select * from foo order by 1 ASC') )) self.assertEqual(foo_content, list(range(200))) db.execute("DROP TABLE foo") @pg_tmp def testCopyInvalidTermination(self): class DontTrapThis(BaseException): pass def EvilGenerator(): raise DontTrapThis() yield None sqlexec("CREATE TABLE foo (i int)") foo = prepare('copy foo from stdin') try: foo.load_chunks([EvilGenerator()]) self.fail("didn't raise the BaseException subclass") except DontTrapThis: pass try: db._pq_complete() except Exception: pass self.assertEqual(prepare('select 1').first(), 1) @pg_tmp def testLookupProcByName(self): db.execute( "CREATE OR REPLACE FUNCTION public.foo() RETURNS INT LANGUAGE SQL AS 'SELECT 1'" ) db.settings['search_path'] = 'public' f = db.proc('foo()') f2 = db.proc('public.foo()') self.assertTrue(f.oid == f2.oid, "function lookup incongruence(%r != %r)" %(f, f2) ) @pg_tmp def testLookupProcById(self): gsoid = db.prepare( "select oid from pg_proc where proname = 'generate_series' limit 1" ).first() gs = db.proc(gsoid) self.assertEqual(list(gs(1, 100)), list(range(1, 101))) def execute_proc(self): ver = db.proc("version()") ver() db.execute( "CREATE OR REPLACE FUNCTION ifoo(int) RETURNS int LANGUAGE SQL AS 'select $1'" ) ifoo = db.proc('ifoo(int)') self.assertEqual(ifoo(1), 1) self.assertEqual(ifoo(None), None) db.execute( "CREATE OR REPLACE FUNCTION ifoo(varchar) RETURNS text LANGUAGE SQL AS 'select $1'" ) ifoo = db.proc('ifoo(varchar)') self.assertEqual(ifoo('1'), '1') self.assertEqual(ifoo(None), None) db.execute( "CREATE OR REPLACE FUNCTION ifoo(varchar,int) RETURNS text LANGUAGE SQL AS 'select ($1::int + $2)::varchar'" ) ifoo = db.proc('ifoo(varchar,int)') self.assertEqual(ifoo('1',1), '2') self.assertEqual(ifoo(None,1), None) self.assertEqual(ifoo('1',None), None) self.assertEqual(ifoo('2',2), '4') @pg_tmp def testProcExecution(self): self.execute_proc() @pg_tmp def testProcExecutionInXact(self): with db.xact(): self.execute_proc() @pg_tmp def testProcExecutionInSubXact(self): with db.xact(), db.xact(): self.execute_proc() @pg_tmp def testNULL(self): # Directly commpare (SELECT NULL) is None self.assertTrue( db.prepare("SELECT NULL")()[0][0] is None, "SELECT NULL did not return None" ) # Indirectly compare (select NULL) is None self.assertTrue( db.prepare("SELECT $1::text")(None)[0][0] is None, "[SELECT $1::text](None) did not return None" ) @pg_tmp def testBool(self): fst, snd = db.prepare("SELECT true, false").first() self.assertTrue(fst is True) self.assertTrue(snd is False) def select(self): #self.assertEqual( # db.prepare('')().command(), # None, # 'Empty statement has command?' #) # Test SELECT 1. s1 = db.prepare("SELECT 1 as name") p = s1() tup = p[0] self.assertTrue(tup[0] == 1) for tup in s1: self.assertEqual(tup[0], 1) for tup in s1: self.assertEqual(tup["name"], 1) @pg_tmp def testSelect(self): self.select() @pg_tmp def testSelectInXact(self): with db.xact(): self.select() def cursor_read(self): ps = db.prepare("SELECT i FROM generate_series(0, (2^8)::int - 1) AS g(i)") c = ps.declare() self.assertEqual(c.read(0), []) self.assertEqual(c.read(0), []) self.assertEqual(c.read(1), [(0,)]) self.assertEqual(c.read(1), [(1,)]) self.assertEqual(c.read(2), [(2,), (3,)]) self.assertEqual(c.read(2), [(4,), (5,)]) self.assertEqual(c.read(3), [(6,), (7,), (8,)]) self.assertEqual(c.read(4), [(9,), (10,), (11,), (12,)]) self.assertEqual(c.read(4), [(13,), (14,), (15,), (16,)]) self.assertEqual(c.read(5), [(17,), (18,), (19,), (20,), (21,)]) self.assertEqual(c.read(0), []) self.assertEqual(c.read(6), [(22,),(23,),(24,),(25,),(26,),(27,)]) r = [-1] i = 4 v = 28 maxv = 2**8 while r: i = i * 2 r = [x for x, in c.read(i)] top = min(maxv, v + i) self.assertEqual(r, list(range(v, top))) v = top @pg_tmp def testCursorRead(self): self.cursor_read() @pg_tmp def testCursorIter(self): ps = db.prepare("SELECT i FROM generate_series(0, 10) AS g(i)") c = ps.declare() self.assertEqual(next(iter(c)), (0,)) self.assertEqual(next(iter(c)), (1,)) self.assertEqual(next(iter(c)), (2,)) @pg_tmp def testCursorReadInXact(self): with db.xact(): self.cursor_read() @pg_tmp def testScroll(self, direction = True): # Use a large row-set. imin = 0 imax = 2**16 if direction: ps = db.prepare("SELECT i FROM generate_series(0, (2^16)::int) AS g(i)") else: ps = db.prepare("SELECT i FROM generate_series((2^16)::int, 0, -1) AS g(i)") c = ps.declare() c.direction = direction if not direction: c.seek(0) self.assertEqual([x for x, in c.read(10)], list(range(10))) # bit strange to me, but i've watched the fetch backwards -jwp 2009 self.assertEqual([x for x, in c.read(10, 'BACKWARD')], list(range(8, -1, -1))) c.seek(0, 2) self.assertEqual([x for x, in c.read(10, 'BACKWARD')], list(range(imax, imax-10, -1))) # move to end c.seek(0, 2) self.assertEqual([x for x, in c.read(100, 'BACKWARD')], list(range(imax, imax-100, -1))) # move backwards, relative c.seek(-100, 1) self.assertEqual([x for x, in c.read(100, 'BACKWARD')], list(range(imax-200, imax-300, -1))) # move abs, again c.seek(14000) self.assertEqual([x for x, in c.read(100)], list(range(14000, 14100))) # move forwards, relative c.seek(100, 1) self.assertEqual([x for x, in c.read(100)], list(range(14200, 14300))) # move abs, again c.seek(24000) self.assertEqual([x for x, in c.read(200)], list(range(24000, 24200))) # move to end and then back some c.seek(20, 2) self.assertEqual([x for x, in c.read(200, 'BACKWARD')], list(range(imax-20, imax-20-200, -1))) c.seek(0, 2) c.seek(-10, 1) r1 = c.read(10) c.seek(10, 2) self.assertEqual(r1, c.read(10)) @pg_tmp def testSeek(self): ps = db.prepare("SELECT i FROM generate_series(0, (2^6)::int - 1) AS g(i)") c = ps.declare() self.assertEqual(c.seek(4, 'FORWARD'), 4) self.assertEqual([x for x, in c.read(10)], list(range(4, 14))) self.assertEqual(c.seek(2, 'BACKWARD'), 2) self.assertEqual([x for x, in c.read(10)], list(range(12, 22))) self.assertEqual(c.seek(-5, 'BACKWARD'), 5) self.assertEqual([x for x, in c.read(10)], list(range(27, 37))) self.assertEqual(c.seek('ALL'), 27) def testScrollBackwards(self): # testScroll again, but backwards this time. self.testScroll(direction = False) @pg_tmp def testWithHold(self): with db.xact(): ps = db.prepare("SELECT 1") c = ps.declare() cid = c.cursor_id self.assertEqual(c.read()[0][0], 1) # make sure it's not cheating self.assertEqual(c.cursor_id, cid) # check grabs beyond the default chunksize. with db.xact(): ps = db.prepare("SELECT i FROM generate_series(0, 99) as g(i)") c = ps.declare() cid = c.cursor_id self.assertEqual([x for x, in c.read()], list(range(100))) # make sure it's not cheating self.assertEqual(c.cursor_id, cid) def load_rows(self): gs = db.prepare("SELECT i FROM generate_series(1, 10000) AS g(i)") self.assertEqual( list((x[0] for x in gs.rows())), list(range(1, 10001)) ) # exercise ``for x in chunks: dst.load_rows(x)`` with new() as db2: db2.execute( """ CREATE TABLE chunking AS SELECT i::text AS t, i::int AS i FROM generate_series(1, 10000) g(i); """ ) read = db.prepare('select * FROM chunking').rows() write = db2.prepare('insert into chunking values ($1, $2)').load_rows with db2.xact(): write(read) del read, write self.assertEqual( db.prepare('select count(*) FROM chunking').first(), 20000 ) self.assertEqual( db.prepare('select count(DISTINCT i) FROM chunking').first(), 10000 ) db.execute('DROP TABLE chunking') @pg_tmp def testLoadRows(self): self.load_rows() @pg_tmp def testLoadRowsInXact(self): with db.xact(): self.load_rows() def load_chunks(self): gs = db.prepare("SELECT i FROM generate_series(1, 10000) AS g(i)") self.assertEqual( list((x[0] for x in chain.from_iterable(gs.chunks()))), list(range(1, 10001)) ) # exercise ``for x in chunks: dst.load_chunks(x)`` with new() as db2: db2.execute( """ CREATE TABLE chunking AS SELECT i::text AS t, i::int AS i FROM generate_series(1, 10000) g(i); """ ) read = db.prepare('select * FROM chunking').chunks() write = db2.prepare('insert into chunking values ($1, $2)').load_chunks with db2.xact(): write(read) del read, write self.assertEqual( db.prepare('select count(*) FROM chunking').first(), 20000 ) self.assertEqual( db.prepare('select count(DISTINCT i) FROM chunking').first(), 10000 ) db.execute('DROP TABLE chunking') @pg_tmp def testLoadChunks(self): self.load_chunks() @pg_tmp def testLoadChunkInXact(self): with db.xact(): self.load_chunks() @pg_tmp def testSimpleDML(self): db.execute("CREATE TEMP TABLE emp(emp_name text, emp_age int)") try: mkemp = db.prepare("INSERT INTO emp VALUES ($1, $2)") del_all_emp = db.prepare("DELETE FROM emp") command, count = mkemp('john', 35) self.assertEqual(command, 'INSERT') self.assertEqual(count, 1) command, count = mkemp('jane', 31) self.assertEqual(command, 'INSERT') self.assertEqual(count, 1) command, count = del_all_emp() self.assertEqual(command, 'DELETE') self.assertEqual(count, 2) finally: db.execute("DROP TABLE emp") def dml(self): db.execute("CREATE TEMP TABLE t(i int)") try: insert_t = db.prepare("INSERT INTO t VALUES ($1)") delete_t = db.prepare("DELETE FROM t WHERE i = $1") delete_all_t = db.prepare("DELETE FROM t") update_t = db.prepare("UPDATE t SET i = $2 WHERE i = $1") self.assertEqual(insert_t(1)[1], 1) self.assertEqual(delete_t(1)[1], 1) self.assertEqual(insert_t(2)[1], 1) self.assertEqual(insert_t(2)[1], 1) self.assertEqual(delete_t(2)[1], 2) self.assertEqual(insert_t(3)[1], 1) self.assertEqual(insert_t(3)[1], 1) self.assertEqual(insert_t(3)[1], 1) self.assertEqual(delete_all_t()[1], 3) self.assertEqual(update_t(1, 2)[1], 0) self.assertEqual(insert_t(1)[1], 1) self.assertEqual(update_t(1, 2)[1], 1) self.assertEqual(delete_t(1)[1], 0) self.assertEqual(delete_t(2)[1], 1) finally: db.execute("DROP TABLE t") @pg_tmp def testDML(self): self.dml() @pg_tmp def testDMLInXact(self): with db.xact(): self.dml() def batch_dml(self): db.execute("CREATE TEMP TABLE t(i int)") try: insert_t = db.prepare("INSERT INTO t VALUES ($1)") delete_t = db.prepare("DELETE FROM t WHERE i = $1") delete_all_t = db.prepare("DELETE FROM t") update_t = db.prepare("UPDATE t SET i = $2 WHERE i = $1") mset = ( (2,), (2,), (3,), (4,), (5,), ) insert_t.load_rows(mset) content = db.prepare("SELECT * FROM t ORDER BY 1 ASC") self.assertEqual(mset, tuple(content())) finally: db.execute("DROP TABLE t") @pg_tmp def testBatchDML(self): self.batch_dml() @pg_tmp def testBatchDMLInXact(self): with db.xact(): self.batch_dml() @pg_tmp def testTypes(self): 'test basic object I/O--input must equal output' for (typname, sample_data) in type_samples: pb = db.prepare( "SELECT $1::" + typname ) for sample in sample_data: rsample = list(pb.rows(sample))[0][0] if isinstance(rsample, pg_types.Array): rsample = rsample.nest() self.assertTrue( rsample == sample, "failed to return %s object data as-is; gave %r, received %r" %( typname, sample, rsample ) ) @pg_tmp def testDomainSupport(self): 'test domain type I/O' db.execute('CREATE DOMAIN int_t AS int') db.execute('CREATE DOMAIN int_t_2 AS int_t') db.execute('CREATE TYPE tt AS (a int_t, b int_t_2)') samples = { 'int_t': [10], 'int_t_2': [11], 'tt': [(12, 13)] } for (typname, sample_data) in samples.items(): pb = db.prepare( "SELECT $1::" + typname ) for sample in sample_data: rsample = list(pb.rows(sample))[0][0] if isinstance(rsample, pg_types.Array): rsample = rsample.nest() self.assertTrue( rsample == sample, "failed to return %s object data as-is; gave %r, received %r" %( typname, sample, rsample ) ) @pg_tmp def testAnonymousRecord(self): 'test anonymous record unpacking' db.execute('CREATE TYPE tar_t AS (a int, b int)') tests = { "SELECT (1::int, '2'::text, '2012-01-01 18:00 UTC'::timestamptz)": (1, '2', datetime.datetime(2012, 1, 1, 18, 0, tzinfo=FixedOffset(0))), "SELECT (1::int, '2'::text, (3::int, '4'::text))": (1, '2', (3, '4')), "SELECT (i::int, (i + 1, i + 2)::tar_t) FROM generate_series(1, 10) as i": (1, (2, 3)), "SELECT (1::int, ARRAY[(2, 3), (3, 4)])": (1, pg_types.Array([(2, 3), (3, 4)])) } for qry, expected in tests.items(): pb = db.prepare(qry) result = next(iter(pb.rows()))[0] self.assertEqual(result, expected) def check_xml(self): try: xml = db.prepare('select $1::xml') textxml = db.prepare('select $1::text::xml') r = textxml.first('') except (pg_exc.FeatureError, pg_exc.UndefinedObjectError): # XML is not available. return foo = etree.XML('') bar = etree.XML('') if hasattr(etree, 'tostringlist'): # 3.2 def tostr(x): return etree.tostring(x, encoding='utf-8') else: # 3.1 compat tostr = etree.tostring self.assertEqual(tostr(xml.first(foo)), tostr(foo)) self.assertEqual(tostr(xml.first(bar)), tostr(bar)) self.assertEqual(tostr(textxml.first('')), tostr(foo)) self.assertEqual(tostr(textxml.first('')), tostr(foo)) self.assertEqual(tostr(xml.first(etree.XML(''))), tostr(foo)) self.assertEqual(tostr(textxml.first('')), tostr(foo)) # test fragments self.assertEqual( tuple( tostr(x) for x in xml.first('') ), (tostr(foo), tostr(bar)) ) self.assertEqual( tuple( tostr(x) for x in textxml.first('') ), (tostr(foo), tostr(bar)) ) # mixed text and etree. self.assertEqual( tuple( tostr(x) for x in xml.first(( '', bar, )) ), (tostr(foo), tostr(bar)) ) self.assertEqual( tuple( tostr(x) for x in xml.first(( '', bar, )) ), (tostr(foo), tostr(bar)) ) @pg_tmp def testXML(self): self.check_xml() @pg_tmp def testXML_ascii(self): # check a non-utf8 encoding (3.2 and up) db.settings['client_encoding'] = 'sql_ascii' self.check_xml() @pg_tmp def testXML_utf8(self): # in 3.2 we always serialize at utf-8, so check that # that path is being ran by forcing the client_encoding to utf8. db.settings['client_encoding'] = 'utf8' self.check_xml() @pg_tmp def testUUID(self): # doesn't exist in all versions supported by py-postgresql. has_uuid = db.prepare( "select true from pg_type where lower(typname) = 'uuid'").first() if has_uuid: ps = db.prepare('select $1::uuid').first x = uuid.uuid1() self.assertEqual(ps(x), x) def _infinity_test(self, typname, inf, neg): ps = db.prepare('SELECT $1::' + typname).first val = ps('infinity') self.assertEqual(val, inf) val = ps('-infinity') self.assertEqual(val, neg) val = ps(inf) self.assertEqual(val, inf) val = ps(neg) self.assertEqual(val, neg) ps = db.prepare('SELECT $1::' + typname + '::text').first self.assertEqual(ps('infinity'), 'infinity') self.assertEqual(ps('-infinity'), '-infinity') @pg_tmp def testInfinity_stdlib_datetime(self): self._infinity_test("timestamptz", infinity_datetime, negative_infinity_datetime) self._infinity_test("timestamp", infinity_datetime, negative_infinity_datetime) @pg_tmp def testInfinity_stdlib_date(self): try: db.prepare("SELECT 'infinity'::date")() self._infinity_test('date', infinity_date, negative_infinity_date) except: pass @pg_tmp def testTypeIOError(self): original = dict(db.typio._cache) ps = db.prepare('SELECT $1::numeric') self.assertRaises(pg_exc.ParameterError, ps, 'foo') try: db.execute('CREATE type test_tuple_error AS (n numeric);') ps = db.prepare('SELECT $1::test_tuple_error AS the_column') self.assertRaises(pg_exc.ParameterError, ps, ('foo',)) try: ps(('foo',)) except pg_exc.ParameterError as err: # 'foo' is not a valid Decimal. # Expecting a double TupleError here, one from the composite pack # and one from the row pack. self.assertTrue(isinstance(err.__cause__, pg_exc.CompositeError)) self.assertEqual(int(err.details['position']), 0) # attribute number that the failure occurred on self.assertEqual(int(err.__cause__.details['position']), 0) else: self.fail("failed to raise TupleError") # testing tuple error reception is a bit more difficult. # to do this, we need to immitate failure as we can't rely that any # causable failure will always exist. class ThisError(Exception): pass def raise_ThisError(arg): raise ThisError(arg) pack, unpack, typ = db.typio.resolve(pg_types.NUMERICOID) # remove any existing knowledge about "test_tuple_error" db.typio._cache = original db.typio._cache[pg_types.NUMERICOID] = (pack, raise_ThisError, typ) # Now, numeric_unpack will always raise "ThisError". ps = db.prepare('SELECT $1::numeric as col') self.assertRaises( pg_exc.ColumnError, ps, decimal.Decimal("101") ) try: ps(decimal.Decimal("101")) except pg_exc.ColumnError as err: self.assertTrue(isinstance(err.__cause__, ThisError)) # might be too inquisitive.... self.assertEqual(int(err.details['position']), 0) self.assertTrue('NUMERIC' in err.message) self.assertTrue('col' in err.message) else: self.fail("failed to raise TupleError from reception") ps = db.prepare('SELECT $1::test_tuple_error AS tte') try: ps((decimal.Decimal("101"),)) except pg_exc.ColumnError as err: self.assertTrue(isinstance(err.__cause__, pg_exc.CompositeError)) self.assertTrue(isinstance(err.__cause__.__cause__, ThisError)) # might be too inquisitive.... self.assertEqual(int(err.details['position']), 0) self.assertEqual(int(err.__cause__.details['position']), 0) self.assertTrue('test_tuple_error' in err.message) else: self.fail("failed to raise TupleError from reception") finally: db.execute('drop type test_tuple_error;') @pg_tmp def testSyntaxError(self): try: db.prepare("SELEKT 1")() except pg_exc.SyntaxError: return self.fail("SyntaxError was not raised") @pg_tmp def testSchemaNameError(self): try: db.prepare("SELECT * FROM sdkfldasjfdskljZknvson.foo")() except pg_exc.SchemaNameError: return self.fail("SchemaNameError was not raised") @pg_tmp def testUndefinedTableError(self): try: db.prepare("SELECT * FROM public.lkansdkvsndlvksdvnlsdkvnsdlvk")() except pg_exc.UndefinedTableError: return self.fail("UndefinedTableError was not raised") @pg_tmp def testUndefinedColumnError(self): try: db.prepare("SELECT x____ysldvndsnkv FROM information_schema.tables")() except pg_exc.UndefinedColumnError: return self.fail("UndefinedColumnError was not raised") @pg_tmp def testSEARVError_avgInWhere(self): try: db.prepare("SELECT 1 WHERE avg(1) = 1")() except pg_exc.SEARVError: return self.fail("SEARVError was not raised") @pg_tmp def testSEARVError_groupByAgg(self): try: db.prepare("SELECT 1 GROUP BY avg(1)")() except pg_exc.SEARVError: return self.fail("SEARVError was not raised") @pg_tmp def testTypeMismatchError(self): try: db.prepare("SELECT 1 WHERE 1")() except pg_exc.TypeMismatchError: return self.fail("TypeMismatchError was not raised") @pg_tmp def testUndefinedObjectError(self): try: self.assertRaises( pg_exc.UndefinedObjectError, db.prepare, "CREATE TABLE lksvdnvsdlksnv(i intt___t)" ) except: # newer versions throw the exception on execution self.assertRaises( pg_exc.UndefinedObjectError, db.prepare("CREATE TABLE lksvdnvsdlksnv(i intt___t)") ) @pg_tmp def testZeroDivisionError(self): self.assertRaises( pg_exc.ZeroDivisionError, db.prepare("SELECT 1/i FROM (select 0 as i) AS g(i)").first, ) @pg_tmp def testTransactionCommit(self): with db.xact(): db.execute("CREATE TEMP TABLE withfoo(i int)") db.prepare("SELECT * FROM withfoo") db.execute("DROP TABLE withfoo") self.assertRaises( pg_exc.UndefinedTableError, db.execute, "SELECT * FROM withfoo" ) @pg_tmp def testTransactionAbort(self): class SomeError(Exception): pass try: with db.xact(): db.execute("CREATE TABLE withfoo (i int)") raise SomeError except SomeError: pass self.assertRaises( pg_exc.UndefinedTableError, db.execute, "SELECT * FROM withfoo" ) @pg_tmp def testSerializeable(self): with new() as db2: db2.execute("create table some_darn_table (i int);") try: with db.xact(isolation = 'serializable'): db.execute('insert into some_darn_table values (123);') # db2 is in autocommit.. db2.execute('insert into some_darn_table values (321);') self.assertNotEqual( list(db.prepare('select * from some_darn_table')), list(db2.prepare('select * from some_darn_table')), ) finally: # cleanup db2.execute("drop table some_darn_table;") @pg_tmp def testReadOnly(self): class something(Exception): pass try: with db.xact(mode = 'read only'): self.assertRaises( pg_exc.ReadOnlyTransactionError, db.execute, "create table ieeee(i int)" ) raise something("yeah, it raised.") self.fail("should have been passed by exception") except something: pass @pg_tmp def testFailedTransactionBlock(self): try: with db.xact(): try: db.execute("selekt 1;") except pg_exc.SyntaxError: pass self.fail("__exit__ didn't identify failed transaction") except pg_exc.InFailedTransactionError as err: self.assertEqual(err.source, 'CLIENT') @pg_tmp def testFailedSubtransactionBlock(self): with db.xact(): try: with db.xact(): try: db.execute("selekt 1;") except pg_exc.SyntaxError: pass self.fail("__exit__ didn't identify failed transaction") except pg_exc.InFailedTransactionError as err: # driver should have released/aborted instead self.assertEqual(err.source, 'CLIENT') @pg_tmp def testSuccessfulSubtransactionBlock(self): with db.xact(): with db.xact(): db.execute("create temp table subxact_sx1(i int);") with db.xact(): db.execute("create temp table subxact_sx2(i int);") # And, because I'm paranoid. # The following block is used to make sure # that savepoints are actually being set. try: with db.xact(): db.execute("selekt 1") except pg_exc.SyntaxError: # Just in case the xact() aren't doing anything. pass with db.xact(): db.execute("create temp table subxact_sx3(i int);") # if it can't drop these tables, it didn't manage the subxacts # properly. db.execute("drop table subxact_sx1") db.execute("drop table subxact_sx2") db.execute("drop table subxact_sx3") @pg_tmp def testReleasedSavepoint(self): # validate that the rolled back savepoint is released as well. x = None with db.xact(): try: with db.xact(): try: with db.xact() as x: db.execute("selekt 1") except pg_exc.SyntaxError: db.execute('RELEASE "xact(' + hex(id(x)) + ')"') except pg_exc.InvalidSavepointSpecificationError as e: pass else: self.fail("InvalidSavepointSpecificationError not raised") @pg_tmp def testCloseInSubTransactionBlock(self): try: with db.xact(): db.close() self.fail("transaction __exit__ didn't identify cause ConnectionDoesNotExistError") except pg_exc.ConnectionDoesNotExistError: pass @pg_tmp def testCloseInSubTransactionBlock(self): try: with db.xact(): with db.xact(): db.close() self.fail("transaction __exit__ didn't identify cause ConnectionDoesNotExistError") self.fail("transaction __exit__ didn't identify cause ConnectionDoesNotExistError") except pg_exc.ConnectionDoesNotExistError: pass @pg_tmp def testSettingsCM(self): orig = db.settings['search_path'] with db.settings(search_path='public'): self.assertEqual(db.settings['search_path'], 'public') self.assertEqual(db.settings['search_path'], orig) @pg_tmp def testSettingsReset(self): # <3 search_path del db.settings['search_path'] cur = db.settings['search_path'] db.settings['search_path'] = 'pg_catalog' del db.settings['search_path'] self.assertEqual(db.settings['search_path'], cur) @pg_tmp def testSettingsCount(self): self.assertEqual( len(db.settings), db.prepare('select count(*) from pg_settings').first() ) @pg_tmp def testSettingsGet(self): self.assertEqual( db.settings['search_path'], db.settings.get('search_path') ) self.assertEqual(None, db.settings.get(' $*0293 vksnd')) @pg_tmp def testSettingsGetSet(self): sub = db.settings.getset( ('search_path', 'default_statistics_target') ) self.assertEqual(db.settings['search_path'], sub['search_path']) self.assertEqual(db.settings['default_statistics_target'], sub['default_statistics_target']) @pg_tmp def testSettings(self): d = dict(db.settings) d = dict(db.settings.items()) k = list(db.settings.keys()) v = list(db.settings.values()) self.assertEqual(len(k), len(d)) self.assertEqual(len(k), len(v)) for x in k: self.assertTrue(d[x] in v) all = list(db.settings.getset(k).items()) all.sort(key=itemgetter(0)) dall = list(d.items()) dall.sort(key=itemgetter(0)) self.assertEqual(dall, all) @pg_tmp def testDo(self): # plpgsql is expected to be available. if db.version_info[:2] < (8,5): return if 'plpgsql' not in db.sys.languages(): db.execute("CREATE LANGUAGE plpgsql") db.do('plpgsql', "BEGIN CREATE TEMP TABLE do_tmp_table(i int, t text); END",) self.assertEqual(len(db.prepare("SELECT * FROM do_tmp_table")()), 0) db.do('plpgsql', "BEGIN INSERT INTO do_tmp_table VALUES (100, 'foo'); END") self.assertEqual(len(db.prepare("SELECT * FROM do_tmp_table")()), 1) @pg_tmp def testListeningChannels(self): db.listen('foo', 'bar') self.assertEqual(set(db.listening_channels()), {'foo','bar'}) db.unlisten('bar') db.listen('foo', 'bar') self.assertEqual(set(db.listening_channels()), {'foo','bar'}) db.unlisten('foo', 'bar') self.assertEqual(set(db.listening_channels()), set()) @pg_tmp def testNotify(self): db.listen('foo', 'bar') db.listen('foo', 'bar') db.notify('foo') db.execute('') self.assertEqual(db._notifies[0].channel, b'foo') self.assertEqual(db._notifies[0].pid, db.backend_id) self.assertEqual(db._notifies[0].payload, b'') del db._notifies[0] db.notify('bar') db.execute('') self.assertEqual(db._notifies[0].channel, b'bar') self.assertEqual(db._notifies[0].pid, db.backend_id) self.assertEqual(db._notifies[0].payload, b'') del db._notifies[0] db.unlisten('foo') db.notify('foo') db.execute('') self.assertEqual(db._notifies, []) # Invoke an error to show that listen() is all or none. self.assertRaises(Exception, db.listen, 'doesntexist', 'x'*64) self.assertTrue('doesntexist' not in db.listening_channels()) @pg_tmp def testPayloads(self): if db.version_info[:2] >= (9,0): db.listen('foo') db.notify(foo = 'bar') self.assertEqual(('foo', 'bar', db.backend_id), list(db.iternotifies(0))[0]) db.notify(('foo', 'barred')) self.assertEqual(('foo', 'barred', db.backend_id), list(db.iternotifies(0))[0]) # mixed db.notify(('foo', 'barred'), 'foo', ('foo', 'bleh'), foo = 'kw') self.assertEqual([ ('foo', 'barred', db.backend_id), ('foo', '', db.backend_id), ('foo', 'bleh', db.backend_id), # Keywords are appened. ('foo', 'kw', db.backend_id), ], list(db.iternotifies(0)) ) # multiple keywords expect = [ ('foo', 'meh', db.backend_id), ('bar', 'foo', db.backend_id), ] rexpect = list(reversed(expect)) db.listen('bar') db.notify(foo = 'meh', bar = 'foo') self.assertTrue(list(db.iternotifies(0)) in [expect, rexpect]) @pg_tmp def testMessageHook(self): create = db.prepare('CREATE TEMP TABLE msghook (i INT PRIMARY KEY)') drop = db.prepare('DROP TABLE msghook') parts = [ create, db, db.connector, db.connector.driver, ] notices = [] def add(x): notices.append(x) # inhibit return True with db.xact(): db.settings['client_min_messages'] = 'NOTICE' # test an installed msghook at each level for x in parts: x.msghook = add create() del x.msghook drop() self.assertEqual(len(notices), len(parts)) last = None for x in notices: if last is None: last = x continue self.assertTrue(x.isconsistent(last)) last = x @pg_tmp def testRowTypeFactory(self): from ..types.namedtuple import NamedTupleFactory db.typio.RowTypeFactory = NamedTupleFactory ps = prepare('select 1 as foo, 2 as bar') first_results = ps.first() self.assertEqual(first_results.foo, 1) self.assertEqual(first_results.bar, 2) call_results = ps()[0] self.assertEqual(call_results.foo, 1) self.assertEqual(call_results.bar, 2) declare_results = ps.declare().read(1)[0] self.assertEqual(declare_results.foo, 1) self.assertEqual(declare_results.bar, 2) sqlexec('create type rtf AS (foo int, bar int)') ps = prepare('select ROW(1, 2)::rtf') composite_results = ps.first() self.assertEqual(composite_results.foo, 1) self.assertEqual(composite_results.bar, 2) @pg_tmp def testNamedTuples(self): from ..types.namedtuple import namedtuples ps = namedtuples(prepare('select 1 as foo, 2 as bar, $1::text as param')) r = list(ps("hello"))[0] self.assertEqual(r[0], 1) self.assertEqual(r.foo, 1) self.assertEqual(r[1], 2) self.assertEqual(r.bar, 2) self.assertEqual(r[2], "hello") self.assertEqual(r.param, "hello") @pg_tmp def testBadFD(self): db.pq.socket.close() # bad fd now. self.assertRaises( pg_exc.ConnectionFailureError, sqlexec, "SELECT 1" ) self.assertTrue(issubclass(pg_exc.ConnectionFailureError, pg_exc.Disconnection)) @pg_tmp def testAdminTerminated(self): with new() as killer: if killer.version_info[:2] <= (9,1): killer.sys.terminate_backends() else: killer.sys.terminate_backends_92() self.assertRaises( pg_exc.AdminShutdownError, sqlexec, "SELECT 1", ) self.assertTrue(issubclass(pg_exc.AdminShutdownError, pg_exc.Disconnection)) @pg_tmp def testQuery(self): self.assertEqual(db.query('select 1'), [(1,)]) self.assertEqual(db.query.first('select 1'), 1) self.assertEqual(next(db.query.column('select 1')), 1) self.assertEqual(next(db.query.rows('select 1')), (1,)) self.assertEqual(db.query.declare('select 1').read(), [(1,)]) self.assertEqual(db.query('select $1::int', 1), [(1,)]) self.assertEqual(db.query.first('select $1::int', 1), 1) self.assertEqual(next(db.query.column('select $1::int', 1)), 1) self.assertEqual(next(db.query.rows('select $1::int', 1)), (1,)) self.assertEqual(db.query.declare('select $1::int', 1).read(), [(1,)]) self.assertEqual(db.query.load_rows('select $1::int', [[1]]), None) self.assertEqual(db.query.load_chunks('select $1::int', [[[1]]]), None) class test_typio(unittest.TestCase): @pg_tmp def testIdentify(self): # It just exercises the code path. db.typio.identify(contrib_hstore = 'pg_catalog.reltime') @pg_tmp def testArrayNulls(self): try: sqlexec('SELECT ARRAY[1,NULL]::int[]') except Exception: # unsupported here return inta = prepare('select $1::int[]').first texta = prepare('select $1::text[]').first self.assertEqual(inta([1,2,None]), [1,2,None]) self.assertEqual(texta(["foo",None,"bar"]), ["foo",None,"bar"]) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_exceptions.py000066400000000000000000000027271203372773200210750ustar00rootroot00000000000000## # .test.test_exceptions ## import unittest import postgresql.exceptions as pg_exc class test_exceptions(unittest.TestCase): def test_pg_code_lookup(self): # in 8.4, pg started using the SQL defined error code for limits # Users *will* get whatever code PG sends, but it's important # that they have some way to abstract it. many-to-one map ftw. self.assertEqual( pg_exc.ErrorLookup('22020'), pg_exc.LimitValueError ) def test_error_lookup(self): # An error code that doesn't exist yields pg_exc.Error self.assertEqual( pg_exc.ErrorLookup('00000'), pg_exc.Error ) self.assertEqual( pg_exc.ErrorLookup('XX000'), pg_exc.InternalError ) # check class fallback self.assertEqual( pg_exc.ErrorLookup('XX444'), pg_exc.InternalError ) # SEARV is a very large class, so there are many # sub-"codeclass" exceptions used to group the many # SEARV errors. Make sure looking up 42000 actually # gives the SEARVError self.assertEqual( pg_exc.ErrorLookup('42000'), pg_exc.SEARVError ) self.assertEqual( pg_exc.ErrorLookup('08P01'), pg_exc.ProtocolError ) def test_warning_lookup(self): self.assertEqual( pg_exc.WarningLookup('01000'), pg_exc.Warning ) self.assertEqual( pg_exc.WarningLookup('02000'), pg_exc.NoDataWarning ) self.assertEqual( pg_exc.WarningLookup('01P01'), pg_exc.DeprecationWarning ) self.assertEqual( pg_exc.WarningLookup('01888'), pg_exc.Warning ) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_installation.py000066400000000000000000000045121203372773200214070ustar00rootroot00000000000000## # .test.test_installation ## import sys import os import unittest from .. import installation as ins class test_installation(unittest.TestCase): """ Most of this is exercised by TestCaseWithCluster, but do some explicit checks up front to help find any specific issues that do not naturally occur. """ def test_parse_configure_options(self): # Check expectations. self.assertEqual( list(ins.parse_configure_options("")), [], ) self.assertEqual( list(ins.parse_configure_options(" ")), [], ) self.assertEqual( list(ins.parse_configure_options("--foo --bar")), [('foo',True),('bar',True)] ) self.assertEqual( list(ins.parse_configure_options("'--foo' '--bar'")), [('foo',True),('bar',True)] ) self.assertEqual( list(ins.parse_configure_options("'--foo=A properly isolated string' '--bar'")), [('foo','A properly isolated string'),('bar',True)] ) # hope they don't ever use backslash escapes. # This is pretty dirty, but it doesn't seem well defined anyways. self.assertEqual( list(ins.parse_configure_options("'--foo=A ''properly'' isolated string' '--bar'")), [('foo',"A 'properly' isolated string"),('bar',True)] ) # handle some simple variations, but it's self.assertEqual( list(ins.parse_configure_options("'--foo' \"--bar\"")), [('foo',True),('bar',True)] ) # Show the failure. try: self.assertEqual( list(ins.parse_configure_options("'--foo' \"--bar=/A dir/file\"")), [('foo',True),('bar','/A dir/file')] ) except AssertionError: pass else: self.fail("did not detect induced failure") def test_minimum(self): 'version info' # Installation only "needs" the version information i = ins.Installation({'version' : 'PostgreSQL 2.2.3'}) self.assertEqual( i.version, 'PostgreSQL 2.2.3' ) self.assertEqual( i.version_info, (2,2,3,'final',0) ) self.assertEqual(i.postgres, None) self.assertEqual(i.postmaster, None) def test_exec(self): # check the executable i = ins.pg_config_dictionary( sys.executable, '-m', __package__ + '.support', 'pg_config') # automatically lowers the key self.assertEqual(i['foo'], 'BaR') self.assertEqual(i['feh'], 'YEAH') self.assertEqual(i['version'], 'NAY') if __name__ == '__main__': from types import ModuleType this = ModuleType("this") this.__dict__.update(globals()) unittest.main(this) fe-1.1.0/postgresql/test/test_iri.py000066400000000000000000000055451203372773200175000ustar00rootroot00000000000000## # .test.test_iri ## import unittest import postgresql.iri as pg_iri value_errors = ( # Invalid scheme. 'http://user@host/index.html', ) iri_samples = ( 'host/dbname/path?param=val#frag', '#frag', '?param=val', '?param=val#frag', 'user@', ':pass@', 'u:p@h', 'u:p@h:1', 'pq://user:password@host:port/database?setting=value#public,private', 'pq://fæm.com:123/õéf/á?param=val', 'pq://l»»@fæm.com:123/õéf/á?param=val', 'pq://fæᎱᏋm.com/õéf/á?param=val', 'pq://fæᎱᏋm.com/õéf/á?param=val&[setting]=value', ) sample_structured_parameters = [ { 'host' : 'hostname', 'port' : '1234', 'database' : 'foo_db', }, { 'user' : 'username', 'database' : 'database_name', 'settings' : {'foo':'bar','feh':'bl%,23'}, }, { 'user' : 'username', 'database' : 'database_name', }, { 'database' : 'database_name', }, { 'user' : 'user_name', }, { 'host' : 'hostname', }, { 'user' : 'username', 'password' : 'pass', 'host' : '', 'port' : '4321', 'database' : 'database_name', 'path' : ['path'], }, { 'user' : 'user', 'password' : 'secret', 'host' : '', 'port' : 'ssh', 'database' : 'database_name', 'settings' : { 'set1' : 'val1', 'set2' : 'val2', }, }, { 'user' : 'user', 'password' : 'secret', 'host' : '', 'port' : 'ssh', 'database' : 'database_name', 'settings' : { 'set1' : 'val1', 'set2' : 'val2', }, 'connect_timeout' : '10', 'sslmode' : 'prefer', }, ] class test_iri(unittest.TestCase): def testPresentPasswordObscure(self): "password is present in IRI, and obscure it" s = 'pq://user:pass@host:port/dbname' o = 'pq://user:***@host:port/dbname' p = pg_iri.parse(s) ps = pg_iri.serialize(p, obscure_password = True) self.assertEqual(ps, o) def testPresentPasswordObscure(self): "password is *not* present in IRI, and do nothing" s = 'pq://user@host:port/dbname' o = 'pq://user@host:port/dbname' p = pg_iri.parse(s) ps = pg_iri.serialize(p, obscure_password = True) self.assertEqual(ps, o) def testValueErrors(self): for x in value_errors: self.assertRaises(ValueError, pg_iri.parse, x ) def testParseSerialize(self): scheme = 'pq://' for x in iri_samples: px = pg_iri.parse(x) spx = pg_iri.serialize(px) pspx = pg_iri.parse(spx) self.assertTrue( pspx == px, "parse-serialize incongruity, %r -> %r -> %r : %r != %r" %( x, px, spx, pspx, px ) ) spspx = pg_iri.serialize(pspx) self.assertTrue( spx == spspx, "parse-serialize incongruity, %r -> %r -> %r -> %r : %r != %r" %( x, px, spx, pspx, spspx, spx ) ) def testSerializeParse(self): for x in sample_structured_parameters: xs = pg_iri.serialize(x) uxs = pg_iri.parse(xs) self.assertTrue( x == uxs, "serialize-parse incongruity, %r -> %r -> %r" %( x, xs, uxs, ) ) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_lib.py000066400000000000000000000101121203372773200174450ustar00rootroot00000000000000## # .test.test_lib - test the .lib package ## import sys import os import unittest import tempfile from .. import exceptions as pg_exc from .. import lib as pg_lib from .. import sys as pg_sys from ..temporal import pg_tmp ilf = """ preface [sym] select 1 [sym_ref] *[sym] [sym_ref_trail] *[sym] WHERE FALSE [sym_first::first] select 1 [sym_rows::rows] select 1 [sym_chunks::chunks] select 1 [sym_declare::declare] select 1 [sym_const:const:first] select 1 [sym_const_rows:const:rows] select 1 [sym_const_chunks:const:chunks] select 1 [sym_const_column:const:column] select 1 [sym_const_ddl:const:] create temp table sym_const_dll (i int); [sym_preload:preload:first] select 1 [sym_proc:proc] test_ilf_proc(int) [sym_srf_proc:proc] test_ilf_srf_proc(int) [&sym_reference] SELECT 'SELECT 1'; [&sym_reference_params] SELECT 'SELECT ' || $1::text; [&sym_reference_first::first] SELECT 'SELECT 1::int4'; [&sym_reference_const:const:first] SELECT 'SELECT 1::int4'; [&sym_reference_proc:proc] SELECT 'test_ilf_proc(int)'::text """ class test_lib(unittest.TestCase): # NOTE: Module libraries are implicitly tested # in postgresql.test.test_driver; much functionality # depends on the `sys` library. def _testILF(self, lib): self.assertTrue('preface' in lib.preface) db.execute("CREATE OR REPLACE FUNCTION test_ilf_proc(int) RETURNS int language sql as 'select $1';") db.execute("CREATE OR REPLACE FUNCTION test_ilf_srf_proc(int) RETURNS SETOF int language sql as 'select $1';") b = pg_lib.Binding(db, lib) self.assertEqual(b.sym_ref(), [(1,)]) self.assertEqual(b.sym_ref_trail(), []) self.assertEqual(b.sym(), [(1,)]) self.assertEqual(b.sym_first(), 1) self.assertEqual(list(b.sym_rows()), [(1,)]) self.assertEqual([list(x) for x in b.sym_chunks()], [[(1,)]]) c = b.sym_declare() self.assertEqual(c.read(), [(1,)]) c.seek(0) self.assertEqual(c.read(), [(1,)]) self.assertEqual(b.sym_const, 1) self.assertEqual(b.sym_const_column, [1]) self.assertEqual(b.sym_const_rows, [(1,)]) self.assertEqual(b.sym_const_chunks, [[(1,)]]) self.assertEqual(b.sym_const_ddl, ('CREATE TABLE', None)) self.assertEqual(b.sym_preload(), 1) # now stored procs self.assertEqual(b.sym_proc(2,), 2) self.assertEqual(list(b.sym_srf_proc(2,)), [2]) self.assertRaises(AttributeError, getattr, b, 'LIES') # reference symbols self.assertEqual(b.sym_reference()(), [(1,)]) self.assertEqual(b.sym_reference_params('1::int')(), [(1,)]) self.assertEqual(b.sym_reference_params("'foo'::text")(), [('foo',)]) self.assertEqual(b.sym_reference_first()(), 1) self.assertEqual(b.sym_reference_const(), 1) self.assertEqual(b.sym_reference_proc()(2,), 2) @pg_tmp def testILF_from_lines(self): lib = pg_lib.ILF.from_lines([l + '\n' for l in ilf.splitlines()]) self._testILF(lib) @pg_tmp def testILF_from_file(self): f = tempfile.NamedTemporaryFile( delete = False, mode = 'w', encoding = 'utf-8' ) n = f.name try: f.write(ilf) f.flush() f.seek(0) lib = pg_lib.ILF.open(n, encoding = 'utf-8') self._testILF(lib) f.close() finally: # so annoying... os.unlink(n) @pg_tmp def testLoad(self): # gotta test it in the cwd... pid = os.getpid() frag = 'temp' + str(pid) fn = 'lib' + frag + '.sql' try: with open(fn, 'w') as f: f.write("[foo]\nSELECT 1") pg_sys.libpath.insert(0, os.path.curdir) l = pg_lib.load(frag) b = pg_lib.Binding(db, l) self.assertEqual(b.foo(), [(1,)]) finally: os.remove(fn) @pg_tmp def testCategory(self): lib = pg_lib.ILF.from_lines([l + '\n' for l in ilf.splitlines()]) # XXX: evil, careful.. lib._name = 'name' c = pg_lib.Category(lib) c(db) self.assertEqual(db.name.sym_first(), 1) c = pg_lib.Category(renamed = lib) c(db) self.assertEqual(db.renamed.sym_first(), 1) @pg_tmp def testCategoryAliases(self): lib = pg_lib.ILF.from_lines([l + '\n' for l in ilf.splitlines()]) # XXX: evil, careful.. lib._name = 'name' c = pg_lib.Category(lib, renamed = lib) c(db) self.assertEqual(db.name.sym_first(), 1) self.assertEqual(db.renamed.sym_first(), 1) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_notifyman.py000066400000000000000000000070721203372773200207160ustar00rootroot00000000000000## # .test.test_notifyman - test .notifyman ## import unittest import threading import time from ..temporal import pg_tmp from ..notifyman import NotificationManager class test_notifyman(unittest.TestCase): @pg_tmp def testNotificationManager(self): # signals each other alt = new() with alt: nm = NotificationManager(db, alt) db.listen('foo') alt.listen('bar') # notify the other. alt.notify('foo') db.notify('bar') # we can separate these here because there's no timeout for ndb, notifies in nm: for n in notifies: if ndb is db: self.assertEqual(n[0], 'foo') self.assertEqual(n[1], '') self.assertEqual(n[2], alt.backend_id) nm.connections.discard(db) elif ndb is alt: self.assertEqual(n[0], 'bar') self.assertEqual(n[1], '') self.assertEqual(n[2], db.backend_id) nm.connections.discard(alt) else: self.fail("unknown connection received notify..") @pg_tmp def testNotificationManagerTimeout(self): nm = NotificationManager(db, timeout = 0.1) db.listen('foo') count = 0 for event in nm: if event is None: # do this a few times, then break out of the loop db.notify('foo') continue ndb, notifies = event self.assertEqual(ndb, db) for n in notifies: self.assertEqual(n[0], 'foo') self.assertEqual(n[1], '') self.assertEqual(n[2], db.backend_id) count = count + 1 if count > 3: break @pg_tmp def testNotificationManagerZeroTimeout(self): # Zero-timeout means raise StopIteration when # there are no notifications to emit. # It checks the wire, but does *not* wait for data. nm = NotificationManager(db, timeout = 0) db.listen('foo') self.assertEqual(list(nm), []) db.notify('foo') time.sleep(0.01) self.assertEqual(list(nm), [('foo','',db.backend_id)]) # bit of a race @pg_tmp def test_iternotifies(self): # db.iternotifies() simplification of NotificationManager alt = new() alt.listen('foo') alt.listen('close') def get_notices(db, l): with db: for x in db.iternotifies(): if x[0] == 'close': break l.append(x) rl = [] t = threading.Thread(target = get_notices, args = (alt, rl,)) t.start() db.notify('foo') while not rl: time.sleep(0.05) channel, payload, pid = rl.pop(0) self.assertEqual(channel, 'foo') self.assertEqual(payload, '') self.assertEqual(pid, db.backend_id) db.notify('close') @pg_tmp def testNotificationManagerZeroTimeout(self): # Zero-timeout means raise StopIteration when # there are no notifications to emit. # It checks the wire, but does *not* wait for data. db.listen('foo') self.assertEqual(list(db.iternotifies(0)), []) db.notify('foo') time.sleep(0.01) self.assertEqual(list(db.iternotifies(0)), [('foo','', db.backend_id)]) # bit of a race @pg_tmp def testNotificationManagerOnClosed(self): # When the connection goes away, the NM iterator # should raise a Stop. db = new() db.listen('foo') db.notify('foo') for n in db.iternotifies(): db.close() self.assertEqual(db.closed, True) del db # closer, after an idle db = new() db.listen('foo') for n in db.iternotifies(0.2): if n is None: # In the loop, notify, and expect to # get the notification even though the # connection was closed. db.notify('foo') db.execute('') db.close() hit = False else: hit = True # hit should get set two times. # once on the first idle, and once on the event # received after the close. self.assertEqual(db.closed, True) self.assertEqual(hit, True) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_optimized.py000066400000000000000000000213071203372773200207130ustar00rootroot00000000000000## # test.test_optimized ## import unittest import struct import sys from ..port import optimized from ..python.itertools import interlace def pack_tuple(*data, packH = struct.Struct("!H").pack, packL = struct.Struct("!L").pack ): return packH(len(data)) + b''.join(( packL(len(x)) + x if x is not None else b'\xff\xff\xff\xff' for x in data )) tuplemessages = ( (b'D', pack_tuple(b'foo', b'bar')), (b'D', pack_tuple(b'foo', None, b'bar')), (b'N', b'fee'), (b'D', pack_tuple(b'foo', None, b'bar')), (b'D', pack_tuple(b'foo', b'bar')), ) class test_optimized(unittest.TestCase): def test_consume_tuple_messages(self): ctm = optimized.consume_tuple_messages # expecting a tuple of pairs. self.assertRaises(TypeError, ctm, []) self.assertEqual(ctm(()), []) # Make sure that the slicing is working. self.assertEqual(ctm(tuplemessages), [ (b'foo', b'bar'), (b'foo', None, b'bar'), ]) # Not really checking consume here, but we are validating that # it's properly propagating exceptions. self.assertRaises(ValueError, ctm, ((b'D', b'\xff\xff\xff\xfefoo'),)) self.assertRaises(ValueError, ctm, ((b'D', b'\x00\x00\x00\x04foo'),)) def test_parse_tuple_message(self): ptm = optimized.parse_tuple_message self.assertRaises(TypeError, ptm, "stringzor") self.assertRaises(TypeError, ptm, 123) self.assertRaises(ValueError, ptm, b'') self.assertRaises(ValueError, ptm, b'0') notenoughdata = struct.pack('!H', 2) self.assertRaises(ValueError, ptm, notenoughdata) wraparound = struct.pack('!HL', 2, 10) + (b'0' * 10) + struct.pack('!L', 0xFFFFFFFE) self.assertRaises(ValueError, ptm, wraparound) oneatt_notenough = struct.pack('!HL', 2, 10) + (b'0' * 10) + struct.pack('!L', 15) self.assertRaises(ValueError, ptm, oneatt_notenough) toomuchdata = struct.pack('!HL', 1, 3) + (b'0' * 10) self.assertRaises(ValueError, ptm, toomuchdata) class faketup(tuple): def __new__(subtype, geeze): r = tuple.__new__(subtype, ()) r.foo = geeze return r zerodata = struct.pack('!H', 0) r = ptm(zerodata) self.assertRaises(AttributeError, getattr, r, 'foo') self.assertRaises(AttributeError, setattr, r, 'foo', 'bar') self.assertEqual(len(r), 0) def test_process_tuple(self): def funpass(procs, tup, col): pass pt = optimized.process_tuple # tuple() requirements self.assertRaises(TypeError, pt, "foo", "bar", funpass) self.assertRaises(TypeError, pt, (), "bar", funpass) self.assertRaises(TypeError, pt, "foo", (), funpass) self.assertRaises(TypeError, pt, (), ("foo",), funpass) def test_pack_tuple_data(self): pit = optimized.pack_tuple_data self.assertEqual(pit((None,)), b'\xff\xff\xff\xff') self.assertEqual(pit((None,)*2), b'\xff\xff\xff\xff'*2) self.assertEqual(pit((None,)*3), b'\xff\xff\xff\xff'*3) self.assertEqual(pit((None,b'foo')), b'\xff\xff\xff\xff\x00\x00\x00\x03foo') self.assertEqual(pit((None,b'')), b'\xff\xff\xff\xff\x00\x00\x00\x00') self.assertEqual(pit((None,b'',b'bar')), b'\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x03bar') self.assertRaises(TypeError, pit, 1) self.assertRaises(TypeError, pit, (1,)) self.assertRaises(TypeError, pit, ("",)) def test_int2(self): d = b'\x00\x01' rd = b'\x01\x00' s = optimized.swap_int2_unpack(d) n = optimized.int2_unpack(d) sd = optimized.swap_int2_pack(1) nd = optimized.int2_pack(1) if sys.byteorder == 'little': self.assertEqual(1, s) self.assertEqual(256, n) self.assertEqual(d, sd) self.assertEqual(rd, nd) else: self.assertEqual(1, n) self.assertEqual(256, s) self.assertEqual(d, nd) self.assertEqual(rd, sd) self.assertRaises(OverflowError, optimized.swap_int2_pack, 2**15) self.assertRaises(OverflowError, optimized.int2_pack, 2**15) self.assertRaises(OverflowError, optimized.swap_int2_pack, (-2**15)-1) self.assertRaises(OverflowError, optimized.int2_pack, (-2**15)-1) def test_int4(self): d = b'\x00\x00\x00\x01' rd = b'\x01\x00\x00\x00' s = optimized.swap_int4_unpack(d) n = optimized.int4_unpack(d) sd = optimized.swap_int4_pack(1) nd = optimized.int4_pack(1) if sys.byteorder == 'little': self.assertEqual(1, s) self.assertEqual(16777216, n) self.assertEqual(d, sd) self.assertEqual(rd, nd) else: self.assertEqual(1, n) self.assertEqual(16777216, s) self.assertEqual(d, nd) self.assertEqual(rd, sd) self.assertRaises(OverflowError, optimized.swap_int4_pack, 2**31) self.assertRaises(OverflowError, optimized.int4_pack, 2**31) self.assertRaises(OverflowError, optimized.swap_int4_pack, (-2**31)-1) self.assertRaises(OverflowError, optimized.int4_pack, (-2**31)-1) def test_int8(self): d = b'\x00\x00\x00\x00\x00\x00\x00\x01' rd = b'\x01\x00\x00\x00\x00\x00\x00\x00' s = optimized.swap_int8_unpack(d) n = optimized.int8_unpack(d) sd = optimized.swap_int8_pack(1) nd = optimized.int8_pack(1) if sys.byteorder == 'little': self.assertEqual(0x1, s) self.assertEqual(0x100000000000000, n) self.assertEqual(d, sd) self.assertEqual(rd, nd) else: self.assertEqual(0x1, n) self.assertEqual(0x100000000000000, s) self.assertEqual(d, nd) self.assertEqual(rd, sd) self.assertEqual(optimized.swap_int8_pack(-1), b'\xFF\xFF\xFF\xFF'*2) self.assertEqual(optimized.int8_pack(-1), b'\xFF\xFF\xFF\xFF'*2) self.assertRaises(OverflowError, optimized.swap_int8_pack, 2**63) self.assertRaises(OverflowError, optimized.int8_pack, 2**63) self.assertRaises(OverflowError, optimized.swap_int8_pack, (-2**63)-1) self.assertRaises(OverflowError, optimized.int8_pack, (-2**63)-1) # edge I/O int8_max = ((2**63) - 1) int8_min = (-(2**63)) swap_max = optimized.swap_int8_pack(int8_max) max = optimized.int8_pack(int8_max) swap_min = optimized.swap_int8_pack(int8_min) min = optimized.int8_pack(int8_min) self.assertEqual(optimized.swap_int8_unpack(swap_max), int8_max) self.assertEqual(optimized.int8_unpack(max), int8_max) self.assertEqual(optimized.swap_int8_unpack(swap_min), int8_min) self.assertEqual(optimized.int8_unpack(min), int8_min) def test_uint2(self): d = b'\x00\x01' rd = b'\x01\x00' s = optimized.swap_uint2_unpack(d) n = optimized.uint2_unpack(d) sd = optimized.swap_uint2_pack(1) nd = optimized.uint2_pack(1) if sys.byteorder == 'little': self.assertEqual(1, s) self.assertEqual(256, n) self.assertEqual(d, sd) self.assertEqual(rd, nd) else: self.assertEqual(1, n) self.assertEqual(256, s) self.assertEqual(d, nd) self.assertEqual(rd, sd) self.assertRaises(OverflowError, optimized.swap_uint2_pack, -1) self.assertRaises(OverflowError, optimized.uint2_pack, -1) self.assertRaises(OverflowError, optimized.swap_uint2_pack, 2**16) self.assertRaises(OverflowError, optimized.uint2_pack, 2**16) self.assertEqual(optimized.uint2_pack(2**16-1), b'\xFF\xFF') self.assertEqual(optimized.swap_uint2_pack(2**16-1), b'\xFF\xFF') def test_uint4(self): d = b'\x00\x00\x00\x01' rd = b'\x01\x00\x00\x00' s = optimized.swap_uint4_unpack(d) n = optimized.uint4_unpack(d) sd = optimized.swap_uint4_pack(1) nd = optimized.uint4_pack(1) if sys.byteorder == 'little': self.assertEqual(1, s) self.assertEqual(16777216, n) self.assertEqual(d, sd) self.assertEqual(rd, nd) else: self.assertEqual(1, n) self.assertEqual(16777216, s) self.assertEqual(d, nd) self.assertEqual(rd, sd) self.assertRaises(OverflowError, optimized.swap_uint4_pack, -1) self.assertRaises(OverflowError, optimized.uint4_pack, -1) self.assertRaises(OverflowError, optimized.swap_uint4_pack, 2**32) self.assertRaises(OverflowError, optimized.uint4_pack, 2**32) self.assertEqual(optimized.uint4_pack(2**32-1), b'\xFF\xFF\xFF\xFF') self.assertEqual(optimized.swap_uint4_pack(2**32-1), b'\xFF\xFF\xFF\xFF') def test_uint8(self): d = b'\x00\x00\x00\x00\x00\x00\x00\x01' rd = b'\x01\x00\x00\x00\x00\x00\x00\x00' s = optimized.swap_uint8_unpack(d) n = optimized.uint8_unpack(d) sd = optimized.swap_uint8_pack(1) nd = optimized.uint8_pack(1) if sys.byteorder == 'little': self.assertEqual(0x1, s) self.assertEqual(0x100000000000000, n) self.assertEqual(d, sd) self.assertEqual(rd, nd) else: self.assertEqual(0x1, n) self.assertEqual(0x100000000000000, s) self.assertEqual(d, nd) self.assertEqual(rd, sd) self.assertRaises(OverflowError, optimized.swap_uint8_pack, -1) self.assertRaises(OverflowError, optimized.uint8_pack, -1) self.assertRaises(OverflowError, optimized.swap_uint8_pack, 2**64) self.assertRaises(OverflowError, optimized.uint8_pack, 2**64) self.assertEqual(optimized.uint8_pack((2**64)-1), b'\xFF\xFF\xFF\xFF'*2) self.assertEqual(optimized.swap_uint8_pack((2**64)-1), b'\xFF\xFF\xFF\xFF'*2) if __name__ == '__main__': from types import ModuleType this = ModuleType("this") this.__dict__.update(globals()) unittest.main(this) fe-1.1.0/postgresql/test/test_pgpassfile.py000066400000000000000000000037321203372773200210460ustar00rootroot00000000000000## # .test.test_pgpassfile ## import unittest from .. import pgpassfile as client_pgpass from io import StringIO passfile_sample = """ # host:1111:dbname:user:password1 host:1111:dbname:user:password1 *:1111:dbname:user:password2 *:*:dbname:user:password3 # Comment *:*:*:user:password4 *:*:*:usern:password4.5 *:*:*:*:password5 """ passfile_sample_map = { ('user', 'host', '1111', 'dbname') : 'password1', ('user', 'host', '1111', 'dbname') : 'password1', ('user', 'foo', '1111', 'dbname') : 'password2', ('user', 'foo', '4321', 'dbname') : 'password3', ('user', 'foo', '4321', 'db,name') : 'password4', ('uuser', 'foo', '4321', 'db,name') : 'password5', ('usern', 'foo', '4321', 'db,name') : 'password4.5', ('foo', 'bar', '19231', 'somedbn') : 'password5', } difficult_passfile_sample = r""" host\\:1111:db\:name:u\\ser:word1 *:1111:\:dbname\::\\user\\:pass\:word2 foohost:1111:\:dbname\::\\user\\:pass\:word3 """ difficult_passfile_sample_map = { ('u\\ser','host\\','1111','db:name') : 'word1', ('\\user\\','somehost','1111',':dbname:') : 'pass:word2', ('\\user\\','someotherhost','1111',':dbname:') : 'pass:word2', # More specific, but comes after '*' ('\\user\\','foohost','1111',':dbname:') : 'pass:word2', ('','','','') : None, } class test_pgpass(unittest.TestCase): def runTest(self): sample1 = client_pgpass.parse(StringIO(passfile_sample)) sample2 = client_pgpass.parse(StringIO(difficult_passfile_sample)) for k, pw in passfile_sample_map.items(): lpw = client_pgpass.lookup_password(sample1, k) self.assertEqual(lpw, pw, "password lookup incongruity, expecting %r got %r with %r" " in \n%s" %( pw, lpw, k, passfile_sample ) ) for k, pw in difficult_passfile_sample_map.items(): lpw = client_pgpass.lookup_password(sample2, k) self.assertEqual(lpw, pw, "password lookup incongruity, expecting %r got %r with %r" " in \n%s" %( pw, lpw, k, difficult_passfile_sample ) ) if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_protocol.py000066400000000000000000000433421203372773200205530ustar00rootroot00000000000000## # .test.test_protocol ## import sys import unittest import struct import decimal import socket import time from threading import Thread from ..protocol import element3 as e3 from ..protocol import xact3 as x3 from ..protocol import client3 as c3 from ..protocol import buffer as pq_buf from ..python.socket import find_available_port, SocketFactory def pair(msg): return (msg.type, msg.serialize()) def pairs(*msgseq): return list(map(pair, msgseq)) long = struct.Struct("!L") packl = long.pack unpackl = long.unpack class test_buffer(unittest.TestCase): def setUp(self): self.buffer = pq_buf.pq_message_stream() def testMultiByteMessage(self): b = self.buffer b.write(b's') self.assertTrue(b.next_message() is None) b.write(b'\x00\x00') self.assertTrue(b.next_message() is None) b.write(b'\x00\x10') self.assertTrue(b.next_message() is None) data = b'twelve_chars' b.write(data) self.assertEqual(b.next_message(), (b's', data)) def testSingleByteMessage(self): b = self.buffer b.write(b's') self.assertTrue(b.next_message() is None) b.write(b'\x00') self.assertTrue(b.next_message() is None) b.write(b'\x00\x00\x05') self.assertTrue(b.next_message() is None) b.write(b'b') self.assertEqual(b.next_message(), (b's', b'b')) def testEmptyMessage(self): b = self.buffer b.write(b'x') self.assertTrue(b.next_message() is None) b.write(b'\x00\x00\x00') self.assertTrue(b.next_message() is None) b.write(b'\x04') self.assertEqual(b.next_message(), (b'x', b'')) def testInvalidLength(self): b = self.buffer b.write(b'y\x00\x00\x00\x03') self.assertRaises(ValueError, b.next_message,) def testRemainder(self): b = self.buffer b.write(b'r\x00\x00\x00\x05Aremainder') self.assertEqual(b.next_message(), (b'r', b'A')) def testLarge(self): b = self.buffer factor = 1024 r = 10000 b.write(b'X' + packl(factor * r + 4)) segment = b'\x00' * factor for x in range(r-1): b.write(segment) b.write(segment) msg = b.next_message() self.assertTrue(msg is not None) self.assertEqual(msg[0], b'X') def test_getvalue(self): # Make sure that getvalue() only applies to messages # that have not been read. b = self.buffer # It should be empty. self.assertEqual(b.getvalue(), b'') d = b'F' + packl(28) b.write(d) self.assertEqual(b.getvalue(), d) d1 = b'01'*12 # 24 b.write(d1) self.assertEqual(b.getvalue(), d + d1) out = b.read()[0] self.assertEqual(out, (b'F', d1)) nd = b'N' b.write(nd) self.assertEqual(b.getvalue(), nd) b.write(packl(4)) self.assertEqual(list(b.read()), [(b'N', b'')]) self.assertEqual(b.getvalue(), b'') # partial; read one message to exercise # that the appropriate fragment of the first # chunk in the buffer is picked up. first_body = (b'1234' * 3) first = b'v' + packl(len(first_body) + 4) + first_body second_body = (b'4321' * 5) second = b'z' + packl(len(second_body) + 4) + second_body b.write(first + second) self.assertEqual(b.getvalue(), first + second) self.assertEqual(list(b.read(1)), [(b'v', first_body)]) self.assertEqual(b.getvalue(), second) self.assertEqual(list(b.read(1)), [(b'z', second_body)]) # now, with a third full message in the next chunk third_body = (b'9876' * 10) third = b'3' + packl(len(third_body) + 4) + third_body b.write(first + second) b.write(third) self.assertEqual(b.getvalue(), first + second + third) self.assertEqual(list(b.read(1)), [(b'v', first_body)]) self.assertEqual(b.getvalue(), second + third) self.assertEqual(list(b.read(1)), [(b'z', second_body)]) self.assertEqual(b.getvalue(), third) self.assertEqual(list(b.read(1)), [(b'3', third_body)]) self.assertEqual(b.getvalue(), b'') ## # element3 tests ## message_samples = [ e3.VoidMessage, e3.Startup([ (b'user', b'jwp'), (b'database', b'template1'), (b'options', b'-f'), ]), e3.Notice(( (b'S', b'FATAL'), (b'M', b'a descriptive message'), (b'C', b'FIVEC'), (b'D', b'bleh'), (b'H', b'dont spit into the fan'), )), e3.Notify(123, b'wood_table'), e3.KillInformation(19320, 589483), e3.ShowOption(b'foo', b'bar'), e3.Authentication(4, b'salt'), e3.Complete(b'SELECT'), e3.Ready(b'I'), e3.CancelRequest(4123, 14252), e3.NegotiateSSL(), e3.Password(b'ckr4t'), e3.AttributeTypes(()), e3.AttributeTypes( (123,) * 1 ), e3.AttributeTypes( (123,0) * 1 ), e3.AttributeTypes( (123,0) * 2 ), e3.AttributeTypes( (123,0) * 4 ), e3.TupleDescriptor(()), e3.TupleDescriptor(( (b'name', 123, 1, 1, 0, 0, 1,), )), e3.TupleDescriptor(( (b'name', 123, 1, 2, 0, 0, 1,), ) * 2), e3.TupleDescriptor(( (b'name', 123, 1, 2, 1, 0, 1,), ) * 3), e3.TupleDescriptor(( (b'name', 123, 1, 1, 0, 0, 1,), ) * 1000), e3.Tuple([]), e3.Tuple([b'foo',]), e3.Tuple([None]), e3.Tuple([b'foo',b'bar']), e3.Tuple([None, None]), e3.Tuple([None, b'foo', None]), e3.Tuple([b'bar', None, b'foo', None, b'bleh']), e3.Tuple([b'foo', b'bar'] * 100), e3.Tuple([None] * 100), e3.Query(b'select * from u'), e3.Parse(b'statement_id', b'query', (123, 0)), e3.Parse(b'statement_id', b'query', (123,)), e3.Parse(b'statement_id', b'query', ()), e3.Bind(b'portal_id', b'statement_id', (b'tt',b'\x00\x00'), [b'data',None], (b'ff',b'xx')), e3.Bind(b'portal_id', b'statement_id', (b'tt',), [None], (b'xx',)), e3.Bind(b'portal_id', b'statement_id', (b'ff',), [b'data'], ()), e3.Bind(b'portal_id', b'statement_id', (), [], (b'xx',)), e3.Bind(b'portal_id', b'statement_id', (), [], ()), e3.Execute(b'portal_id', 500), e3.Execute(b'portal_id', 0), e3.DescribeStatement(b'statement_id'), e3.DescribePortal(b'portal_id'), e3.CloseStatement(b'statement_id'), e3.ClosePortal(b'portal_id'), e3.Function(123, (), [], b'xx'), e3.Function(321, (b'tt',), [b'foo'], b'xx'), e3.Function(321, (b'tt',), [None], b'xx'), e3.Function(321, (b'aa', b'aa'), [None,b'a' * 200], b'xx'), e3.FunctionResult(b''), e3.FunctionResult(b'foobar'), e3.FunctionResult(None), e3.CopyToBegin(123, [321,123]), e3.CopyToBegin(0, [10,]), e3.CopyToBegin(123, []), e3.CopyFromBegin(123, [321,123]), e3.CopyFromBegin(0, [10]), e3.CopyFromBegin(123, []), e3.CopyData(b''), e3.CopyData(b'foo'), e3.CopyData(b'a' * 2048), e3.CopyFail(b''), e3.CopyFail(b'iiieeeeee!'), ] class test_element3(unittest.TestCase): def test_cat_messages(self): # The optimized implementation will identify adjacent copy data, and # take a more efficient route; so rigorously test the switch between the # two modes. self.assertEqual(e3.cat_messages([]), b'') self.assertEqual(e3.cat_messages([b'foo']), b'd\x00\x00\x00\x07foo') self.assertEqual(e3.cat_messages([b'foo', b'foo']), 2*b'd\x00\x00\x00\x07foo') # copy, other, copy self.assertEqual(e3.cat_messages([b'foo', e3.SynchronizeMessage, b'foo']), b'd\x00\x00\x00\x07foo' + e3.SynchronizeMessage.bytes() + b'd\x00\x00\x00\x07foo') # copy, other, copy*1000 self.assertEqual(e3.cat_messages(1000*[b'foo', e3.SynchronizeMessage, b'foo']), 1000*(b'd\x00\x00\x00\x07foo' + e3.SynchronizeMessage.bytes() + b'd\x00\x00\x00\x07foo')) # other, copy, copy*1000 self.assertEqual(e3.cat_messages(1000*[e3.SynchronizeMessage, b'foo', b'foo']), 1000*(e3.SynchronizeMessage.bytes() + 2*b'd\x00\x00\x00\x07foo')) pack_head = struct.Struct("!lH").pack # tuple self.assertEqual(e3.cat_messages([(b'foo',),]), b'D' + pack_head(7 + 4 + 2, 1) + b'\x00\x00\x00\x03foo') # tuple(foo,\N) self.assertEqual(e3.cat_messages([(b'foo',None,),]), b'D' + pack_head(7 + 4 + 4 + 2, 2) + b'\x00\x00\x00\x03foo\xFF\xFF\xFF\xFF') # tuple(foo,\N,bar) self.assertEqual(e3.cat_messages([(b'foo',None,b'bar'),]), b'D' + pack_head(7 + 7 + 4 + 4 + 2, 3) + \ b'\x00\x00\x00\x03foo\xFF\xFF\xFF\xFF\x00\x00\x00\x03bar') # too many attributes self.assertRaises((OverflowError, struct.error), e3.cat_messages, [(None,) * 0x10000]) class ThisEx(Exception): pass class ThatEx(Exception): pass class Bad(e3.Message): def serialize(self): raise ThisEx('foo') self.assertRaises(ThisEx, e3.cat_messages, [Bad()]) class NoType(e3.Message): def serialize(self): return b'' self.assertRaises(AttributeError, e3.cat_messages, [NoType()]) class BadType(e3.Message): type = 123 def serialize(self): return b'' self.assertRaises((TypeError,struct.error), e3.cat_messages, [BadType()]) def testSerializeParseConsistency(self): for msg in message_samples: smsg = msg.serialize() self.assertEqual(msg, msg.parse(smsg)) def testEmptyMessages(self): for x in e3.__dict__.values(): if isinstance(x, e3.EmptyMessage): xtype = type(x) self.assertTrue(x is xtype()) def testUnknownNoticeFields(self): N = e3.Notice.parse(b'\x00\x00Z\x00Xklsvdnvldsvkndvlsn\x00Pfoobar\x00Mmessage\x00') E = e3.Error.parse(b'Z\x00Xklsvdnvldsvkndvlsn\x00Pfoobar\x00Mmessage\x00\x00') self.assertEqual(N[b'M'], b'message') self.assertEqual(E[b'M'], b'message') self.assertEqual(N[b'P'], b'foobar') self.assertEqual(E[b'P'], b'foobar') self.assertEqual(len(N), 4) self.assertEqual(len(E), 4) def testCompleteExtracts(self): x = e3.Complete(b'FOO BAR 1321') self.assertEqual(x.extract_command(), b'FOO BAR') self.assertEqual(x.extract_count(), 1321) x = e3.Complete(b' CREATE TABLE 13210 ') self.assertEqual(x.extract_command(), b'CREATE TABLE') self.assertEqual(x.extract_count(), 13210) x = e3.Complete(b' CREATE TABLE \t713210 ') self.assertEqual(x.extract_command(), b'CREATE TABLE') self.assertEqual(x.extract_count(), 713210) x = e3.Complete(b' CREATE TABLE 0 \t13210 ') self.assertEqual(x.extract_command(), b'CREATE TABLE') self.assertEqual(x.extract_count(), 13210) x = e3.Complete(b' 0 \t13210 ') self.assertEqual(x.extract_command(), None) self.assertEqual(x.extract_count(), 13210) ## # .protocol.xact3 tests ## xact_samples = [ # Simple contrived exchange. ( ( e3.Query(b"COMPLETE"), ), ( e3.Complete(b'COMPLETE'), e3.Ready(b'I'), ) ), ( ( e3.Query(b"ROW DATA"), ), ( e3.TupleDescriptor(( (b'foo', 1, 1, 1, 1, 1, 1), (b'bar', 1, 2, 1, 1, 1, 1), )), e3.Tuple((b'lame', b'lame')), e3.Complete(b'COMPLETE'), e3.Ready(b'I'), ) ), ( ( e3.Query(b"ROW DATA"), ), ( e3.TupleDescriptor(( (b'foo', 1, 1, 1, 1, 1, 1), (b'bar', 1, 2, 1, 1, 1, 1), )), e3.Tuple((b'lame', b'lame')), e3.Tuple((b'lame', b'lame')), e3.Tuple((b'lame', b'lame')), e3.Tuple((b'lame', b'lame')), e3.Ready(b'I'), ) ), ( ( e3.Query(b"NULL"), ), ( e3.Null(), e3.Ready(b'I'), ) ), ( ( e3.Query(b"COPY TO"), ), ( e3.CopyToBegin(1, [1,2]), e3.CopyData(b'row1'), e3.CopyData(b'row2'), e3.CopyDone(), e3.Complete(b'COPY TO'), e3.Ready(b'I'), ) ), ( ( e3.Function(1, [b''], [b''], 1), ), ( e3.FunctionResult(b'foo'), e3.Ready(b'I'), ) ), ( ( e3.Parse(b"NAME", b"SQL", ()), ), ( e3.ParseComplete(), ) ), ( ( e3.Bind(b"NAME", b"STATEMENT_ID", (), (), ()), ), ( e3.BindComplete(), ) ), ( ( e3.Parse(b"NAME", b"SQL", ()), e3.Bind(b"NAME", b"STATEMENT_ID", (), (), ()), ), ( e3.ParseComplete(), e3.BindComplete(), ) ), ( ( e3.Describe(b"STATEMENT_ID"), ), ( e3.AttributeTypes(()), e3.NoData(), ) ), ( ( e3.Describe(b"STATEMENT_ID"), ), ( e3.AttributeTypes(()), e3.TupleDescriptor(()), ) ), ( ( e3.CloseStatement(b"foo"), ), ( e3.CloseComplete(), ), ), ( ( e3.ClosePortal(b"foo"), ), ( e3.CloseComplete(), ), ), ( ( e3.Synchronize(), ), ( e3.Ready(b'I'), ), ), ] class test_xact3(unittest.TestCase): def testTransactionSamplesAll(self): for xcmd, xres in xact_samples: x = x3.Instruction(xcmd) r = tuple([(y.type, y.serialize()) for y in xres]) x.state[1]() self.assertEqual(x.messages, ()) x.state[1](r) self.assertEqual(x.state, x3.Complete) rec = [] for y in x.completed: for z in y[1]: if type(z) is type(b''): z = e3.CopyData(z) rec.append(z) self.assertEqual(xres, tuple(rec)) def testClosing(self): c = x3.Closing() self.assertEqual(c.messages, (e3.DisconnectMessage,)) c.state[1]() self.assertEqual(c.fatal, True) self.assertEqual(c.error_message.__class__, e3.ClientError) self.assertEqual(c.error_message[b'C'], '08003') def testNegotiation(self): # simple successful run n = x3.Negotiation({}, b'') n.state[1]() n.state[1]( pairs( e3.Notice(((b'M', b"foobar"),)), e3.Authentication(e3.AuthRequest_OK, b''), e3.KillInformation(0,0), e3.ShowOption(b'name', b'val'), e3.Ready(b'I'), ) ) self.assertEqual(n.state, x3.Complete) self.assertEqual(n.last_ready.xact_state, b'I') # no killinfo.. should cause protocol error... n = x3.Negotiation({}, b'') n.state[1]() n.state[1]( pairs( e3.Notice(((b'M', b"foobar"),)), e3.Authentication(e3.AuthRequest_OK, b''), e3.ShowOption(b'name', b'val'), e3.Ready(b'I'), ) ) self.assertEqual(n.state, x3.Complete) self.assertEqual(n.last_ready, None) self.assertEqual(n.error_message[b'C'], '08P01') # killinfo twice.. must cause protocol error... n = x3.Negotiation({}, b'') n.state[1]() n.state[1]( pairs( e3.Notice(((b'M', b"foobar"),)), e3.Authentication(e3.AuthRequest_OK, b''), e3.ShowOption(b'name', b'val'), e3.KillInformation(0,0), e3.KillInformation(0,0), e3.Ready(b'I'), ) ) self.assertEqual(n.state, x3.Complete) self.assertEqual(n.last_ready, None) self.assertEqual(n.error_message[b'C'], '08P01') # start with ready message.. n = x3.Negotiation({}, b'') n.state[1]() n.state[1]( pairs( e3.Notice(((b'M', b"foobar"),)), e3.Ready(b'I'), e3.Authentication(e3.AuthRequest_OK, b''), e3.ShowOption(b'name', b'val'), ) ) self.assertEqual(n.state, x3.Complete) self.assertEqual(n.last_ready, None) self.assertEqual(n.error_message[b'C'], '08P01') # unsupported authreq n = x3.Negotiation({}, b'') n.state[1]() n.state[1]( pairs( e3.Authentication(255, b''), ) ) self.assertEqual(n.state, x3.Complete) self.assertEqual(n.last_ready, None) self.assertEqual(n.error_message[b'C'], '--AUT') def testInstructionAsynchook(self): l = [] def hook(data): l.append(data) x = x3.Instruction([ e3.Query(b"NOTHING") ], asynchook = hook) a1 = e3.Notice(((b'M', b"m1"),)) a2 = e3.Notify(0, b'relation', b'parameter') a3 = e3.ShowOption(b'optname', b'optval') # "send" the query message x.state[1]() # "receive" the tuple x.state[1]([(a1.type, a1.serialize()),]) a2l = [(a2.type, a2.serialize()),] x.state[1](a2l) # validate that the hook is not fed twice because # it's the exact same message set. (later assertion will validate) x.state[1](a2l) x.state[1]([(a3.type, a3.serialize()),]) # we only care about validating that l got everything. self.assertEqual([a1,a2,a3], l) self.assertEqual(x.state[0], x3.Receiving) # validate that the asynchook exception is trapped. class Nee(Exception): pass def ehook(msg): raise Nee("this should **not** be part of the summary") x = x3.Instruction([ e3.Query(b"NOTHING") ], asynchook = ehook) a1 = e3.Notice(((b'M', b"m1"),)) x.state[1]() import sys v = None def exchook(typ, val, tb): nonlocal v v = val seh = sys.excepthook sys.excepthook = exchook # we only care about validating that the exchook got called. x.state[1]([(a1.type, a1.serialize())]) sys.excepthook = seh self.assertTrue(isinstance(v, Nee)) class test_client3(unittest.TestCase): def test_timeout(self): portnum = find_available_port() servsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with servsock: servsock.bind(('localhost', portnum)) pc = c3.Connection( SocketFactory( (socket.AF_INET, socket.SOCK_STREAM), ('localhost', portnum) ), {} ) pc.connect(timeout = 1) try: self.assertEqual(pc.xact.fatal, True) self.assertEqual(pc.xact.__class__, x3.Negotiation) finally: if pc.socket is not None: pc.socket.close() def test_SSL_failure(self): portnum = find_available_port() servsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with servsock: servsock.bind(('localhost', portnum)) pc = c3.Connection( SocketFactory( (socket.AF_INET, socket.SOCK_STREAM), ('localhost', portnum) ), {} ) exc = None servsock.listen(1) def client_thread(): pc.connect(ssl = True) client = Thread(target = client_thread) try: client.start() c, addr = servsock.accept() with c: c.send(b'S') c.sendall(b'0000000000000000000000') c.recv(1024) c.close() client.join() finally: if pc.socket is not None: pc.socket.close() self.assertEqual(pc.xact.fatal, True) self.assertEqual(pc.xact.__class__, x3.Negotiation) self.assertEqual(pc.xact.error_message.__class__, e3.ClientError) self.assertTrue(hasattr(pc.xact, 'exception')) def test_bad_negotiation(self): portnum = find_available_port() servsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) servsock.bind(('localhost', portnum)) pc = c3.Connection( SocketFactory( (socket.AF_INET, socket.SOCK_STREAM), ('localhost', portnum) ), {} ) exc = None servsock.listen(1) def client_thread(): pc.connect() client = Thread(target = client_thread) try: client.start() c, addr = servsock.accept() try: c.recv(1024) finally: c.close() time.sleep(0.25) client.join() servsock.close() self.assertEqual(pc.xact.fatal, True) self.assertEqual(pc.xact.__class__, x3.Negotiation) self.assertEqual(pc.xact.error_message.__class__, e3.ClientError) self.assertEqual(pc.xact.error_message[b'C'], '08006') finally: servsock.close() if pc.socket is not None: pc.socket.close() if __name__ == '__main__': from types import ModuleType this = ModuleType("this") this.__dict__.update(globals()) try: unittest.main(this) finally: import gc gc.collect() fe-1.1.0/postgresql/test/test_python.py000066400000000000000000000117631203372773200202350ustar00rootroot00000000000000## # .test.test_python ## import unittest import socket import errno import struct from itertools import chain from operator import methodcaller from contextlib import contextmanager from ..python.itertools import interlace from ..python.structlib import split_sized_data from ..python import functools from ..python import itertools from ..python.socket import find_available_port from ..python import element class Ele(element.Element): _e_label = property( lambda x: getattr(x, 'label', 'ELEMENT') ) _e_factors = ('ancestor', 'secondary') secondary = None def __init__(self, s = None): self.ancestor = s def __str__(self): return 'STRDATA' def _e_metas(self): yield ('first', getattr(self, 'first', 'firstv')) yield ('second', getattr(self, 'second', 'secondv')) class test_element(unittest.TestCase): def test_primary_factor(self): x = Ele() # no factors self.assertEqual(element.prime_factor(object()), None) self.assertEqual(element.prime_factor(x), ('ancestor', None)) y = Ele(x) self.assertEqual(element.prime_factor(y), ('ancestor', x)) def test_primary_factors(self): x = Ele() x.ancestor = x self.assertRaises( element.RecursiveFactor, list, element.prime_factors(x) ) y = Ele(x) x.ancestor = y self.assertRaises( element.RecursiveFactor, list, element.prime_factors(y) ) self.assertRaises( element.RecursiveFactor, list, element.prime_factors(x) ) x.ancestor = None z = Ele(y) self.assertEqual(list(element.prime_factors(z)), [ ('ancestor', y), ('ancestor', x), ('ancestor', None), ]) def test_format_element(self): # Considering that this is subject to change, frequently, # I/O equality tests are inappropriate. # Rather, a hierarchy will be defined, and the existence # of certain pieces of information in the string will be validated. x = Ele() y = Ele() z = Ele() alt1 = Ele() alt2 = Ele() alt1.first = 'alt1-first' alt1.second = 'alt1-second' alt2.first = 'alt2-first' alt2.second = 'alt2-second' altprime = Ele() altprime.first = 'alt2-ancestor' alt2.ancestor = altprime z.ancestor = y y.ancestor = x z.secondary = alt1 y.secondary = alt2 x.first = 'unique1' y.first = 'unique2' x.second = 'unique3' z.second = 'unique4' y.label = 'DIFF' data = element.format_element(z) self.assertTrue(x.first in data) self.assertTrue(y.first in data) self.assertTrue(x.second in data) self.assertTrue(z.second in data) self.assertTrue('DIFF' in data) self.assertTrue('alt1-first' in data) self.assertTrue('alt2-first' in data) self.assertTrue('alt1-second' in data) self.assertTrue('alt2-second' in data) self.assertTrue('alt2-ancestor' in data) x.ancestor = z self.assertRaises(element.RecursiveFactor, element.format_element, z) class test_itertools(unittest.TestCase): def testInterlace(self): i1 = range(0, 100, 4) i2 = range(1, 100, 4) i3 = range(2, 100, 4) i4 = range(3, 100, 4) self.assertEqual( list(itertools.interlace(i1, i2, i3, i4)), list(range(100)) ) class test_functools(unittest.TestCase): def testComposition(self): compose = functools.Composition simple = compose((int, str)) self.assertEqual("100", simple("100")) timesfour_fourtimes = compose((methodcaller('__mul__', 4),)*4) self.assertEqual(4*(4*4*4*4), timesfour_fourtimes(4)) nothing = compose(()) self.assertEqual(nothing("100"), "100") self.assertEqual(nothing(100), 100) self.assertEqual(nothing(None), None) def testRSetAttr(self): class anob(object): pass ob = anob() self.assertRaises(AttributeError, getattr, ob, 'foo') rob = functools.rsetattr('foo', 'bar', ob) self.assertTrue(rob is ob) self.assertTrue(rob.foo is ob.foo) self.assertTrue(rob.foo == 'bar') class test_socket(unittest.TestCase): def testFindAvailable(self): # the port is randomly generated, so make a few trials before # determining success. for i in range(100): portnum = find_available_port() s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: s.connect(('localhost', portnum)) except socket.error as err: self.assertEqual(err.errno, errno.ECONNREFUSED) else: self.fail("got a connection to an available port: " + str(portnum)) finally: s.close() def join_sized_data(*data, packL = struct.Struct("!L").pack, getlen = lambda x: len(x) if x is not None else 0xFFFFFFFF ): return b''.join(interlace(map(packL, map(getlen, data)), (x if x is not None else b'' for x in data))) class test_structlib(unittest.TestCase): def testSizedSplit(self): sample = [ (b'foo', b'bar'), (b'foo', None, b'bar'), (b'foo', None, b'bar'), (b'foo', b'bar'), (), (None,None,None), (b'x', None,None,None, b'yz'), ] packed_sample = [join_sized_data(*x) for x in sample] self.assertRaises(ValueError, split_sized_data(b'\xFF\xFF\xFF\x01foo').__next__) self.assertEqual(sample, [tuple(split_sized_data(x)) for x in packed_sample]) if __name__ == '__main__': from types import ModuleType this = ModuleType("this") this.__dict__.update(globals()) unittest.main(this) fe-1.1.0/postgresql/test/test_ssl_connect.py000066400000000000000000000214711203372773200212230ustar00rootroot00000000000000## # .test.test_ssl_connect ## import sys import os import unittest from .. import exceptions as pg_exc from .. import driver as pg_driver from ..driver import dbapi20 from . import test_connect server_key = """ -----BEGIN RSA PRIVATE KEY----- MIICXAIBAAKBgQCy8veVaqL6MZVT8o0j98ggZYfibGwSN4XGC4rfineA2QZhi8t+ zrzfOS10vLXKtgiIpevHeQbDlrqFDPUDowozurg+jfro2L1jzQjZPdgqOUs+YjKh EO0Ya7NORO7ZgBx8WveXq30k4l8DK41jvpxRyBb9aqNWG4cB7fJqVTwZrwIDAQAB AoGAJ74URGfheEVoz7MPq4xNMvy5mAzSV51jJV/M4OakscYBR8q/UBNkGQNe2A1N Jo8VCBwpaCy11txz4jbFd6BPFFykgXleuRvMxoTv1qV0dZZ0X0ESNEAnjoHtjin/ 25mxsZTR6ucejHqXD9qE9NvFQ+wLv6Xo5rgDpx0onvgLA3kCQQDn4GeMkCfPZCve lDUK+TpJnLYupyElZiidoFMITlFo5WoWNJror2W42A5TD9sZ23pGSxw7ypiWIF4f ukGT5ZSzAkEAxZDwUUhgtoJIK7E9sCJM4AvcjDxGjslbUI/SmQTT+aTNCAmcIRrl kq3WMkPjxi/QFEdkIpPsV9Kc94oQ/8b9FQJBAKHxRQCTsWoTsNvbsIwAcif1Lfu5 N9oR1i34SeVUJWFYUFY/2SzHSwjkxGRYf5I4idZMIOTVYun+ox4PjDtJrScCQEQ4 RiNrIKok1pLvwuNdFLqQnfl2ns6TTQrGfuwDtMaRV5Mc7mKoDPnXOQ1mT/KRdAJs nHEsLwIsYbNAY5pOtfkCQDOy2Ffe7Z1YzFZXCTzpcq4mvMOPEUqlIX6hACNJGhgt 1EpruPwqR2PYDOIC4sXCaSogL8YyjI+Jlhm5kEJ4GaU= -----END RSA PRIVATE KEY----- """ server_crt = """ Certificate: Data: Version: 3 (0x2) Serial Number: a1:02:62:34:22:0d:45:6a Signature Algorithm: md5WithRSAEncryption Issuer: C=US, ST=Arizona, L=Nowhere, O=ACME Inc, OU=Test Division, CN=test.python.projects.postgresql.org Validity Not Before: Feb 18 15:52:20 2009 GMT Not After : Mar 20 15:52:20 2009 GMT Subject: C=US, ST=Arizona, L=Nowhere, O=ACME Inc, OU=Test Division, CN=test.python.projects.postgresql.org Subject Public Key Info: Public Key Algorithm: rsaEncryption RSA Public Key: (1024 bit) Modulus (1024 bit): 00:b2:f2:f7:95:6a:a2:fa:31:95:53:f2:8d:23:f7: c8:20:65:87:e2:6c:6c:12:37:85:c6:0b:8a:df:8a: 77:80:d9:06:61:8b:cb:7e:ce:bc:df:39:2d:74:bc: b5:ca:b6:08:88:a5:eb:c7:79:06:c3:96:ba:85:0c: f5:03:a3:0a:33:ba:b8:3e:8d:fa:e8:d8:bd:63:cd: 08:d9:3d:d8:2a:39:4b:3e:62:32:a1:10:ed:18:6b: b3:4e:44:ee:d9:80:1c:7c:5a:f7:97:ab:7d:24:e2: 5f:03:2b:8d:63:be:9c:51:c8:16:fd:6a:a3:56:1b: 87:01:ed:f2:6a:55:3c:19:af Exponent: 65537 (0x10001) X509v3 extensions: X509v3 Subject Key Identifier: 4B:2F:4F:1A:43:75:43:DC:26:59:89:48:56:73:BB:D0:AA:95:E8:60 X509v3 Authority Key Identifier: keyid:4B:2F:4F:1A:43:75:43:DC:26:59:89:48:56:73:BB:D0:AA:95:E8:60 DirName:/C=US/ST=Arizona/L=Nowhere/O=ACME Inc/OU=Test Division/CN=test.python.projects.postgresql.org serial:A1:02:62:34:22:0D:45:6A X509v3 Basic Constraints: CA:TRUE Signature Algorithm: md5WithRSAEncryption 24:ee:20:0f:b5:86:08:d6:3c:8f:d4:8d:16:fd:ac:e8:49:77: 86:74:7d:b8:f3:15:51:1d:d8:65:17:5e:a8:58:aa:b0:f6:68: 45:cb:77:9d:9f:21:81:e3:5e:86:1c:64:31:39:b6:29:5f:f1: ec:b1:33:45:1f:0c:54:16:26:11:af:e2:23:1b:a6:03:46:9b: 0e:63:ce:2c:02:41:26:93:bc:6f:6e:08:7e:95:b7:7a:f9:3a: 5a:bd:47:4c:92:ce:ea:09:75:de:3d:bb:30:51:a0:c5:f1:5d: 33:5f:c0:37:75:53:4e:6c:b4:3b:b1:a5:1b:fd:59:19:07:18: 22:6a -----BEGIN CERTIFICATE----- MIIDhzCCAvCgAwIBAgIJAKECYjQiDUVqMA0GCSqGSIb3DQEBBAUAMIGKMQswCQYD VQQGEwJVUzEQMA4GA1UECBMHQXJpem9uYTEQMA4GA1UEBxMHTm93aGVyZTERMA8G A1UEChMIQUNNRSBJbmMxFjAUBgNVBAsTDVRlc3QgRGl2aXNpb24xLDAqBgNVBAMT I3Rlc3QucHl0aG9uLnByb2plY3RzLnBvc3RncmVzcWwub3JnMB4XDTA5MDIxODE1 NTIyMFoXDTA5MDMyMDE1NTIyMFowgYoxCzAJBgNVBAYTAlVTMRAwDgYDVQQIEwdB cml6b25hMRAwDgYDVQQHEwdOb3doZXJlMREwDwYDVQQKEwhBQ01FIEluYzEWMBQG A1UECxMNVGVzdCBEaXZpc2lvbjEsMCoGA1UEAxMjdGVzdC5weXRob24ucHJvamVj dHMucG9zdGdyZXNxbC5vcmcwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBALLy 95VqovoxlVPyjSP3yCBlh+JsbBI3hcYLit+Kd4DZBmGLy37OvN85LXS8tcq2CIil 68d5BsOWuoUM9QOjCjO6uD6N+ujYvWPNCNk92Co5Sz5iMqEQ7Rhrs05E7tmAHHxa 95erfSTiXwMrjWO+nFHIFv1qo1YbhwHt8mpVPBmvAgMBAAGjgfIwge8wHQYDVR0O BBYEFEsvTxpDdUPcJlmJSFZzu9CqlehgMIG/BgNVHSMEgbcwgbSAFEsvTxpDdUPc JlmJSFZzu9CqlehgoYGQpIGNMIGKMQswCQYDVQQGEwJVUzEQMA4GA1UECBMHQXJp em9uYTEQMA4GA1UEBxMHTm93aGVyZTERMA8GA1UEChMIQUNNRSBJbmMxFjAUBgNV BAsTDVRlc3QgRGl2aXNpb24xLDAqBgNVBAMTI3Rlc3QucHl0aG9uLnByb2plY3Rz LnBvc3RncmVzcWwub3JnggkAoQJiNCINRWowDAYDVR0TBAUwAwEB/zANBgkqhkiG 9w0BAQQFAAOBgQAk7iAPtYYI1jyP1I0W/azoSXeGdH248xVRHdhlF16oWKqw9mhF y3ednyGB416GHGQxObYpX/HssTNFHwxUFiYRr+IjG6YDRpsOY84sAkEmk7xvbgh+ lbd6+TpavUdMks7qCXXePbswUaDF8V0zX8A3dVNObLQ7saUb/VkZBxgiag== -----END CERTIFICATE----- """ class test_ssl_connect(test_connect.test_connect): """ Run test_connect, but with SSL. """ params = {'sslmode' : 'require'} cluster_path_suffix = '_test_ssl_connect' def configure_cluster(self): super().configure_cluster() self.cluster.settings['ssl'] = 'on' with open(self.cluster.hba_file, 'a') as hba: hba.writelines([ # nossl user "\n", "hostnossl test nossl 0::0/0 trust\n", "hostnossl test nossl 0.0.0.0/0 trust\n", # ssl-only user "hostssl test sslonly 0.0.0.0/0 trust\n", "hostssl test sslonly 0::0/0 trust\n", ]) key_file = os.path.join(self.cluster.data_directory, 'server.key') crt_file = os.path.join(self.cluster.data_directory, 'server.crt') with open(key_file, 'w') as key: key.write(server_key) with open(crt_file, 'w') as crt: crt.write(server_crt) os.chmod(key_file, 0o700) os.chmod(crt_file, 0o700) def initialize_database(self): super().initialize_database() with self.cluster.connection(user = 'test') as db: db.execute( """ CREATE USER nossl; CREATE USER sslonly; """ ) def test_ssl_mode_require(self): host, port = self.cluster.address() params = dict(self.params) params['sslmode'] = 'require' try: pg_driver.connect( user = 'nossl', database = 'test', host = host, port = port, **params ) self.fail("successful connection to nossl user when sslmode = 'require'") except pg_exc.ClientCannotConnectError as err: for pq in err.database.failures: x = pq.error dossl = pq.ssl_negotiation if isinstance(x, pg_exc.AuthenticationSpecificationError) and dossl is True: break else: # let it show as a failure. raise with pg_driver.connect( host = host, port = port, user = 'sslonly', database = 'test', **params ) as c: self.assertEqual(c.prepare('select 1').first(), 1) self.assertEqual(c.security, 'ssl') def test_ssl_mode_disable(self): host, port = self.cluster.address() params = dict(self.params) params['sslmode'] = 'disable' try: pg_driver.connect( user = 'sslonly', database = 'test', host = host, port = port, **params ) self.fail("successful connection to sslonly user with sslmode = 'disable'") except pg_exc.ClientCannotConnectError as err: for pq in err.database.failures: x = pq.error if isinstance(x, pg_exc.AuthenticationSpecificationError) and not hasattr(pq, 'ssl_negotiation'): # looking for an authspec error... break else: # let it show as a failure. raise with pg_driver.connect( host = host, port = port, user = 'nossl', database = 'test', **params ) as c: self.assertEqual(c.prepare('select 1').first(), 1) self.assertEqual(c.security, None) def test_ssl_mode_prefer(self): host, port = self.cluster.address() params = dict(self.params) params['sslmode'] = 'prefer' with pg_driver.connect( user = 'sslonly', host = host, port = port, database = 'test', **params ) as c: self.assertEqual(c.prepare('select 1').first(), 1) self.assertEqual(c.security, 'ssl') with pg_driver.connect( user = 'test', host = host, port = port, database = 'test', **params ) as c: self.assertEqual(c.security, 'ssl') with pg_driver.connect( user = 'nossl', host = host, port = port, database = 'test', **params ) as c: self.assertEqual(c.prepare('select 1').first(), 1) self.assertEqual(c.security, None) def test_ssl_mode_allow(self): host, port = self.cluster.address() params = dict(self.params) params['sslmode'] = 'allow' # nossl user (hostnossl) with pg_driver.connect( user = 'nossl', database = 'test', host = host, port = port, **params ) as c: self.assertEqual(c.prepare('select 1').first(), 1) self.assertEqual(c.security, None) # test user (host) with pg_driver.connect( user = 'test', host = host, port = port, database = 'test', **params ) as c: self.assertEqual(c.security, None) # sslonly user (hostssl) with pg_driver.connect( user = 'sslonly', host = host, port = port, database = 'test', **params ) as c: self.assertEqual(c.prepare('select 1').first(), 1) self.assertEqual(c.security, 'ssl') if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/test/test_string.py000066400000000000000000000146311203372773200202170ustar00rootroot00000000000000## # .test.test_string ## import sys import os import unittest from .. import string as pg_str # strange possibility, split, normalized split_qname_samples = [ ('base', ['base'], 'base'), ('bASe', ['base'], 'base'), ('"base"', ['base'], 'base'), ('"base "', ['base '], '"base "'), ('" base"', [' base'], '" base"'), ('" base"""', [' base"'], '" base"""'), ('""" base"""', ['" base"'], '""" base"""'), ('".base"', ['.base'], '".base"'), ('".base."', ['.base.'], '".base."'), ('schema.base', ['schema', 'base'], 'schema.base'), ('"schema".base', ['schema', 'base'], 'schema.base'), ('schema."base"', ['schema', 'base'], 'schema.base'), ('"schema.base"', ['schema.base'], '"schema.base"'), ('schEmÅ."base"', ['schemå', 'base'], 'schemå.base'), ('scheMa."base"', ['schema', 'base'], 'schema.base'), ('sche_ma.base', ['sche_ma', 'base'], 'sche_ma.base'), ('_schema.base', ['_schema', 'base'], '_schema.base'), ('a000.b111', ['a000', 'b111'], 'a000.b111'), ('" schema"."base"', [' schema', 'base'], '" schema".base'), ('" schema"."ba se"', [' schema', 'ba se'], '" schema"."ba se"'), ('" ""schema"."ba""se"', [' "schema', 'ba"se'], '" ""schema"."ba""se"'), ('" schema" . "ba se"', [' schema', 'ba se'], '" schema"."ba se"'), (' " schema" . "ba se" ', [' schema', 'ba se'], '" schema"."ba se"'), (' ". schema." . "ba se" ', ['. schema.', 'ba se'], '". schema."."ba se"'), ('CAT . ". schema." . "ba se" ', ['cat', '. schema.', 'ba se'], 'cat.". schema."."ba se"'), ('"cat" . ". schema." . "ba se" ', ['cat', '. schema.', 'ba se'], 'cat.". schema."."ba se"'), ('"""cat" . ". schema." . "ba se" ', ['"cat', '. schema.', 'ba se'], '"""cat".". schema."."ba se"'), ('"""cÅt" . ". schema." . "ba se" ', ['"cÅt', '. schema.', 'ba se'], '"""cÅt".". schema."."ba se"'), ] split_samples = [ ('', ['']), ('one-to-one', ['one-to-one']), ('"one-to-one"', [ '', ('"', 'one-to-one'), '' ]), ('$$one-to-one$$', [ '', ('$$', 'one-to-one'), '' ]), ("E'one-to-one'", [ '', ("E'", 'one-to-one'), '' ]), ("E'on''e-to-one'", [ '', ("E'", "on''e-to-one"), '' ]), ("E'on''e-to-\\'one'", [ '', ("E'", "on''e-to-\\'one"), '' ]), ("'one\\'-to-one'", [ '', ("'", "one\\"), "-to-one", ("'", ''), ]), ('"foo"""', [ '', ('"', 'foo""'), '', ]), ('"""foo"', [ '', ('"', '""foo'), '', ]), ("'''foo'", [ '', ("'", "''foo"), '', ]), ("'foo'''", [ '', ("'", "foo''"), '', ]), ("E'foo\\''", [ '', ("E'", "foo\\'"), '', ]), (r"E'foo\\' '", [ '', ("E'", r"foo\\"), ' ', ("'", ''), ]), (r"E'foo\\'' '", [ '', ("E'", r"foo\\'' "), '', ]), ('select \'foo\' as "one"', [ 'select ', ("'", 'foo'), ' as ', ('"', 'one'), '' ]), ('select $$foo$$ as "one"', [ 'select ', ("$$", 'foo'), ' as ', ('"', 'one'), '' ]), ('select $b$foo$b$ as "one"', [ 'select ', ("$b$", 'foo'), ' as ', ('"', 'one'), '' ]), ('select $b$', [ 'select ', ('$b$', ''), ]), ('select $1', [ 'select $1', ]), ('select $1$', [ 'select $1$', ]), ] split_sql_samples = [ ('select 1; select 1', [ ['select 1'], [' select 1'] ]), ('select \'one\' as "text"; select 1', [ ['select ', ("'", 'one'), ' as ', ('"', 'text'), ''], [' select 1'] ]), ('select \'one\' as "text"; select 1', [ ['select ', ("'", 'one'), ' as ', ('"', 'text'), ''], [' select 1'] ]), ('select \'one;\' as ";text;"; select 1; foo', [ ['select ', ("'", 'one;'), ' as ', ('"', ';text;'), ''], (' select 1',), [' foo'], ]), ('select \'one;\' as ";text;"; select $$;$$; foo', [ ['select ', ("'", 'one;'), ' as ', ('"', ';text;'), ''], [' select ', ('$$', ';'), ''], [' foo'], ]), ('select \'one;\' as ";text;"; select $$;$$; foo;\';b\'\'ar\'', [ ['select ', ("'", 'one;'), ' as ', ('"', ';text;'), ''], [' select ', ('$$', ';'), ''], (' foo',), ['', ("'", ";b''ar"), ''], ]), ] class test_strings(unittest.TestCase): def test_split(self): for unsplit, split in split_samples: xsplit = list(pg_str.split(unsplit)) self.assertEqual(xsplit, split) self.assertEqual(pg_str.unsplit(xsplit), unsplit) def test_split_sql(self): for unsplit, split in split_sql_samples: xsplit = list(pg_str.split_sql(unsplit)) self.assertEqual(xsplit, split) self.assertEqual(';'.join([pg_str.unsplit(x) for x in xsplit]), unsplit) def test_qname(self): "indirectly tests split_using" for unsplit, split, norm in split_qname_samples: xsplit = pg_str.split_qname(unsplit) self.assertEqual(xsplit, split) self.assertEqual(pg_str.qname_if_needed(*split), norm) self.assertRaises( ValueError, pg_str.split_qname, '"foo' ) self.assertRaises( ValueError, pg_str.split_qname, 'foo"' ) self.assertRaises( ValueError, pg_str.split_qname, 'bar.foo"' ) self.assertRaises( ValueError, pg_str.split_qname, 'bar".foo"' ) self.assertRaises( ValueError, pg_str.split_qname, '0bar.foo' ) self.assertRaises( ValueError, pg_str.split_qname, 'bar.fo@' ) def test_quotes(self): self.assertEqual( pg_str.quote_literal("""foo'bar"""), """'foo''bar'""" ) self.assertEqual( pg_str.quote_literal("""\\foo'bar\\"""), """'\\foo''bar\\'""" ) self.assertEqual( pg_str.quote_ident_if_needed("foo"), "foo" ) self.assertEqual( pg_str.quote_ident_if_needed("0foo"), '"0foo"' ) self.assertEqual( pg_str.quote_ident_if_needed("foo0"), 'foo0' ) self.assertEqual( pg_str.quote_ident_if_needed("_"), '_' ) self.assertEqual( pg_str.quote_ident_if_needed("_9"), '_9' ) self.assertEqual( pg_str.quote_ident_if_needed('''\\foo'bar\\'''), '''"\\foo'bar\\"''' ) self.assertEqual( pg_str.quote_ident("spam"), '"spam"' ) self.assertEqual( pg_str.qname("spam", "ham"), '"spam"."ham"' ) self.assertEqual( pg_str.escape_ident('"'), '""', ) self.assertEqual( pg_str.escape_ident('""'), '""""', ) chars = ''.join([ chr(x) for x in range(10000) if chr(x) != '"' ]) self.assertEqual( pg_str.escape_ident(chars), chars, ) chars = ''.join([ chr(x) for x in range(10000) if chr(x) != "'" ]) self.assertEqual( pg_str.escape_literal(chars), chars, ) chars = ''.join([ chr(x) for x in range(10000) if chr(x) not in "\\'" ]) self.assertEqual( pg_str.escape_literal(chars), chars, ) if __name__ == '__main__': from types import ModuleType this = ModuleType("this") this.__dict__.update(globals()) unittest.main(this) fe-1.1.0/postgresql/test/test_types.py000066400000000000000000000453241203372773200200600ustar00rootroot00000000000000## # .test.test_types - test type representations and I/O ## import unittest import struct from ..python.functools import process_tuple from .. import types as pg_types from ..types.io import lib as typlib from ..types.io import builtins from ..types.io.contrib_hstore import hstore_factory from ..types import Array class fake_typio(object): @staticmethod def encode(x): return x.encode('utf-8') @staticmethod def decode(x): return x.decode('utf-8') hstore_pack, hstore_unpack = hstore_factory(0, fake_typio) # this must pack to that, and # that must unpack to this expectation_samples = { ('bool', lambda x: builtins.bool_pack(x), lambda x: builtins.bool_unpack(x)) : [ (True, b'\x01'), (False, b'\x00'), ], ('int2', builtins.int2_pack, builtins.int2_unpack) : [ (0, b'\x00\x00'), (1, b'\x00\x01'), (2, b'\x00\x02'), (0x0f, b'\x00\x0f'), (0xf00, b'\x0f\x00'), (0x7fff, b'\x7f\xff'), (-0x8000, b'\x80\x00'), (-1, b'\xff\xff'), (-2, b'\xff\xfe'), (-3, b'\xff\xfd'), ], ('int4', builtins.int4_pack, builtins.int4_unpack) : [ (0, b'\x00\x00\x00\x00'), (1, b'\x00\x00\x00\x01'), (2, b'\x00\x00\x00\x02'), (0x0f, b'\x00\x00\x00\x0f'), (0x7fff, b'\x00\x00\x7f\xff'), (-0x8000, b'\xff\xff\x80\x00'), (0x7fffffff, b'\x7f\xff\xff\xff'), (-0x80000000, b'\x80\x00\x00\x00'), (-1, b'\xff\xff\xff\xff'), (-2, b'\xff\xff\xff\xfe'), (-3, b'\xff\xff\xff\xfd'), ], ('int8', builtins.int8_pack, builtins.int8_unpack) : [ (0, b'\x00\x00\x00\x00\x00\x00\x00\x00'), (1, b'\x00\x00\x00\x00\x00\x00\x00\x01'), (2, b'\x00\x00\x00\x00\x00\x00\x00\x02'), (0x0f, b'\x00\x00\x00\x00\x00\x00\x00\x0f'), (0x7fffffff, b'\x00\x00\x00\x00\x7f\xff\xff\xff'), (0x80000000, b'\x00\x00\x00\x00\x80\x00\x00\x00'), (-0x80000000, b'\xff\xff\xff\xff\x80\x00\x00\x00'), (-1, b'\xff\xff\xff\xff\xff\xff\xff\xff'), (-2, b'\xff\xff\xff\xff\xff\xff\xff\xfe'), (-3, b'\xff\xff\xff\xff\xff\xff\xff\xfd'), ], ('numeric', typlib.numeric_pack, typlib.numeric_unpack) : [ (((0,0,0,0),[]), b'\x00'*2*4), (((0,0,0,0),[1]), b'\x00'*2*4 + b'\x00\x01'), (((1,0,0,0),[1]), b'\x00\x01' + b'\x00'*2*3 + b'\x00\x01'), (((1,1,1,1),[1]), b'\x00\x01'*4 + b'\x00\x01'), (((1,1,1,1),[1,2]), b'\x00\x01'*4 + b'\x00\x01\x00\x02'), (((1,1,1,1),[1,2,3]), b'\x00\x01'*4 + b'\x00\x01\x00\x02\x00\x03'), ], ('varbit', typlib.varbit_pack, typlib.varbit_unpack) : [ ((0, b'\x00'), b'\x00\x00\x00\x00\x00'), ((1, b'\x01'), b'\x00\x00\x00\x01\x01'), ((1, b'\x00'), b'\x00\x00\x00\x01\x00'), ((2, b'\x00'), b'\x00\x00\x00\x02\x00'), ((3, b'\x00'), b'\x00\x00\x00\x03\x00'), ((9, b'\x00\x00'), b'\x00\x00\x00\x09\x00\x00'), # More data than necessary, we allow this. # Let the user do the necessary check if the cost is worth the benefit. ((9, b'\x00\x00\x00'), b'\x00\x00\x00\x09\x00\x00\x00'), ], # idk why ('bytea', builtins.bytea_pack, builtins.bytea_unpack) : [ (b'foo', b'foo'), (b'bar', b'bar'), (b'\x00', b'\x00'), (b'\x01', b'\x01'), ], ('char', builtins.char_pack, builtins.char_unpack) : [ (b'a', b'a'), (b'b', b'b'), (b'\x00', b'\x00'), ], ('point', typlib.point_pack, typlib.point_unpack) : [ ((1.0, 1.0), b'?\xf0\x00\x00\x00\x00\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), ((2.0, 2.0), b'@\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00'), ((-1.0, -1.0), b'\xbf\xf0\x00\x00\x00\x00\x00\x00\xbf\xf0\x00\x00\x00\x00\x00\x00'), ], ('circle', typlib.circle_pack, typlib.circle_unpack) : [ ((1.0, 1.0, 1.0), b'?\xf0\x00\x00\x00\x00\x00\x00?\xf0\x00\x00' \ b'\x00\x00\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), ((2.0, 2.0, 2.0), b'@\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00' \ b'\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00'), ], ('record', typlib.record_pack, typlib.record_unpack) : [ ([], b'\x00\x00\x00\x00'), ([(0,b'foo')], b'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03foo'), ([(0,None)], b'\x00\x00\x00\x01\x00\x00\x00\x00\xff\xff\xff\xff'), ([(15,None)], b'\x00\x00\x00\x01\x00\x00\x00\x0f\xff\xff\xff\xff'), ([(0xffffffff,None)], b'\x00\x00\x00\x01\xff\xff\xff\xff\xff\xff\xff\xff'), ([(0,None), (1,b'some')], b'\x00\x00\x00\x02\x00\x00\x00\x00\xff\xff\xff\xff' \ b'\x00\x00\x00\x01\x00\x00\x00\x04some'), ], ('array', typlib.array_pack, typlib.array_unpack) : [ ([0, 0xf, (1,), (0,), (b'foo',)], b'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x0f\x00\x00\x00\x01' \ b'\x00\x00\x00\x00\x00\x00\x00\x03foo' ), ([0, 0xf, (1,), (0,), (None,)], b'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x0f\x00\x00\x00\x01' \ b'\x00\x00\x00\x00\xff\xff\xff\xff' ) ], ('hstore', hstore_pack, hstore_unpack) : [ ({}, b'\x00\x00\x00\x00'), ({'b' : None}, b'\x00\x00\x00\x01\x00\x00\x00\x01b\xff\xff\xff\xff'), ({'b' : 'k'}, b'\x00\x00\x00\x01\x00\x00\x00\x01b\x00\x00\x00\x01k'), ({'foo' : 'bar'}, b'\x00\x00\x00\x01\x00\x00\x00\x03foo\x00\x00\x00\x03bar'), ({'foo' : None}, b'\x00\x00\x00\x01\x00\x00\x00\x03foo\xff\xff\xff\xff'), ], } expectation_samples[('box', typlib.box_pack, typlib.box_unpack)] = \ expectation_samples[('lseg', typlib.lseg_pack, typlib.lseg_unpack)] = [ ((1.0, 1.0, 1.0, 1.0), b'?\xf0\x00\x00\x00\x00\x00\x00?\xf0' \ b'\x00\x00\x00\x00\x00\x00?\xf0\x00\x00' \ b'\x00\x00\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), ((2.0, 2.0, 1.0, 1.0), b'@\x00\x00\x00\x00\x00\x00\x00@\x00\x00' \ b'\x00\x00\x00\x00\x00?\xf0\x00\x00\x00\x00' \ b'\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), ((-1.0, -1.0, 1.0, 1.0), b'\xbf\xf0\x00\x00\x00\x00\x00\x00\xbf\xf0' \ b'\x00\x00\x00\x00\x00\x00?\xf0\x00\x00\x00' \ b'\x00\x00\x00?\xf0\x00\x00\x00\x00\x00\x00'), ] expectation_samples[('oid', typlib.oid_pack, typlib.oid_unpack)] = \ expectation_samples[('cid', typlib.cid_pack, typlib.cid_unpack)] = \ expectation_samples[('xid', typlib.xid_pack, typlib.xid_unpack)] = [ (0, b'\x00\x00\x00\x00'), (1, b'\x00\x00\x00\x01'), (2, b'\x00\x00\x00\x02'), (0xf, b'\x00\x00\x00\x0f'), (0xffffffff, b'\xff\xff\xff\xff'), (0x7fffffff, b'\x7f\xff\xff\xff'), ] # this must pack and then unpack back into this consistency_samples = { ('bool', lambda x: builtins.bool_pack(x), lambda x: builtins.bool_unpack(x)) : [True, False], ('record', typlib.record_pack, typlib.record_unpack) : [ [], [(0,b'foo')], [(0,None)], [(15,None)], [(0xffffffff,None)], [(0,None), (1,b'some')], [(0,None), (1,b'some'), (0xffff, b"something_else\x00")], [(0,None), (1,b"s\x00me"), (0xffff, b"\x00something_else\x00")], ], ('array', typlib.array_pack, typlib.array_unpack) : [ [0, 0xf, (), (), ()], [0, 0xf, (0,), (0,), ()], [0, 0xf, (1,), (0,), (b'foo',)], [0, 0xf, (1,), (0,), (None,)], [0, 0xf, (2,), (0,), (None,None)], [0, 0xf, (2,), (0,), (b'foo',None)], [0, 0xff, (2,), (0,), (None,b'foo',)], [0, 0xffffffff, (3,), (0,), (None,b'foo',None)], [1, 0xffffffff, (3,), (0,), (None,b'foo',None)], [1, 0xffffffff, (3, 1), (0, 0), (None,b'foo',None)], [1, 0xffffffff, (3, 2), (0, 0), (None,b'one',b'foo',b'two',None,b'three')], ], # Just some random data; it's just an integer, so nothing fancy. ('date', typlib.date_pack, typlib.date_unpack) : [ 123, 321, 0x7FFFFFF, -0x8000000, ], ('point', typlib.point_pack, typlib.point_unpack) : [ (0, 0), (2, 2), (-1, -1), (-1.5, -1.2), (1.5, 1.2), ], ('circle', typlib.circle_pack, typlib.circle_unpack) : [ (0, 0, 0), (2, 2, 2), (-1, -1, -1), (-1.5, -1.2, -1.8), ], ('tid', typlib.tid_pack, typlib.tid_unpack) : [ (0, 0), (1, 1), (0xffffffff, 0xffff), (0, 0xffff), (0xffffffff, 0), (0xffffffff // 2, 0xffff // 2), ], } __ = { ('cidr', typlib.net_pack, typlib.net_unpack) : [ (0, 0, b"\x00\x00\x00\x00"), (2, 0, b"\x00" * 4), (2, 0, b"\xFF" * 4), (2, 32, b"\xFF" * 4), (3, 0, b"\x00\x00" * 16), ], ('inet', typlib.net_pack, typlib.net_unpack) : [ (2, 32, b"\x00\x00\x00\x00"), (2, 16, b"\x7f\x00\x00\x01"), (2, 8, b"\xff\x00\xff\x01"), (3, 128, b"\x7f\x00" * 16), (3, 64, b"\xff\xff" * 16), (3, 32, b"\x00\x00" * 16), ], } consistency_samples[('time', typlib.time_pack, typlib.time_unpack)] = \ consistency_samples[('time64', typlib.time64_pack, typlib.time64_unpack)] = [ (0, 0), (123, 123), (0xFFFFFFFF, 999999), ] # months, days, (seconds, microseconds) consistency_samples[('interval', typlib.interval_pack, typlib.interval_unpack)] = [ (0, 0, (0, 0)), (1, 0, (0, 0)), (0, 1, (0, 0)), (1, 1, (0, 0)), (0, 0, (0, 10000)), (0, 0, (1, 0)), (0, 0, (1, 10000)), (1, 1, (1, 10000)), (100, 50, (1423, 29313)) ] consistency_samples[('timetz', typlib.timetz_pack, typlib.timetz_unpack)] = \ consistency_samples[('timetz', typlib.timetz64_pack, typlib.timetz64_unpack)] = \ [ ((0, 0), 0), ((123, 123), 123), ((0xFFFFFFFF, 999999), -123), ] consistency_samples[('oid', typlib.oid_pack, typlib.oid_unpack)] = \ consistency_samples[('cid', typlib.cid_pack, typlib.cid_unpack)] = \ consistency_samples[('xid', typlib.xid_pack, typlib.xid_unpack)] = [ 0, 0xffffffff, 0xffffffff // 2, 123, 321, 1, 2, 3 ] consistency_samples[('lseg', typlib.lseg_pack, typlib.lseg_unpack)] = \ consistency_samples[('box', typlib.box_pack, typlib.box_unpack)] = [ (1,2,3,4), (4,3,2,1), (0,0,0,0), (-1,-1,-1,-1), (-1.2,-1.5,-2.0,4.0) ] consistency_samples[('path', typlib.path_pack, typlib.path_unpack)] = \ consistency_samples[('polygon', typlib.polygon_pack, typlib.polygon_unpack)] = [ (1,2,3,4), (4,3,2,1), (0,0,0,0), (-1,-1,-1,-1), (-1.2,-1.5,-2.0,4.0), ] from types import GeneratorType def resolve(ob): 'make sure generators get "tuplified"' if type(ob) not in (list, tuple, GeneratorType): return ob return [resolve(x) for x in ob] def testExpectIO(self, samples): for id, sample in samples.items(): name, pack, unpack = id for (sample_unpacked, sample_packed) in sample: pack_trial = pack(sample_unpacked) self.assertTrue( pack_trial == sample_packed, "%s sample: unpacked sample, %r, did not match " \ "%r when packed, rather, %r" %( name, sample_unpacked, sample_packed, pack_trial ) ) sample_unpacked = resolve(sample_unpacked) unpack_trial = resolve(unpack(sample_packed)) self.assertTrue( unpack_trial == sample_unpacked, "%s sample: packed sample, %r, did not match " \ "%r when unpacked, rather, %r" %( name, sample_packed, sample_unpacked, unpack_trial ) ) class test_io(unittest.TestCase): def test_process_tuple(self): def funpass(cause, procs, tup, col): pass self.assertEqual(tuple(process_tuple((),(), funpass)), ()) self.assertEqual(tuple(process_tuple((int,),("100",), funpass)), (100,)) self.assertEqual(tuple(process_tuple((int,int),("100","200"), funpass)), (100,200)) self.assertEqual(tuple(process_tuple((int,int),(None,"200"), funpass)), (None,200)) self.assertEqual(tuple(process_tuple((int,int,int),(None,None,"200"), funpass)), (None,None,200)) # The exception handler must raise. self.assertRaises(RuntimeError, process_tuple, (int,), ("foo",), funpass) class ThisError(Exception): pass data = [] def funraise(cause, procs, tup, col): data.append((procs, tup, col)) raise ThisError from cause self.assertRaises(ThisError, process_tuple, (int,), ("foo",), funraise) self.assertEqual(data[0], ((int,), ("foo",), 0)) del data[0] self.assertRaises(ThisError, process_tuple, (int,int), ("100","bar"), funraise) self.assertEqual(data[0], ((int,int), ("100","bar"), 1)) def testExpectations(self): 'IO tests where the pre-made expected serialized form is compared' testExpectIO(self, expectation_samples) def testConsistency(self): 'IO tests where the unpacked source is compared to re-unpacked result' for id, sample in consistency_samples.items(): name, pack, unpack = id if pack is not None: for x in sample: packed = pack(x) unpacked = resolve(unpack(packed)) x = resolve(x) self.assertTrue(x == unpacked, "inconsistency with %s, %r -> %r -> %r" %( name, x, packed, unpacked ) ) ## # Further hstore tests. def test_hstore(self): # Can't do some tests with the consistency checks # because we are not using ordered dictionaries. self.assertRaises((ValueError, struct.error), hstore_unpack, b'\x00\x00\x00\x00foo') self.assertRaises(ValueError, hstore_unpack, b'\x00\x00\x00\x01') self.assertRaises(ValueError, hstore_unpack, b'\x00\x00\x00\x02\x00\x00\x00\x01G\x00\x00\x00\x01G') sample = [ ([('foo','bar'),('k',None),('zero','heroes')], b'\x00\x00\x00\x03\x00\x00\x00\x03foo' + \ b'\x00\x00\x00\x03bar\x00\x00\x00\x01k\xFF\xFF\xFF\xFF' + \ b'\x00\x00\x00\x04zero\x00\x00\x00\x06heroes'), ([('foo',None),('k',None),('zero',None)], b'\x00\x00\x00\x03\x00\x00\x00\x03foo' + \ b'\xff\xff\xff\xff\x00\x00\x00\x01k\xFF\xFF\xFF\xFF' + \ b'\x00\x00\x00\x04zero\xFF\xFF\xFF\xFF'), ([], b'\x00\x00\x00\x00'), ] for x in sample: src, serialized = x self.assertEqual(hstore_pack(src), serialized) self.assertEqual(hstore_unpack(serialized), dict(src)) # Make some slices; used by testSlicing slice_samples = [ slice(0, None, x+1) for x in range(10) ] + [ slice(x, None, 1) for x in range(10) ] + [ slice(None, x, 1) for x in range(10) ] + [ slice(None, -x, 70) for x in range(10) ] + [ slice(x+1, x, -1) for x in range(10) ] + [ slice(x+4, x, -2) for x in range(10) ] class test_Array(unittest.TestCase): def emptyArray(self, a): self.assertEqual(len(a), 0) self.assertEqual(list(a.elements()), []) self.assertEqual(a.dimensions, ()) self.assertEqual(a.lowerbounds, ()) self.assertEqual(a.upperbounds, ()) self.assertRaises(IndexError, a.__getitem__, 0) def testArrayInstantiation(self): a = Array([]) self.emptyArray(a) # exercise default upper/lower a = Array((1,2,3,)) self.assertEqual((a[0],a[1],a[2]), (1,2,3,)) # Python interface, Python semantics. self.assertRaises(IndexError, a.__getitem__, 3) self.assertEqual(a.dimensions, (3,)) self.assertEqual(a.lowerbounds, (1,)) self.assertEqual(a.upperbounds, (3,)) def testNestedArrayInstantiation(self): a = Array(([1,2],[3,4])) # Python interface, Python semantics. self.assertRaises(IndexError, a.__getitem__, 3) self.assertEqual(a.dimensions, (2,2,)) self.assertEqual(a.lowerbounds, (1,1)) self.assertEqual(a.upperbounds, (2,2)) self.assertEqual(list(a.elements()), [1,2,3,4]) self.assertEqual(list(a), [ Array([1, 2]), Array([3, 4]), ] ) a = Array(([[1],[2]],[[3],[4]])) self.assertRaises(IndexError, a.__getitem__, 3) self.assertEqual(a.dimensions, (2,2,1)) self.assertEqual(a.lowerbounds, (1,1,1)) self.assertEqual(a.upperbounds, (2,2,1)) self.assertEqual(list(a), [ Array([[1], [2]]), Array([[3], [4]]), ] ) self.assertRaises(ValueError, Array, [ [1], [2,3] ]) self.assertRaises(ValueError, Array, [ [1], [] ]) self.assertRaises(ValueError, Array, [ [[1]], [[],2] ]) self.assertRaises(ValueError, Array, [ [[[[[1,2,3]]]]], [[[[[1,2,3]]]]], [[[[[1,2,3]]]]], [[[[[2,2]]]]], ]) def testSlicing(self): elements = [1,2,3,4,5,6,7,8] d1 = Array([1,2,3,4,5,6,7,8]) for x in slice_samples: self.assertEqual( d1[x], Array(elements[x]) ) elements = [[1,2],[3,4],[5,6],[7,8]] d2 = Array(elements) for x in slice_samples: self.assertEqual( d2[x], Array(elements[x]) ) elements = [ [[[1,2],[3,4]]], [[[5,6],[791,8]]], [[[1,2],[333,4]]], [[[1,2],[3,4]]], [[[5,10],[7,8]]], [[[0,6],[7,8]]], [[[1,2],[3,4]]], [[[5,6],[7,8]]], ] d3 = Array(elements) for x in slice_samples: self.assertEqual( d3[x], Array(elements[x]) ) def testFromElements(self): a = Array.from_elements(()) self.emptyArray(a) # exercise default upper/lower a = Array.from_elements((1,2,3,)) self.assertEqual((a[0],a[1],a[2]), (1,2,3,)) # Python interface, Python semantics. self.assertRaises(IndexError, a.__getitem__, 3) self.assertEqual(a.dimensions, (3,)) self.assertEqual(a.lowerbounds, (1,)) self.assertEqual(a.upperbounds, (3,)) # exercise default upper/lower a = Array.from_elements([3,2,1], lowerbounds = (2,), upperbounds = (4,)) self.assertEqual(a.dimensions, (3,)) self.assertEqual(a.lowerbounds, (2,)) self.assertEqual(a.upperbounds, (4,)) def testEmptyDimension(self): self.assertRaises(ValueError, Array, [[]] ) self.assertRaises(ValueError, Array, [[2],[]] ) self.assertRaises(ValueError, Array, [[],[],[]] ) self.assertRaises(ValueError, Array, [[2],[3],[]] ) def testExcessive(self): # lowerbounds too high for upperbounds self.assertRaises(ValueError, Array.from_elements, [1], lowerbounds = (2,), upperbounds = (1,) ) def testNegatives(self): a = Array.from_elements([0], lowerbounds = (-1,), upperbounds = (-1,)) self.assertEqual(a[0], 0) self.assertEqual(a[-1], 0) # upperbounds at zero a = Array.from_elements([1,2], lowerbounds = (-1,), upperbounds = (0,)) self.assertEqual(a[0], 1) self.assertEqual(a[1], 2) self.assertEqual(a[-2], 1) self.assertEqual(a[-1], 2) def testGetElement(self): a = Array([1,2,3,4]) self.assertEqual(a.get_element((0,)), 1) self.assertEqual(a.get_element((1,)), 2) self.assertEqual(a.get_element((2,)), 3) self.assertEqual(a.get_element((3,)), 4) self.assertEqual(a.get_element((-1,)), 4) self.assertEqual(a.get_element((-2,)), 3) self.assertEqual(a.get_element((-3,)), 2) self.assertEqual(a.get_element((-4,)), 1) self.assertRaises(IndexError, a.get_element, (4,)) a = Array([[1,2],[3,4]]) self.assertEqual(a.get_element((0,0)), 1) self.assertEqual(a.get_element((0,1,)), 2) self.assertEqual(a.get_element((1,0,)), 3) self.assertEqual(a.get_element((1,1,)), 4) self.assertEqual(a.get_element((-1,-1)), 4) self.assertEqual(a.get_element((-1,-2,)), 3) self.assertEqual(a.get_element((-2,-1,)), 2) self.assertEqual(a.get_element((-2,-2,)), 1) self.assertRaises(IndexError, a.get_element, (2,0)) self.assertRaises(IndexError, a.get_element, (1,2)) self.assertRaises(IndexError, a.get_element, (0,2)) def testSQLGetElement(self): a = Array([1,2,3,4]) self.assertEqual(a.sql_get_element((1,)), 1) self.assertEqual(a.sql_get_element((2,)), 2) self.assertEqual(a.sql_get_element((3,)), 3) self.assertEqual(a.sql_get_element((4,)), 4) self.assertEqual(a.sql_get_element((0,)), None) self.assertEqual(a.sql_get_element((5,)), None) self.assertEqual(a.sql_get_element((-1,)), None) self.assertEqual(a.sql_get_element((-2,)), None) self.assertEqual(a.sql_get_element((-3,)), None) self.assertEqual(a.sql_get_element((-4,)), None) a = Array([[1,2],[3,4]]) self.assertEqual(a.sql_get_element((1,1)), 1) self.assertEqual(a.sql_get_element((1,2,)), 2) self.assertEqual(a.sql_get_element((2,1,)), 3) self.assertEqual(a.sql_get_element((2,2,)), 4) self.assertEqual(a.sql_get_element((3,1)), None) self.assertEqual(a.sql_get_element((1,3)), None) if __name__ == '__main__': from types import ModuleType this = ModuleType("this") this.__dict__.update(globals()) unittest.main(this) fe-1.1.0/postgresql/test/testall.py000066400000000000000000000017021203372773200173150ustar00rootroot00000000000000## # .test.testall ## import unittest from sys import stderr from ..installation import default from .test_exceptions import * from .test_bytea_codec import * from .test_iri import * from .test_protocol import * from .test_configfile import * from .test_pgpassfile import * from .test_python import * from .test_installation import * from .test_cluster import * # These two require custom cluster configurations. from .test_connect import * # No SSL? cluster initialization will fail. if default().ssl: from .test_ssl_connect import * else: stderr.write("NOTICE: installation doesn't support SSL\n") try: from .test_optimized import * except ImportError: stderr.write("NOTICE: port.optimized could not be imported\n") from .test_driver import * from .test_alock import * from .test_notifyman import * from .test_copyman import * from .test_lib import * from .test_dbapi20 import * from .test_types import * if __name__ == '__main__': unittest.main() fe-1.1.0/postgresql/types/000077500000000000000000000000001203372773200154605ustar00rootroot00000000000000fe-1.1.0/postgresql/types/__init__.py000066400000000000000000000355041203372773200176000ustar00rootroot00000000000000## # types. - Package for I/O and PostgreSQL specific types. ## """ PostgreSQL types and identifiers. """ # XXX: Would be nicer to generate these from a header file... InvalidOid = 0 RECORDOID = 2249 BOOLOID = 16 BITOID = 1560 VARBITOID = 1562 ACLITEMOID = 1033 CHAROID = 18 NAMEOID = 19 TEXTOID = 25 BYTEAOID = 17 BPCHAROID = 1042 VARCHAROID = 1043 CSTRINGOID = 2275 UNKNOWNOID = 705 REFCURSOROID = 1790 UUIDOID = 2950 TSVECTOROID = 3614 GTSVECTOROID = 3642 TSQUERYOID = 3615 REGCONFIGOID = 3734 REGDICTIONARYOID = 3769 JSONOID = 114 XMLOID = 142 MACADDROID = 829 INETOID = 869 CIDROID = 650 TYPEOID = 71 PROCOID = 81 CLASSOID = 83 ATTRIBUTEOID = 75 DATEOID = 1082 TIMEOID = 1083 TIMESTAMPOID = 1114 TIMESTAMPTZOID = 1184 INTERVALOID = 1186 TIMETZOID = 1266 ABSTIMEOID = 702 RELTIMEOID = 703 TINTERVALOID = 704 INT8OID = 20 INT2OID = 21 INT4OID = 23 OIDOID = 26 TIDOID = 27 XIDOID = 28 CIDOID = 29 CASHOID = 790 FLOAT4OID = 700 FLOAT8OID = 701 NUMERICOID = 1700 POINTOID = 600 LINEOID = 628 LSEGOID = 601 PATHOID = 602 BOXOID = 603 POLYGONOID = 604 CIRCLEOID = 718 OIDVECTOROID = 30 INT2VECTOROID = 22 INT4ARRAYOID = 1007 REGPROCOID = 24 REGPROCEDUREOID = 2202 REGOPEROID = 2203 REGOPERATOROID = 2204 REGCLASSOID = 2205 REGTYPEOID = 2206 REGTYPEARRAYOID = 2211 TRIGGEROID = 2279 LANGUAGE_HANDLEROID = 2280 INTERNALOID = 2281 OPAQUEOID = 2282 VOIDOID = 2278 ANYARRAYOID = 2277 ANYELEMENTOID = 2283 ANYOID = 2276 ANYNONARRAYOID = 2776 ANYENUMOID = 3500 #: Mapping of type Oid to SQL type name. oid_to_sql_name = { BPCHAROID : 'CHARACTER', VARCHAROID : 'CHARACTER VARYING', # *OID : 'CHARACTER LARGE OBJECT', # SELECT X'0F' -> bit. XXX: Does bytea have any play here? #BITOID : 'BINARY', #BYTEAOID : 'BINARY VARYING', # *OID : 'BINARY LARGE OBJECT', BOOLOID : 'BOOLEAN', # exact numeric types INT2OID : 'SMALLINT', INT4OID : 'INTEGER', INT8OID : 'BIGINT', NUMERICOID : 'NUMERIC', # approximate numeric types FLOAT4OID : 'REAL', FLOAT8OID : 'DOUBLE PRECISION', # datetime types TIMEOID : 'TIME WITHOUT TIME ZONE', TIMETZOID : 'TIME WITH TIME ZONE', TIMESTAMPOID : 'TIMESTAMP WITHOUT TIME ZONE', TIMESTAMPTZOID : 'TIMESTAMP WITH TIME ZONE', DATEOID : 'DATE', # interval types INTERVALOID : 'INTERVAL', XMLOID : 'XML', } #: Mapping of type Oid to name. oid_to_name = { RECORDOID : 'record', BOOLOID : 'bool', BITOID : 'bit', VARBITOID : 'varbit', ACLITEMOID : 'aclitem', CHAROID : 'char', NAMEOID : 'name', TEXTOID : 'text', BYTEAOID : 'bytea', BPCHAROID : 'bpchar', VARCHAROID : 'varchar', CSTRINGOID : 'cstring', UNKNOWNOID : 'unknown', REFCURSOROID : 'refcursor', UUIDOID : 'uuid', TSVECTOROID : 'tsvector', GTSVECTOROID : 'gtsvector', TSQUERYOID : 'tsquery', REGCONFIGOID : 'regconfig', REGDICTIONARYOID : 'regdictionary', XMLOID : 'xml', MACADDROID : 'macaddr', INETOID : 'inet', CIDROID : 'cidr', TYPEOID : 'type', PROCOID : 'proc', CLASSOID : 'class', ATTRIBUTEOID : 'attribute', DATEOID : 'date', TIMEOID : 'time', TIMESTAMPOID : 'timestamp', TIMESTAMPTZOID : 'timestamptz', INTERVALOID : 'interval', TIMETZOID : 'timetz', ABSTIMEOID : 'abstime', RELTIMEOID : 'reltime', TINTERVALOID : 'tinterval', INT8OID : 'int8', INT2OID : 'int2', INT4OID : 'int4', OIDOID : 'oid', TIDOID : 'tid', XIDOID : 'xid', CIDOID : 'cid', CASHOID : 'cash', FLOAT4OID : 'float4', FLOAT8OID : 'float8', NUMERICOID : 'numeric', POINTOID : 'point', LINEOID : 'line', LSEGOID : 'lseg', PATHOID : 'path', BOXOID : 'box', POLYGONOID : 'polygon', CIRCLEOID : 'circle', OIDVECTOROID : 'oidvector', INT2VECTOROID : 'int2vector', INT4ARRAYOID : 'int4array', REGPROCOID : 'regproc', REGPROCEDUREOID : 'regprocedure', REGOPEROID : 'regoper', REGOPERATOROID : 'regoperator', REGCLASSOID : 'regclass', REGTYPEOID : 'regtype', REGTYPEARRAYOID : 'regtypearray', TRIGGEROID : 'trigger', LANGUAGE_HANDLEROID : 'language_handler', INTERNALOID : 'internal', OPAQUEOID : 'opaque', VOIDOID : 'void', ANYARRAYOID : 'anyarray', ANYELEMENTOID : 'anyelement', ANYOID : 'any', ANYNONARRAYOID : 'anynonarray', ANYENUMOID : 'anyenum', } name_to_oid = dict( [(v,k) for k,v in oid_to_name.items()] ) class Array(object): """ Type used to mimic PostgreSQL arrays. While there are many semantic differences, the primary one is that the elements contained by an Array instance are not strongly typed. The purpose of this class is to provide some consistency with PostgreSQL with respect to the structure of an Array. The structure consists of three parts: * The elements of the array. * The lower boundaries. * The upper boundaries. There is also a `dimensions` property, but it is derived from the `lowerbounds` and `upperbounds` to yield a normalized description of the ARRAY's structure. The Python interfaces, such as __getitem__, are *not* subjected to the semantics of the lower and upper bounds. Rather, the normalized dimensions provide the primary influence for these interfaces. So, unlike SQL indirection, getting an index that does *not* exist will raise a Python `IndexError`. """ # return an iterator over the absolute elements of a nested sequence @classmethod def unroll_nest(typ, hier, dimensions, depth = 0): dsize = dimensions and dimensions[depth] or 0 if len(hier) != dsize: raise ValueError("list size not consistent with dimensions at depth " + str(depth)) r = [] ndepth = depth + 1 if ndepth == len(dimensions): # at the bottom r = hier else: # go deeper for x in hier: r.extend(typ.unroll_nest(x, dimensions, ndepth)) return r # Detect the dimensions of a nested sequence @staticmethod def detect_dimensions(hier, len = len): # if the list is empty, it's a zero-dimension array. if hier: yield len(hier) hier = hier[0] depth = 1 while hier.__class__ is list: depth += 1 l = len(hier) if l < 1: raise ValueError("axis {0} is empty".format(depth)) yield l hier = hier[0] @classmethod def from_elements(typ, elements : "iterable of elements in the array", lowerbounds : "beginning of each axis" = None, upperbounds : "upper bounds; size of each axis" = None, len = len, ): """ Instantiate an Array from the given elements, lowerbounds, and upperbounds. The given elements are bound to the array which provides them with the structure defined by the lower boundaries and the upper boundaries. A `ValueError` will be raised in the following situations: * The number of elements given are inconsistent with the number of elements described by the upper and lower bounds. * The lower bounds at a given axis exceeds the upper bounds at a given axis. * The number of lower bounds is inconsistent with the number of upper bounds. """ # resolve iterable elements = list(elements) nelements = len(elements) # If ndims is zero, lowerbounds will be () if lowerbounds is None: if upperbounds: lowerbounds = (1,) * len(upperbounds) elif nelements == 0: # special for empty ARRAY; no dimensions. lowerbounds = () else: # one dimension. lowerbounds = (1,) else: lowerbounds = tuple(lowerbounds) if upperbounds is not None: upperbounds = tuple(upperbounds) dimensions = [] # upperbounds were given, so check. if upperbounds: elcount = 1 for lb, ub in zip(lowerbounds, upperbounds): x = ub - lb + 1 if x < 1: # special case empty ARRAYs if nelements == 0: upperbounds = () lowerbounds = () dimensions = () elcount = 0 break raise ValueError("lowerbounds exceeds upperbounds") # physical dimensions. dimensions.append(x) elcount = x * elcount else: elcount = 0 if nelements != elcount: raise ValueError("element count inconsistent with boundaries") dimensions = tuple(dimensions) else: # fill in default if nelements == 0: upperbounds = () dimensions = () else: upperbounds = (nelements,) dimensions = (nelements,) # consistency.. if len(lowerbounds) != len(upperbounds): raise ValueError("number of lowerbounds inconsistent with upperbounds") rob = super().__new__(typ) rob._elements = elements rob.lowerbounds = lowerbounds rob.upperbounds = upperbounds rob.dimensions = dimensions rob.ndims = len(dimensions) rob._weight = len(rob._elements) // (dimensions and dimensions[0] or 1) return rob # Method used to create an Array() from nested lists. @classmethod def from_nest(typ, nest): dims = tuple(typ.detect_dimensions(nest)) return typ.from_elements( list(typ.unroll_nest(nest, dims)), upperbounds = dims, # lowerbounds is implied to (1,)*len(upper) ) def __new__(typ, nested_elements): """ Create an types.Array() using the given nested lists. The boundaries of the array are detected by traversing the first items of the nested lists:: Array([[1,2,4],[3,4,8]]) Lists are used to define the boundaries so that tuples may be used to represent any complex elements. The above array will the `lowerbounds` ``(1,1)``, and the `upperbounds` ``(2,3)``. """ if nested_elements.__class__ is Array: return nested_elements return typ.from_nest(list(nested_elements)) def __getnewargs__(self): return (self.nest(),) def elements(self): """ Returns an iterator to the elements of the Array. The elements are produced in physical order. """ return iter(self._elements) def nest(self, seqtype = list): """ Transform the array into a nested list. The `seqtype` keyword can be used to override the type used to represent the elements of a given axis. """ if self.ndims < 2: return seqtype(self._elements) else: rl = [] for x in self: rl.append(x.nest(seqtype = seqtype)) return seqtype(rl) def get_element(self, address, idxerr = "index {0} at axis {1} is out of range {2}".format ): """ Get an element in the array using the given axis sequence. >>> a=Array([[1,2],[3,4]]) >>> a.get_element((0,0)) == 1 True >>> a.get_element((1,1)) == 4 True This is similar to getting items in a nested list:: >>> l=[[1,2],[3,4]] >>> l[0][0] == 1 True """ if not self.dimensions: raise IndexError("array is empty") if len(address) != len(self.dimensions): raise ValueError("given axis sequence is inconsistent with number of dimensions") # normalize axis specification (-N + DIM), check for IndexErrors, and # resolve the element's position. cur = 0 nelements = len(self._elements) for n, a, dim in zip(range(len(address)), address, self.dimensions): if a < 0: a = a + dim if a < 0: raise IndexError(idxerr(a, n, dim)) else: if a >= dim: raise IndexError(idxerr(a, n, dim)) nelements = nelements // dim cur += (a * nelements) return self._elements[cur] def sql_get_element(self, address): """ Like `get_element`, but with SQL indirection semantics. Notably, returns `None` on IndexError. """ try: a = [a - lb for (a, lb) in zip(address, self.lowerbounds)] # get_element accepts negatives, so check the converted sequence. for x in a: if x < 0: return None return self.get_element(a) except IndexError: return None def __repr__(self): return '%s.%s(%r)' %( type(self).__module__, type(self).__name__, self.nest() ) def __len__(self): return self.dimensions and self.dimensions[0] or 0 def __eq__(self, ob): return list(self) == ob def __ne__(self, ob): return list(self) != ob def __gt__(self, ob): return list(self) > ob def __lt__(self, ob): return list(self) < ob def __le__(self, ob): return list(self) <= ob def __ge__(self, ob): return list(self) >= ob def __getitem__(self, item): if self.ndims < 2: # Array with 1dim is more or less a list. return self._elements[item] if isinstance(item, slice): # get a sub-array slice l = len(self) n = 0 r = [] # for each offset in the slice, get the elements and add them # to the new elements list used to build the new Array(). for x in range(*(item.indices(l))): n = n + 1 r.extend( self._elements[slice(self._weight*x,self._weight*(x+1))] ) if n: return self.__class__.from_elements(r, lowerbounds = (1,) + self.lowerbounds[1:], upperbounds = (n,) + self.upperbounds[1:], ) else: # Empty return self.__class__.from_elements(()) else: # get a sub-array l = len(self) if item > l: raise IndexError("index {0} is out of range".format(l)) return self.__class__.from_elements( self._elements[self._weight*item:self._weight*(item+1)], lowerbounds = self.lowerbounds[1:], upperbounds = self.upperbounds[1:], ) def __iter__(self): if self.ndims < 2: # Special case empty and single dimensional ARRAYs return self.elements() return (self[x] for x in range(len(self))) from operator import itemgetter get0 = itemgetter(0) get1 = itemgetter(1) del itemgetter class Row(tuple): "Name addressable items tuple; mapping and sequence" @classmethod def from_mapping(typ, keymap, map, get1 = get1): iter = [ map.get(k) for k,_ in sorted(keymap.items(), key = get1) ] r = typ(iter) r.keymap = keymap return r @classmethod def from_sequence(typ, keymap, seq): r = typ(seq) r.keymap = keymap return r def __getitem__(self, i, gi = tuple.__getitem__): if isinstance(i, (int, slice)): return gi(self, i) idx = self.keymap[i] return gi(self, idx) def get(self, i, gi = tuple.__getitem__, len = len): if type(i) is int: l = len(self) if -l < i < l: return gi(self, i) else: idx = self.keymap.get(i) if idx is not None: return gi(self, idx) return None def keys(self): return self.keymap.keys() def values(self): return iter(self) def items(self): return zip(iter(self.column_names), iter(self)) def index_from_key(self, key): return self.keymap.get(key) def key_from_index(self, index): for k,v in self.keymap.items(): if v == index: return k return None @property def column_names(self, get0 = get0, get1 = get1): l=list(self.keymap.items()) l.sort(key=get1) return tuple(map(get0, l)) def transform(self, *args, **kw): """ Make a new Row after processing the values with the callables associated with the values either by index, \*args, or my column name, \*\*kw. >>> r=Row.from_sequence({'col1':0,'col2':1}, (1,'two')) >>> r.transform(str) ('1','two') >>> r.transform(col2 = str.upper) (1,'TWO') >>> r.transform(str, col2 = str.upper) ('1','TWO') Combine with methodcaller and map to transform lots of rows: >>> rowseq = [r] >>> xf = operator.methodcaller('transform', col2 = str.upper) >>> list(map(xf, rowseq)) [(1,'TWO')] """ r = list(self) i = 0 for x in args: if x is not None: r[i] = x(tuple.__getitem__(self, i)) i = i + 1 for k,v in kw.items(): if v is not None: i = self.index_from_key(k) if i is None: raise KeyError("row has no such key, " + repr(k)) r[i] = v(self[k]) return type(self).from_sequence(self.keymap, r) fe-1.1.0/postgresql/types/bitwise.py000066400000000000000000000045521203372773200175060ustar00rootroot00000000000000class Varbit(object): __slots__ = ('data', 'bits') def from_bits(subtype, bits, data): if bits == 1: return (data[0] & (1 << 7)) and OneBit or ZeroBit else: rob = object.__new__(subtype) rob.bits = bits rob.data = data return rob from_bits = classmethod(from_bits) def __new__(typ, data): if isinstance(data, Varbit): return data if isinstance(data, bytes): return typ.from_bits(len(data) * 8, data) # str(), eg '00101100' bits = len(data) nbytes, remain = divmod(bits, 8) bdata = [bytes((int(data[x:x+8], 2),)) for x in range(0, bits - remain, 8)] if remain != 0: bdata.append(bytes((int(data[nbytes*8:].ljust(8,'0'), 2),))) return typ.from_bits(bits, b''.join(bdata)) def __str__(self): if self.bits: # cut off the remainder from the bits blocks = [bin(x)[2:].rjust(8, '0') for x in self.data] blocks[-1] = blocks[-1][0:(self.bits % 8) or 8] return ''.join(blocks) else: return '' def __repr__(self): return '%s.%s(%r)' %( type(self).__module__, type(self).__name__, str(self) ) def __eq__(self, ob): if not isinstance(ob, type(self)): ob = type(self)(ob) return ob.bits == self.bits and ob.data == self.data def __len__(self): return self.bits def __add__(self, ob): return Varbit(str(self) + str(ob)) def __mul__(self, ob): return Varbit(str(self) * ob) def getbit(self, bitoffset): if bitoffset < 0: idx = self.bits + bitoffset else: idx = bitoffset if not 0 <= idx < self.bits: raise IndexError("bit index %d out of range" %(bitoffset,)) byte, bitofbyte = divmod(idx, 8) if ord(self.data[byte]) & (1 << (7 - bitofbyte)): return OneBit else: return ZeroBit def __getitem__(self, item): if isinstance(item, slice): return type(self)(str(self)[item]) else: return self.getbit(item) def __nonzero__(self): for x in self.data: if x != 0: return True return False class Bit(Varbit): def __new__(subtype, ob): if ob is ZeroBit or ob is False or ob == '0': return ZeroBit elif ob is OneBit or ob is True or ob == '1': return OneBit raise ValueError('unknown bit value %r, 0 or 1' %(ob,)) def __nonzero__(self): return self is OneBit def __str__(self): return self is OneBit and '1' or '0' ZeroBit = object.__new__(Bit) ZeroBit.data = b'\x00' ZeroBit.bits = 1 OneBit = object.__new__(Bit) OneBit.data = b'\x80' OneBit.bits = 1 fe-1.1.0/postgresql/types/geometry.py000066400000000000000000000107301203372773200176660ustar00rootroot00000000000000import math from operator import itemgetter get0 = itemgetter(0) get1 = itemgetter(1) # Geometric types class Point(tuple): """ A point; a pair of floating point numbers. """ __slots__ = () x = property(fget = lambda s: s[0]) y = property(fget = lambda s: s[1]) def __new__(subtype, pair): return tuple.__new__(subtype, (float(pair[0]), float(pair[1]))) def __repr__(self): return '%s.%s(%s)' %( type(self).__module__, type(self).__name__, tuple.__repr__(self), ) def __str__(self): return tuple.__repr__(self) def __add__(self, ob): wx, wy = ob return type(self)((self[0] + wx, self[1] + wy)) def __sub__(self, ob): wx, wy = ob return type(self)((self[0] - wx, self[1] - wy)) def __mul__(self, ob): wx, wy = ob rx = (self[0] * wx) - (self[1] * wy) ry = (self[0] * wy) + (self[1] * wx) return type(self)((rx, ry)) def __div__(self, ob): sx, sy = self wx, wy = ob div = (wx * wx) + (wy * wy) rx = ((sx * wx) + (sy * wy)) / div ry = ((wx * sy) + (wy * sx)) / div return type(self)((rx, ry)) def distance(self, ob, sqrt = math.sqrt): wx, wy = ob dx = self[0] - float(wx) dy = self[1] - float(wy) return sqrt(dx**2 + dy**2) class Lseg(tuple): __slots__ = () one = property(fget = lambda s: s[0]) two = property(fget = lambda s: s[1]) length = property(fget = lambda s: s[0].distance(s[1])) vertical = property(fget = lambda s: s[0][0] == s[1][0]) horizontal = property(fget = lambda s: s[0][1] == s[1][1]) slope = property( fget = lambda s: (s[1][1] - s[0][1]) / (s[1][0] - s[0][0]) ) center = property( fget = lambda s: Point(( (s[0][0] + s[1][0]) / 2.0, (s[0][1] + s[1][1]) / 2.0, )) ) def __new__(subtype, pair): p1, p2 = pair return tuple.__new__(subtype, (Point(p1), Point(p2))) def __repr__(self): # Avoid the point representation return '%s.%s(%s, %s)' %( type(self).__module__, type(self).__name__, tuple.__repr__(self[0]), tuple.__repr__(self[1]), ) def __str__(self): return '[(%s,%s),(%s,%s)]' %( self[0][0], self[0][1], self[1][0], self[1][1], ) def parallel(self, ob): return self.slope == type(self)(ob).slope def intersect(self, ob): raise NotImplementedError def perpendicular(self, ob): return (self.slope / type(self)(ob).slope) == -1.0 class Box(tuple): """ A pair of points. One specifying the top-right point of the box; the other specifying the bottom-left. `high` being top-right; `low` being bottom-left. http://www.postgresql.org/docs/current/static/datatype-geometric.html >>> Box(( (0,0), (-2, -2) )) postgresql.types.geometry.Box(((0.0, 0.0), (-2.0, -2.0))) It will also relocate values to enforce the high-low expectation: >>> t.box(((-4,0),(-2,-3))) postgresql.types.geometry.Box(((-2.0, 0.0), (-4.0, -3.0))) :: (-2, 0) `high` | | (-4,-3) -------+-x `low` y This happens because ``-4`` is less than ``-2``; therefore the ``-4`` belongs on the low point. This is consistent with what PostgreSQL does with its ``box`` type. """ __slots__ = () high = property(fget = get0, doc = "high point of the box") low = property(fget = get1, doc = "low point of the box") center = property( fget = lambda s: Point(( (s[0][0] + s[1][0]) / 2.0, (s[0][1] + s[1][1]) / 2.0 )), doc = "center of the box as a point" ) def __new__(subtype, hl): if isinstance(hl, Box): return hl one, two = hl if one[0] > two[0]: hx = one[0] lx = two[0] else: hx = two[0] lx = one[0] if one[1] > two[1]: hy = one[1] ly = two[1] else: hy = two[1] ly = one[1] return tuple.__new__(subtype, (Point((hx, hy)), Point((lx, ly)))) def __repr__(self): return '%s.%s((%s, %s))' %( type(self).__module__, type(self).__name__, tuple.__repr__(self[0]), tuple.__repr__(self[1]), ) def __str__(self): return '%s,%s' %(self[0], self[1]) class Circle(tuple): """ type for PostgreSQL circles """ __slots__ = () center = property(fget = get0, doc = "center of the circle (point)") radius = property(fget = get1, doc = "radius of the circle (radius >= 0)") def __new__(subtype, pair): center, radius = pair if radius < 0: raise ValueError("radius is subzero") return tuple.__new__(subtype, (Point(center), float(radius))) def __repr__(self): return '%s.%s((%s, %s))' %( type(self).__module__, type(self).__name__, tuple.__repr__(self[0]), repr(self[1]) ) def __str__(self): return '<%s,%s>' %(self[0], self[1]) fe-1.1.0/postgresql/types/io/000077500000000000000000000000001203372773200160675ustar00rootroot00000000000000fe-1.1.0/postgresql/types/io/__init__.py000066400000000000000000000043721203372773200202060ustar00rootroot00000000000000## # .types.io - I/O routines for packing and unpacking data ## """ PostgreSQL type I/O routines--packing and unpacking functions. This package manages the modules providing I/O routines. The name of the function describes what type the function is intended to be used on. Normally, the fucntions return a structured form of the serialized data to be used as a parameter to the creation of a higher level instance. In particular, most of the functions that deal with time return a pair for representing the relative offset: (seconds, microseconds). For times, this provides an abstraction for quad-word based times used by some configurations of PostgreSQL. """ import sys from itertools import cycle, chain from ... import types as pg_types io_modules = { 'builtins' : ( pg_types.BOOLOID, pg_types.CHAROID, pg_types.BYTEAOID, pg_types.INT2OID, pg_types.INT4OID, pg_types.INT8OID, pg_types.FLOAT4OID, pg_types.FLOAT8OID, pg_types.ABSTIMEOID, ), 'pg_bitwise': ( pg_types.BITOID, pg_types.VARBITOID, ), 'pg_network': ( pg_types.MACADDROID, pg_types.INETOID, pg_types.CIDROID, ), 'pg_system': ( pg_types.OIDOID, pg_types.XIDOID, pg_types.CIDOID, pg_types.TIDOID, ), 'pg_geometry': ( pg_types.POINTOID, pg_types.LSEGOID, pg_types.BOXOID, pg_types.CIRCLEOID, ), 'stdlib_datetime' : ( pg_types.DATEOID, pg_types.INTERVALOID, pg_types.TIMEOID, pg_types.TIMETZOID, pg_types.TIMESTAMPOID, pg_types.TIMESTAMPTZOID ), 'stdlib_decimal' : ( pg_types.NUMERICOID, ), 'stdlib_uuid' : ( pg_types.UUIDOID, ), 'stdlib_xml_etree' : ( pg_types.XMLOID, ), # Must be db.typio.identify(contrib_hstore = 'hstore')'d 'contrib_hstore' : ( 'contrib_hstore', ), } # OID -> module name module_io = dict( chain.from_iterable(( zip(x[1], cycle((x[0],))) for x in io_modules.items() )) ) if sys.version_info[:2] < (3,3): def load(relmod): return __import__(__name__ + '.' + relmod, fromlist = True, level = 1) else: def load(relmod): return __import__(relmod, globals = globals(), locals = locals(), fromlist = [''], level = 1) def resolve(oid): io = module_io.get(oid) if io is None: return None if io.__class__ is str: module_io.update(load(io).oid_to_io) io = module_io[oid] return io fe-1.1.0/postgresql/types/io/builtins.py000066400000000000000000000024321203372773200202730ustar00rootroot00000000000000from .. import \ INT2OID, INT4OID, INT8OID, \ BOOLOID, BYTEAOID, CHAROID, \ ABSTIMEOID, FLOAT4OID, FLOAT8OID, \ TEXTOID, BPCHAROID, NAMEOID, VARCHAROID from . import lib bool_pack = {True:b'\x01', False:b'\x00'}.__getitem__ bool_unpack = {b'\x01':True, b'\x00':False}.__getitem__ int2_pack, int2_unpack = lib.short_pack, lib.short_unpack int4_pack, int4_unpack = lib.long_pack, lib.long_unpack int8_pack, int8_unpack = lib.longlong_pack, lib.longlong_unpack bytea_pack = bytes bytea_unpack = bytes char_pack = bytes char_unpack = bytes oid_to_io = { BOOLOID : (bool_pack, bool_unpack, bool), BYTEAOID : (bytea_pack, bytea_unpack, bytes), CHAROID : (char_pack, char_unpack, bytes), INT2OID : (int2_pack, int2_unpack, int), INT4OID : (int4_pack, int4_unpack, int), INT8OID : (int8_pack, int8_unpack, int), ABSTIMEOID : (lib.long_pack, lib.long_unpack, int), FLOAT4OID : (lib.float_pack, lib.float_unpack, float), FLOAT8OID : (lib.double_pack, lib.double_unpack, float), } # Python Representations of PostgreSQL Types oid_to_type = { BOOLOID: bool, VARCHAROID: str, TEXTOID: str, BPCHAROID: str, NAMEOID: str, # This is *not* bpchar, the SQL CHARACTER type. CHAROID: bytes, BYTEAOID: bytes, INT2OID: int, INT4OID: int, INT8OID: int, FLOAT4OID: float, FLOAT8OID: float, } fe-1.1.0/postgresql/types/io/contrib_hstore.py000066400000000000000000000022511203372773200214650ustar00rootroot00000000000000## # .types.io.contrib_hstore - I/O routines for binary hstore ## from ...python.structlib import split_sized_data, ulong_pack, ulong_unpack from ...python.itertools import chunk ## # Build the hstore I/O pair for a given typio. # It primarily needs typio for decode and encode. def hstore_factory(oid, typio, unpack_err = "expected {0} items in hstore, but found {1}".format ): def pack_hstore(x, encode = typio.encode, len = len, ): if hasattr(x, 'items'): x = x.items() encoded = [ (encode(k), encode(v)) if v is not None else (encode(k), None) for k,v in x ] return ulong_pack(len(encoded)) + b''.join( ulong_pack(len(k)) + k + b'\xFF\xFF\xFF\xFF' if v is None else ulong_pack(len(k)) + k + ulong_pack(len(v)) + v for k,v in encoded ) def unpack_hstore(x, decode = typio.decode, split = split_sized_data, len = len ): view = memoryview(x)[4:] n = ulong_unpack(x) r = { decode(y[0]) : (decode(y[1]) if y[1] is not None else None) for y in chunk(split(view), 2) if y } if len(r) != n: raise ValueError(unpack_err(n, len(r))) return r return (pack_hstore, unpack_hstore) oid_to_io = { 'contrib_hstore' : hstore_factory, } fe-1.1.0/postgresql/types/io/lib.py000066400000000000000000000344051203372773200172150ustar00rootroot00000000000000import struct from math import floor from ...python.functools import Composition as compose from ...python.itertools import interlace from ...python.structlib import \ short_pack, short_unpack, \ ulong_pack, ulong_unpack, \ long_pack, long_unpack, \ double_pack, double_unpack, \ longlong_pack, longlong_unpack, \ float_pack, float_unpack, \ LH_pack, LH_unpack, \ dl_pack, dl_unpack, \ dll_pack, dll_unpack, \ ql_pack, ql_unpack, \ qll_pack, qll_unpack, \ llL_pack, llL_unpack, \ dd_pack, dd_unpack, \ ddd_pack, ddd_unpack, \ dddd_pack, dddd_unpack, \ hhhh_pack, hhhh_unpack oid_pack = cid_pack = xid_pack = ulong_pack oid_unpack = cid_unpack = xid_unpack = ulong_unpack tid_pack, tid_unpack = LH_pack, LH_unpack # geometry types point_pack, point_unpack = dd_pack, dd_unpack circle_pack, circle_unpack = ddd_pack, ddd_unpack lseg_pack = box_pack = dddd_pack lseg_unpack = box_unpack = dddd_unpack null_sequence = b'\xff\xff\xff\xff' string_format = b'\x00\x00' binary_format = b'\x00\x01' def numeric_pack(data, hhhh_pack = hhhh_pack, pack = struct.pack, len = len): return hhhh_pack(data[0]) + pack("!%dh"%(len(data[1]),), *data[1]) def numeric_unpack(data, hhhh_unpack = hhhh_unpack, unpack = struct.unpack, len = len): return (hhhh_unpack(data[:8]), unpack("!8x%dh"%((len(data)-8) // 2,), data)) def path_pack(data, pack = struct.pack, len = len): """ Given a sequence of point data, pack it into a path's serialized form. [px1, py1, px2, py2, ...] Must be an even number of numbers. """ return pack("!l%dd" %(len(data),), len(data), *data) def path_unpack(data, long_unpack = long_unpack, unpack = struct.unpack): """ Unpack a path's serialized form into a sequence of point data: [px1, py1, px2, py2, ...] Should be an even number of numbers. """ return unpack("!4x%dd" %(long_unpack(data[:4]),), data) polygon_pack, polygon_unpack = path_pack, path_unpack ## # Binary representations of infinity for datetimes. time_infinity = b'\x7f\xf0\x00\x00\x00\x00\x00\x00' time_negative_infinity = b'\xff\xf0\x00\x00\x00\x00\x00\x00' time64_infinity = b'\x7f\xff\xff\xff\xff\xff\xff\xff' time64_negative_infinity = b'\x80\x00\x00\x00\x00\x00\x00\x00' date_infinity = b'\x7f\xff\xff\xff' date_negative_infinity = b'\x80\x00\x00\x00' # time types date_pack, date_unpack = long_pack, long_unpack def mktimetuple(ts, floor = floor): 'make a pair of (seconds, microseconds) out of the given double' seconds = floor(ts) return (int(seconds), int(1000000 * (ts - seconds))) def mktimetuple64(ts, divmod = divmod): 'make a pair of (seconds, microseconds) out of the given long' return divmod(ts, 1000000) def mktime(seconds_ms, float = float): 'make a double out of the pair of (seconds, microseconds)' return float(seconds_ms[0]) + (seconds_ms[1] / 1000000.0) def mktime64(seconds_ms): 'make an integer out of the pair of (seconds, microseconds)' return seconds_ms[0] * 1000000 + seconds_ms[1] # takes a pair, (seconds, microseconds) time_pack = compose((mktime, double_pack)) time_unpack = compose((double_unpack, mktimetuple)) def interval_pack(m_d_timetup, mktime = mktime, dll_pack = dll_pack): """ Given a triple, (month, day, (seconds, microseconds)), serialize it for transport. """ (month, day, timetup) = m_d_timetup return dll_pack((mktime(timetup), day, month)) def interval_unpack(data, dll_unpack = dll_unpack, mktimetuple = mktimetuple): """ Given a serialized interval, '{month}{day}{time}', yield the triple: (month, day, (seconds, microseconds)) """ tim, day, month = dll_unpack(data) return (month, day, mktimetuple(tim)) def interval_noday_pack(month_day_timetup, dl_pack = dl_pack, mktime = mktime): """ Given a triple, (month, day, (seconds, microseconds)), return the serialized form that does not have an individual day component. There is no day component, so if day is non-zero, it will be converted to seconds and subsequently added to the seconds. """ (month, day, timetup) = month_day_timetup if day: timetup = (timetup[0] + (day * 24 * 60 * 60), timetup[1]) return dl_pack((mktime(timetup), month)) def interval_noday_unpack(data, dl_unpack = dl_unpack, mktimetuple = mktimetuple): """ Given a serialized interval without a day component, return the triple: (month, day(always zero), (seconds, microseconds)) """ tim, month = dl_unpack(data) return (month, 0, mktimetuple(tim)) def time64_pack(data, mktime64 = mktime64, longlong_pack = longlong_pack): return longlong_pack(mktime64(data)) def time64_unpack(data, longlong_unpack = longlong_unpack, mktimetuple64 = mktimetuple64): return mktimetuple64(longlong_unpack(data)) def interval64_pack(m_d_timetup, qll_pack = qll_pack, mktime64 = mktime64): """ Given a triple, (month, day, (seconds, microseconds)), return the serialized data using a quad-word for the (seconds, microseconds) tuple. """ (month, day, timetup) = m_d_timetup return qll_pack((mktime64(timetup), day, month)) def interval64_unpack(data, qll_unpack = qll_unpack, mktimetuple = mktimetuple): """ Unpack an interval containing a quad-word into a triple: (month, day, (seconds, microseconds)) """ tim, day, month = qll_unpack(data) return (month, day, mktimetuple64(tim)) def interval64_noday_pack(m_d_timetup, ql_pack = ql_pack, mktime64 = mktime64): """ Pack an interval without a day component and using a quad-word for second representation. There is no day component, so if day is non-zero, it will be converted to seconds and subsequently added to the seconds. """ (month, day, timetup) = m_d_timetup if day: timetup = (timetup[0] + (day * 24 * 60 * 60), timetup[1]) return ql_pack((mktime64(timetup), month)) def interval64_noday_unpack(data, ql_unpack = ql_unpack, mktimetuple64 = mktimetuple64): """ Unpack a ``noday`` quad-word based interval. Returns a triple: (month, day(always zero), (seconds, microseconds)) """ tim, month = ql_unpack(data) return (month, 0, mktimetuple64(tim)) def timetz_pack(timetup_tz, dl_pack = dl_pack, mktime = mktime): """ Pack a time; offset from beginning of the day and timezone offset. Given a pair, ((seconds, microseconds), timezone_offset), pack it into its serialized form: "!dl". """ (timetup, tz_offset) = timetup_tz return dl_pack((mktime(timetup), tz_offset)) def timetz_unpack(data, dl_unpack = dl_unpack, mktimetuple = mktimetuple): """ Given serialized time data, unpack it into a pair: ((seconds, microseconds), timezone_offset). """ ts, tz = dl_unpack(data) return (mktimetuple(ts), tz) def timetz64_pack(timetup_tz, ql_pack = ql_pack, mktime64 = mktime64): """ Pack a time; offset from beginning of the day and timezone offset. Given a pair, ((seconds, microseconds), timezone_offset), pack it into its serialized form using a long long: "!ql". """ (timetup, tz_offset) = timetup_tz return ql_pack((mktime64(timetup), tz_offset)) def timetz64_unpack(data, ql_unpack = ql_unpack, mktimetuple64 = mktimetuple64): """ Given "long long" serialized time data, "ql", unpack it into a pair: ((seconds, microseconds), timezone_offset) """ ts, tz = ql_unpack(data) return (mktimetuple64(ts), tz) # oidvectors are 128 bytes, so pack the number of Oids in self # and justify that to 128 by padding with \x00. def oidvector_pack(seq, pack = struct.pack): """ Given a sequence of Oids, pack them into the serialized form. An oidvector is a type used by the PostgreSQL catalog. """ return pack("!%dL"%(len(seq),), *seq).ljust(128, '\x00') def oidvector_unpack(data, unpack = struct.unpack): """ Given a serialized oidvector(32 longs), unpack it into a list of unsigned integers. An int2vector is a type used by the PostgreSQL catalog. """ return unpack("!32L", data) def int2vector_pack(seq, pack = struct.pack): """ Given a sequence of integers, pack them into the serialized form. An int2vector is a type used by the PostgreSQL catalog. """ return pack("!%dh"%(len(seq),), *seq).ljust(64, '\x00') def int2vector_unpack(data, unpack = struct.unpack): """ Given a serialized int2vector, unpack it into a list of integers. An int2vector is a type used by the PostgreSQL catalog. """ return unpack("!32h", data) def varbit_pack(bits_data, long_pack = long_pack): r""" Given a pair, serialize the varbit. # (number of bits, data) >>> varbit_pack((1, '\x00')) b'\x00\x00\x00\x01\x00' """ return long_pack(bits_data[0]) + bits_data[1] def varbit_unpack(data, long_unpack = long_unpack): """ Given ``varbit`` data, unpack it into a pair: (bits, data) Where bits are the total number of bits in data (bytes). """ return long_unpack(data[0:4]), data[4:] def net_pack(triple, # Map PGSQL src/include/utils/inet.h to IP version number. fmap = { 4: 2, 6: 3, }, len = len, ): """ net_pack((family, mask, data)) Pack Postgres' inet/cidr data structure. """ family, mask, data = triple return bytes((fmap[family], mask or 0, 0 if mask is None else 1, len(data))) + data def net_unpack(data, # Map IP version number to PGSQL src/include/utils/inet.h. fmap = { 2: 4, 3: 6, } ): """ net_unpack(data) Unpack Postgres' inet/cidr data structure. """ family, mask, is_cidr, size = data[:4] return (fmap[family], mask, data[4:]) def macaddr_pack(data, bytes = bytes): """ Pack a MAC address Format found in PGSQL src/backend/utils/adt/mac.c, and PGSQL Manual types """ # Accept all possible PGSQL Macaddr formats as in manual # Oh for sscanf() as we could just copy PGSQL C in src/util/adt/mac.c colon_parts = data.split(':') dash_parts = data.split('-') dot_parts = data.split('.') if len(colon_parts) == 6: mac_parts = colon_parts elif len(dash_parts) == 6: mac_parts = dash_parts elif len(colon_parts) == 2: mac_parts = [colon_parts[0][:2], colon_parts[0][2:4], colon_parts[0][4:], colon_parts[1][:2], colon_parts[1][2:4], colon_parts[1][4:]] elif len(dash_parts) == 2: mac_parts = [dash_parts[0][:2], dash_parts[0][2:4], dash_parts[0][4:], dash_parts[1][:2], dash_parts[1][2:4], dash_parts[1][4:]] elif len(dot_parts) == 3: mac_parts = [dot_parts[0][:2], dot_parts[0][2:], dot_parts[1][:2], dot_parts[1][2:], dot_parts[2][:2], dot_parts[2][2:]] elif len(colon_parts) == 1: mac_parts = [data[:2], data[2:4], data[4:6], data[6:8], data[8:10], data[10:]] else: raise ValueError('data string cannot be parsed to bytes') if len(mac_parts) != 6 and len(mac_parts[-1]) != 2: raise ValueError('data string cannot be parsed to bytes') return bytes([int(p, 16) for p in mac_parts]) def macaddr_unpack(data): """ Unpack a MAC address Format found in PGSQL src/backend/utils/adt/mac.c """ # This is easy, just go for standard macaddr format, # just like PGSQL in src/util/adt/mac.c macaddr_out() if len(data) != 6: raise ValueError('macaddr has incorrect length') return ("%02x:%02x:%02x:%02x:%02x:%02x" % tuple(data)) def record_unpack(data, long_unpack = long_unpack, oid_unpack = oid_unpack, null_sequence = null_sequence, len = len): """ Given serialized record data, return a tuple of tuples of type Oids and attributes. """ columns = long_unpack(data) offset = 4 for x in range(columns): typid = oid_unpack(data[offset:offset+4]) offset += 4 if data[offset:offset+4] == null_sequence: att = None offset += 4 else: size = long_unpack(data[offset:offset+4]) offset += 4 att = data[offset:offset + size] if size < -1 or len(att) != size: raise ValueError("insufficient data left in message") offset += size yield (typid, att) if len(data) - offset != 0: raise ValueError("extra data, %d octets, at end of record" %(len(data),)) def record_pack(seq, long_pack = long_pack, oid_pack = oid_pack, null_sequence = null_sequence): """ pack a record given an iterable of (type_oid, data) pairs. """ return long_pack(len(seq)) + b''.join([ # typid + (null_seq or data) oid_pack(x) + (y is None and null_sequence or (long_pack(len(y)) + y)) for x, y in seq ]) def elements_pack(elements, null_sequence = null_sequence, long_pack = long_pack, len = len ): """ Pack the elements for containment within a serialized array. This is used by array_pack. """ for x in elements: if x is None: yield null_sequence else: yield long_pack(len(x)) yield x def array_pack(array_data, llL_pack = llL_pack, len = len, long_pack = long_pack, interlace = interlace ): """ Pack a raw array. A raw array consists of flags, type oid, sequence of lower and upper bounds, and an iterable of already serialized element data: (0, element type oid, (lower bounds, upper bounds, ...), iterable of element_data) The lower bounds and upper bounds specifies boundaries of the dimension. So the length of the boundaries sequence is two times the number of dimensions that the array has. array_pack((flags, type_id, dims, lowers, element_data)) The format of ``lower_upper_bounds`` is a sequence of lower bounds and upper bounds. First lower then upper inlined within the sequence: [lower, upper, lower, upper] The above array `dlb` has two dimensions. The lower and upper bounds of the first dimension is defined by the first two elements in the sequence. The second dimension is then defined by the last two elements in the sequence. """ (flags, typid, dims, lbs, elements) = array_data return llL_pack((len(dims), flags, typid)) + \ b''.join(map(long_pack, interlace(dims, lbs))) + \ b''.join(elements_pack(elements)) def elements_unpack(data, offset, long_unpack = long_unpack, null_sequence = null_sequence): """ Unpack the serialized elements of an array into a list. This is used by array_unpack. """ data_len = len(data) while offset < data_len: lend = data[offset:offset+4] offset += 4 if lend == null_sequence: yield None else: sizeof_el = long_unpack(lend) yield data[offset:offset+sizeof_el] offset += sizeof_el def array_unpack(data, llL_unpack = llL_unpack, unpack = struct.unpack_from, long_unpack = long_unpack ): """ Given a serialized array, unpack it into a tuple: (flags, typid, (dims, lower bounds, ...), [elements]) """ ndim, flags, typid = llL_unpack(data) if ndim < 0: raise ValueError("invalid number of dimensions: %d" %(ndim,)) # "ndim" number of pairs of longs end = (4 * 2 * ndim) + 12 # Dimensions and lower bounds; split the two early. #dlb = unpack("!%dl"%(2 * ndim,), data, 12) dims = [long_unpack(data[x:x+4]) for x in range(12, end, 8)] lbs = [long_unpack(data[x:x+4]) for x in range(16, end, 8)] return (flags, typid, dims, lbs, elements_unpack(data, end)) fe-1.1.0/postgresql/types/io/pg_bitwise.py000066400000000000000000000006411203372773200205760ustar00rootroot00000000000000from .. import BITOID, VARBITOID from ..bitwise import Varbit, Bit from . import lib def varbit_pack(x, pack = lib.varbit_pack): return pack((x.bits, x.data)) def varbit_unpack(x, unpack = lib.varbit_unpack): return Varbit.from_bits(*unpack(x)) oid_to_io = { BITOID : (varbit_pack, varbit_unpack, Bit), VARBITOID : (varbit_pack, varbit_unpack, Varbit), } oid_to_type = { BITOID : Bit, VARBITOID : Varbit, } fe-1.1.0/postgresql/types/io/pg_geometry.py000066400000000000000000000024041203372773200207620ustar00rootroot00000000000000from .. import POINTOID, BOXOID, LSEGOID, CIRCLEOID from ..geometry import Point, Box, Lseg, Circle from ...python.functools import Composition as compose from . import lib oid_to_type = { POINTOID: Point, BOXOID: Box, LSEGOID: Lseg, CIRCLEOID: Circle, } # Make a pair of pairs out of a sequence of four objects def two_pair(x): return ((x[0], x[1]), (x[2], x[3])) point_pack = lib.point_pack point_unpack = compose((lib.point_unpack, Point)) def box_pack(x): return lib.box_pack((x[0][0], x[0][1], x[1][0], x[1][1])) box_unpack = compose((lib.box_unpack, two_pair, Box,)) def lseg_pack(x, pack = lib.lseg_pack): return pack((x[0][0], x[0][1], x[1][0], x[1][1])) lseg_unpack = compose((lib.lseg_unpack, two_pair, Lseg)) def circle_pack(x): return lib.circle_pack((x[0][0], x[0][1], x[1])) def circle_unpack(x, unpack = lib.circle_unpack, Circle = Circle): x = unpack(x) return Circle(((x[0], x[1]), x[2])) # Map type oids to a (pack, unpack) pair. oid_to_io = { POINTOID : (point_pack, point_unpack, Point), BOXOID : (box_pack, box_unpack, Box), LSEGOID : (lseg_pack, lseg_unpack, Lseg), CIRCLEOID : (circle_pack, circle_unpack, Circle), #PATHOID : (path_pack, path_unpack), #POLYGONOID : (polygon_pack, polygon_unpack), #LINEOID : (line_pack, line_unpack), } fe-1.1.0/postgresql/types/io/pg_network.py000066400000000000000000000017121203372773200206210ustar00rootroot00000000000000from .. import INETOID, CIDROID, MACADDROID from . import lib import ipaddress oid_to_type = { MACADDROID : str, INETOID: ipaddress._IPAddressBase, CIDROID: ipaddress._BaseNetwork, } def inet_pack(ob, pack = lib.net_pack, Constructor = ipaddress.ip_address): a = Constructor(ob) return pack((a.version, None, a.packed)) def cidr_pack(ob, pack = lib.net_pack, Constructor = ipaddress.ip_network): a = Constructor(ob) return pack((a.version, a.prefixlen, a.network_address.packed)) def inet_unpack(data, unpack = lib.net_unpack, Constructor = ipaddress.ip_address): version, mask, data = unpack(data) return Constructor(data) def cidr_unpack(data, unpack = lib.net_unpack, Constructor = ipaddress.ip_network): version, mask, data = unpack(data) return Constructor(data).supernet(new_prefix=mask) oid_to_io = { MACADDROID : (lib.macaddr_pack, lib.macaddr_unpack, str), CIDROID : (cidr_pack, cidr_unpack, str), INETOID : (inet_pack, inet_unpack, str), } fe-1.1.0/postgresql/types/io/pg_system.py000066400000000000000000000004561203372773200204600ustar00rootroot00000000000000from ...types import OIDOID, XIDOID, CIDOID, TIDOID from . import lib oid_to_io = { OIDOID : (lib.oid_pack, lib.oid_unpack), XIDOID : (lib.xid_pack, lib.xid_unpack), CIDOID : (lib.cid_pack, lib.cid_unpack), TIDOID : (lib.tid_pack, lib.tid_unpack), #ACLITEMOID : (aclitem_pack, aclitem_unpack), } fe-1.1.0/postgresql/types/io/stdlib_datetime.py000066400000000000000000000211741203372773200216030ustar00rootroot00000000000000## # stdlib_datetime - support for the stdlib's datetime. # # I/O routines for date, time, timetz, timestamp, timestamptz, and interval. # Supported by the datetime module. ## import datetime import warnings from functools import partial from operator import methodcaller, add from ...python.datetime import UTC, FixedOffset, \ infinity_date, infinity_datetime, \ negative_infinity_date, negative_infinity_datetime from ...python.functools import Composition as compose from ...exceptions import TypeConversionWarning from .. import \ DATEOID, INTERVALOID, \ TIMEOID, TIMETZOID, \ TIMESTAMPOID, TIMESTAMPTZOID from . import lib oid_to_type = { DATEOID: datetime.date, TIMESTAMPOID: datetime.datetime, TIMESTAMPTZOID: datetime.datetime, TIMEOID: datetime.time, TIMETZOID: datetime.time, # XXX: datetime.timedelta doesn't support months. INTERVALOID: datetime.timedelta, } seconds_in_day = 24 * 60 * 60 seconds_in_hour = 60 * 60 pg_epoch_datetime = datetime.datetime(2000, 1, 1) pg_epoch_datetime_utc = pg_epoch_datetime.replace(tzinfo = UTC) pg_epoch_date = pg_epoch_datetime.date() pg_date_offset = pg_epoch_date.toordinal() ## Difference between PostgreSQL epoch and Unix epoch. ## Used to convert a PostgreSQL ordinal to an ordinal usable by datetime pg_time_days = (pg_date_offset - datetime.date(1970, 1, 1).toordinal()) ## # Constants used to special case infinity and -infinity. time64_pack_constants = { infinity_datetime: lib.time64_infinity, negative_infinity_datetime: lib.time64_negative_infinity, 'infinity': lib.time64_infinity, '-infinity': lib.time64_negative_infinity, } time_pack_constants = { infinity_datetime: lib.time_infinity, negative_infinity_datetime: lib.time_negative_infinity, 'infinity': lib.time_infinity, '-infinity': lib.time_negative_infinity, } date_pack_constants = { infinity_date: lib.date_infinity, negative_infinity_date: lib.date_negative_infinity, 'infinity': lib.date_infinity, '-infinity': lib.date_negative_infinity, } time64_unpack_constants = { lib.time64_infinity: infinity_datetime, lib.time64_negative_infinity: negative_infinity_datetime, } time_unpack_constants = { lib.time_infinity: infinity_datetime, lib.time_negative_infinity: negative_infinity_datetime, } date_unpack_constants = { lib.date_infinity: infinity_date, lib.date_negative_infinity: negative_infinity_date, } def date_pack(x, pack = lib.date_pack, offset = pg_date_offset, get = date_pack_constants.get, ): return get(x) or pack(x.toordinal() - offset) def date_unpack(x, unpack = lib.date_unpack, offset = pg_date_offset, from_ord = datetime.date.fromordinal, get = date_unpack_constants.get, ): return get(x) or from_ord(unpack(x) + pg_date_offset) def timestamp_pack(x, seconds_in_day = seconds_in_day, pg_epoch_datetime = pg_epoch_datetime, ): """ Create a (seconds, microseconds) pair from a `datetime.datetime` instance. """ x = (x - pg_epoch_datetime) return ((x.days * seconds_in_day) + x.seconds, x.microseconds) def timestamp_unpack(seconds, timedelta = datetime.timedelta, relative_to = pg_epoch_datetime.__add__, ): """ Create a `datetime.datetime` instance from a (seconds, microseconds) pair. """ return relative_to(timedelta(0, *seconds)) def timestamptz_pack(x, seconds_in_day = seconds_in_day, pg_epoch_datetime_utc = pg_epoch_datetime_utc, UTC = UTC, ): """ Create a (seconds, microseconds) pair from a `datetime.datetime` instance. """ x = (x.astimezone(UTC) - pg_epoch_datetime_utc) return ((x.days * seconds_in_day) + x.seconds, x.microseconds) def timestamptz_unpack(seconds, timedelta = datetime.timedelta, relative_to = pg_epoch_datetime_utc.__add__, ): """ Create a `datetime.datetime` instance from a (seconds, microseconds) pair. """ return relative_to(timedelta(0, *seconds)) def time_pack(x, seconds_in_hour = seconds_in_hour): """ Create a (seconds, microseconds) pair from a `datetime.time` instance. """ return ( (x.hour * seconds_in_hour) + (x.minute * 60) + x.second, x.microsecond ) def time_unpack(seconds_ms, time = datetime.time, divmod = divmod): """ Create a `datetime.time` instance from a (seconds, microseconds) pair. Seconds being offset from epoch. """ seconds, ms = seconds_ms minutes, sec = divmod(seconds, 60) hours, min = divmod(minutes, 60) return time(hours, min, sec, ms) def interval_pack(x): """ Create a (months, days, (seconds, microseconds)) tuple from a `datetime.timedelta` instance. """ return (0, x.days, (x.seconds, x.microseconds)) def interval_unpack(mds, timedelta = datetime.timedelta): """ Given a (months, days, (seconds, microseconds)) tuple, create a `datetime.timedelta` instance. """ months, days, seconds_ms = mds if months != 0: # XXX: Should this raise an exception? w = TypeConversionWarning( "datetime.timedelta cannot represent relative intervals", details = { 'hint': 'An interval was unpacked with a non-zero "month" field.' }, source = 'DRIVER' ) warnings.warn(w) return timedelta( days = days + (months * 30), seconds = seconds_ms[0], microseconds = seconds_ms[1] ) def timetz_pack(x, time_pack = time_pack, ): """ Create a ((seconds, microseconds), timezone) tuple from a `datetime.time` instance. """ td = x.tzinfo.utcoffset(x) seconds = (td.days * seconds_in_day + td.seconds) return (time_pack(x), seconds) def timetz_unpack(tstz, time_unpack = time_unpack, FixedOffset = FixedOffset, ): """ Create a `datetime.time` instance from a ((seconds, microseconds), timezone) tuple. """ t = time_unpack(tstz[0]) return t.replace(tzinfo = FixedOffset(tstz[1])) FloatTimes = False IntTimes = True NoDay = True WithDay = False # Used to handle the special cases: infinity and -infinity. def proc_when_not_in(proc, dict): def _proc(x, get=dict.get): return get(x) or proc(x) return _proc id_to_io = { (FloatTimes, TIMEOID) : ( compose((time_pack, lib.time_pack)), compose((lib.time_unpack, time_unpack)), datetime.time ), (FloatTimes, TIMETZOID) : ( compose((timetz_pack, lib.timetz_pack)), compose((lib.timetz_unpack, timetz_unpack)), datetime.time ), (FloatTimes, TIMESTAMPOID) : ( proc_when_not_in(compose((timestamp_pack, lib.time_pack)), time_pack_constants), proc_when_not_in(compose((lib.time_unpack, timestamp_unpack)), time_unpack_constants), datetime.datetime ), (FloatTimes, TIMESTAMPTZOID) : ( proc_when_not_in(compose((timestamptz_pack, lib.time_pack)), time_pack_constants), proc_when_not_in(compose((lib.time_unpack, timestamptz_unpack)), time_unpack_constants), datetime.datetime ), (FloatTimes, WithDay, INTERVALOID): ( compose((interval_pack, lib.interval_pack)), compose((lib.interval_unpack, interval_unpack)), datetime.timedelta ), (FloatTimes, NoDay, INTERVALOID): ( compose((interval_pack, lib.interval_noday_pack)), compose((lib.interval_noday_unpack, interval_unpack)), datetime.timedelta ), (IntTimes, TIMEOID) : ( compose((time_pack, lib.time64_pack)), compose((lib.time64_unpack, time_unpack)), datetime.time ), (IntTimes, TIMETZOID) : ( compose((timetz_pack, lib.timetz64_pack)), compose((lib.timetz64_unpack, timetz_unpack)), datetime.time ), (IntTimes, TIMESTAMPOID) : ( proc_when_not_in(compose((timestamp_pack, lib.time64_pack)), time64_pack_constants), proc_when_not_in(compose((lib.time64_unpack, timestamp_unpack)), time64_unpack_constants), datetime.datetime ), (IntTimes, TIMESTAMPTZOID) : ( proc_when_not_in(compose((timestamptz_pack, lib.time64_pack)), time64_pack_constants), proc_when_not_in(compose((lib.time64_unpack, timestamptz_unpack)), time64_unpack_constants), datetime.datetime ), (IntTimes, WithDay, INTERVALOID) : ( compose((interval_pack, lib.interval64_pack)), compose((lib.interval64_unpack, interval_unpack)), datetime.timedelta ), (IntTimes, NoDay, INTERVALOID) : ( compose((interval_pack, lib.interval64_noday_pack)), compose((lib.interval64_noday_unpack, interval_unpack)), datetime.timedelta ), } ## # Identify whether it's IntTimes or FloatTimes def time_type(typio): idt = typio.database.settings.get('integer_datetimes', None) if idt is None: # assume its absence means its on after 9.0 return bool(typio.database.version_info >= (9,0)) elif idt.__class__ is bool: return idt else: return (idt.lower() in ('on', 'true', 't', True)) def select_format(oid, typio, get = id_to_io.__getitem__): return get((time_type(typio), oid)) def select_day_format(oid, typio, get = id_to_io.__getitem__): return get((time_type(typio), typio.database.version_info[:2] <= (8,0), oid)) oid_to_io = { DATEOID : (date_pack, date_unpack, datetime.date,), TIMEOID : select_format, TIMETZOID : select_format, TIMESTAMPOID : select_format, TIMESTAMPTZOID : select_format, INTERVALOID : select_day_format, } fe-1.1.0/postgresql/types/io/stdlib_decimal.py000066400000000000000000000104341203372773200214020ustar00rootroot00000000000000## # types.io.stdlib_decimal # # I/O routines for transforming NUMERIC to and from decimal.Decimal. ## from decimal import Decimal from operator import itemgetter, mul # You know it's gonna get serious :) from itertools import chain, starmap, repeat, groupby, cycle, islice from ...types import NUMERICOID from . import lib oid_to_type = { NUMERICOID: Decimal, } ## # numeric is represented using: # 1. ndigits, the number of *numeric* digits. # 2. weight, the *numeric* digits "left" of the decimal point # 3. sign, negativity. see `numeric_signs` below # 4. dscale, *display* precision. used to identify exponent. # # NOTE: A numeric digit is actually four digits in the representation. # # Python's Decimal consists of: # 1. sign, negativity. # 2. digits, sequence of int()'s # 3. exponent, digits that fall to the right of the decimal point numeric_negative = 16384 def numeric_pack(x, numeric_digit_length : "number of decimal digits in a numeric digit" = 4, get0 = itemgetter(0), get1 = itemgetter(1), Decimal = Decimal, pack = lib.numeric_pack ): if not isinstance(x, Decimal): x = Decimal(x) x = x.as_tuple() if x.exponent == 'F': raise ValueError("numeric does not support infinite values") # normalize trailing zeros (truncate em') # this is important in order to get the weight and padding correct # and to avoid packing superfluous data which will make pg angry. trailing_zeros = 0 weight = 0 if x.exponent < 0: # only attempt to truncate if there are digits after the point, ## for i in range(-1, max(-len(x.digits), x.exponent)-1, -1): if x.digits[i] != 0: break trailing_zeros += 1 # truncate trailing zeros right of the decimal point # this *is* the case as exponent < 0. if trailing_zeros: digits = x.digits[:-trailing_zeros] else: digits = x.digits # the entire exponent is just trailing zeros(zero-weight). rdigits = -(x.exponent + trailing_zeros) ldigits = len(digits) - rdigits rpad = rdigits % numeric_digit_length if rpad: rpad = numeric_digit_length - rpad else: # Need the weight to be divisible by four, # so append zeros onto digits until it is. r = (x.exponent % numeric_digit_length) if x.exponent and r: digits = x.digits + ((0,) * r) weight = (x.exponent - r) else: digits = x.digits weight = x.exponent # The exponent is not evenly divisible by four, so # the weight can't simple be x.exponent as it doesn't # match the size of the numeric digit. ldigits = len(digits) # no fractional quantity. rdigits = 0 rpad = 0 lpad = ldigits % numeric_digit_length if lpad: lpad = numeric_digit_length - lpad weight += (ldigits + lpad) digit_groups = map( get1, groupby( zip( # group by NUMERIC digit size, # every four digits make up a NUMERIC digit cycle((0,) * numeric_digit_length + (1,) * numeric_digit_length), # multiply each digit appropriately # for the eventual sum() into a NUMERIC digit starmap( mul, zip( # pad with leading zeros to make # the cardinality of the digit sequence # to be evenly divisible by four, # the NUMERIC digit size. chain( repeat(0, lpad), digits, repeat(0, rpad), ), cycle([10**x for x in range(numeric_digit_length-1, -1, -1)]), ) ), ), get0, ), ) return pack(( ( (ldigits + rdigits + lpad + rpad) // numeric_digit_length, # ndigits (weight // numeric_digit_length) - 1, # NUMERIC weight numeric_negative if x.sign == 1 else x.sign, # sign - x.exponent if x.exponent < 0 else 0, # dscale ), list(map(sum, ([get1(y) for y in x] for x in digit_groups))), )) def numeric_convert_digits(d, str = str, int = int): i = iter(d) for x in str(next(i)): # no leading zeros yield int(x) # leading digit should not include zeros for y in i: for x in str(y).rjust(4, '0'): yield int(x) numeric_signs = { numeric_negative : 1, } def numeric_unpack(x, unpack = lib.numeric_unpack): header, digits = unpack(x) npad = (header[3] - ((header[0] - (header[1] + 1)) * 4)) return Decimal(( numeric_signs.get(header[2], header[2]), tuple(chain( numeric_convert_digits(digits), (0,) * npad ) if npad >= 0 else list( numeric_convert_digits(digits) )[:npad]), -header[3] )) oid_to_io = { NUMERICOID : (numeric_pack, numeric_unpack, Decimal), } fe-1.1.0/postgresql/types/io/stdlib_uuid.py000066400000000000000000000004361203372773200207530ustar00rootroot00000000000000import uuid from ...types import UUIDOID def uuid_pack(x, UUID = uuid.UUID, bytes = bytes): if isinstance(x, UUID): return bytes(x.bytes) return bytes(UUID(x).bytes) def uuid_unpack(x, UUID = uuid.UUID): return UUID(bytes=x) oid_to_io = { UUIDOID : (uuid_pack, uuid_unpack), } fe-1.1.0/postgresql/types/io/stdlib_xml_etree.py000066400000000000000000000044421203372773200217720ustar00rootroot00000000000000## # types.io.stdlib_xml_etree ## try: import xml.etree.cElementTree as etree except ImportError: import xml.etree.ElementTree as etree from .. import XMLOID from ...python.functools import Composition as compose oid_to_type = { XMLOID: etree.ElementTree, } def xml_unpack(xmldata, XML = etree.XML): try: return XML(xmldata) except Exception: # try it again, but return the sequence of children. return tuple(XML('' + xmldata + '')) if not hasattr(etree, 'tostringlist'): # Python 3.1 support. def xml_pack(xml, tostr = etree.tostring, et = etree.ElementTree, str = str, isinstance = isinstance, tuple = tuple ): if isinstance(xml, str): # If it's a string, encode and return. return xml elif isinstance(xml, tuple): # If it's a tuple, encode and return the joined items. # We do not accept lists here--emphasizing lists being used for ARRAY # bounds. return ''.join((x if isinstance(x, str) else tostr(x) for x in xml)) return tostr(xml) def xml_io_factory(typoid, typio, c = compose): return ( c((xml_pack, typio.encode)), c((typio.decode, xml_unpack)), etree.ElementTree, ) else: # New etree tostring API. def xml_pack(xml, encoding, encoder, tostr = etree.tostring, et = etree.ElementTree, str = str, isinstance = isinstance, tuple = tuple, ): if isinstance(xml, bytes): return xml if isinstance(xml, str): # If it's a string, encode and return. return encoder(xml) elif isinstance(xml, tuple): # If it's a tuple, encode and return the joined items. # We do not accept lists here--emphasizing lists being used for ARRAY # bounds. ## # 3.2 # XXX: tostring doesn't include declaration with utf-8? x = b''.join( x.encode('utf-8') if isinstance(x, str) else tostr(x, encoding = "utf-8") for x in xml ) else: ## # 3.2 # XXX: tostring doesn't include declaration with utf-8? x = tostr(xml, encoding = "utf-8") if encoding in ('utf8','utf-8'): return x else: return encoder(x.decode('utf-8')) def xml_io_factory(typoid, typio, c = compose): def local_xml_pack(x, encoder = typio.encode, typio = typio, xml_pack = xml_pack): return xml_pack(x, typio.encoding, encoder) return (local_xml_pack, c((typio.decode, xml_unpack)), etree.ElementTree,) oid_to_io = { XMLOID : xml_io_factory } fe-1.1.0/postgresql/types/namedtuple.py000066400000000000000000000034361203372773200201760ustar00rootroot00000000000000## # .types.namedtuple - return rows as namedtuples ## """ Factories for namedtuple row representation. """ from collections import namedtuple #: Global namedtuple type cache. cache = {} # Build and cache the namedtuple's produced. def _factory(colnames : [str], namedtuple = namedtuple) -> tuple: global cache # Provide some normalization. # Anything beyond this can just get renamed. colnames = tuple([ x.replace(' ', '_') for x in colnames ]) try: return cache[colnames] except KeyError: NT = namedtuple('row', colnames, rename = True) cache[colnames] = NT return NT def NamedTupleFactory(attribute_map, composite_relid = None): """ Alternative db.typio.RowFactory for producing namedtuple's instead of postgresql.types.Row() instances. To install:: >>> from postgresql.types.namedtuple import NamedTupleFactory >>> import postgresql >>> db = postgresql.open(...) >>> db.typio.RowTypeFactory(NamedTupleFactory) And **all** Rows produced by that connection will be namedtuple()'s. This includes composites. """ colnames = list(attribute_map.items()) colnames.sort(key = lambda x: x[1]) return lambda y: _factory((x[0] for x in colnames))(*y) from itertools import chain, starmap def namedtuples(stmt, from_iter = chain.from_iterable, map = starmap): """ Alternative to the .rows() execution method. Use:: >>> from postgresql.types.namedtuple import namedtuples >>> ps = namedtuples(db.prepare(...)) >>> for nt in ps(...): ... nt.a_column_name This effectively selects the execution method to be used with the statement. """ NT = _factory(stmt.column_names) # build the execution "method" chunks = stmt.chunks def rows_as_namedtuples(*args, **kw): return map(NT, from_iter(chunks(*args, **kw))) # starmap return rows_as_namedtuples del chain, starmap fe-1.1.0/postgresql/versionstring.py000066400000000000000000000057231203372773200176110ustar00rootroot00000000000000## # .versionstring ## """ PostgreSQL version parsing. >>> postgresql.version.split('8.0.1') (8, 0, 1, None, None) """ def split(vstr : str) -> ( 'major','minor','patch',...,'state_class','state_level' ): """ Split a PostgreSQL version string into a tuple (major,minor,patch,...,state_class,state_level) """ v = vstr.strip().split('.') # Get rid of the numbers around the state_class (beta,a,dev,alpha, etc) state_class = v[-1].strip('0123456789') if state_class: last_version, state_level = v[-1].split(state_class) if not state_level: state_level = None else: state_level = int(state_level) vlist = [int(x or '0') for x in v[:-1]] if last_version: vlist.append(int(last_version)) vlist += [None] * (3 - len(vlist)) vlist += [state_class, state_level] else: state_level = None state_class = None vlist = [int(x or '0') for x in v] # pad the difference with `None` objects, and +2 for the state_*. vlist += [None] * ((3 - len(vlist)) + 2) return tuple(vlist) def unsplit(vtup : tuple) -> str: 'join a version tuple back into the original version string' svtup = [str(x) for x in vtup[:-2] if x is not None] state_class, state_level = vtup[-2:] return '.'.join(svtup) + ( '' if state_class is None else state_class + str(state_level) ) def normalize(split_version : "a tuple returned by `split`") -> tuple: """ Given a tuple produced by `split`, normalize the `None` objects into int(0) or 'final' if it's the ``state_class`` """ (*head, state_class, state_level) = split_version mmp = [x if x is not None else 0 for x in head] return tuple( mmp + [state_class or 'final', state_level or 0] ) default_state_class_priority = [ 'dev', 'a', 'alpha', 'b', 'beta', 'rc', 'final', None, ] python = repr def xml(self): return '\n' + \ ' ' + str(self[0]) + '\n' + \ ' ' + str(self[1]) + '\n' + \ ' ' + str(self[2]) + '\n' + \ ' ' + str(self[-2]) + '\n' + \ ' ' + str(self[-1]) + '\n' + \ '' def sh(self): return """PG_VERSION_MAJOR=%s PG_VERSION_MINOR=%s PG_VERSION_PATCH=%s PG_VERSION_STATE=%s PG_VERSION_LEVEL=%s""" %( str(self[0]), str(self[1]), str(self[2]), str(self[-2]), str(self[-1]), ) if __name__ == '__main__': import sys import os from optparse import OptionParser op = OptionParser() op.add_option('-f', '--format', type='choice', dest='format', help='format of output information', choices=('sh', 'xml', 'python'), default='sh', ) op.add_option('-n', '--normalize', action='store_true', dest='normalize', help='replace missing values with defaults', default=False, ) op.set_usage(op.get_usage().strip() + ' "version to parse"') co, ca = op.parse_args() if len(ca) != 1: op.error('requires exactly one argument, the version') else: v = split(ca[0]) if co.normalize: v = normalize(v) sys.stdout.write(getattr(sys.modules[__name__], co.format)(v)) sys.stdout.write(os.linesep) fe-1.1.0/setup.py000077500000000000000000000012621203372773200136270ustar00rootroot00000000000000#!/usr/bin/env python ## # setup.py - .release.distutils ## import sys import os if sys.version_info[:2] < (3,1): sys.stderr.write( "ERROR: py-postgresql is for Python 3.1 and greater." + os.linesep ) sys.stderr.write( "HINT: setup.py was ran using Python " + \ '.'.join([str(x) for x in sys.version_info[:3]]) + ': ' + sys.executable + os.linesep ) sys.exit(1) # distutils data is kept in `postgresql.release.distutils` sys.path.insert(0, '') sys.dont_write_bytecode = True import postgresql.release.distutils as dist defaults = dist.standard_setup_keywords() sys.dont_write_bytecode = False if __name__ == '__main__': from distutils.core import setup setup(**defaults)